Unverified Commit 92d937b1 authored by memoclaw's avatar memoclaw Committed by GitHub

feat: replace auto-increment ID with UID for identity provider resource names (#5687)

Co-authored-by: 's avatarClaude Opus 4.6 <noreply@anthropic.com>
parent f0c44894
...@@ -64,8 +64,9 @@ message SignInRequest { ...@@ -64,8 +64,9 @@ message SignInRequest {
// Nested message for SSO authentication credentials. // Nested message for SSO authentication credentials.
message SSOCredentials { message SSOCredentials {
// The ID of the SSO provider. // The resource name of the SSO provider.
int32 idp_id = 1 [(google.api.field_behavior) = REQUIRED]; // Format: identity-providers/{uid}
string idp_name = 1 [(google.api.field_behavior) = REQUIRED];
// The authorization code from the SSO provider. // The authorization code from the SSO provider.
string code = 2 [(google.api.field_behavior) = REQUIRED]; string code = 2 [(google.api.field_behavior) = REQUIRED];
......
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.6.0 // - protoc-gen-go-grpc v1.6.1
// - protoc (unknown) // - protoc (unknown)
// source: api/v1/activity_service.proto // source: api/v1/activity_service.proto
......
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.6.0 // - protoc-gen-go-grpc v1.6.1
// - protoc (unknown) // - protoc (unknown)
// source: api/v1/attachment_service.proto // source: api/v1/attachment_service.proto
......
...@@ -440,8 +440,9 @@ func (x *SignInRequest_PasswordCredentials) GetPassword() string { ...@@ -440,8 +440,9 @@ func (x *SignInRequest_PasswordCredentials) GetPassword() string {
// Nested message for SSO authentication credentials. // Nested message for SSO authentication credentials.
type SignInRequest_SSOCredentials struct { type SignInRequest_SSOCredentials struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
// The ID of the SSO provider. // The resource name of the SSO provider.
IdpId int32 `protobuf:"varint,1,opt,name=idp_id,json=idpId,proto3" json:"idp_id,omitempty"` // Format: identity-providers/{uid}
IdpName string `protobuf:"bytes,1,opt,name=idp_name,json=idpName,proto3" json:"idp_name,omitempty"`
// The authorization code from the SSO provider. // The authorization code from the SSO provider.
Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"` Code string `protobuf:"bytes,2,opt,name=code,proto3" json:"code,omitempty"`
// The redirect URI used in the SSO flow. // The redirect URI used in the SSO flow.
...@@ -483,11 +484,11 @@ func (*SignInRequest_SSOCredentials) Descriptor() ([]byte, []int) { ...@@ -483,11 +484,11 @@ func (*SignInRequest_SSOCredentials) Descriptor() ([]byte, []int) {
return file_api_v1_auth_service_proto_rawDescGZIP(), []int{2, 1} return file_api_v1_auth_service_proto_rawDescGZIP(), []int{2, 1}
} }
func (x *SignInRequest_SSOCredentials) GetIdpId() int32 { func (x *SignInRequest_SSOCredentials) GetIdpName() string {
if x != nil { if x != nil {
return x.IdpId return x.IdpName
} }
return 0 return ""
} }
func (x *SignInRequest_SSOCredentials) GetCode() string { func (x *SignInRequest_SSOCredentials) GetCode() string {
...@@ -518,15 +519,15 @@ const file_api_v1_auth_service_proto_rawDesc = "" + ...@@ -518,15 +519,15 @@ const file_api_v1_auth_service_proto_rawDesc = "" +
"\x19api/v1/auth_service.proto\x12\fmemos.api.v1\x1a\x19api/v1/user_service.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\x17\n" + "\x19api/v1/auth_service.proto\x12\fmemos.api.v1\x1a\x19api/v1/user_service.proto\x1a\x1cgoogle/api/annotations.proto\x1a\x1fgoogle/api/field_behavior.proto\x1a\x1bgoogle/protobuf/empty.proto\x1a\x1fgoogle/protobuf/timestamp.proto\"\x17\n" +
"\x15GetCurrentUserRequest\"@\n" + "\x15GetCurrentUserRequest\"@\n" +
"\x16GetCurrentUserResponse\x12&\n" + "\x16GetCurrentUserResponse\x12&\n" +
"\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\"\xce\x03\n" + "\x04user\x18\x01 \x01(\v2\x12.memos.api.v1.UserR\x04user\"\xd2\x03\n" +
"\rSignInRequest\x12d\n" + "\rSignInRequest\x12d\n" +
"\x14password_credentials\x18\x01 \x01(\v2/.memos.api.v1.SignInRequest.PasswordCredentialsH\x00R\x13passwordCredentials\x12U\n" + "\x14password_credentials\x18\x01 \x01(\v2/.memos.api.v1.SignInRequest.PasswordCredentialsH\x00R\x13passwordCredentials\x12U\n" +
"\x0fsso_credentials\x18\x02 \x01(\v2*.memos.api.v1.SignInRequest.SSOCredentialsH\x00R\x0essoCredentials\x1aW\n" + "\x0fsso_credentials\x18\x02 \x01(\v2*.memos.api.v1.SignInRequest.SSOCredentialsH\x00R\x0essoCredentials\x1aW\n" +
"\x13PasswordCredentials\x12\x1f\n" + "\x13PasswordCredentials\x12\x1f\n" +
"\busername\x18\x01 \x01(\tB\x03\xe0A\x02R\busername\x12\x1f\n" + "\busername\x18\x01 \x01(\tB\x03\xe0A\x02R\busername\x12\x1f\n" +
"\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1a\x97\x01\n" + "\bpassword\x18\x02 \x01(\tB\x03\xe0A\x02R\bpassword\x1a\x9b\x01\n" +
"\x0eSSOCredentials\x12\x1a\n" + "\x0eSSOCredentials\x12\x1e\n" +
"\x06idp_id\x18\x01 \x01(\x05B\x03\xe0A\x02R\x05idpId\x12\x17\n" + "\bidp_name\x18\x01 \x01(\tB\x03\xe0A\x02R\aidpName\x12\x17\n" +
"\x04code\x18\x02 \x01(\tB\x03\xe0A\x02R\x04code\x12&\n" + "\x04code\x18\x02 \x01(\tB\x03\xe0A\x02R\x04code\x12&\n" +
"\fredirect_uri\x18\x03 \x01(\tB\x03\xe0A\x02R\vredirectUri\x12(\n" + "\fredirect_uri\x18\x03 \x01(\tB\x03\xe0A\x02R\vredirectUri\x12(\n" +
"\rcode_verifier\x18\x04 \x01(\tB\x03\xe0A\x01R\fcodeVerifierB\r\n" + "\rcode_verifier\x18\x04 \x01(\tB\x03\xe0A\x01R\fcodeVerifierB\r\n" +
......
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.6.0 // - protoc-gen-go-grpc v1.6.1
// - protoc (unknown) // - protoc (unknown)
// source: api/v1/auth_service.proto // source: api/v1/auth_service.proto
......
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.6.0 // - protoc-gen-go-grpc v1.6.1
// - protoc (unknown) // - protoc (unknown)
// source: api/v1/idp_service.proto // source: api/v1/idp_service.proto
......
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.6.0 // - protoc-gen-go-grpc v1.6.1
// - protoc (unknown) // - protoc (unknown)
// source: api/v1/instance_service.proto // source: api/v1/instance_service.proto
......
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.6.0 // - protoc-gen-go-grpc v1.6.1
// - protoc (unknown) // - protoc (unknown)
// source: api/v1/memo_service.proto // source: api/v1/memo_service.proto
......
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.6.0 // - protoc-gen-go-grpc v1.6.1
// - protoc (unknown) // - protoc (unknown)
// source: api/v1/shortcut_service.proto // source: api/v1/shortcut_service.proto
......
// Code generated by protoc-gen-go-grpc. DO NOT EDIT. // Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions: // versions:
// - protoc-gen-go-grpc v1.6.0 // - protoc-gen-go-grpc v1.6.1
// - protoc (unknown) // - protoc (unknown)
// source: api/v1/user_service.proto // source: api/v1/user_service.proto
......
...@@ -2757,15 +2757,16 @@ components: ...@@ -2757,15 +2757,16 @@ components:
description: Nested message for password-based authentication credentials. description: Nested message for password-based authentication credentials.
SignInRequest_SSOCredentials: SignInRequest_SSOCredentials:
required: required:
- idpId - idpName
- code - code
- redirectUri - redirectUri
type: object type: object
properties: properties:
idpId: idpName:
type: integer type: string
description: The ID of the SSO provider. description: |-
format: int32 The resource name of the SSO provider.
Format: identity-providers/{uid}
code: code:
type: string type: string
description: The authorization code from the SSO provider. description: The authorization code from the SSO provider.
......
...@@ -74,6 +74,7 @@ type IdentityProvider struct { ...@@ -74,6 +74,7 @@ type IdentityProvider struct {
Type IdentityProvider_Type `protobuf:"varint,3,opt,name=type,proto3,enum=memos.store.IdentityProvider_Type" json:"type,omitempty"` Type IdentityProvider_Type `protobuf:"varint,3,opt,name=type,proto3,enum=memos.store.IdentityProvider_Type" json:"type,omitempty"`
IdentifierFilter string `protobuf:"bytes,4,opt,name=identifier_filter,json=identifierFilter,proto3" json:"identifier_filter,omitempty"` IdentifierFilter string `protobuf:"bytes,4,opt,name=identifier_filter,json=identifierFilter,proto3" json:"identifier_filter,omitempty"`
Config *IdentityProviderConfig `protobuf:"bytes,5,opt,name=config,proto3" json:"config,omitempty"` Config *IdentityProviderConfig `protobuf:"bytes,5,opt,name=config,proto3" json:"config,omitempty"`
Uid string `protobuf:"bytes,6,opt,name=uid,proto3" json:"uid,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
...@@ -143,6 +144,13 @@ func (x *IdentityProvider) GetConfig() *IdentityProviderConfig { ...@@ -143,6 +144,13 @@ func (x *IdentityProvider) GetConfig() *IdentityProviderConfig {
return nil return nil
} }
func (x *IdentityProvider) GetUid() string {
if x != nil {
return x.Uid
}
return ""
}
type IdentityProviderConfig struct { type IdentityProviderConfig struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
// Types that are valid to be assigned to Config: // Types that are valid to be assigned to Config:
...@@ -373,13 +381,14 @@ var File_store_idp_proto protoreflect.FileDescriptor ...@@ -373,13 +381,14 @@ var File_store_idp_proto protoreflect.FileDescriptor
const file_store_idp_proto_rawDesc = "" + const file_store_idp_proto_rawDesc = "" +
"\n" + "\n" +
"\x0fstore/idp.proto\x12\vmemos.store\"\x82\x02\n" + "\x0fstore/idp.proto\x12\vmemos.store\"\x94\x02\n" +
"\x10IdentityProvider\x12\x0e\n" + "\x10IdentityProvider\x12\x0e\n" +
"\x02id\x18\x01 \x01(\x05R\x02id\x12\x12\n" + "\x02id\x18\x01 \x01(\x05R\x02id\x12\x12\n" +
"\x04name\x18\x02 \x01(\tR\x04name\x126\n" + "\x04name\x18\x02 \x01(\tR\x04name\x126\n" +
"\x04type\x18\x03 \x01(\x0e2\".memos.store.IdentityProvider.TypeR\x04type\x12+\n" + "\x04type\x18\x03 \x01(\x0e2\".memos.store.IdentityProvider.TypeR\x04type\x12+\n" +
"\x11identifier_filter\x18\x04 \x01(\tR\x10identifierFilter\x12;\n" + "\x11identifier_filter\x18\x04 \x01(\tR\x10identifierFilter\x12;\n" +
"\x06config\x18\x05 \x01(\v2#.memos.store.IdentityProviderConfigR\x06config\"(\n" + "\x06config\x18\x05 \x01(\v2#.memos.store.IdentityProviderConfigR\x06config\x12\x10\n" +
"\x03uid\x18\x06 \x01(\tR\x03uid\"(\n" +
"\x04Type\x12\x14\n" + "\x04Type\x12\x14\n" +
"\x10TYPE_UNSPECIFIED\x10\x00\x12\n" + "\x10TYPE_UNSPECIFIED\x10\x00\x12\n" +
"\n" + "\n" +
......
...@@ -15,6 +15,7 @@ message IdentityProvider { ...@@ -15,6 +15,7 @@ message IdentityProvider {
Type type = 3; Type type = 3;
string identifier_filter = 4; string identifier_filter = 4;
IdentityProviderConfig config = 5; IdentityProviderConfig config = 5;
string uid = 6;
} }
message IdentityProviderConfig { message IdentityProviderConfig {
......
...@@ -16,7 +16,6 @@ import ( ...@@ -16,7 +16,6 @@ import (
"time" "time"
"github.com/disintegration/imaging" "github.com/disintegration/imaging"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
...@@ -100,10 +99,9 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat ...@@ -100,10 +99,9 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
return nil, status.Errorf(codes.InvalidArgument, "invalid MIME type format") return nil, status.Errorf(codes.InvalidArgument, "invalid MIME type format")
} }
// Use provided attachment_id or generate a new one attachmentUID, err := ValidateAndGenerateUID(request.AttachmentId)
attachmentUID := request.AttachmentId if err != nil {
if attachmentUID == "" { return nil, err
attachmentUID = shortuuid.New()
} }
create := &store.Attachment{ create := &store.Attachment{
......
...@@ -90,8 +90,12 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) ...@@ -90,8 +90,12 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest)
existingUser = user existingUser = user
} else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil { } else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil {
// Authentication Method 2: SSO (OAuth2) authentication // Authentication Method 2: SSO (OAuth2) authentication
idpUID, err := ExtractIdentityProviderUIDFromName(ssoCredentials.IdpName)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &ssoCredentials.IdpId, UID: &idpUID,
}) })
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err) return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)
......
...@@ -25,7 +25,15 @@ func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb ...@@ -25,7 +25,15 @@ func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb
return nil, status.Errorf(codes.PermissionDenied, "permission denied") return nil, status.Errorf(codes.PermissionDenied, "permission denied")
} }
identityProvider, err := s.Store.CreateIdentityProvider(ctx, convertIdentityProviderToStore(request.IdentityProvider)) idpUID, err := ValidateAndGenerateUID(request.IdentityProviderId)
if err != nil {
return nil, err
}
storeIdp := convertIdentityProviderToStore(request.IdentityProvider)
storeIdp.Uid = idpUID
identityProvider, err := s.Store.CreateIdentityProvider(ctx, storeIdp)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err) return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
} }
...@@ -57,12 +65,12 @@ func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListId ...@@ -57,12 +65,12 @@ func (s *APIV1Service) ListIdentityProviders(ctx context.Context, _ *v1pb.ListId
} }
func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) { func (s *APIV1Service) GetIdentityProvider(ctx context.Context, request *v1pb.GetIdentityProviderRequest) (*v1pb.IdentityProvider, error) {
id, err := ExtractIdentityProviderIDFromName(request.Name) uid, err := ExtractIdentityProviderUIDFromName(request.Name)
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err) return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
} }
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &id, UID: &uid,
}) })
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err) return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
...@@ -98,12 +106,22 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb ...@@ -98,12 +106,22 @@ func (s *APIV1Service) UpdateIdentityProvider(ctx context.Context, request *v1pb
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required") return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
} }
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name) uid, err := ExtractIdentityProviderUIDFromName(request.IdentityProvider.Name)
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err) return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
} }
// Look up the IdP by UID to get the internal ID for update.
existing, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &uid})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
}
if existing == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
update := &store.UpdateIdentityProviderV1{ update := &store.UpdateIdentityProviderV1{
ID: id, ID: existing.Id,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]), Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
} }
for _, field := range request.UpdateMask.Paths { for _, field := range request.UpdateMask.Paths {
...@@ -138,13 +156,13 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb ...@@ -138,13 +156,13 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb
return nil, status.Errorf(codes.PermissionDenied, "permission denied") return nil, status.Errorf(codes.PermissionDenied, "permission denied")
} }
id, err := ExtractIdentityProviderIDFromName(request.Name) uid, err := ExtractIdentityProviderUIDFromName(request.Name)
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err) return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
} }
// Check if the identity provider exists before trying to delete it // Look up the IdP by UID to get the internal ID for deletion.
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &id}) identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &uid})
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to check identity provider existence: %v", err) return nil, status.Errorf(codes.Internal, "failed to check identity provider existence: %v", err)
} }
...@@ -152,7 +170,7 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb ...@@ -152,7 +170,7 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb
return nil, status.Errorf(codes.NotFound, "identity provider not found") return nil, status.Errorf(codes.NotFound, "identity provider not found")
} }
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil { if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: identityProvider.Id}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err) return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
} }
return &emptypb.Empty{}, nil return &emptypb.Empty{}, nil
...@@ -160,7 +178,7 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb ...@@ -160,7 +178,7 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider { func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *v1pb.IdentityProvider {
temp := &v1pb.IdentityProvider{ temp := &v1pb.IdentityProvider{
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id), Name: fmt.Sprintf("%s%s", IdentityProviderNamePrefix, identityProvider.Uid),
Title: identityProvider.Name, Title: identityProvider.Name,
IdentifierFilter: identityProvider.IdentifierFilter, IdentifierFilter: identityProvider.IdentifierFilter,
Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]), Type: v1pb.IdentityProvider_Type(v1pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
...@@ -190,10 +208,7 @@ func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider ...@@ -190,10 +208,7 @@ func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider
} }
func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider { func convertIdentityProviderToStore(identityProvider *v1pb.IdentityProvider) *storepb.IdentityProvider {
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
temp := &storepb.IdentityProvider{ temp := &storepb.IdentityProvider{
Id: id,
Name: identityProvider.Title, Name: identityProvider.Title,
IdentifierFilter: identityProvider.IdentifierFilter, IdentifierFilter: identityProvider.IdentifierFilter,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]), Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
......
...@@ -7,13 +7,11 @@ import ( ...@@ -7,13 +7,11 @@ import (
"strings" "strings"
"time" "time"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/internal/base"
"github.com/usememos/memos/plugin/webhook" "github.com/usememos/memos/plugin/webhook"
v1pb "github.com/usememos/memos/proto/gen/api/v1" v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
...@@ -30,13 +28,9 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR ...@@ -30,13 +28,9 @@ func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoR
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
} }
// Use custom memo_id if provided, otherwise generate a new UUID memoUID, err := ValidateAndGenerateUID(request.MemoId)
memoUID := strings.TrimSpace(request.MemoId) if err != nil {
if memoUID == "" { return nil, err
memoUID = shortuuid.New()
} else if !base.UIDMatcher.MatchString(memoUID) {
// Validate custom memo ID format
return nil, status.Errorf(codes.InvalidArgument, "invalid memo_id format: must be 1-32 characters, alphanumeric and hyphens only, cannot start or end with hyphen")
} }
create := &store.Memo{ create := &store.Memo{
......
...@@ -4,8 +4,12 @@ import ( ...@@ -4,8 +4,12 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/usememos/memos/internal/base"
"github.com/usememos/memos/internal/util" "github.com/usememos/memos/internal/util"
) )
...@@ -133,16 +137,12 @@ func ExtractInboxIDFromName(name string) (int32, error) { ...@@ -133,16 +137,12 @@ func ExtractInboxIDFromName(name string) (int32, error) {
return id, nil return id, nil
} }
func ExtractIdentityProviderIDFromName(name string) (int32, error) { func ExtractIdentityProviderUIDFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix) tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix)
if err != nil { if err != nil {
return 0, err return "", err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid identity provider ID %q", tokens[0])
} }
return id, nil return tokens[0], nil
} }
func ExtractActivityIDFromName(name string) (int32, error) { func ExtractActivityIDFromName(name string) (int32, error) {
...@@ -156,3 +156,17 @@ func ExtractActivityIDFromName(name string) (int32, error) { ...@@ -156,3 +156,17 @@ func ExtractActivityIDFromName(name string) (int32, error) {
} }
return id, nil return id, nil
} }
// ValidateAndGenerateUID validates a user-provided UID or generates a new one.
// If provided is empty, a new shortuuid is generated.
// If provided is non-empty, it is validated against base.UIDMatcher.
func ValidateAndGenerateUID(provided string) (string, error) {
uid := strings.TrimSpace(provided)
if uid == "" {
return shortuuid.New(), nil
}
if !base.UIDMatcher.MatchString(uid) {
return "", status.Errorf(codes.InvalidArgument, "invalid ID format: must be 1-32 characters, alphanumeric and hyphens only, cannot start or end with hyphen")
}
return uid, nil
}
...@@ -11,9 +11,9 @@ import ( ...@@ -11,9 +11,9 @@ import (
) )
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
placeholders := []string{"?", "?", "?", "?"} placeholders := []string{"?", "?", "?", "?", "?"}
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"} fields := []string{"`uid`", "`name`", "`type`", "`identifier_filter`", "`config`"}
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config} args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")" stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")"
result, err := d.db.ExecContext(ctx, stmt, args...) result, err := d.db.ExecContext(ctx, stmt, args...)
...@@ -35,8 +35,11 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity ...@@ -35,8 +35,11 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
if v := find.ID; v != nil { if v := find.ID; v != nil {
where, args = append(where, "`id` = ?"), append(args, *v) where, args = append(where, "`id` = ?"), append(args, *v)
} }
if v := find.UID; v != nil {
where, args = append(where, "`uid` = ?"), append(args, *v)
}
rows, err := d.db.QueryContext(ctx, "SELECT `id`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC", rows, err := d.db.QueryContext(ctx, "SELECT `id`, `uid`, `name`, `type`, `identifier_filter`, `config` FROM `idp` WHERE "+strings.Join(where, " AND ")+" ORDER BY `id` ASC",
args..., args...,
) )
if err != nil { if err != nil {
...@@ -50,6 +53,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity ...@@ -50,6 +53,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
var typeString string var typeString string
if err := rows.Scan( if err := rows.Scan(
&identityProvider.ID, &identityProvider.ID,
&identityProvider.UID,
&identityProvider.Name, &identityProvider.Name,
&typeString, &typeString,
&identityProvider.IdentifierFilter, &identityProvider.IdentifierFilter,
......
...@@ -9,8 +9,8 @@ import ( ...@@ -9,8 +9,8 @@ import (
) )
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
fields := []string{"name", "type", "identifier_filter", "config"} fields := []string{"uid", "name", "type", "identifier_filter", "config"}
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config} args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id" stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil { if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
return nil, err return nil, err
...@@ -25,10 +25,14 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity ...@@ -25,10 +25,14 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
if v := find.ID; v != nil { if v := find.ID; v != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v) where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.UID; v != nil {
where, args = append(where, "uid = "+placeholder(len(args)+1)), append(args, *v)
}
rows, err := d.db.QueryContext(ctx, ` rows, err := d.db.QueryContext(ctx, `
SELECT SELECT
id, id,
uid,
name, name,
type, type,
identifier_filter, identifier_filter,
...@@ -48,6 +52,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity ...@@ -48,6 +52,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
var typeString string var typeString string
if err := rows.Scan( if err := rows.Scan(
&identityProvider.ID, &identityProvider.ID,
&identityProvider.UID,
&identityProvider.Name, &identityProvider.Name,
&typeString, &typeString,
&identityProvider.IdentifierFilter, &identityProvider.IdentifierFilter,
...@@ -83,7 +88,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde ...@@ -83,7 +88,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
UPDATE idp UPDATE idp
SET ` + strings.Join(set, ", ") + ` SET ` + strings.Join(set, ", ") + `
WHERE id = ` + placeholder(len(args)+1) + ` WHERE id = ` + placeholder(len(args)+1) + `
RETURNING id, name, type, identifier_filter, config RETURNING id, uid, name, type, identifier_filter, config
` `
args = append(args, update.ID) args = append(args, update.ID)
...@@ -91,6 +96,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde ...@@ -91,6 +96,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
var typeString string var typeString string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&identityProvider.ID, &identityProvider.ID,
&identityProvider.UID,
&identityProvider.Name, &identityProvider.Name,
&typeString, &typeString,
&identityProvider.IdentifierFilter, &identityProvider.IdentifierFilter,
......
...@@ -10,9 +10,9 @@ import ( ...@@ -10,9 +10,9 @@ import (
) )
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
placeholders := []string{"?", "?", "?", "?"} placeholders := []string{"?", "?", "?", "?", "?"}
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"} fields := []string{"`uid`", "`name`", "`type`", "`identifier_filter`", "`config`"}
args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config} args := []any{create.UID, create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ") RETURNING `id`" stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ") RETURNING `id`"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil { if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
...@@ -28,10 +28,14 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity ...@@ -28,10 +28,14 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
if v := find.ID; v != nil { if v := find.ID; v != nil {
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
} }
if v := find.UID; v != nil {
where, args = append(where, fmt.Sprintf("uid = $%d", len(args)+1)), append(args, *v)
}
rows, err := d.db.QueryContext(ctx, ` rows, err := d.db.QueryContext(ctx, `
SELECT SELECT
id, id,
uid,
name, name,
type, type,
identifier_filter, identifier_filter,
...@@ -51,6 +55,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity ...@@ -51,6 +55,7 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
var typeString string var typeString string
if err := rows.Scan( if err := rows.Scan(
&identityProvider.ID, &identityProvider.ID,
&identityProvider.UID,
&identityProvider.Name, &identityProvider.Name,
&typeString, &typeString,
&identityProvider.IdentifierFilter, &identityProvider.IdentifierFilter,
...@@ -86,12 +91,13 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde ...@@ -86,12 +91,13 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
UPDATE idp UPDATE idp
SET ` + strings.Join(set, ", ") + ` SET ` + strings.Join(set, ", ") + `
WHERE id = ? WHERE id = ?
RETURNING id, name, type, identifier_filter, config RETURNING id, uid, name, type, identifier_filter, config
` `
var identityProvider store.IdentityProvider var identityProvider store.IdentityProvider
var typeString string var typeString string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&identityProvider.ID, &identityProvider.ID,
&identityProvider.UID,
&identityProvider.Name, &identityProvider.Name,
&typeString, &typeString,
&identityProvider.IdentifierFilter, &identityProvider.IdentifierFilter,
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
type IdentityProvider struct { type IdentityProvider struct {
ID int32 ID int32
UID string
Name string Name string
Type storepb.IdentityProvider_Type Type storepb.IdentityProvider_Type
IdentifierFilter string IdentifierFilter string
...@@ -18,7 +19,8 @@ type IdentityProvider struct { ...@@ -18,7 +19,8 @@ type IdentityProvider struct {
} }
type FindIdentityProvider struct { type FindIdentityProvider struct {
ID *int32 ID *int32
UID *string
} }
type UpdateIdentityProvider struct { type UpdateIdentityProvider struct {
...@@ -130,6 +132,7 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti ...@@ -130,6 +132,7 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti
func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityProvider, error) { func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityProvider, error) {
identityProvider := &storepb.IdentityProvider{ identityProvider := &storepb.IdentityProvider{
Id: raw.ID, Id: raw.ID,
Uid: raw.UID,
Name: raw.Name, Name: raw.Name,
Type: raw.Type, Type: raw.Type,
IdentifierFilter: raw.IdentifierFilter, IdentifierFilter: raw.IdentifierFilter,
...@@ -145,6 +148,7 @@ func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityPro ...@@ -145,6 +148,7 @@ func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityPro
func convertIdentityProviderToRaw(identityProvider *storepb.IdentityProvider) (*IdentityProvider, error) { func convertIdentityProviderToRaw(identityProvider *storepb.IdentityProvider) (*IdentityProvider, error) {
raw := &IdentityProvider{ raw := &IdentityProvider{
ID: identityProvider.Id, ID: identityProvider.Id,
UID: identityProvider.Uid,
Name: identityProvider.Name, Name: identityProvider.Name,
Type: identityProvider.Type, Type: identityProvider.Type,
IdentifierFilter: identityProvider.IdentifierFilter, IdentifierFilter: identityProvider.IdentifierFilter,
......
-- Add uid column to idp table
ALTER TABLE `idp` ADD COLUMN `uid` VARCHAR(256) NOT NULL DEFAULT '';
-- Populate uid for existing rows using hex of id as a fallback
UPDATE `idp` SET `uid` = LOWER(LPAD(HEX(`id`), 8, '0')) WHERE `uid` = '';
-- Create unique index on uid
ALTER TABLE `idp` ADD UNIQUE INDEX `idx_idp_uid` (`uid`);
...@@ -80,6 +80,7 @@ CREATE TABLE `activity` ( ...@@ -80,6 +80,7 @@ CREATE TABLE `activity` (
-- idp -- idp
CREATE TABLE `idp` ( CREATE TABLE `idp` (
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY, `id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
`uid` VARCHAR(256) NOT NULL UNIQUE,
`name` TEXT NOT NULL, `name` TEXT NOT NULL,
`type` TEXT NOT NULL, `type` TEXT NOT NULL,
`identifier_filter` VARCHAR(256) NOT NULL DEFAULT '', `identifier_filter` VARCHAR(256) NOT NULL DEFAULT '',
......
-- Add uid column to idp table
ALTER TABLE idp ADD COLUMN uid TEXT NOT NULL DEFAULT '';
-- Populate uid for existing rows using hex of id as a fallback
UPDATE idp SET uid = LPAD(TO_HEX(id), 8, '0') WHERE uid = '';
-- Create unique index on uid
CREATE UNIQUE INDEX IF NOT EXISTS idx_idp_uid ON idp (uid);
...@@ -80,6 +80,7 @@ CREATE TABLE activity ( ...@@ -80,6 +80,7 @@ CREATE TABLE activity (
-- idp -- idp
CREATE TABLE idp ( CREATE TABLE idp (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
uid TEXT NOT NULL UNIQUE,
name TEXT NOT NULL, name TEXT NOT NULL,
type TEXT NOT NULL, type TEXT NOT NULL,
identifier_filter TEXT NOT NULL DEFAULT '', identifier_filter TEXT NOT NULL DEFAULT '',
......
-- Add uid column to idp table
ALTER TABLE idp ADD COLUMN uid TEXT NOT NULL DEFAULT '';
-- Populate uid for existing rows using hex of id as a fallback
UPDATE idp SET uid = printf('%08x', id) WHERE uid = '';
-- Create unique index on uid
CREATE UNIQUE INDEX IF NOT EXISTS idx_idp_uid ON idp (uid);
...@@ -81,6 +81,7 @@ CREATE TABLE activity ( ...@@ -81,6 +81,7 @@ CREATE TABLE activity (
-- idp -- idp
CREATE TABLE idp ( CREATE TABLE idp (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
uid TEXT NOT NULL UNIQUE,
name TEXT NOT NULL, name TEXT NOT NULL,
type TEXT NOT NULL, type TEXT NOT NULL,
identifier_filter TEXT NOT NULL DEFAULT '', identifier_filter TEXT NOT NULL DEFAULT '',
......
...@@ -15,6 +15,7 @@ func TestIdentityProviderStore(t *testing.T) { ...@@ -15,6 +15,7 @@ func TestIdentityProviderStore(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
createdIDP, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ createdIDP, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Uid: "test-github-oauth",
Name: "GitHub OAuth", Name: "GitHub OAuth",
Type: storepb.IdentityProvider_OAUTH2, Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: "", IdentifierFilter: "",
...@@ -37,6 +38,7 @@ func TestIdentityProviderStore(t *testing.T) { ...@@ -37,6 +38,7 @@ func TestIdentityProviderStore(t *testing.T) {
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "test-github-oauth", createdIDP.Uid)
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &createdIDP.Id, ID: &createdIDP.Id,
}) })
...@@ -66,7 +68,7 @@ func TestIdentityProviderGetByID(t *testing.T) { ...@@ -66,7 +68,7 @@ func TestIdentityProviderGetByID(t *testing.T) {
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
// Create IDP // Create IDP
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
require.NoError(t, err) require.NoError(t, err)
// Get by ID // Get by ID
...@@ -76,6 +78,13 @@ func TestIdentityProviderGetByID(t *testing.T) { ...@@ -76,6 +78,13 @@ func TestIdentityProviderGetByID(t *testing.T) {
require.Equal(t, idp.Id, found.Id) require.Equal(t, idp.Id, found.Id)
require.Equal(t, idp.Name, found.Name) require.Equal(t, idp.Name, found.Name)
// Get by UID
foundByUID, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{UID: &idp.Uid})
require.NoError(t, err)
require.NotNil(t, foundByUID)
require.Equal(t, idp.Id, foundByUID.Id)
require.Equal(t, idp.Uid, foundByUID.Uid)
// Get by non-existent ID // Get by non-existent ID
nonExistentID := int32(99999) nonExistentID := int32(99999)
notFound, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &nonExistentID}) notFound, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &nonExistentID})
...@@ -91,11 +100,11 @@ func TestIdentityProviderListMultiple(t *testing.T) { ...@@ -91,11 +100,11 @@ func TestIdentityProviderListMultiple(t *testing.T) {
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
// Create multiple IDPs // Create multiple IDPs
_, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth")) _, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth", "github-oauth"))
require.NoError(t, err) require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth")) _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth", "google-oauth"))
require.NoError(t, err) require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth")) _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitLab OAuth", "gitlab-oauth"))
require.NoError(t, err) require.NoError(t, err)
// List all // List all
...@@ -112,9 +121,9 @@ func TestIdentityProviderListByID(t *testing.T) { ...@@ -112,9 +121,9 @@ func TestIdentityProviderListByID(t *testing.T) {
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
// Create multiple IDPs // Create multiple IDPs
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth")) idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("GitHub OAuth", "github-oauth"))
require.NoError(t, err) require.NoError(t, err)
_, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth")) _, err = ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Google OAuth", "google-oauth"))
require.NoError(t, err) require.NoError(t, err)
// List by specific ID // List by specific ID
...@@ -131,7 +140,7 @@ func TestIdentityProviderUpdateName(t *testing.T) { ...@@ -131,7 +140,7 @@ func TestIdentityProviderUpdateName(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name")) idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original Name", "original-name"))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "Original Name", idp.Name) require.Equal(t, "Original Name", idp.Name)
...@@ -158,7 +167,7 @@ func TestIdentityProviderUpdateIdentifierFilter(t *testing.T) { ...@@ -158,7 +167,7 @@ func TestIdentityProviderUpdateIdentifierFilter(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "", idp.IdentifierFilter) require.Equal(t, "", idp.IdentifierFilter)
...@@ -185,7 +194,7 @@ func TestIdentityProviderUpdateConfig(t *testing.T) { ...@@ -185,7 +194,7 @@ func TestIdentityProviderUpdateConfig(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
require.NoError(t, err) require.NoError(t, err)
// Update config // Update config
...@@ -229,7 +238,7 @@ func TestIdentityProviderUpdateMultipleFields(t *testing.T) { ...@@ -229,7 +238,7 @@ func TestIdentityProviderUpdateMultipleFields(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original")) idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Original", "original"))
require.NoError(t, err) require.NoError(t, err)
// Update multiple fields at once // Update multiple fields at once
...@@ -253,7 +262,7 @@ func TestIdentityProviderDelete(t *testing.T) { ...@@ -253,7 +262,7 @@ func TestIdentityProviderDelete(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP")) idp, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("Test IDP", "test-idp"))
require.NoError(t, err) require.NoError(t, err)
// Delete // Delete
...@@ -274,9 +283,9 @@ func TestIdentityProviderDeleteNotAffectOthers(t *testing.T) { ...@@ -274,9 +283,9 @@ func TestIdentityProviderDeleteNotAffectOthers(t *testing.T) {
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
// Create multiple IDPs // Create multiple IDPs
idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1")) idp1, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 1", "idp-1"))
require.NoError(t, err) require.NoError(t, err)
idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2")) idp2, err := ts.CreateIdentityProvider(ctx, createTestOAuth2IDP("IDP 2", "idp-2"))
require.NoError(t, err) require.NoError(t, err)
// Delete first one // Delete first one
...@@ -304,6 +313,7 @@ func TestIdentityProviderOAuth2ConfigScopes(t *testing.T) { ...@@ -304,6 +313,7 @@ func TestIdentityProviderOAuth2ConfigScopes(t *testing.T) {
// Create IDP with multiple scopes // Create IDP with multiple scopes
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Uid: "multi-scope-oauth",
Name: "Multi-Scope OAuth", Name: "Multi-Scope OAuth",
Type: storepb.IdentityProvider_OAUTH2, Type: storepb.IdentityProvider_OAUTH2,
Config: &storepb.IdentityProviderConfig{ Config: &storepb.IdentityProviderConfig{
...@@ -343,6 +353,7 @@ func TestIdentityProviderFieldMapping(t *testing.T) { ...@@ -343,6 +353,7 @@ func TestIdentityProviderFieldMapping(t *testing.T) {
// Create IDP with custom field mapping // Create IDP with custom field mapping
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Uid: "custom-field-mapping",
Name: "Custom Field Mapping", Name: "Custom Field Mapping",
Type: storepb.IdentityProvider_OAUTH2, Type: storepb.IdentityProvider_OAUTH2,
Config: &storepb.IdentityProviderConfig{ Config: &storepb.IdentityProviderConfig{
...@@ -382,17 +393,19 @@ func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) { ...@@ -382,17 +393,19 @@ func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
uid string
filter string filter string
}{ }{
{"Domain filter", "@company\\.com$"}, {"Domain filter", "domain-filter", "@company\\.com$"},
{"Prefix filter", "^admin_"}, {"Prefix filter", "prefix-filter", "^admin_"},
{"Complex regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"}, {"Complex regex", "complex-regex", "^[a-z]+@(dept1|dept2)\\.example\\.com$"},
{"Empty filter", ""}, {"Empty filter", "empty-filter", ""},
} }
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{ idp, err := ts.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Uid: tc.uid,
Name: tc.name, Name: tc.name,
Type: storepb.IdentityProvider_OAUTH2, Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: tc.filter, IdentifierFilter: tc.filter,
...@@ -428,8 +441,9 @@ func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) { ...@@ -428,8 +441,9 @@ func TestIdentityProviderIdentifierFilterPatterns(t *testing.T) {
} }
// Helper function to create a test OAuth2 IDP. // Helper function to create a test OAuth2 IDP.
func createTestOAuth2IDP(name string) *storepb.IdentityProvider { func createTestOAuth2IDP(name, uid string) *storepb.IdentityProvider {
return &storepb.IdentityProvider{ return &storepb.IdentityProvider{
Uid: uid,
Name: name, Name: name,
Type: storepb.IdentityProvider_OAUTH2, Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: "", IdentifierFilter: "",
......
...@@ -133,6 +133,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on ...@@ -133,6 +133,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
const identityProviderTypes = [...new Set(templateList.map((t) => t.type))]; const identityProviderTypes = [...new Set(templateList.map((t) => t.type))];
const [basicInfo, setBasicInfo] = useState({ const [basicInfo, setBasicInfo] = useState({
title: "", title: "",
identifier: "",
identifierFilter: "", identifierFilter: "",
}); });
const [type, setType] = useState<IdentityProvider_Type>(IdentityProvider_Type.OAUTH2); const [type, setType] = useState<IdentityProvider_Type>(IdentityProvider_Type.OAUTH2);
...@@ -161,6 +162,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on ...@@ -161,6 +162,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
// Reset to default state when dialog is closed // Reset to default state when dialog is closed
setBasicInfo({ setBasicInfo({
title: "", title: "",
identifier: "",
identifierFilter: "", identifierFilter: "",
}); });
setType(IdentityProvider_Type.OAUTH2); setType(IdentityProvider_Type.OAUTH2);
...@@ -189,6 +191,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on ...@@ -189,6 +191,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
if (open && identityProvider) { if (open && identityProvider) {
setBasicInfo({ setBasicInfo({
title: identityProvider.title, title: identityProvider.title,
identifier: "",
identifierFilter: identityProvider.identifierFilter, identifierFilter: identityProvider.identifierFilter,
}); });
setType(identityProvider.type); setType(identityProvider.type);
...@@ -210,6 +213,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on ...@@ -210,6 +213,7 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
if (template) { if (template) {
setBasicInfo({ setBasicInfo({
title: template.title, title: template.title,
identifier: template.title.toLowerCase().replace(/[^a-z0-9]+/g, "-"),
identifierFilter: template.identifierFilter, identifierFilter: template.identifierFilter,
}); });
setType(template.type); setType(template.type);
...@@ -229,6 +233,9 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on ...@@ -229,6 +233,9 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
if (basicInfo.title === "") { if (basicInfo.title === "") {
return false; return false;
} }
if (isCreating && basicInfo.identifier === "") {
return false;
}
if (type === IdentityProvider_Type.OAUTH2) { if (type === IdentityProvider_Type.OAUTH2) {
if ( if (
oauth2Config.clientId === "" || oauth2Config.clientId === "" ||
...@@ -254,8 +261,10 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on ...@@ -254,8 +261,10 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
try { try {
if (isCreating) { if (isCreating) {
await identityProviderServiceClient.createIdentityProvider({ await identityProviderServiceClient.createIdentityProvider({
identityProviderId: basicInfo.identifier,
identityProvider: create(IdentityProviderSchema, { identityProvider: create(IdentityProviderSchema, {
...basicInfo, title: basicInfo.title,
identifierFilter: basicInfo.identifierFilter,
type: type, type: type,
config: create(IdentityProviderConfigSchema, { config: create(IdentityProviderConfigSchema, {
config: { config: {
...@@ -343,6 +352,32 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on ...@@ -343,6 +352,32 @@ function CreateIdentityProviderDialog({ open, onOpenChange, identityProvider, on
<Separator className="my-2" /> <Separator className="my-2" />
</> </>
)} )}
{isCreating && (
<>
<p className="mb-1 text-sm font-medium">
ID
<span className="text-destructive">*</span>
</p>
<Input
className="mb-2 w-full font-mono"
placeholder="e.g. github, okta-corp"
maxLength={32}
value={basicInfo.identifier}
onChange={(e) =>
setBasicInfo({
...basicInfo,
identifier: e.target.value
.toLowerCase()
.replace(/[^a-z0-9-]/g, "-")
.replace(/--+/g, "-"),
})
}
/>
<p className="mb-2 text-xs text-muted-foreground">
A unique identifier for this provider. Lowercase letters, numbers, and hyphens only.
</p>
</>
)}
<p className="mb-1 text-sm font-medium"> <p className="mb-1 text-sm font-medium">
{t("common.name")} {t("common.name")}
<span className="text-destructive">*</span> <span className="text-destructive">*</span>
......
...@@ -16,8 +16,8 @@ export const extractMemoIdFromName = (name: string) => { ...@@ -16,8 +16,8 @@ export const extractMemoIdFromName = (name: string) => {
return name.split(memoNamePrefix).pop() || ""; return name.split(memoNamePrefix).pop() || "";
}; };
export const extractIdentityProviderIdFromName = (name: string) => { export const extractIdentityProviderUidFromName = (name: string) => {
return parseInt(name.split(identityProviderNamePrefix).pop() || "", 10); return name.split(identityProviderNamePrefix).pop() || "";
}; };
// Helper function to convert InstanceSetting_Key enum value to string name // Helper function to convert InstanceSetting_Key enum value to string name
......
...@@ -72,7 +72,7 @@ const AuthCallback = () => { ...@@ -72,7 +72,7 @@ const AuthCallback = () => {
return; return;
} }
const { identityProviderId, returnUrl, codeVerifier } = validatedState; const { identityProviderName, returnUrl, codeVerifier } = validatedState;
const redirectUri = absolutifyLink("/auth/callback"); const redirectUri = absolutifyLink("/auth/callback");
(async () => { (async () => {
...@@ -81,7 +81,7 @@ const AuthCallback = () => { ...@@ -81,7 +81,7 @@ const AuthCallback = () => {
credentials: { credentials: {
case: "ssoCredentials", case: "ssoCredentials",
value: { value: {
idpId: identityProviderId, idpName: identityProviderName,
code, code,
redirectUri, redirectUri,
codeVerifier: codeVerifier || "", // Pass PKCE code_verifier for token exchange codeVerifier: codeVerifier || "", // Pass PKCE code_verifier for token exchange
......
...@@ -7,7 +7,6 @@ import { Button } from "@/components/ui/button"; ...@@ -7,7 +7,6 @@ import { Button } from "@/components/ui/button";
import { Separator } from "@/components/ui/separator"; import { Separator } from "@/components/ui/separator";
import { identityProviderServiceClient } from "@/connect"; import { identityProviderServiceClient } from "@/connect";
import { useInstance } from "@/contexts/InstanceContext"; import { useInstance } from "@/contexts/InstanceContext";
import { extractIdentityProviderIdFromName } from "@/helpers/resource-names";
import { absolutifyLink } from "@/helpers/utils"; import { absolutifyLink } from "@/helpers/utils";
import useCurrentUser from "@/hooks/useCurrentUser"; import useCurrentUser from "@/hooks/useCurrentUser";
import { handleError } from "@/lib/error"; import { handleError } from "@/lib/error";
...@@ -50,8 +49,7 @@ const SignIn = () => { ...@@ -50,8 +49,7 @@ const SignIn = () => {
try { try {
// Generate and store secure state parameter with CSRF protection // Generate and store secure state parameter with CSRF protection
// Also generate PKCE parameters (code_challenge) for enhanced security if available // Also generate PKCE parameters (code_challenge) for enhanced security if available
const identityProviderId = extractIdentityProviderIdFromName(identityProvider.name); const { state, codeChallenge } = await storeOAuthState(identityProvider.name);
const { state, codeChallenge } = await storeOAuthState(identityProviderId);
// Build OAuth authorization URL with secure state // Build OAuth authorization URL with secure state
// Include PKCE if available (requires HTTPS/localhost for crypto.subtle) // Include PKCE if available (requires HTTPS/localhost for crypto.subtle)
......
...@@ -16,7 +16,7 @@ import type { Message } from "@bufbuild/protobuf"; ...@@ -16,7 +16,7 @@ import type { Message } from "@bufbuild/protobuf";
* Describes the file api/v1/auth_service.proto. * Describes the file api/v1/auth_service.proto.
*/ */
export const file_api_v1_auth_service: GenFile = /*@__PURE__*/ export const file_api_v1_auth_service: GenFile = /*@__PURE__*/
fileDesc("ChlhcGkvdjEvYXV0aF9zZXJ2aWNlLnByb3RvEgxtZW1vcy5hcGkudjEiFwoVR2V0Q3VycmVudFVzZXJSZXF1ZXN0IjoKFkdldEN1cnJlbnRVc2VyUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyIuwCCg1TaWduSW5SZXF1ZXN0Ek8KFHBhc3N3b3JkX2NyZWRlbnRpYWxzGAEgASgLMi8ubWVtb3MuYXBpLnYxLlNpZ25JblJlcXVlc3QuUGFzc3dvcmRDcmVkZW50aWFsc0gAEkUKD3Nzb19jcmVkZW50aWFscxgCIAEoCzIqLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0LlNTT0NyZWRlbnRpYWxzSAAaQwoTUGFzc3dvcmRDcmVkZW50aWFscxIVCgh1c2VybmFtZRgBIAEoCUID4EECEhUKCHBhc3N3b3JkGAIgASgJQgPgQQIabwoOU1NPQ3JlZGVudGlhbHMSEwoGaWRwX2lkGAEgASgFQgPgQQISEQoEY29kZRgCIAEoCUID4EECEhkKDHJlZGlyZWN0X3VyaRgDIAEoCUID4EECEhoKDWNvZGVfdmVyaWZpZXIYBCABKAlCA+BBAUINCgtjcmVkZW50aWFscyKFAQoOU2lnbkluUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyEhQKDGFjY2Vzc190b2tlbhgCIAEoCRI7ChdhY2Nlc3NfdG9rZW5fZXhwaXJlc19hdBgDIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5UaW1lc3RhbXAiEAoOU2lnbk91dFJlcXVlc3QiFQoTUmVmcmVzaFRva2VuUmVxdWVzdCJcChRSZWZyZXNoVG9rZW5SZXNwb25zZRIUCgxhY2Nlc3NfdG9rZW4YASABKAkSLgoKZXhwaXJlc19hdBgCIAEoCzIaLmdvb2dsZS5wcm90b2J1Zi5UaW1lc3RhbXAyvwMKC0F1dGhTZXJ2aWNlEnQKDkdldEN1cnJlbnRVc2VyEiMubWVtb3MuYXBpLnYxLkdldEN1cnJlbnRVc2VyUmVxdWVzdBokLm1lbW9zLmFwaS52MS5HZXRDdXJyZW50VXNlclJlc3BvbnNlIheC0+STAhESDy9hcGkvdjEvYXV0aC9tZRJjCgZTaWduSW4SGy5tZW1vcy5hcGkudjEuU2lnbkluUmVxdWVzdBocLm1lbW9zLmFwaS52MS5TaWduSW5SZXNwb25zZSIegtPkkwIYOgEqIhMvYXBpL3YxL2F1dGgvc2lnbmluEl0KB1NpZ25PdXQSHC5tZW1vcy5hcGkudjEuU2lnbk91dFJlcXVlc3QaFi5nb29nbGUucHJvdG9idWYuRW1wdHkiHILT5JMCFiIUL2FwaS92MS9hdXRoL3NpZ25vdXQSdgoMUmVmcmVzaFRva2VuEiEubWVtb3MuYXBpLnYxLlJlZnJlc2hUb2tlblJlcXVlc3QaIi5tZW1vcy5hcGkudjEuUmVmcmVzaFRva2VuUmVzcG9uc2UiH4LT5JMCGToBKiIUL2FwaS92MS9hdXRoL3JlZnJlc2hCqAEKEGNvbS5tZW1vcy5hcGkudjFCEEF1dGhTZXJ2aWNlUHJvdG9QAVowZ2l0aHViLmNvbS91c2VtZW1vcy9tZW1vcy9wcm90by9nZW4vYXBpL3YxO2FwaXYxogIDTUFYqgIMTWVtb3MuQXBpLlYxygIMTWVtb3NcQXBpXFYx4gIYTWVtb3NcQXBpXFYxXEdQQk1ldGFkYXRh6gIOTWVtb3M6OkFwaTo6VjFiBnByb3RvMw", [file_api_v1_user_service, file_google_api_annotations, file_google_api_field_behavior, file_google_protobuf_empty, file_google_protobuf_timestamp]); fileDesc("ChlhcGkvdjEvYXV0aF9zZXJ2aWNlLnByb3RvEgxtZW1vcy5hcGkudjEiFwoVR2V0Q3VycmVudFVzZXJSZXF1ZXN0IjoKFkdldEN1cnJlbnRVc2VyUmVzcG9uc2USIAoEdXNlchgBIAEoCzISLm1lbW9zLmFwaS52MS5Vc2VyIu4CCg1TaWduSW5SZXF1ZXN0Ek8KFHBhc3N3b3JkX2NyZWRlbnRpYWxzGAEgASgLMi8ubWVtb3MuYXBpLnYxLlNpZ25JblJlcXVlc3QuUGFzc3dvcmRDcmVkZW50aWFsc0gAEkUKD3Nzb19jcmVkZW50aWFscxgCIAEoCzIqLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0LlNTT0NyZWRlbnRpYWxzSAAaQwoTUGFzc3dvcmRDcmVkZW50aWFscxIVCgh1c2VybmFtZRgBIAEoCUID4EECEhUKCHBhc3N3b3JkGAIgASgJQgPgQQIacQoOU1NPQ3JlZGVudGlhbHMSFQoIaWRwX25hbWUYASABKAlCA+BBAhIRCgRjb2RlGAIgASgJQgPgQQISGQoMcmVkaXJlY3RfdXJpGAMgASgJQgPgQQISGgoNY29kZV92ZXJpZmllchgEIAEoCUID4EEBQg0KC2NyZWRlbnRpYWxzIoUBCg5TaWduSW5SZXNwb25zZRIgCgR1c2VyGAEgASgLMhIubWVtb3MuYXBpLnYxLlVzZXISFAoMYWNjZXNzX3Rva2VuGAIgASgJEjsKF2FjY2Vzc190b2tlbl9leHBpcmVzX2F0GAMgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcCIQCg5TaWduT3V0UmVxdWVzdCIVChNSZWZyZXNoVG9rZW5SZXF1ZXN0IlwKFFJlZnJlc2hUb2tlblJlc3BvbnNlEhQKDGFjY2Vzc190b2tlbhgBIAEoCRIuCgpleHBpcmVzX2F0GAIgASgLMhouZ29vZ2xlLnByb3RvYnVmLlRpbWVzdGFtcDK/AwoLQXV0aFNlcnZpY2USdAoOR2V0Q3VycmVudFVzZXISIy5tZW1vcy5hcGkudjEuR2V0Q3VycmVudFVzZXJSZXF1ZXN0GiQubWVtb3MuYXBpLnYxLkdldEN1cnJlbnRVc2VyUmVzcG9uc2UiF4LT5JMCERIPL2FwaS92MS9hdXRoL21lEmMKBlNpZ25JbhIbLm1lbW9zLmFwaS52MS5TaWduSW5SZXF1ZXN0GhwubWVtb3MuYXBpLnYxLlNpZ25JblJlc3BvbnNlIh6C0+STAhg6ASoiEy9hcGkvdjEvYXV0aC9zaWduaW4SXQoHU2lnbk91dBIcLm1lbW9zLmFwaS52MS5TaWduT3V0UmVxdWVzdBoWLmdvb2dsZS5wcm90b2J1Zi5FbXB0eSIcgtPkkwIWIhQvYXBpL3YxL2F1dGgvc2lnbm91dBJ2CgxSZWZyZXNoVG9rZW4SIS5tZW1vcy5hcGkudjEuUmVmcmVzaFRva2VuUmVxdWVzdBoiLm1lbW9zLmFwaS52MS5SZWZyZXNoVG9rZW5SZXNwb25zZSIfgtPkkwIZOgEqIhQvYXBpL3YxL2F1dGgvcmVmcmVzaEKoAQoQY29tLm1lbW9zLmFwaS52MUIQQXV0aFNlcnZpY2VQcm90b1ABWjBnaXRodWIuY29tL3VzZW1lbW9zL21lbW9zL3Byb3RvL2dlbi9hcGkvdjE7YXBpdjGiAgNNQViqAgxNZW1vcy5BcGkuVjHKAgxNZW1vc1xBcGlcVjHiAhhNZW1vc1xBcGlcVjFcR1BCTWV0YWRhdGHqAg5NZW1vczo6QXBpOjpWMWIGcHJvdG8z", [file_api_v1_user_service, file_google_api_annotations, file_google_api_field_behavior, file_google_protobuf_empty, file_google_protobuf_timestamp]);
/** /**
* @generated from message memos.api.v1.GetCurrentUserRequest * @generated from message memos.api.v1.GetCurrentUserRequest
...@@ -120,11 +120,12 @@ export const SignInRequest_PasswordCredentialsSchema: GenMessage<SignInRequest_P ...@@ -120,11 +120,12 @@ export const SignInRequest_PasswordCredentialsSchema: GenMessage<SignInRequest_P
*/ */
export type SignInRequest_SSOCredentials = Message<"memos.api.v1.SignInRequest.SSOCredentials"> & { export type SignInRequest_SSOCredentials = Message<"memos.api.v1.SignInRequest.SSOCredentials"> & {
/** /**
* The ID of the SSO provider. * The resource name of the SSO provider.
* Format: identity-providers/{uid}
* *
* @generated from field: int32 idp_id = 1; * @generated from field: string idp_name = 1;
*/ */
idpId: number; idpName: string;
/** /**
* The authorization code from the SSO provider. * The authorization code from the SSO provider.
......
...@@ -3,7 +3,7 @@ const STATE_EXPIRY_MS = 10 * 60 * 1000; // 10 minutes ...@@ -3,7 +3,7 @@ const STATE_EXPIRY_MS = 10 * 60 * 1000; // 10 minutes
interface OAuthState { interface OAuthState {
state: string; state: string;
identityProviderId: number; identityProviderName: string;
timestamp: number; timestamp: number;
returnUrl?: string; returnUrl?: string;
codeVerifier?: string; // PKCE code_verifier codeVerifier?: string; // PKCE code_verifier
...@@ -42,7 +42,10 @@ function base64UrlEncode(buffer: Uint8Array): string { ...@@ -42,7 +42,10 @@ function base64UrlEncode(buffer: Uint8Array): string {
// Store OAuth state and PKCE parameters in sessionStorage // Store OAuth state and PKCE parameters in sessionStorage
// Returns state and optional codeChallenge for use in authorization URL // Returns state and optional codeChallenge for use in authorization URL
// PKCE is optional - if crypto APIs are unavailable (HTTP context), falls back to standard OAuth // PKCE is optional - if crypto APIs are unavailable (HTTP context), falls back to standard OAuth
export async function storeOAuthState(identityProviderId: number, returnUrl?: string): Promise<{ state: string; codeChallenge?: string }> { export async function storeOAuthState(
identityProviderName: string,
returnUrl?: string,
): Promise<{ state: string; codeChallenge?: string }> {
const state = generateSecureState(); const state = generateSecureState();
// Try to generate PKCE parameters if crypto.subtle is available (HTTPS/localhost) // Try to generate PKCE parameters if crypto.subtle is available (HTTPS/localhost)
...@@ -70,7 +73,7 @@ export async function storeOAuthState(identityProviderId: number, returnUrl?: st ...@@ -70,7 +73,7 @@ export async function storeOAuthState(identityProviderId: number, returnUrl?: st
const stateData: OAuthState = { const stateData: OAuthState = {
state, state,
identityProviderId, identityProviderName,
timestamp: Date.now(), timestamp: Date.now(),
returnUrl, returnUrl,
codeVerifier, // Store for later retrieval in callback (undefined if PKCE not available) codeVerifier, // Store for later retrieval in callback (undefined if PKCE not available)
...@@ -87,8 +90,8 @@ export async function storeOAuthState(identityProviderId: number, returnUrl?: st ...@@ -87,8 +90,8 @@ export async function storeOAuthState(identityProviderId: number, returnUrl?: st
} }
// Validate and retrieve OAuth state from storage (CSRF protection) // Validate and retrieve OAuth state from storage (CSRF protection)
// Returns identityProviderId, returnUrl, and codeVerifier for PKCE // Returns identityProviderName, returnUrl, and codeVerifier for PKCE
export function validateOAuthState(stateParam: string): { identityProviderId: number; returnUrl?: string; codeVerifier?: string } | null { export function validateOAuthState(stateParam: string): { identityProviderName: string; returnUrl?: string; codeVerifier?: string } | null {
try { try {
const storedData = sessionStorage.getItem(STATE_STORAGE_KEY); const storedData = sessionStorage.getItem(STATE_STORAGE_KEY);
if (!storedData) { if (!storedData) {
...@@ -115,7 +118,7 @@ export function validateOAuthState(stateParam: string): { identityProviderId: nu ...@@ -115,7 +118,7 @@ export function validateOAuthState(stateParam: string): { identityProviderId: nu
// State is valid, clean up and return data // State is valid, clean up and return data
sessionStorage.removeItem(STATE_STORAGE_KEY); sessionStorage.removeItem(STATE_STORAGE_KEY);
return { return {
identityProviderId: stateData.identityProviderId, identityProviderName: stateData.identityProviderName,
returnUrl: stateData.returnUrl, returnUrl: stateData.returnUrl,
codeVerifier: stateData.codeVerifier, // Return PKCE code_verifier codeVerifier: stateData.codeVerifier, // Return PKCE code_verifier
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment