Unverified Commit b34aded3 authored by boojack's avatar boojack Committed by GitHub

refactor: migration idp api (#1842)

* refactor: migration idp api

* chore: update
parent 4ed9a3a0
package api
type IdentityProviderType string
const (
IdentityProviderOAuth2 IdentityProviderType = "OAUTH2"
)
type IdentityProviderConfig struct {
OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"`
}
type IdentityProviderOAuth2Config struct {
ClientID string `json:"clientId"`
ClientSecret string `json:"clientSecret"`
AuthURL string `json:"authUrl"`
TokenURL string `json:"tokenUrl"`
UserInfoURL string `json:"userInfoUrl"`
Scopes []string `json:"scopes"`
FieldMapping *FieldMapping `json:"fieldMapping"`
}
type FieldMapping struct {
Identifier string `json:"identifier"`
DisplayName string `json:"displayName"`
Email string `json:"email"`
}
type IdentityProvider struct {
ID int `json:"id"`
Name string `json:"name"`
Type IdentityProviderType `json:"type"`
IdentifierFilter string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
type IdentityProviderCreate struct {
Name string `json:"name"`
Type IdentityProviderType `json:"type"`
IdentifierFilter string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
type IdentityProviderFind struct {
ID *int
}
type IdentityProviderPatch struct {
ID int
Type IdentityProviderType `json:"type"`
Name *string `json:"name"`
IdentifierFilter *string `json:"identifierFilter"`
Config *IdentityProviderConfig `json:"config"`
}
type IdentityProviderDelete struct {
ID int
}
package v1
// UnknownID is the ID for unknowns.
const UnknownID = -1
// RowStatus is the status for a row.
type RowStatus string
const (
// Normal is the status for a normal row.
Normal RowStatus = "NORMAL"
// Archived is the status for an archived row.
Archived RowStatus = "ARCHIVED"
)
func (e RowStatus) String() string {
switch e {
case Normal:
return "NORMAL"
case Archived:
return "ARCHIVED"
}
return ""
}
This diff is collapsed.
package v1
// Role is the type of a role.
type Role string
const (
// Host is the HOST role.
Host Role = "HOST"
// Admin is the ADMIN role.
Admin Role = "ADMIN"
// NormalUser is the USER role.
NormalUser Role = "USER"
)
func (e Role) String() string {
switch e {
case Host:
return "HOST"
case Admin:
return "ADMIN"
case NormalUser:
return "USER"
}
return "USER"
}
......@@ -6,6 +6,17 @@ import (
"github.com/usememos/memos/store"
)
const (
// Context section
// The key name used to store user id in the context
// user id is extracted from the jwt token subject field.
userIDContextKey = "user-id"
)
func getUserIDContextKey() string {
return userIDContextKey
}
type APIV1Service struct {
Secret string
Profile *profile.Profile
......@@ -20,8 +31,8 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store
}
}
func (s *APIV1Service) Register(e *echo.Echo) {
apiV1Group := e.Group("/api/v1")
func (s *APIV1Service) Register(apiV1Group *echo.Group) {
s.registerTestRoutes(apiV1Group)
s.registerAuthRoutes(apiV1Group, s.Secret)
s.registerIdentityProviderRoutes(apiV1Group)
}
......@@ -22,6 +22,10 @@ const (
userIDContextKey = "user-id"
)
func getUserIDContextKey() string {
return userIDContextKey
}
// Claims creates a struct that will be encoded to a JWT.
// We add jwt.RegisteredClaims as an embedded type, to provide fields such as name.
type Claims struct {
......@@ -29,10 +33,6 @@ type Claims struct {
jwt.RegisteredClaims
}
func getUserIDContextKey() string {
return userIDContextKey
}
func extractTokenFromHeader(c echo.Context) (string, error) {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
......@@ -82,7 +82,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
}
// Skip validation for server status endpoints.
if common.HasPrefixes(path, "/api/ping", "/api/idp", "/api/user/:id") && method == http.MethodGet {
if common.HasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet {
return next(c)
}
......@@ -111,15 +111,9 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
return echo.NewHTTPError(http.StatusUnauthorized,
fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment",
claims.Audience,
auth.AccessTokenAudienceName,
))
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName))
}
generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration
if err != nil {
var ve *jwt.ValidationError
......@@ -130,11 +124,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
generateToken = true
}
} else {
return &echo.HTTPError{
Code: http.StatusUnauthorized,
Message: "Invalid or expired access token",
Internal: err,
}
return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token"))
}
}
......
......@@ -105,12 +105,15 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
s.registerResourceRoutes(apiGroup)
s.registerTagRoutes(apiGroup)
s.registerStorageRoutes(apiGroup)
s.registerIdentityProviderRoutes(apiGroup)
s.registerOpenAIRoutes(apiGroup)
s.registerMemoRelationRoutes(apiGroup)
apiV1Group := e.Group("/api/v1")
apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, s.Secret)
})
apiV1Service := apiV1.NewAPIV1Service(s.Secret, profile, store)
apiV1Service.Register(e)
apiV1Service.Register(apiV1Group)
return s, nil
}
......
......@@ -27,6 +27,211 @@ func (s *Store) SeedDataForNewUser(ctx context.Context, user *api.User) error {
return err
}
// Role is the type of a role.
type Role string
const (
// Host is the HOST role.
Host Role = "HOST"
// Admin is the ADMIN role.
Admin Role = "ADMIN"
// NormalUser is the USER role.
NormalUser Role = "USER"
)
func (e Role) String() string {
switch e {
case Host:
return "HOST"
case Admin:
return "ADMIN"
case NormalUser:
return "USER"
}
return "USER"
}
type UserMessage struct {
ID int
// Standard fields
RowStatus RowStatus
CreatedTs int64
UpdatedTs int64
// Domain specific fields
Username string
Role Role
Email string
Nickname string
PasswordHash string
OpenID string
AvatarURL string
}
type FindUserMessage struct {
ID *int
// Standard fields
RowStatus *RowStatus
// Domain specific fields
Username *string
Role *Role
Email *string
Nickname *string
OpenID *string
}
func (s *Store) CreateUserV1(ctx context.Context, create *UserMessage) (*UserMessage, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
query := `
INSERT INTO user (
username,
role,
email,
nickname,
password_hash,
open_id
)
VALUES (?, ?, ?, ?, ?, ?)
RETURNING id, avatar_url, created_ts, updated_ts, row_status
`
if err := tx.QueryRowContext(ctx, query,
create.Username,
create.Role,
create.Email,
create.Nickname,
create.PasswordHash,
create.OpenID,
).Scan(
&create.ID,
&create.AvatarURL,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, FormatError(err)
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
userMessage := create
return userMessage, nil
}
func (s *Store) ListUsers(ctx context.Context, find *FindUserMessage) ([]*UserMessage, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := listUsers(ctx, tx, find)
if err != nil {
return nil, err
}
return list, nil
}
func (s *Store) GetUser(ctx context.Context, find *FindUserMessage) (*UserMessage, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
list, err := listUsers(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user not found")}
}
memoMessage := list[0]
return memoMessage, nil
}
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUserMessage) ([]*UserMessage, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.Username; v != nil {
where, args = append(where, "username = ?"), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = ?"), append(args, *v)
}
if v := find.Email; v != nil {
where, args = append(where, "email = ?"), append(args, *v)
}
if v := find.Nickname; v != nil {
where, args = append(where, "nickname = ?"), append(args, *v)
}
if v := find.OpenID; v != nil {
where, args = append(where, "open_id = ?"), append(args, *v)
}
query := `
SELECT
id,
username,
role,
email,
nickname,
password_hash,
open_id,
avatar_url,
created_ts,
updated_ts,
row_status
FROM user
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC, row_status DESC
`
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, FormatError(err)
}
defer rows.Close()
userMessageList := make([]*UserMessage, 0)
for rows.Next() {
var userMessage UserMessage
if err := rows.Scan(
&userMessage.ID,
&userMessage.Username,
&userMessage.Role,
&userMessage.Email,
&userMessage.Nickname,
&userMessage.PasswordHash,
&userMessage.OpenID,
&userMessage.AvatarURL,
&userMessage.CreatedTs,
&userMessage.UpdatedTs,
&userMessage.RowStatus,
); err != nil {
return nil, FormatError(err)
}
userMessageList = append(userMessageList, &userMessage)
}
if err := rows.Err(); err != nil {
return nil, FormatError(err)
}
return userMessageList, nil
}
// userRaw is the store model for an User.
// Fields have exactly the same meanings as User.
type userRaw struct {
......
package teststore
import (
"context"
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/store"
)
func TestIdentityProviderStore(t *testing.T) {
ctx := context.Background()
ts := NewTestingStore(ctx, t)
createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{
Name: "GitHub OAuth",
Type: store.IdentityProviderOAuth2,
IdentifierFilter: "",
Config: &store.IdentityProviderConfig{
OAuth2Config: &store.IdentityProviderOAuth2Config{
ClientID: "asd",
ClientSecret: "123",
AuthURL: "https://github.com/auth",
TokenURL: "https://github.com/token",
UserInfoURL: "https://github.com/user",
Scopes: []string{"login"},
FieldMapping: &store.FieldMapping{
Identifier: "login",
DisplayName: "name",
Email: "emai",
},
},
},
})
require.NoError(t, err)
idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{
ID: &createdIDP.ID,
})
require.NoError(t, err)
require.Equal(t, createdIDP, idp)
err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{
ID: idp.ID,
})
require.NoError(t, err)
idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{})
require.NoError(t, err)
require.Equal(t, 0, len(idpList))
}
......@@ -17,9 +17,7 @@ const SSOSection = () => {
}, []);
const fetchIdentityProviderList = async () => {
const {
data: { data: identityProviderList },
} = await api.getIdentityProviderList();
const { data: identityProviderList } = await api.getIdentityProviderList();
setIdentityProviderList(identityProviderList);
};
......
......@@ -246,19 +246,19 @@ export function deleteStorage(storageId: StorageId) {
}
export function getIdentityProviderList() {
return axios.get<ResponseObject<IdentityProvider[]>>(`/api/idp`);
return axios.get<IdentityProvider[]>(`/api/v1/idp`);
}
export function createIdentityProvider(identityProviderCreate: IdentityProviderCreate) {
return axios.post<ResponseObject<IdentityProvider>>(`/api/idp`, identityProviderCreate);
return axios.post<IdentityProvider>(`/api/v1/idp`, identityProviderCreate);
}
export function patchIdentityProvider(identityProviderPatch: IdentityProviderPatch) {
return axios.patch<ResponseObject<IdentityProvider>>(`/api/idp/${identityProviderPatch.id}`, identityProviderPatch);
return axios.patch<IdentityProvider>(`/api/v1/idp/${identityProviderPatch.id}`, identityProviderPatch);
}
export function deleteIdentityProvider(id: IdentityProviderId) {
return axios.delete(`/api/idp/${id}`);
return axios.delete(`/api/v1/idp/${id}`);
}
export async function getRepoStarCount() {
......
......@@ -24,9 +24,7 @@ const Auth = () => {
useEffect(() => {
userStore.doSignOut().catch();
const fetchIdentityProviderList = async () => {
const {
data: { data: identityProviderList },
} = await api.getIdentityProviderList();
const { data: identityProviderList } = await api.getIdentityProviderList();
setIdentityProviderList(identityProviderList);
};
fetchIdentityProviderList();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment