Unverified Commit 583c3d24 authored by boojack's avatar boojack Committed by GitHub

feat(mcp): harden tool exposure and side effects (#5850)

parent 0fc1dab2
......@@ -12,6 +12,35 @@ DELETE /mcp (optional session termination)
Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26).
### Tool Filtering
The default `/mcp` endpoint exposes all tools. Clients can opt into a smaller
tool surface with GitHub-style headers or route aliases:
| Control | Description |
|---|---|
| `X-MCP-Readonly: true` | Hide and block mutating tools |
| `X-MCP-Toolsets: memos,tags,attachments,relations,reactions` | Limit the default tool list to selected toolsets |
| `X-MCP-Tools: list_tags,get_memo` | Add specific tools to the selected toolset list |
| `X-MCP-Exclude-Tools: delete_memo` | Remove specific tools |
Equivalent aliases:
```text
/mcp/readonly
/mcp/x/{toolsets}
/mcp/x/{toolsets}/readonly
```
Examples:
```text
/mcp/x/memos,tags/readonly
X-MCP-Toolsets: memos
X-MCP-Tools: list_tags
X-MCP-Exclude-Tools: delete_memo
```
## Capabilities
The server advertises the following MCP capabilities:
......@@ -126,7 +155,9 @@ claude mcp add --scope user --transport http memos http://localhost:5230/mcp \
| File | Responsibility |
|---|---|
| `mcp.go` | `MCPService` struct, constructor, route registration, auth middleware |
| `mcp.go` | `MCPService` struct, constructor, route registration, auth middleware, tool filtering |
| `tool_metadata.go` | Toolsets, read-only metadata, annotations, structured result helpers |
| `api_helpers.go` | Conversion helpers for calling API service methods from MCP tools |
| `tools_memo.go` | Memo CRUD tools + helpers (JSON types, visibility/access checks) |
| `tools_tag.go` | Tag listing tool |
| `tools_attachment.go` | Attachment listing, metadata, deletion, linking tools |
......
package mcp
import (
"context"
"github.com/pkg/errors"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
)
func visibilityToProto(visibility store.Visibility) v1pb.Visibility {
switch visibility {
case store.Protected:
return v1pb.Visibility_PROTECTED
case store.Public:
return v1pb.Visibility_PUBLIC
default:
return v1pb.Visibility_PRIVATE
}
}
func rowStatusToProto(rowStatus store.RowStatus) v1pb.State {
switch rowStatus {
case store.Archived:
return v1pb.State_ARCHIVED
default:
return v1pb.State_NORMAL
}
}
func (s *MCPService) loadMemoJSONByName(ctx context.Context, name string) (memoJSON, error) {
uid, err := parseMemoUID(name)
if err != nil {
return memoJSON{}, err
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return memoJSON{}, errors.Wrap(err, "failed to get memo")
}
if memo == nil {
return memoJSON{}, errors.New("memo not found")
}
return storeMemoToJSONWithStore(ctx, s.store, memo)
}
func (s *MCPService) loadReactionJSONByID(ctx context.Context, reactionID int32) (reactionJSON, error) {
reaction, err := s.store.GetReaction(ctx, &store.FindReaction{ID: &reactionID})
if err != nil {
return reactionJSON{}, errors.Wrap(err, "failed to get reaction")
}
if reaction == nil {
return reactionJSON{}, errors.New("reaction not found")
}
creator, err := lookupUsername(ctx, s.store, reaction.CreatorID)
if err != nil {
return reactionJSON{}, errors.Wrap(err, "failed to resolve reaction creator")
}
return reactionJSON{
ID: reaction.ID,
Creator: creator,
ReactionType: reaction.ReactionType,
CreateTime: reaction.CreatedTs,
}, nil
}
func (s *MCPService) loadReactionJSONByName(ctx context.Context, name string) (reactionJSON, error) {
_, reactionID, err := apiv1.ExtractMemoReactionIDFromName(name)
if err != nil {
return reactionJSON{}, err
}
return s.loadReactionJSONByID(ctx, reactionID)
}
package mcp
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/labstack/echo/v5"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/server/auth"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
)
const (
headerMCPReadonly = "X-MCP-Readonly"
headerMCPToolsets = "X-MCP-Toolsets"
headerMCPTools = "X-MCP-Tools"
headerMCPExcludeTools = "X-MCP-Exclude-Tools"
)
type mcpRequestConfigContextKey struct{}
type MCPService struct {
profile *profile.Profile
store *store.Store
apiV1Service *apiv1.APIV1Service
authenticator *auth.Authenticator
}
func NewMCPService(profile *profile.Profile, store *store.Store, secret string) *MCPService {
func NewMCPService(profile *profile.Profile, store *store.Store, secret string, apiV1Service *apiv1.APIV1Service) *MCPService {
return &MCPService{
profile: profile,
store: store,
apiV1Service: apiV1Service,
authenticator: auth.NewAuthenticator(store, secret),
}
}
......@@ -31,6 +47,10 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpserver.WithResourceCapabilities(true, true),
mcpserver.WithPromptCapabilities(true),
mcpserver.WithLogging(),
mcpserver.WithToolFilter(s.filterTools),
mcpserver.WithToolHandlerMiddleware(s.enforceToolAccess),
mcpserver.WithRecovery(),
mcpserver.WithResourceRecovery(),
)
s.registerMemoTools(mcpSrv)
s.registerTagTools(mcpSrv)
......@@ -40,7 +60,9 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
s.registerMemoResources(mcpSrv)
s.registerPrompts(mcpSrv)
httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv)
httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv,
mcpserver.WithHTTPContextFunc(s.withRequestConfig),
)
mcpGroup := echoServer.Group("")
mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
......@@ -52,7 +74,18 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
headers := c.Response().Header()
headers.Set("Vary", "Origin")
headers.Set("Access-Control-Allow-Origin", origin)
headers.Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept, Mcp-Session-Id, MCP-Protocol-Version, Last-Event-ID")
headers.Set("Access-Control-Allow-Headers", strings.Join([]string{
"Authorization",
"Content-Type",
"Accept",
"Mcp-Session-Id",
"MCP-Protocol-Version",
"Last-Event-ID",
headerMCPReadonly,
headerMCPToolsets,
headerMCPTools,
headerMCPExcludeTools,
}, ", "))
headers.Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
if c.Request().Method == http.MethodOptions {
return c.NoContent(http.StatusNoContent)
......@@ -72,4 +105,147 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
}
})
mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler))
mcpGroup.Any("/mcp/readonly", echo.WrapHandler(httpHandler))
mcpGroup.Any("/mcp/x/:toolsets", echo.WrapHandler(httpHandler))
mcpGroup.Any("/mcp/x/:toolsets/readonly", echo.WrapHandler(httpHandler))
}
func (*MCPService) withRequestConfig(ctx context.Context, r *http.Request) context.Context {
return context.WithValue(ctx, mcpRequestConfigContextKey{}, parseMCPRequestConfig(r))
}
func (*MCPService) filterTools(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
cfg := mcpRequestConfigFromContext(ctx)
filtered := make([]mcp.Tool, 0, len(tools))
for _, tool := range tools {
if cfg.allowsTool(tool.Name) {
filtered = append(filtered, tool)
}
}
return filtered
}
func (*MCPService) enforceToolAccess(next mcpserver.ToolHandlerFunc) mcpserver.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
cfg := mcpRequestConfigFromContext(ctx)
if !cfg.allowsTool(req.Params.Name) {
return mcp.NewToolResultError(fmt.Sprintf("tool %q is not enabled by MCP configuration", req.Params.Name)), nil
}
return next(ctx, req)
}
}
type mcpRequestConfig struct {
readOnly bool
toolsets map[string]struct{}
includeTools map[string]struct{}
excludeTools map[string]struct{}
}
func mcpRequestConfigFromContext(ctx context.Context) mcpRequestConfig {
if cfg, ok := ctx.Value(mcpRequestConfigContextKey{}).(mcpRequestConfig); ok {
return cfg
}
return mcpRequestConfig{}
}
func parseMCPRequestConfig(r *http.Request) mcpRequestConfig {
cfg := mcpRequestConfig{}
pathToolsets, pathReadonly := parseMCPPathConfig(r.URL.Path)
cfg.readOnly = pathReadonly || parseBoolHeader(r.Header.Get(headerMCPReadonly))
cfg.toolsets = mergeStringSets(cfg.toolsets, pathToolsets)
cfg.toolsets = mergeStringSets(cfg.toolsets, parseCommaSet(r.Header.Get(headerMCPToolsets), strings.ToLower))
cfg.includeTools = parseCommaSet(r.Header.Get(headerMCPTools), keepString)
cfg.excludeTools = parseCommaSet(r.Header.Get(headerMCPExcludeTools), keepString)
return cfg
}
func parseMCPPathConfig(path string) (map[string]struct{}, bool) {
trimmed := strings.Trim(path, "/")
if trimmed == "mcp/readonly" {
return nil, true
}
const prefix = "mcp/x/"
if !strings.HasPrefix(trimmed, prefix) {
return nil, false
}
rest := strings.TrimPrefix(trimmed, prefix)
readOnly := false
if strings.HasSuffix(rest, "/readonly") {
readOnly = true
rest = strings.TrimSuffix(rest, "/readonly")
}
return parseCommaSet(rest, strings.ToLower), readOnly
}
func (cfg mcpRequestConfig) allowsTool(name string) bool {
if _, known := allMCPToolNames[name]; !known {
return false
}
if cfg.readOnly {
if _, mutates := mcpMutationTools[name]; mutates {
return false
}
}
if _, excluded := cfg.excludeTools[name]; excluded {
return false
}
if _, included := cfg.includeTools[name]; included {
return true
}
if len(cfg.toolsets) == 0 {
return true
}
for toolset := range cfg.toolsets {
if _, ok := mcpToolsByToolset[toolset][name]; ok {
return true
}
}
return false
}
func parseBoolHeader(value string) bool {
switch strings.ToLower(strings.TrimSpace(value)) {
case "1", "t", "true", "y", "yes", "on":
return true
default:
return false
}
}
func parseCommaSet(value string, normalize func(string) string) map[string]struct{} {
if value == "" {
return nil
}
result := map[string]struct{}{}
for _, item := range strings.Split(value, ",") {
item = strings.TrimSpace(item)
if item == "" {
continue
}
result[normalize(item)] = struct{}{}
}
if len(result) == 0 {
return nil
}
return result
}
func mergeStringSets(dst map[string]struct{}, src map[string]struct{}) map[string]struct{} {
if len(src) == 0 {
return dst
}
if dst == nil {
dst = map[string]struct{}{}
}
for item := range src {
dst[item] = struct{}{}
}
return dst
}
func keepString(s string) string {
return s
}
package mcp
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"unsafe"
"github.com/labstack/echo/v5"
"github.com/lithammer/shortuuid/v4"
"github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/require"
......@@ -13,6 +19,7 @@ import (
"github.com/usememos/memos/internal/profile"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth"
apiv1service "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
teststore "github.com/usememos/memos/store/test"
)
......@@ -31,10 +38,12 @@ func newTestMCPService(t *testing.T) *testMCPService {
require.NoError(t, stores.Close())
})
svc := NewMCPService(&profile.Profile{
profile := &profile.Profile{
Driver: "sqlite",
InstanceURL: "https://notes.example.com",
}, stores, "test-secret")
}
apiV1Service := apiv1service.NewAPIV1Service("test-secret", profile, stores)
svc := NewMCPService(profile, stores, "test-secret", apiV1Service)
return &testMCPService{
service: svc,
store: stores,
......@@ -115,6 +124,125 @@ func firstText(t *testing.T, result *mcp.CallToolResult) string {
return text.Text
}
func initializeMCPHTTP(t *testing.T, e *echo.Echo, path string, headers map[string]string) string {
t.Helper()
payload := map[string]any{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": map[string]any{
"protocolVersion": "2025-06-18",
"capabilities": map[string]any{},
"clientInfo": map[string]any{
"name": "mcp-test",
"version": "1.0.0",
},
},
}
resp := postMCPHTTP(t, e, path, "", headers, payload)
require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
sessionID := resp.Header().Get("Mcp-Session-Id")
require.NotEmpty(t, sessionID)
return sessionID
}
func callMCPHTTP(t *testing.T, e *echo.Echo, path string, sessionID string, headers map[string]string, method string, params any) map[string]any {
t.Helper()
payload := map[string]any{
"jsonrpc": "2.0",
"id": 2,
"method": method,
}
if params != nil {
payload["params"] = params
}
resp := postMCPHTTP(t, e, path, sessionID, headers, payload)
require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
var decoded map[string]any
require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &decoded))
return decoded
}
func postMCPHTTP(t *testing.T, e *echo.Echo, path string, sessionID string, headers map[string]string, payload map[string]any) *httptest.ResponseRecorder {
t.Helper()
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
if sessionID != "" {
req.Header.Set("Mcp-Session-Id", sessionID)
}
for key, value := range headers {
req.Header.Set(key, value)
}
resp := httptest.NewRecorder()
e.ServeHTTP(resp, req)
return resp
}
func toolNamesFromListResponse(t *testing.T, response map[string]any) map[string]struct{} {
t.Helper()
result, ok := response["result"].(map[string]any)
require.True(t, ok, "missing result: %#v", response)
rawTools, ok := result["tools"].([]any)
require.True(t, ok, "missing tools: %#v", result)
names := map[string]struct{}{}
for _, rawTool := range rawTools {
tool, ok := rawTool.(map[string]any)
require.True(t, ok)
name, ok := tool["name"].(string)
require.True(t, ok)
names[name] = struct{}{}
}
return names
}
func requireToolPresent(t *testing.T, names map[string]struct{}, name string) {
t.Helper()
_, ok := names[name]
require.True(t, ok, "expected tool %q to be present in %#v", name, names)
}
func requireToolAbsent(t *testing.T, names map[string]struct{}, name string) {
t.Helper()
_, ok := names[name]
require.False(t, ok, "expected tool %q to be absent in %#v", name, names)
}
func nextSSEEvent(t *testing.T, client *apiv1service.SSEClient) *apiv1service.SSEEvent {
t.Helper()
events := sseClientEvents(t, client)
var data []byte
select {
case eventData, ok := <-events:
require.True(t, ok, "SSE client channel closed")
data = eventData
case <-time.After(time.Second):
t.Fatal("timed out waiting for SSE event")
}
var event apiv1service.SSEEvent
require.NoError(t, json.Unmarshal(data, &event))
return &event
}
func requireNoSSEEvent(t *testing.T, client *apiv1service.SSEClient) {
t.Helper()
select {
case eventData, ok := <-sseClientEvents(t, client):
require.True(t, ok, "SSE client channel closed")
t.Fatalf("unexpected SSE event received: %s", string(eventData))
case <-time.After(150 * time.Millisecond):
}
}
func sseClientEvents(t *testing.T, client *apiv1service.SSEClient) <-chan []byte {
t.Helper()
field := reflect.ValueOf(client).Elem().FieldByName("events")
events, ok := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface().(chan []byte)
require.True(t, ok)
return events
}
func TestHandleGetMemoAndReadResourceDenyArchivedMemoToNonCreator(t *testing.T) {
ts := newTestMCPService(t)
owner := ts.createUser(t, "owner")
......@@ -273,3 +401,205 @@ func TestIsAllowedOrigin(t *testing.T) {
require.False(t, ts.service.isAllowedOrigin(req))
})
}
func TestMCPToolFilteringRoutesAndHeaders(t *testing.T) {
ts := newTestMCPService(t)
e := echo.New()
ts.service.RegisterRoutes(e)
t.Run("default endpoint lists all tools", func(t *testing.T) {
sessionID := initializeMCPHTTP(t, e, "/mcp", nil)
response := callMCPHTTP(t, e, "/mcp", sessionID, nil, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
require.Len(t, names, len(allMCPToolNames))
requireToolPresent(t, names, "create_memo")
requireToolPresent(t, names, "list_tags")
requireToolPresent(t, names, "upsert_reaction")
})
t.Run("readonly header hides and blocks mutation tools", func(t *testing.T) {
headers := map[string]string{headerMCPReadonly: "true"}
sessionID := initializeMCPHTTP(t, e, "/mcp", nil)
response := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
requireToolPresent(t, names, "list_memos")
requireToolPresent(t, names, "list_tags")
requireToolAbsent(t, names, "create_memo")
requireToolAbsent(t, names, "delete_memo")
callResponse := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/call", map[string]any{
"name": "create_memo",
"arguments": map[string]any{"content": "blocked"},
})
result, ok := callResponse["result"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, result["isError"])
rawContent, ok := result["content"].([]any)
require.True(t, ok)
content, ok := rawContent[0].(map[string]any)
require.True(t, ok)
require.Contains(t, content["text"], "not enabled")
})
t.Run("readonly alias applies path config", func(t *testing.T) {
sessionID := initializeMCPHTTP(t, e, "/mcp/readonly", nil)
response := callMCPHTTP(t, e, "/mcp/readonly", sessionID, nil, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
requireToolPresent(t, names, "get_memo")
requireToolAbsent(t, names, "create_memo")
requireToolAbsent(t, names, "upsert_reaction")
})
t.Run("toolsets include and exclude compose", func(t *testing.T) {
headers := map[string]string{
headerMCPToolsets: "memos",
headerMCPTools: "list_tags",
headerMCPExcludeTools: "get_memo",
}
sessionID := initializeMCPHTTP(t, e, "/mcp", nil)
response := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
requireToolPresent(t, names, "list_memos")
requireToolPresent(t, names, "list_tags")
requireToolAbsent(t, names, "get_memo")
requireToolAbsent(t, names, "list_attachments")
})
t.Run("path toolsets and readonly compose", func(t *testing.T) {
sessionID := initializeMCPHTTP(t, e, "/mcp/x/memos,tags/readonly", nil)
response := callMCPHTTP(t, e, "/mcp/x/memos,tags/readonly", sessionID, nil, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
requireToolPresent(t, names, "list_memos")
requireToolPresent(t, names, "list_tags")
requireToolAbsent(t, names, "create_memo")
requireToolAbsent(t, names, "list_attachments")
})
t.Run("unknown toolset returns empty tool list", func(t *testing.T) {
sessionID := initializeMCPHTTP(t, e, "/mcp/x/notreal", nil)
response := callMCPHTTP(t, e, "/mcp/x/notreal", sessionID, nil, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
require.Empty(t, names)
})
}
func TestMCPMemoAndReactionMutationsEmitSSEEvents(t *testing.T) {
ts := newTestMCPService(t)
user := ts.createUser(t, "author")
ctx := withUser(context.Background(), user.ID)
client := ts.service.apiV1Service.SSEHub.Subscribe(user.ID, store.RoleUser)
defer ts.service.apiV1Service.SSEHub.Unsubscribe(client)
createResult, err := ts.service.handleCreateMemo(ctx, toolRequest("create_memo", map[string]any{
"content": "created from MCP",
"visibility": "PRIVATE",
}))
require.NoError(t, err)
require.False(t, createResult.IsError)
createEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoCreated, createEvent.Type)
var created memoJSON
require.NoError(t, json.Unmarshal([]byte(firstText(t, createResult)), &created))
updateResult, err := ts.service.handleUpdateMemo(ctx, toolRequest("update_memo", map[string]any{
"name": created.Name,
"content": "updated from MCP",
}))
require.NoError(t, err)
require.False(t, updateResult.IsError)
updateEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoUpdated, updateEvent.Type)
commentResult, err := ts.service.handleCreateMemoComment(ctx, toolRequest("create_memo_comment", map[string]any{
"name": created.Name,
"content": "comment from MCP",
}))
require.NoError(t, err)
require.False(t, commentResult.IsError)
commentEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoCommentCreated, commentEvent.Type)
require.Equal(t, created.Name, commentEvent.Name)
upsertReactionResult, err := ts.service.handleUpsertReaction(ctx, toolRequest("upsert_reaction", map[string]any{
"name": created.Name,
"reaction_type": "👍",
}))
require.NoError(t, err)
require.False(t, upsertReactionResult.IsError)
reactionEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventReactionUpserted, reactionEvent.Type)
var reaction reactionJSON
require.NoError(t, json.Unmarshal([]byte(firstText(t, upsertReactionResult)), &reaction))
deleteReactionResult, err := ts.service.handleDeleteReaction(ctx, toolRequest("delete_reaction", map[string]any{
"id": float64(reaction.ID),
}))
require.NoError(t, err)
require.False(t, deleteReactionResult.IsError)
deleteReactionEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventReactionDeleted, deleteReactionEvent.Type)
deleteResult, err := ts.service.handleDeleteMemo(ctx, toolRequest("delete_memo", map[string]any{
"name": created.Name,
}))
require.NoError(t, err)
require.False(t, deleteResult.IsError)
deleteEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoDeleted, deleteEvent.Type)
}
func TestMCPRelationAndAttachmentMutationsEmitMemoUpdated(t *testing.T) {
ts := newTestMCPService(t)
user := ts.createUser(t, "owner")
ctx := withUser(context.Background(), user.ID)
source := ts.createMemo(t, user.ID, store.Private, "source")
target := ts.createMemo(t, user.ID, store.Private, "target")
attachment := ts.createAttachment(t, user.ID, nil)
client := ts.service.apiV1Service.SSEHub.Subscribe(user.ID, store.RoleUser)
defer ts.service.apiV1Service.SSEHub.Unsubscribe(client)
relationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{
"name": "memos/" + source.UID,
"related_memo": "memos/" + target.UID,
}))
require.NoError(t, err)
require.False(t, relationResult.IsError)
relationEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoUpdated, relationEvent.Type)
require.Equal(t, "memos/"+source.UID, relationEvent.Name)
duplicateRelationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{
"name": "memos/" + source.UID,
"related_memo": "memos/" + target.UID,
}))
require.NoError(t, err)
require.False(t, duplicateRelationResult.IsError)
requireNoSSEEvent(t, client)
selfRelationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{
"name": "memos/" + source.UID,
"related_memo": "memos/" + source.UID,
}))
require.NoError(t, err)
require.True(t, selfRelationResult.IsError)
require.Contains(t, firstText(t, selfRelationResult), "itself")
linkResult, err := ts.service.handleLinkAttachmentToMemo(ctx, toolRequest("link_attachment_to_memo", map[string]any{
"name": "attachments/" + attachment.UID,
"memo": "memos/" + source.UID,
}))
require.NoError(t, err)
require.False(t, linkResult.IsError)
attachmentEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoUpdated, attachmentEvent.Type)
require.Equal(t, "memos/"+source.UID, attachmentEvent.Name)
relinkResult, err := ts.service.handleLinkAttachmentToMemo(ctx, toolRequest("link_attachment_to_memo", map[string]any{
"name": "attachments/" + attachment.UID,
"memo": "memos/" + source.UID,
}))
require.NoError(t, err)
require.False(t, relinkResult.IsError)
requireNoSSEEvent(t, client)
}
package mcp
import "github.com/mark3labs/mcp-go/mcp"
var mcpToolsByToolset = map[string]map[string]struct{}{
"memos": stringSet(
"list_memos",
"get_memo",
"create_memo",
"update_memo",
"delete_memo",
"search_memos",
"list_memo_comments",
"create_memo_comment",
),
"tags": stringSet(
"list_tags",
),
"attachments": stringSet(
"list_attachments",
"get_attachment",
"delete_attachment",
"link_attachment_to_memo",
),
"relations": stringSet(
"list_memo_relations",
"create_memo_relation",
"delete_memo_relation",
),
"reactions": stringSet(
"list_reactions",
"upsert_reaction",
"delete_reaction",
),
}
var allMCPToolNames = func() map[string]struct{} {
names := map[string]struct{}{}
for _, tools := range mcpToolsByToolset {
for name := range tools {
names[name] = struct{}{}
}
}
return names
}()
var mcpMutationTools = stringSet(
"create_memo",
"update_memo",
"delete_memo",
"create_memo_comment",
"delete_attachment",
"link_attachment_to_memo",
"create_memo_relation",
"delete_memo_relation",
"upsert_reaction",
"delete_reaction",
)
type deletedJSON struct {
Deleted bool `json:"deleted"`
}
func stringSet(values ...string) map[string]struct{} {
result := make(map[string]struct{}, len(values))
for _, value := range values {
result[value] = struct{}{}
}
return result
}
func readOnlyToolOptions(title string, description string, opts ...mcp.ToolOption) []mcp.ToolOption {
return annotatedToolOptions(title, description, true, false, true, false, opts...)
}
func createToolOptions(title string, description string, idempotent bool, opts ...mcp.ToolOption) []mcp.ToolOption {
return annotatedToolOptions(title, description, false, false, idempotent, false, opts...)
}
func updateToolOptions(title string, description string, opts ...mcp.ToolOption) []mcp.ToolOption {
return annotatedToolOptions(title, description, false, true, false, false, opts...)
}
func annotatedToolOptions(title string, description string, readOnly bool, destructive bool, idempotent bool, openWorld bool, opts ...mcp.ToolOption) []mcp.ToolOption {
base := []mcp.ToolOption{
mcp.WithTitleAnnotation(title),
mcp.WithDescription(description),
mcp.WithReadOnlyHintAnnotation(readOnly),
mcp.WithDestructiveHintAnnotation(destructive),
mcp.WithIdempotentHintAnnotation(idempotent),
mcp.WithOpenWorldHintAnnotation(openWorld),
}
return append(base, opts...)
}
func newToolResultJSON(v any) (*mcp.CallToolResult, error) {
return mcp.NewToolResultJSON(v)
}
func newDeletedToolResult() (*mcp.CallToolResult, error) {
return newToolResultJSON(deletedJSON{Deleted: true})
}
......@@ -9,6 +9,7 @@ import (
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
......@@ -26,6 +27,11 @@ type attachmentJSON struct {
Memo string `json:"memo,omitempty"`
}
type attachmentListJSON struct {
Attachments []attachmentJSON `json:"attachments"`
HasMore bool `json:"has_more"`
}
func storeAttachmentToJSON(ctx context.Context, stores *store.Store, a *store.Attachment) (attachmentJSON, error) {
creator, err := lookupUsername(ctx, stores, a.CreatorID)
if err != nil {
......@@ -98,26 +104,34 @@ func parseAttachmentUID(name string) (string, error) {
func (s *MCPService) registerAttachmentTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_attachments",
mcp.WithDescription("List attachments owned by the authenticated user. Supports pagination and optional filtering by linked memo."),
mcp.WithNumber("page_size", mcp.Description("Maximum attachments to return (1–100, default 20)")),
mcp.WithNumber("page", mcp.Description("Zero-based page index (default 0)")),
mcp.WithString("memo", mcp.Description(`Filter by linked memo resource name, e.g. "memos/abc123"`)),
readOnlyToolOptions("List attachments", "List attachments owned by the authenticated user. Supports pagination and optional filtering by linked memo.",
mcp.WithNumber("page_size", mcp.Description("Maximum attachments to return (1–100, default 20)")),
mcp.WithNumber("page", mcp.Description("Zero-based page index (default 0)")),
mcp.WithString("memo", mcp.Description(`Filter by linked memo resource name, e.g. "memos/abc123"`)),
mcp.WithOutputSchema[attachmentListJSON](),
)...,
), s.handleListAttachments)
mcpSrv.AddTool(mcp.NewTool("get_attachment",
mcp.WithDescription("Get a single attachment's metadata by resource name. Requires authentication."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
readOnlyToolOptions("Get attachment", "Get a single attachment's metadata by resource name. Requires authentication.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
mcp.WithOutputSchema[attachmentJSON](),
)...,
), s.handleGetAttachment)
mcpSrv.AddTool(mcp.NewTool("delete_attachment",
mcp.WithDescription("Permanently delete an attachment and its stored file. Requires authentication and ownership."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
updateToolOptions("Delete attachment", "Permanently delete an attachment and its stored file. Requires authentication and ownership.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteAttachment)
mcpSrv.AddTool(mcp.NewTool("link_attachment_to_memo",
mcp.WithDescription("Link an existing attachment to a memo. Requires authentication and ownership of the attachment."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
mcp.WithString("memo", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
createToolOptions("Link attachment to memo", "Link an existing attachment to a memo. Requires authentication and ownership of the attachment.", true,
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
mcp.WithString("memo", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithOutputSchema[attachmentJSON](),
)...,
), s.handleLinkAttachmentToMemo)
}
......@@ -189,15 +203,7 @@ func (s *MCPService) handleListAttachments(ctx context.Context, req mcp.CallTool
results[i] = result
}
type listResponse struct {
Attachments []attachmentJSON `json:"attachments"`
HasMore bool `json:"has_more"`
}
out, err := marshalJSON(listResponse{Attachments: results, HasMore: hasMore})
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(attachmentListJSON{Attachments: results, HasMore: hasMore})
}
func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
......@@ -224,16 +230,11 @@ func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRe
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
}
out, err := marshalJSON(result)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(result)
}
func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
if _, err := extractUserID(ctx); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
......@@ -242,18 +243,10 @@ func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToo
return mcp.NewToolResultError(err.Error()), nil
}
attachment, err := s.store.GetAttachment(ctx, &store.FindAttachment{UID: &uid, CreatorID: &userID})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to find attachment: %v", err)), nil
}
if attachment == nil {
return mcp.NewToolResultError("attachment not found"), nil
}
if err := s.store.DeleteAttachment(ctx, &store.DeleteAttachment{ID: attachment.ID}); err != nil {
if _, err := s.apiV1Service.DeleteAttachment(ctx, &v1pb.DeleteAttachmentRequest{Name: "attachments/" + uid}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete attachment: %v", err)), nil
}
return mcp.NewToolResultText(`{"deleted":true}`), nil
return newDeletedToolResult()
}
func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
......@@ -293,9 +286,30 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal
return mcp.NewToolResultError(err.Error()), nil
}
if err := s.store.UpdateAttachment(ctx, &store.UpdateAttachment{
ID: attachment.ID,
MemoID: &memo.ID,
currentAttachments, err := s.store.ListAttachments(ctx, &store.FindAttachment{MemoID: &memo.ID})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list memo attachments: %v", err)), nil
}
requestAttachments := make([]*v1pb.Attachment, 0, len(currentAttachments)+1)
var currentTarget *store.Attachment
for _, current := range currentAttachments {
requestAttachments = append(requestAttachments, &v1pb.Attachment{Name: "attachments/" + current.UID})
if current.ID == attachment.ID {
currentTarget = current
}
}
if currentTarget != nil {
result, err := storeAttachmentToJSON(ctx, s.store, currentTarget)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
}
return newToolResultJSON(result)
}
requestAttachments = append(requestAttachments, &v1pb.Attachment{Name: "attachments/" + uid})
if _, err := s.apiV1Service.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
Name: "memos/" + memoUID,
Attachments: requestAttachments,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to link attachment: %v", err)), nil
}
......@@ -309,9 +323,5 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
}
out, err := marshalJSON(result)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(result)
}
......@@ -2,52 +2,19 @@ package mcp
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strings"
"github.com/lithammer/shortuuid/v4"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
"google.golang.org/protobuf/types/known/fieldmaskpb"
storepb "github.com/usememos/memos/proto/gen/store"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
// tagRegexp matches #tag patterns in memo content.
// A tag must start with a letter and contain no whitespace or # characters.
var tagRegexp = regexp.MustCompile(`(?:^|\s)#([A-Za-z][^\s#]*)`)
// extractTags does a best-effort extraction of #tags from raw markdown content.
// It is used when creating or updating memos via MCP to pre-populate Payload.Tags.
// The full markdown service may later rebuild a more accurate payload.
func extractTags(content string) []string {
matches := tagRegexp.FindAllStringSubmatch(content, -1)
seen := make(map[string]struct{}, len(matches))
tags := make([]string, 0, len(matches))
for _, m := range matches {
tag := m[1]
if _, ok := seen[tag]; !ok {
seen[tag] = struct{}{}
tags = append(tags, tag)
}
}
return tags
}
// buildPayload constructs a MemoPayload with tags extracted from content.
// Returns nil when no tags are found so the store omits the payload entirely.
func buildPayload(content string) *storepb.MemoPayload {
tags := extractTags(content)
if len(tags) == 0 {
return nil
}
return &storepb.MemoPayload{Tags: tags}
}
// propertyJSON is the serialisable form of MemoPayload.Property.
type propertyJSON struct {
HasLink bool `json:"has_link"`
......@@ -72,6 +39,11 @@ type memoJSON struct {
Parent string `json:"parent,omitempty"`
}
type memoListJSON struct {
Memos []memoJSON `json:"memos"`
HasMore bool `json:"has_more"`
}
func storeMemoToJSON(m *store.Memo) memoJSON {
j := memoJSON{
Name: "memos/" + m.UID,
......@@ -205,75 +177,81 @@ func extractUserID(ctx context.Context) (int32, error) {
return id, nil
}
func marshalJSON(v any) (string, error) {
b, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(b), nil
}
func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_memos",
mcp.WithDescription("List memos visible to the caller. Authenticated users see their own memos plus public and protected memos; unauthenticated callers see only public memos."),
mcp.WithNumber("page_size", mcp.Description("Maximum memos to return (1–100, default 20)")),
mcp.WithNumber("page", mcp.Description("Zero-based page index for pagination (default 0)")),
mcp.WithString("state",
mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
),
mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
readOnlyToolOptions("List memos", "List memos visible to the caller. Authenticated users see their own memos plus public and protected memos; unauthenticated callers see only public memos.",
mcp.WithNumber("page_size", mcp.Description("Maximum memos to return (1–100, default 20)")),
mcp.WithNumber("page", mcp.Description("Zero-based page index for pagination (default 0)")),
mcp.WithString("state",
mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
),
mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
mcp.WithOutputSchema[memoListJSON](),
)...,
), s.handleListMemos)
mcpSrv.AddTool(mcp.NewTool("get_memo",
mcp.WithDescription("Get a single memo by resource name. Public memos are accessible without authentication."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
readOnlyToolOptions("Get memo", "Get a single memo by resource name. Public memos are accessible without authentication.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithOutputSchema[memoJSON](),
)...,
), s.handleGetMemo)
mcpSrv.AddTool(mcp.NewTool("create_memo",
mcp.WithDescription("Create a new memo. Requires authentication."),
mcp.WithString("content", mcp.Required(), mcp.Description("Memo content in Markdown. Use #tag syntax for tagging.")),
mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("Visibility (default: PRIVATE)"),
),
createToolOptions("Create memo", "Create a new memo. Requires authentication.", false,
mcp.WithString("content", mcp.Required(), mcp.Description("Memo content in Markdown. Use #tag syntax for tagging.")),
mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("Visibility (default: PRIVATE)"),
),
mcp.WithOutputSchema[memoJSON](),
)...,
), s.handleCreateMemo)
mcpSrv.AddTool(mcp.NewTool("update_memo",
mcp.WithDescription("Update a memo's content, visibility, pin state, or archive state. Requires authentication and ownership. Omit any field to leave it unchanged."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Description("New Markdown content")),
mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("New visibility"),
),
mcp.WithBoolean("pinned", mcp.Description("Pin or unpin the memo")),
mcp.WithString("state",
mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Set to ARCHIVED to archive, NORMAL to restore"),
),
updateToolOptions("Update memo", "Update a memo's content, visibility, pin state, or archive state. Requires authentication and ownership. Omit any field to leave it unchanged.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Description("New Markdown content")),
mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("New visibility"),
),
mcp.WithBoolean("pinned", mcp.Description("Pin or unpin the memo")),
mcp.WithString("state",
mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Set to ARCHIVED to archive, NORMAL to restore"),
),
mcp.WithOutputSchema[memoJSON](),
)...,
), s.handleUpdateMemo)
mcpSrv.AddTool(mcp.NewTool("delete_memo",
mcp.WithDescription("Permanently delete a memo. Requires authentication and ownership."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
updateToolOptions("Delete memo", "Permanently delete a memo. Requires authentication and ownership.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteMemo)
mcpSrv.AddTool(mcp.NewTool("search_memos",
mcp.WithDescription("Search memo content. Authenticated users search their own and visible memos; unauthenticated callers search public memos only."),
mcp.WithString("query", mcp.Required(), mcp.Description("Text to search for in memo content")),
readOnlyToolOptions("Search memos", "Search memo content. Authenticated users search their own and visible memos; unauthenticated callers search public memos only.",
mcp.WithString("query", mcp.Required(), mcp.Description("Text to search for in memo content")),
)...,
), s.handleSearchMemos)
mcpSrv.AddTool(mcp.NewTool("list_memo_comments",
mcp.WithDescription("List comments on a memo. Visibility rules for comments match those of the parent memo."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
readOnlyToolOptions("List memo comments", "List comments on a memo. Visibility rules for comments match those of the parent memo.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
)...,
), s.handleListMemoComments)
mcpSrv.AddTool(mcp.NewTool("create_memo_comment",
mcp.WithDescription("Add a comment to a memo. The comment inherits the parent memo's visibility. Requires authentication."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name to comment on, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Required(), mcp.Description("Comment content in Markdown")),
createToolOptions("Create memo comment", "Add a comment to a memo. The comment inherits the parent memo's visibility. Requires authentication.", false,
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name to comment on, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Required(), mcp.Description("Comment content in Markdown")),
mcp.WithOutputSchema[memoJSON](),
)...,
), s.handleCreateMemoComment)
}
......@@ -342,15 +320,7 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques
results[i] = result
}
type listResponse struct {
Memos []memoJSON `json:"memos"`
HasMore bool `json:"has_more"`
}
out, err := marshalJSON(listResponse{Memos: results, HasMore: hasMore})
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(memoListJSON{Memos: results, HasMore: hasMore})
}
func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
......@@ -376,16 +346,11 @@ func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
}
out, err := marshalJSON(result)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(result)
}
func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
if _, err := extractUserID(ctx); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
......@@ -398,31 +363,25 @@ func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(err.Error()), nil
}
memo, err := s.store.CreateMemo(ctx, &store.Memo{
UID: shortuuid.New(),
CreatorID: userID,
Content: content,
Visibility: visibility,
Payload: buildPayload(content),
created, err := s.apiV1Service.CreateMemo(ctx, &v1pb.CreateMemoRequest{
Memo: &v1pb.Memo{
Content: content,
Visibility: visibilityToProto(visibility),
},
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil
}
result, err := storeMemoToJSONWithStore(ctx, s.store, memo)
result, err := s.loadMemoJSONByName(ctx, created.Name)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
}
out, err := marshalJSON(result)
if err != nil {
return nil, err
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(result)
}
func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
if _, err := extractUserID(ctx); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
......@@ -431,66 +390,56 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(err.Error()), nil
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if memo == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoOwnership(memo, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
update := &store.UpdateMemo{ID: memo.ID}
update := &v1pb.Memo{Name: "memos/" + uid}
updateMask := &fieldmaskpb.FieldMask{}
args := req.GetArguments()
if v := req.GetString("content", ""); v != "" {
update.Content = &v
update.Payload = buildPayload(v)
update.Content = v
updateMask.Paths = append(updateMask.Paths, "content")
}
if v := req.GetString("visibility", ""); v != "" {
vis, err := parseVisibility(v)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
update.Visibility = &vis
update.Visibility = visibilityToProto(vis)
updateMask.Paths = append(updateMask.Paths, "visibility")
}
if v := req.GetString("state", ""); v != "" {
rs, err := parseRowStatus(v)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
update.RowStatus = &rs
update.State = rowStatusToProto(rs)
updateMask.Paths = append(updateMask.Paths, "state")
}
if _, ok := args["pinned"]; ok {
pinned := req.GetBool("pinned", false)
update.Pinned = &pinned
update.Pinned = req.GetBool("pinned", false)
updateMask.Paths = append(updateMask.Paths, "pinned")
}
if err := s.store.UpdateMemo(ctx, update); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to update memo: %v", err)), nil
if len(updateMask.Paths) == 0 {
return mcp.NewToolResultError("at least one field must be provided to update"), nil
}
updated, err := s.store.GetMemo(ctx, &store.FindMemo{ID: &memo.ID})
updated, err := s.apiV1Service.UpdateMemo(ctx, &v1pb.UpdateMemoRequest{
Memo: update,
UpdateMask: updateMask,
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated memo: %v", err)), nil
return mcp.NewToolResultError(fmt.Sprintf("failed to update memo: %v", err)), nil
}
result, err := storeMemoToJSONWithStore(ctx, s.store, updated)
result, err := s.loadMemoJSONByName(ctx, updated.Name)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
}
out, err := marshalJSON(result)
if err != nil {
return nil, err
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(result)
}
func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
if _, err := extractUserID(ctx); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
......@@ -499,21 +448,10 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(err.Error()), nil
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if memo == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoOwnership(memo, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
if _, err := s.apiV1Service.DeleteMemo(ctx, &v1pb.DeleteMemoRequest{Name: "memos/" + uid}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete memo: %v", err)), nil
}
return mcp.NewToolResultText(`{"deleted":true}`), nil
return newDeletedToolResult()
}
func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
......@@ -557,11 +495,7 @@ func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequ
}
results[i] = result
}
out, err := marshalJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(results)
}
func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
......@@ -592,8 +526,7 @@ func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToo
return mcp.NewToolResultError(fmt.Sprintf("failed to list relations: %v", err)), nil
}
if len(relations) == 0 {
out, _ := marshalJSON([]memoJSON{})
return mcp.NewToolResultText(out), nil
return newToolResultJSON([]memoJSON{})
}
commentIDs := make([]int32, len(relations))
......@@ -626,11 +559,7 @@ func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToo
results = append(results, result)
}
}
out, err := marshalJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(results)
}
func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
......@@ -659,33 +588,20 @@ func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallTo
return mcp.NewToolResultError(err.Error()), nil
}
comment, err := s.store.CreateMemo(ctx, &store.Memo{
UID: shortuuid.New(),
CreatorID: userID,
Content: content,
Visibility: parent.Visibility,
Payload: buildPayload(content),
ParentUID: &parent.UID,
comment, err := s.apiV1Service.CreateMemoComment(ctx, &v1pb.CreateMemoCommentRequest{
Name: "memos/" + uid,
Comment: &v1pb.Memo{
Content: content,
Visibility: visibilityToProto(parent.Visibility),
},
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %v", err)), nil
}
if _, err = s.store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: comment.ID,
RelatedMemoID: parent.ID,
Type: store.MemoRelationComment,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to link comment: %v", err)), nil
}
result, err := storeMemoToJSONWithStore(ctx, s.store, comment)
result, err := s.loadMemoJSONByName(ctx, comment.Name)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
}
out, err := marshalJSON(result)
if err != nil {
return nil, err
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(result)
}
......@@ -7,6 +7,7 @@ import (
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
......@@ -20,19 +21,24 @@ type reactionJSON struct {
func (s *MCPService) registerReactionTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_reactions",
mcp.WithDescription("List all reactions on a memo. Returns reaction type and creator for each reaction."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
readOnlyToolOptions("List reactions", "List all reactions on a memo. Returns reaction type and creator for each reaction.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
)...,
), s.handleListReactions)
mcpSrv.AddTool(mcp.NewTool("upsert_reaction",
mcp.WithDescription("Add a reaction (emoji) to a memo. If the same reaction already exists from the same user, this is a no-op. Requires authentication."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("reaction_type", mcp.Required(), mcp.Description(`Reaction emoji, e.g. "👍", "❤️", "🎉"`)),
createToolOptions("Upsert reaction", "Add a reaction (emoji) to a memo. If the same reaction already exists from the same user, this is a no-op. Requires authentication.", true,
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("reaction_type", mcp.Required(), mcp.Description(`Reaction emoji, e.g. "👍", "❤️", "🎉"`)),
mcp.WithOutputSchema[reactionJSON](),
)...,
), s.handleUpsertReaction)
mcpSrv.AddTool(mcp.NewTool("delete_reaction",
mcp.WithDescription("Remove a reaction by its ID. Requires authentication and ownership of the reaction."),
mcp.WithNumber("id", mcp.Required(), mcp.Description("Reaction ID to delete")),
updateToolOptions("Delete reaction", "Remove a reaction by its ID. Requires authentication and ownership of the reaction.",
mcp.WithNumber("id", mcp.Required(), mcp.Description("Reaction ID to delete")),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteReaction)
}
......@@ -83,11 +89,7 @@ func (s *MCPService) handleListReactions(ctx context.Context, req mcp.CallToolRe
}
}
out, err := marshalJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(results)
}
func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
......@@ -133,34 +135,26 @@ func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolR
}
contentID := "memos/" + uid
reaction, err := s.store.UpsertReaction(ctx, &store.Reaction{
CreatorID: userID,
ContentID: contentID,
ReactionType: reactionType,
reaction, err := s.apiV1Service.UpsertMemoReaction(ctx, &v1pb.UpsertMemoReactionRequest{
Name: contentID,
Reaction: &v1pb.Reaction{
ContentId: contentID,
ReactionType: reactionType,
},
})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to upsert reaction: %v", err)), nil
}
creator, err := lookupUsername(ctx, s.store, reaction.CreatorID)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve reaction creator: %v", err)), nil
}
out, err := marshalJSON(reactionJSON{
ID: reaction.ID,
Creator: creator,
ReactionType: reaction.ReactionType,
CreateTime: reaction.CreatedTs,
})
result, err := s.loadReactionJSONByName(ctx, reaction.Name)
if err != nil {
return nil, err
return mcp.NewToolResultError(err.Error()), nil
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(result)
}
func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
if _, err := extractUserID(ctx); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
......@@ -176,12 +170,11 @@ func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolR
if reaction == nil {
return mcp.NewToolResultError("reaction not found"), nil
}
if reaction.CreatorID != userID {
return mcp.NewToolResultError("permission denied: can only delete your own reactions"), nil
}
if err := s.store.DeleteReaction(ctx, &store.DeleteReaction{ID: reactionID}); err != nil {
if _, err := s.apiV1Service.DeleteMemoReaction(ctx, &v1pb.DeleteMemoReactionRequest{
Name: fmt.Sprintf("%s/reactions/%d", reaction.ContentID, reactionID),
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete reaction: %v", err)), nil
}
return mcp.NewToolResultText(`{"deleted":true}`), nil
return newDeletedToolResult()
}
......@@ -7,6 +7,7 @@ import (
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
......@@ -19,24 +20,29 @@ type relationJSON struct {
func (s *MCPService) registerRelationTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_memo_relations",
mcp.WithDescription("List all relations (references and comments) for a memo. Requires read access to the memo."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("type",
mcp.Enum("REFERENCE", "COMMENT"),
mcp.Description("Filter by relation type (optional)"),
),
readOnlyToolOptions("List memo relations", "List all relations (references and comments) for a memo. Requires read access to the memo.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("type",
mcp.Enum("REFERENCE", "COMMENT"),
mcp.Description("Filter by relation type (optional)"),
),
)...,
), s.handleListMemoRelations)
mcpSrv.AddTool(mcp.NewTool("create_memo_relation",
mcp.WithDescription("Create a reference relation between two memos. Requires authentication. For comments, use create_memo_comment instead."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)),
createToolOptions("Create memo relation", "Create a reference relation between two memos. Requires authentication. For comments, use create_memo_comment instead.", true,
mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)),
mcp.WithOutputSchema[relationJSON](),
)...,
), s.handleCreateMemoRelation)
mcpSrv.AddTool(mcp.NewTool("delete_memo_relation",
mcp.WithDescription("Delete a reference relation between two memos. Requires authentication and ownership of the source memo."),
mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)),
updateToolOptions("Delete memo relation", "Delete a reference relation between two memos. Requires authentication and ownership of the source memo.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteMemoRelation)
}
......@@ -113,11 +119,7 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo
})
}
out, err := marshalJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(results)
}
func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
......@@ -134,6 +136,9 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
if srcUID == dstUID {
return mcp.NewToolResultError("cannot create a relation from a memo to itself"), nil
}
srcMemo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &srcUID})
if err != nil {
......@@ -157,24 +162,24 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
return mcp.NewToolResultError(err.Error()), nil
}
relation, err := s.store.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: srcMemo.ID,
RelatedMemoID: dstMemo.ID,
Type: store.MemoRelationReference,
})
relations, changed, err := s.buildReferenceRelationSet(ctx, srcMemo, &dstMemo.UID, nil)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create relation: %v", err)), nil
return mcp.NewToolResultError(fmt.Sprintf("failed to build relation set: %v", err)), nil
}
if changed {
if _, err := s.apiV1Service.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: "memos/" + srcUID,
Relations: relations,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create relation: %v", err)), nil
}
}
out, err := marshalJSON(relationJSON{
return newToolResultJSON(relationJSON{
Memo: "memos/" + srcUID,
RelatedMemo: "memos/" + dstUID,
Type: string(relation.Type),
Type: string(store.MemoRelationReference),
})
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
......@@ -214,13 +219,79 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT
return mcp.NewToolResultError(err.Error()), nil
}
refType := store.MemoRelationReference
if err := s.store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{
MemoID: &srcMemo.ID,
RelatedMemoID: &dstMemo.ID,
Type: &refType,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete relation: %v", err)), nil
relations, changed, err := s.buildReferenceRelationSet(ctx, srcMemo, nil, &dstMemo.UID)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to build relation set: %v", err)), nil
}
if changed {
if _, err := s.apiV1Service.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: "memos/" + srcUID,
Relations: relations,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete relation: %v", err)), nil
}
}
return newDeletedToolResult()
}
func (s *MCPService) buildReferenceRelationSet(ctx context.Context, source *store.Memo, includeUID *string, excludeUID *string) ([]*v1pb.MemoRelation, bool, error) {
referenceType := store.MemoRelationReference
relations, err := s.store.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoIDList: []int32{source.ID},
Type: &referenceType,
})
if err != nil {
return nil, false, err
}
idSet := make(map[int32]struct{}, len(relations))
for _, relation := range relations {
idSet[relation.RelatedMemoID] = struct{}{}
}
ids := make([]int32, 0, len(idSet))
for id := range idSet {
ids = append(ids, id)
}
memosByID := map[int32]*store.Memo{}
if len(ids) > 0 {
memos, err := s.store.ListMemos(ctx, &store.FindMemo{IDList: ids, ExcludeContent: true})
if err != nil {
return nil, false, err
}
for _, memo := range memos {
memosByID[memo.ID] = memo
}
}
result := make([]*v1pb.MemoRelation, 0, len(relations)+1)
seenUIDs := map[string]struct{}{}
changed := false
for _, relation := range relations {
relatedMemo := memosByID[relation.RelatedMemoID]
if relatedMemo == nil {
continue
}
if excludeUID != nil && relatedMemo.UID == *excludeUID {
changed = true
continue
}
result = append(result, newReferenceRelation(source.UID, relatedMemo.UID))
seenUIDs[relatedMemo.UID] = struct{}{}
}
if includeUID != nil {
if _, seen := seenUIDs[*includeUID]; !seen && source.UID != *includeUID {
result = append(result, newReferenceRelation(source.UID, *includeUID))
changed = true
}
}
return result, changed, nil
}
func newReferenceRelation(sourceUID string, relatedUID string) *v1pb.MemoRelation {
return &v1pb.MemoRelation{
Memo: &v1pb.MemoRelation_Memo{Name: "memos/" + sourceUID},
RelatedMemo: &v1pb.MemoRelation_Memo{Name: "memos/" + relatedUID},
Type: v1pb.MemoRelation_REFERENCE,
}
return mcp.NewToolResultText(`{"deleted":true}`), nil
}
......@@ -14,7 +14,7 @@ import (
func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_tags",
mcp.WithDescription("List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically."),
readOnlyToolOptions("List tags", "List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically.")...,
), s.handleListTags)
}
......@@ -70,9 +70,5 @@ func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest)
}
})
out, err := marshalJSON(entries)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
return newToolResultJSON(entries)
}
......@@ -89,7 +89,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
}
// Register MCP server.
mcpService := mcprouter.NewMCPService(s.Profile, s.Store, s.Secret)
mcpService := mcprouter.NewMCPService(s.Profile, s.Store, s.Secret, apiV1Service)
mcpService.RegisterRoutes(echoServer)
return s, nil
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment