"""
Model Fallback Manager
Tự động fallback sang model khác khi gặp lỗi (rate limit, quota, context length)
"""

import logging
from enum import Enum
from typing import Any

from common.mongo import get_mongo_db
from langchain_core.language_models import BaseChatModel

from common.llm_factory import LLMFactory

logger = logging.getLogger(__name__)


class ErrorType(Enum):
    """Các loại lỗi có thể fallback"""

    RATE_LIMIT = "rate_limit"
    QUOTA_EXCEEDED = "quota_exceeded"
    CONTEXT_LENGTH = "context_length"
    NETWORK_ERROR = "network_error"
    UNKNOWN = "unknown"


class ModelFallbackManager:
    """
    Quản lý auto fallback models khi gặp lỗi.

    Features:
    - Tự động detect loại lỗi
    - Chọn model fallback phù hợp dựa trên error type
    - Quản lý danh sách models user có key
    - Hỗ trợ model="auto" để tự động chọn model tốt nhất
    """

    # Model capabilities (context size, cost tier)
    MODEL_CAPABILITIES: dict[str, dict[str, Any]] = {
        # OpenAI
        "openai/gpt-5-nano": {"context": 128000, "tier": "cheap", "speed": "fast"},
        "openai/gpt-5-mini": {"context": 128000, "tier": "cheap", "speed": "fast"},
        "openai/gpt-4o": {"context": 128000, "tier": "medium", "speed": "medium"},
        "openai/gpt-4o-mini": {"context": 128000, "tier": "cheap", "speed": "fast"},
        # Gemini
        "google_genai/gemini-2.5-flash": {"context": 1000000, "tier": "cheap", "speed": "fast"},
        "google_genai/gemini-2.5-pro": {"context": 2000000, "tier": "medium", "speed": "medium"},
        "google_genai/gemini-2.0-flash": {"context": 1000000, "tier": "cheap", "speed": "fast"},
        "google_genai/gemini-2.0-pro-exp": {"context": 2000000, "tier": "medium", "speed": "medium"},
        "google_genai/gemini-1.5-flash": {"context": 1000000, "tier": "cheap", "speed": "fast"},
        "google_genai/gemini-1.5-pro": {"context": 2000000, "tier": "medium", "speed": "medium"},
        # Groq
        "groq/meta-llama/llama-4-maverick-17b-128e-instruct": {
            "context": 128000,
            "tier": "cheap",
            "speed": "very_fast",
        },
        "groq/meta-llama/llama-4-scout-17b-16e-instruct": {"context": 128000, "tier": "cheap", "speed": "very_fast"},
        "groq/openai/gpt-oss-120b": {"context": 128000, "tier": "medium", "speed": "very_fast"},
        "groq/openai/gpt-oss-20b": {"context": 128000, "tier": "cheap", "speed": "very_fast"},
    }

    def __init__(self):
        self.llm_factory = LLMFactory()

    async def _get_api_key_for_provider(self, user_id: str, provider: str) -> str | None:
        """
        Lấy API key của user cho provider (openai, google_genai, groq, anthropic, ...)
        """
        db = get_mongo_db()
        doc = await db["user_api_keys"].find_one(
            {"user_id": user_id, "model": {"$regex": f"^{provider}", "$options": "i"}}
        )
        if doc:
            api_key = doc.get("key")
            # Trim whitespace để tránh lỗi "Illegal header value"
            if api_key:
                return api_key.strip()
            return None
        return None

    def detect_error_type(self, error: Exception) -> ErrorType:
        """
        Detect loại lỗi từ exception.

        Returns:
            ErrorType: Loại lỗi được detect
        """
        error_str = str(error).lower()
        type(error).__name__

        # Rate limit
        if "rate limit" in error_str or "429" in error_str or "too many requests" in error_str:
            return ErrorType.RATE_LIMIT

        # Quota exceeded
        if "quota" in error_str or "insufficient" in error_str or "billing" in error_str:
            return ErrorType.QUOTA_EXCEEDED

        # Context length
        if "context" in error_str and ("length" in error_str or "exceeded" in error_str or "too long" in error_str):
            return ErrorType.CONTEXT_LENGTH
        if "maximum context length" in error_str or "token limit" in error_str:
            return ErrorType.CONTEXT_LENGTH

        # Network errors
        if "timeout" in error_str or "connection" in error_str or "network" in error_str:
            return ErrorType.NETWORK_ERROR

        return ErrorType.UNKNOWN

    async def get_user_available_models(self, user_id: str) -> list[str]:
        """
        Lấy danh sách models mà user đã có API key.

        Args:
            user_id: User ID

        Returns:
            List[str]: Danh sách model names user có key
        """
        try:
            db = get_mongo_db()
            collection = db["user_api_keys"]

            # Lấy tất cả API keys của user
            keys = await collection.find({"user_id": user_id}).to_list(length=100)

            # Map model types sang model names
            available_models = []
            for key in keys:
                model_type = key.get("model", "").lower()

                # Map model type to actual model names
                if model_type == "openai":
                    available_models.extend(
                        ["openai/gpt-5-nano", "openai/gpt-5-mini", "openai/gpt-4o", "openai/gpt-4o-mini"]
                    )
                elif model_type == "gemini":
                    available_models.extend(
                        [
                            "google_genai/gemini-2.5-flash",
                            "google_genai/gemini-2.5-pro",
                            "google_genai/gemini-2.0-flash",
                            "google_genai/gemini-1.5-flash",
                        ]
                    )
                elif model_type == "groq":
                    available_models.extend(
                        [
                            "groq/meta-llama/llama-4-maverick-17b-128e-instruct",
                            "groq/meta-llama/llama-4-scout-17b-16e-instruct",
                            "groq/openai/gpt-oss-120b",
                        ]
                    )
                elif model_type == "claude":
                    available_models.extend(
                        ["anthropic/claude-3-5-sonnet-20241022", "anthropic/claude-3-opus-20240229"]
                    )

            # Remove duplicates và sort theo priority
            unique_models = list(dict.fromkeys(available_models))
            return self._sort_models_by_priority(unique_models)

        except Exception as e:
            logger.error(f"Error getting user available models: {e}")
            # Fallback: return common models
            return [
                "openai/gpt-5-nano",
                "google_genai/gemini-2.5-flash",
                "groq/meta-llama/llama-4-maverick-17b-128e-instruct",
            ]

    def _sort_models_by_priority(self, models: list[str]) -> list[str]:
        """
        Sort models theo priority: cheap + fast trước, sau đó đến medium.
        """

        def get_priority(model: str) -> int:
            caps = self.MODEL_CAPABILITIES.get(model, {})
            tier = caps.get("tier", "medium")
            speed = caps.get("speed", "medium")

            # Priority: cheap + fast = 1, cheap + medium = 2, medium = 3
            if tier == "cheap" and speed in ["fast", "very_fast"]:
                return 1
            if tier == "cheap":
                return 2
            return 3

        return sorted(models, key=get_priority)

    async def select_best_model(self, user_id: str, preferred_model: str | None = None) -> str:
        """
        Chọn model tốt nhất từ danh sách user có key.

        Nếu preferred_model được chỉ định và user có key → dùng preferred_model
        Nếu không → chọn model rẻ + nhanh nhất

        Args:
            user_id: User ID
            preferred_model: Model user muốn dùng (optional)

        Returns:
            str: Model name được chọn
        """
        available = await self.get_user_available_models(user_id)

        if not available:
            # Fallback to default
            logger.warning(f"No available models for user {user_id}, using default")
            return "openai/gpt-5-nano"

        # Nếu có preferred_model và user có key → dùng preferred_model
        if preferred_model and preferred_model in available:
            return preferred_model

        # Chọn model đầu tiên trong danh sách đã sort (tốt nhất)
        return available[0]

    def get_fallback_model(self, current_model: str, error_type: ErrorType, available_models: list[str]) -> str | None:
        """
        Chọn fallback model dựa trên error type và current model.

        Args:
            current_model: Model hiện tại bị lỗi
            error_type: Loại lỗi
            available_models: Danh sách models user có key

        Returns:
            Optional[str]: Model fallback, None nếu không có
        """
        if not available_models:
            return None

        current_caps = self.MODEL_CAPABILITIES.get(current_model, {})
        current_context = current_caps.get("context", 128000)
        current_provider = current_model.split("/")[0] if "/" in current_model else ""

        # Loại bỏ current_model khỏi danh sách
        candidates = [m for m in available_models if m != current_model]

        if not candidates:
            return None

        if error_type == ErrorType.CONTEXT_LENGTH:
            # Chọn model có context lớn hơn
            for model in candidates:
                caps = self.MODEL_CAPABILITIES.get(model, {})
                if caps.get("context", 0) > current_context:
                    logger.info(f"🔄 Context length fallback: {current_model} → {model}")
                    return model

            # Nếu không có model nào có context lớn hơn, chọn model có context lớn nhất
            best = max(candidates, key=lambda m: self.MODEL_CAPABILITIES.get(m, {}).get("context", 0))
            logger.info(f"🔄 Context length fallback (best available): {current_model} → {best}")
            return best

        if error_type == ErrorType.RATE_LIMIT:
            # Chọn model khác provider (tránh cùng rate limit)
            for model in candidates:
                provider = model.split("/")[0] if "/" in model else ""
                if provider != current_provider:
                    logger.info(f"🔄 Rate limit fallback: {current_model} → {model} (different provider)")
                    return model

            # Nếu không có provider khác, chọn model tiếp theo
            logger.info(f"🔄 Rate limit fallback: {current_model} → {candidates[0]} (next model)")
            return candidates[0]

        if error_type == ErrorType.QUOTA_EXCEEDED:
            # Chọn model rẻ hơn hoặc provider khác
            current_caps.get("tier", "medium")
            for model in candidates:
                caps = self.MODEL_CAPABILITIES.get(model, {})
                provider = model.split("/")[0] if "/" in model else ""
                # Ưu tiên: provider khác hoặc tier rẻ hơn
                if provider != current_provider or caps.get("tier") == "cheap":
                    logger.info(f"🔄 Quota fallback: {current_model} → {model}")
                    return model

            # Fallback to first available
            logger.info(f"🔄 Quota fallback: {current_model} → {candidates[0]}")
            return candidates[0]

        # Unknown/Network error: chọn model tiếp theo
        logger.info(f"🔄 Generic fallback: {current_model} → {candidates[0]}")
        return candidates[0]

    async def get_model_with_fallback(
        self,
        user_id: str,
        model_name: str,
        streaming: bool = True,
        json_mode: bool = False,
        max_retries: int = 3,
        allow_fallback: bool = True,  # nếu False và model_name != auto: không fallback, báo lỗi thiếu key
    ) -> tuple[BaseChatModel, str]:
        """
        Lấy LLM model với auto fallback khi gặp lỗi.

        Args:
            user_id: User ID
            model_name: Model name (có thể là "auto")
            streaming: Streaming mode
            json_mode: JSON mode
            max_retries: Số lần retry tối đa

        Returns:
            Tuple[BaseChatModel, str]: (LLM instance, actual model name used)

        Raises:
            Exception: Nếu tất cả models đều fail
        """
        # 1. Get available models for fallback (cần có trước để check)
        available_models = await self.get_user_available_models(user_id)

        if not available_models:
            logger.warning(f"No available models for user {user_id}, using default")
            available_models = ["openai/gpt-5-nano"]  # Fallback default

        # 2. Resolve model name
        if model_name == "auto":
            actual_model = await self.select_best_model(user_id, preferred_model=None)
            logger.info(f"🤖 Auto model selection: {actual_model}")
        else:
            actual_model = model_name
            if not allow_fallback:
                # Strict: không fallback. Nếu model không nằm trong danh sách gợi ý → báo lỗi ngay.
                if actual_model not in available_models:
                    logger.error(
                        f"❌ Strict mode: Model {actual_model} không có trong available list và không được phép fallback."
                    )
                    raise Exception(
                        f"Không tìm thấy model {actual_model} trong danh sách được phép, "
                        f"và chế độ strict không cho phép fallback. Vui lòng cấu hình API key cho provider tương ứng."
                    )
            # Cho phép fallback: nếu model không trong available list, chọn best available
            elif actual_model not in available_models:
                logger.warning(f"Model {actual_model} not in available list, using best available")
                actual_model = await self.select_best_model(user_id, preferred_model=None)

        # 3. Try to get model (with fallback on error)
        tried_models = []
        current_model = actual_model

        for attempt in range(max_retries):
            try:
                # Lấy api_key theo provider
                provider = current_model.split("/")[0] if "/" in current_model else current_model
                api_key = await self._get_api_key_for_provider(user_id, provider)
                if not api_key:
                    logger.error(f"❌ No API key found for provider '{provider}' (user_id: {user_id})")
                    logger.error(f"💡 User needs to configure {provider} API key in settings")
                    # Nếu strict (allow_fallback=False và model_name != auto) → báo lỗi ngay
                    if not allow_fallback and model_name != "auto":
                        raise Exception(
                            f"Không tìm thấy API key cho provider '{provider}'. Vui lòng cấu hình API key cho {provider} trong phần Cài đặt."
                        )
                    raise Exception(
                        f"No API key found for provider '{provider}'. Please configure your {provider} API key in settings."
                    )

                # Get model instance với api_key của user
                llm = self.llm_factory.get_model(
                    current_model, streaming=streaming, json_mode=json_mode, api_key=api_key
                )
                logger.info(f"✅ Using model: {current_model}")
                return llm, current_model

            except Exception as e:
                error_type = self.detect_error_type(e)
                logger.warning(f"❌ Model {current_model} failed ({error_type.value}): {e}")

                tried_models.append(current_model)

                # Nếu strict (không cho phép fallback) thì ném lỗi ngay lập tức
                if not allow_fallback and model_name != "auto":
                    logger.error("⛔ Strict mode: không fallback sang model khác. Ném lỗi lên FE.")
                    raise

                # Get fallback model
                fallback = self.get_fallback_model(current_model, error_type, available_models)

                if not fallback or fallback in tried_models:
                    # No more fallback options
                    error_msg = str(e).lower()
                    if "api key" in error_msg or "no api key" in error_msg:
                        # Lỗi API key - message rõ ràng hơn
                        logger.error(f"❌ All models exhausted due to missing API keys. Tried: {tried_models}")
                        logger.error(f"💡 User {user_id} needs to configure API keys in settings")
                        raise Exception(
                            f"Không tìm thấy API key cho bất kỳ model nào. Vui lòng cấu hình API key trong phần Cài đặt. Lỗi cuối: {e}"
                        )
                    # Lỗi khác
                    logger.error(f"❌ All models exhausted. Tried: {tried_models}")
                    raise Exception(f"All available models failed. Last error: {e}")

                current_model = fallback
                logger.info(f"🔄 Retrying with fallback model: {current_model} (attempt {attempt + 2}/{max_retries})")

        # Should not reach here
        raise Exception(f"Failed to get model after {max_retries} attempts")


# Global instance
_fallback_manager: ModelFallbackManager | None = None


def get_fallback_manager() -> ModelFallbackManager:
    """Get singleton ModelFallbackManager instance"""
    global _fallback_manager
    if _fallback_manager is None:
        _fallback_manager = ModelFallbackManager()
    return _fallback_manager
