Commit c4566376 authored by boojack's avatar boojack

fix(api): reduce memory pressure in backend paths

parent 8479e1d5
......@@ -6,6 +6,35 @@ import (
"github.com/pkg/errors"
)
type asyncEmailRequest struct {
config *Config
message *Message
}
var asyncEmailQueue = make(chan asyncEmailRequest, 128)
func init() {
for range 2 {
go func() {
for request := range asyncEmailQueue {
if err := Send(request.config, request.message); err != nil {
recipients := ""
if request.message != nil && len(request.message.To) > 0 {
recipients = request.message.To[0]
if len(request.message.To) > 1 {
recipients += " and others"
}
}
slog.Warn("Failed to send email asynchronously",
slog.String("recipients", recipients),
slog.Any("error", err))
}
}
}()
}
}
// Send sends an email synchronously.
// Returns an error if the email fails to send.
func Send(config *Config, message *Message) error {
......@@ -21,23 +50,12 @@ func Send(config *Config, message *Message) error {
}
// SendAsync sends an email asynchronously.
// It spawns a new goroutine to handle the sending and does not wait for the response.
// It enqueues the message for bounded asynchronous sending and does not wait for the response.
// Any errors are logged but not returned.
func SendAsync(config *Config, message *Message) {
go func() {
if err := Send(config, message); err != nil {
// Since we're in a goroutine, we can only log the error
recipients := ""
if message != nil && len(message.To) > 0 {
recipients = message.To[0]
if len(message.To) > 1 {
recipients += " and others"
}
}
slog.Warn("Failed to send email asynchronously",
slog.String("recipients", recipients),
slog.Any("error", err))
}
}()
select {
case asyncEmailQueue <- asyncEmailRequest{config: config, message: message}:
default:
slog.Warn("Dropped email because the async queue is full")
}
}
......@@ -28,8 +28,25 @@ var (
DialContext: safeDialContext,
},
}
asyncPostQueue = make(chan *WebhookRequestPayload, 128)
)
func init() {
for range 4 {
go func() {
for payload := range asyncPostQueue {
if err := Post(payload); err != nil {
slog.Warn("Failed to dispatch webhook asynchronously",
slog.String("url", payload.URL),
slog.String("activityType", payload.ActivityType),
slog.Any("err", err))
}
}
}()
}
}
// safeDialContext is a net.Dialer.DialContext replacement that resolves the target
// hostname and rejects any address that falls within a reserved/private IP range.
func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
......@@ -82,7 +99,7 @@ func Post(requestPayload *WebhookRequestPayload) error {
}
defer resp.Body.Close()
b, err := io.ReadAll(resp.Body)
b, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return errors.Wrapf(err, "failed to read webhook response from %s", requestPayload.URL)
}
......@@ -107,14 +124,17 @@ func Post(requestPayload *WebhookRequestPayload) error {
}
// PostAsync posts the message to webhook endpoint asynchronously.
// It spawns a new goroutine to handle the request and does not wait for the response.
// It enqueues the request for bounded asynchronous dispatch and does not wait for the response.
func PostAsync(requestPayload *WebhookRequestPayload) {
go func() {
if err := Post(requestPayload); err != nil {
slog.Warn("Failed to dispatch webhook asynchronously",
slog.String("url", requestPayload.URL),
slog.String("activityType", requestPayload.ActivityType),
slog.Any("err", err))
}
}()
if requestPayload == nil {
slog.Warn("Dropped webhook dispatch because payload is nil")
return
}
select {
case asyncPostQueue <- requestPayload:
default:
slog.Warn("Dropped webhook dispatch because the async queue is full",
slog.String("url", requestPayload.URL),
slog.String("activityType", requestPayload.ActivityType))
}
}
package webhook
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestPostAsyncNilPayloadDoesNotPanic(t *testing.T) {
require.NotPanics(t, func() {
PostAsync(nil)
})
}
......@@ -2,6 +2,8 @@ package v1
import (
"bytes"
"encoding/binary"
"hash/crc32"
"image"
"image/color"
"image/jpeg"
......@@ -189,3 +191,42 @@ func TestStripImageExif(t *testing.T) {
assert.Error(t, err)
})
}
func TestValidateImagePixelCountRejectsOversizedDimensions(t *testing.T) {
t.Parallel()
err := validateImagePixelCount(testPNGHeaderWithDimensions(100_000, 100_000))
require.Error(t, err)
require.Contains(t, err.Error(), "image dimensions exceed maximum")
}
func TestStripImageExifRejectsOversizedDimensionsBeforeDecode(t *testing.T) {
t.Parallel()
_, err := stripImageExif(testPNGHeaderWithDimensions(100_000, 100_000), "image/png")
require.Error(t, err)
require.Contains(t, err.Error(), "image dimensions exceed maximum")
}
func testPNGHeaderWithDimensions(width, height uint32) []byte {
var buf bytes.Buffer
buf.Write([]byte{0x89, 'P', 'N', 'G', '\r', '\n', 0x1a, '\n'})
ihdr := make([]byte, 13)
binary.BigEndian.PutUint32(ihdr[0:4], width)
binary.BigEndian.PutUint32(ihdr[4:8], height)
ihdr[8] = 8
ihdr[9] = 2
writePNGChunk(&buf, "IHDR", ihdr)
writePNGChunk(&buf, "IEND", nil)
return buf.Bytes()
}
func writePNGChunk(buf *bytes.Buffer, chunkType string, data []byte) {
_ = binary.Write(buf, binary.BigEndian, uint32(len(data)))
buf.WriteString(chunkType)
buf.Write(data)
crc := crc32.ChecksumIEEE(append([]byte(chunkType), data...))
_ = binary.Write(buf, binary.BigEndian, crc)
}
......@@ -5,6 +5,7 @@ import (
"context"
"encoding/binary"
"fmt"
"image"
"io"
"log/slog"
"mime"
......@@ -45,6 +46,7 @@ const (
// Quality 95 maintains visual quality while ensuring metadata is removed.
defaultJPEGQuality = 95
maxBatchDeleteAttachments = 100
maxImagePixels = 50_000_000
)
var SupportedThumbnailMimeTypes = []string{
......@@ -148,12 +150,18 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
// Strip EXIF metadata from images for privacy protection.
// This removes sensitive information like GPS location, device details, etc.
if shouldStripExif(create.Type) && !isAndroidMotionContainer(create.Payload.GetMotionMedia()) {
if strippedBlob, err := stripImageExif(create.Blob, create.Type); err != nil {
release, err := s.acquireImageProcessingSlot(ctx)
if err != nil {
return nil, status.Errorf(codes.ResourceExhausted, "too many image processing requests")
}
strippedBlob, stripErr := stripImageExif(create.Blob, create.Type)
release()
if stripErr != nil {
// Log warning but continue with original image to ensure uploads don't fail.
slog.Warn("failed to strip EXIF metadata from image",
slog.String("type", create.Type),
slog.String("filename", create.Filename),
slog.String("error", err.Error()))
slog.String("error", stripErr.Error()))
} else {
create.Blob = strippedBlob
create.Size = int64(len(strippedBlob))
......@@ -745,6 +753,32 @@ func shouldStripExif(mimeType string) bool {
return exifCapableImageTypes[mimeType]
}
func (s *APIV1Service) acquireImageProcessingSlot(ctx context.Context) (func(), error) {
if s.imageProcessingSemaphore == nil {
return func() {}, nil
}
if err := s.imageProcessingSemaphore.Acquire(ctx, 1); err != nil {
return nil, err
}
return func() {
s.imageProcessingSemaphore.Release(1)
}, nil
}
func validateImagePixelCount(imageData []byte) error {
config, _, err := image.DecodeConfig(bytes.NewReader(imageData))
if err != nil {
return nil
}
if config.Width <= 0 || config.Height <= 0 {
return errors.New("invalid image dimensions")
}
if config.Width > maxImagePixels/config.Height {
return errors.Errorf("image dimensions exceed maximum of %d pixels", maxImagePixels)
}
return nil
}
// stripImageExif removes EXIF metadata from image files by decoding and re-encoding them.
// This prevents exposure of sensitive metadata such as GPS location, camera details, and timestamps.
//
......@@ -759,6 +793,10 @@ func shouldStripExif(mimeType string) bool {
//
// Returns the cleaned image data without any EXIF metadata, or an error if processing fails.
func stripImageExif(imageData []byte, mimeType string) ([]byte, error) {
if err := validateImagePixelCount(imageData); err != nil {
return nil, err
}
// Decode image with automatic EXIF orientation correction.
// This ensures the image displays correctly after metadata removal.
img, err := imaging.Decode(bytes.NewReader(imageData), imaging.AutoOrientation(true))
......
......@@ -44,6 +44,17 @@ func getPageToken(limit int, offset int) (string, error) {
})
}
func normalizePageSize(pageSize int32) int {
limit := int(pageSize)
if limit <= 0 {
return DefaultPageSize
}
if limit > MaxPageSize {
return MaxPageSize
}
return limit
}
func marshalPageToken(pageToken *v1pb.PageToken) (string, error) {
b, err := proto.Marshal(pageToken)
if err != nil {
......
package v1
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNormalizePageSize(t *testing.T) {
t.Parallel()
tests := []struct {
name string
pageSize int32
want int
}{
{
name: "default for zero",
pageSize: 0,
want: DefaultPageSize,
},
{
name: "default for negative",
pageSize: -1,
want: DefaultPageSize,
},
{
name: "preserves valid size",
pageSize: 42,
want: 42,
},
{
name: "clamps oversized size",
pageSize: int32(MaxPageSize + 1),
want: MaxPageSize,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, normalizePageSize(tt.pageSize))
})
}
}
......@@ -232,14 +232,15 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
}
limit = int(pageToken.Limit)
limit = normalizePageSize(pageToken.Limit)
offset = int(pageToken.Offset)
if offset < 0 {
offset = 0
}
} else {
limit = int(request.PageSize)
}
if limit <= 0 {
limit = DefaultPageSize
limit = normalizePageSize(request.PageSize)
}
limit = min(limit, MaxPageSize)
limitPlusOne := limit + 1
memoFind.Limit = &limitPlusOne
memoFind.Offset = &offset
......@@ -715,18 +716,45 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
}
memoRelationComment := store.MemoRelationComment
var limit, offset int
if request.PageToken != "" {
var pageToken v1pb.PageToken
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
}
limit = normalizePageSize(pageToken.Limit)
offset = int(pageToken.Offset)
if offset < 0 {
offset = 0
}
} else {
limit = normalizePageSize(request.PageSize)
}
limitPlusOne := limit + 1
memoRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &memo.ID,
Type: &memoRelationComment,
MemoFilter: &memoFilter,
Limit: &limitPlusOne,
Offset: &offset,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo relations")
}
nextPageToken := ""
if len(memoRelations) == limitPlusOne {
memoRelations = memoRelations[:limit]
nextPageToken, err = getPageToken(limit, offset+limit)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
}
}
if len(memoRelations) == 0 {
response := &v1pb.ListMemoCommentsResponse{
Memos: []*v1pb.Memo{},
Memos: []*v1pb.Memo{},
NextPageToken: nextPageToken,
}
return response, nil
}
......@@ -807,7 +835,8 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
}
response := &v1pb.ListMemoCommentsResponse{
Memos: memosResponse,
Memos: memosResponse,
NextPageToken: nextPageToken,
}
return response, nil
}
......
......@@ -373,6 +373,46 @@ func TestListMemoCommentsSkipsCommentsWithMissingCreators(t *testing.T) {
require.Empty(t, resp.Memos)
}
func TestListMemoCommentsPaginates(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
owner, err := ts.CreateRegularUser(ctx, "comment-page-owner")
require.NoError(t, err)
ownerCtx := ts.CreateUserContext(ctx, owner.ID)
memo, err := ts.Service.CreateMemo(ownerCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "memo with paged comments",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
for i := 0; i < 3; i++ {
_, err = ts.Service.CreateMemoComment(ownerCtx, &apiv1.CreateMemoCommentRequest{
Name: memo.Name,
Comment: &apiv1.Memo{
Content: fmt.Sprintf("comment %d", i),
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
}
firstPage, err := ts.Service.ListMemoComments(ownerCtx, &apiv1.ListMemoCommentsRequest{Name: memo.Name, PageSize: 2})
require.NoError(t, err)
require.Len(t, firstPage.Memos, 2)
require.NotEmpty(t, firstPage.NextPageToken)
secondPage, err := ts.Service.ListMemoComments(ownerCtx, &apiv1.ListMemoCommentsRequest{Name: memo.Name, PageToken: firstPage.NextPageToken})
require.NoError(t, err)
require.Len(t, secondPage.Memos, 1)
require.Empty(t, secondPage.NextPageToken)
}
// TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments.
// This addresses issue #5483: https://github.com/usememos/memos/issues/5483
func TestCreateMemoWithCustomTimestamps(t *testing.T) {
......
......@@ -17,6 +17,8 @@ import (
"github.com/usememos/memos/store"
)
const maxAPIRequestBytes = 256 << 20
type APIV1Service struct {
v1pb.UnimplementedInstanceServiceServer
v1pb.UnimplementedAuthServiceServer
......@@ -34,7 +36,8 @@ type APIV1Service struct {
SSEHub *SSEHub
// thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion
thumbnailSemaphore *semaphore.Weighted
thumbnailSemaphore *semaphore.Weighted
imageProcessingSemaphore *semaphore.Weighted
}
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store) *APIV1Service {
......@@ -43,12 +46,13 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store
markdown.WithMentionExtension(),
)
return &APIV1Service{
Secret: secret,
Profile: profile,
Store: store,
MarkdownService: markdownService,
SSEHub: NewSSEHub(),
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
Secret: secret,
Profile: profile,
Store: store,
MarkdownService: markdownService,
SSEHub: NewSSEHub(),
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
imageProcessingSemaphore: semaphore.NewWeighted(2),
}
}
......@@ -120,7 +124,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
}))
// Register SSE endpoint with same CORS as rest of /api/v1.
RegisterSSERoutes(gwGroup, s.SSEHub, s.Store, s.Secret)
handler := echo.WrapHandler(gwMux)
handler := echo.WrapHandler(http.MaxBytesHandler(gwMux, maxAPIRequestBytes))
gwGroup.Any("/api/v1/*", handler)
gwGroup.Any("/file/*", handler)
......@@ -135,7 +139,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
)
connectMux := http.NewServeMux()
connectHandler := NewConnectServiceHandler(s)
connectHandler.RegisterConnectHandlers(connectMux, connectInterceptors)
connectHandler.RegisterConnectHandlers(connectMux, connectInterceptors, connect.WithReadMaxBytes(maxAPIRequestBytes))
// Wrap with CORS for browser access
corsHandler := middleware.CORSWithConfig(middleware.CORSConfig{
......@@ -147,7 +151,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
AllowCredentials: true,
})
connectGroup := echoServer.Group("", corsHandler)
connectGroup.Any("/memos.api.v1.*", echo.WrapHandler(connectMux))
connectGroup.Any("/memos.api.v1.*", echo.WrapHandler(http.MaxBytesHandler(connectMux, maxAPIRequestBytes)))
return nil
}
......@@ -223,18 +223,14 @@ func (s *FileServerService) serveMediaStream(c *echo.Context, attachment *store.
// serveStaticFile serves non-streaming files (images, documents, etc.).
func (s *FileServerService) serveStaticFile(c *echo.Context, attachment *store.Attachment, contentType string, wantThumbnail bool) error {
blob, err := s.getAttachmentBlob(attachment)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment blob").Wrap(err)
}
// Generate thumbnail for supported image types.
if wantThumbnail && thumbnailSupportedTypes[attachment.Type] {
if thumbnailBlob, err := s.getOrGenerateThumbnail(c.Request().Context(), attachment); err != nil {
slog.Warn("failed to get thumbnail", "error", err)
} else {
blob = thumbnailBlob
contentType = "image/jpeg"
setSecurityHeaders(c)
setMediaHeaders(c, "image/jpeg", attachment.Type)
return c.Blob(http.StatusOK, "image/jpeg", thumbnailBlob)
}
}
......@@ -246,7 +242,24 @@ func (s *FileServerService) serveStaticFile(c *echo.Context, attachment *store.A
c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("attachment; filename=%q", attachment.Filename))
}
return c.Blob(http.StatusOK, contentType, blob)
switch attachment.StorageType {
case storepb.AttachmentStorageType_LOCAL:
filePath, err := s.resolveLocalPath(attachment.Reference)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to resolve file path").Wrap(err)
}
http.ServeFile(c.Response(), c.Request(), filePath)
return nil
case storepb.AttachmentStorageType_S3:
reader, err := s.getAttachmentReader(c.Request().Context(), attachment)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment reader").Wrap(err)
}
defer reader.Close()
return c.Stream(http.StatusOK, contentType, reader)
default:
return c.Blob(http.StatusOK, contentType, attachment.Blob)
}
}
// =============================================================================
......@@ -260,7 +273,7 @@ func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]b
return s.readLocalFile(attachment.Reference)
case storepb.AttachmentStorageType_S3:
return s.downloadFromS3(attachment)
return s.downloadFromS3(context.Background(), attachment)
default:
return attachment.Blob, nil
......@@ -268,7 +281,7 @@ func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]b
}
// getAttachmentReader returns a reader for streaming attachment content.
func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (io.ReadCloser, error) {
func (s *FileServerService) getAttachmentReader(ctx context.Context, attachment *store.Attachment) (io.ReadCloser, error) {
switch attachment.StorageType {
case storepb.AttachmentStorageType_LOCAL:
filePath, err := s.resolveLocalPath(attachment.Reference)
......@@ -289,7 +302,7 @@ func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (i
if err != nil {
return nil, err
}
reader, err := s3Client.GetObjectStream(context.Background(), s3Object.Key)
reader, err := s3Client.GetObjectStream(ctx, s3Object.Key)
if err != nil {
return nil, errors.Wrap(err, "failed to stream from S3")
}
......@@ -356,13 +369,13 @@ func (*FileServerService) createS3Client(attachment *store.Attachment) (*s3.Clie
}
// downloadFromS3 downloads the entire object from S3.
func (s *FileServerService) downloadFromS3(attachment *store.Attachment) ([]byte, error) {
func (s *FileServerService) downloadFromS3(ctx context.Context, attachment *store.Attachment) ([]byte, error) {
client, s3Object, err := s.createS3Client(attachment)
if err != nil {
return nil, err
}
blob, err := client.GetObject(context.Background(), s3Object.Key)
blob, err := client.GetObject(ctx, s3Object.Key)
if err != nil {
return nil, errors.Wrap(err, "failed to download from S3")
}
......@@ -411,7 +424,7 @@ func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachme
return blob, nil
}
return s.generateThumbnail(attachment, thumbnailPath)
return s.generateThumbnail(ctx, attachment, thumbnailPath)
}
// getThumbnailPath returns the file path for a cached thumbnail.
......@@ -435,8 +448,8 @@ func (*FileServerService) readCachedThumbnail(path string) ([]byte, error) {
}
// generateThumbnail creates a new thumbnail and saves it to disk.
func (s *FileServerService) generateThumbnail(attachment *store.Attachment, thumbnailPath string) ([]byte, error) {
reader, err := s.getAttachmentReader(attachment)
func (s *FileServerService) generateThumbnail(ctx context.Context, attachment *store.Attachment, thumbnailPath string) ([]byte, error) {
reader, err := s.getAttachmentReader(ctx, attachment)
if err != nil {
return nil, errors.Wrap(err, "failed to get attachment reader")
}
......
......@@ -73,6 +73,52 @@ func TestServeAttachmentFile_ShareTokenAllowsDirectMemoAttachment(t *testing.T)
require.Equal(t, "memo attachment", rec.Body.String())
}
func TestServeAttachmentFile_LocalStaticFileSupportsRangeRequests(t *testing.T) {
ctx := context.Background()
svc, fs, _, cleanup := newShareAttachmentTestServices(ctx, t)
defer cleanup()
creator, err := svc.Store.CreateUser(ctx, &store.User{
Username: "range-owner",
Role: store.RoleUser,
Email: "range-owner@example.com",
})
require.NoError(t, err)
creatorCtx := context.WithValue(ctx, auth.UserIDContextKey, creator.ID)
attachment, err := svc.CreateAttachment(creatorCtx, &apiv1.CreateAttachmentRequest{
Attachment: &apiv1.Attachment{
Filename: "range.txt",
Type: "text/plain",
Content: []byte("0123456789"),
},
})
require.NoError(t, err)
_, err = svc.CreateMemo(creatorCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "range memo",
Visibility: apiv1.Visibility_PUBLIC,
Attachments: []*apiv1.Attachment{
{Name: attachment.Name},
},
},
})
require.NoError(t, err)
e := echo.New()
fs.RegisterRoutes(e)
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/file/%s/%s", attachment.Name, attachment.Filename), nil)
req.Header.Set("Range", "bytes=2-5")
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
require.Equal(t, http.StatusPartialContent, rec.Code)
require.Equal(t, "2345", rec.Body.String())
require.Equal(t, "bytes 2-5/10", rec.Header().Get("Content-Range"))
}
func TestServeAttachmentFile_ShareTokenRejectsCommentAttachment(t *testing.T) {
ctx := context.Background()
svc, fs, _, cleanup := newShareAttachmentTestServices(ctx, t)
......
......@@ -34,13 +34,13 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation)
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
where, args := []string{"TRUE"}, []any{}
if find.MemoID != nil {
where, args = append(where, "`memo_id` = ?"), append(args, find.MemoID)
where, args = append(where, "`memo_id` = ?"), append(args, *find.MemoID)
}
if find.RelatedMemoID != nil {
where, args = append(where, "`related_memo_id` = ?"), append(args, find.RelatedMemoID)
where, args = append(where, "`related_memo_id` = ?"), append(args, *find.RelatedMemoID)
}
if find.Type != nil {
where, args = append(where, "`type` = ?"), append(args, find.Type)
where, args = append(where, "`type` = ?"), append(args, *find.Type)
}
if len(find.MemoIDList) > 0 {
placeholders := make([]string, len(find.MemoIDList))
......@@ -73,7 +73,15 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
}
}
rows, err := d.db.QueryContext(ctx, "SELECT `memo_id`, `related_memo_id`, `type` FROM `memo_relation` WHERE "+strings.Join(where, " AND "), args...)
query := "SELECT `memo_id`, `related_memo_id`, `type` FROM `memo_relation` WHERE " + strings.Join(where, " AND ") + " ORDER BY `memo_id` DESC"
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
}
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
......@@ -102,13 +110,13 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
where, args := []string{"TRUE"}, []any{}
if delete.MemoID != nil {
where, args = append(where, "`memo_id` = ?"), append(args, delete.MemoID)
where, args = append(where, "`memo_id` = ?"), append(args, *delete.MemoID)
}
if delete.RelatedMemoID != nil {
where, args = append(where, "`related_memo_id` = ?"), append(args, delete.RelatedMemoID)
where, args = append(where, "`related_memo_id` = ?"), append(args, *delete.RelatedMemoID)
}
if delete.Type != nil {
where, args = append(where, "`type` = ?"), append(args, delete.Type)
where, args = append(where, "`type` = ?"), append(args, *delete.Type)
}
stmt := "DELETE FROM `memo_relation` WHERE " + strings.Join(where, " AND ")
result, err := d.db.ExecContext(ctx, stmt, args...)
......
......@@ -41,13 +41,13 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation)
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
where, args := []string{"1 = 1"}, []any{}
if find.MemoID != nil {
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID)
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *find.MemoID)
}
if find.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, find.RelatedMemoID)
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, *find.RelatedMemoID)
}
if find.Type != nil {
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type)
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, *find.Type)
}
if len(find.MemoIDList) > 0 {
memoPlaceholders := make([]string, len(find.MemoIDList))
......@@ -93,13 +93,22 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
}
}
rows, err := d.db.QueryContext(ctx, `
query := `
SELECT
memo_id,
related_memo_id,
type
FROM memo_relation
WHERE `+strings.Join(where, " AND "), args...)
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY memo_id DESC`
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
}
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
......@@ -128,13 +137,13 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
where, args := []string{"1 = 1"}, []any{}
if delete.MemoID != nil {
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, delete.MemoID)
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *delete.MemoID)
}
if delete.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, delete.RelatedMemoID)
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, *delete.RelatedMemoID)
}
if delete.Type != nil {
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, delete.Type)
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, *delete.Type)
}
stmt := `DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ")
result, err := d.db.ExecContext(ctx, stmt, args...)
......
......@@ -41,13 +41,13 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation)
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
where, args := []string{"TRUE"}, []any{}
if find.MemoID != nil {
where, args = append(where, "memo_id = ?"), append(args, find.MemoID)
where, args = append(where, "memo_id = ?"), append(args, *find.MemoID)
}
if find.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID)
where, args = append(where, "related_memo_id = ?"), append(args, *find.RelatedMemoID)
}
if find.Type != nil {
where, args = append(where, "type = ?"), append(args, find.Type)
where, args = append(where, "type = ?"), append(args, *find.Type)
}
if len(find.MemoIDList) > 0 {
placeholders := make([]string, len(find.MemoIDList))
......@@ -78,13 +78,22 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
}
}
rows, err := d.db.QueryContext(ctx, `
query := `
SELECT
memo_id,
related_memo_id,
type
FROM memo_relation
WHERE `+strings.Join(where, " AND "), args...)
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY memo_id DESC`
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
}
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
......@@ -113,13 +122,13 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
where, args := []string{"TRUE"}, []any{}
if delete.MemoID != nil {
where, args = append(where, "memo_id = ?"), append(args, delete.MemoID)
where, args = append(where, "memo_id = ?"), append(args, *delete.MemoID)
}
if delete.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID)
where, args = append(where, "related_memo_id = ?"), append(args, *delete.RelatedMemoID)
}
if delete.Type != nil {
where, args = append(where, "type = ?"), append(args, delete.Type)
where, args = append(where, "type = ?"), append(args, *delete.Type)
}
stmt := `
DELETE FROM memo_relation
......
......@@ -26,6 +26,8 @@ type FindMemoRelation struct {
MemoFilter *string
// MemoIDList matches relations where memo_id OR related_memo_id is in the list.
MemoIDList []int32
Limit *int
Offset *int
}
type DeleteMemoRelation struct {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment