Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
C
chatbot canifa
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
1
Merge Requests
1
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Vũ Hoàng Anh
chatbot canifa
Commits
49f43a45
Commit
49f43a45
authored
Jan 27, 2026
by
Vũ Hoàng Anh
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Fix mock API routing and retriever alias
parent
566ee233
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
899 additions
and
899 deletions
+899
-899
mock_api_route.py
backend/api/mock_api_route.py
+201
-202
middleware.py
backend/common/middleware.py
+335
-334
rate_limit.py
backend/common/rate_limit.py
+238
-238
server.py
backend/server.py
+125
-125
No files found.
backend/api/mock_api_route.py
View file @
49f43a45
import
asyncio
import
json
import
logging
import
time
from
fastapi
import
APIRouter
,
BackgroundTasks
,
HTTPException
from
pydantic
import
BaseModel
from
agent.tools.data_retrieval_tool
import
SearchItem
,
data_retrieval_tool
logger
=
logging
.
getLogger
(
__name__
)
router
=
APIRouter
()
# --- HELPERS ---
async
def
retry_with_backoff
(
coro_fn
,
max_retries
=
3
,
backoff_factor
=
2
):
"""Retry async function with exponential backoff"""
for
attempt
in
range
(
max_retries
):
try
:
return
await
coro_fn
()
except
Exception
as
e
:
if
attempt
==
max_retries
-
1
:
raise
wait_time
=
backoff_factor
**
attempt
logger
.
warning
(
f
"⚠️ Attempt {attempt + 1} failed: {e!s}, retrying in {wait_time}s..."
)
await
asyncio
.
sleep
(
wait_time
)
# --- MODELS ---
class
MockQueryRequest
(
BaseModel
):
user_query
:
str
user_id
:
str
|
None
=
"test_user"
session_id
:
str
|
None
=
None
class
MockDBRequest
(
BaseModel
):
query
:
str
|
None
=
None
magento_ref_code
:
str
|
None
=
None
price_min
:
float
|
None
=
None
price_max
:
float
|
None
=
None
top_k
:
int
=
10
class
MockRetrieverRequest
(
BaseModel
):
user_query
:
str
price_min
:
float
|
None
=
None
price_max
:
float
|
None
=
None
magento_ref_code
:
str
|
None
=
None
user_id
:
str
|
None
=
"test_user"
session_id
:
str
|
None
=
None
# --- MOCK LLM RESPONSES (không gọi OpenAI) ---
MOCK_AI_RESPONSES
=
[
"Dựa trên tìm kiếm của bạn, tôi tìm thấy các sản phẩm phù hợp với nhu cầu của bạn. Những mặt hàng này có chất lượng tốt và giá cả phải chăng."
,
"Tôi gợi ý cho bạn những sản phẩm sau. Chúng đều là những lựa chọn phổ biến và nhận được đánh giá cao từ khách hàng."
,
"Dựa trên tiêu chí tìm kiếm của bạn, đây là những sản phẩm tốt nhất mà tôi có thể giới thiệu."
,
"Những sản phẩm này hoàn toàn phù hợp với yêu cầu của bạn. Hãy xem chi tiết để chọn sản phẩm yêu thích nhất."
,
"Tôi đã tìm được các mặt hàng tuyệt vời cho bạn. Hãy kiểm tra chúng để tìm ra lựa chọn tốt nhất."
,
]
# --- ENDPOINTS ---
from
agent.mock_controller
import
mock_chat_controller
@
router
.
post
(
"/mock/agent/chat"
,
summary
=
"Mock Agent Chat (Real Tools + Fake LLM)"
)
async
def
mock_chat
(
req
:
MockQueryRequest
,
background_tasks
:
BackgroundTasks
):
"""
Mock Agent Chat using mock_chat_controller:
- ✅ Real embedding + vector search (data_retrieval_tool THẬT)
- ✅ Real products from StarRocks
- ❌ Fake LLM response (no OpenAI cost)
- Perfect for stress testing + end-to-end testing
"""
try
:
logger
.
info
(
f
"🚀 [Mock Agent Chat] Starting with query: {req.user_query}"
)
result
=
await
mock_chat_controller
(
query
=
req
.
user_query
,
user_id
=
req
.
user_id
or
"test_user"
,
background_tasks
=
background_tasks
,
)
return
{
"status"
:
"success"
,
"user_query"
:
req
.
user_query
,
"user_id"
:
req
.
user_id
,
"session_id"
:
req
.
session_id
,
**
result
,
# Include status, ai_response, product_ids, etc.
}
except
Exception
as
e
:
logger
.
error
(
f
"❌ Error in mock agent chat: {e!s}"
,
exc_info
=
True
)
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"Mock Agent Chat Error: {e!s}"
)
@
router
.
post
(
"/mock/db/search"
,
summary
=
"Real Data Retrieval Tool (Agent Tool)"
)
async
def
mock_db_search
(
req
:
MockDBRequest
):
"""
Dùng `data_retrieval_tool` THẬT từ Agent:
- Nếu có magento_ref_code → CODE SEARCH (không cần embedding)
- Nếu có query → HYDE SEMANTIC SEARCH (embedding + vector search)
- Lọc theo giá nếu có price_min/price_max
- Trả về sản phẩm thực từ StarRocks
Format input giống SearchItem của agent tool.
"""
try
:
logger
.
info
(
"📍 Data Retrieval Tool called"
)
start_time
=
time
.
time
()
# Xây dựng SearchItem từ request
search_item
=
SearchItem
(
query
=
req
.
query
or
"sản phẩm"
,
magento_ref_code
=
req
.
magento_ref_code
,
price_min
=
req
.
price_min
,
price_max
=
req
.
price_max
,
action
=
"search"
,
)
logger
.
info
(
f
"🔧 Search params: {search_item.dict(exclude_none=True)}"
)
# Gọi data_retrieval_tool THẬT với retry
result_json
=
await
retry_with_backoff
(
lambda
:
data_retrieval_tool
.
ainvoke
({
"searches"
:
[
search_item
]}),
max_retries
=
3
)
result
=
json
.
loads
(
result_json
)
elapsed_time
=
time
.
time
()
-
start_time
logger
.
info
(
f
"✅ Data Retrieval completed in {elapsed_time:.3f}s"
)
return
{
"status"
:
result
.
get
(
"status"
,
"success"
),
"search_params"
:
search_item
.
dict
(
exclude_none
=
True
),
"total_results"
:
len
(
result
.
get
(
"results"
,
[{}])[
0
]
.
get
(
"products"
,
[])),
"products"
:
result
.
get
(
"results"
,
[{}])[
0
]
.
get
(
"products"
,
[]),
"processing_time_ms"
:
round
(
elapsed_time
*
1000
,
2
),
"raw_result"
:
result
,
}
except
Exception
as
e
:
logger
.
error
(
f
"❌ Error in DB search: {e!s}"
,
exc_info
=
True
)
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"DB Search Error: {e!s}"
)
@
router
.
post
(
"/mock/retriverdb"
,
summary
=
"Real Embedding + Real DB Vector Search"
)
import
asyncio
import
json
import
logging
import
time
from
fastapi
import
APIRouter
,
BackgroundTasks
,
HTTPException
from
pydantic
import
BaseModel
from
agent.tools.data_retrieval_tool
import
SearchItem
,
data_retrieval_tool
logger
=
logging
.
getLogger
(
__name__
)
router
=
APIRouter
()
# --- HELPERS ---
async
def
retry_with_backoff
(
coro_fn
,
max_retries
=
3
,
backoff_factor
=
2
):
"""Retry async function with exponential backoff"""
for
attempt
in
range
(
max_retries
):
try
:
return
await
coro_fn
()
except
Exception
as
e
:
if
attempt
==
max_retries
-
1
:
raise
wait_time
=
backoff_factor
**
attempt
logger
.
warning
(
f
"⚠️ Attempt {attempt + 1} failed: {e!s}, retrying in {wait_time}s..."
)
await
asyncio
.
sleep
(
wait_time
)
# --- MODELS ---
class
MockQueryRequest
(
BaseModel
):
user_query
:
str
user_id
:
str
|
None
=
"test_user"
session_id
:
str
|
None
=
None
class
MockDBRequest
(
BaseModel
):
query
:
str
|
None
=
None
magento_ref_code
:
str
|
None
=
None
price_min
:
float
|
None
=
None
price_max
:
float
|
None
=
None
top_k
:
int
=
10
class
MockRetrieverRequest
(
BaseModel
):
user_query
:
str
price_min
:
float
|
None
=
None
price_max
:
float
|
None
=
None
magento_ref_code
:
str
|
None
=
None
user_id
:
str
|
None
=
"test_user"
session_id
:
str
|
None
=
None
# --- MOCK LLM RESPONSES (không gọi OpenAI) ---
MOCK_AI_RESPONSES
=
[
"Dựa trên tìm kiếm của bạn, tôi tìm thấy các sản phẩm phù hợp với nhu cầu của bạn. Những mặt hàng này có chất lượng tốt và giá cả phải chăng."
,
"Tôi gợi ý cho bạn những sản phẩm sau. Chúng đều là những lựa chọn phổ biến và nhận được đánh giá cao từ khách hàng."
,
"Dựa trên tiêu chí tìm kiếm của bạn, đây là những sản phẩm tốt nhất mà tôi có thể giới thiệu."
,
"Những sản phẩm này hoàn toàn phù hợp với yêu cầu của bạn. Hãy xem chi tiết để chọn sản phẩm yêu thích nhất."
,
"Tôi đã tìm được các mặt hàng tuyệt vời cho bạn. Hãy kiểm tra chúng để tìm ra lựa chọn tốt nhất."
,
]
# --- ENDPOINTS ---
from
agent.mock_controller
import
mock_chat_controller
@
router
.
post
(
"/api/mock/agent/chat"
,
summary
=
"Mock Agent Chat (Real Tools + Fake LLM)"
)
async
def
mock_chat
(
req
:
MockQueryRequest
,
background_tasks
:
BackgroundTasks
):
"""
Mock Agent Chat using mock_chat_controller:
- ✅ Real embedding + vector search (data_retrieval_tool THẬT)
- ✅ Real products from StarRocks
- ❌ Fake LLM response (no OpenAI cost)
- Perfect for stress testing + end-to-end testing
"""
try
:
logger
.
info
(
f
"🚀 [Mock Agent Chat] Starting with query: {req.user_query}"
)
result
=
await
mock_chat_controller
(
query
=
req
.
user_query
,
user_id
=
req
.
user_id
or
"test_user"
,
background_tasks
=
background_tasks
,
)
return
{
"status"
:
"success"
,
"user_query"
:
req
.
user_query
,
"user_id"
:
req
.
user_id
,
"session_id"
:
req
.
session_id
,
**
result
,
# Include status, ai_response, product_ids, etc.
}
except
Exception
as
e
:
logger
.
error
(
f
"❌ Error in mock agent chat: {e!s}"
,
exc_info
=
True
)
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"Mock Agent Chat Error: {e!s}"
)
@
router
.
post
(
"/api/mock/db/search"
,
summary
=
"Real Data Retrieval Tool (Agent Tool)"
)
async
def
mock_db_search
(
req
:
MockDBRequest
):
"""
Dùng `data_retrieval_tool` THẬT từ Agent:
- Nếu có magento_ref_code → CODE SEARCH (không cần embedding)
- Nếu có query → HYDE SEMANTIC SEARCH (embedding + vector search)
- Lọc theo giá nếu có price_min/price_max
- Trả về sản phẩm thực từ StarRocks
Format input giống SearchItem của agent tool.
"""
try
:
logger
.
info
(
"📍 Data Retrieval Tool called"
)
start_time
=
time
.
time
()
# Xây dựng SearchItem từ request
search_item
=
SearchItem
(
query
=
req
.
query
or
"sản phẩm"
,
magento_ref_code
=
req
.
magento_ref_code
,
price_min
=
req
.
price_min
,
price_max
=
req
.
price_max
,
action
=
"search"
,
)
logger
.
info
(
f
"🔧 Search params: {search_item.dict(exclude_none=True)}"
)
# Gọi data_retrieval_tool THẬT với retry
result_json
=
await
retry_with_backoff
(
lambda
:
data_retrieval_tool
.
ainvoke
({
"searches"
:
[
search_item
]}),
max_retries
=
3
)
result
=
json
.
loads
(
result_json
)
elapsed_time
=
time
.
time
()
-
start_time
logger
.
info
(
f
"✅ Data Retrieval completed in {elapsed_time:.3f}s"
)
return
{
"status"
:
result
.
get
(
"status"
,
"success"
),
"search_params"
:
search_item
.
dict
(
exclude_none
=
True
),
"total_results"
:
len
(
result
.
get
(
"results"
,
[{}])[
0
]
.
get
(
"products"
,
[])),
"products"
:
result
.
get
(
"results"
,
[{}])[
0
]
.
get
(
"products"
,
[]),
"processing_time_ms"
:
round
(
elapsed_time
*
1000
,
2
),
"raw_result"
:
result
,
}
except
Exception
as
e
:
logger
.
error
(
f
"❌ Error in DB search: {e!s}"
,
exc_info
=
True
)
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"DB Search Error: {e!s}"
)
@
router
.
post
(
"/api/mock/retrieverdb"
,
summary
=
"Real Embedding + Real DB Vector Search"
)
@
router
.
post
(
"/api/mock/retriverdb"
,
summary
=
"Real Embedding + Real DB Vector Search (Legacy)"
)
async
def
mock_retriever_db
(
req
:
MockRetrieverRequest
):
"""
API thực tế để test Retriever + DB Search (dùng agent tool):
- Lấy query từ user
- Embedding THẬT (gọi OpenAI embedding trong tool)
- Vector search THẬT trong StarRocks
- Trả về kết quả sản phẩm thực (bỏ qua LLM)
Dùng để test performance của embedding + vector search riêng biệt.
"""
try
:
logger
.
info
(
f
"📍 Retriever DB started: {req.user_query}"
)
start_time
=
time
.
time
()
# Xây dựng SearchItem từ request
search_item
=
SearchItem
(
query
=
req
.
user_query
,
magento_ref_code
=
req
.
magento_ref_code
,
price_min
=
req
.
price_min
,
price_max
=
req
.
price_max
,
action
=
"search"
,
)
logger
.
info
(
f
"🔧 Retriever params: {search_item.dict(exclude_none=True)}"
)
# Gọi data_retrieval_tool THẬT (embedding + vector search) với retry
result_json
=
await
retry_with_backoff
(
lambda
:
data_retrieval_tool
.
ainvoke
({
"searches"
:
[
search_item
]}),
max_retries
=
3
)
result
=
json
.
loads
(
result_json
)
elapsed_time
=
time
.
time
()
-
start_time
logger
.
info
(
f
"✅ Retriever completed in {elapsed_time:.3f}s"
)
# Parse kết quả
search_results
=
result
.
get
(
"results"
,
[{}])[
0
]
products
=
search_results
.
get
(
"products"
,
[])
return
{
"status"
:
result
.
get
(
"status"
,
"success"
),
"user_query"
:
req
.
user_query
,
"user_id"
:
req
.
user_id
,
"session_id"
:
req
.
session_id
,
"search_params"
:
search_item
.
dict
(
exclude_none
=
True
),
"total_results"
:
len
(
products
),
"products"
:
products
,
"processing_time_ms"
:
round
(
elapsed_time
*
1000
,
2
),
}
except
Exception
as
e
:
logger
.
error
(
f
"❌ Error in retriever DB: {e!s}"
,
exc_info
=
True
)
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"Retriever DB Error: {e!s}"
)
"""
API thực tế để test Retriever + DB Search (dùng agent tool):
- Lấy query từ user
- Embedding THẬT (gọi OpenAI embedding trong tool)
- Vector search THẬT trong StarRocks
- Trả về kết quả sản phẩm thực (bỏ qua LLM)
Dùng để test performance của embedding + vector search riêng biệt.
"""
try
:
logger
.
info
(
f
"📍 Retriever DB started: {req.user_query}"
)
start_time
=
time
.
time
()
# Xây dựng SearchItem từ request
search_item
=
SearchItem
(
query
=
req
.
user_query
,
magento_ref_code
=
req
.
magento_ref_code
,
price_min
=
req
.
price_min
,
price_max
=
req
.
price_max
,
action
=
"search"
,
)
logger
.
info
(
f
"🔧 Retriever params: {search_item.dict(exclude_none=True)}"
)
# Gọi data_retrieval_tool THẬT (embedding + vector search) với retry
result_json
=
await
retry_with_backoff
(
lambda
:
data_retrieval_tool
.
ainvoke
({
"searches"
:
[
search_item
]}),
max_retries
=
3
)
result
=
json
.
loads
(
result_json
)
elapsed_time
=
time
.
time
()
-
start_time
logger
.
info
(
f
"✅ Retriever completed in {elapsed_time:.3f}s"
)
# Parse kết quả
search_results
=
result
.
get
(
"results"
,
[{}])[
0
]
products
=
search_results
.
get
(
"products"
,
[])
return
{
"status"
:
result
.
get
(
"status"
,
"success"
),
"user_query"
:
req
.
user_query
,
"user_id"
:
req
.
user_id
,
"session_id"
:
req
.
session_id
,
"search_params"
:
search_item
.
dict
(
exclude_none
=
True
),
"total_results"
:
len
(
products
),
"products"
:
products
,
"processing_time_ms"
:
round
(
elapsed_time
*
1000
,
2
),
}
except
Exception
as
e
:
logger
.
error
(
f
"❌ Error in retriever DB: {e!s}"
,
exc_info
=
True
)
raise
HTTPException
(
status_code
=
500
,
detail
=
f
"Retriever DB Error: {e!s}"
)
backend/common/middleware.py
View file @
49f43a45
"""
Middleware Module - Gom Auth + Rate Limit
Singleton Pattern cho cả 2 services
"""
from
__future__
import
annotations
import
json
import
logging
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
from
fastapi
import
HTTPException
,
Request
,
status
from
starlette.middleware.base
import
BaseHTTPMiddleware
if
TYPE_CHECKING
:
from
fastapi
import
FastAPI
logger
=
logging
.
getLogger
(
__name__
)
# =============================================================================
# CONFIGURATION
# =============================================================================
# Public endpoints - không cần auth
PUBLIC_PATHS
=
{
"/"
,
"/health"
,
"/docs"
,
"/openapi.json"
,
"/redoc"
,
}
# Public path prefixes
PUBLIC_PATH_PREFIXES
=
[
"/static"
,
"/mock"
,
]
# =============================================================================
# AUTH + RATE LIMIT MIDDLEWARE CLASS
# =============================================================================
# Paths that need rate limit check
RATE_LIMITED_PATHS
=
[
"/api/agent/chat"
,
]
class
CanifaAuthMiddleware
(
BaseHTTPMiddleware
):
"""
Canifa Authentication + Rate Limit Middleware
Flow:
1. Frontend gửi request với Authorization: Bearer <canifa_token>
2. Middleware verify token với Canifa API → extract customer_id
3. Check message rate limit (Guest: 10, User: 100)
4. Attach user info vào request.state
5. Routes lấy trực tiếp từ request.state
"""
async
def
dispatch
(
self
,
request
:
Request
,
call_next
:
Callable
):
path
=
request
.
url
.
path
method
=
request
.
method
# ✅ Allow OPTIONS requests (CORS preflight)
if
method
==
"OPTIONS"
:
return
await
call_next
(
request
)
# Skip public endpoints
if
path
in
PUBLIC_PATHS
:
return
await
call_next
(
request
)
# Skip public path prefixes
if
any
(
path
.
startswith
(
prefix
)
for
prefix
in
PUBLIC_PATH_PREFIXES
):
return
await
call_next
(
request
)
# =====================================================================
# STEP 1: AUTHENTICATION (Canifa API)
# =====================================================================
try
:
auth_header
=
request
.
headers
.
get
(
"Authorization"
)
# --- Device ID from Body ---
device_id
=
""
if
method
in
[
"POST"
,
"PUT"
,
"PATCH"
]:
try
:
body_bytes
=
await
request
.
body
()
async
def
receive_wrapper
():
return
{
"type"
:
"http.request"
,
"body"
:
body_bytes
}
request
.
_receive
=
receive_wrapper
if
body_bytes
:
try
:
body_json
=
json
.
loads
(
body_bytes
)
device_id
=
body_json
.
get
(
"device_id"
,
""
)
except
json
.
JSONDecodeError
:
pass
except
Exception
as
e
:
logger
.
warning
(
f
"Error reading device_id from body: {e}"
)
# Fallback: Nếu không có trong body, tìm trong header -> IP
if
not
device_id
:
device_id
=
request
.
headers
.
get
(
"device_id"
,
""
)
if
not
device_id
:
device_id
=
f
"unknown_{request.client.host}"
if
request
.
client
else
"unknown"
# ========== DEV MODE: Bypass auth ==========
dev_user_id
=
request
.
headers
.
get
(
"X-Dev-User-Id"
)
if
dev_user_id
:
logger
.
warning
(
f
"⚠️ DEV MODE: Using X-Dev-User-Id={dev_user_id}"
)
request
.
state
.
user
=
{
"customer_id"
:
dev_user_id
}
request
.
state
.
user_id
=
dev_user_id
request
.
state
.
is_authenticated
=
True
request
.
state
.
device_id
=
device_id
or
dev_user_id
return
await
call_next
(
request
)
# --- TRƯỜNG HỢP 1: KHÔNG CÓ TOKEN -> GUEST ---
if
not
auth_header
or
not
auth_header
.
startswith
(
"Bearer "
):
request
.
state
.
user
=
None
request
.
state
.
user_id
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
device_id
=
device_id
else
:
# --- TRƯỜNG HỢP 2: CÓ TOKEN -> GỌI CANIFA VERIFY ---
token
=
auth_header
.
replace
(
"Bearer "
,
""
)
from
common.canifa_api
import
verify_canifa_token
,
extract_user_id_from_canifa_response
try
:
user_data
=
await
verify_canifa_token
(
token
)
user_id
=
await
extract_user_id_from_canifa_response
(
user_data
)
if
user_id
:
request
.
state
.
user
=
user_data
request
.
state
.
user_id
=
user_id
request
.
state
.
token
=
token
request
.
state
.
is_authenticated
=
True
request
.
state
.
device_id
=
device_id
logger
.
debug
(
f
"✅ Canifa Auth Success: User {user_id}"
)
else
:
logger
.
warning
(
f
"⚠️ Invalid Canifa Token -> Guest Mode"
)
request
.
state
.
user
=
None
request
.
state
.
user_id
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
device_id
=
device_id
except
Exception
as
e
:
logger
.
error
(
f
"❌ Canifa Auth Error: {e} -> Guest Mode"
)
request
.
state
.
user
=
None
request
.
state
.
user_id
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
device_id
=
device_id
except
Exception
as
e
:
logger
.
error
(
f
"❌ Middleware Auth Error: {e}"
)
request
.
state
.
user
=
None
request
.
state
.
user_id
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
device_id
=
""
# =====================================================================
# STEP 2: RATE LIMIT CHECK (Chỉ cho các path cần limit)
# =====================================================================
if
path
in
RATE_LIMITED_PATHS
:
try
:
from
common.message_limit
import
message_limit_service
from
fastapi.responses
import
JSONResponse
# Lấy identity_key làm rate limit key
# Guest: device_id → limit 10
# User: user_id → limit 100
is_authenticated
=
request
.
state
.
is_authenticated
if
is_authenticated
and
request
.
state
.
user_id
:
rate_limit_key
=
request
.
state
.
user_id
else
:
rate_limit_key
=
request
.
state
.
device_id
if
rate_limit_key
:
can_send
,
limit_info
=
await
message_limit_service
.
check_limit
(
identity_key
=
rate_limit_key
,
is_authenticated
=
is_authenticated
,
)
# Lưu limit_info vào request.state để route có thể dùng
request
.
state
.
limit_info
=
limit_info
if
not
can_send
:
logger
.
warning
(
f
"⚠️ Rate Limit Exceeded: {rate_limit_key} | "
f
"used={limit_info['used']}/{limit_info['limit']}"
)
return
JSONResponse
(
status_code
=
429
,
content
=
{
"status"
:
"error"
,
"error_code"
:
limit_info
.
get
(
"error_code"
)
or
"MESSAGE_LIMIT_EXCEEDED"
,
"message"
:
limit_info
[
"message"
],
"require_login"
:
limit_info
[
"require_login"
],
"limit_info"
:
{
"limit"
:
limit_info
[
"limit"
],
"used"
:
limit_info
[
"used"
],
"remaining"
:
limit_info
[
"remaining"
],
"reset_seconds"
:
limit_info
[
"reset_seconds"
],
},
},
)
else
:
logger
.
warning
(
f
"⚠️ No identity_key for rate limiting"
)
except
Exception
as
e
:
logger
.
error
(
f
"❌ Rate Limit Check Error: {e}"
)
# Cho phép request tiếp tục nếu lỗi rate limit
return
await
call_next
(
request
)
# =============================================================================
# MIDDLEWARE MANAGER - Singleton to manage all middlewares
# =============================================================================
class
MiddlewareManager
:
"""
Middleware Manager - Singleton Pattern
Quản lý và setup tất cả middlewares cho FastAPI app
Usage:
from common.middleware import middleware_manager
app = FastAPI()
middleware_manager.setup(app, enable_auth=True, enable_rate_limit=True)
"""
_instance
:
MiddlewareManager
|
None
=
None
_initialized
:
bool
=
False
def
__new__
(
cls
)
->
MiddlewareManager
:
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
()
.
__new__
(
cls
)
return
cls
.
_instance
def
__init__
(
self
)
->
None
:
if
MiddlewareManager
.
_initialized
:
return
self
.
_auth_enabled
=
False
self
.
_rate_limit_enabled
=
False
MiddlewareManager
.
_initialized
=
True
logger
.
info
(
"✅ MiddlewareManager initialized"
)
def
setup
(
self
,
app
:
FastAPI
,
*
,
enable_auth
:
bool
=
True
,
enable_rate_limit
:
bool
=
True
,
enable_cors
:
bool
=
True
,
cors_origins
:
list
[
str
]
|
None
=
None
,
)
->
None
:
"""
Setup tất cả middlewares cho FastAPI app.
Args:
app: FastAPI application
enable_auth: Bật Canifa authentication middleware
enable_rate_limit: Bật rate limiting
enable_cors: Bật CORS middleware
cors_origins: List origins cho CORS (default: ["*"])
Note:
Thứ tự middleware quan trọng! Middleware thêm sau sẽ chạy TRƯỚC.
Order: CORS → Auth → RateLimit → SlowAPI
"""
# 1. CORS Middleware (thêm cuối cùng để chạy đầu tiên)
if
enable_cors
:
self
.
_setup_cors
(
app
,
cors_origins
or
[
"*"
])
# 2. Auth Middleware
if
enable_auth
:
self
.
_setup_auth
(
app
)
# 3. Rate Limit Middleware
if
enable_rate_limit
:
self
.
_setup_rate_limit
(
app
)
logger
.
info
(
f
"✅ Middlewares configured: "
f
"CORS={enable_cors}, Auth={enable_auth}, RateLimit={enable_rate_limit}"
)
def
_setup_cors
(
self
,
app
:
FastAPI
,
origins
:
list
[
str
])
->
None
:
"""Setup CORS middleware."""
from
fastapi.middleware.cors
import
CORSMiddleware
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
origins
,
allow_credentials
=
True
,
allow_methods
=
[
"*"
],
allow_headers
=
[
"*"
],
)
logger
.
info
(
f
"✅ CORS middleware enabled (origins: {origins})"
)
def
_setup_auth
(
self
,
app
:
FastAPI
)
->
None
:
"""Setup Canifa auth middleware."""
app
.
add_middleware
(
CanifaAuthMiddleware
)
self
.
_auth_enabled
=
True
logger
.
info
(
"✅ Canifa Auth middleware enabled"
)
def
_setup_rate_limit
(
self
,
app
:
FastAPI
)
->
None
:
"""Setup rate limiting."""
from
common.rate_limit
import
rate_limit_service
rate_limit_service
.
setup
(
app
)
self
.
_rate_limit_enabled
=
True
logger
.
info
(
"✅ Rate Limit middleware enabled"
)
@
property
def
is_auth_enabled
(
self
)
->
bool
:
return
self
.
_auth_enabled
@
property
def
is_rate_limit_enabled
(
self
)
->
bool
:
return
self
.
_rate_limit_enabled
# =============================================================================
# SINGLETON INSTANCE
# =============================================================================
middleware_manager
=
MiddlewareManager
()
"""
Middleware Module - Gom Auth + Rate Limit
Singleton Pattern cho cả 2 services
"""
from
__future__
import
annotations
import
json
import
logging
from
collections.abc
import
Callable
from
typing
import
TYPE_CHECKING
from
fastapi
import
HTTPException
,
Request
,
status
from
starlette.middleware.base
import
BaseHTTPMiddleware
if
TYPE_CHECKING
:
from
fastapi
import
FastAPI
logger
=
logging
.
getLogger
(
__name__
)
# =============================================================================
# CONFIGURATION
# =============================================================================
# Public endpoints - không cần auth
PUBLIC_PATHS
=
{
"/"
,
"/health"
,
"/docs"
,
"/openapi.json"
,
"/redoc"
,
}
# Public path prefixes
PUBLIC_PATH_PREFIXES
=
[
"/static"
,
"/mock"
,
"/api/mock"
,
]
# =============================================================================
# AUTH + RATE LIMIT MIDDLEWARE CLASS
# =============================================================================
# Paths that need rate limit check
RATE_LIMITED_PATHS
=
[
"/api/agent/chat"
,
]
class
CanifaAuthMiddleware
(
BaseHTTPMiddleware
):
"""
Canifa Authentication + Rate Limit Middleware
Flow:
1. Frontend gửi request với Authorization: Bearer <canifa_token>
2. Middleware verify token với Canifa API → extract customer_id
3. Check message rate limit (Guest: 10, User: 100)
4. Attach user info vào request.state
5. Routes lấy trực tiếp từ request.state
"""
async
def
dispatch
(
self
,
request
:
Request
,
call_next
:
Callable
):
path
=
request
.
url
.
path
method
=
request
.
method
# ✅ Allow OPTIONS requests (CORS preflight)
if
method
==
"OPTIONS"
:
return
await
call_next
(
request
)
# Skip public endpoints
if
path
in
PUBLIC_PATHS
:
return
await
call_next
(
request
)
# Skip public path prefixes
if
any
(
path
.
startswith
(
prefix
)
for
prefix
in
PUBLIC_PATH_PREFIXES
):
return
await
call_next
(
request
)
# =====================================================================
# STEP 1: AUTHENTICATION (Canifa API)
# =====================================================================
try
:
auth_header
=
request
.
headers
.
get
(
"Authorization"
)
# --- Device ID from Body ---
device_id
=
""
if
method
in
[
"POST"
,
"PUT"
,
"PATCH"
]:
try
:
body_bytes
=
await
request
.
body
()
async
def
receive_wrapper
():
return
{
"type"
:
"http.request"
,
"body"
:
body_bytes
}
request
.
_receive
=
receive_wrapper
if
body_bytes
:
try
:
body_json
=
json
.
loads
(
body_bytes
)
device_id
=
body_json
.
get
(
"device_id"
,
""
)
except
json
.
JSONDecodeError
:
pass
except
Exception
as
e
:
logger
.
warning
(
f
"Error reading device_id from body: {e}"
)
# Fallback: Nếu không có trong body, tìm trong header -> IP
if
not
device_id
:
device_id
=
request
.
headers
.
get
(
"device_id"
,
""
)
if
not
device_id
:
device_id
=
f
"unknown_{request.client.host}"
if
request
.
client
else
"unknown"
# ========== DEV MODE: Bypass auth ==========
dev_user_id
=
request
.
headers
.
get
(
"X-Dev-User-Id"
)
if
dev_user_id
:
logger
.
warning
(
f
"⚠️ DEV MODE: Using X-Dev-User-Id={dev_user_id}"
)
request
.
state
.
user
=
{
"customer_id"
:
dev_user_id
}
request
.
state
.
user_id
=
dev_user_id
request
.
state
.
is_authenticated
=
True
request
.
state
.
device_id
=
device_id
or
dev_user_id
return
await
call_next
(
request
)
# --- TRƯỜNG HỢP 1: KHÔNG CÓ TOKEN -> GUEST ---
if
not
auth_header
or
not
auth_header
.
startswith
(
"Bearer "
):
request
.
state
.
user
=
None
request
.
state
.
user_id
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
device_id
=
device_id
else
:
# --- TRƯỜNG HỢP 2: CÓ TOKEN -> GỌI CANIFA VERIFY ---
token
=
auth_header
.
replace
(
"Bearer "
,
""
)
from
common.canifa_api
import
verify_canifa_token
,
extract_user_id_from_canifa_response
try
:
user_data
=
await
verify_canifa_token
(
token
)
user_id
=
await
extract_user_id_from_canifa_response
(
user_data
)
if
user_id
:
request
.
state
.
user
=
user_data
request
.
state
.
user_id
=
user_id
request
.
state
.
token
=
token
request
.
state
.
is_authenticated
=
True
request
.
state
.
device_id
=
device_id
logger
.
debug
(
f
"✅ Canifa Auth Success: User {user_id}"
)
else
:
logger
.
warning
(
f
"⚠️ Invalid Canifa Token -> Guest Mode"
)
request
.
state
.
user
=
None
request
.
state
.
user_id
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
device_id
=
device_id
except
Exception
as
e
:
logger
.
error
(
f
"❌ Canifa Auth Error: {e} -> Guest Mode"
)
request
.
state
.
user
=
None
request
.
state
.
user_id
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
device_id
=
device_id
except
Exception
as
e
:
logger
.
error
(
f
"❌ Middleware Auth Error: {e}"
)
request
.
state
.
user
=
None
request
.
state
.
user_id
=
None
request
.
state
.
is_authenticated
=
False
request
.
state
.
device_id
=
""
# =====================================================================
# STEP 2: RATE LIMIT CHECK (Chỉ cho các path cần limit)
# =====================================================================
if
path
in
RATE_LIMITED_PATHS
:
try
:
from
common.message_limit
import
message_limit_service
from
fastapi.responses
import
JSONResponse
# Lấy identity_key làm rate limit key
# Guest: device_id → limit 10
# User: user_id → limit 100
is_authenticated
=
request
.
state
.
is_authenticated
if
is_authenticated
and
request
.
state
.
user_id
:
rate_limit_key
=
request
.
state
.
user_id
else
:
rate_limit_key
=
request
.
state
.
device_id
if
rate_limit_key
:
can_send
,
limit_info
=
await
message_limit_service
.
check_limit
(
identity_key
=
rate_limit_key
,
is_authenticated
=
is_authenticated
,
)
# Lưu limit_info vào request.state để route có thể dùng
request
.
state
.
limit_info
=
limit_info
if
not
can_send
:
logger
.
warning
(
f
"⚠️ Rate Limit Exceeded: {rate_limit_key} | "
f
"used={limit_info['used']}/{limit_info['limit']}"
)
return
JSONResponse
(
status_code
=
429
,
content
=
{
"status"
:
"error"
,
"error_code"
:
limit_info
.
get
(
"error_code"
)
or
"MESSAGE_LIMIT_EXCEEDED"
,
"message"
:
limit_info
[
"message"
],
"require_login"
:
limit_info
[
"require_login"
],
"limit_info"
:
{
"limit"
:
limit_info
[
"limit"
],
"used"
:
limit_info
[
"used"
],
"remaining"
:
limit_info
[
"remaining"
],
"reset_seconds"
:
limit_info
[
"reset_seconds"
],
},
},
)
else
:
logger
.
warning
(
f
"⚠️ No identity_key for rate limiting"
)
except
Exception
as
e
:
logger
.
error
(
f
"❌ Rate Limit Check Error: {e}"
)
# Cho phép request tiếp tục nếu lỗi rate limit
return
await
call_next
(
request
)
# =============================================================================
# MIDDLEWARE MANAGER - Singleton to manage all middlewares
# =============================================================================
class
MiddlewareManager
:
"""
Middleware Manager - Singleton Pattern
Quản lý và setup tất cả middlewares cho FastAPI app
Usage:
from common.middleware import middleware_manager
app = FastAPI()
middleware_manager.setup(app, enable_auth=True, enable_rate_limit=True)
"""
_instance
:
MiddlewareManager
|
None
=
None
_initialized
:
bool
=
False
def
__new__
(
cls
)
->
MiddlewareManager
:
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
()
.
__new__
(
cls
)
return
cls
.
_instance
def
__init__
(
self
)
->
None
:
if
MiddlewareManager
.
_initialized
:
return
self
.
_auth_enabled
=
False
self
.
_rate_limit_enabled
=
False
MiddlewareManager
.
_initialized
=
True
logger
.
info
(
"✅ MiddlewareManager initialized"
)
def
setup
(
self
,
app
:
FastAPI
,
*
,
enable_auth
:
bool
=
True
,
enable_rate_limit
:
bool
=
True
,
enable_cors
:
bool
=
True
,
cors_origins
:
list
[
str
]
|
None
=
None
,
)
->
None
:
"""
Setup tất cả middlewares cho FastAPI app.
Args:
app: FastAPI application
enable_auth: Bật Canifa authentication middleware
enable_rate_limit: Bật rate limiting
enable_cors: Bật CORS middleware
cors_origins: List origins cho CORS (default: ["*"])
Note:
Thứ tự middleware quan trọng! Middleware thêm sau sẽ chạy TRƯỚC.
Order: CORS → Auth → RateLimit → SlowAPI
"""
# 1. CORS Middleware (thêm cuối cùng để chạy đầu tiên)
if
enable_cors
:
self
.
_setup_cors
(
app
,
cors_origins
or
[
"*"
])
# 2. Auth Middleware
if
enable_auth
:
self
.
_setup_auth
(
app
)
# 3. Rate Limit Middleware
if
enable_rate_limit
:
self
.
_setup_rate_limit
(
app
)
logger
.
info
(
f
"✅ Middlewares configured: "
f
"CORS={enable_cors}, Auth={enable_auth}, RateLimit={enable_rate_limit}"
)
def
_setup_cors
(
self
,
app
:
FastAPI
,
origins
:
list
[
str
])
->
None
:
"""Setup CORS middleware."""
from
fastapi.middleware.cors
import
CORSMiddleware
app
.
add_middleware
(
CORSMiddleware
,
allow_origins
=
origins
,
allow_credentials
=
True
,
allow_methods
=
[
"*"
],
allow_headers
=
[
"*"
],
)
logger
.
info
(
f
"✅ CORS middleware enabled (origins: {origins})"
)
def
_setup_auth
(
self
,
app
:
FastAPI
)
->
None
:
"""Setup Canifa auth middleware."""
app
.
add_middleware
(
CanifaAuthMiddleware
)
self
.
_auth_enabled
=
True
logger
.
info
(
"✅ Canifa Auth middleware enabled"
)
def
_setup_rate_limit
(
self
,
app
:
FastAPI
)
->
None
:
"""Setup rate limiting."""
from
common.rate_limit
import
rate_limit_service
rate_limit_service
.
setup
(
app
)
self
.
_rate_limit_enabled
=
True
logger
.
info
(
"✅ Rate Limit middleware enabled"
)
@
property
def
is_auth_enabled
(
self
)
->
bool
:
return
self
.
_auth_enabled
@
property
def
is_rate_limit_enabled
(
self
)
->
bool
:
return
self
.
_rate_limit_enabled
# =============================================================================
# SINGLETON INSTANCE
# =============================================================================
middleware_manager
=
MiddlewareManager
()
backend/common/rate_limit.py
View file @
49f43a45
"""
Rate Limiting Service - Singleton Pattern
Sử dụng SlowAPI với Redis backend (production) hoặc Memory (dev)
"""
from
__future__
import
annotations
import
logging
import
os
from
datetime
import
datetime
,
timedelta
from
typing
import
TYPE_CHECKING
from
fastapi
import
Request
from
fastapi.responses
import
JSONResponse
from
slowapi
import
Limiter
from
slowapi.errors
import
RateLimitExceeded
from
slowapi.middleware
import
SlowAPIMiddleware
from
slowapi.util
import
get_remote_address
if
TYPE_CHECKING
:
from
fastapi
import
FastAPI
logger
=
logging
.
getLogger
(
__name__
)
class
RateLimitService
:
"""
Rate Limiting Service - Singleton Pattern
Usage:
# Trong server.py
from common.rate_limit import RateLimitService
rate_limiter = RateLimitService()
rate_limiter.setup(app)
# Trong route
from common.rate_limit import RateLimitService
@router.post("/chat")
@RateLimitService().limiter.limit("10/minute")
async def chat(request: Request):
...
"""
_instance
:
RateLimitService
|
None
=
None
_initialized
:
bool
=
False
# =========================================================================
# SINGLETON PATTERN
# =========================================================================
def
__new__
(
cls
)
->
RateLimitService
:
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
()
.
__new__
(
cls
)
return
cls
.
_instance
def
__init__
(
self
)
->
None
:
# Chỉ init một lần
if
RateLimitService
.
_initialized
:
return
# Configuration
self
.
storage_uri
=
os
.
getenv
(
"RATE_STORAGE_URI"
,
"memory://"
)
self
.
default_limits
=
[
"100/hour"
,
"30/minute"
]
self
.
block_duration_minutes
=
int
(
os
.
getenv
(
"RATE_LIMIT_BLOCK_MINUTES"
,
"5"
))
# Paths không áp dụng rate limit
self
.
exempt_paths
=
{
"/"
,
"/health"
,
"/docs"
,
"/openapi.json"
,
"/redoc"
,
}
self
.
exempt_prefixes
=
[
"/static"
,
"/mock"
]
# In-memory blocklist (có thể chuyển sang Redis)
self
.
_blocklist
:
dict
[
str
,
datetime
]
=
{}
# Create limiter instance
self
.
limiter
=
Limiter
(
key_func
=
self
.
_get_client_identifier
,
storage_uri
=
self
.
storage_uri
,
default_limits
=
self
.
default_limits
,
)
RateLimitService
.
_initialized
=
True
logger
.
info
(
f
"✅ RateLimitService initialized (storage: {self.storage_uri})"
)
# =========================================================================
# CLIENT IDENTIFIER
# =========================================================================
@
staticmethod
def
_get_client_identifier
(
request
:
Request
)
->
str
:
"""
Lấy client identifier cho rate limiting.
Ưu tiên: user_id (authenticated) > device_id > IP address
"""
# 1. Nếu đã authenticated → dùng user_id
if
hasattr
(
request
.
state
,
"user_id"
)
and
request
.
state
.
user_id
:
return
f
"user:{request.state.user_id}"
# 2. Nếu có device_id trong header → dùng device_id
device_id
=
request
.
headers
.
get
(
"device_id"
)
if
device_id
:
return
f
"device:{device_id}"
# 3. Fallback → IP address
try
:
return
f
"ip:{get_remote_address(request)}"
except
Exception
:
if
request
.
client
:
return
f
"ip:{request.client.host}"
return
"unknown"
# =========================================================================
# BLOCKLIST MANAGEMENT
# =========================================================================
def
is_blocked
(
self
,
key
:
str
)
->
tuple
[
bool
,
int
]:
"""
Check if client is blocked.
Returns: (is_blocked, retry_after_seconds)
"""
now
=
datetime
.
utcnow
()
blocked_until
=
self
.
_blocklist
.
get
(
key
)
if
blocked_until
:
if
blocked_until
>
now
:
retry_after
=
int
((
blocked_until
-
now
)
.
total_seconds
())
return
True
,
retry_after
else
:
# Block expired
self
.
_blocklist
.
pop
(
key
,
None
)
return
False
,
0
def
block_client
(
self
,
key
:
str
)
->
int
:
"""
Block client for configured duration.
Returns: retry_after_seconds
"""
self
.
_blocklist
[
key
]
=
datetime
.
utcnow
()
+
timedelta
(
minutes
=
self
.
block_duration_minutes
)
return
self
.
block_duration_minutes
*
60
def
unblock_client
(
self
,
key
:
str
)
->
None
:
"""Unblock client manually."""
self
.
_blocklist
.
pop
(
key
,
None
)
# =========================================================================
# PATH CHECKING
# =========================================================================
def
is_exempt
(
self
,
path
:
str
)
->
bool
:
"""Check if path is exempt from rate limiting."""
if
path
in
self
.
exempt_paths
:
return
True
return
any
(
path
.
startswith
(
prefix
)
for
prefix
in
self
.
exempt_prefixes
)
# =========================================================================
# SETUP FOR FASTAPI APP
# =========================================================================
def
setup
(
self
,
app
:
FastAPI
)
->
None
:
"""
Setup rate limiting cho FastAPI app.
Gọi trong server.py sau khi tạo app.
"""
# Attach limiter to app state
app
.
state
.
limiter
=
self
.
limiter
app
.
state
.
rate_limit_service
=
self
# Register middleware
self
.
_register_block_middleware
(
app
)
self
.
_register_exception_handler
(
app
)
# Add SlowAPI middleware (PHẢI thêm SAU custom middlewares)
app
.
add_middleware
(
SlowAPIMiddleware
)
logger
.
info
(
"✅ Rate limiting middleware registered"
)
def
_register_block_middleware
(
self
,
app
:
FastAPI
)
->
None
:
"""Register middleware to check blocklist."""
@
app
.
middleware
(
"http"
)
async
def
rate_limit_block_middleware
(
request
:
Request
,
call_next
):
path
=
request
.
url
.
path
# Skip exempt paths
if
self
.
is_exempt
(
path
):
return
await
call_next
(
request
)
# Bypass header cho testing
if
request
.
headers
.
get
(
"X-Bypass-RateLimit"
)
==
"1"
:
return
await
call_next
(
request
)
# Check blocklist
key
=
self
.
_get_client_identifier
(
request
)
is_blocked
,
retry_after
=
self
.
is_blocked
(
key
)
if
is_blocked
:
return
JSONResponse
(
status_code
=
429
,
content
=
{
"detail"
:
"Quá số lượt cho phép. Vui lòng thử lại sau."
,
"retry_after_seconds"
:
retry_after
,
},
headers
=
{
"Retry-After"
:
str
(
retry_after
)},
)
return
await
call_next
(
request
)
def
_register_exception_handler
(
self
,
app
:
FastAPI
)
->
None
:
"""Register exception handler for rate limit exceeded."""
@
app
.
exception_handler
(
RateLimitExceeded
)
async
def
rate_limit_exceeded_handler
(
request
:
Request
,
exc
:
RateLimitExceeded
):
key
=
self
.
_get_client_identifier
(
request
)
retry_after
=
self
.
block_client
(
key
)
logger
.
warning
(
f
"⚠️ Rate limit exceeded for {key}, blocked for {self.block_duration_minutes} minutes"
)
return
JSONResponse
(
status_code
=
429
,
content
=
{
"detail"
:
"Quá số lượt cho phép. Vui lòng thử lại sau."
,
"retry_after_seconds"
:
retry_after
,
},
headers
=
{
"Retry-After"
:
str
(
retry_after
)},
)
# =============================================================================
# SINGLETON INSTANCE - Import trực tiếp để dùng
# =============================================================================
rate_limit_service
=
RateLimitService
()
"""
Rate Limiting Service - Singleton Pattern
Sử dụng SlowAPI với Redis backend (production) hoặc Memory (dev)
"""
from
__future__
import
annotations
import
logging
import
os
from
datetime
import
datetime
,
timedelta
from
typing
import
TYPE_CHECKING
from
fastapi
import
Request
from
fastapi.responses
import
JSONResponse
from
slowapi
import
Limiter
from
slowapi.errors
import
RateLimitExceeded
from
slowapi.middleware
import
SlowAPIMiddleware
from
slowapi.util
import
get_remote_address
if
TYPE_CHECKING
:
from
fastapi
import
FastAPI
logger
=
logging
.
getLogger
(
__name__
)
class
RateLimitService
:
"""
Rate Limiting Service - Singleton Pattern
Usage:
# Trong server.py
from common.rate_limit import RateLimitService
rate_limiter = RateLimitService()
rate_limiter.setup(app)
# Trong route
from common.rate_limit import RateLimitService
@router.post("/chat")
@RateLimitService().limiter.limit("10/minute")
async def chat(request: Request):
...
"""
_instance
:
RateLimitService
|
None
=
None
_initialized
:
bool
=
False
# =========================================================================
# SINGLETON PATTERN
# =========================================================================
def
__new__
(
cls
)
->
RateLimitService
:
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
()
.
__new__
(
cls
)
return
cls
.
_instance
def
__init__
(
self
)
->
None
:
# Chỉ init một lần
if
RateLimitService
.
_initialized
:
return
# Configuration
self
.
storage_uri
=
os
.
getenv
(
"RATE_STORAGE_URI"
,
"memory://"
)
self
.
default_limits
=
[
"100/hour"
,
"30/minute"
]
self
.
block_duration_minutes
=
int
(
os
.
getenv
(
"RATE_LIMIT_BLOCK_MINUTES"
,
"5"
))
# Paths không áp dụng rate limit
self
.
exempt_paths
=
{
"/"
,
"/health"
,
"/docs"
,
"/openapi.json"
,
"/redoc"
,
}
self
.
exempt_prefixes
=
[
"/static"
,
"/mock"
,
"/api/mock"
]
# In-memory blocklist (có thể chuyển sang Redis)
self
.
_blocklist
:
dict
[
str
,
datetime
]
=
{}
# Create limiter instance
self
.
limiter
=
Limiter
(
key_func
=
self
.
_get_client_identifier
,
storage_uri
=
self
.
storage_uri
,
default_limits
=
self
.
default_limits
,
)
RateLimitService
.
_initialized
=
True
logger
.
info
(
f
"✅ RateLimitService initialized (storage: {self.storage_uri})"
)
# =========================================================================
# CLIENT IDENTIFIER
# =========================================================================
@
staticmethod
def
_get_client_identifier
(
request
:
Request
)
->
str
:
"""
Lấy client identifier cho rate limiting.
Ưu tiên: user_id (authenticated) > device_id > IP address
"""
# 1. Nếu đã authenticated → dùng user_id
if
hasattr
(
request
.
state
,
"user_id"
)
and
request
.
state
.
user_id
:
return
f
"user:{request.state.user_id}"
# 2. Nếu có device_id trong header → dùng device_id
device_id
=
request
.
headers
.
get
(
"device_id"
)
if
device_id
:
return
f
"device:{device_id}"
# 3. Fallback → IP address
try
:
return
f
"ip:{get_remote_address(request)}"
except
Exception
:
if
request
.
client
:
return
f
"ip:{request.client.host}"
return
"unknown"
# =========================================================================
# BLOCKLIST MANAGEMENT
# =========================================================================
def
is_blocked
(
self
,
key
:
str
)
->
tuple
[
bool
,
int
]:
"""
Check if client is blocked.
Returns: (is_blocked, retry_after_seconds)
"""
now
=
datetime
.
utcnow
()
blocked_until
=
self
.
_blocklist
.
get
(
key
)
if
blocked_until
:
if
blocked_until
>
now
:
retry_after
=
int
((
blocked_until
-
now
)
.
total_seconds
())
return
True
,
retry_after
else
:
# Block expired
self
.
_blocklist
.
pop
(
key
,
None
)
return
False
,
0
def
block_client
(
self
,
key
:
str
)
->
int
:
"""
Block client for configured duration.
Returns: retry_after_seconds
"""
self
.
_blocklist
[
key
]
=
datetime
.
utcnow
()
+
timedelta
(
minutes
=
self
.
block_duration_minutes
)
return
self
.
block_duration_minutes
*
60
def
unblock_client
(
self
,
key
:
str
)
->
None
:
"""Unblock client manually."""
self
.
_blocklist
.
pop
(
key
,
None
)
# =========================================================================
# PATH CHECKING
# =========================================================================
def
is_exempt
(
self
,
path
:
str
)
->
bool
:
"""Check if path is exempt from rate limiting."""
if
path
in
self
.
exempt_paths
:
return
True
return
any
(
path
.
startswith
(
prefix
)
for
prefix
in
self
.
exempt_prefixes
)
# =========================================================================
# SETUP FOR FASTAPI APP
# =========================================================================
def
setup
(
self
,
app
:
FastAPI
)
->
None
:
"""
Setup rate limiting cho FastAPI app.
Gọi trong server.py sau khi tạo app.
"""
# Attach limiter to app state
app
.
state
.
limiter
=
self
.
limiter
app
.
state
.
rate_limit_service
=
self
# Register middleware
self
.
_register_block_middleware
(
app
)
self
.
_register_exception_handler
(
app
)
# Add SlowAPI middleware (PHẢI thêm SAU custom middlewares)
app
.
add_middleware
(
SlowAPIMiddleware
)
logger
.
info
(
"✅ Rate limiting middleware registered"
)
def
_register_block_middleware
(
self
,
app
:
FastAPI
)
->
None
:
"""Register middleware to check blocklist."""
@
app
.
middleware
(
"http"
)
async
def
rate_limit_block_middleware
(
request
:
Request
,
call_next
):
path
=
request
.
url
.
path
# Skip exempt paths
if
self
.
is_exempt
(
path
):
return
await
call_next
(
request
)
# Bypass header cho testing
if
request
.
headers
.
get
(
"X-Bypass-RateLimit"
)
==
"1"
:
return
await
call_next
(
request
)
# Check blocklist
key
=
self
.
_get_client_identifier
(
request
)
is_blocked
,
retry_after
=
self
.
is_blocked
(
key
)
if
is_blocked
:
return
JSONResponse
(
status_code
=
429
,
content
=
{
"detail"
:
"Quá số lượt cho phép. Vui lòng thử lại sau."
,
"retry_after_seconds"
:
retry_after
,
},
headers
=
{
"Retry-After"
:
str
(
retry_after
)},
)
return
await
call_next
(
request
)
def
_register_exception_handler
(
self
,
app
:
FastAPI
)
->
None
:
"""Register exception handler for rate limit exceeded."""
@
app
.
exception_handler
(
RateLimitExceeded
)
async
def
rate_limit_exceeded_handler
(
request
:
Request
,
exc
:
RateLimitExceeded
):
key
=
self
.
_get_client_identifier
(
request
)
retry_after
=
self
.
block_client
(
key
)
logger
.
warning
(
f
"⚠️ Rate limit exceeded for {key}, blocked for {self.block_duration_minutes} minutes"
)
return
JSONResponse
(
status_code
=
429
,
content
=
{
"detail"
:
"Quá số lượt cho phép. Vui lòng thử lại sau."
,
"retry_after_seconds"
:
retry_after
,
},
headers
=
{
"Retry-After"
:
str
(
retry_after
)},
)
# =============================================================================
# SINGLETON INSTANCE - Import trực tiếp để dùng
# =============================================================================
rate_limit_service
=
RateLimitService
()
backend/server.py
View file @
49f43a45
import
asyncio
import
os
import
platform
if
platform
.
system
()
==
"Windows"
:
print
(
"🔧 Windows detected: Applying SelectorEventLoopPolicy globally..."
)
asyncio
.
set_event_loop_policy
(
asyncio
.
WindowsSelectorEventLoopPolicy
())
import
logging
import
uvicorn
from
fastapi
import
FastAPI
from
fastapi.staticfiles
import
StaticFiles
from
api.chatbot_route
import
router
as
chatbot_router
from
api.conservation_route
import
router
as
conservation_router
from
api.prompt_route
import
router
as
prompt_router
from
common.cache
import
redis_cache
from
common.langfuse_client
import
get_langfuse_client
from
common.middleware
import
middleware_manager
from
config
import
PORT
# Configure Logging
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"
%(asctime)
s [
%(levelname)
s]
%(name)
s:
%(message)
s"
,
handlers
=
[
logging
.
StreamHandler
()],
)
logger
=
logging
.
getLogger
(
__name__
)
langfuse_client
=
get_langfuse_client
()
if
langfuse_client
:
logger
.
info
(
"✅ Langfuse client ready (lazy loading)"
)
else
:
logger
.
warning
(
"⚠️ Langfuse client not available (missing keys or disabled)"
)
app
=
FastAPI
(
title
=
"Contract AI Service"
,
description
=
"API for Contract AI Service"
,
version
=
"1.0.0"
,
)
# =============================================================================
# STARTUP EVENT - Initialize Redis Cache
# =============================================================================
@
app
.
on_event
(
"startup"
)
async
def
startup_event
():
"""Initialize Redis cache on startup."""
await
redis_cache
.
initialize
()
logger
.
info
(
"✅ Redis cache initialized for message limit"
)
# =============================================================================
# MIDDLEWARE SETUP - Gom Auth + RateLimit + CORS vào một chỗ
# =============================================================================
middleware_manager
.
setup
(
app
,
enable_auth
=
True
,
# 👈 Bật lại Auth để test logic Guest/User
enable_rate_limit
=
True
,
# 👈 Bật lại SlowAPI theo yêu cầu
enable_cors
=
True
,
# 👈 Bật CORS
cors_origins
=
[
"*"
],
# 👈 Trong production nên limit origins
)
app
.
include_router
(
conservation_router
)
app
.
include_router
(
chatbot_router
)
app
.
include_router
(
prompt_router
)
# --- MOCK API FOR LOAD TESTING ---
try
:
from
api.mock_api_route
import
router
as
mock_router
import
asyncio
import
os
import
platform
if
platform
.
system
()
==
"Windows"
:
print
(
"🔧 Windows detected: Applying SelectorEventLoopPolicy globally..."
)
asyncio
.
set_event_loop_policy
(
asyncio
.
WindowsSelectorEventLoopPolicy
())
import
logging
import
uvicorn
from
fastapi
import
FastAPI
from
fastapi.staticfiles
import
StaticFiles
from
api.chatbot_route
import
router
as
chatbot_router
from
api.conservation_route
import
router
as
conservation_router
from
api.prompt_route
import
router
as
prompt_router
from
common.cache
import
redis_cache
from
common.langfuse_client
import
get_langfuse_client
from
common.middleware
import
middleware_manager
from
config
import
PORT
# Configure Logging
logging
.
basicConfig
(
level
=
logging
.
INFO
,
format
=
"
%(asctime)
s [
%(levelname)
s]
%(name)
s:
%(message)
s"
,
handlers
=
[
logging
.
StreamHandler
()],
)
logger
=
logging
.
getLogger
(
__name__
)
langfuse_client
=
get_langfuse_client
()
if
langfuse_client
:
logger
.
info
(
"✅ Langfuse client ready (lazy loading)"
)
else
:
logger
.
warning
(
"⚠️ Langfuse client not available (missing keys or disabled)"
)
app
=
FastAPI
(
title
=
"Contract AI Service"
,
description
=
"API for Contract AI Service"
,
version
=
"1.0.0"
,
)
# =============================================================================
# STARTUP EVENT - Initialize Redis Cache
# =============================================================================
@
app
.
on_event
(
"startup"
)
async
def
startup_event
():
"""Initialize Redis cache on startup."""
await
redis_cache
.
initialize
()
logger
.
info
(
"✅ Redis cache initialized for message limit"
)
# =============================================================================
# MIDDLEWARE SETUP - Gom Auth + RateLimit + CORS vào một chỗ
# =============================================================================
middleware_manager
.
setup
(
app
,
enable_auth
=
True
,
# 👈 Bật lại Auth để test logic Guest/User
enable_rate_limit
=
True
,
# 👈 Bật lại SlowAPI theo yêu cầu
enable_cors
=
True
,
# 👈 Bật CORS
cors_origins
=
[
"*"
],
# 👈 Trong production nên limit origins
)
app
.
include_router
(
conservation_router
)
app
.
include_router
(
chatbot_router
)
app
.
include_router
(
prompt_router
)
# --- MOCK API FOR LOAD TESTING ---
try
:
from
api.mock_api_route
import
router
as
mock_router
app
.
include_router
(
mock_router
)
print
(
"✅ Mock API Router mounted at /mock"
)
except
ImportError
:
print
(
"⚠️ Mock Router not found, skipping..."
)
# ==========================================
# 🟢 ĐOẠN MOUNT STATIC HTML CỦA BRO ĐÂY 🟢
# ==========================================
try
:
static_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"static"
)
if
not
os
.
path
.
exists
(
static_dir
):
os
.
makedirs
(
static_dir
)
# Mount thư mục static để chạy file index.html
app
.
mount
(
"/static"
,
StaticFiles
(
directory
=
static_dir
,
html
=
True
),
name
=
"static"
)
print
(
f
"✅ Static files mounted at /static (Dir: {static_dir})"
)
except
Exception
as
e
:
print
(
f
"⚠️ Failed to mount static files: {e}"
)
from
fastapi.responses
import
RedirectResponse
@
app
.
get
(
"/"
)
async
def
root
():
return
RedirectResponse
(
url
=
"/static/index.html"
)
if
__name__
==
"__main__"
:
print
(
"="
*
60
)
print
(
"🚀 Contract AI Service Starting..."
)
print
(
"="
*
60
)
print
(
f
"📡 REST API: http://localhost:{PORT}"
)
print
(
f
"📡 Test Chatbot: http://localhost:{PORT}/static/index.html"
)
print
(
f
"📚 API Docs: http://localhost:{PORT}/docs"
)
print
(
"="
*
60
)
ENABLE_RELOAD
=
False
print
(
f
"⚠️ Hot reload: {ENABLE_RELOAD}"
)
reload_dirs
=
[
"common"
,
"api"
,
"agent"
]
if
ENABLE_RELOAD
:
os
.
environ
[
"PYTHONUNBUFFERED"
]
=
"1"
uvicorn
.
run
(
"server:app"
,
host
=
"0.0.0.0"
,
port
=
PORT
,
reload
=
ENABLE_RELOAD
,
reload_dirs
=
reload_dirs
,
log_level
=
"info"
,
)
print
(
"✅ Mock API Router mounted at /
api/
mock"
)
except
ImportError
:
print
(
"⚠️ Mock Router not found, skipping..."
)
# ==========================================
# 🟢 ĐOẠN MOUNT STATIC HTML CỦA BRO ĐÂY 🟢
# ==========================================
try
:
static_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"static"
)
if
not
os
.
path
.
exists
(
static_dir
):
os
.
makedirs
(
static_dir
)
# Mount thư mục static để chạy file index.html
app
.
mount
(
"/static"
,
StaticFiles
(
directory
=
static_dir
,
html
=
True
),
name
=
"static"
)
print
(
f
"✅ Static files mounted at /static (Dir: {static_dir})"
)
except
Exception
as
e
:
print
(
f
"⚠️ Failed to mount static files: {e}"
)
from
fastapi.responses
import
RedirectResponse
@
app
.
get
(
"/"
)
async
def
root
():
return
RedirectResponse
(
url
=
"/static/index.html"
)
if
__name__
==
"__main__"
:
print
(
"="
*
60
)
print
(
"🚀 Contract AI Service Starting..."
)
print
(
"="
*
60
)
print
(
f
"📡 REST API: http://localhost:{PORT}"
)
print
(
f
"📡 Test Chatbot: http://localhost:{PORT}/static/index.html"
)
print
(
f
"📚 API Docs: http://localhost:{PORT}/docs"
)
print
(
"="
*
60
)
ENABLE_RELOAD
=
False
print
(
f
"⚠️ Hot reload: {ENABLE_RELOAD}"
)
reload_dirs
=
[
"common"
,
"api"
,
"agent"
]
if
ENABLE_RELOAD
:
os
.
environ
[
"PYTHONUNBUFFERED"
]
=
"1"
uvicorn
.
run
(
"server:app"
,
host
=
"0.0.0.0"
,
port
=
PORT
,
reload
=
ENABLE_RELOAD
,
reload_dirs
=
reload_dirs
,
log_level
=
"info"
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment