import logging
from datetime import datetime
from typing import Any

from psycopg_pool import AsyncConnectionPool

from config import CHECKPOINT_POSTGRES_URL

logger = logging.getLogger(__name__)


class ConversationManager:
    def __init__(
        self,
        connection_url: str = CHECKPOINT_POSTGRES_URL,
        table_name: str = "langgraph_chat_histories",
    ):
        self.connection_url = connection_url
        self.table_name = table_name
        self._pool: AsyncConnectionPool | None = None

    async def _get_pool(self) -> AsyncConnectionPool:
        """Get or create async connection pool."""
        if self._pool is None:
            self._pool = AsyncConnectionPool(self.connection_url, open=False)
            await self._pool.open()
        return self._pool

    async def initialize_table(self):
        """Create the chat history table if it doesn't exist"""
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    await cursor.execute(f"""
                        CREATE TABLE IF NOT EXISTS {self.table_name} (
                            id SERIAL PRIMARY KEY,
                            user_id VARCHAR(255) NOT NULL,
                            message TEXT NOT NULL,
                            is_human BOOLEAN NOT NULL,
                            timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                        )
                    """)

                    await cursor.execute(f"""
                        CREATE INDEX IF NOT EXISTS idx_{self.table_name}_user_timestamp 
                        ON {self.table_name} (user_id, timestamp)
                    """)
                await conn.commit()
            logger.info(f"Table {self.table_name} initialized successfully")
        except Exception as e:
            logger.error(f"Error initializing table: {e}")
            raise

    async def save_conversation_turn(self, user_id: str, human_message: str, ai_message: str):
        """Save both human and AI messages in a single atomic transaction."""
        try:
            pool = await self._get_pool()
            timestamp = datetime.now()
            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    await cursor.execute(
                        f"""INSERT INTO {self.table_name} (user_id, message, is_human, timestamp) 
                           VALUES (%s, %s, %s, %s), (%s, %s, %s, %s)""",
                        (
                            user_id,
                            human_message,
                            True,
                            timestamp,
                            user_id,
                            ai_message,
                            False,
                            timestamp,
                        ),
                    )
                await conn.commit()
            logger.debug(f"Saved conversation turn for user {user_id}")
        except Exception as e:
            logger.error(f"Failed to save conversation for user {user_id}: {e}", exc_info=True)
            raise

    async def get_chat_history(
        self, user_id: str, limit: int | None = None, before_id: int | None = None
    ) -> list[dict[str, Any]]:
        """Retrieve chat history for a user using cursor-based pagination."""
        try:
            query = f"""
                SELECT message, is_human, timestamp, id
                FROM {self.table_name} 
                WHERE user_id = %s
            """
            params = [user_id]

            if before_id:
                query += " AND id < %s"
                params.append(before_id)

            query += " ORDER BY id DESC"

            if limit:
                query += " LIMIT %s"
                params.append(limit)

            pool = await self._get_pool()
            async with pool.connection() as conn, conn.cursor() as cursor:
                await cursor.execute(query, tuple(params))
                results = await cursor.fetchall()

                return [
                    {
                        "message": row[0],
                        "is_human": row[1],
                        "timestamp": row[2],
                        "id": row[3],
                    }
                    for row in results
                ]
        except Exception as e:
            logger.error(f"Error retrieving chat history: {e}")
            return []

    async def clear_history(self, user_id: str):
        """Clear all chat history for a user"""
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    await cursor.execute(f"DELETE FROM {self.table_name} WHERE user_id = %s", (user_id,))
                await conn.commit()
            logger.info(f"Cleared chat history for user {user_id}")
        except Exception as e:
            logger.error(f"Error clearing chat history: {e}")

    async def get_user_count(self) -> int:
        """Get total number of unique users"""
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn, conn.cursor() as cursor:
                await cursor.execute(f"SELECT COUNT(DISTINCT user_id) FROM {self.table_name}")
                result = await cursor.fetchone()
                return result[0] if result else 0
        except Exception as e:
            logger.error(f"Error getting user count: {e}")
            return 0

    async def close(self):
        """Close the connection pool"""
        if self._pool:
            await self._pool.close()


# --- Singleton ---
_instance: ConversationManager | None = None


async def get_conversation_manager() -> ConversationManager:
    """Get or create async ConversationManager singleton"""
    global _instance
    if _instance is None:
        _instance = ConversationManager()
        await _instance.initialize_table()
    return _instance
