Commit 803d488a authored by Johnny's avatar Johnny

feat(mcp): refactor MCP server to standard protocol structure

- Replace PAT-only auth with optional auth supporting both PAT and JWT
  via auth.Authenticator.Authenticate(); unauthenticated requests see
  only public memos, matching REST API visibility semantics
- Inline auth middleware into mcp.go following fileserver pattern;
  remove auth_middleware.go
- Introduce memoJSON response type that correctly serialises store.Memo
  (including Payload.Tags and Payload.Property) without proto marshalling
- Add tools: list_memo_comments, create_memo_comment, list_tags
- Extend list_memos with state (NORMAL/ARCHIVED), order_by_pinned, and
  page parameters
- Extend update_memo with pinned and state parameters
- Extract #tags from content on create/update via regex to pre-populate
  Payload.Tags without requiring a full markdown service rebuild
- Add MCP Resources: memo://memos/{uid} template returns memo as
  Markdown with YAML frontmatter, allowing clients to read memos by URI
- Add MCP Prompts: capture (save a thought) and review (search + summarise)
parent 16576be1
package mcp
import (
"net/http"
"github.com/labstack/echo/v5"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
func newAuthMiddleware(s *store.Store, secret string) echo.MiddlewareFunc {
authenticator := auth.NewAuthenticator(s, secret)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
token := auth.ExtractBearerToken(c.Request().Header.Get("Authorization"))
if token == "" {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "a personal access token is required"})
}
user, pat, err := authenticator.AuthenticateByPAT(c.Request().Context(), token)
if err != nil || user == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"})
}
ctx := auth.SetUserInContext(c.Request().Context(), user, pat.GetTokenId())
c.SetRequest(c.Request().WithContext(ctx))
return next(c)
}
}
}
package mcp package mcp
import ( import (
"net/http"
"github.com/labstack/echo/v5" "github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware" "github.com/labstack/echo/v5/middleware"
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
type MCPService struct { type MCPService struct {
store *store.Store store *store.Store
secret string authenticator *auth.Authenticator
} }
func NewMCPService(store *store.Store, secret string) *MCPService { func NewMCPService(store *store.Store, secret string) *MCPService {
return &MCPService{store: store, secret: secret} return &MCPService{
store: store,
authenticator: auth.NewAuthenticator(store, secret),
}
} }
func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpSrv := mcpserver.NewMCPServer("Memos", "1.0.0", mcpserver.WithToolCapabilities(false)) mcpSrv := mcpserver.NewMCPServer("Memos", "1.0.0",
mcpserver.WithToolCapabilities(false),
)
s.registerMemoTools(mcpSrv) s.registerMemoTools(mcpSrv)
s.registerTagTools(mcpSrv)
s.registerMemoResources(mcpSrv)
s.registerPrompts(mcpSrv)
httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv) httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv)
...@@ -27,6 +38,19 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { ...@@ -27,6 +38,19 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{ mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"}, AllowOrigins: []string{"*"},
})) }))
mcpGroup.Use(newAuthMiddleware(s.store, s.secret)) mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
authHeader := c.Request().Header.Get("Authorization")
if authHeader != "" {
result := s.authenticator.Authenticate(c.Request().Context(), authHeader)
if result == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired token"})
}
ctx := auth.ApplyToContext(c.Request().Context(), result)
c.SetRequest(c.Request().WithContext(ctx))
}
return next(c)
}
})
mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler)) mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler))
} }
package mcp
import (
"context"
"errors"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
)
func (s *MCPService) registerPrompts(mcpSrv *mcpserver.MCPServer) {
// capture — turns free-form user input into a structured create_memo call.
mcpSrv.AddPrompt(
mcp.NewPrompt("capture",
mcp.WithPromptDescription("Capture a thought, idea, or note as a new memo. "+
"Use this prompt when the user wants to quickly save something. "+
"The assistant will call create_memo with the provided content."),
mcp.WithArgument("content",
mcp.ArgumentDescription("The text to save as a memo"),
mcp.RequiredArgument(),
),
mcp.WithArgument("tags",
mcp.ArgumentDescription("Comma-separated tags to apply, e.g. \"work,project\""),
),
),
s.handleCapturePrompt,
)
// review — surfaces existing memos on a topic for summarisation.
mcpSrv.AddPrompt(
mcp.NewPrompt("review",
mcp.WithPromptDescription("Search and review memos on a given topic. "+
"The assistant will call search_memos and summarise the results."),
mcp.WithArgument("topic",
mcp.ArgumentDescription("Topic or keyword to search for"),
mcp.RequiredArgument(),
),
),
s.handleReviewPrompt,
)
}
func (*MCPService) handleCapturePrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
content := req.Params.Arguments["content"]
if content == "" {
return nil, errors.New("content argument is required")
}
tags := req.Params.Arguments["tags"]
instruction := fmt.Sprintf(
"Please save the following as a new private memo using the create_memo tool.\n\nContent:\n%s",
content,
)
if tags != "" {
instruction += fmt.Sprintf("\n\nAppend these tags inline using #tag syntax: %s", tags)
}
return &mcp.GetPromptResult{
Description: "Capture a memo",
Messages: []mcp.PromptMessage{
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
},
}, nil
}
func (*MCPService) handleReviewPrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
topic := req.Params.Arguments["topic"]
if topic == "" {
return nil, errors.New("topic argument is required")
}
instruction := fmt.Sprintf(
"Please use the search_memos tool to find memos about %q, then provide a concise summary of what has been written on this topic, grouped by theme. Include the memo names so the user can reference them.",
topic,
)
return &mcp.GetPromptResult{
Description: fmt.Sprintf("Review memos about %q", topic),
Messages: []mcp.PromptMessage{
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
},
}, nil
}
package mcp
import (
"context"
"fmt"
"strings"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
// Memo resource URI scheme: memo://memos/{uid}
// Clients can read any memo they have access to by URI without calling a tool.
func (s *MCPService) registerMemoResources(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddResourceTemplate(
mcp.NewResourceTemplate(
"memo://memos/{uid}",
"Memo",
mcp.WithTemplateDescription("A single Memos note identified by its UID. Returns the memo content as Markdown with a YAML frontmatter header containing metadata."),
mcp.WithTemplateMIMEType("text/markdown"),
),
s.handleReadMemoResource,
)
}
func (s *MCPService) handleReadMemoResource(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
userID := auth.GetUserID(ctx)
// URI format: memo://memos/{uid}
uid := strings.TrimPrefix(req.Params.URI, "memo://memos/")
if uid == req.Params.URI || uid == "" {
return nil, errors.Errorf("invalid memo URI %q: expected memo://memos/<uid>", req.Params.URI)
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo")
}
if memo == nil {
return nil, errors.Errorf("memo not found: %s", uid)
}
if err := checkMemoAccess(memo, userID); err != nil {
return nil, err
}
j := storeMemoToJSON(memo)
text := formatMemoMarkdown(j)
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: req.Params.URI,
MIMEType: "text/markdown",
Text: text,
},
}, nil
}
// formatMemoMarkdown renders a memo as Markdown with a YAML frontmatter header.
func formatMemoMarkdown(j memoJSON) string {
var sb strings.Builder
sb.WriteString("---\n")
fmt.Fprintf(&sb, "name: %s\n", j.Name)
fmt.Fprintf(&sb, "creator: %s\n", j.Creator)
fmt.Fprintf(&sb, "visibility: %s\n", j.Visibility)
fmt.Fprintf(&sb, "state: %s\n", j.State)
fmt.Fprintf(&sb, "pinned: %v\n", j.Pinned)
if len(j.Tags) > 0 {
fmt.Fprintf(&sb, "tags: [%s]\n", strings.Join(j.Tags, ", "))
}
fmt.Fprintf(&sb, "create_time: %d\n", j.CreateTime)
fmt.Fprintf(&sb, "update_time: %d\n", j.UpdateTime)
if j.Parent != "" {
fmt.Fprintf(&sb, "parent: %s\n", j.Parent)
}
sb.WriteString("---\n\n")
sb.WriteString(j.Content)
return sb.String()
}
...@@ -3,22 +3,166 @@ package mcp ...@@ -3,22 +3,166 @@ package mcp
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"regexp"
"strings" "strings"
"github.com/lithammer/shortuuid/v4" "github.com/lithammer/shortuuid/v4"
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
// tagRegexp matches #tag patterns in memo content.
// A tag must start with a letter and contain no whitespace or # characters.
var tagRegexp = regexp.MustCompile(`(?:^|\s)#([A-Za-z][^\s#]*)`)
// extractTags does a best-effort extraction of #tags from raw markdown content.
// It is used when creating or updating memos via MCP to pre-populate Payload.Tags.
// The full markdown service may later rebuild a more accurate payload.
func extractTags(content string) []string {
matches := tagRegexp.FindAllStringSubmatch(content, -1)
seen := make(map[string]struct{}, len(matches))
tags := make([]string, 0, len(matches))
for _, m := range matches {
tag := m[1]
if _, ok := seen[tag]; !ok {
seen[tag] = struct{}{}
tags = append(tags, tag)
}
}
return tags
}
// buildPayload constructs a MemoPayload with tags extracted from content.
// Returns nil when no tags are found so the store omits the payload entirely.
func buildPayload(content string) *storepb.MemoPayload {
tags := extractTags(content)
if len(tags) == 0 {
return nil
}
return &storepb.MemoPayload{Tags: tags}
}
// propertyJSON is the serialisable form of MemoPayload.Property.
type propertyJSON struct {
HasLink bool `json:"has_link"`
HasTaskList bool `json:"has_task_list"`
HasCode bool `json:"has_code"`
HasIncompleteTasks bool `json:"has_incomplete_tasks"`
}
// memoJSON is the canonical response shape for all MCP memo results.
// It serialises correctly with standard encoding/json (no proto marshalling needed).
type memoJSON struct {
Name string `json:"name"`
Creator string `json:"creator"`
CreateTime int64 `json:"create_time"`
UpdateTime int64 `json:"update_time"`
Content string `json:"content,omitempty"`
Visibility string `json:"visibility"`
Tags []string `json:"tags"`
Pinned bool `json:"pinned"`
State string `json:"state"`
Property *propertyJSON `json:"property,omitempty"`
Parent string `json:"parent,omitempty"`
}
func storeMemoToJSON(m *store.Memo) memoJSON {
j := memoJSON{
Name: "memos/" + m.UID,
Creator: fmt.Sprintf("users/%d", m.CreatorID),
CreateTime: m.CreatedTs,
UpdateTime: m.UpdatedTs,
Content: m.Content,
Visibility: string(m.Visibility),
Pinned: m.Pinned,
State: string(m.RowStatus),
Tags: []string{},
}
if m.Payload != nil {
if len(m.Payload.Tags) > 0 {
j.Tags = m.Payload.Tags
}
if p := m.Payload.Property; p != nil && (p.HasLink || p.HasTaskList || p.HasCode || p.HasIncompleteTasks) {
j.Property = &propertyJSON{
HasLink: p.HasLink,
HasTaskList: p.HasTaskList,
HasCode: p.HasCode,
HasIncompleteTasks: p.HasIncompleteTasks,
}
}
}
if m.ParentUID != nil {
j.Parent = "memos/" + *m.ParentUID
}
return j
}
// checkMemoAccess returns an error if the caller cannot read memo.
// userID == 0 means anonymous.
func checkMemoAccess(memo *store.Memo, userID int32) error {
switch memo.Visibility {
case store.Protected:
if userID == 0 {
return errors.New("permission denied")
}
case store.Private:
if memo.CreatorID != userID {
return errors.New("permission denied")
}
default:
// store.Public and any unknown visibility: allow
}
return nil
}
// applyVisibilityFilter restricts find to memos the caller may see.
func applyVisibilityFilter(find *store.FindMemo, userID int32) {
if userID == 0 {
find.VisibilityList = []store.Visibility{store.Public}
} else {
find.Filters = append(find.Filters, fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, userID))
}
}
// parseMemoUID extracts the UID from a "memos/<uid>" resource name.
func parseMemoUID(name string) (string, error) {
uid, ok := strings.CutPrefix(name, "memos/")
if !ok || uid == "" {
return "", errors.Errorf(`memo name must be in the format "memos/<uid>", got %q`, name)
}
return uid, nil
}
// parseVisibility validates a visibility string and returns the store constant.
func parseVisibility(s string) (store.Visibility, error) {
switch v := store.Visibility(s); v {
case store.Public, store.Protected, store.Private:
return v, nil
default:
return "", errors.Errorf("visibility must be PRIVATE, PROTECTED, or PUBLIC; got %q", s)
}
}
// parseRowStatus validates a state string and returns the store constant.
func parseRowStatus(s string) (store.RowStatus, error) {
switch rs := store.RowStatus(s); rs {
case store.Normal, store.Archived:
return rs, nil
default:
return "", errors.Errorf("state must be NORMAL or ARCHIVED; got %q", s)
}
}
func extractUserID(ctx context.Context) (int32, error) { func extractUserID(ctx context.Context) (int32, error) {
id := auth.GetUserID(ctx) id := auth.GetUserID(ctx)
if id == 0 { if id == 0 {
return 0, errors.New("unauthenticated") return 0, errors.New("unauthenticated: a personal access token is required")
} }
return id, nil return id, nil
} }
...@@ -32,58 +176,71 @@ func marshalJSON(v any) (string, error) { ...@@ -32,58 +176,71 @@ func marshalJSON(v any) (string, error) {
} }
func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) { func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
listTool := mcp.NewTool("list_memos", mcpSrv.AddTool(mcp.NewTool("list_memos",
mcp.WithDescription("List the authenticated user's memos"), mcp.WithDescription("List memos visible to the caller. Authenticated users see their own memos plus public and protected memos; unauthenticated callers see only public memos."),
mcp.WithNumber("page_size", mcp.Description("Max memos to return, default 20")), mcp.WithNumber("page_size", mcp.Description("Maximum memos to return (1–100, default 20)")),
mcp.WithString("filter", mcp.Description(`CEL filter expression, e.g. content.contains("keyword")`)), mcp.WithNumber("page", mcp.Description("Zero-based page index for pagination (default 0)")),
) mcp.WithString("state",
mcpSrv.AddTool(listTool, s.handleListMemos) mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
getTool := mcp.NewTool("get_memo", ),
mcp.WithDescription("Get a single memo by resource name"), mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
mcp.WithString("filter", mcp.Description(`Optional CEL filter, e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
), s.handleListMemos)
mcpSrv.AddTool(mcp.NewTool("get_memo",
mcp.WithDescription("Get a single memo by resource name. Public memos are accessible without authentication."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
) ), s.handleGetMemo)
mcpSrv.AddTool(getTool, s.handleGetMemo)
createTool := mcp.NewTool("create_memo", mcpSrv.AddTool(mcp.NewTool("create_memo",
mcp.WithDescription("Create a new memo"), mcp.WithDescription("Create a new memo. Requires authentication."),
mcp.WithString("content", mcp.Required(), mcp.Description("Memo content")), mcp.WithString("content", mcp.Required(), mcp.Description("Memo content in Markdown. Use #tag syntax for tagging.")),
mcp.WithString("visibility", mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"), mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("Visibility: PRIVATE (default), PROTECTED, or PUBLIC"), mcp.Description("Visibility (default: PRIVATE)"),
), ),
) ), s.handleCreateMemo)
mcpSrv.AddTool(createTool, s.handleCreateMemo)
updateTool := mcp.NewTool("update_memo", mcpSrv.AddTool(mcp.NewTool("update_memo",
mcp.WithDescription("Update a memo's content or visibility"), mcp.WithDescription("Update a memo's content, visibility, pin state, or archive state. Requires authentication and ownership. Omit any field to leave it unchanged."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Description("New content (omit to leave unchanged)")), mcp.WithString("content", mcp.Description("New Markdown content")),
mcp.WithString("visibility", mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"), mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("New visibility (omit to leave unchanged)"), mcp.Description("New visibility"),
),
mcp.WithBoolean("pinned", mcp.Description("Pin or unpin the memo")),
mcp.WithString("state",
mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Set to ARCHIVED to archive, NORMAL to restore"),
), ),
) ), s.handleUpdateMemo)
mcpSrv.AddTool(updateTool, s.handleUpdateMemo)
deleteTool := mcp.NewTool("delete_memo", mcpSrv.AddTool(mcp.NewTool("delete_memo",
mcp.WithDescription("Delete a memo"), mcp.WithDescription("Permanently delete a memo. Requires authentication and ownership."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
) ), s.handleDeleteMemo)
mcpSrv.AddTool(deleteTool, s.handleDeleteMemo)
mcpSrv.AddTool(mcp.NewTool("search_memos",
searchTool := mcp.NewTool("search_memos", mcp.WithDescription("Search memo content. Authenticated users search their own and visible memos; unauthenticated callers search public memos only."),
mcp.WithDescription("Search memo content using a text query"), mcp.WithString("query", mcp.Required(), mcp.Description("Text to search for in memo content")),
mcp.WithString("query", mcp.Required(), mcp.Description("Text to search in memo content")), ), s.handleSearchMemos)
)
mcpSrv.AddTool(searchTool, s.handleSearchMemos) mcpSrv.AddTool(mcp.NewTool("list_memo_comments",
mcp.WithDescription("List comments on a memo. Visibility rules for comments match those of the parent memo."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
), s.handleListMemoComments)
mcpSrv.AddTool(mcp.NewTool("create_memo_comment",
mcp.WithDescription("Add a comment to a memo. The comment inherits the parent memo's visibility. Requires authentication."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name to comment on, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Required(), mcp.Description("Comment content in Markdown")),
), s.handleCreateMemoComment)
} }
func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx) userID := auth.GetUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
pageSize := req.GetInt("page_size", 20) pageSize := req.GetInt("page_size", 20)
if pageSize <= 0 { if pageSize <= 0 {
...@@ -92,31 +249,54 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques ...@@ -92,31 +249,54 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques
if pageSize > 100 { if pageSize > 100 {
pageSize = 100 pageSize = 100
} }
filterExpr := req.GetString("filter", "") page := req.GetInt("page", 0)
if page < 0 {
page = 0
}
rowStatus := store.Normal var rowStatus *store.RowStatus
limitPlusOne := pageSize + 1 if state := req.GetString("state", "NORMAL"); state != "" {
zero := 0 rs, err := parseRowStatus(state)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
rowStatus = &rs
}
limit := pageSize + 1
offset := page * pageSize
find := &store.FindMemo{ find := &store.FindMemo{
CreatorID: &userID,
ExcludeComments: true, ExcludeComments: true,
RowStatus: &rowStatus, RowStatus: rowStatus,
Limit: &limitPlusOne, Limit: &limit,
Offset: &zero, Offset: &offset,
OrderByPinned: req.GetBool("order_by_pinned", false),
} }
if filterExpr != "" { applyVisibilityFilter(find, userID)
find.Filters = append(find.Filters, filterExpr) if filter := req.GetString("filter", ""); filter != "" {
find.Filters = append(find.Filters, filter)
} }
memos, err := s.store.ListMemos(ctx, find) memos, err := s.store.ListMemos(ctx, find)
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil
} }
if len(memos) == limitPlusOne {
hasMore := len(memos) > pageSize
if hasMore {
memos = memos[:pageSize] memos = memos[:pageSize]
} }
out, err := marshalJSON(memos) results := make([]memoJSON, len(memos))
for i, m := range memos {
results[i] = storeMemoToJSON(m)
}
type listResponse struct {
Memos []memoJSON `json:"memos"`
HasMore bool `json:"has_more"`
}
out, err := marshalJSON(listResponse{Memos: results, HasMore: hasMore})
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -124,20 +304,13 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques ...@@ -124,20 +304,13 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques
} }
func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx) userID := auth.GetUserID(ctx)
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil { if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
name := req.GetString("name", "")
if name == "" {
return mcp.NewToolResultError("name is required"), nil
}
uid, found := strings.CutPrefix(name, "memos/")
if !found || uid == "" {
return mcp.NewToolResultError(`name must be in the format "memos/<uid>"`), nil
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
...@@ -145,11 +318,11 @@ func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) ...@@ -145,11 +318,11 @@ func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest)
if memo == nil { if memo == nil {
return mcp.NewToolResultError("memo not found"), nil return mcp.NewToolResultError("memo not found"), nil
} }
if memo.Visibility == store.Private && memo.CreatorID != userID { if err := checkMemoAccess(memo, userID); err != nil {
return mcp.NewToolResultError("permission denied"), nil return mcp.NewToolResultError(err.Error()), nil
} }
out, err := marshalJSON(memo) out, err := marshalJSON(storeMemoToJSON(memo))
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -166,26 +339,23 @@ func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolReque ...@@ -166,26 +339,23 @@ func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolReque
if content == "" { if content == "" {
return mcp.NewToolResultError("content is required"), nil return mcp.NewToolResultError("content is required"), nil
} }
visibility, err := parseVisibility(req.GetString("visibility", "PRIVATE"))
visibility := req.GetString("visibility", "PRIVATE") if err != nil {
switch visibility { return mcp.NewToolResultError(err.Error()), nil
case "PRIVATE", "PROTECTED", "PUBLIC":
default:
return mcp.NewToolResultError("visibility must be PRIVATE, PROTECTED, or PUBLIC"), nil
} }
create := &store.Memo{ memo, err := s.store.CreateMemo(ctx, &store.Memo{
UID: shortuuid.New(), UID: shortuuid.New(),
CreatorID: userID, CreatorID: userID,
Content: content, Content: content,
Visibility: store.Visibility(visibility), Visibility: visibility,
} Payload: buildPayload(content),
memo, err := s.store.CreateMemo(ctx, create) })
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil
} }
out, err := marshalJSON(memo) out, err := marshalJSON(storeMemoToJSON(memo))
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -198,13 +368,9 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque ...@@ -198,13 +368,9 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
name := req.GetString("name", "") uid, err := parseMemoUID(req.GetString("name", ""))
if name == "" { if err != nil {
return mcp.NewToolResultError("name is required"), nil return mcp.NewToolResultError(err.Error()), nil
}
uid, found := strings.CutPrefix(name, "memos/")
if !found || uid == "" {
return mcp.NewToolResultError(`name must be in the format "memos/<uid>"`), nil
} }
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
...@@ -219,17 +385,29 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque ...@@ -219,17 +385,29 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque
} }
update := &store.UpdateMemo{ID: memo.ID} update := &store.UpdateMemo{ID: memo.ID}
if content := req.GetString("content", ""); content != "" { args := req.GetArguments()
update.Content = &content
if v := req.GetString("content", ""); v != "" {
update.Content = &v
update.Payload = buildPayload(v)
} }
if vis := req.GetString("visibility", ""); vis != "" { if v := req.GetString("visibility", ""); v != "" {
switch vis { vis, err := parseVisibility(v)
case "PRIVATE", "PROTECTED", "PUBLIC": if err != nil {
default: return mcp.NewToolResultError(err.Error()), nil
return mcp.NewToolResultError("visibility must be PRIVATE, PROTECTED, or PUBLIC"), nil }
update.Visibility = &vis
}
if v := req.GetString("state", ""); v != "" {
rs, err := parseRowStatus(v)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
update.RowStatus = &rs
} }
v := store.Visibility(vis) if _, ok := args["pinned"]; ok {
update.Visibility = &v pinned := req.GetBool("pinned", false)
update.Pinned = &pinned
} }
if err := s.store.UpdateMemo(ctx, update); err != nil { if err := s.store.UpdateMemo(ctx, update); err != nil {
...@@ -241,7 +419,7 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque ...@@ -241,7 +419,7 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated memo: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated memo: %v", err)), nil
} }
out, err := marshalJSON(updated) out, err := marshalJSON(storeMemoToJSON(updated))
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -254,13 +432,9 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque ...@@ -254,13 +432,9 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
name := req.GetString("name", "") uid, err := parseMemoUID(req.GetString("name", ""))
if name == "" { if err != nil {
return mcp.NewToolResultError("name is required"), nil return mcp.NewToolResultError(err.Error()), nil
}
uid, found := strings.CutPrefix(name, "memos/")
if !found || uid == "" {
return mcp.NewToolResultError(`name must be in the format "memos/<uid>"`), nil
} }
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
...@@ -277,40 +451,147 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque ...@@ -277,40 +451,147 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque
if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil { if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete memo: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to delete memo: %v", err)), nil
} }
return mcp.NewToolResultText("memo deleted"), nil return mcp.NewToolResultText(`{"deleted":true}`), nil
} }
func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx) userID := auth.GetUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
query := req.GetString("query", "") query := req.GetString("query", "")
if query == "" { if query == "" {
return mcp.NewToolResultError("query is required"), nil return mcp.NewToolResultError("query is required"), nil
} }
rowStatus := store.Normal
limit := 50 limit := 50
zero := 0 zero := 0
rowStatus := store.Normal
find := &store.FindMemo{ find := &store.FindMemo{
ExcludeComments: true, ExcludeComments: true,
RowStatus: &rowStatus, RowStatus: &rowStatus,
Limit: &limit, Limit: &limit,
Offset: &zero, Offset: &zero,
Filters: []string{ Filters: []string{fmt.Sprintf(`content.contains(%q)`, query)},
fmt.Sprintf("creator_id == %d", userID),
fmt.Sprintf(`content.contains(%q)`, query),
},
} }
applyVisibilityFilter(find, userID)
memos, err := s.store.ListMemos(ctx, find) memos, err := s.store.ListMemos(ctx, find)
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to search memos: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to search memos: %v", err)), nil
} }
out, err := marshalJSON(memos) results := make([]memoJSON, len(memos))
for i, m := range memos {
results[i] = storeMemoToJSON(m)
}
out, err := marshalJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := auth.GetUserID(ctx)
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
parent, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if parent == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoAccess(parent, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
relationType := store.MemoRelationComment
relations, err := s.store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &parent.ID,
Type: &relationType,
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list relations: %v", err)), nil
}
if len(relations) == 0 {
out, _ := marshalJSON([]memoJSON{})
return mcp.NewToolResultText(out), nil
}
commentIDs := make([]int32, len(relations))
for i, r := range relations {
commentIDs[i] = r.MemoID
}
memos, err := s.store.ListMemos(ctx, &store.FindMemo{IDList: commentIDs})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list comments: %v", err)), nil
}
results := make([]memoJSON, 0, len(memos))
for _, m := range memos {
if checkMemoAccess(m, userID) == nil {
results = append(results, storeMemoToJSON(m))
}
}
out, err := marshalJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
content := req.GetString("content", "")
if content == "" {
return mcp.NewToolResultError("content is required"), nil
}
parent, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if parent == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoAccess(parent, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
comment, err := s.store.CreateMemo(ctx, &store.Memo{
UID: shortuuid.New(),
CreatorID: userID,
Content: content,
Visibility: parent.Visibility,
Payload: buildPayload(content),
ParentUID: &parent.UID,
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %v", err)), nil
}
if _, err = s.store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: comment.ID,
RelatedMemoID: parent.ID,
Type: store.MemoRelationComment,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to link comment: %v", err)), nil
}
out, err := marshalJSON(storeMemoToJSON(comment))
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
package mcp
import (
"context"
"fmt"
"sort"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_tags",
mcp.WithDescription("List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically."),
), s.handleListTags)
}
type tagEntry struct {
Tag string `json:"tag"`
Count int `json:"count"`
}
func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := auth.GetUserID(ctx)
rowStatus := store.Normal
find := &store.FindMemo{
ExcludeComments: true,
ExcludeContent: true,
RowStatus: &rowStatus,
}
applyVisibilityFilter(find, userID)
memos, err := s.store.ListMemos(ctx, find)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil
}
counts := make(map[string]int)
for _, m := range memos {
if m.Payload == nil {
continue
}
for _, tag := range m.Payload.Tags {
counts[tag]++
}
}
entries := make([]tagEntry, 0, len(counts))
for tag, count := range counts {
entries = append(entries, tagEntry{Tag: tag, Count: count})
}
sort.Slice(entries, func(i, j int) bool {
if entries[i].Count != entries[j].Count {
return entries[i].Count > entries[j].Count
}
return entries[i].Tag < entries[j].Tag
})
out, err := marshalJSON(entries)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment