Unverified Commit 66e65e4d authored by boojack's avatar boojack Committed by GitHub

refactor: migrate definition to api v1 (#1879)

* refactor: user api v1

* refactor: system setting to apiv1

* chore: remove unused definition

* chore: update

* chore: refactor: system setting

* chore: update

* refactor: migrate tag

* feat: migrate activity store

* refactor: migrate shortcut apiv1

* chore: update
parent b84ecc45
package api
import "github.com/usememos/memos/server/profile"
// ActivityType is the type for an activity.
type ActivityType string
const (
// User related.
// ActivityUserCreate is the type for creating users.
ActivityUserCreate ActivityType = "user.create"
// ActivityUserUpdate is the type for updating users.
ActivityUserUpdate ActivityType = "user.update"
// ActivityUserDelete is the type for deleting users.
ActivityUserDelete ActivityType = "user.delete"
// ActivityUserAuthSignIn is the type for user signin.
ActivityUserAuthSignIn ActivityType = "user.auth.signin"
// ActivityUserAuthSignUp is the type for user signup.
ActivityUserAuthSignUp ActivityType = "user.auth.signup"
// ActivityUserSettingUpdate is the type for updating user settings.
ActivityUserSettingUpdate ActivityType = "user.setting.update"
// Memo related.
// ActivityMemoCreate is the type for creating memos.
ActivityMemoCreate ActivityType = "memo.create"
// ActivityMemoUpdate is the type for updating memos.
ActivityMemoUpdate ActivityType = "memo.update"
// ActivityMemoDelete is the type for deleting memos.
ActivityMemoDelete ActivityType = "memo.delete"
// Shortcut related.
// ActivityShortcutCreate is the type for creating shortcuts.
ActivityShortcutCreate ActivityType = "shortcut.create"
// ActivityShortcutUpdate is the type for updating shortcuts.
ActivityShortcutUpdate ActivityType = "shortcut.update"
// ActivityShortcutDelete is the type for deleting shortcuts.
ActivityShortcutDelete ActivityType = "shortcut.delete"
// Resource related.
// ActivityResourceCreate is the type for creating resources.
ActivityResourceCreate ActivityType = "resource.create"
// ActivityResourceDelete is the type for deleting resources.
ActivityResourceDelete ActivityType = "resource.delete"
// Tag related.
// ActivityTagCreate is the type for creating tags.
ActivityTagCreate ActivityType = "tag.create"
// ActivityTagDelete is the type for deleting tags.
ActivityTagDelete ActivityType = "tag.delete"
// Server related.
// ActivityServerStart is the type for starting server.
ActivityServerStart ActivityType = "server.start"
)
// ActivityLevel is the level of activities.
type ActivityLevel string
const (
// ActivityInfo is the INFO level of activities.
ActivityInfo ActivityLevel = "INFO"
// ActivityWarn is the WARN level of activities.
ActivityWarn ActivityLevel = "WARN"
// ActivityError is the ERROR level of activities.
ActivityError ActivityLevel = "ERROR"
)
type ActivityUserCreatePayload struct {
UserID int `json:"userId"`
Username string `json:"username"`
Role Role `json:"role"`
}
type ActivityUserAuthSignInPayload struct {
UserID int `json:"userId"`
IP string `json:"ip"`
}
type ActivityUserAuthSignUpPayload struct {
Username string `json:"username"`
IP string `json:"ip"`
}
type ActivityMemoCreatePayload struct {
Content string `json:"content"`
Visibility string `json:"visibility"`
}
type ActivityShortcutCreatePayload struct {
Title string `json:"title"`
Payload string `json:"payload"`
}
type ActivityResourceCreatePayload struct {
Filename string `json:"filename"`
Type string `json:"type"`
Size int64 `json:"size"`
}
type ActivityTagCreatePayload struct {
TagName string `json:"tagName"`
}
type ActivityServerStartPayload struct {
ServerID string `json:"serverId"`
Profile *profile.Profile `json:"profile"`
}
type Activity struct {
ID int `json:"id"`
// Standard fields
CreatorID int `json:"creatorId"`
CreatedTs int64 `json:"createdTs"`
// Domain specific fields
Type ActivityType `json:"type"`
Level ActivityLevel `json:"level"`
Payload string `json:"payload"`
}
// ActivityCreate is the API message for creating an activity.
type ActivityCreate struct {
// Standard fields
CreatorID int
// Domain specific fields
Type ActivityType `json:"type"`
Level ActivityLevel
Payload string `json:"payload"`
}
package api
type Shortcut struct {
ID int `json:"id"`
// Standard fields
RowStatus RowStatus `json:"rowStatus"`
CreatorID int `json:"creatorId"`
CreatedTs int64 `json:"createdTs"`
UpdatedTs int64 `json:"updatedTs"`
// Domain specific fields
Title string `json:"title"`
Payload string `json:"payload"`
}
type ShortcutCreate struct {
// Standard fields
CreatorID int `json:"-"`
// Domain specific fields
Title string `json:"title"`
Payload string `json:"payload"`
}
type ShortcutPatch struct {
ID int `json:"-"`
// Standard fields
UpdatedTs *int64
RowStatus *RowStatus `json:"rowStatus"`
// Domain specific fields
Title *string `json:"title"`
Payload *string `json:"payload"`
}
type ShortcutFind struct {
ID *int
// Standard fields
CreatorID *int
// Domain specific fields
Title *string `json:"title"`
}
type ShortcutDelete struct {
ID *int
// Standard fields
CreatorID *int
}
package api
import "github.com/usememos/memos/server/profile"
type SystemStatus struct {
Host *User `json:"host"`
Profile profile.Profile `json:"profile"`
DBSize int64 `json:"dbSize"`
// System settings
// Allow sign up.
AllowSignUp bool `json:"allowSignUp"`
// Disable public memos.
DisablePublicMemos bool `json:"disablePublicMemos"`
// Max upload size.
MaxUploadSizeMiB int `json:"maxUploadSizeMiB"`
// Additional style.
AdditionalStyle string `json:"additionalStyle"`
// Additional script.
AdditionalScript string `json:"additionalScript"`
// Customized server profile, including server name and external url.
CustomizedProfile CustomizedProfile `json:"customizedProfile"`
// Storage service ID.
StorageServiceID int `json:"storageServiceId"`
// Local storage path.
LocalStoragePath string `json:"localStoragePath"`
// Memo display with updated timestamp.
MemoDisplayWithUpdatedTs bool `json:"memoDisplayWithUpdatedTs"`
}
package api
import (
"encoding/json"
"fmt"
"strings"
"golang.org/x/exp/slices"
)
type SystemSettingName string
const (
// SystemSettingServerIDName is the name of server id.
SystemSettingServerIDName SystemSettingName = "server-id"
// SystemSettingSecretSessionName is the name of secret session.
SystemSettingSecretSessionName SystemSettingName = "secret-session"
// SystemSettingAllowSignUpName is the name of allow signup setting.
SystemSettingAllowSignUpName SystemSettingName = "allow-signup"
// SystemSettingDisablePublicMemosName is the name of disable public memos setting.
SystemSettingDisablePublicMemosName SystemSettingName = "disable-public-memos"
// SystemSettingMaxUploadSizeMiBName is the name of max upload size setting.
SystemSettingMaxUploadSizeMiBName SystemSettingName = "max-upload-size-mib"
// SystemSettingAdditionalStyleName is the name of additional style.
SystemSettingAdditionalStyleName SystemSettingName = "additional-style"
// SystemSettingAdditionalScriptName is the name of additional script.
SystemSettingAdditionalScriptName SystemSettingName = "additional-script"
// SystemSettingCustomizedProfileName is the name of customized server profile.
SystemSettingCustomizedProfileName SystemSettingName = "customized-profile"
// SystemSettingStorageServiceIDName is the name of storage service ID.
SystemSettingStorageServiceIDName SystemSettingName = "storage-service-id"
// SystemSettingLocalStoragePathName is the name of local storage path.
SystemSettingLocalStoragePathName SystemSettingName = "local-storage-path"
// SystemSettingOpenAIConfigName is the name of OpenAI config.
SystemSettingOpenAIConfigName SystemSettingName = "openai-config"
// SystemSettingTelegramBotToken is the name of Telegram Bot Token.
SystemSettingTelegramBotTokenName SystemSettingName = "telegram-bot-token"
SystemSettingMemoDisplayWithUpdatedTsName SystemSettingName = "memo-display-with-updated-ts"
)
// CustomizedProfile is the struct definition for SystemSettingCustomizedProfileName system setting item.
type CustomizedProfile struct {
// Name is the server name, default is `memos`
Name string `json:"name"`
// LogoURL is the url of logo image.
LogoURL string `json:"logoUrl"`
// Description is the server description.
Description string `json:"description"`
// Locale is the server default locale.
Locale string `json:"locale"`
// Appearance is the server default appearance.
Appearance string `json:"appearance"`
// ExternalURL is the external url of server. e.g. https://usermemos.com
ExternalURL string `json:"externalUrl"`
}
type OpenAIConfig struct {
Key string `json:"key"`
Host string `json:"host"`
}
func (key SystemSettingName) String() string {
switch key {
case SystemSettingServerIDName:
return "server-id"
case SystemSettingSecretSessionName:
return "secret-session"
case SystemSettingAllowSignUpName:
return "allow-signup"
case SystemSettingDisablePublicMemosName:
return "disable-public-memos"
case SystemSettingMaxUploadSizeMiBName:
return "max-upload-size-mib"
case SystemSettingAdditionalStyleName:
return "additional-style"
case SystemSettingAdditionalScriptName:
return "additional-script"
case SystemSettingCustomizedProfileName:
return "customized-profile"
case SystemSettingStorageServiceIDName:
return "storage-service-id"
case SystemSettingLocalStoragePathName:
return "local-storage-path"
case SystemSettingOpenAIConfigName:
return "openai-config"
case SystemSettingTelegramBotTokenName:
return "telegram-bot-token"
case SystemSettingMemoDisplayWithUpdatedTsName:
return "memo-display-with-updated-ts"
}
return ""
}
type SystemSetting struct {
Name SystemSettingName `json:"name"`
// Value is a JSON string with basic value.
Value string `json:"value"`
Description string `json:"description"`
}
type SystemSettingUpsert struct {
Name SystemSettingName `json:"name"`
Value string `json:"value"`
Description string `json:"description"`
}
const systemSettingUnmarshalError = `failed to unmarshal value from system setting "%v"`
func (upsert SystemSettingUpsert) Validate() error {
switch settingName := upsert.Name; settingName {
case SystemSettingServerIDName:
return fmt.Errorf("updating %v is not allowed", settingName)
case SystemSettingAllowSignUpName:
var value bool
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingDisablePublicMemosName:
var value bool
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingMaxUploadSizeMiBName:
var value int
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingAdditionalStyleName:
var value string
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingAdditionalScriptName:
var value string
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingCustomizedProfileName:
customizedProfile := CustomizedProfile{
Name: "memos",
LogoURL: "",
Description: "",
Locale: "en",
Appearance: "system",
ExternalURL: "",
}
if err := json.Unmarshal([]byte(upsert.Value), &customizedProfile); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
if !slices.Contains(UserSettingLocaleValue, customizedProfile.Locale) {
return fmt.Errorf(`invalid locale value for system setting "%v"`, settingName)
}
if !slices.Contains(UserSettingAppearanceValue, customizedProfile.Appearance) {
return fmt.Errorf(`invalid appearance value for system setting "%v"`, settingName)
}
case SystemSettingStorageServiceIDName:
value := DatabaseStorage
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
return nil
case SystemSettingLocalStoragePathName:
value := ""
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingOpenAIConfigName:
value := OpenAIConfig{}
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingTelegramBotTokenName:
if upsert.Value == "" {
return nil
}
// Bot Token with Reverse Proxy shoule like `http.../bot<token>`
if strings.HasPrefix(upsert.Value, "http") {
slashIndex := strings.LastIndexAny(upsert.Value, "/")
if strings.HasPrefix(upsert.Value[slashIndex:], "/bot") {
return nil
}
return fmt.Errorf("token start with `http` must end with `/bot<token>`")
}
fragments := strings.Split(upsert.Value, ":")
if len(fragments) != 2 {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingMemoDisplayWithUpdatedTsName:
var value bool
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
default:
return fmt.Errorf("invalid system setting name")
}
return nil
}
type SystemSettingFind struct {
Name SystemSettingName `json:"name"`
}
package api
type Tag struct {
Name string
CreatorID int
}
type TagUpsert struct {
Name string
CreatorID int `json:"-"`
}
type TagFind struct {
CreatorID int
}
type TagDelete struct {
Name string `json:"name"`
CreatorID int
}
package api
import (
"fmt"
"github.com/usememos/memos/common"
)
// Role is the type of a role.
type Role string
const (
// Host is the HOST role.
Host Role = "HOST"
// Admin is the ADMIN role.
Admin Role = "ADMIN"
// NormalUser is the USER role.
NormalUser Role = "USER"
)
func (e Role) String() string {
switch e {
case Host:
return "HOST"
case Admin:
return "ADMIN"
case NormalUser:
return "USER"
}
return "USER"
}
type User struct {
ID int `json:"id"`
// Standard fields
RowStatus RowStatus `json:"rowStatus"`
CreatedTs int64 `json:"createdTs"`
UpdatedTs int64 `json:"updatedTs"`
// Domain specific fields
Username string `json:"username"`
Role Role `json:"role"`
Email string `json:"email"`
Nickname string `json:"nickname"`
PasswordHash string `json:"-"`
OpenID string `json:"openId"`
AvatarURL string `json:"avatarUrl"`
UserSettingList []*UserSetting `json:"userSettingList"`
}
type UserCreate struct {
// Domain specific fields
Username string `json:"username"`
Role Role `json:"role"`
Email string `json:"email"`
Nickname string `json:"nickname"`
Password string `json:"password"`
PasswordHash string
OpenID string
}
func (create UserCreate) Validate() error {
if len(create.Username) < 3 {
return fmt.Errorf("username is too short, minimum length is 3")
}
if len(create.Username) > 32 {
return fmt.Errorf("username is too long, maximum length is 32")
}
if len(create.Password) < 3 {
return fmt.Errorf("password is too short, minimum length is 3")
}
if len(create.Password) > 512 {
return fmt.Errorf("password is too long, maximum length is 512")
}
if len(create.Nickname) > 64 {
return fmt.Errorf("nickname is too long, maximum length is 64")
}
if create.Email != "" {
if len(create.Email) > 256 {
return fmt.Errorf("email is too long, maximum length is 256")
}
if !common.ValidateEmail(create.Email) {
return fmt.Errorf("invalid email format")
}
}
return nil
}
type UserPatch struct {
ID int `json:"-"`
// Standard fields
UpdatedTs *int64
RowStatus *RowStatus `json:"rowStatus"`
// Domain specific fields
Username *string `json:"username"`
Email *string `json:"email"`
Nickname *string `json:"nickname"`
Password *string `json:"password"`
ResetOpenID *bool `json:"resetOpenId"`
AvatarURL *string `json:"avatarUrl"`
PasswordHash *string
OpenID *string
}
func (patch UserPatch) Validate() error {
if patch.Username != nil && len(*patch.Username) < 3 {
return fmt.Errorf("username is too short, minimum length is 3")
}
if patch.Username != nil && len(*patch.Username) > 32 {
return fmt.Errorf("username is too long, maximum length is 32")
}
if patch.Password != nil && len(*patch.Password) < 3 {
return fmt.Errorf("password is too short, minimum length is 3")
}
if patch.Password != nil && len(*patch.Password) > 512 {
return fmt.Errorf("password is too long, maximum length is 512")
}
if patch.Nickname != nil && len(*patch.Nickname) > 64 {
return fmt.Errorf("nickname is too long, maximum length is 64")
}
if patch.AvatarURL != nil {
if len(*patch.AvatarURL) > 2<<20 {
return fmt.Errorf("avatar is too large, maximum is 2MB")
}
}
if patch.Email != nil && *patch.Email != "" {
if len(*patch.Email) > 256 {
return fmt.Errorf("email is too long, maximum length is 256")
}
if !common.ValidateEmail(*patch.Email) {
return fmt.Errorf("invalid email format")
}
}
return nil
}
type UserFind struct {
ID *int `json:"id"`
// Standard fields
RowStatus *RowStatus `json:"rowStatus"`
// Domain specific fields
Username *string `json:"username"`
Role *Role
Email *string `json:"email"`
Nickname *string `json:"nickname"`
OpenID *string
}
type UserDelete struct {
ID int
}
package api
import (
"encoding/json"
"fmt"
"strconv"
"golang.org/x/exp/slices"
)
type UserSettingKey string
const (
// UserSettingLocaleKey is the key type for user locale.
UserSettingLocaleKey UserSettingKey = "locale"
// UserSettingAppearanceKey is the key type for user appearance.
UserSettingAppearanceKey UserSettingKey = "appearance"
// UserSettingMemoVisibilityKey is the key type for user preference memo default visibility.
UserSettingMemoVisibilityKey UserSettingKey = "memo-visibility"
// UserSettingTelegramUserID is the key type for telegram UserID of memos user.
UserSettingTelegramUserIDKey UserSettingKey = "telegram-user-id"
)
// String returns the string format of UserSettingKey type.
func (key UserSettingKey) String() string {
switch key {
case UserSettingLocaleKey:
return "locale"
case UserSettingAppearanceKey:
return "appearance"
case UserSettingMemoVisibilityKey:
return "memo-visibility"
case UserSettingTelegramUserIDKey:
return "telegram-user-id"
}
return ""
}
var (
UserSettingLocaleValue = []string{
"de",
"en",
"es",
"fr",
"hr",
"it",
"ja",
"ko",
"nl",
"pl",
"pt-BR",
"ru",
"sl",
"sv",
"tr",
"uk",
"vi",
"zh-Hans",
"zh-Hant",
}
UserSettingAppearanceValue = []string{"system", "light", "dark"}
UserSettingMemoVisibilityValue = []Visibility{Private, Protected, Public}
)
type UserSetting struct {
UserID int
Key UserSettingKey `json:"key"`
// Value is a JSON string with basic value
Value string `json:"value"`
}
type UserSettingUpsert struct {
UserID int `json:"-"`
Key UserSettingKey `json:"key"`
Value string `json:"value"`
}
func (upsert UserSettingUpsert) Validate() error {
if upsert.Key == UserSettingLocaleKey {
localeValue := "en"
err := json.Unmarshal([]byte(upsert.Value), &localeValue)
if err != nil {
return fmt.Errorf("failed to unmarshal user setting locale value")
}
if !slices.Contains(UserSettingLocaleValue, localeValue) {
return fmt.Errorf("invalid user setting locale value")
}
} else if upsert.Key == UserSettingAppearanceKey {
appearanceValue := "system"
err := json.Unmarshal([]byte(upsert.Value), &appearanceValue)
if err != nil {
return fmt.Errorf("failed to unmarshal user setting appearance value")
}
if !slices.Contains(UserSettingAppearanceValue, appearanceValue) {
return fmt.Errorf("invalid user setting appearance value")
}
} else if upsert.Key == UserSettingMemoVisibilityKey {
memoVisibilityValue := Private
err := json.Unmarshal([]byte(upsert.Value), &memoVisibilityValue)
if err != nil {
return fmt.Errorf("failed to unmarshal user setting memo visibility value")
}
if !slices.Contains(UserSettingMemoVisibilityValue, memoVisibilityValue) {
return fmt.Errorf("invalid user setting memo visibility value")
}
} else if upsert.Key == UserSettingTelegramUserIDKey {
var s string
err := json.Unmarshal([]byte(upsert.Value), &s)
if err != nil {
return fmt.Errorf("invalid user setting telegram user id value")
}
if s == "" {
return nil
}
if _, err := strconv.Atoi(s); err != nil {
return fmt.Errorf("invalid user setting telegram user id value")
}
} else {
return fmt.Errorf("invalid user setting key")
}
return nil
}
type UserSettingFind struct {
UserID *int
Key UserSettingKey `json:"key"`
}
type UserSettingDelete struct {
UserID int
}
...@@ -59,6 +59,10 @@ const ( ...@@ -59,6 +59,10 @@ const (
ActivityServerStart ActivityType = "server.start" ActivityServerStart ActivityType = "server.start"
) )
func (t ActivityType) String() string {
return string(t)
}
// ActivityLevel is the level of activities. // ActivityLevel is the level of activities.
type ActivityLevel string type ActivityLevel string
...@@ -71,6 +75,10 @@ const ( ...@@ -71,6 +75,10 @@ const (
ActivityError ActivityLevel = "ERROR" ActivityError ActivityLevel = "ERROR"
) )
func (l ActivityLevel) String() string {
return string(l)
}
type ActivityUserCreatePayload struct { type ActivityUserCreatePayload struct {
UserID int `json:"userId"` UserID int `json:"userId"`
Username string `json:"username"` Username string `json:"username"`
......
...@@ -85,7 +85,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { ...@@ -85,7 +85,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
} }
var userInfo *idp.IdentityProviderUserInfo var userInfo *idp.IdentityProviderUserInfo
if identityProvider.Type == store.IdentityProviderOAuth2 { if identityProvider.Type == store.IdentityProviderOAuth2Type {
oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config) oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err)
...@@ -121,7 +121,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { ...@@ -121,7 +121,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
userCreate := &store.User{ userCreate := &store.User{
Username: userInfo.Identifier, Username: userInfo.Identifier,
// The new signup user should be normal user by default. // The new signup user should be normal user by default.
Role: store.NormalUser, Role: store.RoleUser,
Nickname: userInfo.DisplayName, Nickname: userInfo.DisplayName,
Email: userInfo.Email, Email: userInfo.Email,
OpenID: common.GenUUID(), OpenID: common.GenUUID(),
...@@ -135,7 +135,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { ...@@ -135,7 +135,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err)
} }
userCreate.PasswordHash = string(passwordHash) userCreate.PasswordHash = string(passwordHash)
user, err = s.Store.CreateUserV1(ctx, userCreate) user, err = s.Store.CreateUser(ctx, userCreate)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
} }
...@@ -160,7 +160,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { ...@@ -160,7 +160,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err)
} }
hostUserType := store.Host hostUserType := store.RoleHost
existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{ existedHostUsers, err := s.Store.ListUsers(ctx, &store.FindUser{
Role: &hostUserType, Role: &hostUserType,
}) })
...@@ -171,13 +171,13 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { ...@@ -171,13 +171,13 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
userCreate := &store.User{ userCreate := &store.User{
Username: signup.Username, Username: signup.Username,
// The new signup user should be normal user by default. // The new signup user should be normal user by default.
Role: store.NormalUser, Role: store.RoleUser,
Nickname: signup.Username, Nickname: signup.Username,
OpenID: common.GenUUID(), OpenID: common.GenUUID(),
} }
if len(existedHostUsers) == 0 { if len(existedHostUsers) == 0 {
// Change the default role to host if there is no host user. // Change the default role to host if there is no host user.
userCreate.Role = store.Host userCreate.Role = store.RoleHost
} else { } else {
allowSignUpSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{ allowSignUpSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: SystemSettingAllowSignUpName.String(), Name: SystemSettingAllowSignUpName.String(),
...@@ -204,7 +204,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) { ...@@ -204,7 +204,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group) {
} }
userCreate.PasswordHash = string(passwordHash) userCreate.PasswordHash = string(passwordHash)
user, err := s.Store.CreateUserV1(ctx, userCreate) user, err := s.Store.CreateUser(ctx, userCreate)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err)
} }
...@@ -234,7 +234,7 @@ func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.User ...@@ -234,7 +234,7 @@ func (s *APIV1Service) createAuthSignInActivity(c echo.Context, user *store.User
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := s.Store.CreateActivityV1(ctx, &store.ActivityMessage{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: user.ID, CreatorID: user.ID,
Type: string(ActivityUserAuthSignIn), Type: string(ActivityUserAuthSignIn),
Level: string(ActivityInfo), Level: string(ActivityInfo),
...@@ -256,7 +256,7 @@ func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User ...@@ -256,7 +256,7 @@ func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := s.Store.CreateActivityV1(ctx, &store.ActivityMessage{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: user.ID, CreatorID: user.ID,
Type: string(ActivityUserAuthSignUp), Type: string(ActivityUserAuthSignUp),
Level: string(ActivityInfo), Level: string(ActivityInfo),
......
...@@ -13,12 +13,6 @@ const ( ...@@ -13,12 +13,6 @@ const (
Archived RowStatus = "ARCHIVED" Archived RowStatus = "ARCHIVED"
) )
func (e RowStatus) String() string { func (r RowStatus) String() string {
switch e { return string(r)
case Normal:
return "NORMAL"
case Archived:
return "ARCHIVED"
}
return ""
} }
...@@ -14,9 +14,13 @@ import ( ...@@ -14,9 +14,13 @@ import (
type IdentityProviderType string type IdentityProviderType string
const ( const (
IdentityProviderOAuth2 IdentityProviderType = "OAUTH2" IdentityProviderOAuth2Type IdentityProviderType = "OAUTH2"
) )
func (t IdentityProviderType) String() string {
return string(t)
}
type IdentityProviderConfig struct { type IdentityProviderConfig struct {
OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"` OAuth2Config *IdentityProviderOAuth2Config `json:"oauth2Config"`
} }
...@@ -53,7 +57,7 @@ type CreateIdentityProviderRequest struct { ...@@ -53,7 +57,7 @@ type CreateIdentityProviderRequest struct {
} }
type UpdateIdentityProviderRequest struct { type UpdateIdentityProviderRequest struct {
ID int ID int `json:"-"`
Type IdentityProviderType `json:"type"` Type IdentityProviderType `json:"type"`
Name *string `json:"name"` Name *string `json:"name"`
IdentifierFilter *string `json:"identifierFilter"` IdentifierFilter *string `json:"identifierFilter"`
...@@ -74,7 +78,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { ...@@ -74,7 +78,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != store.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
...@@ -108,7 +112,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { ...@@ -108,7 +112,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != store.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
...@@ -153,7 +157,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { ...@@ -153,7 +157,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role == store.Host { if user == nil || user.Role == store.RoleHost {
isHostUser = true isHostUser = true
} }
} }
...@@ -183,7 +187,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { ...@@ -183,7 +187,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != store.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
...@@ -217,7 +221,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { ...@@ -217,7 +221,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != store.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
......
...@@ -82,7 +82,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e ...@@ -82,7 +82,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
} }
// Skip validation for server status endpoints. // Skip validation for server status endpoints.
if common.HasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet { if common.HasPrefixes(path, "/api/v1/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet {
return next(c) return next(c)
} }
...@@ -93,7 +93,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e ...@@ -93,7 +93,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
return next(c) return next(c)
} }
// When the request is not authenticated, we allow the user to access the memo endpoints for those public memos. // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos.
if common.HasPrefixes(path, "/api/status", "/api/memo") && method == http.MethodGet { if common.HasPrefixes(path, "/api/v1/status", "/api/memo") && method == http.MethodGet {
return next(c) return next(c)
} }
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
......
...@@ -13,13 +13,5 @@ const ( ...@@ -13,13 +13,5 @@ const (
) )
func (v Visibility) String() string { func (v Visibility) String() string {
switch v { return string(v)
case Public:
return "PUBLIC"
case Protected:
return "PROTECTED"
case Private:
return "PRIVATE"
}
return "PRIVATE"
} }
package server package v1
import ( import (
"encoding/json" "encoding/json"
...@@ -7,34 +7,79 @@ import ( ...@@ -7,34 +7,79 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/pkg/errors"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
) )
func (s *Server) registerShortcutRoutes(g *echo.Group) { type Shortcut struct {
ID int `json:"id"`
// Standard fields
RowStatus RowStatus `json:"rowStatus"`
CreatorID int `json:"creatorId"`
CreatedTs int64 `json:"createdTs"`
UpdatedTs int64 `json:"updatedTs"`
// Domain specific fields
Title string `json:"title"`
Payload string `json:"payload"`
}
type CreateShortcutRequest struct {
Title string `json:"title"`
Payload string `json:"payload"`
}
type UpdateShortcutRequest struct {
RowStatus *RowStatus `json:"rowStatus"`
Title *string `json:"title"`
Payload *string `json:"payload"`
}
type ShortcutFind struct {
ID *int
// Standard fields
CreatorID *int
// Domain specific fields
Title *string `json:"title"`
}
type ShortcutDelete struct {
ID *int
// Standard fields
CreatorID *int
}
func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) {
g.POST("/shortcut", func(c echo.Context) error { g.POST("/shortcut", func(c echo.Context) error {
ctx := c.Request().Context() ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok { if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
shortcutCreate := &api.ShortcutCreate{} shortcutCreate := &CreateShortcutRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(shortcutCreate); err != nil { if err := json.NewDecoder(c.Request().Body).Decode(shortcutCreate); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post shortcut request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post shortcut request").SetInternal(err)
} }
shortcutCreate.CreatorID = userID shortcut, err := s.Store.CreateShortcut(ctx, &store.Shortcut{
shortcut, err := s.Store.CreateShortcut(ctx, shortcutCreate) CreatorID: userID,
Title: shortcutCreate.Title,
Payload: shortcutCreate.Payload,
})
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create shortcut").SetInternal(err)
} }
if err := s.createShortcutCreateActivity(c, shortcut); err != nil {
shortcutMessage := convertShortcutFromStore(shortcut)
if err := s.createShortcutCreateActivity(c, shortcutMessage); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(shortcut)) return c.JSON(http.StatusOK, shortcutMessage)
}) })
g.PATCH("/shortcut/:shortcutId", func(c echo.Context) error { g.PATCH("/shortcut/:shortcutId", func(c echo.Context) error {
...@@ -48,10 +93,9 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { ...@@ -48,10 +93,9 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
} }
shortcutFind := &api.ShortcutFind{ shortcut, err := s.Store.GetShortcut(ctx, &store.FindShortcut{
ID: &shortcutID, ID: &shortcutID,
} })
shortcut, err := s.Store.FindShortcut(ctx, shortcutFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err)
} }
...@@ -59,20 +103,32 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { ...@@ -59,20 +103,32 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
request := &UpdateShortcutRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(request); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch shortcut request").SetInternal(err)
}
currentTs := time.Now().Unix() currentTs := time.Now().Unix()
shortcutPatch := &api.ShortcutPatch{ shortcutUpdate := &store.UpdateShortcut{
ID: shortcutID,
UpdatedTs: &currentTs, UpdatedTs: &currentTs,
} }
if err := json.NewDecoder(c.Request().Body).Decode(shortcutPatch); err != nil { if request.RowStatus != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch shortcut request").SetInternal(err) rowStatus := store.RowStatus(*request.RowStatus)
shortcutUpdate.RowStatus = &rowStatus
}
if request.Title != nil {
shortcutUpdate.Title = request.Title
}
if request.Payload != nil {
shortcutUpdate.Payload = request.Payload
} }
shortcutPatch.ID = shortcutID shortcut, err = s.Store.UpdateShortcut(ctx, shortcutUpdate)
shortcut, err = s.Store.PatchShortcut(ctx, shortcutPatch)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch shortcut").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(shortcut)) return c.JSON(http.StatusOK, convertShortcutFromStore(shortcut))
}) })
g.GET("/shortcut", func(c echo.Context) error { g.GET("/shortcut", func(c echo.Context) error {
...@@ -82,14 +138,17 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { ...@@ -82,14 +138,17 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find shortcut") return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find shortcut")
} }
shortcutFind := &api.ShortcutFind{ list, err := s.Store.ListShortcuts(ctx, &store.FindShortcut{
CreatorID: &userID, CreatorID: &userID,
} })
list, err := s.Store.FindShortcutList(ctx, shortcutFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch shortcut list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get shortcut list").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(list)) shortcutMessageList := make([]*Shortcut, 0, len(list))
for _, shortcut := range list {
shortcutMessageList = append(shortcutMessageList, convertShortcutFromStore(shortcut))
}
return c.JSON(http.StatusOK, shortcutMessageList)
}) })
g.GET("/shortcut/:shortcutId", func(c echo.Context) error { g.GET("/shortcut/:shortcutId", func(c echo.Context) error {
...@@ -99,14 +158,16 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { ...@@ -99,14 +158,16 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
} }
shortcutFind := &api.ShortcutFind{ shortcut, err := s.Store.GetShortcut(ctx, &store.FindShortcut{
ID: &shortcutID, ID: &shortcutID,
} })
shortcut, err := s.Store.FindShortcut(ctx, shortcutFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch shortcut by ID %d", *shortcutFind.ID)).SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to fetch shortcut by ID %d", shortcutID)).SetInternal(err)
}
if shortcut == nil {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut by ID %d not found", shortcutID))
} }
return c.JSON(http.StatusOK, composeResponse(shortcut)) return c.JSON(http.StatusOK, convertShortcutFromStore(shortcut))
}) })
g.DELETE("/shortcut/:shortcutId", func(c echo.Context) error { g.DELETE("/shortcut/:shortcutId", func(c echo.Context) error {
...@@ -120,10 +181,9 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { ...@@ -120,10 +181,9 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err)
} }
shortcutFind := &api.ShortcutFind{ shortcut, err := s.Store.GetShortcut(ctx, &store.FindShortcut{
ID: &shortcutID, ID: &shortcutID,
} })
shortcut, err := s.Store.FindShortcut(ctx, shortcutFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find shortcut").SetInternal(err)
} }
...@@ -131,22 +191,18 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) { ...@@ -131,22 +191,18 @@ func (s *Server) registerShortcutRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
shortcutDelete := &api.ShortcutDelete{ if err := s.Store.DeleteShortcut(ctx, &store.DeleteShortcut{
ID: &shortcutID, ID: &shortcutID,
} }); err != nil {
if err := s.Store.DeleteShortcut(ctx, shortcutDelete); err != nil {
if common.ErrorCode(err) == common.NotFound {
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Shortcut ID not found: %d", shortcutID))
}
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete shortcut").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete shortcut").SetInternal(err)
} }
return c.JSON(http.StatusOK, true) return c.JSON(http.StatusOK, true)
}) })
} }
func (s *Server) createShortcutCreateActivity(c echo.Context, shortcut *api.Shortcut) error { func (s *APIV1Service) createShortcutCreateActivity(c echo.Context, shortcut *Shortcut) error {
ctx := c.Request().Context() ctx := c.Request().Context()
payload := api.ActivityShortcutCreatePayload{ payload := ActivityShortcutCreatePayload{
Title: shortcut.Title, Title: shortcut.Title,
Payload: shortcut.Payload, Payload: shortcut.Payload,
} }
...@@ -154,10 +210,10 @@ func (s *Server) createShortcutCreateActivity(c echo.Context, shortcut *api.Shor ...@@ -154,10 +210,10 @@ func (s *Server) createShortcutCreateActivity(c echo.Context, shortcut *api.Shor
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: shortcut.CreatorID, CreatorID: shortcut.CreatorID,
Type: api.ActivityShortcutCreate, Type: ActivityShortcutCreate.String(),
Level: api.ActivityInfo, Level: ActivityInfo.String(),
Payload: string(payloadBytes), Payload: string(payloadBytes),
}) })
if err != nil || activity == nil { if err != nil || activity == nil {
...@@ -165,3 +221,15 @@ func (s *Server) createShortcutCreateActivity(c echo.Context, shortcut *api.Shor ...@@ -165,3 +221,15 @@ func (s *Server) createShortcutCreateActivity(c echo.Context, shortcut *api.Shor
} }
return err return err
} }
func convertShortcutFromStore(shortcut *store.Shortcut) *Shortcut {
return &Shortcut{
ID: shortcut.ID,
RowStatus: RowStatus(shortcut.RowStatus),
CreatorID: shortcut.CreatorID,
Title: shortcut.Title,
Payload: shortcut.Payload,
CreatedTs: shortcut.CreatedTs,
UpdatedTs: shortcut.UpdatedTs,
}
}
package v1
const (
// LocalStorage means the storage service is local file system.
LocalStorage = -1
// DatabaseStorage means the storage service is database.
DatabaseStorage = 0
)
This diff is collapsed.
...@@ -3,7 +3,11 @@ package v1 ...@@ -3,7 +3,11 @@ package v1
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http"
"strings" "strings"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/store"
) )
type SystemSettingName string type SystemSettingName string
...@@ -29,10 +33,9 @@ const ( ...@@ -29,10 +33,9 @@ const (
SystemSettingStorageServiceIDName SystemSettingName = "storage-service-id" SystemSettingStorageServiceIDName SystemSettingName = "storage-service-id"
// SystemSettingLocalStoragePathName is the name of local storage path. // SystemSettingLocalStoragePathName is the name of local storage path.
SystemSettingLocalStoragePathName SystemSettingName = "local-storage-path" SystemSettingLocalStoragePathName SystemSettingName = "local-storage-path"
// SystemSettingOpenAIConfigName is the name of OpenAI config.
SystemSettingOpenAIConfigName SystemSettingName = "openai-config"
// SystemSettingTelegramBotToken is the name of Telegram Bot Token. // SystemSettingTelegramBotToken is the name of Telegram Bot Token.
SystemSettingTelegramBotTokenName SystemSettingName = "telegram-bot-token" SystemSettingTelegramBotTokenName SystemSettingName = "telegram-bot-token"
// SystemSettingMemoDisplayWithUpdatedTsName is the name of memo display with updated ts.
SystemSettingMemoDisplayWithUpdatedTsName SystemSettingName = "memo-display-with-updated-ts" SystemSettingMemoDisplayWithUpdatedTsName SystemSettingName = "memo-display-with-updated-ts"
) )
...@@ -52,41 +55,8 @@ type CustomizedProfile struct { ...@@ -52,41 +55,8 @@ type CustomizedProfile struct {
ExternalURL string `json:"externalUrl"` ExternalURL string `json:"externalUrl"`
} }
type OpenAIConfig struct {
Key string `json:"key"`
Host string `json:"host"`
}
func (key SystemSettingName) String() string { func (key SystemSettingName) String() string {
switch key { return string(key)
case SystemSettingServerIDName:
return "server-id"
case SystemSettingSecretSessionName:
return "secret-session"
case SystemSettingAllowSignUpName:
return "allow-signup"
case SystemSettingDisablePublicMemosName:
return "disable-public-memos"
case SystemSettingMaxUploadSizeMiBName:
return "max-upload-size-mib"
case SystemSettingAdditionalStyleName:
return "additional-style"
case SystemSettingAdditionalScriptName:
return "additional-script"
case SystemSettingCustomizedProfileName:
return "customized-profile"
case SystemSettingStorageServiceIDName:
return "storage-service-id"
case SystemSettingLocalStoragePathName:
return "local-storage-path"
case SystemSettingOpenAIConfigName:
return "openai-config"
case SystemSettingTelegramBotTokenName:
return "telegram-bot-token"
case SystemSettingMemoDisplayWithUpdatedTsName:
return "memo-display-with-updated-ts"
}
return ""
} }
type SystemSetting struct { type SystemSetting struct {
...@@ -96,7 +66,7 @@ type SystemSetting struct { ...@@ -96,7 +66,7 @@ type SystemSetting struct {
Description string `json:"description"` Description string `json:"description"`
} }
type SystemSettingUpsert struct { type UpsertSystemSettingRequest struct {
Name SystemSettingName `json:"name"` Name SystemSettingName `json:"name"`
Value string `json:"value"` Value string `json:"value"`
Description string `json:"description"` Description string `json:"description"`
...@@ -104,7 +74,7 @@ type SystemSettingUpsert struct { ...@@ -104,7 +74,7 @@ type SystemSettingUpsert struct {
const systemSettingUnmarshalError = `failed to unmarshal value from system setting "%v"` const systemSettingUnmarshalError = `failed to unmarshal value from system setting "%v"`
func (upsert SystemSettingUpsert) Validate() error { func (upsert UpsertSystemSettingRequest) Validate() error {
switch settingName := upsert.Name; settingName { switch settingName := upsert.Name; settingName {
case SystemSettingServerIDName: case SystemSettingServerIDName:
return fmt.Errorf("updating %v is not allowed", settingName) return fmt.Errorf("updating %v is not allowed", settingName)
...@@ -157,11 +127,6 @@ func (upsert SystemSettingUpsert) Validate() error { ...@@ -157,11 +127,6 @@ func (upsert SystemSettingUpsert) Validate() error {
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil { if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName) return fmt.Errorf(systemSettingUnmarshalError, settingName)
} }
case SystemSettingOpenAIConfigName:
value := OpenAIConfig{}
if err := json.Unmarshal([]byte(upsert.Value), &value); err != nil {
return fmt.Errorf(systemSettingUnmarshalError, settingName)
}
case SystemSettingTelegramBotTokenName: case SystemSettingTelegramBotTokenName:
if upsert.Value == "" { if upsert.Value == "" {
return nil return nil
...@@ -189,6 +154,77 @@ func (upsert SystemSettingUpsert) Validate() error { ...@@ -189,6 +154,77 @@ func (upsert SystemSettingUpsert) Validate() error {
return nil return nil
} }
type SystemSettingFind struct { func (s *APIV1Service) registerSystemSettingRoutes(g *echo.Group) {
Name SystemSettingName `json:"name"` g.POST("/system/setting", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
systemSettingUpsert := &UpsertSystemSettingRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(systemSettingUpsert); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post system setting request").SetInternal(err)
}
if err := systemSettingUpsert.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "invalid system setting").SetInternal(err)
}
systemSetting, err := s.Store.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: systemSettingUpsert.Name.String(),
Value: systemSettingUpsert.Value,
Description: systemSettingUpsert.Description,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert system setting").SetInternal(err)
}
return c.JSON(http.StatusOK, convertSystemSettingFromStore(systemSetting))
})
g.GET("/system/setting", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
}
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
}
if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
}
list, err := s.Store.ListSystemSettings(ctx, &store.FindSystemSetting{})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting list").SetInternal(err)
}
systemSettingList := make([]*SystemSetting, 0, len(list))
for _, systemSetting := range list {
systemSettingList = append(systemSettingList, convertSystemSettingFromStore(systemSetting))
}
return c.JSON(http.StatusOK, systemSettingList)
})
}
func convertSystemSettingFromStore(systemSetting *store.SystemSetting) *SystemSetting {
return &SystemSetting{
Name: SystemSettingName(systemSetting.Name),
Value: systemSetting.Value,
Description: systemSetting.Description,
}
} }
package server package v1
import ( import (
"encoding/json" "encoding/json"
...@@ -7,16 +7,26 @@ import ( ...@@ -7,16 +7,26 @@ import (
"regexp" "regexp"
"sort" "sort"
"github.com/labstack/echo/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/labstack/echo/v4"
) )
func (s *Server) registerTagRoutes(g *echo.Group) { type Tag struct {
Name string
CreatorID int
}
type UpsertTagRequest struct {
Name string `json:"name"`
}
type DeleteTagRequest struct {
Name string `json:"name"`
}
func (s *APIV1Service) registerTagRoutes(g *echo.Group) {
g.POST("/tag", func(c echo.Context) error { g.POST("/tag", func(c echo.Context) error {
ctx := c.Request().Context() ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int) userID, ok := c.Get(getUserIDContextKey()).(int)
...@@ -24,7 +34,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) { ...@@ -24,7 +34,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
tagUpsert := &api.TagUpsert{} tagUpsert := &UpsertTagRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(tagUpsert); err != nil { if err := json.NewDecoder(c.Request().Body).Decode(tagUpsert); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post tag request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post tag request").SetInternal(err)
} }
...@@ -32,15 +42,18 @@ func (s *Server) registerTagRoutes(g *echo.Group) { ...@@ -32,15 +42,18 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Tag name shouldn't be empty") return echo.NewHTTPError(http.StatusBadRequest, "Tag name shouldn't be empty")
} }
tagUpsert.CreatorID = userID tag, err := s.Store.UpsertTagV1(ctx, &store.Tag{
tag, err := s.Store.UpsertTag(ctx, tagUpsert) Name: tagUpsert.Name,
CreatorID: userID,
})
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert tag").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert tag").SetInternal(err)
} }
if err := s.createTagCreateActivity(c, tag); err != nil { tagMessage := convertTagFromStore(tag)
if err := s.createTagCreateActivity(c, tagMessage); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(tag.Name)) return c.JSON(http.StatusOK, tagMessage.Name)
}) })
g.GET("/tag", func(c echo.Context) error { g.GET("/tag", func(c echo.Context) error {
...@@ -50,19 +63,18 @@ func (s *Server) registerTagRoutes(g *echo.Group) { ...@@ -50,19 +63,18 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag") return echo.NewHTTPError(http.StatusBadRequest, "Missing user id to find tag")
} }
tagFind := &api.TagFind{ list, err := s.Store.ListTags(ctx, &store.FindTag{
CreatorID: userID, CreatorID: userID,
} })
tagList, err := s.Store.FindTagList(ctx, tagFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err)
} }
tagNameList := []string{} tagNameList := []string{}
for _, tag := range tagList { for _, tag := range list {
tagNameList = append(tagNameList, tag.Name) tagNameList = append(tagNameList, tag.Name)
} }
return c.JSON(http.StatusOK, composeResponse(tagNameList)) return c.JSON(http.StatusOK, tagNameList)
}) })
g.GET("/tag/suggestion", func(c echo.Context) error { g.GET("/tag/suggestion", func(c echo.Context) error {
...@@ -83,15 +95,14 @@ func (s *Server) registerTagRoutes(g *echo.Group) { ...@@ -83,15 +95,14 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find memo list").SetInternal(err)
} }
tagFind := &api.TagFind{ list, err := s.Store.ListTags(ctx, &store.FindTag{
CreatorID: userID, CreatorID: userID,
} })
existTagList, err := s.Store.FindTagList(ctx, tagFind)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find tag list").SetInternal(err)
} }
tagNameList := []string{} tagNameList := []string{}
for _, tag := range existTagList { for _, tag := range list {
tagNameList = append(tagNameList, tag.Name) tagNameList = append(tagNameList, tag.Name)
} }
...@@ -108,7 +119,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) { ...@@ -108,7 +119,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
tagList = append(tagList, tag) tagList = append(tagList, tag)
} }
sort.Strings(tagList) sort.Strings(tagList)
return c.JSON(http.StatusOK, composeResponse(tagList)) return c.JSON(http.StatusOK, tagList)
}) })
g.POST("/tag/delete", func(c echo.Context) error { g.POST("/tag/delete", func(c echo.Context) error {
...@@ -118,7 +129,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) { ...@@ -118,7 +129,7 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
tagDelete := &api.TagDelete{} tagDelete := &DeleteTagRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(tagDelete); err != nil { if err := json.NewDecoder(c.Request().Body).Decode(tagDelete); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post tag request").SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post tag request").SetInternal(err)
} }
...@@ -126,17 +137,45 @@ func (s *Server) registerTagRoutes(g *echo.Group) { ...@@ -126,17 +137,45 @@ func (s *Server) registerTagRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, "Tag name shouldn't be empty") return echo.NewHTTPError(http.StatusBadRequest, "Tag name shouldn't be empty")
} }
tagDelete.CreatorID = userID err := s.Store.DeleteTag(ctx, &store.DeleteTag{
if err := s.Store.DeleteTag(ctx, tagDelete); err != nil { Name: tagDelete.Name,
if common.ErrorCode(err) == common.NotFound { CreatorID: userID,
return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Tag name not found: %s", tagDelete.Name)) })
} if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete tag name: %v", tagDelete.Name)).SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to delete tag name: %v", tagDelete.Name)).SetInternal(err)
} }
return c.JSON(http.StatusOK, true) return c.JSON(http.StatusOK, true)
}) })
} }
func (s *APIV1Service) createTagCreateActivity(c echo.Context, tag *Tag) error {
ctx := c.Request().Context()
payload := ActivityTagCreatePayload{
TagName: tag.Name,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return errors.Wrap(err, "failed to marshal activity payload")
}
activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: tag.CreatorID,
Type: ActivityTagCreate.String(),
Level: ActivityInfo.String(),
Payload: string(payloadBytes),
})
if err != nil || activity == nil {
return errors.Wrap(err, "failed to create activity")
}
return err
}
func convertTagFromStore(tag *store.Tag) *Tag {
return &Tag{
Name: tag.Name,
CreatorID: tag.CreatorID,
}
}
var tagRegexp = regexp.MustCompile(`#([^\s#]+)`) var tagRegexp = regexp.MustCompile(`#([^\s#]+)`)
func findTagListFromMemoContent(memoContent string) []string { func findTagListFromMemoContent(memoContent string) []string {
...@@ -154,24 +193,3 @@ func findTagListFromMemoContent(memoContent string) []string { ...@@ -154,24 +193,3 @@ func findTagListFromMemoContent(memoContent string) []string {
sort.Strings(tagList) sort.Strings(tagList)
return tagList return tagList
} }
func (s *Server) createTagCreateActivity(c echo.Context, tag *api.Tag) error {
ctx := c.Request().Context()
payload := api.ActivityTagCreatePayload{
TagName: tag.Name,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return errors.Wrap(err, "failed to marshal activity payload")
}
activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{
CreatorID: tag.CreatorID,
Type: api.ActivityTagCreate,
Level: api.ActivityInfo,
Payload: string(payloadBytes),
})
if err != nil || activity == nil {
return errors.Wrap(err, "failed to create activity")
}
return err
}
package server package v1
import ( import (
"testing" "testing"
......
package v1
import "github.com/labstack/echo/v4"
func (*APIV1Service) registerTestRoutes(g *echo.Group) {
g.GET("/test", func(c echo.Context) error {
return c.String(200, "Hello World")
})
}
This diff is collapsed.
...@@ -3,8 +3,10 @@ package v1 ...@@ -3,8 +3,10 @@ package v1
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/store"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
...@@ -63,19 +65,18 @@ var ( ...@@ -63,19 +65,18 @@ var (
) )
type UserSetting struct { type UserSetting struct {
UserID int UserID int `json:"userId"`
Key UserSettingKey `json:"key"` Key UserSettingKey `json:"key"`
// Value is a JSON string with basic value Value string `json:"value"`
Value string `json:"value"`
} }
type UserSettingUpsert struct { type UpsertUserSettingRequest struct {
UserID int `json:"-"` UserID int `json:"-"`
Key UserSettingKey `json:"key"` Key UserSettingKey `json:"key"`
Value string `json:"value"` Value string `json:"value"`
} }
func (upsert UserSettingUpsert) Validate() error { func (upsert UpsertUserSettingRequest) Validate() error {
if upsert.Key == UserSettingLocaleKey { if upsert.Key == UserSettingLocaleKey {
localeValue := "en" localeValue := "en"
err := json.Unmarshal([]byte(upsert.Value), &localeValue) err := json.Unmarshal([]byte(upsert.Value), &localeValue)
...@@ -104,18 +105,11 @@ func (upsert UserSettingUpsert) Validate() error { ...@@ -104,18 +105,11 @@ func (upsert UserSettingUpsert) Validate() error {
return fmt.Errorf("invalid user setting memo visibility value") return fmt.Errorf("invalid user setting memo visibility value")
} }
} else if upsert.Key == UserSettingTelegramUserIDKey { } else if upsert.Key == UserSettingTelegramUserIDKey {
var s string var key string
err := json.Unmarshal([]byte(upsert.Value), &s) err := json.Unmarshal([]byte(upsert.Value), &key)
if err != nil { if err != nil {
return fmt.Errorf("invalid user setting telegram user id value") return fmt.Errorf("invalid user setting telegram user id value")
} }
if s == "" {
return nil
}
if _, err := strconv.Atoi(s); err != nil {
return fmt.Errorf("invalid user setting telegram user id value")
}
} else { } else {
return fmt.Errorf("invalid user setting key") return fmt.Errorf("invalid user setting key")
} }
...@@ -123,12 +117,41 @@ func (upsert UserSettingUpsert) Validate() error { ...@@ -123,12 +117,41 @@ func (upsert UserSettingUpsert) Validate() error {
return nil return nil
} }
type UserSettingFind struct { func (s *APIV1Service) registerUserSettingRoutes(g *echo.Group) {
UserID *int g.POST("/user/setting", func(c echo.Context) error {
ctx := c.Request().Context()
userID, ok := c.Get(getUserIDContextKey()).(int)
if !ok {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session")
}
Key UserSettingKey `json:"key"` userSettingUpsert := &UpsertUserSettingRequest{}
if err := json.NewDecoder(c.Request().Body).Decode(userSettingUpsert); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post user setting upsert request").SetInternal(err)
}
if err := userSettingUpsert.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid user setting format").SetInternal(err)
}
userSettingUpsert.UserID = userID
userSetting, err := s.Store.UpsertUserSetting(ctx, &store.UserSetting{
UserID: userID,
Key: userSettingUpsert.Key.String(),
Value: userSettingUpsert.Value,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to upsert user setting").SetInternal(err)
}
userSettingMessage := convertUserSettingFromStore(userSetting)
return c.JSON(http.StatusOK, userSettingMessage)
})
} }
type UserSettingDelete struct { func convertUserSettingFromStore(userSetting *store.UserSetting) *UserSetting {
UserID int return &UserSetting{
UserID: userSetting.UserID,
Key: UserSettingKey(userSetting.Key),
Value: userSetting.Value,
}
} }
...@@ -25,7 +25,12 @@ func (s *APIV1Service) Register(rootGroup *echo.Group) { ...@@ -25,7 +25,12 @@ func (s *APIV1Service) Register(rootGroup *echo.Group) {
apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc { apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, s.Secret) return JWTMiddleware(s, next, s.Secret)
}) })
s.registerTestRoutes(apiV1Group) s.registerSystemRoutes(apiV1Group)
s.registerSystemSettingRoutes(apiV1Group)
s.registerAuthRoutes(apiV1Group) s.registerAuthRoutes(apiV1Group)
s.registerIdentityProviderRoutes(apiV1Group) s.registerIdentityProviderRoutes(apiV1Group)
s.registerUserRoutes(apiV1Group)
s.registerUserSettingRoutes(apiV1Group)
s.registerTagRoutes(apiV1Group)
s.registerShortcutRoutes(apiV1Group)
} }
...@@ -4,8 +4,8 @@ import ( ...@@ -4,8 +4,8 @@ import (
"net/http" "net/http"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/store"
) )
type response struct { type response struct {
...@@ -39,10 +39,9 @@ func (s *Server) defaultAuthSkipper(c echo.Context) bool { ...@@ -39,10 +39,9 @@ func (s *Server) defaultAuthSkipper(c echo.Context) bool {
// If there is openId in query string and related user is found, then skip auth. // If there is openId in query string and related user is found, then skip auth.
openID := c.QueryParam("openId") openID := c.QueryParam("openId")
if openID != "" { if openID != "" {
userFind := &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
OpenID: &openID, OpenID: &openID,
} })
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil && common.ErrorCode(err) != common.NotFound {
return false return false
} }
......
...@@ -81,11 +81,6 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha ...@@ -81,11 +81,6 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
return next(c) return next(c)
} }
// Skip validation for server status endpoints.
if common.HasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet {
return next(c)
}
token := findAccessToken(c) token := findAccessToken(c)
if token == "" { if token == "" {
// Allow the user to access the public endpoints. // Allow the user to access the public endpoints.
...@@ -93,7 +88,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha ...@@ -93,7 +88,7 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
return next(c) return next(c)
} }
// When the request is not authenticated, we allow the user to access the memo endpoints for those public memos. // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos.
if common.HasPrefixes(path, "/api/status", "/api/memo") && method == http.MethodGet { if common.HasPrefixes(path, "/api/memo") && method == http.MethodGet {
return next(c) return next(c)
} }
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
......
...@@ -60,10 +60,10 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { ...@@ -60,10 +60,10 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
} }
// Find disable public memos system setting. // Find disable public memos system setting.
disablePublicMemosSystemSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ disablePublicMemosSystemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: api.SystemSettingDisablePublicMemosName, Name: apiv1.SystemSettingDisablePublicMemosName.String(),
}) })
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find system setting").SetInternal(err)
} }
if disablePublicMemosSystemSetting != nil { if disablePublicMemosSystemSetting != nil {
...@@ -73,14 +73,14 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { ...@@ -73,14 +73,14 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal system setting").SetInternal(err)
} }
if disablePublicMemos { if disablePublicMemos {
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
// Enforce normal user to create private memo if public memos are disabled. // Enforce normal user to create private memo if public memos are disabled.
if user.Role == "USER" { if user.Role == store.RoleUser {
createMemoRequest.Visibility = api.Private createMemoRequest.Visibility = api.Private
} }
} }
...@@ -91,7 +91,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { ...@@ -91,7 +91,7 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create memo").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create memo").SetInternal(err)
} }
if err := createMemoCreateActivity(c.Request().Context(), s.Store, memoMessage); err != nil { if err := s.createMemoCreateActivity(ctx, memoMessage); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
} }
...@@ -561,8 +561,8 @@ func (s *Server) registerMemoRoutes(g *echo.Group) { ...@@ -561,8 +561,8 @@ func (s *Server) registerMemoRoutes(g *echo.Group) {
}) })
} }
func createMemoCreateActivity(ctx context.Context, store *store.Store, memo *store.MemoMessage) error { func (s *Server) createMemoCreateActivity(ctx context.Context, memo *store.MemoMessage) error {
payload := api.ActivityMemoCreatePayload{ payload := apiv1.ActivityMemoCreatePayload{
Content: memo.Content, Content: memo.Content,
Visibility: memo.Visibility.String(), Visibility: memo.Visibility.String(),
} }
...@@ -570,10 +570,10 @@ func createMemoCreateActivity(ctx context.Context, store *store.Store, memo *sto ...@@ -570,10 +570,10 @@ func createMemoCreateActivity(ctx context.Context, store *store.Store, memo *sto
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := store.CreateActivity(ctx, &api.ActivityCreate{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: memo.CreatorID, CreatorID: memo.CreatorID,
Type: api.ActivityMemoCreate, Type: apiv1.ActivityMemoCreate.String(),
Level: api.ActivityInfo, Level: apiv1.ActivityInfo.String(),
Payload: string(payloadBytes), Payload: string(payloadBytes),
}) })
if err != nil || activity == nil { if err != nil || activity == nil {
...@@ -654,7 +654,7 @@ func (s *Server) composeMemoMessageToMemoResponse(ctx context.Context, memoMessa ...@@ -654,7 +654,7 @@ func (s *Server) composeMemoMessageToMemoResponse(ctx context.Context, memoMessa
} }
// Compose creator name. // Compose creator name.
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &memoResponse.CreatorID, ID: &memoResponse.CreatorID,
}) })
if err != nil { if err != nil {
...@@ -699,10 +699,10 @@ func (s *Server) composeMemoMessageToMemoResponse(ctx context.Context, memoMessa ...@@ -699,10 +699,10 @@ func (s *Server) composeMemoMessageToMemoResponse(ctx context.Context, memoMessa
} }
func (s *Server) getMemoDisplayWithUpdatedTsSettingValue(ctx context.Context) (bool, error) { func (s *Server) getMemoDisplayWithUpdatedTsSettingValue(ctx context.Context) (bool, error) {
memoDisplayWithUpdatedTsSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ memoDisplayWithUpdatedTsSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: api.SystemSettingMemoDisplayWithUpdatedTsName, Name: apiv1.SystemSettingMemoDisplayWithUpdatedTsName.String(),
}) })
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil {
return false, errors.Wrap(err, "failed to find system setting") return false, errors.Wrap(err, "failed to find system setting")
} }
memoDisplayWithUpdatedTs := false memoDisplayWithUpdatedTs := false
......
package server
import (
"encoding/json"
"net/http"
"github.com/labstack/echo/v4"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
"github.com/usememos/memos/plugin/openai"
)
func (s *Server) registerOpenAIRoutes(g *echo.Group) {
g.POST("/openai/chat-completion", func(c echo.Context) error {
ctx := c.Request().Context()
openAIConfigSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{
Name: api.SystemSettingOpenAIConfigName,
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find openai key").SetInternal(err)
}
openAIConfig := api.OpenAIConfig{}
if openAIConfigSetting != nil {
err = json.Unmarshal([]byte(openAIConfigSetting.Value), &openAIConfig)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to unmarshal openai system setting value").SetInternal(err)
}
}
if openAIConfig.Key == "" {
return echo.NewHTTPError(http.StatusBadRequest, "OpenAI API key not set")
}
messages := []openai.ChatCompletionMessage{}
if err := json.NewDecoder(c.Request().Body).Decode(&messages); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post chat completion request").SetInternal(err)
}
if len(messages) == 0 {
return echo.NewHTTPError(http.StatusBadRequest, "No messages provided")
}
result, err := openai.PostChatCompletion(messages, openAIConfig.Key, openAIConfig.Host)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to post chat completion").SetInternal(err)
}
return c.JSON(http.StatusOK, composeResponse(result))
})
}
...@@ -22,6 +22,7 @@ import ( ...@@ -22,6 +22,7 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/usememos/memos/api" "github.com/usememos/memos/api"
apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/common/log" "github.com/usememos/memos/common/log"
"github.com/usememos/memos/plugin/storage/s3" "github.com/usememos/memos/plugin/storage/s3"
...@@ -102,7 +103,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { ...@@ -102,7 +103,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err)
} }
if err := createResourceCreateActivity(c.Request().Context(), s.Store, resource); err != nil { if err := s.createResourceCreateActivity(ctx, resource); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(resource)) return c.JSON(http.StatusOK, composeResponse(resource))
...@@ -116,7 +117,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { ...@@ -116,7 +117,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
} }
// This is the backend default max upload size limit. // This is the backend default max upload size limit.
maxUploadSetting := s.Store.GetSystemSettingValueOrDefault(&ctx, api.SystemSettingMaxUploadSizeMiBName, "32") maxUploadSetting := s.Store.GetSystemSettingValueWithDefault(&ctx, apiv1.SystemSettingMaxUploadSizeMiBName.String(), "32")
var settingMaxUploadSizeBytes int var settingMaxUploadSizeBytes int
if settingMaxUploadSizeMiB, err := strconv.Atoi(maxUploadSetting); err == nil { if settingMaxUploadSizeMiB, err := strconv.Atoi(maxUploadSetting); err == nil {
settingMaxUploadSizeBytes = settingMaxUploadSizeMiB * MebiByte settingMaxUploadSizeBytes = settingMaxUploadSizeMiB * MebiByte
...@@ -150,8 +151,8 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { ...@@ -150,8 +151,8 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
defer sourceFile.Close() defer sourceFile.Close()
var resourceCreate *api.ResourceCreate var resourceCreate *api.ResourceCreate
systemSettingStorageServiceID, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{Name: api.SystemSettingStorageServiceIDName}) systemSettingStorageServiceID, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{Name: apiv1.SystemSettingStorageServiceIDName.String()})
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err)
} }
storageServiceID := api.DatabaseStorage storageServiceID := api.DatabaseStorage
...@@ -179,7 +180,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { ...@@ -179,7 +180,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
// filepath.Join() should be used for local file paths, // filepath.Join() should be used for local file paths,
// as it handles the os-specific path separator automatically. // as it handles the os-specific path separator automatically.
// path.Join() always uses '/' as path separator. // path.Join() always uses '/' as path separator.
systemSettingLocalStoragePath, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{Name: api.SystemSettingLocalStoragePathName}) systemSettingLocalStoragePath, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{Name: apiv1.SystemSettingLocalStoragePathName.String()})
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find local storage path setting").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find local storage path setting").SetInternal(err)
} }
...@@ -265,7 +266,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) { ...@@ -265,7 +266,7 @@ func (s *Server) registerResourceRoutes(g *echo.Group) {
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create resource").SetInternal(err)
} }
if err := createResourceCreateActivity(c.Request().Context(), s.Store, resource); err != nil { if err := s.createResourceCreateActivity(ctx, resource); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create activity").SetInternal(err)
} }
return c.JSON(http.StatusOK, composeResponse(resource)) return c.JSON(http.StatusOK, composeResponse(resource))
...@@ -530,8 +531,8 @@ func (s *Server) registerResourcePublicRoutes(g *echo.Group) { ...@@ -530,8 +531,8 @@ func (s *Server) registerResourcePublicRoutes(g *echo.Group) {
}) })
} }
func createResourceCreateActivity(ctx context.Context, store *store.Store, resource *api.Resource) error { func (s *Server) createResourceCreateActivity(ctx context.Context, resource *api.Resource) error {
payload := api.ActivityResourceCreatePayload{ payload := apiv1.ActivityResourceCreatePayload{
Filename: resource.Filename, Filename: resource.Filename,
Type: resource.Type, Type: resource.Type,
Size: resource.Size, Size: resource.Size,
...@@ -540,10 +541,10 @@ func createResourceCreateActivity(ctx context.Context, store *store.Store, resou ...@@ -540,10 +541,10 @@ func createResourceCreateActivity(ctx context.Context, store *store.Store, resou
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := store.CreateActivity(ctx, &api.ActivityCreate{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: resource.CreatorID, CreatorID: resource.CreatorID,
Type: api.ActivityResourceCreate, Type: apiv1.ActivityResourceCreate.String(),
Level: api.ActivityInfo, Level: apiv1.ActivityInfo.String(),
Payload: string(payloadBytes), Payload: string(payloadBytes),
}) })
if err != nil || activity == nil { if err != nil || activity == nil {
......
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"github.com/gorilla/feeds" "github.com/gorilla/feeds"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/usememos/memos/api" "github.com/usememos/memos/api"
apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
"github.com/yuin/goldmark" "github.com/yuin/goldmark"
...@@ -80,7 +81,7 @@ func (s *Server) registerRSSRoutes(g *echo.Group) { ...@@ -80,7 +81,7 @@ func (s *Server) registerRSSRoutes(g *echo.Group) {
const MaxRSSItemCount = 100 const MaxRSSItemCount = 100
const MaxRSSItemTitleLength = 100 const MaxRSSItemTitleLength = 100
func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.MemoMessage, baseURL string, profile *api.CustomizedProfile) (string, error) { func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.MemoMessage, baseURL string, profile *apiv1.CustomizedProfile) (string, error) {
feed := &feeds.Feed{ feed := &feeds.Feed{
Title: profile.Name, Title: profile.Name,
Link: &feeds.Link{Href: baseURL}, Link: &feeds.Link{Href: baseURL},
...@@ -126,15 +127,14 @@ func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store. ...@@ -126,15 +127,14 @@ func (s *Server) generateRSSFromMemoList(ctx context.Context, memoList []*store.
return rss, nil return rss, nil
} }
func (s *Server) getSystemCustomizedProfile(ctx context.Context) (*api.CustomizedProfile, error) { func (s *Server) getSystemCustomizedProfile(ctx context.Context) (*apiv1.CustomizedProfile, error) {
systemSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{ systemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: api.SystemSettingCustomizedProfileName, Name: apiv1.SystemSettingCustomizedProfileName.String(),
}) })
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil {
return nil, err return nil, err
} }
customizedProfile := &apiv1.CustomizedProfile{
customizedProfile := &api.CustomizedProfile{
Name: "memos", Name: "memos",
LogoURL: "", LogoURL: "",
Description: "", Description: "",
......
...@@ -6,9 +6,10 @@ import ( ...@@ -6,9 +6,10 @@ import (
"fmt" "fmt"
"time" "time"
"github.com/google/uuid"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/usememos/memos/api" apiv1 "github.com/usememos/memos/api/v1"
apiV1 "github.com/usememos/memos/api/v1" "github.com/usememos/memos/common"
"github.com/usememos/memos/plugin/telegram" "github.com/usememos/memos/plugin/telegram"
"github.com/usememos/memos/server/profile" "github.com/usememos/memos/server/profile"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
...@@ -97,18 +98,13 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store ...@@ -97,18 +98,13 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, s.Secret) return JWTMiddleware(s, next, s.Secret)
}) })
s.registerSystemRoutes(apiGroup)
s.registerUserRoutes(apiGroup)
s.registerMemoRoutes(apiGroup) s.registerMemoRoutes(apiGroup)
s.registerMemoResourceRoutes(apiGroup) s.registerMemoResourceRoutes(apiGroup)
s.registerShortcutRoutes(apiGroup)
s.registerResourceRoutes(apiGroup) s.registerResourceRoutes(apiGroup)
s.registerTagRoutes(apiGroup)
s.registerStorageRoutes(apiGroup) s.registerStorageRoutes(apiGroup)
s.registerOpenAIRoutes(apiGroup)
s.registerMemoRelationRoutes(apiGroup) s.registerMemoRelationRoutes(apiGroup)
apiV1Service := apiV1.NewAPIV1Service(s.Secret, profile, store) apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store)
apiV1Service.Register(rootGroup) apiV1Service.Register(rootGroup)
return s, nil return s, nil
...@@ -145,8 +141,46 @@ func (s *Server) GetEcho() *echo.Echo { ...@@ -145,8 +141,46 @@ func (s *Server) GetEcho() *echo.Echo {
return s.e return s.e
} }
func (s *Server) getSystemServerID(ctx context.Context) (string, error) {
serverIDSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: apiv1.SystemSettingServerIDName.String(),
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return "", err
}
if serverIDSetting == nil || serverIDSetting.Value == "" {
serverIDSetting, err = s.Store.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: apiv1.SystemSettingServerIDName.String(),
Value: uuid.NewString(),
})
if err != nil {
return "", err
}
}
return serverIDSetting.Value, nil
}
func (s *Server) getSystemSecretSessionName(ctx context.Context) (string, error) {
secretSessionNameValue, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{
Name: apiv1.SystemSettingSecretSessionName.String(),
})
if err != nil && common.ErrorCode(err) != common.NotFound {
return "", err
}
if secretSessionNameValue == nil || secretSessionNameValue.Value == "" {
secretSessionNameValue, err = s.Store.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: apiv1.SystemSettingSecretSessionName.String(),
Value: uuid.NewString(),
})
if err != nil {
return "", err
}
}
return secretSessionNameValue.Value, nil
}
func (s *Server) createServerStartActivity(ctx context.Context) error { func (s *Server) createServerStartActivity(ctx context.Context) error {
payload := api.ActivityServerStartPayload{ payload := apiv1.ActivityServerStartPayload{
ServerID: s.ID, ServerID: s.ID,
Profile: s.Profile, Profile: s.Profile,
} }
...@@ -154,10 +188,10 @@ func (s *Server) createServerStartActivity(ctx context.Context) error { ...@@ -154,10 +188,10 @@ func (s *Server) createServerStartActivity(ctx context.Context) error {
if err != nil { if err != nil {
return errors.Wrap(err, "failed to marshal activity payload") return errors.Wrap(err, "failed to marshal activity payload")
} }
activity, err := s.Store.CreateActivity(ctx, &api.ActivityCreate{ activity, err := s.Store.CreateActivity(ctx, &store.ActivityMessage{
CreatorID: api.UnknownID, CreatorID: apiv1.UnknownID,
Type: api.ActivityServerStart, Type: apiv1.ActivityServerStart.String(),
Level: api.ActivityInfo, Level: apiv1.ActivityInfo.String(),
Payload: string(payloadBytes), Payload: string(payloadBytes),
}) })
if err != nil || activity == nil { if err != nil || activity == nil {
......
...@@ -8,7 +8,9 @@ import ( ...@@ -8,7 +8,9 @@ import (
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/usememos/memos/api" "github.com/usememos/memos/api"
apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/store"
) )
func (s *Server) registerStorageRoutes(g *echo.Group) { func (s *Server) registerStorageRoutes(g *echo.Group) {
...@@ -19,13 +21,13 @@ func (s *Server) registerStorageRoutes(g *echo.Group) { ...@@ -19,13 +21,13 @@ func (s *Server) registerStorageRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != api.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
...@@ -48,13 +50,13 @@ func (s *Server) registerStorageRoutes(g *echo.Group) { ...@@ -48,13 +50,13 @@ func (s *Server) registerStorageRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != api.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
...@@ -84,14 +86,14 @@ func (s *Server) registerStorageRoutes(g *echo.Group) { ...@@ -84,14 +86,14 @@ func (s *Server) registerStorageRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
// We should only show storage list to host user. // We should only show storage list to host user.
if user == nil || user.Role != api.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
...@@ -109,13 +111,13 @@ func (s *Server) registerStorageRoutes(g *echo.Group) { ...@@ -109,13 +111,13 @@ func (s *Server) registerStorageRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session")
} }
user, err := s.Store.FindUser(ctx, &api.UserFind{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err)
} }
if user == nil || user.Role != api.Host { if user == nil || user.Role != store.RoleHost {
return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized")
} }
...@@ -124,8 +126,8 @@ func (s *Server) registerStorageRoutes(g *echo.Group) { ...@@ -124,8 +126,8 @@ func (s *Server) registerStorageRoutes(g *echo.Group) {
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("storageId"))).SetInternal(err) return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("storageId"))).SetInternal(err)
} }
systemSetting, err := s.Store.FindSystemSetting(ctx, &api.SystemSettingFind{Name: api.SystemSettingStorageServiceIDName}) systemSetting, err := s.Store.GetSystemSetting(ctx, &store.FindSystemSetting{Name: apiv1.SystemSettingStorageServiceIDName.String()})
if err != nil && common.ErrorCode(err) != common.NotFound { if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find storage").SetInternal(err)
} }
if systemSetting != nil { if systemSetting != nil {
......
...@@ -24,7 +24,7 @@ func newTelegramHandler(store *store.Store) *telegramHandler { ...@@ -24,7 +24,7 @@ func newTelegramHandler(store *store.Store) *telegramHandler {
} }
func (t *telegramHandler) BotToken(ctx context.Context) string { func (t *telegramHandler) BotToken(ctx context.Context) string {
return t.store.GetSystemSettingValueOrDefault(&ctx, api.SystemSettingTelegramBotTokenName, "") return t.store.GetSystemSettingValueWithDefault(&ctx, apiv1.SystemSettingTelegramBotTokenName.String(), "")
} }
const ( const (
...@@ -80,11 +80,6 @@ func (t *telegramHandler) MessageHandle(ctx context.Context, bot *telegram.Bot, ...@@ -80,11 +80,6 @@ func (t *telegramHandler) MessageHandle(ctx context.Context, bot *telegram.Bot,
return err return err
} }
if err := createMemoCreateActivity(ctx, t.store, memoMessage); err != nil {
_, err := bot.EditMessage(ctx, message.Chat.ID, reply.MessageID, fmt.Sprintf("failed to createMemoCreateActivity: %s", err), nil)
return err
}
// create resources // create resources
for filename, blob := range blobs { for filename, blob := range blobs {
// TODO support more // TODO support more
...@@ -108,10 +103,6 @@ func (t *telegramHandler) MessageHandle(ctx context.Context, bot *telegram.Bot, ...@@ -108,10 +103,6 @@ func (t *telegramHandler) MessageHandle(ctx context.Context, bot *telegram.Bot,
_, err := bot.EditMessage(ctx, message.Chat.ID, reply.MessageID, fmt.Sprintf("failed to CreateResource: %s", err), nil) _, err := bot.EditMessage(ctx, message.Chat.ID, reply.MessageID, fmt.Sprintf("failed to CreateResource: %s", err), nil)
return err return err
} }
if err := createResourceCreateActivity(ctx, t.store, resource); err != nil {
_, err := bot.EditMessage(ctx, message.Chat.ID, reply.MessageID, fmt.Sprintf("failed to createResourceCreateActivity: %s", err), nil)
return err
}
_, err = t.store.UpsertMemoResource(ctx, &api.MemoResourceUpsert{ _, err = t.store.UpsertMemoResource(ctx, &api.MemoResourceUpsert{
MemoID: memoMessage.ID, MemoID: memoMessage.ID,
......
This diff is collapsed.
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common" "github.com/usememos/memos/common"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
...@@ -33,10 +32,8 @@ func (s setupService) Setup(ctx context.Context, hostUsername, hostPassword stri ...@@ -33,10 +32,8 @@ func (s setupService) Setup(ctx context.Context, hostUsername, hostPassword stri
} }
func (s setupService) makeSureHostUserNotExists(ctx context.Context) error { func (s setupService) makeSureHostUserNotExists(ctx context.Context) error {
hostUserType := api.Host hostUserType := store.RoleHost
existedHostUsers, err := s.store.FindUserList(ctx, &api.UserFind{ existedHostUsers, err := s.store.ListUsers(ctx, &store.FindUser{Role: &hostUserType})
Role: &hostUserType,
})
if err != nil { if err != nil {
return fmt.Errorf("find user list: %w", err) return fmt.Errorf("find user list: %w", err)
} }
...@@ -52,7 +49,7 @@ func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword ...@@ -52,7 +49,7 @@ func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword
userCreate := &store.User{ userCreate := &store.User{
Username: hostUsername, Username: hostUsername,
// The new signup user should be normal user by default. // The new signup user should be normal user by default.
Role: store.Host, Role: store.RoleHost,
Nickname: hostUsername, Nickname: hostUsername,
OpenID: common.GenUUID(), OpenID: common.GenUUID(),
} }
...@@ -87,7 +84,7 @@ func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword ...@@ -87,7 +84,7 @@ func (s setupService) createUser(ctx context.Context, hostUsername, hostPassword
} }
userCreate.PasswordHash = string(passwordHash) userCreate.PasswordHash = string(passwordHash)
if _, err := s.store.CreateUserV1(ctx, userCreate); err != nil { if _, err := s.store.CreateUser(ctx, userCreate); err != nil {
return fmt.Errorf("failed to create user: %w", err) return fmt.Errorf("failed to create user: %w", err)
} }
......
...@@ -2,9 +2,6 @@ package store ...@@ -2,9 +2,6 @@ package store
import ( import (
"context" "context"
"database/sql"
"github.com/usememos/memos/api"
) )
type ActivityMessage struct { type ActivityMessage struct {
...@@ -20,8 +17,8 @@ type ActivityMessage struct { ...@@ -20,8 +17,8 @@ type ActivityMessage struct {
Payload string Payload string
} }
// CreateActivityV1 creates an instance of Activity. // CreateActivity creates an instance of Activity.
func (s *Store) CreateActivityV1(ctx context.Context, create *ActivityMessage) (*ActivityMessage, error) { func (s *Store) CreateActivity(ctx context.Context, create *ActivityMessage) (*ActivityMessage, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, FormatError(err)
...@@ -51,80 +48,3 @@ func (s *Store) CreateActivityV1(ctx context.Context, create *ActivityMessage) ( ...@@ -51,80 +48,3 @@ func (s *Store) CreateActivityV1(ctx context.Context, create *ActivityMessage) (
activityMessage := create activityMessage := create
return activityMessage, nil return activityMessage, nil
} }
// activityRaw is the store model for an Activity.
// Fields have exactly the same meanings as Activity.
type activityRaw struct {
ID int
// Standard fields
CreatorID int
CreatedTs int64
// Domain specific fields
Type api.ActivityType
Level api.ActivityLevel
Payload string
}
// toActivity creates an instance of Activity based on the ActivityRaw.
func (raw *activityRaw) toActivity() *api.Activity {
return &api.Activity{
ID: raw.ID,
CreatorID: raw.CreatorID,
CreatedTs: raw.CreatedTs,
Type: raw.Type,
Level: raw.Level,
Payload: raw.Payload,
}
}
// CreateActivity creates an instance of Activity.
func (s *Store) CreateActivity(ctx context.Context, create *api.ActivityCreate) (*api.Activity, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
activityRaw, err := createActivity(ctx, tx, create)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
activity := activityRaw.toActivity()
return activity, nil
}
// createActivity creates a new activity.
func createActivity(ctx context.Context, tx *sql.Tx, create *api.ActivityCreate) (*activityRaw, error) {
query := `
INSERT INTO activity (
creator_id,
type,
level,
payload
)
VALUES (?, ?, ?, ?)
RETURNING id, type, level, payload, creator_id, created_ts
`
var activityRaw activityRaw
if err := tx.QueryRowContext(ctx, query, create.CreatorID, create.Type, create.Level, create.Payload).Scan(
&activityRaw.ID,
&activityRaw.Type,
&activityRaw.Level,
&activityRaw.Payload,
&activityRaw.CreatorID,
&activityRaw.CreatedTs,
); err != nil {
return nil, FormatError(err)
}
return &activityRaw, nil
}
...@@ -4,6 +4,6 @@ import ( ...@@ -4,6 +4,6 @@ import (
"fmt" "fmt"
) )
func getUserSettingCacheKeyV1(userID int, key string) string { func getUserSettingCacheKey(userID int, key string) string {
return fmt.Sprintf("%d-%s", userID, key) return fmt.Sprintf("%d-%s", userID, key)
} }
...@@ -11,9 +11,13 @@ import ( ...@@ -11,9 +11,13 @@ import (
type IdentityProviderType string type IdentityProviderType string
const ( const (
IdentityProviderOAuth2 IdentityProviderType = "OAUTH2" IdentityProviderOAuth2Type IdentityProviderType = "OAUTH2"
) )
func (t IdentityProviderType) String() string {
return string(t)
}
type IdentityProviderConfig struct { type IdentityProviderConfig struct {
OAuth2Config *IdentityProviderOAuth2Config OAuth2Config *IdentityProviderOAuth2Config
} }
...@@ -66,7 +70,7 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv ...@@ -66,7 +70,7 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv
defer tx.Rollback() defer tx.Rollback()
var configBytes []byte var configBytes []byte
if create.Type == IdentityProviderOAuth2 { if create.Type == IdentityProviderOAuth2Type {
configBytes, err = json.Marshal(create.Config.OAuth2Config) configBytes, err = json.Marshal(create.Config.OAuth2Config)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -167,7 +171,7 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti ...@@ -167,7 +171,7 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti
} }
if v := update.Config; v != nil { if v := update.Config; v != nil {
var configBytes []byte var configBytes []byte
if update.Type == IdentityProviderOAuth2 { if update.Type == IdentityProviderOAuth2Type {
configBytes, err = json.Marshal(update.Config.OAuth2Config) configBytes, err = json.Marshal(update.Config.OAuth2Config)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -197,7 +201,7 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti ...@@ -197,7 +201,7 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti
return nil, err return nil, err
} }
if identityProvider.Type == IdentityProviderOAuth2 { if identityProvider.Type == IdentityProviderOAuth2Type {
oauth2Config := &IdentityProviderOAuth2Config{} oauth2Config := &IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err return nil, err
...@@ -279,7 +283,7 @@ func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityPr ...@@ -279,7 +283,7 @@ func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityPr
return nil, err return nil, err
} }
if identityProvider.Type == IdentityProviderOAuth2 { if identityProvider.Type == IdentityProviderOAuth2Type {
oauth2Config := &IdentityProviderOAuth2Config{} oauth2Config := &IdentityProviderOAuth2Config{}
if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil {
return nil, err return nil, err
......
...@@ -3,20 +3,14 @@ package store ...@@ -3,20 +3,14 @@ package store
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings" "strings"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
) )
// shortcutRaw is the store model for an Shortcut. type Shortcut struct {
// Fields have exactly the same meanings as Shortcut.
type shortcutRaw struct {
ID int ID int
// Standard fields // Standard fields
RowStatus api.RowStatus RowStatus RowStatus
CreatorID int CreatorID int
CreatedTs int64 CreatedTs int64
UpdatedTs int64 UpdatedTs int64
...@@ -26,176 +20,115 @@ type shortcutRaw struct { ...@@ -26,176 +20,115 @@ type shortcutRaw struct {
Payload string Payload string
} }
func (raw *shortcutRaw) toShortcut() *api.Shortcut { type UpdateShortcut struct {
return &api.Shortcut{ ID int
ID: raw.ID,
RowStatus: raw.RowStatus,
CreatorID: raw.CreatorID,
CreatedTs: raw.CreatedTs,
UpdatedTs: raw.UpdatedTs,
Title: raw.Title, UpdatedTs *int64
Payload: raw.Payload, RowStatus *RowStatus
} Title *string
Payload *string
} }
func (s *Store) CreateShortcut(ctx context.Context, create *api.ShortcutCreate) (*api.Shortcut, error) { type FindShortcut struct {
tx, err := s.db.BeginTx(ctx, nil) ID *int
if err != nil { CreatorID *int
return nil, FormatError(err) Title *string
} }
defer tx.Rollback()
shortcutRaw, err := createShortcut(ctx, tx, create)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, FormatError(err)
}
s.shortcutCache.Store(shortcutRaw.ID, shortcutRaw)
shortcut := shortcutRaw.toShortcut()
return shortcut, nil type DeleteShortcut struct {
ID *int
CreatorID *int
} }
func (s *Store) PatchShortcut(ctx context.Context, patch *api.ShortcutPatch) (*api.Shortcut, error) { func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
shortcutRaw, err := patchShortcut(ctx, tx, patch) query := `
if err != nil { INSERT INTO shortcut (
title,
payload,
creator_id
)
VALUES (?, ?, ?)
RETURNING id, created_ts, updated_ts, row_status
`
if err := tx.QueryRowContext(ctx, query, create.Title, create.Payload, create.CreatorID).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err return nil, err
} }
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return nil, FormatError(err) return nil, err
} }
s.shortcutCache.Store(shortcutRaw.ID, shortcutRaw) shortcut := create
shortcut := shortcutRaw.toShortcut()
return shortcut, nil return shortcut, nil
} }
func (s *Store) FindShortcutList(ctx context.Context, find *api.ShortcutFind) ([]*api.Shortcut, error) { func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
shortcutRawList, err := findShortcutList(ctx, tx, find) list, err := listShortcuts(ctx, tx, find)
if err != nil { if err != nil {
return nil, err return nil, err
} }
list := []*api.Shortcut{}
for _, raw := range shortcutRawList {
list = append(list, raw.toShortcut())
}
return list, nil return list, nil
} }
func (s *Store) FindShortcut(ctx context.Context, find *api.ShortcutFind) (*api.Shortcut, error) { func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, error) {
if find.ID != nil {
if shortcut, ok := s.shortcutCache.Load(*find.ID); ok {
return shortcut.(*shortcutRaw).toShortcut(), nil
}
}
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
list, err := findShortcutList(ctx, tx, find) list, err := listShortcuts(ctx, tx, find)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(list) == 0 { if len(list) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} return nil, nil
} }
shortcutRaw := list[0] shortcut := list[0]
s.shortcutCache.Store(shortcutRaw.ID, shortcutRaw)
shortcut := shortcutRaw.toShortcut()
return shortcut, nil return shortcut, nil
} }
func (s *Store) DeleteShortcut(ctx context.Context, delete *api.ShortcutDelete) error { func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
err = deleteShortcut(ctx, tx, delete)
if err != nil {
return FormatError(err)
}
if err := tx.Commit(); err != nil {
return FormatError(err)
}
s.shortcutCache.Delete(*delete.ID)
return nil
}
func createShortcut(ctx context.Context, tx *sql.Tx, create *api.ShortcutCreate) (*shortcutRaw, error) {
query := `
INSERT INTO shortcut (
title,
payload,
creator_id
)
VALUES (?, ?, ?)
RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status
`
var shortcutRaw shortcutRaw
if err := tx.QueryRowContext(ctx, query, create.Title, create.Payload, create.CreatorID).Scan(
&shortcutRaw.ID,
&shortcutRaw.Title,
&shortcutRaw.Payload,
&shortcutRaw.CreatorID,
&shortcutRaw.CreatedTs,
&shortcutRaw.UpdatedTs,
&shortcutRaw.RowStatus,
); err != nil {
return nil, FormatError(err)
}
return &shortcutRaw, nil
}
func patchShortcut(ctx context.Context, tx *sql.Tx, patch *api.ShortcutPatch) (*shortcutRaw, error) {
set, args := []string{}, []any{} set, args := []string{}, []any{}
if v := update.UpdatedTs; v != nil {
if v := patch.UpdatedTs; v != nil {
set, args = append(set, "updated_ts = ?"), append(args, *v) set, args = append(set, "updated_ts = ?"), append(args, *v)
} }
if v := patch.Title; v != nil { if v := update.Title; v != nil {
set, args = append(set, "title = ?"), append(args, *v) set, args = append(set, "title = ?"), append(args, *v)
} }
if v := patch.Payload; v != nil { if v := update.Payload; v != nil {
set, args = append(set, "payload = ?"), append(args, *v) set, args = append(set, "payload = ?"), append(args, *v)
} }
if v := patch.RowStatus; v != nil { if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v) set, args = append(set, "row_status = ?"), append(args, *v)
} }
args = append(args, update.ID)
args = append(args, patch.ID)
query := ` query := `
UPDATE shortcut UPDATE shortcut
...@@ -203,23 +136,55 @@ func patchShortcut(ctx context.Context, tx *sql.Tx, patch *api.ShortcutPatch) (* ...@@ -203,23 +136,55 @@ func patchShortcut(ctx context.Context, tx *sql.Tx, patch *api.ShortcutPatch) (*
WHERE id = ? WHERE id = ?
RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status RETURNING id, title, payload, creator_id, created_ts, updated_ts, row_status
` `
var shortcutRaw shortcutRaw shortcut := &Shortcut{}
if err := tx.QueryRowContext(ctx, query, args...).Scan( if err := tx.QueryRowContext(ctx, query, args...).Scan(
&shortcutRaw.ID, &shortcut.ID,
&shortcutRaw.Title, &shortcut.Title,
&shortcutRaw.Payload, &shortcut.Payload,
&shortcutRaw.CreatorID, &shortcut.CreatorID,
&shortcutRaw.CreatedTs, &shortcut.CreatedTs,
&shortcutRaw.UpdatedTs, &shortcut.UpdatedTs,
&shortcutRaw.RowStatus, &shortcut.RowStatus,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return shortcut, nil
}
func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
where, args := []string{}, []any{}
if v := delete.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := delete.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
stmt := `DELETE FROM shortcut WHERE ` + strings.Join(where, " AND ")
if _, err := tx.ExecContext(ctx, stmt, args...); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err
} }
return &shortcutRaw, nil s.shortcutCache.Delete(*delete.ID)
return nil
} }
func findShortcutList(ctx context.Context, tx *sql.Tx, find *api.ShortcutFind) ([]*shortcutRaw, error) { func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shortcut, error) {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil { if v := find.ID; v != nil {
...@@ -251,53 +216,28 @@ func findShortcutList(ctx context.Context, tx *sql.Tx, find *api.ShortcutFind) ( ...@@ -251,53 +216,28 @@ func findShortcutList(ctx context.Context, tx *sql.Tx, find *api.ShortcutFind) (
} }
defer rows.Close() defer rows.Close()
shortcutRawList := make([]*shortcutRaw, 0) list := make([]*Shortcut, 0)
for rows.Next() { for rows.Next() {
var shortcutRaw shortcutRaw var shortcut Shortcut
if err := rows.Scan( if err := rows.Scan(
&shortcutRaw.ID, &shortcut.ID,
&shortcutRaw.Title, &shortcut.Title,
&shortcutRaw.Payload, &shortcut.Payload,
&shortcutRaw.CreatorID, &shortcut.CreatorID,
&shortcutRaw.CreatedTs, &shortcut.CreatedTs,
&shortcutRaw.UpdatedTs, &shortcut.UpdatedTs,
&shortcutRaw.RowStatus, &shortcut.RowStatus,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
list = append(list, &shortcut)
shortcutRawList = append(shortcutRawList, &shortcutRaw)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
return shortcutRawList, nil return list, nil
}
func deleteShortcut(ctx context.Context, tx *sql.Tx, delete *api.ShortcutDelete) error {
where, args := []string{}, []any{}
if v := delete.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := delete.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
stmt := `DELETE FROM shortcut WHERE ` + strings.Join(where, " AND ")
result, err := tx.ExecContext(ctx, stmt, args...)
if err != nil {
return FormatError(err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("shortcut not found")}
}
return nil
} }
func vacuumShortcut(ctx context.Context, tx *sql.Tx) error { func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {
......
...@@ -12,9 +12,8 @@ import ( ...@@ -12,9 +12,8 @@ import (
type Store struct { type Store struct {
Profile *profile.Profile Profile *profile.Profile
db *sql.DB db *sql.DB
systemSettingCache sync.Map // map[string]*systemSettingRaw systemSettingCache sync.Map // map[string]*SystemSetting
userCache sync.Map // map[int]*userRaw userCache sync.Map // map[int]*User
userV1Cache sync.Map // map[string]*User
userSettingCache sync.Map // map[string]*UserSetting userSettingCache sync.Map // map[string]*UserSetting
shortcutCache sync.Map // map[int]*shortcutRaw shortcutCache sync.Map // map[int]*shortcutRaw
idpCache sync.Map // map[int]*IdentityProvider idpCache sync.Map // map[int]*IdentityProvider
...@@ -36,7 +35,7 @@ func (s *Store) GetDB() *sql.DB { ...@@ -36,7 +35,7 @@ func (s *Store) GetDB() *sql.DB {
func (s *Store) Vacuum(ctx context.Context) error { func (s *Store) Vacuum(ctx context.Context) error {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return FormatError(err) return err
} }
defer tx.Rollback() defer tx.Rollback()
...@@ -45,7 +44,7 @@ func (s *Store) Vacuum(ctx context.Context) error { ...@@ -45,7 +44,7 @@ func (s *Store) Vacuum(ctx context.Context) error {
} }
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return FormatError(err) return err
} }
// Vacuum sqlite database file size after deleting resource. // Vacuum sqlite database file size after deleting resource.
......
...@@ -3,11 +3,7 @@ package store ...@@ -3,11 +3,7 @@ package store
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strings" "strings"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
) )
type SystemSetting struct { type SystemSetting struct {
...@@ -20,10 +16,39 @@ type FindSystemSetting struct { ...@@ -20,10 +16,39 @@ type FindSystemSetting struct {
Name string Name string
} }
func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) (*SystemSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
INSERT INTO system_setting (
name, value, description
)
VALUES (?, ?, ?)
ON CONFLICT(name) DO UPDATE
SET
value = EXCLUDED.value,
description = EXCLUDED.description
`
if _, err := tx.ExecContext(ctx, query, upsert.Name, upsert.Value, upsert.Description); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
systemSetting := upsert
return systemSetting, nil
}
func (s *Store) ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) { func (s *Store) ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
...@@ -47,7 +72,7 @@ func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) ( ...@@ -47,7 +72,7 @@ func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) (
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
...@@ -65,6 +90,15 @@ func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) ( ...@@ -65,6 +90,15 @@ func (s *Store) GetSystemSetting(ctx context.Context, find *FindSystemSetting) (
return systemSettingMessage, nil return systemSettingMessage, nil
} }
func (s *Store) GetSystemSettingValueWithDefault(ctx *context.Context, settingName string, defaultValue string) string {
if setting, err := s.GetSystemSetting(*ctx, &FindSystemSetting{
Name: settingName,
}); err == nil && setting != nil {
return setting.Value
}
return defaultValue
}
func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting) ([]*SystemSetting, error) { func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting) ([]*SystemSetting, error) {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if find.Name != "" { if find.Name != "" {
...@@ -81,7 +115,7 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting ...@@ -81,7 +115,7 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting
rows, err := tx.QueryContext(ctx, query, args...) rows, err := tx.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer rows.Close() defer rows.Close()
...@@ -93,7 +127,7 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting ...@@ -93,7 +127,7 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting
&systemSettingMessage.Value, &systemSettingMessage.Value,
&systemSettingMessage.Description, &systemSettingMessage.Description,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, err
} }
list = append(list, systemSettingMessage) list = append(list, systemSettingMessage)
} }
...@@ -104,160 +138,3 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting ...@@ -104,160 +138,3 @@ func listSystemSettings(ctx context.Context, tx *sql.Tx, find *FindSystemSetting
return list, nil return list, nil
} }
type systemSettingRaw struct {
Name api.SystemSettingName
Value string
Description string
}
func (raw *systemSettingRaw) toSystemSetting() *api.SystemSetting {
return &api.SystemSetting{
Name: raw.Name,
Value: raw.Value,
Description: raw.Description,
}
}
func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *api.SystemSettingUpsert) (*api.SystemSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
systemSettingRaw, err := upsertSystemSetting(ctx, tx, upsert)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
systemSetting := systemSettingRaw.toSystemSetting()
s.systemSettingCache.Store(systemSettingRaw.Name, systemSettingRaw)
return systemSetting, nil
}
func (s *Store) FindSystemSettingList(ctx context.Context, find *api.SystemSettingFind) ([]*api.SystemSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
systemSettingRawList, err := findSystemSettingList(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
list := []*api.SystemSetting{}
for _, raw := range systemSettingRawList {
s.systemSettingCache.Store(raw.Name, raw)
list = append(list, raw.toSystemSetting())
}
return list, nil
}
func (s *Store) FindSystemSetting(ctx context.Context, find *api.SystemSettingFind) (*api.SystemSetting, error) {
if systemSetting, ok := s.systemSettingCache.Load(find.Name); ok {
systemSettingRaw := systemSetting.(*systemSettingRaw)
return systemSettingRaw.toSystemSetting(), nil
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, FormatError(err)
}
defer tx.Rollback()
systemSettingRawList, err := findSystemSettingList(ctx, tx, find)
if err != nil {
return nil, err
}
if len(systemSettingRawList) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")}
}
systemSettingRaw := systemSettingRawList[0]
s.systemSettingCache.Store(systemSettingRaw.Name, systemSettingRaw)
return systemSettingRaw.toSystemSetting(), nil
}
func (s *Store) GetSystemSettingValueOrDefault(ctx *context.Context, find api.SystemSettingName, defaultValue string) string {
if setting, err := s.FindSystemSetting(*ctx, &api.SystemSettingFind{
Name: find,
}); err == nil {
return setting.Value
}
return defaultValue
}
func upsertSystemSetting(ctx context.Context, tx *sql.Tx, upsert *api.SystemSettingUpsert) (*systemSettingRaw, error) {
query := `
INSERT INTO system_setting (
name, value, description
)
VALUES (?, ?, ?)
ON CONFLICT(name) DO UPDATE
SET
value = EXCLUDED.value,
description = EXCLUDED.description
RETURNING name, value, description
`
var systemSettingRaw systemSettingRaw
if err := tx.QueryRowContext(ctx, query, upsert.Name, upsert.Value, upsert.Description).Scan(
&systemSettingRaw.Name,
&systemSettingRaw.Value,
&systemSettingRaw.Description,
); err != nil {
return nil, FormatError(err)
}
return &systemSettingRaw, nil
}
func findSystemSettingList(ctx context.Context, tx *sql.Tx, find *api.SystemSettingFind) ([]*systemSettingRaw, error) {
where, args := []string{"1 = 1"}, []any{}
if find.Name.String() != "" {
where, args = append(where, "name = ?"), append(args, find.Name.String())
}
query := `
SELECT
name,
value,
description
FROM system_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, FormatError(err)
}
defer rows.Close()
systemSettingRawList := make([]*systemSettingRaw, 0)
for rows.Next() {
var systemSettingRaw systemSettingRaw
if err := rows.Scan(
&systemSettingRaw.Name,
&systemSettingRaw.Value,
&systemSettingRaw.Description,
); err != nil {
return nil, FormatError(err)
}
systemSettingRawList = append(systemSettingRawList, &systemSettingRaw)
}
if err := rows.Err(); err != nil {
return nil, FormatError(err)
}
return systemSettingRawList, nil
}
...@@ -5,32 +5,39 @@ import ( ...@@ -5,32 +5,39 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"strings" "strings"
"github.com/usememos/memos/api"
"github.com/usememos/memos/common"
) )
type tagRaw struct { type Tag struct {
Name string Name string
CreatorID int CreatorID int
} }
func (raw *tagRaw) toTag() *api.Tag { type FindTag struct {
return &api.Tag{ CreatorID int
Name: raw.Name,
CreatorID: raw.CreatorID,
}
} }
func (s *Store) UpsertTag(ctx context.Context, upsert *api.TagUpsert) (*api.Tag, error) { type DeleteTag struct {
Name string
CreatorID int
}
func (s *Store) UpsertTagV1(ctx context.Context, upsert *Tag) (*Tag, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
defer tx.Rollback() defer tx.Rollback()
tagRaw, err := upsertTag(ctx, tx, upsert) query := `
if err != nil { INSERT INTO tag (
name, creator_id
)
VALUES (?, ?)
ON CONFLICT(name, creator_id) DO UPDATE
SET
name = EXCLUDED.name
`
if _, err := tx.ExecContext(ctx, query, upsert.Name, upsert.CreatorID); err != nil {
return nil, err return nil, err
} }
...@@ -38,74 +45,18 @@ func (s *Store) UpsertTag(ctx context.Context, upsert *api.TagUpsert) (*api.Tag, ...@@ -38,74 +45,18 @@ func (s *Store) UpsertTag(ctx context.Context, upsert *api.TagUpsert) (*api.Tag,
return nil, err return nil, err
} }
tag := tagRaw.toTag() tag := upsert
return tag, nil return tag, nil
} }
func (s *Store) FindTagList(ctx context.Context, find *api.TagFind) ([]*api.Tag, error) { func (s *Store) ListTags(ctx context.Context, find *FindTag) ([]*Tag, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
defer tx.Rollback() defer tx.Rollback()
tagRawList, err := findTagList(ctx, tx, find)
if err != nil {
return nil, err
}
list := []*api.Tag{}
for _, raw := range tagRawList {
list = append(list, raw.toTag())
}
return list, nil
}
func (s *Store) DeleteTag(ctx context.Context, delete *api.TagDelete) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
}
defer tx.Rollback()
if err := deleteTag(ctx, tx, delete); err != nil {
return FormatError(err)
}
if err := tx.Commit(); err != nil {
return FormatError(err)
}
return nil
}
func upsertTag(ctx context.Context, tx *sql.Tx, upsert *api.TagUpsert) (*tagRaw, error) {
query := `
INSERT INTO tag (
name, creator_id
)
VALUES (?, ?)
ON CONFLICT(name, creator_id) DO UPDATE
SET
name = EXCLUDED.name
RETURNING name, creator_id
`
var tagRaw tagRaw
if err := tx.QueryRowContext(ctx, query, upsert.Name, upsert.CreatorID).Scan(
&tagRaw.Name,
&tagRaw.CreatorID,
); err != nil {
return nil, FormatError(err)
}
return &tagRaw, nil
}
func findTagList(ctx context.Context, tx *sql.Tx, find *api.TagFind) ([]*tagRaw, error) {
where, args := []string{"creator_id = ?"}, []any{find.CreatorID} where, args := []string{"creator_id = ?"}, []any{find.CreatorID}
query := ` query := `
SELECT SELECT
name, name,
...@@ -120,38 +71,48 @@ func findTagList(ctx context.Context, tx *sql.Tx, find *api.TagFind) ([]*tagRaw, ...@@ -120,38 +71,48 @@ func findTagList(ctx context.Context, tx *sql.Tx, find *api.TagFind) ([]*tagRaw,
} }
defer rows.Close() defer rows.Close()
tagRawList := make([]*tagRaw, 0) list := []*Tag{}
for rows.Next() { for rows.Next() {
var tagRaw tagRaw tag := &Tag{}
if err := rows.Scan( if err := rows.Scan(
&tagRaw.Name, &tag.Name,
&tagRaw.CreatorID, &tag.CreatorID,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
tagRawList = append(tagRawList, &tagRaw) list = append(list, tag)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, FormatError(err) return nil, FormatError(err)
} }
return tagRawList, nil return list, nil
} }
func deleteTag(ctx context.Context, tx *sql.Tx, delete *api.TagDelete) error { func (s *Store) DeleteTag(ctx context.Context, delete *DeleteTag) error {
where, args := []string{"name = ?", "creator_id = ?"}, []any{delete.Name, delete.CreatorID} tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return FormatError(err)
}
defer tx.Rollback()
stmt := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ") where, args := []string{"name = ?", "creator_id = ?"}, []any{delete.Name, delete.CreatorID}
result, err := tx.ExecContext(ctx, stmt, args...) query := `DELETE FROM tag WHERE ` + strings.Join(where, " AND ")
result, err := tx.ExecContext(ctx, query, args...)
if err != nil { if err != nil {
return FormatError(err) return FormatError(err)
} }
rows, _ := result.RowsAffected() rows, _ := result.RowsAffected()
if rows == 0 { if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("tag not found")} return fmt.Errorf("tag not found")
}
if err := tx.Commit(); err != nil {
// Prevent linter warning.
return err
} }
return nil return nil
......
This diff is collapsed.
...@@ -20,7 +20,7 @@ type FindUserSetting struct { ...@@ -20,7 +20,7 @@ type FindUserSetting struct {
func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) { func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
...@@ -41,14 +41,14 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*Us ...@@ -41,14 +41,14 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*Us
} }
userSetting := upsert userSetting := upsert
s.userSettingCache.Store(getUserSettingCacheKeyV1(userSetting.UserID, userSetting.Key), userSetting) s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
return userSetting, nil return userSetting, nil
} }
func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) { func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) {
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
...@@ -58,21 +58,21 @@ func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([] ...@@ -58,21 +58,21 @@ func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]
} }
for _, userSetting := range userSettingList { for _, userSetting := range userSettingList {
s.userSettingCache.Store(getUserSettingCacheKeyV1(userSetting.UserID, userSetting.Key), userSetting) s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
} }
return userSettingList, nil return userSettingList, nil
} }
func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) { func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) {
if find.UserID != nil { if find.UserID != nil {
if cache, ok := s.userSettingCache.Load(getUserSettingCacheKeyV1(*find.UserID, find.Key)); ok { if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key)); ok {
return cache.(*UserSetting), nil return cache.(*UserSetting), nil
} }
} }
tx, err := s.db.BeginTx(ctx, nil) tx, err := s.db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
...@@ -84,8 +84,9 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*Use ...@@ -84,8 +84,9 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*Use
if len(list) == 0 { if len(list) == 0 {
return nil, nil return nil, nil
} }
userSetting := list[0] userSetting := list[0]
s.userSettingCache.Store(getUserSettingCacheKeyV1(userSetting.UserID, userSetting.Key), userSetting) s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting)
return userSetting, nil return userSetting, nil
} }
...@@ -108,7 +109,7 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([ ...@@ -108,7 +109,7 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([
WHERE ` + strings.Join(where, " AND ") WHERE ` + strings.Join(where, " AND ")
rows, err := tx.QueryContext(ctx, query, args...) rows, err := tx.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, FormatError(err) return nil, err
} }
defer rows.Close() defer rows.Close()
...@@ -120,13 +121,13 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([ ...@@ -120,13 +121,13 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([
&userSetting.Key, &userSetting.Key,
&userSetting.Value, &userSetting.Value,
); err != nil { ); err != nil {
return nil, FormatError(err) return nil, err
} }
userSettingList = append(userSettingList, &userSetting) userSettingList = append(userSettingList, &userSetting)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, FormatError(err) return nil, err
} }
return userSettingList, nil return userSettingList, nil
...@@ -145,7 +146,7 @@ func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { ...@@ -145,7 +146,7 @@ func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
)` )`
_, err := tx.ExecContext(ctx, stmt) _, err := tx.ExecContext(ctx, stmt)
if err != nil { if err != nil {
return FormatError(err) return err
} }
return nil return nil
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/usememos/memos/api"
apiv1 "github.com/usememos/memos/api/v1" apiv1 "github.com/usememos/memos/api/v1"
) )
...@@ -27,7 +26,7 @@ func TestAuthServer(t *testing.T) { ...@@ -27,7 +26,7 @@ func TestAuthServer(t *testing.T) {
require.Equal(t, signup.Username, user.Username) require.Equal(t, signup.Username, user.Username)
} }
func (s *TestingServer) postAuthSignup(signup *apiv1.SignUp) (*api.User, error) { func (s *TestingServer) postAuthSignup(signup *apiv1.SignUp) (*apiv1.User, error) {
rawData, err := json.Marshal(&signup) rawData, err := json.Marshal(&signup)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to marshal signup") return nil, errors.Wrap(err, "failed to marshal signup")
...@@ -44,7 +43,7 @@ func (s *TestingServer) postAuthSignup(signup *apiv1.SignUp) (*api.User, error) ...@@ -44,7 +43,7 @@ func (s *TestingServer) postAuthSignup(signup *apiv1.SignUp) (*api.User, error)
return nil, errors.Wrap(err, "fail to read response body") return nil, errors.Wrap(err, "fail to read response body")
} }
user := &api.User{} user := &apiv1.User{}
if err = json.Unmarshal(buf.Bytes(), user); err != nil { if err = json.Unmarshal(buf.Bytes(), user); err != nil {
return nil, errors.Wrap(err, "fail to unmarshal post signup response") return nil, errors.Wrap(err, "fail to unmarshal post signup response")
} }
......
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/usememos/memos/api"
apiv1 "github.com/usememos/memos/api/v1" apiv1 "github.com/usememos/memos/api/v1"
) )
...@@ -20,7 +19,7 @@ func TestSystemServer(t *testing.T) { ...@@ -20,7 +19,7 @@ func TestSystemServer(t *testing.T) {
status, err := s.getSystemStatus() status, err := s.getSystemStatus()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, (*api.User)(nil), status.Host) require.Equal(t, (*apiv1.User)(nil), status.Host)
signup := &apiv1.SignUp{ signup := &apiv1.SignUp{
Username: "testuser", Username: "testuser",
...@@ -36,8 +35,8 @@ func TestSystemServer(t *testing.T) { ...@@ -36,8 +35,8 @@ func TestSystemServer(t *testing.T) {
require.Equal(t, user.Username, status.Host.Username) require.Equal(t, user.Username, status.Host.Username)
} }
func (s *TestingServer) getSystemStatus() (*api.SystemStatus, error) { func (s *TestingServer) getSystemStatus() (*apiv1.SystemStatus, error) {
body, err := s.get("/api/status", nil) body, err := s.get("/api/v1/status", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -48,12 +47,9 @@ func (s *TestingServer) getSystemStatus() (*api.SystemStatus, error) { ...@@ -48,12 +47,9 @@ func (s *TestingServer) getSystemStatus() (*api.SystemStatus, error) {
return nil, errors.Wrap(err, "fail to read response body") return nil, errors.Wrap(err, "fail to read response body")
} }
type SystemStatusResponse struct { systemStatus := &apiv1.SystemStatus{}
Data *api.SystemStatus `json:"data"` if err = json.Unmarshal(buf.Bytes(), systemStatus); err != nil {
}
res := new(SystemStatusResponse)
if err = json.Unmarshal(buf.Bytes(), res); err != nil {
return nil, errors.Wrap(err, "fail to unmarshal get system status response") return nil, errors.Wrap(err, "fail to unmarshal get system status response")
} }
return res.Data, nil return systemStatus, nil
} }
...@@ -14,7 +14,7 @@ func TestIdentityProviderStore(t *testing.T) { ...@@ -14,7 +14,7 @@ func TestIdentityProviderStore(t *testing.T) {
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProvider{ createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProvider{
Name: "GitHub OAuth", Name: "GitHub OAuth",
Type: store.IdentityProviderOAuth2, Type: store.IdentityProviderOAuth2Type,
IdentifierFilter: "", IdentifierFilter: "",
Config: &store.IdentityProviderConfig{ Config: &store.IdentityProviderConfig{
OAuth2Config: &store.IdentityProviderOAuth2Config{ OAuth2Config: &store.IdentityProviderOAuth2Config{
......
...@@ -6,30 +6,30 @@ import ( ...@@ -6,30 +6,30 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/usememos/memos/api" apiv1 "github.com/usememos/memos/api/v1"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
func TestSystemSettingStore(t *testing.T) { func TestSystemSettingStore(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
_, err := ts.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ _, err := ts.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: api.SystemSettingServerIDName, Name: apiv1.SystemSettingServerIDName.String(),
Value: "test_server_id", Value: "test_server_id",
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = ts.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ _, err = ts.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: api.SystemSettingSecretSessionName, Name: apiv1.SystemSettingSecretSessionName.String(),
Value: "test_secret_session_name", Value: "test_secret_session_name",
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = ts.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ _, err = ts.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: api.SystemSettingAllowSignUpName, Name: apiv1.SystemSettingAllowSignUpName.String(),
Value: "true", Value: "true",
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = ts.UpsertSystemSetting(ctx, &api.SystemSettingUpsert{ _, err = ts.UpsertSystemSetting(ctx, &store.SystemSetting{
Name: api.SystemSettingLocalStoragePathName, Name: apiv1.SystemSettingLocalStoragePathName.String(),
Value: "/tmp/memos", Value: "/tmp/memos",
}) })
require.NoError(t, err) require.NoError(t, err)
......
...@@ -13,15 +13,21 @@ func TestUserSettingStore(t *testing.T) { ...@@ -13,15 +13,21 @@ func TestUserSettingStore(t *testing.T) {
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
user, err := createTestingHostUser(ctx, ts) user, err := createTestingHostUser(ctx, ts)
require.NoError(t, err) require.NoError(t, err)
_, err = ts.UpsertUserSetting(ctx, &store.UserSetting{ testSetting, err := ts.UpsertUserSetting(ctx, &store.UserSetting{
UserID: user.ID, UserID: user.ID,
Key: "test_key", Key: "test_key",
Value: "test_value", Value: "test_value",
}) })
require.NoError(t, err) require.NoError(t, err)
localeSetting, err := ts.UpsertUserSetting(ctx, &store.UserSetting{
UserID: user.ID,
Key: "locale",
Value: "zh",
})
require.NoError(t, err)
list, err := ts.ListUserSettings(ctx, &store.FindUserSetting{}) list, err := ts.ListUserSettings(ctx, &store.FindUserSetting{})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(list)) require.Equal(t, 2, len(list))
require.Equal(t, "test_key", list[0].Key) require.Equal(t, testSetting, list[0])
require.Equal(t, "test_value", list[0].Value) require.Equal(t, localeSetting, list[1])
} }
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/usememos/memos/api"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
...@@ -18,7 +17,7 @@ func TestUserStore(t *testing.T) { ...@@ -18,7 +17,7 @@ func TestUserStore(t *testing.T) {
users, err := ts.ListUsers(ctx, &store.FindUser{}) users, err := ts.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, len(users)) require.Equal(t, 1, len(users))
require.Equal(t, store.Host, users[0].Role) require.Equal(t, store.RoleHost, users[0].Role)
require.Equal(t, user, users[0]) require.Equal(t, user, users[0])
userPatchNickname := "test_nickname_2" userPatchNickname := "test_nickname_2"
userPatch := &store.UpdateUser{ userPatch := &store.UpdateUser{
...@@ -28,7 +27,7 @@ func TestUserStore(t *testing.T) { ...@@ -28,7 +27,7 @@ func TestUserStore(t *testing.T) {
user, err = ts.UpdateUser(ctx, userPatch) user, err = ts.UpdateUser(ctx, userPatch)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, userPatchNickname, user.Nickname) require.Equal(t, userPatchNickname, user.Nickname)
err = ts.DeleteUser(ctx, &api.UserDelete{ err = ts.DeleteUser(ctx, &store.DeleteUser{
ID: user.ID, ID: user.ID,
}) })
require.NoError(t, err) require.NoError(t, err)
...@@ -40,7 +39,7 @@ func TestUserStore(t *testing.T) { ...@@ -40,7 +39,7 @@ func TestUserStore(t *testing.T) {
func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, error) { func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, error) {
userCreate := &store.User{ userCreate := &store.User{
Username: "test", Username: "test",
Role: store.Host, Role: store.RoleHost,
Email: "test@test.com", Email: "test@test.com",
Nickname: "test_nickname", Nickname: "test_nickname",
OpenID: "test_open_id", OpenID: "test_open_id",
...@@ -50,6 +49,6 @@ func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, e ...@@ -50,6 +49,6 @@ func createTestingHostUser(ctx context.Context, ts *store.Store) (*store.User, e
return nil, err return nil, err
} }
userCreate.PasswordHash = string(passwordHash) userCreate.PasswordHash = string(passwordHash)
user, err := ts.CreateUserV1(ctx, userCreate) user, err := ts.CreateUser(ctx, userCreate)
return user, err return user, err
} }
...@@ -31,7 +31,7 @@ const CreateTagDialog: React.FC<Props> = (props: Props) => { ...@@ -31,7 +31,7 @@ const CreateTagDialog: React.FC<Props> = (props: Props) => {
useEffect(() => { useEffect(() => {
getTagSuggestionList().then(({ data }) => { getTagSuggestionList().then(({ data }) => {
setSuggestTagNameList(data.data.filter((tag) => validateTagName(tag))); setSuggestTagNameList(data.filter((tag) => validateTagName(tag)));
}); });
}, [tagNameList]); }, [tagNameList]);
......
...@@ -29,7 +29,7 @@ const PreferencesSection = () => { ...@@ -29,7 +29,7 @@ const PreferencesSection = () => {
}, []); }, []);
const fetchUserList = async () => { const fetchUserList = async () => {
const { data } = (await api.getUserList()).data; const { data } = await api.getUserList();
setUserList(data); setUserList(data);
}; };
......
...@@ -39,7 +39,7 @@ const SystemSection = () => { ...@@ -39,7 +39,7 @@ const SystemSection = () => {
}, []); }, []);
useEffect(() => { useEffect(() => {
api.getSystemSetting().then(({ data: { data: systemSettings } }) => { api.getSystemSetting().then(({ data: systemSettings }) => {
const telegramBotSetting = systemSettings.find((setting) => setting.name === "telegram-bot-token"); const telegramBotSetting = systemSettings.find((setting) => setting.name === "telegram-bot-token");
if (telegramBotSetting) { if (telegramBotSetting) {
setTelegramBotToken(telegramBotSetting.value); setTelegramBotToken(telegramBotSetting.value);
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment