import asyncio
import csv
import os
import sys
from typing import Any

# Ensure we can import from backend root
current_dir = os.path.dirname(os.path.abspath(__file__))
backend_root = os.path.dirname(current_dir)
sys.path.append(backend_root)

from common.starrocks_connection import StarRocksConnection
from config import STARROCKS_DB, STARROCKS_HOST, STARROCKS_PASSWORD, STARROCKS_USER

TABLE_NAME = "shared_source.magento_product_dimension_with_text_embedding"
DEFAULT_SIZES = [90, 92, 98, 104, 110, 116, 122, 128, 134, 140, 152, 164]
OUTPUT_CSV = os.path.join(current_dir, "numeric_size_skus.csv")


def _get_missing_env() -> list[str]:
    missing = []
    if not STARROCKS_HOST:
        missing.append("STARROCKS_HOST")
    if not STARROCKS_DB:
        missing.append("STARROCKS_DB")
    if not STARROCKS_USER:
        missing.append("STARROCKS_USER")
    if not STARROCKS_PASSWORD:
        missing.append("STARROCKS_PASSWORD")
    return missing


def _skip_or_warn_if_missing_env() -> bool:
    missing = _get_missing_env()
    if not missing:
        return False

    message = f"Missing StarRocks env vars: {', '.join(missing)}"
    if "PYTEST_CURRENT_TEST" in os.environ:
        import pytest

        pytest.skip(message)

    print(f"[SKIP] {message}")
    return True


def _parse_sizes_env() -> list[int]:
    raw = os.getenv("NUMERIC_SIZES")
    if not raw:
        return DEFAULT_SIZES

    sizes: list[int] = []
    for token in raw.split(","):
        token = token.strip()
        if not token:
            continue
        try:
            sizes.append(int(token))
        except ValueError:
            continue

    return sizes or DEFAULT_SIZES


def _build_regex_pattern(sizes: list[int]) -> str:
    sizes_str = "|".join(str(s) for s in sorted(set(sizes)))
    # Match tokens like 140 or 140cm inside pipe-delimited lists.
    return rf"(^|\\|)({sizes_str})(cm)?(\\||$)"


async def fetch_numeric_size_rows(sizes: list[int]) -> list[dict[str, Any]]:
    db = StarRocksConnection()
    pattern = _build_regex_pattern(sizes)

    sql = f"""
    SELECT
        internal_ref_code,
        magento_ref_code,
        size_scale
    FROM {TABLE_NAME}
    WHERE LOWER(size_scale) REGEXP %s
    GROUP BY internal_ref_code, magento_ref_code, size_scale
    ORDER BY internal_ref_code, magento_ref_code
    """

    return await db.execute_query_async(sql, params=(pattern,))


def _write_csv(rows: list[dict[str, Any]], path: str) -> None:
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", newline="", encoding="utf-8") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(["internal_ref_code", "magento_ref_code", "size_scale"])
        for row in rows:
            writer.writerow(
                [row.get("internal_ref_code"), row.get("magento_ref_code"), row.get("size_scale")]
            )


def _print_summary(rows: list[dict[str, Any]], sizes: list[int]) -> None:
    internal_codes = {row.get("internal_ref_code") for row in rows}
    magento_codes = {row.get("magento_ref_code") for row in rows}

    print("\n" + "=" * 80)
    print("NUMERIC SIZE SKUS")
    print("=" * 80)
    print(f"Table: {TABLE_NAME}")
    print(f"Sizes filter: {', '.join(str(s) for s in sorted(set(sizes)))}")
    print(f"Matched rows: {len(rows)}")
    print(f"Distinct internal_ref_code: {len(internal_codes)}")
    print(f"Distinct magento_ref_code: {len(magento_codes)}")


def _print_sample(rows: list[dict[str, Any]], limit: int = 30) -> None:
    print("\nSample (first 30 rows):")
    for row in rows[:limit]:
        print(
            f"- {row.get('internal_ref_code')} | {row.get('magento_ref_code')} | {row.get('size_scale')}"
        )


async def _run() -> None:
    if _skip_or_warn_if_missing_env():
        return

    sizes = _parse_sizes_env()
    rows = await fetch_numeric_size_rows(sizes)

    _print_summary(rows, sizes)
    _print_sample(rows, limit=30)
    _write_csv(rows, OUTPUT_CSV)
    print(f"\nCSV written to: {OUTPUT_CSV}")

    await StarRocksConnection.clear_pool()


if __name__ == "__main__":
    asyncio.run(_run())
