import hashlib
import json
import logging

from openai import AsyncOpenAI, OpenAI

from config import OPENAI_API_KEY

logger = logging.getLogger(__name__)

__all__ = [
    "create_embedding",
    "create_embedding_async",
    "create_embeddings_async",
    "get_async_embedding_client",
    "get_embedding_client",
]


class EmbeddingClientManager:
    """
    Singleton Class quản lý OpenAI Embedding Client (Sync & Async).
    """

    def __init__(self):
        self._client: OpenAI | None = None
        self._async_client: AsyncOpenAI | None = None

    def get_client(self) -> OpenAI:
        """Sync Client lazy loading"""
        if self._client is None:
            if not OPENAI_API_KEY:
                raise RuntimeError("CRITICAL: OPENAI_API_KEY chưa được thiết lập")
            self._client = OpenAI(api_key=OPENAI_API_KEY)
        return self._client

    def get_async_client(self) -> AsyncOpenAI:
        """Async Client lazy loading"""
        if self._async_client is None:
            if not OPENAI_API_KEY:
                raise RuntimeError("CRITICAL: OPENAI_API_KEY chưa được thiết lập")
            self._async_client = AsyncOpenAI(api_key=OPENAI_API_KEY)
        return self._async_client


logger = logging.getLogger(__name__)

from common.cache import redis_cache

# --- Singleton ---
_manager = EmbeddingClientManager()
get_embedding_client = _manager.get_client
get_async_embedding_client = _manager.get_async_client


def create_embedding(text: str) -> list[float]:
    """Sync embedding generation with Layer 2 Cache"""
    try:
        # 1. Try Cache (Sync wrapper for get_embedding if needed, but here we just use what we have)
        # Note: common.cache is async, so sync create_embedding will still call OpenAI
        # unless we add a sync cache method. For now, focus on async.
        client = get_embedding_client()
        response = client.embeddings.create(model="text-embedding-3-small", input=text)
        return response.data[0].embedding
    except Exception as e:
        logger.error(f"Error creating embedding (sync): {e}")
        return []


async def create_embedding_async(text: str) -> list[float]:
    """
    Async embedding generation with Layer 2 Cache.
    Saves OpenAI costs by reusing embeddings for identical queries.
    """
    try:
        # 1. Try Layer 2 Cache
        cached = await redis_cache.get_embedding(text)
        if cached:
            return cached

        # 2. Call OpenAI
        client = get_async_embedding_client()
        response = await client.embeddings.create(model="text-embedding-3-small", input=text)
        embedding = response.data[0].embedding

        # 3. Store in Cache
        if embedding:
            await redis_cache.set_embedding(text, embedding)
            
        return embedding
    except Exception as e:
        logger.error(f"Error creating embedding (async): {e}")
        return []


async def create_embeddings_async(texts: list[str]) -> list[list[float]]:
    """
    Batch async embedding generation with per-item Layer 2 Cache.
    """
    try:
        if not texts:
            return []

        results = [[] for _ in texts]
        missed_indices = []
        missed_texts = []

        client = redis_cache.get_client()
        if client:
            keys = []
            for text in texts:
                text_hash = hashlib.md5(text.strip().lower().encode()).hexdigest()
                keys.append(f"emb_cache:{text_hash}")

            cached_values = await client.mget(keys)
            for i, cached in enumerate(cached_values):
                if cached:
                    results[i] = json.loads(cached)
                else:
                    missed_indices.append(i)
                    missed_texts.append(texts[i])
        else:
            # Fallback: no redis client, treat all as miss
            missed_indices = list(range(len(texts)))
            missed_texts = texts

        # 2. Call OpenAI for missed texts
        if missed_texts:
            client = get_async_embedding_client()
            response = await client.embeddings.create(model="text-embedding-3-small", input=missed_texts)
            
            # OpenAI returns embeddings in the same order as missed_texts
            for i, data_item in enumerate(response.data):
                idx = missed_indices[i]
                embedding = data_item.embedding
                results[idx] = embedding
                
                # 3. Cache the new embedding
                await redis_cache.set_embedding(missed_texts[i], embedding)

        return results
    except Exception as e:
        logger.error(f"Error creating batch embeddings (async): {e}")
        return [[] for _ in texts]
