import torch
from open_clip import create_model_and_transforms, get_tokenizer
from PIL import Image
import numpy as np
from loguru import logger

class ImageEncoder:
    def __init__(self, model_name='hf-hub:Marqo/marqo-fashionSigLIP', device='cpu'):
        self.device = device
        logger.info(f"Loading SigLIP model: {model_name} on {device}...")
        self.model, _, self.preprocess = create_model_and_transforms(model_name, device=device)
        self.model.eval()
        logger.info("Model loaded successfully.")

    def encode_image(self, pil_image: Image.Image) -> np.ndarray:
        """
        Extracts the feature vector (embedding) from an image.
        This is the 'penultimate layer' representation.
        """
        try:
            image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
            with torch.no_grad():
                image_features = self.model.encode_image(image)
                # Normalize the features to unit length for better similarity matching
                image_features /= image_features.norm(dim=-1, keepdim=True)
            return image_features.cpu().numpy().flatten()
        except Exception as e:
            logger.error(f"Error encoding image: {e}")
            raise e

    def get_embedding_dimension(self):
        # SigLIP Marqo usually produces 768 dimensions
        return 768

if __name__ == "__main__":
    # Quick test
    encoder = ImageEncoder()
    mock_image = Image.new('RGB', (224, 224), color='red')
    embedding = encoder.encode_image(mock_image)
    print(f"Embedding shape: {embedding.shape}")
    print(f"First 5 values: {embedding[:5]}")
