import asyncio
import logging
import os
import sys
from collections import Counter
from typing import Any

# Ensure we can import from backend root
current_dir = os.path.dirname(os.path.abspath(__file__))
backend_root = os.path.dirname(current_dir)
sys.path.append(backend_root)

from common.starrocks_connection import StarRocksConnection
from config import STARROCKS_DB, STARROCKS_HOST, STARROCKS_PASSWORD, STARROCKS_USER

logger = logging.getLogger(__name__)

TABLE_NAME = "shared_source.magento_product_dimension_with_text_embedding"
LETTER_SIZES = {
    "XXXS",
    "XXS",
    "XS",
    "S",
    "M",
    "L",
    "XL",
    "XXL",
    "XXXL",
}


def _get_missing_env() -> list[str]:
    missing = []
    if not STARROCKS_HOST:
        missing.append("STARROCKS_HOST")
    if not STARROCKS_DB:
        missing.append("STARROCKS_DB")
    if not STARROCKS_USER:
        missing.append("STARROCKS_USER")
    if not STARROCKS_PASSWORD:
        missing.append("STARROCKS_PASSWORD")
    return missing


def _skip_or_warn_if_missing_env() -> bool:
    missing = _get_missing_env()
    if not missing:
        return False

    message = f"Missing StarRocks env vars: {', '.join(missing)}"
    if "PYTEST_CURRENT_TEST" in os.environ:
        import pytest

        pytest.skip(message)

    print(f"[SKIP] {message}")
    return True


def _get_limit_from_env() -> int | None:
    raw = os.getenv("SIZE_SCALE_LIMIT")
    if not raw:
        return None
    try:
        value = int(raw)
    except ValueError:
        return None
    return max(1, value)


async def fetch_size_scale_list(limit: int | None = None) -> list[dict[str, Any]]:
    db = StarRocksConnection()
    limit_clause = f" LIMIT {limit}" if limit else ""

    sql = f"""
    SELECT
        size_scale,
        COUNT(*) AS row_count
    FROM {TABLE_NAME}
    GROUP BY size_scale
    ORDER BY size_scale ASC{limit_clause}
    """

    return await db.execute_query_async(sql)


async def fetch_size_scale_summary() -> dict[str, int]:
    db = StarRocksConnection()

    total_sql = f"SELECT COUNT(*) AS total_rows FROM {TABLE_NAME}"
    distinct_sql = (
        f"SELECT COUNT(DISTINCT size_scale) AS distinct_size_scale FROM {TABLE_NAME}"
    )
    null_sql = (
        f"SELECT COUNT(*) AS null_size_scale FROM {TABLE_NAME} "
        "WHERE size_scale IS NULL OR size_scale = ''"
    )

    total_rows = await db.execute_query_async(total_sql)
    distinct_sizes = await db.execute_query_async(distinct_sql)
    null_sizes = await db.execute_query_async(null_sql)

    return {
        "total_rows": int(total_rows[0]["total_rows"]) if total_rows else 0,
        "distinct_size_scale": (
            int(distinct_sizes[0]["distinct_size_scale"]) if distinct_sizes else 0
        ),
        "null_size_scale": int(null_sizes[0]["null_size_scale"]) if null_sizes else 0,
    }


def _normalize_token(token: str) -> str:
    return token.strip()


def _split_size_scale(size_scale: str | None) -> list[str]:
    if not size_scale:
        return []
    tokens = [_normalize_token(t) for t in size_scale.split("|")]
    return [t for t in tokens if t]


def _build_token_summary(rows: list[dict[str, Any]]) -> dict[str, Any]:
    token_counter: Counter[str] = Counter()
    letter_counter: Counter[str] = Counter()
    numeric_counter: Counter[str] = Counter()
    other_counter: Counter[str] = Counter()

    for row in rows:
        size_scale = row.get("size_scale")
        row_count = int(row.get("row_count") or 0)
        tokens = _split_size_scale(size_scale)
        for token in tokens:
            token_counter[token] += row_count
            if token in LETTER_SIZES:
                letter_counter[token] += row_count
            elif token.replace(".", "", 1).isdigit():
                numeric_counter[token] += row_count
            else:
                other_counter[token] += row_count

    return {
        "all_tokens": token_counter,
        "letter_tokens": letter_counter,
        "numeric_tokens": numeric_counter,
        "other_tokens": other_counter,
    }


def test_starrocks_size_scale_list():
    if _skip_or_warn_if_missing_env():
        return

    rows = asyncio.run(fetch_size_scale_list(limit=10))
    assert isinstance(rows, list)


async def _run():
    if _skip_or_warn_if_missing_env():
        return

    print_limit = _get_limit_from_env()
    summary = await fetch_size_scale_summary()
    rows = await fetch_size_scale_list()
    token_summary = _build_token_summary(rows)

    print("\n" + "=" * 80)
    print("STARROCKS SIZE SCALE LIST")
    print("=" * 80)
    print(f"Table: {TABLE_NAME}")
    print(f"Total rows: {summary['total_rows']}")
    print(f"Distinct size_scale: {summary['distinct_size_scale']}")
    print(f"Null/empty size_scale: {summary['null_size_scale']}")
    if print_limit:
        print(f"Print limit: {print_limit}")
    print("\nNormalized token counts (top 50):")
    for token, count in token_summary["all_tokens"].most_common(50):
        print(f"- {token}: {count}")
    print("\nLetter sizes:")
    for token in sorted(LETTER_SIZES):
        count = token_summary["letter_tokens"].get(token, 0)
        if count:
            print(f"- {token}: {count}")
    print("\nNumeric sizes (top 50):")
    for token, count in token_summary["numeric_tokens"].most_common(50):
        print(f"- {token}: {count}")
    print("\nOther tokens (top 50):")
    for token, count in token_summary["other_tokens"].most_common(50):
        print(f"- {token}: {count}")
    print("\nsize_scale\trow_count")

    rows_to_print = rows[:print_limit] if print_limit else rows
    for row in rows_to_print:
        size_scale = row.get("size_scale")
        row_count = row.get("row_count")
        if size_scale in (None, ""):
            size_scale = "<NULL/EMPTY>"
        print(f"{size_scale}\t{row_count}")

    await StarRocksConnection.clear_pool()


if __name__ == "__main__":
    asyncio.run(_run())
