Unverified Commit d688914b authored by boojack's avatar boojack Committed by GitHub

feat(auth): add SSO user identity linkage (#5883)

parent 50638040
...@@ -9,9 +9,6 @@ build/ ...@@ -9,9 +9,6 @@ build/
bin/ bin/
memos memos
# Plan/design documents
docs/plans/
.DS_Store .DS_Store
# Jetbrains # Jetbrains
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
SSO sign-in in memos currently treats the IdP-provided identifier as the local username. The identifier value comes from the OAuth2 UserInfo claim named in `FieldMapping.identifier`, while local usernames are validated by `validateUsername` against `base.UIDMatcher`. Real IdPs frequently emit identifiers such as email addresses, opaque subject IDs, or provider-specific account IDs that are valid authentication subjects but are not valid memos usernames. SSO sign-in in memos currently treats the IdP-provided identifier as the local username. The identifier value comes from the OAuth2 UserInfo claim named in `FieldMapping.identifier`, while local usernames are validated by `validateUsername` against `base.UIDMatcher`. Real IdPs frequently emit identifiers such as email addresses, opaque subject IDs, or provider-specific account IDs that are valid authentication subjects but are not valid memos usernames.
The existing issue artifacts under `docs/issues/2026-04-21-sso-user-identity-linkage/` already scope a persistent linkage between SSO identities and local users. A broader review of upstream open source schemas now shows that similar systems converge on separating external identity from the local user row, but do not converge on one universal table name or one exact column set. That difference matters because the implementation problem is narrower than "copy one upstream schema exactly" and broader than "pick any new table name locally." The existing issue artifacts under `docs/plans/2026-04-21-sso-user-identity-linkage/` already scope a persistent linkage between SSO identities and local users. A broader review of upstream open source schemas now shows that similar systems converge on separating external identity from the local user row, but do not converge on one universal table name or one exact column set. That difference matters because the implementation problem is narrower than "copy one upstream schema exactly" and broader than "pick any new table name locally."
## Issue Statement ## Issue Statement
......
## Execution Log
### T1: Add `user_identity` migrations + LATEST.sql updates
**Status**: Completed
**Files Changed**:
- Created: `store/migration/sqlite/0.28/00__user_identity.sql`
- Created: `store/migration/postgres/0.28/00__user_identity.sql`
- Created: `store/migration/mysql/0.28/00__user_identity.sql`
- Modified: `store/migration/sqlite/LATEST.sql`
- Modified: `store/migration/postgres/LATEST.sql`
- Modified: `store/migration/mysql/LATEST.sql`
**Validation**:
- `rg 'CREATE TABLE \`?user_identity\`?' store/migration` — PASS (hits in all 6 expected files).
- `rg 'UNIQUE \(\`?provider\`?, \`?extern_uid\`?\)' store/migration` — PASS (6 hits).
- `go build ./...` — PASS.
**Path Corrections**: None.
**Deviations**: None.
### T2: Add `store.UserIdentity` model, `Store` methods, and driver interface
**Status**: Completed
**Files Changed**:
- Created: `store/user_identity.go`
- Modified: `store/driver.go`
**Validation**:
- Interface-only build is expected to fail until T3–T5; deferred compile check to T5.
- `rg 'CreateUserIdentity|ListUserIdentities' store/driver.go store/user_identity.go` — PASS (method declarations present in both files).
**Path Corrections**: None.
**Deviations**: None.
### T3: Implement SQLite driver for `user_identity`
**Status**: Completed
**Files Changed**:
- Created: `store/db/sqlite/user_identity.go`
**Validation**:
- `go build ./store/db/sqlite/...` — PASS.
**Path Corrections**: None.
**Deviations**: None.
### T4: Implement Postgres driver for `user_identity`
**Status**: Completed
**Files Changed**:
- Created: `store/db/postgres/user_identity.go`
**Validation**:
- `go build ./store/db/postgres/...` — PASS.
**Path Corrections**: None.
**Deviations**: None.
### T5: Implement MySQL driver for `user_identity`
**Status**: Completed
**Files Changed**:
- Created: `store/db/mysql/user_identity.go`
**Validation**:
- `go build ./...` — PASS (whole repo compiles; all drivers satisfy the `Driver` interface).
**Path Corrections**: None.
**Deviations**: None.
### T6: Add store-layer tests for `user_identity`
**Status**: Completed
**Files Changed**:
- Created: `store/test/user_identity_test.go`
**Validation**:
- `DRIVER=sqlite go test ./store/test/ -run TestUserIdentity -count=1 -v` — PASS:
- `TestUserIdentityCreateAndGet` — PASS
- `TestUserIdentityListByUserID` — PASS
- `TestUserIdentityUniqueConflict` — PASS
- `TestUserIdentitySameExternUIDDifferentProviders` — PASS
**Path Corrections**: None.
**Deviations**: None.
### T7: Add SSO username derivation helper
**Status**: Completed
**Files Changed**:
- Created: `server/router/api/v1/sso_username.go`
**Validation**:
- `go build ./server/router/api/v1/...` — PASS.
- `go vet ./server/router/api/v1/...` — PASS.
**Path Corrections**: None.
**Deviations**: None.
### T8: Route SSO sign-in through `user_identity` linkage
**Status**: Completed
**Files Changed**:
- Modified: `server/router/api/v1/auth_service.go`
- `SignIn` SSO branch now delegates user resolution to a new `resolveSSOUser` method.
- `resolveSSOUser` does: `user_identity` lookup → hit path (load user by linked `user_id`); miss path (registration gate → `deriveSSOUsername` → create user → create linkage → race recovery on unique(provider, extern_uid)).
- Added `isUserIdentityUniqueViolation` helper (string match on the three backends' unique-constraint error strings, matching the pattern in `memo_service.go:103–105`).
**Validation**:
- `go build ./...` — PASS.
- `go vet ./...` — PASS.
- `DRIVER=sqlite go test ./store/test/ -run TestUserIdentity -count=1` — PASS (regression check).
**Path Corrections**:
- The plan pseudocode referenced `identityProvider.UID`; the actual protobuf type `storepb.IdentityProvider` exposes the field as `Uid`. Used `identityProvider.Uid` in the implementation. No semantic deviation.
**Deviations**: None.
## Completion Declaration
**All tasks completed successfully.**
This diff is collapsed.
...@@ -90,6 +90,33 @@ service UserService { ...@@ -90,6 +90,33 @@ service UserService {
option (google.api.method_signature) = "parent"; option (google.api.method_signature) = "parent";
} }
// ListLinkedIdentities returns a list of linked SSO identities for a user.
rpc ListLinkedIdentities(ListLinkedIdentitiesRequest) returns (ListLinkedIdentitiesResponse) {
option (google.api.http) = {get: "/api/v1/{parent=users/*}/linkedIdentities"};
option (google.api.method_signature) = "parent";
}
// CreateLinkedIdentity links an SSO identity to the authenticated user.
rpc CreateLinkedIdentity(CreateLinkedIdentityRequest) returns (LinkedIdentity) {
option (google.api.http) = {
post: "/api/v1/{parent=users/*}/linkedIdentities"
body: "*"
};
option (google.api.method_signature) = "parent,idp_name";
}
// GetLinkedIdentity gets a linked SSO identity for a user.
rpc GetLinkedIdentity(GetLinkedIdentityRequest) returns (LinkedIdentity) {
option (google.api.http) = {get: "/api/v1/{name=users/*/linkedIdentities/*}"};
option (google.api.method_signature) = "name";
}
// DeleteLinkedIdentity unlinks an SSO identity from a user.
rpc DeleteLinkedIdentity(DeleteLinkedIdentityRequest) returns (google.protobuf.Empty) {
option (google.api.http) = {delete: "/api/v1/{name=users/*/linkedIdentities/*}"};
option (google.api.method_signature) = "name";
}
// ListPersonalAccessTokens returns a list of Personal Access Tokens (PATs) for a user. // ListPersonalAccessTokens returns a list of Personal Access Tokens (PATs) for a user.
// PATs are long-lived tokens for API/script access, distinct from short-lived JWT access tokens. // PATs are long-lived tokens for API/script access, distinct from short-lived JWT access tokens.
rpc ListPersonalAccessTokens(ListPersonalAccessTokensRequest) returns (ListPersonalAccessTokensResponse) { rpc ListPersonalAccessTokens(ListPersonalAccessTokensRequest) returns (ListPersonalAccessTokensResponse) {
...@@ -466,6 +493,87 @@ message ListUserSettingsResponse { ...@@ -466,6 +493,87 @@ message ListUserSettingsResponse {
int32 total_size = 3; int32 total_size = 3;
} }
// LinkedIdentity represents an SSO identity linked to a user account.
message LinkedIdentity {
option (google.api.resource) = {
type: "memos.api.v1/LinkedIdentity"
pattern: "users/{user}/linkedIdentities/{linked_identity}"
singular: "linkedIdentity"
plural: "linkedIdentities"
};
// The resource name of the linked identity.
// Format: users/{user}/linkedIdentities/{linked_identity}
string name = 1 [(google.api.field_behavior) = IDENTIFIER];
// The resource name of the identity provider.
// Format: identity-providers/{uid}
string idp_name = 2 [
(google.api.field_behavior) = OUTPUT_ONLY,
(google.api.resource_reference) = {type: "memos.api.v1/IdentityProvider"}
];
// The external user identifier from the identity provider.
string extern_uid = 3 [(google.api.field_behavior) = OUTPUT_ONLY];
}
message ListLinkedIdentitiesRequest {
// Required. The parent resource whose linked identities will be listed.
// Format: users/{user}
string parent = 1 [
(google.api.field_behavior) = REQUIRED,
(google.api.resource_reference) = {type: "memos.api.v1/User"}
];
}
message ListLinkedIdentitiesResponse {
// The list of linked identities.
repeated LinkedIdentity linked_identities = 1;
}
message CreateLinkedIdentityRequest {
// Required. The parent user who owns the linked identity.
// Format: users/{user}
string parent = 1 [
(google.api.field_behavior) = REQUIRED,
(google.api.resource_reference) = {type: "memos.api.v1/User"}
];
// Required. The identity provider to link.
// Format: identity-providers/{uid}
string idp_name = 2 [
(google.api.field_behavior) = REQUIRED,
(google.api.resource_reference) = {type: "memos.api.v1/IdentityProvider"}
];
// Required. The authorization code from the identity provider.
string code = 3 [(google.api.field_behavior) = REQUIRED];
// Required. The redirect URI used in the OAuth flow.
string redirect_uri = 4 [(google.api.field_behavior) = REQUIRED];
// Optional. The PKCE code verifier used in the OAuth flow.
string code_verifier = 5 [(google.api.field_behavior) = OPTIONAL];
}
message GetLinkedIdentityRequest {
// Required. The resource name of the linked identity to get.
// Format: users/{user}/linkedIdentities/{linked_identity}
string name = 1 [
(google.api.field_behavior) = REQUIRED,
(google.api.resource_reference) = {type: "memos.api.v1/LinkedIdentity"}
];
}
message DeleteLinkedIdentityRequest {
// Required. The resource name of the linked identity to delete.
// Format: users/{user}/linkedIdentities/{linked_identity}
string name = 1 [
(google.api.field_behavior) = REQUIRED,
(google.api.resource_reference) = {type: "memos.api.v1/LinkedIdentity"}
];
}
// PersonalAccessToken represents a long-lived token for API/script access. // PersonalAccessToken represents a long-lived token for API/script access.
// PATs are distinct from short-lived JWT access tokens used for session authentication. // PATs are distinct from short-lived JWT access tokens used for session authentication.
message PersonalAccessToken { message PersonalAccessToken {
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -1353,6 +1353,123 @@ paths: ...@@ -1353,6 +1353,123 @@ paths:
application/json: application/json:
schema: schema:
$ref: '#/components/schemas/Status' $ref: '#/components/schemas/Status'
/api/v1/users/{user}/linkedIdentities:
get:
tags:
- UserService
description: ListLinkedIdentities returns a list of linked SSO identities for a user.
operationId: UserService_ListLinkedIdentities
parameters:
- name: user
in: path
description: The user id.
required: true
schema:
type: string
responses:
"200":
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/ListLinkedIdentitiesResponse'
default:
description: Default error response
content:
application/json:
schema:
$ref: '#/components/schemas/Status'
post:
tags:
- UserService
description: CreateLinkedIdentity links an SSO identity to the authenticated user.
operationId: UserService_CreateLinkedIdentity
parameters:
- name: user
in: path
description: The user id.
required: true
schema:
type: string
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/CreateLinkedIdentityRequest'
required: true
responses:
"200":
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/LinkedIdentity'
default:
description: Default error response
content:
application/json:
schema:
$ref: '#/components/schemas/Status'
/api/v1/users/{user}/linkedIdentities/{linkedIdentity}:
get:
tags:
- UserService
description: GetLinkedIdentity gets a linked SSO identity for a user.
operationId: UserService_GetLinkedIdentity
parameters:
- name: user
in: path
description: The user id.
required: true
schema:
type: string
- name: linkedIdentity
in: path
description: The linkedIdentity id.
required: true
schema:
type: string
responses:
"200":
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/LinkedIdentity'
default:
description: Default error response
content:
application/json:
schema:
$ref: '#/components/schemas/Status'
delete:
tags:
- UserService
description: DeleteLinkedIdentity unlinks an SSO identity from a user.
operationId: UserService_DeleteLinkedIdentity
parameters:
- name: user
in: path
description: The user id.
required: true
schema:
type: string
- name: linkedIdentity
in: path
description: The linkedIdentity id.
required: true
schema:
type: string
responses:
"200":
description: OK
content: {}
default:
description: Default error response
content:
application/json:
schema:
$ref: '#/components/schemas/Status'
/api/v1/users/{user}/notifications: /api/v1/users/{user}/notifications:
get: get:
tags: tags:
...@@ -2269,6 +2386,33 @@ components: ...@@ -2269,6 +2386,33 @@ components:
}; };
// ... // ...
CreateLinkedIdentityRequest:
required:
- parent
- idpName
- code
- redirectUri
type: object
properties:
parent:
type: string
description: |-
Required. The parent user who owns the linked identity.
Format: users/{user}
idpName:
type: string
description: |-
Required. The identity provider to link.
Format: identity-providers/{uid}
code:
type: string
description: Required. The authorization code from the identity provider.
redirectUri:
type: string
description: Required. The redirect URI used in the OAuth flow.
codeVerifier:
type: string
description: Optional. The PKCE code verifier used in the OAuth flow.
CreatePersonalAccessTokenRequest: CreatePersonalAccessTokenRequest:
required: required:
- parent - parent
...@@ -2555,6 +2699,25 @@ components: ...@@ -2555,6 +2699,25 @@ components:
so a single entry like "project/.*" matches all tags under that prefix. so a single entry like "project/.*" matches all tags under that prefix.
Exact tag names are also valid (they are trivially valid regex patterns). Exact tag names are also valid (they are trivially valid regex patterns).
description: Tag metadata configuration. description: Tag metadata configuration.
LinkedIdentity:
type: object
properties:
name:
type: string
description: |-
The resource name of the linked identity.
Format: users/{user}/linkedIdentities/{linked_identity}
idpName:
readOnly: true
type: string
description: |-
The resource name of the identity provider.
Format: identity-providers/{uid}
externUid:
readOnly: true
type: string
description: The external user identifier from the identity provider.
description: LinkedIdentity represents an SSO identity linked to a user account.
ListAllUserStatsResponse: ListAllUserStatsResponse:
type: object type: object
properties: properties:
...@@ -2588,6 +2751,14 @@ components: ...@@ -2588,6 +2751,14 @@ components:
items: items:
$ref: '#/components/schemas/IdentityProvider' $ref: '#/components/schemas/IdentityProvider'
description: The list of identity providers. description: The list of identity providers.
ListLinkedIdentitiesResponse:
type: object
properties:
linkedIdentities:
type: array
items:
$ref: '#/components/schemas/LinkedIdentity'
description: The list of linked identities.
ListMemoAttachmentsResponse: ListMemoAttachmentsResponse:
type: object type: object
properties: properties:
......
This diff is collapsed.
...@@ -112,11 +112,9 @@ func (s *ConnectServiceHandler) UpdateUser(ctx context.Context, req *connect.Req ...@@ -112,11 +112,9 @@ func (s *ConnectServiceHandler) UpdateUser(ctx context.Context, req *connect.Req
} }
func (s *ConnectServiceHandler) DeleteUser(ctx context.Context, req *connect.Request[v1pb.DeleteUserRequest]) (*connect.Response[emptypb.Empty], error) { func (s *ConnectServiceHandler) DeleteUser(ctx context.Context, req *connect.Request[v1pb.DeleteUserRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteUser(ctx, req.Msg) return connectWithHeaderCarrier(ctx, func(ctx context.Context) (*emptypb.Empty, error) {
if err != nil { return s.APIV1Service.DeleteUser(ctx, req.Msg)
return nil, convertGRPCError(err) })
}
return connect.NewResponse(resp), nil
} }
func (s *ConnectServiceHandler) ListAllUserStats(ctx context.Context, req *connect.Request[v1pb.ListAllUserStatsRequest]) (*connect.Response[v1pb.ListAllUserStatsResponse], error) { func (s *ConnectServiceHandler) ListAllUserStats(ctx context.Context, req *connect.Request[v1pb.ListAllUserStatsRequest]) (*connect.Response[v1pb.ListAllUserStatsResponse], error) {
...@@ -159,6 +157,38 @@ func (s *ConnectServiceHandler) ListUserSettings(ctx context.Context, req *conne ...@@ -159,6 +157,38 @@ func (s *ConnectServiceHandler) ListUserSettings(ctx context.Context, req *conne
return connect.NewResponse(resp), nil return connect.NewResponse(resp), nil
} }
func (s *ConnectServiceHandler) ListLinkedIdentities(ctx context.Context, req *connect.Request[v1pb.ListLinkedIdentitiesRequest]) (*connect.Response[v1pb.ListLinkedIdentitiesResponse], error) {
resp, err := s.APIV1Service.ListLinkedIdentities(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) CreateLinkedIdentity(ctx context.Context, req *connect.Request[v1pb.CreateLinkedIdentityRequest]) (*connect.Response[v1pb.LinkedIdentity], error) {
resp, err := s.APIV1Service.CreateLinkedIdentity(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) GetLinkedIdentity(ctx context.Context, req *connect.Request[v1pb.GetLinkedIdentityRequest]) (*connect.Response[v1pb.LinkedIdentity], error) {
resp, err := s.APIV1Service.GetLinkedIdentity(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) DeleteLinkedIdentity(ctx context.Context, req *connect.Request[v1pb.DeleteLinkedIdentityRequest]) (*connect.Response[emptypb.Empty], error) {
resp, err := s.APIV1Service.DeleteLinkedIdentity(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
func (s *ConnectServiceHandler) ListPersonalAccessTokens(ctx context.Context, req *connect.Request[v1pb.ListPersonalAccessTokensRequest]) (*connect.Response[v1pb.ListPersonalAccessTokensResponse], error) { func (s *ConnectServiceHandler) ListPersonalAccessTokens(ctx context.Context, req *connect.Request[v1pb.ListPersonalAccessTokensRequest]) (*connect.Response[v1pb.ListPersonalAccessTokensResponse], error) {
resp, err := s.APIV1Service.ListPersonalAccessTokens(ctx, req.Msg) resp, err := s.APIV1Service.ListPersonalAccessTokens(ctx, req.Msg)
if err != nil { if err != nil {
......
package v1
import (
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
)
// deriveSSOUsername produces the local username for a new SSO-created user.
//
// The current policy is to use a standard UUID string directly. This keeps the
// username independent of IdP profile fields and avoids availability probes or
// retry loops around concurrent first-time logins.
func deriveSSOUsername() (string, error) {
username := util.GenUUID()
if err := validateUsername(username); err != nil {
return "", errors.Wrap(err, "generated UUID did not satisfy username constraints")
}
return username, nil
}
package test
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
)
func TestCreateLinkedIdentityBindsCurrentUser(t *testing.T) {
t.Parallel()
ts := NewTestService(t)
defer ts.Cleanup()
ctx := context.Background()
currentUser, err := ts.CreateRegularUser(ctx, "alice")
require.NoError(t, err)
mockIDP := newMockOAuthServer(t, "bind-code", "bind-access-token", map[string]any{
"sub": "google-sub-1",
"name": "Alice Example",
"email": "alice@example.com",
})
defer mockIDP.Close()
idpName := createTestingOAuthIdentityProvider(ctx, t, ts, mockIDP.URL, "google-bind")
beforeUsers, err := ts.Store.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err)
authCtx := ts.CreateUserContext(apiv1.WithHeaderCarrier(ctx), currentUser.ID)
response, err := ts.Service.CreateLinkedIdentity(authCtx, &v1pb.CreateLinkedIdentityRequest{
Parent: apiv1.BuildUserName(currentUser.Username),
IdpName: idpName,
Code: "bind-code",
RedirectUri: "http://localhost:8080/auth/callback",
})
require.NoError(t, err)
require.NotNil(t, response)
require.Equal(t, apiv1.BuildUserName(currentUser.Username)+"/linkedIdentities/google-bind", response.Name)
require.Equal(t, apiv1.IdentityProviderNamePrefix+"google-bind", response.IdpName)
require.Equal(t, "google-sub-1", response.ExternUid)
afterUsers, err := ts.Store.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err)
require.Len(t, afterUsers, len(beforeUsers))
provider := "google-bind"
externUID := "google-sub-1"
identity, err := ts.Store.GetUserIdentity(ctx, &store.FindUserIdentity{
Provider: &provider,
ExternUID: &externUID,
})
require.NoError(t, err)
require.NotNil(t, identity)
require.Equal(t, currentUser.ID, identity.UserID)
}
func TestCreateLinkedIdentityRejectsBindingIdentityLinkedToAnotherUser(t *testing.T) {
t.Parallel()
ts := NewTestService(t)
defer ts.Cleanup()
ctx := context.Background()
owner, err := ts.CreateRegularUser(ctx, "owner")
require.NoError(t, err)
binder, err := ts.CreateRegularUser(ctx, "binder")
require.NoError(t, err)
mockIDP := newMockOAuthServer(t, "conflict-code", "conflict-access-token", map[string]any{
"sub": "google-sub-2",
"name": "Conflict Example",
"email": "conflict@example.com",
})
defer mockIDP.Close()
idpName := createTestingOAuthIdentityProvider(ctx, t, ts, mockIDP.URL, "google-conflict")
_, err = ts.Store.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: owner.ID,
Provider: "google-conflict",
ExternUID: "google-sub-2",
})
require.NoError(t, err)
authCtx := ts.CreateUserContext(apiv1.WithHeaderCarrier(ctx), binder.ID)
_, err = ts.Service.CreateLinkedIdentity(authCtx, &v1pb.CreateLinkedIdentityRequest{
Parent: apiv1.BuildUserName(binder.Username),
IdpName: idpName,
Code: "conflict-code",
RedirectUri: "http://localhost:8080/auth/callback",
})
require.Error(t, err)
require.Equal(t, codes.AlreadyExists, status.Code(err))
}
func TestListAndDeleteLinkedIdentities(t *testing.T) {
t.Parallel()
ts := NewTestService(t)
defer ts.Cleanup()
ctx := context.Background()
currentUser, err := ts.CreateRegularUser(ctx, "alice")
require.NoError(t, err)
_, err = ts.Store.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: currentUser.ID,
Provider: "google",
ExternUID: "alice@gmail.com",
})
require.NoError(t, err)
authCtx := ts.CreateUserContext(ctx, currentUser.ID)
listResp, err := ts.Service.ListLinkedIdentities(authCtx, &v1pb.ListLinkedIdentitiesRequest{
Parent: apiv1.BuildUserName(currentUser.Username),
})
require.NoError(t, err)
require.Len(t, listResp.LinkedIdentities, 1)
linkedIdentityName := apiv1.BuildUserName(currentUser.Username) + "/linkedIdentities/google"
require.Equal(t, linkedIdentityName, listResp.LinkedIdentities[0].Name)
require.Equal(t, apiv1.IdentityProviderNamePrefix+"google", listResp.LinkedIdentities[0].IdpName)
require.Equal(t, "alice@gmail.com", listResp.LinkedIdentities[0].ExternUid)
got, err := ts.Service.GetLinkedIdentity(authCtx, &v1pb.GetLinkedIdentityRequest{
Name: linkedIdentityName,
})
require.NoError(t, err)
require.Equal(t, linkedIdentityName, got.Name)
require.Equal(t, apiv1.IdentityProviderNamePrefix+"google", got.IdpName)
require.Equal(t, "alice@gmail.com", got.ExternUid)
_, err = ts.Service.DeleteLinkedIdentity(authCtx, &v1pb.DeleteLinkedIdentityRequest{
Name: linkedIdentityName,
})
require.NoError(t, err)
listResp, err = ts.Service.ListLinkedIdentities(authCtx, &v1pb.ListLinkedIdentitiesRequest{
Parent: apiv1.BuildUserName(currentUser.Username),
})
require.NoError(t, err)
require.Empty(t, listResp.LinkedIdentities)
}
func TestCreateLinkedIdentityRejectsSecondIdentityForSameProvider(t *testing.T) {
t.Parallel()
ts := NewTestService(t)
defer ts.Cleanup()
ctx := context.Background()
currentUser, err := ts.CreateRegularUser(ctx, "alice")
require.NoError(t, err)
_, err = ts.Store.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: currentUser.ID,
Provider: "google-provider",
ExternUID: "google-sub-1",
})
require.NoError(t, err)
mockIDP := newMockOAuthServer(t, "second-code", "second-access-token", map[string]any{
"sub": "google-sub-2",
"name": "Alice Example",
"email": "alice@example.com",
})
defer mockIDP.Close()
idpName := createTestingOAuthIdentityProvider(ctx, t, ts, mockIDP.URL, "google-provider")
authCtx := ts.CreateUserContext(apiv1.WithHeaderCarrier(ctx), currentUser.ID)
_, err = ts.Service.CreateLinkedIdentity(authCtx, &v1pb.CreateLinkedIdentityRequest{
Parent: apiv1.BuildUserName(currentUser.Username),
IdpName: idpName,
Code: "second-code",
RedirectUri: "http://localhost:8080/auth/callback",
})
require.Error(t, err)
require.Equal(t, codes.AlreadyExists, status.Code(err))
}
func createTestingOAuthIdentityProvider(ctx context.Context, t *testing.T, ts *TestService, serverURL, uid string) string {
t.Helper()
idp, err := ts.Store.CreateIdentityProvider(ctx, &storepb.IdentityProvider{
Uid: uid,
Name: "Google",
Type: storepb.IdentityProvider_OAUTH2,
Config: &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: "test-client-id",
ClientSecret: "test-client-secret",
AuthUrl: serverURL + "/oauth2/authorize",
TokenUrl: serverURL + "/oauth2/token",
UserInfoUrl: serverURL + "/oauth2/userinfo",
FieldMapping: &storepb.FieldMapping{
Identifier: "sub",
DisplayName: "name",
Email: "email",
},
},
},
},
})
require.NoError(t, err)
return apiv1.IdentityProviderNamePrefix + idp.Uid
}
func newMockOAuthServer(t *testing.T, code, accessToken string, userInfo map[string]any) *httptest.Server {
t.Helper()
userInfoBytes, err := json.Marshal(userInfo)
require.NoError(t, err)
mux := http.NewServeMux()
mux.HandleFunc("/oauth2/token", func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
body, err := io.ReadAll(r.Body)
require.NoError(t, err)
values, err := url.ParseQuery(string(body))
require.NoError(t, err)
require.Equal(t, code, values.Get("code"))
require.Equal(t, "authorization_code", values.Get("grant_type"))
w.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w).Encode(map[string]any{
"access_token": accessToken,
"token_type": "Bearer",
"expires_in": 3600,
})
require.NoError(t, err)
})
mux.HandleFunc("/oauth2/userinfo", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, err := w.Write(userInfoBytes)
require.NoError(t, err)
})
return httptest.NewServer(mux)
}
package test
import (
"context"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/timestamppb"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
)
func TestDeleteUserSelfDeleteCleansAccountDataAndAuthCookies(t *testing.T) {
t.Parallel()
ts := NewTestService(t)
defer ts.Cleanup()
ctx := context.Background()
user, err := ts.CreateRegularUser(ctx, "alice")
require.NoError(t, err)
_, err = ts.Store.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: user.ID,
Provider: "google",
ExternUID: "alice-google-sub",
})
require.NoError(t, err)
err = ts.Store.AddUserRefreshToken(ctx, user.ID, &storepb.RefreshTokensUserSetting_RefreshToken{
TokenId: "refresh-token-id",
ExpiresAt: timestamppb.New(time.Now().Add(time.Hour)),
CreatedAt: timestamppb.Now(),
})
require.NoError(t, err)
headerCtx := apiv1.WithHeaderCarrier(ctx)
authCtx := ts.CreateUserContext(headerCtx, user.ID)
_, err = ts.Service.DeleteUser(authCtx, &v1pb.DeleteUserRequest{
Name: apiv1.BuildUserName(user.Username),
})
require.NoError(t, err)
deletedUser, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID})
require.NoError(t, err)
require.Nil(t, deletedUser)
identities, err := ts.Store.ListUserIdentities(ctx, &store.FindUserIdentity{UserID: &user.ID})
require.NoError(t, err)
require.Empty(t, identities)
refreshSetting, err := ts.Store.GetUserSetting(ctx, &store.FindUserSetting{
UserID: &user.ID,
Key: storepb.UserSetting_REFRESH_TOKENS,
})
require.NoError(t, err)
require.Nil(t, refreshSetting)
carrier := apiv1.GetHeaderCarrier(authCtx)
require.NotNil(t, carrier)
require.Contains(t, strings.ToLower(carrier.Get("Set-Cookie")), "memos_refresh=")
}
...@@ -335,12 +335,29 @@ func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserR ...@@ -335,12 +335,29 @@ func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserR
if currentUser.ID != userID && currentUser.Role != store.RoleAdmin { if currentUser.ID != userID && currentUser.Role != store.RoleAdmin {
return nil, status.Errorf(codes.PermissionDenied, "permission denied") return nil, status.Errorf(codes.PermissionDenied, "permission denied")
} }
isSelfDelete := currentUser.ID == userID
if err := s.Store.DeleteUserIdentities(ctx, &store.DeleteUserIdentity{
UserID: &userID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete user identities: %v", err)
}
if err := s.Store.DeleteUserSettings(ctx, &store.DeleteUserSetting{
UserID: &userID,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete user settings: %v", err)
}
if err := s.Store.DeleteUser(ctx, &store.DeleteUser{ if err := s.Store.DeleteUser(ctx, &store.DeleteUser{
ID: user.ID, ID: user.ID,
}); err != nil { }); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete user: %v", err) return nil, status.Errorf(codes.Internal, "failed to delete user: %v", err)
} }
if isSelfDelete {
if err := s.clearAuthCookies(ctx); err != nil {
slog.Warn("failed to clear auth cookies after self delete", "user_id", userID, "error", err)
}
}
return &emptypb.Empty{}, nil return &emptypb.Empty{}, nil
} }
...@@ -390,6 +407,27 @@ func (s *APIV1Service) resolveUserAndWebhookIDFromName(ctx context.Context, name ...@@ -390,6 +407,27 @@ func (s *APIV1Service) resolveUserAndWebhookIDFromName(ctx context.Context, name
return user, parts[3], nil return user, parts[3], nil
} }
func (s *APIV1Service) resolveUserAndLinkedIdentityProviderFromName(ctx context.Context, name string) (*store.User, string, error) {
parts := strings.Split(name, "/")
if len(parts) != 4 || parts[0] != "users" || parts[2] != "linkedIdentities" {
return nil, "", errors.Errorf("invalid linked identity name: %s", name)
}
user, err := s.resolveUserFromName(ctx, BuildUserName(parts[1]))
if err != nil {
return nil, "", err
}
return user, parts[3], nil
}
func convertLinkedIdentityFromStore(user *store.User, identity *store.UserIdentity) *v1pb.LinkedIdentity {
return &v1pb.LinkedIdentity{
Name: fmt.Sprintf("%s/linkedIdentities/%s", BuildUserName(user.Username), identity.Provider),
IdpName: IdentityProviderNamePrefix + identity.Provider,
ExternUid: identity.ExternUID,
}
}
func (s *APIV1Service) resolveUserAndNotificationIDFromName(ctx context.Context, name string) (*store.User, int32, error) { func (s *APIV1Service) resolveUserAndNotificationIDFromName(ctx context.Context, name string) (*store.User, int32, error) {
parts := strings.Split(name, "/") parts := strings.Split(name, "/")
if len(parts) != 4 || parts[0] != "users" || parts[2] != "notifications" { if len(parts) != 4 || parts[0] != "users" || parts[2] != "notifications" {
...@@ -597,6 +635,141 @@ func (s *APIV1Service) ListUserSettings(ctx context.Context, request *v1pb.ListU ...@@ -597,6 +635,141 @@ func (s *APIV1Service) ListUserSettings(ctx context.Context, request *v1pb.ListU
return response, nil return response, nil
} }
func (s *APIV1Service) ListLinkedIdentities(ctx context.Context, request *v1pb.ListLinkedIdentitiesRequest) (*v1pb.ListLinkedIdentitiesResponse, error) {
user, err := s.resolveUserFromName(ctx, request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err)
}
userID := user.ID
claims := auth.GetUserClaims(ctx)
if claims == nil || claims.UserID != userID {
currentUser, _ := s.fetchCurrentUser(ctx)
if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
identities, err := s.Store.ListUserIdentities(ctx, &store.FindUserIdentity{UserID: &userID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list linked identities: %v", err)
}
response := &v1pb.ListLinkedIdentitiesResponse{
LinkedIdentities: []*v1pb.LinkedIdentity{},
}
for _, identity := range identities {
response.LinkedIdentities = append(response.LinkedIdentities, convertLinkedIdentityFromStore(user, identity))
}
return response, nil
}
func (s *APIV1Service) CreateLinkedIdentity(ctx context.Context, request *v1pb.CreateLinkedIdentityRequest) (*v1pb.LinkedIdentity, error) {
user, err := s.resolveUserFromName(ctx, request.Parent)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid parent: %v", err)
}
currentUser, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if currentUser == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if currentUser.ID != user.ID {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
identityProvider, userInfo, err := s.resolveSSOIdentity(ctx, request.IdpName, request.Code, request.RedirectUri, request.CodeVerifier)
if err != nil {
return nil, err
}
provider := identityProvider.Uid
externUID := userInfo.Identifier
if _, err := s.bindSSOIdentityToUser(ctx, currentUser, provider, externUID); err != nil {
return nil, err
}
identity, err := s.Store.GetUserIdentity(ctx, &store.FindUserIdentity{
UserID: &currentUser.ID,
Provider: &provider,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get linked identity: %v", err)
}
if identity == nil {
return nil, status.Errorf(codes.Internal, "linked identity not found after creation")
}
return convertLinkedIdentityFromStore(user, identity), nil
}
func (s *APIV1Service) GetLinkedIdentity(ctx context.Context, request *v1pb.GetLinkedIdentityRequest) (*v1pb.LinkedIdentity, error) {
user, provider, err := s.resolveUserAndLinkedIdentityProviderFromName(ctx, request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid linked identity name: %v", err)
}
userID := user.ID
claims := auth.GetUserClaims(ctx)
if claims == nil || claims.UserID != userID {
currentUser, _ := s.fetchCurrentUser(ctx)
if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
identity, err := s.Store.GetUserIdentity(ctx, &store.FindUserIdentity{
UserID: &userID,
Provider: &provider,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get linked identity: %v", err)
}
if identity == nil {
return nil, status.Errorf(codes.NotFound, "linked identity not found")
}
return convertLinkedIdentityFromStore(user, identity), nil
}
func (s *APIV1Service) DeleteLinkedIdentity(ctx context.Context, request *v1pb.DeleteLinkedIdentityRequest) (*emptypb.Empty, error) {
user, provider, err := s.resolveUserAndLinkedIdentityProviderFromName(ctx, request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid linked identity name: %v", err)
}
userID := user.ID
claims := auth.GetUserClaims(ctx)
if claims == nil || claims.UserID != userID {
currentUser, _ := s.fetchCurrentUser(ctx)
if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
}
existing, err := s.Store.GetUserIdentity(ctx, &store.FindUserIdentity{
UserID: &userID,
Provider: &provider,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get linked identity: %v", err)
}
if existing == nil {
return nil, status.Errorf(codes.NotFound, "linked identity not found")
}
if err := s.Store.DeleteUserIdentities(ctx, &store.DeleteUserIdentity{
UserID: &userID,
Provider: &provider,
}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete linked identity: %v", err)
}
return &emptypb.Empty{}, nil
}
// ListPersonalAccessTokens retrieves all Personal Access Tokens (PATs) for a user. // ListPersonalAccessTokens retrieves all Personal Access Tokens (PATs) for a user.
// //
// Personal Access Tokens are used for: // Personal Access Tokens are used for:
......
package mysql
import (
"context"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
)
func (d *DB) CreateUserIdentity(ctx context.Context, create *store.UserIdentity) (*store.UserIdentity, error) {
stmt := "INSERT INTO `user_identity` (`user_id`, `provider`, `extern_uid`) VALUES (?, ?, ?)"
result, err := d.db.ExecContext(ctx, stmt, create.UserID, create.Provider, create.ExternUID)
if err != nil {
return nil, err
}
rawID, err := result.LastInsertId()
if err != nil {
return nil, err
}
id := int32(rawID)
list, err := d.ListUserIdentities(ctx, &store.FindUserIdentity{ID: &id})
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, errors.Errorf("failed to create user identity")
}
return list[0], nil
}
func (d *DB) ListUserIdentities(ctx context.Context, find *store.FindUserIdentity) ([]*store.UserIdentity, error) {
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "`id` = ?"), append(args, *find.ID)
}
if find.UserID != nil {
where, args = append(where, "`user_id` = ?"), append(args, *find.UserID)
}
if find.Provider != nil {
where, args = append(where, "`provider` = ?"), append(args, *find.Provider)
}
if find.ExternUID != nil {
where, args = append(where, "`extern_uid` = ?"), append(args, *find.ExternUID)
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
user_id,
provider,
extern_uid,
created_ts,
updated_ts
FROM user_identity
WHERE `+strings.Join(where, " AND ")+`
ORDER BY id ASC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.UserIdentity{}
for rows.Next() {
ui := &store.UserIdentity{}
if err := rows.Scan(
&ui.ID,
&ui.UserID,
&ui.Provider,
&ui.ExternUID,
&ui.CreatedTs,
&ui.UpdatedTs,
); err != nil {
return nil, err
}
list = append(list, ui)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteUserIdentities(ctx context.Context, delete *store.DeleteUserIdentity) error {
where, args := []string{"1 = 1"}, []any{}
if delete.ID != nil {
where, args = append(where, "`id` = ?"), append(args, *delete.ID)
}
if delete.UserID != nil {
where, args = append(where, "`user_id` = ?"), append(args, *delete.UserID)
}
if delete.Provider != nil {
where, args = append(where, "`provider` = ?"), append(args, *delete.Provider)
}
if _, err := d.db.ExecContext(ctx, "DELETE FROM `user_identity` WHERE "+strings.Join(where, " AND "), args...); err != nil {
return err
}
return nil
}
...@@ -57,6 +57,22 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ...@@ -57,6 +57,22 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
return userSettingList, nil return userSettingList, nil
} }
func (d *DB) DeleteUserSettings(ctx context.Context, delete *store.DeleteUserSetting) error {
where, args := []string{"1 = 1"}, []any{}
if v := delete.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
where, args = append(where, "`key` = ?"), append(args, v.String())
}
if v := delete.UserID; v != nil {
where, args = append(where, "`user_id` = ?"), append(args, *v)
}
if _, err := d.db.ExecContext(ctx, "DELETE FROM `user_setting` WHERE "+strings.Join(where, " AND "), args...); err != nil {
return err
}
return nil
}
func (d *DB) GetUserByPATHash(ctx context.Context, tokenHash string) (*store.PATQueryResult, error) { func (d *DB) GetUserByPATHash(ctx context.Context, tokenHash string) (*store.PATQueryResult, error) {
query := ` query := `
SELECT SELECT
......
package postgres
import (
"context"
"strings"
"github.com/usememos/memos/store"
)
func (d *DB) CreateUserIdentity(ctx context.Context, create *store.UserIdentity) (*store.UserIdentity, error) {
stmt := "INSERT INTO user_identity (user_id, provider, extern_uid) VALUES (" + placeholders(3) + ") RETURNING id, created_ts, updated_ts"
if err := d.db.QueryRowContext(ctx, stmt, create.UserID, create.Provider, create.ExternUID).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
); err != nil {
return nil, err
}
return create, nil
}
func (d *DB) ListUserIdentities(ctx context.Context, find *store.FindUserIdentity) ([]*store.UserIdentity, error) {
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
}
if find.UserID != nil {
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID)
}
if find.Provider != nil {
where, args = append(where, "provider = "+placeholder(len(args)+1)), append(args, *find.Provider)
}
if find.ExternUID != nil {
where, args = append(where, "extern_uid = "+placeholder(len(args)+1)), append(args, *find.ExternUID)
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
user_id,
provider,
extern_uid,
created_ts,
updated_ts
FROM user_identity
WHERE `+strings.Join(where, " AND ")+`
ORDER BY id ASC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.UserIdentity{}
for rows.Next() {
ui := &store.UserIdentity{}
if err := rows.Scan(
&ui.ID,
&ui.UserID,
&ui.Provider,
&ui.ExternUID,
&ui.CreatedTs,
&ui.UpdatedTs,
); err != nil {
return nil, err
}
list = append(list, ui)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteUserIdentities(ctx context.Context, delete *store.DeleteUserIdentity) error {
where, args := []string{"1 = 1"}, []any{}
if delete.ID != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *delete.ID)
}
if delete.UserID != nil {
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *delete.UserID)
}
if delete.Provider != nil {
where, args = append(where, "provider = "+placeholder(len(args)+1)), append(args, *delete.Provider)
}
if _, err := d.db.ExecContext(ctx, "DELETE FROM user_identity WHERE "+strings.Join(where, " AND "), args...); err != nil {
return err
}
return nil
}
...@@ -70,6 +70,22 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ...@@ -70,6 +70,22 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
return userSettingList, nil return userSettingList, nil
} }
func (d *DB) DeleteUserSettings(ctx context.Context, delete *store.DeleteUserSetting) error {
where, args := []string{"1 = 1"}, []any{}
if v := delete.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, v.String())
}
if v := delete.UserID; v != nil {
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *v)
}
if _, err := d.db.ExecContext(ctx, "DELETE FROM user_setting WHERE "+strings.Join(where, " AND "), args...); err != nil {
return err
}
return nil
}
func (d *DB) GetUserByPATHash(ctx context.Context, tokenHash string) (*store.PATQueryResult, error) { func (d *DB) GetUserByPATHash(ctx context.Context, tokenHash string) (*store.PATQueryResult, error) {
// Simplified query: fetch all PERSONAL_ACCESS_TOKENS rows and search in Go // Simplified query: fetch all PERSONAL_ACCESS_TOKENS rows and search in Go
// This matches SQLite/MySQL behavior and avoids PostgreSQL's strict JSONB errors // This matches SQLite/MySQL behavior and avoids PostgreSQL's strict JSONB errors
......
package sqlite
import (
"context"
"strings"
"github.com/usememos/memos/store"
)
func (d *DB) CreateUserIdentity(ctx context.Context, create *store.UserIdentity) (*store.UserIdentity, error) {
stmt := "INSERT INTO `user_identity` (`user_id`, `provider`, `extern_uid`) VALUES (?, ?, ?) RETURNING `id`, `created_ts`, `updated_ts`"
if err := d.db.QueryRowContext(ctx, stmt, create.UserID, create.Provider, create.ExternUID).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
); err != nil {
return nil, err
}
return create, nil
}
func (d *DB) ListUserIdentities(ctx context.Context, find *store.FindUserIdentity) ([]*store.UserIdentity, error) {
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "`id` = ?"), append(args, *find.ID)
}
if find.UserID != nil {
where, args = append(where, "`user_id` = ?"), append(args, *find.UserID)
}
if find.Provider != nil {
where, args = append(where, "`provider` = ?"), append(args, *find.Provider)
}
if find.ExternUID != nil {
where, args = append(where, "`extern_uid` = ?"), append(args, *find.ExternUID)
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
user_id,
provider,
extern_uid,
created_ts,
updated_ts
FROM user_identity
WHERE `+strings.Join(where, " AND ")+`
ORDER BY id ASC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.UserIdentity{}
for rows.Next() {
ui := &store.UserIdentity{}
if err := rows.Scan(
&ui.ID,
&ui.UserID,
&ui.Provider,
&ui.ExternUID,
&ui.CreatedTs,
&ui.UpdatedTs,
); err != nil {
return nil, err
}
list = append(list, ui)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteUserIdentities(ctx context.Context, delete *store.DeleteUserIdentity) error {
where, args := []string{"1 = 1"}, []any{}
if delete.ID != nil {
where, args = append(where, "`id` = ?"), append(args, *delete.ID)
}
if delete.UserID != nil {
where, args = append(where, "`user_id` = ?"), append(args, *delete.UserID)
}
if delete.Provider != nil {
where, args = append(where, "`provider` = ?"), append(args, *delete.Provider)
}
if _, err := d.db.ExecContext(ctx, "DELETE FROM `user_identity` WHERE "+strings.Join(where, " AND "), args...); err != nil {
return err
}
return nil
}
...@@ -69,6 +69,22 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ...@@ -69,6 +69,22 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
return userSettingList, nil return userSettingList, nil
} }
func (d *DB) DeleteUserSettings(ctx context.Context, delete *store.DeleteUserSetting) error {
where, args := []string{"1 = 1"}, []any{}
if v := delete.Key; v != storepb.UserSetting_KEY_UNSPECIFIED {
where, args = append(where, "key = ?"), append(args, v.String())
}
if v := delete.UserID; v != nil {
where, args = append(where, "user_id = ?"), append(args, *v)
}
if _, err := d.db.ExecContext(ctx, "DELETE FROM user_setting WHERE "+strings.Join(where, " AND "), args...); err != nil {
return err
}
return nil
}
func (d *DB) GetUserByPATHash(ctx context.Context, tokenHash string) (*store.PATQueryResult, error) { func (d *DB) GetUserByPATHash(ctx context.Context, tokenHash string) (*store.PATQueryResult, error) {
query := ` query := `
SELECT SELECT
......
...@@ -45,6 +45,7 @@ type Driver interface { ...@@ -45,6 +45,7 @@ type Driver interface {
// UserSetting model related methods. // UserSetting model related methods.
UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error)
ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error)
DeleteUserSettings(ctx context.Context, delete *DeleteUserSetting) error
GetUserByPATHash(ctx context.Context, tokenHash string) (*PATQueryResult, error) GetUserByPATHash(ctx context.Context, tokenHash string) (*PATQueryResult, error)
// IdentityProvider model related methods. // IdentityProvider model related methods.
...@@ -70,4 +71,9 @@ type Driver interface { ...@@ -70,4 +71,9 @@ type Driver interface {
ListMemoShares(ctx context.Context, find *FindMemoShare) ([]*MemoShare, error) ListMemoShares(ctx context.Context, find *FindMemoShare) ([]*MemoShare, error)
GetMemoShare(ctx context.Context, find *FindMemoShare) (*MemoShare, error) GetMemoShare(ctx context.Context, find *FindMemoShare) (*MemoShare, error)
DeleteMemoShare(ctx context.Context, delete *DeleteMemoShare) error DeleteMemoShare(ctx context.Context, delete *DeleteMemoShare) error
// UserIdentity model related methods.
CreateUserIdentity(ctx context.Context, create *UserIdentity) (*UserIdentity, error)
ListUserIdentities(ctx context.Context, find *FindUserIdentity) ([]*UserIdentity, error)
DeleteUserIdentities(ctx context.Context, delete *DeleteUserIdentity) error
} }
-- user_identity stores the linkage between an external identity subject and a local user.
-- (provider, extern_uid) is unique across the table; provider stores the idp.uid.
-- Each local user can link at most one external account per provider.
CREATE TABLE `user_identity` (
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
`user_id` INT NOT NULL,
`provider` VARCHAR(256) NOT NULL,
`extern_uid` VARCHAR(256) NOT NULL,
`created_ts` BIGINT NOT NULL DEFAULT (UNIX_TIMESTAMP()),
`updated_ts` BIGINT NOT NULL DEFAULT (UNIX_TIMESTAMP()),
UNIQUE (`provider`, `extern_uid`),
UNIQUE (`user_id`, `provider`)
);
CREATE INDEX `idx_user_identity_user_id` ON `user_identity`(`user_id`);
...@@ -109,3 +109,17 @@ CREATE TABLE `memo_share` ( ...@@ -109,3 +109,17 @@ CREATE TABLE `memo_share` (
); );
CREATE INDEX `idx_memo_share_memo_id` ON `memo_share`(`memo_id`); CREATE INDEX `idx_memo_share_memo_id` ON `memo_share`(`memo_id`);
-- user_identity
CREATE TABLE `user_identity` (
`id` INT NOT NULL AUTO_INCREMENT PRIMARY KEY,
`user_id` INT NOT NULL,
`provider` VARCHAR(256) NOT NULL,
`extern_uid` VARCHAR(256) NOT NULL,
`created_ts` BIGINT NOT NULL DEFAULT (UNIX_TIMESTAMP()),
`updated_ts` BIGINT NOT NULL DEFAULT (UNIX_TIMESTAMP()),
UNIQUE (`provider`, `extern_uid`),
UNIQUE (`user_id`, `provider`)
);
CREATE INDEX `idx_user_identity_user_id` ON `user_identity`(`user_id`);
-- user_identity stores the linkage between an external identity subject and a local user.
-- (provider, extern_uid) is unique across the table; provider stores the idp.uid.
-- Each local user can link at most one external account per provider.
CREATE TABLE user_identity (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL,
provider TEXT NOT NULL,
extern_uid TEXT NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
UNIQUE (provider, extern_uid),
UNIQUE (user_id, provider)
);
CREATE INDEX idx_user_identity_user_id ON user_identity(user_id);
...@@ -109,3 +109,17 @@ CREATE TABLE memo_share ( ...@@ -109,3 +109,17 @@ CREATE TABLE memo_share (
); );
CREATE INDEX idx_memo_share_memo_id ON memo_share(memo_id); CREATE INDEX idx_memo_share_memo_id ON memo_share(memo_id);
-- user_identity
CREATE TABLE user_identity (
id SERIAL PRIMARY KEY,
user_id INTEGER NOT NULL,
provider TEXT NOT NULL,
extern_uid TEXT NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
UNIQUE (provider, extern_uid),
UNIQUE (user_id, provider)
);
CREATE INDEX idx_user_identity_user_id ON user_identity(user_id);
-- user_identity stores the linkage between an external identity subject and a local user.
-- (provider, extern_uid) is unique across the table; provider stores the idp.uid.
-- Each local user can link at most one external account per provider.
CREATE TABLE user_identity (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
provider TEXT NOT NULL,
extern_uid TEXT NOT NULL,
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
UNIQUE (provider, extern_uid),
UNIQUE (user_id, provider)
);
CREATE INDEX idx_user_identity_user_id ON user_identity(user_id);
...@@ -110,3 +110,17 @@ CREATE TABLE memo_share ( ...@@ -110,3 +110,17 @@ CREATE TABLE memo_share (
); );
CREATE INDEX idx_memo_share_memo_id ON memo_share(memo_id); CREATE INDEX idx_memo_share_memo_id ON memo_share(memo_id);
-- user_identity
CREATE TABLE user_identity (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
provider TEXT NOT NULL,
extern_uid TEXT NOT NULL,
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
UNIQUE (provider, extern_uid),
UNIQUE (user_id, provider)
);
CREATE INDEX idx_user_identity_user_id ON user_identity(user_id);
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
func TestUserIdentityCreateAndGet(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
defer ts.Close()
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
provider := "idp-uid-1"
externUID := "jane@example.com"
created, err := ts.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: user.ID,
Provider: provider,
ExternUID: externUID,
})
require.NoError(t, err)
require.NotZero(t, created.ID)
require.NotZero(t, created.CreatedTs)
require.Equal(t, user.ID, created.UserID)
require.Equal(t, provider, created.Provider)
require.Equal(t, externUID, created.ExternUID)
got, err := ts.GetUserIdentity(ctx, &store.FindUserIdentity{
Provider: &provider,
ExternUID: &externUID,
})
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, created.ID, got.ID)
require.Equal(t, user.ID, got.UserID)
// Miss returns (nil, nil).
missingProvider := "idp-uid-missing"
notFound, err := ts.GetUserIdentity(ctx, &store.FindUserIdentity{
Provider: &missingProvider,
ExternUID: &externUID,
})
require.NoError(t, err)
require.Nil(t, notFound)
}
func TestUserIdentityListByUserID(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
defer ts.Close()
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: user.ID,
Provider: "idp-A",
ExternUID: "sub-a-1",
})
require.NoError(t, err)
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: user.ID,
Provider: "idp-B",
ExternUID: "sub-b-1",
})
require.NoError(t, err)
list, err := ts.ListUserIdentities(ctx, &store.FindUserIdentity{
UserID: &user.ID,
})
require.NoError(t, err)
require.Len(t, list, 2)
}
func TestUserIdentityUniqueConflict(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
defer ts.Close()
userA, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
userB, err := createTestingUserWithRole(ctx, ts, "conflict_user", store.RoleUser)
require.NoError(t, err)
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: userA.ID,
Provider: "idp-A",
ExternUID: "sub-1",
})
require.NoError(t, err)
// Second insert with the same (provider, extern_uid) must fail regardless of user_id.
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: userB.ID,
Provider: "idp-A",
ExternUID: "sub-1",
})
require.Error(t, err)
}
func TestUserIdentitySameExternUIDDifferentProviders(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
defer ts.Close()
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: user.ID,
Provider: "idp-A",
ExternUID: "sub-1",
})
require.NoError(t, err)
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: user.ID,
Provider: "idp-B",
ExternUID: "sub-1",
})
require.NoError(t, err)
externUID := "sub-1"
list, err := ts.ListUserIdentities(ctx, &store.FindUserIdentity{
ExternUID: &externUID,
})
require.NoError(t, err)
require.Len(t, list, 2)
}
func TestUserIdentitySameUserSameProviderConflicts(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
defer ts.Close()
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: user.ID,
Provider: "idp-A",
ExternUID: "sub-1",
})
require.NoError(t, err)
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: user.ID,
Provider: "idp-A",
ExternUID: "sub-2",
})
require.Error(t, err)
}
func TestUserIdentityDeleteByUserAndProvider(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
defer ts.Close()
user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err)
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: user.ID,
Provider: "idp-A",
ExternUID: "sub-a-1",
})
require.NoError(t, err)
_, err = ts.CreateUserIdentity(ctx, &store.UserIdentity{
UserID: user.ID,
Provider: "idp-B",
ExternUID: "sub-b-1",
})
require.NoError(t, err)
provider := "idp-A"
err = ts.DeleteUserIdentities(ctx, &store.DeleteUserIdentity{
UserID: &user.ID,
Provider: &provider,
})
require.NoError(t, err)
list, err := ts.ListUserIdentities(ctx, &store.FindUserIdentity{
UserID: &user.ID,
})
require.NoError(t, err)
require.Len(t, list, 1)
require.Equal(t, "idp-B", list[0].Provider)
}
package store
import "context"
// UserIdentity is the linkage between an external identity subject and a local user.
// Uniqueness is enforced on (Provider, ExternUID); one local user may have multiple
// identities across different providers.
type UserIdentity struct {
ID int32
UserID int32
Provider string
ExternUID string
CreatedTs int64
UpdatedTs int64
}
// FindUserIdentity is used to filter user identities in list/get queries.
type FindUserIdentity struct {
ID *int32
UserID *int32
Provider *string
ExternUID *string
}
// DeleteUserIdentity is used to delete user identity linkage rows.
type DeleteUserIdentity struct {
ID *int32
UserID *int32
Provider *string
}
// CreateUserIdentity creates a new external-identity linkage record.
// Returns the driver error on unique-constraint violation; callers are responsible
// for reconciling concurrent first-login races on (Provider, ExternUID).
func (s *Store) CreateUserIdentity(ctx context.Context, create *UserIdentity) (*UserIdentity, error) {
return s.driver.CreateUserIdentity(ctx, create)
}
// ListUserIdentities returns all linkage records matching the filter.
func (s *Store) ListUserIdentities(ctx context.Context, find *FindUserIdentity) ([]*UserIdentity, error) {
return s.driver.ListUserIdentities(ctx, find)
}
// DeleteUserIdentities deletes all linkage records matching the filter.
func (s *Store) DeleteUserIdentities(ctx context.Context, delete *DeleteUserIdentity) error {
return s.driver.DeleteUserIdentities(ctx, delete)
}
// GetUserIdentity returns the first linkage record matching the filter, or nil if none found.
func (s *Store) GetUserIdentity(ctx context.Context, find *FindUserIdentity) (*UserIdentity, error) {
list, err := s.ListUserIdentities(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}
...@@ -21,6 +21,11 @@ type FindUserSetting struct { ...@@ -21,6 +21,11 @@ type FindUserSetting struct {
Key storepb.UserSetting_Key Key storepb.UserSetting_Key
} }
type DeleteUserSetting struct {
UserID *int32
Key storepb.UserSetting_Key
}
// RefreshTokenQueryResult contains the result of querying a refresh token. // RefreshTokenQueryResult contains the result of querying a refresh token.
type RefreshTokenQueryResult struct { type RefreshTokenQueryResult struct {
UserID int32 UserID int32
...@@ -102,6 +107,23 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*sto ...@@ -102,6 +107,23 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*sto
return userSetting, nil return userSetting, nil
} }
func (s *Store) DeleteUserSettings(ctx context.Context, delete *DeleteUserSetting) error {
existing, err := s.ListUserSettings(ctx, &FindUserSetting{
UserID: delete.UserID,
Key: delete.Key,
})
if err != nil {
return err
}
if err := s.driver.DeleteUserSettings(ctx, delete); err != nil {
return err
}
for _, setting := range existing {
s.userSettingCache.Delete(ctx, getUserSettingCacheKey(setting.UserId, setting.Key.String()))
}
return nil
}
// GetUserByPATHash finds a user by PAT hash. // GetUserByPATHash finds a user by PAT hash.
func (s *Store) GetUserByPATHash(ctx context.Context, tokenHash string) (*PATQueryResult, error) { func (s *Store) GetUserByPATHash(ctx context.Context, tokenHash string) (*PATQueryResult, error) {
result, err := s.driver.GetUserByPATHash(ctx, tokenHash) result, err := s.driver.GetUserByPATHash(ctx, tokenHash)
......
import { useEffect, useMemo, useState } from "react";
import { toast } from "react-hot-toast";
import { Button } from "@/components/ui/button";
import { identityProviderServiceClient, userServiceClient } from "@/connect";
import { absolutifyLink } from "@/helpers/utils";
import useCurrentUser from "@/hooks/useCurrentUser";
import { handleError } from "@/lib/error";
import { IdentityProvider, IdentityProvider_Type } from "@/types/proto/api/v1/idp_service_pb";
import { LinkedIdentity } from "@/types/proto/api/v1/user_service_pb";
import { useTranslate } from "@/utils/i18n";
import { storeOAuthState } from "@/utils/oauth";
import SettingGroup from "./SettingGroup";
import SettingTable from "./SettingTable";
interface LinkedIdentityRow extends Record<string, unknown> {
name: string;
title: string;
externUid: string;
linkedIdentity?: LinkedIdentity;
identityProvider: IdentityProvider;
}
const LinkedIdentitySection = () => {
const t = useTranslate();
const currentUser = useCurrentUser();
const [identityProviderList, setIdentityProviderList] = useState<IdentityProvider[]>([]);
const [linkedIdentityList, setLinkedIdentityList] = useState<LinkedIdentity[]>([]);
const fetchData = async () => {
if (!currentUser?.name) {
return;
}
const [{ identityProviders }, { linkedIdentities }] = await Promise.all([
identityProviderServiceClient.listIdentityProviders({}),
userServiceClient.listLinkedIdentities({ parent: currentUser.name }),
]);
setIdentityProviderList(identityProviders);
setLinkedIdentityList(linkedIdentities);
};
useEffect(() => {
if (!currentUser?.name) {
return;
}
fetchData().catch((error: unknown) => {
handleError(error, toast.error, {
context: "Load linked identities",
});
});
}, [currentUser?.name]);
const oauthIdentityProviders = useMemo(
() => identityProviderList.filter((identityProvider) => identityProvider.type === IdentityProvider_Type.OAUTH2),
[identityProviderList],
);
const linkedIdentityByProviderName = useMemo(() => {
const mapping = new Map<string, LinkedIdentity>();
for (const linkedIdentity of linkedIdentityList) {
if (!mapping.has(linkedIdentity.idpName)) {
mapping.set(linkedIdentity.idpName, linkedIdentity);
}
}
return mapping;
}, [linkedIdentityList]);
const rows = useMemo<LinkedIdentityRow[]>(
() =>
oauthIdentityProviders.map((identityProvider) => {
const linkedIdentity = linkedIdentityByProviderName.get(identityProvider.name);
return {
name: identityProvider.name,
title: identityProvider.title,
externUid: linkedIdentity?.externUid ?? "",
linkedIdentity,
identityProvider,
};
}),
[linkedIdentityByProviderName, oauthIdentityProviders],
);
const handleLinkIdentityProvider = async (identityProvider: IdentityProvider) => {
if (!currentUser?.name) {
return;
}
const redirectUri = absolutifyLink("/auth/callback");
const oauth2Config = identityProvider.config?.config?.case === "oauth2Config" ? identityProvider.config.config.value : undefined;
if (!oauth2Config) {
toast.error("Identity provider configuration is invalid.");
return;
}
try {
const returnUrl = `${window.location.pathname}${window.location.search}${window.location.hash}`;
const { state, codeChallenge } = await storeOAuthState(identityProvider.name, "link", returnUrl, currentUser.name);
let authUrl = `${oauth2Config.authUrl}?client_id=${
oauth2Config.clientId
}&redirect_uri=${encodeURIComponent(redirectUri)}&state=${state}&response_type=code&scope=${encodeURIComponent(
oauth2Config.scopes.join(" "),
)}`;
if (codeChallenge) {
authUrl += `&code_challenge=${codeChallenge}&code_challenge_method=S256`;
}
window.location.href = authUrl;
} catch (error) {
handleError(error, toast.error, {
context: "Failed to initiate OAuth flow",
fallbackMessage: "Failed to initiate account linking. Please try again.",
});
}
};
const handleUnlinkIdentityProvider = async (row: LinkedIdentityRow) => {
if (!row.linkedIdentity?.name) {
return;
}
try {
await userServiceClient.deleteLinkedIdentity({
name: row.linkedIdentity.name,
});
await fetchData();
toast.success(`Unlinked ${row.title}.`);
} catch (error) {
handleError(error, toast.error, {
context: "Delete linked identity",
fallbackMessage: "Failed to unlink identity provider.",
});
}
};
if (oauthIdentityProviders.length === 0) {
return null;
}
return (
<SettingGroup
showSeparator
title="SSO accounts"
description="Each provider can be linked to this account at most once. A linked row shows the current extern_uid and can be unlinked."
>
<SettingTable<LinkedIdentityRow>
columns={[
{
key: "title",
header: "SSO provider",
render: (_, row: LinkedIdentityRow) => <span className="text-foreground">{row.title}</span>,
},
{
key: "externUid",
header: "extern_uid",
render: (_, row: LinkedIdentityRow) => (
<span className={row.externUid ? "text-foreground" : "text-muted-foreground"}>
{row.externUid || t("attachment-library.labels.not-linked")}
</span>
),
},
{
key: "actions",
header: "",
className: "text-right",
render: (_, row: LinkedIdentityRow) =>
row.linkedIdentity ? (
<Button variant="outline" size="sm" onClick={() => handleUnlinkIdentityProvider(row)}>
Unlink
</Button>
) : (
<Button variant="outline" size="sm" onClick={() => handleLinkIdentityProvider(row.identityProvider)}>
{t("common.link")}
</Button>
),
},
]}
data={rows}
emptyMessage="No SSO providers found."
getRowKey={(row) => row.name}
/>
</SettingGroup>
);
};
export default LinkedIdentitySection;
import { MoreVerticalIcon, PenLineIcon } from "lucide-react"; import { MoreVerticalIcon, PenLineIcon } from "lucide-react";
import { useState } from "react";
import toast from "react-hot-toast";
import ConfirmDialog from "@/components/ConfirmDialog";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
import { userServiceClient } from "@/connect";
import { useAuth } from "@/contexts/AuthContext";
import useCurrentUser from "@/hooks/useCurrentUser"; import useCurrentUser from "@/hooks/useCurrentUser";
import { useDialog } from "@/hooks/useDialog"; import { useDialog } from "@/hooks/useDialog";
import useNavigateTo from "@/hooks/useNavigateTo";
import { handleError } from "@/lib/error";
import { ROUTES } from "@/router/routes";
import { useTranslate } from "@/utils/i18n"; import { useTranslate } from "@/utils/i18n";
import ChangeMemberPasswordDialog from "../ChangeMemberPasswordDialog"; import ChangeMemberPasswordDialog from "../ChangeMemberPasswordDialog";
import UpdateAccountDialog from "../UpdateAccountDialog"; import UpdateAccountDialog from "../UpdateAccountDialog";
import UserAvatar from "../UserAvatar"; import UserAvatar from "../UserAvatar";
import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger } from "../ui/dropdown-menu"; import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger } from "../ui/dropdown-menu";
import AccessTokenSection from "./AccessTokenSection"; import AccessTokenSection from "./AccessTokenSection";
import LinkedIdentitySection from "./LinkedIdentitySection";
import SettingGroup from "./SettingGroup"; import SettingGroup from "./SettingGroup";
import SettingSection from "./SettingSection"; import SettingSection from "./SettingSection";
const MyAccountSection = () => { const MyAccountSection = () => {
const t = useTranslate(); const t = useTranslate();
const user = useCurrentUser(); const user = useCurrentUser();
const { logout } = useAuth();
const navigateTo = useNavigateTo();
const accountDialog = useDialog(); const accountDialog = useDialog();
const passwordDialog = useDialog(); const passwordDialog = useDialog();
const [deleteOpen, setDeleteOpen] = useState(false);
const handleDeleteAccount = async () => {
if (!user?.name) {
return;
}
try {
await userServiceClient.deleteUser({ name: user.name });
await logout();
toast.success(t("setting.member.delete-success", { username: user.username }));
navigateTo(ROUTES.AUTH, { replace: true });
} catch (error) {
handleError(error, toast.error, { context: "Delete account" });
throw error;
}
};
return ( return (
<SettingSection title={t("setting.my-account.label")}> <SettingSection title={t("setting.my-account.label")}>
...@@ -42,21 +69,35 @@ const MyAccountSection = () => { ...@@ -42,21 +69,35 @@ const MyAccountSection = () => {
</DropdownMenuTrigger> </DropdownMenuTrigger>
<DropdownMenuContent align="end"> <DropdownMenuContent align="end">
<DropdownMenuItem onClick={passwordDialog.open}>{t("setting.account.change-password")}</DropdownMenuItem> <DropdownMenuItem onClick={passwordDialog.open}>{t("setting.account.change-password")}</DropdownMenuItem>
<DropdownMenuItem onClick={() => setDeleteOpen(true)} className="text-destructive focus:text-destructive">
{t("setting.account.delete-account")}
</DropdownMenuItem>
</DropdownMenuContent> </DropdownMenuContent>
</DropdownMenu> </DropdownMenu>
</div> </div>
</div> </div>
</SettingGroup> </SettingGroup>
<SettingGroup showSeparator> <LinkedIdentitySection />
<AccessTokenSection /> <AccessTokenSection />
</SettingGroup>
{/* Update Account Dialog */} {/* Update Account Dialog */}
<UpdateAccountDialog open={accountDialog.isOpen} onOpenChange={accountDialog.setOpen} /> <UpdateAccountDialog open={accountDialog.isOpen} onOpenChange={accountDialog.setOpen} />
{/* Change Password Dialog */} {/* Change Password Dialog */}
<ChangeMemberPasswordDialog open={passwordDialog.isOpen} onOpenChange={passwordDialog.setOpen} user={user} /> <ChangeMemberPasswordDialog open={passwordDialog.isOpen} onOpenChange={passwordDialog.setOpen} user={user} />
<ConfirmDialog
open={deleteOpen}
onOpenChange={setDeleteOpen}
title={user ? t("setting.member.delete-warning", { username: user.username }) : ""}
description={t("setting.member.delete-warning-description")}
confirmLabel={t("common.delete")}
cancelLabel={t("common.cancel")}
onConfirm={handleDeleteAccount}
confirmVariant="destructive"
/>
</SettingSection> </SettingSection>
); );
}; };
......
import { MoreVerticalIcon, PlusIcon } from "lucide-react"; import { MoreVerticalIcon, PlusIcon } from "lucide-react";
import { useEffect, useState } from "react"; import { useEffect, useMemo, useState } from "react";
import { toast } from "react-hot-toast"; import { toast } from "react-hot-toast";
import ConfirmDialog from "@/components/ConfirmDialog"; import ConfirmDialog from "@/components/ConfirmDialog";
import { Button } from "@/components/ui/button"; import { Button } from "@/components/ui/button";
...@@ -7,13 +7,23 @@ import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigge ...@@ -7,13 +7,23 @@ import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigge
import { identityProviderServiceClient } from "@/connect"; import { identityProviderServiceClient } from "@/connect";
import { useDialog } from "@/hooks/useDialog"; import { useDialog } from "@/hooks/useDialog";
import { handleError } from "@/lib/error"; import { handleError } from "@/lib/error";
import { IdentityProvider } from "@/types/proto/api/v1/idp_service_pb"; import { IdentityProvider, IdentityProvider_Type } from "@/types/proto/api/v1/idp_service_pb";
import { useTranslate } from "@/utils/i18n"; import { useTranslate } from "@/utils/i18n";
import CreateIdentityProviderDialog from "../CreateIdentityProviderDialog"; import CreateIdentityProviderDialog from "../CreateIdentityProviderDialog";
import LearnMore from "../LearnMore"; import LearnMore from "../LearnMore";
import SettingSection from "./SettingSection"; import SettingSection from "./SettingSection";
import SettingTable from "./SettingTable"; import SettingTable from "./SettingTable";
interface IdentityProviderRow extends Record<string, unknown> {
name: string;
providerUid: string;
title: string;
typeLabel: string;
provider: IdentityProvider;
}
const getIdentityProviderUID = (name: string) => name.replace(/^identity-providers\//, "");
const SSOSection = () => { const SSOSection = () => {
const t = useTranslate(); const t = useTranslate();
const [identityProviderList, setIdentityProviderList] = useState<IdentityProvider[]>([]); const [identityProviderList, setIdentityProviderList] = useState<IdentityProvider[]>([]);
...@@ -22,14 +32,32 @@ const SSOSection = () => { ...@@ -22,14 +32,32 @@ const SSOSection = () => {
const idpDialog = useDialog(); const idpDialog = useDialog();
const fetchIdentityProviderList = async () => { const fetchIdentityProviderList = async () => {
try {
const { identityProviders } = await identityProviderServiceClient.listIdentityProviders({}); const { identityProviders } = await identityProviderServiceClient.listIdentityProviders({});
setIdentityProviderList(identityProviders); setIdentityProviderList(identityProviders);
} catch (error: unknown) {
handleError(error, toast.error, {
context: "Load identity providers",
});
}
}; };
useEffect(() => { useEffect(() => {
fetchIdentityProviderList(); void fetchIdentityProviderList();
}, []); }, []);
const rows = useMemo<IdentityProviderRow[]>(
() =>
identityProviderList.map((provider) => ({
name: provider.name,
providerUid: getIdentityProviderUID(provider.name),
title: provider.title,
typeLabel: IdentityProvider_Type[provider.type] ?? "TYPE_UNSPECIFIED",
provider,
})),
[identityProviderList],
);
const handleDeleteIdentityProvider = (identityProvider: IdentityProvider) => { const handleDeleteIdentityProvider = (identityProvider: IdentityProvider) => {
setDeleteTarget(identityProvider); setDeleteTarget(identityProvider);
}; };
...@@ -88,20 +116,25 @@ const SSOSection = () => { ...@@ -88,20 +116,25 @@ const SSOSection = () => {
<SettingTable <SettingTable
columns={[ columns={[
{ {
key: "title", key: "providerUid",
header: t("common.name"), header: "provider_uid",
render: (_, provider: IdentityProvider) => ( render: (_, row: IdentityProviderRow) => (
<span className="text-foreground"> <div className="flex flex-col">
{provider.title} <span className="text-foreground">{row.providerUid}</span>
<span className="ml-2 text-sm text-muted-foreground">({provider.type})</span> {row.title ? <span className="text-sm text-muted-foreground">{row.title}</span> : null}
</span> </div>
), ),
}, },
{
key: "typeLabel",
header: t("common.type"),
render: (_, row: IdentityProviderRow) => <span className="text-muted-foreground">{row.typeLabel}</span>,
},
{ {
key: "actions", key: "actions",
header: "", header: "",
className: "text-right", className: "text-right",
render: (_, provider: IdentityProvider) => ( render: (_, row: IdentityProviderRow) => (
<DropdownMenu> <DropdownMenu>
<DropdownMenuTrigger asChild> <DropdownMenuTrigger asChild>
<Button variant="outline" size="sm"> <Button variant="outline" size="sm">
...@@ -109,9 +142,9 @@ const SSOSection = () => { ...@@ -109,9 +142,9 @@ const SSOSection = () => {
</Button> </Button>
</DropdownMenuTrigger> </DropdownMenuTrigger>
<DropdownMenuContent align="end" sideOffset={2}> <DropdownMenuContent align="end" sideOffset={2}>
<DropdownMenuItem onClick={() => handleEditIdentityProvider(provider)}>{t("common.edit")}</DropdownMenuItem> <DropdownMenuItem onClick={() => handleEditIdentityProvider(row.provider)}>{t("common.edit")}</DropdownMenuItem>
<DropdownMenuItem <DropdownMenuItem
onClick={() => handleDeleteIdentityProvider(provider)} onClick={() => handleDeleteIdentityProvider(row.provider)}
className="text-destructive focus:text-destructive" className="text-destructive focus:text-destructive"
> >
{t("common.delete")} {t("common.delete")}
...@@ -121,9 +154,9 @@ const SSOSection = () => { ...@@ -121,9 +154,9 @@ const SSOSection = () => {
), ),
}, },
]} ]}
data={identityProviderList} data={rows}
emptyMessage={t("setting.sso.no-sso-found")} emptyMessage={t("setting.sso.no-sso-found")}
getRowKey={(provider) => provider.name} getRowKey={(row) => row.name}
/> />
<CreateIdentityProviderDialog <CreateIdentityProviderDialog
......
...@@ -386,6 +386,7 @@ ...@@ -386,6 +386,7 @@
}, },
"account": { "account": {
"change-password": "Change password", "change-password": "Change password",
"delete-account": "Delete account",
"email-note": "Optional", "email-note": "Optional",
"export-memos": "Export Memos", "export-memos": "Export Memos",
"nickname-note": "Displayed in the banner", "nickname-note": "Displayed in the banner",
......
...@@ -493,6 +493,7 @@ ...@@ -493,6 +493,7 @@
}, },
"account": { "account": {
"change-password": "修改密码", "change-password": "修改密码",
"delete-account": "删除账号",
"email-note": "可选", "email-note": "可选",
"export-memos": "导出备忘录", "export-memos": "导出备忘录",
"nickname-note": "显示在横幅中", "nickname-note": "显示在横幅中",
......
...@@ -2,7 +2,7 @@ import { timestampDate } from "@bufbuild/protobuf/wkt"; ...@@ -2,7 +2,7 @@ import { timestampDate } from "@bufbuild/protobuf/wkt";
import { useEffect, useRef, useState } from "react"; import { useEffect, useRef, useState } from "react";
import { useSearchParams } from "react-router-dom"; import { useSearchParams } from "react-router-dom";
import { setAccessToken } from "@/auth-state"; import { setAccessToken } from "@/auth-state";
import { authServiceClient } from "@/connect"; import { authServiceClient, userServiceClient } from "@/connect";
import { useAuth } from "@/contexts/AuthContext"; import { useAuth } from "@/contexts/AuthContext";
import { absolutifyLink } from "@/helpers/utils"; import { absolutifyLink } from "@/helpers/utils";
import useNavigateTo from "@/hooks/useNavigateTo"; import useNavigateTo from "@/hooks/useNavigateTo";
...@@ -18,7 +18,7 @@ interface State { ...@@ -18,7 +18,7 @@ interface State {
const AuthCallback = () => { const AuthCallback = () => {
const navigateTo = useNavigateTo(); const navigateTo = useNavigateTo();
const { initialize } = useAuth(); const { currentUser, initialize, isInitialized } = useAuth();
const [searchParams] = useSearchParams(); const [searchParams] = useSearchParams();
const handledRef = useRef(false); const handledRef = useRef(false);
const [state, setState] = useState<State>({ const [state, setState] = useState<State>({
...@@ -27,10 +27,12 @@ const AuthCallback = () => { ...@@ -27,10 +27,12 @@ const AuthCallback = () => {
}); });
useEffect(() => { useEffect(() => {
if (!isInitialized) {
return;
}
if (handledRef.current) { if (handledRef.current) {
return; return;
} }
handledRef.current = true;
// Check for OAuth error response first (e.g., user denied access) // Check for OAuth error response first (e.g., user denied access)
const error = searchParams.get("error"); const error = searchParams.get("error");
const errorDescription = searchParams.get("error_description"); const errorDescription = searchParams.get("error_description");
...@@ -74,11 +76,27 @@ const AuthCallback = () => { ...@@ -74,11 +76,27 @@ const AuthCallback = () => {
return; return;
} }
const { identityProviderName, returnUrl, codeVerifier } = validatedState; const { flowMode, identityProviderName, returnUrl, linkingUserName, codeVerifier } = validatedState;
const redirectUri = absolutifyLink("/auth/callback"); const redirectUri = absolutifyLink("/auth/callback");
handledRef.current = true;
(async () => { (async () => {
try { try {
if (flowMode === "link") {
if (!currentUser?.name) {
throw new Error("Failed to link account. Please sign in to Memos again and retry.");
}
if (linkingUserName && currentUser.name !== linkingUserName) {
throw new Error("The signed-in user changed before the OAuth callback completed. Please retry linking from account settings.");
}
await userServiceClient.createLinkedIdentity({
parent: currentUser.name,
idpName: identityProviderName,
code,
redirectUri,
codeVerifier: codeVerifier || "",
});
} else {
const response = await authServiceClient.signIn({ const response = await authServiceClient.signIn({
credentials: { credentials: {
case: "ssoCredentials", case: "ssoCredentials",
...@@ -94,6 +112,7 @@ const AuthCallback = () => { ...@@ -94,6 +112,7 @@ const AuthCallback = () => {
if (response.accessToken) { if (response.accessToken) {
setAccessToken(response.accessToken, response.accessTokenExpiresAt ? timestampDate(response.accessTokenExpiresAt) : undefined); setAccessToken(response.accessToken, response.accessTokenExpiresAt ? timestampDate(response.accessTokenExpiresAt) : undefined);
} }
}
setState({ setState({
loading: false, loading: false,
errorMessage: "", errorMessage: "",
...@@ -116,7 +135,7 @@ const AuthCallback = () => { ...@@ -116,7 +135,7 @@ const AuthCallback = () => {
}); });
} }
})(); })();
}, [searchParams, navigateTo]); }, [currentUser?.name, initialize, isInitialized, navigateTo, searchParams]);
if (state.loading) return null; if (state.loading) return null;
......
...@@ -44,7 +44,7 @@ const SignIn = () => { ...@@ -44,7 +44,7 @@ const SignIn = () => {
try { try {
// Generate and store secure state parameter with CSRF protection // Generate and store secure state parameter with CSRF protection
// Also generate PKCE parameters (code_challenge) for enhanced security if available // Also generate PKCE parameters (code_challenge) for enhanced security if available
const { state, codeChallenge } = await storeOAuthState(identityProvider.name, redirectTarget); const { state, codeChallenge } = await storeOAuthState(identityProvider.name, "signin", redirectTarget);
// Build OAuth authorization URL with secure state // Build OAuth authorization URL with secure state
// Include PKCE if available (requires HTTPS/localhost for crypto.subtle) // Include PKCE if available (requires HTTPS/localhost for crypto.subtle)
......
const STATE_STORAGE_KEY = "oauth_state"; const STATE_STORAGE_KEY = "oauth_state";
const STATE_EXPIRY_MS = 10 * 60 * 1000; // 10 minutes const STATE_EXPIRY_MS = 10 * 60 * 1000; // 10 minutes
export type OAuthFlowMode = "signin" | "link";
interface OAuthState { interface OAuthState {
state: string; state: string;
identityProviderName: string; identityProviderName: string;
flowMode: OAuthFlowMode;
timestamp: number; timestamp: number;
returnUrl?: string; returnUrl?: string;
linkingUserName?: string;
codeVerifier?: string; // PKCE code_verifier codeVerifier?: string; // PKCE code_verifier
} }
...@@ -44,7 +48,9 @@ function base64UrlEncode(buffer: Uint8Array): string { ...@@ -44,7 +48,9 @@ function base64UrlEncode(buffer: Uint8Array): string {
// PKCE is optional - if crypto APIs are unavailable (HTTP context), falls back to standard OAuth // PKCE is optional - if crypto APIs are unavailable (HTTP context), falls back to standard OAuth
export async function storeOAuthState( export async function storeOAuthState(
identityProviderName: string, identityProviderName: string,
flowMode: OAuthFlowMode,
returnUrl?: string, returnUrl?: string,
linkingUserName?: string,
): Promise<{ state: string; codeChallenge?: string }> { ): Promise<{ state: string; codeChallenge?: string }> {
const state = generateSecureState(); const state = generateSecureState();
...@@ -74,8 +80,10 @@ export async function storeOAuthState( ...@@ -74,8 +80,10 @@ export async function storeOAuthState(
const stateData: OAuthState = { const stateData: OAuthState = {
state, state,
identityProviderName, identityProviderName,
flowMode,
timestamp: Date.now(), timestamp: Date.now(),
returnUrl, returnUrl,
linkingUserName,
codeVerifier, // Store for later retrieval in callback (undefined if PKCE not available) codeVerifier, // Store for later retrieval in callback (undefined if PKCE not available)
}; };
...@@ -90,8 +98,10 @@ export async function storeOAuthState( ...@@ -90,8 +98,10 @@ export async function storeOAuthState(
} }
// Validate and retrieve OAuth state from storage (CSRF protection) // Validate and retrieve OAuth state from storage (CSRF protection)
// Returns identityProviderName, returnUrl, and codeVerifier for PKCE // Returns identityProviderName, flowMode, returnUrl, linkingUserName, and codeVerifier for PKCE
export function validateOAuthState(stateParam: string): { identityProviderName: string; returnUrl?: string; codeVerifier?: string } | null { export function validateOAuthState(
stateParam: string,
): { identityProviderName: string; flowMode: OAuthFlowMode; returnUrl?: string; linkingUserName?: string; codeVerifier?: string } | null {
try { try {
const storedData = sessionStorage.getItem(STATE_STORAGE_KEY); const storedData = sessionStorage.getItem(STATE_STORAGE_KEY);
if (!storedData) { if (!storedData) {
...@@ -119,7 +129,9 @@ export function validateOAuthState(stateParam: string): { identityProviderName: ...@@ -119,7 +129,9 @@ export function validateOAuthState(stateParam: string): { identityProviderName:
sessionStorage.removeItem(STATE_STORAGE_KEY); sessionStorage.removeItem(STATE_STORAGE_KEY);
return { return {
identityProviderName: stateData.identityProviderName, identityProviderName: stateData.identityProviderName,
flowMode: stateData.flowMode || "signin",
returnUrl: stateData.returnUrl, returnUrl: stateData.returnUrl,
linkingUserName: stateData.linkingUserName,
codeVerifier: stateData.codeVerifier, // Return PKCE code_verifier codeVerifier: stateData.codeVerifier, // Return PKCE code_verifier
}; };
} catch (error) { } catch (error) {
......
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import { storeOAuthState, validateOAuthState } from "@/utils/oauth";
describe("oauth state", () => {
beforeEach(() => {
sessionStorage.clear();
});
afterEach(() => {
sessionStorage.clear();
});
it("round-trips the linking user for link flows", async () => {
const { state } = await storeOAuthState("identity-providers/google", "link", "/settings", "users/alice");
expect(validateOAuthState(state)).toEqual({
identityProviderName: "identity-providers/google",
flowMode: "link",
returnUrl: "/settings",
linkingUserName: "users/alice",
codeVerifier: expect.any(String),
});
});
it("defaults older states to signin without a linking user", () => {
sessionStorage.setItem(
"oauth_state",
JSON.stringify({
state: "legacy-state",
identityProviderName: "identity-providers/google",
timestamp: Date.now(),
returnUrl: "/auth",
}),
);
expect(validateOAuthState("legacy-state")).toEqual({
identityProviderName: "identity-providers/google",
flowMode: "signin",
returnUrl: "/auth",
linkingUserName: undefined,
codeVerifier: undefined,
});
});
});
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment