import logging
from datetime import datetime
from typing import Any

from psycopg import sql
from psycopg_pool import AsyncConnectionPool

from config import CHECKPOINT_POSTGRES_URL

# Runtime imports with fallback
try:
    from agno.db.base import BaseDb
    from agno.models import Message  # type: ignore[import-untyped]
except ImportError:
    # Create stub class if agno not installed
    class BaseDbStub:  # type: ignore
        pass

    # Create a simple Message-like class for when Agno is not available
    class MessageStub:  # type: ignore
        def __init__(self, role: str, content: str, created_at: Any = None):
            self.role = role
            self.content = content
            self.created_at = created_at

    BaseDb = BaseDbStub  # type: ignore
    Message = MessageStub  # type: ignore

logger = logging.getLogger(__name__)


class SessionData:
    """Simple Session object để Agno framework có thể access .metadata và .session_data"""

    def __init__(self, session_id: str, metadata: Any = None, session_data: Any = None, manager: Any = None):
        self.session_id = session_id
        self.metadata = metadata  # Agno expects this attribute
        self.session_data = session_data  # Agno expects this attribute
        self._manager = manager  # Reference to ConversationManager for async operations

    def get_messages(self, *args, **kwargs) -> list[Any]:
        """Agno calls this to get messages from session"""
        # This is called synchronously but we have async data
        # Return empty list - messages will be loaded via load_history
        return []

    def upsert_run(self, run: Any = None) -> bool:
        """Agno calls this to save run data"""
        # This is a sync method, just acknowledge for now
        # Actual message saving happens via save_message/save_session
        return True


# Use composition instead of inheritance to avoid implementing all BaseDb methods
class ConversationManager:  # Don't inherit BaseDb directly
    """
    Conversation Manager với Agno BaseDb interface.
    Hỗ trợ cả legacy methods và Agno Agent.
    """

    def __init__(
        self,
        connection_url: str | None = None,
        table_name: str = "langgraph_chat_histories",
    ):
        self.connection_url: str = connection_url or CHECKPOINT_POSTGRES_URL or ""
        if not self.connection_url:
            raise ValueError("connection_url is required")
        self.table_name = table_name
        self._pool: AsyncConnectionPool | None = None

    async def _get_pool(self) -> AsyncConnectionPool:
        """Get or create async connection pool với config hợp lý."""
        if self._pool is None:
            # Pool config: min_size=1, max_size=5, timeout=10s
            self._pool = AsyncConnectionPool(
                self.connection_url,
                min_size=1,
                max_size=5,
                timeout=10.0,  # 10s timeout thay vì default 30s
                open=False,
            )
            try:
                await self._pool.open()
                logger.info(
                    f"✅ PostgreSQL connection pool opened: {self.connection_url.split('@')[-1] if '@' in self.connection_url else '***'}"
                )
            except Exception as e:
                logger.error(f"❌ Failed to open PostgreSQL pool: {e}")
                self._pool = None
                raise
        return self._pool

    async def initialize_table(self):
        """Create the chat history table if it doesn't exist"""
        try:
            logger.info(f"🔌 Initializing PostgreSQL table: {self.table_name}")
            pool = await self._get_pool()

            # Use connection với timeout ngắn hơn
            async with pool.connection(timeout=5.0) as conn:  # 5s timeout cho connection
                async with conn.cursor() as cursor:
                    await cursor.execute(
                        sql.SQL("""
                            CREATE TABLE IF NOT EXISTS {} (
                                id SERIAL PRIMARY KEY,
                                user_id VARCHAR(255) NOT NULL,
                                message TEXT NOT NULL,
                                is_human BOOLEAN NOT NULL,
                                timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                            )
                        """).format(sql.Identifier(self.table_name))
                    )

                    await cursor.execute(
                        sql.SQL("""
                            CREATE INDEX IF NOT EXISTS {} 
                            ON {} (user_id, timestamp)
                        """).format(
                            sql.Identifier(f"idx_{self.table_name}_user_timestamp"),
                            sql.Identifier(self.table_name),
                        )
                    )
                await conn.commit()
            logger.info(f"✅ Table {self.table_name} initialized successfully")
        except Exception as e:
            logger.error(f"❌ Error initializing table: {e}")
            logger.error(
                f"   Connection URL: {self.connection_url.split('@')[-1] if '@' in self.connection_url else '***'}"
            )
            raise

    async def save_conversation_turn(self, user_id: str, human_message: str, ai_message: str):
        """Save both human and AI messages in a single atomic transaction."""
        try:
            pool = await self._get_pool()
            timestamp = datetime.now()
            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    await cursor.execute(
                        sql.SQL("""
                            INSERT INTO {} (user_id, message, is_human, timestamp)
                            VALUES (%s, %s, %s, %s), (%s, %s, %s, %s)
                        """).format(sql.Identifier(self.table_name)),
                        (
                            user_id,
                            human_message,
                            True,
                            timestamp,
                            user_id,
                            ai_message,
                            False,
                            timestamp,
                        ),
                    )
                await conn.commit()
            logger.debug(f"Saved conversation turn for user {user_id}")
        except Exception as e:
            logger.error(f"Failed to save conversation for user {user_id}: {e}", exc_info=True)
            raise

    async def get_chat_history(
        self, user_id: str, limit: int | None = None, before_id: int | None = None
    ) -> list[dict[str, Any]]:
        """Retrieve chat history for a user using cursor-based pagination."""
        try:
            base_query = sql.SQL("SELECT message, is_human, timestamp, id FROM {} WHERE user_id = %s").format(
                sql.Identifier(self.table_name)
            )
            params: list[Any] = [user_id]

            query_parts: list[sql.Composable] = [base_query]

            if before_id:
                query_parts.append(sql.SQL(" AND id < %s"))
                params.append(before_id)

            query_parts.append(sql.SQL(" ORDER BY id DESC"))

            if limit:
                query_parts.append(sql.SQL(" LIMIT %s"))
                params.append(limit)

            query = sql.Composed(query_parts)

            pool = await self._get_pool()
            async with pool.connection() as conn, conn.cursor() as cursor:
                await cursor.execute(query, tuple(params))
                results = await cursor.fetchall()

                return [
                    {
                        "message": row[0],
                        "is_human": row[1],
                        "timestamp": row[2],
                        "id": row[3],
                    }
                    for row in results
                ]
        except Exception as e:
            logger.error(f"Error retrieving chat history: {e}")
            return []

    async def clear_history(self, user_id: str):
        """Clear all chat history for a user"""
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    await cursor.execute(
                        sql.SQL("DELETE FROM {} WHERE user_id = %s").format(sql.Identifier(self.table_name)),
                        (user_id,),
                    )
                await conn.commit()
            logger.info(f"Cleared chat history for user {user_id}")
        except Exception as e:
            logger.error(f"Error clearing chat history: {e}")

    async def get_user_count(self) -> int:
        """Get total number of unique users"""
        try:
            pool = await self._get_pool()
            async with pool.connection() as conn, conn.cursor() as cursor:
                await cursor.execute(
                    sql.SQL("SELECT COUNT(DISTINCT user_id) FROM {}").format(sql.Identifier(self.table_name))
                )
                result = await cursor.fetchone()
                return result[0] if result else 0
        except Exception as e:
            logger.error(f"Error getting user count: {e}")
            return 0

    async def close(self):
        """Close the connection pool"""
        if self._pool:
            await self._pool.close()
            self._pool = None

    # ========== Agno BaseDb Interface Methods ==========
    # Giữ nguyên methods cũ ở trên để backward compatible

    async def initialize(self):
        """Agno interface: Initialize table (alias của initialize_table)"""
        return await self.initialize_table()

    async def load_history(self, session_id: str, limit: int = 20) -> list[Any]:
        """
        Agno interface: Load history và convert sang Agno Message format.
        Reuse code từ get_chat_history().

        Args:
            session_id: User ID (Agno dùng session_id, map với user_id)
            limit: Số messages tối đa

        Returns:
            List of Agno Message objects
        """
        try:
            # Reuse method cũ
            history_dicts = await self.get_chat_history(user_id=session_id, limit=limit)

            # Convert từ DB format → Agno Message format
            messages = []
            for h in reversed(history_dicts):  # Reverse để chronological order
                role = "user" if h["is_human"] else "assistant"
                agno_message = Message(
                    role=role,
                    content=h["message"],
                    created_at=h["timestamp"],
                )
                messages.append(agno_message)

            logger.debug(f"📥 [Agno] Loaded {len(messages)} messages for session {session_id}")
            return messages
        except Exception as e:
            logger.error(f"❌ [Agno] Error loading history for {session_id}: {e}")
            return []

    async def save_message(self, session_id: str, message: Any):
        """
        Agno interface: Save single message.

        Args:
            session_id: User ID
            message: Agno Message object
        """
        try:
            pool = await self._get_pool()
            is_human = message.role == "user"

            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    await cursor.execute(
                        sql.SQL("""
                            INSERT INTO {} (user_id, message, is_human, timestamp)
                            VALUES (%s, %s, %s, %s)
                        """).format(sql.Identifier(self.table_name)),
                        (
                            session_id,
                            message.content,
                            is_human,
                            message.created_at or datetime.now(),
                        ),
                    )
                await conn.commit()
            logger.debug(f"💾 [Agno] Saved message for session {session_id}")
        except Exception as e:
            logger.error(f"❌ [Agno] Error saving message for {session_id}: {e}", exc_info=True)
            raise

    async def save_session(self, session_id: str, messages: list[Any]):
        """
        Agno interface: Save multiple messages (batch).

        Args:
            session_id: User ID
            messages: List of Agno Message objects
        """
        try:
            pool = await self._get_pool()
            timestamp = datetime.now()

            async with pool.connection() as conn:
                async with conn.cursor() as cursor:
                    # Batch insert
                    values = []
                    for msg in messages:
                        is_human = msg.role == "user"
                        values.append(
                            (
                                session_id,
                                msg.content,
                                is_human,
                                msg.created_at or timestamp,
                            )
                        )

                    await cursor.executemany(
                        sql.SQL("""
                            INSERT INTO {} (user_id, message, is_human, timestamp)
                            VALUES (%s, %s, %s, %s)
                        """).format(sql.Identifier(self.table_name)),
                        values,
                    )
                await conn.commit()
            logger.debug(f"💾 [Agno] Saved {len(messages)} messages for session {session_id}")
        except Exception as e:
            logger.error(f"❌ [Agno] Error saving session for {session_id}: {e}", exc_info=True)
            raise

    async def get_session_messages(self, session_id: str) -> list[Any]:
        """Agno interface: Get all messages for a session"""
        return await self.load_history(session_id, limit=1000)

    async def clear_session(self, session_id: str):
        """Agno interface: Clear session (alias của clear_history)"""
        return await self.clear_history(session_id)

    def get_session(self, session_id: str, session_type: str = "default"):
        """
        Agno interface: Get session data (SYNC method - Agno calls this synchronously).
        Returns SessionData object with required attributes for Agno framework.
        """
        try:
            # Return SessionData object with required attributes: metadata, session_data
            session = SessionData(
                session_id=session_id,
                metadata=None,
                session_data={"session_type": session_type, "created_at": datetime.now()},
                manager=self,
            )
            logger.debug(f"📋 [Agno] Get session: {session_id}")
            return session
        except Exception as e:
            logger.error(f"❌ [Agno] Error getting session {session_id}: {e}")
            return None

    def upsert_session(self, session: Any):
        """
        Agno interface: Save/update session (SYNC method - Agno calls this synchronously).
        This is a placeholder since actual message saving happens via save_message/save_session.

        Args:
            session: SessionData object or dict with 'session_id' key
        """
        try:
            # Handle both SessionData object and dict
            if isinstance(session, SessionData):
                session_id = session.session_id
            else:
                session_id = session.get("session_id") if isinstance(session, dict) else None

            if not session_id:
                logger.error("❌ [Agno] upsert_session: session_id is required")
                return False

            logger.debug(f"💾 [Agno] Upserted session {session_id}")
            return True
        except Exception as e:
            logger.error(f"❌ [Agno] Error upserting session: {e}", exc_info=True)
            return False


# ConversationManager implements BaseDb interface methods
# but doesn't inherit BaseDb to avoid implementing all abstract methods
# Agno will accept it as long as it has the required methods


# --- Singleton ---
_instance: ConversationManager | None = None


async def get_conversation_manager() -> ConversationManager:
    """Get or create async ConversationManager singleton"""
    global _instance
    if _instance is None:
        try:
            _instance = ConversationManager()
            await _instance.initialize_table()
        except Exception as e:
            logger.error(f"❌ Failed to initialize ConversationManager: {e}")
            # Reset instance để retry lần sau
            _instance = None
            raise
    return _instance
