import faiss
import numpy as np
import pickle
import os
from loguru import logger

class FaissManager:
    def __init__(self, dimension=768, index_path="models/vector_db.index", mapping_path="models/sku_mapping.pkl"):
        self.dimension = dimension
        self.index_path = index_path
        self.mapping_path = mapping_path
        self.index = None
        self.sku_mapping = [] # List of SKUs, index matches FAISS ID

        if os.path.exists(index_path):
            self.load()
        else:
            logger.info("Initializing new FAISS index...")
            self.index = faiss.IndexFlatL2(dimension)
            self.sku_mapping = []

    def add_vectors(self, vectors: np.ndarray, skus: list):
        """
        Add image vectors and their corresponding SKUs to the index.
        """
        if vectors.shape[1] != self.dimension:
            raise ValueError(f"Vector dimension mismatch. Expected {self.dimension}, got {vectors.shape[1]}")
        
        self.index.add(vectors.astype('float32'))
        self.sku_mapping.extend(skus)
        logger.info(f"Added {len(skus)} items. Total items: {self.index.ntotal}")

    def search(self, query_vector: np.ndarray, top_k=5):
        """
        Search for the most similar vectors in the index.
        Returns: distances, results (list of dicts with sku and distance)
        """
        if self.index.ntotal == 0:
            return [], []

        query_vector = query_vector.reshape(1, -1).astype('float32')
        distances, indices = self.index.search(query_vector, top_k)
        
        results = []
        for dist, idx in zip(distances[0], indices[0]):
            if idx != -1: # -1 means not found
                results.append({
                    "sku": self.sku_mapping[idx],
                    "distance": float(dist)
                })
        
        return distances[0], results

    def save(self):
        os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
        faiss.write_index(self.index, self.index_path)
        with open(self.mapping_path, "wb") as f:
            pickle.dump(self.sku_mapping, f)
        logger.info(f"Index saved to {self.index_path}")

    def load(self):
        logger.info(f"Loading FAISS index from {self.index_path}...")
        self.index = faiss.read_index(self.index_path)
        with open(self.mapping_path, "rb") as f:
            self.sku_mapping = pickle.load(f)
        logger.info(f"Loaded {self.index.ntotal} items.")

if __name__ == "__main__":
    # Quick test
    mgr = FaissManager(dimension=4)
    v = np.random.random((2, 4)).astype('float32')
    mgr.add_vectors(v, ["SKU1", "SKU2"])
    dist, res = mgr.search(v[0], top_k=1)
    print(f"Search result: {res}")
