Unverified Commit 583c3d24 authored by boojack's avatar boojack Committed by GitHub

feat(mcp): harden tool exposure and side effects (#5850)

parent 0fc1dab2
...@@ -12,6 +12,35 @@ DELETE /mcp (optional session termination) ...@@ -12,6 +12,35 @@ DELETE /mcp (optional session termination)
Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26). Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26).
### Tool Filtering
The default `/mcp` endpoint exposes all tools. Clients can opt into a smaller
tool surface with GitHub-style headers or route aliases:
| Control | Description |
|---|---|
| `X-MCP-Readonly: true` | Hide and block mutating tools |
| `X-MCP-Toolsets: memos,tags,attachments,relations,reactions` | Limit the default tool list to selected toolsets |
| `X-MCP-Tools: list_tags,get_memo` | Add specific tools to the selected toolset list |
| `X-MCP-Exclude-Tools: delete_memo` | Remove specific tools |
Equivalent aliases:
```text
/mcp/readonly
/mcp/x/{toolsets}
/mcp/x/{toolsets}/readonly
```
Examples:
```text
/mcp/x/memos,tags/readonly
X-MCP-Toolsets: memos
X-MCP-Tools: list_tags
X-MCP-Exclude-Tools: delete_memo
```
## Capabilities ## Capabilities
The server advertises the following MCP capabilities: The server advertises the following MCP capabilities:
...@@ -126,7 +155,9 @@ claude mcp add --scope user --transport http memos http://localhost:5230/mcp \ ...@@ -126,7 +155,9 @@ claude mcp add --scope user --transport http memos http://localhost:5230/mcp \
| File | Responsibility | | File | Responsibility |
|---|---| |---|---|
| `mcp.go` | `MCPService` struct, constructor, route registration, auth middleware | | `mcp.go` | `MCPService` struct, constructor, route registration, auth middleware, tool filtering |
| `tool_metadata.go` | Toolsets, read-only metadata, annotations, structured result helpers |
| `api_helpers.go` | Conversion helpers for calling API service methods from MCP tools |
| `tools_memo.go` | Memo CRUD tools + helpers (JSON types, visibility/access checks) | | `tools_memo.go` | Memo CRUD tools + helpers (JSON types, visibility/access checks) |
| `tools_tag.go` | Tag listing tool | | `tools_tag.go` | Tag listing tool |
| `tools_attachment.go` | Attachment listing, metadata, deletion, linking tools | | `tools_attachment.go` | Attachment listing, metadata, deletion, linking tools |
......
package mcp
import (
"context"
"github.com/pkg/errors"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
)
func visibilityToProto(visibility store.Visibility) v1pb.Visibility {
switch visibility {
case store.Protected:
return v1pb.Visibility_PROTECTED
case store.Public:
return v1pb.Visibility_PUBLIC
default:
return v1pb.Visibility_PRIVATE
}
}
func rowStatusToProto(rowStatus store.RowStatus) v1pb.State {
switch rowStatus {
case store.Archived:
return v1pb.State_ARCHIVED
default:
return v1pb.State_NORMAL
}
}
func (s *MCPService) loadMemoJSONByName(ctx context.Context, name string) (memoJSON, error) {
uid, err := parseMemoUID(name)
if err != nil {
return memoJSON{}, err
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return memoJSON{}, errors.Wrap(err, "failed to get memo")
}
if memo == nil {
return memoJSON{}, errors.New("memo not found")
}
return storeMemoToJSONWithStore(ctx, s.store, memo)
}
func (s *MCPService) loadReactionJSONByID(ctx context.Context, reactionID int32) (reactionJSON, error) {
reaction, err := s.store.GetReaction(ctx, &store.FindReaction{ID: &reactionID})
if err != nil {
return reactionJSON{}, errors.Wrap(err, "failed to get reaction")
}
if reaction == nil {
return reactionJSON{}, errors.New("reaction not found")
}
creator, err := lookupUsername(ctx, s.store, reaction.CreatorID)
if err != nil {
return reactionJSON{}, errors.Wrap(err, "failed to resolve reaction creator")
}
return reactionJSON{
ID: reaction.ID,
Creator: creator,
ReactionType: reaction.ReactionType,
CreateTime: reaction.CreatedTs,
}, nil
}
func (s *MCPService) loadReactionJSONByName(ctx context.Context, name string) (reactionJSON, error) {
_, reactionID, err := apiv1.ExtractMemoReactionIDFromName(name)
if err != nil {
return reactionJSON{}, err
}
return s.loadReactionJSONByID(ctx, reactionID)
}
package mcp package mcp
import ( import (
"context"
"fmt"
"net/http" "net/http"
"strings"
"github.com/labstack/echo/v5" "github.com/labstack/echo/v5"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/usememos/memos/internal/profile" "github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
const (
headerMCPReadonly = "X-MCP-Readonly"
headerMCPToolsets = "X-MCP-Toolsets"
headerMCPTools = "X-MCP-Tools"
headerMCPExcludeTools = "X-MCP-Exclude-Tools"
)
type mcpRequestConfigContextKey struct{}
type MCPService struct { type MCPService struct {
profile *profile.Profile profile *profile.Profile
store *store.Store store *store.Store
apiV1Service *apiv1.APIV1Service
authenticator *auth.Authenticator authenticator *auth.Authenticator
} }
func NewMCPService(profile *profile.Profile, store *store.Store, secret string) *MCPService { func NewMCPService(profile *profile.Profile, store *store.Store, secret string, apiV1Service *apiv1.APIV1Service) *MCPService {
return &MCPService{ return &MCPService{
profile: profile, profile: profile,
store: store, store: store,
apiV1Service: apiV1Service,
authenticator: auth.NewAuthenticator(store, secret), authenticator: auth.NewAuthenticator(store, secret),
} }
} }
...@@ -31,6 +47,10 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { ...@@ -31,6 +47,10 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpserver.WithResourceCapabilities(true, true), mcpserver.WithResourceCapabilities(true, true),
mcpserver.WithPromptCapabilities(true), mcpserver.WithPromptCapabilities(true),
mcpserver.WithLogging(), mcpserver.WithLogging(),
mcpserver.WithToolFilter(s.filterTools),
mcpserver.WithToolHandlerMiddleware(s.enforceToolAccess),
mcpserver.WithRecovery(),
mcpserver.WithResourceRecovery(),
) )
s.registerMemoTools(mcpSrv) s.registerMemoTools(mcpSrv)
s.registerTagTools(mcpSrv) s.registerTagTools(mcpSrv)
...@@ -40,7 +60,9 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { ...@@ -40,7 +60,9 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
s.registerMemoResources(mcpSrv) s.registerMemoResources(mcpSrv)
s.registerPrompts(mcpSrv) s.registerPrompts(mcpSrv)
httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv) httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv,
mcpserver.WithHTTPContextFunc(s.withRequestConfig),
)
mcpGroup := echoServer.Group("") mcpGroup := echoServer.Group("")
mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
...@@ -52,7 +74,18 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { ...@@ -52,7 +74,18 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
headers := c.Response().Header() headers := c.Response().Header()
headers.Set("Vary", "Origin") headers.Set("Vary", "Origin")
headers.Set("Access-Control-Allow-Origin", origin) headers.Set("Access-Control-Allow-Origin", origin)
headers.Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept, Mcp-Session-Id, MCP-Protocol-Version, Last-Event-ID") headers.Set("Access-Control-Allow-Headers", strings.Join([]string{
"Authorization",
"Content-Type",
"Accept",
"Mcp-Session-Id",
"MCP-Protocol-Version",
"Last-Event-ID",
headerMCPReadonly,
headerMCPToolsets,
headerMCPTools,
headerMCPExcludeTools,
}, ", "))
headers.Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") headers.Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
if c.Request().Method == http.MethodOptions { if c.Request().Method == http.MethodOptions {
return c.NoContent(http.StatusNoContent) return c.NoContent(http.StatusNoContent)
...@@ -72,4 +105,147 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { ...@@ -72,4 +105,147 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
} }
}) })
mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler)) mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler))
mcpGroup.Any("/mcp/readonly", echo.WrapHandler(httpHandler))
mcpGroup.Any("/mcp/x/:toolsets", echo.WrapHandler(httpHandler))
mcpGroup.Any("/mcp/x/:toolsets/readonly", echo.WrapHandler(httpHandler))
}
func (*MCPService) withRequestConfig(ctx context.Context, r *http.Request) context.Context {
return context.WithValue(ctx, mcpRequestConfigContextKey{}, parseMCPRequestConfig(r))
}
func (*MCPService) filterTools(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
cfg := mcpRequestConfigFromContext(ctx)
filtered := make([]mcp.Tool, 0, len(tools))
for _, tool := range tools {
if cfg.allowsTool(tool.Name) {
filtered = append(filtered, tool)
}
}
return filtered
}
func (*MCPService) enforceToolAccess(next mcpserver.ToolHandlerFunc) mcpserver.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
cfg := mcpRequestConfigFromContext(ctx)
if !cfg.allowsTool(req.Params.Name) {
return mcp.NewToolResultError(fmt.Sprintf("tool %q is not enabled by MCP configuration", req.Params.Name)), nil
}
return next(ctx, req)
}
}
type mcpRequestConfig struct {
readOnly bool
toolsets map[string]struct{}
includeTools map[string]struct{}
excludeTools map[string]struct{}
}
func mcpRequestConfigFromContext(ctx context.Context) mcpRequestConfig {
if cfg, ok := ctx.Value(mcpRequestConfigContextKey{}).(mcpRequestConfig); ok {
return cfg
}
return mcpRequestConfig{}
}
func parseMCPRequestConfig(r *http.Request) mcpRequestConfig {
cfg := mcpRequestConfig{}
pathToolsets, pathReadonly := parseMCPPathConfig(r.URL.Path)
cfg.readOnly = pathReadonly || parseBoolHeader(r.Header.Get(headerMCPReadonly))
cfg.toolsets = mergeStringSets(cfg.toolsets, pathToolsets)
cfg.toolsets = mergeStringSets(cfg.toolsets, parseCommaSet(r.Header.Get(headerMCPToolsets), strings.ToLower))
cfg.includeTools = parseCommaSet(r.Header.Get(headerMCPTools), keepString)
cfg.excludeTools = parseCommaSet(r.Header.Get(headerMCPExcludeTools), keepString)
return cfg
}
func parseMCPPathConfig(path string) (map[string]struct{}, bool) {
trimmed := strings.Trim(path, "/")
if trimmed == "mcp/readonly" {
return nil, true
}
const prefix = "mcp/x/"
if !strings.HasPrefix(trimmed, prefix) {
return nil, false
}
rest := strings.TrimPrefix(trimmed, prefix)
readOnly := false
if strings.HasSuffix(rest, "/readonly") {
readOnly = true
rest = strings.TrimSuffix(rest, "/readonly")
}
return parseCommaSet(rest, strings.ToLower), readOnly
}
func (cfg mcpRequestConfig) allowsTool(name string) bool {
if _, known := allMCPToolNames[name]; !known {
return false
}
if cfg.readOnly {
if _, mutates := mcpMutationTools[name]; mutates {
return false
}
}
if _, excluded := cfg.excludeTools[name]; excluded {
return false
}
if _, included := cfg.includeTools[name]; included {
return true
}
if len(cfg.toolsets) == 0 {
return true
}
for toolset := range cfg.toolsets {
if _, ok := mcpToolsByToolset[toolset][name]; ok {
return true
}
}
return false
}
func parseBoolHeader(value string) bool {
switch strings.ToLower(strings.TrimSpace(value)) {
case "1", "t", "true", "y", "yes", "on":
return true
default:
return false
}
}
func parseCommaSet(value string, normalize func(string) string) map[string]struct{} {
if value == "" {
return nil
}
result := map[string]struct{}{}
for _, item := range strings.Split(value, ",") {
item = strings.TrimSpace(item)
if item == "" {
continue
}
result[normalize(item)] = struct{}{}
}
if len(result) == 0 {
return nil
}
return result
}
func mergeStringSets(dst map[string]struct{}, src map[string]struct{}) map[string]struct{} {
if len(src) == 0 {
return dst
}
if dst == nil {
dst = map[string]struct{}{}
}
for item := range src {
dst[item] = struct{}{}
}
return dst
}
func keepString(s string) string {
return s
} }
This diff is collapsed.
package mcp
import "github.com/mark3labs/mcp-go/mcp"
var mcpToolsByToolset = map[string]map[string]struct{}{
"memos": stringSet(
"list_memos",
"get_memo",
"create_memo",
"update_memo",
"delete_memo",
"search_memos",
"list_memo_comments",
"create_memo_comment",
),
"tags": stringSet(
"list_tags",
),
"attachments": stringSet(
"list_attachments",
"get_attachment",
"delete_attachment",
"link_attachment_to_memo",
),
"relations": stringSet(
"list_memo_relations",
"create_memo_relation",
"delete_memo_relation",
),
"reactions": stringSet(
"list_reactions",
"upsert_reaction",
"delete_reaction",
),
}
var allMCPToolNames = func() map[string]struct{} {
names := map[string]struct{}{}
for _, tools := range mcpToolsByToolset {
for name := range tools {
names[name] = struct{}{}
}
}
return names
}()
var mcpMutationTools = stringSet(
"create_memo",
"update_memo",
"delete_memo",
"create_memo_comment",
"delete_attachment",
"link_attachment_to_memo",
"create_memo_relation",
"delete_memo_relation",
"upsert_reaction",
"delete_reaction",
)
type deletedJSON struct {
Deleted bool `json:"deleted"`
}
func stringSet(values ...string) map[string]struct{} {
result := make(map[string]struct{}, len(values))
for _, value := range values {
result[value] = struct{}{}
}
return result
}
func readOnlyToolOptions(title string, description string, opts ...mcp.ToolOption) []mcp.ToolOption {
return annotatedToolOptions(title, description, true, false, true, false, opts...)
}
func createToolOptions(title string, description string, idempotent bool, opts ...mcp.ToolOption) []mcp.ToolOption {
return annotatedToolOptions(title, description, false, false, idempotent, false, opts...)
}
func updateToolOptions(title string, description string, opts ...mcp.ToolOption) []mcp.ToolOption {
return annotatedToolOptions(title, description, false, true, false, false, opts...)
}
func annotatedToolOptions(title string, description string, readOnly bool, destructive bool, idempotent bool, openWorld bool, opts ...mcp.ToolOption) []mcp.ToolOption {
base := []mcp.ToolOption{
mcp.WithTitleAnnotation(title),
mcp.WithDescription(description),
mcp.WithReadOnlyHintAnnotation(readOnly),
mcp.WithDestructiveHintAnnotation(destructive),
mcp.WithIdempotentHintAnnotation(idempotent),
mcp.WithOpenWorldHintAnnotation(openWorld),
}
return append(base, opts...)
}
func newToolResultJSON(v any) (*mcp.CallToolResult, error) {
return mcp.NewToolResultJSON(v)
}
func newDeletedToolResult() (*mcp.CallToolResult, error) {
return newToolResultJSON(deletedJSON{Deleted: true})
}
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors" "github.com/pkg/errors"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
...@@ -26,6 +27,11 @@ type attachmentJSON struct { ...@@ -26,6 +27,11 @@ type attachmentJSON struct {
Memo string `json:"memo,omitempty"` Memo string `json:"memo,omitempty"`
} }
type attachmentListJSON struct {
Attachments []attachmentJSON `json:"attachments"`
HasMore bool `json:"has_more"`
}
func storeAttachmentToJSON(ctx context.Context, stores *store.Store, a *store.Attachment) (attachmentJSON, error) { func storeAttachmentToJSON(ctx context.Context, stores *store.Store, a *store.Attachment) (attachmentJSON, error) {
creator, err := lookupUsername(ctx, stores, a.CreatorID) creator, err := lookupUsername(ctx, stores, a.CreatorID)
if err != nil { if err != nil {
...@@ -98,26 +104,34 @@ func parseAttachmentUID(name string) (string, error) { ...@@ -98,26 +104,34 @@ func parseAttachmentUID(name string) (string, error) {
func (s *MCPService) registerAttachmentTools(mcpSrv *mcpserver.MCPServer) { func (s *MCPService) registerAttachmentTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_attachments", mcpSrv.AddTool(mcp.NewTool("list_attachments",
mcp.WithDescription("List attachments owned by the authenticated user. Supports pagination and optional filtering by linked memo."), readOnlyToolOptions("List attachments", "List attachments owned by the authenticated user. Supports pagination and optional filtering by linked memo.",
mcp.WithNumber("page_size", mcp.Description("Maximum attachments to return (1–100, default 20)")), mcp.WithNumber("page_size", mcp.Description("Maximum attachments to return (1–100, default 20)")),
mcp.WithNumber("page", mcp.Description("Zero-based page index (default 0)")), mcp.WithNumber("page", mcp.Description("Zero-based page index (default 0)")),
mcp.WithString("memo", mcp.Description(`Filter by linked memo resource name, e.g. "memos/abc123"`)), mcp.WithString("memo", mcp.Description(`Filter by linked memo resource name, e.g. "memos/abc123"`)),
mcp.WithOutputSchema[attachmentListJSON](),
)...,
), s.handleListAttachments) ), s.handleListAttachments)
mcpSrv.AddTool(mcp.NewTool("get_attachment", mcpSrv.AddTool(mcp.NewTool("get_attachment",
mcp.WithDescription("Get a single attachment's metadata by resource name. Requires authentication."), readOnlyToolOptions("Get attachment", "Get a single attachment's metadata by resource name. Requires authentication.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
mcp.WithOutputSchema[attachmentJSON](),
)...,
), s.handleGetAttachment) ), s.handleGetAttachment)
mcpSrv.AddTool(mcp.NewTool("delete_attachment", mcpSrv.AddTool(mcp.NewTool("delete_attachment",
mcp.WithDescription("Permanently delete an attachment and its stored file. Requires authentication and ownership."), updateToolOptions("Delete attachment", "Permanently delete an attachment and its stored file. Requires authentication and ownership.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteAttachment) ), s.handleDeleteAttachment)
mcpSrv.AddTool(mcp.NewTool("link_attachment_to_memo", mcpSrv.AddTool(mcp.NewTool("link_attachment_to_memo",
mcp.WithDescription("Link an existing attachment to a memo. Requires authentication and ownership of the attachment."), createToolOptions("Link attachment to memo", "Link an existing attachment to a memo. Requires authentication and ownership of the attachment.", true,
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
mcp.WithString("memo", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("memo", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithOutputSchema[attachmentJSON](),
)...,
), s.handleLinkAttachmentToMemo) ), s.handleLinkAttachmentToMemo)
} }
...@@ -189,15 +203,7 @@ func (s *MCPService) handleListAttachments(ctx context.Context, req mcp.CallTool ...@@ -189,15 +203,7 @@ func (s *MCPService) handleListAttachments(ctx context.Context, req mcp.CallTool
results[i] = result results[i] = result
} }
type listResponse struct { return newToolResultJSON(attachmentListJSON{Attachments: results, HasMore: hasMore})
Attachments []attachmentJSON `json:"attachments"`
HasMore bool `json:"has_more"`
}
out, err := marshalJSON(listResponse{Attachments: results, HasMore: hasMore})
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -224,16 +230,11 @@ func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRe ...@@ -224,16 +230,11 @@ func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRe
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
} }
out, err := marshalJSON(result) return newToolResultJSON(result)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx) if _, err := extractUserID(ctx); err != nil {
if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
...@@ -242,18 +243,10 @@ func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToo ...@@ -242,18 +243,10 @@ func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToo
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
attachment, err := s.store.GetAttachment(ctx, &store.FindAttachment{UID: &uid, CreatorID: &userID}) if _, err := s.apiV1Service.DeleteAttachment(ctx, &v1pb.DeleteAttachmentRequest{Name: "attachments/" + uid}); err != nil {
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to find attachment: %v", err)), nil
}
if attachment == nil {
return mcp.NewToolResultError("attachment not found"), nil
}
if err := s.store.DeleteAttachment(ctx, &store.DeleteAttachment{ID: attachment.ID}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete attachment: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to delete attachment: %v", err)), nil
} }
return mcp.NewToolResultText(`{"deleted":true}`), nil return newDeletedToolResult()
} }
func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -293,9 +286,30 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal ...@@ -293,9 +286,30 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
if err := s.store.UpdateAttachment(ctx, &store.UpdateAttachment{ currentAttachments, err := s.store.ListAttachments(ctx, &store.FindAttachment{MemoID: &memo.ID})
ID: attachment.ID, if err != nil {
MemoID: &memo.ID, return mcp.NewToolResultError(fmt.Sprintf("failed to list memo attachments: %v", err)), nil
}
requestAttachments := make([]*v1pb.Attachment, 0, len(currentAttachments)+1)
var currentTarget *store.Attachment
for _, current := range currentAttachments {
requestAttachments = append(requestAttachments, &v1pb.Attachment{Name: "attachments/" + current.UID})
if current.ID == attachment.ID {
currentTarget = current
}
}
if currentTarget != nil {
result, err := storeAttachmentToJSON(ctx, s.store, currentTarget)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
}
return newToolResultJSON(result)
}
requestAttachments = append(requestAttachments, &v1pb.Attachment{Name: "attachments/" + uid})
if _, err := s.apiV1Service.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
Name: "memos/" + memoUID,
Attachments: requestAttachments,
}); err != nil { }); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to link attachment: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to link attachment: %v", err)), nil
} }
...@@ -309,9 +323,5 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal ...@@ -309,9 +323,5 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
} }
out, err := marshalJSON(result) return newToolResultJSON(result)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
This diff is collapsed.
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
...@@ -20,19 +21,24 @@ type reactionJSON struct { ...@@ -20,19 +21,24 @@ type reactionJSON struct {
func (s *MCPService) registerReactionTools(mcpSrv *mcpserver.MCPServer) { func (s *MCPService) registerReactionTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_reactions", mcpSrv.AddTool(mcp.NewTool("list_reactions",
mcp.WithDescription("List all reactions on a memo. Returns reaction type and creator for each reaction."), readOnlyToolOptions("List reactions", "List all reactions on a memo. Returns reaction type and creator for each reaction.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
)...,
), s.handleListReactions) ), s.handleListReactions)
mcpSrv.AddTool(mcp.NewTool("upsert_reaction", mcpSrv.AddTool(mcp.NewTool("upsert_reaction",
mcp.WithDescription("Add a reaction (emoji) to a memo. If the same reaction already exists from the same user, this is a no-op. Requires authentication."), createToolOptions("Upsert reaction", "Add a reaction (emoji) to a memo. If the same reaction already exists from the same user, this is a no-op. Requires authentication.", true,
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("reaction_type", mcp.Required(), mcp.Description(`Reaction emoji, e.g. "👍", "❤️", "🎉"`)), mcp.WithString("reaction_type", mcp.Required(), mcp.Description(`Reaction emoji, e.g. "👍", "❤️", "🎉"`)),
mcp.WithOutputSchema[reactionJSON](),
)...,
), s.handleUpsertReaction) ), s.handleUpsertReaction)
mcpSrv.AddTool(mcp.NewTool("delete_reaction", mcpSrv.AddTool(mcp.NewTool("delete_reaction",
mcp.WithDescription("Remove a reaction by its ID. Requires authentication and ownership of the reaction."), updateToolOptions("Delete reaction", "Remove a reaction by its ID. Requires authentication and ownership of the reaction.",
mcp.WithNumber("id", mcp.Required(), mcp.Description("Reaction ID to delete")), mcp.WithNumber("id", mcp.Required(), mcp.Description("Reaction ID to delete")),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteReaction) ), s.handleDeleteReaction)
} }
...@@ -83,11 +89,7 @@ func (s *MCPService) handleListReactions(ctx context.Context, req mcp.CallToolRe ...@@ -83,11 +89,7 @@ func (s *MCPService) handleListReactions(ctx context.Context, req mcp.CallToolRe
} }
} }
out, err := marshalJSON(results) return newToolResultJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -133,34 +135,26 @@ func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolR ...@@ -133,34 +135,26 @@ func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolR
} }
contentID := "memos/" + uid contentID := "memos/" + uid
reaction, err := s.store.UpsertReaction(ctx, &store.Reaction{ reaction, err := s.apiV1Service.UpsertMemoReaction(ctx, &v1pb.UpsertMemoReactionRequest{
CreatorID: userID, Name: contentID,
ContentID: contentID, Reaction: &v1pb.Reaction{
ContentId: contentID,
ReactionType: reactionType, ReactionType: reactionType,
},
}) })
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to upsert reaction: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to upsert reaction: %v", err)), nil
} }
creator, err := lookupUsername(ctx, s.store, reaction.CreatorID) result, err := s.loadReactionJSONByName(ctx, reaction.Name)
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve reaction creator: %v", err)), nil return mcp.NewToolResultError(err.Error()), nil
}
out, err := marshalJSON(reactionJSON{
ID: reaction.ID,
Creator: creator,
ReactionType: reaction.ReactionType,
CreateTime: reaction.CreatedTs,
})
if err != nil {
return nil, err
} }
return mcp.NewToolResultText(out), nil return newToolResultJSON(result)
} }
func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx) if _, err := extractUserID(ctx); err != nil {
if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
...@@ -176,12 +170,11 @@ func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolR ...@@ -176,12 +170,11 @@ func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolR
if reaction == nil { if reaction == nil {
return mcp.NewToolResultError("reaction not found"), nil return mcp.NewToolResultError("reaction not found"), nil
} }
if reaction.CreatorID != userID {
return mcp.NewToolResultError("permission denied: can only delete your own reactions"), nil
}
if err := s.store.DeleteReaction(ctx, &store.DeleteReaction{ID: reactionID}); err != nil { if _, err := s.apiV1Service.DeleteMemoReaction(ctx, &v1pb.DeleteMemoReactionRequest{
Name: fmt.Sprintf("%s/reactions/%d", reaction.ContentID, reactionID),
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete reaction: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to delete reaction: %v", err)), nil
} }
return mcp.NewToolResultText(`{"deleted":true}`), nil return newDeletedToolResult()
} }
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
...@@ -19,24 +20,29 @@ type relationJSON struct { ...@@ -19,24 +20,29 @@ type relationJSON struct {
func (s *MCPService) registerRelationTools(mcpSrv *mcpserver.MCPServer) { func (s *MCPService) registerRelationTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_memo_relations", mcpSrv.AddTool(mcp.NewTool("list_memo_relations",
mcp.WithDescription("List all relations (references and comments) for a memo. Requires read access to the memo."), readOnlyToolOptions("List memo relations", "List all relations (references and comments) for a memo. Requires read access to the memo.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("type", mcp.WithString("type",
mcp.Enum("REFERENCE", "COMMENT"), mcp.Enum("REFERENCE", "COMMENT"),
mcp.Description("Filter by relation type (optional)"), mcp.Description("Filter by relation type (optional)"),
), ),
)...,
), s.handleListMemoRelations) ), s.handleListMemoRelations)
mcpSrv.AddTool(mcp.NewTool("create_memo_relation", mcpSrv.AddTool(mcp.NewTool("create_memo_relation",
mcp.WithDescription("Create a reference relation between two memos. Requires authentication. For comments, use create_memo_comment instead."), createToolOptions("Create memo relation", "Create a reference relation between two memos. Requires authentication. For comments, use create_memo_comment instead.", true,
mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)), mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)),
mcp.WithOutputSchema[relationJSON](),
)...,
), s.handleCreateMemoRelation) ), s.handleCreateMemoRelation)
mcpSrv.AddTool(mcp.NewTool("delete_memo_relation", mcpSrv.AddTool(mcp.NewTool("delete_memo_relation",
mcp.WithDescription("Delete a reference relation between two memos. Requires authentication and ownership of the source memo."), updateToolOptions("Delete memo relation", "Delete a reference relation between two memos. Requires authentication and ownership of the source memo.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)), mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteMemoRelation) ), s.handleDeleteMemoRelation)
} }
...@@ -113,11 +119,7 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo ...@@ -113,11 +119,7 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo
}) })
} }
out, err := marshalJSON(results) return newToolResultJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -134,6 +136,9 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT ...@@ -134,6 +136,9 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
if err != nil { if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
if srcUID == dstUID {
return mcp.NewToolResultError("cannot create a relation from a memo to itself"), nil
}
srcMemo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &srcUID}) srcMemo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &srcUID})
if err != nil { if err != nil {
...@@ -157,24 +162,24 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT ...@@ -157,24 +162,24 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
relation, err := s.store.UpsertMemoRelation(ctx, &store.MemoRelation{ relations, changed, err := s.buildReferenceRelationSet(ctx, srcMemo, &dstMemo.UID, nil)
MemoID: srcMemo.ID,
RelatedMemoID: dstMemo.ID,
Type: store.MemoRelationReference,
})
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to build relation set: %v", err)), nil
}
if changed {
if _, err := s.apiV1Service.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: "memos/" + srcUID,
Relations: relations,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create relation: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to create relation: %v", err)), nil
} }
}
out, err := marshalJSON(relationJSON{ return newToolResultJSON(relationJSON{
Memo: "memos/" + srcUID, Memo: "memos/" + srcUID,
RelatedMemo: "memos/" + dstUID, RelatedMemo: "memos/" + dstUID,
Type: string(relation.Type), Type: string(store.MemoRelationReference),
}) })
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -214,13 +219,79 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT ...@@ -214,13 +219,79 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
refType := store.MemoRelationReference relations, changed, err := s.buildReferenceRelationSet(ctx, srcMemo, nil, &dstMemo.UID)
if err := s.store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{ if err != nil {
MemoID: &srcMemo.ID, return mcp.NewToolResultError(fmt.Sprintf("failed to build relation set: %v", err)), nil
RelatedMemoID: &dstMemo.ID, }
Type: &refType, if changed {
if _, err := s.apiV1Service.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: "memos/" + srcUID,
Relations: relations,
}); err != nil { }); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete relation: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to delete relation: %v", err)), nil
} }
return mcp.NewToolResultText(`{"deleted":true}`), nil }
return newDeletedToolResult()
}
func (s *MCPService) buildReferenceRelationSet(ctx context.Context, source *store.Memo, includeUID *string, excludeUID *string) ([]*v1pb.MemoRelation, bool, error) {
referenceType := store.MemoRelationReference
relations, err := s.store.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoIDList: []int32{source.ID},
Type: &referenceType,
})
if err != nil {
return nil, false, err
}
idSet := make(map[int32]struct{}, len(relations))
for _, relation := range relations {
idSet[relation.RelatedMemoID] = struct{}{}
}
ids := make([]int32, 0, len(idSet))
for id := range idSet {
ids = append(ids, id)
}
memosByID := map[int32]*store.Memo{}
if len(ids) > 0 {
memos, err := s.store.ListMemos(ctx, &store.FindMemo{IDList: ids, ExcludeContent: true})
if err != nil {
return nil, false, err
}
for _, memo := range memos {
memosByID[memo.ID] = memo
}
}
result := make([]*v1pb.MemoRelation, 0, len(relations)+1)
seenUIDs := map[string]struct{}{}
changed := false
for _, relation := range relations {
relatedMemo := memosByID[relation.RelatedMemoID]
if relatedMemo == nil {
continue
}
if excludeUID != nil && relatedMemo.UID == *excludeUID {
changed = true
continue
}
result = append(result, newReferenceRelation(source.UID, relatedMemo.UID))
seenUIDs[relatedMemo.UID] = struct{}{}
}
if includeUID != nil {
if _, seen := seenUIDs[*includeUID]; !seen && source.UID != *includeUID {
result = append(result, newReferenceRelation(source.UID, *includeUID))
changed = true
}
}
return result, changed, nil
}
func newReferenceRelation(sourceUID string, relatedUID string) *v1pb.MemoRelation {
return &v1pb.MemoRelation{
Memo: &v1pb.MemoRelation_Memo{Name: "memos/" + sourceUID},
RelatedMemo: &v1pb.MemoRelation_Memo{Name: "memos/" + relatedUID},
Type: v1pb.MemoRelation_REFERENCE,
}
} }
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) { func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_tags", mcpSrv.AddTool(mcp.NewTool("list_tags",
mcp.WithDescription("List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically."), readOnlyToolOptions("List tags", "List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically.")...,
), s.handleListTags) ), s.handleListTags)
} }
...@@ -70,9 +70,5 @@ func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) ...@@ -70,9 +70,5 @@ func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest)
} }
}) })
out, err := marshalJSON(entries) return newToolResultJSON(entries)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
...@@ -89,7 +89,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store ...@@ -89,7 +89,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
} }
// Register MCP server. // Register MCP server.
mcpService := mcprouter.NewMCPService(s.Profile, s.Store, s.Secret) mcpService := mcprouter.NewMCPService(s.Profile, s.Store, s.Secret, apiV1Service)
mcpService.RegisterRoutes(echoServer) mcpService.RegisterRoutes(echoServer)
return s, nil return s, nil
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment