Unverified Commit 583c3d24 authored by boojack's avatar boojack Committed by GitHub

feat(mcp): harden tool exposure and side effects (#5850)

parent 0fc1dab2
...@@ -12,6 +12,35 @@ DELETE /mcp (optional session termination) ...@@ -12,6 +12,35 @@ DELETE /mcp (optional session termination)
Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26). Transport: [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports) (single endpoint, MCP spec 2025-03-26).
### Tool Filtering
The default `/mcp` endpoint exposes all tools. Clients can opt into a smaller
tool surface with GitHub-style headers or route aliases:
| Control | Description |
|---|---|
| `X-MCP-Readonly: true` | Hide and block mutating tools |
| `X-MCP-Toolsets: memos,tags,attachments,relations,reactions` | Limit the default tool list to selected toolsets |
| `X-MCP-Tools: list_tags,get_memo` | Add specific tools to the selected toolset list |
| `X-MCP-Exclude-Tools: delete_memo` | Remove specific tools |
Equivalent aliases:
```text
/mcp/readonly
/mcp/x/{toolsets}
/mcp/x/{toolsets}/readonly
```
Examples:
```text
/mcp/x/memos,tags/readonly
X-MCP-Toolsets: memos
X-MCP-Tools: list_tags
X-MCP-Exclude-Tools: delete_memo
```
## Capabilities ## Capabilities
The server advertises the following MCP capabilities: The server advertises the following MCP capabilities:
...@@ -126,7 +155,9 @@ claude mcp add --scope user --transport http memos http://localhost:5230/mcp \ ...@@ -126,7 +155,9 @@ claude mcp add --scope user --transport http memos http://localhost:5230/mcp \
| File | Responsibility | | File | Responsibility |
|---|---| |---|---|
| `mcp.go` | `MCPService` struct, constructor, route registration, auth middleware | | `mcp.go` | `MCPService` struct, constructor, route registration, auth middleware, tool filtering |
| `tool_metadata.go` | Toolsets, read-only metadata, annotations, structured result helpers |
| `api_helpers.go` | Conversion helpers for calling API service methods from MCP tools |
| `tools_memo.go` | Memo CRUD tools + helpers (JSON types, visibility/access checks) | | `tools_memo.go` | Memo CRUD tools + helpers (JSON types, visibility/access checks) |
| `tools_tag.go` | Tag listing tool | | `tools_tag.go` | Tag listing tool |
| `tools_attachment.go` | Attachment listing, metadata, deletion, linking tools | | `tools_attachment.go` | Attachment listing, metadata, deletion, linking tools |
......
package mcp
import (
"context"
"github.com/pkg/errors"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
)
func visibilityToProto(visibility store.Visibility) v1pb.Visibility {
switch visibility {
case store.Protected:
return v1pb.Visibility_PROTECTED
case store.Public:
return v1pb.Visibility_PUBLIC
default:
return v1pb.Visibility_PRIVATE
}
}
func rowStatusToProto(rowStatus store.RowStatus) v1pb.State {
switch rowStatus {
case store.Archived:
return v1pb.State_ARCHIVED
default:
return v1pb.State_NORMAL
}
}
func (s *MCPService) loadMemoJSONByName(ctx context.Context, name string) (memoJSON, error) {
uid, err := parseMemoUID(name)
if err != nil {
return memoJSON{}, err
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return memoJSON{}, errors.Wrap(err, "failed to get memo")
}
if memo == nil {
return memoJSON{}, errors.New("memo not found")
}
return storeMemoToJSONWithStore(ctx, s.store, memo)
}
func (s *MCPService) loadReactionJSONByID(ctx context.Context, reactionID int32) (reactionJSON, error) {
reaction, err := s.store.GetReaction(ctx, &store.FindReaction{ID: &reactionID})
if err != nil {
return reactionJSON{}, errors.Wrap(err, "failed to get reaction")
}
if reaction == nil {
return reactionJSON{}, errors.New("reaction not found")
}
creator, err := lookupUsername(ctx, s.store, reaction.CreatorID)
if err != nil {
return reactionJSON{}, errors.Wrap(err, "failed to resolve reaction creator")
}
return reactionJSON{
ID: reaction.ID,
Creator: creator,
ReactionType: reaction.ReactionType,
CreateTime: reaction.CreatedTs,
}, nil
}
func (s *MCPService) loadReactionJSONByName(ctx context.Context, name string) (reactionJSON, error) {
_, reactionID, err := apiv1.ExtractMemoReactionIDFromName(name)
if err != nil {
return reactionJSON{}, err
}
return s.loadReactionJSONByID(ctx, reactionID)
}
package mcp package mcp
import ( import (
"context"
"fmt"
"net/http" "net/http"
"strings"
"github.com/labstack/echo/v5" "github.com/labstack/echo/v5"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/usememos/memos/internal/profile" "github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
const (
headerMCPReadonly = "X-MCP-Readonly"
headerMCPToolsets = "X-MCP-Toolsets"
headerMCPTools = "X-MCP-Tools"
headerMCPExcludeTools = "X-MCP-Exclude-Tools"
)
type mcpRequestConfigContextKey struct{}
type MCPService struct { type MCPService struct {
profile *profile.Profile profile *profile.Profile
store *store.Store store *store.Store
apiV1Service *apiv1.APIV1Service
authenticator *auth.Authenticator authenticator *auth.Authenticator
} }
func NewMCPService(profile *profile.Profile, store *store.Store, secret string) *MCPService { func NewMCPService(profile *profile.Profile, store *store.Store, secret string, apiV1Service *apiv1.APIV1Service) *MCPService {
return &MCPService{ return &MCPService{
profile: profile, profile: profile,
store: store, store: store,
apiV1Service: apiV1Service,
authenticator: auth.NewAuthenticator(store, secret), authenticator: auth.NewAuthenticator(store, secret),
} }
} }
...@@ -31,6 +47,10 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { ...@@ -31,6 +47,10 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpserver.WithResourceCapabilities(true, true), mcpserver.WithResourceCapabilities(true, true),
mcpserver.WithPromptCapabilities(true), mcpserver.WithPromptCapabilities(true),
mcpserver.WithLogging(), mcpserver.WithLogging(),
mcpserver.WithToolFilter(s.filterTools),
mcpserver.WithToolHandlerMiddleware(s.enforceToolAccess),
mcpserver.WithRecovery(),
mcpserver.WithResourceRecovery(),
) )
s.registerMemoTools(mcpSrv) s.registerMemoTools(mcpSrv)
s.registerTagTools(mcpSrv) s.registerTagTools(mcpSrv)
...@@ -40,7 +60,9 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { ...@@ -40,7 +60,9 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
s.registerMemoResources(mcpSrv) s.registerMemoResources(mcpSrv)
s.registerPrompts(mcpSrv) s.registerPrompts(mcpSrv)
httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv) httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv,
mcpserver.WithHTTPContextFunc(s.withRequestConfig),
)
mcpGroup := echoServer.Group("") mcpGroup := echoServer.Group("")
mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
...@@ -52,7 +74,18 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { ...@@ -52,7 +74,18 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
headers := c.Response().Header() headers := c.Response().Header()
headers.Set("Vary", "Origin") headers.Set("Vary", "Origin")
headers.Set("Access-Control-Allow-Origin", origin) headers.Set("Access-Control-Allow-Origin", origin)
headers.Set("Access-Control-Allow-Headers", "Authorization, Content-Type, Accept, Mcp-Session-Id, MCP-Protocol-Version, Last-Event-ID") headers.Set("Access-Control-Allow-Headers", strings.Join([]string{
"Authorization",
"Content-Type",
"Accept",
"Mcp-Session-Id",
"MCP-Protocol-Version",
"Last-Event-ID",
headerMCPReadonly,
headerMCPToolsets,
headerMCPTools,
headerMCPExcludeTools,
}, ", "))
headers.Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") headers.Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
if c.Request().Method == http.MethodOptions { if c.Request().Method == http.MethodOptions {
return c.NoContent(http.StatusNoContent) return c.NoContent(http.StatusNoContent)
...@@ -72,4 +105,147 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { ...@@ -72,4 +105,147 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
} }
}) })
mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler)) mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler))
mcpGroup.Any("/mcp/readonly", echo.WrapHandler(httpHandler))
mcpGroup.Any("/mcp/x/:toolsets", echo.WrapHandler(httpHandler))
mcpGroup.Any("/mcp/x/:toolsets/readonly", echo.WrapHandler(httpHandler))
}
func (*MCPService) withRequestConfig(ctx context.Context, r *http.Request) context.Context {
return context.WithValue(ctx, mcpRequestConfigContextKey{}, parseMCPRequestConfig(r))
}
func (*MCPService) filterTools(ctx context.Context, tools []mcp.Tool) []mcp.Tool {
cfg := mcpRequestConfigFromContext(ctx)
filtered := make([]mcp.Tool, 0, len(tools))
for _, tool := range tools {
if cfg.allowsTool(tool.Name) {
filtered = append(filtered, tool)
}
}
return filtered
}
func (*MCPService) enforceToolAccess(next mcpserver.ToolHandlerFunc) mcpserver.ToolHandlerFunc {
return func(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
cfg := mcpRequestConfigFromContext(ctx)
if !cfg.allowsTool(req.Params.Name) {
return mcp.NewToolResultError(fmt.Sprintf("tool %q is not enabled by MCP configuration", req.Params.Name)), nil
}
return next(ctx, req)
}
}
type mcpRequestConfig struct {
readOnly bool
toolsets map[string]struct{}
includeTools map[string]struct{}
excludeTools map[string]struct{}
}
func mcpRequestConfigFromContext(ctx context.Context) mcpRequestConfig {
if cfg, ok := ctx.Value(mcpRequestConfigContextKey{}).(mcpRequestConfig); ok {
return cfg
}
return mcpRequestConfig{}
}
func parseMCPRequestConfig(r *http.Request) mcpRequestConfig {
cfg := mcpRequestConfig{}
pathToolsets, pathReadonly := parseMCPPathConfig(r.URL.Path)
cfg.readOnly = pathReadonly || parseBoolHeader(r.Header.Get(headerMCPReadonly))
cfg.toolsets = mergeStringSets(cfg.toolsets, pathToolsets)
cfg.toolsets = mergeStringSets(cfg.toolsets, parseCommaSet(r.Header.Get(headerMCPToolsets), strings.ToLower))
cfg.includeTools = parseCommaSet(r.Header.Get(headerMCPTools), keepString)
cfg.excludeTools = parseCommaSet(r.Header.Get(headerMCPExcludeTools), keepString)
return cfg
}
func parseMCPPathConfig(path string) (map[string]struct{}, bool) {
trimmed := strings.Trim(path, "/")
if trimmed == "mcp/readonly" {
return nil, true
}
const prefix = "mcp/x/"
if !strings.HasPrefix(trimmed, prefix) {
return nil, false
}
rest := strings.TrimPrefix(trimmed, prefix)
readOnly := false
if strings.HasSuffix(rest, "/readonly") {
readOnly = true
rest = strings.TrimSuffix(rest, "/readonly")
}
return parseCommaSet(rest, strings.ToLower), readOnly
}
func (cfg mcpRequestConfig) allowsTool(name string) bool {
if _, known := allMCPToolNames[name]; !known {
return false
}
if cfg.readOnly {
if _, mutates := mcpMutationTools[name]; mutates {
return false
}
}
if _, excluded := cfg.excludeTools[name]; excluded {
return false
}
if _, included := cfg.includeTools[name]; included {
return true
}
if len(cfg.toolsets) == 0 {
return true
}
for toolset := range cfg.toolsets {
if _, ok := mcpToolsByToolset[toolset][name]; ok {
return true
}
}
return false
}
func parseBoolHeader(value string) bool {
switch strings.ToLower(strings.TrimSpace(value)) {
case "1", "t", "true", "y", "yes", "on":
return true
default:
return false
}
}
func parseCommaSet(value string, normalize func(string) string) map[string]struct{} {
if value == "" {
return nil
}
result := map[string]struct{}{}
for _, item := range strings.Split(value, ",") {
item = strings.TrimSpace(item)
if item == "" {
continue
}
result[normalize(item)] = struct{}{}
}
if len(result) == 0 {
return nil
}
return result
}
func mergeStringSets(dst map[string]struct{}, src map[string]struct{}) map[string]struct{} {
if len(src) == 0 {
return dst
}
if dst == nil {
dst = map[string]struct{}{}
}
for item := range src {
dst[item] = struct{}{}
}
return dst
}
func keepString(s string) string {
return s
} }
package mcp package mcp
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"net/http"
"net/http/httptest" "net/http/httptest"
"reflect"
"testing" "testing"
"time"
"unsafe"
"github.com/labstack/echo/v5"
"github.com/lithammer/shortuuid/v4" "github.com/lithammer/shortuuid/v4"
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
...@@ -13,6 +19,7 @@ import ( ...@@ -13,6 +19,7 @@ import (
"github.com/usememos/memos/internal/profile" "github.com/usememos/memos/internal/profile"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
apiv1service "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
teststore "github.com/usememos/memos/store/test" teststore "github.com/usememos/memos/store/test"
) )
...@@ -31,10 +38,12 @@ func newTestMCPService(t *testing.T) *testMCPService { ...@@ -31,10 +38,12 @@ func newTestMCPService(t *testing.T) *testMCPService {
require.NoError(t, stores.Close()) require.NoError(t, stores.Close())
}) })
svc := NewMCPService(&profile.Profile{ profile := &profile.Profile{
Driver: "sqlite", Driver: "sqlite",
InstanceURL: "https://notes.example.com", InstanceURL: "https://notes.example.com",
}, stores, "test-secret") }
apiV1Service := apiv1service.NewAPIV1Service("test-secret", profile, stores)
svc := NewMCPService(profile, stores, "test-secret", apiV1Service)
return &testMCPService{ return &testMCPService{
service: svc, service: svc,
store: stores, store: stores,
...@@ -115,6 +124,125 @@ func firstText(t *testing.T, result *mcp.CallToolResult) string { ...@@ -115,6 +124,125 @@ func firstText(t *testing.T, result *mcp.CallToolResult) string {
return text.Text return text.Text
} }
func initializeMCPHTTP(t *testing.T, e *echo.Echo, path string, headers map[string]string) string {
t.Helper()
payload := map[string]any{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": map[string]any{
"protocolVersion": "2025-06-18",
"capabilities": map[string]any{},
"clientInfo": map[string]any{
"name": "mcp-test",
"version": "1.0.0",
},
},
}
resp := postMCPHTTP(t, e, path, "", headers, payload)
require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
sessionID := resp.Header().Get("Mcp-Session-Id")
require.NotEmpty(t, sessionID)
return sessionID
}
func callMCPHTTP(t *testing.T, e *echo.Echo, path string, sessionID string, headers map[string]string, method string, params any) map[string]any {
t.Helper()
payload := map[string]any{
"jsonrpc": "2.0",
"id": 2,
"method": method,
}
if params != nil {
payload["params"] = params
}
resp := postMCPHTTP(t, e, path, sessionID, headers, payload)
require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
var decoded map[string]any
require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &decoded))
return decoded
}
func postMCPHTTP(t *testing.T, e *echo.Echo, path string, sessionID string, headers map[string]string, payload map[string]any) *httptest.ResponseRecorder {
t.Helper()
body, err := json.Marshal(payload)
require.NoError(t, err)
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
if sessionID != "" {
req.Header.Set("Mcp-Session-Id", sessionID)
}
for key, value := range headers {
req.Header.Set(key, value)
}
resp := httptest.NewRecorder()
e.ServeHTTP(resp, req)
return resp
}
func toolNamesFromListResponse(t *testing.T, response map[string]any) map[string]struct{} {
t.Helper()
result, ok := response["result"].(map[string]any)
require.True(t, ok, "missing result: %#v", response)
rawTools, ok := result["tools"].([]any)
require.True(t, ok, "missing tools: %#v", result)
names := map[string]struct{}{}
for _, rawTool := range rawTools {
tool, ok := rawTool.(map[string]any)
require.True(t, ok)
name, ok := tool["name"].(string)
require.True(t, ok)
names[name] = struct{}{}
}
return names
}
func requireToolPresent(t *testing.T, names map[string]struct{}, name string) {
t.Helper()
_, ok := names[name]
require.True(t, ok, "expected tool %q to be present in %#v", name, names)
}
func requireToolAbsent(t *testing.T, names map[string]struct{}, name string) {
t.Helper()
_, ok := names[name]
require.False(t, ok, "expected tool %q to be absent in %#v", name, names)
}
func nextSSEEvent(t *testing.T, client *apiv1service.SSEClient) *apiv1service.SSEEvent {
t.Helper()
events := sseClientEvents(t, client)
var data []byte
select {
case eventData, ok := <-events:
require.True(t, ok, "SSE client channel closed")
data = eventData
case <-time.After(time.Second):
t.Fatal("timed out waiting for SSE event")
}
var event apiv1service.SSEEvent
require.NoError(t, json.Unmarshal(data, &event))
return &event
}
func requireNoSSEEvent(t *testing.T, client *apiv1service.SSEClient) {
t.Helper()
select {
case eventData, ok := <-sseClientEvents(t, client):
require.True(t, ok, "SSE client channel closed")
t.Fatalf("unexpected SSE event received: %s", string(eventData))
case <-time.After(150 * time.Millisecond):
}
}
func sseClientEvents(t *testing.T, client *apiv1service.SSEClient) <-chan []byte {
t.Helper()
field := reflect.ValueOf(client).Elem().FieldByName("events")
events, ok := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface().(chan []byte)
require.True(t, ok)
return events
}
func TestHandleGetMemoAndReadResourceDenyArchivedMemoToNonCreator(t *testing.T) { func TestHandleGetMemoAndReadResourceDenyArchivedMemoToNonCreator(t *testing.T) {
ts := newTestMCPService(t) ts := newTestMCPService(t)
owner := ts.createUser(t, "owner") owner := ts.createUser(t, "owner")
...@@ -273,3 +401,205 @@ func TestIsAllowedOrigin(t *testing.T) { ...@@ -273,3 +401,205 @@ func TestIsAllowedOrigin(t *testing.T) {
require.False(t, ts.service.isAllowedOrigin(req)) require.False(t, ts.service.isAllowedOrigin(req))
}) })
} }
func TestMCPToolFilteringRoutesAndHeaders(t *testing.T) {
ts := newTestMCPService(t)
e := echo.New()
ts.service.RegisterRoutes(e)
t.Run("default endpoint lists all tools", func(t *testing.T) {
sessionID := initializeMCPHTTP(t, e, "/mcp", nil)
response := callMCPHTTP(t, e, "/mcp", sessionID, nil, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
require.Len(t, names, len(allMCPToolNames))
requireToolPresent(t, names, "create_memo")
requireToolPresent(t, names, "list_tags")
requireToolPresent(t, names, "upsert_reaction")
})
t.Run("readonly header hides and blocks mutation tools", func(t *testing.T) {
headers := map[string]string{headerMCPReadonly: "true"}
sessionID := initializeMCPHTTP(t, e, "/mcp", nil)
response := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
requireToolPresent(t, names, "list_memos")
requireToolPresent(t, names, "list_tags")
requireToolAbsent(t, names, "create_memo")
requireToolAbsent(t, names, "delete_memo")
callResponse := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/call", map[string]any{
"name": "create_memo",
"arguments": map[string]any{"content": "blocked"},
})
result, ok := callResponse["result"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, result["isError"])
rawContent, ok := result["content"].([]any)
require.True(t, ok)
content, ok := rawContent[0].(map[string]any)
require.True(t, ok)
require.Contains(t, content["text"], "not enabled")
})
t.Run("readonly alias applies path config", func(t *testing.T) {
sessionID := initializeMCPHTTP(t, e, "/mcp/readonly", nil)
response := callMCPHTTP(t, e, "/mcp/readonly", sessionID, nil, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
requireToolPresent(t, names, "get_memo")
requireToolAbsent(t, names, "create_memo")
requireToolAbsent(t, names, "upsert_reaction")
})
t.Run("toolsets include and exclude compose", func(t *testing.T) {
headers := map[string]string{
headerMCPToolsets: "memos",
headerMCPTools: "list_tags",
headerMCPExcludeTools: "get_memo",
}
sessionID := initializeMCPHTTP(t, e, "/mcp", nil)
response := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
requireToolPresent(t, names, "list_memos")
requireToolPresent(t, names, "list_tags")
requireToolAbsent(t, names, "get_memo")
requireToolAbsent(t, names, "list_attachments")
})
t.Run("path toolsets and readonly compose", func(t *testing.T) {
sessionID := initializeMCPHTTP(t, e, "/mcp/x/memos,tags/readonly", nil)
response := callMCPHTTP(t, e, "/mcp/x/memos,tags/readonly", sessionID, nil, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
requireToolPresent(t, names, "list_memos")
requireToolPresent(t, names, "list_tags")
requireToolAbsent(t, names, "create_memo")
requireToolAbsent(t, names, "list_attachments")
})
t.Run("unknown toolset returns empty tool list", func(t *testing.T) {
sessionID := initializeMCPHTTP(t, e, "/mcp/x/notreal", nil)
response := callMCPHTTP(t, e, "/mcp/x/notreal", sessionID, nil, "tools/list", map[string]any{})
names := toolNamesFromListResponse(t, response)
require.Empty(t, names)
})
}
func TestMCPMemoAndReactionMutationsEmitSSEEvents(t *testing.T) {
ts := newTestMCPService(t)
user := ts.createUser(t, "author")
ctx := withUser(context.Background(), user.ID)
client := ts.service.apiV1Service.SSEHub.Subscribe(user.ID, store.RoleUser)
defer ts.service.apiV1Service.SSEHub.Unsubscribe(client)
createResult, err := ts.service.handleCreateMemo(ctx, toolRequest("create_memo", map[string]any{
"content": "created from MCP",
"visibility": "PRIVATE",
}))
require.NoError(t, err)
require.False(t, createResult.IsError)
createEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoCreated, createEvent.Type)
var created memoJSON
require.NoError(t, json.Unmarshal([]byte(firstText(t, createResult)), &created))
updateResult, err := ts.service.handleUpdateMemo(ctx, toolRequest("update_memo", map[string]any{
"name": created.Name,
"content": "updated from MCP",
}))
require.NoError(t, err)
require.False(t, updateResult.IsError)
updateEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoUpdated, updateEvent.Type)
commentResult, err := ts.service.handleCreateMemoComment(ctx, toolRequest("create_memo_comment", map[string]any{
"name": created.Name,
"content": "comment from MCP",
}))
require.NoError(t, err)
require.False(t, commentResult.IsError)
commentEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoCommentCreated, commentEvent.Type)
require.Equal(t, created.Name, commentEvent.Name)
upsertReactionResult, err := ts.service.handleUpsertReaction(ctx, toolRequest("upsert_reaction", map[string]any{
"name": created.Name,
"reaction_type": "👍",
}))
require.NoError(t, err)
require.False(t, upsertReactionResult.IsError)
reactionEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventReactionUpserted, reactionEvent.Type)
var reaction reactionJSON
require.NoError(t, json.Unmarshal([]byte(firstText(t, upsertReactionResult)), &reaction))
deleteReactionResult, err := ts.service.handleDeleteReaction(ctx, toolRequest("delete_reaction", map[string]any{
"id": float64(reaction.ID),
}))
require.NoError(t, err)
require.False(t, deleteReactionResult.IsError)
deleteReactionEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventReactionDeleted, deleteReactionEvent.Type)
deleteResult, err := ts.service.handleDeleteMemo(ctx, toolRequest("delete_memo", map[string]any{
"name": created.Name,
}))
require.NoError(t, err)
require.False(t, deleteResult.IsError)
deleteEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoDeleted, deleteEvent.Type)
}
func TestMCPRelationAndAttachmentMutationsEmitMemoUpdated(t *testing.T) {
ts := newTestMCPService(t)
user := ts.createUser(t, "owner")
ctx := withUser(context.Background(), user.ID)
source := ts.createMemo(t, user.ID, store.Private, "source")
target := ts.createMemo(t, user.ID, store.Private, "target")
attachment := ts.createAttachment(t, user.ID, nil)
client := ts.service.apiV1Service.SSEHub.Subscribe(user.ID, store.RoleUser)
defer ts.service.apiV1Service.SSEHub.Unsubscribe(client)
relationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{
"name": "memos/" + source.UID,
"related_memo": "memos/" + target.UID,
}))
require.NoError(t, err)
require.False(t, relationResult.IsError)
relationEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoUpdated, relationEvent.Type)
require.Equal(t, "memos/"+source.UID, relationEvent.Name)
duplicateRelationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{
"name": "memos/" + source.UID,
"related_memo": "memos/" + target.UID,
}))
require.NoError(t, err)
require.False(t, duplicateRelationResult.IsError)
requireNoSSEEvent(t, client)
selfRelationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{
"name": "memos/" + source.UID,
"related_memo": "memos/" + source.UID,
}))
require.NoError(t, err)
require.True(t, selfRelationResult.IsError)
require.Contains(t, firstText(t, selfRelationResult), "itself")
linkResult, err := ts.service.handleLinkAttachmentToMemo(ctx, toolRequest("link_attachment_to_memo", map[string]any{
"name": "attachments/" + attachment.UID,
"memo": "memos/" + source.UID,
}))
require.NoError(t, err)
require.False(t, linkResult.IsError)
attachmentEvent := nextSSEEvent(t, client)
require.Equal(t, apiv1service.SSEEventMemoUpdated, attachmentEvent.Type)
require.Equal(t, "memos/"+source.UID, attachmentEvent.Name)
relinkResult, err := ts.service.handleLinkAttachmentToMemo(ctx, toolRequest("link_attachment_to_memo", map[string]any{
"name": "attachments/" + attachment.UID,
"memo": "memos/" + source.UID,
}))
require.NoError(t, err)
require.False(t, relinkResult.IsError)
requireNoSSEEvent(t, client)
}
package mcp
import "github.com/mark3labs/mcp-go/mcp"
var mcpToolsByToolset = map[string]map[string]struct{}{
"memos": stringSet(
"list_memos",
"get_memo",
"create_memo",
"update_memo",
"delete_memo",
"search_memos",
"list_memo_comments",
"create_memo_comment",
),
"tags": stringSet(
"list_tags",
),
"attachments": stringSet(
"list_attachments",
"get_attachment",
"delete_attachment",
"link_attachment_to_memo",
),
"relations": stringSet(
"list_memo_relations",
"create_memo_relation",
"delete_memo_relation",
),
"reactions": stringSet(
"list_reactions",
"upsert_reaction",
"delete_reaction",
),
}
var allMCPToolNames = func() map[string]struct{} {
names := map[string]struct{}{}
for _, tools := range mcpToolsByToolset {
for name := range tools {
names[name] = struct{}{}
}
}
return names
}()
var mcpMutationTools = stringSet(
"create_memo",
"update_memo",
"delete_memo",
"create_memo_comment",
"delete_attachment",
"link_attachment_to_memo",
"create_memo_relation",
"delete_memo_relation",
"upsert_reaction",
"delete_reaction",
)
type deletedJSON struct {
Deleted bool `json:"deleted"`
}
func stringSet(values ...string) map[string]struct{} {
result := make(map[string]struct{}, len(values))
for _, value := range values {
result[value] = struct{}{}
}
return result
}
func readOnlyToolOptions(title string, description string, opts ...mcp.ToolOption) []mcp.ToolOption {
return annotatedToolOptions(title, description, true, false, true, false, opts...)
}
func createToolOptions(title string, description string, idempotent bool, opts ...mcp.ToolOption) []mcp.ToolOption {
return annotatedToolOptions(title, description, false, false, idempotent, false, opts...)
}
func updateToolOptions(title string, description string, opts ...mcp.ToolOption) []mcp.ToolOption {
return annotatedToolOptions(title, description, false, true, false, false, opts...)
}
func annotatedToolOptions(title string, description string, readOnly bool, destructive bool, idempotent bool, openWorld bool, opts ...mcp.ToolOption) []mcp.ToolOption {
base := []mcp.ToolOption{
mcp.WithTitleAnnotation(title),
mcp.WithDescription(description),
mcp.WithReadOnlyHintAnnotation(readOnly),
mcp.WithDestructiveHintAnnotation(destructive),
mcp.WithIdempotentHintAnnotation(idempotent),
mcp.WithOpenWorldHintAnnotation(openWorld),
}
return append(base, opts...)
}
func newToolResultJSON(v any) (*mcp.CallToolResult, error) {
return mcp.NewToolResultJSON(v)
}
func newDeletedToolResult() (*mcp.CallToolResult, error) {
return newToolResultJSON(deletedJSON{Deleted: true})
}
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors" "github.com/pkg/errors"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
...@@ -26,6 +27,11 @@ type attachmentJSON struct { ...@@ -26,6 +27,11 @@ type attachmentJSON struct {
Memo string `json:"memo,omitempty"` Memo string `json:"memo,omitempty"`
} }
type attachmentListJSON struct {
Attachments []attachmentJSON `json:"attachments"`
HasMore bool `json:"has_more"`
}
func storeAttachmentToJSON(ctx context.Context, stores *store.Store, a *store.Attachment) (attachmentJSON, error) { func storeAttachmentToJSON(ctx context.Context, stores *store.Store, a *store.Attachment) (attachmentJSON, error) {
creator, err := lookupUsername(ctx, stores, a.CreatorID) creator, err := lookupUsername(ctx, stores, a.CreatorID)
if err != nil { if err != nil {
...@@ -98,26 +104,34 @@ func parseAttachmentUID(name string) (string, error) { ...@@ -98,26 +104,34 @@ func parseAttachmentUID(name string) (string, error) {
func (s *MCPService) registerAttachmentTools(mcpSrv *mcpserver.MCPServer) { func (s *MCPService) registerAttachmentTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_attachments", mcpSrv.AddTool(mcp.NewTool("list_attachments",
mcp.WithDescription("List attachments owned by the authenticated user. Supports pagination and optional filtering by linked memo."), readOnlyToolOptions("List attachments", "List attachments owned by the authenticated user. Supports pagination and optional filtering by linked memo.",
mcp.WithNumber("page_size", mcp.Description("Maximum attachments to return (1–100, default 20)")), mcp.WithNumber("page_size", mcp.Description("Maximum attachments to return (1–100, default 20)")),
mcp.WithNumber("page", mcp.Description("Zero-based page index (default 0)")), mcp.WithNumber("page", mcp.Description("Zero-based page index (default 0)")),
mcp.WithString("memo", mcp.Description(`Filter by linked memo resource name, e.g. "memos/abc123"`)), mcp.WithString("memo", mcp.Description(`Filter by linked memo resource name, e.g. "memos/abc123"`)),
mcp.WithOutputSchema[attachmentListJSON](),
)...,
), s.handleListAttachments) ), s.handleListAttachments)
mcpSrv.AddTool(mcp.NewTool("get_attachment", mcpSrv.AddTool(mcp.NewTool("get_attachment",
mcp.WithDescription("Get a single attachment's metadata by resource name. Requires authentication."), readOnlyToolOptions("Get attachment", "Get a single attachment's metadata by resource name. Requires authentication.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
mcp.WithOutputSchema[attachmentJSON](),
)...,
), s.handleGetAttachment) ), s.handleGetAttachment)
mcpSrv.AddTool(mcp.NewTool("delete_attachment", mcpSrv.AddTool(mcp.NewTool("delete_attachment",
mcp.WithDescription("Permanently delete an attachment and its stored file. Requires authentication and ownership."), updateToolOptions("Delete attachment", "Permanently delete an attachment and its stored file. Requires authentication and ownership.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteAttachment) ), s.handleDeleteAttachment)
mcpSrv.AddTool(mcp.NewTool("link_attachment_to_memo", mcpSrv.AddTool(mcp.NewTool("link_attachment_to_memo",
mcp.WithDescription("Link an existing attachment to a memo. Requires authentication and ownership of the attachment."), createToolOptions("Link attachment to memo", "Link an existing attachment to a memo. Requires authentication and ownership of the attachment.", true,
mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Attachment resource name, e.g. "attachments/abc123"`)),
mcp.WithString("memo", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("memo", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithOutputSchema[attachmentJSON](),
)...,
), s.handleLinkAttachmentToMemo) ), s.handleLinkAttachmentToMemo)
} }
...@@ -189,15 +203,7 @@ func (s *MCPService) handleListAttachments(ctx context.Context, req mcp.CallTool ...@@ -189,15 +203,7 @@ func (s *MCPService) handleListAttachments(ctx context.Context, req mcp.CallTool
results[i] = result results[i] = result
} }
type listResponse struct { return newToolResultJSON(attachmentListJSON{Attachments: results, HasMore: hasMore})
Attachments []attachmentJSON `json:"attachments"`
HasMore bool `json:"has_more"`
}
out, err := marshalJSON(listResponse{Attachments: results, HasMore: hasMore})
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -224,16 +230,11 @@ func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRe ...@@ -224,16 +230,11 @@ func (s *MCPService) handleGetAttachment(ctx context.Context, req mcp.CallToolRe
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
} }
out, err := marshalJSON(result) return newToolResultJSON(result)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx) if _, err := extractUserID(ctx); err != nil {
if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
...@@ -242,18 +243,10 @@ func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToo ...@@ -242,18 +243,10 @@ func (s *MCPService) handleDeleteAttachment(ctx context.Context, req mcp.CallToo
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
attachment, err := s.store.GetAttachment(ctx, &store.FindAttachment{UID: &uid, CreatorID: &userID}) if _, err := s.apiV1Service.DeleteAttachment(ctx, &v1pb.DeleteAttachmentRequest{Name: "attachments/" + uid}); err != nil {
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to find attachment: %v", err)), nil
}
if attachment == nil {
return mcp.NewToolResultError("attachment not found"), nil
}
if err := s.store.DeleteAttachment(ctx, &store.DeleteAttachment{ID: attachment.ID}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete attachment: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to delete attachment: %v", err)), nil
} }
return mcp.NewToolResultText(`{"deleted":true}`), nil return newDeletedToolResult()
} }
func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -293,9 +286,30 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal ...@@ -293,9 +286,30 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
if err := s.store.UpdateAttachment(ctx, &store.UpdateAttachment{ currentAttachments, err := s.store.ListAttachments(ctx, &store.FindAttachment{MemoID: &memo.ID})
ID: attachment.ID, if err != nil {
MemoID: &memo.ID, return mcp.NewToolResultError(fmt.Sprintf("failed to list memo attachments: %v", err)), nil
}
requestAttachments := make([]*v1pb.Attachment, 0, len(currentAttachments)+1)
var currentTarget *store.Attachment
for _, current := range currentAttachments {
requestAttachments = append(requestAttachments, &v1pb.Attachment{Name: "attachments/" + current.UID})
if current.ID == attachment.ID {
currentTarget = current
}
}
if currentTarget != nil {
result, err := storeAttachmentToJSON(ctx, s.store, currentTarget)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
}
return newToolResultJSON(result)
}
requestAttachments = append(requestAttachments, &v1pb.Attachment{Name: "attachments/" + uid})
if _, err := s.apiV1Service.SetMemoAttachments(ctx, &v1pb.SetMemoAttachmentsRequest{
Name: "memos/" + memoUID,
Attachments: requestAttachments,
}); err != nil { }); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to link attachment: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to link attachment: %v", err)), nil
} }
...@@ -309,9 +323,5 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal ...@@ -309,9 +323,5 @@ func (s *MCPService) handleLinkAttachmentToMemo(ctx context.Context, req mcp.Cal
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to resolve attachment creator: %v", err)), nil
} }
out, err := marshalJSON(result) return newToolResultJSON(result)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
...@@ -2,52 +2,19 @@ package mcp ...@@ -2,52 +2,19 @@ package mcp
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"regexp"
"strings" "strings"
"github.com/lithammer/shortuuid/v4"
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/protobuf/types/known/fieldmaskpb"
storepb "github.com/usememos/memos/proto/gen/store" v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
// tagRegexp matches #tag patterns in memo content.
// A tag must start with a letter and contain no whitespace or # characters.
var tagRegexp = regexp.MustCompile(`(?:^|\s)#([A-Za-z][^\s#]*)`)
// extractTags does a best-effort extraction of #tags from raw markdown content.
// It is used when creating or updating memos via MCP to pre-populate Payload.Tags.
// The full markdown service may later rebuild a more accurate payload.
func extractTags(content string) []string {
matches := tagRegexp.FindAllStringSubmatch(content, -1)
seen := make(map[string]struct{}, len(matches))
tags := make([]string, 0, len(matches))
for _, m := range matches {
tag := m[1]
if _, ok := seen[tag]; !ok {
seen[tag] = struct{}{}
tags = append(tags, tag)
}
}
return tags
}
// buildPayload constructs a MemoPayload with tags extracted from content.
// Returns nil when no tags are found so the store omits the payload entirely.
func buildPayload(content string) *storepb.MemoPayload {
tags := extractTags(content)
if len(tags) == 0 {
return nil
}
return &storepb.MemoPayload{Tags: tags}
}
// propertyJSON is the serialisable form of MemoPayload.Property. // propertyJSON is the serialisable form of MemoPayload.Property.
type propertyJSON struct { type propertyJSON struct {
HasLink bool `json:"has_link"` HasLink bool `json:"has_link"`
...@@ -72,6 +39,11 @@ type memoJSON struct { ...@@ -72,6 +39,11 @@ type memoJSON struct {
Parent string `json:"parent,omitempty"` Parent string `json:"parent,omitempty"`
} }
type memoListJSON struct {
Memos []memoJSON `json:"memos"`
HasMore bool `json:"has_more"`
}
func storeMemoToJSON(m *store.Memo) memoJSON { func storeMemoToJSON(m *store.Memo) memoJSON {
j := memoJSON{ j := memoJSON{
Name: "memos/" + m.UID, Name: "memos/" + m.UID,
...@@ -205,75 +177,81 @@ func extractUserID(ctx context.Context) (int32, error) { ...@@ -205,75 +177,81 @@ func extractUserID(ctx context.Context) (int32, error) {
return id, nil return id, nil
} }
func marshalJSON(v any) (string, error) {
b, err := json.Marshal(v)
if err != nil {
return "", err
}
return string(b), nil
}
func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) { func (s *MCPService) registerMemoTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_memos", mcpSrv.AddTool(mcp.NewTool("list_memos",
mcp.WithDescription("List memos visible to the caller. Authenticated users see their own memos plus public and protected memos; unauthenticated callers see only public memos."), readOnlyToolOptions("List memos", "List memos visible to the caller. Authenticated users see their own memos plus public and protected memos; unauthenticated callers see only public memos.",
mcp.WithNumber("page_size", mcp.Description("Maximum memos to return (1–100, default 20)")), mcp.WithNumber("page_size", mcp.Description("Maximum memos to return (1–100, default 20)")),
mcp.WithNumber("page", mcp.Description("Zero-based page index for pagination (default 0)")), mcp.WithNumber("page", mcp.Description("Zero-based page index for pagination (default 0)")),
mcp.WithString("state", mcp.WithString("state",
mcp.Enum("NORMAL", "ARCHIVED"), mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"), mcp.Description("Filter by state: NORMAL (default) or ARCHIVED"),
), ),
mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")), mcp.WithBoolean("order_by_pinned", mcp.Description("When true, pinned memos appear first (default false)")),
mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)), mcp.WithString("filter", mcp.Description(`Optional CEL filter (supported subset of standard CEL syntax), e.g. content.contains("keyword") or tags.exists(t, t == "work")`)),
mcp.WithOutputSchema[memoListJSON](),
)...,
), s.handleListMemos) ), s.handleListMemos)
mcpSrv.AddTool(mcp.NewTool("get_memo", mcpSrv.AddTool(mcp.NewTool("get_memo",
mcp.WithDescription("Get a single memo by resource name. Public memos are accessible without authentication."), readOnlyToolOptions("Get memo", "Get a single memo by resource name. Public memos are accessible without authentication.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithOutputSchema[memoJSON](),
)...,
), s.handleGetMemo) ), s.handleGetMemo)
mcpSrv.AddTool(mcp.NewTool("create_memo", mcpSrv.AddTool(mcp.NewTool("create_memo",
mcp.WithDescription("Create a new memo. Requires authentication."), createToolOptions("Create memo", "Create a new memo. Requires authentication.", false,
mcp.WithString("content", mcp.Required(), mcp.Description("Memo content in Markdown. Use #tag syntax for tagging.")), mcp.WithString("content", mcp.Required(), mcp.Description("Memo content in Markdown. Use #tag syntax for tagging.")),
mcp.WithString("visibility", mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"), mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("Visibility (default: PRIVATE)"), mcp.Description("Visibility (default: PRIVATE)"),
), ),
mcp.WithOutputSchema[memoJSON](),
)...,
), s.handleCreateMemo) ), s.handleCreateMemo)
mcpSrv.AddTool(mcp.NewTool("update_memo", mcpSrv.AddTool(mcp.NewTool("update_memo",
mcp.WithDescription("Update a memo's content, visibility, pin state, or archive state. Requires authentication and ownership. Omit any field to leave it unchanged."), updateToolOptions("Update memo", "Update a memo's content, visibility, pin state, or archive state. Requires authentication and ownership. Omit any field to leave it unchanged.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Description("New Markdown content")), mcp.WithString("content", mcp.Description("New Markdown content")),
mcp.WithString("visibility", mcp.WithString("visibility",
mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"), mcp.Enum("PRIVATE", "PROTECTED", "PUBLIC"),
mcp.Description("New visibility"), mcp.Description("New visibility"),
), ),
mcp.WithBoolean("pinned", mcp.Description("Pin or unpin the memo")), mcp.WithBoolean("pinned", mcp.Description("Pin or unpin the memo")),
mcp.WithString("state", mcp.WithString("state",
mcp.Enum("NORMAL", "ARCHIVED"), mcp.Enum("NORMAL", "ARCHIVED"),
mcp.Description("Set to ARCHIVED to archive, NORMAL to restore"), mcp.Description("Set to ARCHIVED to archive, NORMAL to restore"),
), ),
mcp.WithOutputSchema[memoJSON](),
)...,
), s.handleUpdateMemo) ), s.handleUpdateMemo)
mcpSrv.AddTool(mcp.NewTool("delete_memo", mcpSrv.AddTool(mcp.NewTool("delete_memo",
mcp.WithDescription("Permanently delete a memo. Requires authentication and ownership."), updateToolOptions("Delete memo", "Permanently delete a memo. Requires authentication and ownership.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteMemo) ), s.handleDeleteMemo)
mcpSrv.AddTool(mcp.NewTool("search_memos", mcpSrv.AddTool(mcp.NewTool("search_memos",
mcp.WithDescription("Search memo content. Authenticated users search their own and visible memos; unauthenticated callers search public memos only."), readOnlyToolOptions("Search memos", "Search memo content. Authenticated users search their own and visible memos; unauthenticated callers search public memos only.",
mcp.WithString("query", mcp.Required(), mcp.Description("Text to search for in memo content")), mcp.WithString("query", mcp.Required(), mcp.Description("Text to search for in memo content")),
)...,
), s.handleSearchMemos) ), s.handleSearchMemos)
mcpSrv.AddTool(mcp.NewTool("list_memo_comments", mcpSrv.AddTool(mcp.NewTool("list_memo_comments",
mcp.WithDescription("List comments on a memo. Visibility rules for comments match those of the parent memo."), readOnlyToolOptions("List memo comments", "List comments on a memo. Visibility rules for comments match those of the parent memo.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
)...,
), s.handleListMemoComments) ), s.handleListMemoComments)
mcpSrv.AddTool(mcp.NewTool("create_memo_comment", mcpSrv.AddTool(mcp.NewTool("create_memo_comment",
mcp.WithDescription("Add a comment to a memo. The comment inherits the parent memo's visibility. Requires authentication."), createToolOptions("Create memo comment", "Add a comment to a memo. The comment inherits the parent memo's visibility. Requires authentication.", false,
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name to comment on, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name to comment on, e.g. "memos/abc123"`)),
mcp.WithString("content", mcp.Required(), mcp.Description("Comment content in Markdown")), mcp.WithString("content", mcp.Required(), mcp.Description("Comment content in Markdown")),
mcp.WithOutputSchema[memoJSON](),
)...,
), s.handleCreateMemoComment) ), s.handleCreateMemoComment)
} }
...@@ -342,15 +320,7 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques ...@@ -342,15 +320,7 @@ func (s *MCPService) handleListMemos(ctx context.Context, req mcp.CallToolReques
results[i] = result results[i] = result
} }
type listResponse struct { return newToolResultJSON(memoListJSON{Memos: results, HasMore: hasMore})
Memos []memoJSON `json:"memos"`
HasMore bool `json:"has_more"`
}
out, err := marshalJSON(listResponse{Memos: results, HasMore: hasMore})
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -376,16 +346,11 @@ func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest) ...@@ -376,16 +346,11 @@ func (s *MCPService) handleGetMemo(ctx context.Context, req mcp.CallToolRequest)
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil
} }
out, err := marshalJSON(result) return newToolResultJSON(result)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx) if _, err := extractUserID(ctx); err != nil {
if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
...@@ -398,31 +363,25 @@ func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolReque ...@@ -398,31 +363,25 @@ func (s *MCPService) handleCreateMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
memo, err := s.store.CreateMemo(ctx, &store.Memo{ created, err := s.apiV1Service.CreateMemo(ctx, &v1pb.CreateMemoRequest{
UID: shortuuid.New(), Memo: &v1pb.Memo{
CreatorID: userID, Content: content,
Content: content, Visibility: visibilityToProto(visibility),
Visibility: visibility, },
Payload: buildPayload(content),
}) })
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to create memo: %v", err)), nil
} }
result, err := storeMemoToJSONWithStore(ctx, s.store, memo) result, err := s.loadMemoJSONByName(ctx, created.Name)
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil return mcp.NewToolResultError(err.Error()), nil
}
out, err := marshalJSON(result)
if err != nil {
return nil, err
} }
return mcp.NewToolResultText(out), nil return newToolResultJSON(result)
} }
func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx) if _, err := extractUserID(ctx); err != nil {
if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
...@@ -431,66 +390,56 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque ...@@ -431,66 +390,56 @@ func (s *MCPService) handleUpdateMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) update := &v1pb.Memo{Name: "memos/" + uid}
if err != nil { updateMask := &fieldmaskpb.FieldMask{}
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if memo == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoOwnership(memo, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
update := &store.UpdateMemo{ID: memo.ID}
args := req.GetArguments() args := req.GetArguments()
if v := req.GetString("content", ""); v != "" { if v := req.GetString("content", ""); v != "" {
update.Content = &v update.Content = v
update.Payload = buildPayload(v) updateMask.Paths = append(updateMask.Paths, "content")
} }
if v := req.GetString("visibility", ""); v != "" { if v := req.GetString("visibility", ""); v != "" {
vis, err := parseVisibility(v) vis, err := parseVisibility(v)
if err != nil { if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
update.Visibility = &vis update.Visibility = visibilityToProto(vis)
updateMask.Paths = append(updateMask.Paths, "visibility")
} }
if v := req.GetString("state", ""); v != "" { if v := req.GetString("state", ""); v != "" {
rs, err := parseRowStatus(v) rs, err := parseRowStatus(v)
if err != nil { if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
update.RowStatus = &rs update.State = rowStatusToProto(rs)
updateMask.Paths = append(updateMask.Paths, "state")
} }
if _, ok := args["pinned"]; ok { if _, ok := args["pinned"]; ok {
pinned := req.GetBool("pinned", false) update.Pinned = req.GetBool("pinned", false)
update.Pinned = &pinned updateMask.Paths = append(updateMask.Paths, "pinned")
} }
if err := s.store.UpdateMemo(ctx, update); err != nil { if len(updateMask.Paths) == 0 {
return mcp.NewToolResultError(fmt.Sprintf("failed to update memo: %v", err)), nil return mcp.NewToolResultError("at least one field must be provided to update"), nil
} }
updated, err := s.store.GetMemo(ctx, &store.FindMemo{ID: &memo.ID}) updated, err := s.apiV1Service.UpdateMemo(ctx, &v1pb.UpdateMemoRequest{
Memo: update,
UpdateMask: updateMask,
})
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to fetch updated memo: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to update memo: %v", err)), nil
} }
result, err := storeMemoToJSONWithStore(ctx, s.store, updated) result, err := s.loadMemoJSONByName(ctx, updated.Name)
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil return mcp.NewToolResultError(err.Error()), nil
}
out, err := marshalJSON(result)
if err != nil {
return nil, err
} }
return mcp.NewToolResultText(out), nil return newToolResultJSON(result)
} }
func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx) if _, err := extractUserID(ctx); err != nil {
if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
...@@ -499,21 +448,10 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque ...@@ -499,21 +448,10 @@ func (s *MCPService) handleDeleteMemo(ctx context.Context, req mcp.CallToolReque
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid}) if _, err := s.apiV1Service.DeleteMemo(ctx, &v1pb.DeleteMemoRequest{Name: "memos/" + uid}); err != nil {
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if memo == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoOwnership(memo, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
if err := s.store.DeleteMemo(ctx, &store.DeleteMemo{ID: memo.ID}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete memo: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to delete memo: %v", err)), nil
} }
return mcp.NewToolResultText(`{"deleted":true}`), nil return newDeletedToolResult()
} }
func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -557,11 +495,7 @@ func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequ ...@@ -557,11 +495,7 @@ func (s *MCPService) handleSearchMemos(ctx context.Context, req mcp.CallToolRequ
} }
results[i] = result results[i] = result
} }
out, err := marshalJSON(results) return newToolResultJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -592,8 +526,7 @@ func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToo ...@@ -592,8 +526,7 @@ func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToo
return mcp.NewToolResultError(fmt.Sprintf("failed to list relations: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to list relations: %v", err)), nil
} }
if len(relations) == 0 { if len(relations) == 0 {
out, _ := marshalJSON([]memoJSON{}) return newToolResultJSON([]memoJSON{})
return mcp.NewToolResultText(out), nil
} }
commentIDs := make([]int32, len(relations)) commentIDs := make([]int32, len(relations))
...@@ -626,11 +559,7 @@ func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToo ...@@ -626,11 +559,7 @@ func (s *MCPService) handleListMemoComments(ctx context.Context, req mcp.CallToo
results = append(results, result) results = append(results, result)
} }
} }
out, err := marshalJSON(results) return newToolResultJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -659,33 +588,20 @@ func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallTo ...@@ -659,33 +588,20 @@ func (s *MCPService) handleCreateMemoComment(ctx context.Context, req mcp.CallTo
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
comment, err := s.store.CreateMemo(ctx, &store.Memo{ comment, err := s.apiV1Service.CreateMemoComment(ctx, &v1pb.CreateMemoCommentRequest{
UID: shortuuid.New(), Name: "memos/" + uid,
CreatorID: userID, Comment: &v1pb.Memo{
Content: content, Content: content,
Visibility: parent.Visibility, Visibility: visibilityToProto(parent.Visibility),
Payload: buildPayload(content), },
ParentUID: &parent.UID,
}) })
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to create comment: %v", err)), nil
} }
if _, err = s.store.UpsertMemoRelation(ctx, &store.MemoRelation{ result, err := s.loadMemoJSONByName(ctx, comment.Name)
MemoID: comment.ID,
RelatedMemoID: parent.ID,
Type: store.MemoRelationComment,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to link comment: %v", err)), nil
}
result, err := storeMemoToJSONWithStore(ctx, s.store, comment)
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memo creator: %v", err)), nil return mcp.NewToolResultError(err.Error()), nil
}
out, err := marshalJSON(result)
if err != nil {
return nil, err
} }
return mcp.NewToolResultText(out), nil return newToolResultJSON(result)
} }
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
...@@ -20,19 +21,24 @@ type reactionJSON struct { ...@@ -20,19 +21,24 @@ type reactionJSON struct {
func (s *MCPService) registerReactionTools(mcpSrv *mcpserver.MCPServer) { func (s *MCPService) registerReactionTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_reactions", mcpSrv.AddTool(mcp.NewTool("list_reactions",
mcp.WithDescription("List all reactions on a memo. Returns reaction type and creator for each reaction."), readOnlyToolOptions("List reactions", "List all reactions on a memo. Returns reaction type and creator for each reaction.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
)...,
), s.handleListReactions) ), s.handleListReactions)
mcpSrv.AddTool(mcp.NewTool("upsert_reaction", mcpSrv.AddTool(mcp.NewTool("upsert_reaction",
mcp.WithDescription("Add a reaction (emoji) to a memo. If the same reaction already exists from the same user, this is a no-op. Requires authentication."), createToolOptions("Upsert reaction", "Add a reaction (emoji) to a memo. If the same reaction already exists from the same user, this is a no-op. Requires authentication.", true,
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("reaction_type", mcp.Required(), mcp.Description(`Reaction emoji, e.g. "👍", "❤️", "🎉"`)), mcp.WithString("reaction_type", mcp.Required(), mcp.Description(`Reaction emoji, e.g. "👍", "❤️", "🎉"`)),
mcp.WithOutputSchema[reactionJSON](),
)...,
), s.handleUpsertReaction) ), s.handleUpsertReaction)
mcpSrv.AddTool(mcp.NewTool("delete_reaction", mcpSrv.AddTool(mcp.NewTool("delete_reaction",
mcp.WithDescription("Remove a reaction by its ID. Requires authentication and ownership of the reaction."), updateToolOptions("Delete reaction", "Remove a reaction by its ID. Requires authentication and ownership of the reaction.",
mcp.WithNumber("id", mcp.Required(), mcp.Description("Reaction ID to delete")), mcp.WithNumber("id", mcp.Required(), mcp.Description("Reaction ID to delete")),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteReaction) ), s.handleDeleteReaction)
} }
...@@ -83,11 +89,7 @@ func (s *MCPService) handleListReactions(ctx context.Context, req mcp.CallToolRe ...@@ -83,11 +89,7 @@ func (s *MCPService) handleListReactions(ctx context.Context, req mcp.CallToolRe
} }
} }
out, err := marshalJSON(results) return newToolResultJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -133,34 +135,26 @@ func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolR ...@@ -133,34 +135,26 @@ func (s *MCPService) handleUpsertReaction(ctx context.Context, req mcp.CallToolR
} }
contentID := "memos/" + uid contentID := "memos/" + uid
reaction, err := s.store.UpsertReaction(ctx, &store.Reaction{ reaction, err := s.apiV1Service.UpsertMemoReaction(ctx, &v1pb.UpsertMemoReactionRequest{
CreatorID: userID, Name: contentID,
ContentID: contentID, Reaction: &v1pb.Reaction{
ReactionType: reactionType, ContentId: contentID,
ReactionType: reactionType,
},
}) })
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to upsert reaction: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to upsert reaction: %v", err)), nil
} }
creator, err := lookupUsername(ctx, s.store, reaction.CreatorID) result, err := s.loadReactionJSONByName(ctx, reaction.Name)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve reaction creator: %v", err)), nil
}
out, err := marshalJSON(reactionJSON{
ID: reaction.ID,
Creator: creator,
ReactionType: reaction.ReactionType,
CreateTime: reaction.CreatedTs,
})
if err != nil { if err != nil {
return nil, err return mcp.NewToolResultError(err.Error()), nil
} }
return mcp.NewToolResultText(out), nil return newToolResultJSON(result)
} }
func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx) if _, err := extractUserID(ctx); err != nil {
if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
...@@ -176,12 +170,11 @@ func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolR ...@@ -176,12 +170,11 @@ func (s *MCPService) handleDeleteReaction(ctx context.Context, req mcp.CallToolR
if reaction == nil { if reaction == nil {
return mcp.NewToolResultError("reaction not found"), nil return mcp.NewToolResultError("reaction not found"), nil
} }
if reaction.CreatorID != userID {
return mcp.NewToolResultError("permission denied: can only delete your own reactions"), nil
}
if err := s.store.DeleteReaction(ctx, &store.DeleteReaction{ID: reactionID}); err != nil { if _, err := s.apiV1Service.DeleteMemoReaction(ctx, &v1pb.DeleteMemoReactionRequest{
Name: fmt.Sprintf("%s/reactions/%d", reaction.ContentID, reactionID),
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete reaction: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to delete reaction: %v", err)), nil
} }
return mcp.NewToolResultText(`{"deleted":true}`), nil return newDeletedToolResult()
} }
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth" "github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
...@@ -19,24 +20,29 @@ type relationJSON struct { ...@@ -19,24 +20,29 @@ type relationJSON struct {
func (s *MCPService) registerRelationTools(mcpSrv *mcpserver.MCPServer) { func (s *MCPService) registerRelationTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_memo_relations", mcpSrv.AddTool(mcp.NewTool("list_memo_relations",
mcp.WithDescription("List all relations (references and comments) for a memo. Requires read access to the memo."), readOnlyToolOptions("List memo relations", "List all relations (references and comments) for a memo. Requires read access to the memo.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("type", mcp.WithString("type",
mcp.Enum("REFERENCE", "COMMENT"), mcp.Enum("REFERENCE", "COMMENT"),
mcp.Description("Filter by relation type (optional)"), mcp.Description("Filter by relation type (optional)"),
), ),
)...,
), s.handleListMemoRelations) ), s.handleListMemoRelations)
mcpSrv.AddTool(mcp.NewTool("create_memo_relation", mcpSrv.AddTool(mcp.NewTool("create_memo_relation",
mcp.WithDescription("Create a reference relation between two memos. Requires authentication. For comments, use create_memo_comment instead."), createToolOptions("Create memo relation", "Create a reference relation between two memos. Requires authentication. For comments, use create_memo_comment instead.", true,
mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)), mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)),
mcp.WithOutputSchema[relationJSON](),
)...,
), s.handleCreateMemoRelation) ), s.handleCreateMemoRelation)
mcpSrv.AddTool(mcp.NewTool("delete_memo_relation", mcpSrv.AddTool(mcp.NewTool("delete_memo_relation",
mcp.WithDescription("Delete a reference relation between two memos. Requires authentication and ownership of the source memo."), updateToolOptions("Delete memo relation", "Delete a reference relation between two memos. Requires authentication and ownership of the source memo.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)), mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)), mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteMemoRelation) ), s.handleDeleteMemoRelation)
} }
...@@ -113,11 +119,7 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo ...@@ -113,11 +119,7 @@ func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallTo
}) })
} }
out, err := marshalJSON(results) return newToolResultJSON(results)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -134,6 +136,9 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT ...@@ -134,6 +136,9 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
if err != nil { if err != nil {
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
if srcUID == dstUID {
return mcp.NewToolResultError("cannot create a relation from a memo to itself"), nil
}
srcMemo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &srcUID}) srcMemo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &srcUID})
if err != nil { if err != nil {
...@@ -157,24 +162,24 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT ...@@ -157,24 +162,24 @@ func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallT
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
relation, err := s.store.UpsertMemoRelation(ctx, &store.MemoRelation{ relations, changed, err := s.buildReferenceRelationSet(ctx, srcMemo, &dstMemo.UID, nil)
MemoID: srcMemo.ID,
RelatedMemoID: dstMemo.ID,
Type: store.MemoRelationReference,
})
if err != nil { if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create relation: %v", err)), nil return mcp.NewToolResultError(fmt.Sprintf("failed to build relation set: %v", err)), nil
}
if changed {
if _, err := s.apiV1Service.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: "memos/" + srcUID,
Relations: relations,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create relation: %v", err)), nil
}
} }
out, err := marshalJSON(relationJSON{ return newToolResultJSON(relationJSON{
Memo: "memos/" + srcUID, Memo: "memos/" + srcUID,
RelatedMemo: "memos/" + dstUID, RelatedMemo: "memos/" + dstUID,
Type: string(relation.Type), Type: string(store.MemoRelationReference),
}) })
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
...@@ -214,13 +219,79 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT ...@@ -214,13 +219,79 @@ func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallT
return mcp.NewToolResultError(err.Error()), nil return mcp.NewToolResultError(err.Error()), nil
} }
refType := store.MemoRelationReference relations, changed, err := s.buildReferenceRelationSet(ctx, srcMemo, nil, &dstMemo.UID)
if err := s.store.DeleteMemoRelation(ctx, &store.DeleteMemoRelation{ if err != nil {
MemoID: &srcMemo.ID, return mcp.NewToolResultError(fmt.Sprintf("failed to build relation set: %v", err)), nil
RelatedMemoID: &dstMemo.ID, }
Type: &refType, if changed {
}); err != nil { if _, err := s.apiV1Service.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
return mcp.NewToolResultError(fmt.Sprintf("failed to delete relation: %v", err)), nil Name: "memos/" + srcUID,
Relations: relations,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete relation: %v", err)), nil
}
}
return newDeletedToolResult()
}
func (s *MCPService) buildReferenceRelationSet(ctx context.Context, source *store.Memo, includeUID *string, excludeUID *string) ([]*v1pb.MemoRelation, bool, error) {
referenceType := store.MemoRelationReference
relations, err := s.store.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoIDList: []int32{source.ID},
Type: &referenceType,
})
if err != nil {
return nil, false, err
}
idSet := make(map[int32]struct{}, len(relations))
for _, relation := range relations {
idSet[relation.RelatedMemoID] = struct{}{}
}
ids := make([]int32, 0, len(idSet))
for id := range idSet {
ids = append(ids, id)
}
memosByID := map[int32]*store.Memo{}
if len(ids) > 0 {
memos, err := s.store.ListMemos(ctx, &store.FindMemo{IDList: ids, ExcludeContent: true})
if err != nil {
return nil, false, err
}
for _, memo := range memos {
memosByID[memo.ID] = memo
}
}
result := make([]*v1pb.MemoRelation, 0, len(relations)+1)
seenUIDs := map[string]struct{}{}
changed := false
for _, relation := range relations {
relatedMemo := memosByID[relation.RelatedMemoID]
if relatedMemo == nil {
continue
}
if excludeUID != nil && relatedMemo.UID == *excludeUID {
changed = true
continue
}
result = append(result, newReferenceRelation(source.UID, relatedMemo.UID))
seenUIDs[relatedMemo.UID] = struct{}{}
}
if includeUID != nil {
if _, seen := seenUIDs[*includeUID]; !seen && source.UID != *includeUID {
result = append(result, newReferenceRelation(source.UID, *includeUID))
changed = true
}
}
return result, changed, nil
}
func newReferenceRelation(sourceUID string, relatedUID string) *v1pb.MemoRelation {
return &v1pb.MemoRelation{
Memo: &v1pb.MemoRelation_Memo{Name: "memos/" + sourceUID},
RelatedMemo: &v1pb.MemoRelation_Memo{Name: "memos/" + relatedUID},
Type: v1pb.MemoRelation_REFERENCE,
} }
return mcp.NewToolResultText(`{"deleted":true}`), nil
} }
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) { func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_tags", mcpSrv.AddTool(mcp.NewTool("list_tags",
mcp.WithDescription("List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically."), readOnlyToolOptions("List tags", "List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically.")...,
), s.handleListTags) ), s.handleListTags)
} }
...@@ -70,9 +70,5 @@ func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) ...@@ -70,9 +70,5 @@ func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest)
} }
}) })
out, err := marshalJSON(entries) return newToolResultJSON(entries)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
} }
...@@ -89,7 +89,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store ...@@ -89,7 +89,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
} }
// Register MCP server. // Register MCP server.
mcpService := mcprouter.NewMCPService(s.Profile, s.Store, s.Secret) mcpService := mcprouter.NewMCPService(s.Profile, s.Store, s.Secret, apiV1Service)
mcpService.RegisterRoutes(echoServer) mcpService.RegisterRoutes(echoServer)
return s, nil return s, nil
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment