Unverified Commit 30c0611a authored by boojack's avatar boojack Committed by GitHub

fix: fix legacy username auth flows (#5885)

parent d688914b
...@@ -529,6 +529,36 @@ func (s *APIV1Service) clearAuthCookies(ctx context.Context) error { ...@@ -529,6 +529,36 @@ func (s *APIV1Service) clearAuthCookies(ctx context.Context) error {
return nil return nil
} }
func isSecureRequest(ctx context.Context) bool {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return false
}
for _, value := range md.Get("x-forwarded-proto") {
for _, proto := range strings.Split(value, ",") {
if strings.EqualFold(strings.TrimSpace(proto), "https") {
return true
}
}
}
for _, value := range md.Get("forwarded") {
lowerValue := strings.ToLower(value)
if strings.Contains(lowerValue, "proto=https") {
return true
}
}
for _, value := range md.Get("origin") {
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(value)), "https://") {
return true
}
}
return false
}
func (*APIV1Service) buildRefreshTokenCookie(ctx context.Context, refreshToken string, expireTime time.Time) string { func (*APIV1Service) buildRefreshTokenCookie(ctx context.Context, refreshToken string, expireTime time.Time) string {
attrs := []string{ attrs := []string{
fmt.Sprintf("%s=%s", auth.RefreshTokenCookieName, refreshToken), fmt.Sprintf("%s=%s", auth.RefreshTokenCookieName, refreshToken),
...@@ -543,19 +573,7 @@ func (*APIV1Service) buildRefreshTokenCookie(ctx context.Context, refreshToken s ...@@ -543,19 +573,7 @@ func (*APIV1Service) buildRefreshTokenCookie(ctx context.Context, refreshToken s
attrs = append(attrs, "Expires="+expireTime.UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT")) attrs = append(attrs, "Expires="+expireTime.UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT"))
} }
// Try to determine if the request is HTTPS by checking the origin header if isSecureRequest(ctx) {
// Default to non-HTTPS (Lax SameSite) if metadata is not available
isHTTPS := false
if md, ok := metadata.FromIncomingContext(ctx); ok {
for _, v := range md.Get("origin") {
if strings.HasPrefix(v, "https://") {
isHTTPS = true
break
}
}
}
if isHTTPS {
attrs = append(attrs, "SameSite=Lax", "Secure") attrs = append(attrs, "SameSite=Lax", "Secure")
} else { } else {
attrs = append(attrs, "SameSite=Lax") attrs = append(attrs, "SameSite=Lax")
......
...@@ -2,7 +2,9 @@ package v1 ...@@ -2,7 +2,9 @@ package v1
import ( import (
"context" "context"
"strings"
"testing" "testing"
"time"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
...@@ -177,3 +179,50 @@ func TestClientInfoExamples(t *testing.T) { ...@@ -177,3 +179,50 @@ func TestClientInfoExamples(t *testing.T) {
}) })
} }
} }
func TestBuildRefreshTokenCookieSecureFlag(t *testing.T) {
service := &APIV1Service{}
t.Run("sets Secure for https origin", func(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(
"origin", "https://memos.example",
))
cookie := service.buildRefreshTokenCookie(ctx, "token", testCookieExpiry())
if !containsCookieAttribute(cookie, "Secure") {
t.Fatalf("expected Secure attribute in cookie: %s", cookie)
}
})
t.Run("sets Secure for forwarded proto", func(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(
"x-forwarded-proto", "https",
))
cookie := service.buildRefreshTokenCookie(ctx, "token", testCookieExpiry())
if !containsCookieAttribute(cookie, "Secure") {
t.Fatalf("expected Secure attribute in cookie: %s", cookie)
}
})
t.Run("omits Secure for plain http", func(t *testing.T) {
ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs(
"origin", "http://memos.example",
))
cookie := service.buildRefreshTokenCookie(ctx, "token", testCookieExpiry())
if containsCookieAttribute(cookie, "Secure") {
t.Fatalf("did not expect Secure attribute in cookie: %s", cookie)
}
})
}
func testCookieExpiry() time.Time {
return time.Date(2030, time.January, 2, 3, 4, 5, 0, time.UTC)
}
func containsCookieAttribute(cookie, attr string) bool {
for _, part := range strings.Split(cookie, ";") {
if strings.EqualFold(strings.TrimSpace(part), attr) {
return true
}
}
return false
}
...@@ -38,12 +38,21 @@ func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc ...@@ -38,12 +38,21 @@ func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc
if ua := header.Get("User-Agent"); ua != "" { if ua := header.Get("User-Agent"); ua != "" {
md.Set("user-agent", ua) md.Set("user-agent", ua)
} }
if origin := header.Get("Origin"); origin != "" {
md.Set("origin", origin)
}
if xff := header.Get("X-Forwarded-For"); xff != "" { if xff := header.Get("X-Forwarded-For"); xff != "" {
md.Set("x-forwarded-for", xff) md.Set("x-forwarded-for", xff)
} }
if xfp := header.Get("X-Forwarded-Proto"); xfp != "" {
md.Set("x-forwarded-proto", xfp)
}
if xri := header.Get("X-Real-Ip"); xri != "" { if xri := header.Get("X-Real-Ip"); xri != "" {
md.Set("x-real-ip", xri) md.Set("x-real-ip", xri)
} }
if forwarded := header.Get("Forwarded"); forwarded != "" {
md.Set("forwarded", forwarded)
}
// Forward Cookie header for authentication methods that need it (e.g., RefreshToken) // Forward Cookie header for authentication methods that need it (e.g., RefreshToken)
if cookie := header.Get("Cookie"); cookie != "" { if cookie := header.Get("Cookie"); cookie != "" {
md.Set("cookie", cookie) md.Set("cookie", cookie)
......
package v1
import (
"context"
"testing"
"connectrpc.com/connect"
"google.golang.org/grpc/metadata"
"google.golang.org/protobuf/types/known/emptypb"
)
func TestMetadataInterceptorForwardsSecurityHeaders(t *testing.T) {
interceptor := NewMetadataInterceptor()
req := connect.NewRequest(&emptypb.Empty{})
req.Header().Set("Origin", "https://memos.example")
req.Header().Set("X-Forwarded-Proto", "https")
req.Header().Set("Forwarded", "for=203.0.113.1;proto=https")
handler := interceptor.WrapUnary(func(ctx context.Context, _ connect.AnyRequest) (connect.AnyResponse, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
t.Fatal("expected metadata in context")
}
if got := md.Get("origin"); len(got) != 1 || got[0] != "https://memos.example" {
t.Fatalf("unexpected origin metadata: %v", got)
}
if got := md.Get("x-forwarded-proto"); len(got) != 1 || got[0] != "https" {
t.Fatalf("unexpected x-forwarded-proto metadata: %v", got)
}
if got := md.Get("forwarded"); len(got) != 1 || got[0] != "for=203.0.113.1;proto=https" {
t.Fatalf("unexpected forwarded metadata: %v", got)
}
return connect.NewResponse(&emptypb.Empty{}), nil
})
if _, err := handler(context.Background(), req); err != nil {
t.Fatalf("metadata interceptor returned error: %v", err)
}
}
...@@ -94,7 +94,7 @@ func TestListShortcuts(t *testing.T) { ...@@ -94,7 +94,7 @@ func TestListShortcuts(t *testing.T) {
require.Contains(t, err.Error(), "permission denied") require.Contains(t, err.Error(), "permission denied")
}) })
t.Run("ListShortcuts rejects numeric parent", func(t *testing.T) { t.Run("ListShortcuts returns not found for numeric parent", func(t *testing.T) {
ts := NewTestService(t) ts := NewTestService(t)
defer ts.Cleanup() defer ts.Cleanup()
...@@ -107,7 +107,7 @@ func TestListShortcuts(t *testing.T) { ...@@ -107,7 +107,7 @@ func TestListShortcuts(t *testing.T) {
Parent: "users/1", Parent: "users/1",
}) })
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "invalid user name") require.Contains(t, err.Error(), "user not found")
}) })
} }
......
...@@ -202,7 +202,7 @@ func TestListUserNotificationsRejectsNumericParent(t *testing.T) { ...@@ -202,7 +202,7 @@ func TestListUserNotificationsRejectsNumericParent(t *testing.T) {
Parent: "users/1", Parent: "users/1",
}) })
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "invalid user name") require.Contains(t, err.Error(), "user not found")
} }
func TestListUserNotificationsIncludesMemoMentionPayload(t *testing.T) { func TestListUserNotificationsIncludesMemoMentionPayload(t *testing.T) {
......
...@@ -5,8 +5,11 @@ import ( ...@@ -5,8 +5,11 @@ import (
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/fieldmaskpb"
apiv1 "github.com/usememos/memos/proto/gen/api/v1" apiv1 "github.com/usememos/memos/proto/gen/api/v1"
apiv1server "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
) )
func TestUserResourceName(t *testing.T) { func TestUserResourceName(t *testing.T) {
...@@ -100,7 +103,7 @@ func TestUserResourceName(t *testing.T) { ...@@ -100,7 +103,7 @@ func TestUserResourceName(t *testing.T) {
require.Contains(t, err.Error(), "invalid username") require.Contains(t, err.Error(), "invalid username")
}) })
t.Run("GetUser rejects numeric user resource names", func(t *testing.T) { t.Run("GetUser returns not found for numeric user resource names", func(t *testing.T) {
ts := NewTestService(t) ts := NewTestService(t)
defer ts.Cleanup() defer ts.Cleanup()
...@@ -111,6 +114,84 @@ func TestUserResourceName(t *testing.T) { ...@@ -111,6 +114,84 @@ func TestUserResourceName(t *testing.T) {
Name: "users/1", Name: "users/1",
}) })
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "invalid user name") require.Contains(t, err.Error(), "user not found")
})
t.Run("legacy invalid username remains addressable for get update and delete", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
legacyUser, err := ts.CreateRegularUser(ctx, "legacy_user")
require.NoError(t, err)
got, err := ts.Service.GetUser(ctx, &apiv1.GetUserRequest{
Name: "users/legacy_user",
})
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, "users/legacy_user", got.Name)
authCtx := ts.CreateUserContext(apiv1server.WithHeaderCarrier(ctx), legacyUser.ID)
updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: apiv1server.BuildUserName(legacyUser.Username),
DisplayName: "Legacy User",
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"display_name"}},
})
require.NoError(t, err)
require.Equal(t, "Legacy User", updated.DisplayName)
_, err = ts.Service.DeleteUser(authCtx, &apiv1.DeleteUserRequest{
Name: apiv1server.BuildUserName(legacyUser.Username),
})
require.NoError(t, err)
deleted, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &legacyUser.ID})
require.NoError(t, err)
require.Nil(t, deleted)
})
t.Run("email-like legacy username can be renamed to a valid username", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
legacyUser, err := ts.CreateRegularUser(ctx, "alice@example.com")
require.NoError(t, err)
authCtx := ts.CreateUserContext(apiv1server.WithHeaderCarrier(ctx), legacyUser.ID)
updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: apiv1server.BuildUserName(legacyUser.Username),
Username: "alice",
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"username"}},
})
require.NoError(t, err)
require.Equal(t, "users/alice", updated.Name)
require.Equal(t, "alice", updated.Username)
renamed, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &legacyUser.ID})
require.NoError(t, err)
require.NotNil(t, renamed)
require.Equal(t, "alice", renamed.Username)
})
t.Run("email-like legacy username can be deleted", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
legacyUser, err := ts.CreateRegularUser(ctx, "bob@example.com")
require.NoError(t, err)
authCtx := ts.CreateUserContext(apiv1server.WithHeaderCarrier(ctx), legacyUser.ID)
_, err = ts.Service.DeleteUser(authCtx, &apiv1.DeleteUserRequest{
Name: apiv1server.BuildUserName(legacyUser.Username),
})
require.NoError(t, err)
deleted, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &legacyUser.ID})
require.NoError(t, err)
require.Nil(t, deleted)
}) })
} }
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/fieldmaskpb"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
apiv1server "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
)
func TestUserServiceWithEmailLikeUsername(t *testing.T) {
ctx := context.Background()
t.Run("GetUser accepts email-like username in resource name", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "alice@example.com")
require.NoError(t, err)
got, err := ts.Service.GetUser(ctx, &apiv1.GetUserRequest{
Name: "users/alice@example.com",
})
require.NoError(t, err)
require.NotNil(t, got)
require.Equal(t, user.Username, got.Username)
require.Equal(t, "users/alice@example.com", got.Name)
})
t.Run("ListUserSettings accepts email-like username in parent", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "alice@example.com")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
resp, err := ts.Service.ListUserSettings(userCtx, &apiv1.ListUserSettingsRequest{
Parent: "users/alice@example.com",
})
require.NoError(t, err)
require.NotNil(t, resp)
require.NotEmpty(t, resp.Settings)
})
t.Run("UpdateUser can change non-username fields for email-like username", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "alice@example.com")
require.NoError(t, err)
authCtx := ts.CreateUserContext(ctx, user.ID)
updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: "users/alice@example.com",
DisplayName: "Alice Example",
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"display_name"}},
})
require.NoError(t, err)
require.Equal(t, "Alice Example", updated.DisplayName)
require.Equal(t, "users/alice@example.com", updated.Name)
})
t.Run("UpdateUser can rename email-like username to valid username", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "bob@example.com")
require.NoError(t, err)
authCtx := ts.CreateUserContext(ctx, user.ID)
updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: "users/bob@example.com",
Username: "bob",
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"username"}},
})
require.NoError(t, err)
require.Equal(t, "bob", updated.Username)
require.Equal(t, apiv1server.BuildUserName("bob"), updated.Name)
stored, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID})
require.NoError(t, err)
require.NotNil(t, stored)
require.Equal(t, "bob", stored.Username)
})
t.Run("UpdateUser can archive email-like username account", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "dave@example.com")
require.NoError(t, err)
authCtx := ts.CreateUserContext(ctx, user.ID)
updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: "users/dave@example.com",
State: apiv1.State_ARCHIVED,
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"state"}},
})
require.NoError(t, err)
require.Equal(t, apiv1.State_ARCHIVED, updated.State)
stored, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID})
require.NoError(t, err)
require.NotNil(t, stored)
require.Equal(t, store.Archived, stored.RowStatus)
})
t.Run("DeleteUser can remove email-like username account", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "carol@example.com")
require.NoError(t, err)
authCtx := ts.CreateUserContext(apiv1server.WithHeaderCarrier(ctx), user.ID)
_, err = ts.Service.DeleteUser(authCtx, &apiv1.DeleteUserRequest{
Name: "users/carol@example.com",
})
require.NoError(t, err)
deleted, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID})
require.NoError(t, err)
require.Nil(t, deleted)
})
}
...@@ -5,9 +5,11 @@ import ( ...@@ -5,9 +5,11 @@ import (
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/fieldmaskpb"
apiv1 "github.com/usememos/memos/proto/gen/api/v1" apiv1 "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
apiv1server "github.com/usememos/memos/server/router/api/v1"
) )
func TestCreateUserRegistration(t *testing.T) { func TestCreateUserRegistration(t *testing.T) {
...@@ -172,4 +174,66 @@ func TestCreateUserRegistration(t *testing.T) { ...@@ -172,4 +174,66 @@ func TestCreateUserRegistration(t *testing.T) {
require.Equal(t, "users/wannabeadmin", createdUser.Name) require.Equal(t, "users/wannabeadmin", createdUser.Name)
require.Equal(t, apiv1.User_USER, createdUser.Role, "Unauthenticated users can only create USER role") require.Equal(t, apiv1.User_USER, createdUser.Role, "Unauthenticated users can only create USER role")
}) })
t.Run("CreateUser blocked when password auth disabled for self signup", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
_, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
DisallowPasswordAuth: true,
},
},
})
require.NoError(t, err)
_, err = ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "password signup is not allowed")
})
t.Run("CreateUser rejects empty password", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
_, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "password must not be empty")
})
t.Run("UpdateUser rejects empty password", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "alice")
require.NoError(t, err)
authCtx := ts.CreateUserContext(ctx, user.ID)
_, err = ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{
User: &apiv1.User{
Name: apiv1server.BuildUserName(user.Username),
Password: "",
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"password"}},
})
require.Error(t, err)
require.Contains(t, err.Error(), "password must not be empty")
})
} }
...@@ -108,5 +108,5 @@ func TestGetUserStats_TagCount(t *testing.T) { ...@@ -108,5 +108,5 @@ func TestGetUserStats_TagCount(t *testing.T) {
Name: "users/1", Name: "users/1",
}) })
require.Error(t, err) require.Error(t, err)
require.Contains(t, err.Error(), "invalid user name") require.Contains(t, err.Error(), "user not found")
} }
...@@ -14,8 +14,7 @@ func BuildUserName(username string) string { ...@@ -14,8 +14,7 @@ func BuildUserName(username string) string {
return UserNamePrefix + username return UserNamePrefix + username
} }
// ExtractUsernameFromName extracts the username token from a user resource name. func parseUsernameFromName(name string) (string, error) {
func ExtractUsernameFromName(name string) (string, error) {
tokens, err := GetNameParentTokens(name, UserNamePrefix) tokens, err := GetNameParentTokens(name, UserNamePrefix)
if err != nil { if err != nil {
return "", err return "", err
...@@ -24,9 +23,6 @@ func ExtractUsernameFromName(name string) (string, error) { ...@@ -24,9 +23,6 @@ func ExtractUsernameFromName(name string) (string, error) {
if username == "" { if username == "" {
return "", errors.Errorf("invalid user name %q", name) return "", errors.Errorf("invalid user name %q", name)
} }
if err := validateUsername(username); err != nil {
return "", err
}
return username, nil return username, nil
} }
...@@ -51,7 +47,7 @@ func isNumericUsername(username string) bool { ...@@ -51,7 +47,7 @@ func isNumericUsername(username string) bool {
// ResolveUserByName resolves a username-based user resource name to a store user. // ResolveUserByName resolves a username-based user resource name to a store user.
func ResolveUserByName(ctx context.Context, stores *store.Store, name string) (*store.User, error) { func ResolveUserByName(ctx context.Context, stores *store.Store, name string) (*store.User, error) {
username, err := ExtractUsernameFromName(name) username, err := parseUsernameFromName(name)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -30,6 +30,13 @@ import ( ...@@ -30,6 +30,13 @@ import (
const maxBatchGetUsers = 100 const maxBatchGetUsers = 100
func validatePassword(password string) error {
if password == "" {
return errors.New("password must not be empty")
}
return nil
}
func (s *APIV1Service) ListUsers(ctx context.Context, request *v1pb.ListUsersRequest) (*v1pb.ListUsersResponse, error) { func (s *APIV1Service) ListUsers(ctx context.Context, request *v1pb.ListUsersRequest) (*v1pb.ListUsersResponse, error) {
currentUser, err := s.fetchCurrentUser(ctx) currentUser, err := s.fetchCurrentUser(ctx)
if err != nil { if err != nil {
...@@ -156,6 +163,9 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR ...@@ -156,6 +163,9 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR
if instanceGeneralSetting.DisallowUserRegistration { if instanceGeneralSetting.DisallowUserRegistration {
return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed") return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed")
} }
if instanceGeneralSetting.DisallowPasswordAuth {
return nil, status.Errorf(codes.PermissionDenied, "password signup is not allowed")
}
} }
} }
...@@ -179,6 +189,9 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR ...@@ -179,6 +189,9 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR
if err := validateUsername(request.User.Username); err != nil { if err := validateUsername(request.User.Username); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username) return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username)
} }
if err := validatePassword(request.User.Password); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
// If validate_only is true, just validate without creating // If validate_only is true, just validate without creating
if request.ValidateOnly { if request.ValidateOnly {
...@@ -294,6 +307,9 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR ...@@ -294,6 +307,9 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR
role := convertUserRoleToStore(request.User.Role) role := convertUserRoleToStore(request.User.Role)
update.Role = &role update.Role = &role
case "password": case "password":
if err := validatePassword(request.User.Password); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost) passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate password hash: %v", err) return nil, status.Errorf(codes.Internal, "failed to generate password hash: %v", err)
......
...@@ -30,6 +30,7 @@ const SignUp = () => { ...@@ -30,6 +30,7 @@ const SignUp = () => {
const [searchParams] = useSearchParams(); const [searchParams] = useSearchParams();
const redirectTarget = getSafeRedirectPath(searchParams.get(AUTH_REDIRECT_PARAM)); const redirectTarget = getSafeRedirectPath(searchParams.get(AUTH_REDIRECT_PARAM));
const signInPath = searchParams.toString() ? `${ROUTES.AUTH}?${searchParams.toString()}` : ROUTES.AUTH; const signInPath = searchParams.toString() ? `${ROUTES.AUTH}?${searchParams.toString()}` : ROUTES.AUTH;
const canUsePasswordSignUp = !instanceGeneralSetting.disallowUserRegistration && !instanceGeneralSetting.disallowPasswordAuth;
const handleUsernameInputChanged = (e: React.ChangeEvent<HTMLInputElement>) => { const handleUsernameInputChanged = (e: React.ChangeEvent<HTMLInputElement>) => {
const text = e.target.value as string; const text = e.target.value as string;
...@@ -93,7 +94,7 @@ const SignUp = () => { ...@@ -93,7 +94,7 @@ const SignUp = () => {
<img className="h-14 w-auto rounded-full shadow" src={instanceGeneralSetting.customProfile?.logoUrl || "/logo.webp"} alt="" /> <img className="h-14 w-auto rounded-full shadow" src={instanceGeneralSetting.customProfile?.logoUrl || "/logo.webp"} alt="" />
<p className="ml-2 text-5xl text-foreground opacity-80">{instanceGeneralSetting.customProfile?.title || "Memos"}</p> <p className="ml-2 text-5xl text-foreground opacity-80">{instanceGeneralSetting.customProfile?.title || "Memos"}</p>
</div> </div>
{!instanceGeneralSetting.disallowUserRegistration ? ( {canUsePasswordSignUp ? (
<> <>
<p className="w-full text-2xl mt-2 text-muted-foreground">{t("auth.create-your-account")}</p> <p className="w-full text-2xl mt-2 text-muted-foreground">{t("auth.create-your-account")}</p>
<form className="w-full mt-2" onSubmit={handleFormSubmit}> <form className="w-full mt-2" onSubmit={handleFormSubmit}>
...@@ -137,6 +138,8 @@ const SignUp = () => { ...@@ -137,6 +138,8 @@ const SignUp = () => {
</div> </div>
</form> </form>
</> </>
) : instanceGeneralSetting.disallowPasswordAuth ? (
<p className="w-full text-2xl mt-2 text-muted-foreground">Password sign up is not allowed.</p>
) : ( ) : (
<p className="w-full text-2xl mt-2 text-muted-foreground">Sign up is not allowed.</p> <p className="w-full text-2xl mt-2 text-muted-foreground">Sign up is not allowed.</p>
)} )}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment