import os
import argparse
from PIL import Image
from tqdm import tqdm
import numpy as np
from loguru import logger

from core.detector import ClothingDetector
from core.encoder import ImageEncoder
from db.faiss_mgr import FaissManager

def index_images(image_dir, index_path="models/vector_db.index"):
    # 1. Init components
    detector = ClothingDetector()
    encoder = ImageEncoder()
    faiss_mgr = FaissManager(dimension=768, index_path=index_path)

    # 2. Get file list
    extensions = ('.jpg', '.jpeg', '.png', '.webp')
    image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(extensions)]
    
    logger.info(f"Found {len(image_files)} images in {image_dir}")

    # 3. Process and Index
    for filename in tqdm(image_files):
        try:
            path = os.path.join(image_dir, filename)
            image = Image.open(path).convert("RGB")
            
            # Detect and crop (using the first detection as the main product for indexing)
            crops = detector.detect_and_crop(image)
            if not crops:
                continue
            
            # Embed the crop
            embedding = encoder.encode_image(crops[0])
            
            # SKU is usually the filename without extension
            sku = os.path.splitext(filename)[0]
            
            # Add to FAISS
            faiss_mgr.add_vectors(np.array([embedding]), [sku])
            
        except Exception as e:
            logger.error(f"Error processing {filename}: {e}")

    # 4. Save index
    faiss_mgr.save()
    logger.info("Indexing complete.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Index images into FAISS database.")
    parser.add_argument("--dir", type=str, required=True, help="Directory containing images to index.")
    parser.add_argument("--out", type=str, default="models/vector_db.index", help="Output index path.")
    
    args = parser.parse_args()
    index_images(args.dir, args.out)
