Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
C
chatbot canifa
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
1
Merge Requests
1
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Vũ Hoàng Anh
chatbot canifa
Commits
28274420
Commit
28274420
authored
Jan 09, 2026
by
Hoanganhvu123
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: Migrate from LangGraph to Agno framework
parent
f057ad1e
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
675 additions
and
1823 deletions
+675
-1823
__init__.py
backend/agent/__init__.py
+8
-7
agno_agent.py
backend/agent/agno_agent.py
+129
-0
agno_controller.py
backend/agent/agno_controller.py
+134
-0
controller.py
backend/agent/controller.py
+0
-200
graph.py
backend/agent/graph.py
+0
-149
agno_tools.py
backend/agent/tools/agno_tools.py
+24
-0
brand_knowledge_tool.py
backend/agent/tools/brand_knowledge_tool.py
+4
-2
customer_info_tool.py
backend/agent/tools/customer_info_tool.py
+2
-4
data_retrieval_tool.py
backend/agent/tools/data_retrieval_tool.py
+8
-3
data_retrieval_tool.save.py
backend/agent/tools/data_retrieval_tool.save.py
+0
-187
product_search_helpers_save.py
backend/agent/tools/product_search_helpers_save.py
+0
-579
chatbot_route.py
backend/api/chatbot_route.py
+1
-1
conversation_manager.py
backend/common/conversation_manager.py
+234
-37
langfuse_client.py
backend/common/langfuse_client.py
+103
-19
law_database.py
backend/common/law_database.py
+0
-470
llm_factory.py
backend/common/llm_factory.py
+0
-148
starrocks_connection.py
backend/common/starrocks_connection.py
+26
-17
requirements.txt
backend/requirements.txt
+2
-0
No files found.
backend/agent/__init__.py
View file @
28274420
"""
Fashion Q&A Agent Package
Fashion Q&A Agent Package
- Agno Framework
"""
from
.graph
import
build_graph
from
.models
import
AgentConfig
,
AgentState
,
get_config
# Only export what's needed for Agno
from
.agno_agent
import
get_agno_agent
from
.agno_controller
import
chat_controller
from
.models
import
QueryRequest
__all__
=
[
"AgentConfig"
,
"AgentState"
,
"build_graph"
,
"get_config"
,
"get_agno_agent"
,
"chat_controller"
,
"QueryRequest"
,
]
backend/agent/agno_agent.py
0 → 100644
View file @
28274420
"""
CANIFA Agent với Agno Framework
Thay thế LangGraph bằng Agno
"""
import
logging
from
typing
import
TYPE_CHECKING
,
Any
,
cast
# Type checking imports (only used for type hints)
if
TYPE_CHECKING
:
from
agno.agent
import
Agent
as
AgentType
from
agno.db.base
import
BaseDb
as
BaseDbType
from
agno.models.openai
import
OpenAIChat
as
OpenAIChatType
else
:
AgentType
=
Any
# type: ignore
BaseDbType
=
Any
# type: ignore
OpenAIChatType
=
Any
# type: ignore
# Runtime imports with fallback
try
:
from
agno.agent
import
Agent
from
agno.db.base
import
BaseDb
from
agno.models.openai
import
OpenAIChat
except
ImportError
:
# Fallback nếu chưa install agno
Agent
=
None
BaseDb
=
Any
# type: ignore
OpenAIChat
=
None
from
common.conversation_manager
import
get_conversation_manager
from
config
import
DEFAULT_MODEL
,
OPENAI_API_KEY
from
.prompt
import
get_system_prompt
from
.tools.agno_tools
import
get_agno_tools
logger
=
logging
.
getLogger
(
__name__
)
def
create_agno_model
(
model_name
:
str
=
DEFAULT_MODEL
,
json_mode
:
bool
=
False
):
"""
Tạo Agno model từ config.py
"""
if
OpenAIChat
is
None
:
raise
ImportError
(
"Agno not installed. Run: pip install agno"
)
return
OpenAIChat
(
id
=
model_name
,
api_key
=
OPENAI_API_KEY
,
# Agno sẽ handle json_mode nếu cần
)
async
def
create_agno_agent
(
model_name
:
str
=
DEFAULT_MODEL
,
json_mode
:
bool
=
False
,
)
->
AgentType
:
# type: ignore
"""
Tạo Agno Agent với ConversationManager (có memory)
Args:
model_name: Model name từ config.py
json_mode: Enable JSON output
Returns:
Configured Agno Agent
"""
# Tạo model từ config
model
=
create_agno_model
(
model_name
,
json_mode
)
# Lấy tools (đã convert sang Agno format)
tools
=
get_agno_tools
()
# Lấy system prompt
system_prompt
=
get_system_prompt
()
# Lấy ConversationManager (đã implement BaseDb interface)
db
=
await
get_conversation_manager
()
if
Agent
is
None
:
raise
ImportError
(
"Agno not installed. Run: pip install agno"
)
# Type cast: ConversationManager implements BaseDb interface (duck typing)
# Runtime sẽ hoạt động vì ConversationManager có đủ methods cần thiết
db_cast
=
cast
(
BaseDbType
,
db
)
# type: ignore[assignment]
# Tạo Agno Agent với DB (có memory)
agent
=
Agent
(
name
=
"CANIFA Agent"
,
model
=
model
,
db
=
db_cast
,
# Dùng ConversationManager (implement BaseDb interface)
tools
=
tools
,
instructions
=
system_prompt
,
# Agno dùng instructions thay vì system_prompt
add_history_to_context
=
True
,
# Bật history
num_history_runs
=
20
,
# Load 20 messages gần nhất
markdown
=
True
,
)
logger
.
info
(
f
"✅ Agno Agent created with model: {model_name} (WITH MEMORY)"
)
return
agent
# Singleton instance
_agno_agent_instance
:
AgentType
|
None
=
None
# type: ignore
async
def
get_agno_agent
(
model_name
:
str
=
DEFAULT_MODEL
,
json_mode
:
bool
=
False
,
)
->
AgentType
:
# type: ignore
"""
Get or create Agno Agent singleton (với memory)
"""
global
_agno_agent_instance
if
_agno_agent_instance
is
None
:
# Tạo agent với ConversationManager (có memory)
_agno_agent_instance
=
await
create_agno_agent
(
model_name
=
model_name
,
json_mode
=
json_mode
,
)
return
_agno_agent_instance
def
reset_agno_agent
():
"""Reset singleton for testing"""
global
_agno_agent_instance
_agno_agent_instance
=
None
backend/agent/agno_controller.py
0 → 100644
View file @
28274420
"""
CANIFA Agent Controller với Agno Framework
"""
import
json
import
logging
from
typing
import
Any
from
fastapi
import
BackgroundTasks
from
common.langfuse_client
import
langfuse_trace_context
from
config
import
DEFAULT_MODEL
from
.agno_agent
import
get_agno_agent
logger
=
logging
.
getLogger
(
__name__
)
async
def
chat_controller
(
query
:
str
,
user_id
:
str
,
background_tasks
:
BackgroundTasks
,
model_name
:
str
=
DEFAULT_MODEL
,
images
:
list
[
str
]
|
None
=
None
,
)
->
dict
:
"""
Controller với Agno Agent (có memory tự động).
Agno tự động load/save history qua ConversationManager.
"""
logger
.
info
(
f
"▶️ Agno chat_controller | User: {user_id} | Model: {model_name}"
)
try
:
agent
=
await
get_agno_agent
(
model_name
=
model_name
,
json_mode
=
True
)
with
langfuse_trace_context
(
user_id
=
user_id
,
session_id
=
user_id
):
# Agno tự động load history và save sau khi respond (memory enabled)
result
=
agent
.
run
(
query
,
session_id
=
user_id
)
# Extract response
ai_content
=
str
(
result
.
content
if
hasattr
(
result
,
"content"
)
and
result
.
content
else
str
(
result
))
logger
.
info
(
f
"💾 AI Response: {ai_content[:200]}..."
)
# Parse response và extract products
ai_text
,
product_ids
=
_parse_agno_response
(
result
,
ai_content
)
return
{
"ai_response"
:
ai_text
,
"product_ids"
:
product_ids
,
}
except
Exception
as
e
:
logger
.
error
(
f
"💥 Agno chat error for user {user_id}: {e}"
,
exc_info
=
True
)
raise
def
_parse_agno_response
(
result
:
Any
,
ai_content
:
str
)
->
tuple
[
str
,
list
[
dict
]]:
"""
Parse Agno response và extract AI text + product IDs.
Returns: (ai_text_response, product_ids)
"""
ai_text
=
ai_content
product_ids
=
[]
# Try parse JSON response
try
:
ai_json
=
json
.
loads
(
ai_content
)
ai_text
=
ai_json
.
get
(
"ai_response"
,
ai_content
)
product_ids
=
ai_json
.
get
(
"product_ids"
,
[])
or
[]
except
(
json
.
JSONDecodeError
,
Exception
)
as
e
:
logger
.
debug
(
f
"Response is not JSON, using raw text: {e}"
)
# Extract products từ tool results
if
hasattr
(
result
,
"messages"
):
tool_products
=
_extract_products_from_messages
(
result
.
messages
)
# Merge và deduplicate
seen_skus
=
{
p
.
get
(
"sku"
)
for
p
in
product_ids
if
isinstance
(
p
,
dict
)
and
"sku"
in
p
}
for
product
in
tool_products
:
if
isinstance
(
product
,
dict
)
and
product
.
get
(
"sku"
)
not
in
seen_skus
:
product_ids
.
append
(
product
)
seen_skus
.
add
(
product
.
get
(
"sku"
))
return
ai_text
,
product_ids
def
_extract_products_from_messages
(
messages
:
list
)
->
list
[
dict
]:
"""Extract products từ Agno tool messages."""
products
=
[]
seen_skus
=
set
()
for
msg
in
messages
:
if
not
(
hasattr
(
msg
,
"content"
)
and
isinstance
(
msg
.
content
,
str
)):
continue
try
:
tool_result
=
json
.
loads
(
msg
.
content
)
if
tool_result
.
get
(
"status"
)
!=
"success"
:
continue
# Handle multi-search format
if
"results"
in
tool_result
:
for
result_item
in
tool_result
[
"results"
]:
products
.
extend
(
_parse_products
(
result_item
.
get
(
"products"
,
[]),
seen_skus
))
# Handle single search format
elif
"products"
in
tool_result
:
products
.
extend
(
_parse_products
(
tool_result
[
"products"
],
seen_skus
))
except
(
json
.
JSONDecodeError
,
KeyError
,
TypeError
)
as
e
:
logger
.
debug
(
f
"Skip invalid tool message: {e}"
)
continue
return
products
def
_parse_products
(
products
:
list
[
dict
],
seen_skus
:
set
[
str
])
->
list
[
dict
]:
"""Parse và format products, skip duplicates."""
parsed
=
[]
for
product
in
products
:
if
not
isinstance
(
product
,
dict
):
continue
sku
=
product
.
get
(
"internal_ref_code"
)
if
not
sku
or
sku
in
seen_skus
:
continue
seen_skus
.
add
(
sku
)
parsed
.
append
({
"sku"
:
sku
,
"name"
:
product
.
get
(
"magento_product_name"
,
""
),
"price"
:
product
.
get
(
"price_vnd"
,
0
),
"sale_price"
:
product
.
get
(
"sale_price_vnd"
),
"url"
:
product
.
get
(
"magento_url_key"
,
""
),
"thumbnail_image_url"
:
product
.
get
(
"thumbnail_image_url"
,
""
),
})
return
parsed
backend/agent/controller.py
deleted
100644 → 0
View file @
f057ad1e
"""
Fashion Q&A Agent Controller
Langfuse will auto-trace via LangChain integration (no code changes needed).
"""
import
json
import
logging
import
uuid
from
fastapi
import
BackgroundTasks
from
langchain_core.messages
import
AIMessage
,
HumanMessage
,
ToolMessage
from
langchain_core.runnables
import
RunnableConfig
from
common.conversation_manager
import
ConversationManager
,
get_conversation_manager
from
common.langfuse_client
import
get_callback_handler
,
langfuse_trace_context
from
common.llm_factory
import
create_llm
from
config
import
DEFAULT_MODEL
from
.graph
import
build_graph
from
.models
import
AgentState
,
get_config
from
.tools.get_tools
import
get_all_tools
logger
=
logging
.
getLogger
(
__name__
)
async
def
chat_controller
(
query
:
str
,
user_id
:
str
,
background_tasks
:
BackgroundTasks
,
model_name
:
str
=
DEFAULT_MODEL
,
images
:
list
[
str
]
|
None
=
None
,
)
->
dict
:
"""
Controller main logic for non-streaming chat requests.
Langfuse will automatically trace all LangChain operations.
"""
logger
.
info
(
f
"▶️ Starting chat_controller with model: {model_name} for user: {user_id}"
)
config
=
get_config
()
config
.
model_name
=
model_name
# Enable JSON mode to ensure structured output
llm
=
create_llm
(
model_name
=
model_name
,
streaming
=
False
,
json_mode
=
True
)
tools
=
get_all_tools
()
graph
=
build_graph
(
config
,
llm
=
llm
,
tools
=
tools
)
# Init ConversationManager (Singleton)
memory
=
await
get_conversation_manager
()
# LOAD HISTORY & Prepare State (Optimize: history logic remains solid)
history_dicts
=
await
memory
.
get_chat_history
(
user_id
,
limit
=
20
)
history
=
[]
for
h
in
reversed
(
history_dicts
):
msg_cls
=
HumanMessage
if
h
[
"is_human"
]
else
AIMessage
history
.
append
(
msg_cls
(
content
=
h
[
"message"
]))
initial_state
,
exec_config
=
_prepare_execution_context
(
query
=
query
,
user_id
=
user_id
,
history
=
history
,
images
=
images
)
try
:
# 🔥 Wrap graph execution với langfuse_trace_context để set user_id cho tất cả observations
with
langfuse_trace_context
(
user_id
=
user_id
,
session_id
=
user_id
):
# TỐI ƯU: Chạy Graph
result
=
await
graph
.
ainvoke
(
initial_state
,
config
=
exec_config
)
# TỐI ƯU: Extract IDs từ Tool Messages một lần duy nhất
all_product_ids
=
_extract_product_ids
(
result
.
get
(
"messages"
,
[]))
# TỐI ƯU: Xử lý AI Response
ai_raw_content
=
result
.
get
(
"ai_response"
)
.
content
if
result
.
get
(
"ai_response"
)
else
""
logger
.
info
(
f
"💾 [RAW AI OUTPUT]:
\n
{ai_raw_content}"
)
# Parse JSON để lấy text response và product_ids từ AI
ai_text_response
=
ai_raw_content
try
:
# Vì json_mode=True, OpenAI sẽ nhả raw JSON
ai_json
=
json
.
loads
(
ai_raw_content
)
# Extract text response từ JSON
ai_text_response
=
ai_json
.
get
(
"ai_response"
,
ai_raw_content
)
# Merge product_ids từ AI JSON (nếu có) - KHÔNG dùng set() vì dict unhashable
explicit_ids
=
ai_json
.
get
(
"product_ids"
,
[])
if
explicit_ids
and
isinstance
(
explicit_ids
,
list
):
# Merge và deduplicate by SKU
seen_skus
=
{
p
[
"sku"
]
for
p
in
all_product_ids
if
"sku"
in
p
}
for
product
in
explicit_ids
:
if
isinstance
(
product
,
dict
)
and
product
.
get
(
"sku"
)
not
in
seen_skus
:
all_product_ids
.
append
(
product
)
seen_skus
.
add
(
product
.
get
(
"sku"
))
except
(
json
.
JSONDecodeError
,
Exception
)
as
e
:
# Nếu AI trả về text thường (hiếm khi xảy ra trong JSON mode) thì ignore
logger
.
warning
(
f
"Could not parse AI response as JSON: {e}"
)
pass
# BACKGROUND TASK: Lưu history nhanh gọn
background_tasks
.
add_task
(
_handle_post_chat_async
,
memory
=
memory
,
user_id
=
user_id
,
human_query
=
query
,
ai_msg
=
AIMessage
(
content
=
ai_text_response
),
)
return
{
"ai_response"
:
ai_text_response
,
# CHỈ text, không phải JSON
"product_ids"
:
all_product_ids
,
# Array of product objects
}
except
Exception
as
e
:
logger
.
error
(
f
"💥 Chat error for user {user_id}: {e}"
,
exc_info
=
True
)
raise
def
_extract_product_ids
(
messages
:
list
)
->
list
[
dict
]:
"""
Extract full product info from tool messages (data_retrieval_tool results).
Returns list of product objects with: sku, name, price, sale_price, url, thumbnail_image_url.
"""
products
=
[]
seen_skus
=
set
()
for
msg
in
messages
:
if
isinstance
(
msg
,
ToolMessage
):
try
:
# Tool result is JSON string
tool_result
=
json
.
loads
(
msg
.
content
)
# Check if tool returned products
if
tool_result
.
get
(
"status"
)
==
"success"
and
"products"
in
tool_result
:
for
product
in
tool_result
[
"products"
]:
sku
=
product
.
get
(
"internal_ref_code"
)
if
sku
and
sku
not
in
seen_skus
:
seen_skus
.
add
(
sku
)
# Extract full product info
product_obj
=
{
"sku"
:
sku
,
"name"
:
product
.
get
(
"magento_product_name"
,
""
),
"price"
:
product
.
get
(
"price_vnd"
,
0
),
"sale_price"
:
product
.
get
(
"sale_price_vnd"
),
# null nếu không sale
"url"
:
product
.
get
(
"magento_url_key"
,
""
),
"thumbnail_image_url"
:
product
.
get
(
"thumbnail_image_url"
,
""
),
}
products
.
append
(
product_obj
)
except
(
json
.
JSONDecodeError
,
KeyError
,
TypeError
)
as
e
:
logger
.
debug
(
f
"Could not parse tool message for products: {e}"
)
continue
return
products
def
_prepare_execution_context
(
query
:
str
,
user_id
:
str
,
history
:
list
,
images
:
list
|
None
):
"""Prepare initial state and execution config for the graph run."""
initial_state
:
AgentState
=
{
"user_query"
:
HumanMessage
(
content
=
query
),
"messages"
:
[
HumanMessage
(
content
=
query
)],
"history"
:
history
,
"user_id"
:
user_id
,
"images_embedding"
:
[],
"ai_response"
:
None
,
}
run_id
=
str
(
uuid
.
uuid4
())
# Metadata for LangChain (tags for logging/filtering)
metadata
=
{
"run_id"
:
run_id
,
"tags"
:
"chatbot,production"
,
}
# 🔥 CallbackHandler - sẽ được wrap trong langfuse_trace_context để set user_id
# Per Langfuse docs: propagate_attributes() handles user_id propagation
langfuse_handler
=
get_callback_handler
()
exec_config
=
RunnableConfig
(
configurable
=
{
"user_id"
:
user_id
,
"transient_images"
:
images
or
[],
"run_id"
:
run_id
,
},
run_id
=
run_id
,
metadata
=
metadata
,
callbacks
=
[
langfuse_handler
]
if
langfuse_handler
else
[],
)
return
initial_state
,
exec_config
async
def
_handle_post_chat_async
(
memory
:
ConversationManager
,
user_id
:
str
,
human_query
:
str
,
ai_msg
:
AIMessage
|
None
):
"""Save chat history in background task after response is sent."""
if
ai_msg
:
try
:
await
memory
.
save_conversation_turn
(
user_id
,
human_query
,
ai_msg
.
content
)
logger
.
debug
(
f
"Saved conversation for user {user_id}"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to save conversation for user {user_id}: {e}"
,
exc_info
=
True
)
backend/agent/graph.py
deleted
100644 → 0
View file @
f057ad1e
"""
Fashion Q&A Agent Graph
LangGraph workflow với clean architecture.
Tất cả resources (LLM, Tools) khởi tạo trong __init__.
Sử dụng ConversationManager (Postgres) để lưu history thay vì checkpoint.
"""
import
logging
from
typing
import
Any
from
langchain_core.language_models
import
BaseChatModel
from
langchain_core.prompts
import
ChatPromptTemplate
,
MessagesPlaceholder
from
langchain_core.runnables
import
RunnableConfig
from
langgraph.cache.memory
import
InMemoryCache
from
langgraph.graph
import
END
,
StateGraph
from
langgraph.prebuilt
import
ToolNode
from
langgraph.types
import
CachePolicy
from
common.llm_factory
import
create_llm
from
.models
import
AgentConfig
,
AgentState
,
get_config
from
.prompt
import
get_system_prompt
from
.tools.get_tools
import
get_all_tools
,
get_collection_tools
logger
=
logging
.
getLogger
(
__name__
)
class
CANIFAGraph
:
"""
Fashion Q&A Agent Graph Manager.
"""
def
__init__
(
self
,
config
:
AgentConfig
|
None
=
None
,
llm
:
BaseChatModel
|
None
=
None
,
tools
:
list
|
None
=
None
,
):
self
.
config
=
config
or
get_config
()
self
.
_compiled_graph
:
Any
|
None
=
None
self
.
llm
:
BaseChatModel
=
llm
or
create_llm
(
model_name
=
self
.
config
.
model_name
,
api_key
=
self
.
config
.
openai_api_key
,
streaming
=
True
)
self
.
all_tools
=
tools
or
get_all_tools
()
self
.
collection_tools
=
get_collection_tools
()
# Vẫn lấy list name để routing
self
.
retrieval_tools
=
self
.
all_tools
self
.
llm_with_tools
=
self
.
llm
.
bind_tools
(
self
.
all_tools
,
strict
=
True
)
self
.
system_prompt
=
get_system_prompt
()
self
.
prompt_template
=
ChatPromptTemplate
.
from_messages
(
[
(
"system"
,
self
.
system_prompt
),
MessagesPlaceholder
(
variable_name
=
"history"
),
MessagesPlaceholder
(
variable_name
=
"user_query"
),
MessagesPlaceholder
(
variable_name
=
"messages"
),
]
)
self
.
chain
=
self
.
prompt_template
|
self
.
llm_with_tools
self
.
cache
=
InMemoryCache
()
async
def
_agent_node
(
self
,
state
:
AgentState
,
config
:
RunnableConfig
)
->
dict
:
"""Agent node - Chỉ việc đổ dữ liệu riêng vào khuôn đã có sẵn."""
messages
=
state
.
get
(
"messages"
,
[])
history
=
state
.
get
(
"history"
,
[])
user_query
=
state
.
get
(
"user_query"
)
transient_images
=
config
.
get
(
"configurable"
,
{})
.
get
(
"transient_images"
,
[])
if
transient_images
and
messages
:
pass
# Invoke chain with user_query, history, and messages
response
=
await
self
.
chain
.
ainvoke
({
"user_query"
:
[
user_query
]
if
user_query
else
[],
"history"
:
history
,
"messages"
:
messages
})
return
{
"messages"
:
[
response
],
"ai_response"
:
response
}
def
_should_continue
(
self
,
state
:
AgentState
)
->
str
:
"""Routing: tool nodes hoặc end."""
last_message
=
state
[
"messages"
][
-
1
]
if
not
hasattr
(
last_message
,
"tool_calls"
)
or
not
last_message
.
tool_calls
:
logger
.
info
(
"🏁 Agent finished"
)
return
"end"
tool_names
=
[
tc
[
"name"
]
for
tc
in
last_message
.
tool_calls
]
collection_names
=
[
t
.
name
for
t
in
self
.
collection_tools
]
if
any
(
name
in
collection_names
for
name
in
tool_names
):
logger
.
info
(
f
"🔄 → collect_tools: {tool_names}"
)
return
"collect_tools"
logger
.
info
(
f
"🔄 → retrieve_tools: {tool_names}"
)
return
"retrieve_tools"
def
build
(
self
)
->
Any
:
"""Build và compile LangGraph workflow."""
if
self
.
_compiled_graph
is
not
None
:
return
self
.
_compiled_graph
workflow
=
StateGraph
(
AgentState
)
# Nodes
workflow
.
add_node
(
"agent"
,
self
.
_agent_node
)
workflow
.
add_node
(
"retrieve_tools"
,
ToolNode
(
self
.
retrieval_tools
),
cache_policy
=
CachePolicy
(
ttl
=
3600
))
workflow
.
add_node
(
"collect_tools"
,
ToolNode
(
self
.
collection_tools
))
# Edges
workflow
.
set_entry_point
(
"agent"
)
workflow
.
add_conditional_edges
(
"agent"
,
self
.
_should_continue
,
{
"retrieve_tools"
:
"retrieve_tools"
,
"collect_tools"
:
"collect_tools"
,
"end"
:
END
},
)
workflow
.
add_edge
(
"retrieve_tools"
,
"agent"
)
workflow
.
add_edge
(
"collect_tools"
,
"agent"
)
self
.
_compiled_graph
=
workflow
.
compile
(
cache
=
self
.
cache
)
# No Checkpointer
logger
.
info
(
"✅ Graph compiled (Langfuse callback will be per-run)"
)
return
self
.
_compiled_graph
@
property
def
graph
(
self
)
->
Any
:
return
self
.
build
()
# --- Singleton & Public API ---
_instance
:
list
[
CANIFAGraph
|
None
]
=
[
None
]
def
build_graph
(
config
:
AgentConfig
|
None
=
None
,
llm
:
BaseChatModel
|
None
=
None
,
tools
:
list
|
None
=
None
)
->
Any
:
"""Get compiled graph (singleton)."""
if
_instance
[
0
]
is
None
:
_instance
[
0
]
=
CANIFAGraph
(
config
,
llm
,
tools
)
return
_instance
[
0
]
.
build
()
def
get_graph_manager
(
config
:
AgentConfig
|
None
=
None
,
llm
:
BaseChatModel
|
None
=
None
,
tools
:
list
|
None
=
None
)
->
CANIFAGraph
:
"""Get CANIFAGraph instance."""
if
_instance
[
0
]
is
None
:
_instance
[
0
]
=
CANIFAGraph
(
config
,
llm
,
tools
)
return
_instance
[
0
]
def
reset_graph
()
->
None
:
"""Reset singleton for testing."""
_instance
[
0
]
=
None
backend/agent/tools/agno_tools.py
0 → 100644
View file @
28274420
"""
Agno Tools - Pure Python functions cho Agno Agent
Đã convert từ LangChain @tool decorator sang Agno format
"""
from
.data_retrieval_tool
import
data_retrieval_tool
from
.brand_knowledge_tool
import
canifa_knowledge_search
from
.customer_info_tool
import
collect_customer_info
def
get_agno_tools
():
"""
Get tools cho Agno Agent.
Agno tự động convert Python functions thành tool definitions.
Returns:
List of Python functions (Agno tools)
"""
return
[
data_retrieval_tool
,
canifa_knowledge_search
,
collect_customer_info
,
]
backend/agent/tools/brand_knowledge_tool.py
View file @
28274420
import
logging
from
langchain_core.tools
import
tool
from
pydantic
import
BaseModel
,
Field
from
common.embedding_service
import
create_embedding_async
...
...
@@ -15,7 +14,6 @@ class KnowledgeSearchInput(BaseModel):
)
@
tool
(
"canifa_knowledge_search"
,
args_schema
=
KnowledgeSearchInput
)
async
def
canifa_knowledge_search
(
query
:
str
)
->
str
:
"""
Tra cứu TOÀN BỘ thông tin về thương hiệu và dịch vụ của Canifa.
...
...
@@ -35,6 +33,10 @@ async def canifa_knowledge_search(query: str) -> str:
- 'Cho mình xem bảng size áo nam.'
- 'Phí vận chuyển đi tỉnh là bao nhiêu?'
- 'Canifa thành lập năm nào?'
Args:
query: Câu hỏi hoặc nhu cầu tìm kiếm thông tin phi sản phẩm của khách hàng
(ví dụ: tìm cửa hàng, hỏi chính sách, tra bảng size...)
"""
logger
.
info
(
f
"🔍 [Semantic Search] Brand Knowledge query: {query}"
)
...
...
backend/agent/tools/customer_info_tool.py
View file @
28274420
...
...
@@ -6,16 +6,14 @@ Dùng để đẩy data về CRM hoặc hệ thống lưu trữ khách hàng.
import
json
import
logging
from
langchain_core.tools
import
tool
logger
=
logging
.
getLogger
(
__name__
)
@
tool
async
def
collect_customer_info
(
name
:
str
,
phone
:
str
,
email
:
str
|
None
)
->
str
:
async
def
collect_customer_info
(
name
:
str
,
phone
:
str
,
email
:
str
|
None
=
None
)
->
str
:
"""
Sử dụng tool này để ghi lại thông tin khách hàng khi họ muốn tư vấn sâu hơn,
nhận khuyến mãi hoặc đăng ký mua hàng.
Args:
name: Tên của khách hàng
phone: Số điện thoại của khách hàng
...
...
backend/agent/tools/data_retrieval_tool.py
View file @
28274420
...
...
@@ -9,7 +9,6 @@ import logging
import
time
from
decimal
import
Decimal
from
langchain_core.tools
import
tool
from
pydantic
import
BaseModel
,
Field
from
agent.tools.product_search_helpers
import
build_starrocks_query
,
save_preview_to_log
...
...
@@ -50,8 +49,6 @@ class MultiSearchParams(BaseModel):
searches
:
list
[
SearchItem
]
=
Field
(
...
,
description
=
"Danh sách các truy vấn tìm kiếm chạy song song"
)
@
tool
(
args_schema
=
MultiSearchParams
)
# @traceable(run_type="tool", name="data_retrieval_tool")
async
def
data_retrieval_tool
(
searches
:
list
[
SearchItem
])
->
str
:
"""
Siêu công cụ tìm kiếm sản phẩm CANIFA - Hỗ trợ Parallel Multi-Search (Chạy song song nhiều query).
...
...
@@ -86,6 +83,14 @@ async def data_retrieval_tool(searches: list[SearchItem]) -> str:
{"query": "Quần jean nam slim fit năng động"},
{"query": "Áo khoác nam thể thao trẻ trung"}
]
Args:
searches: Danh sách các truy vấn tìm kiếm chạy song song. Mỗi item là SearchItem với:
- query: Mô tả sản phẩm chi tiết (bắt buộc)
- magento_ref_code: Mã sản phẩm cụ thể (nếu có)
- price_min: Giá thấp nhất (nếu có)
- price_max: Giá cao nhất (nếu có)
- action: 'search' hoặc 'visual_search'
"""
logger
.
info
(
"🔧 [DEBUG] data_retrieval_tool STARTED"
)
try
:
...
...
backend/agent/tools/data_retrieval_tool.save.py
deleted
100644 → 0
View file @
f057ad1e
"""
CANIFA Data Retrieval Tool - Tối giản cho Agentic Workflow.
Hỗ trợ Hybrid Search: Semantic (Vector) + Metadata Filter.
"""
import
asyncio
import
json
import
logging
import
time
from
decimal
import
Decimal
from
langchain_core.tools
import
tool
from
pydantic
import
BaseModel
,
Field
from
agent.tools.product_search_helpers
import
build_starrocks_query
from
common.starrocks_connection
import
StarRocksConnection
# from langsmith import traceable
logger
=
logging
.
getLogger
(
__name__
)
class
DecimalEncoder
(
json
.
JSONEncoder
):
"""Xử lý kiểu Decimal từ Database khi convert sang JSON."""
def
default
(
self
,
obj
):
if
isinstance
(
obj
,
Decimal
):
return
float
(
obj
)
return
super
()
.
default
(
obj
)
class
SearchItem
(
BaseModel
):
"""Cấu trúc một mục tìm kiếm đơn lẻ trong Multi-Search."""
query
:
str
=
Field
(
...
,
description
=
"Câu hỏi/mục đích tự do của user (đi chơi, dự tiệc, phỏng vấn,...) - dùng cho Semantic Search"
,
)
keywords
:
str
|
None
=
Field
(
...
,
description
=
"Từ khóa sản phẩm cụ thể (áo polo, quần jean,...) - dùng cho LIKE search"
)
magento_ref_code
:
str
|
None
=
Field
(
...
,
description
=
"Mã sản phẩm hoặc mã màu/SKU (Ví dụ: 8TS24W001 hoặc 8TS24W001-SK010)."
)
product_line_vn
:
str
|
None
=
Field
(
...
,
description
=
"Dòng sản phẩm (Áo phông, Quần short,...)"
)
gender_by_product
:
str
|
None
=
Field
(
...
,
description
=
"Giới tính: male, female"
)
age_by_product
:
str
|
None
=
Field
(
...
,
description
=
"Độ tuổi: adult, kids, baby, others"
)
master_color
:
str
|
None
=
Field
(
...
,
description
=
"Màu sắc chính (Đen/ Black, Trắng/ White,...)"
)
material_group
:
str
|
None
=
Field
(
...
,
description
=
"Nhóm chất liệu. BẮT BUỘC dùng đúng: 'Yarn - Sợi', 'Knit - Dệt Kim', 'Woven - Dệt Thoi', 'Knit/Woven - Dệt Kim/Dệt Thoi'."
,
)
season
:
str
|
None
=
Field
(
...
,
description
=
"Mùa (Spring Summer, Autumn Winter)"
)
style
:
str
|
None
=
Field
(
...
,
description
=
"Phong cách (Basic Update, Fashion,...)"
)
fitting
:
str
|
None
=
Field
(
...
,
description
=
"Form dáng (Regular, Slim, Loose,...)"
)
form_neckline
:
str
|
None
=
Field
(
...
,
description
=
"Kiểu cổ (Crew Neck, V-neck,...)"
)
form_sleeve
:
str
|
None
=
Field
(
...
,
description
=
"Kiểu tay (Short Sleeve, Long Sleeve,...)"
)
price_min
:
float
|
None
=
Field
(
...
,
description
=
"Giá thấp nhất"
)
price_max
:
float
|
None
=
Field
(
...
,
description
=
"Giá cao nhất"
)
action
:
str
=
Field
(
...
,
description
=
"Hành động: 'search' (tìm kiếm) hoặc 'visual_search' (phân tích ảnh)"
)
class
MultiSearchParams
(
BaseModel
):
"""Tham số cho Parallel Multi-Search."""
searches
:
list
[
SearchItem
]
=
Field
(
...
,
description
=
"Danh sách các truy vấn tìm kiếm chạy song song"
)
@
tool
(
args_schema
=
MultiSearchParams
)
# @traceable(run_type="tool", name="data_retrieval_tool")
async
def
data_retrieval_tool
(
searches
:
list
[
SearchItem
])
->
str
:
"""
Siêu công cụ tìm kiếm sản phẩm CANIFA - Hỗ trợ Parallel Multi-Search (Chạy song song nhiều query).
💡 ĐIỂM ĐẶC BIỆT:
Công cụ này cho phép thực hiện NHIỀU truy vấn tìm kiếm CÙNG LÚC.
Hãy dùng nó khi cần SO SÁNH sản phẩm hoặc tìm trọn bộ OUTFIT (mix & match).
⚠️ QUAN TRỌNG - KHI NÀO DÙNG GÌ:
1️⃣ DÙNG 'query' (Semantic Search - BUỘC PHẢI CÓ):
- Áp dụng cho mọi lượt search để cung cấp bối cảnh (context).
- Ví dụ: "áo thun nam đi biển", "quần tây công sở", "đồ cho bé màu xanh"...
2️⃣ DÙNG METADATA FILTERS (Exact/Partial Match):
- Khi khách nói rõ THUỘC TÍNH: Màu sắc, giá, giới tính, độ tuổi, mã sản phẩm.
- **QUY TẮC MÃ SẢN PHẨM:** Mọi loại mã (VD: `8TS...` hoặc `8TS...-SK...`) → Điền vào `magento_ref_code`.
- **QUY TẮC CHẤT LIÊU (material_group):** Chỉ dùng: `Yarn - Sợi`, `Knit - Dệt Kim`, `Woven - Dệt Thoi`, `Knit/Woven - Dệt Kim/Dệt Thoi`.
📝 VÍ DỤ CHI TIẾT (Single Search):
- Example 1: searches=[{"query": "áo polo nam giá dưới 400k", "keywords": "áo polo", "gender_by_product": "male", "price_max": 400000}]
- Example 2: searches=[{"query": "sản phẩm mã 8TS24W001", "magento_ref_code": "8TS24W001"}]
🚀 VÍ DỤ CẤP CAO (Multi-Search Parallel):
- Example 3 - So sánh: "So sánh áo thun nam đen và áo sơ mi trắng dưới 500k"
Tool Call: searches=[
{"query": "áo thun nam màu đen dưới 500k", "keywords": "áo thun", "master_color": "Đen", "gender_by_product": "male", "price_max": 500000},
{"query": "áo sơ mi nam trắng dưới 500k", "keywords": "áo sơ mi", "master_color": "Trắng", "gender_by_product": "male", "price_max": 500000}
]
- Example 4 - Phối đồ: "Tìm cho mình một cái quần jean và một cái áo khoác để đi chơi"
Tool Call: searches=[
{"query": "quần jean đi chơi năng động", "keywords": "quần jean"},
{"query": "áo khoác đi chơi năng động", "keywords": "áo khoác"}
]
- Example 5 - Cả gia đình: "Tìm áo phông màu xanh cho bố, mẹ và bé trai"
Tool Call: searches=[
{"query": "áo phông nam người lớn màu xanh", "keywords": "áo phông", "master_color": "Xanh", "gender_by_product": "male", "age_by_product": "adult"},
{"query": "áo phông nữ người lớn màu xanh", "keywords": "áo phông", "master_color": "Xanh", "gender_by_product": "female", "age_by_product": "adult"},
{"query": "áo phông bé trai màu xanh", "keywords": "áo phông", "master_color": "Xanh", "gender_by_product": "male", "age_by_product": "others"}
]
"""
logger
.
info
(
"🔧 [DEBUG] data_retrieval_tool STARTED"
)
try
:
logger
.
info
(
"🔧 [DEBUG] Creating StarRocksConnection instance"
)
db
=
StarRocksConnection
()
logger
.
info
(
"🔧 [DEBUG] StarRocksConnection created successfully"
)
# 0. Log input parameters (Đúng ý bro)
logger
.
info
(
f
"📥 [Tool Input] data_retrieval_tool received {len(searches)} items:"
)
for
idx
,
item
in
enumerate
(
searches
):
logger
.
info
(
f
" 🔹 Item [{idx}]: {item.dict(exclude_none=True)}"
)
# 1. Tạo tasks chạy song song (Parallel)
logger
.
info
(
"🔧 [DEBUG] Creating parallel tasks"
)
tasks
=
[]
for
item
in
searches
:
tasks
.
append
(
_execute_single_search
(
db
,
item
))
logger
.
info
(
f
"🚀 [Parallel Search] Executing {len(searches)} queries simultaneously..."
)
logger
.
info
(
"🔧 [DEBUG] About to call asyncio.gather()"
)
results
=
await
asyncio
.
gather
(
*
tasks
)
logger
.
info
(
f
"🔧 [DEBUG] asyncio.gather() completed with {len(results)} results"
)
# 2. Tổng hợp kết quả
combined_results
=
[]
for
i
,
products
in
enumerate
(
results
):
combined_results
.
append
(
{
"search_index"
:
i
,
"search_criteria"
:
searches
[
i
]
.
dict
(
exclude_none
=
True
),
"count"
:
len
(
products
),
"products"
:
products
,
}
)
return
json
.
dumps
({
"status"
:
"success"
,
"results"
:
combined_results
},
ensure_ascii
=
False
,
cls
=
DecimalEncoder
)
except
Exception
as
e
:
logger
.
error
(
f
"Error in Multi-Search data_retrieval_tool: {e}"
)
return
json
.
dumps
({
"status"
:
"error"
,
"message"
:
str
(
e
)})
async
def
_execute_single_search
(
db
:
StarRocksConnection
,
item
:
SearchItem
)
->
list
[
dict
]:
"""Thực thi một search query đơn lẻ (Async)."""
try
:
logger
.
info
(
f
"🔧 [DEBUG] _execute_single_search STARTED for query: {item.query[:50] if item.query else 'None'}"
)
# ⏱️ Timer: Build query (bao gồm embedding nếu có)
query_build_start
=
time
.
time
()
logger
.
info
(
"🔧 [DEBUG] Calling build_starrocks_query()"
)
sql
=
await
build_starrocks_query
(
item
)
query_build_time
=
(
time
.
time
()
-
query_build_start
)
*
1000
# Convert to ms
logger
.
info
(
f
"🔧 [DEBUG] SQL query built, length: {len(sql)}"
)
logger
.
info
(
f
"⏱️ [TIMER] Query Build Time (bao gồm embedding): {query_build_time:.2f}ms"
)
# ⏱️ Timer: Execute DB query
db_start
=
time
.
time
()
logger
.
info
(
"🔧 [DEBUG] Calling db.execute_query_async()"
)
products
=
await
db
.
execute_query_async
(
sql
)
db_time
=
(
time
.
time
()
-
db_start
)
*
1000
# Convert to ms
logger
.
info
(
f
"🔧 [DEBUG] Query executed, got {len(products)} products"
)
logger
.
info
(
f
"⏱️ [TIMER] DB Query Execution Time: {db_time:.2f}ms"
)
logger
.
info
(
f
"⏱️ [TIMER] Total Time (Build + DB): {query_build_time + db_time:.2f}ms"
)
return
_format_product_results
(
products
)
except
Exception
as
e
:
logger
.
error
(
f
"Single search error for item {item}: {e}"
)
return
[]
def
_format_product_results
(
products
:
list
[
dict
])
->
list
[
dict
]:
"""Lọc và format kết quả trả về cho Agent."""
allowed_fields
=
{
"internal_ref_code"
,
"description_text_full"
,
}
return
[{
k
:
v
for
k
,
v
in
p
.
items
()
if
k
in
allowed_fields
}
for
p
in
products
[:
5
]]
backend/agent/tools/product_search_helpers_save.py
deleted
100644 → 0
View file @
f057ad1e
# import logging
# from common.embedding_service import create_embedding_async
# logger = logging.getLogger(__name__)
# def _escape(val: str) -> str:
# """Thoát dấu nháy đơn để tránh SQL Injection cơ bản."""
# return val.replace("'", "''")
# def _get_where_clauses(params) -> list[str]:
# """
# Xây dựng WHERE clauses theo thứ tự ưu tiên dựa trên selectivity thực tế
# FILTER PRIORITY (Based on Canifa catalog analysis):
# 🔥 TIER 1 (99% selectivity):
# 1. SKU Code → 1-5 records
# 🎯 TIER 2 (50-70% selectivity):
# 2. Gender → Splits catalog in half
# 3. Age → Kids vs Adults split
# 4. Product Category → 10-15 categories
# 💎 TIER 3 (30-50% selectivity):
# 5. Material Group → Knit vs Woven (2 groups)
# 6. Price Range → Numeric filtering
# 🎨 TIER 4 (10-30% selectivity):
# 7. Season → 4 seasons
# 8. Style/Fitting → Multiple options
# ⚠️ TIER 5 (<10% selectivity):
# 9. Form details → Granular attributes
# 10. Color → LOWEST selectivity (many SKUs share colors)
# Early return: If SKU exists, skip low-selectivity filters
# """
# clauses = []
# # 🔥 TIER 1: SKU/Product Code (Unique identifier)
# # Selectivity: ~99% → 1 SKU = 1 style (3-5 colors max)
# sku_clause = _get_sku_clause(params)
# if sku_clause:
# clauses.append(sku_clause)
# # Early return optimization: SKU đã xác định product rõ ràng
# # CHỈ GIỮ LẠI price filter (nếu có) để verify budget constraint
# # BỎ QUA: gender, color, style, fitting... vì SKU đã unique
# price_clauses = _get_price_clauses(params)
# if price_clauses:
# clauses.extend(price_clauses)
# return clauses # ⚡ STOP - Không thêm filter khác!
# # 🎯 TIER 2: High-level categorization (50-70% reduction)
# # Gender + Age + Category có selectivity cao nhất trong non-SKU filters
# clauses.extend(_get_high_selectivity_clauses(params))
# # 💎 TIER 3: Material & Price (30-50% reduction)
# material_clause = _get_material_clause(params)
# if material_clause:
# clauses.append(material_clause)
# clauses.extend(_get_price_clauses(params))
# # 🎨 TIER 4: Attributes (10-30% reduction)
# clauses.extend(_get_attribute_clauses(params))
# # ⚠️ TIER 5: Granular details & Color (LAST - lowest selectivity)
# clauses.extend(_get_form_detail_clauses(params))
# color_clause = _get_color_clause(params)
# if color_clause:
# clauses.append(color_clause) # Color ALWAYS LAST!
# return clauses
# def _get_sku_clause(params) -> str | None:
# """
# TIER 1: SKU/Product Code (Highest selectivity - 99%)
# 1 SKU code = 1 product style (may have 3-5 color variants)
# WHY SKU is always priority #1:
# - 1 code = 1 unique product design
# - Adding other filters (color, style, gender) is redundant
# - Only price filter may be kept for budget validation
# Example queries:
# - "Mã 6OT25W010" → Only SKU needed
# - "Mã 6OT25W010 màu xám" → Only SKU (color is for display/selection, not filtering)
# - "Mã 6OT25W010 dưới 500k" → SKU + price (validate budget)
# """
# m_code = getattr(params, "magento_ref_code", None)
# if m_code:
# m = _escape(m_code)
# return f"(magento_ref_code = '{m}' OR internal_ref_code = '{m}')"
# return None
# def _get_color_clause(params) -> str | None:
# """
# TIER 5: Color (LOWEST selectivity - 5-10%)
# Multiple SKUs share the same color (e.g., 50+ gray products)
# ALWAYS filter color LAST after other constraints
# """
# color = getattr(params, "master_color", None)
# if color:
# c = _escape(color).lower()
# return f"(LOWER(master_color) LIKE '%{c}%' OR LOWER(product_color_name) LIKE '%{c}%')"
# return None
# def _get_high_selectivity_clauses(params) -> list[str]:
# """
# TIER 2: High-level categorization (50-70% reduction per filter)
# Order: Gender → Age → Product Category
# """
# clauses = []
# # Gender: Male/Female/Unisex split (50-70% reduction)
# gender = getattr(params, "gender_by_product", None)
# if gender:
# clauses.append(f"gender_by_product = '{_escape(gender)}'")
# # Age: Kids/Adults split (50% reduction of remaining)
# age = getattr(params, "age_by_product", None)
# if age:
# clauses.append(f"age_by_product = '{_escape(age)}'")
# # Product Category: Váy/Áo/Quần... (30-50% reduction)
# product_line = getattr(params, "product_line_vn", None)
# if product_line:
# p = _escape(product_line).lower()
# clauses.append(f"LOWER(product_line_vn) LIKE '%{p}%'")
# return clauses
# def _get_material_clause(params) -> str | None:
# """TIER 3: Material Group - Knit vs Woven (50% split)"""
# material = getattr(params, "material_group", None)
# if material:
# m = _escape(material).lower()
# return f"LOWER(material_group) LIKE '%{m}%'"
# return None
# def _get_price_clauses(params) -> list[str]:
# """TIER 3: Price Range - Numeric filtering (30-40% reduction)"""
# clauses = []
# p_min = getattr(params, "price_min", None)
# if p_min is not None:
# clauses.append(f"sale_price >= {p_min}")
# p_max = getattr(params, "price_max", None)
# if p_max is not None:
# clauses.append(f"sale_price <= {p_max}")
# return clauses
# def _get_attribute_clauses(params) -> list[str]:
# """
# TIER 4: Attributes (10-30% reduction)
# Season, Style, Fitting
# """
# clauses = []
# # Season: 4 seasons (~25% each)
# season = getattr(params, "season", None)
# if season:
# s = _escape(season).lower()
# clauses.append(f"LOWER(season) LIKE '%{s}%'")
# # Style: Basic/Feminine/Sporty... (~15-20% reduction)
# style = getattr(params, "style", None)
# if style:
# st = _escape(style).lower()
# clauses.append(f"LOWER(style) LIKE '%{st}%'")
# # Fitting: Regular/Slim/Loose (~15% reduction)
# fitting = getattr(params, "fitting", None)
# if fitting:
# f = _escape(fitting).lower()
# clauses.append(f"LOWER(fitting) LIKE '%{f}%'")
# # Size Scale: S, M, L, 29, 30... (Specific filtering)
# size = getattr(params, "size_scale", None)
# if size:
# sz = _escape(size).lower()
# clauses.append(f"LOWER(size_scale) LIKE '%{sz}%'")
# return clauses
# def _get_form_detail_clauses(params) -> list[str]:
# """
# TIER 5: Granular form details (<10% reduction each)
# Neckline, Sleeve type
# """
# clauses = []
# form_fields = [
# ("form_neckline", "form_neckline"),
# ("form_sleeve", "form_sleeve"),
# ]
# for param_name, col_name in form_fields:
# val = getattr(params, param_name, None)
# if val:
# v = _escape(val).lower()
# clauses.append(f"LOWER({col_name}) LIKE '%{v}%'")
# return clauses
# async def build_starrocks_query(params, query_vector: list[float] | None = None) -> str:
# """
# Build SQL Hybrid tối ưu với Filter Priority:
# 1. Pre-filtering theo độ ưu tiên (SKU → Exact → Price → Partial)
# 2. Vector Search (HNSW Index) - Semantic understanding
# 3. Flexible Keyword Search (OR + Scoring) - Fuzzy matching fallback
# 4. Grouping (Gom màu theo style)
# """
# # --- Process vector in query field ---
# query_text = getattr(params, "query", None)
# # if query_text and query_vector is None:
# # query_vector = await create_embedding_async(query_text)
# # --- Build filter clauses (OPTIMIZED ORDER) ---
# where_clauses = _get_where_clauses(params)
# where_sql = " AND ".join(where_clauses) if where_clauses else "1=1"
# # --- Build SQL ---
# if query_vector and len(query_vector) > 0:
# v_str = "[" + ",".join(str(v) for v in query_vector) + "]"
# sql = f"""
# WITH top_sku_candidates AS (
# SELECT
# approx_cosine_similarity(vector, {v_str}) as similarity_score,
# internal_ref_code,
# product_name,
# sale_price,
# original_price,
# master_color,
# product_image_url,
# product_image_url_thumbnail,
# product_web_url,
# description_text,
# material,
# material_group,
# gender_by_product,
# age_by_product,
# season,
# style,
# fitting,
# form_neckline,
# form_sleeve,
# product_line_vn,
# product_color_name
# FROM shared_source.magento_product_dimension_with_text_embedding
# WHERE {where_sql} AND vector IS NOT NULL
# ORDER BY similarity_score DESC
# LIMIT 50
# )
# SELECT
# internal_ref_code,
# ANY_VALUE(product_name) as product_name,
# ANY_VALUE(sale_price) as sale_price,
# ANY_VALUE(original_price) as original_price,
# GROUP_CONCAT(DISTINCT master_color ORDER BY master_color SEPARATOR ', ') as available_colors,
# ANY_VALUE(product_image_url) as product_image_url,
# ANY_VALUE(product_image_url_thumbnail) as product_image_url_thumbnail,
# ANY_VALUE(product_web_url) as product_web_url,
# ANY_VALUE(description_text) as description_text,
# ANY_VALUE(material) as material,
# ANY_VALUE(material_group) as material_group,
# ANY_VALUE(gender_by_product) as gender_by_product,
# ANY_VALUE(age_by_product) as age_by_product,
# ANY_VALUE(season) as season,
# ANY_VALUE(style) as style,
# ANY_VALUE(fitting) as fitting,
# ANY_VALUE(form_neckline) as form_neckline,
# ANY_VALUE(form_sleeve) as form_sleeve,
# ANY_VALUE(product_line_vn) as product_line_vn,
# MAX(similarity_score) as max_score
# FROM top_sku_candidates
# GROUP BY internal_ref_code
# ORDER BY max_score DESC
# LIMIT 10
# """
# else:
# # ⚡ FALLBACK: FLEXIBLE KEYWORD SEARCH (OR + SCORING)
# # Giải quyết case: User search "áo khoác nỉ" → DB có "Áo nỉ nam"
# keywords = getattr(params, "keywords", None)
# keyword_score_sql = ""
# keyword_filter = ""
# if keywords:
# k_clean = _escape(keywords).lower().strip()
# if k_clean:
# words = k_clean.split()
# # Build scoring expression: Each matched word = +1 point
# # Example: "áo khoác nỉ" (3 words)
# # - "Áo nỉ nam" matches 2/3 → Score = 2
# # - "Áo khoác nỉ hoodie" matches 3/3 → Score = 3
# score_terms = [
# f"(CASE WHEN LOWER(product_name) LIKE '%{w}%' THEN 1 ELSE 0 END)"
# for w in words
# ]
# keyword_score_sql = f"({' + '.join(score_terms)}) as keyword_match_score"
# # Minimum threshold: At least 50% of words must match
# # Example: 3 words → need at least 2 matches (66%)
# # 2 words → need at least 1 match (50%)
# min_matches = max(1, len(words) // 2)
# keyword_filter = f" AND ({' + '.join(score_terms)}) >= {min_matches}"
# # Select clause with optional scoring
# select_score = f", {keyword_score_sql}" if keyword_score_sql else ""
# order_by = "keyword_match_score DESC, sale_price ASC" if keyword_score_sql else "sale_price ASC"
# sql = f"""
# SELECT
# internal_ref_code,
# ANY_VALUE(product_name) as product_name,
# ANY_VALUE(sale_price) as sale_price,
# ANY_VALUE(original_price) as original_price,
# GROUP_CONCAT(DISTINCT master_color ORDER BY master_color SEPARATOR ', ') as available_colors,
# ANY_VALUE(product_image_url) as product_image_url,
# ANY_VALUE(product_image_url_thumbnail) as product_image_url_thumbnail,
# ANY_VALUE(product_web_url) as product_web_url,
# ANY_VALUE(description_text) as description_text,
# ANY_VALUE(material) as material,
# ANY_VALUE(material_group) as material_group,
# ANY_VALUE(gender_by_product) as gender_by_product,
# ANY_VALUE(age_by_product) as age_by_product,
# ANY_VALUE(season) as season,
# ANY_VALUE(style) as style,
# ANY_VALUE(fitting) as fitting,
# ANY_VALUE(form_neckline) as form_neckline,
# ANY_VALUE(form_sleeve) as form_sleeve,
# ANY_VALUE(product_line_vn) as product_line_vn
# {select_score}
# FROM shared_source.magento_product_dimension_with_text_embedding
# WHERE {where_sql} {keyword_filter}
# GROUP BY internal_ref_code
# HAVING COUNT(*) > 0
# ORDER BY {order_by}
# LIMIT 10
# """
# # Log filter statistics
# filter_info = f"Mode: {'Vector' if query_vector else 'Keyword'}, Filters: {len(where_clauses)}"
# if where_clauses:
# # Identify high-priority filters used
# has_sku = any('internal_ref_code' in c or 'magento_ref_code' in c for c in where_clauses)
# has_gender = any('gender_by_product' in c for c in where_clauses)
# has_category = any('product_line_vn' in c for c in where_clauses)
# priority_info = []
# if has_sku:
# priority_info.append("SKU")
# if has_gender:
# priority_info.append("Gender")
# if has_category:
# priority_info.append("Category")
# if priority_info:
# filter_info += f", Priority: {'+'.join(priority_info)}"
# logger.info(f"📊 {filter_info}")
# # Write SQL to file for debugging
# try:
# with open(r"d:\cnf\chatbot_canifa\backend\embedding.txt", "w", encoding="utf-8") as f:
# f.write(sql)
# except Exception as e:
# logger.error(f"Failed to write SQL to embedding.txt: {e}")
# return sql
import
logging
import
time
from
common.embedding_service
import
create_embedding_async
logger
=
logging
.
getLogger
(
__name__
)
def
_escape
(
val
:
str
)
->
str
:
"""Thoát dấu nháy đơn để tránh SQL Injection cơ bản."""
return
val
.
replace
(
"'"
,
"''"
)
def
_get_where_clauses
(
params
)
->
list
[
str
]:
"""Xây dựng danh sách các điều kiện lọc từ params."""
clauses
=
[]
clauses
.
extend
(
_get_price_clauses
(
params
))
clauses
.
extend
(
_get_metadata_clauses
(
params
))
clauses
.
extend
(
_get_special_clauses
(
params
))
return
clauses
def
_get_price_clauses
(
params
)
->
list
[
str
]:
"""Lọc theo giá."""
clauses
=
[]
p_min
=
getattr
(
params
,
"price_min"
,
None
)
if
p_min
is
not
None
:
clauses
.
append
(
f
"sale_price >= {p_min}"
)
p_max
=
getattr
(
params
,
"price_max"
,
None
)
if
p_max
is
not
None
:
clauses
.
append
(
f
"sale_price <= {p_max}"
)
return
clauses
def
_get_metadata_clauses
(
params
)
->
list
[
str
]:
"""Xây dựng điều kiện lọc từ metadata (Phối hợp Exact và Partial)."""
clauses
=
[]
# 1. Exact Match (Giới tính, Độ tuổi) - Các trường này cần độ chính xác tuyệt đối
exact_fields
=
[
(
"gender_by_product"
,
"gender_by_product"
),
(
"age_by_product"
,
"age_by_product"
),
]
for
param_name
,
col_name
in
exact_fields
:
val
=
getattr
(
params
,
param_name
,
None
)
if
val
:
clauses
.
append
(
f
"{col_name} = '{_escape(val)}'"
)
# 2. Partial Match (LIKE) - Giúp map text linh hoạt hơn (Chất liệu, Dòng SP, Phong cách...)
# Cái này giúp map: "Yarn" -> "Yarn - Sợi", "Knit" -> "Knit - Dệt Kim"
partial_fields
=
[
(
"season"
,
"season"
),
(
"material_group"
,
"material_group"
),
(
"product_line_vn"
,
"product_line_vn"
),
(
"style"
,
"style"
),
(
"fitting"
,
"fitting"
),
(
"form_neckline"
,
"form_neckline"
),
(
"form_sleeve"
,
"form_sleeve"
),
]
for
param_name
,
col_name
in
partial_fields
:
val
=
getattr
(
params
,
param_name
,
None
)
if
val
:
v
=
_escape
(
val
)
.
lower
()
# Dùng LOWER + LIKE để cân mọi loại ký tự thừa hoặc hoa/thường
clauses
.
append
(
f
"LOWER({col_name}) LIKE '
%
{v}
%
'"
)
return
clauses
def
_get_special_clauses
(
params
)
->
list
[
str
]:
"""Các trường hợp đặc biệt: Mã sản phẩm, Màu sắc."""
clauses
=
[]
# Mã sản phẩm / SKU
m_code
=
getattr
(
params
,
"magento_ref_code"
,
None
)
if
m_code
:
m
=
_escape
(
m_code
)
clauses
.
append
(
f
"(magento_ref_code = '{m}' OR internal_ref_code = '{m}')"
)
# Màu sắc
color
=
getattr
(
params
,
"master_color"
,
None
)
if
color
:
c
=
_escape
(
color
)
.
lower
()
clauses
.
append
(
f
"(LOWER(master_color) LIKE '
%
{c}
%
' OR LOWER(product_color_name) LIKE '
%
{c}
%
')"
)
return
clauses
async
def
build_starrocks_query
(
params
,
query_vector
:
list
[
float
]
|
None
=
None
)
->
str
:
"""
Build SQL Hybrid tối ưu với POST-FILTERING Strategy & Anti-Duplication.
🔥 CHIẾN LƯỢC TỐI ƯU:
1. Vector Search TRƯỚC (LIMIT 100) để tận dụng HNSW Index (tốc độ ~50ms).
2. JOIN chính xác theo (code + màu) để tránh bùng nổ dữ liệu (Data Explosion).
3. Dùng MAX_BY để lấy description của đúng thằng có score cao nhất.
"""
logger
.
info
(
"🔧 [DEBUG] build_starrocks_query STARTED"
)
# --- 1. Xử lý Vector ---
query_text
=
getattr
(
params
,
"query"
,
None
)
if
query_text
and
query_vector
is
None
:
emb_start
=
time
.
time
()
query_vector
=
await
create_embedding_async
(
query_text
)
emb_time
=
(
time
.
time
()
-
emb_start
)
*
1000
logger
.
info
(
f
"⏱️ [TIMER] Embedding Generation: {emb_time:.2f}ms"
)
# --- 2. Xây dựng Filter cho POST-FILTERING ---
where_clauses
=
_get_where_clauses
(
params
)
post_filter_sql
=
" AND "
.
join
(
where_clauses
)
if
where_clauses
else
"1=1"
# --- 3. Build SQL ---
if
query_vector
and
len
(
query_vector
)
>
0
:
v_str
=
"["
+
","
.
join
(
str
(
v
)
for
v
in
query_vector
)
+
"]"
# Alias các trường trong filter sang bảng t2 để tránh lỗi ambiguous
post_filter_aliased
=
post_filter_sql
fields_to_alias
=
[
"sale_price"
,
"gender_by_product"
,
"age_by_product"
,
"material_group"
,
"season"
,
"style"
,
"fitting"
,
"form_neckline"
,
"form_sleeve"
,
"product_line_vn"
,
"magento_ref_code"
,
"internal_ref_code"
,
"master_color"
,
"product_color_name"
,
]
for
field
in
fields_to_alias
:
post_filter_aliased
=
post_filter_aliased
.
replace
(
field
,
f
"t2.{field}"
)
sql
=
f
"""
WITH top_candidates AS (
SELECT /*+ SET_VAR(ann_params='{{"ef_search":64}}') */
internal_ref_code,
product_color_code,
approx_cosine_similarity(vector, {v_str}) as similarity_score
FROM shared_source.magento_product_dimension_with_text_embedding
WHERE vector IS NOT NULL
ORDER BY similarity_score DESC
LIMIT 100
)
SELECT
t1.internal_ref_code,
-- MAX_BY đảm bảo mô tả đi kèm đúng với thằng cao điểm nhất (Data Integrity)
MAX_BY(t2.description_text_full, t1.similarity_score) as description_text_full,
MAX(t1.similarity_score) as max_score
FROM top_candidates t1
JOIN shared_source.magento_product_dimension_with_text_embedding t2
ON t1.internal_ref_code = t2.internal_ref_code
AND t1.product_color_code = t2.product_color_code -- QUAN TRỌNG: Tránh nhân bản dòng theo màu
WHERE {post_filter_aliased}
GROUP BY t1.internal_ref_code
ORDER BY max_score DESC
LIMIT 10
"""
else
:
# FALLBACK: Keyword search
keywords
=
getattr
(
params
,
"keywords"
,
None
)
k_filter
=
""
if
keywords
:
k
=
_escape
(
keywords
)
.
lower
()
k_filter
=
f
" AND LOWER(product_name) LIKE '
%
{k}
%
'"
where_sql
=
" AND "
.
join
(
where_clauses
)
if
where_clauses
else
"1=1"
sql
=
f
"""
SELECT
internal_ref_code,
-- Lấy đại diện 1 mô tả cho keyword search
MAX(description_text_full) as description_text_full,
MIN(sale_price) as min_price
FROM shared_source.magento_product_dimension_with_text_embedding
WHERE {where_sql} {k_filter}
GROUP BY internal_ref_code
ORDER BY min_price ASC
LIMIT 10
"""
# --- 4. Ghi Log Debug ---
try
:
debug_path
=
r"d:\cnf\chatbot_canifa\backend\query.txt"
with
open
(
debug_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
sql
)
logger
.
info
(
f
"💾 SQL saved to: {debug_path}"
)
except
Exception
as
e
:
logger
.
error
(
f
"Save log failed: {e}"
)
return
sql
backend/api/chatbot_route.py
View file @
28274420
...
...
@@ -9,7 +9,7 @@ import logging
from
fastapi
import
APIRouter
,
BackgroundTasks
,
HTTPException
from
opentelemetry
import
trace
from
agent.controller
import
chat_controller
from
agent.
agno_
controller
import
chat_controller
from
agent.models
import
QueryRequest
from
config
import
DEFAULT_MODEL
...
...
backend/common/conversation_manager.py
View file @
28274420
...
...
@@ -2,54 +2,106 @@ import logging
from
datetime
import
datetime
from
typing
import
Any
from
psycopg
import
sql
from
psycopg_pool
import
AsyncConnectionPool
from
config
import
CHECKPOINT_POSTGRES_URL
# Runtime imports with fallback
try
:
from
agno.db.base
import
BaseDb
from
agno.models
import
Message
# type: ignore[import-untyped]
except
ImportError
:
# Create stub class if agno not installed
class
BaseDbStub
:
# type: ignore
pass
# Create a simple Message-like class for when Agno is not available
class
MessageStub
:
# type: ignore
def
__init__
(
self
,
role
:
str
,
content
:
str
,
created_at
:
Any
=
None
):
self
.
role
=
role
self
.
content
=
content
self
.
created_at
=
created_at
BaseDb
=
BaseDbStub
# type: ignore
Message
=
MessageStub
# type: ignore
logger
=
logging
.
getLogger
(
__name__
)
class
ConversationManager
:
# Use composition instead of inheritance to avoid implementing all BaseDb methods
class
ConversationManager
:
# Don't inherit BaseDb directly
"""
Conversation Manager với Agno BaseDb interface.
Hỗ trợ cả legacy methods và Agno Agent.
"""
def
__init__
(
self
,
connection_url
:
str
=
CHECKPOINT_POSTGRES_URL
,
connection_url
:
str
|
None
=
None
,
table_name
:
str
=
"langgraph_chat_histories"
,
):
self
.
connection_url
=
connection_url
self
.
connection_url
:
str
=
connection_url
or
CHECKPOINT_POSTGRES_URL
or
""
if
not
self
.
connection_url
:
raise
ValueError
(
"connection_url is required"
)
self
.
table_name
=
table_name
self
.
_pool
:
AsyncConnectionPool
|
None
=
None
async
def
_get_pool
(
self
)
->
AsyncConnectionPool
:
"""Get or create async connection pool."""
"""Get or create async connection pool
với config hợp lý
."""
if
self
.
_pool
is
None
:
self
.
_pool
=
AsyncConnectionPool
(
self
.
connection_url
,
open
=
False
)
# Pool config: min_size=1, max_size=5, timeout=10s
self
.
_pool
=
AsyncConnectionPool
(
self
.
connection_url
,
min_size
=
1
,
max_size
=
5
,
timeout
=
10.0
,
# 10s timeout thay vì default 30s
open
=
False
,
)
try
:
await
self
.
_pool
.
open
()
logger
.
info
(
f
"✅ PostgreSQL connection pool opened: {self.connection_url.split('@')[-1] if '@' in self.connection_url else '***'}"
)
except
Exception
as
e
:
logger
.
error
(
f
"❌ Failed to open PostgreSQL pool: {e}"
)
self
.
_pool
=
None
raise
return
self
.
_pool
async
def
initialize_table
(
self
):
"""Create the chat history table if it doesn't exist"""
try
:
logger
.
info
(
f
"🔌 Initializing PostgreSQL table: {self.table_name}"
)
pool
=
await
self
.
_get_pool
()
async
with
pool
.
connection
()
as
conn
:
# Use connection với timeout ngắn hơn
async
with
pool
.
connection
(
timeout
=
5.0
)
as
conn
:
# 5s timeout cho connection
async
with
conn
.
cursor
()
as
cursor
:
await
cursor
.
execute
(
f
"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
await
cursor
.
execute
(
sql
.
SQL
(
"""
CREATE TABLE IF NOT EXISTS {} (
id SERIAL PRIMARY KEY,
user_id VARCHAR(255) NOT NULL,
message TEXT NOT NULL,
is_human BOOLEAN NOT NULL,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
"""
)
.
format
(
sql
.
Identifier
(
self
.
table_name
))
)
await
cursor
.
execute
(
f
"""
CREATE INDEX IF NOT EXISTS idx_{self.table_name}_user_timestamp
ON {self.table_name} (user_id, timestamp)
"""
)
await
cursor
.
execute
(
sql
.
SQL
(
"""
CREATE INDEX IF NOT EXISTS idx_{}_user_timestamp
ON {} (user_id, timestamp)
"""
)
.
format
(
sql
.
Identifier
(
self
.
table_name
),
sql
.
Identifier
(
self
.
table_name
),
)
)
await
conn
.
commit
()
logger
.
info
(
f
"Table {self.table_name} initialized successfully"
)
logger
.
info
(
f
"
✅
Table {self.table_name} initialized successfully"
)
except
Exception
as
e
:
logger
.
error
(
f
"Error initializing table: {e}"
)
logger
.
error
(
f
"❌ Error initializing table: {e}"
)
logger
.
error
(
f
" Connection URL: {self.connection_url.split('@')[-1] if '@' in self.connection_url else '***'}"
)
raise
async
def
save_conversation_turn
(
self
,
user_id
:
str
,
human_message
:
str
,
ai_message
:
str
):
...
...
@@ -60,8 +112,10 @@ class ConversationManager:
async
with
pool
.
connection
()
as
conn
:
async
with
conn
.
cursor
()
as
cursor
:
await
cursor
.
execute
(
f
"""INSERT INTO {self.table_name} (user_id, message, is_human, timestamp)
VALUES (
%
s,
%
s,
%
s,
%
s), (
%
s,
%
s,
%
s,
%
s)"""
,
sql
.
SQL
(
"""
INSERT INTO {} (user_id, message, is_human, timestamp)
VALUES (
%
s,
%
s,
%
s,
%
s), (
%
s,
%
s,
%
s,
%
s)
"""
)
.
format
(
sql
.
Identifier
(
self
.
table_name
)),
(
user_id
,
human_message
,
...
...
@@ -84,23 +138,25 @@ class ConversationManager:
)
->
list
[
dict
[
str
,
Any
]]:
"""Retrieve chat history for a user using cursor-based pagination."""
try
:
query
=
f
"""
SELECT message, is_human, timestamp, id
FROM {self.table_name}
WHERE user_id =
%
s
"""
params
=
[
user_id
]
base_query
=
sql
.
SQL
(
"SELECT message, is_human, timestamp, id FROM {} WHERE user_id =
%
s"
)
.
format
(
sql
.
Identifier
(
self
.
table_name
)
)
params
:
list
[
Any
]
=
[
user_id
]
query_parts
:
list
[
sql
.
Composable
]
=
[
base_query
]
if
before_id
:
query
+=
" AND id <
%
s"
query
_parts
.
append
(
sql
.
SQL
(
" AND id <
%
s"
))
params
.
append
(
before_id
)
query
+=
" ORDER BY id DESC"
query
_parts
.
append
(
sql
.
SQL
(
" ORDER BY id DESC"
))
if
limit
:
query
+=
" LIMIT
%
s"
query
_parts
.
append
(
sql
.
SQL
(
" LIMIT
%
s"
))
params
.
append
(
limit
)
query
=
sql
.
Composed
(
query_parts
)
pool
=
await
self
.
_get_pool
()
async
with
pool
.
connection
()
as
conn
,
conn
.
cursor
()
as
cursor
:
await
cursor
.
execute
(
query
,
tuple
(
params
))
...
...
@@ -125,7 +181,12 @@ class ConversationManager:
pool
=
await
self
.
_get_pool
()
async
with
pool
.
connection
()
as
conn
:
async
with
conn
.
cursor
()
as
cursor
:
await
cursor
.
execute
(
f
"DELETE FROM {self.table_name} WHERE user_id =
%
s"
,
(
user_id
,))
await
cursor
.
execute
(
sql
.
SQL
(
"DELETE FROM {} WHERE user_id =
%
s"
)
.
format
(
sql
.
Identifier
(
self
.
table_name
)
),
(
user_id
,),
)
await
conn
.
commit
()
logger
.
info
(
f
"Cleared chat history for user {user_id}"
)
except
Exception
as
e
:
...
...
@@ -136,7 +197,11 @@ class ConversationManager:
try
:
pool
=
await
self
.
_get_pool
()
async
with
pool
.
connection
()
as
conn
,
conn
.
cursor
()
as
cursor
:
await
cursor
.
execute
(
f
"SELECT COUNT(DISTINCT user_id) FROM {self.table_name}"
)
await
cursor
.
execute
(
sql
.
SQL
(
"SELECT COUNT(DISTINCT user_id) FROM {}"
)
.
format
(
sql
.
Identifier
(
self
.
table_name
)
)
)
result
=
await
cursor
.
fetchone
()
return
result
[
0
]
if
result
else
0
except
Exception
as
e
:
...
...
@@ -147,6 +212,132 @@ class ConversationManager:
"""Close the connection pool"""
if
self
.
_pool
:
await
self
.
_pool
.
close
()
self
.
_pool
=
None
# ========== Agno BaseDb Interface Methods ==========
# Giữ nguyên methods cũ ở trên để backward compatible
async
def
initialize
(
self
):
"""Agno interface: Initialize table (alias của initialize_table)"""
return
await
self
.
initialize_table
()
async
def
load_history
(
self
,
session_id
:
str
,
limit
:
int
=
20
)
->
list
[
Any
]:
"""
Agno interface: Load history và convert sang Agno Message format.
Reuse code từ get_chat_history().
Args:
session_id: User ID (Agno dùng session_id, map với user_id)
limit: Số messages tối đa
Returns:
List of Agno Message objects
"""
try
:
# Reuse method cũ
history_dicts
=
await
self
.
get_chat_history
(
user_id
=
session_id
,
limit
=
limit
)
# Convert từ DB format → Agno Message format
messages
=
[]
for
h
in
reversed
(
history_dicts
):
# Reverse để chronological order
role
=
"user"
if
h
[
"is_human"
]
else
"assistant"
agno_message
=
Message
(
role
=
role
,
content
=
h
[
"message"
],
created_at
=
h
[
"timestamp"
],
)
messages
.
append
(
agno_message
)
logger
.
debug
(
f
"📥 [Agno] Loaded {len(messages)} messages for session {session_id}"
)
return
messages
except
Exception
as
e
:
logger
.
error
(
f
"❌ [Agno] Error loading history for {session_id}: {e}"
)
return
[]
async
def
save_message
(
self
,
session_id
:
str
,
message
:
Any
):
"""
Agno interface: Save single message.
Args:
session_id: User ID
message: Agno Message object
"""
try
:
pool
=
await
self
.
_get_pool
()
is_human
=
message
.
role
==
"user"
async
with
pool
.
connection
()
as
conn
:
async
with
conn
.
cursor
()
as
cursor
:
await
cursor
.
execute
(
sql
.
SQL
(
"""
INSERT INTO {} (user_id, message, is_human, timestamp)
VALUES (
%
s,
%
s,
%
s,
%
s)
"""
)
.
format
(
sql
.
Identifier
(
self
.
table_name
)),
(
session_id
,
message
.
content
,
is_human
,
message
.
created_at
or
datetime
.
now
(),
),
)
await
conn
.
commit
()
logger
.
debug
(
f
"💾 [Agno] Saved message for session {session_id}"
)
except
Exception
as
e
:
logger
.
error
(
f
"❌ [Agno] Error saving message for {session_id}: {e}"
,
exc_info
=
True
)
raise
async
def
save_session
(
self
,
session_id
:
str
,
messages
:
list
[
Any
]):
"""
Agno interface: Save multiple messages (batch).
Args:
session_id: User ID
messages: List of Agno Message objects
"""
try
:
pool
=
await
self
.
_get_pool
()
timestamp
=
datetime
.
now
()
async
with
pool
.
connection
()
as
conn
:
async
with
conn
.
cursor
()
as
cursor
:
# Batch insert
values
=
[]
for
msg
in
messages
:
is_human
=
msg
.
role
==
"user"
values
.
append
(
(
session_id
,
msg
.
content
,
is_human
,
msg
.
created_at
or
timestamp
,
)
)
await
cursor
.
executemany
(
sql
.
SQL
(
"""
INSERT INTO {} (user_id, message, is_human, timestamp)
VALUES (
%
s,
%
s,
%
s,
%
s)
"""
)
.
format
(
sql
.
Identifier
(
self
.
table_name
)),
values
,
)
await
conn
.
commit
()
logger
.
debug
(
f
"💾 [Agno] Saved {len(messages)} messages for session {session_id}"
)
except
Exception
as
e
:
logger
.
error
(
f
"❌ [Agno] Error saving session for {session_id}: {e}"
,
exc_info
=
True
)
raise
async
def
get_session_messages
(
self
,
session_id
:
str
)
->
list
[
Any
]:
"""Agno interface: Get all messages for a session"""
return
await
self
.
load_history
(
session_id
,
limit
=
1000
)
async
def
clear_session
(
self
,
session_id
:
str
):
"""Agno interface: Clear session (alias của clear_history)"""
return
await
self
.
clear_history
(
session_id
)
# ConversationManager implements BaseDb interface methods
# but doesn't inherit BaseDb to avoid implementing all abstract methods
# Agno will accept it as long as it has the required methods
# --- Singleton ---
...
...
@@ -157,6 +348,12 @@ async def get_conversation_manager() -> ConversationManager:
"""Get or create async ConversationManager singleton"""
global
_instance
if
_instance
is
None
:
try
:
_instance
=
ConversationManager
()
await
_instance
.
initialize_table
()
except
Exception
as
e
:
logger
.
error
(
f
"❌ Failed to initialize ConversationManager: {e}"
)
# Reset instance để retry lần sau
_instance
=
None
raise
return
_instance
backend/common/langfuse_client.py
View file @
28274420
"""
Simple Langfuse Client Wrapper
Minimal setup using langfuse.langchain module
With propagate_attributes for proper user_id tracking
Langfuse Client với OpenInference instrumentation cho Agno
Tự động trace tất cả Agno calls (LLM, tools, agent runs)
"""
import
asyncio
import
base64
import
logging
import
os
from
concurrent.futures
import
ThreadPoolExecutor
...
...
@@ -19,6 +19,21 @@ from config import (
LANGFUSE_SECRET_KEY
,
)
# OpenInference imports (optional - only if available)
_OPENINFERENCE_AVAILABLE
=
False
AgnoInstrumentor
=
None
# type: ignore
try
:
from
openinference.instrumentation.agno
import
AgnoInstrumentor
# type: ignore[import-untyped]
from
opentelemetry
import
trace
as
trace_api
from
opentelemetry.exporter.otlp.proto.http.trace_exporter
import
OTLPSpanExporter
from
opentelemetry.sdk.trace
import
TracerProvider
from
opentelemetry.sdk.trace.export
import
SimpleSpanProcessor
_OPENINFERENCE_AVAILABLE
=
True
except
ImportError
:
pass
logger
=
logging
.
getLogger
(
__name__
)
# ⚡ Global state for async batch export
...
...
@@ -31,9 +46,10 @@ _batch_lock = asyncio.Lock if hasattr(asyncio, "Lock") else None
def
initialize_langfuse
()
->
bool
:
"""
1. Set environment variables
2. Initialize Langfuse client
3. Setup thread pool for async batch export
1. Setup OpenInference instrumentation cho Agno (nếu available)
2. Configure OTLP exporter để gửi traces đến Langfuse
3. Initialize Langfuse client (fallback)
4. Register shutdown handler
"""
global
_langfuse_client
,
_export_executor
...
...
@@ -44,27 +60,95 @@ def initialize_langfuse() -> bool:
# Set environment
os
.
environ
[
"LANGFUSE_PUBLIC_KEY"
]
=
LANGFUSE_PUBLIC_KEY
os
.
environ
[
"LANGFUSE_SECRET_KEY"
]
=
LANGFUSE_SECRET_KEY
os
.
environ
[
"LANGFUSE_BASE_URL"
]
=
LANGFUSE_BASE_URL
or
"https://cloud.langfuse.com"
os
.
environ
[
"LANGFUSE_TIMEOUT"
]
=
"10"
# 10s timeout, not blocking
# Disable default flush to prevent blocking
os
.
environ
[
"LANGFUSE_FLUSHINTERVAL"
]
=
"300"
# 5 min, very infrequent
base_url
=
LANGFUSE_BASE_URL
or
"https://cloud.langfuse.com"
os
.
environ
[
"LANGFUSE_BASE_URL"
]
=
base_url
os
.
environ
[
"LANGFUSE_TIMEOUT"
]
=
"10"
os
.
environ
[
"LANGFUSE_FLUSHINTERVAL"
]
=
"300"
try
:
# ========== Setup OpenInference cho Agno ==========
global
_OPENINFERENCE_AVAILABLE
if
_OPENINFERENCE_AVAILABLE
:
try
:
# Determine Langfuse OTLP endpoint
if
"localhost"
in
base_url
or
"127.0.0.1"
in
base_url
:
otlp_endpoint
=
f
"{base_url}/api/public/otel"
elif
"us.cloud"
in
base_url
:
otlp_endpoint
=
"https://us.cloud.langfuse.com/api/public/otel"
elif
"eu.cloud"
in
base_url
:
otlp_endpoint
=
"https://eu.cloud.langfuse.com/api/public/otel"
else
:
# Custom deployment
otlp_endpoint
=
f
"{base_url}/api/public/otel"
# Create auth header
langfuse_auth
=
base64
.
b64encode
(
f
"{LANGFUSE_PUBLIC_KEY}:{LANGFUSE_SECRET_KEY}"
.
encode
()
)
.
decode
()
# Set OTLP environment variables
os
.
environ
[
"OTEL_EXPORTER_OTLP_ENDPOINT"
]
=
otlp_endpoint
os
.
environ
[
"OTEL_EXPORTER_OTLP_HEADERS"
]
=
f
"Authorization=Basic {langfuse_auth}"
# Configure TracerProvider
tracer_provider
=
TracerProvider
()
tracer_provider
.
add_span_processor
(
SimpleSpanProcessor
(
OTLPSpanExporter
()))
trace_api
.
set_tracer_provider
(
tracer_provider
=
tracer_provider
)
# Instrument Agno
if
AgnoInstrumentor
:
AgnoInstrumentor
()
.
instrument
()
logger
.
info
(
f
"✅ OpenInference instrumentation enabled for Agno"
)
logger
.
info
(
f
" → Sending traces to: {otlp_endpoint}"
)
except
Exception
as
e
:
logger
.
warning
(
f
"⚠️ Failed to setup OpenInference: {e}. Falling back to Langfuse SDK."
)
_OPENINFERENCE_AVAILABLE
=
False
# ========== Fallback: Langfuse SDK ==========
if
not
_OPENINFERENCE_AVAILABLE
:
_langfuse_client
=
get_client
()
_export_executor
=
ThreadPoolExecutor
(
max_workers
=
1
,
thread_name_prefix
=
"langfuse_export"
)
if
_langfuse_client
.
auth_check
():
logger
.
info
(
"✅ Langfuse Ready! (async batch export)"
)
# Register shutdown handler
import
atexit
atexit
.
register
(
shutdown_langfuse
)
logger
.
info
(
f
"✅ Langfuse initialized (BASE_URL: {base_url})"
)
return
True
logger
.
error
(
"❌ Langfuse auth failed"
)
return
False
except
Exception
as
e
:
logger
.
error
(
f
"❌ Langfuse init error: {e}"
)
return
False
def
shutdown_langfuse
():
"""Shutdown Langfuse client gracefully để tránh nghẽn khi exit"""
global
_langfuse_client
,
_export_executor
try
:
if
_langfuse_client
:
# Flush pending traces trước khi shutdown
try
:
_langfuse_client
.
flush
()
except
Exception
as
e
:
logger
.
debug
(
f
"Langfuse flush error during shutdown: {e}"
)
# Shutdown client (non-blocking với timeout)
try
:
if
hasattr
(
_langfuse_client
,
"shutdown"
):
_langfuse_client
.
shutdown
()
except
Exception
as
e
:
logger
.
debug
(
f
"Langfuse shutdown error: {e}"
)
if
_export_executor
:
_export_executor
.
shutdown
(
wait
=
False
)
# Non-blocking shutdown
logger
.
debug
(
"🔒 Langfuse client shutdown completed"
)
except
Exception
as
e
:
logger
.
debug
(
f
"Error during Langfuse shutdown: {e}"
)
async
def
async_flush_langfuse
():
"""
Async wrapper to flush Langfuse without blocking event loop.
...
...
backend/common/law_database.py
deleted
100644 → 0
View file @
f057ad1e
# services/law_db.py
import
asyncio
import
os
from
typing
import
Any
import
httpx
# Support absolute import when run as module, and fallback when run as script
try
:
from
common.supabase_client
import
(
close_supabase_client
,
init_supabase_client
,
supabase_rpc_call
,
)
from
common.openai_client
import
get_openai_client
except
ImportError
:
import
os
as
_os
import
sys
_ROOT
=
_os
.
path
.
dirname
(
_os
.
path
.
dirname
(
_os
.
path
.
dirname
(
_os
.
path
.
abspath
(
__file__
))))
if
_ROOT
not
in
sys
.
path
:
sys
.
path
.
append
(
_ROOT
)
from
common.supabase_client
import
(
init_supabase_client
,
supabase_rpc_call
,
)
from
common.openai_client
import
get_openai_client
# ====================== CONFIG ======================
def
get_supabase_config
():
"""Lazy load config để tránh lỗi khi chưa có env vars"""
return
{
"url"
:
f
"{os.environ['SUPABASE_URL']}/rest/v1/rpc/hoi_phap_luat_all_in_one"
,
"headers"
:
{
"apikey"
:
os
.
environ
[
"SUPABASE_ANON_KEY"
],
"Authorization"
:
f
"Bearer {os.environ['SUPABASE_ANON_KEY']}"
,
"Content-Type"
:
"application/json; charset=utf-8"
,
"Accept"
:
"application/json"
,
},
}
# ====================== HTTP HELPERS ======================
async
def
_post_with_retry
(
client
:
httpx
.
AsyncClient
,
url
:
str
,
**
kw
)
->
httpx
.
Response
:
"""POST với retry/backoff."""
last_exc
:
Exception
|
None
=
None
for
i
in
range
(
3
):
try
:
r
=
await
client
.
post
(
url
,
**
kw
)
r
.
raise_for_status
()
return
r
except
Exception
as
e
:
last_exc
=
e
if
i
==
2
:
raise
await
asyncio
.
sleep
(
0.4
*
(
2
**
i
))
raise
last_exc
# về lý thuyết không tới đây
# ====================== LABEL HELPER ======================
def
_label_for_call
(
call
:
dict
[
str
,
Any
],
index
:
int
)
->
str
:
"""Tạo nhãn hiển thị cho mỗi lệnh gọi dựa trên tham số (để phân biệt kết quả)."""
params
=
call
.
get
(
"params"
)
or
{}
if
params
.
get
(
"p_so_hieu"
):
return
str
(
params
[
"p_so_hieu"
])
if
params
.
get
(
"p_vb_pattern"
):
return
str
(
params
[
"p_vb_pattern"
])
return
f
"Truy vấn {index}"
# ====================== RAW FETCHERS ======================
async
def
_get_embedding
(
text
:
str
)
->
list
[
float
]:
"""Gọi OpenAI API để lấy embedding vector của đoạn văn bản."""
if
not
text
:
return
[]
# already imported get_openai_client above with fallback
OAI
=
get_openai_client
()
normalized_text
=
(
text
or
""
)
.
strip
()
.
lower
()
input_text
=
normalized_text
[:
8000
]
# Giới hạn độ dài
try
:
resp
=
await
OAI
.
embeddings
.
create
(
model
=
"text-embedding-3-small"
,
input
=
input_text
,
)
return
resp
.
data
[
0
]
.
embedding
except
Exception
as
e
:
print
(
f
"❌ Lỗi OpenAI embedding: {e}"
)
return
[]
async
def
law_db_fetch_one
(
params
:
dict
[
str
,
Any
])
->
list
[
dict
[
str
,
Any
]]:
"""
Gọi trực tiếp bằng httpx thay vì RPC
"""
# Xử lý p_vector_text thành embedding
processed_params
=
dict
(
params
)
if
processed_params
.
get
(
"p_vector_text"
):
vector_text
=
processed_params
.
pop
(
"p_vector_text"
)
embedding
=
await
_get_embedding
(
vector_text
)
if
embedding
and
len
(
embedding
)
>
0
:
processed_params
[
"p_vector"
]
=
embedding
# Set mode to semantic nếu chưa có
if
"p_mode"
not
in
processed_params
and
"_mode"
not
in
processed_params
:
processed_params
[
"p_mode"
]
=
"semantic"
# Map parameters to correct function signature
mapped_params
=
{}
for
key
,
value
in
processed_params
.
items
():
# Skip None values
if
value
is
None
:
continue
# Validate p_vector format
if
key
==
"p_vector"
:
if
isinstance
(
value
,
list
)
and
len
(
value
)
>
0
:
# Check for NaN or invalid values
import
math
valid_vector
=
[
v
for
v
in
value
if
isinstance
(
v
,
(
int
,
float
))
and
not
math
.
isnan
(
v
)]
if
len
(
valid_vector
)
==
len
(
value
):
mapped_params
[
"p_vector"
]
=
value
else
:
print
(
"⚠️ Warning: p_vector contains invalid values (NaN/inf), skipping"
)
elif
isinstance
(
value
,
list
)
and
len
(
value
)
==
0
:
print
(
"⚠️ Warning: p_vector is empty, skipping"
)
else
:
mapped_params
[
"p_vector"
]
=
value
elif
key
in
{
"_mode"
,
"mode"
}:
mapped_params
[
"p_mode"
]
=
value
elif
key
==
"p_vb_pattern"
:
mapped_params
[
"p_vb_pattern"
]
=
value
elif
key
==
"p_so_hieu"
:
mapped_params
[
"p_so_hieu"
]
=
value
elif
key
==
"p_trang_thai"
:
mapped_params
[
"p_trang_thai"
]
=
value
elif
key
==
"p_co_quan"
:
mapped_params
[
"p_co_quan"
]
=
value
elif
key
==
"p_loai_vb"
:
mapped_params
[
"p_loai_vb"
]
=
value
elif
key
==
"p_nam_from"
:
# Validate integer
if
isinstance
(
value
,
(
int
,
float
))
and
not
(
isinstance
(
value
,
float
)
and
value
!=
int
(
value
)):
mapped_params
[
"p_nam_from"
]
=
int
(
value
)
elif
value
is
not
None
:
print
(
f
"⚠️ Warning: p_nam_from has invalid type/value: {value}, skipping"
)
elif
key
==
"p_nam_to"
:
# Validate integer
if
isinstance
(
value
,
(
int
,
float
))
and
not
(
isinstance
(
value
,
float
)
and
value
!=
int
(
value
)):
mapped_params
[
"p_nam_to"
]
=
int
(
value
)
elif
value
is
not
None
:
print
(
f
"⚠️ Warning: p_nam_to has invalid type/value: {value}, skipping"
)
elif
key
==
"p_only_source"
:
mapped_params
[
"p_only_source"
]
=
value
elif
key
==
"p_chapter"
:
mapped_params
[
"p_chapter"
]
=
value
elif
key
==
"p_article"
:
mapped_params
[
"p_article"
]
=
value
elif
key
==
"p_phu_luc"
:
mapped_params
[
"p_phu_luc"
]
=
value
elif
key
==
"p_limit"
:
# Validate integer
if
isinstance
(
value
,
(
int
,
float
))
and
not
(
isinstance
(
value
,
float
)
and
value
!=
int
(
value
)):
mapped_params
[
"p_limit"
]
=
int
(
value
)
elif
value
is
not
None
:
print
(
f
"⚠️ Warning: p_limit has invalid type/value: {value}, skipping"
)
elif
key
==
"p_ef_search"
:
# Validate integer
if
isinstance
(
value
,
(
int
,
float
))
and
not
(
isinstance
(
value
,
float
)
and
value
!=
int
(
value
)):
mapped_params
[
"p_ef_search"
]
=
int
(
value
)
elif
value
is
not
None
:
print
(
f
"⚠️ Warning: p_ef_search has invalid type/value: {value}, skipping"
)
# Gọi qua Supabase shared client (đảm bảo đã init)
try
:
await
init_supabase_client
()
# DEBUG: Print ra JSON sẽ gửi
print
(
"📤 GỬI JSON PAYLOAD:"
)
import
json
debug_payload
=
json
.
dumps
(
mapped_params
,
ensure_ascii
=
False
,
indent
=
2
)
print
(
debug_payload
[:
500
])
# Print 500 chars đầu để check
rows
=
await
supabase_rpc_call
(
"hoi_phap_luat_all_in_one"
,
mapped_params
)
print
(
f
"✅ NHẬN RESULT: {len(rows)} rows"
)
if
rows
:
print
(
f
"First row keys: {list(rows[0].keys())}"
)
_nd
=
rows
[
0
]
.
get
(
"NoiDung"
)
or
rows
[
0
]
.
get
(
"NoiDungDieu"
)
or
rows
[
0
]
.
get
(
"NoiDungPhuLuc"
)
or
""
try
:
print
(
f
"NoiDung length: {len(_nd)}"
)
except
Exception
:
print
(
"NoiDung length: (unavailable)"
)
return
rows
or
[]
except
Exception
as
e
:
print
(
f
"❌ HTTPX call failed: {e}"
)
import
traceback
traceback
.
print_exc
()
return
[]
async
def
law_db_fetch_plan
(
calls
:
list
[
dict
[
str
,
Any
]])
->
list
[
dict
[
str
,
Any
]]:
"""
Nhận danh sách calls (mỗi call có .params) -> chạy song song -> trả:
[
{"label": "...", "rows": [ ... ]},
...
]
"""
if
not
calls
:
return
[]
async
def
run_one
(
call
:
dict
[
str
,
Any
],
idx
:
int
)
->
dict
[
str
,
Any
]:
label
=
_label_for_call
(
call
,
idx
)
params
=
call
.
get
(
"params"
,
{})
# Truyền mode xuống để map thành p_mode trong law_db_fetch_one
mode_value
=
call
.
get
(
"mode"
)
print
(
f
"DEBUG: call.mode = {mode_value}"
)
if
mode_value
is
not
None
and
"_mode"
not
in
params
and
"mode"
not
in
params
:
# Ưu tiên dùng _mode để tránh đè lên tên trường khác
params
=
dict
(
params
)
params
[
"_mode"
]
=
mode_value
print
(
f
"DEBUG: added _mode = {mode_value} to params"
)
print
(
f
"DEBUG: final params keys = {list(params.keys())}"
)
try
:
rows
=
await
law_db_fetch_one
(
params
)
return
{
"label"
:
label
,
"rows"
:
rows
}
except
Exception
:
return
{
"label"
:
label
,
"rows"
:
[]}
tasks
=
[
run_one
(
c
,
i
)
for
i
,
c
in
enumerate
(
calls
,
start
=
1
)]
return
await
asyncio
.
gather
(
*
tasks
)
# ====================== PREVIEW BUILDERS ======================
def
build_db_preview
(
rows
:
list
[
dict
[
str
,
Any
]])
->
str
:
"""
Xây dựng chuỗi văn bản nội dung từ danh sách hàng kết quả (có nội dung văn bản).
Nhóm theo văn bản pháp luật và các đơn vị (điều, phụ lục) bên trong.
"""
if
not
rows
:
return
""
docs
:
dict
[
str
,
dict
[
str
,
Any
]]
=
{}
for
r
in
rows
:
so_hieu
=
(
r
.
get
(
"SoHieu"
)
or
""
)
.
strip
()
or
"(Không rõ số hiệu)"
title
=
(
r
.
get
(
"TieuDe"
)
or
r
.
get
(
"TieuDeVanBan"
)
or
r
.
get
(
"TieuDeDieu"
)
or
so_hieu
)
.
strip
()
docs
.
setdefault
(
so_hieu
,
{
"title"
:
title
,
"groups"
:
{}})
content
=
(
r
.
get
(
"NoiDung"
)
or
r
.
get
(
"NoiDungDieu"
)
or
r
.
get
(
"NoiDungPhuLuc"
)
or
""
)
.
strip
()
if
not
content
:
continue
chunk_idx
=
int
(
r
.
get
(
"chunk_idx"
)
or
r
.
get
(
"ChunkIdx"
)
or
r
.
get
(
"ChunkIndex"
)
or
1
)
phu_luc
=
r
.
get
(
"PhuLuc"
)
or
r
.
get
(
"Phu_Luc"
)
or
r
.
get
(
"Phu_luc"
)
dieu
=
r
.
get
(
"Dieu"
)
chuong
=
r
.
get
(
"Chuong"
)
tieu_de_dieu
=
r
.
get
(
"TieuDeDieu"
)
or
""
tieu_de_pl
=
r
.
get
(
"TieuDePhuLuc"
)
or
r
.
get
(
"TenPhuLuc"
)
or
""
if
phu_luc
is
not
None
:
group_key
=
(
"PL"
,
str
(
phu_luc
))
group_title
=
f
"Phụ lục {phu_luc}"
subtitle
=
tieu_de_pl
elif
dieu
is
not
None
:
group_key
=
(
"DIEU"
,
str
(
dieu
))
group_title
=
f
"Điều {dieu}"
subtitle
=
tieu_de_dieu
else
:
group_key
=
(
"KHAC"
,
f
"Chương {chuong}"
if
chuong
is
not
None
else
"Khác"
)
group_title
=
group_key
[
1
]
subtitle
=
""
groups
=
docs
[
so_hieu
][
"groups"
]
groups
.
setdefault
(
group_key
,
{
"title"
:
group_title
,
"subtitle"
:
subtitle
,
"segs"
:
[]})
groups
[
group_key
][
"segs"
]
.
append
((
chunk_idx
,
content
))
parts
:
list
[
str
]
=
[]
for
so_hieu
,
doc
in
docs
.
items
():
header
=
f
"=== {doc['title']} ({so_hieu}) ==="
parts
.
append
(
header
)
for
group_key
,
group
in
sorted
(
doc
[
"groups"
]
.
items
(),
key
=
lambda
x
:
(
{
"PL"
:
0
,
"DIEU"
:
1
,
"KHAC"
:
2
}
.
get
(
x
[
0
][
0
],
3
),
int
(
x
[
0
][
1
])
if
x
[
0
][
1
]
.
isdigit
()
else
float
(
"inf"
),
),
):
title_line
=
group
[
"title"
]
+
(
f
" — {group['subtitle']}"
if
group
[
"subtitle"
]
else
""
)
parts
.
append
(
title_line
)
for
_
,
text
in
sorted
(
group
[
"segs"
],
key
=
lambda
seg
:
seg
[
0
]):
parts
.
append
(
text
)
parts
.
append
(
""
)
# ngắt giữa các nhóm
parts
.
append
(
""
)
# ngắt giữa các văn bản
return
"
\n
"
.
join
(
parts
)
.
strip
()
def
build_meta_preview
(
rows
:
list
[
dict
[
str
,
Any
]])
->
str
:
"""
Xây dựng chuỗi văn bản liệt kê các văn bản (chỉ meta, không có nội dung chi tiết).
"""
if
not
rows
:
return
""
lines
=
[]
for
i
,
r
in
enumerate
(
rows
[:
50
],
start
=
1
):
title
=
r
.
get
(
"TieuDe"
)
or
r
.
get
(
"TieuDeVanBan"
)
or
"(Không rõ tiêu đề)"
so_hieu
=
r
.
get
(
"SoHieu"
)
or
"—"
loai
=
r
.
get
(
"LoaiVanBan"
)
or
""
cq
=
r
.
get
(
"CoQuanBanHanh"
)
or
""
nam
=
r
.
get
(
"NamBanHanh"
)
or
""
trang_thai
=
r
.
get
(
"TrangThaiVB"
)
or
r
.
get
(
"TrangThai"
)
or
""
meta
=
" • "
.
join
(
filter
(
None
,
[
loai
,
cq
,
f
"Năm {nam}"
if
nam
else
""
,
trang_thai
]))
lines
.
append
(
f
"{i}. {title} — {so_hieu}{(' (' + meta + ')') if meta else ''}"
)
return
"
\n
"
.
join
(
lines
)
def
_build_multi_preview
(
labeled_results
:
list
[
dict
[
str
,
Any
]])
->
str
:
"""
Ghép nội dung từ nhiều kết quả truy vấn thành một chuỗi,
có tiêu đề cho từng nhóm kết quả tương ứng với từng truy vấn.
"""
sections
:
list
[
str
]
=
[]
for
item
in
labeled_results
:
label
=
item
.
get
(
"label"
,
"Truy vấn"
)
rows
=
item
.
get
(
"rows"
)
or
[]
if
any
(
r
.
get
(
"NoiDung"
)
or
r
.
get
(
"NoiDungDieu"
)
or
r
.
get
(
"NoiDungPhuLuc"
)
for
r
in
rows
):
preview_text
=
build_db_preview
(
rows
)
else
:
preview_text
=
build_meta_preview
(
rows
)
if
not
preview_text
:
preview_text
=
"(Không có dữ liệu)"
sections
.
append
(
f
"### {label}
\n
{preview_text}"
)
return
"
\n\n
"
.
join
(
sections
)
.
strip
()
# ====================== PUBLIC APIS ======================
async
def
fetch_data_db_law
(
calls
:
list
[
dict
[
str
,
Any
]])
->
str
:
"""
Hàm chính: nhận calls -> fetch raw data -> build preview theo mode -> return preview
"""
# Bước 1: Fetch raw data song song
labeled
=
await
law_db_fetch_plan
(
calls
or
[])
# Bước 2: Build preview theo mode của từng call
sections
:
list
[
str
]
=
[]
for
item
in
labeled
:
label
=
item
.
get
(
"label"
,
"Truy vấn"
)
rows
=
item
.
get
(
"rows"
)
or
[]
# Tìm mode từ calls tương ứng
call_mode
=
"content"
# default
for
call
in
calls
:
if
(
call
.
get
(
"params"
,
{})
.
get
(
"p_so_hieu"
)
and
str
(
call
.
get
(
"params"
,
{})
.
get
(
"p_so_hieu"
))
in
label
)
or
(
call
.
get
(
"params"
,
{})
.
get
(
"p_vb_pattern"
)
and
str
(
call
.
get
(
"params"
,
{})
.
get
(
"p_vb_pattern"
))
in
label
):
call_mode
=
call
.
get
(
"mode"
,
"content"
)
break
# Build preview theo mode
if
call_mode
==
"content"
:
preview_text
=
build_db_preview
(
rows
)
elif
call_mode
==
"meta"
:
preview_text
=
build_meta_preview
(
rows
)
elif
call_mode
==
"semantic"
:
preview_text
=
build_db_preview
(
rows
)
else
:
preview_text
=
build_db_preview
(
rows
)
if
not
preview_text
:
preview_text
=
"(Không có dữ liệu)"
sections
.
append
(
f
"### {label}
\n
{preview_text}"
)
return
"
\n\n
"
.
join
(
sections
)
.
strip
()
__all__
=
[
"build_db_preview"
,
"build_meta_preview"
,
"fetch_data_db_law"
,
"law_db_fetch_one"
,
"law_db_fetch_plan"
,
]
# if __name__ == "__main__":
# import argparse
# import json as _json
# async def main():
# parser = argparse.ArgumentParser(description="Test fetch_data_db_law nhanh qua CLI")
# parser.add_argument("--mode", "-m", type=str, default="content", help="content|semantic|meta (có thể phân tách bằng dấu phẩy)")
# parser.add_argument("--so_hieu", type=str, default=None, help="Giá trị cho p_so_hieu")
# parser.add_argument("--vb_pattern", type=str, default=None, help="Regex/từ khóa cho p_vb_pattern")
# parser.add_argument("--vector_text", type=str, default=None, help="Văn bản để embedding semantic (p_vector_text)")
# parser.add_argument("--loai_vb", type=str, default=None)
# parser.add_argument("--co_quan", type=str, default=None)
# parser.add_argument("--nam_from", type=int, default=None)
# parser.add_argument("--nam_to", type=int, default=None)
# parser.add_argument("--only_source", type=str, default=None, help="chinh_thong|dia_phuong")
# parser.add_argument("--limit", type=int, default=50)
# parser.add_argument("--article", type=int, default=None)
# parser.add_argument("--chapter", type=int, default=None)
# parser.add_argument("--phu_luc", type=str, default=None)
# parser.add_argument("--multi", action="store_true", help="Nếu bật, tạo nhiều calls mẫu để so sánh")
# args = parser.parse_args()
# modes = [m.strip() for m in (args.mode or "content").split(",") if m.strip()]
# def _build_params():
# return {
# "p_so_hieu": args.so_hieu,
# "p_vb_pattern": args.vb_pattern,
# "p_co_quan": args.co_quan,
# "p_loai_vb": args.loai_vb,
# "p_nam_from": args.nam_from,
# "p_nam_to": args.nam_to,
# "p_only_source": args.only_source,
# "p_article": args.article,
# "p_chapter": args.chapter,
# "p_phu_luc": args.phu_luc,
# "p_limit": args.limit,
# "p_vector_text": args.vector_text,
# }
# calls = []
# if args.multi:
# # Tạo một số tổ hợp mẫu để tiện so sánh
# sample_targets = []
# if args.so_hieu:
# sample_targets.append({"p_so_hieu": args.so_hieu})
# if args.vb_pattern:
# sample_targets.append({"p_vb_pattern": args.vb_pattern})
# if args.vector_text:
# sample_targets.append({"p_vector_text": args.vector_text})
# if not sample_targets:
# # nếu không có gì, dùng một pattern mặc định
# sample_targets = [
# {"p_vb_pattern": "BVMT|bảo vệ môi trường|môi trường"},
# {"p_vector_text": "nội dung nghị định chưa được xác định theo yêu cầu của người hỏi"},
# ]
# base_params = _build_params()
# for mode in modes:
# for target in sample_targets:
# params = dict(base_params)
# params.update({k: v for k, v in target.items() if v is not None})
# calls.append({"mode": mode, "params": {k: v for k, v in params.items() if v is not None}})
# else:
# params = {k: v for k, v in _build_params().items() if v is not None}
# if not params:
# # nếu không truyền gì, tạo ví dụ tối thiểu
# params = {"p_vb_pattern": "BVMT|bảo vệ môi trường|môi trường", "p_limit": 30}
# for mode in modes:
# calls.append({"mode": mode, "params": params})
# print("CLI calls:")
# print(_json.dumps(calls, ensure_ascii=False, indent=2))
# # Init Supabase client trước khi gọi
# await init_supabase_client()
# try:
# preview = await fetch_data_db_law(calls)
# print("\n===== PREVIEW =====\n")
# print(preview)
# finally:
# await close_supabase_client()
# asyncio.run(main())
backend/common/llm_factory.py
deleted
100644 → 0
View file @
f057ad1e
"""
LLM Factory - OpenAI LLM creation with caching.
Manages initialization and caching of OpenAI models.
"""
import
contextlib
import
logging
from
langchain_core.language_models
import
BaseChatModel
from
langchain_openai
import
ChatOpenAI
,
OpenAIEmbeddings
from
config
import
OPENAI_API_KEY
logger
=
logging
.
getLogger
(
__name__
)
class
LLMFactory
:
"""Singleton factory for managing OpenAI LLM instances with caching."""
COMMON_MODELS
:
list
[
str
]
=
[
"gpt-4o-mini"
,
"gpt-4o"
,
"gpt-5-nano"
,
"gpt-5-mini"
,
]
def
__init__
(
self
):
"""Initialize LLM factory with empty cache."""
self
.
_cache
:
dict
[
tuple
[
str
,
bool
,
bool
,
str
|
None
],
BaseChatModel
]
=
{}
def
get_model
(
self
,
model_name
:
str
,
streaming
:
bool
=
True
,
json_mode
:
bool
=
False
,
api_key
:
str
|
None
=
None
,
)
->
BaseChatModel
:
"""
Get or create an LLM instance from cache.
Args:
model_name: Model identifier (e.g., "gpt-4o-mini", "gemini-2.0-flash-lite-preview-02-05")
streaming: Enable streaming responses
json_mode: Enable JSON output format
api_key: Optional API key override
Returns:
Configured LLM instance
"""
clean_model
=
model_name
.
split
(
"/"
)[
-
1
]
if
"/"
in
model_name
else
model_name
cache_key
=
(
clean_model
,
streaming
,
json_mode
,
api_key
)
if
cache_key
in
self
.
_cache
:
logger
.
debug
(
f
"♻️ Using cached model: {clean_model}"
)
return
self
.
_cache
[
cache_key
]
logger
.
info
(
f
"Creating new LLM instance: {clean_model}"
)
return
self
.
_create_instance
(
clean_model
,
streaming
,
json_mode
,
api_key
)
def
_create_instance
(
self
,
model_name
:
str
,
streaming
:
bool
=
False
,
json_mode
:
bool
=
False
,
api_key
:
str
|
None
=
None
,
)
->
BaseChatModel
:
"""Create and cache a new OpenAI LLM instance."""
try
:
llm
=
self
.
_create_openai
(
model_name
,
streaming
,
json_mode
,
api_key
)
cache_key
=
(
model_name
,
streaming
,
json_mode
,
api_key
)
self
.
_cache
[
cache_key
]
=
llm
return
llm
except
Exception
as
e
:
logger
.
error
(
f
"❌ Failed to create model {model_name}: {e}"
)
raise
def
_create_openai
(
self
,
model_name
:
str
,
streaming
:
bool
,
json_mode
:
bool
,
api_key
:
str
|
None
)
->
BaseChatModel
:
"""Create OpenAI model instance."""
key
=
api_key
or
OPENAI_API_KEY
if
not
key
:
raise
ValueError
(
"OPENAI_API_KEY is required"
)
llm_kwargs
=
{
"model"
:
model_name
,
"streaming"
:
streaming
,
"api_key"
:
key
,
"temperature"
:
0
,
}
# Nếu bật json_mode, tiêm trực tiếp vào constructor
if
json_mode
:
llm_kwargs
[
"model_kwargs"
]
=
{
"response_format"
:
{
"type"
:
"json_object"
}}
logger
.
info
(
f
"⚙️ Initializing OpenAI in JSON mode: {model_name}"
)
llm
=
ChatOpenAI
(
**
llm_kwargs
)
logger
.
info
(
f
"✅ Created OpenAI: {model_name}"
)
return
llm
def
_enable_json_mode
(
self
,
llm
:
BaseChatModel
,
model_name
:
str
)
->
BaseChatModel
:
"""Enable JSON mode for the LLM."""
try
:
llm
=
llm
.
bind
(
response_format
=
{
"type"
:
"json_object"
})
logger
.
debug
(
f
"⚙️ JSON mode enabled for {model_name}"
)
except
Exception
as
e
:
logger
.
warning
(
f
"⚠️ JSON mode not supported: {e}"
)
return
llm
def
initialize
(
self
,
skip_warmup
:
bool
=
True
)
->
None
:
"""
Pre-initialize common models.
Args:
skip_warmup: Skip initialization if True
"""
if
skip_warmup
or
self
.
_cache
:
return
logger
.
info
(
"🔥 Warming up LLM Factory..."
)
for
model_name
in
self
.
COMMON_MODELS
:
with
contextlib
.
suppress
(
Exception
):
self
.
get_model
(
model_name
,
streaming
=
True
)
# --- Singleton Instance & Public API ---
_factory
=
LLMFactory
()
def
create_llm
(
model_name
:
str
,
streaming
:
bool
=
True
,
json_mode
:
bool
=
False
,
api_key
:
str
|
None
=
None
,
)
->
BaseChatModel
:
"""Create or get cached LLM instance."""
return
_factory
.
get_model
(
model_name
,
streaming
=
streaming
,
json_mode
=
json_mode
,
api_key
=
api_key
)
def
init_llm_factory
(
skip_warmup
:
bool
=
True
)
->
None
:
"""Initialize the LLM factory."""
_factory
.
initialize
(
skip_warmup
)
def
create_embedding_model
()
->
OpenAIEmbeddings
:
"""Create OpenAI embeddings model."""
return
OpenAIEmbeddings
(
model
=
"text-embedding-3-small"
,
api_key
=
OPENAI_API_KEY
)
backend/common/starrocks_connection.py
View file @
28274420
...
...
@@ -3,8 +3,8 @@ StarRocks Database Connection Utility
Based on chatbot-rsa pattern
"""
import
logging
import
asyncio
import
logging
from
typing
import
Any
import
aiomysql
...
...
@@ -34,11 +34,11 @@ class StarRocksConnection:
password
:
str
|
None
=
None
,
port
:
int
|
None
=
None
,
):
self
.
host
=
host
or
STARROCKS_HOST
self
.
database
=
database
or
STARROCKS_DB
self
.
user
=
user
or
STARROCKS_USER
self
.
password
=
password
or
STARROCKS_PASSWORD
self
.
port
=
port
or
STARROCKS_PORT
self
.
host
=
host
or
STARROCKS_HOST
or
""
self
.
database
=
database
or
STARROCKS_DB
or
""
self
.
user
=
user
or
STARROCKS_USER
or
""
self
.
password
=
password
or
STARROCKS_PASSWORD
or
""
self
.
port
=
port
or
STARROCKS_PORT
or
3306
# self.conn references the shared connection
self
.
conn
=
None
...
...
@@ -61,11 +61,15 @@ class StarRocksConnection:
print
(
f
" [DB] 🔌 Đang kết nối StarRocks (New Session): {self.host}:{self.port}..."
)
logger
.
info
(
f
"🔌 Connecting to StarRocks at {self.host}:{self.port} (DB: {self.database})..."
)
try
:
# Ensure all required parameters are strings (not None)
if
not
all
([
self
.
host
,
self
.
user
,
self
.
password
,
self
.
database
]):
raise
ValueError
(
"Missing required StarRocks connection parameters"
)
new_conn
=
pymysql
.
connect
(
host
=
self
.
host
,
port
=
self
.
port
,
user
=
self
.
user
,
password
=
self
.
password
,
password
=
self
.
password
,
# Now guaranteed to be str, not None
database
=
self
.
database
,
charset
=
"utf8mb4"
,
cursorclass
=
DictCursor
,
...
...
@@ -121,11 +125,15 @@ class StarRocksConnection:
# Double-check inside lock to prevent multiple pools
if
StarRocksConnection
.
_shared_pool
is
None
:
logger
.
info
(
f
"🔌 Creating Async Pool to {self.host}:{self.port}..."
)
# Ensure all required parameters are strings (not None)
if
not
all
([
self
.
host
,
self
.
user
,
self
.
password
,
self
.
database
]):
raise
ValueError
(
"Missing required StarRocks connection parameters"
)
StarRocksConnection
.
_shared_pool
=
await
aiomysql
.
create_pool
(
host
=
self
.
host
,
port
=
self
.
port
,
user
=
self
.
user
,
password
=
self
.
password
,
password
=
self
.
password
,
# Now guaranteed to be str, not None
db
=
self
.
database
,
charset
=
"utf8mb4"
,
cursorclass
=
aiomysql
.
DictCursor
,
...
...
@@ -160,15 +168,16 @@ class StarRocksConnection:
# Nếu StarRocks OOM, đợi một chút rồi thử lại
await
asyncio
.
sleep
(
0.5
*
(
attempt
+
1
))
continue
el
if
"Disconnected"
in
str
(
e
)
or
"Lost connection"
in
str
(
e
):
if
"Disconnected"
in
str
(
e
)
or
"Lost connection"
in
str
(
e
):
# Nếu mất kết nối, có thể pool bị stale, thử lại ngay
continue
else
:
# Các lỗi khác (cú pháp,...) thì raise luôn
raise
logger
.
error
(
f
"❌ Failed after {max_retries} attempts: {last_error}"
)
if
last_error
:
raise
last_error
raise
RuntimeError
(
"Failed to execute query after multiple attempts"
)
def
close
(
self
):
"""Explicitly close if needed (e.g. app shutdown)"""
...
...
backend/requirements.txt
View file @
28274420
...
...
@@ -49,7 +49,9 @@ langchain==1.2.0
langchain-core==1.2.3
langchain-google-genai==4.1.2
langchain-openai==1.1.6
agno==2.3.24
langfuse==3.11.0
openinference-instrumentation-agno==1.0.0
langgraph==1.0.5
langgraph-checkpoint==3.0.1
langgraph-checkpoint-postgres==3.0.2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment