Commit c373131b authored by Steven's avatar Steven

chore: migrate idp service

parent a7770326
...@@ -12,21 +12,21 @@ import ( ...@@ -12,21 +12,21 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
"github.com/usememos/memos/plugin/idp" "github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/store" storepb "github.com/usememos/memos/proto/gen/store"
) )
// IdentityProvider represents an OAuth2 Identity Provider. // IdentityProvider represents an OAuth2 Identity Provider.
type IdentityProvider struct { type IdentityProvider struct {
config *store.IdentityProviderOAuth2Config config *storepb.OAuth2Config
} }
// NewIdentityProvider initializes a new OAuth2 Identity Provider with the given configuration. // NewIdentityProvider initializes a new OAuth2 Identity Provider with the given configuration.
func NewIdentityProvider(config *store.IdentityProviderOAuth2Config) (*IdentityProvider, error) { func NewIdentityProvider(config *storepb.OAuth2Config) (*IdentityProvider, error) {
for v, field := range map[string]string{ for v, field := range map[string]string{
config.ClientID: "clientId", config.ClientId: "clientId",
config.ClientSecret: "clientSecret", config.ClientSecret: "clientSecret",
config.TokenURL: "tokenUrl", config.TokenUrl: "tokenUrl",
config.UserInfoURL: "userInfoUrl", config.UserInfoUrl: "userInfoUrl",
config.FieldMapping.Identifier: "fieldMapping.identifier", config.FieldMapping.Identifier: "fieldMapping.identifier",
} { } {
if v == "" { if v == "" {
...@@ -42,13 +42,13 @@ func NewIdentityProvider(config *store.IdentityProviderOAuth2Config) (*IdentityP ...@@ -42,13 +42,13 @@ func NewIdentityProvider(config *store.IdentityProviderOAuth2Config) (*IdentityP
// ExchangeToken returns the exchanged OAuth2 token using the given authorization code. // ExchangeToken returns the exchanged OAuth2 token using the given authorization code.
func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) { func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code string) (string, error) {
conf := &oauth2.Config{ conf := &oauth2.Config{
ClientID: p.config.ClientID, ClientID: p.config.ClientId,
ClientSecret: p.config.ClientSecret, ClientSecret: p.config.ClientSecret,
RedirectURL: redirectURL, RedirectURL: redirectURL,
Scopes: p.config.Scopes, Scopes: p.config.Scopes,
Endpoint: oauth2.Endpoint{ Endpoint: oauth2.Endpoint{
AuthURL: p.config.AuthURL, AuthURL: p.config.AuthUrl,
TokenURL: p.config.TokenURL, TokenURL: p.config.TokenUrl,
AuthStyle: oauth2.AuthStyleInParams, AuthStyle: oauth2.AuthStyleInParams,
}, },
} }
...@@ -69,7 +69,7 @@ func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code ...@@ -69,7 +69,7 @@ func (p *IdentityProvider) ExchangeToken(ctx context.Context, redirectURL, code
// UserInfo returns the parsed user information using the given OAuth2 token. // UserInfo returns the parsed user information using the given OAuth2 token.
func (p *IdentityProvider) UserInfo(token string) (*idp.IdentityProviderUserInfo, error) { func (p *IdentityProvider) UserInfo(token string) (*idp.IdentityProviderUserInfo, error) {
client := &http.Client{} client := &http.Client{}
req, err := http.NewRequest(http.MethodGet, p.config.UserInfoURL, nil) req, err := http.NewRequest(http.MethodGet, p.config.UserInfoUrl, nil)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to new http request") return nil, errors.Wrap(err, "failed to new http request")
} }
......
...@@ -14,24 +14,24 @@ import ( ...@@ -14,24 +14,24 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/idp" "github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/store" storepb "github.com/usememos/memos/proto/gen/store"
) )
func TestNewIdentityProvider(t *testing.T) { func TestNewIdentityProvider(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
config *store.IdentityProviderOAuth2Config config *storepb.OAuth2Config
containsErr string containsErr string
}{ }{
{ {
name: "no tokenUrl", name: "no tokenUrl",
config: &store.IdentityProviderOAuth2Config{ config: &storepb.OAuth2Config{
ClientID: "test-client-id", ClientId: "test-client-id",
ClientSecret: "test-client-secret", ClientSecret: "test-client-secret",
AuthURL: "", AuthUrl: "",
TokenURL: "", TokenUrl: "",
UserInfoURL: "https://example.com/api/user", UserInfoUrl: "https://example.com/api/user",
FieldMapping: &store.FieldMapping{ FieldMapping: &storepb.FieldMapping{
Identifier: "login", Identifier: "login",
}, },
}, },
...@@ -39,13 +39,13 @@ func TestNewIdentityProvider(t *testing.T) { ...@@ -39,13 +39,13 @@ func TestNewIdentityProvider(t *testing.T) {
}, },
{ {
name: "no userInfoUrl", name: "no userInfoUrl",
config: &store.IdentityProviderOAuth2Config{ config: &storepb.OAuth2Config{
ClientID: "test-client-id", ClientId: "test-client-id",
ClientSecret: "test-client-secret", ClientSecret: "test-client-secret",
AuthURL: "", AuthUrl: "",
TokenURL: "https://example.com/token", TokenUrl: "https://example.com/token",
UserInfoURL: "", UserInfoUrl: "",
FieldMapping: &store.FieldMapping{ FieldMapping: &storepb.FieldMapping{
Identifier: "login", Identifier: "login",
}, },
}, },
...@@ -53,13 +53,13 @@ func TestNewIdentityProvider(t *testing.T) { ...@@ -53,13 +53,13 @@ func TestNewIdentityProvider(t *testing.T) {
}, },
{ {
name: "no field mapping identifier", name: "no field mapping identifier",
config: &store.IdentityProviderOAuth2Config{ config: &storepb.OAuth2Config{
ClientID: "test-client-id", ClientId: "test-client-id",
ClientSecret: "test-client-secret", ClientSecret: "test-client-secret",
AuthURL: "", AuthUrl: "",
TokenURL: "https://example.com/token", TokenUrl: "https://example.com/token",
UserInfoURL: "https://example.com/api/user", UserInfoUrl: "https://example.com/api/user",
FieldMapping: &store.FieldMapping{ FieldMapping: &storepb.FieldMapping{
Identifier: "", Identifier: "",
}, },
}, },
...@@ -113,7 +113,7 @@ func TestIdentityProvider(t *testing.T) { ...@@ -113,7 +113,7 @@ func TestIdentityProvider(t *testing.T) {
ctx := context.Background() ctx := context.Background()
const ( const (
testClientID = "test-client-id" testClientId = "test-client-id"
testCode = "test-code" testCode = "test-code"
testAccessToken = "test-access-token" testAccessToken = "test-access-token"
testSubject = "123456789" testSubject = "123456789"
...@@ -132,12 +132,12 @@ func TestIdentityProvider(t *testing.T) { ...@@ -132,12 +132,12 @@ func TestIdentityProvider(t *testing.T) {
s := newMockServer(t, testCode, testAccessToken, userInfo) s := newMockServer(t, testCode, testAccessToken, userInfo)
oauth2, err := NewIdentityProvider( oauth2, err := NewIdentityProvider(
&store.IdentityProviderOAuth2Config{ &storepb.OAuth2Config{
ClientID: testClientID, ClientId: testClientId,
ClientSecret: "test-client-secret", ClientSecret: "test-client-secret",
TokenURL: fmt.Sprintf("%s/oauth2/token", s.URL), TokenUrl: fmt.Sprintf("%s/oauth2/token", s.URL),
UserInfoURL: fmt.Sprintf("%s/oauth2/userinfo", s.URL), UserInfoUrl: fmt.Sprintf("%s/oauth2/userinfo", s.URL),
FieldMapping: &store.FieldMapping{ FieldMapping: &storepb.FieldMapping{
Identifier: "sub", Identifier: "sub",
DisplayName: "name", DisplayName: "name",
Email: "email", Email: "email",
......
...@@ -54,7 +54,7 @@ message IdentityProvider { ...@@ -54,7 +54,7 @@ message IdentityProvider {
message IdentityProviderConfig { message IdentityProviderConfig {
oneof config { oneof config {
OAuth2Config oauth2 = 1; OAuth2Config oauth2_config = 1;
} }
} }
......
...@@ -53,7 +53,7 @@ message Storage { ...@@ -53,7 +53,7 @@ message Storage {
} }
message StorageConfig { message StorageConfig {
oneof storage_config { oneof config {
S3Config s3_config = 1; S3Config s3_config = 1;
} }
} }
......
...@@ -1160,7 +1160,7 @@ Used internally for obfuscating the page token. ...@@ -1160,7 +1160,7 @@ Used internally for obfuscating the page token.
| Field | Type | Label | Description | | Field | Type | Label | Description |
| ----- | ---- | ----- | ----------- | | ----- | ---- | ----- | ----------- |
| oauth2 | [OAuth2Config](#memos-api-v2-OAuth2Config) | | | | oauth2_config | [OAuth2Config](#memos-api-v2-OAuth2Config) | | |
......
This diff is collapsed.
This diff is collapsed.
...@@ -13,9 +13,12 @@ ...@@ -13,9 +13,12 @@
- [store/idp.proto](#store_idp-proto) - [store/idp.proto](#store_idp-proto)
- [FieldMapping](#memos-store-FieldMapping) - [FieldMapping](#memos-store-FieldMapping)
- [IdentityProvider](#memos-store-IdentityProvider)
- [IdentityProviderConfig](#memos-store-IdentityProviderConfig) - [IdentityProviderConfig](#memos-store-IdentityProviderConfig)
- [OAuth2Config](#memos-store-OAuth2Config) - [OAuth2Config](#memos-store-OAuth2Config)
- [IdentityProvider.Type](#memos-store-IdentityProvider-Type)
- [store/inbox.proto](#store_inbox-proto) - [store/inbox.proto](#store_inbox-proto)
- [InboxMessage](#memos-store-InboxMessage) - [InboxMessage](#memos-store-InboxMessage)
...@@ -175,6 +178,25 @@ ...@@ -175,6 +178,25 @@
<a name="memos-store-IdentityProvider"></a>
### IdentityProvider
| Field | Type | Label | Description |
| ----- | ---- | ----- | ----------- |
| id | [int32](#int32) | | |
| name | [string](#string) | | |
| type | [IdentityProvider.Type](#memos-store-IdentityProvider-Type) | | |
| identifier_filter | [string](#string) | | |
| config | [IdentityProviderConfig](#memos-store-IdentityProviderConfig) | | |
<a name="memos-store-IdentityProviderConfig"></a> <a name="memos-store-IdentityProviderConfig"></a>
### IdentityProviderConfig ### IdentityProviderConfig
...@@ -183,7 +205,7 @@ ...@@ -183,7 +205,7 @@
| Field | Type | Label | Description | | Field | Type | Label | Description |
| ----- | ---- | ----- | ----------- | | ----- | ---- | ----- | ----------- |
| oauth2 | [OAuth2Config](#memos-store-OAuth2Config) | | | | oauth2_config | [OAuth2Config](#memos-store-OAuth2Config) | | |
...@@ -213,6 +235,18 @@ ...@@ -213,6 +235,18 @@
<a name="memos-store-IdentityProvider-Type"></a>
### IdentityProvider.Type
| Name | Number | Description |
| ---- | ------ | ----------- |
| TYPE_UNSPECIFIED | 0 | |
| OAUTH2 | 1 | |
......
This diff is collapsed.
...@@ -4,9 +4,22 @@ package memos.store; ...@@ -4,9 +4,22 @@ package memos.store;
option go_package = "gen/store"; option go_package = "gen/store";
message IdentityProvider {
int32 id = 1;
string name = 2;
enum Type {
TYPE_UNSPECIFIED = 0;
OAUTH2 = 1;
}
Type type = 3;
string identifier_filter = 4;
IdentityProviderConfig config = 5;
}
message IdentityProviderConfig { message IdentityProviderConfig {
oneof config { oneof config {
OAuth2Config oauth2 = 1; OAuth2Config oauth2_config = 1;
} }
} }
......
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"regexp"
"strings" "strings"
"time" "time"
...@@ -14,8 +13,6 @@ import ( ...@@ -14,8 +13,6 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/usememos/memos/internal/util" "github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/route/api/auth" "github.com/usememos/memos/server/route/api/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
...@@ -40,7 +37,6 @@ type SignUp struct { ...@@ -40,7 +37,6 @@ type SignUp struct {
func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
g.POST("/auth/signin", s.SignIn) g.POST("/auth/signin", s.SignIn)
g.POST("/auth/signin/sso", s.SignInSSO)
g.POST("/auth/signout", s.SignOut) g.POST("/auth/signout", s.SignOut)
g.POST("/auth/signup", s.SignUp) g.POST("/auth/signup", s.SignUp)
} }
...@@ -111,117 +107,6 @@ func (s *APIV1Service) SignIn(c echo.Context) error { ...@@ -111,117 +107,6 @@ func (s *APIV1Service) SignIn(c echo.Context) error {
return c.JSON(http.StatusOK, userMessage) return c.JSON(http.StatusOK, userMessage)
} }
// SignInSSO godoc
//
// @Summary Sign-in to memos using SSO.
// @Tags auth
// @Accept json
// @Produce json
// @Param body body SSOSignIn true "SSO sign-in object"
// @Success 200 {object} store.User "User information"
// @Failure 400 {object} nil "Malformatted signin request"
// @Failure 401 {object} nil "Access denied, identifier does not match the filter."
// @Failure 403 {object} nil "User has been archived with username {username}"
// @Failure 404 {object} nil "Identity provider not found"
// @Failure 500 {object} nil "Failed to find identity provider | Failed to create identity provider instance | Failed to exchange token | Failed to get user info | Failed to compile identifier filter | Incorrect login credentials, please try again | Failed to generate random password | Failed to generate password hash | Failed to create user | Failed to generate tokens | Failed to create activity"
// @Router /api/v1/auth/signin/sso [POST]
func (s *APIV1Service) SignInSSO(c echo.Context) error {
ctx := c.Request().Context()
signin := &SSOSignIn{}
if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err)
}
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{
ID: &signin.IdentityProviderID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider").SetInternal(err)
}
if identityProvider == nil {
return echo.NewHTTPError(http.StatusNotFound, "Identity provider not found")
}
var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == store.IdentityProviderOAuth2Type {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err)
}
token, err := oauth2IdentityProvider.ExchangeToken(ctx, signin.RedirectURI, signin.Code)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to exchange token").SetInternal(err)
}
userInfo, err = oauth2IdentityProvider.UserInfo(token)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user info").SetInternal(err)
}
}
identifierFilter := identityProvider.IdentifierFilter
if identifierFilter != "" {
identifierFilterRegex, err := regexp.Compile(identifierFilter)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to compile identifier filter").SetInternal(err)
}
if !identifierFilterRegex.MatchString(userInfo.Identifier) {
return echo.NewHTTPError(http.StatusUnauthorized, "Access denied, identifier does not match the filter.").SetInternal(err)
}
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
Username: &userInfo.Identifier,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Incorrect login credentials, please try again")
}
if user == nil {
workspaceGeneralSetting, err := s.Store.GetWorkspaceGeneralSetting(ctx)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err)
}
if workspaceGeneralSetting.DisallowSignup {
return echo.NewHTTPError(http.StatusUnauthorized, "signup is disabled").SetInternal(err)
}
userCreate := &store.User{
Username: userInfo.Identifier,
// The new signup user should be normal user by default.
Role: store.RoleUser,
Nickname: userInfo.DisplayName,
Email: userInfo.Email,
}
password, err := util.RandomString(20)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate random password").SetInternal(err)
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
}
userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
}
}
if user.RowStatus == store.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", userInfo.Identifier))
}
accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, time.Now().Add(auth.AccessTokenDuration), []byte(s.Secret))
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
}
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert access token, err: %s", err)).SetInternal(err)
}
cookieExp := time.Now().Add(auth.CookieExpDuration)
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
userMessage := convertUserFromStore(user)
return c.JSON(http.StatusOK, userMessage)
}
// SignOut godoc // SignOut godoc
// //
// @Summary Sign-out from memos. // @Summary Sign-out from memos.
......
This diff is collapsed.
...@@ -68,7 +68,6 @@ func (s *APIV1Service) Register(rootGroup *echo.Group) { ...@@ -68,7 +68,6 @@ func (s *APIV1Service) Register(rootGroup *echo.Group) {
s.registerSystemRoutes(apiV1Group) s.registerSystemRoutes(apiV1Group)
s.registerSystemSettingRoutes(apiV1Group) s.registerSystemSettingRoutes(apiV1Group)
s.registerAuthRoutes(apiV1Group) s.registerAuthRoutes(apiV1Group)
s.registerIdentityProviderRoutes(apiV1Group)
s.registerUserRoutes(apiV1Group) s.registerUserRoutes(apiV1Group)
s.registerTagRoutes(apiV1Group) s.registerTagRoutes(apiV1Group)
s.registerStorageRoutes(apiV1Group) s.registerStorageRoutes(apiV1Group)
......
...@@ -4,6 +4,8 @@ var authenticationAllowlistMethods = map[string]bool{ ...@@ -4,6 +4,8 @@ var authenticationAllowlistMethods = map[string]bool{
"/memos.api.v2.WorkspaceService/GetWorkspaceProfile": true, "/memos.api.v2.WorkspaceService/GetWorkspaceProfile": true,
"/memos.api.v2.WorkspaceSettingService/GetWorkspaceSetting": true, "/memos.api.v2.WorkspaceSettingService/GetWorkspaceSetting": true,
"/memos.api.v2.WorkspaceSettingService/ListWorkspaceSettings": true, "/memos.api.v2.WorkspaceSettingService/ListWorkspaceSettings": true,
"/memos.api.v2.IdentityProviderService/ListIdentityProviders": true,
"/memos.api.v2.IdentityProviderService/GetIdentityProvider": true,
"/memos.api.v2.AuthService/GetAuthStatus": true, "/memos.api.v2.AuthService/GetAuthStatus": true,
"/memos.api.v2.AuthService/SignIn": true, "/memos.api.v2.AuthService/SignIn": true,
"/memos.api.v2.AuthService/SignInWithSSO": true, "/memos.api.v2.AuthService/SignInWithSSO": true,
......
...@@ -181,42 +181,42 @@ paths: ...@@ -181,42 +181,42 @@ paths:
in: query in: query
required: false required: false
type: string type: string
- name: identityProvider.config.oauth2.clientId - name: identityProvider.config.oauth2Config.clientId
in: query in: query
required: false required: false
type: string type: string
- name: identityProvider.config.oauth2.clientSecret - name: identityProvider.config.oauth2Config.clientSecret
in: query in: query
required: false required: false
type: string type: string
- name: identityProvider.config.oauth2.authUrl - name: identityProvider.config.oauth2Config.authUrl
in: query in: query
required: false required: false
type: string type: string
- name: identityProvider.config.oauth2.tokenUrl - name: identityProvider.config.oauth2Config.tokenUrl
in: query in: query
required: false required: false
type: string type: string
- name: identityProvider.config.oauth2.userInfoUrl - name: identityProvider.config.oauth2Config.userInfoUrl
in: query in: query
required: false required: false
type: string type: string
- name: identityProvider.config.oauth2.scopes - name: identityProvider.config.oauth2Config.scopes
in: query in: query
required: false required: false
type: array type: array
items: items:
type: string type: string
collectionFormat: multi collectionFormat: multi
- name: identityProvider.config.oauth2.fieldMapping.identifier - name: identityProvider.config.oauth2Config.fieldMapping.identifier
in: query in: query
required: false required: false
type: string type: string
- name: identityProvider.config.oauth2.fieldMapping.displayName - name: identityProvider.config.oauth2Config.fieldMapping.displayName
in: query in: query
required: false required: false
type: string type: string
- name: identityProvider.config.oauth2.fieldMapping.email - name: identityProvider.config.oauth2Config.fieldMapping.email
in: query in: query
required: false required: false
type: string type: string
...@@ -1070,7 +1070,7 @@ paths: ...@@ -1070,7 +1070,7 @@ paths:
type: object type: object
properties: properties:
type: type:
$ref: '#/definitions/v2IdentityProviderType' $ref: '#/definitions/apiv2IdentityProviderType'
title: title:
type: string type: string
identifierFilter: identifierFilter:
...@@ -2006,11 +2006,33 @@ definitions: ...@@ -2006,11 +2006,33 @@ definitions:
type: string type: string
email: email:
type: string type: string
apiv2IdentityProvider:
type: object
properties:
name:
type: string
title: |-
The name of the identityProvider.
Format: identityProviders/{id}
type:
$ref: '#/definitions/apiv2IdentityProviderType'
title:
type: string
identifierFilter:
type: string
config:
$ref: '#/definitions/apiv2IdentityProviderConfig'
apiv2IdentityProviderConfig: apiv2IdentityProviderConfig:
type: object type: object
properties: properties:
oauth2: oauth2Config:
$ref: '#/definitions/apiv2OAuth2Config' $ref: '#/definitions/apiv2OAuth2Config'
apiv2IdentityProviderType:
type: string
enum:
- TYPE_UNSPECIFIED
- OAUTH2
default: TYPE_UNSPECIFIED
apiv2OAuth2Config: apiv2OAuth2Config:
type: object type: object
properties: properties:
...@@ -2293,7 +2315,7 @@ definitions: ...@@ -2293,7 +2315,7 @@ definitions:
type: object type: object
properties: properties:
identityProvider: identityProvider:
$ref: '#/definitions/v2IdentityProvider' $ref: '#/definitions/apiv2IdentityProvider'
description: The created identityProvider. description: The created identityProvider.
v2CreateMemoCommentResponse: v2CreateMemoCommentResponse:
type: object type: object
...@@ -2389,7 +2411,7 @@ definitions: ...@@ -2389,7 +2411,7 @@ definitions:
type: object type: object
properties: properties:
identityProvider: identityProvider:
$ref: '#/definitions/v2IdentityProvider' $ref: '#/definitions/apiv2IdentityProvider'
description: The identityProvider. description: The identityProvider.
v2GetLinkMetadataResponse: v2GetLinkMetadataResponse:
type: object type: object
...@@ -2454,28 +2476,6 @@ definitions: ...@@ -2454,28 +2476,6 @@ definitions:
properties: properties:
setting: setting:
$ref: '#/definitions/apiv2WorkspaceSetting' $ref: '#/definitions/apiv2WorkspaceSetting'
v2IdentityProvider:
type: object
properties:
name:
type: string
title: |-
The name of the identityProvider.
Format: identityProviders/{id}
type:
$ref: '#/definitions/v2IdentityProviderType'
title:
type: string
identifierFilter:
type: string
config:
$ref: '#/definitions/apiv2IdentityProviderConfig'
v2IdentityProviderType:
type: string
enum:
- TYPE_UNSPECIFIED
- OAUTH2
default: TYPE_UNSPECIFIED
v2Inbox: v2Inbox:
type: object type: object
properties: properties:
...@@ -2530,7 +2530,7 @@ definitions: ...@@ -2530,7 +2530,7 @@ definitions:
type: array type: array
items: items:
type: object type: object
$ref: '#/definitions/v2IdentityProvider' $ref: '#/definitions/apiv2IdentityProvider'
v2ListInboxesResponse: v2ListInboxesResponse:
type: object type: object
properties: properties:
...@@ -2820,7 +2820,7 @@ definitions: ...@@ -2820,7 +2820,7 @@ definitions:
type: object type: object
properties: properties:
identityProvider: identityProvider:
$ref: '#/definitions/v2IdentityProvider' $ref: '#/definitions/apiv2IdentityProvider'
description: The updated identityProvider. description: The updated identityProvider.
v2UpdateInboxResponse: v2UpdateInboxResponse:
type: object type: object
......
...@@ -18,6 +18,7 @@ import ( ...@@ -18,6 +18,7 @@ import (
"github.com/usememos/memos/plugin/idp" "github.com/usememos/memos/plugin/idp"
"github.com/usememos/memos/plugin/idp/oauth2" "github.com/usememos/memos/plugin/idp/oauth2"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2" apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/route/api/auth" "github.com/usememos/memos/server/route/api/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
...@@ -71,7 +72,7 @@ func (s *APIV2Service) SignIn(ctx context.Context, request *apiv2pb.SignInReques ...@@ -71,7 +72,7 @@ func (s *APIV2Service) SignIn(ctx context.Context, request *apiv2pb.SignInReques
} }
func (s *APIV2Service) SignInWithSSO(ctx context.Context, request *apiv2pb.SignInWithSSORequest) (*apiv2pb.SignInWithSSOResponse, error) { func (s *APIV2Service) SignInWithSSO(ctx context.Context, request *apiv2pb.SignInWithSSORequest) (*apiv2pb.SignInWithSSOResponse, error) {
identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ identityProvider, err := s.Store.GetIdentityProviderV1(ctx, &store.FindIdentityProvider{
ID: &request.IdpId, ID: &request.IdpId,
}) })
if err != nil { if err != nil {
...@@ -82,8 +83,8 @@ func (s *APIV2Service) SignInWithSSO(ctx context.Context, request *apiv2pb.SignI ...@@ -82,8 +83,8 @@ func (s *APIV2Service) SignInWithSSO(ctx context.Context, request *apiv2pb.SignI
} }
var userInfo *idp.IdentityProviderUserInfo var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == store.IdentityProviderOAuth2Type { if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config) oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.GetOauth2Config())
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create oauth2 identity provider, err: %s", err)) return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create oauth2 identity provider, err: %s", err))
} }
......
package v2
import (
"context"
"fmt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
func (s *APIV2Service) CreateIdentityProvider(ctx context.Context, request *apiv2pb.CreateIdentityProviderRequest) (*apiv2pb.CreateIdentityProviderResponse, error) {
currentUser, err := getCurrentUser(ctx, s.Store)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get user: %v", err)
}
if currentUser.Role != store.RoleHost {
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}
identityProvider, err := s.Store.CreateIdentityProviderV1(ctx, convertIdentityProviderToStore(request.IdentityProvider))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to create identity provider, error: %+v", err)
}
return &apiv2pb.CreateIdentityProviderResponse{
IdentityProvider: convertIdentityProviderFromStore(identityProvider),
}, nil
}
func (s *APIV2Service) ListIdentityProviders(ctx context.Context, _ *apiv2pb.ListIdentityProvidersRequest) (*apiv2pb.ListIdentityProvidersResponse, error) {
identityProviders, err := s.Store.ListIdentityProvidersV1(ctx, &store.FindIdentityProvider{})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list identity providers, error: %+v", err)
}
response := &apiv2pb.ListIdentityProvidersResponse{
IdentityProviders: []*apiv2pb.IdentityProvider{},
}
for _, identityProvider := range identityProviders {
response.IdentityProviders = append(response.IdentityProviders, convertIdentityProviderFromStore(identityProvider))
}
return response, nil
}
func (s *APIV2Service) GetIdentityProvider(ctx context.Context, request *apiv2pb.GetIdentityProviderRequest) (*apiv2pb.GetIdentityProviderResponse, error) {
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
identityProvider, err := s.Store.GetIdentityProviderV1(ctx, &store.FindIdentityProvider{
ID: &id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get identity provider, error: %+v", err)
}
if identityProvider == nil {
return nil, status.Errorf(codes.NotFound, "identity provider not found")
}
return &apiv2pb.GetIdentityProviderResponse{
IdentityProvider: convertIdentityProviderFromStore(identityProvider),
}, nil
}
func (s *APIV2Service) UpdateIdentityProvider(ctx context.Context, request *apiv2pb.UpdateIdentityProviderRequest) (*apiv2pb.UpdateIdentityProviderResponse, error) {
if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "update_mask is required")
}
id, err := ExtractIdentityProviderIDFromName(request.IdentityProvider.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
update := &store.UpdateIdentityProviderV1{
ID: id,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[request.IdentityProvider.Type.String()]),
}
for _, field := range request.UpdateMask.Paths {
switch field {
case "title":
update.Name = &request.IdentityProvider.Title
case "config":
update.Config = convertIdentityProviderConfigToStore(request.IdentityProvider.Type, request.IdentityProvider.Config)
}
}
identityProvider, err := s.Store.UpdateIdentityProviderV1(ctx, update)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to update identity provider, error: %+v", err)
}
return &apiv2pb.UpdateIdentityProviderResponse{
IdentityProvider: convertIdentityProviderFromStore(identityProvider),
}, nil
}
func (s *APIV2Service) DeleteIdentityProvider(ctx context.Context, request *apiv2pb.DeleteIdentityProviderRequest) (*apiv2pb.DeleteIdentityProviderResponse, error) {
id, err := ExtractIdentityProviderIDFromName(request.Name)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err)
}
if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil {
return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err)
}
return &apiv2pb.DeleteIdentityProviderResponse{}, nil
}
func convertIdentityProviderFromStore(identityProvider *storepb.IdentityProvider) *apiv2pb.IdentityProvider {
temp := &apiv2pb.IdentityProvider{
Name: fmt.Sprintf("%s%d", IdentityProviderNamePrefix, identityProvider.Id),
Title: identityProvider.Name,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: apiv2pb.IdentityProvider_Type(apiv2pb.IdentityProvider_Type_value[identityProvider.Type.String()]),
}
if identityProvider.Type == storepb.IdentityProvider_OAUTH2 {
oauth2Config := identityProvider.Config.GetOauth2Config()
temp.Config = &apiv2pb.IdentityProviderConfig{
Config: &apiv2pb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &apiv2pb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &apiv2pb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
},
},
},
}
}
return temp
}
func convertIdentityProviderToStore(identityProvider *apiv2pb.IdentityProvider) *storepb.IdentityProvider {
id, _ := ExtractIdentityProviderIDFromName(identityProvider.Name)
temp := &storepb.IdentityProvider{
Id: id,
Name: identityProvider.Title,
IdentifierFilter: identityProvider.IdentifierFilter,
Type: storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[identityProvider.Type.String()]),
Config: convertIdentityProviderConfigToStore(identityProvider.Type, identityProvider.Config),
}
return temp
}
func convertIdentityProviderConfigToStore(identityProviderType apiv2pb.IdentityProvider_Type, config *apiv2pb.IdentityProviderConfig) *storepb.IdentityProviderConfig {
if identityProviderType == apiv2pb.IdentityProvider_OAUTH2 {
oauth2Config := config.GetOauth2Config()
return &storepb.IdentityProviderConfig{
Config: &storepb.IdentityProviderConfig_Oauth2Config{
Oauth2Config: &storepb.OAuth2Config{
ClientId: oauth2Config.ClientId,
ClientSecret: oauth2Config.ClientSecret,
AuthUrl: oauth2Config.AuthUrl,
TokenUrl: oauth2Config.TokenUrl,
UserInfoUrl: oauth2Config.UserInfoUrl,
Scopes: oauth2Config.Scopes,
FieldMapping: &storepb.FieldMapping{
Identifier: oauth2Config.FieldMapping.Identifier,
DisplayName: oauth2Config.FieldMapping.DisplayName,
Email: oauth2Config.FieldMapping.Email,
},
},
},
}
}
return nil
}
...@@ -16,6 +16,7 @@ const ( ...@@ -16,6 +16,7 @@ const (
ResourceNamePrefix = "resources/" ResourceNamePrefix = "resources/"
InboxNamePrefix = "inboxes/" InboxNamePrefix = "inboxes/"
StorageNamePrefix = "storages/" StorageNamePrefix = "storages/"
IdentityProviderNamePrefix = "identityProviders/"
) )
// GetNameParentTokens returns the tokens from a resource name. // GetNameParentTokens returns the tokens from a resource name.
...@@ -110,3 +111,15 @@ func ExtractStorageIDFromName(name string) (int32, error) { ...@@ -110,3 +111,15 @@ func ExtractStorageIDFromName(name string) (int32, error) {
} }
return id, nil return id, nil
} }
func ExtractIdentityProviderIDFromName(name string) (int32, error) {
tokens, err := GetNameParentTokens(name, IdentityProviderNamePrefix)
if err != nil {
return 0, err
}
id, err := util.ConvertStringToInt32(tokens[0])
if err != nil {
return 0, errors.Errorf("invalid identity provider ID %q", tokens[0])
}
return id, nil
}
...@@ -105,7 +105,7 @@ func convertStorageFromStore(storage *storepb.Storage) *apiv2pb.Storage { ...@@ -105,7 +105,7 @@ func convertStorageFromStore(storage *storepb.Storage) *apiv2pb.Storage {
if storage.Type == storepb.Storage_S3 { if storage.Type == storepb.Storage_S3 {
s3Config := storage.Config.GetS3Config() s3Config := storage.Config.GetS3Config()
temp.Config = &apiv2pb.StorageConfig{ temp.Config = &apiv2pb.StorageConfig{
StorageConfig: &apiv2pb.StorageConfig_S3Config{ Config: &apiv2pb.StorageConfig_S3Config{
S3Config: &apiv2pb.S3Config{ S3Config: &apiv2pb.S3Config{
EndPoint: s3Config.EndPoint, EndPoint: s3Config.EndPoint,
Path: s3Config.Path, Path: s3Config.Path,
......
...@@ -32,6 +32,7 @@ type APIV2Service struct { ...@@ -32,6 +32,7 @@ type APIV2Service struct {
apiv2pb.UnimplementedWebhookServiceServer apiv2pb.UnimplementedWebhookServiceServer
apiv2pb.UnimplementedLinkServiceServer apiv2pb.UnimplementedLinkServiceServer
apiv2pb.UnimplementedStorageServiceServer apiv2pb.UnimplementedStorageServiceServer
apiv2pb.UnimplementedIdentityProviderServiceServer
Secret string Secret string
Profile *profile.Profile Profile *profile.Profile
...@@ -70,6 +71,7 @@ func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store ...@@ -70,6 +71,7 @@ func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store
apiv2pb.RegisterWebhookServiceServer(grpcServer, apiv2Service) apiv2pb.RegisterWebhookServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterLinkServiceServer(grpcServer, apiv2Service) apiv2pb.RegisterLinkServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterStorageServiceServer(grpcServer, apiv2Service) apiv2pb.RegisterStorageServiceServer(grpcServer, apiv2Service)
apiv2pb.RegisterIdentityProviderServiceServer(grpcServer, apiv2Service)
reflection.Register(grpcServer) reflection.Register(grpcServer)
return apiv2Service return apiv2Service
...@@ -129,6 +131,9 @@ func (s *APIV2Service) RegisterGateway(ctx context.Context, e *echo.Echo) error ...@@ -129,6 +131,9 @@ func (s *APIV2Service) RegisterGateway(ctx context.Context, e *echo.Echo) error
if err := apiv2pb.RegisterStorageServiceHandler(context.Background(), gwMux, conn); err != nil { if err := apiv2pb.RegisterStorageServiceHandler(context.Background(), gwMux, conn); err != nil {
return err return err
} }
if err := apiv2pb.RegisterIdentityProviderServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
e.Any("/api/v2/*", echo.WrapHandler(gwMux)) e.Any("/api/v2/*", echo.WrapHandler(gwMux))
// GRPC web proxy. // GRPC web proxy.
......
...@@ -2,29 +2,18 @@ package mysql ...@@ -2,29 +2,18 @@ package mysql
import ( import (
"context" "context"
"encoding/json"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
var configBytes []byte
if create.Type == store.IdentityProviderOAuth2Type {
bytes, err := json.Marshal(create.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
}
placeholders := []string{"?", "?", "?", "?"} placeholders := []string{"?", "?", "?", "?"}
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"} fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
args := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)} args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")" stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ")"
result, err := d.db.ExecContext(ctx, stmt, args...) result, err := d.db.ExecContext(ctx, stmt, args...)
...@@ -58,28 +47,17 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity ...@@ -58,28 +47,17 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
var identityProviders []*store.IdentityProvider var identityProviders []*store.IdentityProvider
for rows.Next() { for rows.Next() {
var identityProvider store.IdentityProvider var identityProvider store.IdentityProvider
var identityProviderConfig string var typeString string
if err := rows.Scan( if err := rows.Scan(
&identityProvider.ID, &identityProvider.ID,
&identityProvider.Name, &identityProvider.Name,
&identityProvider.Type, &typeString,
&identityProvider.IdentifierFilter, &identityProvider.IdentifierFilter,
&identityProviderConfig, &identityProvider.Config,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
if identityProvider.Type == store.IdentityProviderOAuth2Type {
oauth2Config := &store.IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProvider.Config = &store.IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
identityProviders = append(identityProviders, &identityProvider) identityProviders = append(identityProviders, &identityProvider)
} }
...@@ -112,17 +90,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde ...@@ -112,17 +90,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
set, args = append(set, "`identifier_filter` = ?"), append(args, *v) set, args = append(set, "`identifier_filter` = ?"), append(args, *v)
} }
if v := update.Config; v != nil { if v := update.Config; v != nil {
var configBytes []byte set, args = append(set, "`config` = ?"), append(args, *v)
if update.Type == store.IdentityProviderOAuth2Type {
bytes, err := json.Marshal(update.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
}
set, args = append(set, "`config` = ?"), append(args, string(configBytes))
} }
args = append(args, update.ID) args = append(args, update.ID)
......
...@@ -2,28 +2,15 @@ package postgres ...@@ -2,28 +2,15 @@ package postgres
import ( import (
"context" "context"
"encoding/json"
"strings" "strings"
"github.com/pkg/errors" storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
var configBytes []byte
if create.Type == store.IdentityProviderOAuth2Type {
bytes, err := json.Marshal(create.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
}
fields := []string{"name", "type", "identifier_filter", "config"} fields := []string{"name", "type", "identifier_filter", "config"}
args := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)} args := []any{create.Name, create.Type.Type(), create.IdentifierFilter, create.Config}
stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id" stmt := "INSERT INTO idp (" + strings.Join(fields, ", ") + ") VALUES (" + placeholders(len(args)) + ") RETURNING id"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil { if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
return nil, err return nil, err
...@@ -58,28 +45,18 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity ...@@ -58,28 +45,18 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
var identityProviders []*store.IdentityProvider var identityProviders []*store.IdentityProvider
for rows.Next() { for rows.Next() {
var identityProvider store.IdentityProvider var identityProvider store.IdentityProvider
var identityProviderConfig string var typeString string
if err := rows.Scan( if err := rows.Scan(
&identityProvider.ID, &identityProvider.ID,
&identityProvider.Name, &identityProvider.Name,
&identityProvider.Type, &typeString,
&identityProvider.IdentifierFilter, &identityProvider.IdentifierFilter,
&identityProviderConfig, &identityProvider.Config,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
if identityProvider.Type == store.IdentityProviderOAuth2Type { identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
oauth2Config := &store.IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProvider.Config = &store.IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
identityProviders = append(identityProviders, &identityProvider) identityProviders = append(identityProviders, &identityProvider)
} }
...@@ -90,18 +67,6 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity ...@@ -90,18 +67,6 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
return identityProviders, nil return identityProviders, nil
} }
func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) {
list, err := d.ListIdentityProviders(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}
func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) { func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) {
set, args := []string{}, []any{} set, args := []string{}, []any{}
if v := update.Name; v != nil { if v := update.Name; v != nil {
...@@ -111,17 +76,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde ...@@ -111,17 +76,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
set, args = append(set, "identifier_filter = "+placeholder(len(args)+1)), append(args, *v) set, args = append(set, "identifier_filter = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Config; v != nil { if v := update.Config; v != nil {
var configBytes []byte set, args = append(set, "config = "+placeholder(len(args)+1)), append(args, *v)
if update.Type == store.IdentityProviderOAuth2Type {
bytes, err := json.Marshal(update.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
}
set, args = append(set, "config = "+placeholder(len(args)+1)), append(args, string(configBytes))
} }
stmt := ` stmt := `
...@@ -133,29 +88,18 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde ...@@ -133,29 +88,18 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
args = append(args, update.ID) args = append(args, update.ID)
var identityProvider store.IdentityProvider var identityProvider store.IdentityProvider
var identityProviderConfig string var typeString string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&identityProvider.ID, &identityProvider.ID,
&identityProvider.Name, &identityProvider.Name,
&identityProvider.Type, &typeString,
&identityProvider.IdentifierFilter, &identityProvider.IdentifierFilter,
&identityProviderConfig, &identityProvider.Config,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
if identityProvider.Type == store.IdentityProviderOAuth2Type { identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
oauth2Config := &store.IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProvider.Config = &store.IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
return &identityProvider, nil return &identityProvider, nil
} }
......
...@@ -2,30 +2,17 @@ package sqlite ...@@ -2,30 +2,17 @@ package sqlite
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"strings" "strings"
"github.com/pkg/errors" storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) {
var configBytes []byte
if create.Type == store.IdentityProviderOAuth2Type {
bytes, err := json.Marshal(create.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, errors.Errorf("unsupported idp type %s", string(create.Type))
}
placeholders := []string{"?", "?", "?", "?"} placeholders := []string{"?", "?", "?", "?"}
fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"} fields := []string{"`name`", "`type`", "`identifier_filter`", "`config`"}
args := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)} args := []any{create.Name, create.Type.String(), create.IdentifierFilter, create.Config}
stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ") RETURNING `id`" stmt := "INSERT INTO `idp` (" + strings.Join(fields, ", ") + ") VALUES (" + strings.Join(placeholders, ", ") + ") RETURNING `id`"
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil { if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(&create.ID); err != nil {
...@@ -61,28 +48,17 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity ...@@ -61,28 +48,17 @@ func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentity
var identityProviders []*store.IdentityProvider var identityProviders []*store.IdentityProvider
for rows.Next() { for rows.Next() {
var identityProvider store.IdentityProvider var identityProvider store.IdentityProvider
var identityProviderConfig string var typeString string
if err := rows.Scan( if err := rows.Scan(
&identityProvider.ID, &identityProvider.ID,
&identityProvider.Name, &identityProvider.Name,
&identityProvider.Type, &typeString,
&identityProvider.IdentifierFilter, &identityProvider.IdentifierFilter,
&identityProviderConfig, &identityProvider.Config,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
if identityProvider.Type == store.IdentityProviderOAuth2Type {
oauth2Config := &store.IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProvider.Config = &store.IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
identityProviders = append(identityProviders, &identityProvider) identityProviders = append(identityProviders, &identityProvider)
} }
...@@ -102,17 +78,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde ...@@ -102,17 +78,7 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
set, args = append(set, "identifier_filter = ?"), append(args, *v) set, args = append(set, "identifier_filter = ?"), append(args, *v)
} }
if v := update.Config; v != nil { if v := update.Config; v != nil {
var configBytes []byte set, args = append(set, "config = ?"), append(args, *v)
if update.Type == store.IdentityProviderOAuth2Type {
bytes, err := json.Marshal(update.Config.OAuth2Config)
if err != nil {
return nil, err
}
configBytes = bytes
} else {
return nil, errors.Errorf("unsupported idp type %s", string(update.Type))
}
set, args = append(set, "config = ?"), append(args, string(configBytes))
} }
args = append(args, update.ID) args = append(args, update.ID)
...@@ -123,29 +89,17 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde ...@@ -123,29 +89,17 @@ func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIde
RETURNING id, name, type, identifier_filter, config RETURNING id, name, type, identifier_filter, config
` `
var identityProvider store.IdentityProvider var identityProvider store.IdentityProvider
var identityProviderConfig string var typeString string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&identityProvider.ID, &identityProvider.ID,
&identityProvider.Name, &identityProvider.Name,
&identityProvider.Type, &typeString,
&identityProvider.IdentifierFilter, &identityProvider.IdentifierFilter,
&identityProviderConfig, &identityProvider.Config,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
identityProvider.Type = storepb.IdentityProvider_Type(storepb.IdentityProvider_Type_value[typeString])
if identityProvider.Type == store.IdentityProviderOAuth2Type {
oauth2Config := &store.IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err
}
identityProvider.Config = &store.IdentityProviderConfig{
OAuth2Config: oauth2Config,
}
} else {
return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type))
}
return &identityProvider, nil return &identityProvider, nil
} }
......
...@@ -2,44 +2,19 @@ package store ...@@ -2,44 +2,19 @@ package store
import ( import (
"context" "context"
)
type IdentityProviderType string "github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
const ( storepb "github.com/usememos/memos/proto/gen/store"
IdentityProviderOAuth2Type IdentityProviderType = "OAUTH2"
) )
func (t IdentityProviderType) String() string {
return string(t)
}
type IdentityProviderConfig struct {
OAuth2Config *IdentityProviderOAuth2Config
}
type IdentityProviderOAuth2Config struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
AuthURL string `json:"authUrl"`
TokenURL string `json:"tokenUrl"`
UserInfoURL string `json:"userInfoUrl"`
Scopes []string `json:"scopes"`
FieldMapping *FieldMapping `json:"fieldMapping"`
}
type FieldMapping struct {
Identifier string `json:"identifier"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
}
type IdentityProvider struct { type IdentityProvider struct {
ID int32 ID int32
Name string Name string
Type IdentityProviderType Type storepb.IdentityProvider_Type
IdentifierFilter string IdentifierFilter string
Config *IdentityProviderConfig Config string
} }
type FindIdentityProvider struct { type FindIdentityProvider struct {
...@@ -48,64 +23,107 @@ type FindIdentityProvider struct { ...@@ -48,64 +23,107 @@ type FindIdentityProvider struct {
type UpdateIdentityProvider struct { type UpdateIdentityProvider struct {
ID int32 ID int32
Type IdentityProviderType
Name *string Name *string
IdentifierFilter *string IdentifierFilter *string
Config *IdentityProviderConfig Config *string
} }
type DeleteIdentityProvider struct { type DeleteIdentityProvider struct {
ID int32 ID int32
} }
func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) { func (s *Store) CreateIdentityProviderV1(ctx context.Context, create *storepb.IdentityProvider) (*storepb.IdentityProvider, error) {
identityProvider, err := s.driver.CreateIdentityProvider(ctx, create) raw, err := convertIdentityProviderToRaw(create)
if err != nil {
return nil, err
}
identityProviderRaw, err := s.driver.CreateIdentityProvider(ctx, raw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.idpCache.Store(identityProvider.ID, identityProvider) identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
if err != nil {
return nil, err
}
s.idpV1Cache.Store(identityProvider.Id, identityProvider)
return identityProvider, nil return identityProvider, nil
} }
func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) { func (s *Store) ListIdentityProvidersV1(ctx context.Context, find *FindIdentityProvider) ([]*storepb.IdentityProvider, error) {
identityProviders, err := s.driver.ListIdentityProviders(ctx, find) list, err := s.driver.ListIdentityProviders(ctx, find)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, item := range identityProviders { identityProviders := []*storepb.IdentityProvider{}
s.idpCache.Store(item.ID, item) for _, raw := range list {
identityProvider, err := convertIdentityProviderFromRaw(raw)
if err != nil {
return nil, err
}
s.idpV1Cache.Store(identityProvider.Id, identityProvider)
} }
return identityProviders, nil return identityProviders, nil
} }
func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) { func (s *Store) GetIdentityProviderV1(ctx context.Context, find *FindIdentityProvider) (*storepb.IdentityProvider, error) {
if find.ID != nil { if find.ID != nil {
if cache, ok := s.idpCache.Load(*find.ID); ok { if cache, ok := s.idpV1Cache.Load(*find.ID); ok {
return cache.(*IdentityProvider), nil return cache.(*storepb.IdentityProvider), nil
} }
} }
list, err := s.ListIdentityProviders(ctx, find) list, err := s.ListIdentityProvidersV1(ctx, find)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(list) == 0 { if len(list) == 0 {
return nil, nil return nil, nil
} }
if len(list) > 1 {
return nil, errors.Errorf("Found multiple identity providers with ID %d", *find.ID)
}
identityProvider := list[0] identityProvider := list[0]
return identityProvider, nil return identityProvider, nil
} }
func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) { type UpdateIdentityProviderV1 struct {
identityProvider, err := s.driver.UpdateIdentityProvider(ctx, update) ID int32
Type storepb.IdentityProvider_Type
Name *string
IdentifierFilter *string
Config *storepb.IdentityProviderConfig
}
func (s *Store) UpdateIdentityProviderV1(ctx context.Context, update *UpdateIdentityProviderV1) (*storepb.IdentityProvider, error) {
updateRaw := &UpdateIdentityProvider{
ID: update.ID,
}
if update.Name != nil {
updateRaw.Name = update.Name
}
if update.IdentifierFilter != nil {
updateRaw.IdentifierFilter = update.IdentifierFilter
}
if update.Config != nil {
configRaw, err := convertIdentityProviderConfigToRaw(update.Type, update.Config)
if err != nil {
return nil, err
}
updateRaw.Config = &configRaw
}
identityProviderRaw, err := s.driver.UpdateIdentityProvider(ctx, updateRaw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.idpCache.Store(identityProvider.ID, identityProvider) identityProvider, err := convertIdentityProviderFromRaw(identityProviderRaw)
if err != nil {
return nil, err
}
s.idpV1Cache.Store(identityProvider.Id, identityProvider)
return identityProvider, nil return identityProvider, nil
} }
...@@ -118,3 +136,57 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti ...@@ -118,3 +136,57 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti
s.idpCache.Delete(delete.ID) s.idpCache.Delete(delete.ID)
return nil return nil
} }
func convertIdentityProviderFromRaw(raw *IdentityProvider) (*storepb.IdentityProvider, error) {
identityProvider := &storepb.IdentityProvider{
Id: raw.ID,
Name: raw.Name,
Type: raw.Type,
IdentifierFilter: raw.IdentifierFilter,
}
config, err := convertIdentityProviderConfigFromRaw(identityProvider.Type, raw.Config)
if err != nil {
return nil, err
}
identityProvider.Config = config
return identityProvider, nil
}
func convertIdentityProviderToRaw(identityProvider *storepb.IdentityProvider) (*IdentityProvider, error) {
raw := &IdentityProvider{
ID: identityProvider.Id,
Name: identityProvider.Name,
Type: identityProvider.Type,
IdentifierFilter: identityProvider.IdentifierFilter,
}
configRaw, err := convertIdentityProviderConfigToRaw(identityProvider.Type, identityProvider.Config)
if err != nil {
return nil, err
}
raw.Config = configRaw
return raw, nil
}
func convertIdentityProviderConfigFromRaw(identityProviderType storepb.IdentityProvider_Type, raw string) (*storepb.IdentityProviderConfig, error) {
config := &storepb.IdentityProviderConfig{}
if identityProviderType == storepb.IdentityProvider_OAUTH2 {
oauth2Config := &storepb.OAuth2Config{}
if err := protojsonUnmarshaler.Unmarshal([]byte(raw), oauth2Config); err != nil {
return nil, errors.Wrap(err, "Failed to unmarshal OAuth2Config")
}
config.Config = &storepb.IdentityProviderConfig_Oauth2Config{Oauth2Config: oauth2Config}
}
return config, nil
}
func convertIdentityProviderConfigToRaw(identityProviderType storepb.IdentityProvider_Type, config *storepb.IdentityProviderConfig) (string, error) {
raw := ""
if identityProviderType == storepb.IdentityProvider_OAUTH2 {
bytes, err := protojson.Marshal(config.GetOauth2Config())
if err != nil {
return "", errors.Wrap(err, "Failed to marshal OAuth2Config")
}
raw = string(bytes)
}
return raw, nil
}
...@@ -4,7 +4,7 @@ import ( ...@@ -4,7 +4,7 @@ import (
"context" "context"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/encoding/protojson"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
) )
...@@ -65,7 +65,7 @@ func (s *Store) CreateStorageV1(ctx context.Context, create *storepb.Storage) (* ...@@ -65,7 +65,7 @@ func (s *Store) CreateStorageV1(ctx context.Context, create *storepb.Storage) (*
} }
if create.Type == storepb.Storage_S3 { if create.Type == storepb.Storage_S3 {
configBytes, err := proto.Marshal(create.Config.GetS3Config()) configBytes, err := protojson.Marshal(create.Config.GetS3Config())
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to marshal s3 config") return nil, errors.Wrap(err, "failed to marshal s3 config")
} }
...@@ -174,7 +174,7 @@ func convertStorageConfigFromRaw(storageType storepb.Storage_Type, configRaw str ...@@ -174,7 +174,7 @@ func convertStorageConfigFromRaw(storageType storepb.Storage_Type, configRaw str
func convertStorageConfigToRaw(storageType storepb.Storage_Type, config *storepb.StorageConfig) (string, error) { func convertStorageConfigToRaw(storageType storepb.Storage_Type, config *storepb.StorageConfig) (string, error) {
raw := "" raw := ""
if storageType == storepb.Storage_S3 { if storageType == storepb.Storage_S3 {
bytes, err := proto.Marshal(config.GetS3Config()) bytes, err := protojson.Marshal(config.GetS3Config())
if err != nil { if err != nil {
return "", err return "", err
} }
......
...@@ -11,11 +11,11 @@ import ( ...@@ -11,11 +11,11 @@ import (
type Store struct { type Store struct {
Profile *profile.Profile Profile *profile.Profile
driver Driver driver Driver
workspaceSettingCache sync.Map // map[string]*WorkspaceSetting workspaceSettingCache sync.Map // map[string]*storepb.WorkspaceSetting
workspaceSettingV1Cache sync.Map // map[string]*storepb.WorkspaceSetting
userCache sync.Map // map[int]*User userCache sync.Map // map[int]*User
userSettingCache sync.Map // map[string]*UserSetting userSettingCache sync.Map // map[string]*UserSetting
idpCache sync.Map // map[int]*IdentityProvider idpCache sync.Map // map[int]*IdentityProvider
idpV1Cache sync.Map // map[int]*storepb.IdentityProvider
} }
// New creates a new instance of Store. // New creates a new instance of Store.
......
...@@ -32,20 +32,10 @@ func (s *Store) ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSe ...@@ -32,20 +32,10 @@ func (s *Store) ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSe
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, systemSettingMessage := range list {
s.workspaceSettingCache.Store(systemSettingMessage.Name, systemSettingMessage)
}
return list, nil return list, nil
} }
func (s *Store) GetWorkspaceSetting(ctx context.Context, find *FindWorkspaceSetting) (*WorkspaceSetting, error) { func (s *Store) GetWorkspaceSetting(ctx context.Context, find *FindWorkspaceSetting) (*WorkspaceSetting, error) {
if find.Name != "" {
if cache, ok := s.workspaceSettingCache.Load(find.Name); ok {
return cache.(*WorkspaceSetting), nil
}
}
list, err := s.ListWorkspaceSettings(ctx, find) list, err := s.ListWorkspaceSettings(ctx, find)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -56,7 +46,6 @@ func (s *Store) GetWorkspaceSetting(ctx context.Context, find *FindWorkspaceSett ...@@ -56,7 +46,6 @@ func (s *Store) GetWorkspaceSetting(ctx context.Context, find *FindWorkspaceSett
} }
systemSettingMessage := list[0] systemSettingMessage := list[0]
s.workspaceSettingCache.Store(systemSettingMessage.Name, systemSettingMessage)
return systemSettingMessage, nil return systemSettingMessage, nil
} }
...@@ -65,7 +54,6 @@ func (s *Store) DeleteWorkspaceSetting(ctx context.Context, delete *DeleteWorksp ...@@ -65,7 +54,6 @@ func (s *Store) DeleteWorkspaceSetting(ctx context.Context, delete *DeleteWorksp
if err != nil { if err != nil {
return errors.Wrap(err, "Failed to delete workspace setting") return errors.Wrap(err, "Failed to delete workspace setting")
} }
s.workspaceSettingCache.Delete(delete.Name)
return nil return nil
} }
...@@ -101,7 +89,7 @@ func (s *Store) UpsertWorkspaceSettingV1(ctx context.Context, upsert *storepb.Wo ...@@ -101,7 +89,7 @@ func (s *Store) UpsertWorkspaceSettingV1(ctx context.Context, upsert *storepb.Wo
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Failed to convert workspace setting") return nil, errors.Wrap(err, "Failed to convert workspace setting")
} }
s.workspaceSettingV1Cache.Store(workspaceSetting.Key.String(), workspaceSetting) s.workspaceSettingCache.Store(workspaceSetting.Key.String(), workspaceSetting)
return workspaceSetting, nil return workspaceSetting, nil
} }
...@@ -117,14 +105,14 @@ func (s *Store) ListWorkspaceSettingsV1(ctx context.Context, find *FindWorkspace ...@@ -117,14 +105,14 @@ func (s *Store) ListWorkspaceSettingsV1(ctx context.Context, find *FindWorkspace
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Failed to convert workspace setting") return nil, errors.Wrap(err, "Failed to convert workspace setting")
} }
s.workspaceSettingV1Cache.Store(workspaceSetting.Key.String(), workspaceSetting) s.workspaceSettingCache.Store(workspaceSetting.Key.String(), workspaceSetting)
workspaceSettings = append(workspaceSettings, workspaceSetting) workspaceSettings = append(workspaceSettings, workspaceSetting)
} }
return workspaceSettings, nil return workspaceSettings, nil
} }
func (s *Store) GetWorkspaceSettingV1(ctx context.Context, find *FindWorkspaceSetting) (*storepb.WorkspaceSetting, error) { func (s *Store) GetWorkspaceSettingV1(ctx context.Context, find *FindWorkspaceSetting) (*storepb.WorkspaceSetting, error) {
if cache, ok := s.workspaceSettingV1Cache.Load(find.Name); ok { if cache, ok := s.workspaceSettingCache.Load(find.Name); ok {
return cache.(*storepb.WorkspaceSetting), nil return cache.(*storepb.WorkspaceSetting), nil
} }
......
...@@ -6,51 +6,54 @@ import ( ...@@ -6,51 +6,54 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func TestIdentityProviderStore(t *testing.T) { func TestIdentityProviderStore(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProvider{ createdIDP, err := ts.CreateIdentityProviderV1(ctx, &storepb.IdentityProvider{
Name: "GitHub OAuth", Name: "GitHub OAuth",
Type: store.IdentityProviderOAuth2Type, Type: storepb.IdentityProvider_OAUTH2,
IdentifierFilter: "", IdentifierFilter: "",
Config: &store.IdentityProviderConfig{ Config: &storepb.IdentityProviderConfig{
OAuth2Config: &store.IdentityProviderOAuth2Config{ Config: &storepb.IdentityProviderConfig_Oauth2Config{
ClientID: "client_id", Oauth2Config: &storepb.OAuth2Config{
ClientId: "client_id",
ClientSecret: "client_secret", ClientSecret: "client_secret",
AuthURL: "https://github.com/auth", AuthUrl: "https://github.com/auth",
TokenURL: "https://github.com/token", TokenUrl: "https://github.com/token",
UserInfoURL: "https://github.com/user", UserInfoUrl: "https://github.com/user",
Scopes: []string{"login"}, Scopes: []string{"login"},
FieldMapping: &store.FieldMapping{ FieldMapping: &storepb.FieldMapping{
Identifier: "login", Identifier: "login",
DisplayName: "name", DisplayName: "name",
Email: "emai", Email: "email",
},
}, },
}, },
}, },
}) })
require.NoError(t, err) require.NoError(t, err)
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ idp, err := ts.GetIdentityProviderV1(ctx, &store.FindIdentityProvider{
ID: &createdIDP.ID, ID: &createdIDP.Id,
}) })
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, idp) require.NotNil(t, idp)
require.Equal(t, createdIDP, idp) require.Equal(t, createdIDP, idp)
newName := "My GitHub OAuth" newName := "My GitHub OAuth"
updatedIdp, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProvider{ updatedIdp, err := ts.UpdateIdentityProviderV1(ctx, &store.UpdateIdentityProviderV1{
ID: idp.ID, ID: idp.Id,
Name: &newName, Name: &newName,
}) })
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, newName, updatedIdp.Name) require.Equal(t, newName, updatedIdp.Name)
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{
ID: idp.ID, ID: idp.Id,
}) })
require.NoError(t, err) require.NoError(t, err)
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{}) idpList, err := ts.ListIdentityProvidersV1(ctx, &store.FindIdentityProvider{})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(idpList)) require.Equal(t, 0, len(idpList))
ts.Close() ts.Close()
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
"@matejmazur/react-katex": "^3.1.3", "@matejmazur/react-katex": "^3.1.3",
"@mui/joy": "5.0.0-beta.32", "@mui/joy": "5.0.0-beta.32",
"@reduxjs/toolkit": "^2.2.3", "@reduxjs/toolkit": "^2.2.3",
"axios": "^1.6.8",
"classnames": "^2.5.1", "classnames": "^2.5.1",
"copy-to-clipboard": "^3.3.3", "copy-to-clipboard": "^3.3.3",
"fuse.js": "^7.0.0", "fuse.js": "^7.0.0",
......
...@@ -23,9 +23,6 @@ dependencies: ...@@ -23,9 +23,6 @@ dependencies:
'@reduxjs/toolkit': '@reduxjs/toolkit':
specifier: ^2.2.3 specifier: ^2.2.3
version: 2.2.3(react-redux@9.1.0)(react@18.2.0) version: 2.2.3(react-redux@9.1.0)(react@18.2.0)
axios:
specifier: ^1.6.8
version: 1.6.8
classnames: classnames:
specifier: ^2.5.1 specifier: ^2.5.1
version: 2.5.1 version: 2.5.1
...@@ -2048,10 +2045,6 @@ packages: ...@@ -2048,10 +2045,6 @@ packages:
is-shared-array-buffer: 1.0.3 is-shared-array-buffer: 1.0.3
dev: true dev: true
/asynckit@0.4.0:
resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==}
dev: false
/autoprefixer@10.4.19(postcss@8.4.38): /autoprefixer@10.4.19(postcss@8.4.38):
resolution: {integrity: sha512-BaENR2+zBZ8xXhM4pUaKUxlVdxZ0EZhjvbopwnXmxRUfqDmwSpC2lAi/QXvx7NRdPCo1WKEcEF6mV64si1z4Ew==} resolution: {integrity: sha512-BaENR2+zBZ8xXhM4pUaKUxlVdxZ0EZhjvbopwnXmxRUfqDmwSpC2lAi/QXvx7NRdPCo1WKEcEF6mV64si1z4Ew==}
engines: {node: ^10 || ^12 || >=14} engines: {node: ^10 || ^12 || >=14}
...@@ -2075,16 +2068,6 @@ packages: ...@@ -2075,16 +2068,6 @@ packages:
possible-typed-array-names: 1.0.0 possible-typed-array-names: 1.0.0
dev: true dev: true
/axios@1.6.8:
resolution: {integrity: sha512-v/ZHtJDU39mDpyBoFVkETcd/uNdxrWRrg3bKpOKzXFA6Bvqopts6ALSMU3y6ijYxbw2B+wPrIv46egTzJXCLGQ==}
dependencies:
follow-redirects: 1.15.6
form-data: 4.0.0
proxy-from-env: 1.1.0
transitivePeerDependencies:
- debug
dev: false
/babel-plugin-macros@3.1.0: /babel-plugin-macros@3.1.0:
resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==} resolution: {integrity: sha512-Cg7TFGpIr01vOQNODXOOaGz2NpCU5gl8x1qJFbb6hbZxR7XrcE2vtbAsTAbJ7/xwJtUuJEw8K8Zr/AE0LHlesg==}
engines: {node: '>=10', npm: '>=6'} engines: {node: '>=10', npm: '>=6'}
...@@ -2216,13 +2199,6 @@ packages: ...@@ -2216,13 +2199,6 @@ packages:
/color-name@1.1.4: /color-name@1.1.4:
resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==}
/combined-stream@1.0.8:
resolution: {integrity: sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==}
engines: {node: '>= 0.8'}
dependencies:
delayed-stream: 1.0.0
dev: false
/commander@4.1.1: /commander@4.1.1:
resolution: {integrity: sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==} resolution: {integrity: sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==}
engines: {node: '>= 6'} engines: {node: '>= 6'}
...@@ -2681,11 +2657,6 @@ packages: ...@@ -2681,11 +2657,6 @@ packages:
robust-predicates: 3.0.2 robust-predicates: 3.0.2
dev: false dev: false
/delayed-stream@1.0.0:
resolution: {integrity: sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==}
engines: {node: '>=0.4.0'}
dev: false
/dequal@2.0.3: /dequal@2.0.3:
resolution: {integrity: sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==} resolution: {integrity: sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==}
engines: {node: '>=6'} engines: {node: '>=6'}
...@@ -3159,16 +3130,6 @@ packages: ...@@ -3159,16 +3130,6 @@ packages:
resolution: {integrity: sha512-X8cqMLLie7KsNUDSdzeN8FYK9rEt4Dt67OsG/DNGnYTSDBG4uFAJFBnUeiV+zCVAvwFy56IjM9sH51jVaEhNxw==} resolution: {integrity: sha512-X8cqMLLie7KsNUDSdzeN8FYK9rEt4Dt67OsG/DNGnYTSDBG4uFAJFBnUeiV+zCVAvwFy56IjM9sH51jVaEhNxw==}
dev: true dev: true
/follow-redirects@1.15.6:
resolution: {integrity: sha512-wWN62YITEaOpSK584EZXJafH1AGpO8RVgElfkuXbTOrPX4fIfOyEpW/CsiNd8JdYrAoOvafRTOEnvsO++qCqFA==}
engines: {node: '>=4.0'}
peerDependencies:
debug: '*'
peerDependenciesMeta:
debug:
optional: true
dev: false
/for-each@0.3.3: /for-each@0.3.3:
resolution: {integrity: sha512-jqYfLp7mo9vIyQf8ykW2v7A+2N4QjeCeI5+Dz9XraiO1ign81wjiH7Fb9vSOWvQfNtmSa4H2RoQTrrXivdUZmw==} resolution: {integrity: sha512-jqYfLp7mo9vIyQf8ykW2v7A+2N4QjeCeI5+Dz9XraiO1ign81wjiH7Fb9vSOWvQfNtmSa4H2RoQTrrXivdUZmw==}
dependencies: dependencies:
...@@ -3183,15 +3144,6 @@ packages: ...@@ -3183,15 +3144,6 @@ packages:
signal-exit: 4.1.0 signal-exit: 4.1.0
dev: false dev: false
/form-data@4.0.0:
resolution: {integrity: sha512-ETEklSGi5t0QMZuiXoA/Q6vcnxcLQP5vdugSpuAyi6SVGi2clPPp+xgEhuMaHC+zGgn31Kd235W35f7Hykkaww==}
engines: {node: '>= 6'}
dependencies:
asynckit: 0.4.0
combined-stream: 1.0.8
mime-types: 2.1.35
dev: false
/fraction.js@4.3.7: /fraction.js@4.3.7:
resolution: {integrity: sha512-ZsDfxO51wGAXREY55a7la9LScWpwv9RxIrYABrlvOFBlH/ShPnrtsXeuUIfXKKOVicNxQ+o8JTbJvjS4M89yew==} resolution: {integrity: sha512-ZsDfxO51wGAXREY55a7la9LScWpwv9RxIrYABrlvOFBlH/ShPnrtsXeuUIfXKKOVicNxQ+o8JTbJvjS4M89yew==}
dev: true dev: true
...@@ -4142,18 +4094,6 @@ packages: ...@@ -4142,18 +4094,6 @@ packages:
braces: 3.0.2 braces: 3.0.2
picomatch: 2.3.1 picomatch: 2.3.1
/mime-db@1.52.0:
resolution: {integrity: sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==}
engines: {node: '>= 0.6'}
dev: false
/mime-types@2.1.35:
resolution: {integrity: sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==}
engines: {node: '>= 0.6'}
dependencies:
mime-db: 1.52.0
dev: false
/mime@1.6.0: /mime@1.6.0:
resolution: {integrity: sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==} resolution: {integrity: sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==}
engines: {node: '>=4'} engines: {node: '>=4'}
...@@ -4564,10 +4504,6 @@ packages: ...@@ -4564,10 +4504,6 @@ packages:
long: 5.2.3 long: 5.2.3
dev: true dev: true
/proxy-from-env@1.1.0:
resolution: {integrity: sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==}
dev: false
/prr@1.0.1: /prr@1.0.1:
resolution: {integrity: sha512-yPw4Sng1gWghHQWj0B3ZggWUm4qVbPwPFcRG8KyxiU7J2OHFSoEHKS+EZ3fv5l1t9CyCiop6l/ZYeWbrgoQejw==} resolution: {integrity: sha512-yPw4Sng1gWghHQWj0B3ZggWUm4qVbPwPFcRG8KyxiU7J2OHFSoEHKS+EZ3fv5l1t9CyCiop6l/ZYeWbrgoQejw==}
requiresBuild: true requiresBuild: true
......
import { Button, Divider, IconButton, Input, Option, Select, Typography } from "@mui/joy"; import { Button, Divider, IconButton, Input, Option, Select, Typography } from "@mui/joy";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { toast } from "react-hot-toast"; import { toast } from "react-hot-toast";
import * as api from "@/helpers/api"; import { identityProviderServiceClient } from "@/grpcweb";
import { UNKNOWN_ID } from "@/helpers/consts";
import { absolutifyLink } from "@/helpers/utils"; import { absolutifyLink } from "@/helpers/utils";
import { FieldMapping, IdentityProvider, IdentityProvider_Type, OAuth2Config } from "@/types/proto/api/v2/idp_service";
import { useTranslate } from "@/utils/i18n"; import { useTranslate } from "@/utils/i18n";
import { generateDialog } from "./Dialog"; import { generateDialog } from "./Dialog";
import Icon from "./Icon"; import Icon from "./Icon";
const templateList: IdentityProvider[] = [ const templateList: IdentityProvider[] = [
{ {
id: UNKNOWN_ID, name: "",
name: "GitHub", title: "GitHub",
type: "OAUTH2", type: IdentityProvider_Type.OAUTH2,
identifierFilter: "", identifierFilter: "",
config: { config: {
oauth2Config: { oauth2Config: {
...@@ -31,9 +31,9 @@ const templateList: IdentityProvider[] = [ ...@@ -31,9 +31,9 @@ const templateList: IdentityProvider[] = [
}, },
}, },
{ {
id: UNKNOWN_ID, name: "",
name: "GitLab", title: "GitLab",
type: "OAUTH2", type: IdentityProvider_Type.OAUTH2,
identifierFilter: "", identifierFilter: "",
config: { config: {
oauth2Config: { oauth2Config: {
...@@ -52,9 +52,9 @@ const templateList: IdentityProvider[] = [ ...@@ -52,9 +52,9 @@ const templateList: IdentityProvider[] = [
}, },
}, },
{ {
id: UNKNOWN_ID, name: "",
name: "Google", title: "Google",
type: "OAUTH2", type: IdentityProvider_Type.OAUTH2,
identifierFilter: "", identifierFilter: "",
config: { config: {
oauth2Config: { oauth2Config: {
...@@ -73,9 +73,9 @@ const templateList: IdentityProvider[] = [ ...@@ -73,9 +73,9 @@ const templateList: IdentityProvider[] = [
}, },
}, },
{ {
id: UNKNOWN_ID, name: "",
name: "Custom", title: "Custom",
type: "OAUTH2", type: IdentityProvider_Type.OAUTH2,
identifierFilter: "", identifierFilter: "",
config: { config: {
oauth2Config: { oauth2Config: {
...@@ -108,8 +108,8 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => { ...@@ -108,8 +108,8 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => {
name: "", name: "",
identifierFilter: "", identifierFilter: "",
}); });
const [type, setType] = useState<IdentityProviderType>("OAUTH2"); const [type, setType] = useState<IdentityProvider_Type>(IdentityProvider_Type.OAUTH2);
const [oauth2Config, setOAuth2Config] = useState<IdentityProviderOAuth2Config>({ const [oauth2Config, setOAuth2Config] = useState<OAuth2Config>({
clientId: "", clientId: "",
clientSecret: "", clientSecret: "",
authUrl: "", authUrl: "",
...@@ -133,9 +133,10 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => { ...@@ -133,9 +133,10 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => {
identifierFilter: identityProvider.identifierFilter, identifierFilter: identityProvider.identifierFilter,
}); });
setType(identityProvider.type); setType(identityProvider.type);
if (identityProvider.type === "OAUTH2") { if (identityProvider.type === IdentityProvider_Type.OAUTH2) {
setOAuth2Config(identityProvider.config.oauth2Config); const oauth2Config = OAuth2Config.fromPartial(identityProvider.config?.oauth2Config || {});
setOAuth2Scopes(identityProvider.config.oauth2Config.scopes.join(" ")); setOAuth2Config(oauth2Config);
setOAuth2Scopes(oauth2Config.scopes.join(" "));
} }
} }
}, []); }, []);
...@@ -145,16 +146,17 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => { ...@@ -145,16 +146,17 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => {
return; return;
} }
const template = templateList.find((t) => t.name === selectedTemplate); const template = templateList.find((t) => t.title === selectedTemplate);
if (template) { if (template) {
setBasicInfo({ setBasicInfo({
name: template.name, name: template.name,
identifierFilter: template.identifierFilter, identifierFilter: template.identifierFilter,
}); });
setType(template.type); setType(template.type);
if (template.type === "OAUTH2") { if (template.type === IdentityProvider_Type.OAUTH2) {
setOAuth2Config(template.config.oauth2Config); const oauth2Config = OAuth2Config.fromPartial(template.config?.oauth2Config || {});
setOAuth2Scopes(template.config.oauth2Config.scopes.join(" ")); setOAuth2Config(oauth2Config);
setOAuth2Scopes(oauth2Config.scopes.join(" "));
} }
} }
}, [selectedTemplate]); }, [selectedTemplate]);
...@@ -174,7 +176,7 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => { ...@@ -174,7 +176,7 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => {
oauth2Config.tokenUrl === "" || oauth2Config.tokenUrl === "" ||
oauth2Config.userInfoUrl === "" || oauth2Config.userInfoUrl === "" ||
oauth2Scopes === "" || oauth2Scopes === "" ||
oauth2Config.fieldMapping.identifier === "" oauth2Config.fieldMapping?.identifier === ""
) { ) {
return false; return false;
} }
...@@ -191,7 +193,8 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => { ...@@ -191,7 +193,8 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => {
const handleConfirmBtnClick = async () => { const handleConfirmBtnClick = async () => {
try { try {
if (isCreating) { if (isCreating) {
await api.createIdentityProvider({ await identityProviderServiceClient.createIdentityProvider({
identityProvider: {
...basicInfo, ...basicInfo,
type: type, type: type,
config: { config: {
...@@ -200,19 +203,22 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => { ...@@ -200,19 +203,22 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => {
scopes: oauth2Scopes.split(" "), scopes: oauth2Scopes.split(" "),
}, },
}, },
},
}); });
toast.success(t("setting.sso-section.sso-created", { name: basicInfo.name })); toast.success(t("setting.sso-section.sso-created", { name: basicInfo.name }));
} else { } else {
await api.patchIdentityProvider({ await identityProviderServiceClient.updateIdentityProvider({
id: identityProvider.id, identityProvider: {
type: type,
...basicInfo, ...basicInfo,
type: type,
config: { config: {
oauth2Config: { oauth2Config: {
...oauth2Config, ...oauth2Config,
scopes: oauth2Scopes.split(" "), scopes: oauth2Scopes.split(" "),
}, },
}, },
},
updateMask: ["title", "identifierFilter", "config"],
}); });
toast.success(t("setting.sso-section.sso-updated", { name: basicInfo.name })); toast.success(t("setting.sso-section.sso-updated", { name: basicInfo.name }));
} }
...@@ -226,7 +232,7 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => { ...@@ -226,7 +232,7 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => {
destroy(); destroy();
}; };
const setPartialOAuth2Config = (state: Partial<IdentityProviderOAuth2Config>) => { const setPartialOAuth2Config = (state: Partial<OAuth2Config>) => {
setOAuth2Config({ setOAuth2Config({
...oauth2Config, ...oauth2Config,
...state, ...state,
...@@ -259,8 +265,8 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => { ...@@ -259,8 +265,8 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => {
</Typography> </Typography>
<Select className="mb-1 h-auto w-full" value={selectedTemplate} onChange={(_, e) => setSelectedTemplate(e ?? selectedTemplate)}> <Select className="mb-1 h-auto w-full" value={selectedTemplate} onChange={(_, e) => setSelectedTemplate(e ?? selectedTemplate)}>
{templateList.map((template) => ( {templateList.map((template) => (
<Option key={template.name} value={template.name}> <Option key={template.title} value={template.title}>
{template.name} {template.title}
</Option> </Option>
))} ))}
</Select> </Select>
...@@ -380,8 +386,10 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => { ...@@ -380,8 +386,10 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => {
<Input <Input
className="mb-2" className="mb-2"
placeholder={t("setting.sso-section.identifier")} placeholder={t("setting.sso-section.identifier")}
value={oauth2Config.fieldMapping.identifier} value={oauth2Config.fieldMapping!.identifier}
onChange={(e) => setPartialOAuth2Config({ fieldMapping: { ...oauth2Config.fieldMapping, identifier: e.target.value } })} onChange={(e) =>
setPartialOAuth2Config({ fieldMapping: { ...oauth2Config.fieldMapping, identifier: e.target.value } as FieldMapping })
}
fullWidth fullWidth
/> />
<Typography className="!mb-1" level="body-md"> <Typography className="!mb-1" level="body-md">
...@@ -390,8 +398,10 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => { ...@@ -390,8 +398,10 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => {
<Input <Input
className="mb-2" className="mb-2"
placeholder={t("setting.sso-section.display-name")} placeholder={t("setting.sso-section.display-name")}
value={oauth2Config.fieldMapping.displayName} value={oauth2Config.fieldMapping!.displayName}
onChange={(e) => setPartialOAuth2Config({ fieldMapping: { ...oauth2Config.fieldMapping, displayName: e.target.value } })} onChange={(e) =>
setPartialOAuth2Config({ fieldMapping: { ...oauth2Config.fieldMapping, displayName: e.target.value } as FieldMapping })
}
fullWidth fullWidth
/> />
<Typography className="!mb-1" level="body-md"> <Typography className="!mb-1" level="body-md">
...@@ -400,8 +410,10 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => { ...@@ -400,8 +410,10 @@ const CreateIdentityProviderDialog: React.FC<Props> = (props: Props) => {
<Input <Input
className="mb-2" className="mb-2"
placeholder={t("common.email")} placeholder={t("common.email")}
value={oauth2Config.fieldMapping.email} value={oauth2Config.fieldMapping!.email}
onChange={(e) => setPartialOAuth2Config({ fieldMapping: { ...oauth2Config.fieldMapping, email: e.target.value } })} onChange={(e) =>
setPartialOAuth2Config({ fieldMapping: { ...oauth2Config.fieldMapping, email: e.target.value } as FieldMapping })
}
fullWidth fullWidth
/> />
</> </>
......
...@@ -2,7 +2,8 @@ import { Button, Divider, Dropdown, List, ListItem, Menu, MenuButton, MenuItem } ...@@ -2,7 +2,8 @@ import { Button, Divider, Dropdown, List, ListItem, Menu, MenuButton, MenuItem }
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { toast } from "react-hot-toast"; import { toast } from "react-hot-toast";
import { Link } from "react-router-dom"; import { Link } from "react-router-dom";
import * as api from "@/helpers/api"; import { identityProviderServiceClient } from "@/grpcweb";
import { IdentityProvider } from "@/types/proto/api/v2/idp_service";
import { useTranslate } from "@/utils/i18n"; import { useTranslate } from "@/utils/i18n";
import showCreateIdentityProviderDialog from "../CreateIdentityProviderDialog"; import showCreateIdentityProviderDialog from "../CreateIdentityProviderDialog";
import { showCommonDialog } from "../Dialog/CommonDialog"; import { showCommonDialog } from "../Dialog/CommonDialog";
...@@ -18,8 +19,8 @@ const SSOSection = () => { ...@@ -18,8 +19,8 @@ const SSOSection = () => {
}, []); }, []);
const fetchIdentityProviderList = async () => { const fetchIdentityProviderList = async () => {
const { data: identityProviderList } = await api.getIdentityProviderList(); const { identityProviders } = await identityProviderServiceClient.listIdentityProviders({});
setIdentityProviderList(identityProviderList); setIdentityProviderList(identityProviders);
}; };
const handleDeleteIdentityProvider = async (identityProvider: IdentityProvider) => { const handleDeleteIdentityProvider = async (identityProvider: IdentityProvider) => {
...@@ -32,10 +33,10 @@ const SSOSection = () => { ...@@ -32,10 +33,10 @@ const SSOSection = () => {
dialogName: "delete-identity-provider-dialog", dialogName: "delete-identity-provider-dialog",
onConfirm: async () => { onConfirm: async () => {
try { try {
await api.deleteIdentityProvider(identityProvider.id); await identityProviderServiceClient.deleteIdentityProvider({ name: identityProvider.name });
} catch (error: any) { } catch (error: any) {
console.error(error); console.error(error);
toast.error(error.response.data.message); toast.error(error.details);
} }
await fetchIdentityProviderList(); await fetchIdentityProviderList();
}, },
...@@ -54,12 +55,12 @@ const SSOSection = () => { ...@@ -54,12 +55,12 @@ const SSOSection = () => {
<Divider /> <Divider />
{identityProviderList.map((identityProvider) => ( {identityProviderList.map((identityProvider) => (
<div <div
key={identityProvider.id} key={identityProvider.name}
className="py-2 w-full border-b last:border-b dark:border-zinc-700 flex flex-row items-center justify-between" className="py-2 w-full border-b last:border-b dark:border-zinc-700 flex flex-row items-center justify-between"
> >
<div className="flex flex-row items-center"> <div className="flex flex-row items-center">
<p className="ml-2"> <p className="ml-2">
{identityProvider.name} {identityProvider.title}
<span className="text-sm ml-1 opacity-40">({identityProvider.type})</span> <span className="text-sm ml-1 opacity-40">({identityProvider.type})</span>
</p> </p>
</div> </div>
......
import { createChannel, createClientFactory, FetchTransport } from "nice-grpc-web"; import { createChannel, createClientFactory, FetchTransport } from "nice-grpc-web";
import { ActivityServiceDefinition } from "./types/proto/api/v2/activity_service"; import { ActivityServiceDefinition } from "./types/proto/api/v2/activity_service";
import { AuthServiceDefinition } from "./types/proto/api/v2/auth_service"; import { AuthServiceDefinition } from "./types/proto/api/v2/auth_service";
import { IdentityProviderServiceDefinition } from "./types/proto/api/v2/idp_service";
import { InboxServiceDefinition } from "./types/proto/api/v2/inbox_service"; import { InboxServiceDefinition } from "./types/proto/api/v2/inbox_service";
import { LinkServiceDefinition } from "./types/proto/api/v2/link_service"; import { LinkServiceDefinition } from "./types/proto/api/v2/link_service";
import { MemoServiceDefinition } from "./types/proto/api/v2/memo_service"; import { MemoServiceDefinition } from "./types/proto/api/v2/memo_service";
...@@ -44,3 +45,5 @@ export const webhookServiceClient = clientFactory.create(WebhookServiceDefinitio ...@@ -44,3 +45,5 @@ export const webhookServiceClient = clientFactory.create(WebhookServiceDefinitio
export const linkServiceClient = clientFactory.create(LinkServiceDefinition, channel); export const linkServiceClient = clientFactory.create(LinkServiceDefinition, channel);
export const storageServiceClient = clientFactory.create(StorageServiceDefinition, channel); export const storageServiceClient = clientFactory.create(StorageServiceDefinition, channel);
export const identityProviderServiceClient = clientFactory.create(IdentityProviderServiceDefinition, channel);
import axios from "axios";
axios.defaults.baseURL = import.meta.env.VITE_API_BASE_URL || "";
axios.defaults.withCredentials = true;
export function getIdentityProviderList() {
return axios.get<IdentityProvider[]>(`/api/v1/idp`);
}
export function createIdentityProvider(identityProviderCreate: IdentityProviderCreate) {
return axios.post<IdentityProvider>(`/api/v1/idp`, identityProviderCreate);
}
export function patchIdentityProvider(identityProviderPatch: IdentityProviderPatch) {
return axios.patch<IdentityProvider>(`/api/v1/idp/${identityProviderPatch.id}`, identityProviderPatch);
}
export function deleteIdentityProvider(id: IdentityProviderId) {
return axios.delete(`/api/v1/idp/${id}`);
}
...@@ -5,13 +5,13 @@ import { toast } from "react-hot-toast"; ...@@ -5,13 +5,13 @@ import { toast } from "react-hot-toast";
import { Link } from "react-router-dom"; import { Link } from "react-router-dom";
import AppearanceSelect from "@/components/AppearanceSelect"; import AppearanceSelect from "@/components/AppearanceSelect";
import LocaleSelect from "@/components/LocaleSelect"; import LocaleSelect from "@/components/LocaleSelect";
import { authServiceClient } from "@/grpcweb"; import { authServiceClient, identityProviderServiceClient } from "@/grpcweb";
import * as api from "@/helpers/api";
import { absolutifyLink } from "@/helpers/utils"; import { absolutifyLink } from "@/helpers/utils";
import useLoading from "@/hooks/useLoading"; import useLoading from "@/hooks/useLoading";
import useNavigateTo from "@/hooks/useNavigateTo"; import useNavigateTo from "@/hooks/useNavigateTo";
import { useCommonContext } from "@/layouts/CommonContextProvider"; import { useCommonContext } from "@/layouts/CommonContextProvider";
import { useUserStore, useWorkspaceSettingStore } from "@/store/v1"; import { extractIdentityProviderIdFromName, useUserStore, useWorkspaceSettingStore } from "@/store/v1";
import { IdentityProvider, IdentityProvider_Type } from "@/types/proto/api/v2/idp_service";
import { WorkspaceGeneralSetting } from "@/types/proto/api/v2/workspace_setting_service"; import { WorkspaceGeneralSetting } from "@/types/proto/api/v2/workspace_setting_service";
import { WorkspaceSettingKey } from "@/types/proto/store/workspace_setting"; import { WorkspaceSettingKey } from "@/types/proto/store/workspace_setting";
import { useTranslate } from "@/utils/i18n"; import { useTranslate } from "@/utils/i18n";
...@@ -33,8 +33,8 @@ const SignIn = () => { ...@@ -33,8 +33,8 @@ const SignIn = () => {
useEffect(() => { useEffect(() => {
const fetchIdentityProviderList = async () => { const fetchIdentityProviderList = async () => {
const { data: identityProviderList } = await api.getIdentityProviderList(); const { identityProviders } = await identityProviderServiceClient.listIdentityProviders({});
setIdentityProviderList(identityProviderList); setIdentityProviderList(identityProviders);
}; };
fetchIdentityProviderList(); fetchIdentityProviderList();
}, []); }, []);
...@@ -95,10 +95,14 @@ const SignIn = () => { ...@@ -95,10 +95,14 @@ const SignIn = () => {
}; };
const handleSignInWithIdentityProvider = async (identityProvider: IdentityProvider) => { const handleSignInWithIdentityProvider = async (identityProvider: IdentityProvider) => {
const stateQueryParameter = `auth.signin.${identityProvider.name}-${identityProvider.id}`; const stateQueryParameter = `auth.signin.${identityProvider.title}-${extractIdentityProviderIdFromName(identityProvider.name)}`;
if (identityProvider.type === "OAUTH2") { if (identityProvider.type === IdentityProvider_Type.OAUTH2) {
const redirectUri = absolutifyLink("/auth/callback"); const redirectUri = absolutifyLink("/auth/callback");
const oauth2Config = identityProvider.config.oauth2Config; const oauth2Config = identityProvider.config?.oauth2Config;
if (!oauth2Config) {
toast.error("Identity provider configuration is invalid.");
return;
}
const authUrl = `${oauth2Config.authUrl}?client_id=${ const authUrl = `${oauth2Config.authUrl}?client_id=${
oauth2Config.clientId oauth2Config.clientId
}&redirect_uri=${redirectUri}&state=${stateQueryParameter}&response_type=code&scope=${encodeURIComponent( }&redirect_uri=${redirectUri}&state=${stateQueryParameter}&response_type=code&scope=${encodeURIComponent(
...@@ -185,7 +189,7 @@ const SignIn = () => { ...@@ -185,7 +189,7 @@ const SignIn = () => {
<div className="w-full flex flex-col space-y-2"> <div className="w-full flex flex-col space-y-2">
{identityProviderList.map((identityProvider) => ( {identityProviderList.map((identityProvider) => (
<Button <Button
key={identityProvider.id} key={identityProvider.name}
variant="outlined" variant="outlined"
color="neutral" color="neutral"
className="w-full" className="w-full"
......
export const WorkspaceSettingPrefix = "settings/"; export const WorkspaceSettingPrefix = "settings/";
export const UserNamePrefix = "users/"; export const UserNamePrefix = "users/";
export const MemoNamePrefix = "memos/"; export const MemoNamePrefix = "memos/";
export const IdentityProviderNamePrefix = "identityProviders/";
export const extractMemoIdFromName = (name: string) => { export const extractMemoIdFromName = (name: string) => {
return parseInt(name.split(MemoNamePrefix).pop() || "", 10); return parseInt(name.split(MemoNamePrefix).pop() || "", 10);
}; };
export const extractIdentityProviderIdFromName = (name: string) => {
return parseInt(name.split(IdentityProviderNamePrefix).pop() || "", 10);
};
type IdentityProviderId = number;
type IdentityProviderType = "OAUTH2";
interface FieldMapping {
identifier: string;
displayName: string;
email: string;
}
interface IdentityProviderOAuth2Config {
clientId: string;
clientSecret: string;
authUrl: string;
tokenUrl: string;
userInfoUrl: string;
scopes: string[];
fieldMapping: FieldMapping;
}
interface IdentityProviderConfig {
oauth2Config: IdentityProviderOAuth2Config;
}
interface IdentityProvider {
id: IdentityProviderId;
name: string;
type: IdentityProviderType;
identifierFilter: string;
config: IdentityProviderConfig;
}
interface IdentityProviderCreate {
name: string;
type: IdentityProviderType;
identifierFilter: string;
config: IdentityProviderConfig;
}
interface IdentityProviderPatch {
id: IdentityProviderId;
type: IdentityProviderType;
name?: string;
identifierFilter?: string;
config?: IdentityProviderConfig;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment