You need to sign in or sign up before continuing.
Unverified Commit 0fb83a74 authored by boojack's avatar boojack Committed by GitHub

fix(auth): harden authorization and username validation (#5890)

parent ee179985
......@@ -141,7 +141,14 @@ func (a *Authenticator) AuthenticateToUser(ctx context.Context, authHeader, cook
if !strings.HasPrefix(token, PersonalAccessTokenPrefix) {
claims, err := a.AuthenticateByAccessTokenV2(token)
if err == nil && claims != nil {
return a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
user, err := a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
if err != nil {
return nil, err
}
if user == nil || user.RowStatus == store.Archived {
return nil, nil
}
return user, nil
}
} else {
user, _, err := a.AuthenticateByPAT(ctx, token)
......@@ -174,6 +181,10 @@ func (a *Authenticator) Authenticate(ctx context.Context, authHeader string) *Au
if token != "" && !strings.HasPrefix(token, PersonalAccessTokenPrefix) {
claims, err := a.AuthenticateByAccessTokenV2(token)
if err == nil && claims != nil {
user, err := a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID})
if err != nil || user == nil || user.RowStatus == store.Archived {
return nil
}
return &AuthResult{
Claims: claims,
AccessToken: token,
......
......@@ -140,6 +140,24 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
create.Size = int64(size)
create.Blob = request.Attachment.Content
if request.Attachment.Memo != nil {
memoUID, err := ExtractMemoUIDFromName(*request.Attachment.Memo)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found: %s", *request.Attachment.Memo)
}
if !canModifyMemo(user, memo) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
create.MemoID = &memo.ID
}
if create.Payload == nil || create.Payload.MotionMedia == nil {
if detectedMotion := detectAndroidMotionMedia(create.Blob, create.Type, attachmentUID); detectedMotion != nil {
create.Payload = ensureAttachmentPayload(create.Payload)
......@@ -172,20 +190,6 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
return nil, status.Errorf(codes.Internal, "failed to save attachment blob: %v", err)
}
if request.Attachment.Memo != nil {
memoUID, err := ExtractMemoUIDFromName(*request.Attachment.Memo)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err)
}
memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to find memo: %v", err)
}
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found: %s", *request.Attachment.Memo)
}
create.MemoID = &memo.ID
}
attachment, err := s.Store.CreateAttachment(ctx, create)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create attachment: %v", err)
......
......@@ -595,6 +595,9 @@ func (s *APIV1Service) fetchCurrentUser(ctx context.Context) (*store.User, error
if user == nil {
return nil, errors.Errorf("user %d not found", userID)
}
if user.RowStatus == store.Archived {
return nil, nil
}
return user, nil
}
......
......@@ -77,3 +77,7 @@ func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error {
func isSuperUser(user *store.User) bool {
return user.Role == store.RoleAdmin
}
func canModifyMemo(user *store.User, memo *store.Memo) bool {
return user != nil && memo != nil && (memo.CreatorID == user.ID || isSuperUser(user))
}
......@@ -2,11 +2,16 @@ package v1
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"connectrpc.com/connect"
"github.com/labstack/echo/v5"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/usememos/memos/internal/profile"
)
func TestMetadataInterceptorForwardsSecurityHeaders(t *testing.T) {
......@@ -37,3 +42,24 @@ func TestMetadataInterceptorForwardsSecurityHeaders(t *testing.T) {
t.Fatalf("metadata interceptor returned error: %v", err)
}
}
func TestAllowedConnectOrigin(t *testing.T) {
service := &APIV1Service{
Profile: &profile.Profile{InstanceURL: "https://memos.example"},
}
e := echo.New()
req := httptest.NewRequest(http.MethodOptions, "http://localhost/memos.api.v1.AuthService/SignIn", nil)
req.Host = "localhost"
rec := httptest.NewRecorder()
ctx := e.NewContext(req, rec)
if !service.isAllowedConnectOrigin(ctx, "http://localhost") {
t.Fatal("expected same host origin to be allowed")
}
if !service.isAllowedConnectOrigin(ctx, "https://memos.example") {
t.Fatal("expected instance URL origin to be allowed")
}
if service.isAllowedConnectOrigin(ctx, "https://evil.example") {
t.Fatal("expected unknown origin to be denied")
}
}
......@@ -32,7 +32,7 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.CreatorID != user.ID && !isSuperUser(user) {
if !canModifyMemo(user, memo) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
if err := s.setMemoAttachmentsInternal(ctx, memo, request.Attachments); err != nil {
......
......@@ -13,7 +13,7 @@ import (
// retry loops around concurrent first-time logins.
func deriveSSOUsername() (string, error) {
username := util.GenUUID()
if err := validateUsername(username); err != nil {
if err := validateWritableUsername(username); err != nil {
return "", errors.Wrap(err, "generated UUID did not satisfy username constraints")
}
return username, nil
......
......@@ -138,6 +138,123 @@ func TestCreateAttachment(t *testing.T) {
})
}
func TestCreateAttachmentMemoPermission(t *testing.T) {
ctx := context.Background()
t.Run("owner can create attachment directly linked to memo", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
owner, err := ts.CreateRegularUser(ctx, "attachment-owner")
require.NoError(t, err)
ownerCtx := ts.CreateUserContext(ctx, owner.ID)
memo, err := ts.Service.CreateMemo(ownerCtx, &v1pb.CreateMemoRequest{
Memo: &v1pb.Memo{
Content: "memo with direct attachment",
},
})
require.NoError(t, err)
attachment, err := ts.Service.CreateAttachment(ownerCtx, &v1pb.CreateAttachmentRequest{
Attachment: &v1pb.Attachment{
Filename: "owner.txt",
Type: "text/plain",
Content: []byte("owner"),
Memo: &memo.Name,
},
})
require.NoError(t, err)
attachmentUID, err := apiv1.ExtractAttachmentUIDFromName(attachment.Name)
require.NoError(t, err)
stored, err := ts.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
require.NoError(t, err)
require.NotNil(t, stored.MemoID)
require.Equal(t, memoIDFromName(ctx, t, ts, memo.Name), *stored.MemoID)
})
t.Run("admin can create attachment directly linked to memo", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
owner, err := ts.CreateRegularUser(ctx, "attachment-admin-owner")
require.NoError(t, err)
ownerCtx := ts.CreateUserContext(ctx, owner.ID)
admin, err := ts.CreateHostUser(ctx, "attachment-admin")
require.NoError(t, err)
adminCtx := ts.CreateUserContext(ctx, admin.ID)
memo, err := ts.Service.CreateMemo(ownerCtx, &v1pb.CreateMemoRequest{
Memo: &v1pb.Memo{
Content: "memo with admin attachment",
},
})
require.NoError(t, err)
attachment, err := ts.Service.CreateAttachment(adminCtx, &v1pb.CreateAttachmentRequest{
Attachment: &v1pb.Attachment{
Filename: "admin.txt",
Type: "text/plain",
Content: []byte("admin"),
Memo: &memo.Name,
},
})
require.NoError(t, err)
attachmentUID, err := apiv1.ExtractAttachmentUIDFromName(attachment.Name)
require.NoError(t, err)
stored, err := ts.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID})
require.NoError(t, err)
require.NotNil(t, stored.MemoID)
require.Equal(t, memoIDFromName(ctx, t, ts, memo.Name), *stored.MemoID)
})
t.Run("non-owner cannot create attachment directly linked to memo", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
owner, err := ts.CreateRegularUser(ctx, "attachment-owner-denied")
require.NoError(t, err)
ownerCtx := ts.CreateUserContext(ctx, owner.ID)
other, err := ts.CreateRegularUser(ctx, "attachment-other-denied")
require.NoError(t, err)
otherCtx := ts.CreateUserContext(ctx, other.ID)
memo, err := ts.Service.CreateMemo(ownerCtx, &v1pb.CreateMemoRequest{
Memo: &v1pb.Memo{
Content: "memo with blocked attachment",
},
})
require.NoError(t, err)
_, err = ts.Service.CreateAttachment(otherCtx, &v1pb.CreateAttachmentRequest{
Attachment: &v1pb.Attachment{
Filename: "blocked.txt",
Type: "text/plain",
Content: []byte("blocked"),
Memo: &memo.Name,
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
attachments, err := ts.Store.ListAttachments(ctx, &store.FindAttachment{
CreatorID: &other.ID,
})
require.NoError(t, err)
require.Empty(t, attachments)
})
}
func memoIDFromName(ctx context.Context, t *testing.T, ts *TestService, name string) int32 {
t.Helper()
memoUID, err := apiv1.ExtractMemoUIDFromName(name)
require.NoError(t, err)
memo, err := ts.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID})
require.NoError(t, err)
require.NotNil(t, memo)
return memo.ID
}
func TestCreateAttachmentMotionMedia(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
......
......@@ -154,6 +154,23 @@ func TestListAndDeleteLinkedIdentities(t *testing.T) {
require.Empty(t, listResp.LinkedIdentities)
}
func TestListLinkedIdentitiesRequiresAuthentication(t *testing.T) {
t.Parallel()
ts := NewTestService(t)
defer ts.Cleanup()
ctx := context.Background()
user, err := ts.CreateRegularUser(ctx, "linked-identity-auth")
require.NoError(t, err)
_, err = ts.Service.ListLinkedIdentities(ctx, &v1pb.ListLinkedIdentitiesRequest{
Parent: apiv1.BuildUserName(user.Username),
})
require.Error(t, err)
require.Equal(t, codes.Unauthenticated, status.Code(err))
}
func TestCreateLinkedIdentityRejectsSecondIdentityForSameProvider(t *testing.T) {
t.Parallel()
......
......@@ -78,6 +78,33 @@ func TestAuthenticatorAccessTokenV2(t *testing.T) {
_, err = authenticator.AuthenticateByAccessTokenV2(token)
assert.Error(t, err)
})
t.Run("request authentication rejects archived user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "archived-access-token")
require.NoError(t, err)
token, _, err := auth.GenerateAccessTokenV2(
user.ID,
user.Username,
string(user.Role),
string(user.RowStatus),
[]byte(ts.Secret),
)
require.NoError(t, err)
archivedStatus := store.Archived
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
ID: user.ID,
RowStatus: &archivedStatus,
})
require.NoError(t, err)
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
result := authenticator.Authenticate(ctx, "Bearer "+token)
assert.Nil(t, result)
})
}
func TestAuthenticatorRefreshToken(t *testing.T) {
......
......@@ -52,3 +52,20 @@ func TestBatchGetUsersRejectsTooManyUsernames(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "too many usernames")
}
func TestBatchGetUsersRejectsTooManyNonEmptyUsernamesBeforeDedupe(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
usernames := make([]string, 0, 101)
for range 101 {
usernames = append(usernames, "legacy@example.com")
}
_, err := ts.Service.BatchGetUsers(ctx, &apiv1.BatchGetUsersRequest{
Usernames: usernames,
})
require.Error(t, err)
require.Contains(t, err.Error(), "too many usernames")
}
......@@ -5,6 +5,7 @@ import (
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt"
"google.golang.org/protobuf/types/known/fieldmaskpb"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
......@@ -15,6 +16,26 @@ import (
func TestUserServiceWithEmailLikeUsername(t *testing.T) {
ctx := context.Background()
t.Run("SignIn accepts email-like legacy username", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user := createLegacyPasswordUser(ctx, t, ts, "signin@example.com", "password123")
signInCtx := apiv1server.WithHeaderCarrier(ctx)
resp, err := ts.Service.SignIn(signInCtx, &apiv1.SignInRequest{
Credentials: &apiv1.SignInRequest_PasswordCredentials_{
PasswordCredentials: &apiv1.SignInRequest_PasswordCredentials{
Username: user.Username,
Password: "password123",
},
},
})
require.NoError(t, err)
require.Equal(t, user.Username, resp.User.Username)
require.NotEmpty(t, resp.AccessToken)
})
t.Run("GetUser accepts email-like username in resource name", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
......@@ -31,6 +52,38 @@ func TestUserServiceWithEmailLikeUsername(t *testing.T) {
require.Equal(t, "users/alice@example.com", got.Name)
})
t.Run("BatchGetUsers accepts email-like legacy username", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "batch@example.com")
require.NoError(t, err)
resp, err := ts.Service.BatchGetUsers(ctx, &apiv1.BatchGetUsersRequest{
Usernames: []string{" batch@example.com ", "missing@example.com", "batch@example.com"},
})
require.NoError(t, err)
require.Len(t, resp.Users, 1)
require.Equal(t, user.Username, resp.Users[0].Username)
require.Equal(t, "users/batch@example.com", resp.Users[0].Name)
})
t.Run("BatchGetUsers accepts underscore legacy username", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "legacy_batch")
require.NoError(t, err)
resp, err := ts.Service.BatchGetUsers(ctx, &apiv1.BatchGetUsersRequest{
Usernames: []string{"legacy_batch"},
})
require.NoError(t, err)
require.Len(t, resp.Users, 1)
require.Equal(t, user.Username, resp.Users[0].Username)
require.Equal(t, "users/legacy_batch", resp.Users[0].Name)
})
t.Run("ListUserSettings accepts email-like username in parent", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
......@@ -92,14 +145,70 @@ func TestUserServiceWithEmailLikeUsername(t *testing.T) {
require.Equal(t, "bob", stored.Username)
})
t.Run("UpdateUser rejects writing invalid username values", func(t *testing.T) {
for _, username := range []string{"alice@example.com", "legacy_user"} {
t.Run(username, func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "rename@example.com")
require.NoError(t, err)
authCtx := ts.CreateUserContext(ctx, user.ID)
_, err = ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: "users/rename@example.com",
Username: username,
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"username"}},
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid username")
stored, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID})
require.NoError(t, err)
require.NotNil(t, stored)
require.Equal(t, "rename@example.com", stored.Username)
})
}
})
t.Run("admin cannot rename user to invalid username", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "admin-rename-target")
require.NoError(t, err)
admin, err := ts.CreateHostUser(ctx, "rename-admin")
require.NoError(t, err)
adminCtx := ts.CreateUserContext(ctx, admin.ID)
_, err = ts.Service.UpdateUser(adminCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: apiv1server.BuildUserName(user.Username),
Username: "admin@example.com",
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"username"}},
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid username")
stored, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID})
require.NoError(t, err)
require.NotNil(t, stored)
require.Equal(t, "admin-rename-target", stored.Username)
})
t.Run("UpdateUser can archive email-like username account", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "dave@example.com")
require.NoError(t, err)
admin, err := ts.CreateHostUser(ctx, "email-admin")
require.NoError(t, err)
authCtx := ts.CreateUserContext(ctx, user.ID)
authCtx := ts.CreateUserContext(ctx, admin.ID)
updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: "users/dave@example.com",
......@@ -134,3 +243,17 @@ func TestUserServiceWithEmailLikeUsername(t *testing.T) {
require.Nil(t, deleted)
})
}
func createLegacyPasswordUser(ctx context.Context, t *testing.T, ts *TestService, username, password string) *store.User {
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
require.NoError(t, err)
user, err := ts.Store.CreateUser(ctx, &store.User{
Username: username,
Role: store.RoleUser,
Email: username,
PasswordHash: string(passwordHash),
})
require.NoError(t, err)
return user
}
......@@ -2,6 +2,8 @@ package test
import (
"context"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/require"
......@@ -10,6 +12,7 @@ import (
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
apiv1server "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
)
func TestCreateUserRegistration(t *testing.T) {
......@@ -218,6 +221,41 @@ func TestCreateUserRegistration(t *testing.T) {
require.Contains(t, err.Error(), "password must not be empty")
})
t.Run("CreateUser rejects invalid writable usernames", func(t *testing.T) {
for _, username := range []string{"alice@example.com", "legacy_user", "123"} {
t.Run(username, func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
_, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: username,
Email: "newuser@example.com",
Password: "password123",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid username")
})
}
})
t.Run("CreateUser validate only rejects invalid writable username", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
_, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "alice@example.com",
Email: "newuser@example.com",
Password: "password123",
},
ValidateOnly: true,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid username")
})
t.Run("UpdateUser rejects empty password", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
......@@ -236,4 +274,114 @@ func TestCreateUserRegistration(t *testing.T) {
require.Error(t, err)
require.Contains(t, err.Error(), "password must not be empty")
})
t.Run("UpdateUser rejects missing user message", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "missing-message")
require.NoError(t, err)
authCtx := ts.CreateUserContext(ctx, user.ID)
_, err = ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"display_name"}},
})
require.Error(t, err)
require.Contains(t, err.Error(), "user is required")
})
t.Run("CreateUser concurrent first setup creates one admin", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
const workers = 12
var wg sync.WaitGroup
for i := range workers {
wg.Go(func() {
_, _ = ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: fmt.Sprintf("setup-user-%d", i),
Email: "setup-user@example.com",
Password: "password123",
},
})
})
}
wg.Wait()
users, err := ts.Store.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err)
adminCount := 0
for _, user := range users {
if user.Role == store.RoleAdmin {
adminCount++
}
}
require.Equal(t, 1, adminCount)
})
t.Run("UpdateUser state requires admin", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "state-user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
_, err = ts.Service.UpdateUser(userCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: apiv1server.BuildUserName(user.Username),
State: apiv1.State_ARCHIVED,
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"state"}},
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
admin, err := ts.CreateHostUser(ctx, "state-admin")
require.NoError(t, err)
adminCtx := ts.CreateUserContext(ctx, admin.ID)
updated, err := ts.Service.UpdateUser(adminCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: apiv1server.BuildUserName(user.Username),
State: apiv1.State_ARCHIVED,
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"state"}},
})
require.NoError(t, err)
require.Equal(t, apiv1.State_ARCHIVED, updated.State)
})
t.Run("archived user context is rejected", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "archived-access-user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
archived := store.Archived
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
ID: user.ID,
RowStatus: &archived,
})
require.NoError(t, err)
_, err = ts.Service.GetCurrentUser(userCtx, &apiv1.GetCurrentUserRequest{})
require.Error(t, err)
_, err = ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "should not be created",
},
})
require.Error(t, err)
_, err = ts.Service.UpdateUser(userCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: apiv1server.BuildUserName(user.Username),
State: apiv1.State_NORMAL,
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"state"}},
})
require.Error(t, err)
})
}
......@@ -26,7 +26,7 @@ func parseUsernameFromName(name string) (string, error) {
return username, nil
}
func validateUsername(username string) error {
func validateWritableUsername(username string) error {
if username == "" || isNumericUsername(username) || !base.UIDMatcher.MatchString(username) {
return errors.Errorf("invalid username %q", username)
}
......
package v1
import (
"testing"
)
func TestValidateWritableUsername(t *testing.T) {
tests := []struct {
name string
username string
wantError bool
}{
{
name: "lowercase",
username: "alice",
},
{
name: "mixed case",
username: "Alice",
},
{
name: "hyphenated",
username: "alice-smith",
},
{
name: "uuid",
username: "550e8400-e29b-41d4-a716-446655440000",
},
{
name: "empty",
username: "",
wantError: true,
},
{
name: "numeric",
username: "123",
wantError: true,
},
{
name: "email",
username: "alice@example.com",
wantError: true,
},
{
name: "underscore",
username: "alice_smith",
wantError: true,
},
{
name: "space",
username: "alice smith",
wantError: true,
},
{
name: "slash",
username: "alice/smith",
wantError: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
err := validateWritableUsername(test.username)
if test.wantError && err == nil {
t.Fatalf("validateWritableUsername(%q) succeeded, want error", test.username)
}
if !test.wantError && err != nil {
t.Fatalf("validateWritableUsername(%q) returned error: %v", test.username, err)
}
})
}
}
func TestParseUsernameFromNameAllowsLegacyUsernames(t *testing.T) {
tests := []struct {
name string
want string
wantFail bool
}{
{
name: "users/alice",
want: "alice",
},
{
name: "users/alice@example.com",
want: "alice@example.com",
},
{
name: "users/alice_smith",
want: "alice_smith",
},
{
name: "users/",
wantFail: true,
},
{
name: "invalid/alice",
wantFail: true,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got, err := parseUsernameFromName(test.name)
if test.wantFail && err == nil {
t.Fatalf("parseUsernameFromName(%q) succeeded, want error", test.name)
}
if !test.wantFail && err != nil {
t.Fatalf("parseUsernameFromName(%q) returned error: %v", test.name, err)
}
if got != test.want {
t.Fatalf("parseUsernameFromName(%q) = %q, want %q", test.name, got, test.want)
}
})
}
}
This diff is collapsed.
......@@ -3,6 +3,8 @@ package v1
import (
"context"
"net/http"
"net/url"
"strings"
"connectrpc.com/connect"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
......@@ -143,8 +145,11 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
// Wrap with CORS for browser access
corsHandler := middleware.CORSWithConfig(middleware.CORSConfig{
UnsafeAllowOriginFunc: func(_ *echo.Context, origin string) (string, bool, error) {
return origin, true, nil
UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) {
if s.isAllowedConnectOrigin(c, origin) {
return origin, true, nil
}
return "", false, nil
},
AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions},
AllowHeaders: []string{"*"},
......@@ -155,3 +160,23 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
return nil
}
func (s *APIV1Service) isAllowedConnectOrigin(c *echo.Context, origin string) bool {
originURL, err := url.Parse(origin)
if err != nil || originURL.Scheme == "" || originURL.Host == "" {
return false
}
if strings.EqualFold(originURL.Host, c.Request().Host) {
return true
}
if s.Profile == nil || s.Profile.InstanceURL == "" {
return false
}
instanceURL, err := url.Parse(s.Profile.InstanceURL)
if err != nil || instanceURL.Scheme == "" || instanceURL.Host == "" {
return false
}
return strings.EqualFold(originURL.Scheme, instanceURL.Scheme) && strings.EqualFold(originURL.Host, instanceURL.Host)
}
package store
import (
"sync"
"time"
"github.com/usememos/memos/internal/profile"
......@@ -12,6 +13,8 @@ type Store struct {
profile *profile.Profile
driver Driver
userCreateMu sync.Mutex
// Cache settings
cacheConfig cache.Config
......
......@@ -95,6 +95,30 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) {
return user, nil
}
// CreateUserIfNoUsers creates a user only when the instance has no users.
// The in-process lock prevents concurrent first-user setup requests from
// creating multiple admins in the same server process.
func (s *Store) CreateUserIfNoUsers(ctx context.Context, create *User) (*User, bool, error) {
s.userCreateMu.Lock()
defer s.userCreateMu.Unlock()
limitOne := 1
users, err := s.driver.ListUsers(ctx, &FindUser{Limit: &limitOne})
if err != nil {
return nil, false, err
}
if len(users) > 0 {
return nil, false, nil
}
user, err := s.driver.CreateUser(ctx, create)
if err != nil {
return nil, false, err
}
s.userCache.Set(ctx, userCacheKey(user.ID), user)
return user, true, nil
}
func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) {
user, err := s.driver.UpdateUser(ctx, update)
if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment