import onnxruntime as ort
import numpy as np
import cv2
from PIL import Image
from loguru import logger

class ClothingDetector:
    def __init__(self, model_path="models/yolow-l_0_05_nms_0_3_v2.onnx"):
        self.model_path = model_path
        if not os.path.exists(model_path):
            logger.warning(f"YOLO model not found at {model_path}. Please place it there.")
            self.session = None
        else:
            logger.info(f"Loading YOLO model from {model_path}...")
            self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
            logger.info("YOLO loaded.")

    def detect_and_crop(self, pil_image: Image.Image, score_threshold=0.2):
        """
        Detect clothing items and return a list of cropped PIL images.
        """
        if self.session is None:
            return [pil_image] # Fallback to original if no model

        # 1. Preprocess
        orig_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
        orig_h, orig_w = orig_image.shape[:2]
        
        resized_image = cv2.resize(orig_image, (640, 640))
        inference_input = resized_image.astype(np.float32) / 255.0
        inference_input = np.transpose(inference_input, (2, 0, 1))
        inference_input = np.expand_dims(inference_input, axis=0)

        # 2. Inference
        input_name = self.session.get_inputs()[0].name
        output_names = [o.name for o in self.session.get_outputs()]
        outputs = self.session.run(output_names, {input_name: inference_input})
        
        # [x_min, y_min, x_max, y_max], scores, class_ids
        bboxes = outputs[1][0]
        scores = outputs[2][0]
        
        crops = []
        for i, score in enumerate(scores):
            if score > score_threshold:
                x1, y1, x2, y2 = bboxes[i]
                # Scale back to original
                x1 = int(x1 * orig_w / 640)
                y1 = int(y1 * orig_h / 640)
                x2 = int(x2 * orig_w / 640)
                y2 = int(y2 * orig_h / 640)
                
                # Crop
                crop_cv2 = orig_image[y1:y2, x1:x2]
                if crop_cv2.size > 0:
                    crop_pil = Image.fromarray(cv2.cvtColor(crop_cv2, cv2.COLOR_BGR2RGB))
                    crops.append(crop_pil)
        
        # If nothing detected, return original image as a single crop
        return crops if crops else [pil_image]
import os
