Commit 446d8f46 authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

feat: add hybrid search with keyword filtering, update port to 5003

parent c9a437ef
Pipeline #3326 failed with stages
# Sử dụng Python 3.11 Slim để tối ưu dung lượng
# Lấy Python 3.11 slim (ít file, nhẹ)
FROM python:3.11-slim
# Thiết lập thư mục làm việc
# Thư mục làm việc
WORKDIR /app
# Cài đặt các dependencies hệ thống cần thiết (nếu có, ví dụ build tools)
# RUN apt-get update && apt-get install -y gcc libpq-dev && rm -rf /var/lib/apt/lists/*
ENV PYTHONUNBUFFERED=1 PYTHONDONTWRITEBYTECODE=1
# Copy requirements.txt trước để tận dụng Docker cache
# Copy requirements rồi cài package
COPY requirements.txt .
RUN pip install -r requirements.txt
# Cài đặt thư viện Python
RUN pip install --no-cache-dir -r requirements.txt
# Copy toàn bộ source code vào image
# Copy code
COPY . .
# Expose port 5000 (Port chạy server)
EXPOSE 5000
# Copy entrypoint script (nếu có)
COPY entrypoint.sh /app/entrypoint.sh
RUN chmod +x /app/entrypoint.sh
# Mở port 5003
EXPOSE 5003
# Lệnh chạy server
CMD ["python", "server.py"]
# Chạy server
CMD ["/app/entrypoint.sh"]
\ No newline at end of file
"""
Fashion Q&A Agent Package
"""
from .graph import build_graph
from .models import AgentConfig, AgentState, get_config
__all__ = [
"AgentConfig",
"AgentState",
"build_graph",
"get_config",
]
"""
Fashion Q&A Agent Controller
Điều phối luồng chạy của Agent, tích hợp ConversationManager (Postgres Memory).
Switched to LangSmith for tracing (configured via environment variables).
"""
import json
import logging
import uuid
from collections.abc import AsyncGenerator
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from common.llm_factory import create_llm
from common.conversation_manager import get_conversation_manager, ConversationManager
from config import DEFAULT_MODEL
from .graph import build_graph
from .models import AgentState, get_config
from .tools.get_tools import get_all_tools
logger = logging.getLogger(__name__)
async def chat_controller(
query: str, user_id: str, model_name: str = DEFAULT_MODEL, conversation_id: str | None = None, images: list[str] | None = None
) -> AsyncGenerator[str, None]:
# 1. Khởi tạo & Chuẩn bị (Dependency Injection)
logger.info(f"▶️ Starting chat_controller with model: {model_name} for user: {user_id}")
config = get_config()
config.model_name = model_name
# Khởi tạo resources - Factory sẽ tự động chọn provider dựa trên tên model
llm = create_llm(model_name=model_name, streaming=True)
tools = get_all_tools()
graph = build_graph(config, llm=llm, tools=tools)
# Init ConversationManager (Singleton)
memory = get_conversation_manager()
actual_conv_id = conversation_id or str(uuid.uuid4())
# LOAD HISTORY & Prepare State
# Get history from Postgres (returns list of dicts)
history_dicts = memory.get_chat_history(user_id, limit=10)
# Convert to BaseMessage objects
history = []
for h in reversed(history_dicts): # API returns desc, we want chronological for context
if h['is_human']:
history.append(HumanMessage(content=h['message']))
else:
history.append(AIMessage(content=h['message']))
current_human_msg = HumanMessage(content=query)
initial_state, exec_config = _prepare_execution_context(
query=query, user_id=user_id, actual_conv_id=actual_conv_id, history=history, images=images
)
final_ai_message = None
# 3. Stream Engine
try:
async for event in graph.astream(initial_state, config=exec_config, stream_mode="values"):
final_ai_message = _extract_last_ai_message(event) or final_ai_message
# Serialize messages to dicts to avoid "content='...'" string representation
if "messages" in event:
event["messages"] = [
m.dict() if hasattr(m, "dict") else m
for m in event["messages"]
]
yield f"data: {json.dumps(event, default=str, ensure_ascii=False)}\n\n"
# 4. Hậu xử lý (Lưu DB)
_handle_post_chat(
memory=memory,
user_id=user_id,
human_query=query,
ai_msg=final_ai_message,
)
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"💥 Stream error: {e}", exc_info=True)
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
logger.info(f"✅ Request completed for conversation {actual_conv_id}")
def _prepare_execution_context(query: str, user_id: str, actual_conv_id: str, history: list, images: list | None):
"""Tách logic chuẩn bị state và config để giảm độ phức tạp."""
initial_state: AgentState = {
"messages": [HumanMessage(content=query)],
"history": history,
"user_id": user_id,
"images": [],
"thread_id": actual_conv_id,
"image_analysis": None,
}
run_id = str(uuid.uuid4())
# Metadata for LangSmith
metadata = {
"conversation_id": actual_conv_id,
"user_id": user_id,
"run_id": run_id
}
exec_config = RunnableConfig(
configurable={
"conversation_id": actual_conv_id,
"user_id": user_id,
"transient_images": images or [],
"run_id": run_id,
},
run_id=run_id,
metadata=metadata, # Attach metadata for LangSmith
)
return initial_state, exec_config
def _extract_last_ai_message(event: dict) -> AIMessage | None:
"""Trích xuất tin nhắn AI cuối cùng từ event stream."""
if event.get("messages"):
last_msg = event["messages"][-1]
if isinstance(last_msg, AIMessage):
return last_msg
return None
def _handle_post_chat(memory: ConversationManager, user_id: str, human_query: str, ai_msg: AIMessage | None):
"""Xử lý lưu history sau khi kết thúc stream. LangSmith tự động trace nên không cần update thủ công."""
if ai_msg:
# Save User Message
memory.save_message(user_id, human_query, True)
# Save AI Message
memory.save_message(user_id, ai_msg.content, False)
logger.info(f"💾 Saved conversation for user {user_id} to Postgres")
"""
Fashion Q&A Agent Controller
Điều phối luồng chạy của Agent, tích hợp ConversationManager (Postgres Memory).
Using Langfuse @observe() decorator for automatic trace creation.
"""
import json
import logging
import uuid
from collections.abc import AsyncGenerator
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import RunnableConfig
from langfuse import get_client, observe, propagate_attributes
from langfuse.langchain import CallbackHandler
from common.llm_factory import create_llm
from common.conversation_manager import get_conversation_manager, ConversationManager
from .graph import build_graph
from .models import AgentState, get_config
from .tools.get_tools import get_all_tools
logger = logging.getLogger(__name__)
@observe(capture_input=False, capture_output=False)
async def chat_controller(
query: str, user_id: str, model_name: str, conversation_id: str | None = None, images: list[str] | None = None
) -> AsyncGenerator[str, None]:
# 1. Khởi tạo & Chuẩn bị (Dependency Injection)
logger.info(f"▶️ Starting chat_controller for user: {user_id}")
config = get_config()
config.model_name = model_name
# Khởi tạo resources bên ngoài để dễ test/mock
llm = create_llm(model_name=model_name, api_key=config.openai_api_key, streaming=True)
tools = get_all_tools()
graph = build_graph(config, llm=llm, tools=tools)
# Init ConversationManager (Singleton)
memory = get_conversation_manager()
actual_conv_id = conversation_id or str(uuid.uuid4())
# 2. Chạy logic chính trong trace context
with propagate_attributes(session_id=actual_conv_id, user_id=user_id, tags=["canifa", "chatbot"]):
# LOAD HISTORY & Prepare State
# Get history from Postgres (returns list of dicts)
history_dicts = memory.get_chat_history(user_id, limit=10)
# Convert to BaseMessage objects
history = []
for h in reversed(history_dicts): # API returns desc, we want chronological for context
if h['is_human']:
history.append(HumanMessage(content=h['message']))
else:
history.append(AIMessage(content=h['message']))
current_human_msg = HumanMessage(content=query)
initial_state, exec_config = _prepare_execution_context(
query=query, user_id=user_id, actual_conv_id=actual_conv_id, history=history, images=images
)
final_ai_message = None
# 3. Stream Engine
try:
async for event in graph.astream(initial_state, config=exec_config, stream_mode="values"):
final_ai_message = _extract_last_ai_message(event) or final_ai_message
# Serialize messages to dicts to avoid "content='...'" string representation
if "messages" in event:
event["messages"] = [
m.dict() if hasattr(m, "dict") else m
for m in event["messages"]
]
yield f"data: {json.dumps(event, default=str, ensure_ascii=False)}\n\n"
# 4. Hậu xử lý (Lưu DB & Trace)
_handle_post_chat(
memory=memory,
user_id=user_id,
human_query=query,
ai_msg=final_ai_message,
query=query,
images_count=len(images) if images else 0,
)
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"💥 Stream error: {e}", exc_info=True)
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
logger.info(f"✅ Request completed for conversation {actual_conv_id}")
def _prepare_execution_context(query: str, user_id: str, actual_conv_id: str, history: list, images: list | None):
"""Tách logic chuẩn bị state và config để giảm độ phức tạp."""
initial_state: AgentState = {
"messages": [HumanMessage(content=query)],
"history": history,
"user_id": user_id,
"images": [],
"thread_id": actual_conv_id,
"image_analysis": None,
}
run_id = str(uuid.uuid4())
exec_config = RunnableConfig(
configurable={
"conversation_id": actual_conv_id,
"user_id": user_id,
"transient_images": images or [],
"run_id": run_id,
},
run_id=run_id,
callbacks=[CallbackHandler()],
)
return initial_state, exec_config
def _extract_last_ai_message(event: dict) -> AIMessage | None:
"""Trích xuất tin nhắn AI cuối cùng từ event stream."""
if event.get("messages"):
last_msg = event["messages"][-1]
if isinstance(last_msg, AIMessage):
return last_msg
return None
def _handle_post_chat(memory: ConversationManager, user_id: str, human_query: str, ai_msg: AIMessage | None, query: str, images_count: int):
"""Xử lý lưu history và update trace sau khi kết thúc stream."""
if ai_msg:
# Save User Message
memory.save_message(user_id, human_query, True)
# Save AI Message
memory.save_message(user_id, ai_msg.content, False)
logger.info(f"💾 Saved conversation for user {user_id} to Postgres")
# Update trace
try:
langfuse = get_client()
langfuse.update_current_trace(
name="canifa-chatbot-query",
input={"query": query, "images_count": images_count},
output={"response": ai_msg.content if ai_msg else None},
)
except Exception as e:
logger.warning(f"Failed to update trace: {e}")
"""
Fashion Q&A Agent Graph
LangGraph workflow với clean architecture.
Tất cả resources (LLM, Tools) khởi tạo trong __init__.
Sử dụng ConversationManager (Postgres) để lưu history thay vì checkpoint.
"""
import logging
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableConfig
from langgraph.cache.memory import InMemoryCache
from langgraph.graph import END, StateGraph
from langgraph.prebuilt import ToolNode
from langgraph.types import CachePolicy
from common.llm_factory import create_llm
from .models import AgentConfig, AgentState, get_config
from .prompt import get_system_prompt
from .tools.get_tools import get_all_tools, get_collection_tools
logger = logging.getLogger(__name__)
class CANIFAGraph:
"""
Fashion Q&A Agent Graph Manager.
Khởi tạo tất cả resources trong __init__, các node dùng self.xxx.
"""
def __init__(
self,
config: AgentConfig | None = None,
llm: BaseChatModel | None = None,
tools: list | None = None,
):
self.config = config or get_config()
self._compiled_graph: Any | None = None
# Dependency Injection: Ưu tiên dùng llm/tools được truyền vào
self.llm: BaseChatModel = llm or create_llm(
model_name=self.config.model_name, api_key=self.config.openai_api_key, streaming=True
)
# Phân loại tools
self.all_tools = tools or get_all_tools()
self.collection_tools = get_collection_tools() # Vẫn lấy list name để routing
# Retrieval tools are logically all tools minus collection tools (conceptually, or specific list)
# For simplicity and robust tool usage, we can bind all tools to retrieval node if needed,
# or separate them. The user code snippet uses `self.retrieval_tools` but passed `all_tools`.
# Reviewing user snippet: `workflow.add_node("retrieve_tools", ToolNode(self.retrieval_tools)...`
# But `retrieval_tools` wasn't defined in __init__ in the user snippet, likely implied.
# I'll define retrieval_tools as all tools for now or filter if strictly needed.
# Assuming all_tools are retrieval compatible except collection ones?
# Let's use all_tools for the ToolNode to be safe unless distinct behavior is needed.
self.retrieval_tools = self.all_tools
self.llm_with_tools = self.llm.bind_tools(self.all_tools)
self.system_prompt = get_system_prompt()
self.cache = InMemoryCache()
async def _agent_node(self, state: AgentState, config: RunnableConfig) -> dict:
"""Agent node - LLM reasoning với tools và history sạch."""
messages = state.get("messages", [])
history = state.get("history", [])
prompt = ChatPromptTemplate.from_messages(
[
("system", self.system_prompt),
MessagesPlaceholder(variable_name="history"), # Long-term clean history
MessagesPlaceholder(variable_name="messages"), # Current turn technical messages
]
)
# 2. Xử lý Image hint (Lấy từ Config của lượt chạy này)
transient_images = config.get("configurable", {}).get("transient_images", [])
if transient_images and messages:
# Removed image processing logic as requested
pass
# Invoke LLM
chain = prompt | self.llm_with_tools
response = await chain.ainvoke({"messages": messages, "history": history})
return {"messages": [response]}
def _should_continue(self, state: AgentState) -> str:
"""Routing: tool nodes hoặc end."""
last_message = state["messages"][-1]
if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
logger.info("🏁 Agent finished")
return "end"
tool_names = [tc["name"] for tc in last_message.tool_calls]
collection_names = [t.name for t in self.collection_tools]
if any(name in collection_names for name in tool_names):
logger.info(f"🔄 → collect_tools: {tool_names}")
return "collect_tools"
logger.info(f"🔄 → retrieve_tools: {tool_names}")
return "retrieve_tools"
def build(self) -> Any:
"""Build và compile LangGraph workflow (Không dùng Checkpointer)."""
if self._compiled_graph is not None:
return self._compiled_graph
logger.info("🔨 Building LangGraph workflow (No Checkpointer)...")
workflow = StateGraph(AgentState)
# Nodes
workflow.add_node("agent", self._agent_node)
workflow.add_node("retrieve_tools", ToolNode(self.retrieval_tools), cache_policy=CachePolicy(ttl=3600))
workflow.add_node("collect_tools", ToolNode(self.collection_tools))
# Edges
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
self._should_continue,
{"retrieve_tools": "retrieve_tools", "collect_tools": "collect_tools", "end": END},
)
workflow.add_edge("retrieve_tools", "agent")
workflow.add_edge("collect_tools", "agent")
# Compile WITHOUT checkpointer
self._compiled_graph = workflow.compile(cache=self.cache)
# ❌ KHÔNG ATTACH Langfuse callback vào compiled graph
# ✅ Sẽ pass callback vào runtime config của mỗi lượt chạy
logger.info("✅ Graph compiled (Langfuse callback will be per-run)")
return self._compiled_graph
@property
def graph(self) -> Any:
return self.build()
# --- Singleton & Public API ---
_instance: list[CANIFAGraph | None] = [None]
def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None) -> Any:
"""Get compiled graph (singleton)."""
if _instance[0] is None:
_instance[0] = CANIFAGraph(config, llm, tools)
return _instance[0].build()
def get_graph_manager(
config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None
) -> CANIFAGraph:
"""Get CANIFAGraph instance."""
if _instance[0] is None:
_instance[0] = CANIFAGraph(config, llm, tools)
return _instance[0]
def reset_graph() -> None:
"""Reset singleton for testing."""
_instance[0] = None
# """
# Simple Memory Manager for Fashion Q&A Agent
# Tối giản hóa: Chỉ sử dụng 1 Collection 'conversations' trong MongoDB
# """
# logger = logging.getLogger(__name__)
class SimpleMemoryManager:
"""
Quản lý bộ nhớ tối giản: Lưu/Load message từ 1 collection duy nhất.
Sử dụng conversation_id làm định danh chính.
TẠM THỜI COMMENT LOGIC ĐỂ TEST CHAY
"""
def __init__(self):
# self.client = get_mongo_client()
# self.db = self.client[MONGODB_DB_NAME or "ai_law"]
# self.collection = self.db["conversations"] # Tên bảng tối giản
# self._indexes_created = False
pass
async def _ensure_indexes(self):
# if not self._indexes_created:
# # Index theo ID cuộc hội thoại và thời gian cập nhật
# await self.collection.create_index("_id")
# await self.collection.create_index("updated_at")
# self._indexes_created = True
pass
async def save_messages(
self,
conversation_id: str,
messages: list, # List[BaseMessage],
user_id: str | None = None, # Optional[str] = None
):
"""Lưu toàn bộ danh sách tin nhắn vào cuộc hội thoại."""
# try:
# await self._ensure_indexes()
# messages_dict = [self._message_to_dict(msg) for msg in messages]
# await self.collection.update_one(
# {"_id": conversation_id},
# {
# "$set": {
# "user_id": user_id,
# "messages": messages_dict,
# "updated_at": datetime.utcnow(),
# },
# "$setOnInsert": {
# "created_at": datetime.utcnow(),
# }
# },
# upsert=True
# )
# except Exception as e:
# # logger.error(f"❌ Save memory error: {e}")
# raise
pass
async def load_messages(self, conversation_id: str, limit: int = 20) -> list: # List[BaseMessage]:
"""Load N tin nhắn gần nhất của cuộc hội thoại."""
# try:
# await self._ensure_indexes()
# doc = await self.collection.find_one({"_id": conversation_id})
# if not doc or "messages" not in doc:
# return []
# msgs_dict = doc["messages"]
# # Chỉ lấy số lượng tin nhắn giới hạn để tiết kiệm token
# if limit:
# msgs_dict = msgs_dict[-limit:]
# return [self._dict_to_message(m) for m in msgs_dict]
# except Exception as e:
# # logger.error(f"❌ Load memory error: {e}")
# return []
return []
# def _message_to_dict(self, msg: BaseMessage) -> dict:
# return {
# "type": msg.__class__.__name__,
# "content": msg.content,
# "timestamp": datetime.utcnow().isoformat(),
# }
# def _dict_to_message(self, msg_dict: dict) -> BaseMessage:
# m_type = msg_dict.get("type", "HumanMessage")
# content = msg_dict.get("content", "")
# if m_type == "AIMessage": return AIMessage(content=content)
# if m_type == "SystemMessage": return SystemMessage(content=content)
# return HumanMessage(content=content)
# Singleton
_memory_manager = None
def get_memory_manager():
global _memory_manager
if _memory_manager is None:
_memory_manager = SimpleMemoryManager()
return _memory_manager
from typing import Annotated, Any, TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from pydantic import BaseModel
import config as global_config
class QueryRequest(BaseModel):
"""API Request model cho Fashion Q&A Chat"""
query: str
history: list[BaseMessage] | None = None
model_name: str = global_config.DEFAULT_MODEL
user_id: str | None = None
images: list[str] | None = None
image_analysis: dict[str, Any] | None = None
conversation_id: str | None = None
class AgentState(TypedDict):
"""Trạng thái của Agent trong LangGraph."""
messages: Annotated[list[BaseMessage], add_messages]
history: list[BaseMessage] # Conversation history sạch (Human + AI)
user_id: str | None
images: list[str] | None
image_analysis: dict[str, Any] | None
thread_id: str | None
class AgentConfig:
"""Class chứa cấu hình runtime cho Agent."""
def __init__(self, **kwargs):
self.model_name = kwargs.get("model_name") or global_config.DEFAULT_MODEL
self.openai_api_key = kwargs.get("openai_api_key")
self.google_api_key = kwargs.get("google_api_key")
self.groq_api_key = kwargs.get("groq_api_key")
self.supabase_url = kwargs.get("supabase_url")
self.supabase_key = kwargs.get("supabase_key")
self.langfuse_public_key = kwargs.get("langfuse_public_key")
self.langfuse_secret_key = kwargs.get("langfuse_secret_key")
self.langfuse_base_url = kwargs.get("langfuse_base_url")
def get_config() -> AgentConfig:
"""Khởi tạo cấu hình Agent từ các biến môi trường."""
return AgentConfig(
model_name=global_config.DEFAULT_MODEL,
openai_api_key=global_config.OPENAI_API_KEY,
google_api_key=global_config.GOOGLE_API_KEY,
groq_api_key=global_config.GROQ_API_KEY,
supabase_url=global_config.AI_SUPABASE_URL,
supabase_key=global_config.AI_SUPABASE_KEY,
langfuse_public_key=global_config.LANGFUSE_PUBLIC_KEY,
langfuse_secret_key=global_config.LANGFUSE_SECRET_KEY,
langfuse_base_url=global_config.LANGFUSE_BASE_URL,
)
"""
Agent Nodes Package
"""
from .agent import agent_node
__all__ = ["agent_node"]
This diff is collapsed.
"""
Tools Package
Export tool và factory function
"""
from .data_retrieval_tool import data_retrieval_tool
from .get_tools import get_all_tools
__all__ = ["data_retrieval_tool", "get_all_tools"]
"""
Tool thu thập thông tin khách hàng (Tên, Số điện thoại, Email)
Dùng để đẩy data về CRM hoặc hệ thống lưu trữ khách hàng.
"""
import json
import logging
from langchain_core.tools import tool
logger = logging.getLogger(__name__)
@tool
async def collect_customer_info(name: str, phone: str, email: str | None = None) -> str:
"""
Sử dụng tool này để ghi lại thông tin khách hàng khi họ muốn tư vấn sâu hơn,
nhận khuyến mãi hoặc đăng ký mua hàng.
Args:
name: Tên của khách hàng
phone: Số điện thoại của khách hàng
email: Email của khách hàng (không bắt buộc)
"""
try:
print(f"\n[TOOL] --- 📝 Thu thập thông tin khách hàng: {name} - {phone} ---")
logger.info(f"📝 Collecting customer info: {name}, {phone}, {email}")
# Giả lập việc đẩy data đi (CRM/Sheet)
# Trong thực tế, bạn sẽ gọi một API ở đây
db_record = {
"customer_name": name,
"phone_number": phone,
"email_address": email,
"status": "pending_consultation",
}
# Trả về kết quả thành công
return json.dumps(
{
"status": "success",
"message": (
f"Cảm ơn anh/chị {name}. CiCi đã ghi nhận thông tin và sẽ có nhân viên "
f"liên hệ tư vấn qua số điện thoại {phone} sớm nhất ạ!"
),
"data_captured": db_record,
},
ensure_ascii=False,
)
except Exception as e:
logger.error(f"❌ Lỗi khi thu thập thông tin: {e}")
return json.dumps(
{
"status": "error",
"message": f"Xin lỗi, CiCi gặp sự cố khi lưu thông tin. Anh/chị vui lòng thử lại sau ạ. Lỗi: {e!s}",
},
ensure_ascii=False,
)
"""
CANIFA Data Retrieval Tool - Tối giản cho Agentic Workflow.
Hỗ trợ Hybrid Search: Semantic (Vector) + Metadata Filter.
"""
import json
import logging
from decimal import Decimal
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from agent.tools.product_search_helpers import build_starrocks_query
from common.starrocks_connection import StarRocksConnection
logger = logging.getLogger(__name__)
class DecimalEncoder(json.JSONEncoder):
"""Xử lý kiểu Decimal từ Database khi convert sang JSON."""
def default(self, obj):
if isinstance(obj, Decimal):
return float(obj)
return super().default(obj)
class SearchParams(BaseModel):
"""Cấu trúc tham số tìm kiếm mà Agent phải cung cấp, map trực tiếp với Database."""
query: str | None = Field(
None,
description="Câu hỏi/mục đích tự do của user (đi chơi, dự tiệc, phỏng vấn,...) - dùng cho Semantic Search",
)
keywords: str | None = Field(
None, description="Từ khóa kỹ thuật cụ thể (áo polo, quần jean,...) - dùng cho LIKE search"
)
internal_ref_code: str | None = Field(None, description="Mã sản phẩm (ví dụ: 1TS23S012)")
product_color_code: str | None = Field(None, description="Mã màu sản phẩm (ví dụ: 1TS23S012-SK010)")
product_line_vn: str | None = Field(None, description="Dòng sản phẩm (Áo phông, Quần short,...)")
gender_by_product: str | None = Field(None, description="Giới tính: male, female")
age_by_product: str | None = Field(None, description="Độ tuổi: adult, kids, baby, others")
master_color: str | None = Field(None, description="Màu sắc chính (Đen/ Black, Trắng/ White,...)")
material_group: str | None = Field(None, description="Nhóm chất liệu (Knit - Dệt Kim, Woven - Dệt Thoi,...)")
season: str | None = Field(None, description="Mùa (Spring Summer, Autumn Winter)")
style: str | None = Field(None, description="Phong cách (Basic Update, Fashion,...)")
fitting: str | None = Field(None, description="Form dáng (Regular, Slim, Loose,...)")
form_neckline: str | None = Field(None, description="Kiểu cổ (Crew Neck, V-neck,...)")
form_sleeve: str | None = Field(None, description="Kiểu tay (Short Sleeve, Long Sleeve,...)")
price_min: float | None = Field(None, description="Giá thấp nhất")
price_max: float | None = Field(None, description="Giá cao nhất")
action: str = Field("search", description="Hành động: 'search' (tìm kiếm) hoặc 'visual_search' (phân tích ảnh)")
from langsmith import traceable
@tool(args_schema=SearchParams)
@traceable(run_type="tool", name="data_retrieval_tool")
async def data_retrieval_tool(
action: str = "search",
query: str | None = None,
keywords: str | None = None,
internal_ref_code: str | None = None,
product_color_code: str | None = None,
product_line_vn: str | None = None,
gender_by_product: str | None = None,
age_by_product: str | None = None,
master_color: str | None = None,
material_group: str | None = None,
season: str | None = None,
style: str | None = None,
fitting: str | None = None,
form_neckline: str | None = None,
form_sleeve: str | None = None,
price_min: float | None = None,
price_max: float | None = None,
) -> str:
"""
Tìm kiếm sản phẩm trong database của CANIFA sử dụng tìm kiếm ngữ nghĩa (Semantic), từ khóa (Keywords) hoặc các bộ lọc thuộc tính.
Cơ chế hoạt động (Hybrid Search):
- Nếu có 'query': Hệ thống sẽ tạo vector embedding và tìm kiếm theo độ tương đồng ngữ nghĩa.
- Nếu có 'keywords' hoặc các thuộc tính khác: Hệ thống sẽ tạo các câu lệnh SQL WHERE để lọc chính xác kết quả.
- Kết hợp cả hai để mang lại kết quả tối ưu nhất.
Ví dụ sử dụng (Examples):
1. Tìm kiếm theo ý định chung:
User: "Tìm cho mình một bộ đồ đi biển mát mẻ"
Tool call: data_retrieval_tool(query="bộ đồ đi biển mát mẻ", gender_by_product="Female")
2. Tìm chính xác theo loại sản phẩm và giá:
User: "Áo polo nam dưới 400k"
Tool call: data_retrieval_tool(keywords="áo polo", gender_by_product="Male", price_max=400000)
3. Tìm theo mã sản phẩm cụ thể:
User: "Check sản phẩm 8TS24W001"
Tool call: data_retrieval_tool(internal_ref_code="8TS24W001")
4. Kết hợp tìm kiếm sâu:
User: "Áo khoác len mùa đông cho bé trai từ 200k đến 500k"
Tool call: data_retrieval_tool(query="áo khoác len ấm áp", material_group="Len", age_by_product="Kids", price_min=200000, price_max=500000)
"""
try:
# 1. Log & Prepare Params
# 1. Log & Prepare Params
params = SearchParams(
action=action,
query=query,
keywords=keywords,
internal_ref_code=internal_ref_code,
product_color_code=product_color_code,
product_line_vn=product_line_vn,
gender_by_product=gender_by_product,
age_by_product=age_by_product,
master_color=master_color,
material_group=material_group,
season=season,
style=style,
fitting=fitting,
form_neckline=form_neckline,
form_sleeve=form_sleeve,
price_min=price_min,
price_max=price_max,
)
_log_agent_call(params)
# 2. Prepare Vector (Async) if needed
query_vector = None
if query:
from common.embedding_service import create_embedding_async
query_vector = await create_embedding_async(query)
# 3. Execute Search (Async)
sql = build_starrocks_query(params, query_vector=query_vector)
db = StarRocksConnection()
products = await db.execute_query_async(sql)
if not products:
return _handle_no_results(query, keywords)
# 4. Format Results
clean_products = _format_product_results(products)
return json.dumps(
{"status": "success", "count": len(clean_products), "products": clean_products},
ensure_ascii=False,
cls=DecimalEncoder,
)
except Exception as e:
logger.error(f"Error in data_retrieval_tool: {e}")
return json.dumps({"status": "error", "message": str(e)})
def _log_agent_call(params: SearchParams):
"""Log parameters for debugging."""
filtered_params = {k: v for k, v in params.dict().items() if v is not None}
logger.info(f"📋 Agent Tool Call - data_retrieval_tool: {json.dumps(filtered_params, ensure_ascii=False)}")
def _handle_no_results(query: str | None, keywords: str | None) -> str:
"""Return standardized no-results message."""
logger.warning(f"No products found for search: query={query}, keywords={keywords}")
return json.dumps(
{
"status": "no_results",
"message": "Không tìm thấy sản phẩm nào phù hợp với yêu cầu. Vui lòng thử lại với từ khóa hoặc bộ lọc khác.",
},
ensure_ascii=False,
)
def _format_product_results(products: list[dict]) -> list[dict]:
"""Filter and format product fields for the agent."""
allowed_fields = {
"internal_ref_code",
"magento_ref_code",
"product_color_code",
"product_name",
"color_code",
"master_color",
"product_color_name",
"season_sale",
"season",
"style",
"fitting",
"size_scale",
"graphic",
"pattern",
"weaving",
"shape_detail",
"form_neckline",
"form_sleeve",
"form_length",
"form_waistline",
"form_shoulderline",
"material",
"product_group",
"product_line_vn",
"unit_of_measure",
"sale_price",
"original_price",
"material_group",
"product_line_en",
"age_by_product",
"gender_by_product",
"product_image_url",
"description_text",
"product_image_url_thumbnail",
"product_web_url",
"product_web_material",
}
return [{k: v for k, v in p.items() if k in allowed_fields} for p in products[:5]]
"""
Tools Factory
Chỉ return 1 tool duy nhất: data_retrieval_tool
"""
from langchain_core.tools import Tool
from .customer_info_tool import collect_customer_info
from .data_retrieval_tool import data_retrieval_tool
def get_retrieval_tools() -> list[Tool]:
"""Các tool chỉ dùng để đọc/truy vấn dữ liệu (Có thể cache)"""
return [data_retrieval_tool]
def get_collection_tools() -> list[Tool]:
"""Các tool dùng để ghi/thu thập dữ liệu (KHÔNG cache)"""
return [collect_customer_info]
def get_all_tools() -> list[Tool]:
"""Return toàn bộ list tools cho Agent"""
return get_retrieval_tools() + get_collection_tools()
import logging
from common.embedding_service import create_embedding
logger = logging.getLogger(__name__)
def _escape(val: str) -> str:
"""Thoát dấu nháy đơn để tránh SQL Injection cơ bản."""
return val.replace("'", "''")
def _get_where_clauses(params) -> list[str]:
"""Xây dựng danh sách các điều kiện lọc từ params."""
clauses = []
clauses.extend(_get_price_clauses(params))
clauses.extend(_get_exact_match_clauses(params))
clauses.extend(_get_special_clauses(params))
return clauses
def _get_price_clauses(params) -> list[str]:
"""Lọc theo giá."""
clauses = []
p_min = getattr(params, "price_min", None)
if p_min is not None:
clauses.append(f"sale_price >= {p_min}")
p_max = getattr(params, "price_max", None)
if p_max is not None:
clauses.append(f"sale_price <= {p_max}")
return clauses
def _get_exact_match_clauses(params) -> list[str]:
"""Các trường lọc chính xác (Exact match)."""
clauses = []
exact_filters = [
("gender_by_product", "gender_by_product"),
("age_by_product", "age_by_product"),
("season", "season"),
("material_group", "material_group"),
("product_line_vn", "product_line_vn"),
("style", "style"),
("fitting", "fitting"),
("form_neckline", "form_neckline"),
("form_sleeve", "form_sleeve"),
]
for param_name, col_name in exact_filters:
val = getattr(params, param_name, None)
if val:
clauses.append(f"{col_name} = '{_escape(val)}'")
return clauses
def _get_special_clauses(params) -> list[str]:
"""Các trường hợp đặc biệt: Mã sản phẩm, Màu sắc."""
clauses = []
# Mã sản phẩm
ref_code = getattr(params, "internal_ref_code", None)
if ref_code:
r = _escape(ref_code)
clauses.append(f"(internal_ref_code = '{r}' OR magento_ref_code = '{r}')")
# Màu sắc
color = getattr(params, "master_color", None)
if color:
c = _escape(color).lower()
clauses.append(f"(LOWER(master_color) LIKE '%{c}%' OR LOWER(product_color_name) LIKE '%{c}%')")
return clauses
def build_starrocks_query(params, query_vector: list[float] | None = None) -> str:
"""
Build SQL Hybrid tối ưu:
1. Pre-filtering (Metadata)
2. Vector Search (HNSW Index)
3. Grouping (Gom màu theo style)
"""
# --- Process vector in query field ---
query_text = getattr(params, "query", None)
if query_text and query_vector is None:
query_vector = create_embedding(query_text)
# --- Build filter clauses ---
where_clauses = _get_where_clauses(params)
where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
# --- Build SQL ---
if query_vector and len(query_vector) > 0:
v_str = "[" + ",".join(str(v) for v in query_vector) + "]"
# OPTIMIZED: Only SELECT necessary fields in CTE, not SELECT *
sql = f"""
WITH top_sku_candidates AS (
SELECT
internal_ref_code,
product_name,
sale_price,
original_price,
master_color,
product_image_url,
product_image_url_thumbnail,
product_web_url,
description_text,
material,
material_group,
gender_by_product,
age_by_product,
season,
style,
fitting,
form_neckline,
form_sleeve,
product_line_vn,
product_color_name,
cosine_similarity(vector, {v_str}) as similarity_score
FROM shared_source.magento_product_dimension_with_text_embedding
WHERE {where_sql} AND vector IS NOT NULL
ORDER BY similarity_score DESC
LIMIT 50
)
SELECT
internal_ref_code,
ANY_VALUE(product_name) as product_name,
ANY_VALUE(sale_price) as sale_price,
ANY_VALUE(original_price) as original_price,
GROUP_CONCAT(DISTINCT master_color ORDER BY master_color SEPARATOR ', ') as available_colors,
ANY_VALUE(product_image_url) as product_image_url,
ANY_VALUE(product_image_url_thumbnail) as product_image_url_thumbnail,
ANY_VALUE(product_web_url) as product_web_url,
ANY_VALUE(description_text) as description_text,
ANY_VALUE(material) as material,
ANY_VALUE(material_group) as material_group,
ANY_VALUE(gender_by_product) as gender_by_product,
ANY_VALUE(age_by_product) as age_by_product,
ANY_VALUE(season) as season,
ANY_VALUE(style) as style,
ANY_VALUE(fitting) as fitting,
ANY_VALUE(form_neckline) as form_neckline,
ANY_VALUE(form_sleeve) as form_sleeve,
ANY_VALUE(product_line_vn) as product_line_vn,
MAX(similarity_score) as max_score
FROM top_sku_candidates
GROUP BY internal_ref_code
ORDER BY max_score DESC
LIMIT 5
""" # noqa: S608
else:
# FALLBACK: Keyword search - MAXIMALLY OPTIMIZED (No CTE overhead)
keywords = getattr(params, "keywords", None)
keyword_filter = ""
if keywords:
k = _escape(keywords).lower()
keyword_filter = f" AND LOWER(product_name) LIKE '%{k}%'"
# Direct query - No CTE needed, StarRocks optimizes GROUP BY internally
sql = f"""
SELECT
internal_ref_code,
ANY_VALUE(product_name) as product_name,
ANY_VALUE(sale_price) as sale_price,
ANY_VALUE(original_price) as original_price,
GROUP_CONCAT(DISTINCT master_color ORDER BY master_color SEPARATOR ', ') as available_colors,
ANY_VALUE(product_image_url) as product_image_url,
ANY_VALUE(product_image_url_thumbnail) as product_image_url_thumbnail,
ANY_VALUE(product_web_url) as product_web_url,
ANY_VALUE(description_text) as description_text,
ANY_VALUE(material) as material,
ANY_VALUE(material_group) as material_group,
ANY_VALUE(gender_by_product) as gender_by_product,
ANY_VALUE(age_by_product) as age_by_product,
ANY_VALUE(season) as season,
ANY_VALUE(style) as style,
ANY_VALUE(fitting) as fitting,
ANY_VALUE(form_neckline) as form_neckline,
ANY_VALUE(form_sleeve) as form_sleeve,
ANY_VALUE(product_line_vn) as product_line_vn
FROM shared_source.magento_product_dimension_with_text_embedding
WHERE {where_sql} {keyword_filter}
GROUP BY internal_ref_code
HAVING COUNT(*) > 0
ORDER BY sale_price ASC
LIMIT 5
""" # noqa: S608
logger.info(f"📊 Query Mode: {'Vector' if query_vector else 'Keyword'}")
return sql
"""
Fashion Q&A Agent Router
FastAPI endpoints cho Fashion Q&A Agent service.
Router chỉ chứa định nghĩa API, logic nằm ở controller.
"""
import json
import logging
from collections.abc import AsyncGenerator
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import StreamingResponse
from agent.controller import chat_controller
from agent.models import QueryRequest
from config import DEFAULT_MODEL
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/chat", summary="Fashion Q&A Chat (Non-streaming)")
async def fashion_qa_chat(req: QueryRequest, request: Request):
"""
Endpoint chat không stream - trả về response JSON đầy đủ một lần.
"""
# Trích xuất user_id từ request (auth middleware)
user_id = getattr(request.state, "user_id", None) or req.user_id or "default_user"
logger.info(f"📥 [Incoming Query - NonStream] User: {user_id} | Query: {req.query}")
try:
# Gọi controller để xử lý logic và nhận generator stream
# Note: Vì chat_controller có decorator @observe(), cần await để unwrap
generator: AsyncGenerator[str, None] = chat_controller(
query=req.query,
user_id=user_id,
model_name=DEFAULT_MODEL,
conversation_id=req.conversation_id,
images=req.images,
)
# Collect toàn bộ events từ generator
final_response = None
async for chunk in generator:
# Parse SSE data format
if chunk.startswith("data: "):
data_str = chunk[6:].strip()
if data_str != "[DONE]":
final_response = json.loads(data_str)
# Trả về response cuối cùng
if final_response and "messages" in final_response:
last_message = final_response["messages"][-1]
response_text = last_message.get("content", "") if isinstance(last_message, dict) else str(last_message)
logger.info(f"📤 [Outgoing Response - NonStream] User: {user_id} | Response: {response_text}")
return {
"status": "success",
"response": response_text,
"conversation_id": req.conversation_id,
}
return {"status": "error", "message": "No response generated"}
except Exception as e:
logger.error(f"Error in fashion_qa_chat: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) from e
# ====================== FASHION Q&A CHAT API ======================
@router.post("/stream/chat", summary="Fashion Q&A Chat with Streaming Response")
async def fashion_qa_chat_stream(req: QueryRequest, request: Request):
"""
Endpoint duy nhất cho việc chat với Fashion Agent.
"""
# Trích xuất user_id từ request (auth middleware)
user_id = getattr(request.state, "user_id", None) or req.user_id or "default_user"
logger.info(f"📥 [Incoming Query] User: {user_id} | Query: {req.query}")
try:
# Gọi controller để xử lý logic và nhận generator stream
# Note: Vì chat_controller có decorator @observe(), cần await để unwrap
generator: AsyncGenerator[str, None] = chat_controller(
query=req.query,
user_id=user_id,
model_name=DEFAULT_MODEL,
conversation_id=req.conversation_id,
images=req.images,
)
async def logging_generator(gen: AsyncGenerator[str, None]):
full_response_log = ""
first_chunk = True
try:
async for chunk in gen:
if first_chunk:
logger.info("🚀 [Stream Started] First chunk received")
first_chunk = False
full_response_log += chunk
yield chunk
except Exception as e:
logger.error(f"❌ [Stream Error] {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
logger.info(f"📤 [Outgoing Response Stream Finished] Total Chunks Length: {len(full_response_log)}")
return StreamingResponse(
logging_generator(generator),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
except Exception as e:
logger.error(f"Error in fashion_qa_chat: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) from e
from fastapi import APIRouter, HTTPException
from typing import List, Dict, Any, Optional
from pydantic import BaseModel
import logging
from common.conversation_manager import get_conversation_manager
router = APIRouter(tags=["Conservation"])
logger = logging.getLogger(__name__)
class ChatMessage(BaseModel):
id: int
user_id: str | None = None # Optional usually not needed in list but good for consistency
message: str
is_human: bool
timestamp: Any
class ChatHistoryResponse(BaseModel):
data: List[Dict[str, Any]]
next_cursor: Optional[int] = None
@router.get("/history/{user_id}", summary="Get Chat History by User ID", response_model=ChatHistoryResponse)
async def get_chat_history(user_id: str, limit: Optional[int] = 20, before_id: Optional[int] = None):
"""
Lấy lịch sử chat của user từ Postgres database.
Trả về object chứa `data` (list messages) và `next_cursor` để dùng cho trang tiếp theo.
"""
try:
# Sử dụng ConversationManager Singleton
manager = get_conversation_manager()
# Lấy history từ DB với pagination
history = manager.get_chat_history(user_id, limit=limit, before_id=before_id)
next_cursor = None
if history and len(history) > 0:
# Cursor cho trang tiếp theo chính là ID của tin nhắn cuối cùng (cũ nhất trong batch này)
next_cursor = history[-1]['id']
return {
"data": history,
"next_cursor": next_cursor
}
except Exception as e:
logger.error(f"Error fetching chat history for user {user_id}: {e}")
raise HTTPException(status_code=500, detail=str(e))
import os
import sys
import io
import time
import numpy as np
from fastapi import APIRouter, UploadFile, File, Form, HTTPException
from fastapi.responses import JSONResponse
from PIL import Image
from loguru import logger
# Đường dẫn đến module image_process
BASE_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
IMAGE_PROCESS_DIR = os.path.join(BASE_DIR, "image_process")
if IMAGE_PROCESS_DIR not in sys.path:
sys.path.append(IMAGE_PROCESS_DIR)
from core.encoder import ImageEncoder
from common.starrocks_connection import StarRocksConnection
router = APIRouter(prefix="/api/recommend", tags=["Recommendation StarRocks"])
# Singleton
_encoder = None
_db = StarRocksConnection()
def get_encoder():
global _encoder
if _encoder is None:
_encoder = ImageEncoder()
return _encoder
@router.post("/image")
async def recommend_by_image(
file: UploadFile = File(...),
):
"""
API tìm kiếm sản phẩm bằng ảnh - Hardcoded StarRocks SQL logic
"""
start_time = time.time()
encoder = get_encoder()
try:
# 1. Xử lý ảnh và lấy Vector
contents = await file.read()
img = Image.open(io.BytesIO(contents)).convert("RGB")
# Trích xuất vector 768-dim
vector = encoder.encode_image(img)
# Chuyển vector sang định dạng string [v1, v2, ...] để gắn vào SQL
v_str = "[" + ",".join(map(str, vector.tolist())) + "]"
# 2. Hardcoded SQL cho StarRocks (CTE + Grouping)
# Bảng đích: magento_product_dimension_with_image_embedding
target_table = "shared_source.magento_product_dimension_with_image_embedding"
sql = f"""
WITH top_sku_candidates AS (
SELECT
internal_ref_code,
product_name,
sale_price,
original_price,
master_color,
product_image_url,
product_image_url_thumbnail,
product_web_url,
description_text,
material,
material_group,
gender_by_product,
age_by_product,
season,
style,
fitting,
form_neckline,
form_sleeve,
product_line_vn,
product_color_name,
cosine_similarity(vector, {v_str}) as similarity_score
FROM {target_table}
WHERE vector IS NOT NULL
ORDER BY similarity_score DESC
LIMIT 50
)
SELECT
internal_ref_code,
MAX_BY(product_name, similarity_score) as product_name,
MAX_BY(sale_price, similarity_score) as sale_price,
MAX_BY(original_price, similarity_score) as original_price,
GROUP_CONCAT(DISTINCT master_color ORDER BY master_color SEPARATOR ', ') as available_colors,
MAX_BY(product_image_url, similarity_score) as product_image_url,
MAX_BY(product_image_url_thumbnail, similarity_score) as product_image_url_thumbnail,
MAX_BY(product_web_url, similarity_score) as product_web_url,
MAX_BY(description_text, similarity_score) as description_text,
MAX_BY(material, similarity_score) as material,
MAX_BY(material_group, similarity_score) as material_group,
MAX_BY(gender_by_product, similarity_score) as gender_by_product,
MAX_BY(age_by_product, similarity_score) as age_by_product,
MAX_BY(season, similarity_score) as season,
MAX_BY(style, similarity_score) as style,
MAX_BY(fitting, similarity_score) as fitting,
MAX_BY(form_neckline, similarity_score) as form_neckline,
MAX_BY(form_sleeve, similarity_score) as form_sleeve,
MAX_BY(product_line_vn, similarity_score) as product_line_vn,
MAX(similarity_score) as similarity_score
FROM top_sku_candidates
GROUP BY internal_ref_code
ORDER BY similarity_score DESC
LIMIT 10
"""
# 3. Thực thi query
results = await _db.execute_query_async(sql)
process_time = time.time() - start_time
logger.info(f"Visual search done in {process_time:.3f}s. Found {len(results)} groups.")
return {
"status": "success",
"process_time": f"{process_time:.3f}s",
"results": results
}
except Exception as e:
logger.error(f"Recommend Error: {e}")
return JSONResponse(
status_code=500,
content={"status": "error", "message": str(e)}
)
"""
RECOMMEND TEXT API
═══════════════════
API endpoint cho Text-based Product Recommendation
Sử dụng Hybrid Search: Semantic (Vector) Search → Keyword Filtering
"""
import logging
from fastapi import APIRouter
from pydantic import BaseModel, Field
from typing import Optional
from search.hybrid_search import hybrid_search
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/recommend", tags=["Recommendation - Text"])
class SearchParams(BaseModel):
"""Request body cho search API"""
query: str = Field(..., description="Search query (VD: 'áo mùa đông cho bé gái')")
limit: int = Field(default=20, ge=1, le=100, description="Số kết quả trả về (1-100)")
price_min: Optional[float] = Field(default=None, ge=0, description="Giá tối thiểu")
price_max: Optional[float] = Field(default=None, ge=0, description="Giá tối đa")
@router.post("/text")
async def recommend_by_text(params: SearchParams):
"""
🔍 HYBRID SEARCH: Semantic Search → Keyword Filtering
Flow:
1. Thêm prefix "description_text: " vào query
2. Embed prefixed query → Vector
3. Vector search (HNSW) → TOP 200 candidates
4. Parse keywords → Filter (season, gender, age, product_type, color...)
5. Return TOP N results
Example queries:
- "áo mùa đông cho bé gái"
- "quần jean nam"
- "váy đỏ người lớn"
- "áo khoác phao trẻ em"
"""
logger.info(f"Hybrid search: query='{params.query}', limit={params.limit}")
result = await hybrid_search(
query=params.query,
limit=params.limit,
price_min=params.price_min,
price_max=params.price_max,
)
return result
import pymysql
conn = pymysql.connect(
host='172.16.2.100',
port=9030,
user='anhvh',
password='v0WYGeyLRCckXotT',
database='shared_source',
cursorclass=pymysql.cursors.DictCursor
)
try:
with conn.cursor() as cursor:
# Check product có gender=male nhưng tên có "nữ"
print("=== Products với gender='male' và product_line LIKE '%Áo%' ===\n")
cursor.execute("""
SELECT product_name, gender_by_product, product_line_vn
FROM magento_product_dimension_with_text_embedding
WHERE gender_by_product = 'male' AND product_line_vn LIKE '%Áo%'
LIMIT 10
""")
for row in cursor.fetchall():
print(f" {row['product_name']}")
print(f" → gender: {row['gender_by_product']}, line: {row['product_line_vn']}\n")
print("\n=== Check products có tên 'Áo body nữ' ===\n")
cursor.execute("""
SELECT product_name, gender_by_product, product_line_vn
FROM magento_product_dimension_with_text_embedding
WHERE product_name LIKE '%Áo body nữ%'
LIMIT 5
""")
for row in cursor.fetchall():
print(f" {row['product_name']}")
print(f" → gender: {row['gender_by_product']}, line: {row['product_line_vn']}\n")
finally:
conn.close()
product_name: Pack 3 đôi tất bé gái cổ thấp. master_color: Xanh da trời/ Blue. product_image_url: https://2885371169.e.cdneverest.net/pub/media/catalog/product/1/a/1ax24a002-pb449-7-9-ghep-u.jpg. product_image_url_thumbnail: https://2885371169.e.cdneverest.net/pub/media/catalog/product/1/a/1ax24a002-pb449-7-9-ghep-u.jpg. product_web_url: https://canifa.com/tat-be-gai-1ax24a002?color=PB449&utm_source=chatbot&utm_medium=rsa&utm_campaign=testver25.9. description_text: Pack 3 đôi tất cổ thấp, họa tiết basic.
Chất liệu cotton pha mềm mại, co giãn tốt, màu sắc đơn giản dễ phối mix đồ.. material: None. material_group: Yarn - Sợi. gender_by_product: female. age_by_product: others. season: Year. style: Feminine. fitting: Slim. size_scale: 4/6. form_neckline: None. form_sleeve: None. product_line_vn: Tất. product_color_name: Blue Strip 449
\ No newline at end of file
version: "3.8"
services:
# --- Backend Service ---
backend:
build: .
container_name: canifa_backend
env_file: .env
ports:
- "5000:5000"
- "5003:5003"
volumes:
- .:/app # Mount code để hot-reload khi dev (tuỳ chọn)
- .:/app
environment:
- CHECKPOINT_POSTGRES_URL=postgresql://postgres:password@postgres_db:5432/canifa_chat
# Các biến môi trường khác bro có thể thêm vào đây hoặc dùng file .env
- OPENAI_API_KEY=${OPENAI_API_KEY}
- LANGFUSE_PUBLIC_KEY=${LANGFUSE_PUBLIC_KEY}
- LANGFUSE_SECRET_KEY=${LANGFUSE_SECRET_KEY}
- LANGFUSE_HOST=${LANGFUSE_HOST}
- STARROCKS_HOST=${STARROCKS_HOST}
- STARROCKS_port=${STARROCKS_port}
- STARROCKS_USER=${STARROCKS_USER}
- STARROCKS_PASSWORD=${STARROCKS_PASSWORD}
- STARROCKS_DB=${STARROCKS_DB}
depends_on:
- postgres_db
- PORT=5003
restart: unless-stopped
# --- Database Service (Postgres) ---
postgres_db:
image: postgres:15-alpine
container_name: canifa_postgres
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: password
POSTGRES_DB: canifa_chat
ports:
- "5433:5432" # Map ra host port 5433 để tránh đụng độ nếu máy bro đang chạy Postgres khác
volumes:
- postgres_data:/var/lib/postgresql/data
restart: unless-stopped
volumes:
postgres_data:
deploy:
resources:
limits:
memory: 8g
logging:
driver: "json-file"
options:
tag: "{{.Name}}"
#!/bin/bash
NUM_CORES=$(nproc)
WORKERS=$((2 * NUM_CORES + 1))
echo "🔧 [STARTUP] CPU cores: $NUM_CORES"
echo "🔧 [STARTUP] Gunicorn workers: $WORKERS"
exec gunicorn \
server:app \
--workers "$WORKERS" \
--worker-class uvicorn.workers.UvicornWorker \
--worker-connections 1000 \
--max-requests 1000 \
--max-requests-jitter 100 \
--timeout 30 \
--access-logfile - \
--error-logfile - \
--bind 0.0.0.0:5000 \
--log-level info
\ No newline at end of file
"""
Performance Test API Server
Server giả lập để test tải với Locust - KHÔNG gọi OpenAI, KHÔNG gọi Postgres.
"""
import logging
import sys
import time
from pathlib import Path
import uvicorn
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
# Setup path to import backend modules
sys.path.insert(0, str(Path(__file__).parent.parent))
# Setup Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ============================================================
# 1. MOCK LLM - Trả lời siêu tốc, không gọi OpenAI
# ============================================================
from typing import Any
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage
from langchain_core.outputs import ChatGeneration, ChatResult
class MockHighSpeedLLM(BaseChatModel):
"""
Mock LLM siêu tốc cho Performance Testing.
Không gọi OpenAI - trả về response ngay lập tức.
"""
def _generate(
self, messages: list[Any], stop: list[str] | None = None, run_manager: Any | None = None, **kwargs: Any
) -> ChatResult:
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content="[MOCK] Xin chào! Đây là bot test, không phải OpenAI thật."))
]
)
@property
def _llm_type(self) -> str:
return "mock-high-speed"
def bind_tools(self, tools: Any, **kwargs: Any) -> Any:
"""Bypass tool binding - trả về self."""
return self
# ============================================================
# 2. PATCH create_llm TRƯỚC KHI IMPORT GRAPH
# ============================================================
def mock_create_llm(*args, **kwargs):
"""Factory function giả - trả về MockLLM."""
logger.info("🎭 MockLLM được gọi thay vì OpenAI!")
return MockHighSpeedLLM()
# Patch TRƯỚC khi import agent.graph
import common.llm_factory
common.llm_factory.create_llm = mock_create_llm
logger.info("🎭 PATCHED common.llm_factory.create_llm")
# ============================================================
# 3. PATCH Checkpointer - Dùng MemorySaver thay Postgres
# ============================================================
from langgraph.checkpoint.memory import MemorySaver
# Tạo 1 instance MemorySaver dùng chung
_shared_memory_saver = MemorySaver()
def mock_get_checkpointer():
"""Trả về MemorySaver thay vì Postgres."""
logger.info("🧠 Using MemorySaver (RAM) instead of Postgres")
return _shared_memory_saver
import agent.checkpointer
# Patch function
agent.checkpointer.get_checkpointer = mock_get_checkpointer
# QUAN TRỌNG: Reset singleton để nó gọi lại mock function
agent.checkpointer._checkpointer_manager._checkpointer = _shared_memory_saver
logger.info("🔧 PATCHED agent.checkpointer (Reset singleton)")
# ============================================================
# 4. PATCH Langfuse - Tắt để tránh Rate Limit
# ============================================================
def mock_get_callback_handler():
"""Trả về None - không gửi trace."""
return
import common.langfuse_client
common.langfuse_client.get_callback_handler = mock_get_callback_handler
logger.info("🔇 PATCHED common.langfuse_client.get_callback_handler")
# ============================================================
# 5. GIỜ MỚI IMPORT GRAPH (Sau khi patch xong)
# ============================================================
from agent.config import get_config
from langchain_core.messages import HumanMessage
from agent.graph import build_graph
from agent.tools.product_search_helpers import build_starrocks_query
# ============================================================
# 6. IMPORT CHO DB TEST
# ============================================================
from common.starrocks_connection import StarRocksConnection
# ============================================================
# FASTAPI APP
# ============================================================
app = FastAPI(title="Performance Test API")
# Request Models
class SearchRequest(BaseModel):
query: str
limit: int = 10
class MockParams(BaseModel):
query_text: str = ""
limit: int = 10
sku: str = None
gender: str = None
season: str = None
color: str = None
product_line: str = None
price_min: float = None
price_max: float = None
# Global variables
mock_graph = None
@app.on_event("startup")
async def startup_event():
global mock_graph
logger.info("🚀 Performance Test API Server Starting...")
# 1. Pre-warm DB connection
conn = StarRocksConnection()
conn.connect()
logger.info("✅ StarRocks Connection initialized")
# 2. Build Mock Graph (LLM đã bị patch từ đầu file)
config = get_config()
mock_graph = build_graph(config)
logger.info("✅ Mock Graph built successfully (No OpenAI, No Postgres, No Langfuse)")
# ============================================================
# ENDPOINTS
# ============================================================
@app.post("/test/db-search")
async def test_db_search(request: SearchRequest):
"""
Test StarRocks DB Search.
KHÔNG GỌI OPENAI - Chỉ test DB.
"""
start_time = time.time()
try:
params = MockParams(query_text=request.query, limit=request.limit)
sql = build_starrocks_query(params)
db = StarRocksConnection()
products = await db.execute_query_async(sql)
# Filter fields
limited_products = products[:5]
ALLOWED_FIELDS = {
"product_name",
"sale_price",
"original_price",
"product_image_url_thumbnail",
"product_web_url",
"master_color",
"product_color_name",
"material",
"internal_ref_code",
}
clean_products = [{k: v for k, v in p.items() if k in ALLOWED_FIELDS} for p in limited_products]
process_time = time.time() - start_time
return {
"status": "success",
"count": len(clean_products),
"process_time_seconds": round(process_time, 4),
"products": clean_products,
}
except Exception as e:
logger.error(f"DB Search Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/test/db-ping")
async def test_db_ping():
"""
Test kết nối DB thuần túy (SELECT 1).
"""
start_time = time.time()
try:
db = StarRocksConnection()
await db.execute_query_async("SELECT 1")
process_time = time.time() - start_time
return {"status": "success", "process_time_seconds": round(process_time, 4)}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/test/graph-mock-chat")
async def test_graph_mock_chat(request: SearchRequest):
"""
Test toàn bộ Flow Graph với MockLLM.
KHÔNG GỌI OPENAI - Test logic code Python + LangGraph.
"""
if not mock_graph:
raise HTTPException(500, "Mock Graph not initialized")
start_time = time.time()
try:
# Tạo thread_id unique cho mỗi request
thread_id = f"perf_test_{int(time.time() * 1000)}"
input_state = {
"messages": [HumanMessage(content=request.query)],
"user_id": "perf_user",
}
config_runnable = {"configurable": {"thread_id": thread_id}}
# Chạy Graph
async for _event in mock_graph.astream(input_state, config=config_runnable):
pass # Chỉ cần chạy hết flow
process_time = time.time() - start_time
return {
"status": "success",
"mode": "mock_llm",
"process_time_seconds": round(process_time, 4),
"message": "Graph Flow executed successfully",
}
except Exception as e:
logger.error(f"Graph Error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "Performance Test API is running!", "mode": "MOCK (No OpenAI)"}
# ============================================================
# MAIN
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("🚀 PERFORMANCE TEST SERVER")
print("=" * 60)
print("🎭 LLM: MockHighSpeedLLM (No OpenAI)")
print("🧠 Checkpointer: MemorySaver (No Postgres)")
print("🔇 Langfuse: Disabled")
print("=" * 60)
uvicorn.run(app, host="0.0.0.0", port=8000)
import random
from locust import HttpUser, between, task
# Danh sách từ khóa để random
PRODUCTS = [
"áo phông",
"áo khoác",
"áo len",
"áo polo",
"áo sơ mi",
"quần jean",
"quần âu",
"quần short",
"váy liền",
"chân váy",
]
COLORS = ["", "màu đen", "màu trắng", "màu đỏ", "màu xanh", "màu vàng", "màu be"]
GENDERS = ["", "nam", "nữ", "bé trai", "bé gái"]
class FashionUser(HttpUser):
# Thời gian nghỉ giữa các lần request (mô phỏng người thật: 1-3s)
wait_time = between(1, 3)
# --- DB TESTS (ĐÃ TEST XONG, TẠM TẮT) ---
# @task(3)
# def search_random_product(self):
# """
# Giả lập user tìm kiếm sản phẩm bất kỳ.
# Ghép random: [Sản phẩm] + [Giới tính] + [Màu]
# Ví dụ: "áo phông nam màu đen"
# """
# # Random query builder
# p = random.choice(PRODUCTS)
# g = random.choice(GENDERS)
# c = random.choice(COLORS)
# query = f"{p} {g} {c}".strip()
# # Loại bỏ khoảng trắng thừa
# query = " ".join(query.split())
# self.client.post("/test/db-search", json={"query": query, "limit": 10}, name="/test/db-search (Dynamic)")
# @task(1)
# def search_hot_items(self):
# """
# Giả lập các món Hot trend nhiều người tìm giống nhau (Cache test)
# """
# hot_item = "áo giữ nhiệt"
# self.client.post("/test/db-search", json={"query": hot_item, "limit": 10}, name="/test/db-search (Hero Item)")
# @task(3)
# def fast_ping_check(self):
# """
# Kiểm tra tốc độ thuần túy của việc lấy Connection + SELECT 1.
# Nếu cái này < 100ms mà Search > 3000ms --> Do câu Query nặng.
# Nếu cái này cũng > 3000ms --> Do Pool bị nghẽn (Server quá tải).
# """
# self.client.post("/test/db-ping", json={}, name="/test/db-ping (Isolate Latency)")
# --- LOGIC GRAPH TEST (BẬT) ---
@task(1)
def check_graph_overhead(self):
"""
Test Logic Flow (Mock LLM + MemorySaver).
⚠️ Chạy với 300-500 users thôi, không nên quá 1000.
"""
query = random.choice(["hello", "tìm áo", "giá bao nhiêu", "có màu gì"])
self.client.post("/test/graph-mock-chat", json={"query": query}, name="/test/graph Logic (Mock LLM)")
locust -f locustfile.py --host=http://localhost:8000
\ No newline at end of file
.\.venv\Scripts\activate
\ No newline at end of file
.\.venv\Scripts\
.\.venv\Scripts\activate
uvicorn server:app --host 0.0.0.0 --port 8000 --reload
uvicorn server:app --host 0.0.0.0 --port 5000
docker restart chatbot-backend
docker restart chatbot-backend && docker logs -f chatbot-backend
docker logs -f chatbot-backend
docker restart chatbot-backend
\ No newline at end of file
"""
Script: Check products với data gender bị sai
Tìm các sản phẩm có tên chứa "nữ" nhưng gender_by_product = 'male' (và ngược lại)
"""
import pymysql
conn = pymysql.connect(
host='172.16.2.100',
port=9030,
user='anhvh',
password='v0WYGeyLRCckXotT',
database='shared_source',
cursorclass=pymysql.cursors.DictCursor
)
try:
with conn.cursor() as cursor:
# Products tên có "nữ" nhưng gender = 'male'
print("=" * 80)
print("PRODUCTS TÊN CÓ 'NỮ' NHƯNG gender = 'male' (DATA SAI!)")
print("=" * 80)
cursor.execute("""
SELECT
magento_ref_code,
internal_ref_code,
product_name,
gender_by_product,
product_line_vn
FROM magento_product_dimension_with_text_embedding
WHERE gender_by_product = 'male'
AND (product_name LIKE '%nữ%' OR product_name LIKE '%Nữ%')
LIMIT 50
""")
rows = cursor.fetchall()
print(f"\nTìm thấy {len(rows)} sản phẩm:\n")
for row in rows:
print(f"magento_ref_code: {row['magento_ref_code']}")
print(f" internal_ref_code: {row['internal_ref_code']}")
print(f" product_name: {row['product_name']}")
print(f" gender: {row['gender_by_product']} (SAI! Phải là female)")
print(f" product_line: {row['product_line_vn']}")
print()
# Products tên có "nam/trai" nhưng gender = 'female'
print("\n" + "=" * 80)
print("PRODUCTS TÊN CÓ 'NAM/TRAI' NHƯNG gender = 'female' (DATA SAI!)")
print("=" * 80)
cursor.execute("""
SELECT
magento_ref_code,
internal_ref_code,
product_name,
gender_by_product,
product_line_vn
FROM magento_product_dimension_with_text_embedding
WHERE gender_by_product = 'female'
AND (product_name LIKE '%nam%' OR product_name LIKE '%trai%'
OR product_name LIKE '%Nam%' OR product_name LIKE '%Trai%')
LIMIT 50
""")
rows = cursor.fetchall()
print(f"\nTìm thấy {len(rows)} sản phẩm:\n")
for row in rows:
print(f"magento_ref_code: {row['magento_ref_code']}")
print(f" internal_ref_code: {row['internal_ref_code']}")
print(f" product_name: {row['product_name']}")
print(f" gender: {row['gender_by_product']} (có thể sai)")
print(f" product_line: {row['product_line_vn']}")
print()
finally:
conn.close()
# Search module for Hybrid Search
# Combines Semantic (Vector) Search + Keyword Filtering
from .hybrid_search import hybrid_search
from .query_parser import parse_query, build_where_clause
from .keyword_mappings import KEYWORD_MAPPINGS
__all__ = [
"hybrid_search",
"parse_query",
"build_where_clause",
"KEYWORD_MAPPINGS",
]
This diff is collapsed.
This diff is collapsed.
"""
QUERY PARSER
════════════
Parse user query để extract keywords và build SQL WHERE clause
Functions:
- parse_query(): Extract filters từ query text
- build_where_clause(): Convert filters thành SQL WHERE
"""
import re
import logging
from .keyword_mappings import KEYWORD_MAPPINGS
logger = logging.getLogger(__name__)
def is_word_match(keyword: str, query: str) -> bool:
"""
Check keyword có phải là từ riêng biệt trong query không
Tránh case "việt nam" match "nam"
VD:
- "nam" trong "áo nam" → True (từ riêng biệt)
- "nam" trong "việt nam" → False (substring của từ khác)
- "bé gái" trong "áo bé gái" → True
"""
# Pattern: keyword phải có word boundary (space, start, end) ở 2 đầu
# Dùng \b cho word boundary với tiếng Việt cần xử lý đặc biệt
# Escape special regex characters trong keyword
escaped_keyword = re.escape(keyword)
# Pattern: từ phải đứng riêng (có space hoặc đầu/cuối string)
pattern = r'(?:^|[\s,\.!?])' + escaped_keyword + r'(?:[\s,\.!?]|$)'
return bool(re.search(pattern, query))
def parse_query(query: str) -> dict:
"""
Parse query text và extract filters dựa trên KEYWORD_MAPPINGS
Args:
query: User query string (VD: "áo mùa đông cho bé gái màu hồng")
Returns:
dict: Filters grouped by field
{
"season": {"values": ["Fall Winter"], "op": "IN"},
"gender_by_product": {"values": ["female"], "op": "="},
...
}
Note:
- Nếu không có keyword nào match → trả về {} → không filter gì cả (lấy hết)
- Keywords phải là từ riêng biệt (word boundary), tránh "việt nam" → "nam"
"""
query_lower = query.lower().strip()
filters = {}
matched_keywords = set()
# Sort keywords by length (dài trước) để match chính xác hơn
# VD: "áo khoác" phải match trước "áo"
# VD: "bé gái" phải match trước "bé"
sorted_keywords = sorted(KEYWORD_MAPPINGS.keys(), key=len, reverse=True)
for keyword in sorted_keywords:
# Check keyword có xuất hiện như từ riêng biệt không
if not is_word_match(keyword, query_lower):
continue
# Kiểm tra keyword này có phải là substring của keyword đã match không
# VD: đã match "bé gái" thì không match "bé" nữa
is_substring_of_matched = False
for matched in matched_keywords:
if keyword in matched and keyword != matched:
is_substring_of_matched = True
break
if is_substring_of_matched:
continue
mapping = KEYWORD_MAPPINGS[keyword]
field = mapping["field"]
# Chỉ lấy field đầu tiên match (không override)
if field not in filters:
filters[field] = {
"values": mapping["values"],
"op": mapping["op"]
}
matched_keywords.add(keyword)
logger.debug(f"Matched keyword: '{keyword}' → {field} = {mapping['values']}")
logger.info(f"Query '{query}' → Extracted filters: {filters}")
return filters
def build_where_clause(filters: dict) -> str:
"""
Build SQL WHERE clause từ extracted filters
Args:
filters: Dict of filters from parse_query()
Returns:
str: SQL WHERE clause (VD: "WHERE season IN ('Fall Winter') AND ...")
Returns empty string if no filters
"""
if not filters:
return ""
clauses = []
for field, config in filters.items():
values = config["values"]
op = config["op"]
if op == "=":
# Single value equality
# VD: gender_by_product = 'female'
clauses.append(f"{field} = '{values[0]}'")
elif op == "IN":
# Multiple values
# VD: season IN ('Fall Winter')
vals_str = ", ".join([f"'{v}'" for v in values])
clauses.append(f"{field} IN ({vals_str})")
elif op == "LIKE":
# Text search với LIKE
# VD: (product_line_vn LIKE '%Áo%' OR product_line_vn LIKE '%Áo khoác%')
like_parts = [f"{field} LIKE '%{v}%'" for v in values]
if len(like_parts) == 1:
clauses.append(like_parts[0])
else:
clauses.append(f"({' OR '.join(like_parts)})")
where_clause = "WHERE " + " AND ".join(clauses)
logger.debug(f"Built WHERE clause: {where_clause}")
return where_clause
def get_matched_keywords(query: str) -> list[str]:
"""
Get list of keywords that were matched in the query
Useful for debugging/logging
"""
query_lower = query.lower().strip()
matched = []
sorted_keywords = sorted(KEYWORD_MAPPINGS.keys(), key=len, reverse=True)
for keyword in sorted_keywords:
# Check word boundary
if not is_word_match(keyword, query_lower):
continue
# Check not substring of already matched
is_sub = any(keyword in m and keyword != m for m in matched)
if not is_sub:
matched.append(keyword)
return matched
......@@ -36,9 +36,8 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles # Import StaticFiles
# Updated APIs
from api.chatbot_route import router as chatbot_router
from api.conservation_route import router as conservation_router
# from api.recommend_image import router as recommend_image_router
from api.recommend_text import router as recommend_text_router
from common.middleware import ClerkAuthMiddleware
from config import PORT
......@@ -49,17 +48,6 @@ app = FastAPI(
version="1.0.0",
)
# Clerk Auth (comment out if testing locally without valid token might be needed,
# but user environment has it enabled. Static files usually should be public or protected.
# For now, we mount static files BEFORE auth if we want them public, or just let auth handle it.
# Usually static files for testing are convenient to be public
# or we just need to authenticate via header in fetch.
# But for simple HTML test page, often it is easier if it is open.
# However, app.add_middleware applies to everything.
# Let's assume ClerkAuthMiddleware allows public paths or we just login.
# IMPORTANT: If ClerkAuthMiddleware blocks everything, the static page won't load easily without token.
# But user just asked to mount it.
# app.add_middleware(ClerkAuthMiddleware)
print("✅ Clerk Authentication middleware DISABLED (for testing)")
# Add CORS middleware
......@@ -71,15 +59,9 @@ app.add_middleware(
allow_headers=["*"],
)
# ========== REST API Routes ==========
# Conversation History (Mới)
app.include_router(conservation_router)
# app.include_router(recommend_image_router)
app.include_router(recommend_text_router)
# Chatbot Agent (Mới)
app.include_router(chatbot_router, prefix="/api/agent")
# Mount Static Files
# Mount this LAST to avoid conflicts with API routes
try:
static_dir = os.path.join(os.path.dirname(__file__), "static")
if not os.path.exists(static_dir):
......
This diff is collapsed.
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Canifa StarRocks Visual Search</title>
<style>
:root {
--primary: #d32f2f;
--dark: #121212;
--card: #1e1e1e;
--text: #e0e0e0;
--accent: #2196f3;
}
body {
font-family: 'Inter', -apple-system, sans-serif;
background-color: var(--dark);
color: var(--text);
margin: 0;
padding: 20px;
display: flex;
flex-direction: column;
align-items: center;
}
.container {
width: 100%;
max-width: 1100px;
}
.header {
text-align: center;
margin-bottom: 30px;
}
.header h1 {
color: var(--primary);
margin-bottom: 5px;
}
.upload-section {
background: var(--card);
border: 2px dashed #444;
border-radius: 15px;
padding: 30px;
text-align: center;
cursor: pointer;
transition: all 0.3s;
}
.upload-section:hover {
border-color: var(--primary);
background: #252525;
}
#imagePreview {
max-width: 200px;
border-radius: 8px;
margin-top: 15px;
display: none;
box-shadow: 0 4px 15px rgba(0, 0, 0, 0.5);
}
.results-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(260px, 1fr));
gap: 20px;
margin-top: 30px;
}
.product-card {
background: var(--card);
border-radius: 12px;
overflow: hidden;
border: 1px solid #333;
transition: transform 0.2s;
position: relative;
}
.product-card:hover {
transform: translateY(-5px);
border-color: var(--accent);
}
.match-score {
position: absolute;
top: 10px;
right: 10px;
background: rgba(33, 150, 243, 0.9);
color: white;
padding: 4px 10px;
border-radius: 15px;
font-size: 0.85em;
font-weight: bold;
}
.product-image {
width: 100%;
aspect-ratio: 2/3;
object-fit: cover;
background: #252525;
}
.product-info {
padding: 15px;
}
.product-name {
font-weight: bold;
font-size: 1em;
margin-bottom: 8px;
height: 2.4em;
overflow: hidden;
display: -webkit-box;
-webkit-line-clamp: 2;
-webkit-box-orient: vertical;
}
.price-row {
display: flex;
align-items: center;
gap: 10px;
margin-bottom: 8px;
}
.sale-price {
color: var(--primary);
font-weight: bold;
font-size: 1.1em;
}
.original-price {
text-decoration: line-through;
color: #888;
font-size: 0.9em;
}
.sku {
font-size: 0.8em;
color: #666;
margin-top: 10px;
}
.btn-view {
display: block;
width: 100%;
padding: 8px;
background: var(--accent);
color: white;
text-align: center;
text-decoration: none;
border-radius: 5px;
margin-top: 10px;
font-size: 0.9em;
}
.loader {
display: none;
width: 30px;
height: 30px;
border: 3px solid #333;
border-top: 3px solid var(--primary);
border-radius: 50%;
animation: spin 1s linear infinite;
margin: 20px auto;
}
@keyframes spin {
100% {
transform: rotate(360deg);
}
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🚀 StarRocks Visual Recommend</h1>
<p>Tìm kiếm sản phẩm bằng ảnh trực tiếp từ database StarRocks</p>
</div>
<div class="upload-section" onclick="document.getElementById('fileInput').click()">
<p>Nhấn để chọn ảnh sản phẩm</p>
<input type="file" id="fileInput" hidden accept="image/*" onchange="handleFile(this)">
<img id="imagePreview" alt="Preview">
</div>
<div id="loader" class="loader"></div>
<div id="resultsArea" class="results-grid"></div>
</div>
<script>
async function handleFile(input) {
const file = input.files[0];
if (!file) return;
// Preview
const reader = new FileReader();
reader.onload = (e) => {
document.getElementById('imagePreview').src = e.target.result;
document.getElementById('imagePreview').style.display = 'inline-block';
};
reader.readAsDataURL(file);
// Upload
const loader = document.getElementById('loader');
const resultsArea = document.getElementById('resultsArea');
loader.style.display = 'block';
resultsArea.innerHTML = '';
const formData = new FormData();
formData.append('file', file);
try {
const response = await fetch('/api/recommend/image', {
method: 'POST',
body: formData
});
const data = await response.json();
if (data.status === 'success') {
renderResults(data.results);
} else {
resultsArea.innerHTML = `<p style="color:red">Lỗi: ${data.message}</p>`;
}
} catch (err) {
resultsArea.innerHTML = `<p style="color:red">Lỗi: ${err.message}</p>`;
} finally {
loader.style.display = 'none';
}
}
function renderResults(results) {
const resultsArea = document.getElementById('resultsArea');
if (results.length === 0) {
resultsArea.innerHTML = '<p>Không tìm thấy sản phẩm nào phù hợp.</p>';
return;
}
results.forEach(item => {
const score = (item.similarity_score * 100).toFixed(1);
const card = document.createElement('div');
card.className = 'product-card';
card.innerHTML = `
<div class="match-score">${score}%</div>
<img src="${item.product_image_url_thumbnail || item.product_image_url}" class="product-image" onerror="this.src='https://placehold.co/400x600?text=Canifa'">
<div class="product-info">
<div class="product-name">${item.product_name}</div>
<div class="price-row">
<span class="sale-price">${item.sale_price.toLocaleString()}đ</span>
${item.original_price > item.sale_price ? `<span class="original-price">${item.original_price.toLocaleString()}đ</span>` : ''}
</div>
<div style="font-size:0.85em; color:#aaa">Màu sắc: ${item.available_colors}</div>
<div class="sku">#${item.internal_ref_code}</div>
<a href="${item.product_web_url}" target="_blank" class="btn-view">Xem trên Canifa.com</a>
</div>
`;
resultsArea.appendChild(card);
});
}
</script>
</body>
</html>
\ No newline at end of file
import os
# Model Paths
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_DIR, "models")
YOLO_MODEL_PATH = os.path.join(MODELS_DIR, "yolow-l_0_05_nms_0_3_v2.onnx")
SIGLIP_MODEL_NAME = "hf-hub:Marqo/marqo-fashionSigLIP"
# FAISS config
INDEX_PATH = os.path.join(MODELS_DIR, "vector_db.index")
MAPPING_PATH = os.path.join(MODELS_DIR, "sku_mapping.pkl")
# Thresholds
DEFAULT_CONFIDENCE_THRESHOLD = 0.5
MARGIN_HIGH_CONFIDENCE = 0.1
# Device
DEVICE = "cpu"
import onnxruntime as ort
import numpy as np
import cv2
from PIL import Image
from loguru import logger
class ClothingDetector:
def __init__(self, model_path="models/yolow-l_0_05_nms_0_3_v2.onnx"):
self.model_path = model_path
if not os.path.exists(model_path):
logger.warning(f"YOLO model not found at {model_path}. Please place it there.")
self.session = None
else:
logger.info(f"Loading YOLO model from {model_path}...")
self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
logger.info("YOLO loaded.")
def detect_and_crop(self, pil_image: Image.Image, score_threshold=0.2):
"""
Detect clothing items and return a list of cropped PIL images.
"""
if self.session is None:
return [pil_image] # Fallback to original if no model
# 1. Preprocess
orig_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
orig_h, orig_w = orig_image.shape[:2]
resized_image = cv2.resize(orig_image, (640, 640))
inference_input = resized_image.astype(np.float32) / 255.0
inference_input = np.transpose(inference_input, (2, 0, 1))
inference_input = np.expand_dims(inference_input, axis=0)
# 2. Inference
input_name = self.session.get_inputs()[0].name
output_names = [o.name for o in self.session.get_outputs()]
outputs = self.session.run(output_names, {input_name: inference_input})
# [x_min, y_min, x_max, y_max], scores, class_ids
bboxes = outputs[1][0]
scores = outputs[2][0]
crops = []
for i, score in enumerate(scores):
if score > score_threshold:
x1, y1, x2, y2 = bboxes[i]
# Scale back to original
x1 = int(x1 * orig_w / 640)
y1 = int(y1 * orig_h / 640)
x2 = int(x2 * orig_w / 640)
y2 = int(y2 * orig_h / 640)
# Crop
crop_cv2 = orig_image[y1:y2, x1:x2]
if crop_cv2.size > 0:
crop_pil = Image.fromarray(cv2.cvtColor(crop_cv2, cv2.COLOR_BGR2RGB))
crops.append(crop_pil)
# If nothing detected, return original image as a single crop
return crops if crops else [pil_image]
import os
import torch
from open_clip import create_model_and_transforms, get_tokenizer
from PIL import Image
import numpy as np
from loguru import logger
class ImageEncoder:
def __init__(self, model_name='hf-hub:Marqo/marqo-fashionSigLIP', device='cpu'):
self.device = device
logger.info(f"Loading SigLIP model: {model_name} on {device}...")
self.model, _, self.preprocess = create_model_and_transforms(model_name, device=device)
self.model.eval()
logger.info("Model loaded successfully.")
def encode_image(self, pil_image: Image.Image) -> np.ndarray:
"""
Extracts the feature vector (embedding) from an image.
This is the 'penultimate layer' representation.
"""
try:
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
with torch.no_grad():
image_features = self.model.encode_image(image)
# Normalize the features to unit length for better similarity matching
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features.cpu().numpy().flatten()
except Exception as e:
logger.error(f"Error encoding image: {e}")
raise e
def get_embedding_dimension(self):
# SigLIP Marqo usually produces 768 dimensions
return 768
if __name__ == "__main__":
# Quick test
encoder = ImageEncoder()
mock_image = Image.new('RGB', (224, 224), color='red')
embedding = encoder.encode_image(mock_image)
print(f"Embedding shape: {embedding.shape}")
print(f"First 5 values: {embedding[:5]}")
from loguru import logger
class MatchResult:
def __init__(self, top_k_results: list):
self.results = top_k_results
self.confidence_level = "low"
self.message = ""
def analyze_confidence(self, score_threshold=0.5, margin_threshold=0.1):
"""
Logic 'Chốt hạ' dựa trên Score và Margin.
Note: results contain 'distance' (Euclidean), smaller is better.
But for UI we often use 'similarity' (1 - distance normalized).
Let's assume input results are sorted by best match.
"""
if not self.results:
self.confidence_level = "none"
self.message = "No matches found."
return
# Simple conversion from Euclidean distance to a 'similarity' score for logic
# This is high-level, actual metric depends on the embedding space.
def dist_to_sim(d):
return max(0, 1 - d)
s1 = dist_to_sim(self.results[0]['distance'])
if len(self.results) > 1:
s2 = dist_to_sim(self.results[1]['distance'])
margin = s1 - s2
else:
margin = 1.0 # Only one result, high margin
logger.info(f"Analysis: Top1 Similarity: {s1:.3f}, Margin: {margin:.3f}")
if s1 > score_threshold and margin > margin_threshold:
self.confidence_level = "high"
self.message = "Exact product identified."
# Only keep the best one for high confidence
self.results = [self.results[0]]
elif s1 > (score_threshold - 0.1) and margin < margin_threshold:
self.confidence_level = "medium"
self.message = "Multiple similar products found, clarification needed."
# Keep top 2-3
self.results = self.results[:3]
else:
self.confidence_level = "low"
self.message = "No exact match found. Showing similar items."
def to_dict(self):
return {
"confidence": self.confidence_level,
"message": self.message,
"items": self.results
}
import faiss
import numpy as np
import pickle
import os
from loguru import logger
class FaissManager:
def __init__(self, dimension=768, index_path="models/vector_db.index", mapping_path="models/sku_mapping.pkl"):
self.dimension = dimension
self.index_path = index_path
self.mapping_path = mapping_path
self.index = None
self.sku_mapping = [] # List of SKUs, index matches FAISS ID
if os.path.exists(index_path):
self.load()
else:
logger.info("Initializing new FAISS index...")
self.index = faiss.IndexFlatL2(dimension)
self.sku_mapping = []
def add_vectors(self, vectors: np.ndarray, skus: list):
"""
Add image vectors and their corresponding SKUs to the index.
"""
if vectors.shape[1] != self.dimension:
raise ValueError(f"Vector dimension mismatch. Expected {self.dimension}, got {vectors.shape[1]}")
self.index.add(vectors.astype('float32'))
self.sku_mapping.extend(skus)
logger.info(f"Added {len(skus)} items. Total items: {self.index.ntotal}")
def search(self, query_vector: np.ndarray, top_k=5):
"""
Search for the most similar vectors in the index.
Returns: distances, results (list of dicts with sku and distance)
"""
if self.index.ntotal == 0:
return [], []
query_vector = query_vector.reshape(1, -1).astype('float32')
distances, indices = self.index.search(query_vector, top_k)
results = []
for dist, idx in zip(distances[0], indices[0]):
if idx != -1: # -1 means not found
results.append({
"sku": self.sku_mapping[idx],
"distance": float(dist)
})
return distances[0], results
def save(self):
os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
faiss.write_index(self.index, self.index_path)
with open(self.mapping_path, "wb") as f:
pickle.dump(self.sku_mapping, f)
logger.info(f"Index saved to {self.index_path}")
def load(self):
logger.info(f"Loading FAISS index from {self.index_path}...")
self.index = faiss.read_index(self.index_path)
with open(self.mapping_path, "rb") as f:
self.sku_mapping = pickle.load(f)
logger.info(f"Loaded {self.index.ntotal} items.")
if __name__ == "__main__":
# Quick test
mgr = FaissManager(dimension=4)
v = np.random.random((2, 4)).astype('float32')
mgr.add_vectors(v, ["SKU1", "SKU2"])
dist, res = mgr.search(v[0], top_k=1)
print(f"Search result: {res}")
import json
import numpy as np
import os
from loguru import logger
import ast
# Import our FaissManager
from db.faiss_mgr import FaissManager
from config import INDEX_PATH, MAPPING_PATH
def ingest_from_json(json_path):
if not os.path.exists(json_path):
logger.error(f"JSON file not found at {json_path}")
return
logger.info(f"Reading data from {json_path}...")
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Process items
all_items = []
if isinstance(data, dict):
for key in data:
if isinstance(data[key], list):
all_items.extend(data[key])
elif isinstance(data, list):
all_items = data
logger.info(f"Found {len(all_items)} products to index.")
faiss_mgr = FaissManager(dimension=768, index_path=INDEX_PATH, mapping_path=MAPPING_PATH)
vectors = []
skus = []
for item in all_items:
sku = item.get('internal_ref_code')
vector_str = item.get('vector')
if not sku or not vector_str:
continue
try:
# Vector is stored as a string "[0.1, 0.2, ...]"
vector = np.array(ast.literal_eval(vector_str)).astype('float32')
# Normalize for similarity
norm = np.linalg.norm(vector)
if norm > 0:
vector = vector / norm
vectors.append(vector)
skus.append(sku)
except Exception as e:
logger.error(f"Error parsing vector for SKU {sku}: {e}")
if vectors:
vectors_np = np.stack(vectors)
faiss_mgr.add_vectors(vectors_np, skus)
faiss_mgr.save()
logger.info(f"Ingestion complete. Total items in FAISS: {faiss_mgr.index.ntotal}")
else:
logger.warning("No valid vectors found to index.")
if __name__ == "__main__":
target_json = r"d:\cnf\chatbot_canifa_image\data\magento_product_dimension_with_image_embedding_202512241158.json"
ingest_from_json(target_json)
import os
import time
import io
from fastapi import FastAPI, Request, File, UploadFile, Form
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from PIL import Image
from loguru import logger
from core.detector import ClothingDetector
from core.encoder import ImageEncoder
from core.matcher import MatchResult
from db.faiss_mgr import FaissManager
app = FastAPI(title="Image Identification System")
# Initialize components
templates = Jinja2Templates(directory="templates")
detector = ClothingDetector()
encoder = ImageEncoder()
faiss_mgr = FaissManager()
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/match")
async def match_image(
file: UploadFile = File(...),
score_threshold: float = Form(0.5)
):
start_all = time.time()
try:
# 1. Read Image
contents = await file.read()
image = Image.open(io.BytesIO(contents)).convert("RGB")
# 2. Detector: Remove noise (crop products)
logger.info("Step 1: Detecting and cropping...")
crops = detector.detect_and_crop(image)
logger.info(f"Detected {len(crops)} items.")
final_response_results = []
# 3. Process each crop
for i, crop in enumerate(crops):
logger.info(f"Processing crop {i+1}...")
# Encoder: Get vector
embedding = encoder.encode_image(crop)
# FAISS: Search in DB
distances, raw_results = faiss_mgr.search(embedding, top_k=5)
# Matcher: Analyze confidence
matcher = MatchResult(raw_results)
matcher.analyze_confidence(score_threshold=score_threshold)
result_data = matcher.to_dict()
final_response_results.append(result_data)
total_time = time.time() - start_all
return JSONResponse({
"status": "success",
"process_time": f"{total_time:.3f}s",
"detections": final_response_results
})
except Exception as e:
logger.error(f"Processing failed: {e}")
return JSONResponse({"status": "error", "message": str(e)}, status_code=500)
if __name__ == "__main__":
import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8010, reload=True)
import os
import argparse
from PIL import Image
from tqdm import tqdm
import numpy as np
from loguru import logger
from core.detector import ClothingDetector
from core.encoder import ImageEncoder
from db.faiss_mgr import FaissManager
def index_images(image_dir, index_path="models/vector_db.index"):
# 1. Init components
detector = ClothingDetector()
encoder = ImageEncoder()
faiss_mgr = FaissManager(dimension=768, index_path=index_path)
# 2. Get file list
extensions = ('.jpg', '.jpeg', '.png', '.webp')
image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(extensions)]
logger.info(f"Found {len(image_files)} images in {image_dir}")
# 3. Process and Index
for filename in tqdm(image_files):
try:
path = os.path.join(image_dir, filename)
image = Image.open(path).convert("RGB")
# Detect and crop (using the first detection as the main product for indexing)
crops = detector.detect_and_crop(image)
if not crops:
continue
# Embed the crop
embedding = encoder.encode_image(crops[0])
# SKU is usually the filename without extension
sku = os.path.splitext(filename)[0]
# Add to FAISS
faiss_mgr.add_vectors(np.array([embedding]), [sku])
except Exception as e:
logger.error(f"Error processing {filename}: {e}")
# 4. Save index
faiss_mgr.save()
logger.info("Indexing complete.")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Index images into FAISS database.")
parser.add_argument("--dir", type=str, required=True, help="Directory containing images to index.")
parser.add_argument("--out", type=str, default="models/vector_db.index", help="Output index path.")
args = parser.parse_args()
index_images(args.dir, args.out)
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Image Identification Lab</title>
<style>
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background-color: #f4f7f6;
margin: 0;
padding: 20px;
color: #333;
}
.container {
max-width: 800px;
margin: auto;
background: white;
padding: 30px;
border-radius: 12px;
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.08);
}
h1 {
text-align: center;
color: #2c3e50;
margin-bottom: 30px;
}
.upload-section {
border: 2px dashed #3498db;
padding: 40px;
text-align: center;
border-radius: 8px;
cursor: pointer;
transition: background 0.3s;
margin-bottom: 20px;
}
.upload-section:hover {
background: #ebf5fb;
}
#preview {
max-width: 100%;
max-height: 400px;
display: none;
margin: 20px auto;
border-radius: 8px;
}
.btn {
display: block;
width: 100%;
padding: 12px;
background: #3498db;
color: white;
border: none;
border-radius: 6px;
font-size: 16px;
cursor: pointer;
font-weight: bold;
}
.btn:hover {
background: #2980b9;
}
#results {
margin-top: 30px;
border-top: 2px solid #eee;
padding-top: 20px;
}
.result-item {
background: #f9f9f9;
padding: 15px;
border-radius: 8px;
margin-bottom: 10px;
border-left: 5px solid #2ecc71;
}
.match-score {
float: right;
font-weight: bold;
color: #27ae60;
}
.loader {
border: 4px solid #f3f3f3;
border-top: 4px solid #3498db;
border-radius: 50%;
width: 30px;
height: 30px;
animation: spin 2s linear infinite;
display: none;
margin: 20px auto;
}
@keyframes spin {
0% {
transform: rotate(0deg);
}
100% {
transform: rotate(360deg);
}
}
</style>
</head>
<body>
<div class="container">
<h1>🔍 Image Identification Lab</h1>
<div class="upload-section" onclick="document.getElementById('fileInput').click()">
<p>Click to Upload or Drag & Drop Product Image</p>
<input type="file" id="fileInput" hidden accept="image/*" onchange="previewImage(event)">
</div>
<img id="preview" alt="Preview">
<button class="btn" onclick="uploadAndMatch()">IDENTIFY PRODUCT</button>
<div id="loader" class="loader"></div>
<div id="results">
<!-- Results will appear here -->
</div>
</div>
<script>
function previewImage(event) {
const preview = document.getElementById('preview');
const file = event.target.files[0];
const reader = new FileReader();
reader.onload = function () {
preview.src = reader.result;
preview.style.display = 'block';
}
if (file) reader.readAsDataURL(file);
}
async function uploadAndMatch() {
const fileInput = document.getElementById('fileInput');
if (fileInput.files.length === 0) {
alert("Please select an image first!");
return;
}
const resultsDiv = document.getElementById('results');
const loader = document.getElementById('loader');
resultsDiv.innerHTML = '';
loader.style.display = 'block';
const formData = new FormData();
formData.append('file', fileInput.files[0]);
try {
const response = await fetch('/match', {
method: 'POST',
body: formData
});
const data = await response.json();
loader.style.display = 'none';
if (data.status === 'success') {
resultsDiv.innerHTML = `<h3>Process Time: ${data.process_time}</h3>`;
data.detections.forEach((detection, detIndex) => {
const detHtml = `
<div style="margin-top:20px; padding:15px; background:#f0f3f5; border-radius:8px;">
<h4 style="margin-top:0;">Detection #${detIndex + 1}: ${detection.message}</h4>
<div style="color: #2980b9; font-weight:bold; margin-bottom:10px;">Confidence: ${detection.confidence.toUpperCase()}</div>
<div id="items-${detIndex}"></div>
</div>
`;
resultsDiv.insertAdjacentHTML('beforeend', detHtml);
const itemsDiv = document.getElementById(`items-${detIndex}`);
detection.items.forEach(res => {
// Convert distance to a mock similarity % for display
const sim = Math.max(0, (1 - res.distance) * 100).toFixed(1);
itemsDiv.innerHTML += `
<div class="result-item">
<span class="match-score">${sim}% Match</span>
<div><strong>SKU:</strong> ${res.sku}</div>
</div>
`;
});
});
} else {
resultsDiv.innerHTML = `<p style="color:red">Error: ${data.message}</p>`;
}
} catch (error) {
loader.style.display = 'none';
alert("Upload failed: " + error);
}
}
</script>
</body>
</html>
\ No newline at end of file
import numpy as np
from db.faiss_mgr import FaissManager
from config import INDEX_PATH, MAPPING_PATH
def test_search():
mgr = FaissManager(dimension=768, index_path=INDEX_PATH, mapping_path=MAPPING_PATH)
if mgr.index.ntotal == 0:
print("Index is empty!")
return
print(f"Index total: {mgr.index.ntotal}")
# Create a random query vector
query = np.random.random((768,)).astype('float32')
norm = np.linalg.norm(query)
query /= norm
dist, res = mgr.search(query, top_k=3)
print("Top 3 matches for a random vector:")
for r in res:
print(f"SKU: {r['sku']}, Distance: {r['distance']:.4f}")
if __name__ == "__main__":
test_search()
import json
import numpy as np
import os
from loguru import logger
import ast
# Import our FaissManager
from db.faiss_mgr import FaissManager
from config import INDEX_PATH, MAPPING_PATH
def ingest_from_json(json_path):
if not os.path.exists(json_path):
logger.error(f"JSON file not found at {json_path}")
return
logger.info(f"Reading data from {json_path}...")
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# The user provided a structure where the key is the SQL query
# and the value is a list of results.
all_items = []
for key in data:
if isinstance(data[key], list):
all_items.extend(data[key])
logger.info(f"Found {len(all_items)} products to index.")
faiss_mgr = FaissManager(dimension=768, index_path=INDEX_PATH, mapping_path=MAPPING_PATH)
vectors = []
skus = []
for item in all_items:
sku = item.get('internal_ref_code')
vector_str = item.get('vector')
if not sku or not vector_str:
continue
try:
# Vector is stored as a string "[0.1, 0.2, ...]"
vector = np.array(ast.literal_eval(vector_str)).astype('float32')
# Normalize to unit length (standard for cosine similarity in FAISS IndexFlatL2)
norm = np.linalg.norm(vector)
if norm > 0:
vector = vector / norm
vectors.append(vector)
skus.append(sku)
except Exception as e:
logger.error(f"Error parsing vector for SKU {sku}: {e}")
if vectors:
vectors_np = np.stack(vectors)
faiss_mgr.add_vectors(vectors_np, skus)
faiss_mgr.save()
logger.info(f"Ingestion complete. Total items in FAISS: {faiss_mgr.index.ntotal}")
else:
logger.warning("No valid vectors found to index.")
if __name__ == "__main__":
# Path to the JSON file provided by the user
target_json = r"d:\cnf\chatbot_canifa_image\data\magento_product_dimension_with_image_embedding_202512241158.json"
ingest_from_json(target_json)
cdcdccdc @ f5deb28b
Subproject commit f5deb28b0f9f83c6c3ea53cdd9e68f5c1a5fcbe7
chatbot-rsa @ d6b45f42
Subproject commit d6b45f42c45f8f1c5957894201bff23f140da1a2
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment