"""
Performance Test API Server
Server giả lập để test tải với Locust - KHÔNG gọi OpenAI, KHÔNG gọi Postgres.
"""

import logging
import sys
import time
from pathlib import Path

import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

# Setup path to import backend modules
sys.path.insert(0, str(Path(__file__).parent.parent))

# Setup Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ============================================================
# 1. MOCK LLM - Trả lời siêu tốc, không gọi OpenAI
# ============================================================
from typing import Any

from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, ChatResult


class MockHighSpeedLLM(BaseChatModel):
    """
    Mock LLM siêu tốc cho Performance Testing.
    Không gọi OpenAI - trả về response ngay lập tức.
    """

    def _generate(
        self, messages: list[Any], stop: list[str] | None = None, run_manager: Any | None = None, **kwargs: Any
    ) -> ChatResult:
        return ChatResult(
            generations=[
                ChatGeneration(message=AIMessage(content="[MOCK] Xin chào! Đây là bot test, không phải OpenAI thật."))
            ]
        )

    @property
    def _llm_type(self) -> str:
        return "mock-high-speed"

    def bind_tools(self, tools: Any, **kwargs: Any) -> Any:
        """Bypass tool binding - trả về self."""
        return self


# ============================================================
# 2. PATCH create_llm TRƯỚC KHI IMPORT GRAPH
# ============================================================
def mock_create_llm(*args, **kwargs):
    """Factory function giả - trả về MockLLM."""
    logger.info("🎭 MockLLM được gọi thay vì OpenAI!")
    return MockHighSpeedLLM()


# Patch TRƯỚC khi import agent.graph
import common.llm_factory

common.llm_factory.create_llm = mock_create_llm
logger.info("🎭 PATCHED common.llm_factory.create_llm")

# ============================================================
# 3. PATCH Checkpointer - Dùng MemorySaver thay Postgres
# ============================================================
from langgraph.checkpoint.memory import MemorySaver

# Tạo 1 instance MemorySaver dùng chung
_shared_memory_saver = MemorySaver()


def mock_get_checkpointer():
    """Trả về MemorySaver thay vì Postgres."""
    logger.info("🧠 Using MemorySaver (RAM) instead of Postgres")
    return _shared_memory_saver


import agent.checkpointer

# Patch function
agent.checkpointer.get_checkpointer = mock_get_checkpointer
# QUAN TRỌNG: Reset singleton để nó gọi lại mock function
agent.checkpointer._checkpointer_manager._checkpointer = _shared_memory_saver
logger.info("🔧 PATCHED agent.checkpointer (Reset singleton)")


# ============================================================
# 4. PATCH Langfuse - Tắt để tránh Rate Limit
# ============================================================
def mock_get_callback_handler():
    """Trả về None - không gửi trace."""
    return


import common.langfuse_client

common.langfuse_client.get_callback_handler = mock_get_callback_handler
logger.info("🔇 PATCHED common.langfuse_client.get_callback_handler")

# ============================================================
# 5. GIỜ MỚI IMPORT GRAPH (Sau khi patch xong)
# ============================================================
from agent.config import get_config
from langchain_core.messages import HumanMessage

from agent.graph import build_graph
from agent.tools.product_search_helpers import build_starrocks_query

# ============================================================
# 6. IMPORT CHO DB TEST
# ============================================================
from common.starrocks_connection import StarRocksConnection

# ============================================================
# FASTAPI APP
# ============================================================
app = FastAPI(title="Performance Test API")


# Request Models
class SearchRequest(BaseModel):
    query: str
    limit: int = 10


class MockParams(BaseModel):
    query_text: str = ""
    limit: int = 10
    sku: str = None
    gender: str = None
    season: str = None
    color: str = None
    product_line: str = None
    price_min: float = None
    price_max: float = None


# Global variables
mock_graph = None


@app.on_event("startup")
async def startup_event():
    global mock_graph

    logger.info("🚀 Performance Test API Server Starting...")

    # 1. Pre-warm DB connection
    conn = StarRocksConnection()
    conn.connect()
    logger.info("✅ StarRocks Connection initialized")

    # 2. Build Mock Graph (LLM đã bị patch từ đầu file)
    config = get_config()
    mock_graph = build_graph(config)
    logger.info("✅ Mock Graph built successfully (No OpenAI, No Postgres, No Langfuse)")


# ============================================================
# ENDPOINTS
# ============================================================


@app.post("/test/db-search")
async def test_db_search(request: SearchRequest):
    """
    Test StarRocks DB Multi-Search Parallel.
    """
    import asyncio
    start_time = time.time()

    try:
        # Giả lập Multi-Search với 2 query song song
        params1 = MockParams(query_text=request.query)
        params2 = MockParams(query_text=request.query + " nam") # Truy vấn phái sinh

        # Launch parallel task creation
        tasks = [build_starrocks_query(params1), build_starrocks_query(params2)]
        sqls = await asyncio.gather(*tasks)

        db = StarRocksConnection()
        
        # Parallel DB fetching
        db_tasks = [db.execute_query_async(sql) for sql in sqls]
        results = await asyncio.gather(*db_tasks)

        # Trích xuất và làm sạch dữ liệu
        ALLOWED_FIELDS = {"product_name", "sale_price", "internal_ref_code", "product_image_url_thumbnail"}
        all_products = []
        for products in results:
            clean = [{k: v for k, v in p.items() if k in ALLOWED_FIELDS} for p in products[:5]]
            all_products.extend(clean)

        process_time = time.time() - start_time
        return {
            "status": "success",
            "count": len(all_products),
            "process_time_seconds": round(process_time, 4),
            "products": all_products,
            "_queries_run": len(sqls)
        }

    except Exception as e:
        logger.error(f"DB Multi-Search Error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/test/db-ping")
async def test_db_ping():
    """
    Test kết nối DB thuần túy (SELECT 1).
    """
    start_time = time.time()
    try:
        db = StarRocksConnection()
        await db.execute_query_async("SELECT 1")

        process_time = time.time() - start_time
        return {"status": "success", "process_time_seconds": round(process_time, 4)}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/test/graph-mock-chat")
async def test_graph_mock_chat(request: SearchRequest):
    """
    Test toàn bộ Flow Graph với MockLLM.
    KHÔNG GỌI OPENAI - Test logic code Python + LangGraph.
    """
    if not mock_graph:
        raise HTTPException(500, "Mock Graph not initialized")

    start_time = time.time()
    try:
        # Tạo thread_id unique cho mỗi request
        thread_id = f"perf_test_{int(time.time() * 1000)}"

        input_state = {
            "messages": [HumanMessage(content=request.query)],
            "user_id": "perf_user",
        }
        config_runnable = {"configurable": {"thread_id": thread_id}}

        # Chạy Graph
        async for _event in mock_graph.astream(input_state, config=config_runnable):
            pass  # Chỉ cần chạy hết flow

        process_time = time.time() - start_time
        return {
            "status": "success",
            "mode": "mock_llm",
            "process_time_seconds": round(process_time, 4),
            "message": "Graph Flow executed successfully",
        }
    except Exception as e:
        logger.error(f"Graph Error: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.get("/")
async def root():
    return {"message": "Performance Test API is running!", "mode": "MOCK (No OpenAI)"}


# ============================================================
# MAIN
# ============================================================
if __name__ == "__main__":
    print("=" * 60)
    print("🚀 PERFORMANCE TEST SERVER")
    print("=" * 60)
    print("🎭 LLM: MockHighSpeedLLM (No OpenAI)")
    print("🧠 Checkpointer: MemorySaver (No Postgres)")
    print("🔇 Langfuse: Disabled")
    print("=" * 60)

    uvicorn.run(app, host="0.0.0.0", port=8000)
