import logging
import psycopg2
from typing import List, Dict, Optional, Any
from datetime import datetime
from config import CHECKPOINT_POSTGRES_URL

logger = logging.getLogger(__name__)

class DatabaseConnection:
    def __init__(self, connection_url: str):
        self.connection_url = connection_url
        self.conn = None

    def connect(self):
        if not self.conn or self.conn.closed:
            self.conn = psycopg2.connect(self.connection_url)
        return self.conn
    
    def close(self):
        if self.conn and not self.conn.closed:
            self.conn.close()

class ConversationManager:
    def __init__(self, connection_url: str = CHECKPOINT_POSTGRES_URL, table_name: str = "langgraph_chat_histories"):
        self.db = DatabaseConnection(connection_url)
        self.table_name = table_name

    def initialize_table(self):
        """Create the chat history table if it doesn't exist"""
        try:
            conn = self.db.connect()
            with conn.cursor() as cursor:
                cursor.execute(f"""
                    CREATE TABLE IF NOT EXISTS {self.table_name} (
                        id SERIAL PRIMARY KEY,
                        user_id VARCHAR(255) NOT NULL,
                        message TEXT NOT NULL,
                        is_human BOOLEAN NOT NULL,
                        timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
                    )
                """)
                
                # Create index 
                cursor.execute(f"""
                    CREATE INDEX IF NOT EXISTS idx_{self.table_name}_user_timestamp 
                    ON {self.table_name} (user_id, timestamp)
                """)
            conn.commit()
            logger.info(f"✅ Table {self.table_name} initialized successfully")
        except Exception as e:
            logger.error(f"Error initializing table: {e}")
            raise e

    def save_message(self, user_id: str, message: str, is_human: bool):
        """Save a message to the chat history"""
        try:
            conn = self.db.connect()
            with conn.cursor() as cursor:
                cursor.execute(
                    f"""INSERT INTO {self.table_name} (user_id, message, is_human, timestamp) 
                       VALUES (%s, %s, %s, %s)""",
                    (user_id, message, is_human, datetime.now())
                )
            conn.commit()
            logger.debug(f"💾 Saved message for user {user_id}: {message[:50]}...")
        except Exception as e:
            logger.error(f"Error saving message: {e}")

    def get_chat_history(self, user_id: str, limit: Optional[int] = None, before_id: Optional[int] = None) -> List[Dict[str, Any]]:
        """Retrieve chat history for a user using cursor-based pagination (before_id)"""
        try:
            # Base query
            query = f"""
                SELECT message, is_human, timestamp, id
                FROM {self.table_name} 
                WHERE user_id = %s
            """
            params = [user_id]

            # Add cursor condition if provided
            if before_id:
                query += " AND id < %s"
                params.append(before_id)
            
            # Order by id DESC (ensures strict chronological consistency with ID cursor)
            query += " ORDER BY id DESC"
            
            if limit:
                query += " LIMIT %s"
                params.append(limit)
            
            conn = self.db.connect()
            with conn.cursor() as cursor:
                cursor.execute(query, tuple(params))
                results = cursor.fetchall()
                
                return [
                    {
                        'message': row[0],
                        'is_human': row[1],
                        'timestamp': row[2],
                        'id': row[3]
                    }
                    for row in results
                ]
        except Exception as e:
            logger.error(f"Error retrieving chat history: {e}")
            return []

    def clear_history(self, user_id: str):
        """Clear all chat history for a user"""
        try:
            conn = self.db.connect()
            with conn.cursor() as cursor:
                cursor.execute(f"DELETE FROM {self.table_name} WHERE user_id = %s", (user_id,))
            conn.commit()
            logger.info(f"🗑️ Cleared chat history for user {user_id}")
        except Exception as e:
            logger.error(f"Error clearing chat history: {e}")

    def get_user_count(self) -> int:
        """Get total number of unique users"""
        try:
            conn = self.db.connect()
            with conn.cursor() as cursor:
                cursor.execute(f"SELECT COUNT(DISTINCT user_id) FROM {self.table_name}")
                result = cursor.fetchone()
                return result[0] if result else 0
        except Exception as e:
            logger.error(f"Error getting user count: {e}")
            return 0

# --- Singleton ---
_instance: Optional[ConversationManager] = None

def get_conversation_manager() -> ConversationManager:
    """Get or create generic ConversationManager singleton"""
    global _instance
    if _instance is None:
        _instance = ConversationManager()
        # Initialize table on first creation
        _instance.initialize_table()
    return _instance
