import logging
from typing import Any

import httpx
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field

from common.starrocks_connection import StarRocksConnection

logger = logging.getLogger(__name__)
router = APIRouter()

STOCK_API_URL = "https://canifa.com/v1/middleware/stock_get_stock_list"
DEFAULT_MAX_SKUS = 200
DEFAULT_CHUNK_SIZE = 50
TABLE_NAME = "shared_source.magento_product_dimension_with_text_embedding"


class StockExpandRequest(BaseModel):
    codes: str = Field(
        description=(
            "Comma-separated product codes. Supports base codes, product_color_code, "
            "or full SKU (code-color-size). Example: '6ST25W005,6ST25W005-SE091-L'"
        )
    )
    sizes: str | None = Field(
        default=None,
        description="Optional comma-separated sizes to filter (e.g. 'S,M,L,XL,140').",
    )
    max_skus: int = Field(default=DEFAULT_MAX_SKUS, ge=1)
    chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, ge=1)
    expand_only: bool = Field(default=False)
    truncate: bool = Field(default=True)
    timeout_sec: float = Field(default=10.0, gt=0)


def _split_csv(value: str | None) -> list[str]:
    if not value:
        return []
    return [token.strip() for token in value.split(",") if token.strip()]


def _normalize_size(token: str) -> str:
    normalized = token.strip().upper()
    if normalized.endswith("CM"):
        normalized = normalized[:-2]
    return normalized


def _is_full_sku(token: str) -> bool:
    return token.count("-") >= 2


def _chunked(items: list[str], size: int) -> list[list[str]]:
    return [items[i : i + size] for i in range(0, len(items), size)]


async def _fetch_variants(codes: list[str]) -> list[dict[str, Any]]:
    if not codes:
        return []

    placeholders = ",".join(["%s"] * len(codes))
    sql = f"""
    SELECT
        internal_ref_code,
        magento_ref_code,
        product_color_code,
        size_scale
    FROM {TABLE_NAME}
    WHERE internal_ref_code IN ({placeholders})
       OR magento_ref_code IN ({placeholders})
       OR product_color_code IN ({placeholders})
    GROUP BY internal_ref_code, magento_ref_code, product_color_code, size_scale
    """

    params = codes * 3
    db = StarRocksConnection()
    return await db.execute_query_async(sql, params=tuple(params))


@router.post("/api/stock/check", summary="Expand product codes and check stock")
async def check_stock(req: StockExpandRequest):
    """
    Expand base codes to full SKUs using StarRocks, then call Canifa stock API.
    """
    input_codes = _split_csv(req.codes)
    if not input_codes:
        raise HTTPException(status_code=400, detail="codes is required")

    size_filter = {_normalize_size(s) for s in _split_csv(req.sizes)} if req.sizes else None

    full_skus: list[str] = []
    lookup_codes: list[str] = []

    for token in input_codes:
        if _is_full_sku(token):
            full_skus.append(token)
        else:
            lookup_codes.append(token)

    variant_rows = await _fetch_variants(lookup_codes)

    expanded_skus: list[str] = []
    missing_size_color_codes: list[str] = []
    for row in variant_rows:
        product_color_code = row.get("product_color_code")
        size_scale = row.get("size_scale")
        if not product_color_code:
            continue
        if not size_scale:
            missing_size_color_codes.append(product_color_code)
            continue

        for raw_token in str(size_scale).split("|"):
            token = raw_token.strip()
            if not token:
                continue
            normalized = _normalize_size(token)
            if size_filter and normalized not in size_filter:
                continue
            expanded_skus.append(f"{product_color_code}-{normalized}")

    # Deduplicate while preserving order
    seen = set()
    ordered_skus: list[str] = []
    for sku in full_skus + expanded_skus:
        if sku not in seen:
            seen.add(sku)
            ordered_skus.append(sku)

    truncated = False
    if len(ordered_skus) > req.max_skus:
        if req.truncate:
            ordered_skus = ordered_skus[: req.max_skus]
            truncated = True
        else:
            raise HTTPException(
                status_code=400,
                detail=f"Expanded SKU count {len(ordered_skus)} exceeds max_skus {req.max_skus}",
            )

    response_payload = {
        "status": "success",
        "input_codes": input_codes,
        "lookup_codes": lookup_codes,
        "input_full_skus": full_skus,
        "expanded_skus_count": len(expanded_skus),
        "requested_skus_count": len(ordered_skus),
        "requested_skus": ordered_skus,
        "missing_size_color_codes": missing_size_color_codes,
        "truncated": truncated,
    }

    if req.expand_only:
        return response_payload

    if not ordered_skus:
        response_payload["stock_responses"] = []
        return response_payload

    try:
        stock_responses: list[dict[str, Any]] = []
        async with httpx.AsyncClient(timeout=req.timeout_sec) as client:
            for chunk in _chunked(ordered_skus, req.chunk_size):
                resp = await client.get(STOCK_API_URL, params={"skus": ",".join(chunk)})
                resp.raise_for_status()
                stock_responses.append(resp.json())
        response_payload["stock_responses"] = stock_responses
        return response_payload
    except httpx.RequestError as exc:
        logger.error(f"Network error checking stock: {exc}")
        raise HTTPException(status_code=502, detail=f"Network error: {exc}") from exc
    except httpx.HTTPStatusError as exc:
        logger.error(f"HTTP error checking stock: {exc}")
        raise HTTPException(status_code=502, detail=f"Stock API error: {exc}") from exc
    except Exception as exc:
        logger.error(f"Unexpected error checking stock: {exc}")
        raise HTTPException(status_code=500, detail=f"Unexpected error: {exc}") from exc
