"""
Simple Langfuse Client Wrapper
Minimal setup using langfuse.langchain module
With propagate_attributes for proper user_id tracking
"""

import asyncio
import logging
import os
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager

from langfuse import Langfuse, get_client, propagate_attributes
from langfuse.langchain import CallbackHandler

from config import (
    LANGFUSE_BASE_URL,
    LANGFUSE_PUBLIC_KEY,
    LANGFUSE_SECRET_KEY,
)

logger = logging.getLogger(__name__)

# ⚡ Global state for async batch export
_langfuse_client: Langfuse | None = None
_export_executor: ThreadPoolExecutor | None = None
_pending_traces: list = []
_export_task: asyncio.Task | None = None
_batch_lock = asyncio.Lock if hasattr(asyncio, "Lock") else None


def initialize_langfuse() -> bool:
    """
    1. Set environment variables
    2. Initialize Langfuse client
    3. Setup thread pool for async batch export
    """
    global _langfuse_client, _export_executor

    if not LANGFUSE_PUBLIC_KEY or not LANGFUSE_SECRET_KEY:
        logger.warning("⚠️ LANGFUSE KEYS MISSING. Tracing disabled.")
        return False

    # Set environment
    os.environ["LANGFUSE_PUBLIC_KEY"] = LANGFUSE_PUBLIC_KEY
    os.environ["LANGFUSE_SECRET_KEY"] = LANGFUSE_SECRET_KEY
    os.environ["LANGFUSE_BASE_URL"] = LANGFUSE_BASE_URL or "https://cloud.langfuse.com"
    os.environ["LANGFUSE_TIMEOUT"] = "10"  # 10s timeout, not blocking

    # Disable default flush to prevent blocking
    os.environ["LANGFUSE_FLUSHINTERVAL"] = "300"  # 5 min, very infrequent

    try:
        _langfuse_client = get_client()
        _export_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="langfuse_export")

        if _langfuse_client.auth_check():
            logger.info("✅ Langfuse Ready! (async batch export)")
            return True
        logger.error("❌ Langfuse auth failed")
        return False

    except Exception as e:
        logger.error(f"❌ Langfuse init error: {e}")
        return False


async def async_flush_langfuse():
    """
    Async wrapper to flush Langfuse without blocking event loop.
    Uses thread pool executor to run sync flush in background.
    """
    if not _langfuse_client or not _export_executor:
        return

    try:
        loop = asyncio.get_event_loop()
        # Run flush in thread pool (non-blocking)
        await loop.run_in_executor(_export_executor, _langfuse_client.flush)
        logger.debug("📤 Langfuse flushed (async)")
    except Exception as e:
        logger.warning(f"⚠️ Async flush failed: {e}")


def get_callback_handler(
    trace_id: str | None = None,
    user_id: str | None = None,
    session_id: str | None = None,
    tags: list[str] | None = None,
    **trace_kwargs,
) -> CallbackHandler | None:
    """
    Get CallbackHandler with unique trace context.

    Args:
        trace_id: Optional unique trace ID
        user_id: User ID for grouping traces by user (NOT set here - use propagate_attributes instead)
        session_id: Session ID for grouping traces by session/conversation
        tags: List of tags for filtering traces
        **trace_kwargs: Additional trace attributes

    Returns:
        CallbackHandler instance + propagate_attributes context manager

    Note:
        Per Langfuse docs: use propagate_attributes(user_id=...) context manager
        to properly set user_id across all observations in the trace.
        This makes user_id appear as a filterable field in Langfuse UI.
    """
    try:
        if not _langfuse_client:
            logger.warning("⚠️ Langfuse client not initialized")
            return None

        handler = CallbackHandler()
        logger.debug("✅ Langfuse CallbackHandler created")
        return handler
    except Exception as e:
        logger.warning(f"⚠️ CallbackHandler error: {e}")
        return None


@contextmanager
def langfuse_trace_context(user_id: str | None = None, session_id: str | None = None, tags: list[str] | None = None):
    """
    Context manager to propagate user_id, session_id, tags to all observations.

    Usage:
        with langfuse_trace_context(user_id="user_123", session_id="session_456"):
            # All observations created here will have these attributes
            await invoke_chain()
    """
    attrs = {}
    if user_id:
        attrs["user_id"] = user_id
    if session_id:
        attrs["session_id"] = session_id

    # Tags are set via metadata, not propagate_attributes
    with propagate_attributes(**attrs):
        yield
