"""
CANIFA Data Retrieval Tool - Tối giản cho Agentic Workflow.
Hỗ trợ Hybrid Search: Semantic (Vector) + Metadata Filter.
"""

import asyncio
import json
import logging
import time

from langchain_core.tools import tool
from pydantic import BaseModel, Field

from agent.tools.data_retrieval_filter import (
    COLOR_MAP,
    filter_by_age,
    filter_by_gender,
    filter_by_product_name,
    filter_with_priority,
    format_product_results,
)
from agent.tools.product_search_helpers import build_starrocks_query
from agent.tools.stock_helpers import fetch_stock_for_skus
from common.starrocks_connection import get_db_connection

# Setup Logger
logger = logging.getLogger(__name__)

from agent.prompt_utils import read_tool_prompt

PRODUCT_NAME_KEYWORDS = [
    ("áo sơ mi", "Áo sơ mi"),
    ("ao so mi", "Áo sơ mi"),
    ("sơ mi", "Áo sơ mi"),
    ("so mi", "Áo sơ mi"),
    ("chân váy", "Chân váy"),
    ("chan vay", "Chân váy"),
    ("váy liền thân", "Váy liền thân"),
    ("vay lien than", "Váy liền thân"),
    ("váy liền", "Váy liền thân"),
    ("vay lien", "Váy liền thân"),
    ("váy đầm", "Váy liền thân"),
    ("vay dam", "Váy liền thân"),
    ("đầm", "Váy liền thân"),
    ("dam", "Váy liền thân"),
    ("váy", "Váy"),
    ("vay", "Váy"),
    ("áo khoác", "Áo khoác"),
    ("ao khoac", "Áo khoác"),
    ("áo len", "Áo len"),
    ("ao len", "Áo len"),
    ("áo thun", "Áo thun"),
    ("ao thun", "Áo thun"),
    ("áo polo", "Áo polo"),
    ("ao polo", "Áo polo"),
    ("hoodie", "Áo hoodie"),
    ("áo hoodie", "Áo hoodie"),
    ("ao hoodie", "Áo hoodie"),
    ("quần jeans", "Quần jeans"),
    ("quan jeans", "Quần jeans"),
    ("quần short", "Quần short"),
    ("quan short", "Quần short"),
    ("quần dài", "Quần dài"),
    ("quan dai", "Quần dài"),
    ("quần", "Quần"),
    ("quan", "Quần"),
    ("áo", "Áo"),
    ("ao", "Áo"),
    ("phụ kiện", "Phụ kiện"),
    ("phu kien", "Phụ kiện"),
    ("túi", "Túi"),
    ("tui", "Túi"),
    ("mũ", "Mũ"),
    ("mu", "Mũ"),
    ("khăn", "Khăn"),
    ("khan", "Khăn"),
    ("tất", "Tất"),
    ("tat", "Tất"),
]


def infer_product_name_from_description(description: str | None) -> str | None:
    if not description:
        return None
    desc_lower = description.lower()
    for keyword, product_name in PRODUCT_NAME_KEYWORDS:
        if keyword in desc_lower:
            return product_name
    return None


class SearchItem(BaseModel):
    model_config = {"extra": "forbid"}  # STRICT MODE

    description: str = Field(
        description="Mô tả sản phẩm cần tìm (semantic search trong description_text). VD: 'váy tiểu thư', 'áo thun basic', 'đầm dự tiệc sang chảnh'"
    )
    # STRICT MODE REQUIREMENT: All fields must be in 'required' list. 
    # Use '...' (Ellipsis) or just don't set a default value, but keep type as | None.
    product_name: str | None = Field(
        description="CHỈ tên loại sản phẩm cơ bản: Áo, Váy, Quần, Chân váy, Áo khoác... KHÔNG bao gồm style/mô tả."
    )
    magento_ref_code: str | None = Field(description="Mã sản phẩm chính xác (SKU)")
    price_min: int | None = Field(description="Giá thấp nhất (VND)")
    price_max: int | None = Field(description="Giá cao nhất (VND)")

    # Metadata filters
    gender_by_product: str | None = Field(description="Giới tính (Nam/Nữ/Bé trai/Bé gái)")
    age_by_product: str | None = Field(description="Độ tuổi (Người lớn/Trẻ em)")
    master_color: str | None = Field(description="Màu sắc chính")
    form_sleeve: str | None = Field(description="Dáng tay áo")
    style: str | None = Field(
        description="Phong cách (CHỈ dùng: minimalist, classic, basic, sporty, elegant, casual, feminine). KHÔNG dùng cho từ mô tả như 'tiểu thư', 'sang chảnh'."
    )
    fitting: str | None = Field(description="Dáng đồ (Slim/Regular/Loose)")
    form_neckline: str | None = Field(description="Dáng cổ áo")
    material_group: str | None = Field(description="Chất liệu")
    season: str | None = Field(description="Mùa")

    # Extra fields for SQL match if needed
    product_line_vn: str | None = Field(description="Dòng sản phẩm (VN) cho lọc SQL")


class MultiSearchParams(BaseModel):
    model_config = {"extra": "forbid"}  # STRICT MODE
    searches: list[SearchItem] = Field(description="Danh sách các truy vấn tìm kiếm")


async def _execute_single_search(
    db, item: SearchItem, query_vector: list[float] | None = None
) -> tuple[list[dict], dict]:
    """
    Thực thi một search query đơn lẻ (Async).

    Returns:
        Tuple of (products, filter_info)
    """
    try:
        short_query = (
            (item.description[:60] + "...") if item.description and len(item.description) > 60 else item.description
        )
        logger.debug(
            "_execute_single_search started, query=%r, code=%r",
            short_query,
            item.magento_ref_code,
        )

        # Infer product_name if missing (avoid wrong category results)
        if not item.product_name:
            inferred_name = infer_product_name_from_description(item.description)
            if inferred_name:
                try:
                    item = item.model_copy(update={"product_name": inferred_name})
                except AttributeError:
                    item = item.copy(update={"product_name": inferred_name})
                logger.warning("🧭 Inferred product_name='%s' from description", inferred_name)

        # Timer: build query (sử dụng vector đã có hoặc build mới)
        query_build_start = time.time()
        sql, params = await build_starrocks_query(item, query_vector=query_vector)
        query_build_time = (time.time() - query_build_start) * 1000  # Convert to ms
        logger.debug("SQL built, length=%s, build_time_ms=%.2f", len(sql), query_build_time)

        if not sql:
            return []

        # Timer: execute DB query
        db_start = time.time()
        products = await db.execute_query_async(sql, params=params)
        db_time = (time.time() - db_start) * 1000  # Convert to ms
        logger.info(
            "_execute_single_search done, products=%s, build_ms=%.2f, db_ms=%.2f, total_ms=%.2f",
            len(products),
            query_build_time,
            db_time,
            query_build_time + db_time,
        )

        # Debug: Log first product to see fields
        if products:
            first_p = products[0]
            logger.info("🔍 [DEBUG] First product keys: %s", list(first_p.keys()))
            logger.info(
                "🔍 [DEBUG] First product price: %s, sale_price: %s",
                first_p.get("original_price"),
                first_p.get("sale_price"),
            )

        # ====== POST-FILTERS: Filter results by requested criteria ======
        original_count = len(products)
        all_filter_info = {}  # Aggregate fallback info for all filters

        logger.warning(
            "🔍 [POST-FILTER] Starting with %d products from DB. SearchItem params: product_name=%r, gender=%r, age=%r, color=%r",
            original_count,
            item.product_name,
            item.gender_by_product,
            item.age_by_product,
            item.master_color,
        )

        # ====== LAYER 1: HARD FILTERS (No fallback) ======
        # Filter by PRODUCT_NAME (HARD) - must match product type
        if item.product_name and products:
            before = len(products)
            products = filter_by_product_name(products, item.product_name)
            logger.warning(
                "📦 Product name filter (HARD): %s → %d→%d products", item.product_name, before, len(products)
            )

        # Filter by GENDER (HARD)
        if item.gender_by_product and products:
            before = len(products)
            products = filter_by_gender(products, item.gender_by_product)
            logger.warning(
                "👤 Gender filter (HARD): %s → %d→%d products", item.gender_by_product, before, len(products)
            )

        # Filter by AGE (HARD)
        if item.age_by_product and products:
            before = len(products)
            products = filter_by_age(products, item.age_by_product)
            logger.warning("🎂 Age filter (HARD): %s → %d→%d products", item.age_by_product, before, len(products))

        # ====== LAYER 2: SOFT FILTERS (With priority-based fallback) ======
        # Only apply if we still have products from HARD filters

        # 1. COLOR (with automatic fallback)
        if item.master_color and products:
            before_count = len(products)
            products, info = filter_with_priority(products, item.master_color, "master_color", COLOR_MAP, "Màu")
            # Store fallback info for Agent's response
            if info.get("fallback_used") or info.get("matched_value"):
                all_filter_info["color"] = info
                if info.get("fallback_used"):
                    logger.warning(
                        "🎨 COLOR FALLBACK: Requested '%s' → Found '%s' (%d products)",
                        info.get("requested_value"),
                        info.get("matched_value"),
                        len(products),
                    )
                else:
                    logger.info(
                        "🎨 Color filter: Matched '%s' exactly (%d products)", info.get("matched_value"), len(products)
                    )
            else:
                logger.warning("🎨 Color filter: NO MATCH - %d → %d products", before_count, len(products))

        # === DISABLED FILTERS (chỉ giữ name, gender, age, color) ===
        # 2. SLEEVE - DISABLED
        # 3. STYLE - DISABLED
        # 4. FITTING - DISABLED
        # 5. NECKLINE - DISABLED
        # 6. MATERIAL - DISABLED
        # 7. SEASON - DISABLED

        # Combine filter info - STRUCTURE RÕRANƠ CHO AGENT
        filter_info = {
            "fallback_used": any(info.get("fallback_used") for info in all_filter_info.values()),
            "filters_applied": all_filter_info,  # ← Chi tiết từng filter (fallback hay không)
        }

        # Build recommendation message for Agent
        # Agent sẽ dùng cái này để báo với khách hàng
        fallback_messages = [info.get("message") for info in all_filter_info.values() if info.get("message")]
        if fallback_messages:
            filter_info["recommendation_message"] = " ".join(fallback_messages)  # ← Báo khách cụ thể là có fallback

        # Log summary chi tiết
        if original_count != len(products):
            logger.info(
                "📊 Post-filter summary: %d → %d products. Fallback used: %s",
                original_count,
                len(products),
                filter_info.get("fallback_used"),
            )

        return format_product_results(products), filter_info
    except Exception as e:
        logger.exception("Single search error for item %r: %s", item, e)
        return [], {"fallback_used": False, "error": str(e)}


