"""
Test Script: Kiểm tra Vector Index - SIÊU GỌN 
Chỉ đo thời gian từ lúc có full query đến khi nhận kết quả
"""
import sys
import os
import asyncio
import time

# Add parent dir to path
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from common.embedding_service import create_embedding_async
from common.starrocks_connection import StarRocksConnection


async def main():
    print("="*60)
    print("🔍 TEST VECTOR INDEX - PERFORMANCE CHECK")
    print("="*60)
    
    # 1. Chuẩn bị dữ liệu (Embedding)
    test_query = "áo sơ mi"
    print(f" Creating embedding for: '{test_query}'...")
    query_vector = await create_embedding_async(test_query)
    
    # 2. Tạo full query string (COPY Y HỆT câu bro bảo "ngon" nhưng thêm Hint chuẩn)
    v_str = "[" + ",".join(str(v) for v in query_vector) + "]"
    
    sql = f"""
    SELECT /*+ SET_VAR(ann_params='{{"ef_search":64}}') */
        internal_ref_code,
        approx_cosine_similarity(vector, {v_str}) as score
    FROM shared_source.magento_product_dimension_with_text_embedding__tmp
    ORDER BY score DESC
    LIMIT 50
    """
    
    # 3. Kết nối DB
    db = StarRocksConnection()
    conn = db.connect()
    conn.ping(reconnect=True) 

    try:
        with conn.cursor() as cursor:
            # --- KIỂM TRA PLAN ---
            print("🔬 Đang kiểm tra Query Plan (XEM FULL)...")
            cursor.execute(f"EXPLAIN {sql}")
            plan = cursor.fetchall()
            
            is_vector_on = False
            for row in plan:
                line = str(row)
                print(f"  > {line}") # In full để mình soi
                if "VECTORINDEX: ON" in line:
                    is_vector_on = True
            
            print(f"\n📊 Vector Index Status: {'ON ✅' if is_vector_on else 'OFF ❌'}")
            
            # Baseline: Độ trễ mạng
            t_base_0 = time.perf_counter()
            cursor.execute("SELECT 1")
            cursor.fetchone()
            baseline_ms = (time.perf_counter() - t_base_0) * 1000
            print(f"📡 Network Baseline: {baseline_ms:.2f} ms")

            print(f"\n🚀 ĐANG GỬI TRUY VẤN (10 vòng lặp)...")
            
            durations = []
            for i in range(1, 11):
                start_time = time.perf_counter()
                cursor.execute(sql)
                rows = cursor.fetchall()
                duration_ms = (time.perf_counter() - start_time) * 1000
                durations.append(duration_ms)
                
                status_icon = "⚡" if duration_ms < 80 else "🐢"
                print(f"🔄 Lần {i:02d}: {duration_ms:7.2f} ms {status_icon}")
            
            print("-" * 30)
            print(f"📈 AVG: {sum(durations)/10:.2f}ms | MIN: {min(durations):.2f}ms | MAX: {max(durations):.2f}ms")
            print("-" * 30 + "\n")
            
    except Exception as e:
        print(f"\n❌ Lỗi: {e}\n")
    finally:
        pass

    print("="*60 + "\n")


if __name__ == "__main__":
    asyncio.run(main())
