Commit 42d1650c authored by Steven's avatar Steven

chore: tweak auth service

parent 6e1b01cb
......@@ -6,19 +6,20 @@ import "api/v1/user_service.proto";
import "google/api/annotations.proto";
import "google/api/field_behavior.proto";
import "google/protobuf/empty.proto";
import "google/protobuf/timestamp.proto";
option go_package = "gen/api/v1";
service AuthService {
// GetCurrentSession returns the current active session information.
// This method is idempotent and safe, suitable for checking current session state.
rpc GetCurrentSession(GetCurrentSessionRequest) returns (User) {
rpc GetCurrentSession(GetCurrentSessionRequest) returns (GetCurrentSessionResponse) {
option (google.api.http) = {get: "/api/v1/auth/sessions/current"};
}
// CreateSession authenticates a user and creates a new session.
// Returns the authenticated user information upon successful authentication.
rpc CreateSession(CreateSessionRequest) returns (User) {
rpc CreateSession(CreateSessionRequest) returns (CreateSessionResponse) {
option (google.api.http) = {
post: "/api/v1/auth/sessions"
body: "*"
......@@ -36,6 +37,9 @@ message GetCurrentSessionRequest {}
message GetCurrentSessionResponse {
User user = 1;
// Current session expiration time (if available).
google.protobuf.Timestamp expires_at = 2;
}
message CreateSessionRequest {
......@@ -67,7 +71,7 @@ message CreateSessionRequest {
// Provide one authentication method (username/password or SSO).
// Required field to specify the authentication method.
oneof method {
oneof credentials {
// Username and password authentication method.
PasswordCredentials password_credentials = 1;
......@@ -80,4 +84,12 @@ message CreateSessionRequest {
bool never_expire = 3 [(google.api.field_behavior) = OPTIONAL];
}
message CreateSessionResponse {
// The authenticated user information.
User user = 1;
// Token expiration time.
google.protobuf.Timestamp expires_at = 2;
}
message DeleteSessionRequest {}
This diff is collapsed.
......@@ -31,10 +31,10 @@ const (
type AuthServiceClient interface {
// GetCurrentSession returns the current active session information.
// This method is idempotent and safe, suitable for checking current session state.
GetCurrentSession(ctx context.Context, in *GetCurrentSessionRequest, opts ...grpc.CallOption) (*User, error)
GetCurrentSession(ctx context.Context, in *GetCurrentSessionRequest, opts ...grpc.CallOption) (*GetCurrentSessionResponse, error)
// CreateSession authenticates a user and creates a new session.
// Returns the authenticated user information upon successful authentication.
CreateSession(ctx context.Context, in *CreateSessionRequest, opts ...grpc.CallOption) (*User, error)
CreateSession(ctx context.Context, in *CreateSessionRequest, opts ...grpc.CallOption) (*CreateSessionResponse, error)
// DeleteSession terminates the current user session.
// This is an idempotent operation that invalidates the user's authentication.
DeleteSession(ctx context.Context, in *DeleteSessionRequest, opts ...grpc.CallOption) (*emptypb.Empty, error)
......@@ -48,9 +48,9 @@ func NewAuthServiceClient(cc grpc.ClientConnInterface) AuthServiceClient {
return &authServiceClient{cc}
}
func (c *authServiceClient) GetCurrentSession(ctx context.Context, in *GetCurrentSessionRequest, opts ...grpc.CallOption) (*User, error) {
func (c *authServiceClient) GetCurrentSession(ctx context.Context, in *GetCurrentSessionRequest, opts ...grpc.CallOption) (*GetCurrentSessionResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(User)
out := new(GetCurrentSessionResponse)
err := c.cc.Invoke(ctx, AuthService_GetCurrentSession_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
......@@ -58,9 +58,9 @@ func (c *authServiceClient) GetCurrentSession(ctx context.Context, in *GetCurren
return out, nil
}
func (c *authServiceClient) CreateSession(ctx context.Context, in *CreateSessionRequest, opts ...grpc.CallOption) (*User, error) {
func (c *authServiceClient) CreateSession(ctx context.Context, in *CreateSessionRequest, opts ...grpc.CallOption) (*CreateSessionResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(User)
out := new(CreateSessionResponse)
err := c.cc.Invoke(ctx, AuthService_CreateSession_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
......@@ -84,10 +84,10 @@ func (c *authServiceClient) DeleteSession(ctx context.Context, in *DeleteSession
type AuthServiceServer interface {
// GetCurrentSession returns the current active session information.
// This method is idempotent and safe, suitable for checking current session state.
GetCurrentSession(context.Context, *GetCurrentSessionRequest) (*User, error)
GetCurrentSession(context.Context, *GetCurrentSessionRequest) (*GetCurrentSessionResponse, error)
// CreateSession authenticates a user and creates a new session.
// Returns the authenticated user information upon successful authentication.
CreateSession(context.Context, *CreateSessionRequest) (*User, error)
CreateSession(context.Context, *CreateSessionRequest) (*CreateSessionResponse, error)
// DeleteSession terminates the current user session.
// This is an idempotent operation that invalidates the user's authentication.
DeleteSession(context.Context, *DeleteSessionRequest) (*emptypb.Empty, error)
......@@ -101,10 +101,10 @@ type AuthServiceServer interface {
// pointer dereference when methods are called.
type UnimplementedAuthServiceServer struct{}
func (UnimplementedAuthServiceServer) GetCurrentSession(context.Context, *GetCurrentSessionRequest) (*User, error) {
func (UnimplementedAuthServiceServer) GetCurrentSession(context.Context, *GetCurrentSessionRequest) (*GetCurrentSessionResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentSession not implemented")
}
func (UnimplementedAuthServiceServer) CreateSession(context.Context, *CreateSessionRequest) (*User, error) {
func (UnimplementedAuthServiceServer) CreateSession(context.Context, *CreateSessionRequest) (*CreateSessionResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method CreateSession not implemented")
}
func (UnimplementedAuthServiceServer) DeleteSession(context.Context, *DeleteSessionRequest) (*emptypb.Empty, error) {
......
......@@ -141,7 +141,7 @@ paths:
"200":
description: A successful response.
schema:
$ref: '#/definitions/v1User'
$ref: '#/definitions/v1CreateSessionResponse'
default:
description: An unexpected error response.
schema:
......@@ -164,7 +164,7 @@ paths:
"200":
description: A successful response.
schema:
$ref: '#/definitions/v1User'
$ref: '#/definitions/v1GetCurrentSessionResponse'
default:
description: An unexpected error response.
schema:
......@@ -3287,6 +3287,16 @@ definitions:
description: |-
Whether the session should never expire.
Optional field that defaults to false for security.
v1CreateSessionResponse:
type: object
properties:
user:
$ref: '#/definitions/v1User'
description: The authenticated user information.
expiresAt:
type: string
format: date-time
description: Token expiration time.
v1EmbeddedContentNode:
type: object
properties:
......@@ -3301,6 +3311,15 @@ definitions:
properties:
symbol:
type: string
v1GetCurrentSessionResponse:
type: object
properties:
user:
$ref: '#/definitions/v1User'
expiresAt:
type: string
format: date-time
description: Current session expiration time (if available).
v1HTMLElementNode:
type: object
properties:
......
......@@ -3,7 +3,6 @@ package v1
var authenticationAllowlistMethods = map[string]bool{
"/memos.api.v1.WorkspaceService/GetWorkspaceProfile": true,
"/memos.api.v1.WorkspaceService/GetWorkspaceSetting": true,
"/memos.api.v1.IdentityProviderService/GetIdentityProvider": true,
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": true,
"/memos.api.v1.AuthService/CreateSession": true,
"/memos.api.v1.AuthService/GetCurrentSession": true,
......
......@@ -29,7 +29,7 @@ const (
unmatchedUsernameAndPasswordError = "unmatched username and password"
)
func (s *APIV1Service) GetCurrentSession(ctx context.Context, _ *v1pb.GetCurrentSessionRequest) (*v1pb.User, error) {
func (s *APIV1Service) GetCurrentSession(ctx context.Context, _ *v1pb.GetCurrentSessionRequest) (*v1pb.GetCurrentSessionResponse, error) {
user, err := s.GetCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err)
......@@ -50,10 +50,12 @@ func (s *APIV1Service) GetCurrentSession(ctx context.Context, _ *v1pb.GetCurrent
}
}
return convertUserFromStore(user), nil
return &v1pb.GetCurrentSessionResponse{
User: convertUserFromStore(user),
}, nil
}
func (s *APIV1Service) CreateSession(ctx context.Context, request *v1pb.CreateSessionRequest) (*v1pb.User, error) {
func (s *APIV1Service) CreateSession(ctx context.Context, request *v1pb.CreateSessionRequest) (*v1pb.CreateSessionResponse, error) {
var existingUser *store.User
if passwordCredentials := request.GetPasswordCredentials(); passwordCredentials != nil {
user, err := s.Store.GetUser(ctx, &store.FindUser{
......@@ -173,7 +175,11 @@ func (s *APIV1Service) CreateSession(ctx context.Context, request *v1pb.CreateSe
if err := s.doSignIn(ctx, existingUser, expireTime); err != nil {
return nil, status.Errorf(codes.Internal, "failed to sign in, error: %v", err)
}
return convertUserFromStore(existingUser), nil
return &v1pb.CreateSessionResponse{
User: convertUserFromStore(existingUser),
ExpiresAt: timestamppb.New(expireTime),
}, nil
}
func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
......
......@@ -13,5 +13,5 @@ func TestGetCurrentSchemaVersion(t *testing.T) {
currentSchemaVersion, err := ts.GetCurrentSchemaVersion()
require.NoError(t, err)
require.Equal(t, "0.24.2", currentSchemaVersion)
require.Equal(t, "0.25.1", currentSchemaVersion)
}
......@@ -45,7 +45,10 @@ const PasswordSignInForm = observer(() => {
try {
actionBtnLoadingState.setLoading();
await authServiceClient.createSession({ passwordCredentials: { username, password }, neverExpire: remember });
await authServiceClient.createSession({
passwordCredentials: { username, password },
neverExpire: remember,
});
await initialUserStore();
navigateTo("/");
} catch (error: any) {
......
......@@ -231,7 +231,16 @@ const userStore = (() => {
export const initialUserStore = async () => {
try {
const currentUser = await authServiceClient.getCurrentSession({});
const { user: currentUser } = await authServiceClient.getCurrentSession({});
if (!currentUser) {
// If no user is authenticated, we can skip the rest of the initialization.
userStore.state.setPartial({
currentUser: undefined,
userSetting: undefined,
userMapByName: {},
});
return;
}
const userSetting = await userServiceClient.getUserSetting({ name: currentUser.name });
userStore.state.setPartial({
currentUser: currentUser.name,
......
......@@ -7,6 +7,7 @@
/* eslint-disable */
import { BinaryReader, BinaryWriter } from "@bufbuild/protobuf/wire";
import { Empty } from "../../google/protobuf/empty";
import { Timestamp } from "../../google/protobuf/timestamp";
import { User } from "./user_service";
export const protobufPackage = "memos.api.v1";
......@@ -15,7 +16,11 @@ export interface GetCurrentSessionRequest {
}
export interface GetCurrentSessionResponse {
user?: User | undefined;
user?:
| User
| undefined;
/** Current session expiration time (if available). */
expiresAt?: Date | undefined;
}
export interface CreateSessionRequest {
......@@ -67,6 +72,15 @@ export interface CreateSessionRequest_SSOCredentials {
redirectUri: string;
}
export interface CreateSessionResponse {
/** The authenticated user information. */
user?:
| User
| undefined;
/** Token expiration time. */
expiresAt?: Date | undefined;
}
export interface DeleteSessionRequest {
}
......@@ -105,7 +119,7 @@ export const GetCurrentSessionRequest: MessageFns<GetCurrentSessionRequest> = {
};
function createBaseGetCurrentSessionResponse(): GetCurrentSessionResponse {
return { user: undefined };
return { user: undefined, expiresAt: undefined };
}
export const GetCurrentSessionResponse: MessageFns<GetCurrentSessionResponse> = {
......@@ -113,6 +127,9 @@ export const GetCurrentSessionResponse: MessageFns<GetCurrentSessionResponse> =
if (message.user !== undefined) {
User.encode(message.user, writer.uint32(10).fork()).join();
}
if (message.expiresAt !== undefined) {
Timestamp.encode(toTimestamp(message.expiresAt), writer.uint32(18).fork()).join();
}
return writer;
},
......@@ -131,6 +148,14 @@ export const GetCurrentSessionResponse: MessageFns<GetCurrentSessionResponse> =
message.user = User.decode(reader, reader.uint32());
continue;
}
case 2: {
if (tag !== 18) {
break;
}
message.expiresAt = fromTimestamp(Timestamp.decode(reader, reader.uint32()));
continue;
}
}
if ((tag & 7) === 4 || tag === 0) {
break;
......@@ -146,6 +171,7 @@ export const GetCurrentSessionResponse: MessageFns<GetCurrentSessionResponse> =
fromPartial(object: DeepPartial<GetCurrentSessionResponse>): GetCurrentSessionResponse {
const message = createBaseGetCurrentSessionResponse();
message.user = (object.user !== undefined && object.user !== null) ? User.fromPartial(object.user) : undefined;
message.expiresAt = object.expiresAt ?? undefined;
return message;
},
};
......@@ -352,6 +378,64 @@ export const CreateSessionRequest_SSOCredentials: MessageFns<CreateSessionReques
},
};
function createBaseCreateSessionResponse(): CreateSessionResponse {
return { user: undefined, expiresAt: undefined };
}
export const CreateSessionResponse: MessageFns<CreateSessionResponse> = {
encode(message: CreateSessionResponse, writer: BinaryWriter = new BinaryWriter()): BinaryWriter {
if (message.user !== undefined) {
User.encode(message.user, writer.uint32(10).fork()).join();
}
if (message.expiresAt !== undefined) {
Timestamp.encode(toTimestamp(message.expiresAt), writer.uint32(18).fork()).join();
}
return writer;
},
decode(input: BinaryReader | Uint8Array, length?: number): CreateSessionResponse {
const reader = input instanceof BinaryReader ? input : new BinaryReader(input);
let end = length === undefined ? reader.len : reader.pos + length;
const message = createBaseCreateSessionResponse();
while (reader.pos < end) {
const tag = reader.uint32();
switch (tag >>> 3) {
case 1: {
if (tag !== 10) {
break;
}
message.user = User.decode(reader, reader.uint32());
continue;
}
case 2: {
if (tag !== 18) {
break;
}
message.expiresAt = fromTimestamp(Timestamp.decode(reader, reader.uint32()));
continue;
}
}
if ((tag & 7) === 4 || tag === 0) {
break;
}
reader.skip(tag & 7);
}
return message;
},
create(base?: DeepPartial<CreateSessionResponse>): CreateSessionResponse {
return CreateSessionResponse.fromPartial(base ?? {});
},
fromPartial(object: DeepPartial<CreateSessionResponse>): CreateSessionResponse {
const message = createBaseCreateSessionResponse();
message.user = (object.user !== undefined && object.user !== null) ? User.fromPartial(object.user) : undefined;
message.expiresAt = object.expiresAt ?? undefined;
return message;
},
};
function createBaseDeleteSessionRequest(): DeleteSessionRequest {
return {};
}
......@@ -399,7 +483,7 @@ export const AuthServiceDefinition = {
name: "GetCurrentSession",
requestType: GetCurrentSessionRequest,
requestStream: false,
responseType: User,
responseType: GetCurrentSessionResponse,
responseStream: false,
options: {
_unknownFields: {
......@@ -450,7 +534,7 @@ export const AuthServiceDefinition = {
name: "CreateSession",
requestType: CreateSessionRequest,
requestStream: false,
responseType: User,
responseType: CreateSessionResponse,
responseStream: false,
options: {
_unknownFields: {
......@@ -550,6 +634,18 @@ export type DeepPartial<T> = T extends Builtin ? T
: T extends {} ? { [K in keyof T]?: DeepPartial<T[K]> }
: Partial<T>;
function toTimestamp(date: Date): Timestamp {
const seconds = Math.trunc(date.getTime() / 1_000);
const nanos = (date.getTime() % 1_000) * 1_000_000;
return { seconds, nanos };
}
function fromTimestamp(t: Timestamp): Date {
let millis = (t.seconds || 0) * 1_000;
millis += (t.nanos || 0) / 1_000_000;
return new globalThis.Date(millis);
}
export interface MessageFns<T> {
encode(message: T, writer?: BinaryWriter): BinaryWriter;
decode(input: BinaryReader | Uint8Array, length?: number): T;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment