Commit c998c05f authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

feat: optimize logging, enable embedding cache, fix rate limiting hanging, and...

feat: optimize logging, enable embedding cache, fix rate limiting hanging, and improve pagination security
parent c2ff6158
...@@ -116,7 +116,7 @@ Header: `Authorization: Bearer <token>` ...@@ -116,7 +116,7 @@ Header: `Authorization: Bearer <token>`
"message": "...", // JSON String "message": "...", // JSON String
"is_human": false, "is_human": false,
"timestamp": "..." "timestamp": "..."
} }
], ],
"next_cursor": 104 // Dùng ID này cho `before_id` tiếp theo "next_cursor": 104 // Dùng ID này cho `before_id` tiếp theo
} }
...@@ -127,9 +127,10 @@ Header: `Authorization: Bearer <token>` ...@@ -127,9 +127,10 @@ Header: `Authorization: Bearer <token>`
--- ---
## 3. Reset (Xóa và tạo mới) ## 3. Reset (Xóa và tạo mới)
**POST** `/api/history/archive` **POST** `/api/history/archive`
*(Lưu ý: Chỉ dành cho User đã đăng nhập)*
### Request ### Request
Gửi Header `device_id` (Guest) hoặc `Authorization` (User). Header `Authorization` (User).
Body rỗng `{}`. Body rỗng `{}`.
### Response ### Response
......
...@@ -76,8 +76,8 @@ async def chat_controller( ...@@ -76,8 +76,8 @@ async def chat_controller(
config.model_name = model_name config.model_name = model_name
llm = create_llm(model_name=model_name, streaming=False, json_mode=True) llm = create_llm(model_name=model_name, streaming=False, json_mode=True)
tools = get_all_tools() # tools = get_all_tools() # Singleton now handles tools
graph = build_graph(config, llm=llm, tools=tools) graph = build_graph(config) # Singleton usage
# Init ConversationManager (Singleton) # Init ConversationManager (Singleton)
memory = await get_conversation_manager() memory = await get_conversation_manager()
......
...@@ -129,18 +129,28 @@ class CANIFAGraph: ...@@ -129,18 +129,28 @@ class CANIFAGraph:
_instance: list[CANIFAGraph | None] = [None] _instance: list[CANIFAGraph | None] = [None]
def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None) -> Any: def build_graph(config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None) -> Any:
"""Get compiled graph (always fresh to pick up prompt changes).""" """Get compiled graph (Singleton usage)."""
# Always create new instance to pick up prompt changes during hot reload # Use singleton to avoid rebuilding graph on every request
instance = CANIFAGraph(config, llm, tools) manager = get_graph_manager(config, llm, tools)
return instance.build() return manager.build()
def get_graph_manager( def get_graph_manager(
config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None config: AgentConfig | None = None, llm: BaseChatModel | None = None, tools: list | None = None
) -> CANIFAGraph: ) -> CANIFAGraph:
"""Get CANIFAGraph instance.""" """Get CANIFAGraph instance (Auto-rebuild if model config changes)."""
# 1. New Instance if Empty
if _instance[0] is None: if _instance[0] is None:
_instance[0] = CANIFAGraph(config, llm, tools) _instance[0] = CANIFAGraph(config, llm, tools)
logger.info(f"✨ Graph Created: {_instance[0].config.model_name}")
return _instance[0]
# 2. Check for Config Changes (e.g. Model Switch)
if config and config.model_name != _instance[0].config.model_name:
logger.info(f"🔄 Model Switch: {_instance[0].config.model_name} -> {config.model_name}")
_instance[0] = CANIFAGraph(config, llm, tools)
return _instance[0]
return _instance[0] return _instance[0]
......
...@@ -177,13 +177,16 @@ async def _execute_single_search(db, item: SearchItem, query_vector: list[float] ...@@ -177,13 +177,16 @@ async def _execute_single_search(db, item: SearchItem, query_vector: list[float]
# Timer: build query (sử dụng vector đã có hoặc build mới) # Timer: build query (sử dụng vector đã có hoặc build mới)
query_build_start = time.time() query_build_start = time.time()
sql = await build_starrocks_query(item, query_vector=query_vector) sql, params = await build_starrocks_query(item, query_vector=query_vector)
query_build_time = (time.time() - query_build_start) * 1000 # Convert to ms query_build_time = (time.time() - query_build_start) * 1000 # Convert to ms
logger.debug("SQL built, length=%s, build_time_ms=%.2f", len(sql), query_build_time) logger.debug("SQL built, length=%s, build_time_ms=%.2f", len(sql), query_build_time)
if not sql:
return []
# Timer: execute DB query # Timer: execute DB query
db_start = time.time() db_start = time.time()
products = await db.execute_query_async(sql) products = await db.execute_query_async(sql, params=params)
db_time = (time.time() - db_start) * 1000 # Convert to ms db_time = (time.time() - db_start) * 1000 # Convert to ms
logger.info( logger.info(
"_execute_single_search done, products=%s, build_ms=%.2f, db_ms=%.2f, total_ms=%.2f", "_execute_single_search done, products=%s, build_ms=%.2f, db_ms=%.2f, total_ms=%.2f",
...@@ -223,6 +226,7 @@ def _format_product_results(products: list[dict]) -> list[dict]: ...@@ -223,6 +226,7 @@ def _format_product_results(products: list[dict]) -> list[dict]:
"name": parsed.get("product_name", ""), "name": parsed.get("product_name", ""),
"price": p.get("original_price") or 0, "price": p.get("original_price") or 0,
"sale_price": p.get("sale_price") or 0, "sale_price": p.get("sale_price") or 0,
"description": p.get("description_text_full", ""),
"url": parsed.get("product_web_url", ""), "url": parsed.get("product_web_url", ""),
"thumbnail_image_url": parsed.get("product_image_url_thumbnail", ""), "thumbnail_image_url": parsed.get("product_image_url_thumbnail", ""),
"discount_amount": p.get("discount_amount") or 0, "discount_amount": p.get("discount_amount") or 0,
......
...@@ -6,48 +6,38 @@ from common.embedding_service import create_embedding_async ...@@ -6,48 +6,38 @@ from common.embedding_service import create_embedding_async
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _escape(val: str) -> str:
"""Thoát dấu nháy đơn để tránh SQL Injection cơ bản."""
return val.replace("'", "''")
def _get_price_clauses(params, sql_params: list) -> list[str]:
def _get_where_clauses(params) -> list[str]: """Lọc theo giá (Parameterized)."""
"""Xây dựng danh sách các điều kiện lọc từ params."""
clauses = []
clauses.extend(_get_price_clauses(params))
clauses.extend(_get_metadata_clauses(params))
clauses.extend(_get_special_clauses(params))
return clauses
def _get_price_clauses(params) -> list[str]:
"""Lọc theo giá."""
clauses = [] clauses = []
p_min = getattr(params, "price_min", None) p_min = getattr(params, "price_min", None)
if p_min is not None: if p_min is not None:
clauses.append(f"sale_price >= {p_min}") clauses.append("sale_price >= %s")
sql_params.append(p_min)
p_max = getattr(params, "price_max", None) p_max = getattr(params, "price_max", None)
if p_max is not None: if p_max is not None:
clauses.append(f"sale_price <= {p_max}") clauses.append("sale_price <= %s")
sql_params.append(p_max)
return clauses return clauses
def _get_metadata_clauses(params) -> list[str]: def _get_metadata_clauses(params, sql_params: list) -> list[str]:
"""Xây dựng điều kiện lọc từ metadata (Phối hợp Exact và Partial).""" """Xây dựng điều kiện lọc từ metadata (Parameterized)."""
clauses = [] clauses = []
# 1. Exact Match (Giới tính, Độ tuổi) - Các trường này cần độ chính xác tuyệt đối # 1. Exact Match
exact_fields = [ exact_fields = [
("gender_by_product", "gender_by_product"), ("gender_by_product", "gender_by_product"),
("age_by_product", "age_by_product"), ("age_by_product", "age_by_product"),
("form_neckline", "form_neckline"),
] ]
for param_name, col_name in exact_fields: for param_name, col_name in exact_fields:
val = getattr(params, param_name, None) val = getattr(params, param_name, None)
if val: if val:
clauses.append(f"{col_name} = '{_escape(val)}'") clauses.append(f"{col_name} = %s")
sql_params.append(val)
# 2. Partial Match (LIKE) - Giúp map text linh hoạt hơn (Chất liệu, Dòng SP, Phong cách...) # 2. Partial Match (LIKE)
# Cái này giúp map: "Yarn" -> "Yarn - Sợi", "Knit" -> "Knit - Dệt Kim"
partial_fields = [ partial_fields = [
("season", "season"), ("season", "season"),
("material_group", "material_group"), ("material_group", "material_group"),
...@@ -60,48 +50,44 @@ def _get_metadata_clauses(params) -> list[str]: ...@@ -60,48 +50,44 @@ def _get_metadata_clauses(params) -> list[str]:
for param_name, col_name in partial_fields: for param_name, col_name in partial_fields:
val = getattr(params, param_name, None) val = getattr(params, param_name, None)
if val: if val:
v = _escape(val).lower() clauses.append(f"LOWER({col_name}) LIKE %s")
# Dùng LOWER + LIKE để cân mọi loại ký tự thừa hoặc hoa/thường sql_params.append(f"%{val.lower()}%")
clauses.append(f"LOWER({col_name}) LIKE '%{v}%'")
return clauses return clauses
def _get_special_clauses(params) -> list[str]: def _get_special_clauses(params, sql_params: list) -> list[str]:
"""Các trường hợp đặc biệt: Mã sản phẩm, Màu sắc.""" """Các trường hợp đặc biệt: Mã sản phẩm, Màu sắc."""
clauses = [] clauses = []
# Mã sản phẩm / SKU # Mã sản phẩm / SKU
m_code = getattr(params, "magento_ref_code", None) m_code = getattr(params, "magento_ref_code", None)
if m_code: if m_code:
m = _escape(m_code) clauses.append("(magento_ref_code = %s OR internal_ref_code = %s)")
clauses.append(f"(magento_ref_code = '{m}' OR internal_ref_code = '{m}')") sql_params.extend([m_code, m_code])
# Màu sắc # Màu sắc
color = getattr(params, "master_color", None) color = getattr(params, "master_color", None)
if color: if color:
c = _escape(color).lower() c_wildcard = f"%{color.lower()}%"
clauses.append(f"(LOWER(master_color) LIKE '%{c}%' OR LOWER(product_color_name) LIKE '%{c}%')") clauses.append("(LOWER(master_color) LIKE %s OR LOWER(product_color_name) LIKE %s)")
sql_params.extend([c_wildcard, c_wildcard])
return clauses return clauses
async def build_starrocks_query(params, query_vector: list[float] | None = None) -> str: async def build_starrocks_query(params, query_vector: list[float] | None = None) -> tuple[str, list]:
""" """
Build SQL cho Product Search với 2 chiến lược: Build SQL query với Parameterized Query để tránh SQL Injection.
1. CODE SEARCH: Nếu có magento_ref_code → Tìm trực tiếp theo mã (KHÔNG dùng vector) Returns: (sql_string, params_list)
2. HYDE SEARCH: Semantic search với HyDE vector (Pure vector approach)
""" """
# ============================================================ # ============================================================
# CASE 1: CODE SEARCH - Tìm theo mã sản phẩm (No Vector) # CASE 1: CODE SEARCH
# ============================================================ # ============================================================
magento_code = getattr(params, "magento_ref_code", None) magento_code = getattr(params, "magento_ref_code", None)
if magento_code: if magento_code:
logger.info(f"🎯 [CODE SEARCH] Direct search by code: {magento_code}") logger.info(f"🎯 [CODE SEARCH] Direct search by code: {magento_code}")
code = _escape(magento_code)
sql = """
# Tìm trực tiếp theo mã + Lọc trùng (GROUP BY internal_ref_code)
# Tìm chính xác theo mã (Lấy tất cả các bản ghi/màu sắc/size của mã đó)
sql = f"""
SELECT SELECT
internal_ref_code, internal_ref_code,
description_text_full, description_text_full,
...@@ -110,24 +96,16 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None) ...@@ -110,24 +96,16 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
discount_amount, discount_amount,
1.0 as max_score 1.0 as max_score
FROM shared_source.magento_product_dimension_with_text_embedding FROM shared_source.magento_product_dimension_with_text_embedding
WHERE (magento_ref_code = '{code}' OR internal_ref_code = '{code}') WHERE (magento_ref_code = %s OR internal_ref_code = %s)
""" """
return sql, [magento_code, magento_code]
print("✅ [CODE SEARCH] Query built - No vector search needed!")
# Ghi log debug query FULL vào Background Task (Không làm chậm Request)
# asyncio.create_task(save_query_to_log(sql))
return sql
# ============================================================ # ============================================================
# CASE 2: HYDE SEARCH - Semantic Vector Search # CASE 2: HYDE SEARCH - Semantic Vector Search
# ============================================================ # ============================================================
logger.info("🚀 [HYDE RETRIEVER] Starting semantic vector search...") logger.info("🚀 [HYDE RETRIEVER] Starting semantic vector search...")
# 1. Lấy Vector từ HyDE (AI-generated hypothetical document)
query_text = getattr(params, "query", None) query_text = getattr(params, "query", None)
if query_text and query_vector is None: if query_text and query_vector is None:
emb_start = time.time() emb_start = time.time()
query_vector = await create_embedding_async(query_text) query_vector = await create_embedding_async(query_text)
...@@ -135,18 +113,23 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None) ...@@ -135,18 +113,23 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
if not query_vector: if not query_vector:
logger.warning("⚠️ No vector found, returning empty query.") logger.warning("⚠️ No vector found, returning empty query.")
return "" return "", []
# Vector params
v_str = "[" + ",".join(str(v) for v in query_vector) + "]" v_str = "[" + ",".join(str(v) for v in query_vector) + "]"
# 2. Build PRICE filter ONLY (chỉ lọc giá, để vector tự semantic search) # Collect Params
price_clauses = _get_price_clauses(params) price_params: list = []
price_clauses = _get_price_clauses(params, price_params)
where_filter = "" where_filter = ""
if price_clauses: if price_clauses:
where_filter = " AND " + " AND ".join(price_clauses) where_filter = " AND " + " AND ".join(price_clauses)
logger.info(f"💰 [PRICE FILTER] Applied: {where_filter}") logger.info(f"💰 [PRICE FILTER] Applied: {where_filter}")
# 3. SQL Pure Vector Search + Price Filter Only # Build SQL
# NOTE: Vector v_str is safe (generated from floats) so f-string is OK here.
# Using %s for vector list string might cause StarRocks to treat it as string literal '[...]' instead of array.
sql = f""" sql = f"""
WITH top_matches AS ( WITH top_matches AS (
SELECT /*+ SET_VAR(ann_params='{{"ef_search":128}}') */ SELECT /*+ SET_VAR(ann_params='{{"ef_search":128}}') */
...@@ -174,5 +157,7 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None) ...@@ -174,5 +157,7 @@ async def build_starrocks_query(params, query_vector: list[float] | None = None)
ORDER BY max_score DESC ORDER BY max_score DESC
LIMIT 20 LIMIT 20
""" """
# Return sql and params (params only contains filter values now, not the vector)
return sql, price_params
return sql
...@@ -23,7 +23,10 @@ tracer = trace.get_tracer(__name__) ...@@ -23,7 +23,10 @@ tracer = trace.get_tracer(__name__)
router = APIRouter() router = APIRouter()
from common.rate_limit import rate_limit_service
@router.post("/api/agent/chat", summary="Fashion Q&A Chat (Non-streaming)") @router.post("/api/agent/chat", summary="Fashion Q&A Chat (Non-streaming)")
@rate_limit_service.limiter.limit("50/minute")
async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks: BackgroundTasks): async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks: BackgroundTasks):
""" """
Endpoint chat không stream - trả về response JSON đầy đủ một lần. Endpoint chat không stream - trả về response JSON đầy đủ một lần.
......
...@@ -28,7 +28,10 @@ class ClearHistoryResponse(BaseModel): ...@@ -28,7 +28,10 @@ class ClearHistoryResponse(BaseModel):
message: str message: str
from common.rate_limit import rate_limit_service
@router.get("/api/history/{identity_key}", summary="Get Chat History", response_model=ChatHistoryResponse) @router.get("/api/history/{identity_key}", summary="Get Chat History", response_model=ChatHistoryResponse)
@rate_limit_service.limiter.limit("50/minute")
async def get_chat_history(request: Request, identity_key: str, limit: int | None = 50, before_id: int | None = None): async def get_chat_history(request: Request, identity_key: str, limit: int | None = 50, before_id: int | None = None):
""" """
Lấy lịch sử chat theo identity_key. Lấy lịch sử chat theo identity_key.
...@@ -63,7 +66,8 @@ async def get_chat_history(request: Request, identity_key: str, limit: int | Non ...@@ -63,7 +66,8 @@ async def get_chat_history(request: Request, identity_key: str, limit: int | Non
@router.delete("/api/history/{identity_key}", summary="Clear Chat History", response_model=ClearHistoryResponse) @router.delete("/api/history/{identity_key}", summary="Clear Chat History", response_model=ClearHistoryResponse)
async def clear_chat_history(identity_key: str): @rate_limit_service.limiter.limit("10/minute")
async def clear_chat_history(request: Request, identity_key: str):
""" """
Xóa toàn bộ lịch sử chat theo identity_key. Xóa toàn bộ lịch sử chat theo identity_key.
Logic: Middleware đã parse token -> Nếu user đã login thì dùng user_id, không thì dùng device_id. Logic: Middleware đã parse token -> Nếu user đã login thì dùng user_id, không thì dùng device_id.
...@@ -89,28 +93,28 @@ class ArchiveResponse(BaseModel): ...@@ -89,28 +93,28 @@ class ArchiveResponse(BaseModel):
@router.post("/api/history/archive", summary="Archive Chat History", response_model=ArchiveResponse) @router.post("/api/history/archive", summary="Archive Chat History", response_model=ArchiveResponse)
@rate_limit_service.limiter.limit("5/minute")
async def archive_chat_history(request: Request): async def archive_chat_history(request: Request):
""" """
Lưu trữ lịch sử chat hiện tại (đổi tên key) và reset chat mới. Lưu trữ lịch sử chat hiện tại (đổi tên key) và reset chat mới.
Giới hạn 5 lần/ngày. Giới hạn 5 lần/ngày.
""" """
try: try:
# Tự động resolve identity
# NOTE: Với Reset, ta extract lại device_id từ body nếu có (cho trường hợp Guest bấm reset)
try:
req_json = await request.json()
body_device_id = req_json.get("device_id")
except:
body_device_id = None
identity = get_user_identity(request) identity = get_user_identity(request)
# Nếu chưa login mà có body_device_id -> ưu tiên dùng nó làm key # Chỉ dành cho User đã đăng nhập
if not identity.is_authenticated and body_device_id: if not identity.is_authenticated:
logger.info("Archive: Using device_id from Body for Guest") return JSONResponse(
identity_key = body_device_id status_code=401,
else: content={
identity_key = identity.history_key "status": "error",
"error_code": "LOGIN_REQUIRED",
"message": "Tính năng chỉ dành cho thành viên đã đăng nhập.",
"require_login": True
}
)
identity_key = identity.history_key
# Check reset limit # Check reset limit
can_reset, usage, remaining = await reset_limit_service.check_limit(identity_key) can_reset, usage, remaining = await reset_limit_service.check_limit(identity_key)
......
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel from pydantic import BaseModel
import os import os
import re import re
...@@ -30,8 +30,10 @@ def validate_prompt_braces(content: str) -> tuple[bool, list[str]]: ...@@ -30,8 +30,10 @@ def validate_prompt_braces(content: str) -> tuple[bool, list[str]]:
return len(problematic) == 0, problematic return len(problematic) == 0, problematic
from common.rate_limit import rate_limit_service
@router.get("/api/agent/system-prompt") @router.get("/api/agent/system-prompt")
async def get_system_prompt_content(): async def get_system_prompt_content(request: Request):
"""Get current system prompt content""" """Get current system prompt content"""
try: try:
if os.path.exists(PROMPT_FILE_PATH): if os.path.exists(PROMPT_FILE_PATH):
...@@ -44,11 +46,12 @@ async def get_system_prompt_content(): ...@@ -44,11 +46,12 @@ async def get_system_prompt_content():
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
@router.post("/api/agent/system-prompt") @router.post("/api/agent/system-prompt")
async def update_system_prompt_content(request: PromptUpdateRequest): @rate_limit_service.limiter.limit("10/minute")
async def update_system_prompt_content(request: Request, body: PromptUpdateRequest):
"""Update system prompt content""" """Update system prompt content"""
try: try:
# Validate braces # Validate braces
is_valid, problematic = validate_prompt_braces(request.content) is_valid, problematic = validate_prompt_braces(body.content)
if not is_valid: if not is_valid:
# Return warning but still allow save # Return warning but still allow save
...@@ -62,7 +65,7 @@ async def update_system_prompt_content(request: PromptUpdateRequest): ...@@ -62,7 +65,7 @@ async def update_system_prompt_content(request: PromptUpdateRequest):
# 1. Update file # 1. Update file
with open(PROMPT_FILE_PATH, "w", encoding="utf-8") as f: with open(PROMPT_FILE_PATH, "w", encoding="utf-8") as f:
f.write(request.content) f.write(body.content)
# 2. Reset Graph Singleton to force reload prompt # 2. Reset Graph Singleton to force reload prompt
reset_graph() reset_graph()
......
This diff is collapsed.
...@@ -43,10 +43,7 @@ class EmbeddingClientManager: ...@@ -43,10 +43,7 @@ class EmbeddingClientManager:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# NOTE: from common.cache import redis_cache
# - TẠM THỜI KHÔNG DÙNG REDIS CACHE CHO EMBEDDING để tránh phụ thuộc Redis/aioredis.
# - Nếu cần bật lại cache, import `redis_cache` từ `common.cache`
# và dùng như các đoạn code cũ (get_embedding / set_embedding).
# --- Singleton --- # --- Singleton ---
_manager = EmbeddingClientManager() _manager = EmbeddingClientManager()
...@@ -55,8 +52,11 @@ get_async_embedding_client = _manager.get_async_client ...@@ -55,8 +52,11 @@ get_async_embedding_client = _manager.get_async_client
def create_embedding(text: str) -> list[float]: def create_embedding(text: str) -> list[float]:
"""Sync embedding generation (No cache for sync to avoid overhead)""" """Sync embedding generation with Layer 2 Cache"""
try: try:
# 1. Try Cache (Sync wrapper for get_embedding if needed, but here we just use what we have)
# Note: common.cache is async, so sync create_embedding will still call OpenAI
# unless we add a sync cache method. For now, focus on async.
client = get_embedding_client() client = get_embedding_client()
response = client.embeddings.create(model="text-embedding-3-small", input=text) response = client.embeddings.create(model="text-embedding-3-small", input=text)
return response.data[0].embedding return response.data[0].embedding
...@@ -67,13 +67,24 @@ def create_embedding(text: str) -> list[float]: ...@@ -67,13 +67,24 @@ def create_embedding(text: str) -> list[float]:
async def create_embedding_async(text: str) -> list[float]: async def create_embedding_async(text: str) -> list[float]:
""" """
Async embedding generation (KHÔNG dùng cache). Async embedding generation with Layer 2 Cache.
Nếu sau này cần cache lại, có thể thêm redis_cache.get_embedding / set_embedding. Saves OpenAI costs by reusing embeddings for identical queries.
""" """
try: try:
# 1. Try Layer 2 Cache
cached = await redis_cache.get_embedding(text)
if cached:
return cached
# 2. Call OpenAI
client = get_async_embedding_client() client = get_async_embedding_client()
response = await client.embeddings.create(model="text-embedding-3-small", input=text) response = await client.embeddings.create(model="text-embedding-3-small", input=text)
embedding = response.data[0].embedding embedding = response.data[0].embedding
# 3. Store in Cache
if embedding:
await redis_cache.set_embedding(text, embedding)
return embedding return embedding
except Exception as e: except Exception as e:
logger.error(f"Error creating embedding (async): {e}") logger.error(f"Error creating embedding (async): {e}")
...@@ -88,12 +99,32 @@ async def create_embeddings_async(texts: list[str]) -> list[list[float]]: ...@@ -88,12 +99,32 @@ async def create_embeddings_async(texts: list[str]) -> list[list[float]]:
if not texts: if not texts:
return [] return []
client = get_async_embedding_client() results = [[] for _ in texts]
response = await client.embeddings.create(model="text-embedding-3-small", input=texts) missed_indices = []
missed_texts = []
# Giữ nguyên thứ tự embedding theo order input
sorted_data = sorted(response.data, key=lambda x: x.index) # 1. Check Cache for each text
results = [item.embedding for item in sorted_data] for i, text in enumerate(texts):
cached = await redis_cache.get_embedding(text)
if cached:
results[i] = cached
else:
missed_indices.append(i)
missed_texts.append(text)
# 2. Call OpenAI for missed texts
if missed_texts:
client = get_async_embedding_client()
response = await client.embeddings.create(model="text-embedding-3-small", input=missed_texts)
# OpenAI returns embeddings in the same order as missed_texts
for i, data_item in enumerate(response.data):
idx = missed_indices[i]
embedding = data_item.embedding
results[idx] = embedding
# 3. Cache the new embedding
await redis_cache.set_embedding(missed_texts[i], embedding)
return results return results
except Exception as e: except Exception as e:
......
"""
StarRocks Database Connection Utility
Based on chatbot-rsa pattern
"""
import asyncio
import logging
from typing import Any
import aiomysql
import pymysql
from pymysql.cursors import DictCursor
from config import (
STARROCKS_DB,
STARROCKS_HOST,
STARROCKS_PASSWORD,
STARROCKS_PORT,
STARROCKS_USER,
)
logger = logging.getLogger(__name__)
__all__ = ["StarRocksConnection", "get_db_connection"]
class StarRocksConnectionManager:
"""
Singleton Class quản lý StarRocks Connection.
"""
def __init__(self):
self._connection: StarRocksConnection | None = None
def get_connection(self) -> "StarRocksConnection":
"""Lazy loading connection"""
if self._connection is None:
logger.info("🔧 [LAZY LOADING] Creating StarRocksConnection instance (first time)")
self._connection = StarRocksConnection()
return self._connection
# --- Singleton ---
_manager = StarRocksConnectionManager()
get_db_connection = _manager.get_connection
class StarRocksConnection:
# Shared connection (Singleton-like behavior) for all instances
_shared_conn = None
def __init__(
self,
host: str | None = None,
database: str | None = None,
user: str | None = None,
password: str | None = None,
port: int | None = None,
):
self.host = host or STARROCKS_HOST
self.database = database or STARROCKS_DB
self.user = user or STARROCKS_USER
self.password = password or STARROCKS_PASSWORD
self.port = port or STARROCKS_PORT
# self.conn references the shared connection
self.conn = None
logger.info(f"✅ StarRocksConnection initialized: {self.host}:{self.port}")
def connect(self):
"""
Establish or reuse persistent connection.
"""
# 1. Try to reuse existing shared connection
if StarRocksConnection._shared_conn and StarRocksConnection._shared_conn.open:
try:
# Ping to check if alive, reconnect if needed
StarRocksConnection._shared_conn.ping(reconnect=True)
self.conn = StarRocksConnection._shared_conn
return self.conn
except Exception as e:
logger.warning(f"⚠️ Connection lost, reconnecting: {e}")
StarRocksConnection._shared_conn = None
# 2. Create new connection if needed
print(f" [DB] 🔌 Đang kết nối StarRocks (New Session): {self.host}:{self.port}...")
logger.info(f"🔌 Connecting to StarRocks at {self.host}:{self.port} (DB: {self.database})...")
try:
new_conn = pymysql.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
database=self.database,
charset="utf8mb4",
cursorclass=DictCursor,
connect_timeout=10,
read_timeout=30,
write_timeout=30,
)
print(" [DB] ✅ Kết nối thành công.")
logger.info("✅ Connected to StarRocks")
# Save to class variable
StarRocksConnection._shared_conn = new_conn
self.conn = new_conn
except Exception as e:
print(f" [DB] ❌ Lỗi kết nối: {e!s}")
logger.error(f"❌ Failed to connect to StarRocks: {e}")
raise
return self.conn
def execute_query(self, query: str, params: tuple | None = None) -> list[dict[str, Any]]:
# print(" [DB] 🚀 Bắt đầu truy vấn dữ liệu...")
# (Reduced noise in logs)
logger.info("🚀 Executing StarRocks Query (Persistent Conn).")
conn = self.connect()
try:
with conn.cursor() as cursor:
cursor.execute(query, params)
results = cursor.fetchall()
print(f" [DB] ✅ Truy vấn xong. Lấy được {len(results)} dòng.")
logger.info(f"📊 Query successful, returned {len(results)} rows")
return [dict(row) for row in results]
except Exception as e:
print(f" [DB] ❌ Lỗi truy vấn: {e!s}")
logger.error(f"❌ StarRocks query error: {e}")
# Incase of query error due to connection, invalidate it
StarRocksConnection._shared_conn = None
raise
# FINALLY BLOCK REMOVED: Do NOT close connection
# Async pool shared
_shared_pool = None
_pool_lock = asyncio.Lock()
@classmethod
async def clear_pool(cls):
"""Clear and close existing pool (force recreate fresh connections)"""
async with cls._pool_lock:
if cls._shared_pool is not None:
logger.warning("🔄 Clearing StarRocks connection pool...")
cls._shared_pool.close()
await cls._shared_pool.wait_closed()
cls._shared_pool = None
logger.info("✅ Pool cleared successfully")
async def get_pool(self):
"""
Get or create shared async connection pool (Thread-safe singleton)
"""
if StarRocksConnection._shared_pool is None:
async with StarRocksConnection._pool_lock:
# Double-check inside lock to prevent multiple pools
if StarRocksConnection._shared_pool is None:
logger.info(f"🔌 Creating Async Pool to {self.host}:{self.port}...")
StarRocksConnection._shared_pool = await aiomysql.create_pool(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
db=self.database,
charset="utf8mb4",
cursorclass=aiomysql.DictCursor,
minsize=5, # ← Từ 10 → 5
maxsize=30, # ← Từ 200 → 30 (QUAN TRỌNG!)
connect_timeout=10,
pool_recycle=1800, # ← Từ 3600 → 1800 (30 phút)
autocommit=True,
)
return StarRocksConnection._shared_pool
async def execute_query_async(self, query: str, params: tuple | None = None) -> list[dict[str, Any]]:
"""
Execute query asynchronously using aiomysql pool with Retry Logic.
"""
max_retries = 3
last_error = None
for attempt in range(max_retries):
try:
pool = await self.get_pool()
# logger.info(f"🚀 Executing Async Query (Attempt {attempt+1}).")
# Tăng timeout lên 30s cho load test với 300 users
conn = await asyncio.wait_for(pool.acquire(), timeout=30)
try:
async with conn.cursor() as cursor:
await cursor.execute(query, params)
results = await cursor.fetchall()
# logger.info(f"📊 Async Query successful, returned {len(results)} rows")
return [dict(row) for row in results]
finally:
pool.release(conn)
except TimeoutError as e:
last_error = e
logger.warning(f"⏱️ Pool acquire timeout (Attempt {attempt + 1}/{max_retries})")
# Timeout khi lấy connection → pool đầy, chờ rồi thử lại
await asyncio.sleep(0.2 * (attempt + 1))
continue
except ConnectionAbortedError as e:
last_error = e
logger.warning(f"🔌 Connection aborted (Attempt {attempt + 1}/{max_retries}): {e}")
# Connection bị abort → clear pool và thử lại với fresh connections
if attempt < max_retries - 1:
await StarRocksConnection.clear_pool()
await asyncio.sleep(0.3)
continue
except Exception as e:
last_error = e
logger.warning(f"⚠️ StarRocks DB Error (Attempt {attempt + 1}/{max_retries}): {e}")
if "Memory of process exceed limit" in str(e):
# Nếu StarRocks OOM, đợi một chút rồi thử lại
await asyncio.sleep(0.5 * (attempt + 1))
continue
if "Disconnected" in str(e) or "Lost connection" in str(e) or "aborted" in str(e).lower():
# Nếu mất kết nối, clear pool và thử lại
if attempt < max_retries - 1:
await StarRocksConnection.clear_pool()
await asyncio.sleep(0.3)
continue
# Các lỗi khác (cú pháp,...) thì raise luôn
raise
logger.error(f"❌ Failed after {max_retries} attempts: {last_error}")
raise last_error
def close(self):
"""Explicitly close if needed (e.g. app shutdown)"""
if StarRocksConnection._shared_conn and StarRocksConnection._shared_conn.open:
StarRocksConnection._shared_conn.close()
StarRocksConnection._shared_conn = None
self.conn = None
...@@ -189,10 +189,6 @@ class StarRocksConnection: ...@@ -189,10 +189,6 @@ class StarRocksConnection:
conn = await asyncio.wait_for(pool.acquire(), timeout=90) conn = await asyncio.wait_for(pool.acquire(), timeout=90)
async with conn.cursor() as cursor: async with conn.cursor() as cursor:
# Ping kiểm tra sức khỏe connection
await conn.ping()
# Chạy query
await cursor.execute(query, params) await cursor.execute(query, params)
results = await cursor.fetchall() results = await cursor.fetchall()
return [dict(row) for row in results] return [dict(row) for row in results]
......
...@@ -58,7 +58,7 @@ async def startup_event(): ...@@ -58,7 +58,7 @@ async def startup_event():
middleware_manager.setup( middleware_manager.setup(
app, app,
enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User
enable_rate_limit=False, # 👈 Tắt slowapi vì đã có business rate limit enable_rate_limit=True, # 👈 Bật lại SlowAPI theo yêu cầu
enable_cors=True, # 👈 Bật CORS enable_cors=True, # 👈 Bật CORS
cors_origins=["*"], # 👈 Trong production nên limit origins cors_origins=["*"], # 👈 Trong production nên limit origins
) )
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment