import json
import logging
import asyncio
from datetime import datetime, date
from typing import Any

import psycopg
from psycopg_pool import AsyncConnectionPool

from config import CHECKPOINT_POSTGRES_URL

logger = logging.getLogger(__name__)


class ConversationManager:
    def __init__(
        self,
        connection_url: str = CHECKPOINT_POSTGRES_URL,
        table_name: str = "langgraph_chat_histories",
    ):
        self.connection_url = connection_url
        self.table_name = table_name
        self._pool: AsyncConnectionPool | None = None

    async def _get_pool(self) -> AsyncConnectionPool:
        """Get or create async connection pool."""
        if self._pool is None:
            self._pool = AsyncConnectionPool(
                self.connection_url, 
                min_size=1,
                max_size=20,
                max_lifetime=600, # Recycle connections every 10 mins
                max_idle=300,     # Close idle connections after 5 mins
                open=False,
                kwargs={"autocommit": True}
            )
            await self._pool.open()
        return self._pool

    async def initialize_table(self):
        """Create the chat history table if it doesn't exist"""
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    await cursor.execute(f"""
                        CREATE TABLE IF NOT EXISTS {self.table_name} (
                            id SERIAL PRIMARY KEY,
                            identity_key VARCHAR(255) NOT NULL,
                            message TEXT NOT NULL,
                            is_human BOOLEAN NOT NULL,
                            timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                        )
                    """)

                    await cursor.execute(f"""
                        CREATE INDEX IF NOT EXISTS idx_{self.table_name}_identity_timestamp 
                        ON {self.table_name} (identity_key, timestamp)
                    """)
                await conn.commit()
            logger.info(f"Table {self.table_name} initialized successfully")
        except Exception as e:
            logger.error(f"Error initializing table: {e}")
            raise

    async def save_conversation_turn(self, identity_key: str, human_message: str, ai_message: str):
        """Save both human and AI messages in a single atomic transaction with retry logic."""
        max_retries = 3
        for attempt in range(max_retries):
            try:
                pool = await self._get_pool()
                timestamp = datetime.now()
                async with pool.connection() as conn:
                    async with conn.cursor() as cursor:
                        await cursor.execute(
                            f"""INSERT INTO {self.table_name} (identity_key, message, is_human, timestamp) 
                               VALUES (%s, %s, %s, %s), (%s, %s, %s, %s)""",
                            (
                                identity_key,
                                human_message,
                                True,
                                timestamp,
                                identity_key,
                                ai_message,
                                False,
                                timestamp,
                            ),
                        )
                    # With autocommit=True in pool, and context manager, transactions are handled.
                    # Explicit commit can be safer but might be redundant if autocommit is on.
                    # Let's keep existing logic but be mindful of autocommit.
                    # Actually if autocommit=True, we don't need conn.commit().
                    # But if we want atomic transaction for 2 inserts, we should NOT use autocommit=True for the pool globally,
                    # OR we start a transaction block.
                    # But psycopg3 connection `async with pool.connection() as conn` actually starts a transaction by default if autocommit is False.
                    # Let's revert pool autocommit=True and handle it normally which is safer for atomicity.
                    await conn.commit()
                
                logger.debug(f"Saved conversation turn for identity_key {identity_key}")
                return # Success
                
            except psycopg.OperationalError as e:
                logger.warning(f"Database connection error (attempt {attempt+1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    logger.error(f"Failed to save conversation after {max_retries} attempts: {e}")
                    raise
                await asyncio.sleep(0.5)
                
            except Exception as e:
                logger.error(f"Failed to save conversation for identity_key {identity_key}: {e}", exc_info=True)
                raise

    async def get_chat_history(
        self, identity_key: str, limit: int | None = None, before_id: int | None = None
    ) -> list[dict[str, Any]]:
        """
        Retrieve chat history for an identity (user_id or device_id) using cursor-based pagination.
        AI messages được parse từ JSON string để lấy product_ids.
        """
        max_retries = 3
        for attempt in range(max_retries):
            try:
                query = f"""
                    SELECT message, is_human, timestamp, id
                    FROM {self.table_name} 
                    WHERE identity_key = %s
                    AND DATE(timestamp) = DATE(CURRENT_TIMESTAMP AT TIME ZONE 'Asia/Ho_Chi_Minh')
                """
                params = [identity_key]

                if before_id:
                    query += " AND id < %s"
                    params.append(before_id)

                query += " ORDER BY id DESC"

                if limit:
                    query += " LIMIT %s"
                    params.append(limit)

                pool = await self._get_pool()
                async with pool.connection() as conn, conn.cursor() as cursor:
                    await cursor.execute(query, tuple(params))
                    results = await cursor.fetchall()

                history = []
                for row in results:
                    message_content = row[0]
                    is_human = row[1]
                    
                    entry = {
                        "is_human": is_human,
                        "timestamp": row[2],
                        "id": row[3],
                    }
                    
                    if is_human:
                        # User message - text thuần
                        entry["message"] = message_content
                    else:
                        # AI message - parse JSON để lấy ai_response + product_ids
                        try:
                            parsed = json.loads(message_content)
                            entry["message"] = parsed.get("ai_response", message_content)
                            entry["product_ids"] = parsed.get("product_ids", [])
                        except (json.JSONDecodeError, TypeError):
                            # Fallback nếu không phải JSON (data cũ)
                            entry["message"] = message_content
                            entry["product_ids"] = []
                    
                    history.append(entry)
                
                return history

            except psycopg.OperationalError as e:
                logger.warning(f"Database connection error in get_chat_history (attempt {attempt+1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    logger.error(f"Failed to get chat history after {max_retries} attempts: {e}")
                    raise
                await asyncio.sleep(0.5)

            except Exception as e:
                logger.error(f"Error retrieving chat history: {e}")
                return []

    async def archive_history(self, identity_key: str) -> str:
        """
        Archive current chat history for identity_key by renaming it in the DB.
        Only archives messages from TODAY (which are the visible ones).
        Returns the new archived key.
        """
        try:
            timestamp_suffix = datetime.now().strftime("%Y%m%d_%H%M%S")
            new_key = f"{identity_key}_archived_{timestamp_suffix}"
            
            pool = await self._get_pool()
            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    # Rename identity_key for today's messages
                    await cursor.execute(
                        f"""
                        UPDATE {self.table_name}
                        SET identity_key = %s
                        WHERE identity_key = %s
                        AND DATE(timestamp) = DATE(CURRENT_TIMESTAMP AT TIME ZONE 'Asia/Ho_Chi_Minh')
                        """,
                        (new_key, identity_key)
                    )
                await conn.commit()
            
            logger.info(f"Archived history for {identity_key} to {new_key}")
            return new_key
            
        except Exception as e:
            logger.error(f"Error archiving history: {e}")
            raise

    async def clear_history(self, identity_key: str):
        """Clear all chat history for an identity"""
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    await cursor.execute(f"DELETE FROM {self.table_name} WHERE identity_key = %s", (identity_key,))
                await conn.commit()
            logger.info(f"Cleared chat history for identity_key {identity_key}")
        except Exception as e:
            logger.error(f"Error clearing chat history: {e}")

    async def get_user_count(self) -> int:
        """Get total number of unique identities"""
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn, conn.cursor() as cursor:
                await cursor.execute(f"SELECT COUNT(DISTINCT identity_key) FROM {self.table_name}")
                result = await cursor.fetchone()
                return result[0] if result else 0
        except Exception as e:
            logger.error(f"Error getting user count: {e}")
            return 0

    async def get_message_count_today(self, identity_key: str) -> int:
        """
        Đếm số tin nhắn của identity trong ngày hôm nay (cho rate limiting).
        Chỉ đếm human messages (is_human = true).
        """
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn, conn.cursor() as cursor:
                await cursor.execute(
                    f"""
                    SELECT COUNT(*) FROM {self.table_name} 
                    WHERE identity_key = %s 
                    AND is_human = true 
                    AND DATE(timestamp) = CURRENT_DATE
                    """,
                    (identity_key,),
                )
                result = await cursor.fetchone()
                return result[0] if result else 0
        except Exception as e:
            logger.error(f"Error counting messages for {identity_key}: {e}")
            return 0

    async def close(self):
        """Close the connection pool"""
        if self._pool:
            await self._pool.close()


# --- Singleton ---
_instance: ConversationManager | None = None


async def get_conversation_manager() -> ConversationManager:
    """Get or create async ConversationManager singleton"""
    global _instance
    if _instance is None:
        _instance = ConversationManager()
        await _instance.initialize_table()
    return _instance
