Commit ca799906 authored by johnnyjoy's avatar johnnyjoy

refactor: merge sign in requests

parent a0f68895
...@@ -13,14 +13,10 @@ service AuthService { ...@@ -13,14 +13,10 @@ service AuthService {
rpc GetAuthStatus(GetAuthStatusRequest) returns (User) { rpc GetAuthStatus(GetAuthStatusRequest) returns (User) {
option (google.api.http) = {post: "/api/v1/auth/status"}; option (google.api.http) = {post: "/api/v1/auth/status"};
} }
// SignIn signs in the user with the given username and password. // SignIn signs in the user.
rpc SignIn(SignInRequest) returns (User) { rpc SignIn(SignInRequest) returns (User) {
option (google.api.http) = {post: "/api/v1/auth/signin"}; option (google.api.http) = {post: "/api/v1/auth/signin"};
} }
// SignInWithSSO signs in the user with the given SSO code.
rpc SignInWithSSO(SignInWithSSORequest) returns (User) {
option (google.api.http) = {post: "/api/v1/auth/signin/sso"};
}
// SignUp signs up the user with the given username and password. // SignUp signs up the user with the given username and password.
rpc SignUp(SignUpRequest) returns (User) { rpc SignUp(SignUpRequest) returns (User) {
option (google.api.http) = {post: "/api/v1/auth/signup"}; option (google.api.http) = {post: "/api/v1/auth/signup"};
...@@ -38,15 +34,26 @@ message GetAuthStatusResponse { ...@@ -38,15 +34,26 @@ message GetAuthStatusResponse {
} }
message SignInRequest { message SignInRequest {
// Provide one authentication method (username/password or SSO).
oneof method {
// Username and password authentication method.
PasswordCredentials password_credentials = 1;
// SSO provider authentication method.
SSOCredentials sso_credentials = 2;
}
// Whether the session should never expire.
bool never_expire = 3;
}
message PasswordCredentials {
// The username to sign in with. // The username to sign in with.
string username = 1; string username = 1;
// The password to sign in with. // The password to sign in with.
string password = 2; string password = 2;
// Whether the session should never expire.
bool never_expire = 3;
} }
message SignInWithSSORequest { message SSOCredentials {
// The ID of the SSO provider. // The ID of the SSO provider.
int32 idp_id = 1; int32 idp_id = 1;
// The code to sign in with. // The code to sign in with.
......
This diff is collapsed.
...@@ -87,39 +87,6 @@ func local_request_AuthService_SignIn_0(ctx context.Context, marshaler runtime.M ...@@ -87,39 +87,6 @@ func local_request_AuthService_SignIn_0(ctx context.Context, marshaler runtime.M
return msg, metadata, err return msg, metadata, err
} }
var filter_AuthService_SignInWithSSO_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)}
func request_AuthService_SignInWithSSO_0(ctx context.Context, marshaler runtime.Marshaler, client AuthServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq SignInWithSSORequest
metadata runtime.ServerMetadata
)
io.Copy(io.Discard, req.Body)
if err := req.ParseForm(); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_AuthService_SignInWithSSO_0); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
msg, err := client.SignInWithSSO(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
return msg, metadata, err
}
func local_request_AuthService_SignInWithSSO_0(ctx context.Context, marshaler runtime.Marshaler, server AuthServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq SignInWithSSORequest
metadata runtime.ServerMetadata
)
if err := req.ParseForm(); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_AuthService_SignInWithSSO_0); err != nil {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
msg, err := server.SignInWithSSO(ctx, &protoReq)
return msg, metadata, err
}
var filter_AuthService_SignUp_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} var filter_AuthService_SignUp_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)}
func request_AuthService_SignUp_0(ctx context.Context, marshaler runtime.Marshaler, client AuthServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { func request_AuthService_SignUp_0(ctx context.Context, marshaler runtime.Marshaler, client AuthServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
...@@ -218,26 +185,6 @@ func RegisterAuthServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux ...@@ -218,26 +185,6 @@ func RegisterAuthServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux
} }
forward_AuthService_SignIn_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) forward_AuthService_SignIn_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
}) })
mux.Handle(http.MethodPost, pattern_AuthService_SignInWithSSO_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
var stream runtime.ServerTransportStream
ctx = grpc.NewContextWithServerTransportStream(ctx, &stream)
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/memos.api.v1.AuthService/SignInWithSSO", runtime.WithHTTPPathPattern("/api/v1/auth/signin/sso"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := local_request_AuthService_SignInWithSSO_0(annotatedContext, inboundMarshaler, server, req, pathParams)
md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer())
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_AuthService_SignInWithSSO_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_AuthService_SignUp_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { mux.Handle(http.MethodPost, pattern_AuthService_SignUp_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context()) ctx, cancel := context.WithCancel(req.Context())
defer cancel() defer cancel()
...@@ -352,23 +299,6 @@ func RegisterAuthServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux ...@@ -352,23 +299,6 @@ func RegisterAuthServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux
} }
forward_AuthService_SignIn_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) forward_AuthService_SignIn_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
}) })
mux.Handle(http.MethodPost, pattern_AuthService_SignInWithSSO_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/memos.api.v1.AuthService/SignInWithSSO", runtime.WithHTTPPathPattern("/api/v1/auth/signin/sso"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := request_AuthService_SignInWithSSO_0(annotatedContext, inboundMarshaler, client, req, pathParams)
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_AuthService_SignInWithSSO_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
mux.Handle(http.MethodPost, pattern_AuthService_SignUp_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { mux.Handle(http.MethodPost, pattern_AuthService_SignUp_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context()) ctx, cancel := context.WithCancel(req.Context())
defer cancel() defer cancel()
...@@ -409,7 +339,6 @@ func RegisterAuthServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux ...@@ -409,7 +339,6 @@ func RegisterAuthServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux
var ( var (
pattern_AuthService_GetAuthStatus_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "status"}, "")) pattern_AuthService_GetAuthStatus_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "status"}, ""))
pattern_AuthService_SignIn_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "signin"}, "")) pattern_AuthService_SignIn_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "signin"}, ""))
pattern_AuthService_SignInWithSSO_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3, 2, 4}, []string{"api", "v1", "auth", "signin", "sso"}, ""))
pattern_AuthService_SignUp_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "signup"}, "")) pattern_AuthService_SignUp_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "signup"}, ""))
pattern_AuthService_SignOut_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "signout"}, "")) pattern_AuthService_SignOut_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "auth", "signout"}, ""))
) )
...@@ -417,7 +346,6 @@ var ( ...@@ -417,7 +346,6 @@ var (
var ( var (
forward_AuthService_GetAuthStatus_0 = runtime.ForwardResponseMessage forward_AuthService_GetAuthStatus_0 = runtime.ForwardResponseMessage
forward_AuthService_SignIn_0 = runtime.ForwardResponseMessage forward_AuthService_SignIn_0 = runtime.ForwardResponseMessage
forward_AuthService_SignInWithSSO_0 = runtime.ForwardResponseMessage
forward_AuthService_SignUp_0 = runtime.ForwardResponseMessage forward_AuthService_SignUp_0 = runtime.ForwardResponseMessage
forward_AuthService_SignOut_0 = runtime.ForwardResponseMessage forward_AuthService_SignOut_0 = runtime.ForwardResponseMessage
) )
...@@ -22,7 +22,6 @@ const _ = grpc.SupportPackageIsVersion9 ...@@ -22,7 +22,6 @@ const _ = grpc.SupportPackageIsVersion9
const ( const (
AuthService_GetAuthStatus_FullMethodName = "/memos.api.v1.AuthService/GetAuthStatus" AuthService_GetAuthStatus_FullMethodName = "/memos.api.v1.AuthService/GetAuthStatus"
AuthService_SignIn_FullMethodName = "/memos.api.v1.AuthService/SignIn" AuthService_SignIn_FullMethodName = "/memos.api.v1.AuthService/SignIn"
AuthService_SignInWithSSO_FullMethodName = "/memos.api.v1.AuthService/SignInWithSSO"
AuthService_SignUp_FullMethodName = "/memos.api.v1.AuthService/SignUp" AuthService_SignUp_FullMethodName = "/memos.api.v1.AuthService/SignUp"
AuthService_SignOut_FullMethodName = "/memos.api.v1.AuthService/SignOut" AuthService_SignOut_FullMethodName = "/memos.api.v1.AuthService/SignOut"
) )
...@@ -33,10 +32,8 @@ const ( ...@@ -33,10 +32,8 @@ const (
type AuthServiceClient interface { type AuthServiceClient interface {
// GetAuthStatus returns the current auth status of the user. // GetAuthStatus returns the current auth status of the user.
GetAuthStatus(ctx context.Context, in *GetAuthStatusRequest, opts ...grpc.CallOption) (*User, error) GetAuthStatus(ctx context.Context, in *GetAuthStatusRequest, opts ...grpc.CallOption) (*User, error)
// SignIn signs in the user with the given username and password. // SignIn signs in the user.
SignIn(ctx context.Context, in *SignInRequest, opts ...grpc.CallOption) (*User, error) SignIn(ctx context.Context, in *SignInRequest, opts ...grpc.CallOption) (*User, error)
// SignInWithSSO signs in the user with the given SSO code.
SignInWithSSO(ctx context.Context, in *SignInWithSSORequest, opts ...grpc.CallOption) (*User, error)
// SignUp signs up the user with the given username and password. // SignUp signs up the user with the given username and password.
SignUp(ctx context.Context, in *SignUpRequest, opts ...grpc.CallOption) (*User, error) SignUp(ctx context.Context, in *SignUpRequest, opts ...grpc.CallOption) (*User, error)
// SignOut signs out the user. // SignOut signs out the user.
...@@ -71,16 +68,6 @@ func (c *authServiceClient) SignIn(ctx context.Context, in *SignInRequest, opts ...@@ -71,16 +68,6 @@ func (c *authServiceClient) SignIn(ctx context.Context, in *SignInRequest, opts
return out, nil return out, nil
} }
func (c *authServiceClient) SignInWithSSO(ctx context.Context, in *SignInWithSSORequest, opts ...grpc.CallOption) (*User, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(User)
err := c.cc.Invoke(ctx, AuthService_SignInWithSSO_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *authServiceClient) SignUp(ctx context.Context, in *SignUpRequest, opts ...grpc.CallOption) (*User, error) { func (c *authServiceClient) SignUp(ctx context.Context, in *SignUpRequest, opts ...grpc.CallOption) (*User, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(User) out := new(User)
...@@ -107,10 +94,8 @@ func (c *authServiceClient) SignOut(ctx context.Context, in *SignOutRequest, opt ...@@ -107,10 +94,8 @@ func (c *authServiceClient) SignOut(ctx context.Context, in *SignOutRequest, opt
type AuthServiceServer interface { type AuthServiceServer interface {
// GetAuthStatus returns the current auth status of the user. // GetAuthStatus returns the current auth status of the user.
GetAuthStatus(context.Context, *GetAuthStatusRequest) (*User, error) GetAuthStatus(context.Context, *GetAuthStatusRequest) (*User, error)
// SignIn signs in the user with the given username and password. // SignIn signs in the user.
SignIn(context.Context, *SignInRequest) (*User, error) SignIn(context.Context, *SignInRequest) (*User, error)
// SignInWithSSO signs in the user with the given SSO code.
SignInWithSSO(context.Context, *SignInWithSSORequest) (*User, error)
// SignUp signs up the user with the given username and password. // SignUp signs up the user with the given username and password.
SignUp(context.Context, *SignUpRequest) (*User, error) SignUp(context.Context, *SignUpRequest) (*User, error)
// SignOut signs out the user. // SignOut signs out the user.
...@@ -131,9 +116,6 @@ func (UnimplementedAuthServiceServer) GetAuthStatus(context.Context, *GetAuthSta ...@@ -131,9 +116,6 @@ func (UnimplementedAuthServiceServer) GetAuthStatus(context.Context, *GetAuthSta
func (UnimplementedAuthServiceServer) SignIn(context.Context, *SignInRequest) (*User, error) { func (UnimplementedAuthServiceServer) SignIn(context.Context, *SignInRequest) (*User, error) {
return nil, status.Errorf(codes.Unimplemented, "method SignIn not implemented") return nil, status.Errorf(codes.Unimplemented, "method SignIn not implemented")
} }
func (UnimplementedAuthServiceServer) SignInWithSSO(context.Context, *SignInWithSSORequest) (*User, error) {
return nil, status.Errorf(codes.Unimplemented, "method SignInWithSSO not implemented")
}
func (UnimplementedAuthServiceServer) SignUp(context.Context, *SignUpRequest) (*User, error) { func (UnimplementedAuthServiceServer) SignUp(context.Context, *SignUpRequest) (*User, error) {
return nil, status.Errorf(codes.Unimplemented, "method SignUp not implemented") return nil, status.Errorf(codes.Unimplemented, "method SignUp not implemented")
} }
...@@ -197,24 +179,6 @@ func _AuthService_SignIn_Handler(srv interface{}, ctx context.Context, dec func( ...@@ -197,24 +179,6 @@ func _AuthService_SignIn_Handler(srv interface{}, ctx context.Context, dec func(
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _AuthService_SignInWithSSO_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SignInWithSSORequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AuthServiceServer).SignInWithSSO(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: AuthService_SignInWithSSO_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AuthServiceServer).SignInWithSSO(ctx, req.(*SignInWithSSORequest))
}
return interceptor(ctx, in, info, handler)
}
func _AuthService_SignUp_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { func _AuthService_SignUp_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(SignUpRequest) in := new(SignUpRequest)
if err := dec(in); err != nil { if err := dec(in); err != nil {
...@@ -266,10 +230,6 @@ var AuthService_ServiceDesc = grpc.ServiceDesc{ ...@@ -266,10 +230,6 @@ var AuthService_ServiceDesc = grpc.ServiceDesc{
MethodName: "SignIn", MethodName: "SignIn",
Handler: _AuthService_SignIn_Handler, Handler: _AuthService_SignIn_Handler,
}, },
{
MethodName: "SignInWithSSO",
Handler: _AuthService_SignInWithSSO_Handler,
},
{ {
MethodName: "SignUp", MethodName: "SignUp",
Handler: _AuthService_SignUp_Handler, Handler: _AuthService_SignUp_Handler,
......
This diff is collapsed.
...@@ -44,8 +44,10 @@ func (s *APIV1Service) GetAuthStatus(ctx context.Context, _ *v1pb.GetAuthStatusR ...@@ -44,8 +44,10 @@ func (s *APIV1Service) GetAuthStatus(ctx context.Context, _ *v1pb.GetAuthStatusR
} }
func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.User, error) { func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) (*v1pb.User, error) {
var existingUser *store.User
if passwordCredentials := request.GetPasswordCredentials(); passwordCredentials != nil {
user, err := s.Store.GetUser(ctx, &store.FindUser{ user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &request.Username, Username: &passwordCredentials.Username,
}) })
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err) return nil, status.Errorf(codes.Internal, "failed to get user, error: %v", err)
...@@ -54,10 +56,9 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) ...@@ -54,10 +56,9 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest)
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError) return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
} }
// Compare the stored hashed password, with the hashed version of the password that was received. // Compare the stored hashed password, with the hashed version of the password that was received.
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(request.Password)); err != nil { if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(passwordCredentials.Password)); err != nil {
return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError) return nil, status.Errorf(codes.InvalidArgument, unmatchedUsernameAndPasswordError)
} }
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx) workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err) return nil, status.Errorf(codes.Internal, "failed to get workspace general setting, error: %v", err)
...@@ -66,24 +67,10 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest) ...@@ -66,24 +67,10 @@ func (s *APIV1Service) SignIn(ctx context.Context, request *v1pb.SignInRequest)
if workspaceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser { if workspaceGeneralSetting.DisallowPasswordAuth && user.Role == store.RoleUser {
return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed") return nil, status.Errorf(codes.PermissionDenied, "password signin is not allowed")
} }
if user.RowStatus == store.Archived { existingUser = user
return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", request.Username) } else if ssoCredentials := request.GetSsoCredentials(); ssoCredentials != nil {
}
expireTime := time.Now().Add(AccessTokenDuration)
if request.NeverExpire {
// Set the expire time to 100 years.
expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
}
if err := s.doSignIn(ctx, user, expireTime); err != nil {
return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err)
}
return convertUserFromStore(user), nil
}
func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWithSSORequest) (*v1pb.User, error) {
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &request.IdpId, ID: &ssoCredentials.IdpId,
}) })
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err) return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %v", err)
...@@ -98,7 +85,7 @@ func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWi ...@@ -98,7 +85,7 @@ func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWi
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err) return nil, status.Errorf(codes.Internal, "failed to create oauth2 identity provider, error: %v", err)
} }
token, err := oauth2IdentityProvider.ExchangeToken(ctx, request.RedirectUri, request.Code) token, err := oauth2IdentityProvider.ExchangeToken(ctx, ssoCredentials.RedirectUri, ssoCredentials.Code)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err) return nil, status.Errorf(codes.Internal, "failed to exchange token, error: %v", err)
} }
...@@ -158,14 +145,25 @@ func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWi ...@@ -158,14 +145,25 @@ func (s *APIV1Service) SignInWithSSO(ctx context.Context, request *v1pb.SignInWi
return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err) return nil, status.Errorf(codes.Internal, "failed to create user, error: %v", err)
} }
} }
if user.RowStatus == store.Archived { existingUser = user
return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", userInfo.Identifier)
} }
if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil { if existingUser == nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid credentials")
}
if existingUser.RowStatus == store.Archived {
return nil, status.Errorf(codes.PermissionDenied, "user has been archived with username %s", existingUser.Username)
}
expireTime := time.Now().Add(AccessTokenDuration)
if request.NeverExpire {
// Set the expire time to 100 years.
expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
}
if err := s.doSignIn(ctx, existingUser, expireTime); err != nil {
return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err) return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err)
} }
return convertUserFromStore(user), nil return convertUserFromStore(existingUser), nil
} }
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error { func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
......
...@@ -45,7 +45,7 @@ const PasswordSignInForm = observer(() => { ...@@ -45,7 +45,7 @@ const PasswordSignInForm = observer(() => {
try { try {
actionBtnLoadingState.setLoading(); actionBtnLoadingState.setLoading();
await authServiceClient.signIn({ username, password, neverExpire: remember }); await authServiceClient.signIn({ passwordCredentials: { username, password }, neverExpire: remember });
await initialUserStore(); await initialUserStore();
navigateTo("/"); navigateTo("/");
} catch (error: any) { } catch (error: any) {
......
...@@ -45,10 +45,12 @@ const AuthCallback = () => { ...@@ -45,10 +45,12 @@ const AuthCallback = () => {
const redirectUri = absolutifyLink("/auth/callback"); const redirectUri = absolutifyLink("/auth/callback");
(async () => { (async () => {
try { try {
await authServiceClient.signInWithSSO({ await authServiceClient.signIn({
ssoCredentials: {
idpId: identityProviderId, idpId: identityProviderId,
code, code,
redirectUri, redirectUri,
},
}); });
setState({ setState({
loading: false, loading: false,
......
...@@ -19,15 +19,26 @@ export interface GetAuthStatusResponse { ...@@ -19,15 +19,26 @@ export interface GetAuthStatusResponse {
} }
export interface SignInRequest { export interface SignInRequest {
/** Username and password authentication method. */
passwordCredentials?:
| PasswordCredentials
| undefined;
/** SSO provider authentication method. */
ssoCredentials?:
| SSOCredentials
| undefined;
/** Whether the session should never expire. */
neverExpire: boolean;
}
export interface PasswordCredentials {
/** The username to sign in with. */ /** The username to sign in with. */
username: string; username: string;
/** The password to sign in with. */ /** The password to sign in with. */
password: string; password: string;
/** Whether the session should never expire. */
neverExpire: boolean;
} }
export interface SignInWithSSORequest { export interface SSOCredentials {
/** The ID of the SSO provider. */ /** The ID of the SSO provider. */
idpId: number; idpId: number;
/** The code to sign in with. */ /** The code to sign in with. */
...@@ -127,16 +138,16 @@ export const GetAuthStatusResponse: MessageFns<GetAuthStatusResponse> = { ...@@ -127,16 +138,16 @@ export const GetAuthStatusResponse: MessageFns<GetAuthStatusResponse> = {
}; };
function createBaseSignInRequest(): SignInRequest { function createBaseSignInRequest(): SignInRequest {
return { username: "", password: "", neverExpire: false }; return { passwordCredentials: undefined, ssoCredentials: undefined, neverExpire: false };
} }
export const SignInRequest: MessageFns<SignInRequest> = { export const SignInRequest: MessageFns<SignInRequest> = {
encode(message: SignInRequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { encode(message: SignInRequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter {
if (message.username !== "") { if (message.passwordCredentials !== undefined) {
writer.uint32(10).string(message.username); PasswordCredentials.encode(message.passwordCredentials, writer.uint32(10).fork()).join();
} }
if (message.password !== "") { if (message.ssoCredentials !== undefined) {
writer.uint32(18).string(message.password); SSOCredentials.encode(message.ssoCredentials, writer.uint32(18).fork()).join();
} }
if (message.neverExpire !== false) { if (message.neverExpire !== false) {
writer.uint32(24).bool(message.neverExpire); writer.uint32(24).bool(message.neverExpire);
...@@ -156,7 +167,7 @@ export const SignInRequest: MessageFns<SignInRequest> = { ...@@ -156,7 +167,7 @@ export const SignInRequest: MessageFns<SignInRequest> = {
break; break;
} }
message.username = reader.string(); message.passwordCredentials = PasswordCredentials.decode(reader, reader.uint32());
continue; continue;
} }
case 2: { case 2: {
...@@ -164,7 +175,7 @@ export const SignInRequest: MessageFns<SignInRequest> = { ...@@ -164,7 +175,7 @@ export const SignInRequest: MessageFns<SignInRequest> = {
break; break;
} }
message.password = reader.string(); message.ssoCredentials = SSOCredentials.decode(reader, reader.uint32());
continue; continue;
} }
case 3: { case 3: {
...@@ -189,19 +200,81 @@ export const SignInRequest: MessageFns<SignInRequest> = { ...@@ -189,19 +200,81 @@ export const SignInRequest: MessageFns<SignInRequest> = {
}, },
fromPartial(object: DeepPartial<SignInRequest>): SignInRequest { fromPartial(object: DeepPartial<SignInRequest>): SignInRequest {
const message = createBaseSignInRequest(); const message = createBaseSignInRequest();
message.passwordCredentials = (object.passwordCredentials !== undefined && object.passwordCredentials !== null)
? PasswordCredentials.fromPartial(object.passwordCredentials)
: undefined;
message.ssoCredentials = (object.ssoCredentials !== undefined && object.ssoCredentials !== null)
? SSOCredentials.fromPartial(object.ssoCredentials)
: undefined;
message.neverExpire = object.neverExpire ?? false;
return message;
},
};
function createBasePasswordCredentials(): PasswordCredentials {
return { username: "", password: "" };
}
export const PasswordCredentials: MessageFns<PasswordCredentials> = {
encode(message: PasswordCredentials, writer: BinaryWriter = new BinaryWriter()): BinaryWriter {
if (message.username !== "") {
writer.uint32(10).string(message.username);
}
if (message.password !== "") {
writer.uint32(18).string(message.password);
}
return writer;
},
decode(input: BinaryReader | Uint8Array, length?: number): PasswordCredentials {
const reader = input instanceof BinaryReader ? input : new BinaryReader(input);
let end = length === undefined ? reader.len : reader.pos + length;
const message = createBasePasswordCredentials();
while (reader.pos < end) {
const tag = reader.uint32();
switch (tag >>> 3) {
case 1: {
if (tag !== 10) {
break;
}
message.username = reader.string();
continue;
}
case 2: {
if (tag !== 18) {
break;
}
message.password = reader.string();
continue;
}
}
if ((tag & 7) === 4 || tag === 0) {
break;
}
reader.skip(tag & 7);
}
return message;
},
create(base?: DeepPartial<PasswordCredentials>): PasswordCredentials {
return PasswordCredentials.fromPartial(base ?? {});
},
fromPartial(object: DeepPartial<PasswordCredentials>): PasswordCredentials {
const message = createBasePasswordCredentials();
message.username = object.username ?? ""; message.username = object.username ?? "";
message.password = object.password ?? ""; message.password = object.password ?? "";
message.neverExpire = object.neverExpire ?? false;
return message; return message;
}, },
}; };
function createBaseSignInWithSSORequest(): SignInWithSSORequest { function createBaseSSOCredentials(): SSOCredentials {
return { idpId: 0, code: "", redirectUri: "" }; return { idpId: 0, code: "", redirectUri: "" };
} }
export const SignInWithSSORequest: MessageFns<SignInWithSSORequest> = { export const SSOCredentials: MessageFns<SSOCredentials> = {
encode(message: SignInWithSSORequest, writer: BinaryWriter = new BinaryWriter()): BinaryWriter { encode(message: SSOCredentials, writer: BinaryWriter = new BinaryWriter()): BinaryWriter {
if (message.idpId !== 0) { if (message.idpId !== 0) {
writer.uint32(8).int32(message.idpId); writer.uint32(8).int32(message.idpId);
} }
...@@ -214,10 +287,10 @@ export const SignInWithSSORequest: MessageFns<SignInWithSSORequest> = { ...@@ -214,10 +287,10 @@ export const SignInWithSSORequest: MessageFns<SignInWithSSORequest> = {
return writer; return writer;
}, },
decode(input: BinaryReader | Uint8Array, length?: number): SignInWithSSORequest { decode(input: BinaryReader | Uint8Array, length?: number): SSOCredentials {
const reader = input instanceof BinaryReader ? input : new BinaryReader(input); const reader = input instanceof BinaryReader ? input : new BinaryReader(input);
let end = length === undefined ? reader.len : reader.pos + length; let end = length === undefined ? reader.len : reader.pos + length;
const message = createBaseSignInWithSSORequest(); const message = createBaseSSOCredentials();
while (reader.pos < end) { while (reader.pos < end) {
const tag = reader.uint32(); const tag = reader.uint32();
switch (tag >>> 3) { switch (tag >>> 3) {
...@@ -254,11 +327,11 @@ export const SignInWithSSORequest: MessageFns<SignInWithSSORequest> = { ...@@ -254,11 +327,11 @@ export const SignInWithSSORequest: MessageFns<SignInWithSSORequest> = {
return message; return message;
}, },
create(base?: DeepPartial<SignInWithSSORequest>): SignInWithSSORequest { create(base?: DeepPartial<SSOCredentials>): SSOCredentials {
return SignInWithSSORequest.fromPartial(base ?? {}); return SSOCredentials.fromPartial(base ?? {});
}, },
fromPartial(object: DeepPartial<SignInWithSSORequest>): SignInWithSSORequest { fromPartial(object: DeepPartial<SSOCredentials>): SSOCredentials {
const message = createBaseSignInWithSSORequest(); const message = createBaseSSOCredentials();
message.idpId = object.idpId ?? 0; message.idpId = object.idpId ?? 0;
message.code = object.code ?? ""; message.code = object.code ?? "";
message.redirectUri = object.redirectUri ?? ""; message.redirectUri = object.redirectUri ?? "";
...@@ -401,7 +474,7 @@ export const AuthServiceDefinition = { ...@@ -401,7 +474,7 @@ export const AuthServiceDefinition = {
}, },
}, },
}, },
/** SignIn signs in the user with the given username and password. */ /** SignIn signs in the user. */
signIn: { signIn: {
name: "SignIn", name: "SignIn",
requestType: SignInRequest, requestType: SignInRequest,
...@@ -439,48 +512,6 @@ export const AuthServiceDefinition = { ...@@ -439,48 +512,6 @@ export const AuthServiceDefinition = {
}, },
}, },
}, },
/** SignInWithSSO signs in the user with the given SSO code. */
signInWithSSO: {
name: "SignInWithSSO",
requestType: SignInWithSSORequest,
requestStream: false,
responseType: User,
responseStream: false,
options: {
_unknownFields: {
578365826: [
new Uint8Array([
25,
34,
23,
47,
97,
112,
105,
47,
118,
49,
47,
97,
117,
116,
104,
47,
115,
105,
103,
110,
105,
110,
47,
115,
115,
111,
]),
],
},
},
},
/** SignUp signs up the user with the given username and password. */ /** SignUp signs up the user with the given username and password. */
signUp: { signUp: {
name: "SignUp", name: "SignUp",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment