Commit 26d10212 authored by Steven's avatar Steven

refactor: consolidate duplicated auth logic into auth package

Add ApplyToContext and AuthenticateToUser helpers to the auth package,
then remove the duplicated auth code spread across the MCP middleware,
file server, Connect interceptor, and gRPC-Gateway middleware.

- auth.ApplyToContext: single place to set claims/user into context after Authenticate()
- auth.AuthenticateToUser: resolves any credential (bearer token or refresh cookie) to a *store.User
- MCP middleware: replaced manual PAT DB lookup + expiry check with Authenticator.AuthenticateByPAT
- File server: replaced authenticateByBearerToken/authenticateByRefreshToken with AuthenticateToUser
- Connect interceptor + Gateway middleware: replaced duplicated context-setting block with ApplyToContext
- MCPService now accepts secret to construct its own Authenticator
parent 47d94147
...@@ -130,6 +130,40 @@ type AuthResult struct { ...@@ -130,6 +130,40 @@ type AuthResult struct {
AccessToken string // Non-empty if authenticated via JWT AccessToken string // Non-empty if authenticated via JWT
} }
// AuthenticateToUser resolves the current request to a *store.User, checking the
// Authorization header first (access token or PAT), then falling back to the
// refresh token cookie. Returns (nil, nil) when no credentials are present.
func (a *Authenticator) AuthenticateToUser(ctx context.Context, authHeader, cookieHeader string) (*store.User, error) {
// Try Bearer token first.
if authHeader != "" {
token := ExtractBearerToken(authHeader)
if token != "" {
if !strings.HasPrefix(token, PersonalAccessTokenPrefix) {
claims, err := a.AuthenticateByAccessTokenV2(token)
if err == nil && claims != nil {
return a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
}
} else {
user, _, err := a.AuthenticateByPAT(ctx, token)
if err == nil {
return user, nil
}
}
}
}
// Fallback: refresh token cookie.
if cookieHeader != "" {
refreshToken := ExtractRefreshTokenFromCookie(cookieHeader)
if refreshToken != "" {
user, _, err := a.AuthenticateByRefreshToken(ctx, refreshToken)
return user, err
}
}
return nil, nil
}
// Authenticate tries to authenticate using the provided credentials. // Authenticate tries to authenticate using the provided credentials.
// Priority: 1. Access Token V2, 2. PAT // Priority: 1. Access Token V2, 2. PAT
// Returns nil if no valid credentials are provided. // Returns nil if no valid credentials are provided.
......
...@@ -81,3 +81,19 @@ func GetUserClaims(ctx context.Context) *UserClaims { ...@@ -81,3 +81,19 @@ func GetUserClaims(ctx context.Context) *UserClaims {
func SetUserClaimsInContext(ctx context.Context, claims *UserClaims) context.Context { func SetUserClaimsInContext(ctx context.Context, claims *UserClaims) context.Context {
return context.WithValue(ctx, UserClaimsContextKey, claims) return context.WithValue(ctx, UserClaimsContextKey, claims)
} }
// ApplyToContext sets the authenticated identity from an AuthResult into the context.
// This is the canonical way to propagate auth state after a successful Authenticate call.
// Safe to call with a nil result (no-op).
func ApplyToContext(ctx context.Context, result *AuthResult) context.Context {
if result == nil {
return ctx
}
if result.Claims != nil {
ctx = SetUserClaimsInContext(ctx, result.Claims)
ctx = context.WithValue(ctx, UserIDContextKey, result.Claims.UserID)
} else if result.User != nil {
ctx = SetUserInContext(ctx, result.User, result.AccessToken)
}
return ctx
}
...@@ -222,17 +222,7 @@ func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { ...@@ -222,17 +222,7 @@ func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required")) return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required"))
} }
// Set context based on auth result ctx = auth.ApplyToContext(ctx, result)
if result != nil {
if result.Claims != nil {
// Access Token V2 - stateless, use claims
ctx = auth.SetUserClaimsInContext(ctx, result.Claims)
ctx = context.WithValue(ctx, auth.UserIDContextKey, result.Claims.UserID)
} else if result.User != nil {
// PAT - have full user
ctx = auth.SetUserInContext(ctx, result.User, result.AccessToken)
}
}
return next(ctx, req) return next(ctx, req)
} }
......
...@@ -73,16 +73,9 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech ...@@ -73,16 +73,9 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
return return
} }
// Set context based on auth result (may be nil for public endpoints) // Apply auth result to context (no-op when result is nil for public endpoints)
if result != nil { if result != nil {
if result.Claims != nil { ctx = auth.ApplyToContext(ctx, result)
// Access Token V2 - stateless, use claims
ctx = auth.SetUserClaimsInContext(ctx, result.Claims)
ctx = context.WithValue(ctx, auth.UserIDContextKey, result.Claims.UserID)
} else if result.User != nil {
// PAT - have full user
ctx = auth.SetUserInContext(ctx, result.User, result.AccessToken)
}
r = r.WithContext(ctx) r = r.WithContext(ctx)
} }
......
...@@ -515,58 +515,9 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c *ec ...@@ -515,58 +515,9 @@ func (s *FileServerService) checkAttachmentPermission(ctx context.Context, c *ec
// getCurrentUser retrieves the current authenticated user from the request. // getCurrentUser retrieves the current authenticated user from the request.
// Authentication priority: Bearer token (Access Token V2 or PAT) > Refresh token cookie. // Authentication priority: Bearer token (Access Token V2 or PAT) > Refresh token cookie.
func (s *FileServerService) getCurrentUser(ctx context.Context, c *echo.Context) (*store.User, error) { func (s *FileServerService) getCurrentUser(ctx context.Context, c *echo.Context) (*store.User, error) {
// Try Bearer token authentication. authHeader := c.Request().Header.Get(echo.HeaderAuthorization)
if authHeader := c.Request().Header.Get(echo.HeaderAuthorization); authHeader != "" { cookieHeader := c.Request().Header.Get("Cookie")
if user, err := s.authenticateByBearerToken(ctx, authHeader); err == nil && user != nil { return s.authenticator.AuthenticateToUser(ctx, authHeader, cookieHeader)
return user, nil
}
}
// Fallback: Try refresh token cookie.
if cookieHeader := c.Request().Header.Get("Cookie"); cookieHeader != "" {
if user, err := s.authenticateByRefreshToken(ctx, cookieHeader); err == nil && user != nil {
return user, nil
}
}
return nil, nil
}
// authenticateByBearerToken authenticates using Authorization header.
func (s *FileServerService) authenticateByBearerToken(ctx context.Context, authHeader string) (*store.User, error) {
token := auth.ExtractBearerToken(authHeader)
if token == "" {
return nil, nil
}
// Try Access Token V2 (stateless JWT).
if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
claims, err := s.authenticator.AuthenticateByAccessTokenV2(token)
if err == nil && claims != nil {
return s.Store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
}
}
// Try Personal Access Token (stateful).
if strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) {
user, _, err := s.authenticator.AuthenticateByPAT(ctx, token)
if err == nil {
return user, nil
}
}
return nil, nil
}
// authenticateByRefreshToken authenticates using refresh token cookie.
func (s *FileServerService) authenticateByRefreshToken(ctx context.Context, cookieHeader string) (*store.User, error) {
refreshToken := auth.ExtractRefreshTokenFromCookie(cookieHeader)
if refreshToken == "" {
return nil, nil
}
user, _, err := s.authenticator.AuthenticateByRefreshToken(ctx, refreshToken)
return user, err
} }
// getUserByIdentifier finds a user by either ID or username. // getUserByIdentifier finds a user by either ID or username.
......
...@@ -2,8 +2,6 @@ package mcp ...@@ -2,8 +2,6 @@ package mcp
import ( import (
"net/http" "net/http"
"strings"
"time"
"github.com/labstack/echo/v5" "github.com/labstack/echo/v5"
...@@ -11,23 +9,21 @@ import ( ...@@ -11,23 +9,21 @@ import (
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func newAuthMiddleware(s *store.Store) echo.MiddlewareFunc { func newAuthMiddleware(s *store.Store, secret string) echo.MiddlewareFunc {
authenticator := auth.NewAuthenticator(s, secret)
return func(next echo.HandlerFunc) echo.HandlerFunc { return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error { return func(c *echo.Context) error {
token := auth.ExtractBearerToken(c.Request().Header.Get("Authorization")) token := auth.ExtractBearerToken(c.Request().Header.Get("Authorization"))
if !strings.HasPrefix(token, auth.PersonalAccessTokenPrefix) { if token == "" {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "a personal access token is required"}) return c.JSON(http.StatusUnauthorized, map[string]string{"message": "a personal access token is required"})
} }
result, err := s.GetUserByPATHash(c.Request().Context(), auth.HashPersonalAccessToken(token)) user, pat, err := authenticator.AuthenticateByPAT(c.Request().Context(), token)
if err != nil || result == nil { if err != nil || user == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"})
}
if result.PAT.ExpiresAt != nil && result.PAT.ExpiresAt.AsTime().Before(time.Now()) {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"}) return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"})
} }
ctx := auth.SetUserInContext(c.Request().Context(), result.User, result.PAT.GetTokenId()) ctx := auth.SetUserInContext(c.Request().Context(), user, pat.GetTokenId())
c.SetRequest(c.Request().WithContext(ctx)) c.SetRequest(c.Request().WithContext(ctx))
return next(c) return next(c)
} }
......
...@@ -9,11 +9,12 @@ import ( ...@@ -9,11 +9,12 @@ import (
) )
type MCPService struct { type MCPService struct {
store *store.Store store *store.Store
secret string
} }
func NewMCPService(store *store.Store) *MCPService { func NewMCPService(store *store.Store, secret string) *MCPService {
return &MCPService{store: store} return &MCPService{store: store, secret: secret}
} }
func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
...@@ -26,6 +27,6 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { ...@@ -26,6 +27,6 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{ mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"}, AllowOrigins: []string{"*"},
})) }))
mcpGroup.Use(newAuthMiddleware(s.store)) mcpGroup.Use(newAuthMiddleware(s.store, s.secret))
mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler)) mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler))
} }
...@@ -81,7 +81,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store ...@@ -81,7 +81,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
} }
// Register MCP server. // Register MCP server.
mcpService := mcprouter.NewMCPService(s.Store) mcpService := mcprouter.NewMCPService(s.Store, s.Secret)
mcpService.RegisterRoutes(echoServer) mcpService.RegisterRoutes(echoServer)
return s, nil return s, nil
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment