"""
Batch Processor để xử lý batch testing
"""

import asyncio
import logging
import statistics
import uuid
from datetime import datetime
from typing import Any, Callable

from services.api_client import ChatbotAPIClient
from services.langfuse_client import LangfuseClient

logger = logging.getLogger(__name__)


class BatchProcessor:
    """Processor để xử lý batch testing"""

    def __init__(self, max_concurrent: int = 5, api_url: str | None = None):
        self.api_client = ChatbotAPIClient(api_url=api_url)
        self.langfuse_client = LangfuseClient()
        self.max_concurrent = max_concurrent

    async def process_batch(
        self,
        questions: list[dict[str, Any]],
        num_tests_per_question: int = 1,
        progress_callback: Callable[[dict[str, Any]], None] | None = None,
    ) -> dict[str, Any]:
        """
        Xử lý batch testing cho danh sách câu hỏi

        Args:
            questions: List câu hỏi [{"id": 1, "question": "...", "row_data": {...}}, ...]
            num_tests_per_question: Số lần test mỗi câu hỏi
            progress_callback: Callback để update progress (optional)

        Returns:
            Dict chứa:
            - summary: Tổng kết
            - detailed_results: Chi tiết từng lần test
            - aggregated_results: Tổng hợp theo câu hỏi
        """
        total_questions = len(questions)
        total_tests = total_questions * num_tests_per_question

        detailed_results: list[dict[str, Any]] = []
        aggregated_results: list[dict[str, Any]] = []
        start_time = datetime.now()

        processed = 0
        successful = 0
        failed = 0

        semaphore = asyncio.Semaphore(self.max_concurrent)
        lock = asyncio.Lock()

        async def run_test(question_data: dict[str, Any], test_attempt: int):
            nonlocal processed, successful, failed
            question_id = question_data["id"]
            question = question_data["question"]
            row_data = question_data.get("row_data", {})
            test_id = f"q{question_id}_t{test_attempt}"

            try:
                async with semaphore:
                    result = await self.api_client.chat(
                        query=question,
                        user_id=str(uuid.uuid4()),
                        test_id=test_id,
                        test_attempt=test_attempt,
                    )

                result["question_id"] = question_id
                result["question"] = question
                result["test_attempt"] = test_attempt
                result["product_count"] = len(result.get("product_ids", []))

                # Lấy thêm metrics từ Langfuse (optional)
                if self.langfuse_client.enabled:
                    await asyncio.sleep(1)
                    langfuse_metrics = await self.langfuse_client.get_trace_metrics(
                        trace_id=test_id
                    )
                    if langfuse_metrics:
                        result.update(langfuse_metrics)

            except Exception as e:
                result = {
                    "status": "error",
                    "ai_response": "",
                    "product_ids": [],
                    "latency_ms": 0,
                    "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                    "error_message": str(e),
                    "question_id": question_id,
                    "question": question,
                    "test_attempt": test_attempt,
                    "product_count": 0,
                }

            async with lock:
                detailed_results.append(result)
                processed += 1
                if result.get("status") == "success":
                    successful += 1
                else:
                    failed += 1

                if progress_callback:
                    progress_callback(
                        {
                            "processed": processed,
                            "total": total_tests,
                            "current_question": result.get("question_id"),
                            "current_attempt": result.get("test_attempt"),
                            "successful": successful,
                            "failed": failed,
                        }
                    )

        # Tạo và chạy tasks cho tất cả tests (các question được xử lý song song)
        tasks = []
        for q in questions:
            for attempt in range(1, num_tests_per_question + 1):
                tasks.append(asyncio.create_task(run_test(q, attempt)))

        if tasks:
            await asyncio.gather(*tasks)

        # Tính aggregated results per question
        for q in questions:
            qid = q["id"]
            q_results = [r for r in detailed_results if r.get("question_id") == qid]
            aggregated = self._aggregate_test_results(
                qid, q.get("question", ""), q_results, q.get("row_data", {})
            )
            aggregated_results.append(aggregated)

        # Tính tổng chi phí từ detailed results
        total_cost = round(sum(r.get("cost", 0) for r in detailed_results), 4)

        # Tính tổng kết
        end_time = datetime.now()
        duration_seconds = (end_time - start_time).total_seconds()

        summary = {
            "total_questions": total_questions,
            "num_tests_per_question": num_tests_per_question,
            "total_tests": total_tests,
            "successful": successful,
            "failed": failed,
            "success_rate": round(
                (successful / total_tests * 100) if total_tests > 0 else 0, 2
            ),
            "total_cost_usd": round(total_cost, 4),
            "avg_cost_per_test": round(
                total_cost / total_tests if total_tests > 0 else 0, 4
            ),
            "duration_seconds": round(duration_seconds, 2),
            "start_time": start_time.strftime("%Y-%m-%d %H:%M:%S"),
            "end_time": end_time.strftime("%Y-%m-%d %H:%M:%S"),
        }

        return {
            "summary": summary,
            "detailed_results": detailed_results,
            "aggregated_results": aggregated_results,
        }

    def _aggregate_test_results(
        self,
        question_id: int,
        question: str,
        test_results: list[dict[str, Any]],
        row_data: dict[str, Any],
    ) -> dict[str, Any]:
        """
        Tính toán aggregated metrics cho một câu hỏi sau N lần test

        Args:
            question_id: ID câu hỏi
            question: Nội dung câu hỏi
            test_results: List kết quả từ N lần test
            row_data: Dữ liệu gốc từ Excel

        Returns:
            Dict aggregated metrics
        """
        successful_results = [r for r in test_results if r["status"] == "success"]
        success_count = len(successful_results)
        total_count = len(test_results)

        # Tính average metrics
        latencies = [r["latency_ms"] for r in test_results if r.get("latency_ms")]
        costs = [r.get("cost", 0) for r in test_results if r.get("cost")]
        product_counts = [r.get("product_count", 0) for r in test_results]

        # Response consistency (so sánh các responses)
        responses = [r.get("ai_response", "") for r in successful_results]
        consistency = self._calculate_consistency(responses)

        aggregated = {
            "question_id": question_id,
            "question": question,
            "total_tests": total_count,
            "successful_tests": success_count,
            "failed_tests": total_count - success_count,
            "success_rate": round(
                (success_count / total_count * 100) if total_count > 0 else 0, 2
            ),
            "avg_latency_ms": round(statistics.mean(latencies), 2) if latencies else 0,
            "min_latency_ms": round(min(latencies), 2) if latencies else 0,
            "max_latency_ms": round(max(latencies), 2) if latencies else 0,
            "avg_cost_usd": round(statistics.mean(costs), 4) if costs else 0,
            "total_cost_usd": round(sum(costs), 4),
            "avg_product_count": round(statistics.mean(product_counts), 2)
            if product_counts
            else 0,
            "response_consistency": consistency,
            "sample_response": successful_results[0].get("ai_response", "")
            if successful_results
            else "",
        }

        # Thêm các cột gốc từ Excel
        for key, value in row_data.items():
            if key not in aggregated:
                aggregated[f"original_{key}"] = value

        return aggregated

    def _calculate_consistency(self, responses: list[str]) -> str:
        """
        Tính toán độ nhất quán của responses

        Args:
            responses: List các responses

        Returns:
            "High" | "Medium" | "Low"
        """
        if len(responses) <= 1:
            return "N/A"

        # So sánh độ dài
        lengths = [len(r) for r in responses]
        length_variance = statistics.variance(lengths) if len(lengths) > 1 else 0

        # So sánh nội dung (simple similarity)
        if len(responses) == 2:
            similarity = self._simple_similarity(responses[0], responses[1])
        else:
            # Tính average similarity
            similarities = []
            for i in range(len(responses)):
                for j in range(i + 1, len(responses)):
                    similarities.append(
                        self._simple_similarity(responses[i], responses[j])
                    )
            similarity = statistics.mean(similarities) if similarities else 0

        # Đánh giá
        if similarity > 0.8 and length_variance < 100:
            return "High"
        elif similarity > 0.5:
            return "Medium"
        else:
            return "Low"

    def _simple_similarity(self, text1: str, text2: str) -> float:
        """Tính similarity đơn giản giữa 2 texts"""
        if not text1 or not text2:
            return 0.0

        # Simple word overlap
        words1 = set(text1.lower().split())
        words2 = set(text2.lower().split())

        if not words1 or not words2:
            return 0.0

        intersection = len(words1 & words2)
        union = len(words1 | words2)

        return intersection / union if union > 0 else 0.0
