Commit 304590ec authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

Add model selection and lab updates

parent 0983572f
...@@ -48,7 +48,7 @@ class CANIFAGraph: ...@@ -48,7 +48,7 @@ class CANIFAGraph:
self.collection_tools = get_collection_tools() # Vẫn lấy list name để routing self.collection_tools = get_collection_tools() # Vẫn lấy list name để routing
self.retrieval_tools = self.all_tools self.retrieval_tools = self.all_tools
self.llm_with_tools = self.llm.bind_tools(self.all_tools, strict=True) self.llm_with_tools = self.llm.bind_tools(self.all_tools) # No strict: compat with both OpenAI & Gemini schemas
self.cache = InMemoryCache() self.cache = InMemoryCache()
# Chain caching: avoid rebuilding ChatPromptTemplate every turn # Chain caching: avoid rebuilding ChatPromptTemplate every turn
...@@ -185,13 +185,12 @@ class CANIFAGraph: ...@@ -185,13 +185,12 @@ class CANIFAGraph:
return self.build() return self.build()
# --- Singleton & Public API --- # --- Per-model Instance Cache & Public API ---
_instance: list[CANIFAGraph | None] = [None] _instances: dict[str, CANIFAGraph] = {}
def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None) -> Any: def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None) -> Any:
"""Get compiled graph (Singleton usage).""" """Get compiled graph (cached per model)."""
# Use singleton to avoid rebuilding graph on every request
manager = get_graph_manager(config, llm, tools) manager = get_graph_manager(config, llm, tools)
return manager.build() return manager.build()
...@@ -199,38 +198,33 @@ def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = N ...@@ -199,38 +198,33 @@ def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = N
def get_graph_manager( def get_graph_manager(
config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None
) -> CANIFAGraph: ) -> CANIFAGraph:
"""Get CANIFAGraph instance (Auto-rebuild if model config changes). """Get CANIFAGraph instance per model_name.
Prompt is now fetched dynamically per request from Langfuse, Each model gets its own cached graph instance, enabling true parallel
so no need to rebuild graph when prompt changes. execution when Lab sends requests to different models simultaneously.
""" """
# 1. New Instance if Empty effective_config = config or get_config()
if _instance[0] is None: model_key = effective_config.model_name
_instance[0] = CANIFAGraph(config, llm, tools)
logger.info(f"✨ Graph Created: {_instance[0].config.model_name} (prompts from Langfuse)")
return _instance[0]
# 2. Check for Model Config Changes only if model_key not in _instances:
is_model_changed = config and config.model_name != _instance[0].config.model_name _instances[model_key] = CANIFAGraph(effective_config, llm, tools)
logger.info(f"✨ Graph Created: {model_key} (prompts from Langfuse) | Total cached: {len(_instances)}")
if is_model_changed: return _instances[model_key]
logger.info(f"🔄 Rebuilding Graph: Model ({_instance[0].config.model_name}->{config.model_name})")
_instance[0] = CANIFAGraph(config, llm, tools)
return _instance[0]
return _instance[0]
def reset_graph() -> None: def reset_graph() -> None:
"""Reset singleton for testing.""" """Reset all cached instances (for testing)."""
_instance[0] = None _instances.clear()
def reset_chain_cache() -> None: def reset_chain_cache() -> None:
"""Reset only the cached chain (when prompt changes). """Reset only the cached chain for all instances (when prompt changes).
Keeps the graph/LLM/tools intact, only forces chain rebuild on next request. Keeps the graph/LLM/tools intact, only forces chain rebuild on next request.
""" """
if _instance[0] is not None: for model_key, inst in _instances.items():
_instance[0]._cached_chain = None inst._cached_chain = None
_instance[0]._cached_prompt_hash = None inst._cached_prompt_hash = None
logger.info("🔄 Chain cache cleared — will rebuild on next request") if _instances:
logger.info(f"🔄 Chain cache cleared for {len(_instances)} model(s) — will rebuild on next request")
...@@ -14,6 +14,7 @@ class QueryRequest(BaseModel): ...@@ -14,6 +14,7 @@ class QueryRequest(BaseModel):
user_query: str user_query: str
images: list[str] | None = None images: list[str] | None = None
image_analysis: dict[str, Any] | None = None image_analysis: dict[str, Any] | None = None
model_name: str | None = None # Override model per-request (Lab mode)
class AgentState(TypedDict): class AgentState(TypedDict):
......
...@@ -31,7 +31,7 @@ from agent.prompt_utils import read_tool_prompt ...@@ -31,7 +31,7 @@ from agent.prompt_utils import read_tool_prompt
class SearchItem(BaseModel): class SearchItem(BaseModel):
model_config = {"extra": "ignore"} # Gemini may send extra fields model_config = {"extra": "ignore", "json_schema_extra": {"additionalProperties": False}} # ignore for Gemini compat; additionalProperties for OpenAI strict mode
# ====== SEARCH TEXT (optional fallback) ====== # ====== SEARCH TEXT (optional fallback) ======
description: str | None = Field( description: str | None = Field(
...@@ -121,7 +121,7 @@ class SearchItem(BaseModel): ...@@ -121,7 +121,7 @@ class SearchItem(BaseModel):
class MultiSearchParams(BaseModel): class MultiSearchParams(BaseModel):
model_config = {"extra": "ignore"} # Gemini may send extra fields model_config = {"extra": "ignore", "json_schema_extra": {"additionalProperties": False}} # ignore for Gemini compat; additionalProperties for OpenAI strict mode
searches: list[SearchItem] = Field(description="Danh sách các truy vấn tìm kiếm") searches: list[SearchItem] = Field(description="Danh sách các truy vấn tìm kiếm")
......
...@@ -106,10 +106,12 @@ async def fashion_qa_chat_dev(request: Request, req: QueryRequest, background_ta ...@@ -106,10 +106,12 @@ async def fashion_qa_chat_dev(request: Request, req: QueryRequest, background_ta
try: try:
# DEV MODE: Return ai_response + products immediately # DEV MODE: Return ai_response + products immediately
# Lab mode: allow model override from request body
effective_model = req.model_name or DEFAULT_MODEL
result = await chat_controller( result = await chat_controller(
query=req.user_query, query=req.user_query,
background_tasks=background_tasks, background_tasks=background_tasks,
model_name=DEFAULT_MODEL, model_name=effective_model,
images=req.images, images=req.images,
identity_key=str(identity_id), identity_key=str(identity_id),
return_user_insight=False, return_user_insight=False,
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment