"""
Custom Streaming Callback để bắt tokens từ LLM real-time
Không cần đợi graph.astream() emit event!
"""

import asyncio
import logging
import re
from typing import Any

from langchain_core.callbacks.base import AsyncCallbackHandler

logger = logging.getLogger(__name__)


class ProductIDStreamingCallback(AsyncCallbackHandler):
    """
    Callback để bắt LLM tokens real-time và check product_ids.
    Khi có product_ids → trigger break ngay, không đợi user_insight!
    """

    def __init__(self):
        self.accumulated_content = ""
        self.product_ids_found = False
        self.ai_response_text = ""
        self.product_skus = []
        self.should_stop = False
        self.product_found_event = asyncio.Event()  # ✅ Event thay vì polling!

    async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        """
        Callback khi LLM sinh token mới.
        Accumulate và check regex ngay!
        """
        self.accumulated_content += token

        # Debug log mỗi 100 chars
        if len(self.accumulated_content) % 100 == 0:
            logger.debug(f"📡 Streamed {len(self.accumulated_content)} chars...")

        # Check xem đã có product_ids chưa
        if not self.product_ids_found:
            product_match = re.search(r'"product_ids"\s*:\s*\[(.*?)\]', self.accumulated_content, re.DOTALL)

            if product_match:
                logger.warning(f"🎯 FOUND product_ids at {len(self.accumulated_content)} chars!")
                self.product_ids_found = True

                # Extract ai_response
                ai_text_match = re.search(
                    r'"ai_response"\s*:\s*"(.*?)"(?=\s*,\s*"product_ids")', self.accumulated_content, re.DOTALL
                )

                if ai_text_match:
                    self.ai_response_text = ai_text_match.group(1)
                    self.ai_response_text = self.ai_response_text.replace('\\"', '"').replace("\\n", "\n")

                # Extract SKUs
                skus_text = product_match.group(1)
                self.product_skus = re.findall(r'"([^"]+)"', skus_text)

                logger.warning(f"✅ Extracted {len(self.product_skus)} SKUs: {self.product_skus}")
                logger.info("✅ product_ids found → response can return early (stream continues)")

                # ✅ Set event → wake up controller NGAY LẬP TỨC!
                self.should_stop = True
                self.product_found_event.set()

    async def on_llm_end(self, response, **kwargs: Any) -> None:
        """Called when LLM finishes."""
        if not self.product_ids_found:
            logger.info("ℹ️ LLM turn ended without product_ids (may appear after tool calls)")

    async def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
        """Called when LLM errors."""
        logger.error(f"❌ LLM Error: {error}")

    def reset(self):
        """Reset callback state."""
        self.accumulated_content = ""
        self.product_ids_found = False
        self.ai_response_text = ""
        self.product_skus = []
        self.should_stop = False
        self.product_found_event.clear()
