Commit 3ebbd3e4 authored by Vũ Hoàng Anh's avatar Vũ Hoàng Anh

feat: Implement Dynamic Message Limit (Guest/User) and Canifa Auth Integration

parent ec2525d2
...@@ -150,13 +150,13 @@ async def chat_controller( ...@@ -150,13 +150,13 @@ async def chat_controller(
# Cache for 5 minutes (300s) - Short enough for stock safety # Cache for 5 minutes (300s) - Short enough for stock safety
# await redis_cache.set_response(user_id=user_id, query=query, response_data=response_payload, ttl=300) # await redis_cache.set_response(user_id=user_id, query=query, response_data=response_payload, ttl=300)
# Add to history in background # Add to history in background - lưu nguyên response JSON
background_tasks.add_task( background_tasks.add_task(
_handle_post_chat_async, _handle_post_chat_async,
memory=memory, memory=memory,
user_id=user_id, user_id=user_id,
human_query=query, human_query=query,
ai_msg=AIMessage(content=ai_text_response), ai_response=response_payload, # dict: {ai_response, product_ids}
) )
logger.info("chat_controller finished in %.2fs", duration) logger.info("chat_controller finished in %.2fs", duration)
...@@ -235,12 +235,17 @@ def _prepare_execution_context(query: str, user_id: str, history: list, images: ...@@ -235,12 +235,17 @@ def _prepare_execution_context(query: str, user_id: str, history: list, images:
async def _handle_post_chat_async( async def _handle_post_chat_async(
memory: ConversationManager, user_id: str, human_query: str, ai_msg: AIMessage | None memory: ConversationManager, user_id: str, human_query: str, ai_response: dict | None
): ):
"""Save chat history in background task after response is sent.""" """
if ai_msg: Save chat history in background task after response is sent.
Lưu AI response dưới dạng JSON string.
"""
if ai_response:
try: try:
await memory.save_conversation_turn(user_id, human_query, ai_msg.content) # Convert dict thành JSON string để lưu vào TEXT field
ai_response_json = json.dumps(ai_response, ensure_ascii=False)
await memory.save_conversation_turn(user_id, human_query, ai_response_json)
logger.debug(f"Saved conversation for user {user_id}") logger.debug(f"Saved conversation for user {user_id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to save conversation for user {user_id}: {e}", exc_info=True) logger.error(f"Failed to save conversation for user {user_id}: {e}", exc_info=True)
...@@ -2,15 +2,21 @@ ...@@ -2,15 +2,21 @@
Fashion Q&A Agent Router Fashion Q&A Agent Router
FastAPI endpoints cho Fashion Q&A Agent service. FastAPI endpoints cho Fashion Q&A Agent service.
Router chỉ chứa định nghĩa API, logic nằm ở controller. Router chỉ chứa định nghĩa API, logic nằm ở controller.
Message Limit:
- Guest (không login): 10 tin/ngày theo device_id
- User đã login: 100 tin/ngày theo user_id
""" """
import logging import logging
from fastapi import APIRouter, BackgroundTasks, HTTPException from fastapi import APIRouter, BackgroundTasks, HTTPException, Request
from opentelemetry import trace from opentelemetry import trace
from agent.controller import chat_controller from agent.controller import chat_controller
from agent.models import QueryRequest from agent.models import QueryRequest
from common.message_limit import message_limit_service
from common.user_identity import get_user_identity
from config import DEFAULT_MODEL from config import DEFAULT_MODEL
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -19,11 +25,41 @@ router = APIRouter() ...@@ -19,11 +25,41 @@ router = APIRouter()
@router.post("/api/agent/chat", summary="Fashion Q&A Chat (Non-streaming)") @router.post("/api/agent/chat", summary="Fashion Q&A Chat (Non-streaming)")
async def fashion_qa_chat(req: QueryRequest, background_tasks: BackgroundTasks): async def fashion_qa_chat(request: Request, req: QueryRequest, background_tasks: BackgroundTasks):
""" """
Endpoint chat không stream - trả về response JSON đầy đủ một lần. Endpoint chat không stream - trả về response JSON đầy đủ một lần.
Message Limit:
- Guest: 10 tin nhắn/ngày (theo device_id)
- User đã login: 100 tin nhắn/ngày (theo user_id)
""" """
user_id = req.user_id or "default_user" # 1. Xác định user identity
identity = get_user_identity(request)
user_id = identity.primary_id
# 2. Check message limit TRƯỚC khi xử lý
can_send, limit_info = await message_limit_service.check_limit(
identity_key=identity.rate_limit_key,
is_authenticated=identity.is_authenticated,
)
if not can_send:
logger.warning(
f"⚠️ Message limit exceeded: {identity.rate_limit_key} | "
f"used={limit_info['used']}/{limit_info['limit']}"
)
return {
"status": "error",
"error_code": "MESSAGE_LIMIT_EXCEEDED",
"message": limit_info["message"],
"require_login": limit_info["require_login"],
"limit_info": {
"limit": limit_info["limit"],
"used": limit_info["used"],
"remaining": limit_info["remaining"],
"reset_seconds": limit_info["reset_seconds"],
},
}
logger.info(f"📥 [Incoming Query - NonStream] User: {user_id} | Query: {req.user_query}") logger.info(f"📥 [Incoming Query - NonStream] User: {user_id} | Query: {req.user_query}")
...@@ -62,11 +98,23 @@ async def fashion_qa_chat(req: QueryRequest, background_tasks: BackgroundTasks): ...@@ -62,11 +98,23 @@ async def fashion_qa_chat(req: QueryRequest, background_tasks: BackgroundTasks):
}, },
) )
# 3. Increment message count SAU KHI chat thành công
usage_info = await message_limit_service.increment(
identity_key=identity.rate_limit_key,
is_authenticated=identity.is_authenticated,
)
return { return {
"status": "success", "status": "success",
"ai_response": result["ai_response"], "ai_response": result["ai_response"],
"product_ids": result.get("product_ids", []), "product_ids": result.get("product_ids", []),
"limit_info": {
"limit": usage_info["limit"],
"used": usage_info["used"],
"remaining": usage_info["remaining"],
},
} }
except Exception as e: except Exception as e:
logger.error(f"Error in fashion_qa_chat: {e}", exc_info=True) logger.error(f"Error in fashion_qa_chat: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) from e raise HTTPException(status_code=500, detail=str(e)) from e
"""
Chat History API Routes
- GET /api/history/{user_id} - Lấy lịch sử chat (có product_ids)
- DELETE /api/history/{user_id} - Xóa lịch sử chat
"""
import logging import logging
from typing import Any from typing import Any
from fastapi import APIRouter from fastapi import APIRouter, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from common.conversation_manager import get_conversation_manager from common.conversation_manager import get_conversation_manager
router = APIRouter(tags=["Conservation"]) router = APIRouter(tags=["Chat History"])
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ChatMessage(BaseModel):
id: int
user_id: str | None = None # Optional usually not needed in list but good for consistency
message: str
is_human: bool
timestamp: Any
class ChatHistoryResponse(BaseModel): class ChatHistoryResponse(BaseModel):
data: list[dict[str, Any]] data: list[dict[str, Any]]
next_cursor: int | None = None next_cursor: int | None = None
@router.get("/api/history/{user_id}", summary="Get Chat History by User ID", response_model=ChatHistoryResponse) class ClearHistoryResponse(BaseModel):
success: bool
message: str
@router.get("/api/history/{user_id}", summary="Get Chat History", response_model=ChatHistoryResponse)
async def get_chat_history(user_id: str, limit: int | None = 50, before_id: int | None = None): async def get_chat_history(user_id: str, limit: int | None = 50, before_id: int | None = None):
""" """
Lấy lịch sử chat của user từ Postgres database. Lấy lịch sử chat của user.
Trả về object chứa `data` (list messages) và `next_cursor` để dùng cho trang tiếp theo.
Response bao gồm:
- message: Nội dung tin nhắn
- is_human: True nếu là user, False nếu là AI
- product_ids: List sản phẩm liên quan (chỉ có với AI messages)
- timestamp: Thời gian
- id: ID tin nhắn (dùng cho pagination)
""" """
try: try:
# Sử dụng ConversationManager Singleton
manager = await get_conversation_manager() manager = await get_conversation_manager()
# Lấy history từ DB
history = await manager.get_chat_history(user_id, limit=limit, before_id=before_id) history = await manager.get_chat_history(user_id, limit=limit, before_id=before_id)
next_cursor = None next_cursor = None
...@@ -43,3 +49,19 @@ async def get_chat_history(user_id: str, limit: int | None = 50, before_id: int ...@@ -43,3 +49,19 @@ async def get_chat_history(user_id: str, limit: int | None = 50, before_id: int
return {"data": history, "next_cursor": next_cursor} return {"data": history, "next_cursor": next_cursor}
except Exception as e: except Exception as e:
logger.error(f"Error fetching chat history for user {user_id}: {e}") logger.error(f"Error fetching chat history for user {user_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to fetch chat history")
@router.delete("/api/history/{user_id}", summary="Clear Chat History", response_model=ClearHistoryResponse)
async def clear_chat_history(user_id: str):
"""
Xóa toàn bộ lịch sử chat của user.
"""
try:
manager = await get_conversation_manager()
await manager.clear_history(user_id)
logger.info(f"✅ Cleared chat history for user {user_id}")
return {"success": True, "message": f"Đã xóa lịch sử chat của user {user_id}"}
except Exception as e:
logger.error(f"Error clearing chat history for user {user_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to clear chat history")
...@@ -2,7 +2,7 @@ import hashlib ...@@ -2,7 +2,7 @@ import hashlib
import json import json
import logging import logging
import aioredis import redis.asyncio as aioredis # redis package với async support (thay thế aioredis deprecated)
from config import ( from config import (
REDIS_CACHE_DB, REDIS_CACHE_DB,
......
"""
Canifa API Service
Xử lý các logic liên quan đến API của Canifa (Magento)
"""
import logging
import httpx
from typing import Optional, Dict, Any
logger = logging.getLogger(__name__)
# URL API Canifa
CANIFA_CUSTOMER_API = "https://canifa.com/v1/magento/customer"
# GraphQL Query Body giả lập (để lấy User Info)
CANIFA_QUERY_BODY = [
{
"customer": "customer-custom-query",
"metadata": {
"fields": "\n customer {\n gender\n customer_id\n phone_number\n date_of_birth\n default_billing\n default_shipping\n email\n firstname\n is_subscribed\n lastname\n middlename\n prefix\n suffix\n taxvat\n addresses {\n city\n country_code\n default_billing\n default_shipping\n extension_attributes {\n attribute_code\n value\n }\n custom_attributes {\n attribute_code\n value\n }\n firstname\n id\n lastname\n postcode\n prefix\n region {\n region_code\n region_id\n region\n }\n street\n suffix\n telephone\n vat_id\n }\n is_subscribed\n }\n "
}
},
{}
]
async def verify_canifa_token(token: str) -> Optional[Dict[str, Any]]:
"""
Verify token với API Canifa (Magento).
Dùng token làm cookie `vsf-customer` để gọi API lấy thông tin customer.
Args:
token: Giá trị của cookie vsf-customer (lấy từ Header Authorization)
Returns:
Dict info user hoặc None nếu lỗi
"""
if not token:
return None
headers = {
"accept": "application/json, text/plain, */*",
"content-type": "application/json",
"Cookie": f"vsf-customer={token}" # Quan trọng: Gửi token dưới dạng Cookie
}
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.post(
CANIFA_CUSTOMER_API,
json=CANIFA_QUERY_BODY,
headers=headers
)
if response.status_code == 200:
data = response.json()
# Check nếu response là lỗi (Magento thường trả 200 kèm body lỗi đôi khi)
if isinstance(data, dict):
if data.get("code") != 200:
logger.warning(f"Canifa API Business Error: {data.get('code')} - {data.get('result')}")
return None
return data.get("result", {})
# Nếu Canifa trả list (đôi khi batch request trả về list)
return data
else:
logger.warning(f"Canifa API Failed: {response.status_code} - {response.text}")
return None
except Exception as e:
logger.error(f"Error calling Canifa API: {e}")
return None
async def extract_user_id_from_canifa_response(data: Any) -> Optional[str]:
"""
Bóc customer_id từ response data của Canifa.
"""
if not data:
return None
try:
# Dự phòng các format data trả về khác nhau
customer = None
# Format 1: data['customer']
if isinstance(data, dict):
customer = data.get('customer') or data.get('data', {}).get('customer')
# Format 2: data là list (nếu query batch)
elif isinstance(data, list) and len(data) > 0:
item = data[0]
if isinstance(item, dict):
customer = item.get('result', {}).get('customer') or item.get('data', {}).get('customer')
if customer and isinstance(customer, dict):
user_id = customer.get('customer_id') or customer.get('id')
if user_id:
return str(user_id)
return None
except Exception as e:
logger.error(f"Error parsing user_id from Canifa response: {e}")
return None
import json
import logging import logging
from datetime import datetime from datetime import datetime, date
from typing import Any from typing import Any
from psycopg_pool import AsyncConnectionPool from psycopg_pool import AsyncConnectionPool
...@@ -82,7 +83,10 @@ class ConversationManager: ...@@ -82,7 +83,10 @@ class ConversationManager:
async def get_chat_history( async def get_chat_history(
self, user_id: str, limit: int | None = None, before_id: int | None = None self, user_id: str, limit: int | None = None, before_id: int | None = None
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
"""Retrieve chat history for a user using cursor-based pagination.""" """
Retrieve chat history for a user using cursor-based pagination.
AI messages được parse từ JSON string để lấy product_ids.
"""
try: try:
query = f""" query = f"""
SELECT message, is_human, timestamp, id SELECT message, is_human, timestamp, id
...@@ -106,15 +110,34 @@ class ConversationManager: ...@@ -106,15 +110,34 @@ class ConversationManager:
await cursor.execute(query, tuple(params)) await cursor.execute(query, tuple(params))
results = await cursor.fetchall() results = await cursor.fetchall()
return [ history = []
{ for row in results:
"message": row[0], message_content = row[0]
"is_human": row[1], is_human = row[1]
entry = {
"is_human": is_human,
"timestamp": row[2], "timestamp": row[2],
"id": row[3], "id": row[3],
} }
for row in results
] if is_human:
# User message - text thuần
entry["message"] = message_content
else:
# AI message - parse JSON để lấy ai_response + product_ids
try:
parsed = json.loads(message_content)
entry["message"] = parsed.get("ai_response", message_content)
entry["product_ids"] = parsed.get("product_ids", [])
except (json.JSONDecodeError, TypeError):
# Fallback nếu không phải JSON (data cũ)
entry["message"] = message_content
entry["product_ids"] = []
history.append(entry)
return history
except Exception as e: except Exception as e:
logger.error(f"Error retrieving chat history: {e}") logger.error(f"Error retrieving chat history: {e}")
return [] return []
...@@ -143,6 +166,29 @@ class ConversationManager: ...@@ -143,6 +166,29 @@ class ConversationManager:
logger.error(f"Error getting user count: {e}") logger.error(f"Error getting user count: {e}")
return 0 return 0
async def get_message_count_today(self, user_id: str) -> int:
"""
Đếm số tin nhắn của user trong ngày hôm nay (cho rate limiting).
Chỉ đếm human messages (is_human = true).
"""
try:
pool = await self._get_pool()
async with pool.connection() as conn, conn.cursor() as cursor:
await cursor.execute(
f"""
SELECT COUNT(*) FROM {self.table_name}
WHERE user_id = %s
AND is_human = true
AND DATE(timestamp) = CURRENT_DATE
""",
(user_id,),
)
result = await cursor.fetchone()
return result[0] if result else 0
except Exception as e:
logger.error(f"Error counting messages for {user_id}: {e}")
return 0
async def close(self): async def close(self):
"""Close the connection pool""" """Close the connection pool"""
if self._pool: if self._pool:
......
"""
Message Limit Service
Giới hạn số tin nhắn theo ngày:
- Guest (không login): 10 tin/ngày theo device_id
- User đã login: 100 tin/ngày theo user_id
Lưu trữ: Redis (dùng chung với cache.py)
"""
from __future__ import annotations
import logging
import os
from datetime import datetime
from common.cache import redis_cache
logger = logging.getLogger(__name__)
# =============================================================================
# CONFIGURATION
# =============================================================================
GUEST_LIMIT_PER_DAY = int(os.getenv("MESSAGE_LIMIT_GUEST", "3")) # Tạm set 3 để test
USER_LIMIT_PER_DAY = int(os.getenv("MESSAGE_LIMIT_USER", "100"))
# Redis key prefix
MESSAGE_COUNT_PREFIX = "msg_limit:"
class MessageLimitService:
"""
Service quản lý giới hạn tin nhắn theo ngày.
Dùng Redis để lưu trữ, tự động reset mỗi ngày.
Usage:
from common.message_limit import message_limit_service
# Check trước khi chat
can_send, info = await message_limit_service.check_limit(
identity_key="device:abc123", # hoặc "user:123"
is_authenticated=False
)
if not can_send:
return {"error": info["message"], ...}
# Sau khi chat thành công
await message_limit_service.increment(identity_key, is_authenticated)
"""
_instance: MessageLimitService | None = None
_initialized: bool = False
def __new__(cls) -> MessageLimitService:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if MessageLimitService._initialized:
return
# Fallback in-memory storage: { "device_id": {"guest": 0, "user": 0} }
self._memory_storage: dict[str, dict[str, int]] = {}
self._memory_date: str = ""
# Limits
self.guest_limit = 3 # Test limit
self.total_limit = 5 # Test limit
MessageLimitService._initialized = True
logger.info(
f"✅ MessageLimitService initialized "
f"(Guest Limit: {self.guest_limit}, Total Limit: {self.total_limit})"
)
# =========================================================================
# HELPER METHODS
# =========================================================================
def _get_today_key(self) -> str:
"""Get today's date key (format: YYYY-MM-DD)."""
return datetime.now().strftime("%Y-%m-%d")
def _get_redis_key(self, identity_key: str) -> str:
"""
Build Redis key.
Format: msg_limit:2026-01-17:device_id
Structure: Hash { "guest": int, "user": int }
"""
today = self._get_today_key()
return f"{MESSAGE_COUNT_PREFIX}{today}:{identity_key}"
def _get_seconds_until_midnight(self) -> int:
"""
Get seconds until next midnight (00:00).
"""
from datetime import timedelta
now = datetime.now()
tomorrow = now.date() + timedelta(days=1)
midnight = datetime.combine(tomorrow, datetime.min.time())
return int((midnight - now).total_seconds())
def _reset_memory_if_new_day(self) -> None:
"""Reset in-memory storage nếu qua ngày mới."""
today = self._get_today_key()
if self._memory_date != today:
self._memory_storage.clear()
self._memory_date = today
logger.debug(f"🔄 Memory storage reset for new day: {today}")
# =========================================================================
# REDIS OPERATIONS
# =========================================================================
async def _get_counts_from_redis(self, identity_key: str) -> dict[str, int] | None:
"""
Get all counts (guest, user) từ Redis Hash.
Returns: {"guest": int, "user": int} hoặc None nếu lỗi Redis.
"""
try:
client = redis_cache.get_client()
if not client:
return None
redis_key = self._get_redis_key(identity_key)
# HGETALL trả về dict {b'guest': b'1', ...}
data = await client.hgetall(redis_key)
# Parse data
counts = {"guest": 0, "user": 0}
if data:
# Redis trả về bytes trong dict keys/values
counts["guest"] = int(data.get("guest") or data.get(b"guest") or 0)
counts["user"] = int(data.get("user") or data.get(b"user") or 0)
return counts
except Exception as e:
logger.warning(f"Redis get counts error: {e}")
return None
async def _increment_in_redis(self, identity_key: str, field: str) -> int | None:
"""
Increment specific field ('guest' or 'user') trong Redis Hash.
"""
try:
client = redis_cache.get_client()
if not client:
return None
redis_key = self._get_redis_key(identity_key)
# HINCRBY field 1
new_val = await client.hincrby(redis_key, field, 1)
# Set TTL message nếu key mới tạo (coi như mới nếu chỉ có 1 field vừa set = 1)
# Tuy nhiên để chắc chắn, cứ set expire với nx=True hoặc check ttl
# Đơn giản nhất: Nếu new_val == 1 -> Có thể là key mới, set TTL
if new_val == 1:
ttl = await client.ttl(redis_key)
if ttl < 0: # Chưa có TTL
await client.expire(redis_key, 48 * 3600)
return new_val
except Exception as e:
logger.warning(f"Redis increment error: {e}")
return None
# =========================================================================
# PUBLIC METHODS
# =========================================================================
async def check_limit(
self,
identity_key: str,
is_authenticated: bool,
) -> tuple[bool, dict]:
"""
Check logic:
- Total (guest + user) < 100
- Nếu Guest: thêm điều kiện guest < 10
"""
reset_seconds = self._get_seconds_until_midnight()
# 1. Get counts
counts = await self._get_counts_from_redis(identity_key)
# Fallback memory
if counts is None:
self._reset_memory_if_new_day()
counts = self._memory_storage.get(identity_key, {"guest": 0, "user": 0})
guest_used = counts.get("guest", 0)
user_used = counts.get("user", 0)
total_used = guest_used + user_used
# 2. Logic Checking
can_send = True
limit_display = self.total_limit
message = ""
require_login = False
# Check Total Limit (Hard limit cho device)
if total_used >= self.total_limit:
can_send = False
# Thông báo khi hết tổng quota (dù là user hay guest)
if is_authenticated:
message = f"Bạn đã sử dụng hết {self.total_limit} tin nhắn hôm nay. Quay lại vào ngày mai nhé!"
else:
# Guest dùng hết 100 tin (hiếm, vì guest bị chặn ở 10 rồi, trừ khi login rồi logout)
message = f"Thiết bị này đã đạt giới hạn {self.total_limit} tin nhắn hôm nay."
# Check Guest Limit (nếu chưa login và chưa bị chặn bởi total)
elif not is_authenticated:
limit_display = self.guest_limit
if guest_used >= self.guest_limit:
can_send = False
require_login = True
message = (
f"Bạn đã dùng hết {self.guest_limit} tin nhắn miễn phí. "
f"Đăng nhập ngay để dùng tiếp (tối đa {self.total_limit} tin/ngày)!"
)
# 3. Build Remaining Info
# Nếu là guest: remaining = min(guest_remaining, total_remaining)
# Thực ra guest chỉ care guest_remaining vì guest < total
if is_authenticated:
remaining = max(0, self.total_limit - total_used)
else:
# Guest bị chặn bởi guest_limit hoặc total_limit (trường hợp login rồi logout)
guest_remaining = max(0, self.guest_limit - guest_used)
total_remaining = max(0, self.total_limit - total_used)
remaining = min(guest_remaining, total_remaining)
info = {
"limit": limit_display,
"used": total_used if is_authenticated else guest_used, # Show cái user quan tâm
"total_used": total_used, # Info thêm để debug/admin
"guest_used": guest_used,
"user_used": user_used,
"remaining": remaining,
"reset_seconds": reset_seconds,
"is_authenticated": is_authenticated,
"require_login": require_login,
"message": message,
}
return can_send, info
async def increment(self, identity_key: str, is_authenticated: bool) -> dict:
"""
Increment field tương ứng (guest hoặc user).
"""
field = "user" if is_authenticated else "guest"
# Redis Increment
new_val = await self._increment_in_redis(identity_key, field)
# Memory Fallback
if new_val is None:
self._reset_memory_if_new_day()
if identity_key not in self._memory_storage:
self._memory_storage[identity_key] = {"guest": 0, "user": 0}
self._memory_storage[identity_key][field] += 1
# Trả về info mới nhất (gọi lại check_limit để đồng bộ logic tính toán)
# Tuy nhiên để tối ưu performance, ta tự tính lại nhanh cũng được.
# Nhưng gọi check_limit an toàn hơn cho đồng nhất output structure.
_, info = await self.check_limit(identity_key, is_authenticated)
logger.debug(
f"📈 Incr {field}: {identity_key} | "
f"Guest:{info['guest_used']} User:{info['user_used']} Total:{info['total_used']}"
)
return info
async def get_usage(self, identity_key: str, is_authenticated: bool) -> dict:
"""Wrapper gọi check_limit để lấy info (nhưng bỏ qua bool result)"""
_, info = await self.check_limit(identity_key, is_authenticated)
return info
async def reset(self, identity_key: str) -> bool:
"""Manually reset count (delete key)."""
try:
client = redis_cache.get_client()
if client:
redis_key = self._get_redis_key(identity_key)
await client.delete(redis_key)
self._memory_storage.pop(identity_key, None)
logger.info(f"🔄 Manual reset for {identity_key}")
return True
except Exception as e:
logger.error(f"Reset error: {e}")
return False
# =============================================================================
# SINGLETON INSTANCE
# =============================================================================
message_limit_service = MessageLimitService()
"""
Middleware Module - Gom Auth + Rate Limit
Singleton Pattern cho cả 2 services
"""
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from typing import TYPE_CHECKING
from fastapi import HTTPException, Request, status from fastapi import HTTPException, Request, status
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from common.clerk_auth import verify_clerk_token if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Public endpoints that don't require authentication
# =============================================================================
# CONFIGURATION
# =============================================================================
# Public endpoints - không cần auth
PUBLIC_PATHS = { PUBLIC_PATHS = {
"/", "/",
"/health", "/health",
...@@ -19,17 +30,20 @@ PUBLIC_PATHS = { ...@@ -19,17 +30,20 @@ PUBLIC_PATHS = {
"/redoc", "/redoc",
} }
# Paths that start with these prefixes are public # Public path prefixes
PUBLIC_PATH_PREFIXES = [ PUBLIC_PATH_PREFIXES = [
# Socket.IO removed - using SSE instead "/static",
"/mock",
] ]
# =============================================================================
# AUTH MIDDLEWARE CLASS
# =============================================================================
class ClerkAuthMiddleware(BaseHTTPMiddleware): class ClerkAuthMiddleware(BaseHTTPMiddleware):
""" """
Clerk Authentication Middleware Clerk Authentication Middleware
Tự động verify Clerk JWT tokens cho protected endpoints.
Attach user info vào request.state để routes có thể sử dụng.
Flow: Flow:
1. Frontend gửi request với Authorization: Bearer <clerk_token> 1. Frontend gửi request với Authorization: Bearer <clerk_token>
...@@ -42,63 +56,195 @@ class ClerkAuthMiddleware(BaseHTTPMiddleware): ...@@ -42,63 +56,195 @@ class ClerkAuthMiddleware(BaseHTTPMiddleware):
path = request.url.path path = request.url.path
method = request.method method = request.method
# ✅ Allow OPTIONS requests (CORS preflight) - không cần auth # ✅ Allow OPTIONS requests (CORS preflight)
if method == "OPTIONS": if method == "OPTIONS":
return await call_next(request) return await call_next(request)
# Skip public endpoints - không cần auth # Skip public endpoints
if path in PUBLIC_PATHS: if path in PUBLIC_PATHS:
return await call_next(request) return await call_next(request)
# Skip paths with public prefixes # Skip public path prefixes
if any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES): if any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES):
return await call_next(request) return await call_next(request)
# ✅ For protected endpoints, verify Clerk token và raise 401 nếu không có # ✅ Authentication Process
try: try:
# Get token from Authorization header
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
# ========== DEV MODE: Bypass auth nếu có X-Dev-User-Id header ========== # ========== DEV MODE: Bypass auth ==========
# ⚠️ CHỈ DÙNG CHO TEST/DEV - XÓA TRƯỚC KHI DEPLOY PRODUCTION
dev_user_id = request.headers.get("X-Dev-User-Id") dev_user_id = request.headers.get("X-Dev-User-Id")
if dev_user_id: if dev_user_id:
logger.warning(f"⚠️ DEV MODE: Using X-Dev-User-Id={dev_user_id} (bypassing Clerk auth)") logger.warning(f"⚠️ DEV MODE: Using X-Dev-User-Id={dev_user_id}")
request.state.user = {"user_id": dev_user_id, "clerk_user_id": dev_user_id} request.state.user = {"customer_id": dev_user_id}
request.state.user_id = dev_user_id request.state.user_id = dev_user_id
request.state.clerk_token = None request.state.is_authenticated = True
return await call_next(request) return await call_next(request)
# --- TRƯỜNG HỢP 1: KHÔNG CÓ TOKEN -> GUEST ---
if not auth_header or not auth_header.startswith("Bearer "): if not auth_header or not auth_header.startswith("Bearer "):
# Log để debug # Guest Mode (Không User ID, Không Auth)
logger.warning(f"⚠️ No Authorization header for {method} {path}. Headers: {dict(request.headers)}") # logger.debug(f"ℹ️ Guest access (no token) for {path}")
# Không có token → raise 401 request.state.user = None
raise HTTPException( request.state.user_id = None
status_code=status.HTTP_401_UNAUTHORIZED, request.state.is_authenticated = False
detail="Authentication required. Please provide a valid Clerk token in Authorization header.", return await call_next(request)
)
# Extract token # --- TRƯỜNG HỢP 2: CÓ TOKEN -> GỌI CANIFA VERIFY ---
token = auth_header.replace("Bearer ", "") token = auth_header.replace("Bearer ", "")
from fastapi.security import HTTPAuthorizationCredentials
credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=token) # Import Lazy để tránh circular import nếu có
from common.canifa_api import verify_canifa_token, extract_user_id_from_canifa_response
# ✅ Verify token và extract user info try:
user = await verify_clerk_token(request, credentials) # 1. Gọi API Canifa
user_data = await verify_canifa_token(token)
# ✅ Attach user info vào request.state để routes sử dụng
request.state.user = user # 2. Lấy User ID
request.state.user_id = user.get("user_id") or user.get("clerk_user_id") user_id = await extract_user_id_from_canifa_response(user_data)
request.state.clerk_token = token # Lưu token để forward cho Supabase RLS
if user_id:
# ✅ VERIFY THÀNH CÔNG -> USER VIP
request.state.user = user_data
request.state.user_id = user_id
request.state.token = token
request.state.is_authenticated = True
logger.debug(f"✅ Auth Success: User {user_id}")
else:
# ❌ VERIFY FAILED -> GUEST
logger.warning(f"⚠️ Invalid Canifa Token (No ID found) -> Guest Mode")
request.state.user = None
request.state.user_id = None
request.state.is_authenticated = False
logger.debug(f"✅ Authenticated user {request.state.user_id} for {path}") except Exception as e:
logger.error(f"❌ Canifa Auth Error: {e} -> Guest Mode")
request.state.user = None
request.state.user_id = None
request.state.is_authenticated = False
except HTTPException:
# Re-raise HTTPException (401, etc.)
raise
except Exception as e: except Exception as e:
logger.error(f"❌ Clerk auth verification failed for {path}: {e}") logger.error(f"❌ Middleware Unexpected Error: {e}")
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token") # Fallback an toàn: Guest mode
request.state.user = None
request.state.user_id = None
request.state.is_authenticated = False
return await call_next(request) return await call_next(request)
# =============================================================================
# MIDDLEWARE MANAGER - Singleton to manage all middlewares
# =============================================================================
class MiddlewareManager:
"""
Middleware Manager - Singleton Pattern
Quản lý và setup tất cả middlewares cho FastAPI app
Usage:
from common.middleware import middleware_manager
app = FastAPI()
middleware_manager.setup(app, enable_auth=True, enable_rate_limit=True)
"""
_instance: MiddlewareManager | None = None
_initialized: bool = False
def __new__(cls) -> MiddlewareManager:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
if MiddlewareManager._initialized:
return
self._auth_enabled = False
self._rate_limit_enabled = False
MiddlewareManager._initialized = True
logger.info("✅ MiddlewareManager initialized")
def setup(
self,
app: FastAPI,
*,
enable_auth: bool = True,
enable_rate_limit: bool = True,
enable_cors: bool = True,
cors_origins: list[str] | None = None,
) -> None:
"""
Setup tất cả middlewares cho FastAPI app.
Args:
app: FastAPI application
enable_auth: Bật Clerk authentication middleware
enable_rate_limit: Bật rate limiting
enable_cors: Bật CORS middleware
cors_origins: List origins cho CORS (default: ["*"])
Note:
Thứ tự middleware quan trọng! Middleware thêm sau sẽ chạy TRƯỚC.
Order: CORS → Auth → RateLimit → SlowAPI
"""
# 1. CORS Middleware (thêm cuối cùng để chạy đầu tiên)
if enable_cors:
self._setup_cors(app, cors_origins or ["*"])
# 2. Auth Middleware
if enable_auth:
self._setup_auth(app)
# 3. Rate Limit Middleware
if enable_rate_limit:
self._setup_rate_limit(app)
logger.info(
f"✅ Middlewares configured: "
f"CORS={enable_cors}, Auth={enable_auth}, RateLimit={enable_rate_limit}"
)
def _setup_cors(self, app: FastAPI, origins: list[str]) -> None:
"""Setup CORS middleware."""
from fastapi.middleware.cors import CORSMiddleware
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
logger.info(f"✅ CORS middleware enabled (origins: {origins})")
def _setup_auth(self, app: FastAPI) -> None:
"""Setup Clerk auth middleware."""
app.add_middleware(ClerkAuthMiddleware)
self._auth_enabled = True
logger.info("✅ Clerk Auth middleware enabled")
def _setup_rate_limit(self, app: FastAPI) -> None:
"""Setup rate limiting."""
from common.rate_limit import rate_limit_service
rate_limit_service.setup(app)
self._rate_limit_enabled = True
logger.info("✅ Rate Limit middleware enabled")
@property
def is_auth_enabled(self) -> bool:
return self._auth_enabled
@property
def is_rate_limit_enabled(self) -> bool:
return self._rate_limit_enabled
# =============================================================================
# SINGLETON INSTANCE
# =============================================================================
middleware_manager = MiddlewareManager()
"""
Rate Limiting Service - Singleton Pattern
Sử dụng SlowAPI với Redis backend (production) hoặc Memory (dev)
"""
from __future__ import annotations
import logging
import os
from datetime import datetime, timedelta
from typing import TYPE_CHECKING
from fastapi import Request
from fastapi.responses import JSONResponse
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address
if TYPE_CHECKING:
from fastapi import FastAPI
logger = logging.getLogger(__name__)
class RateLimitService:
"""
Rate Limiting Service - Singleton Pattern
Usage:
# Trong server.py
from common.rate_limit import RateLimitService
rate_limiter = RateLimitService()
rate_limiter.setup(app)
# Trong route
from common.rate_limit import RateLimitService
@router.post("/chat")
@RateLimitService().limiter.limit("10/minute")
async def chat(request: Request):
...
"""
_instance: RateLimitService | None = None
_initialized: bool = False
# =========================================================================
# SINGLETON PATTERN
# =========================================================================
def __new__(cls) -> RateLimitService:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self) -> None:
# Chỉ init một lần
if RateLimitService._initialized:
return
# Configuration
self.storage_uri = os.getenv("RATE_STORAGE_URI", "memory://")
self.default_limits = ["100/hour", "30/minute"]
self.block_duration_minutes = int(os.getenv("RATE_LIMIT_BLOCK_MINUTES", "5"))
# Paths không áp dụng rate limit
self.exempt_paths = {
"/",
"/health",
"/docs",
"/openapi.json",
"/redoc",
}
self.exempt_prefixes = ["/static", "/mock"]
# In-memory blocklist (có thể chuyển sang Redis)
self._blocklist: dict[str, datetime] = {}
# Create limiter instance
self.limiter = Limiter(
key_func=self._get_client_identifier,
storage_uri=self.storage_uri,
default_limits=self.default_limits,
)
RateLimitService._initialized = True
logger.info(f"✅ RateLimitService initialized (storage: {self.storage_uri})")
# =========================================================================
# CLIENT IDENTIFIER
# =========================================================================
@staticmethod
def _get_client_identifier(request: Request) -> str:
"""
Lấy client identifier cho rate limiting.
Ưu tiên: user_id (authenticated) > device_id > IP address
"""
# 1. Nếu đã authenticated → dùng user_id
if hasattr(request.state, "user_id") and request.state.user_id:
return f"user:{request.state.user_id}"
# 2. Nếu có device_id trong header → dùng device_id
device_id = request.headers.get("device_id")
if device_id:
return f"device:{device_id}"
# 3. Fallback → IP address
try:
return f"ip:{get_remote_address(request)}"
except Exception:
if request.client:
return f"ip:{request.client.host}"
return "unknown"
# =========================================================================
# BLOCKLIST MANAGEMENT
# =========================================================================
def is_blocked(self, key: str) -> tuple[bool, int]:
"""
Check if client is blocked.
Returns: (is_blocked, retry_after_seconds)
"""
now = datetime.utcnow()
blocked_until = self._blocklist.get(key)
if blocked_until:
if blocked_until > now:
retry_after = int((blocked_until - now).total_seconds())
return True, retry_after
else:
# Block expired
self._blocklist.pop(key, None)
return False, 0
def block_client(self, key: str) -> int:
"""
Block client for configured duration.
Returns: retry_after_seconds
"""
self._blocklist[key] = datetime.utcnow() + timedelta(minutes=self.block_duration_minutes)
return self.block_duration_minutes * 60
def unblock_client(self, key: str) -> None:
"""Unblock client manually."""
self._blocklist.pop(key, None)
# =========================================================================
# PATH CHECKING
# =========================================================================
def is_exempt(self, path: str) -> bool:
"""Check if path is exempt from rate limiting."""
if path in self.exempt_paths:
return True
return any(path.startswith(prefix) for prefix in self.exempt_prefixes)
# =========================================================================
# SETUP FOR FASTAPI APP
# =========================================================================
def setup(self, app: FastAPI) -> None:
"""
Setup rate limiting cho FastAPI app.
Gọi trong server.py sau khi tạo app.
"""
# Attach limiter to app state
app.state.limiter = self.limiter
app.state.rate_limit_service = self
# Register middleware
self._register_block_middleware(app)
self._register_exception_handler(app)
# Add SlowAPI middleware (PHẢI thêm SAU custom middlewares)
app.add_middleware(SlowAPIMiddleware)
logger.info("✅ Rate limiting middleware registered")
def _register_block_middleware(self, app: FastAPI) -> None:
"""Register middleware to check blocklist."""
@app.middleware("http")
async def rate_limit_block_middleware(request: Request, call_next):
path = request.url.path
# Skip exempt paths
if self.is_exempt(path):
return await call_next(request)
# Bypass header cho testing
if request.headers.get("X-Bypass-RateLimit") == "1":
return await call_next(request)
# Check blocklist
key = self._get_client_identifier(request)
is_blocked, retry_after = self.is_blocked(key)
if is_blocked:
return JSONResponse(
status_code=429,
content={
"detail": "Quá số lượt cho phép. Vui lòng thử lại sau.",
"retry_after_seconds": retry_after,
},
headers={"Retry-After": str(retry_after)},
)
return await call_next(request)
def _register_exception_handler(self, app: FastAPI) -> None:
"""Register exception handler for rate limit exceeded."""
@app.exception_handler(RateLimitExceeded)
async def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded):
key = self._get_client_identifier(request)
retry_after = self.block_client(key)
logger.warning(f"⚠️ Rate limit exceeded for {key}, blocked for {self.block_duration_minutes} minutes")
return JSONResponse(
status_code=429,
content={
"detail": "Quá số lượt cho phép. Vui lòng thử lại sau.",
"retry_after_seconds": retry_after,
},
headers={"Retry-After": str(retry_after)},
)
# =============================================================================
# SINGLETON INSTANCE - Import trực tiếp để dùng
# =============================================================================
rate_limit_service = RateLimitService()
"""
User Identity Helper
Xác định user identity từ request
Design:
- Có user_id: Langfuse User ID = user_id, metadata = {device_id: "xxx", is_authenticated: true}
- Không user_id: Langfuse User ID = device_id, metadata = {device_id: "xxx", is_authenticated: false}
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from datetime import datetime
from fastapi import Request
logger = logging.getLogger(__name__)
@dataclass
class UserIdentity:
"""User identity với các thông tin cần thiết"""
# ID chính dùng cho Langfuse, history, rate limit
primary_id: str
# Device ID (luôn có)
device_id: str
# User ID từ token (chỉ có khi đã login)
user_id: str | None
# Đã login hay chưa
is_authenticated: bool
@property
def langfuse_user_id(self) -> str:
"""User ID cho Langfuse tracking"""
return self.primary_id
@property
def langfuse_session_id(self) -> str:
"""Session ID cho Langfuse (theo device + ngày)"""
today = datetime.now().strftime("%Y%m%d")
return f"{self.device_id}-{today}"
@property
def langfuse_metadata(self) -> dict:
"""Metadata cho Langfuse"""
return {
"device_id": self.device_id,
"is_authenticated": self.is_authenticated,
}
@property
def langfuse_tags(self) -> list[str]:
"""Tags cho Langfuse"""
tags = ["chatbot", "production"]
tags.append("customer" if self.is_authenticated else "guest")
return tags
@property
def history_key(self) -> str:
"""Key để lưu/load chat history (theo device_id)"""
return self.device_id
@property
def rate_limit_key(self) -> str:
"""Key cho rate limiting (luôn theo device_id, limit tùy login status)"""
return self.device_id
def get_user_identity(request: Request) -> UserIdentity:
"""
Extract user identity từ request.
Logic:
- Có user_id (từ token) → primary_id = user_id
- Không có → primary_id = device_id
Args:
request: FastAPI Request object
Returns:
UserIdentity object
"""
# 1. Lấy device_id từ header (luôn có)
device_id = request.headers.get("device_id", "")
if not device_id:
device_id = f"unknown_{request.client.host}" if request.client else "unknown"
# 2. Lấy user_id từ token (middleware đã parse)
user_id = None
is_authenticated = False
if hasattr(request.state, "user_id") and request.state.user_id:
user_id = request.state.user_id
is_authenticated = True
# 3. Primary ID
primary_id = user_id if user_id else device_id
identity = UserIdentity(
primary_id=primary_id,
device_id=device_id,
user_id=user_id,
is_authenticated=is_authenticated,
)
logger.debug(
f"UserIdentity: langfuse_user_id={identity.langfuse_user_id}, "
f"metadata={identity.langfuse_metadata}"
)
return identity
server {
listen 80;
server_name _; #bot ip server
# Log files
access_log /var/log/nginx/chatbot_access.log;
error_log /var/log/nginx/chatbot_error.log;
location /chat {
# allow 1.2.3.4;
# deny all;
proxy_pass http://127.0.0.1:5000;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_read_timeout 300s;
proxy_connect_timeout 300s;
proxy_send_timeout 300s;
}
# endpoit for history
location /history {
# allow 1.2.3.4;
# deny all;
proxy_pass http://127.0.0.1:5000;
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
}
location / {
proxy_pass http://127.0.0.1:8000;
}
}
# /etc/nginx/sites-available/your-api
# Rate limit zones
limit_req_zone $binary_remote_addr zone=ip_limit:10m rate=100r/h;
# Upstream backend servers
upstream backend {
server localhost:8000;
# Nếu có nhiều backend servers:
# server localhost:8001;
# server localhost:8002;
}
# Redirect HTTP to HTTPS
server {
listen 80;
server_name api.yourdomain.com;
# Redirect to HTTPS
return 301 https://$server_name$request_uri;
}
# Main HTTPS server
server {
listen 443 ssl http2;
server_name api.yourdomain.com;
# SSL certificates (Let's Encrypt)
ssl_certificate /etc/letsencrypt/live/api.yourdomain.com/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/api.yourdomain.com/privkey.pem;
# SSL settings
ssl_protocols TLSv1.2 TLSv1.3;
ssl_ciphers HIGH:!aNULL:!MD5;
ssl_prefer_server_ciphers on;
# Security headers
add_header Strict-Transport-Security "max-age=31536000" always;
add_header X-Frame-Options "SAMEORIGIN" always;
add_header X-Content-Type-Options "nosniff" always;
add_header X-XSS-Protection "1; mode=block" always;
# Logging
access_log /var/log/nginx/api_access.log;
error_log /var/log/nginx/api_error.log;
# Main API endpoint
location /api/ {
# Rate limiting (100 requests/hour per IP)
limit_req zone=ip_limit burst=20 nodelay;
limit_req_status 429;
# CORS headers (if needed)
add_header 'Access-Control-Allow-Origin' '*' always;
add_header 'Access-Control-Allow-Methods' 'GET, POST, OPTIONS' always;
add_header 'Access-Control-Allow-Headers' 'Authorization, Content-Type, X-Anonymous-ID' always;
# Handle preflight
if ($request_method = 'OPTIONS') {
return 204;
}
# Proxy to backend
proxy_pass http://backend;
# Pass headers
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# Pass auth headers
proxy_set_header Authorization $http_authorization;
proxy_set_header X-Anonymous-ID $http_x_anonymous_id;
# Timeouts
proxy_connect_timeout 60s;
proxy_send_timeout 60s;
proxy_read_timeout 60s;
# Buffer settings
proxy_buffering off;
proxy_request_buffering off;
}
# Health check endpoint (không rate limit)
location /health {
access_log off;
proxy_pass http://backend/health;
}
# Custom error pages
error_page 429 /429.json;
location = /429.json {
internal;
return 429 '{"error":"Too many requests. Please try again later.","retry_after":3600}';
add_header Content-Type application/json always;
add_header Retry-After 3600 always;
}
error_page 502 503 504 /50x.json;
location = /50x.json {
internal;
return 502 '{"error":"Service temporarily unavailable"}';
add_header Content-Type application/json always;
}
}
\ No newline at end of file
...@@ -10,12 +10,13 @@ import logging ...@@ -10,12 +10,13 @@ import logging
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from api.chatbot_route import router as chatbot_router from api.chatbot_route import router as chatbot_router
from api.conservation_route import router as conservation_router from api.conservation_route import router as conservation_router
from common.cache import redis_cache
from common.langfuse_client import get_langfuse_client from common.langfuse_client import get_langfuse_client
from common.middleware import middleware_manager
from config import PORT from config import PORT
# Configure Logging # Configure Logging
...@@ -39,13 +40,26 @@ app = FastAPI( ...@@ -39,13 +40,26 @@ app = FastAPI(
version="1.0.0", version="1.0.0",
) )
logger.info("✅ Clerk Authentication middleware DISABLED (for testing)")
app.add_middleware( # =============================================================================
CORSMiddleware, # STARTUP EVENT - Initialize Redis Cache
allow_origins=["*"], # =============================================================================
allow_credentials=True, @app.on_event("startup")
allow_methods=["*"], async def startup_event():
allow_headers=["*"], """Initialize Redis cache on startup."""
await redis_cache.initialize()
logger.info("✅ Redis cache initialized for message limit")
# =============================================================================
# MIDDLEWARE SETUP - Gom Auth + RateLimit + CORS vào một chỗ
# =============================================================================
middleware_manager.setup(
app,
enable_auth=True, # 👈 Bật lại Auth để test logic Guest/User
enable_rate_limit=True, # 👈 Bật rate limiting
enable_cors=True, # 👈 Bật CORS
cors_origins=["*"], # 👈 Trong production nên limit origins
) )
app.include_router(conservation_router) app.include_router(conservation_router)
...@@ -105,6 +119,6 @@ if __name__ == "__main__": ...@@ -105,6 +119,6 @@ if __name__ == "__main__":
host="0.0.0.0", host="0.0.0.0",
port=PORT, port=PORT,
reload=ENABLE_RELOAD, reload=ENABLE_RELOAD,
# reload_dirs=reload_dirs, reload_dirs=reload_dirs,
log_level="info", log_level="info",
) )
...@@ -441,7 +441,8 @@ ...@@ -441,7 +441,8 @@
<div class="header"> <div class="header">
<h2>🤖 Canifa AI Chat</h2> <h2>🤖 Canifa AI Chat</h2>
<div class="config-area"> <div class="config-area">
<input type="text" id="userId" placeholder="Enter User ID" value="" onblur="saveUserId()" onchange="saveUserId()"> <input type="text" id="userId" placeholder="Enter User ID" value="" onblur="saveUserId()"
onchange="saveUserId()">
<button onclick="loadHistory(true)">↻ History</button> <button onclick="loadHistory(true)">↻ History</button>
<button onclick="clearUI()" style="background: #d32f2f;">✗ Clear UI</button> <button onclick="clearUI()" style="background: #d32f2f;">✗ Clear UI</button>
</div> </div>
...@@ -515,20 +516,28 @@ ...@@ -515,20 +516,28 @@
if (Array.isArray(messages) && messages.length > 0) { if (Array.isArray(messages) && messages.length > 0) {
currentCursor = cursor; currentCursor = cursor;
const batch = [...messages].reverse();
if (isRefresh) { if (isRefresh) {
// Refresh: reverse để oldest ở trên, newest ở dưới
const batch = [...messages].reverse();
batch.forEach(msg => appendMessage(msg, 'bottom')); batch.forEach(msg => appendMessage(msg, 'bottom'));
setTimeout(() => { setTimeout(() => {
const chatBox = document.getElementById('chatBox'); const chatBox = document.getElementById('chatBox');
chatBox.scrollTop = chatBox.scrollHeight; chatBox.scrollTop = chatBox.scrollHeight;
}, 100); }, 100);
} else { } else {
// Keep scroll position relative to bottom content // Load more: messages từ API theo DESC (newest first của batch cũ)
// Ví dụ: [AI 95, User 95, AI 94, User 94, ...]
// Prepend từ index 0: mỗi lần prepend sẽ đẩy cái trước xuống
// Kết quả: User 94 → AI 94 → User 95 → AI 95 (oldest ở trên)
const chatBox = document.getElementById('chatBox'); const chatBox = document.getElementById('chatBox');
const oldHeight = chatBox.scrollHeight; const oldHeight = chatBox.scrollHeight;
batch.forEach(msg => appendMessage(msg, 'top')); // Loop thuận: prepend từng message từ đầu mảng
// Element đầu (newest của batch) sẽ bị đẩy xuống bởi các element sau
for (let i = 0; i < messages.length; i++) {
appendMessage(messages[i], 'top');
}
// Adjust scroll to keep view stable // Adjust scroll to keep view stable
chatBox.scrollTop = chatBox.scrollHeight - oldHeight; chatBox.scrollTop = chatBox.scrollHeight - oldHeight;
...@@ -843,7 +852,7 @@ ...@@ -843,7 +852,7 @@
} }
// Load user ID from localStorage on page load and auto-load history // Load user ID from localStorage on page load and auto-load history
window.onload = function() { window.onload = function () {
const savedUserId = localStorage.getItem('canifa_user_id'); const savedUserId = localStorage.getItem('canifa_user_id');
if (savedUserId) { if (savedUserId) {
document.getElementById('userId').value = savedUserId; document.getElementById('userId').value = savedUserId;
......
"""Test message limit - Guest limit = 3"""
import requests
DEVICE_ID = "limit-test-002"
API_URL = "http://localhost:5000/api/agent/chat"
print("=" * 50)
print("TEST MESSAGE LIMIT (Guest = 3 tin/ngày)")
print("=" * 50)
print(f"Device ID: {DEVICE_ID}")
print()
for i in range(5): # Gửi 5 tin để thấy bị chặn
print(f"--- Tin nhắn #{i+1} ---")
response = requests.post(
API_URL,
json={"user_query": f"test message {i+1}"},
headers={"device_id": DEVICE_ID}
)
data = response.json()
if data.get("status") == "success":
limit_info = data.get("limit_info", {})
print(f"✅ Thành công!")
print(f" Used: {limit_info.get('used')}/{limit_info.get('limit')}")
print(f" Remaining: {limit_info.get('remaining')}")
else:
print(f"❌ Bị chặn!")
print(f" Error: {data.get('error_code')}")
print(f" Message: {data.get('message')}")
print(f" Require login: {data.get('require_login')}")
limit_info = data.get("limit_info", {})
if limit_info:
print(f" Used: {limit_info.get('used')}/{limit_info.get('limit')}")
print()
print("=" * 50)
print("TEST HOÀN TẤT!")
"""
TEST SCRIPT FOR MESSAGE LIMIT V2
Logic:
- Guest Limit: 10
- Total Limit (Guest + User): 100
- Support Memory & Redis
"""
import requests
import time
API_URL = "http://localhost:5000/api/agent/chat"
DEVICE_ID = "v2_test_device_003"
USER_TOKEN = "Bearer test_token_123" # Mock token (nếu dev mode support)
print(f"🚀 START TEST V2 - Device: {DEVICE_ID} 🚀")
print("=" * 60)
def send_msg(i, is_login=False):
headers = {"device_id": DEVICE_ID}
if is_login:
headers["X-Dev-User-Id"] = "user_123" # Bypass auth middleware
user_status = "USER "
else:
user_status = "GUEST"
print(f"📩 [{user_status}] Msg #{i} sending...", end=" ")
try:
resp = requests.post(
API_URL,
json={"user_query": f"test msg {i}"},
headers=headers
)
data = resp.json()
if data.get("status") == "success":
limit = data['limit_info']
print(f"✅ OK! Used: {limit['used']}/{limit['limit']} | Remaining: {limit['remaining']}")
else:
print(f"❌ BLOCKED! {data.get('message')}")
if 'limit_info' in data:
print(f" Info: {data['limit_info']}")
except Exception as e:
print(f"ERROR: {e}")
# 1. Gửi 3 tin Guest
print("\n--- PHASE 1: GUEST (3 msgs) ---")
for i in range(1, 4):
send_msg(i, is_login=False)
# 2. Login và gửi tiếp
print("\n--- PHASE 2: LOGIN (USER) ---")
send_msg(4, is_login=True)
# 3. Check info
print("\n--- CHECK INFO ---")
try:
resp = requests.get(
"http://localhost:5000/api/message-limit",
headers={"device_id": DEVICE_ID} # Check as guest
)
print("Guest View:", resp.json())
resp = requests.get(
"http://localhost:5000/api/message-limit",
headers={"device_id": DEVICE_ID, "X-Dev-User-Id": "user_123"} # Check as user
)
print("User View:", resp.json())
except:
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment