Unverified Commit 7932f6d0 authored by Johnny's avatar Johnny Committed by GitHub

refactor: user auth improvements (#5360)

parent 2c2ef537
......@@ -7,6 +7,10 @@ web/dist
# Build artifacts
build/
bin/
memos
# Plan/design documents
docs/plans/
.DS_Store
......
......@@ -11,87 +11,104 @@ import "google/protobuf/timestamp.proto";
option go_package = "gen/api/v1";
service AuthService {
// GetCurrentSession returns the current active session information.
// This method is idempotent and safe, suitable for checking current session state.
rpc GetCurrentSession(GetCurrentSessionRequest) returns (GetCurrentSessionResponse) {
option (google.api.http) = {get: "/api/v1/auth/sessions/current"};
// GetCurrentUser returns the authenticated user's information.
// Validates the access token and returns user details.
// Similar to OIDC's /userinfo endpoint.
rpc GetCurrentUser(GetCurrentUserRequest) returns (GetCurrentUserResponse) {
option (google.api.http) = {get: "/api/v1/auth/me"};
}
// CreateSession authenticates a user and creates a new session.
// Returns the authenticated user information upon successful authentication.
rpc CreateSession(CreateSessionRequest) returns (CreateSessionResponse) {
// SignIn authenticates a user with credentials and returns tokens.
// On success, returns an access token and sets a refresh token cookie.
// Supports password-based and SSO authentication methods.
rpc SignIn(SignInRequest) returns (SignInResponse) {
option (google.api.http) = {
post: "/api/v1/auth/sessions"
post: "/api/v1/auth/signin"
body: "*"
};
}
// DeleteSession terminates the current user session.
// This is an idempotent operation that invalidates the user's authentication.
rpc DeleteSession(DeleteSessionRequest) returns (google.protobuf.Empty) {
option (google.api.http) = {delete: "/api/v1/auth/sessions/current"};
// SignOut terminates the user's authentication.
// Revokes the refresh token and clears the authentication cookie.
rpc SignOut(SignOutRequest) returns (google.protobuf.Empty) {
option (google.api.http) = {post: "/api/v1/auth/signout"};
}
// RefreshToken exchanges a valid refresh token for a new access token.
// The refresh token is read from the HttpOnly cookie.
// Returns a new short-lived access token.
rpc RefreshToken(RefreshTokenRequest) returns (RefreshTokenResponse) {
option (google.api.http) = {
post: "/api/v1/auth/refresh"
body: "*"
};
}
}
message GetCurrentSessionRequest {}
message GetCurrentUserRequest {}
message GetCurrentSessionResponse {
message GetCurrentUserResponse {
// The authenticated user's information.
User user = 1;
// Last time the session was accessed.
// Used for sliding expiration calculation (last_accessed_time + 2 weeks).
google.protobuf.Timestamp last_accessed_at = 2;
}
message CreateSessionRequest {
message SignInRequest {
// Nested message for password-based authentication credentials.
message PasswordCredentials {
// The username to sign in with.
// Required field for password-based authentication.
string username = 1 [(google.api.field_behavior) = REQUIRED];
// The password to sign in with.
// Required field for password-based authentication.
string password = 2 [(google.api.field_behavior) = REQUIRED];
}
// Nested message for SSO authentication credentials.
message SSOCredentials {
// The ID of the SSO provider.
// Required field to identify the SSO provider.
int32 idp_id = 1 [(google.api.field_behavior) = REQUIRED];
// The authorization code from the SSO provider.
// Required field for completing the SSO flow.
string code = 2 [(google.api.field_behavior) = REQUIRED];
// The redirect URI used in the SSO flow.
// Required field for security validation.
string redirect_uri = 3 [(google.api.field_behavior) = REQUIRED];
// The PKCE code verifier for enhanced security (RFC 7636).
// Optional field - if provided, enables PKCE flow protection against authorization code interception.
// Optional - enables PKCE flow protection against authorization code interception.
string code_verifier = 4 [(google.api.field_behavior) = OPTIONAL];
}
// Provide one authentication method (username/password or SSO).
// Required field to specify the authentication method.
// Authentication credentials. Provide one method.
oneof credentials {
// Username and password authentication method.
// Username and password authentication.
PasswordCredentials password_credentials = 1;
// SSO provider authentication method.
// SSO provider authentication.
SSOCredentials sso_credentials = 2;
}
}
message CreateSessionResponse {
// The authenticated user information.
message SignInResponse {
// The authenticated user's information.
User user = 1;
// Last time the session was accessed.
// Used for sliding expiration calculation (last_accessed_time + 2 weeks).
google.protobuf.Timestamp last_accessed_at = 2;
// The short-lived access token for API requests.
// Store in memory only, not in localStorage.
string access_token = 2;
// When the access token expires.
// Client should call RefreshToken before this time.
google.protobuf.Timestamp access_token_expires_at = 3;
}
message DeleteSessionRequest {}
message SignOutRequest {}
message RefreshTokenRequest {}
message RefreshTokenResponse {
// The new short-lived access token.
string access_token = 1;
// When the access token expires.
google.protobuf.Timestamp expires_at = 2;
}
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -19,6 +19,10 @@ message UserSetting {
SHORTCUTS = 4;
// The webhooks of the user.
WEBHOOKS = 5;
// Refresh tokens for the user.
REFRESH_TOKENS = 6;
// Personal access tokens for the user.
PERSONAL_ACCESS_TOKENS = 7;
}
int32 user_id = 1;
......@@ -30,6 +34,8 @@ message UserSetting {
AccessTokensUserSetting access_tokens = 5;
ShortcutsUserSetting shortcuts = 6;
WebhooksUserSetting webhooks = 7;
RefreshTokensUserSetting refresh_tokens = 8;
PersonalAccessTokensUserSetting personal_access_tokens = 9;
}
}
......@@ -83,6 +89,40 @@ message AccessTokensUserSetting {
repeated AccessToken access_tokens = 1;
}
message RefreshTokensUserSetting {
message RefreshToken {
// Unique identifier (matches 'tid' claim in JWT)
string token_id = 1;
// When the token expires
google.protobuf.Timestamp expires_at = 2;
// When the token was created
google.protobuf.Timestamp created_at = 3;
// Client information for session management UI
SessionsUserSetting.ClientInfo client_info = 4;
// Optional description
string description = 5;
}
repeated RefreshToken refresh_tokens = 1;
}
message PersonalAccessTokensUserSetting {
message PersonalAccessToken {
// Unique identifier for this token
string token_id = 1;
// SHA-256 hash of the actual token
string token_hash = 2;
// User-provided description
string description = 3;
// When the token expires (null = never)
google.protobuf.Timestamp expires_at = 4;
// When the token was created
google.protobuf.Timestamp created_at = 5;
// When the token was last used
google.protobuf.Timestamp last_used_at = 6;
}
repeated PersonalAccessToken tokens = 1;
}
message ShortcutsUserSetting {
message Shortcut {
string id = 1;
......
This diff is collapsed.
......@@ -23,6 +23,12 @@ const (
// AccessTokenContextKey stores the JWT token for token-based auth.
// Only set when authenticated via Bearer token.
AccessTokenContextKey
// UserClaimsContextKey stores the claims from access token.
UserClaimsContextKey
// RefreshTokenIDContextKey stores the refresh token ID.
RefreshTokenIDContextKey
)
// GetUserID retrieves the authenticated user's ID from the context.
......@@ -70,3 +76,25 @@ func SetUserInContext(ctx context.Context, user *store.User, sessionID, accessTo
}
return ctx
}
// UserClaims represents authenticated user info from access token.
type UserClaims struct {
UserID int32
Username string
Role string
Status string
}
// GetUserClaims retrieves the user claims from context.
// Returns nil if not authenticated via access token.
func GetUserClaims(ctx context.Context) *UserClaims {
if v, ok := ctx.Value(UserClaimsContextKey).(*UserClaims); ok {
return v
}
return nil
}
// SetUserClaimsInContext sets the user claims in context.
func SetUserClaimsInContext(ctx context.Context, claims *UserClaims) context.Context {
return context.WithValue(ctx, UserClaimsContextKey, claims)
}
......@@ -33,3 +33,16 @@ func ExtractBearerToken(authHeader string) string {
}
return parts[1]
}
// ExtractRefreshTokenFromCookie extracts the refresh token from cookie header.
func ExtractRefreshTokenFromCookie(cookieHeader string) string {
if cookieHeader == "" {
return ""
}
req := &http.Request{Header: http.Header{"Cookie": []string{cookieHeader}}}
cookie, err := req.Cookie(RefreshTokenCookieName)
if err != nil {
return ""
}
return cookie.Value
}
This diff is collapsed.
This diff is collapsed.
......@@ -9,9 +9,9 @@ package v1
// Format: Full gRPC procedure path as returned by req.Spec().Procedure (Connect)
// or info.FullMethod (gRPC interceptor).
var PublicMethods = map[string]struct{}{
// Auth Service - login flow must be accessible without auth
"/memos.api.v1.AuthService/CreateSession": {},
"/memos.api.v1.AuthService/GetCurrentSession": {},
// Auth Service - login/token endpoints must be accessible without auth
"/memos.api.v1.AuthService/SignIn": {},
"/memos.api.v1.AuthService/RefreshToken": {}, // Token refresh uses cookie, must be accessible when access token expired
// Instance Service - needed before login to show instance info
"/memos.api.v1.InstanceService/GetInstanceProfile": {},
......
......@@ -10,8 +10,8 @@ import (
func TestPublicMethodsArePublic(t *testing.T) {
publicMethods := []string{
// Auth Service
"/memos.api.v1.AuthService/CreateSession",
"/memos.api.v1.AuthService/GetCurrentSession",
"/memos.api.v1.AuthService/SignIn",
"/memos.api.v1.AuthService/RefreshToken",
// Instance Service
"/memos.api.v1.InstanceService/GetInstanceProfile",
"/memos.api.v1.InstanceService/GetInstanceSetting",
......@@ -39,8 +39,9 @@ func TestPublicMethodsArePublic(t *testing.T) {
// TestProtectedMethodsRequireAuth verifies that non-public methods are recognized as protected.
func TestProtectedMethodsRequireAuth(t *testing.T) {
protectedMethods := []string{
// Auth Service - logout requires auth
"/memos.api.v1.AuthService/DeleteSession",
// Auth Service - logout and get current user require auth
"/memos.api.v1.AuthService/SignOut",
"/memos.api.v1.AuthService/GetCurrentUser",
// Instance Service - admin operations
"/memos.api.v1.InstanceService/UpdateInstanceSetting",
// User Service - modification operations
......
......@@ -43,7 +43,7 @@ var SupportedThumbnailMimeTypes = []string{
}
func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.CreateAttachmentRequest) (*v1pb.Attachment, error) {
user, err := s.GetCurrentUser(ctx)
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
......@@ -123,7 +123,7 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
}
func (s *APIV1Service) ListAttachments(ctx context.Context, request *v1pb.ListAttachmentsRequest) (*v1pb.ListAttachmentsResponse, error) {
user, err := s.GetCurrentUser(ctx)
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
......@@ -234,7 +234,7 @@ func (s *APIV1Service) DeleteAttachment(ctx context.Context, request *v1pb.Delet
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid attachment id: %v", err)
}
user, err := s.GetCurrentUser(ctx)
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
......
This diff is collapsed.
......@@ -43,6 +43,10 @@ func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc
if xri := header.Get("X-Real-Ip"); xri != "" {
md.Set("x-real-ip", xri)
}
// Forward Cookie header for authentication methods that need it (e.g., RefreshToken)
if cookie := header.Get("Cookie"); cookie != "" {
md.Set("cookie", cookie)
}
// Set metadata in context so services can use metadata.FromIncomingContext()
ctx = metadata.NewIncomingContext(ctx, md)
......@@ -198,9 +202,16 @@ func (in *AuthInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc {
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("authentication required"))
}
// Set user in context (may be nil for public endpoints)
// Set context based on auth result
if result != nil {
ctx = auth.SetUserInContext(ctx, result.User, result.SessionID, result.AccessToken)
if result.Claims != nil {
// Access Token V2 - stateless, use claims
ctx = auth.SetUserClaimsInContext(ctx, result.Claims)
ctx = context.WithValue(ctx, auth.UserIDContextKey, result.Claims.UserID)
} else if result.User != nil {
// PAT or legacy auth - have full user
ctx = auth.SetUserInContext(ctx, result.User, result.SessionID, result.AccessToken)
}
}
return next(ctx, req)
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -14,7 +14,7 @@ import (
)
func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.SetMemoAttachmentsRequest) (*emptypb.Empty, error) {
user, err := s.GetCurrentUser(ctx)
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment