Commit 09afa579 authored by Johnny's avatar Johnny

chore: implement session sliding expiration and JWT authentication

- Added UpdateSessionLastAccessed method to update session access time.
- Enhanced Authenticate method to support both session cookie and JWT token authentication.
- Introduced AuthResult struct to encapsulate authentication results.
- Added SetUserInContext function to simplify context management for authenticated users.

refactor(auth): streamline gRPC and HTTP authentication

- Removed gRPC authentication interceptor and replaced it with a unified approach using GatewayAuthMiddleware for HTTP requests.
- Updated Connect interceptors to utilize the new authentication logic.
- Consolidated public and admin-only method checks into service layer for better maintainability.

chore(api): clean up unused code and improve documentation

- Removed deprecated logger interceptor and unused gRPC server code.
- Updated ACL configuration documentation for clarity on public and admin-only methods.
- Enhanced metadata handling in Connect RPC to ensure consistent header access.

fix(server): simplify server startup and shutdown process

- Eliminated cmux dependency for handling HTTP and gRPC traffic.
- Streamlined server initialization and shutdown logic for better performance and readability.
parent 65a19df4
...@@ -13,7 +13,6 @@ require ( ...@@ -13,7 +13,6 @@ require (
github.com/google/cel-go v0.26.1 github.com/google/cel-go v0.26.1
github.com/google/uuid v1.6.0 github.com/google/uuid v1.6.0
github.com/gorilla/feeds v1.2.0 github.com/gorilla/feeds v1.2.0
github.com/grpc-ecosystem/go-grpc-middleware v1.4.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/labstack/echo/v4 v4.13.4 github.com/labstack/echo/v4 v4.13.4
...@@ -85,7 +84,6 @@ require ( ...@@ -85,7 +84,6 @@ require (
github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/soheilhy/cmux v0.1.5
github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect
github.com/valyala/fasttemplate v1.2.2 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect
golang.org/x/sys v0.36.0 // indirect golang.org/x/sys v0.36.0 // indirect
......
This diff is collapsed.
...@@ -195,3 +195,46 @@ func validateAccessToken(token string, tokens []*storepb.AccessTokensUserSetting ...@@ -195,3 +195,46 @@ func validateAccessToken(token string, tokens []*storepb.AccessTokensUserSetting
} }
return false return false
} }
// UpdateSessionLastAccessed updates the last accessed time for a session.
// This implements sliding expiration - sessions remain valid as long as they're used.
// Should be called after successful session-based authentication.
func (a *Authenticator) UpdateSessionLastAccessed(ctx context.Context, userID int32, sessionID string) {
// Fire-and-forget update; failures are logged but don't block the request
_ = a.store.UpdateUserSessionLastAccessed(ctx, userID, sessionID, timestamppb.Now())
}
// AuthResult contains the result of an authentication attempt.
type AuthResult struct {
User *store.User
SessionID string // Non-empty if authenticated via session cookie
AccessToken string // Non-empty if authenticated via JWT
}
// Authenticate tries to authenticate using the provided credentials.
// It tries session cookie first, then JWT token.
// Returns nil if no valid credentials are provided.
// On successful session auth, it also updates the session sliding expiration.
func (a *Authenticator) Authenticate(ctx context.Context, sessionCookie, authHeader string) *AuthResult {
// Try session cookie authentication first
if sessionCookie != "" {
user, err := a.AuthenticateBySession(ctx, sessionCookie)
if err == nil && user != nil {
_, sessionID, parseErr := ParseSessionCookieValue(sessionCookie)
if parseErr == nil && sessionID != "" {
a.UpdateSessionLastAccessed(ctx, user.ID, sessionID)
}
return &AuthResult{User: user, SessionID: sessionID}
}
}
// Try JWT token authentication
if token := ExtractBearerToken(authHeader); token != "" {
user, err := a.AuthenticateByJWT(ctx, token)
if err == nil && user != nil {
return &AuthResult{User: user, AccessToken: token}
}
}
return nil
}
package auth package auth
import "context" import (
"context"
"github.com/usememos/memos/store"
)
// ContextKey is the key type for context values. // ContextKey is the key type for context values.
// Using a custom type prevents collisions with other packages. // Using a custom type prevents collisions with other packages.
...@@ -47,3 +51,22 @@ func GetAccessToken(ctx context.Context) string { ...@@ -47,3 +51,22 @@ func GetAccessToken(ctx context.Context) string {
} }
return "" return ""
} }
// SetUserInContext sets the authenticated user's information in the context.
// This is a simpler alternative to AuthorizeAndSetContext for cases where
// authorization is handled separately (e.g., HTTP middleware).
//
// Parameters:
// - user: The authenticated user
// - sessionID: Set if authenticated via session cookie (empty string otherwise)
// - accessToken: Set if authenticated via JWT token (empty string otherwise)
func SetUserInContext(ctx context.Context, user *store.User, sessionID, accessToken string) context.Context {
ctx = context.WithValue(ctx, UserIDContextKey, user.ID)
if sessionID != "" {
ctx = context.WithValue(ctx, SessionIDContextKey, sessionID)
}
if accessToken != "" {
ctx = context.WithValue(ctx, AccessTokenContextKey, accessToken)
}
return ctx
}
package v1
// gRPC Authentication Interceptor
//
// This file implements the authentication interceptor for gRPC requests.
// It extracts credentials from gRPC metadata and delegates to the shared Authenticator.
//
// Authentication flow:
// 1. Extract session cookie or bearer token from metadata
// 2. Validate credentials using Authenticator
// 3. Check authorization (admin-only methods)
// 4. Set user context and proceed with request
//
// For public methods (defined in acl_config.go), authentication is skipped.
import (
"context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
// GRPCAuthInterceptor is the authentication interceptor for gRPC server.
// It validates incoming requests and sets user context for authenticated requests.
type GRPCAuthInterceptor struct {
authenticator *auth.Authenticator
}
// NewGRPCAuthInterceptor creates a new gRPC authentication interceptor.
func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
return &GRPCAuthInterceptor{
authenticator: auth.NewAuthenticator(store, secret),
}
}
// AuthenticationInterceptor is the unary interceptor for gRPC API.
//
// Authentication strategy (in priority order):
// 1. Session Cookie: "user_session" cookie with format "{userID}-{sessionID}"
// 2. Bearer Token: "Authorization: Bearer {jwt_token}" header
// 3. Public Methods: Allow without auth if method is in public allowlist
// 4. Reject: Return Unauthenticated error
//
// On successful authentication, context values are set:
// - auth.UserIDContextKey: The authenticated user's ID
// - auth.SessionIDContextKey: Session ID (cookie auth only)
// - auth.AccessTokenContextKey: JWT token (bearer auth only).
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
// If metadata is missing, only allow public methods
if IsPublicMethod(serverInfo.FullMethod) {
return handler(ctx, request)
}
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
}
// Try session cookie authentication
if sessionCookie := extractSessionCookieFromMetadata(md); sessionCookie != "" {
user, err := in.authenticator.AuthenticateBySession(ctx, sessionCookie)
if err == nil && user != nil {
_, sessionID, err := auth.ParseSessionCookieValue(sessionCookie)
if err != nil {
// This should not happen since AuthenticateBySession already validated the cookie
// but handle it gracefully anyway
sessionID = ""
}
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, serverInfo.FullMethod, user, sessionID, "", IsAdminOnlyMethod)
if err != nil {
return nil, toGRPCError(err, codes.PermissionDenied)
}
return handler(ctx, request)
}
}
// Try bearer token authentication
if token := extractBearerTokenFromMetadata(md); token != "" {
user, err := in.authenticator.AuthenticateByJWT(ctx, token)
if err == nil && user != nil {
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, serverInfo.FullMethod, user, "", token, IsAdminOnlyMethod)
if err != nil {
return nil, toGRPCError(err, codes.PermissionDenied)
}
return handler(ctx, request)
}
}
// Allow public methods without authentication
if IsPublicMethod(serverInfo.FullMethod) {
return handler(ctx, request)
}
return nil, status.Errorf(codes.Unauthenticated, "authentication required")
}
// toGRPCError converts an error to a gRPC status error with the given code.
// If the error is already a gRPC status error, it is returned as-is.
func toGRPCError(err error, code codes.Code) error {
if err == nil {
return nil
}
if _, ok := status.FromError(err); ok {
return err
}
return status.Errorf(code, "%v", err)
}
// extractSessionCookieFromMetadata extracts the session cookie value from gRPC metadata.
// Checks both "grpcgateway-cookie" (from gRPC-Gateway) and "cookie" (native gRPC).
// Returns empty string if no session cookie is found.
func extractSessionCookieFromMetadata(md metadata.MD) string {
// gRPC-Gateway puts cookies in "grpcgateway-cookie", native gRPC uses "cookie"
for _, cookieHeader := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
if cookie := auth.ExtractSessionCookieFromHeader(cookieHeader); cookie != "" {
return cookie
}
}
return ""
}
// extractBearerTokenFromMetadata extracts JWT token from Authorization header in gRPC metadata.
// Returns empty string if no valid bearer token is found.
func extractBearerTokenFromMetadata(md metadata.MD) string {
authHeaders := md.Get("Authorization")
if len(authHeaders) == 0 {
return ""
}
return auth.ExtractBearerToken(authHeaders[0])
}
package v1 package v1
// Access Control List (ACL) Configuration // PublicMethods defines API endpoints that don't require authentication.
// All other endpoints require a valid session or access token.
// //
// This file defines which API methods require authentication and which require admin privileges. // This is the SINGLE SOURCE OF TRUTH for public endpoints.
// Used by both gRPC and Connect interceptors to enforce access control. // Both Connect interceptor and gRPC-Gateway interceptor use this map.
// //
// Method names follow the gRPC full method format: "/{package}.{service}/{method}" // Format: Full gRPC procedure path as returned by req.Spec().Procedure (Connect)
// Example: "/memos.api.v1.MemoService/CreateMemo" // or info.FullMethod (gRPC interceptor).
var PublicMethods = map[string]struct{}{
// publicMethods lists methods that can be called without authentication. // Auth Service - login flow must be accessible without auth
// These are typically read-only endpoints for public content or login-related endpoints. "/memos.api.v1.AuthService/CreateSession": {},
var publicMethods = map[string]bool{ "/memos.api.v1.AuthService/GetCurrentSession": {},
// Instance info - needed before login
"/memos.api.v1.InstanceService/GetInstanceProfile": true, // Instance Service - needed before login to show instance info
"/memos.api.v1.InstanceService/GetInstanceSetting": true, "/memos.api.v1.InstanceService/GetInstanceProfile": {},
"/memos.api.v1.InstanceService/GetInstanceSetting": {},
// Auth - login/session endpoints
"/memos.api.v1.AuthService/CreateSession": true, // User Service - public user profiles and stats
"/memos.api.v1.AuthService/GetCurrentSession": true, "/memos.api.v1.UserService/GetUser": {},
"/memos.api.v1.UserService/GetUserAvatar": {},
// User - public user info and registration "/memos.api.v1.UserService/GetUserStats": {},
"/memos.api.v1.UserService/CreateUser": true, // Registration (also admin-only when not first user) "/memos.api.v1.UserService/ListAllUserStats": {},
"/memos.api.v1.UserService/GetUser": true, "/memos.api.v1.UserService/SearchUsers": {},
"/memos.api.v1.UserService/GetUserAvatar": true,
"/memos.api.v1.UserService/GetUserStats": true, // Identity Provider Service - SSO buttons on login page
"/memos.api.v1.UserService/ListAllUserStats": true, "/memos.api.v1.IdentityProviderService/ListIdentityProviders": {},
"/memos.api.v1.UserService/SearchUsers": true,
// Memo Service - public memos (visibility filtering done in service layer)
// Identity providers - needed for SSO login "/memos.api.v1.MemoService/GetMemo": {},
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": true, "/memos.api.v1.MemoService/ListMemos": {},
// Memo - public memo access
"/memos.api.v1.MemoService/GetMemo": true,
"/memos.api.v1.MemoService/ListMemos": true,
// Attachment - public attachment access
"/memos.api.v1.AttachmentService/GetAttachmentBinary": true,
}
// adminOnlyMethods lists methods that require admin (Host or Admin role) privileges.
// Regular users cannot call these methods even if authenticated.
var adminOnlyMethods = map[string]bool{
"/memos.api.v1.UserService/CreateUser": true, // Admin creates users (except first user registration)
"/memos.api.v1.InstanceService/UpdateInstanceSetting": true,
}
// IsPublicMethod returns true if the method can be called without authentication.
func IsPublicMethod(fullMethodName string) bool {
return publicMethods[fullMethodName]
} }
// IsAdminOnlyMethod returns true if the method requires admin privileges. // IsPublicMethod checks if a procedure path is public (no authentication required).
func IsAdminOnlyMethod(fullMethodName string) bool { // Returns true for public methods, false for protected methods.
return adminOnlyMethods[fullMethodName] func IsPublicMethod(procedure string) bool {
_, ok := PublicMethods[procedure]
return ok
} }
package v1
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestPublicMethodsArePublic verifies that methods in PublicMethods are recognized as public.
func TestPublicMethodsArePublic(t *testing.T) {
publicMethods := []string{
// Auth Service
"/memos.api.v1.AuthService/CreateSession",
"/memos.api.v1.AuthService/GetCurrentSession",
// Instance Service
"/memos.api.v1.InstanceService/GetInstanceProfile",
"/memos.api.v1.InstanceService/GetInstanceSetting",
// User Service
"/memos.api.v1.UserService/GetUser",
"/memos.api.v1.UserService/GetUserAvatar",
"/memos.api.v1.UserService/GetUserStats",
"/memos.api.v1.UserService/ListAllUserStats",
"/memos.api.v1.UserService/SearchUsers",
// Identity Provider Service
"/memos.api.v1.IdentityProviderService/ListIdentityProviders",
// Memo Service
"/memos.api.v1.MemoService/GetMemo",
"/memos.api.v1.MemoService/ListMemos",
}
for _, method := range publicMethods {
t.Run(method, func(t *testing.T) {
assert.True(t, IsPublicMethod(method), "Expected %s to be public", method)
})
}
}
// TestProtectedMethodsRequireAuth verifies that non-public methods are recognized as protected.
func TestProtectedMethodsRequireAuth(t *testing.T) {
protectedMethods := []string{
// Auth Service - logout requires auth
"/memos.api.v1.AuthService/DeleteSession",
// Instance Service - admin operations
"/memos.api.v1.InstanceService/UpdateInstanceSetting",
// User Service - modification operations
"/memos.api.v1.UserService/ListUsers",
"/memos.api.v1.UserService/UpdateUser",
"/memos.api.v1.UserService/DeleteUser",
// Memo Service - write operations
"/memos.api.v1.MemoService/CreateMemo",
"/memos.api.v1.MemoService/UpdateMemo",
"/memos.api.v1.MemoService/DeleteMemo",
// Attachment Service - write operations
"/memos.api.v1.AttachmentService/CreateAttachment",
"/memos.api.v1.AttachmentService/DeleteAttachment",
// Shortcut Service
"/memos.api.v1.ShortcutService/CreateShortcut",
"/memos.api.v1.ShortcutService/ListShortcuts",
"/memos.api.v1.ShortcutService/UpdateShortcut",
"/memos.api.v1.ShortcutService/DeleteShortcut",
// Activity Service
"/memos.api.v1.ActivityService/GetActivity",
}
for _, method := range protectedMethods {
t.Run(method, func(t *testing.T) {
assert.False(t, IsPublicMethod(method), "Expected %s to require auth", method)
})
}
}
// TestUnknownMethodsRequireAuth verifies that unknown methods default to requiring auth.
func TestUnknownMethodsRequireAuth(t *testing.T) {
unknownMethods := []string{
"/unknown.Service/Method",
"/memos.api.v1.UnknownService/Method",
"",
"invalid",
}
for _, method := range unknownMethods {
t.Run(method, func(t *testing.T) {
assert.False(t, IsPublicMethod(method), "Unknown method %q should require auth", method)
})
}
}
// TestPublicMethodsMapConsistency verifies that PublicMethods map matches test expectations.
func TestPublicMethodsMapConsistency(t *testing.T) {
// Ensure the PublicMethods map has the expected number of entries
expectedCount := 13
actualCount := len(PublicMethods)
assert.Equal(t, expectedCount, actualCount,
"PublicMethods map has %d entries, expected %d. Update this test if public methods changed intentionally.",
actualCount, expectedCount)
}
...@@ -2,17 +2,62 @@ package v1 ...@@ -2,17 +2,62 @@ package v1
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"log/slog" "log/slog"
"runtime/debug" "runtime/debug"
"connectrpc.com/connect" "connectrpc.com/connect"
"github.com/pkg/errors" pkgerrors "github.com/pkg/errors"
"google.golang.org/grpc/metadata"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
// MetadataInterceptor converts Connect HTTP headers to gRPC metadata.
//
// This ensures service methods can use metadata.FromIncomingContext() to access
// headers like User-Agent, X-Forwarded-For, etc., regardless of whether the
// request came via Connect RPC or gRPC-Gateway.
type MetadataInterceptor struct{}
// NewMetadataInterceptor creates a new metadata interceptor.
func NewMetadataInterceptor() *MetadataInterceptor {
return &MetadataInterceptor{}
}
func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
// Convert HTTP headers to gRPC metadata
header := req.Header()
md := metadata.MD{}
// Copy important headers for client info extraction
if ua := header.Get("User-Agent"); ua != "" {
md.Set("user-agent", ua)
}
if xff := header.Get("X-Forwarded-For"); xff != "" {
md.Set("x-forwarded-for", xff)
}
if xri := header.Get("X-Real-Ip"); xri != "" {
md.Set("x-real-ip", xri)
}
// Set metadata in context so services can use metadata.FromIncomingContext()
ctx = metadata.NewIncomingContext(ctx, md)
return next(ctx, req)
}
}
func (*MetadataInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) connect.StreamingClientFunc {
return next
}
func (*MetadataInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next
}
// LoggingInterceptor logs Connect RPC requests with appropriate log levels. // LoggingInterceptor logs Connect RPC requests with appropriate log levels.
// //
// Log levels: // Log levels:
...@@ -61,7 +106,7 @@ func (*LoggingInterceptor) classifyError(err error) (slog.Level, string) { ...@@ -61,7 +106,7 @@ func (*LoggingInterceptor) classifyError(err error) (slog.Level, string) {
} }
var connectErr *connect.Error var connectErr *connect.Error
if !errors.As(err, &connectErr) { if !pkgerrors.As(err, &connectErr) {
return slog.LevelError, "unknown error" return slog.LevelError, "unknown error"
} }
...@@ -99,7 +144,7 @@ func (in *RecoveryInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFu ...@@ -99,7 +144,7 @@ func (in *RecoveryInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFu
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
in.logPanic(req.Spec().Procedure, r) in.logPanic(req.Spec().Procedure, r)
err = connect.NewError(connect.CodeInternal, errors.New("internal server error")) err = connect.NewError(connect.CodeInternal, pkgerrors.New("internal server error"))
} }
}() }()
return next(ctx, req) return next(ctx, req)
...@@ -127,9 +172,8 @@ func (in *RecoveryInterceptor) logPanic(procedure string, panicValue any) { ...@@ -127,9 +172,8 @@ func (in *RecoveryInterceptor) logPanic(procedure string, panicValue any) {
// AuthInterceptor handles authentication for Connect handlers. // AuthInterceptor handles authentication for Connect handlers.
// //
// It reuses the same authentication logic as GRPCAuthInterceptor by delegating // It enforces authentication for all endpoints except those listed in PublicMethods.
// to a shared Authenticator instance. This ensures consistent authentication // Role-based authorization (admin checks) remains in the service layer.
// behavior across both gRPC and Connect protocols.
type AuthInterceptor struct { type AuthInterceptor struct {
authenticator *auth.Authenticator authenticator *auth.Authenticator
} }
...@@ -143,45 +187,23 @@ func NewAuthInterceptor(store *store.Store, secret string) *AuthInterceptor { ...@@ -143,45 +187,23 @@ func NewAuthInterceptor(store *store.Store, secret string) *AuthInterceptor {
func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc { func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) { return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
procedure := req.Spec().Procedure
header := req.Header() header := req.Header()
sessionCookie := auth.ExtractSessionCookieFromHeader(header.Get("Cookie"))
authHeader := header.Get("Authorization")
// Try session cookie authentication first result := in.authenticator.Authenticate(ctx, sessionCookie, authHeader)
if sessionCookie := auth.ExtractSessionCookieFromHeader(header.Get("Cookie")); sessionCookie != "" {
user, err := in.authenticator.AuthenticateBySession(ctx, sessionCookie)
if err == nil && user != nil {
_, sessionID, err := auth.ParseSessionCookieValue(sessionCookie)
if err != nil {
// This should not happen since AuthenticateBySession already validated the cookie
// but handle it gracefully anyway
sessionID = ""
}
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, procedure, user, sessionID, "", IsAdminOnlyMethod)
if err != nil {
return nil, convertAuthError(err)
}
return next(ctx, req)
}
}
// Try JWT token authentication // Enforce authentication for non-public methods
if accessToken := auth.ExtractBearerToken(header.Get("Authorization")); accessToken != "" { if result == nil && !IsPublicMethod(req.Spec().Procedure) {
user, err := in.authenticator.AuthenticateByJWT(ctx, accessToken) return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required"))
if err == nil && user != nil {
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, procedure, user, "", accessToken, IsAdminOnlyMethod)
if err != nil {
return nil, convertAuthError(err)
}
return next(ctx, req)
}
} }
// Allow public methods without authentication // Set user in context (may be nil for public endpoints)
if IsPublicMethod(procedure) { if result != nil {
return next(ctx, req) ctx = auth.SetUserInContext(ctx, result.User, result.SessionID, result.AccessToken)
} }
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required")) return next(ctx, req)
} }
} }
...@@ -192,17 +214,3 @@ func (*AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) co ...@@ -192,17 +214,3 @@ func (*AuthInterceptor) WrapStreamingClient(next connect.StreamingClientFunc) co
func (*AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc { func (*AuthInterceptor) WrapStreamingHandler(next connect.StreamingHandlerFunc) connect.StreamingHandlerFunc {
return next return next
} }
// convertAuthError converts authentication/authorization errors to Connect errors.
func convertAuthError(err error) error {
if err == nil {
return nil
}
// Check if it's already a Connect error
var connectErr *connect.Error
if errors.As(err, &connectErr) {
return err
}
// Default to permission denied for auth errors
return connect.NewError(connect.CodePermissionDenied, err)
}
package v1
import (
"context"
"fmt"
"log/slog"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type LoggerInterceptor struct {
logStacktrace bool
}
func NewLoggerInterceptor(logStacktrace bool) *LoggerInterceptor {
return &LoggerInterceptor{logStacktrace: logStacktrace}
}
func (in *LoggerInterceptor) LoggerInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
resp, err := handler(ctx, request)
in.loggerInterceptorDo(ctx, serverInfo.FullMethod, err)
return resp, err
}
func (in *LoggerInterceptor) loggerInterceptorDo(ctx context.Context, fullMethod string, err error) {
st := status.Convert(err)
var logLevel slog.Level
var logMsg string
switch st.Code() {
case codes.OK:
logLevel = slog.LevelInfo
logMsg = "OK"
case codes.Unauthenticated, codes.OutOfRange, codes.PermissionDenied, codes.NotFound:
logLevel = slog.LevelInfo
logMsg = "client error"
case codes.Internal, codes.Unknown, codes.DataLoss, codes.Unavailable, codes.DeadlineExceeded:
logLevel = slog.LevelError
logMsg = "server error"
default:
logLevel = slog.LevelError
logMsg = "unknown error"
}
logAttrs := []slog.Attr{slog.String("method", fullMethod)}
if err != nil {
logAttrs = append(logAttrs, slog.String("error", err.Error()))
if in.logStacktrace {
logAttrs = append(logAttrs, slog.String("stacktrace", fmt.Sprintf("%v", err)))
}
}
slog.LogAttrs(ctx, logLevel, logMsg, logAttrs...)
}
...@@ -2,8 +2,6 @@ package v1 ...@@ -2,8 +2,6 @@ package v1
import ( import (
"context" "context"
"fmt"
"math"
"net/http" "net/http"
"connectrpc.com/connect" "connectrpc.com/connect"
...@@ -11,20 +9,15 @@ import ( ...@@ -11,20 +9,15 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware" "github.com/labstack/echo/v4/middleware"
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/reflection"
"github.com/usememos/memos/internal/profile" "github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/plugin/markdown" "github.com/usememos/memos/plugin/markdown"
v1pb "github.com/usememos/memos/proto/gen/api/v1" v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
type APIV1Service struct { type APIV1Service struct {
grpc_health_v1.UnimplementedHealthServer
v1pb.UnimplementedInstanceServiceServer v1pb.UnimplementedInstanceServiceServer
v1pb.UnimplementedAuthServiceServer v1pb.UnimplementedAuthServiceServer
v1pb.UnimplementedUserServiceServer v1pb.UnimplementedUserServiceServer
...@@ -39,82 +32,86 @@ type APIV1Service struct { ...@@ -39,82 +32,86 @@ type APIV1Service struct {
Store *store.Store Store *store.Store
MarkdownService markdown.Service MarkdownService markdown.Service
grpcServer *grpc.Server
// thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion // thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion
thumbnailSemaphore *semaphore.Weighted thumbnailSemaphore *semaphore.Weighted
} }
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store, grpcServer *grpc.Server) *APIV1Service { func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store) *APIV1Service {
grpc.EnableTracing = true
markdownService := markdown.NewService( markdownService := markdown.NewService(
markdown.WithTagExtension(), markdown.WithTagExtension(),
) )
apiv1Service := &APIV1Service{ return &APIV1Service{
Secret: secret, Secret: secret,
Profile: profile, Profile: profile,
Store: store, Store: store,
MarkdownService: markdownService, MarkdownService: markdownService,
grpcServer: grpcServer,
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
} }
grpc_health_v1.RegisterHealthServer(grpcServer, apiv1Service)
v1pb.RegisterInstanceServiceServer(grpcServer, apiv1Service)
v1pb.RegisterAuthServiceServer(grpcServer, apiv1Service)
v1pb.RegisterUserServiceServer(grpcServer, apiv1Service)
v1pb.RegisterMemoServiceServer(grpcServer, apiv1Service)
v1pb.RegisterAttachmentServiceServer(grpcServer, apiv1Service)
v1pb.RegisterShortcutServiceServer(grpcServer, apiv1Service)
v1pb.RegisterActivityServiceServer(grpcServer, apiv1Service)
v1pb.RegisterIdentityProviderServiceServer(grpcServer, apiv1Service)
reflection.Register(grpcServer)
return apiv1Service
} }
// RegisterGateway registers the gRPC-Gateway with the given Echo instance. // RegisterGateway registers the gRPC-Gateway and Connect handlers with the given Echo instance.
func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error { func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Echo) error {
var target string // Auth middleware for gRPC-Gateway - runs after routing, has access to method name.
if len(s.Profile.UNIXSock) == 0 { // Uses the same PublicMethods config as the Connect AuthInterceptor.
addr := s.Profile.Addr authenticator := auth.NewAuthenticator(s.Store, s.Secret)
if addr == "" { gatewayAuthMiddleware := func(next runtime.HandlerFunc) runtime.HandlerFunc {
addr = "localhost" return func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) {
ctx := r.Context()
// Get the RPC method name from context (set by grpc-gateway after routing)
rpcMethod, _ := runtime.RPCMethod(ctx)
// Extract credentials from HTTP headers
var sessionCookie string
if cookie, err := r.Cookie("user_session"); err == nil {
sessionCookie = cookie.Value
} }
target = fmt.Sprintf("%s:%d", addr, s.Profile.Port) authHeader := r.Header.Get("Authorization")
} else {
target = fmt.Sprintf("unix:%s", s.Profile.UNIXSock) result := authenticator.Authenticate(ctx, sessionCookie, authHeader)
// Enforce authentication for non-public methods
if result == nil && !IsPublicMethod(rpcMethod) {
http.Error(w, `{"code": 16, "message": "authentication required"}`, http.StatusUnauthorized)
return
}
// Set user in context (may be nil for public endpoints)
if result != nil {
ctx = auth.SetUserInContext(ctx, result.User, result.SessionID, result.AccessToken)
r = r.WithContext(ctx)
}
next(w, r, pathParams)
} }
conn, err := grpc.NewClient(
target,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(math.MaxInt32)),
)
if err != nil {
return err
} }
gwMux := runtime.NewServeMux() // Create gRPC-Gateway mux with auth middleware.
if err := v1pb.RegisterInstanceServiceHandler(ctx, gwMux, conn); err != nil { gwMux := runtime.NewServeMux(
runtime.WithMiddlewares(gatewayAuthMiddleware),
)
if err := v1pb.RegisterInstanceServiceHandlerServer(ctx, gwMux, s); err != nil {
return err return err
} }
if err := v1pb.RegisterAuthServiceHandler(ctx, gwMux, conn); err != nil { if err := v1pb.RegisterAuthServiceHandlerServer(ctx, gwMux, s); err != nil {
return err return err
} }
if err := v1pb.RegisterUserServiceHandler(ctx, gwMux, conn); err != nil { if err := v1pb.RegisterUserServiceHandlerServer(ctx, gwMux, s); err != nil {
return err return err
} }
if err := v1pb.RegisterMemoServiceHandler(ctx, gwMux, conn); err != nil { if err := v1pb.RegisterMemoServiceHandlerServer(ctx, gwMux, s); err != nil {
return err return err
} }
if err := v1pb.RegisterAttachmentServiceHandler(ctx, gwMux, conn); err != nil { if err := v1pb.RegisterAttachmentServiceHandlerServer(ctx, gwMux, s); err != nil {
return err return err
} }
if err := v1pb.RegisterShortcutServiceHandler(ctx, gwMux, conn); err != nil { if err := v1pb.RegisterShortcutServiceHandlerServer(ctx, gwMux, s); err != nil {
return err return err
} }
if err := v1pb.RegisterActivityServiceHandler(ctx, gwMux, conn); err != nil { if err := v1pb.RegisterActivityServiceHandlerServer(ctx, gwMux, s); err != nil {
return err return err
} }
if err := v1pb.RegisterIdentityProviderServiceHandler(ctx, gwMux, conn); err != nil { if err := v1pb.RegisterIdentityProviderServiceHandlerServer(ctx, gwMux, s); err != nil {
return err return err
} }
gwGroup := echoServer.Group("") gwGroup := echoServer.Group("")
...@@ -127,6 +124,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech ...@@ -127,6 +124,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
// Connect handlers for browser clients (replaces grpc-web). // Connect handlers for browser clients (replaces grpc-web).
logStacktraces := s.Profile.IsDev() logStacktraces := s.Profile.IsDev()
connectInterceptors := connect.WithInterceptors( connectInterceptors := connect.WithInterceptors(
NewMetadataInterceptor(), // Convert HTTP headers to gRPC metadata first
NewLoggingInterceptor(logStacktraces), NewLoggingInterceptor(logStacktraces),
NewRecoveryInterceptor(logStacktraces), NewRecoveryInterceptor(logStacktraces),
NewAuthInterceptor(s.Store, s.Secret), NewAuthInterceptor(s.Store, s.Secret),
......
...@@ -4,20 +4,15 @@ import ( ...@@ -4,20 +4,15 @@ import (
"context" "context"
"fmt" "fmt"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"runtime" "runtime"
"runtime/debug"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
grpcrecovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware" "github.com/labstack/echo/v4/middleware"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/soheilhy/cmux"
"google.golang.org/grpc"
"github.com/usememos/memos/internal/profile" "github.com/usememos/memos/internal/profile"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
...@@ -35,7 +30,6 @@ type Server struct { ...@@ -35,7 +30,6 @@ type Server struct {
Store *store.Store Store *store.Store
echoServer *echo.Echo echoServer *echo.Echo
grpcServer *grpc.Server
runnerCancelFuncs []context.CancelFunc runnerCancelFuncs []context.CancelFunc
} }
...@@ -73,20 +67,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store ...@@ -73,20 +67,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
rootGroup := echoServer.Group("") rootGroup := echoServer.Group("")
// Log full stacktraces if we're in dev apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store)
logStacktraces := profile.IsDev()
grpcServer := grpc.NewServer(
// Override the maximum receiving message size to math.MaxInt32 for uploading large attachments.
grpc.MaxRecvMsgSize(math.MaxInt32),
grpc.ChainUnaryInterceptor(
apiv1.NewLoggerInterceptor(logStacktraces).LoggerInterceptor,
newRecoveryInterceptor(logStacktraces),
apiv1.NewGRPCAuthInterceptor(store, secret).AuthenticationInterceptor,
))
s.grpcServer = grpcServer
apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store, grpcServer)
// Register HTTP file server routes BEFORE gRPC-Gateway to ensure proper range request handling for Safari. // Register HTTP file server routes BEFORE gRPC-Gateway to ensure proper range request handling for Safari.
// This uses native HTTP serving (http.ServeContent) instead of gRPC for video/audio files. // This uses native HTTP serving (http.ServeContent) instead of gRPC for video/audio files.
...@@ -103,26 +84,6 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store ...@@ -103,26 +84,6 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
return s, nil return s, nil
} }
func newRecoveryInterceptor(logStacktraces bool) grpc.UnaryServerInterceptor {
var recoveryOptions []grpcrecovery.Option
if logStacktraces {
recoveryOptions = append(recoveryOptions, grpcrecovery.WithRecoveryHandler(func(p any) error {
if p == nil {
return nil
}
switch val := p.(type) {
case runtime.Error:
return &stacktraceError{err: val, stacktrace: debug.Stack()}
default:
return nil
}
}))
}
return grpcrecovery.UnaryServerInterceptor(recoveryOptions...)
}
func (s *Server) Start(ctx context.Context) error { func (s *Server) Start(ctx context.Context) error {
var address, network string var address, network string
if len(s.Profile.UNIXSock) == 0 { if len(s.Profile.UNIXSock) == 0 {
...@@ -137,25 +98,13 @@ func (s *Server) Start(ctx context.Context) error { ...@@ -137,25 +98,13 @@ func (s *Server) Start(ctx context.Context) error {
return errors.Wrap(err, "failed to listen") return errors.Wrap(err, "failed to listen")
} }
muxServer := cmux.New(listener) // Start Echo server directly (no cmux needed - all traffic is HTTP).
go func() { s.echoServer.Listener = listener
grpcListener := muxServer.MatchWithWriters(cmux.HTTP2MatchHeaderFieldSendSettings("content-type", "application/grpc"))
if err := s.grpcServer.Serve(grpcListener); err != nil {
slog.Error("failed to serve gRPC", "error", err)
}
}()
go func() { go func() {
httpListener := muxServer.Match(cmux.HTTP1Fast(http.MethodPatch)) if err := s.echoServer.Start(address); err != nil && err != http.ErrServerClosed {
s.echoServer.Listener = httpListener
if err := s.echoServer.Start(address); err != nil {
slog.Error("failed to start echo server", "error", err) slog.Error("failed to start echo server", "error", err)
} }
}() }()
go func() {
if err := muxServer.Serve(); err != nil {
slog.Error("mux server listen error", "error", err)
}
}()
s.StartBackgroundRunners(ctx) s.StartBackgroundRunners(ctx)
return nil return nil
...@@ -179,9 +128,6 @@ func (s *Server) Shutdown(ctx context.Context) { ...@@ -179,9 +128,6 @@ func (s *Server) Shutdown(ctx context.Context) {
slog.Error("failed to shutdown server", slog.String("error", err.Error())) slog.Error("failed to shutdown server", slog.String("error", err.Error()))
} }
// Shutdown gRPC server.
s.grpcServer.GracefulStop()
// Close database connection. // Close database connection.
if err := s.Store.Close(); err != nil { if err := s.Store.Close(); err != nil {
slog.Error("failed to close database", slog.String("error", err.Error())) slog.Error("failed to close database", slog.String("error", err.Error()))
...@@ -234,23 +180,3 @@ func (s *Server) getOrUpsertInstanceBasicSetting(ctx context.Context) (*storepb. ...@@ -234,23 +180,3 @@ func (s *Server) getOrUpsertInstanceBasicSetting(ctx context.Context) (*storepb.
} }
return instanceBasicSetting, nil return instanceBasicSetting, nil
} }
// stacktraceError wraps an underlying error and captures the stacktrace. It
// implements fmt.Formatter, so it'll be rendered when invoked by something like
// `fmt.Sprint("%v", err)`.
type stacktraceError struct {
err error
stacktrace []byte
}
func (e *stacktraceError) Error() string {
return e.err.Error()
}
func (e *stacktraceError) Unwrap() error {
return e.err
}
func (e *stacktraceError) Format(f fmt.State, _ rune) {
f.Write(e.stacktrace)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment