"""
Fashion Q&A Agent Controller
Langfuse will auto-trace via LangChain integration (no code changes needed).
"""

import asyncio
import json
import logging
import re
import time
import uuid

from fastapi import BackgroundTasks
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langfuse import propagate_attributes

from common.cache import redis_cache
from common.conversation_manager import get_conversation_manager
from common.langfuse_client import get_callback_handler
from config import DEFAULT_MODEL, REDIS_CACHE_TURN_ON

from .graph import build_graph
from .helper import handle_post_chat_async
from .models import AgentState, get_config
from .streaming_callback import ProductIDStreamingCallback

logger = logging.getLogger(__name__)


async def save_user_insight_to_redis(identity_key: str, insight: str):
    """Background task to save user insight to Redis (non-blocking)."""
    try:
        client = redis_cache.get_client()
        if client:
            insight_key = f"identity_key_insight:{identity_key}"
            await client.set(insight_key, insight)
            logger.info(f"💾 Updated User Insight for {identity_key}: {insight}")
    except Exception as e:
        logger.error(f"❌ Failed to save user insight: {e}")


async def extract_and_save_user_insight(json_content: str, identity_key: str):
    """Background task: Extract user_insight từ partial/full JSON và save to Redis.

    Returns:
        dict | None: user_insight dict nếu extract thành công, None nếu không tìm thấy
    """
    start_time = time.time()
    logger.info(f"🔄 [Background] Starting user_insight extraction for {identity_key}")
    try:
        # Regex match user_insight object
        insight_match = re.search(r'"user_insight"\s*:\s*(\{.*?\})\s*}?\s*$', json_content, re.DOTALL)

        if insight_match:
            insight_json_str = insight_match.group(1)
            # Parse to validate
            insight_dict = json.loads(insight_json_str)
            insight_str = json.dumps(insight_dict, ensure_ascii=False, indent=2)

            # Save to Redis
            await save_user_insight_to_redis(identity_key, insight_str)
            elapsed = time.time() - start_time
            logger.warning(f"✅ [user_insight] Extracted + saved in {elapsed:.2f}s | Key: {identity_key}")
            return insight_dict
        logger.warning(f"⚠️ [Background] No user_insight found in JSON for {identity_key}")
        return None

    except Exception as e:
        logger.error(f"❌ [Background] Failed to extract user_insight for {identity_key}: {e}")
        return None


async def chat_controller(
    query: str,
    user_id: str,
    background_tasks: BackgroundTasks,
    model_name: str = DEFAULT_MODEL,
    images: list[str] | None = None,
    identity_key: str | None = None,
    return_user_insight: bool = False,
) -> dict:
    """
    Controller main logic for non-streaming chat requests.

    Flow:
    1. Check cache (if enabled) → HIT: return cached response
    2. MISS: Call LLM → Save to cache → Return response
    3. user_insight extract:
       - return_user_insight=False (prod): background task (non-blocking)
       - return_user_insight=True (dev): extract ngay và return luôn

    Args:
        identity_key: Key for saving/loading history (identity.history_key)
                      Guest: device_id, User: user_id
        return_user_insight: If True, extract and return user_insight immediately (dev mode)
    """
    effective_identity_key = identity_key or user_id
    request_start_time = time.time()

    logger.info(
        "chat_controller start: model=%s, user_id=%s, identity_key=%s", model_name, user_id, effective_identity_key
    )

    # ====================== CACHE LAYER ======================
    if REDIS_CACHE_TURN_ON:
        cached_response = await redis_cache.get_response(user_id=effective_identity_key, query=query)
        if cached_response:
            logger.info(f"⚡ CACHE HIT for identity_key={effective_identity_key}")
            memory = await get_conversation_manager()
            background_tasks.add_task(
                handle_post_chat_async,
                memory=memory,
                identity_key=effective_identity_key,
                human_query=query,
                ai_response=cached_response,
            )
            return {**cached_response, "cached": True}

    # ====================== NORMAL LLM FLOW ======================
    logger.info("chat_controller: proceed with live LLM call")

    config = get_config()
    config.model_name = model_name

    graph = build_graph(config)

    # Init ConversationManager (Singleton)
    memory = await get_conversation_manager()

    # Load History (only text, no product_ids for AI context)
    history_dicts = await memory.get_chat_history(effective_identity_key, limit=15, include_product_ids=False)
    messages = [
        HumanMessage(content=m["message"]) if m["is_human"] else AIMessage(content=m["message"]) for m in history_dicts
    ][::-1]  # Reverse to chronological order (Oldest -> Newest)

    # Prepare State
    # Fetch User Insight from Redis
    user_insight = None
    if effective_identity_key:
        try:
            client = redis_cache.get_client()
            if client:
                insight_key = f"identity_key_insight:{effective_identity_key}"
                user_insight = await client.get(insight_key)
                if user_insight:
                    logger.info(f"🧠 Loaded User Insight for {effective_identity_key}: {user_insight[:50]}...")
        except Exception as e:
            logger.error(f"❌ Error fetching user insight: {e}")

    initial_state: AgentState = {
        "user_query": HumanMessage(content=query),
        "messages": [],  # Start empty, acting as scratchpad. History & Query are separate.
        "history": messages,
        "user_id": user_id,
        "user_insight": user_insight,
        "images_embedding": [],
        "ai_response": None,
    }

    run_id = str(uuid.uuid4())
    langfuse_handler = get_callback_handler()

    # ⚡ STREAMING CALLBACK - Bắt tokens real-time!
    streaming_callback = ProductIDStreamingCallback()

    exec_config = RunnableConfig(
        configurable={
            "user_id": user_id,
            "transient_images": images or [],
            "run_id": run_id,
        },
        run_id=run_id,
        metadata={"run_id": run_id, "tags": "chatbot,production"},
        callbacks=[langfuse_handler, streaming_callback] if langfuse_handler else [streaming_callback],
    )

    # Execute Graph with Streaming - TRẢ NGAY KHI CÓ AI_RESPONSE + PRODUCT_IDS
    session_id = f"{user_id}-{run_id[:8]}"

    ai_text_response = ""
    final_product_ids = []
    accumulated_content = ""

    logger.info("🌊 Starting LLM streaming...")

    event_count = 0
    start_time = time.time()
    ai_text_response = ""
    final_product_ids = []

    # Create streaming task
    async def consume_events():
        nonlocal event_count, ai_text_response, final_product_ids, accumulated_content
        async for event in graph.astream(initial_state, config=exec_config):
            event_count += 1
            elapsed = time.time() - start_time
            logger.info(f"📦 Event #{event_count} at t={elapsed:.2f}s | Keys: {list(event.keys())}")

            # Bắt event có ai_response (text response từ LLM)
            if "ai_response" in event:
                ai_message = event["ai_response"]
                if ai_message and not getattr(ai_message, "tool_calls", None):
                    ai_raw_content = ai_message.content if ai_message else ""
                    accumulated_content = ai_raw_content

    # ✅ Event-based waiting - KHÔNG POLLING, CPU gần 0%!
    with propagate_attributes(user_id=user_id, session_id=session_id):
        stream_task = asyncio.create_task(consume_events())

        try:
            if return_user_insight:
                # Dev mode: Đợi stream hoàn thành để có đầy đủ user_insight
                logger.info("🔍 [DEV MODE] Waiting for full stream to get user_insight...")
                await stream_task
                ai_text_response = streaming_callback.ai_response_text
                final_product_ids = streaming_callback.product_skus
                logger.info("✅ Stream completed in dev mode")
            else:
                # Prod mode: Break ngay khi có product_ids, nhưng vẫn để stream chạy để lấy user_insight
                wait_task = asyncio.create_task(streaming_callback.product_found_event.wait())
                done, pending = await asyncio.wait(
                    [stream_task, wait_task],
                    return_when=asyncio.FIRST_COMPLETED,
                )

                if streaming_callback.product_ids_found:
                    elapsed = time.time() - start_time
                    logger.warning(
                        f"🎯 CALLBACK EVENT fired at t={elapsed:.2f}s | Returning response early (stream continues)"
                    )
                    ai_text_response = streaming_callback.ai_response_text
                    final_product_ids = streaming_callback.product_skus
                else:
                    # Stream task finished normally
                    logger.info("✅ Stream completed without callback trigger")
                    full_content = streaming_callback.accumulated_content
                    if full_content:
                        from .helper import parse_ai_response_fast

                        parsed_text, parsed_skus, _ = parse_ai_response_fast(full_content)
                        if parsed_text:
                            ai_text_response = parsed_text
                        if parsed_skus:
                            final_product_ids = parsed_skus
                        logger.info(
                            "🧩 Fallback parse after stream: ai_response=%s chars, skus=%s",
                            len(ai_text_response),
                            len(final_product_ids),
                        )

                # Cancel only the wait task if still pending
                for task in pending:
                    if task is not stream_task:
                        task.cancel()

        except asyncio.CancelledError:
            logger.warning("⚡ Stream task cancelled by callback event")

    # 📊 LOG TIME 1: KHI CÓ AI_RESPONSE + PRODUCT_IDS (callback break)
    elapsed_response = time.time() - start_time
    response_ready_s = time.time() - request_start_time
    logger.warning(f"⏱️ [1] RESPONSE READY: {elapsed_response:.2f}s | ai_response + product_ids")

    # 🛍️ MAP SKUs → Full Product Details
    from .helper import fetch_products_by_skus

    enriched_products = []
    if final_product_ids:
        logger.info(f"🔍 Mapping {len(final_product_ids)} SKUs to full product details...")
        enriched_products = await fetch_products_by_skus(final_product_ids)
        
        # Build flexible lookup map: 
        # SKU -> Product, ProductID -> Product
        product_lookup = {}
        
        for p in enriched_products:
            # Case 1: Flat Product (has 'sku')
            if "sku" in p:
                product_lookup[p["sku"]] = p
                # Also map by base/color code if possible (e.g. 6DS25S010 from 6DS25S010-Blue-S)
                base_code = p["sku"].split("-")[0]
                if base_code not in product_lookup:
                     product_lookup[base_code] = p
            
            # Case 2: Grouped Product (has 'product_id' and 'all_skus')
            elif "product_id" in p:
                p_id = p["product_id"]
                product_lookup[p_id] = p
                # Map all variants SKUs to this parent product
                if "all_skus" in p:
                    for s in p["all_skus"]:
                        product_lookup[s] = p
                        
        ordered_products = []
        seen_ids: set[str] = set()
        
        for sku in final_product_ids:
            # Try exact match first
            product = product_lookup.get(sku)
            
            if product:
                # Use a unique identifier for dedup (product_id or sku)
                uid = product.get("product_id") or product.get("sku")
                if uid and uid not in seen_ids:
                    ordered_products.append(product)
                    seen_ids.add(uid)
            else:
                 logger.debug(f"⚠️ SKU {sku} not found in fetched products")

        enriched_products = ordered_products
        logger.info(f"✅ Mapped {len(enriched_products)} products (matched AI IDs)")

    # ✅ user_insight handling
    user_insight_dict = None

    if return_user_insight:
        # Dev mode: chờ stream xong rồi extract ngay
        callback_accumulated_content = streaming_callback.accumulated_content
        if callback_accumulated_content and effective_identity_key:
            logger.info("🔍 [DEV] Extracting user_insight synchronously...")
            user_insight_dict = await extract_and_save_user_insight(
                callback_accumulated_content, effective_identity_key
            )

            # 📊 LOG TIME 2: SAU KHI EXTRACT USER_INSIGHT
            elapsed_total = time.time() - start_time
            logger.warning(f"⏱️ [2] TOTAL TIME (with user_insight): {elapsed_total:.2f}s | Returning all data!")
    elif background_tasks:
        # Prod/Dev fast mode: để stream chạy xong rồi mới extract trong background
        async def finalize_user_insight_after_stream():
            try:
                await stream_task
                full_content = streaming_callback.accumulated_content
                if full_content and effective_identity_key:
                    await extract_and_save_user_insight(full_content, effective_identity_key)
            except Exception as exc:
                logger.error(f"❌ [Background] user_insight finalize failed: {exc}")

        logger.info("💾 [PROD] Scheduling background task for user_insight extraction (post-stream)")
        background_tasks.add_task(finalize_user_insight_after_stream)

    response_payload = {
        "ai_response": ai_text_response,
        "product_ids": enriched_products,  # ⚡ Full product objects, not just SKUs!
        "response_ready_s": round(response_ready_s, 2),
        "response_ready_stream_s": round(elapsed_response, 2),
    }

    if user_insight_dict is not None:
        response_payload["user_insight"] = user_insight_dict

    # ====================== SAVE TO CACHE ======================
    if REDIS_CACHE_TURN_ON:
        await redis_cache.set_response(
            user_id=effective_identity_key, query=query, response_data=response_payload, ttl=300
        )
        logger.debug(f"💾 Cached response for identity_key={effective_identity_key}")

    # Save to History (Background) - We save the payload WITHOUT insight to history text
    background_tasks.add_task(
        handle_post_chat_async,
        memory=memory,
        identity_key=effective_identity_key,
        human_query=query,
        ai_response=response_payload,
    )

    logger.info("chat_controller finished")
    return {**response_payload, "cached": False}
