You need to sign in or sign up before continuing.
Commit 17f9182c authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

fix(stock): refactor stock api, strict schema, dedup variants, fix mapping

parent 37643d07
# Backend Singleton Lazy Loading Rule
All backend services, managers, and graph facilitators in the `backend/` directory MUST follow the Singleton pattern with Lazy Loading.
## Pattern Requirements
### 1. Class Structure
- Use a class attribute `_instance` initialized to `None`.
- Define a class method `get_instance(cls, ...)` to handle the lazy initialization.
- **Do not use threading locks** unless explicitly required for high-concurrency external resources (keep it simple by default).
### 2. Implementation Template
```python
class MyManager:
_instance = None
def __init__(self, *args, **kwargs):
if MyManager._instance is not None:
raise RuntimeError("Use get_instance() instead")
# Heavy initialization here
@classmethod
def get_instance(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = cls(*args, **kwargs)
return cls._instance
```
## Activation
- **Always On** for files matching `backend/**/*.py`.
- Referenced by: `@python-singleton-lazy` skill.
## Reasons
- **Consistency**: Unified way to access core services.
- **Performance**: Lazy loading reduces startup time and memory footprint.
- **Simplicity**: No complex locking logic for standard backend components.
---
name: python-singleton-lazy
description: Implement Python classes using a simple Singleton pattern with Lazy Loading (No Locks).
author: Antigravity
version: 1.0.1
---
# Python Simple Singleton with Lazy Loading
Use this pattern for a lightweight, lazy-loaded Singleton. This version does **not** use Thread Locking, making it faster and simpler for single-threaded applications or when thread safety is handled elsewhere.
## Implementation
We use a simple class method `get_instance` to check if the instance exists. If not, it creates it.
```python
from typing import Optional, Any
class ServiceName:
"""
Singleton class for [Service Description] with simple lazy loading.
"""
_instance: Optional['ServiceName'] = None
def __init__(self):
"""
Private constructor. Do not call directly.
Use ServiceName.get_instance() instead.
"""
if ServiceName._instance is not None:
raise RuntimeError("Call get_instance() instead")
# --- Initialization Logic Here ---
print("Initializing ServiceName...")
# ---------------------------------
@classmethod
def get_instance(cls) -> 'ServiceName':
"""
Static access method.
Creates the instance only when first called.
"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
def some_method(self) -> Any:
"""
Example business logic method.
"""
return "Task Completed"
# Usage Example:
# service = ServiceName.get_instance() # Initializes here
# service2 = ServiceName.get_instance() # Returns existing instance
```
## Checklist for AI
When applying this skill:
1. [ ] Rename `ServiceName` to a meaningful class name.
2. [ ] Define `_instance` as a class attribute acting as the cache.
3. [ ] Implement `get_instance` with a simple `if is None` check.
4. [ ] Prevent direct instantiation in `__init__`.
---
description: Ralph Wiggum "Infinite" Loop - A continuous cycle for perfection.
---
# Ralph Wiggum "Infinite" Loop
**CONCEPT:** This is an **INFINITE LOOP**. You do not exit this loop until the task is perfectly verified or the user forcibly stops you.
## THE LOOP
### 🔄 Phase 1: The "Dumb" Questions (Preparation)
**ENTRY POINT.** Always start here.
1. **"What exactly am I trying to do?"** (Explain it to a 5-year-old).
2. **"Do I have all the files and info I need?"** (If not, STOP and read them first).
3. **"What is the stupidest mistake I could make here?"** (Example: wiping the database).
4. **"Is there a simpler way?"** (Don't over-engineer).
*Decision:* If you are confused -> Stay in Phase 1. If clear -> Go to Phase 2.
### 🔄 Phase 2: Micro-Planning
Plan for **ONE** step only.
1. Define the **Single Next Action** (e.g., "Create file X").
2. Define **Success Criteria** for this specific action.
*Action:* Go to Phase 3.
### 🔄 Phase 3: Execution
1. **EXECUTE** the single tool call.
2. **STOP.** Do not do anything else.
*Action:* Go to Phase 4.
### 🔄 Phase 4: Verification
1. **VERIFY** the immediate result (Run code/Read file).
2. **ASK:** "Did it work 100%?"
- **YES:** Loop back to **Phase 1** (To prepare for the *next* micro-step).
- **NO:** Loop back to **Phase 1** (To re-evaluate why it failed).
> **CRITICAL RULE:** NEVER BREAK THE LOOP. Even if you think you are done, loop back to Phase 1 one last time to ask: "Is there absolutely nothing left to do?" Only then can you stop.
# 🔬 VERIFICATION: LangGraph Streaming Behavior
## 🎯 MỤC ĐÍCH
Kiểm tra xem LangGraph `astream()` có stream **incremental** (từng phần) hay chỉ emit event **sau khi node hoàn thành**.
---
## 📊 KẾT QUẢ EXPECTED
### **Scenario 1: Incremental Streaming (Lý tưởng)** ✅
Nếu LangGraph stream incremental, backend logs sẽ hiển thị:
```
🌊 Starting LLM streaming...
📦 Event #1 at t=2.50s | Keys: ['messages']
📦 Event #2 at t=3.20s | Keys: ['ai_response']
📡 Event #2 (t=3.20s): ai_response with 150 chars
Preview: {"ai_response": "Anh chọn áo thun th...
📦 Event #3 at t=4.10s | Keys: ['ai_response']
📡 Event #3 (t=4.10s): ai_response with 380 chars
Preview: {"ai_response": "Anh chọn áo thun thể thao nam chuẩn luôn! Em tìm...
📦 Event #4 at t=5.50s | Keys: ['ai_response']
📡 Event #4 (t=5.50s): ai_response with 620 chars
Preview: {"ai_response": "...", "product_ids": ["SKU1", "SKU2"]...
🎯 Event #4 (t=5.50s): Regex matched product_ids!
✅ Extracted 3 SKUs: ['SKU1', 'SKU2', 'SKU3']
🚨 BREAKING at Event #4 (t=5.50s) - user_insight KHÔNG ĐỢI!
```
**→ Content tăng dần (150 → 380 → 620 chars)**
**→ Break sớm khi có product_ids (t=5.5s thay vì t=12s)**
---
### **Scenario 2: Event-based (Sau khi xong)** ❌
Nếu LangGraph chỉ emit sau khi node xong, logs sẽ là:
```
🌊 Starting LLM streaming...
📦 Event #1 at t=2.30s | Keys: ['messages'] ← Tool execution
📦 Event #2 at t=11.80s | Keys: ['ai_response'] ← LLM node hoàn thành
📡 Event #2 (t=11.80s): ai_response with 1250 chars ← TOÀN BỘ RESPONSE
Preview: {"ai_response": "Anh chọn áo thun thể thao nam chuẩn luôn!...", "product_ids": ["SKU1", "SKU2", "SKU3"], "user_insight": {...}}
🎯 Event #2 (t=11.80s): Regex matched product_ids!
✅ Extracted 3 SKUs: ['SKU1', 'SKU2', 'SKU3']
🚨 BREAKING at Event #2 (t=11.80s) - user_insight KHÔNG ĐỢI!
```
**→ CHỈ 1 EVENT duy nhất với full content**
**→ Emit sau khi LLM xong hết (t=11.8s)**
**→ KHÔNG THỂ break sớm hơn!**
---
## 🔍 PHÂN TÍCH
### **Nếu Scenario 2 (Event-based):**
**Giải thích:**
- LLM **đang stream tokens internal** từ t=2s → t=12s
- LangGraph **chờ node xong** mới emit event
- Event chứa **full response** luôn
- Regex match ngay lập tức vì đã có đầy đủ
**Kết luận:**
- ✅ Code đã đúng, streaming đã bật
- ❌ Nhưng không thể break sớm hơn vì event chưa có
- ⏱️ Latency không giảm được (~12s)
---
## 💡 GIẢI PHÁP
Nếu kết quả là Scenario 2, muốn stream thực sự cần:
### **Option A: Custom Streaming Callback**
```python
from langchain.callbacks.base import AsyncCallbackHandler
class StreamingCallback(AsyncCallbackHandler):
async def on_llm_new_token(self, token: str, **kwargs):
# Accumulate và check regex
self.accumulated += token
if '"product_ids"' in self.accumulated:
# Trigger break somehow
pass
```
### **Option B: SSE Endpoint**
Stream events trực tiếp cho client, client tự parse
### **Option C: Giữ nguyên**
Code đã tối ưu trong giới hạn, accept latency
---
## 📝 NOTES
- **Streaming=True** trong LLM → LangChain stream tokens internal
- **graph.astream()** → Stream events, không phải tokens
- **Break early** chỉ có ý nghĩa nếu events emit incremental
**Hãy check logs backend để xác định scenario nào!**
# 🌊 STREAMING BEHAVIOR EXPLAINED
## ❓ TẠI SAO VẪN CHẬM MẶC DÙ ĐÃ BẬT STREAMING?
### ✅ HIỆN TẠI - STREAMING ĐÃ BẬT:
1. **LLM Factory** ([llm_factory.py](../common/llm_factory.py#L87)):
```python
llm = create_llm(model_name=..., streaming=True) # ✅ BẬT
```
2. **Controller** ([controller.py](controller.py#L167)):
```python
async for event in graph.astream(initial_state, config=exec_config): # ✅ DÙNG ASTREAM
```
3. **Regex Early Break** ([controller.py](controller.py#L186)):
```python
if product_match:
# Bắt được product_ids → BREAK ngay!
break # ✅ KHÔNG ĐỢI user_insight
```
---
## 🔍 VẤN ĐỀ THỰC SỰ:
### **LangGraph astream() ≠ Token Streaming**
`graph.astream()` stream **EVENTS** (node completions), KHÔNG phi **TOKENS**:
```
graph.astream() tạo ra các events:
├─ Event 1: Tool node hoàn thành → {"messages": [...]}
├─ Event 2: LLM node hoàn thành → {"ai_response": AIMessage(...)} ← TOÀN BỘ RESPONSE MỘT LẦN!
└─ Event 3: Agent node hoàn thành → {"messages": [...]}
```
**LLM response được stream BÊN TRONG node**, nhưng `graph.astream()` chỉ emit event SAU KHI node xong!
---
## ⏱️ TIMELINE THỰC TẾ:
```
t=0s: Client gửi request
t=0-2s: Tool execution (DB query)
t=2-12s: LLM streaming tokens (INTERNAL) ← ĐANG STREAM NHƯNG KHÔNG VISIBLE!
t=12s: LLM node hoàn thành → graph.astream() emit event
t=12s: Regex match product_ids → BREAK
t=12s: Response trả về client
```
**Latency**: ~12s
---
## 💡 TẠI SAO KHÔNG BREAK SỚM HƠN?
Vì:
1. **Event chỉ emit SAU KHI LLM node xong**
2. LLM node chứa TOÀN BỘ JSON response (ai_response + product_ids + user_insight)
3. Không thể break giữa chừng vì event chưa được emit
---
## 🎯 GIẢI PHÁP ĐỂ STREAM THỰC SỰ:
### **Option 1: Custom Streaming Callback** (Phức tạp)
```python
from langchain.callbacks.base import BaseCallbackHandler
class TokenStreamCallback(BaseCallbackHandler):
def on_llm_new_token(self, token: str, **kwargs):
# Stream từng token ra client
yield token
```
### **Option 2: SSE Endpoint** (Chuẩn nhất)
```python
@router.get("/api/agent/chat-stream")
async def chat_stream(request: Request):
async def event_generator():
async for event in graph.astream(...):
if "ai_response" in event:
yield f"data: {json.dumps(event)}\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
```
### **Option 3: WebSocket** (Real-time)
```python
@router.websocket("/ws/chat")
async def websocket_chat(websocket: WebSocket):
await websocket.accept()
async for event in graph.astream(...):
await websocket.send_json(event)
```
---
## ✅ CODE HIỆN TẠI LÀ TỐI ƯU NHẤT (TRONG GIỚI HẠN)
**Streaming + Early Break ĐÃ ĐÚNG:**
- ✅ Break ngay khi có product_ids
- ✅ User_insight xử lý background
- ✅ Không đợi full response
**NHƯNG:**
- ❌ Không thể break sớm hơn vì event chưa emit
- ❌ Client vẫn phải đợi LLM xong (~12s)
---
## 📊 SO SÁNH LATENCY:
| Method | Latency | Complexity |
|--------|---------|------------|
| **Current (RESTful + Internal Stream)** | ~12s | ⭐ Simple |
| **SSE Streaming** | ~8s (stream chunks) | ⭐⭐ Medium |
| **WebSocket** | ~5s (real-time) | ⭐⭐⭐ Complex |
---
## 🚀 KẾT LUẬN:
**Code ĐÃ TỐI ƯU TỐI ĐA trong RESTful context!**
Để giảm latency thêm, cần:
1. **Switch sang SSE/WebSocket** (requires client changes)
2. **Faster LLM model** (gpt-4o-mini thay vì gpt-5-nano)
3. **Cache hit** (< 100ms)
**Current implementation: 12s → Optimized: 8-10s (SSE) hoặc 5-7s (WebSocket)**
This diff is collapsed.
......@@ -146,31 +146,48 @@ def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = N
def get_graph_manager(
config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None
) -> CANIFAGraph:
"""Get CANIFAGraph instance (Auto-rebuild if model config changes)."""
from .prompt import get_last_modified
"""Get CANIFAGraph instance (Auto-rebuild if model config changes OR prompt version changed)."""
import asyncio
from common.cache import get_prompt_version
current_prompt_mtime = get_last_modified()
# Get current prompt version from Redis (shared across all workers)
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# Inside async context - use nest_asyncio or just check later
# Fallback to file mtime check if we can't get Redis version synchronously
from .prompt import get_last_modified
current_prompt_version = get_last_modified()
else:
current_prompt_version = loop.run_until_complete(get_prompt_version())
except RuntimeError:
# No event loop, create one
current_prompt_version = asyncio.run(get_prompt_version())
except Exception as e:
logger.warning(f"Failed to get prompt version: {e}, using mtime fallback")
from .prompt import get_last_modified
current_prompt_version = get_last_modified()
# 1. New Instance if Empty
if _instance[0] is None:
_instance[0] = CANIFAGraph(config, llm, tools)
_instance[0].prompt_mtime = current_prompt_mtime
logger.info(f"✨ Graph Created: {_instance[0].config.model_name}")
_instance[0].prompt_version = current_prompt_version
logger.info(f"✨ Graph Created: {_instance[0].config.model_name}, prompt_version={current_prompt_version}")
return _instance[0]
# 2. Check for Config Changes (Model Switch OR Prompt Update)
# 2. Check for Config Changes (Model Switch OR Prompt Version Change)
is_model_changed = config and config.model_name != _instance[0].config.model_name
is_prompt_changed = current_prompt_mtime != getattr(_instance[0], "prompt_mtime", 0.0)
is_prompt_changed = current_prompt_version != getattr(_instance[0], "prompt_version", 0)
if is_model_changed or is_prompt_changed:
change_reason = []
if is_model_changed: change_reason.append(f"Model ({_instance[0].config.model_name}->{config.model_name})")
if is_prompt_changed: change_reason.append("Prompt File Updated")
if is_prompt_changed: change_reason.append(f"Prompt Version ({getattr(_instance[0], 'prompt_version', 0)}->{current_prompt_version})")
logger.info(f"🔄 Rebuilding Graph due to: {', '.join(change_reason)}")
_instance[0] = CANIFAGraph(config, llm, tools)
_instance[0].prompt_mtime = current_prompt_mtime
_instance[0].prompt_version = current_prompt_version
return _instance[0]
return _instance[0]
......
This diff is collapsed.
"""
Custom Streaming Callback để bắt tokens từ LLM real-time
Không cần đợi graph.astream() emit event!
"""
import asyncio
import logging
import re
from typing import Any
from langchain_core.callbacks.base import AsyncCallbackHandler
logger = logging.getLogger(__name__)
class ProductIDStreamingCallback(AsyncCallbackHandler):
"""
Callback để bắt LLM tokens real-time và check product_ids.
Khi có product_ids → trigger break ngay, không đợi user_insight!
"""
def __init__(self):
self.accumulated_content = ""
self.product_ids_found = False
self.ai_response_text = ""
self.product_skus = []
self.should_stop = False
self.product_found_event = asyncio.Event() # ✅ Event thay vì polling!
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""
Callback khi LLM sinh token mới.
Accumulate và check regex ngay!
"""
self.accumulated_content += token
# Debug log mỗi 100 chars
if len(self.accumulated_content) % 100 == 0:
logger.debug(f"📡 Streamed {len(self.accumulated_content)} chars...")
# Check xem đã có product_ids chưa
if not self.product_ids_found:
product_match = re.search(r'"product_ids"\s*:\s*\[(.*?)\]', self.accumulated_content, re.DOTALL)
if product_match:
logger.warning(f"🎯 FOUND product_ids at {len(self.accumulated_content)} chars!")
self.product_ids_found = True
# Extract ai_response
ai_text_match = re.search(
r'"ai_response"\s*:\s*"(.*?)"(?=\s*,\s*"product_ids")', self.accumulated_content, re.DOTALL
)
if ai_text_match:
self.ai_response_text = ai_text_match.group(1)
self.ai_response_text = self.ai_response_text.replace('\\"', '"').replace("\\n", "\n")
# Extract SKUs
skus_text = product_match.group(1)
self.product_skus = re.findall(r'"([^"]+)"', skus_text)
logger.warning(f"✅ Extracted {len(self.product_skus)} SKUs: {self.product_skus}")
logger.info("✅ product_ids found → response can return early (stream continues)")
# ✅ Set event → wake up controller NGAY LẬP TỨC!
self.should_stop = True
self.product_found_event.set()
async def on_llm_end(self, response, **kwargs: Any) -> None:
"""Called when LLM finishes."""
if not self.product_ids_found:
logger.info("ℹ️ LLM turn ended without product_ids (may appear after tool calls)")
async def on_llm_error(self, error: Exception, **kwargs: Any) -> None:
"""Called when LLM errors."""
logger.error(f"❌ LLM Error: {error}")
def reset(self):
"""Reset callback state."""
self.accumulated_content = ""
self.product_ids_found = False
self.ai_response_text = ""
self.product_skus = []
self.should_stop = False
self.product_found_event.clear()
This diff is collapsed.
Tra cứu TOÀN BỘ thông tin về thương hiệu và dịch vụ của Canifa.
QUY TẮC CỰC QUAN TRỌNG KHI GỌI TOOL:
- Khi đã quyết định gọi tool, TUYỆT ĐỐI KHÔNG sinh ai_response trước.
- Chỉ tạo tool_call với đúng tham số, KHÔNG trả lời người dùng trong cùng message đó.
- Sau khi tool trả kết quả mới được sinh ai_response.
Sử dụng tool này khi khách hàng hỏi về:
1. THƯƠNG HIỆU & GIỚI THIỆU: Lịch sử hình thành, giá trị cốt lõi, sứ mệnh.
......
Siêu công cụ tìm kiếm sản phẩm CANIFA - Hỗ trợ Parallel Multi-Search (chạy song song nhiều truy vấn).
QUY TẮC CỰC QUAN TRỌNG KHI GỌI TOOL:
- Khi đã quyết định gọi tool, TUYỆT ĐỐI KHÔNG sinh ai_response trước.
- Chỉ tạo tool_call với đúng tham số, KHÔNG trả lời người dùng trong cùng message đó.
- Sau khi tool trả kết quả mới được sinh ai_response.
QUY TẮC SINH SEARCH QUERIES:
- Nếu khách hỏi 1 món đồ cụ thể -> Sinh 1 Query.
- Nếu khách hỏi set đồ, phối đồ, hoặc nhu cầu chung chung (đi biển, đi tiệc) -> Sinh 2-3 Queries để tìm các món liên quan.
......
Tra cứu danh sách các chương trình khuyến mãi (CTKM) đang diễn ra theo ngày.
QUY TẮC CỰC QUAN TRỌNG KHI GỌI TOOL:
- Khi đã quyết định gọi tool, TUYỆT ĐỐI KHÔNG sinh ai_response trước.
- Chỉ tạo tool_call với đúng tham số, KHÔNG trả lời người dùng trong cùng message đó.
- Sau khi tool trả kết quả mới được sinh ai_response.
Sử dụng tool này khi khách hàng hỏi về:
- "Hôm nay có khuyến mãi gì không?"
- "Đang có chương trình gì hot?"
......
import json
import logging
from typing import Any
import httpx
from langchain_core.tools import tool
from pydantic import BaseModel, Field
from agent.prompt_utils import read_tool_prompt
from config import INTERNAL_STOCK_API
logger = logging.getLogger(__name__)
DEFAULT_MAX_SKUS = 200
DEFAULT_CHUNK_SIZE = 50
class StockCheckInput(BaseModel):
skus: str = Field(
description="Danh sách mã SKU sản phẩm cần kiểm tra tồn kho, phân cách bằng dấu phẩy. Ví dụ: '6ST25W005-SE091-L,6ST25W005-SE091-M'"
description=(
"Danh sách mã sản phẩm cần kiểm tra tồn kho (có thể là mã base, mã màu, "
"hoặc SKU đầy đủ), phân cách bằng dấu phẩy. "
"Ví dụ: '6ST25W005,6ST25W005-SE091,6ST25W005-SE091-L'"
)
)
sizes: str | None = Field(
default=None,
description="Optional: lọc theo size (S,M,L,XL,140...)",
)
max_skus: int = Field(default=DEFAULT_MAX_SKUS, ge=1)
chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, ge=1)
timeout_sec: float = Field(default=10.0, gt=0)
def _split_csv(value: str | None) -> list[str]:
if not value:
return []
return [token.strip() for token in value.split(",") if token.strip()]
def _normalize_size(token: str) -> str:
normalized = token.strip().upper()
if normalized.endswith("CM"):
normalized = normalized[:-2]
return normalized
def _is_full_sku(token: str) -> bool:
return token.count("-") >= 2
async def _fetch_variants(codes: list[str]) -> list[dict[str, Any]]:
if not codes:
return []
placeholders = ",".join(["%s"] * len(codes))
sql = f"""
SELECT
internal_ref_code,
magento_ref_code,
product_color_code,
size_scale
FROM {TABLE_NAME}
WHERE internal_ref_code IN ({placeholders})
OR magento_ref_code IN ({placeholders})
OR product_color_code IN ({placeholders})
GROUP BY internal_ref_code, magento_ref_code, product_color_code, size_scale
"""
params = codes * 3
db = StarRocksConnection()
return await db.execute_query_async(sql, params=tuple(params))
@tool("check_is_stock", args_schema=StockCheckInput)
async def check_is_stock(skus: str) -> str:
async def check_is_stock(
skus: str,
sizes: str | None = None,
max_skus: int = DEFAULT_MAX_SKUS,
chunk_size: int = DEFAULT_CHUNK_SIZE,
timeout_sec: float = 10.0,
) -> str:
"""
Kiểm tra tình trạng tồn kho của các mã sản phẩm (SKU) thực tế từ hệ thống Canifa.
Sử dụng tool này khi người dùng hỏi về tình trạng còn hàng, hết hàng của sản phẩm cụ thể.
Input nhận vào là chuỗi các SKU phân cách bởi dấu phẩy.
Kiểm tra tồn kho theo mã sản phẩm.
- Hỗ trợ mã base / mã màu / SKU đầy đủ.
- Nếu thiếu màu/size thì tự expand từ DB, kết hợp màu + size (kể cả size số).
- Gọi API tồn kho theo batch và trả về JSON tổng hợp.
"""
logger.info(f"🔍 [Stock Check] Checking stock for SKUs: {skus}")
url = "https://canifa.com/v1/middleware/stock_get_stock_list"
params = {"skus": skus}
if not skus:
return "Lỗi: thiếu mã sản phẩm để kiểm tra tồn kho."
api_url = f"{INTERNAL_STOCK_API}"
payload = {
"codes": skus,
"sizes": sizes,
"max_skus": max_skus,
"chunk_size": chunk_size,
"truncate": True,
"expand_only": False
}
try:
async with httpx.AsyncClient() as client:
response = await client.get(url, params=params, timeout=10.0)
response.raise_for_status()
data = response.json()
logger.info(f"✅ Stock Check response: {str(data)[:200]}...")
async with httpx.AsyncClient(timeout=timeout_sec) as client:
resp = await client.post(api_url, json=payload)
resp.raise_for_status()
return json.dumps(resp.json(), ensure_ascii=False)
# Trả về raw JSON để LLM tự xử lý thông tin
return str(data)
except httpx.RequestError as e:
logger.error(f"❌ Network error checking stock: {e}")
return f"Lỗi kết nối khi kiểm tra tồn kho: {str(e)}"
except httpx.HTTPStatusError as e:
logger.error(f"❌ HTTP error {e.response.status_code}: {e}")
return f"Lỗi server khi kiểm tra tồn kho (Status {e.response.status_code})"
except Exception as e:
logger.error(f"❌ Unexpected error in check_is_stock: {e}")
return f"Lỗi không xác định khi kiểm tra tồn kho: {str(e)}"
except httpx.RequestError as exc:
logger.error(f"Network error checking stock: {exc}")
return f"Lỗi kết nối khi kiểm tra tồn kho: {exc}"
except httpx.HTTPStatusError as exc:
logger.error(f"HTTP error {exc.response.status_code}: {exc}")
return f"Lỗi server khi kiểm tra tồn kho (Status {exc.response.status_code})"
except Exception as exc:
logger.error(f"Unexpected error in check_is_stock: {exc}")
return f"Lỗi không xác định khi kiểm tra tồn kho: {exc}"
# Load dynamic docstring from file
dynamic_prompt = read_tool_prompt("check_is_stock")
if dynamic_prompt:
check_is_stock.__doc__ = dynamic_prompt
check_is_stock.description = dynamic_prompt
This diff is collapsed.
Client Request
[Cache Check] → HIT? → Return ngay
↓ MISS
[Load History + User Insight]
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
🌊 STREAMING BẮT ĐẦU
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
async for event in graph.astream():
├─ LLM gọi tools (nếu cần)
├─ Tools return data
├─ LLM bắt đầu sinh JSON response (streaming tokens)
└─ Bắt được event["ai_response"] với content streaming:
Accumulate tokens: '{"ai_response": "text...", "product_ids": ["SK'
⚡ REGEX: Match ngay khi detect được pattern "product_ids": [...]
├─ Regex match: "ai_response": "..." ✅
├─ Regex match: "product_ids": ["SKU1", "SKU2"] ✅
└─ user_insight: {...} ← VẪN ĐANG STREAM, CHƯA CÓ!
🚨 BREAK NGAY! TRẢ RESPONSE, KHÔNG ĐỢI user_insight!
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
⚡ RESPONSE TRẢ NGAY:
{
"ai_response": "...",
"product_ids": ["SKU1", "SKU2"]
}
↓ (Background tasks)
├─ 💾 Save user_insight to Redis
├─ 💾 Cache response
└─ 📝 Save history
\ No newline at end of file
......@@ -6,7 +6,6 @@ from common.embedding_service import create_embedding_async
logger = logging.getLogger(__name__)
def _get_price_clauses(params, sql_params: list) -> list[str]:
"""Lọc theo giá (Parameterized)."""
clauses = []
......@@ -29,7 +28,7 @@ def _get_metadata_clauses(params, sql_params: list) -> list[str]:
exact_fields = [
("gender_by_product", "gender_by_product"),
("age_by_product", "age_by_product"),
("form_neckline", "form_neckline"),
("form_neckline", "form_neckline"),
]
for param_name, col_name in exact_fields:
val = getattr(params, param_name, None)
......@@ -86,11 +85,12 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
magento_code = getattr(params, "magento_ref_code", None)
if magento_code:
logger.info(f"🎯 [CODE SEARCH] Direct search by code: {magento_code}")
sql = """
SELECT
internal_ref_code,
magento_ref_code,
product_color_code,
description_text_full,
sale_price,
original_price,
......@@ -107,7 +107,7 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
# CASE 2: HYDE SEARCH - Semantic Vector Search
# ============================================================
logger.info("🚀 [HYDE RETRIEVER] Starting semantic vector search...")
query_text = getattr(params, "description", None)
if query_text and query_vector is None:
emb_start = time.time()
......@@ -120,11 +120,11 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
# Vector params
v_str = "[" + ",".join(str(v) for v in query_vector) + "]"
# Collect Params
price_params: list = []
price_clauses = _get_price_clauses(params, price_params)
where_filter = ""
if price_clauses:
where_filter = " AND " + " AND ".join(price_clauses)
......@@ -137,6 +137,7 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
WITH top_matches AS (
SELECT /*+ SET_VAR(ann_params='{{"ef_search":128}}') */
internal_ref_code,
magento_ref_code,
product_color_code,
description_text_full,
sale_price,
......@@ -149,6 +150,8 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
)
SELECT
internal_ref_code,
MAX_BY(magento_ref_code, similarity_score) as magento_ref_code,
MAX_BY(product_color_code, similarity_score) as product_color_code,
MAX_BY(description_text_full, similarity_score) as description_text_full,
MAX_BY(sale_price, similarity_score) as sale_price,
MAX_BY(original_price, similarity_score) as original_price,
......@@ -160,7 +163,6 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
ORDER BY max_score DESC
LIMIT 70
"""
# Return sql and params (params only contains filter values now, not the vector)
return sql, price_params
"""
Stock Helpers - Shared stock fetching logic.
Được dùng bởi cả `data_retrieval_tool` và `check_is_stock` tool.
Gọi qua internal API /api/stock/check để tận dụng logic expand SKU.
"""
import logging
from typing import Any
import httpx
from config import INTERNAL_STOCK_API
logger = logging.getLogger(__name__)
DEFAULT_TIMEOUT_SEC = 15.0
async def fetch_stock_for_skus(
skus: list[str],
timeout_sec: float = DEFAULT_TIMEOUT_SEC,
) -> dict[str, dict[str, Any]]:
"""
Fetch stock info for a list of SKUs via internal API.
Supports base codes, product_color_code, and full SKUs.
API will automatically expand short codes to full SKUs.
Args:
skus: List of SKU codes (any format: base, color, or full)
timeout_sec: HTTP timeout in seconds (default 15)
Returns:
Dict mapping SKU -> stock data from API response.
Example: {"6ST25W005-SE091-L": {"qty": 10, "is_in_stock": true, ...}}
"""
if not skus:
return {}
# Deduplicate while preserving order
seen = set()
unique_skus: list[str] = []
for sku in skus:
if sku and sku not in seen:
seen.add(sku)
unique_skus.append(sku)
if not unique_skus:
return {}
stock_map: dict[str, dict[str, Any]] = {}
try:
async with httpx.AsyncClient(timeout=timeout_sec) as client:
# Call internal API with POST
payload = {
"codes": ",".join(unique_skus),
"truncate": True,
"max_skus": 200,
}
resp = await client.post(INTERNAL_STOCK_API, json=payload)
resp.raise_for_status()
data = resp.json()
# Parse response from /api/stock/check
# Format: {"stock_responses": [{"code": 200, "result": [...]}]}
stock_responses = data.get("stock_responses", [])
for stock_resp in stock_responses:
results = stock_resp.get("result", [])
for item in results:
sku_key = item.get("sku")
if sku_key:
stock_map[sku_key] = {
"is_in_stock": item.get("is_in_stock", False),
"qty": item.get("qty", 0),
}
logger.info(f"📦 [STOCK] Fetched stock for {len(stock_map)} SKUs (input: {len(unique_skus)})")
return stock_map
except httpx.RequestError as exc:
logger.error(f"Network error fetching stock: {exc}")
return {}
except httpx.HTTPStatusError as exc:
logger.error(f"HTTP error {exc.response.status_code} fetching stock: {exc}")
return {}
except Exception as exc:
logger.error(f"Unexpected error fetching stock: {exc}")
return {}
......@@ -8,25 +8,26 @@ Note: Rate limit check đã được xử lý trong middleware (CanifaAuthMiddle
import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from fastapi import APIRouter, BackgroundTasks, Request
from fastapi.responses import JSONResponse
from agent.controller import chat_controller
from agent.models import QueryRequest
from common.message_limit import message_limit_service
from common.cache import redis_cache
from common.message_limit import message_limit_service
from common.rate_limit import rate_limit_service
from config import DEFAULT_MODEL
logger = logging.getLogger(__name__)
router = APIRouter()
@router.post("/api/agent/chat", summary="Fashion Q&A Chat (Non-streaming)")
@rate_limit_service.limiter.limit("50/minute")
async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks: BackgroundTasks):
"""
Endpoint chat không stream - trả về response JSON đầy đủ một lần.
Note: Rate limit đã được check trong middleware.
"""
# 1. Lấy user identity từ Middleware (request.state)
......@@ -34,20 +35,21 @@ async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks:
user_id = getattr(request.state, "user_id", None)
device_id = getattr(request.state, "device_id", "unknown")
is_authenticated = getattr(request.state, "is_authenticated", False)
# Định danh duy nhất cho Request này (Log, History, Rate Limit, Langfuse)
identity_id = user_id if is_authenticated else device_id
# Rate limit đã check trong middleware, lấy limit_info từ request.state
limit_info = getattr(request.state, 'limit_info', None)
limit_info = getattr(request.state, "limit_info", None)
print(f"\n🔥🔥🔥 REQUEST ARRIVED! User: {identity_id} | Query: {req.user_query} 🔥🔥🔥\n")
logger.info(f"📥 [Incoming Query - NonStream] User: {identity_id} | Query: {req.user_query}")
try:
# Gọi controller để xử lý logic (Non-streaming)
result = await chat_controller(
query=req.user_query,
user_id=str(identity_id), # Langfuse User ID
user_id=str(identity_id), # Langfuse User ID
background_tasks=background_tasks,
model_name=DEFAULT_MODEL,
images=req.images,
......@@ -84,8 +86,8 @@ async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks:
content={
"status": "error",
"error_code": "SYSTEM_ERROR",
"message": "Oops 😥 Hiện Canifa-AI chưa thể xử lý yêu cầu của bạn ngay lúc này, vui lòng quay lại trong giây lát."
}
"message": "Oops 😥 Hiện Canifa-AI chưa thể xử lý yêu cầu của bạn ngay lúc này, vui lòng quay lại trong giây lát.",
},
)
......@@ -94,18 +96,19 @@ async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks:
async def fashion_qa_chat_dev(request: Request, req: QueryRequest, background_tasks: BackgroundTasks):
"""
Endpoint chat dành cho DEV - trả về đầy đủ user_insight.
Note: Rate limit đã được check trong middleware.
"""
user_id = getattr(request.state, "user_id", None)
device_id = getattr(request.state, "device_id", "unknown")
is_authenticated = getattr(request.state, "is_authenticated", False)
identity_id = user_id if is_authenticated else device_id
limit_info = getattr(request.state, 'limit_info', None)
limit_info = getattr(request.state, "limit_info", None)
logger.info(f"📥 [Incoming Query - Dev] User: {identity_id} | Query: {req.user_query}")
try:
# DEV MODE: Return ai_response + products immediately, user_insight via polling
result = await chat_controller(
query=req.user_query,
user_id=str(identity_id),
......@@ -113,6 +116,7 @@ async def fashion_qa_chat_dev(request: Request, req: QueryRequest, background_ta
model_name=DEFAULT_MODEL,
images=req.images,
identity_key=str(identity_id),
return_user_insight=False,
)
usage_info = await message_limit_service.increment(
......@@ -120,14 +124,11 @@ async def fashion_qa_chat_dev(request: Request, req: QueryRequest, background_ta
is_authenticated=is_authenticated,
)
# user_insight đã được trả về từ controller
user_insight = result.get("user_insight")
return {
"status": "success",
"ai_response": result["ai_response"],
"product_ids": result.get("product_ids", []),
"user_insight": user_insight,
"insight_status": "pending",
"limit_info": {
"limit": usage_info["limit"],
"used": usage_info["used"],
......@@ -141,7 +142,41 @@ async def fashion_qa_chat_dev(request: Request, req: QueryRequest, background_ta
content={
"status": "error",
"error_code": "SYSTEM_ERROR",
"message": "Oops 😥 Hiện Canifa-AI chưa thể xử lý yêu cầu của bạn ngay lúc này, vui lòng quay lại trong giây lát."
}
"message": "Oops 😥 Hiện Canifa-AI chưa thể xử lý yêu cầu của bạn ngay lúc này, vui lòng quay lại trong giây lát.",
},
)
@router.get("/api/agent/user-insight", summary="Get latest user_insight (Dev)")
@rate_limit_service.limiter.limit("120/minute")
async def get_user_insight(request: Request):
"""
Polling endpoint for dev UI to fetch latest user_insight from Redis.
"""
user_id = getattr(request.state, "user_id", None)
device_id = getattr(request.state, "device_id", "unknown")
is_authenticated = getattr(request.state, "is_authenticated", False)
identity_id = user_id if is_authenticated else device_id
try:
client = redis_cache.get_client()
if not client:
return {"status": "pending", "user_insight": None}
insight_key = f"identity_key_insight:{identity_id}"
user_insight = await client.get(insight_key)
if user_insight:
return {"status": "success", "user_insight": user_insight}
return {"status": "pending", "user_insight": None}
except Exception as e:
logger.error(f"Error in get_user_insight: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content={
"status": "error",
"error_code": "SYSTEM_ERROR",
"message": "Không thể tải user_insight lúc này.",
},
)
......@@ -3,6 +3,7 @@ from pydantic import BaseModel
import os
import re
from agent.graph import reset_graph
from common.cache import bump_prompt_version
router = APIRouter()
......@@ -68,12 +69,16 @@ async def update_system_prompt_content(request: Request, body: PromptUpdateReque
with open(PROMPT_FILE_PATH, "w", encoding="utf-8") as f:
f.write(body.content)
# 2. Reset Graph Singleton to force reload prompt
# 2. Bump prompt version in Redis (ALL workers will detect this)
new_version = await bump_prompt_version()
# 3. Reset local worker's Graph Singleton (immediate effect for this worker)
reset_graph()
response = {
"status": "success",
"message": "System prompt updated successfully. Graph reloaded."
"message": f"System prompt updated. Version: {new_version}. All workers will reload on next request.",
"prompt_version": new_version
}
if warning:
response["warning"] = warning
......@@ -82,3 +87,4 @@ async def update_system_prompt_content(request: Request, body: PromptUpdateReque
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
import logging
from typing import Any
import httpx
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from common.starrocks_connection import StarRocksConnection
logger = logging.getLogger(__name__)
router = APIRouter()
STOCK_API_URL = "https://canifa.com/v1/middleware/stock_get_stock_list"
DEFAULT_MAX_SKUS = 200
DEFAULT_CHUNK_SIZE = 50
TABLE_NAME = "shared_source.magento_product_dimension_with_text_embedding"
class StockExpandRequest(BaseModel):
codes: str = Field(
description=(
"Comma-separated product codes. Supports base codes, product_color_code, "
"or full SKU (code-color-size). Example: '6ST25W005,6ST25W005-SE091-L'"
)
)
sizes: str | None = Field(
default=None,
description="Optional comma-separated sizes to filter (e.g. 'S,M,L,XL,140').",
)
max_skus: int = Field(default=DEFAULT_MAX_SKUS, ge=1)
chunk_size: int = Field(default=DEFAULT_CHUNK_SIZE, ge=1)
expand_only: bool = Field(default=False)
truncate: bool = Field(default=True)
timeout_sec: float = Field(default=10.0, gt=0)
def _split_csv(value: str | None) -> list[str]:
if not value:
return []
return [token.strip() for token in value.split(",") if token.strip()]
def _normalize_size(token: str) -> str:
normalized = token.strip().upper()
if normalized.endswith("CM"):
normalized = normalized[:-2]
return normalized
def _is_full_sku(token: str) -> bool:
return token.count("-") >= 2
def _chunked(items: list[str], size: int) -> list[list[str]]:
return [items[i : i + size] for i in range(0, len(items), size)]
async def _fetch_variants(codes: list[str]) -> list[dict[str, Any]]:
if not codes:
return []
placeholders = ",".join(["%s"] * len(codes))
sql = f"""
SELECT
internal_ref_code,
magento_ref_code,
product_color_code,
size_scale
FROM {TABLE_NAME}
WHERE internal_ref_code IN ({placeholders})
OR magento_ref_code IN ({placeholders})
OR product_color_code IN ({placeholders})
GROUP BY internal_ref_code, magento_ref_code, product_color_code, size_scale
"""
params = codes * 3
db = StarRocksConnection()
return await db.execute_query_async(sql, params=tuple(params))
@router.post("/api/stock/check", summary="Expand product codes and check stock")
async def check_stock(req: StockExpandRequest):
"""
Expand base codes to full SKUs using StarRocks, then call Canifa stock API.
"""
input_codes = _split_csv(req.codes)
if not input_codes:
raise HTTPException(status_code=400, detail="codes is required")
size_filter = {_normalize_size(s) for s in _split_csv(req.sizes)} if req.sizes else None
full_skus: list[str] = []
lookup_codes: list[str] = []
for token in input_codes:
if _is_full_sku(token):
full_skus.append(token)
else:
lookup_codes.append(token)
variant_rows = await _fetch_variants(lookup_codes)
expanded_skus: list[str] = []
missing_size_color_codes: list[str] = []
for row in variant_rows:
product_color_code = row.get("product_color_code")
size_scale = row.get("size_scale")
if not product_color_code:
continue
if not size_scale:
missing_size_color_codes.append(product_color_code)
continue
for raw_token in str(size_scale).split("|"):
token = raw_token.strip()
if not token:
continue
normalized = _normalize_size(token)
if size_filter and normalized not in size_filter:
continue
expanded_skus.append(f"{product_color_code}-{normalized}")
# Deduplicate while preserving order
seen = set()
ordered_skus: list[str] = []
for sku in full_skus + expanded_skus:
if sku not in seen:
seen.add(sku)
ordered_skus.append(sku)
truncated = False
if len(ordered_skus) > req.max_skus:
if req.truncate:
ordered_skus = ordered_skus[: req.max_skus]
truncated = True
else:
raise HTTPException(
status_code=400,
detail=f"Expanded SKU count {len(ordered_skus)} exceeds max_skus {req.max_skus}",
)
response_payload = {
"status": "success",
"input_codes": input_codes,
"lookup_codes": lookup_codes,
"input_full_skus": full_skus,
"expanded_skus_count": len(expanded_skus),
"requested_skus_count": len(ordered_skus),
"requested_skus": ordered_skus,
"missing_size_color_codes": missing_size_color_codes,
"truncated": truncated,
}
if req.expand_only:
return response_payload
if not ordered_skus:
response_payload["stock_responses"] = []
return response_payload
try:
stock_responses: list[dict[str, Any]] = []
async with httpx.AsyncClient(timeout=req.timeout_sec) as client:
for chunk in _chunked(ordered_skus, req.chunk_size):
resp = await client.get(STOCK_API_URL, params={"skus": ",".join(chunk)})
resp.raise_for_status()
stock_responses.append(resp.json())
response_payload["stock_responses"] = stock_responses
return response_payload
except httpx.RequestError as exc:
logger.error(f"Network error checking stock: {exc}")
raise HTTPException(status_code=502, detail=f"Network error: {exc}") from exc
except httpx.HTTPStatusError as exc:
logger.error(f"HTTP error checking stock: {exc}")
raise HTTPException(status_code=502, detail=f"Stock API error: {exc}") from exc
except Exception as exc:
logger.error(f"Unexpected error checking stock: {exc}")
raise HTTPException(status_code=500, detail=f"Unexpected error: {exc}") from exc
......@@ -159,3 +159,30 @@ redis_cache = RedisClient()
def get_redis_cache() -> RedisClient:
return redis_cache
# --- Prompt Version Sync (Multi-Worker) ---
PROMPT_VERSION_KEY = "system:prompt_version"
async def get_prompt_version() -> int:
"""Get current prompt version from Redis (shared across all workers)."""
try:
client = redis_cache.get_client()
if client:
version = await client.get(PROMPT_VERSION_KEY)
return int(version) if version else 0
except Exception as e:
logger.warning(f"Failed to get prompt version: {e}")
return 0
async def bump_prompt_version() -> int:
"""Increment prompt version in Redis (call when prompt is updated)."""
try:
client = redis_cache.get_client()
if client:
new_version = await client.incr(PROMPT_VERSION_KEY)
logger.info(f"🔄 Prompt version bumped to: {new_version}")
return new_version
except Exception as e:
logger.warning(f"Failed to bump prompt version: {e}")
return 0
......@@ -10,10 +10,10 @@ import httpx
logger = logging.getLogger(__name__)
# CANIFA_CUSTOMER_API = "https://vsf2.canifa.com/v1/magento/customer"
CANIFA_CUSTOMER_API = "https://vsf2.canifa.com/v1/magento/customer"
CANIFA_CUSTOMER_API = "https://canifa.com/v1/magento/customer"
# CANIFA_CUSTOMER_API = "https://canifa.com/v1/magento/customer"
_http_client: httpx.AsyncClient | None = None
......
......@@ -67,7 +67,7 @@ class LLMFactory:
"""Create and cache a new OpenAI LLM instance."""
try:
llm = self._create_openai(model_name, streaming, json_mode, api_key)
cache_key = (model_name, streaming, json_mode, api_key)
self._cache[cache_key] = llm
return llm
......@@ -85,19 +85,18 @@ class LLMFactory:
llm_kwargs = {
"model": model_name,
"streaming": streaming,
"streaming": streaming, # ← STREAMING CONFIG
"api_key": key,
"temperature": 0,
"max_tokens": 1500,
}
# Nếu bật json_mode, tiêm trực tiếp vào constructor
if json_mode:
llm_kwargs["model_kwargs"] = {"response_format": {"type": "json_object"}}
logger.info(f"⚙️ Initializing OpenAI in JSON mode: {model_name}")
llm = ChatOpenAI(**llm_kwargs)
logger.info(f"✅ Created OpenAI: {model_name}")
logger.info(f"✅ Created OpenAI: {model_name} | Streaming: {streaming}")
return llm
def _enable_json_mode(self, llm: BaseChatModel, model_name: str) -> BaseChatModel:
......
......@@ -25,6 +25,7 @@ __all__ = [
"FIRECRAWL_API_KEY",
"GOOGLE_API_KEY",
"GROQ_API_KEY",
"INTERNAL_STOCK_API",
"JWT_ALGORITHM",
"JWT_SECRET",
"LANGFUSE_BASE_URL",
......@@ -43,6 +44,8 @@ __all__ = [
"OTEL_SERVICE_NAME",
"OTEL_TRACES_EXPORTER",
"PORT",
"RATE_LIMIT_GUEST",
"RATE_LIMIT_USER",
"REDIS_HOST",
"REDIS_PASSWORD",
"REDIS_PORT",
......@@ -52,9 +55,8 @@ __all__ = [
"STARROCKS_PASSWORD",
"STARROCKS_PORT",
"STARROCKS_USER",
"STOCK_API_URL",
"USE_MONGO_CONVERSATION",
"RATE_LIMIT_GUEST",
"RATE_LIMIT_USER",
]
# ====================== SUPABASE CONFIGURATION ======================
......@@ -140,4 +142,8 @@ OTEL_EXPORTER_JAEGER_AGENT_SPLIT_OVERSIZED_BATCHES = os.getenv("OTEL_EXPORTER_JA
RATE_LIMIT_GUEST: int = int(os.getenv("RATE_LIMIT_GUEST", "10"))
RATE_LIMIT_USER: int = int(os.getenv("RATE_LIMIT_USER", "100"))
# ====================== STOCK API CONFIGURATION ======================
# External Canifa Stock API (dùng trực tiếp nếu cần)
STOCK_API_URL: str = os.getenv("STOCK_API_URL", "https://canifa.com/v1/middleware/stock_get_stock_list")
# Internal Stock API (có logic expand SKU từ base code)
INTERNAL_STOCK_API: str = os.getenv("INTERNAL_STOCK_API", "http://localhost:5000/api/stock/check")
......@@ -3,7 +3,6 @@
uvicorn server:app --host 0.0.0.0 --port 5000 --reload
uvicorn server:app --host 0.0.0.0 --port 5000
docker restart chatbot-backend
......@@ -15,3 +14,7 @@ docker logs -f chatbot-backend
docker restart canifa_backend
sudo docker compose -f docker-compose.prod.yml up -d --build
Get-NetTCPConnection -LocalPort 5000 | ForEach-Object { Stop-Process -Id $_.OwningProcess -Force }
taskkill /F /IM python.exe
\ No newline at end of file
......@@ -13,6 +13,7 @@ from api.conservation_route import router as conservation_router
from api.tool_prompt_route import router as tool_prompt_router
from api.prompt_route import router as prompt_router
from api.mock_api_route import router as mock_router
from api.stock_route import router as stock_router
from common.cache import redis_cache
from common.langfuse_client import get_langfuse_client
......@@ -65,6 +66,7 @@ app.include_router(chatbot_router)
app.include_router(prompt_router)
app.include_router(tool_prompt_router) # Register new router
app.include_router(mock_router)
app.include_router(stock_router)
try:
......
This diff is collapsed.
# TEST STREAMING + BACKGROUND USER_INSIGHT
Write-Host "`n==== STREAMING TEST ====`n" -ForegroundColor Cyan
$query = "Ao khoac nam mua dong"
$deviceId = "test_stream_verify"
Write-Host "Sending request..." -ForegroundColor Green
$timing = Measure-Command {
$body = '{"user_query":"' + $query + '","device_id":"' + $deviceId + '"}'
$result = $body | curl.exe -s -X POST "http://localhost:5000/api/agent/chat" -H "Content-Type: application/json" --data-binary "@-"
$result | Out-Null
}
Write-Host "`nResponse Time: $($timing.TotalMilliseconds) ms" -ForegroundColor Green
Write-Host "`nCheck backend logs for:" -ForegroundColor Yellow
Write-Host " - Starting LLM streaming" -ForegroundColor Gray
Write-Host " - Regex matched product_ids" -ForegroundColor Gray
Write-Host " - BREAKING STREAM NOW" -ForegroundColor Gray
Write-Host " - Background task extraction" -ForegroundColor Gray
Write-Host "`nDone!" -ForegroundColor Green
import asyncio
import csv
import os
import sys
from typing import Any
# Ensure we can import from backend root
current_dir = os.path.dirname(os.path.abspath(__file__))
backend_root = os.path.dirname(current_dir)
sys.path.append(backend_root)
from common.starrocks_connection import StarRocksConnection
from config import STARROCKS_DB, STARROCKS_HOST, STARROCKS_PASSWORD, STARROCKS_USER
TABLE_NAME = "shared_source.magento_product_dimension_with_text_embedding"
DEFAULT_SIZES = [90, 92, 98, 104, 110, 116, 122, 128, 134, 140, 152, 164]
OUTPUT_CSV = os.path.join(current_dir, "numeric_size_skus.csv")
def _get_missing_env() -> list[str]:
missing = []
if not STARROCKS_HOST:
missing.append("STARROCKS_HOST")
if not STARROCKS_DB:
missing.append("STARROCKS_DB")
if not STARROCKS_USER:
missing.append("STARROCKS_USER")
if not STARROCKS_PASSWORD:
missing.append("STARROCKS_PASSWORD")
return missing
def _skip_or_warn_if_missing_env() -> bool:
missing = _get_missing_env()
if not missing:
return False
message = f"Missing StarRocks env vars: {', '.join(missing)}"
if "PYTEST_CURRENT_TEST" in os.environ:
import pytest
pytest.skip(message)
print(f"[SKIP] {message}")
return True
def _parse_sizes_env() -> list[int]:
raw = os.getenv("NUMERIC_SIZES")
if not raw:
return DEFAULT_SIZES
sizes: list[int] = []
for token in raw.split(","):
token = token.strip()
if not token:
continue
try:
sizes.append(int(token))
except ValueError:
continue
return sizes or DEFAULT_SIZES
def _build_regex_pattern(sizes: list[int]) -> str:
sizes_str = "|".join(str(s) for s in sorted(set(sizes)))
# Match tokens like 140 or 140cm inside pipe-delimited lists.
return rf"(^|\\|)({sizes_str})(cm)?(\\||$)"
async def fetch_numeric_size_rows(sizes: list[int]) -> list[dict[str, Any]]:
db = StarRocksConnection()
pattern = _build_regex_pattern(sizes)
sql = f"""
SELECT
internal_ref_code,
magento_ref_code,
size_scale
FROM {TABLE_NAME}
WHERE LOWER(size_scale) REGEXP %s
GROUP BY internal_ref_code, magento_ref_code, size_scale
ORDER BY internal_ref_code, magento_ref_code
"""
return await db.execute_query_async(sql, params=(pattern,))
def _write_csv(rows: list[dict[str, Any]], path: str) -> None:
os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["internal_ref_code", "magento_ref_code", "size_scale"])
for row in rows:
writer.writerow(
[row.get("internal_ref_code"), row.get("magento_ref_code"), row.get("size_scale")]
)
def _print_summary(rows: list[dict[str, Any]], sizes: list[int]) -> None:
internal_codes = {row.get("internal_ref_code") for row in rows}
magento_codes = {row.get("magento_ref_code") for row in rows}
print("\n" + "=" * 80)
print("NUMERIC SIZE SKUS")
print("=" * 80)
print(f"Table: {TABLE_NAME}")
print(f"Sizes filter: {', '.join(str(s) for s in sorted(set(sizes)))}")
print(f"Matched rows: {len(rows)}")
print(f"Distinct internal_ref_code: {len(internal_codes)}")
print(f"Distinct magento_ref_code: {len(magento_codes)}")
def _print_sample(rows: list[dict[str, Any]], limit: int = 30) -> None:
print("\nSample (first 30 rows):")
for row in rows[:limit]:
print(
f"- {row.get('internal_ref_code')} | {row.get('magento_ref_code')} | {row.get('size_scale')}"
)
async def _run() -> None:
if _skip_or_warn_if_missing_env():
return
sizes = _parse_sizes_env()
rows = await fetch_numeric_size_rows(sizes)
_print_summary(rows, sizes)
_print_sample(rows, limit=30)
_write_csv(rows, OUTPUT_CSV)
print(f"\nCSV written to: {OUTPUT_CSV}")
await StarRocksConnection.clear_pool()
if __name__ == "__main__":
asyncio.run(_run())
import asyncio
import os
import re
import sys
from collections import Counter
from typing import Any
# Ensure we can import from backend root
current_dir = os.path.dirname(os.path.abspath(__file__))
backend_root = os.path.dirname(current_dir)
sys.path.append(backend_root)
from common.starrocks_connection import StarRocksConnection
from config import STARROCKS_DB, STARROCKS_HOST, STARROCKS_PASSWORD, STARROCKS_USER
TABLE_NAME = "shared_source.magento_product_dimension_with_text_embedding"
def _get_missing_env() -> list[str]:
missing = []
if not STARROCKS_HOST:
missing.append("STARROCKS_HOST")
if not STARROCKS_DB:
missing.append("STARROCKS_DB")
if not STARROCKS_USER:
missing.append("STARROCKS_USER")
if not STARROCKS_PASSWORD:
missing.append("STARROCKS_PASSWORD")
return missing
def _skip_or_warn_if_missing_env() -> bool:
missing = _get_missing_env()
if not missing:
return False
message = f"Missing StarRocks env vars: {', '.join(missing)}"
if "PYTEST_CURRENT_TEST" in os.environ:
import pytest
pytest.skip(message)
print(f"[SKIP] {message}")
return True
def _split_size_scale(size_scale: str | None) -> list[str]:
if not size_scale:
return []
return [token.strip() for token in size_scale.split("|") if token.strip()]
def _normalize_numeric_token(token: str) -> str | None:
if not token:
return None
token = token.strip().lower()
token = re.sub(r"cm$", "", token)
if re.fullmatch(r"\d+(\.\d+)?", token):
return token
return None
async def fetch_size_scale_rows() -> list[dict[str, Any]]:
db = StarRocksConnection()
sql = f"""
SELECT
size_scale,
COUNT(*) AS row_count
FROM {TABLE_NAME}
GROUP BY size_scale
"""
return await db.execute_query_async(sql)
def _build_numeric_summary(rows: list[dict[str, Any]]) -> Counter[str]:
counter: Counter[str] = Counter()
for row in rows:
size_scale = row.get("size_scale")
row_count = int(row.get("row_count") or 0)
for token in _split_size_scale(size_scale):
numeric_token = _normalize_numeric_token(token)
if numeric_token:
counter[numeric_token] += row_count
return counter
def _print_summary(counter: Counter[str]) -> None:
def _sort_key(val: str) -> float:
try:
return float(val)
except ValueError:
return float("inf")
tokens_sorted = sorted(counter.keys(), key=_sort_key)
print("\n" + "=" * 80)
print("NUMERIC SIZE TOKENS")
print("=" * 80)
print(f"Total unique numeric sizes: {len(tokens_sorted)}")
print("\nAll numeric sizes (sorted):")
print(", ".join(tokens_sorted))
print("\nCounts (descending):")
for token, count in counter.most_common():
print(f"- {token}: {count}")
async def _run() -> None:
if _skip_or_warn_if_missing_env():
return
rows = await fetch_size_scale_rows()
numeric_counter = _build_numeric_summary(rows)
_print_summary(numeric_counter)
await StarRocksConnection.clear_pool()
if __name__ == "__main__":
asyncio.run(_run())
......@@ -2,37 +2,54 @@ import requests
import json
import time
url = "http://localhost:5000/api/agent/chat"
# Use the DEV endpoint as per user logs
url = "http://localhost:5000/api/agent/chat-dev"
# Token can be anything for dev if middleware allows, or use the valid one
token = "071w198x23ict4hs1i6bl889fit5p3f7"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {token}"
}
payload = {
"user_query": "tư vấn cho mình áo hoodie"
}
print(f"Sending AUTHENTICATED POST request to {url}...")
print(f"Token: {token}")
queries = [
"tìm cho mình chân váy màu đỏ",
"tìm quần màu đỏ"
]
print(f"Target URL: {url}")
for query in queries:
payload = {
"user_query": query
}
start = time.time()
try:
response = requests.post(url, json=payload, headers=headers, timeout=120)
print(f"Status Code: {response.status_code}")
print(f"Time taken: {time.time() - start:.2f}s")
if response.status_code == 200:
data = response.json()
print("Response JSON:")
# Print limit info specifically to check if limit increased to USER level (100)
if "limit_info" in data:
print("Limit Info:", json.dumps(data["limit_info"], indent=2))
print("\n" + "-"*50)
print(f"Testing Query: '{query}'")
start = time.time()
try:
response = requests.post(url, json=payload, headers=headers, timeout=120)
duration = time.time() - start
print(f"Status Code: {response.status_code}")
print(f"Time taken: {duration:.2f}s")
if response.status_code == 200:
data = response.json()
# print("Response JSON:", json.dumps(data, indent=2, ensure_ascii=False))
# Extract key info
ai_response = data.get("ai_response", "")
user_insight = data.get("user_insight", {})
product_ids = data.get("product_ids", [])
print(f"🤖 AI Response: {ai_response}")
print(f"📦 Product IDs: {product_ids}")
# print(f"🧠 User Insight: {json.dumps(user_insight, ensure_ascii=False)}")
else:
print(json.dumps(data, indent=2, ensure_ascii=False))
else:
print("Error Response:")
print(response.text)
print("❌ Error Response:")
print(response.text)
except Exception as e:
print(f"Error: {e}")
except Exception as e:
print(f"❌ Exception: {e}")
This diff is collapsed.
import asyncio
import logging
import sys
import os
import json
import warnings
# Ensure we can import from backend root
current_dir = os.path.dirname(os.path.abspath(__file__))
backend_root = os.path.dirname(current_dir)
sys.path.append(backend_root)
# Setup logging
logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
from agent.tools.data_retrieval_tool import data_retrieval_tool, SearchItem
from common.starrocks_connection import StarRocksConnection
# Suppress ResourceWarning for unclosed sockets/loops
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=ResourceWarning)
async def test_search_cases():
"""
Run specific search cases to verify filtering logic.
"""
test_cases = [
{
"name": "Red Skirt (Chân váy đỏ)",
"query": "chân váy màu đỏ",
"search_item": SearchItem(
query="chân váy đỏ",
product_name="chân váy",
master_color="đỏ",
magento_ref_code=None, price_min=None, price_max=None,
gender_by_product=None, age_by_product=None, form_sleeve=None, style=None,
fitting=None, form_neckline=None, material_group=None, season=None, product_line_vn=None
),
"expect": {
"product_matches": ["skirt", "chân váy"],
"color_matches": ["red", "đỏ"]
}
},
{
"name": "Red Pants (Quần đỏ)",
"query": "quần đỏ",
"search_item": SearchItem(
query="quần đỏ",
product_name="quần",
master_color="đỏ",
magento_ref_code=None, price_min=None, price_max=None,
gender_by_product=None, age_by_product=None, form_sleeve=None, style=None,
fitting=None, form_neckline=None, material_group=None, season=None, product_line_vn=None
),
"expect": {
"product_matches": ["pants", "quần", "trousers"],
"color_matches": ["red", "đỏ"]
}
},
{
"name": "Wool Material (Vải len)",
"query": "đồ len",
"search_item": SearchItem(
query="đồ len",
material_group="len",
product_name=None, magento_ref_code=None, price_min=None, price_max=None,
gender_by_product=None, age_by_product=None, master_color=None, form_sleeve=None, style=None,
fitting=None, form_neckline=None, season=None, product_line_vn=None
),
"expect": {
"material_matches": ["wool", "len", "cashmere"]
}
}
]
print("\n" + "="*80)
print("🚀 STARTING DYNAMIC SEARCH VERIFICATION")
print("="*80 + "\n")
try:
for case in test_cases:
print(f"🔍 Testing Case: {case['name']}")
print(f" Query: {case['query']}")
try:
case['result_status'] = "FAIL"
# We call the tool. Format: {"searches": [item]}
result_json = await data_retrieval_tool.ainvoke({"searches": [case["search_item"]]})
result = json.loads(result_json)
if result["status"] != "success":
case['result_detail'] = f"Tool Error: {result.get('message')}"
print(f" ❌ FAILED: {case['result_detail']}")
continue
products = result["results"]
filter_info = result["filter_info"]
print(f" Found {len(products)} products.")
if not products:
case['result_status'] = "NO_RESULTS"
detail = "0 products found"
if filter_info.get("message"):
detail += f" [Msg: {filter_info.get('message')}]"
case['result_detail'] = detail
print(f" ⚠️ {detail}")
continue
# Verify first few products
match_count = 0
check_limit = min(5, len(products))
for i in range(check_limit):
p = products[i]
desc = p.get("description_text_full", "") or p.get("description", "")
desc = desc.lower()
is_valid = True
# Check Color
if "color_matches" in case["expect"]:
found_color = False
for c in case["expect"]["color_matches"]:
if c in desc:
found_color = True
break
if not found_color:
is_valid = False
# print(f" ❌ Product {i}: Color matches not found")
# Check Material
if "material_matches" in case["expect"]:
found_mat = False
for m in case["expect"]["material_matches"]:
if m in desc:
found_mat = True
break
if not found_mat:
is_valid = False
# print(f" ❌ Product {i}: Material matches not found")
if is_valid:
match_count += 1
if match_count == check_limit:
case['result_status'] = "PASS"
case['result_detail'] = f"Top {check_limit} products match criteria"
print(f" ✅ VERIFIED: {case['result_detail']}")
else:
case['result_status'] = "PARTIAL"
case['result_detail'] = f"{match_count}/{check_limit} matched criteria"
print(f" ⚠️ PARTIAL: {case['result_detail']}")
except Exception as e:
case['result_status'] = "ERROR"
case['result_detail'] = str(e)
print(f" ❌ EXCEPTION: {e}")
print("-" * 50)
finally:
print("\n" + "="*80)
print("📊 TEST SUMMARY")
print("="*80)
for case in test_cases:
status = case.get('result_status', 'UNKNOWN')
detail = case.get('result_detail', '')
print(f"🔹 {case['name']:<30} | {status:<10} | {detail}")
print("="*80)
# CLEANUP
# print("\n🧹 Cleaning up connections...")
await StarRocksConnection.clear_pool()
if __name__ == "__main__":
asyncio.run(test_search_cases())
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment