Commit 3baf7736 authored by Hoanganhvu123's avatar Hoanganhvu123

feat: integrate Claude API via zunef gateway with JSON parsing

Cherry-pick commit 88341cc4 with conflict resolution in graph.py.
Replace with_structured_output with manual JSON parsing for Claude API.
Co-Authored-By: 's avatarClaude Sonnet 4.6 <noreply@anthropic.com>
parent 7659d548
...@@ -269,13 +269,15 @@ class LeadStageGraph: ...@@ -269,13 +269,15 @@ class LeadStageGraph:
"add_to_cart": add_to_cart, "add_to_cart": add_to_cart,
} }
# AI 1: Classifier — structured output, NHẸ (streaming=False) # AI 1: Classifier — structured output bypass for Claude API (zunef gateway)
_classifier_base = create_llm(model_name=self.model_name, streaming=False) _classifier_base = create_llm(model_name=self.model_name, streaming=False)
self.classifier_llm = _classifier_base.with_structured_output(ClassifierOutput, method="function_calling") self.classifier_llm = _classifier_base
# Note: Claude models bypass structured output, parse JSON manually
# AI 2: Stylist — structured output, NẶNG (streaming=False vì structured) # AI 2: Stylist — structured output bypass for Claude API (zunef gateway)
_stylist_base = create_llm(model_name=self.model_name, streaming=False) _stylist_base = create_llm(model_name=self.model_name, streaming=False)
self.stylist_llm = _stylist_base.with_structured_output(StylistOutput, method="function_calling") self.stylist_llm = _stylist_base
# Note: Claude models bypass structured output, parse JSON manually
self._compiled = None self._compiled = None
logger.info(f"✅ LeadStageGraph initialized | model: {self.model_name}") logger.info(f"✅ LeadStageGraph initialized | model: {self.model_name}")
...@@ -386,7 +388,33 @@ class LeadStageGraph: ...@@ -386,7 +388,33 @@ class LeadStageGraph:
start = time.time() start = time.time()
try: try:
output: ClassifierOutput = await self.classifier_llm.ainvoke(classifier_messages, config=config) response = await self.classifier_llm.ainvoke(
classifier_messages, config=config
)
# Parse JSON from text response (Claude models return text, not structured output)
import json as json_lib
text = response.content if hasattr(response, "content") else str(response)
logger.info(f"[CLASSIFIER] Raw response repr: {repr(text[:500])}")
# Extract JSON from text (may have extra text)
text = text.strip()
if text.startswith('```json'):
text = text[7:]
if text.startswith('```'):
text = text[3:]
if text.endswith('```'):
text = text[:-3]
text = text.strip()
try:
output = json_lib.loads(text)
except json_lib.JSONDecodeError as je:
logger.error(f"[CLASSIFIER] JSON parse error: {je}. Text was: {repr(text[:500])}")
raise
# Convert dict to object-like access
class DictToObj:
def __init__(self, data):
self.__dict__.update(data)
output = DictToObj(output)
except Exception as e: except Exception as e:
logger.error(f"❌ Classifier error: {e}") logger.error(f"❌ Classifier error: {e}")
return { return {
...@@ -403,7 +431,13 @@ class LeadStageGraph: ...@@ -403,7 +431,13 @@ class LeadStageGraph:
# ── Resolve tool_args: typed lead_search_args hoac raw dict ── # ── Resolve tool_args: typed lead_search_args hoac raw dict ──
if tool_name == "lead_search_tool" and output.lead_search_args: if tool_name == "lead_search_tool" and output.lead_search_args:
# Handle both dict and Pydantic model cases
if isinstance(output.lead_search_args, dict):
tool_args = output.lead_search_args
elif hasattr(output.lead_search_args, 'model_dump'):
tool_args = output.lead_search_args.model_dump(exclude_none=False) tool_args = output.lead_search_args.model_dump(exclude_none=False)
else:
tool_args = dict(output.lead_search_args) if output.lead_search_args else {}
elif tool_name == "lead_search_tool" and output.tool_args: elif tool_name == "lead_search_tool" and output.tool_args:
# fallback: model dung tool_args thay vi lead_search_args # fallback: model dung tool_args thay vi lead_search_args
tool_args = output.tool_args tool_args = output.tool_args
...@@ -557,13 +591,39 @@ class LeadStageGraph: ...@@ -557,13 +591,39 @@ class LeadStageGraph:
HumanMessage(content="=== KHÔNG CÓ TOOL RESULT === (Khách chào hỏi hoặc câu hỏi chung)") HumanMessage(content="=== KHÔNG CÓ TOOL RESULT === (Khách chào hỏi hoặc câu hỏi chung)")
) )
# ── LLM Call: Structured Output ── # ── LLM Call: Parse JSON from text (Claude returns text, not structured) ──
# ── LLM Call: Parse JSON from text (Claude returns text, not structured) ──
start = time.time() start = time.time()
try: try:
output: StylistOutput = await self.stylist_llm.ainvoke(stylist_messages, config=config) response = await self.stylist_llm.ainvoke(
stylist_messages, config=config
)
# Parse JSON from text response
text = response.content if hasattr(response, "content") else str(response)
logger.info(f"[STYLIST] Raw response text (first 500 chars): {repr(text[:500])}")
text = text.strip()
if text.startswith('```json'):
text = text[7:]
if text.startswith('```'):
text = text[3:]
if text.endswith('```'):
text = text[:-3]
text = text.strip()
output = json.loads(text)
# Convert dict to object-like access
class DictToObj:
def __init__(self, data):
self.__dict__.update(data)
def __getattr__(self, name):
return self.__dict__.get(name)
output = DictToObj(output)
# Handle nested user_insight dict
if hasattr(output, "user_insight") and isinstance(output.user_insight, dict):
output.user_insight = DictToObj(output.user_insight)
except Exception as e: except Exception as e:
logger.error(f"❌ Stylist structured output error: {e}") logger.error(f"❌ Stylist JSON parse error: {e}")
# Fallback: trả response text thuần
fallback = "Dạ bạn cho mình hỏi thêm để tư vấn chính xác hơn nhé!" fallback = "Dạ bạn cho mình hỏi thêm để tư vấn chính xác hơn nhé!"
return { return {
"messages": [AIMessage(content=fallback)], "messages": [AIMessage(content=fallback)],
...@@ -582,8 +642,18 @@ class LeadStageGraph: ...@@ -582,8 +642,18 @@ class LeadStageGraph:
from_text = _extract_skus_from_text(output.ai_response) from_text = _extract_skus_from_text(output.ai_response)
merged_product_ids = _dedupe_preserve_order((output.product_ids or []) + from_text) merged_product_ids = _dedupe_preserve_order((output.product_ids or []) + from_text)
# Extract insight # Extract insight - handle both dict and object cases
insight_dict = output.user_insight.model_dump() merged_product_ids = _dedupe_preserve_order((output.product_ids or []) + from_text)
# Extract insight - handle both dict and object cases
insight_raw = getattr(output, "user_insight", None)
if isinstance(insight_raw, dict):
insight_dict = insight_raw
elif hasattr(insight_raw, "__dict__"):
insight_dict = insight_raw.__dict__
else:
insight_dict = {}
lead_stage = { lead_stage = {
"stage": insight_dict.get("STAGE_NUM", 1), "stage": insight_dict.get("STAGE_NUM", 1),
"stage_name": insight_dict.get("STAGE", "BROWSE"), "stage_name": insight_dict.get("STAGE", "BROWSE"),
...@@ -592,21 +662,24 @@ class LeadStageGraph: ...@@ -592,21 +662,24 @@ class LeadStageGraph:
} }
stage_injection = format_insight_injection(insight_dict) stage_injection = format_insight_injection(insight_dict)
ai_response = getattr(output, "ai_response", "Dạ mình chưa hiểu, bạn nói lại nhé!")
product_ids = getattr(output, "product_ids", []) or []
logger.info( logger.info(
f"💬 Stylist: stage={insight_dict.get('STAGE')} | goal={insight_dict.get('GOAL')} | " f"💬 Stylist: stage={insight_dict.get('STAGE')} | goal={insight_dict.get('GOAL')} | "
f"response_len={len(output.ai_response)} | products={merged_product_ids} | {elapsed_ms:.0f}ms" f"response_len={len(ai_response)} | products={merged_product_ids} | {elapsed_ms:.0f}ms"
) )
diag = { diag = {
"step": "stylist", "step": "stylist",
"label": "💬 Stylist (Response + Insight)", "label": "💬 Stylist (Response + Insight)",
"content": output.ai_response[:500], "content": ai_response[:500],
"elapsed_ms": round(elapsed_ms), "elapsed_ms": round(elapsed_ms),
"raw_json": json.dumps(output.model_dump(), ensure_ascii=False, indent=2), "raw_json": json.dumps(insight_dict, ensure_ascii=False, indent=2),
} }
return { return {
"messages": [AIMessage(content=output.ai_response)], "messages": [AIMessage(content=ai_response)],
"updated_insight": insight_dict, "updated_insight": insight_dict,
"lead_stage": lead_stage, "lead_stage": lead_stage,
"stage_injection": stage_injection, "stage_injection": stage_injection,
......
...@@ -225,6 +225,18 @@ Query: "combo gì cho ngày 2/9" ...@@ -225,6 +225,18 @@ Query: "combo gì cho ngày 2/9"
→ inferred.tags: ["Màu đỏ", "Đi chơi / dạo phố"] → inferred.tags: ["Màu đỏ", "Đi chơi / dạo phố"]
→ inferred.keywords: ["quốc khánh", "2/9", "lễ lớn"] → inferred.keywords: ["quốc khánh", "2/9", "lễ lớn"]
``` ```
**QUAN TRỌNG: BẮT BUỘC OUTPUT FORMAT**
Bạn PHẢI trả về DUY NHẤT một JSON object hợp lệ, KHÔNG text thêm thắt.
JSON schema:
{
"reasoning": "lý luận ngắn gọn",
"tool_name": null hoặc "lead_search_tool",
"lead_search_args": {"literal": {"raw_text": "..."}, "inferred": {"product_line_vn": [], "gender_by_product": null, ...}, "magento_ref_code": null, "reasoning": "..."},
"tool_args": null,
"ai_response": null hoặc "câu chào hỏi",
"product_ids": []
}
""" """
# ═══════════════════════════════════════════════ # ═══════════════════════════════════════════════
...@@ -420,6 +432,15 @@ Cập nhật 12 trường dưới đây sau mỗi turn. Cộng dồn thông tin. ...@@ -420,6 +432,15 @@ Cập nhật 12 trường dưới đây sau mỗi turn. Cộng dồn thông tin.
- KHÔNG dán link URL - KHÔNG dán link URL
- KHÔNG tự bịa tính năng không có trong tool results - KHÔNG tự bịa tính năng không có trong tool results
- Mở đầu bằng 1 câu đồng cảm dựa trên GOAL - Mở đầu bằng 1 câu đồng cảm dựa trên GOAL
**QUAN TRỌNG: BẮT BUỘC OUTPUT FORMAT**
Bạn PHẢI trả về DUY NHẤT một JSON object hợp lệ, KHÔNG text thêm thắt.
JSON schema:
{
"ai_response": "câu trả lời cho khách (max 250 từ)",
"product_ids": [],
"user_insight": {"USER": "...", "TARGET": "...", "GOAL": "...", ...}
}
""" """
# ═══════════════════════════════════════════════ # ═══════════════════════════════════════════════
......
...@@ -6,11 +6,12 @@ Supports OpenAI, Groq, and Google Gemini models. ...@@ -6,11 +6,12 @@ Supports OpenAI, Groq, and Google Gemini models.
import contextlib import contextlib
import logging import logging
from langchain_core.language_models import BaseChatModel
from langchain_core.language_models import BaseChatModel from langchain_core.language_models import BaseChatModel
from langchain_google_genai import ChatGoogleGenerativeAI from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from config import GOOGLE_API_KEY, GROQ_API_KEY, OPENAI_API_KEY from config import CLAUDE_BASE_URL, DS2API_API_KEY, DS2API_BASE_URL, GOOGLE_API_KEY, GROQ_API_KEY, OPENAI_API_KEY
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -64,6 +65,12 @@ class LLMFactory: ...@@ -64,6 +65,12 @@ class LLMFactory:
) -> BaseChatModel: ) -> BaseChatModel:
"""Create and cache a new LLM instance (auto-detect provider).""" """Create and cache a new LLM instance (auto-detect provider)."""
try: try:
# Auto-detect: Claude models
is_claude = "claude" in model_name.lower()
if is_claude and CLAUDE_BASE_URL:
llm = self._create_claude(model_name, streaming, json_mode, api_key)
else:
# Auto-detect: Gemini models use Google API # Auto-detect: Gemini models use Google API
is_gemini = "gemini" in model_name.lower() is_gemini = "gemini" in model_name.lower()
...@@ -82,6 +89,25 @@ class LLMFactory: ...@@ -82,6 +89,25 @@ class LLMFactory:
def _create_gemini(self, model_name: str, streaming: bool, json_mode: bool, api_key: str | None) -> BaseChatModel: def _create_gemini(self, model_name: str, streaming: bool, json_mode: bool, api_key: str | None) -> BaseChatModel:
"""Create Google Gemini model instance. Uses GOOGLE_API_KEY if available, otherwise falls back to Gemini CLI (OAuth).""" """Create Google Gemini model instance. Uses GOOGLE_API_KEY if available, otherwise falls back to Gemini CLI (OAuth)."""
# Override: use DS2API if configured
if DS2API_BASE_URL:
logger.info(f"🔄 Gemini model {model_name} -> using DS2API wrapper (OpenAI-compatible)")
# Treat as OpenAI-compatible via DS2API
key = DS2API_API_KEY or api_key or "dummy-key"
llm_kwargs = {
"model": model_name,
"streaming": streaming,
"api_key": key,
"temperature": 0,
"max_tokens": 1500,
"base_url": DS2API_BASE_URL,
}
if json_mode:
llm_kwargs["model_kwargs"] = {"response_mime_type": "application/json"}
llm = ChatOpenAI(**llm_kwargs)
logger.info(f"✅ Created DS2API (from Gemini): {model_name} | Streaming: {streaming}")
return llm
key = GOOGLE_API_KEY key = GOOGLE_API_KEY
if not key: if not key:
# Fallback: dùng Gemini CLI (OAuth local) thay vì crash # Fallback: dùng Gemini CLI (OAuth local) thay vì crash
...@@ -113,10 +139,33 @@ class LLMFactory: ...@@ -113,10 +139,33 @@ class LLMFactory:
return llm return llm
def _create_openai(self, model_name: str, streaming: bool, json_mode: bool, api_key: str | None) -> BaseChatModel: def _create_openai(self, model_name: str, streaming: bool, json_mode: bool, api_key: str | None) -> BaseChatModel:
"""Create OpenAI-compatible model instance (OpenAI or Groq).""" """Create OpenAI-compatible model instance (OpenAI, Groq, or DS2API)."""
# --- Check DS2API first (override) ---
if DS2API_BASE_URL:
# Use DS2API as OpenAI-compatible endpoint
key = DS2API_API_KEY or api_key or OPENAI_API_KEY or "dummy-key"
base_url = DS2API_BASE_URL
logger.info(f"🔄 Using DS2API wrapper: {base_url} | model={model_name}")
llm_kwargs = {
"model": model_name,
"streaming": streaming,
"api_key": key,
"temperature": 0,
"max_tokens": 1500,
"base_url": base_url,
}
if json_mode:
llm_kwargs["model_kwargs"] = {"response_format": {"type": "json_object"}}
logger.info(f"⚙️ DS2API in JSON mode: {model_name}")
if "codex" in model_name.lower():
llm_kwargs["use_responses_api"] = True
llm = ChatOpenAI(**llm_kwargs)
logger.info(f"✅ DS2API created: {model_name} | Streaming: {streaming}")
return llm
# --- Auto-detect provider --- # --- Auto-detect provider ---
is_groq = any(kw in model_name.lower() for kw in ("gpt-oss", "llama", "mixtral", "gemma", "qwen", "deepseek")) is_groq = any(kw in model_name.lower() for kw in ("gpt-oss", "llama", "mixtral", "gemma", "qwen"))
# Also detect openai/ prefix used by Groq (e.g. "openai/gpt-oss-120b") # Also detect openai/ prefix used by Groq (e.g. "openai/gpt-oss-120b")
if model_name.startswith("openai/"): if model_name.startswith("openai/"):
is_groq = True is_groq = True
...@@ -161,6 +210,113 @@ class LLMFactory: ...@@ -161,6 +210,113 @@ class LLMFactory:
logger.info(f"✅ Created {provider}: {model_name} | Streaming: {streaming}") logger.info(f"✅ Created {provider}: {model_name} | Streaming: {streaming}")
return llm return llm
def _create_claude(self, model_name: str, streaming: bool, json_mode: bool, api_key: str | None) -> BaseChatModel:
"""Create Claude model via requests library (bypassing LangChain/anthropic SDK to avoid Cloudflare blocks)."""
import requests, json as json_lib
key = api_key or OPENAI_API_KEY
base_url = CLAUDE_BASE_URL
logger.info(f"🔄 Using Claude API (requests): {base_url} | model={model_name}")
# Return a wrapper that mimics LangChain's BaseChatModel interface
class AnthropicWrapper:
def __init__(self, key, base_url, model_name, streaming):
self.key = key
self.base_url = base_url
self.model_name = model_name
self.streaming = streaming
self._output_schema = None
def with_structured_output(self, output_schema, **kwargs):
self._output_schema = output_schema
return self
def _convert_messages(self, messages):
anthropic_messages = []
system_prompt = None
for msg in (messages if isinstance(messages, list) else [messages]):
if hasattr(msg, 'content') and hasattr(msg, 'type'):
if msg.type == 'system':
system_prompt = msg.content
else:
role = 'user' if msg.type == 'human' else 'assistant'
anthropic_messages.append({'role': role, 'content': msg.content})
elif isinstance(msg, dict):
anthropic_messages.append(msg)
elif isinstance(msg, str):
anthropic_messages.append({'role': 'user', 'content': msg})
return anthropic_messages, system_prompt
def _call_api(self, messages_list, system_prompt=None):
url = f"{self.base_url}/messages"
headers = {
'x-api-key': self.key,
'Content-Type': 'application/json'
}
payload = {
'model': self.model_name,
'max_tokens': 1500,
'temperature': 0,
'messages': messages_list
}
# Handle system prompt - ALWAYS inject JSON format requirement
json_instruction = "\n\nIMPORTANT: You MUST output ONLY valid JSON. No text before or after the JSON object."
if system_prompt:
payload['system'] = system_prompt + json_instruction
logger.info(f"[CLAUDE API] System prompt set (first 200 chars): {payload['system'][:200]}")
elif self._output_schema:
schema_str = str(self._output_schema.model_json_schema() if hasattr(self._output_schema, 'model_json_schema') else self._output_schema)
payload['system'] = f"You must output valid JSON matching this schema: {schema_str}. Output ONLY the JSON, no other text."
logger.info(f"[CLAUDE API] Output schema prompt set (first 200 chars): {payload['system'][:200]}")
logger.info(f"[CLAUDE API] Payload messages count: {len(messages_list)}")
resp = requests.post(url, headers=headers, json=payload, timeout=30, stream=True)
resp.raise_for_status()
# Parse SSE response
full_text = ""
for line in resp.iter_lines():
if line:
line_str = line.decode('utf-8') if isinstance(line, bytes) else line
if line_str.startswith('data: '):
data_str = line_str[6:].strip()
if data_str == '[DONE]':
break
try:
data = json_lib.loads(data_str)
if data.get('type') == 'content_block_delta':
delta = data.get('delta', {})
if delta.get('type') == 'text_delta':
full_text += delta.get('text', '')
except Exception:
pass
return full_text
def invoke(self, messages, **kwargs):
result = self._convert_messages(messages)
messages_list = result[0]
system_prompt = result[1]
text = self._call_api(messages_list, system_prompt)
from langchain_core.messages import AIMessage
return AIMessage(content=text)
async def ainvoke(self, messages, **kwargs):
import asyncio
loop = asyncio.get_event_loop()
result = self._convert_messages(messages)
messages_list = result[0]
system_prompt = result[1]
text = await loop.run_in_executor(None, self._call_api, messages_list, system_prompt)
from langchain_core.messages import AIMessage
return AIMessage(content=text)
llm = AnthropicWrapper(key, base_url, model_name, streaming)
logger.info(f"✅ Claude API (requests) created: {model_name} | Streaming: {streaming}")
return llm
def _enable_json_mode(self, llm: BaseChatModel, model_name: str) -> BaseChatModel: def _enable_json_mode(self, llm: BaseChatModel, model_name: str) -> BaseChatModel:
"""Enable JSON mode for the LLM.""" """Enable JSON mode for the LLM."""
try: try:
......
...@@ -81,7 +81,14 @@ OPENAI_API_KEY: str | None = os.getenv("OPENAI_API_KEY") ...@@ -81,7 +81,14 @@ OPENAI_API_KEY: str | None = os.getenv("OPENAI_API_KEY")
GOOGLE_API_KEY: str | None = os.getenv("GOOGLE_API_KEY") GOOGLE_API_KEY: str | None = os.getenv("GOOGLE_API_KEY")
GROQ_API_KEY: str | None = os.getenv("GROQ_API_KEY") GROQ_API_KEY: str | None = os.getenv("GROQ_API_KEY")
DEFAULT_MODEL: str = os.getenv("DEFAULT_MODEL", "gemini-3.1-flash-lite-preview") # DS2API (DeepSeek via OpenAI-compatible wrapper)
DS2API_BASE_URL: str | None = os.getenv("DS2API_BASE_URL")
DS2API_API_KEY: str | None = os.getenv("DS2API_API_KEY")
# Claude API (via zunef gateway)
CLAUDE_BASE_URL: str | None = os.getenv("CLAUDE_BASE_URL", "https://claude-api.zunef.com/v1/ai")
DEFAULT_MODEL: str = os.getenv("DEFAULT_MODEL", "claude-sonnet-4-6")
# DEFAULT_MODEL: str = os.getenv("DEFAULT_MODEL", "gpt-5.1-codex-mini") # DEFAULT_MODEL: str = os.getenv("DEFAULT_MODEL", "gpt-5.1-codex-mini")
# ====================== JWT CONFIGURATION ====================== # ====================== JWT CONFIGURATION ======================
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment