Commit 803d488a authored by Johnny's avatar Johnny

feat(mcp): refactor MCP server to standard protocol structure

- Replace PAT-only auth with optional auth supporting both PAT and JWT
  via auth.Authenticator.Authenticate(); unauthenticated requests see
  only public memos, matching REST API visibility semantics
- Inline auth middleware into mcp.go following fileserver pattern;
  remove auth_middleware.go
- Introduce memoJSON response type that correctly serialises store.Memo
  (including Payload.Tags and Payload.Property) without proto marshalling
- Add tools: list_memo_comments, create_memo_comment, list_tags
- Extend list_memos with state (NORMAL/ARCHIVED), order_by_pinned, and
  page parameters
- Extend update_memo with pinned and state parameters
- Extract #tags from content on create/update via regex to pre-populate
  Payload.Tags without requiring a full markdown service rebuild
- Add MCP Resources: memo://memos/{uid} template returns memo as
  Markdown with YAML frontmatter, allowing clients to read memos by URI
- Add MCP Prompts: capture (save a thought) and review (search + summarise)
parent 16576be1
package mcp
import (
"net/http"
"github.com/labstack/echo/v5"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
func newAuthMiddleware(s *store.Store, secret string) echo.MiddlewareFunc {
authenticator := auth.NewAuthenticator(s, secret)
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
token := auth.ExtractBearerToken(c.Request().Header.Get("Authorization"))
if token == "" {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "a personal access token is required"})
}
user, pat, err := authenticator.AuthenticateByPAT(c.Request().Context(), token)
if err != nil || user == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"})
}
ctx := auth.SetUserInContext(c.Request().Context(), user, pat.GetTokenId())
c.SetRequest(c.Request().WithContext(ctx))
return next(c)
}
}
}
package mcp package mcp
import ( import (
"net/http"
"github.com/labstack/echo/v5" "github.com/labstack/echo/v5"
"github.com/labstack/echo/v5/middleware" "github.com/labstack/echo/v5/middleware"
mcpserver "github.com/mark3labs/mcp-go/server" mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
type MCPService struct { type MCPService struct {
store *store.Store store *store.Store
secret string authenticator *auth.Authenticator
} }
func NewMCPService(store *store.Store, secret string) *MCPService { func NewMCPService(store *store.Store, secret string) *MCPService {
return &MCPService{store: store, secret: secret} return &MCPService{
store: store,
authenticator: auth.NewAuthenticator(store, secret),
}
} }
func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpSrv := mcpserver.NewMCPServer("Memos", "1.0.0", mcpserver.WithToolCapabilities(false)) mcpSrv := mcpserver.NewMCPServer("Memos", "1.0.0",
mcpserver.WithToolCapabilities(false),
)
s.registerMemoTools(mcpSrv) s.registerMemoTools(mcpSrv)
s.registerTagTools(mcpSrv)
s.registerMemoResources(mcpSrv)
s.registerPrompts(mcpSrv)
httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv) httpHandler := mcpserver.NewStreamableHTTPServer(mcpSrv)
...@@ -27,6 +38,19 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) { ...@@ -27,6 +38,19 @@ func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{ mcpGroup.Use(middleware.CORSWithConfig(middleware.CORSConfig{
AllowOrigins: []string{"*"}, AllowOrigins: []string{"*"},
})) }))
mcpGroup.Use(newAuthMiddleware(s.store, s.secret)) mcpGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c *echo.Context) error {
authHeader := c.Request().Header.Get("Authorization")
if authHeader != "" {
result := s.authenticator.Authenticate(c.Request().Context(), authHeader)
if result == nil {
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired token"})
}
ctx := auth.ApplyToContext(c.Request().Context(), result)
c.SetRequest(c.Request().WithContext(ctx))
}
return next(c)
}
})
mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler)) mcpGroup.Any("/mcp", echo.WrapHandler(httpHandler))
} }
package mcp
import (
"context"
"errors"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
)
func (s *MCPService) registerPrompts(mcpSrv *mcpserver.MCPServer) {
// capture — turns free-form user input into a structured create_memo call.
mcpSrv.AddPrompt(
mcp.NewPrompt("capture",
mcp.WithPromptDescription("Capture a thought, idea, or note as a new memo. "+
"Use this prompt when the user wants to quickly save something. "+
"The assistant will call create_memo with the provided content."),
mcp.WithArgument("content",
mcp.ArgumentDescription("The text to save as a memo"),
mcp.RequiredArgument(),
),
mcp.WithArgument("tags",
mcp.ArgumentDescription("Comma-separated tags to apply, e.g. \"work,project\""),
),
),
s.handleCapturePrompt,
)
// review — surfaces existing memos on a topic for summarisation.
mcpSrv.AddPrompt(
mcp.NewPrompt("review",
mcp.WithPromptDescription("Search and review memos on a given topic. "+
"The assistant will call search_memos and summarise the results."),
mcp.WithArgument("topic",
mcp.ArgumentDescription("Topic or keyword to search for"),
mcp.RequiredArgument(),
),
),
s.handleReviewPrompt,
)
}
func (*MCPService) handleCapturePrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
content := req.Params.Arguments["content"]
if content == "" {
return nil, errors.New("content argument is required")
}
tags := req.Params.Arguments["tags"]
instruction := fmt.Sprintf(
"Please save the following as a new private memo using the create_memo tool.\n\nContent:\n%s",
content,
)
if tags != "" {
instruction += fmt.Sprintf("\n\nAppend these tags inline using #tag syntax: %s", tags)
}
return &mcp.GetPromptResult{
Description: "Capture a memo",
Messages: []mcp.PromptMessage{
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
},
}, nil
}
func (*MCPService) handleReviewPrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
topic := req.Params.Arguments["topic"]
if topic == "" {
return nil, errors.New("topic argument is required")
}
instruction := fmt.Sprintf(
"Please use the search_memos tool to find memos about %q, then provide a concise summary of what has been written on this topic, grouped by theme. Include the memo names so the user can reference them.",
topic,
)
return &mcp.GetPromptResult{
Description: fmt.Sprintf("Review memos about %q", topic),
Messages: []mcp.PromptMessage{
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
},
}, nil
}
package mcp
import (
"context"
"fmt"
"strings"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/pkg/errors"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
// Memo resource URI scheme: memo://memos/{uid}
// Clients can read any memo they have access to by URI without calling a tool.
func (s *MCPService) registerMemoResources(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddResourceTemplate(
mcp.NewResourceTemplate(
"memo://memos/{uid}",
"Memo",
mcp.WithTemplateDescription("A single Memos note identified by its UID. Returns the memo content as Markdown with a YAML frontmatter header containing metadata."),
mcp.WithTemplateMIMEType("text/markdown"),
),
s.handleReadMemoResource,
)
}
func (s *MCPService) handleReadMemoResource(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
userID := auth.GetUserID(ctx)
// URI format: memo://memos/{uid}
uid := strings.TrimPrefix(req.Params.URI, "memo://memos/")
if uid == req.Params.URI || uid == "" {
return nil, errors.Errorf("invalid memo URI %q: expected memo://memos/<uid>", req.Params.URI)
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return nil, errors.Wrap(err, "failed to get memo")
}
if memo == nil {
return nil, errors.Errorf("memo not found: %s", uid)
}
if err := checkMemoAccess(memo, userID); err != nil {
return nil, err
}
j := storeMemoToJSON(memo)
text := formatMemoMarkdown(j)
return []mcp.ResourceContents{
mcp.TextResourceContents{
URI: req.Params.URI,
MIMEType: "text/markdown",
Text: text,
},
}, nil
}
// formatMemoMarkdown renders a memo as Markdown with a YAML frontmatter header.
func formatMemoMarkdown(j memoJSON) string {
var sb strings.Builder
sb.WriteString("---\n")
fmt.Fprintf(&sb, "name: %s\n", j.Name)
fmt.Fprintf(&sb, "creator: %s\n", j.Creator)
fmt.Fprintf(&sb, "visibility: %s\n", j.Visibility)
fmt.Fprintf(&sb, "state: %s\n", j.State)
fmt.Fprintf(&sb, "pinned: %v\n", j.Pinned)
if len(j.Tags) > 0 {
fmt.Fprintf(&sb, "tags: [%s]\n", strings.Join(j.Tags, ", "))
}
fmt.Fprintf(&sb, "create_time: %d\n", j.CreateTime)
fmt.Fprintf(&sb, "update_time: %d\n", j.UpdateTime)
if j.Parent != "" {
fmt.Fprintf(&sb, "parent: %s\n", j.Parent)
}
sb.WriteString("---\n\n")
sb.WriteString(j.Content)
return sb.String()
}
This diff is collapsed.
package mcp
import (
"context"
"fmt"
"sort"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_tags",
mcp.WithDescription("List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically."),
), s.handleListTags)
}
type tagEntry struct {
Tag string `json:"tag"`
Count int `json:"count"`
}
func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := auth.GetUserID(ctx)
rowStatus := store.Normal
find := &store.FindMemo{
ExcludeComments: true,
ExcludeContent: true,
RowStatus: &rowStatus,
}
applyVisibilityFilter(find, userID)
memos, err := s.store.ListMemos(ctx, find)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil
}
counts := make(map[string]int)
for _, m := range memos {
if m.Payload == nil {
continue
}
for _, tag := range m.Payload.Tags {
counts[tag]++
}
}
entries := make([]tagEntry, 0, len(counts))
for tag, count := range counts {
entries = append(entries, tagEntry{Tag: tag, Count: count})
}
sort.Slice(entries, func(i, j int) bool {
if entries[i].Count != entries[j].Count {
return entries[i].Count > entries[j].Count
}
return entries[i].Tag < entries[j].Tag
})
out, err := marshalJSON(entries)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(out), nil
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment