import json
import logging
import asyncio
from datetime import datetime, date
from typing import Any

import psycopg
from psycopg import sql
from psycopg_pool import AsyncConnectionPool

from config import CHECKPOINT_POSTGRES_URL

logger = logging.getLogger(__name__)


class ConversationManager:
    def __init__(
        self,
        connection_url: str = CHECKPOINT_POSTGRES_URL,
        table_name: str = "langgraph_chat_histories",
    ):
        self.connection_url = connection_url
        self.table_name = table_name
        self._pool: AsyncConnectionPool | None = None

    async def _get_pool(self) -> AsyncConnectionPool:
        """Get or create async connection pool."""
        if self._pool is None:
            self._pool = AsyncConnectionPool(
                self.connection_url, 
                min_size=1,
                max_size=20,
                max_lifetime=600, # Recycle connections every 10 mins
                max_idle=300,     # Close idle connections after 5 mins
                open=False,
                # kwargs={"autocommit": True} # DISABLE autocommit to support atomic transactions
            )
            await self._pool.open()
        return self._pool

    async def initialize_table(self):
        """Create the chat history table if it doesn't exist"""
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    # Use sql.SQL for safe identifier quoting
                    create_table_query = sql.SQL("""
                        CREATE TABLE IF NOT EXISTS {table} (
                            id SERIAL PRIMARY KEY,
                            identity_key VARCHAR(255) NOT NULL,
                            message TEXT NOT NULL,
                            is_human BOOLEAN NOT NULL,
                            timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                        )
                    """).format(table=sql.Identifier(self.table_name))
                    await cursor.execute(create_table_query)

                    create_index_query = sql.SQL("""
                        CREATE INDEX IF NOT EXISTS {index_name} 
                        ON {table} (identity_key, timestamp)
                    """).format(
                        index_name=sql.Identifier(f"idx_{self.table_name}_identity_timestamp"),
                        table=sql.Identifier(self.table_name)
                    )
                    await cursor.execute(create_index_query)
                await conn.commit()
            logger.info(f"Table {self.table_name} initialized successfully")
        except Exception as e:
            logger.error(f"Error initializing table: {e}")
            raise

    async def save_conversation_turn(self, identity_key: str, human_message: str, ai_message: str):
        """Save both human and AI messages in a single atomic transaction with retry logic."""
        max_retries = 3
        for attempt in range(max_retries):
            try:
                pool = await self._get_pool()
                timestamp = datetime.now()
                # Transaction block: atomic insert
                async with pool.connection() as conn:
                    async with conn.cursor() as cursor:
                        insert_query = sql.SQL("""
                            INSERT INTO {table} (identity_key, message, is_human, timestamp) 
                            VALUES (%s, %s, %s, %s), (%s, %s, %s, %s)
                        """).format(table=sql.Identifier(self.table_name))
                        
                        await cursor.execute(
                            insert_query,
                            (
                                identity_key, human_message, True, timestamp,
                                identity_key, ai_message, False, timestamp,
                            ),
                        )
                    await conn.commit()
                
                logger.debug(f"Saved conversation turn for identity_key {identity_key}")
                return # Success
                
            except psycopg.OperationalError as e:
                logger.warning(f"Database connection error (attempt {attempt+1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    logger.error(f"Failed to save conversation after {max_retries} attempts: {e}")
                    raise
                await asyncio.sleep(0.5)
                
            except Exception as e:
                logger.error(f"Failed to save conversation for identity_key {identity_key}: {e}", exc_info=True)
                raise

    async def get_chat_history(
        self, identity_key: str, limit: int | None = None, before_id: int | None = None
    ) -> list[dict[str, Any]]:
        """
        Retrieve chat history for an identity (user_id or device_id) using cursor-based pagination.
        AI messages được parse từ JSON string để lấy product_ids.
        Uses cached graph for performance.
        """
        max_retries = 3
        for attempt in range(max_retries):
            try:
                # Optimize: Use Range Query for Index usage
                now = datetime.now().astimezone() # Ensure Timezone Aware (e.g. +07:00)
                start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
                end_of_day = now.replace(hour=23, minute=59, second=59, microsecond=999999)

                base_query = sql.SQL("""
                    SELECT message, is_human, timestamp, id
                    FROM {table} 
                    WHERE identity_key = %s
                    AND timestamp >= %s AND timestamp <= %s
                """).format(table=sql.Identifier(self.table_name))
                
                params = [identity_key, start_of_day, end_of_day]
                query_parts = [base_query]

                if before_id:
                    query_parts.append(sql.SQL("AND id < %s"))
                    params.append(before_id)

                query_parts.append(sql.SQL("ORDER BY id DESC"))

                if limit:
                    query_parts.append(sql.SQL("LIMIT %s"))
                    params.append(limit)

                final_query = sql.SQL(" ").join(query_parts)

                pool = await self._get_pool()
                async with pool.connection() as conn:
                    async with conn.cursor() as cursor:
                        await cursor.execute(final_query, tuple(params))
                        results = await cursor.fetchall()

                history = []
                for row in results:
                    message_content = row[0]
                    is_human = row[1]
                    
                    entry = {
                        "is_human": is_human,
                        "timestamp": row[2],
                        "id": row[3],
                    }
                    
                    if is_human:
                        # User message - text thuần
                        entry["message"] = message_content
                    else:
                        # AI message - parse JSON để lấy ai_response + product_ids
                        try:
                            parsed = json.loads(message_content)
                            entry["message"] = parsed.get("ai_response", message_content)
                            entry["product_ids"] = parsed.get("product_ids", [])
                        except (json.JSONDecodeError, TypeError):
                            # Fallback nếu không phải JSON (data cũ)
                            entry["message"] = message_content
                            entry["product_ids"] = []
                    
                    history.append(entry)
                
                return history

            except psycopg.OperationalError as e:
                logger.warning(f"Database connection error in get_chat_history (attempt {attempt+1}/{max_retries}): {e}")
                if attempt == max_retries - 1:
                    logger.error(f"Failed to get chat history after {max_retries} attempts: {e}")
                    raise
                await asyncio.sleep(0.5)

            except Exception as e:
                logger.error(f"Error retrieving chat history: {e}")
                return []

    async def archive_history(self, identity_key: str) -> str:
        """
        Archive current chat history for identity_key by renaming it in the DB.
        Only archives messages from TODAY (which are the visible ones).
        Returns the new archived key.
        """
        try:
            timestamp_suffix = datetime.now().strftime("%Y%m%d_%H%M%S")
            # Format: user123_archived_20231027_103045
            new_key = f"{identity_key}_archived_{timestamp_suffix}"
            
            # Optimize: Use Range Query
            now = datetime.now().astimezone() # Ensure Timezone Aware (e.g. +07:00)
            start_of_day = now.replace(hour=0, minute=0, second=0, microsecond=0)
            end_of_day = now.replace(hour=23, minute=59, second=59, microsecond=999999)

            pool = await self._get_pool()
            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    query = sql.SQL("""
                        UPDATE {table}
                        SET identity_key = %s
                        WHERE identity_key = %s
                        AND timestamp >= %s AND timestamp <= %s
                    """).format(table=sql.Identifier(self.table_name))

                    await cursor.execute(
                        query,
                        (new_key, identity_key, start_of_day, end_of_day)
                    )
                await conn.commit()
            
            logger.info(f"Archived history for {identity_key} to {new_key}")
            return new_key
            
        except Exception as e:
            logger.error(f"Error archiving history: {e}")
            raise

    async def clear_history(self, identity_key: str):
        """Clear all chat history for an identity"""
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    query = sql.SQL("DELETE FROM {table} WHERE identity_key = %s").format(
                        table=sql.Identifier(self.table_name)
                    )
                    await cursor.execute(query, (identity_key,))
                await conn.commit()
            logger.info(f"Cleared chat history for identity_key {identity_key}")
        except Exception as e:
            logger.error(f"Error clearing chat history: {e}")

    async def get_user_count(self) -> int:
        """Get total number of unique identities"""
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn, conn.cursor() as cursor:
                query = sql.SQL("SELECT COUNT(DISTINCT identity_key) FROM {table}").format(
                    table=sql.Identifier(self.table_name)
                )
                await cursor.execute(query)
                result = await cursor.fetchone()
                return result[0] if result else 0
        except Exception as e:
            logger.error(f"Error getting user count: {e}")
            return 0

    async def get_message_count_today(self, identity_key: str) -> int:
        """
        Đếm số tin nhắn của identity trong ngày hôm nay (cho rate limiting).
        Chỉ đếm human messages (is_human = true).
        """
        try:
            # Optimize: Use Range Query
            now = datetime.now()
            start_of_day = datetime(now.year, now.month, now.day, 0, 0, 0)
            end_of_day = datetime(now.year, now.month, now.day, 23, 59, 59, 999999)

            pool = await self._get_pool()
            async with pool.connection() as conn, conn.cursor() as cursor:
                query = sql.SQL("""
                    SELECT COUNT(*) FROM {table} 
                    WHERE identity_key = %s 
                    AND is_human = true 
                    AND timestamp >= %s AND timestamp <= %s
                """).format(table=sql.Identifier(self.table_name))
                
                await cursor.execute(
                    query,
                    (identity_key, start_of_day, end_of_day),
                )
                result = await cursor.fetchone()
                return result[0] if result else 0
        except Exception as e:
            logger.error(f"Error counting messages for {identity_key}: {e}")
            return 0

    async def close(self):
        """Close the connection pool"""
        if self._pool:
            await self._pool.close()


# --- Singleton ---
_instance: ConversationManager | None = None


async def get_conversation_manager() -> ConversationManager:
    """Get or create async ConversationManager singleton"""
    global _instance
    if _instance is None:
        _instance = ConversationManager()
        await _instance.initialize_table()
    return _instance
