Unverified Commit 50638040 authored by boojack's avatar boojack Committed by GitHub

fix: reduce list memo query overhead (#5880)

parent ebc0e10f
......@@ -297,18 +297,22 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
}
// RELATIONS (batch load to avoid N+1)
relationMap, err := s.batchConvertMemoRelations(ctx, memos)
relationMap, err := s.batchConvertMemoRelations(ctx, memos, false)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to batch load memo relations")
}
creatorIDs := make([]int32, 0, len(memos))
creatorIDs := make([]int32, 0, len(memos)+len(reactions))
for _, memo := range memos {
creatorIDs = append(creatorIDs, memo.CreatorID)
}
for _, reaction := range reactions {
creatorIDs = append(creatorIDs, reaction.CreatorID)
}
creatorMap, err := s.listUsersByID(ctx, creatorIDs)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo creators: %v", err)
}
conversionOptions := memoConversionOptions{displayWithUpdateTime: instanceMemoRelatedSetting.DisplayWithUpdateTime}
for _, memo := range memos {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
......@@ -316,7 +320,7 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
attachments := attachmentMap[memo.ID]
relations := relationMap[memo.ID]
memoMessage, err := s.convertMemoFromStoreWithCreators(ctx, memo, reactions, attachments, relations, creatorMap)
memoMessage, err := s.convertMemoFromStoreWithCreatorsAndOptions(ctx, memo, reactions, attachments, relations, creatorMap, conversionOptions)
if err != nil {
if stderrors.Is(err, errMemoCreatorNotFound) {
slog.Warn("Skipping memo with missing creator",
......@@ -798,18 +802,26 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
}
// RELATIONS (batch load to avoid N+1)
relationMap, err := s.batchConvertMemoRelations(ctx, memos)
relationMap, err := s.batchConvertMemoRelations(ctx, memos, false)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to batch load memo relations")
}
creatorIDs := make([]int32, 0, len(memos))
creatorIDs := make([]int32, 0, len(memos)+len(reactions))
for _, memo := range memos {
creatorIDs = append(creatorIDs, memo.CreatorID)
}
for _, reaction := range reactions {
creatorIDs = append(creatorIDs, reaction.CreatorID)
}
creatorMap, err := s.listUsersByID(ctx, creatorIDs)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo creators: %v", err)
}
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get instance memo related setting")
}
conversionOptions := memoConversionOptions{displayWithUpdateTime: instanceMemoRelatedSetting.DisplayWithUpdateTime}
var memosResponse []*v1pb.Memo
for _, m := range memos {
......@@ -818,7 +830,7 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
attachments := attachmentMap[m.ID]
relations := relationMap[m.ID]
memoMessage, err := s.convertMemoFromStoreWithCreators(ctx, m, reactions, attachments, relations, creatorMap)
memoMessage, err := s.convertMemoFromStoreWithCreatorsAndOptions(ctx, m, reactions, attachments, relations, creatorMap, conversionOptions)
if err != nil {
if stderrors.Is(err, errMemoCreatorNotFound) {
slog.Warn("Skipping memo comment with missing creator",
......
......@@ -20,6 +20,10 @@ var (
errReactionCreatorNotFound = stderrors.New("reaction creator not found")
)
type memoConversionOptions struct {
displayWithUpdateTime bool
}
func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation) (*v1pb.Memo, error) {
creatorMap, err := s.listUsersByID(ctx, []int32{memo.CreatorID})
if err != nil {
......@@ -29,12 +33,24 @@ func (s *APIV1Service) convertMemoFromStore(ctx context.Context, memo *store.Mem
}
func (s *APIV1Service) convertMemoFromStoreWithCreators(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation, creatorMap map[int32]*store.User) (*v1pb.Memo, error) {
displayTs := memo.CreatedTs
instanceMemoRelatedSetting, err := s.Store.GetInstanceMemoRelatedSetting(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get instance memo related setting")
}
if instanceMemoRelatedSetting.DisplayWithUpdateTime {
return s.convertMemoFromStoreWithCreatorsAndOptions(
ctx,
memo,
reactions,
attachments,
relations,
creatorMap,
memoConversionOptions{displayWithUpdateTime: instanceMemoRelatedSetting.DisplayWithUpdateTime},
)
}
func (s *APIV1Service) convertMemoFromStoreWithCreatorsAndOptions(ctx context.Context, memo *store.Memo, reactions []*store.Reaction, attachments []*store.Attachment, relations []*v1pb.MemoRelation, creatorMap map[int32]*store.User, options memoConversionOptions) (*v1pb.Memo, error) {
displayTs := memo.CreatedTs
if options.displayWithUpdateTime {
displayTs = memo.UpdatedTs
}
......@@ -65,10 +81,11 @@ func (s *APIV1Service) convertMemoFromStoreWithCreators(ctx context.Context, mem
memoMessage.Parent = &parentName
}
memoMessage.Reactions, err = s.convertReactionsFromStoreWithCreators(ctx, reactions, creatorMap)
reactionMessages, err := s.convertReactionsFromStoreWithCreators(ctx, reactions, creatorMap)
if err != nil {
return nil, errors.Wrap(err, "failed to convert reactions")
}
memoMessage.Reactions = reactionMessages
if relations != nil {
memoMessage.Relations = relations
......@@ -179,7 +196,7 @@ func convertReactionFromStoreWithCreators(reaction *store.Reaction, creatorsByID
// batchConvertMemoRelations batch-loads relations for a list of memos and returns
// a map from memo ID to its converted relations. This avoids N+1 queries when listing memos.
func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*store.Memo) (map[int32][]*v1pb.MemoRelation, error) {
func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*store.Memo, includeSnippets bool) (map[int32][]*v1pb.MemoRelation, error) {
if len(memos) == 0 {
return map[int32][]*v1pb.MemoRelation{}, nil
}
......@@ -202,14 +219,21 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s
memoIDSet[m.ID] = true
}
// Single batch query to get all relations involving any of these memos.
allRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoIDList: memoIDs,
outgoingRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
SourceMemoIDList: memoIDs,
MemoFilter: &memoFilter,
})
if err != nil {
return nil, errors.Wrap(err, "failed to batch list outgoing memo relations")
}
incomingRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoIDList: memoIDs,
MemoFilter: &memoFilter,
})
if err != nil {
return nil, errors.Wrap(err, "failed to batch list memo relations")
return nil, errors.Wrap(err, "failed to batch list incoming memo relations")
}
allRelations := mergeMemoRelations(outgoingRelations, incomingRelations)
// Collect all memo IDs referenced in relations that we need to resolve.
neededIDs := make(map[int32]bool)
......@@ -220,10 +244,16 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s
// Build ID→UID map from the memos we already have.
memoIDToUID := make(map[int32]string, len(memos))
memoIDToContent := make(map[int32]string, len(memos))
memoIDToSnippet := make(map[int32]string, len(memos))
for _, m := range memos {
memoIDToUID[m.ID] = m.UID
memoIDToContent[m.ID] = m.Content
if includeSnippets {
snippet, err := s.getMemoContentSnippet(m.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get memo content snippet")
}
memoIDToSnippet[m.ID] = snippet
}
delete(neededIDs, m.ID)
}
......@@ -233,13 +263,20 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s
for id := range neededIDs {
extraIDs = append(extraIDs, id)
}
extraMemos, err := s.Store.ListMemos(ctx, &store.FindMemo{IDList: extraIDs})
extraFind := &store.FindMemo{IDList: extraIDs, ExcludeContent: !includeSnippets}
extraMemos, err := s.Store.ListMemos(ctx, extraFind)
if err != nil {
return nil, errors.Wrap(err, "failed to batch fetch related memos")
}
for _, m := range extraMemos {
memoIDToUID[m.ID] = m.UID
memoIDToContent[m.ID] = m.Content
if includeSnippets {
snippet, err := s.getMemoContentSnippet(m.Content)
if err != nil {
return nil, errors.Wrap(err, "failed to get related memo content snippet")
}
memoIDToSnippet[m.ID] = snippet
}
}
}
......@@ -252,16 +289,14 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s
continue
}
memoSnippet, _ := s.getMemoContentSnippet(memoIDToContent[r.MemoID])
relatedSnippet, _ := s.getMemoContentSnippet(memoIDToContent[r.RelatedMemoID])
relation := &v1pb.MemoRelation{
Memo: &v1pb.MemoRelation_Memo{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, memoUID),
Snippet: memoSnippet,
Snippet: memoIDToSnippet[r.MemoID],
},
RelatedMemo: &v1pb.MemoRelation_Memo{
Name: fmt.Sprintf("%s%s", MemoNamePrefix, relatedUID),
Snippet: relatedSnippet,
Snippet: memoIDToSnippet[r.RelatedMemoID],
},
Type: convertMemoRelationTypeFromStore(r.Type),
}
......@@ -280,13 +315,29 @@ func (s *APIV1Service) batchConvertMemoRelations(ctx context.Context, memos []*s
// loadMemoRelations loads relations for a single memo and converts them to API format.
func (s *APIV1Service) loadMemoRelations(ctx context.Context, memo *store.Memo) ([]*v1pb.MemoRelation, error) {
relationMap, err := s.batchConvertMemoRelations(ctx, []*store.Memo{memo})
relationMap, err := s.batchConvertMemoRelations(ctx, []*store.Memo{memo}, true)
if err != nil {
return nil, err
}
return relationMap[memo.ID], nil
}
func mergeMemoRelations(groups ...[]*store.MemoRelation) []*store.MemoRelation {
seen := make(map[string]struct{})
merged := make([]*store.MemoRelation, 0)
for _, relations := range groups {
for _, relation := range relations {
key := fmt.Sprintf("%d:%d:%s", relation.MemoID, relation.RelatedMemoID, relation.Type)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
merged = append(merged, relation)
}
}
return merged
}
func convertMemoPropertyFromStore(property *storepb.MemoPayload_Property) *v1pb.Memo_Property {
if property == nil {
return nil
......
......@@ -176,7 +176,7 @@ func (s *APIV1Service) GetMemoByShare(ctx context.Context, request *v1pb.GetMemo
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
relations, err := s.batchConvertMemoRelations(ctx, []*store.Memo{memo})
relations, err := s.batchConvertMemoRelations(ctx, []*store.Memo{memo}, true)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to load memo relations")
}
......
package test
import (
"context"
"fmt"
"path/filepath"
"testing"
"github.com/usememos/memos/internal/profile"
"github.com/usememos/memos/internal/version"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth"
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/store"
"github.com/usememos/memos/store/db"
)
const (
benchmarkTopLevelMemoCount = 5000
benchmarkPageSize = 16
)
type benchmarkService struct {
*TestService
hostUser *store.User
authenticatedCtx context.Context
publicCtx context.Context
pageTenToken string
commentParentName string
}
func newBenchmarkService(tb testing.TB) *benchmarkService {
tb.Helper()
ctx := context.Background()
testService := newTestingServiceForTB(tb)
hostUser, err := testService.CreateHostUser(ctx, "bench-host")
if err != nil {
tb.Fatalf("failed to create host user: %v", err)
}
commentParentName, err := seedListMemosBenchmarkData(ctx, testService.Store, hostUser)
if err != nil {
tb.Fatalf("failed to seed benchmark data: %v", err)
}
authenticatedCtx := context.WithValue(context.Background(), auth.UserIDContextKey, hostUser.ID)
pageTenToken, err := getListMemosPageToken(authenticatedCtx, testService.Service, 10, benchmarkPageSize)
if err != nil {
tb.Fatalf("failed to build page token: %v", err)
}
return &benchmarkService{
TestService: testService,
hostUser: hostUser,
authenticatedCtx: authenticatedCtx,
publicCtx: context.Background(),
pageTenToken: pageTenToken,
commentParentName: commentParentName,
}
}
func newTestingServiceForTB(tb testing.TB) *TestService {
tb.Helper()
ctx := context.Background()
dataDir := tb.TempDir()
testProfile := getBenchmarkProfile(dataDir)
dbDriver, err := db.NewDBDriver(testProfile)
if err != nil {
tb.Fatalf("failed to create db driver: %v", err)
}
testStore := store.New(dbDriver, testProfile)
if err := testStore.Migrate(ctx); err != nil {
tb.Fatalf("failed to migrate db: %v", err)
}
tb.Cleanup(func() {
testStore.Close()
})
service := newServiceWithProfile(testProfile, testStore)
return &TestService{
Service: service,
Store: testStore,
Profile: testProfile,
Secret: service.Secret,
}
}
func getBenchmarkProfile(dataDir string) *profile.Profile {
return &profile.Profile{
Demo: true,
Version: version.GetCurrentVersion(),
InstanceURL: "http://localhost:8080",
Driver: "sqlite",
DSN: filepath.Join(dataDir, "bench.db"),
Data: dataDir,
}
}
func newServiceWithProfile(testProfile *profile.Profile, testStore *store.Store) *apiv1.APIV1Service {
service := apiv1.NewAPIV1Service("bench-secret", testProfile, testStore)
return service
}
func seedListMemosBenchmarkData(ctx context.Context, stores *store.Store, hostUser *store.User) (string, error) {
topLevelMemos := make([]*store.Memo, 0, benchmarkTopLevelMemoCount)
commentParentName := ""
for i := 0; i < benchmarkTopLevelMemoCount; i++ {
visibility := store.Private
if i%4 == 0 {
visibility = store.Public
}
memo, err := stores.CreateMemo(ctx, &store.Memo{
UID: fmt.Sprintf("memo-%06d", i),
CreatorID: hostUser.ID,
Content: benchmarkMemoContent(i),
Visibility: visibility,
})
if err != nil {
return "", err
}
topLevelMemos = append(topLevelMemos, memo)
if i%3 == 0 {
if _, err := stores.CreateAttachment(ctx, &store.Attachment{
UID: fmt.Sprintf("att-%06d", i),
CreatorID: hostUser.ID,
Filename: fmt.Sprintf("memo-%06d.png", i),
Type: "image/png",
Size: 2048,
MemoID: &memo.ID,
}); err != nil {
return "", err
}
}
if i%5 == 0 {
if _, err := stores.UpsertReaction(ctx, &store.Reaction{
CreatorID: hostUser.ID,
ContentID: "memos/" + memo.UID,
ReactionType: "thumbs-up",
}); err != nil {
return "", err
}
}
}
for i, memo := range topLevelMemos {
if i+1 < len(topLevelMemos) && i%4 == 0 {
if _, err := stores.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: memo.ID,
RelatedMemoID: topLevelMemos[i+1].ID,
Type: store.MemoRelationReference,
}); err != nil {
return "", err
}
}
if i%6 == 0 {
commentMemo, err := stores.CreateMemo(ctx, &store.Memo{
UID: fmt.Sprintf("comment-%06d", i),
CreatorID: hostUser.ID,
Content: fmt.Sprintf("Comment for memo %06d", i),
Visibility: store.Private,
})
if err != nil {
return "", err
}
if _, err := stores.UpsertMemoRelation(ctx, &store.MemoRelation{
MemoID: commentMemo.ID,
RelatedMemoID: memo.ID,
Type: store.MemoRelationComment,
}); err != nil {
return "", err
}
if commentParentName == "" {
commentParentName = "memos/" + memo.UID
}
}
}
return commentParentName, nil
}
func benchmarkMemoContent(i int) string {
return fmt.Sprintf("# Bench Memo %06d\n\nThis is benchmark memo %06d with enough content to exercise snippet generation.\n\n- task one\n- task two\n", i, i)
}
func getListMemosPageToken(ctx context.Context, service *apiv1.APIV1Service, page int, pageSize int32) (string, error) {
pageToken := ""
for range page - 1 {
resp, err := service.ListMemos(ctx, &v1pb.ListMemosRequest{
PageSize: pageSize,
PageToken: pageToken,
})
if err != nil {
return "", err
}
pageToken = resp.NextPageToken
if pageToken == "" {
break
}
}
return pageToken, nil
}
func BenchmarkListMemos(b *testing.B) {
bench := newBenchmarkService(b)
b.Run("authenticated_first_page", func(b *testing.B) {
req := &v1pb.ListMemosRequest{PageSize: benchmarkPageSize}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := bench.Service.ListMemos(bench.authenticatedCtx, req)
if err != nil {
b.Fatalf("ListMemos failed: %v", err)
}
if len(resp.Memos) == 0 {
b.Fatal("expected memos in authenticated benchmark response")
}
}
})
b.Run("authenticated_page_ten", func(b *testing.B) {
req := &v1pb.ListMemosRequest{PageSize: benchmarkPageSize, PageToken: bench.pageTenToken}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := bench.Service.ListMemos(bench.authenticatedCtx, req)
if err != nil {
b.Fatalf("ListMemos failed: %v", err)
}
if len(resp.Memos) == 0 {
b.Fatal("expected memos in paged benchmark response")
}
}
})
b.Run("public_first_page", func(b *testing.B) {
req := &v1pb.ListMemosRequest{PageSize: benchmarkPageSize}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := bench.Service.ListMemos(bench.publicCtx, req)
if err != nil {
b.Fatalf("ListMemos failed: %v", err)
}
if len(resp.Memos) == 0 {
b.Fatal("expected memos in public benchmark response")
}
}
})
}
func BenchmarkListMemoCommentsPreview(b *testing.B) {
bench := newBenchmarkService(b)
if bench.commentParentName == "" {
b.Fatal("expected seeded memo with comments")
}
req := &v1pb.ListMemoCommentsRequest{
Name: bench.commentParentName,
PageSize: 3,
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resp, err := bench.Service.ListMemoComments(bench.authenticatedCtx, req)
if err != nil {
b.Fatalf("ListMemoComments failed: %v", err)
}
if len(resp.Memos) == 0 {
b.Fatal("expected comments in benchmark response")
}
}
}
......@@ -54,6 +54,22 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
}
where = append(where, fmt.Sprintf("(`memo_id` IN (%s) OR `related_memo_id` IN (%s))", inClause, inClause))
}
if len(find.SourceMemoIDList) > 0 {
placeholders := make([]string, len(find.SourceMemoIDList))
for i, id := range find.SourceMemoIDList {
placeholders[i] = "?"
args = append(args, id)
}
where = append(where, fmt.Sprintf("`memo_id` IN (%s)", strings.Join(placeholders, ", ")))
}
if len(find.RelatedMemoIDList) > 0 {
placeholders := make([]string, len(find.RelatedMemoIDList))
for i, id := range find.RelatedMemoIDList {
placeholders[i] = "?"
args = append(args, id)
}
where = append(where, fmt.Sprintf("`related_memo_id` IN (%s)", strings.Join(placeholders, ", ")))
}
if find.MemoFilter != nil {
engine, err := filter.DefaultEngine()
if err != nil {
......
......@@ -50,18 +50,34 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, *find.Type)
}
if len(find.MemoIDList) > 0 {
memoPlaceholders := make([]string, len(find.MemoIDList))
placeholders := make([]string, len(find.MemoIDList))
for i, id := range find.MemoIDList {
memoPlaceholders[i] = placeholder(len(args) + 1)
placeholders[i] = placeholder(len(args) + 1)
args = append(args, id)
}
inClause := strings.Join(placeholders, ", ")
relatedPlaceholders := make([]string, len(find.MemoIDList))
for i, id := range find.MemoIDList {
relatedPlaceholders[i] = placeholder(len(args) + 1)
args = append(args, id)
}
where = append(where, fmt.Sprintf("(memo_id IN (%s) OR related_memo_id IN (%s))",
strings.Join(memoPlaceholders, ", "), strings.Join(relatedPlaceholders, ", ")))
where = append(where, fmt.Sprintf("(memo_id IN (%s) OR related_memo_id IN (%s))", inClause, strings.Join(relatedPlaceholders, ", ")))
}
if len(find.SourceMemoIDList) > 0 {
placeholders := make([]string, len(find.SourceMemoIDList))
for i, id := range find.SourceMemoIDList {
placeholders[i] = placeholder(len(args) + 1)
args = append(args, id)
}
where = append(where, fmt.Sprintf("memo_id IN (%s)", strings.Join(placeholders, ", ")))
}
if len(find.RelatedMemoIDList) > 0 {
placeholders := make([]string, len(find.RelatedMemoIDList))
for i, id := range find.RelatedMemoIDList {
placeholders[i] = placeholder(len(args) + 1)
args = append(args, id)
}
where = append(where, fmt.Sprintf("related_memo_id IN (%s)", strings.Join(placeholders, ", ")))
}
if find.MemoFilter != nil {
engine, err := filter.DefaultEngine()
......
......@@ -56,12 +56,27 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
args = append(args, id)
}
inClause := strings.Join(placeholders, ", ")
// Duplicate args for the second IN clause.
for _, id := range find.MemoIDList {
args = append(args, id)
}
where = append(where, fmt.Sprintf("(memo_id IN (%s) OR related_memo_id IN (%s))", inClause, inClause))
}
if len(find.SourceMemoIDList) > 0 {
placeholders := make([]string, len(find.SourceMemoIDList))
for i, id := range find.SourceMemoIDList {
placeholders[i] = "?"
args = append(args, id)
}
where = append(where, fmt.Sprintf("memo_id IN (%s)", strings.Join(placeholders, ", ")))
}
if len(find.RelatedMemoIDList) > 0 {
placeholders := make([]string, len(find.RelatedMemoIDList))
for i, id := range find.RelatedMemoIDList {
placeholders[i] = "?"
args = append(args, id)
}
where = append(where, fmt.Sprintf("related_memo_id IN (%s)", strings.Join(placeholders, ", ")))
}
if find.MemoFilter != nil {
engine, err := filter.DefaultEngine()
if err != nil {
......
......@@ -26,6 +26,10 @@ type FindMemoRelation struct {
MemoFilter *string
// MemoIDList matches relations where memo_id OR related_memo_id is in the list.
MemoIDList []int32
// SourceMemoIDList matches relations where memo_id is in the list.
SourceMemoIDList []int32
// RelatedMemoIDList matches relations where related_memo_id is in the list.
RelatedMemoIDList []int32
Limit *int
Offset *int
}
......
......@@ -5,6 +5,7 @@ import type { MemoRelation } from "@/types/proto/api/v1/memo_service_pb";
import { useTranslate } from "@/utils/i18n";
import RelationCard from "./RelationCard";
import { getRelationBuckets, getRelationMemo, getRelationMemoName, type RelationDirection } from "./relationHelpers";
import { useResolvedRelationMemos } from "./useResolvedRelationMemos";
interface RelationListViewProps {
relations: MemoRelation[];
......@@ -16,6 +17,7 @@ interface RelationListViewProps {
function RelationListView({ relations, currentMemoName, parentPage, className }: RelationListViewProps) {
const t = useTranslate();
const [activeTab, setActiveTab] = useState<"referencing" | "referenced">("referencing");
const resolvedMemos = useResolvedRelationMemos(relations);
const { referencing: referencingRelations, referenced: referencedRelations } = useMemo(
() => getRelationBuckets(relations, currentMemoName),
......@@ -60,9 +62,15 @@ function RelationListView({ relations, currentMemoName, parentPage, className }:
}
contentClassName="flex flex-col gap-0 p-1.5"
>
{activeRelations.map((relation) => (
<RelationCard key={getRelationMemoName(relation, direction)} memo={getRelationMemo(relation, direction)!} parentPage={parentPage} />
))}
{activeRelations.map((relation) => {
const memo = getRelationMemo(relation, direction);
if (!memo) {
return null;
}
return (
<RelationCard key={getRelationMemoName(relation, direction)} memo={resolvedMemos[memo.name] ?? memo} parentPage={parentPage} />
);
})}
</MetadataSection>
);
}
......
......@@ -11,9 +11,10 @@ export const useResolvedRelationMemos = (relations: MemoRelation[]) => {
const names = new Set<string>();
for (const relation of relations) {
const relatedMemo = relation.relatedMemo;
if (relatedMemo?.name && !relatedMemo.snippet && !resolvedMemos[relatedMemo.name]) {
names.add(relatedMemo.name);
for (const memo of [relation.memo, relation.relatedMemo]) {
if (memo?.name && !memo.snippet && !resolvedMemos[memo.name]) {
names.add(memo.name);
}
}
}
......
......@@ -10,7 +10,7 @@ const MemoCommentListView: React.FC = () => {
const { memo } = useMemoViewContext();
const { isInMemoDetailPage, commentAmount } = useMemoViewDerived();
const { data } = useMemoComments(memo.name, { enabled: !isInMemoDetailPage && commentAmount > 0 });
const { data } = useMemoComments(memo.name, { enabled: !isInMemoDetailPage && commentAmount > 0, pageSize: 3 });
const comments = data?.memos ?? [];
const displayedComments = comments.slice(0, 3);
const { data: commentCreators } = useUsersByNames(displayedComments.map((comment) => comment.creator));
......
......@@ -5,7 +5,7 @@ import { useInfiniteQuery, useMutation, useQuery, useQueryClient } from "@tansta
import { memoServiceClient } from "@/connect";
import { userKeys } from "@/hooks/useUserQueries";
import type { ListMemosRequest, ListMemosResponse, Memo } from "@/types/proto/api/v1/memo_service_pb";
import { ListMemosRequestSchema, MemoSchema } from "@/types/proto/api/v1/memo_service_pb";
import { ListMemoCommentsRequestSchema, ListMemosRequestSchema, MemoSchema } from "@/types/proto/api/v1/memo_service_pb";
// Query keys factory for consistent cache management
export const memoKeys = {
......@@ -243,11 +243,16 @@ export function useDeleteMemo() {
});
}
export function useMemoComments(name: string, options?: { enabled?: boolean }) {
export function useMemoComments(name: string, options?: { enabled?: boolean; pageSize?: number }) {
return useQuery({
queryKey: memoKeys.comments(name),
queryKey: [...memoKeys.comments(name), options?.pageSize ?? 0],
queryFn: async () => {
const response = await memoServiceClient.listMemoComments({ name });
const response = await memoServiceClient.listMemoComments(
create(ListMemoCommentsRequestSchema, {
name,
pageSize: options?.pageSize ?? 0,
}),
);
return response;
},
enabled: options?.enabled ?? true,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment