"""
🔍 BOTTLENECK PROFILER - Tìm chỗ code chậm nhất
============================================================
Tool này sẽ:
1. Profile TOÀN BỘ API call (từ request → response)
2. Breakdown time cho từng bước: LLM, DB, Tool calls, etc.
3. Export chi tiết ra file để phân tích

CÁCH DÙNG:
----------
1. Chạy script này:
   python locust/profiler_bottleneck.py

2. Xem kết quả trong:
   - Terminal: Summary report
   - File: profiler_results.json (chi tiết)

3. Phân tích bottleneck:
   - LLM calls > 1000ms? → Cân nhắc streaming hoặc cache
   - DB queries > 500ms? → Optimize index hoặc query
   - Tool execution > 2000ms? → Check logic tool

============================================================
"""

import asyncio
import json
import logging
import sys
import time
from pathlib import Path

# Setup path
sys.path.insert(0, str(Path(__file__).parent.parent))

from agent.graph import build_graph
from agent.config import get_config
from langchain_core.messages import HumanMessage
from common.starrocks_connection import StarRocksConnection

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# ============================================================
# PROFILING DECORATORS
# ============================================================
class PerformanceProfiler:
    """Track và log performance của mỗi step."""
    
    def __init__(self):
        self.metrics = {
            "total_time": 0,
            "llm_calls": [],
            "db_queries": [],
            "tool_executions": [],
            "graph_steps": []
        }
        self.current_step_start = None
    
    def start_timer(self):
        """Bắt đầu đếm thời gian."""
        self.current_step_start = time.time()
    
    def record_llm_call(self, duration_ms: float, tokens: int = 0):
        """Record LLM call time."""
        self.metrics["llm_calls"].append({
            "duration_ms": duration_ms,
            "tokens": tokens
        })
    
    def record_db_query(self, duration_ms: float, query_type: str):
        """Record DB query time."""
        self.metrics["db_queries"].append({
            "duration_ms": duration_ms,
            "type": query_type
        })
    
    def record_tool_execution(self, tool_name: str, duration_ms: float):
        """Record tool execution time."""
        self.metrics["tool_executions"].append({
            "tool": tool_name,
            "duration_ms": duration_ms
        })
    
    def record_graph_step(self, step_name: str, duration_ms: float):
        """Record graph node execution time."""
        self.metrics["graph_steps"].append({
            "step": step_name,
            "duration_ms": duration_ms
        })
    
    def get_summary(self):
        """Tính toán summary statistics."""
        llm_total = sum(c["duration_ms"] for c in self.metrics["llm_calls"])
        llm_avg = llm_total / len(self.metrics["llm_calls"]) if self.metrics["llm_calls"] else 0
        
        db_total = sum(q["duration_ms"] for q in self.metrics["db_queries"])
        db_avg = db_total / len(self.metrics["db_queries"]) if self.metrics["db_queries"] else 0
        
        tool_total = sum(t["duration_ms"] for t in self.metrics["tool_executions"])
        
        return {
            "total_time_ms": self.metrics["total_time"],
            "llm": {
                "total_ms": llm_total,
                "avg_ms": llm_avg,
                "count": len(self.metrics["llm_calls"]),
                "percentage": (llm_total / self.metrics["total_time"] * 100) if self.metrics["total_time"] > 0 else 0
            },
            "db": {
                "total_ms": db_total,
                "avg_ms": db_avg,
                "count": len(self.metrics["db_queries"]),
                "percentage": (db_total / self.metrics["total_time"] * 100) if self.metrics["total_time"] > 0 else 0
            },
            "tools": {
                "total_ms": tool_total,
                "count": len(self.metrics["tool_executions"]),
                "percentage": (tool_total / self.metrics["total_time"] * 100) if self.metrics["total_time"] > 0 else 0
            },
            "graph_steps": self.metrics["graph_steps"]
        }
    
    def print_report(self):
        """In báo cáo đẹp mắt ra terminal."""
        summary = self.get_summary()
        
        print("\n" + "="*60)
        print("🔍 BOTTLENECK PROFILING REPORT")
        print("="*60)
        print(f"⏱️  TOTAL TIME: {summary['total_time_ms']:.2f}ms\n")
        
        print("📊 BREAKDOWN BY COMPONENT:")
        print(f"  🤖 LLM Calls:")
        print(f"     - Time: {summary['llm']['total_ms']:.2f}ms ({summary['llm']['percentage']:.1f}%)")
        print(f"     - Count: {summary['llm']['count']}")
        print(f"     - Avg: {summary['llm']['avg_ms']:.2f}ms/call\n")
        
        print(f"  🗄️  Database Queries:")
        print(f"     - Time: {summary['db']['total_ms']:.2f}ms ({summary['db']['percentage']:.1f}%)")
        print(f"     - Count: {summary['db']['count']}")
        print(f"     - Avg: {summary['db']['avg_ms']:.2f}ms/query\n")
        
        print(f"  🔧 Tool Executions:")
        print(f"     - Time: {summary['tools']['total_ms']:.2f}ms ({summary['tools']['percentage']:.1f}%)")
        print(f"     - Count: {summary['tools']['count']}\n")
        
        print("📈 GRAPH EXECUTION STEPS:")
        for step in summary['graph_steps']:
            print(f"  ├─ {step['step']}: {step['duration_ms']:.2f}ms")
        
        print("\n" + "="*60)
        print("🎯 BOTTLENECK ANALYSIS:")
        
        # Identify bottleneck
        components = [
            ("LLM", summary['llm']['percentage']),
            ("Database", summary['db']['percentage']),
            ("Tools", summary['tools']['percentage'])
        ]
        bottleneck = max(components, key=lambda x: x[1])
        
        if bottleneck[1] > 40:
            print(f"⚠️  PRIMARY BOTTLENECK: {bottleneck[0]} ({bottleneck[1]:.1f}%)")
            
            if bottleneck[0] == "LLM":
                print("💡 RECOMMENDATIONS:")
                print("   - Consider streaming response instead of waiting full completion")
                print("   - Cache common queries/responses")
                print("   - Use lighter model for simple tasks")
            elif bottleneck[0] == "Database":
                print("💡 RECOMMENDATIONS:")
                print("   - Check if vector index is being used (EXPLAIN query)")
                print("   - Optimize WHERE clauses and add proper indexes")
                print("   - Consider connection pooling")
                print("   - Cache frequent queries")
            elif bottleneck[0] == "Tools":
                print("💡 RECOMMENDATIONS:")
                print("   - Profile individual tool execution")
                print("   - Add caching for deterministic tools")
                print("   - Parallelize independent tool calls")
        else:
            print("✅ No single dominant bottleneck - well balanced!")
        
        print("="*60 + "\n")


# ============================================================
# TEST QUERIES
# ============================================================
TEST_QUERIES = [
    "tìm áo phông nam màu đen",
    "Canifa có cửa hàng ở Hà Nội không?",
    "quần jean nữ giá dưới 500k",
    "chính sách đổi trả như thế nào?",
    "áo khoác mùa đông",
]


# ============================================================
# MAIN PROFILING FUNCTION
# ============================================================
async def profile_chatbot_performance():
    """Chạy test và profile performance."""
    
    print("🚀 Starting Performance Profiling...")
    print("="*60)
    
    # Build graph
    config = get_config()
    graph = build_graph(config)
    
    # Test với mỗi query
    all_profiles = []
    
    for i, query in enumerate(TEST_QUERIES, 1):
        print(f"\n[{i}/{len(TEST_QUERIES)}] Testing: '{query}'")
        profiler = PerformanceProfiler()
        
        # Tạo unique thread
        thread_id = f"profile_test_{int(time.time() * 1000)}"
        
        # Run graph với profiling
        start_time = time.time()
        
        try:
            input_state = {
                "messages": [HumanMessage(content=query)],
                "user_id": "profiler_user"
            }
            config_runnable = {"configurable": {"thread_id": thread_id}}
            
            step_count = 0
            async for event in graph.astream(input_state, config=config_runnable):
                step_count += 1
                step_name = list(event.keys())[0] if event else f"step_{step_count}"
                step_time = time.time()
                
                # Record step (simplified - trong production cần hook vào LangGraph callbacks)
                profiler.record_graph_step(step_name, 0)  # TODO: measure actual step time
            
            total_time = (time.time() - start_time) * 1000
            profiler.metrics["total_time"] = total_time
            
            # Mock some metrics (in production, hook into actual LLM/DB calls)
            # Đây chỉ là example - cần integrate với actual callbacks
            profiler.record_llm_call(800, 150)  # Example
            profiler.record_db_query(300, "vector_search")  # Example
            
            all_profiles.append({
                "query": query,
                "profiler": profiler,
                "summary": profiler.get_summary()
            })
            
            print(f"  ✅ Completed in {total_time:.2f}ms")
        
        except Exception as e:
            print(f"  ❌ Error: {e}")
    
    # Print aggregate report
    print("\n" + "="*60)
    print("📊 AGGREGATE REPORT")
    print("="*60)
    
    for profile in all_profiles:
        print(f"\nQuery: '{profile['query']}'")
        profile['profiler'].print_report()
    
    # Save detailed results
    results_file = Path(__file__).parent / "profiler_results.json"
    with open(results_file, "w", encoding="utf-8") as f:
        json.dump(
            {
                "timestamp": time.time(),
                "results": [
                    {
                        "query": p["query"],
                        "summary": p["summary"]
                    }
                    for p in all_profiles
                ]
            },
            f,
            indent=2,
            ensure_ascii=False
        )
    
    print(f"\n💾 Detailed results saved to: {results_file}")
    print("\n✅ Profiling complete!")


# ============================================================
# BONUS: Quick DB Query Profiler
# ============================================================
async def profile_db_queries():
    """Profile ONLY database queries để tìm slow queries."""
    
    print("\n🗄️  DATABASE QUERY PROFILER")
    print("="*60)
    
    test_queries = [
        ("Simple text search", """
            SELECT * FROM shared_source.canifa_products_by_sku
            WHERE LOWER(product_name) LIKE '%áo phông%'
            LIMIT 10
        """),
        ("Vector search (no filters)", """
            WITH top_candidates AS (
                SELECT internal_ref_code,
                       approx_cosine_similarity(embedding, [0.1, 0.2, /* ... 766 more ... */]) as score
                FROM shared_source.canifa_products_by_sku
                ORDER BY score DESC
                LIMIT 100
            )
            SELECT * FROM top_candidates WHERE score > 0.7
        """),
    ]
    
    db = StarRocksConnection()
    
    for name, sql in test_queries:
        print(f"\nTesting: {name}")
        start = time.time()
        
        try:
            # Execute query
            await db.execute_query_async(sql)
            duration = (time.time() - start) * 1000
            
            print(f"  ⏱️  Duration: {duration:.2f}ms")
            
            if duration > 500:
                print(f"  ⚠️  SLOW QUERY DETECTED!")
                print(f"  💡 Recommendation: Run EXPLAIN to check index usage")
        
        except Exception as e:
            print(f"  ❌ Error: {e}")
    
    print("\n" + "="*60)


# ============================================================
# ENTRY POINT
# ============================================================
if __name__ == "__main__":
    print(__doc__)
    
    # Chạy main profiler
    asyncio.run(profile_chatbot_performance())
    
    # Bonus: DB profiler
    # asyncio.run(profile_db_queries())