@tool(args_schema=MultiSearchParams)
async def data_retrieval_tool(searches: list[SearchItem]) -> str:
    """
    Công cụ tìm kiếm sản phẩm CANIFA.
    Hỗ trợ tìm kiếm Semantic và lọc theo Metadata.
    """
    logger.info("🔧 data_retrieval_tool called with %d items", len(searches))

    # Get DB Connection
    db = get_db_connection()
    if not db:
        return json.dumps({"status": "error", "message": "Database connection failed"})

    combined_results = []
    all_filter_infos = []

    tasks = []
    for item in searches:
        tasks.append(_execute_single_search(db, item))

    results_list = await asyncio.gather(*tasks)

    for products, filter_info in results_list:
        combined_results.extend(products)
        if filter_info:
            all_filter_infos.append(filter_info)

    # ============================================================
    # STOCK ENRICHMENT: Fetch stock info for all products
    # ============================================================
    skus_to_check = []
    skus_to_check = []
    for product in combined_results:
        # Handle Flat Result (has 'sku')
        if "sku" in product:
            skus_to_check.append(product["sku"])
        
        # Handle Grouped Result (has 'all_skus')
        elif "all_skus" in product:
            skus_to_check.extend(product["all_skus"])
            
        # Fallback for raw results (legacy or defensive)
        else:
            sku = product.get("product_color_code") or product.get("magento_ref_code")
            if sku:
                skus_to_check.append(sku)

    stock_map = {}
    if skus_to_check:
        logger.info(f"🔍 Checking stock for {len(skus_to_check)} SKUs: {skus_to_check[:5]}...")
        try:
            stock_map = await fetch_stock_for_skus(skus_to_check)
            logger.info(f"📦 [STOCK] Enriched {len(stock_map)} products with stock info")
        except Exception as e:
            logger.error(f"❌ Error fetching stock in data_retrieval: {e}")

    # Merge stock info into each product
    for product in combined_results:
        # Handle Flat Result
        if "sku" in product:
            sku = product["sku"]
            if sku in stock_map:
                product["stock_info"] = stock_map[sku]
        
        # Handle Grouped Result
        elif "all_skus" in product:
            # Aggregate stock for all variants
            # For brevity, we can store a map or a summary. 
            # Let's store a map of {sku: stock_info}
            group_stock = {}
            has_stock = False
            for s in product["all_skus"]:
                if s in stock_map:
                    group_stock[s] = stock_map[s]
                    has_stock = True
            
            if has_stock:
                product["stock_info"] = group_stock

        # Fallback logic
        else:
            sku = product.get("product_color_code") or product.get("magento_ref_code")
            if sku and sku in stock_map:
                product["stock_info"] = stock_map[sku]

    # Aggregate filter info from first result for simplicity in response
    final_info = all_filter_infos[0] if all_filter_infos else {}

    output = {
        "status": "success",
        "results": combined_results,
        "filter_info": final_info,
        "stock_enriched": len(stock_map) > 0,
        # ← Agent sẽ check filter_info.fallback_used để biết có fallback không
        # ← Nếu có, dùng filter_info.recommendation_message để báo khách
    }

    logger.info(
        "🎁 Final result: %d products. Fallback used: %s. Stock enriched: %s",
        len(combined_results),
        final_info.get("fallback_used", False),
        len(stock_map) > 0,
    )

    return json.dumps(output, ensure_ascii=False, default=str)


# Load dynamic docstring
# Load dynamic docstring
dynamic_prompt = read_tool_prompt("data_retrieval_tool")
if dynamic_prompt:
    data_retrieval_tool.__doc__ = dynamic_prompt
    data_retrieval_tool.description = dynamic_prompt
