import hashlib
import json
import logging

import redis.asyncio as aioredis  # redis package với async support (thay thế aioredis deprecated)

from config import (
    REDIS_CACHE_DB,
    REDIS_CACHE_PORT,
    REDIS_CACHE_TURN_ON,
    REDIS_CACHE_URL,
    REDIS_PASSWORD,
    REDIS_USERNAME,
)

logger = logging.getLogger(__name__)

# ====================== CACHE CONFIGURATION ======================
# Layer 1: Response Cache (Short TTL to keep stock/price safe)
DEFAULT_RESPONSE_TTL = 300  # 5 minutes
RESPONSE_KEY_PREFIX = "resp_cache:"

# Layer 2: Embedding Cache (Long TTL since vectors are static)
EMBEDDING_CACHE_TTL = 86400  # 24 hours
EMBEDDING_KEY_PREFIX = "emb_cache:"

class RedisClient:
    """
    Hybrid Cache Client for Canifa Chatbot.
    Layer 1: Exact Response Cache (Short TTL)
    Layer 2: Embedding Cache (Long TTL)
    """

    def __init__(self):
        self._client: aioredis.Redis | None = None
        self._enabled = REDIS_CACHE_TURN_ON
        self._stats = {
            "resp_hits": 0,
            "emb_hits": 0,
            "misses": 0,
        }

    async def initialize(self) -> aioredis.Redis | None:
        """Initialize connection"""
        if not self._enabled:
            logger.info("🚫 Redis Cache is DISABLED via REDIS_CACHE_TURN_ON")
            return None

        if self._client is not None:
            return self._client

        try:
            connection_kwargs = {
                "host": REDIS_CACHE_URL,
                "port": REDIS_CACHE_PORT,
                "db": REDIS_CACHE_DB,
                "decode_responses": True,
                "socket_connect_timeout": 5,
            }
            if REDIS_PASSWORD:
                connection_kwargs["password"] = REDIS_PASSWORD
            if REDIS_USERNAME:
                connection_kwargs["username"] = REDIS_USERNAME

            self._client = aioredis.Redis(**connection_kwargs)
            await self._client.ping()

            logger.info(f"✅ Redis Hybrid Cache connected: {REDIS_CACHE_URL}:{REDIS_CACHE_PORT} (db={REDIS_CACHE_DB})")
            return self._client

        except Exception as e:
            logger.error(f"❌ Failed to connect to Redis: {e}")
            self._enabled = False
            return None

    def get_client(self) -> aioredis.Redis | None:
        if not self._enabled:
            return None
        return self._client

    # --- Layer 1: Exact Response Cache (Short TTL) ---

    async def get_response(self, user_id: str, query: str) -> dict | None:
        """Get exact matched response (100% safe, short TTL)"""
        if not self._enabled: return None
        try:
            client = self.get_client()
            if not client: return None

            # Hash of (user_id + query) for exact match
            query_key = f"{user_id}:{query.strip().lower()}"
            cache_hash = hashlib.md5(query_key.encode()).hexdigest()
            key = f"{RESPONSE_KEY_PREFIX}{cache_hash}"

            cached = await client.get(key)
            if cached:
                self._stats["resp_hits"] += 1
                logger.info(f"⚡ LAYER 1 HIT (Response) | User: {user_id}")
                return json.loads(cached)

            return None
        except Exception as e:
            logger.warning(f"Redis get_response error: {e}")
            return None

    async def set_response(self, user_id: str, query: str, response_data: dict, ttl: int = DEFAULT_RESPONSE_TTL):
        """Store full response in cache with short TTL"""
        if not self._enabled or not response_data: return
        try:
            client = self.get_client()
            if not client: return

            query_key = f"{user_id}:{query.strip().lower()}"
            cache_hash = hashlib.md5(query_key.encode()).hexdigest()
            key = f"{RESPONSE_KEY_PREFIX}{cache_hash}"

            await client.setex(key, ttl, json.dumps(response_data))
            logger.debug(f"💾 LAYER 1 STORED (Response) | TTL: {ttl}s")
        except Exception as e:
            logger.warning(f"Redis set_response error: {e}")

    # --- Layer 2: Embedding Cache (Long TTL) ---

    async def get_embedding(self, text: str) -> list[float] | None:
        """Get cached embedding (Saves OpenAI costs)"""
        if not self._enabled: return None
        try:
            client = self.get_client()
            if not client: return None

            text_hash = hashlib.md5(text.strip().lower().encode()).hexdigest()
            key = f"{EMBEDDING_KEY_PREFIX}{text_hash}"

            cached = await client.get(key)
            if cached:
                self._stats["emb_hits"] += 1
                logger.info(f"🔵 LAYER 2 HIT (Embedding) | Query: {text[:20]}...")
                return json.loads(cached)

            return None
        except Exception as e:
            logger.warning(f"Redis get_embedding error: {e}")
            return None

    async def set_embedding(self, text: str, embedding: list[float], ttl: int = EMBEDDING_CACHE_TTL):
        """Store embedding for long term"""
        if not self._enabled or not embedding: return
        try:
            client = self.get_client()
            if not client: return

            text_hash = hashlib.md5(text.strip().lower().encode()).hexdigest()
            key = f"{EMBEDDING_KEY_PREFIX}{text_hash}"

            await client.setex(key, ttl, json.dumps(embedding))
            logger.debug(f"💾 LAYER 2 STORED (Embedding) | TTL: {ttl}s")
        except Exception as e:
            logger.warning(f"Redis set_embedding error: {e}")

# --- Singleton Export ---
redis_cache = RedisClient()

def  get_redis_cache() -> RedisClient:
    return redis_cache
