Unverified Commit 50638040 authored by boojack's avatar boojack Committed by GitHub

fix: reduce list memo query overhead (#5880)

parent ebc0e10f
...@@ -297,18 +297,22 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq ...@@ -297,18 +297,22 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
} }
// RELATIONS (batch load to avoid N+1) // RELATIONS (batch load to avoid N+1)
relationMap, err := s.batchConvertMemoRelations(ctx, memos) relationMap, err := s.batchConvertMemoRelations(ctx, memos, false)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to batch load memo relations") return nil, status.Errorf(codes.Internal, "failed to batch load memo relations")
} }
creatorIDs := make([]int32, 0, len(memos)) creatorIDs := make([]int32, 0, len(memos)+len(reactions))
for _, memo := range memos { for _, memo := range memos {
creatorIDs = append(creatorIDs, memo.CreatorID) creatorIDs = append(creatorIDs, memo.CreatorID)
} }
for _, reaction := range reactions {
creatorIDs = append(creatorIDs, reaction.CreatorID)
}
creatorMap, err := s.listUsersByID(ctx, creatorIDs) creatorMap, err := s.listUsersByID(ctx, creatorIDs)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo creators: %v", err) return nil, status.Errorf(codes.Internal, "failed to list memo creators: %v", err)
} }
conversionOptions := memoConversionOptions{displayWithUpdateTime: instanceMemoRelatedSetting.DisplayWithUpdateTime}
for _, memo := range memos { for _, memo := range memos {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID) memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
...@@ -316,7 +320,7 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq ...@@ -316,7 +320,7 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
attachments := attachmentMap[memo.ID] attachments := attachmentMap[memo.ID]
relations := relationMap[memo.ID] relations := relationMap[memo.ID]
memoMessage, err := s.convertMemoFromStoreWithCreators(ctx, memo, reactions, attachments, relations, creatorMap) memoMessage, err := s.convertMemoFromStoreWithCreatorsAndOptions(ctx, memo, reactions, attachments, relations, creatorMap, conversionOptions)
if err != nil { if err != nil {
if stderrors.Is(err, errMemoCreatorNotFound) { if stderrors.Is(err, errMemoCreatorNotFound) {
slog.Warn("Skipping memo with missing creator", slog.Warn("Skipping memo with missing creator",
...@@ -798,18 +802,26 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM ...@@ -798,18 +802,26 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
} }
// RELATIONS (batch load to avoid N+1) // RELATIONS (batch load to avoid N+1)
relationMap, err := s.batchConvertMemoRelations(ctx, memos) relationMap, err := s.batchConvertMemoRelations(ctx, memos, false)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to batch load memo relations") return nil, status.Errorf(codes.Internal, "failed to batch load memo relations")
} }
creatorIDs := make([]int32, 0, len(memos)) creatorIDs := make([]int32, 0, len(memos)+len(reactions))
for _, memo := range memos { for _, memo := range memos {
creatorIDs = append(creatorIDs, memo.CreatorID) creatorIDs = append(creatorIDs, memo.CreatorID)
} }
for _, reaction := range reactions {
creatorIDs = append(creatorIDs, reaction.CreatorID)
}
creatorMap, err := s.listUsersByID(ctx, creatorIDs) creatorMap, err := s.listUsersByID(ctx, creatorIDs)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo creators: %v", err) return nil, status.Errorf(codes.Internal, "failed to list memo creators: %v", err)
} }
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
conversionOptions := memoConversionOptions{displayWithUpdateTime: instanceMemoRelatedSetting.DisplayWithUpdateTime}
var memosResponse []*v1pb.Memo var memosResponse []*v1pb.Memo
for _, m := range memos { for _, m := range memos {
...@@ -818,7 +830,7 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM ...@@ -818,7 +830,7 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
attachments := attachmentMap[m.ID] attachments := attachmentMap[m.ID]
relations := relationMap[m.ID] relations := relationMap[m.ID]
memoMessage, err := s.convertMemoFromStoreWithCreators(ctx, m, reactions, attachments, relations, creatorMap) memoMessage, err := s.convertMemoFromStoreWithCreatorsAndOptions(ctx, m, reactions, attachments, relations, creatorMap, conversionOptions)
if err != nil { if err != nil {
if stderrors.Is(err, errMemoCreatorNotFound) { if stderrors.Is(err, errMemoCreatorNotFound) {
slog.Warn("Skipping memo comment with missing creator", slog.Warn("Skipping memo comment with missing creator",
......
...@@ -20,6 +20,10 @@ var ( ...@@ -20,6 +20,10 @@ var (
errReactionCreatorNotFound = stderrors.New("reaction creator not found") errReactionCreatorNotFound = stderrors.New("reaction creator not found")
) )
type memoConversionOptions struct {
displayWithUpdateTime bool
}
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation) (*v1pb.Memo, error) { func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation) (*v1pb.Memo, error) {
creatorMap, err := s.listUsersByID(ctx, []int32{memo.CreatorID}) creatorMap, err := s.listUsersByID(ctx, []int32{memo.CreatorID})
if err != nil { if err != nil {
...@@ -29,12 +33,24 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem ...@@ -29,12 +33,24 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem
} }
func (s *APIV1Service) convertMemoFromStoreWithCreators(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation, creatorMap map[int32]*store.User) (*v1pb.Memo, error) { func (s *APIV1Service) convertMemoFromStoreWithCreators(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation, creatorMap map[int32]*store.User) (*v1pb.Memo, error) {
displayTs := memo.CreatedTs
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx) instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to get instance memo related setting") return nil, errors.Wrap(err, "failed to get instance memo related setting")
} }
if instanceMemoRelatedSetting.DisplayWithUpdateTime { return s.convertMemoFromStoreWithCreatorsAndOptions(
ctx,
memo,
reactions,
attachments,
relations,
creatorMap,
memoConversionOptions{displayWithUpdateTime: instanceMemoRelatedSetting.DisplayWithUpdateTime},
)
}
func (s *APIV1Service) convertMemoFromStoreWithCreatorsAndOptions(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation, creatorMap map[int32]*store.User, options memoConversionOptions) (*v1pb.Memo, error) {
displayTs := memo.CreatedTs
if options.displayWithUpdateTime {
displayTs = memo.UpdatedTs displayTs = memo.UpdatedTs
} }
...@@ -65,10 +81,11 @@ func (s *APIV1Service) convertMemoFromStoreWithCreators(ctx context.Context, mem ...@@ -65,10 +81,11 @@ func (s *APIV1Service) convertMemoFromStoreWithCreators(ctx context.Context, mem
memoMessage.Parent = &parentName memoMessage.Parent = &parentName
} }
memoMessage.Reactions, err = s.convertReactionsFromStoreWithCreators(ctx, reactions, creatorMap) reactionMessages, err := s.convertReactionsFromStoreWithCreators(ctx, reactions, creatorMap)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to convert reactions") return nil, errors.Wrap(err, "failed to convert reactions")
} }
memoMessage.Reactions = reactionMessages
if relations != nil { if relations != nil {
memoMessage.Relations = relations memoMessage.Relations = relations
...@@ -179,7 +196,7 @@ func convertReactionFromStoreWithCreators(reaction *store.Reaction, creatorsByID ...@@ -179,7 +196,7 @@ func convertReactionFromStoreWithCreators(reaction *store.Reaction, creatorsByID
// batchConvertMemoRelations batch-loads relations for a list of memos and returns // batchConvertMemoRelations batch-loads relations for a list of memos and returns
// a map from memo ID to its converted relations. This avoids N+1 queries when listing memos. // a map from memo ID to its converted relations. This avoids N+1 queries when listing memos.
func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*store.Memo) (map[int32][]*v1pb.MemoRelation, error) { func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*store.Memo, includeSnippets bool) (map[int32][]*v1pb.MemoRelation, error) {
if len(memos) == 0 { if len(memos) == 0 {
return map[int32][]*v1pb.MemoRelation{}, nil return map[int32][]*v1pb.MemoRelation{}, nil
} }
...@@ -202,14 +219,21 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s ...@@ -202,14 +219,21 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s
memoIDSet[m.ID] = true memoIDSet[m.ID] = true
} }
// Single batch query to get all relations involving any of these memos. outgoingRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
allRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{ SourceMemoIDList: memoIDs,
MemoIDList: memoIDs, MemoFilter: &memoFilter,
})
if err != nil {
return nil, errors.Wrap(err, "failed to batch list outgoing memo relations")
}
incomingRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoIDList: memoIDs,
MemoFilter: &memoFilter, MemoFilter: &memoFilter,
}) })
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to batch list memo relations") return nil, errors.Wrap(err, "failed to batch list incoming memo relations")
} }
allRelations := mergeMemoRelations(outgoingRelations, incomingRelations)
// Collect all memo IDs referenced in relations that we need to resolve. // Collect all memo IDs referenced in relations that we need to resolve.
neededIDs := make(map[int32]bool) neededIDs := make(map[int32]bool)
...@@ -220,10 +244,16 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s ...@@ -220,10 +244,16 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s
// Build ID→UID map from the memos we already have. // Build ID→UID map from the memos we already have.
memoIDToUID := make(map[int32]string, len(memos)) memoIDToUID := make(map[int32]string, len(memos))
memoIDToContent := make(map[int32]string, len(memos)) memoIDToSnippet := make(map[int32]string, len(memos))
for _, m := range memos { for _, m := range memos {
memoIDToUID[m.ID] = m.UID memoIDToUID[m.ID] = m.UID
memoIDToContent[m.ID] = m.Content if includeSnippets {
snippet, err := s.getMemoContentSnippet(m.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get memo content snippet")
}
memoIDToSnippet[m.ID] = snippet
}
delete(neededIDs, m.ID) delete(neededIDs, m.ID)
} }
...@@ -233,13 +263,20 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s ...@@ -233,13 +263,20 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s
for id := range neededIDs { for id := range neededIDs {
extraIDs = append(extraIDs, id) extraIDs = append(extraIDs, id)
} }
extraMemos, err := s.Store.ListMemos(ctx, &store.FindMemo{IDList: extraIDs}) extraFind := &store.FindMemo{IDList: extraIDs, ExcludeContent: !includeSnippets}
extraMemos, err := s.Store.ListMemos(ctx, extraFind)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to batch fetch related memos") return nil, errors.Wrap(err, "failed to batch fetch related memos")
} }
for _, m := range extraMemos { for _, m := range extraMemos {
memoIDToUID[m.ID] = m.UID memoIDToUID[m.ID] = m.UID
memoIDToContent[m.ID] = m.Content if includeSnippets {
snippet, err := s.getMemoContentSnippet(m.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get related memo content snippet")
}
memoIDToSnippet[m.ID] = snippet
}
} }
} }
...@@ -252,16 +289,14 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s ...@@ -252,16 +289,14 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s
continue continue
} }
memoSnippet, _ := s.getMemoContentSnippet(memoIDToContent[r.MemoID])
relatedSnippet, _ := s.getMemoContentSnippet(memoIDToContent[r.RelatedMemoID])
relation := &v1pb.MemoRelation{ relation := &v1pb.MemoRelation{
Memo: &v1pb.MemoRelation_Memo{ Memo: &v1pb.MemoRelation_Memo{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memoUID), Name: fmt.Sprintf("%s%s", MemoNamePrefix, memoUID),
Snippet: memoSnippet, Snippet: memoIDToSnippet[r.MemoID],
}, },
RelatedMemo: &v1pb.MemoRelation_Memo{ RelatedMemo: &v1pb.MemoRelation_Memo{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, relatedUID), Name: fmt.Sprintf("%s%s", MemoNamePrefix, relatedUID),
Snippet: relatedSnippet, Snippet: memoIDToSnippet[r.RelatedMemoID],
}, },
Type: convertMemoRelationTypeFromStore(r.Type), Type: convertMemoRelationTypeFromStore(r.Type),
} }
...@@ -280,13 +315,29 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s ...@@ -280,13 +315,29 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s
// loadMemoRelations loads relations for a single memo and converts them to API format. // loadMemoRelations loads relations for a single memo and converts them to API format.
func (s *APIV1Service) loadMemoRelations(ctx context.Context, memo *store.Memo) ([]*v1pb.MemoRelation, error) { func (s *APIV1Service) loadMemoRelations(ctx context.Context, memo *store.Memo) ([]*v1pb.MemoRelation, error) {
relationMap, err := s.batchConvertMemoRelations(ctx, []*store.Memo{memo}) relationMap, err := s.batchConvertMemoRelations(ctx, []*store.Memo{memo}, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return relationMap[memo.ID], nil return relationMap[memo.ID], nil
} }
func mergeMemoRelations(groups ...[]*store.MemoRelation) []*store.MemoRelation {
seen := make(map[string]struct{})
merged := make([]*store.MemoRelation, 0)
for _, relations := range groups {
for _, relation := range relations {
key := fmt.Sprintf("%d:%d:%s", relation.MemoID, relation.RelatedMemoID, relation.Type)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
merged = append(merged, relation)
}
}
return merged
}
func convertMemoPropertyFromStore(property *storepb.MemoPayload_Property) *v1pb.Memo_Property { func convertMemoPropertyFromStore(property *storepb.MemoPayload_Property) *v1pb.Memo_Property {
if property == nil { if property == nil {
return nil return nil
......
...@@ -176,7 +176,7 @@ func (s *APIV1Service) GetMemoByShare(ctx context.Context, request *v1pb.GetMemo ...@@ -176,7 +176,7 @@ func (s *APIV1Service) GetMemoByShare(ctx context.Context, request *v1pb.GetMemo
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments") return nil, status.Errorf(codes.Internal, "failed to list attachments")
} }
relations, err := s.batchConvertMemoRelations(ctx, []*store.Memo{memo}) relations, err := s.batchConvertMemoRelations(ctx, []*store.Memo{memo}, true)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to load memo relations") return nil, status.Errorf(codes.Internal, "failed to load memo relations")
} }
......
package test
import (
"context"
"fmt"
"path/filepath"
"testing"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/internal/version"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
"github.com/usememos/memos/store/db"
)
const (
benchmarkTopLevelMemoCount = 5000
benchmarkPageSize = 16
)
type benchmarkService struct {
*TestService
hostUser *store.User
authenticatedCtx context.Context
publicCtx context.Context
pageTenToken string
commentParentName string
}
func newBenchmarkService(tb testing.TB) *benchmarkService {
tb.Helper()
ctx := context.Background()
testService := newTestingServiceForTB(tb)
hostUser, err := testService.CreateHostUser(ctx, "bench-host")
if err != nil {
tb.Fatalf("failed to create host user: %v", err)
}
commentParentName, err := seedListMemosBenchmarkData(ctx, testService.Store, hostUser)
if err != nil {
tb.Fatalf("failed to seed benchmark data: %v", err)
}
authenticatedCtx := context.WithValue(context.Background(), auth.UserIDContextKey, hostUser.ID)
pageTenToken, err := getListMemosPageToken(authenticatedCtx, testService.Service, 10, benchmarkPageSize)
if err != nil {
tb.Fatalf("failed to build page token: %v", err)
}
return &benchmarkService{
TestService: testService,
hostUser: hostUser,
authenticatedCtx: authenticatedCtx,
publicCtx: context.Background(),
pageTenToken: pageTenToken,
commentParentName: commentParentName,
}
}
func newTestingServiceForTB(tb testing.TB) *TestService {
tb.Helper()
ctx := context.Background()
dataDir := tb.TempDir()
testProfile := getBenchmarkProfile(dataDir)
dbDriver, err := db.NewDBDriver(testProfile)
if err != nil {
tb.Fatalf("failed to create db driver: %v", err)
}
testStore := store.New(dbDriver, testProfile)
if err := testStore.Migrate(ctx); err != nil {
tb.Fatalf("failed to migrate db: %v", err)
}
tb.Cleanup(func() {
testStore.Close()
})
service := newServiceWithProfile(testProfile, testStore)
return &TestService{
Service: service,
Store: testStore,
Profile: testProfile,
Secret: service.Secret,
}
}
func getBenchmarkProfile(dataDir string) *profile.Profile {
return &profile.Profile{
Demo: true,
Version: version.GetCurrentVersion(),
InstanceURL: "http://localhost:8080",
Driver: "sqlite",
DSN: filepath.Join(dataDir, "bench.db"),
Data: dataDir,
}
}
func newServiceWithProfile(testProfile *profile.Profile, testStore *store.Store) *apiv1.APIV1Service {
service := apiv1.NewAPIV1Service("bench-secret", testProfile, testStore)
return service
}
func seedListMemosBenchmarkData(ctx context.Context, stores *store.Store, hostUser *store.User) (string, error) {
topLevelMemos := make([]*store.Memo, 0, benchmarkTopLevelMemoCount)
commentParentName := ""
for i := 0; i < benchmarkTopLevelMemoCount; i++ {
visibility := store.Private
if i%4 == 0 {
visibility = store.Public
}
memo, err := stores.CreateMemo(ctx, &store.Memo{
UID: fmt.Sprintf("memo-%06d", i),
CreatorID: hostUser.ID,
Content: benchmarkMemoContent(i),
Visibility: visibility,
})
if err != nil {
return "", err
}
topLevelMemos = append(topLevelMemos, memo)
if i%3 == 0 {
if _, err := stores.CreateAttachment(ctx, &store.Attachment{
UID: fmt.Sprintf("att-%06d", i),
CreatorID: hostUser.ID,
Filename: fmt.Sprintf("memo-%06d.png", i),
Type: "image/png",
Size: 2048,
MemoID: &memo.ID,
}); err != nil {
return "", err
}
}
if i%5 == 0 {
if _, err := stores.UpsertReaction(ctx, &store.Reaction{
CreatorID: hostUser.ID,
ContentID: "memos/" + memo.UID,
ReactionType: "thumbs-up",
}); err != nil {
return "", err
}
}
}
for i, memo := range topLevelMemos {
if i+1 < len(topLevelMemos) && i%4 == 0 {
if _, err := stores.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo.ID,
RelatedMemoID: topLevelMemos[i+1].ID,
Type: store.MemoRelationReference,
}); err != nil {
return "", err
}
}
if i%6 == 0 {
commentMemo, err := stores.CreateMemo(ctx, &store.Memo{
UID: fmt.Sprintf("comment-%06d", i),
CreatorID: hostUser.ID,
Content: fmt.Sprintf("Comment for memo %06d", i),
Visibility: store.Private,
})
if err != nil {
return "", err
}
if _, err := stores.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: commentMemo.ID,
RelatedMemoID: memo.ID,
Type: store.MemoRelationComment,
}); err != nil {
return "", err
}
if commentParentName == "" {
commentParentName = "memos/" + memo.UID
}
}
}
return commentParentName, nil
}
func benchmarkMemoContent(i int) string {
return fmt.Sprintf("# Bench Memo %06d\n\nThis is benchmark memo %06d with enough content to exercise snippet generation.\n\n- task one\n- task two\n", i, i)
}
func getListMemosPageToken(ctx context.Context, service *apiv1.APIV1Service, page int, pageSize int32) (string, error) {
pageToken := ""
for range page - 1 {
resp, err := service.ListMemos(ctx, &v1pb.ListMemosRequest{
PageSize: pageSize,
PageToken: pageToken,
})
if err != nil {
return "", err
}
pageToken = resp.NextPageToken
if pageToken == "" {
break
}
}
return pageToken, nil
}
func BenchmarkListMemos(b *testing.B) {
bench := newBenchmarkService(b)
b.Run("authenticated_first_page", func(b *testing.B) {
req := &v1pb.ListMemosRequest{PageSize: benchmarkPageSize}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := bench.Service.ListMemos(bench.authenticatedCtx, req)
if err != nil {
b.Fatalf("ListMemos failed: %v", err)
}
if len(resp.Memos) == 0 {
b.Fatal("expected memos in authenticated benchmark response")
}
}
})
b.Run("authenticated_page_ten", func(b *testing.B) {
req := &v1pb.ListMemosRequest{PageSize: benchmarkPageSize, PageToken: bench.pageTenToken}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := bench.Service.ListMemos(bench.authenticatedCtx, req)
if err != nil {
b.Fatalf("ListMemos failed: %v", err)
}
if len(resp.Memos) == 0 {
b.Fatal("expected memos in paged benchmark response")
}
}
})
b.Run("public_first_page", func(b *testing.B) {
req := &v1pb.ListMemosRequest{PageSize: benchmarkPageSize}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := bench.Service.ListMemos(bench.publicCtx, req)
if err != nil {
b.Fatalf("ListMemos failed: %v", err)
}
if len(resp.Memos) == 0 {
b.Fatal("expected memos in public benchmark response")
}
}
})
}
func BenchmarkListMemoCommentsPreview(b *testing.B) {
bench := newBenchmarkService(b)
if bench.commentParentName == "" {
b.Fatal("expected seeded memo with comments")
}
req := &v1pb.ListMemoCommentsRequest{
Name: bench.commentParentName,
PageSize: 3,
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := bench.Service.ListMemoComments(bench.authenticatedCtx, req)
if err != nil {
b.Fatalf("ListMemoComments failed: %v", err)
}
if len(resp.Memos) == 0 {
b.Fatal("expected comments in benchmark response")
}
}
}
...@@ -54,6 +54,22 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation ...@@ -54,6 +54,22 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
} }
where = append(where, fmt.Sprintf("(`memo_id` IN (%s) OR `related_memo_id` IN (%s))", inClause, inClause)) where = append(where, fmt.Sprintf("(`memo_id` IN (%s) OR `related_memo_id` IN (%s))", inClause, inClause))
} }
if len(find.SourceMemoIDList) > 0 {
placeholders := make([]string, len(find.SourceMemoIDList))
for i, id := range find.SourceMemoIDList {
placeholders[i] = "?"
args = append(args, id)
}
where = append(where, fmt.Sprintf("`memo_id` IN (%s)", strings.Join(placeholders, ", ")))
}
if len(find.RelatedMemoIDList) > 0 {
placeholders := make([]string, len(find.RelatedMemoIDList))
for i, id := range find.RelatedMemoIDList {
placeholders[i] = "?"
args = append(args, id)
}
where = append(where, fmt.Sprintf("`related_memo_id` IN (%s)", strings.Join(placeholders, ", ")))
}
if find.MemoFilter != nil { if find.MemoFilter != nil {
engine, err := filter.DefaultEngine() engine, err := filter.DefaultEngine()
if err != nil { if err != nil {
......
...@@ -50,18 +50,34 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation ...@@ -50,18 +50,34 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, *find.Type) where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, *find.Type)
} }
if len(find.MemoIDList) > 0 { if len(find.MemoIDList) > 0 {
memoPlaceholders := make([]string, len(find.MemoIDList)) placeholders := make([]string, len(find.MemoIDList))
for i, id := range find.MemoIDList { for i, id := range find.MemoIDList {
memoPlaceholders[i] = placeholder(len(args) + 1) placeholders[i] = placeholder(len(args) + 1)
args = append(args, id) args = append(args, id)
} }
inClause := strings.Join(placeholders, ", ")
relatedPlaceholders := make([]string, len(find.MemoIDList)) relatedPlaceholders := make([]string, len(find.MemoIDList))
for i, id := range find.MemoIDList { for i, id := range find.MemoIDList {
relatedPlaceholders[i] = placeholder(len(args) + 1) relatedPlaceholders[i] = placeholder(len(args) + 1)
args = append(args, id) args = append(args, id)
} }
where = append(where, fmt.Sprintf("(memo_id IN (%s) OR related_memo_id IN (%s))", where = append(where, fmt.Sprintf("(memo_id IN (%s) OR related_memo_id IN (%s))", inClause, strings.Join(relatedPlaceholders, ", ")))
strings.Join(memoPlaceholders, ", "), strings.Join(relatedPlaceholders, ", "))) }
if len(find.SourceMemoIDList) > 0 {
placeholders := make([]string, len(find.SourceMemoIDList))
for i, id := range find.SourceMemoIDList {
placeholders[i] = placeholder(len(args) + 1)
args = append(args, id)
}
where = append(where, fmt.Sprintf("memo_id IN (%s)", strings.Join(placeholders, ", ")))
}
if len(find.RelatedMemoIDList) > 0 {
placeholders := make([]string, len(find.RelatedMemoIDList))
for i, id := range find.RelatedMemoIDList {
placeholders[i] = placeholder(len(args) + 1)
args = append(args, id)
}
where = append(where, fmt.Sprintf("related_memo_id IN (%s)", strings.Join(placeholders, ", ")))
} }
if find.MemoFilter != nil { if find.MemoFilter != nil {
engine, err := filter.DefaultEngine() engine, err := filter.DefaultEngine()
......
...@@ -56,12 +56,27 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation ...@@ -56,12 +56,27 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
args = append(args, id) args = append(args, id)
} }
inClause := strings.Join(placeholders, ", ") inClause := strings.Join(placeholders, ", ")
// Duplicate args for the second IN clause.
for _, id := range find.MemoIDList { for _, id := range find.MemoIDList {
args = append(args, id) args = append(args, id)
} }
where = append(where, fmt.Sprintf("(memo_id IN (%s) OR related_memo_id IN (%s))", inClause, inClause)) where = append(where, fmt.Sprintf("(memo_id IN (%s) OR related_memo_id IN (%s))", inClause, inClause))
} }
if len(find.SourceMemoIDList) > 0 {
placeholders := make([]string, len(find.SourceMemoIDList))
for i, id := range find.SourceMemoIDList {
placeholders[i] = "?"
args = append(args, id)
}
where = append(where, fmt.Sprintf("memo_id IN (%s)", strings.Join(placeholders, ", ")))
}
if len(find.RelatedMemoIDList) > 0 {
placeholders := make([]string, len(find.RelatedMemoIDList))
for i, id := range find.RelatedMemoIDList {
placeholders[i] = "?"
args = append(args, id)
}
where = append(where, fmt.Sprintf("related_memo_id IN (%s)", strings.Join(placeholders, ", ")))
}
if find.MemoFilter != nil { if find.MemoFilter != nil {
engine, err := filter.DefaultEngine() engine, err := filter.DefaultEngine()
if err != nil { if err != nil {
......
...@@ -26,6 +26,10 @@ type FindMemoRelation struct { ...@@ -26,6 +26,10 @@ type FindMemoRelation struct {
MemoFilter *string MemoFilter *string
// MemoIDList matches relations where memo_id OR related_memo_id is in the list. // MemoIDList matches relations where memo_id OR related_memo_id is in the list.
MemoIDList []int32 MemoIDList []int32
// SourceMemoIDList matches relations where memo_id is in the list.
SourceMemoIDList []int32
// RelatedMemoIDList matches relations where related_memo_id is in the list.
RelatedMemoIDList []int32
Limit *int Limit *int
Offset *int Offset *int
} }
......
...@@ -5,6 +5,7 @@ import type { MemoRelation } from "@/types/proto/api/v1/memo_service_pb"; ...@@ -5,6 +5,7 @@ import type { MemoRelation } from "@/types/proto/api/v1/memo_service_pb";
import { useTranslate } from "@/utils/i18n"; import { useTranslate } from "@/utils/i18n";
import RelationCard from "./RelationCard"; import RelationCard from "./RelationCard";
import { getRelationBuckets, getRelationMemo, getRelationMemoName, type RelationDirection } from "./relationHelpers"; import { getRelationBuckets, getRelationMemo, getRelationMemoName, type RelationDirection } from "./relationHelpers";
import { useResolvedRelationMemos } from "./useResolvedRelationMemos";
interface RelationListViewProps { interface RelationListViewProps {
relations: MemoRelation[]; relations: MemoRelation[];
...@@ -16,6 +17,7 @@ interface RelationListViewProps { ...@@ -16,6 +17,7 @@ interface RelationListViewProps {
function RelationListView({ relations, currentMemoName, parentPage, className }: RelationListViewProps) { function RelationListView({ relations, currentMemoName, parentPage, className }: RelationListViewProps) {
const t = useTranslate(); const t = useTranslate();
const [activeTab, setActiveTab] = useState<"referencing" | "referenced">("referencing"); const [activeTab, setActiveTab] = useState<"referencing" | "referenced">("referencing");
const resolvedMemos = useResolvedRelationMemos(relations);
const { referencing: referencingRelations, referenced: referencedRelations } = useMemo( const { referencing: referencingRelations, referenced: referencedRelations } = useMemo(
() => getRelationBuckets(relations, currentMemoName), () => getRelationBuckets(relations, currentMemoName),
...@@ -60,9 +62,15 @@ function RelationListView({ relations, currentMemoName, parentPage, className }: ...@@ -60,9 +62,15 @@ function RelationListView({ relations, currentMemoName, parentPage, className }:
} }
contentClassName="flex flex-col gap-0 p-1.5" contentClassName="flex flex-col gap-0 p-1.5"
> >
{activeRelations.map((relation) => ( {activeRelations.map((relation) => {
<RelationCard key={getRelationMemoName(relation, direction)} memo={getRelationMemo(relation, direction)!} parentPage={parentPage} /> const memo = getRelationMemo(relation, direction);
))} if (!memo) {
return null;
}
return (
<RelationCard key={getRelationMemoName(relation, direction)} memo={resolvedMemos[memo.name] ?? memo} parentPage={parentPage} />
);
})}
</MetadataSection> </MetadataSection>
); );
} }
......
...@@ -11,9 +11,10 @@ export const useResolvedRelationMemos = (relations: MemoRelation[]) => { ...@@ -11,9 +11,10 @@ export const useResolvedRelationMemos = (relations: MemoRelation[]) => {
const names = new Set<string>(); const names = new Set<string>();
for (const relation of relations) { for (const relation of relations) {
const relatedMemo = relation.relatedMemo; for (const memo of [relation.memo, relation.relatedMemo]) {
if (relatedMemo?.name && !relatedMemo.snippet && !resolvedMemos[relatedMemo.name]) { if (memo?.name && !memo.snippet && !resolvedMemos[memo.name]) {
names.add(relatedMemo.name); names.add(memo.name);
}
} }
} }
......
...@@ -10,7 +10,7 @@ const MemoCommentListView: React.FC = () => { ...@@ -10,7 +10,7 @@ const MemoCommentListView: React.FC = () => {
const { memo } = useMemoViewContext(); const { memo } = useMemoViewContext();
const { isInMemoDetailPage, commentAmount } = useMemoViewDerived(); const { isInMemoDetailPage, commentAmount } = useMemoViewDerived();
const { data } = useMemoComments(memo.name, { enabled: !isInMemoDetailPage && commentAmount > 0 }); const { data } = useMemoComments(memo.name, { enabled: !isInMemoDetailPage && commentAmount > 0, pageSize: 3 });
const comments = data?.memos ?? []; const comments = data?.memos ?? [];
const displayedComments = comments.slice(0, 3); const displayedComments = comments.slice(0, 3);
const { data: commentCreators } = useUsersByNames(displayedComments.map((comment) => comment.creator)); const { data: commentCreators } = useUsersByNames(displayedComments.map((comment) => comment.creator));
......
...@@ -5,7 +5,7 @@ import { useInfiniteQuery, useMutation, useQuery, useQueryClient } from "@tansta ...@@ -5,7 +5,7 @@ import { useInfiniteQuery, useMutation, useQuery, useQueryClient } from "@tansta
import { memoServiceClient } from "@/connect"; import { memoServiceClient } from "@/connect";
import { userKeys } from "@/hooks/useUserQueries"; import { userKeys } from "@/hooks/useUserQueries";
import type { ListMemosRequest, ListMemosResponse, Memo } from "@/types/proto/api/v1/memo_service_pb"; import type { ListMemosRequest, ListMemosResponse, Memo } from "@/types/proto/api/v1/memo_service_pb";
import { ListMemosRequestSchema, MemoSchema } from "@/types/proto/api/v1/memo_service_pb"; import { ListMemoCommentsRequestSchema, ListMemosRequestSchema, MemoSchema } from "@/types/proto/api/v1/memo_service_pb";
// Query keys factory for consistent cache management // Query keys factory for consistent cache management
export const memoKeys = { export const memoKeys = {
...@@ -243,11 +243,16 @@ export function useDeleteMemo() { ...@@ -243,11 +243,16 @@ export function useDeleteMemo() {
}); });
} }
export function useMemoComments(name: string, options?: { enabled?: boolean }) { export function useMemoComments(name: string, options?: { enabled?: boolean; pageSize?: number }) {
return useQuery({ return useQuery({
queryKey: memoKeys.comments(name), queryKey: [...memoKeys.comments(name), options?.pageSize ?? 0],
queryFn: async () => { queryFn: async () => {
const response = await memoServiceClient.listMemoComments({ name }); const response = await memoServiceClient.listMemoComments(
create(ListMemoCommentsRequestSchema, {
name,
pageSize: options?.pageSize ?? 0,
}),
);
return response; return response;
}, },
enabled: options?.enabled ?? true, enabled: options?.enabled ?? true,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment