"""
Agent Helper Functions
Các hàm tiện ích cho chat controller.
"""

import json
import logging
import uuid
from decimal import Decimal

from langchain_core.messages import HumanMessage, ToolMessage
from langchain_core.runnables import RunnableConfig

from common.conversation_manager import ConversationManager
from common.langfuse_client import get_callback_handler
from common.langfuse_client import get_callback_handler
from common.starrocks_connection import get_db_connection
from agent.tools.data_retrieval_filter import format_product_results
from .models import AgentState

logger = logging.getLogger(__name__)


def decimal_default(obj):
    """
    JSON serializer for objects not serializable by default json code.
    Handles Decimal objects.
    """
    if isinstance(obj, Decimal):
        return float(obj)
    raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")


def extract_product_ids(messages: list) -> list[dict]:
    """
    Extract full product info from tool messages (data_retrieval_tool results).
    Returns list of product objects with: sku, name, price, sale_price, url, thumbnail_image_url.
    """
    products = []
    seen_skus = set()

    for msg in messages:
        if isinstance(msg, ToolMessage):
            try:
                # Tool result is JSON string
                tool_result = json.loads(msg.content)

                # Check if tool returned products (new format with "results" wrapper)
                if tool_result.get("status") == "success":
                    # Handle both direct "products" and nested "results" format
                    product_list = []
                    
                    if "results" in tool_result:
                        results_data = tool_result["results"]
                        if results_data and isinstance(results_data, list):
                            # Check first item to determine format
                            first_item = results_data[0] if len(results_data) > 0 else {}
                            if isinstance(first_item, dict) and "products" in first_item:
                                # Nested format: {"results": [{"products": [...]}]}
                                for result_item in results_data:
                                    product_list.extend(result_item.get("products", []))
                            else:
                                # Flat format: {"results": [product1, product2]} (Current)
                                product_list = results_data
                    elif "products" in tool_result:
                        # Legacy format: {"products": [...]}
                        product_list = tool_result["products"]
                    
                    logger.warning(f"🛠️ [EXTRACT] Extracted {len(product_list)} products")
                    
                    for product in product_list:
                        sku = product.get("sku") or product.get("internal_ref_code")
                        if sku and sku not in seen_skus:
                            seen_skus.add(sku)

                            # Extract full product info (already parsed by tool)
                            product_obj = {
                                "sku": sku,
                                "name": product.get("name", ""),
                                "price": product.get("price", 0),
                                "sale_price": product.get("sale_price"),
                                "url": product.get("url", ""),
                                "thumbnail_image_url": product.get("thumbnail_image_url", ""),
                            }
                            products.append(product_obj)
            except (json.JSONDecodeError, KeyError, TypeError) as e:
                logger.debug(f"Could not parse tool message for products: {e}")
                continue

    return products


async def fetch_products_by_skus(skus: list[str]) -> list[dict]:
    """
    Fetch product details from DB for a list of SKUs.
    Used when AI mentions products that are not in the current tool output context.
    """
    if not skus:
        return []
    
    db = get_db_connection()
    if not db:
        logger.error("❌ DB Connection failed in fetch_products_by_skus")
        return []

    # Format SKUs for SQL IN clause
    placeholders = ",".join(["%s"] * len(skus))
    sql = f"""
    SELECT 
        internal_ref_code,
        description_text_full,
        sale_price,
        original_price,
        discount_amount,
        product_line_vn,
        product_line_en,
        1.0 as max_score
    FROM shared_source.magento_product_dimension_with_text_embedding
    WHERE internal_ref_code IN ({placeholders}) OR magento_ref_code IN ({placeholders})
    """
    
    # Params: Pass SKUs twice (once for internal_ref, once for magento_ref)
    params = skus + skus
    
    try:
        results = await db.execute_query_async(sql, params=params)
        logger.info(f"🔄 Fetched {len(results)} fallback products from DB for SKUs: {skus}")
        return format_product_results(results)
    except Exception as e:
        logger.error(f"❌ Error fetching fallback products: {e}")
        return []


async def parse_ai_response_async(ai_raw_content: str, all_products: list) -> tuple[str, list, str | None]:
    """
    Async version of parse_ai_response with DB fallback.
    
    Parse AI response từ LLM output và map SKUs với product data.
    Nếu SKU được mention nhưng không có trong all_products (context hiện tại),
    sẽ query trực tiếp DB để lấy thông tin.
    
    Flow:
    - LLM trả về: {"ai_response": "...", "product_ids": ["SKU1"], ...}
    - Map SKUs → enriched products từ context
    - Nếu thiếu → Query DB
    """
    from .structured_models import ChatResponse, UserInsight
    import re
    
    ai_text_response = ai_raw_content
    final_products = []
    user_insight = None
    
    logger.info(f"🤖 Raw AI JSON: {ai_raw_content}")

    try:
        # Try to parse if it's a JSON string from LLM
        ai_json = json.loads(ai_raw_content)
        
        # === PYDANTIC VALIDATION ===
        try:
            # Try strict Pydantic validation
            parsed_response = ChatResponse.model_validate(ai_json)
            ai_text_response = parsed_response.ai_response
            explicit_skus = parsed_response.product_ids
            
            # Convert user_insight to dict/string for storage
            if parsed_response.user_insight:
                user_insight = parsed_response.user_insight.model_dump_json(indent=2)
            
            logger.info("✅ Pydantic validation passed for ChatResponse")
            
        except Exception as validation_error:
            # Fallback to manual parsing if Pydantic fails
            logger.warning(f"⚠️ Pydantic validation failed, using fallback: {validation_error}")
            ai_text_response = ai_json.get("ai_response", ai_raw_content)
            explicit_skus = ai_json.get("product_ids", [])
            raw_insight = ai_json.get("user_insight")
            
            if raw_insight:
                if isinstance(raw_insight, dict):
                    user_insight = json.dumps(raw_insight, ensure_ascii=False, indent=2)
                elif isinstance(raw_insight, str):
                    user_insight = raw_insight
        
        # === CRITICAL: Filter/Fetch products ===
        # Extract SKUs mentioned in ai_response text using regex pattern [SKU]
        mentioned_skus_in_text = set(re.findall(r'\[([A-Z0-9]+)\]', ai_text_response))
        logger.info(f"📝 SKUs mentioned in ai_response: {mentioned_skus_in_text}")
        
        # Determine target SKUs
        target_skus = set()
        
        # 1. Use explicit SKUs if available and confirmed by text, OR just explicit
        if explicit_skus and isinstance(explicit_skus, list):
             # Optional: Filter explicit SKUs to only those actually in text to reduce hallucination
             # But if explicit list is provided, we generally trust it unless we want strict text-match
             if mentioned_skus_in_text:
                 explicit_set = set(str(s) for s in explicit_skus)
                 target_skus = explicit_set.intersection(mentioned_skus_in_text)
                 if not target_skus: # If intersection empty, fallback to text mentions
                     target_skus = mentioned_skus_in_text
             else:
                 target_skus = set(str(s) for s in explicit_skus)
        elif mentioned_skus_in_text:
             # 2. If no explicit SKUs, use text mentions
             target_skus = mentioned_skus_in_text
             
        logger.info(f"🎯 Target SKUs to return: {target_skus}")

        if target_skus:
            # Build lookup from current context
            product_lookup = {p["sku"]: p for p in all_products if p.get("sku")}
            
            found_products = []
            missing_skus = []
            
            for sku in target_skus:
                if sku in product_lookup:
                    found_products.append(product_lookup[sku])
                else:
                    missing_skus.append(sku)
            
            # Fetch missing SKUs from DB
            if missing_skus:
                logger.info(f"🕵️ Missing SKUs in context, fetching from DB: {missing_skus}")
                fallback_products = await fetch_products_by_skus(missing_skus)
                found_products.extend(fallback_products)
                
            final_products = found_products

    except (json.JSONDecodeError, TypeError) as e:
        logger.warning(f"⚠️ Failed to parse AI response as JSON: {e}")

    return ai_text_response, final_products, user_insight



def prepare_execution_context(query: str, user_id: str, history: list, images: list | None):
    """
    Prepare initial state and execution config for the graph run.
    
    Returns:
        tuple: (initial_state, exec_config)
    """
    initial_state: AgentState = {
        "user_query": HumanMessage(content=query),
        "messages": [HumanMessage(content=query)],
        "history": history,
        "user_id": user_id,
        "images_embedding": [],
        "ai_response": None,
    }
    run_id = str(uuid.uuid4())

    # Metadata for LangChain (tags for logging/filtering)
    metadata = {
        "run_id": run_id,
        "tags": "chatbot,production",
    }

    langfuse_handler = get_callback_handler()

    exec_config = RunnableConfig(
        configurable={
            "user_id": user_id,
            "transient_images": images or [],
            "run_id": run_id,
        },
        run_id=run_id,
        metadata=metadata,
        callbacks=[langfuse_handler] if langfuse_handler else [],
    )
    return initial_state, exec_config


async def handle_post_chat_async(
    memory: ConversationManager, 
    identity_key: str, 
    human_query: str, 
    ai_response: dict | None
):
    """
    Save chat history in background task after response is sent.
    Lưu AI response dưới dạng JSON string.
    """
    if ai_response:
        try:
            # Convert dict thành JSON string để lưu vào TEXT field
            # Use decimal_default to handle Decimal types from DB
            ai_response_json = json.dumps(ai_response, ensure_ascii=False, default=decimal_default)
            await memory.save_conversation_turn(identity_key, human_query, ai_response_json)
            logger.debug(f"Saved conversation for identity_key {identity_key}")
        except Exception as e:
            logger.error(f"Failed to save conversation for identity_key {identity_key}: {e}", exc_info=True)
