Commit c4566376 authored by boojack's avatar boojack

fix(api): reduce memory pressure in backend paths

parent 8479e1d5
...@@ -6,6 +6,35 @@ import ( ...@@ -6,6 +6,35 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type asyncEmailRequest struct {
config *Config
message *Message
}
var asyncEmailQueue = make(chan asyncEmailRequest, 128)
func init() {
for range 2 {
go func() {
for request := range asyncEmailQueue {
if err := Send(request.config, request.message); err != nil {
recipients := ""
if request.message != nil && len(request.message.To) > 0 {
recipients = request.message.To[0]
if len(request.message.To) > 1 {
recipients += " and others"
}
}
slog.Warn("Failed to send email asynchronously",
slog.String("recipients", recipients),
slog.Any("error", err))
}
}
}()
}
}
// Send sends an email synchronously. // Send sends an email synchronously.
// Returns an error if the email fails to send. // Returns an error if the email fails to send.
func Send(config *Config, message *Message) error { func Send(config *Config, message *Message) error {
...@@ -21,23 +50,12 @@ func Send(config *Config, message *Message) error { ...@@ -21,23 +50,12 @@ func Send(config *Config, message *Message) error {
} }
// SendAsync sends an email asynchronously. // SendAsync sends an email asynchronously.
// It spawns a new goroutine to handle the sending and does not wait for the response. // It enqueues the message for bounded asynchronous sending and does not wait for the response.
// Any errors are logged but not returned. // Any errors are logged but not returned.
func SendAsync(config *Config, message *Message) { func SendAsync(config *Config, message *Message) {
go func() { select {
if err := Send(config, message); err != nil { case asyncEmailQueue <- asyncEmailRequest{config: config, message: message}:
// Since we're in a goroutine, we can only log the error default:
recipients := "" slog.Warn("Dropped email because the async queue is full")
if message != nil && len(message.To) > 0 {
recipients = message.To[0]
if len(message.To) > 1 {
recipients += " and others"
} }
}
slog.Warn("Failed to send email asynchronously",
slog.String("recipients", recipients),
slog.Any("error", err))
}
}()
} }
...@@ -28,8 +28,25 @@ var ( ...@@ -28,8 +28,25 @@ var (
DialContext: safeDialContext, DialContext: safeDialContext,
}, },
} }
asyncPostQueue = make(chan *WebhookRequestPayload, 128)
) )
func init() {
for range 4 {
go func() {
for payload := range asyncPostQueue {
if err := Post(payload); err != nil {
slog.Warn("Failed to dispatch webhook asynchronously",
slog.String("url", payload.URL),
slog.String("activityType", payload.ActivityType),
slog.Any("err", err))
}
}
}()
}
}
// safeDialContext is a net.Dialer.DialContext replacement that resolves the target // safeDialContext is a net.Dialer.DialContext replacement that resolves the target
// hostname and rejects any address that falls within a reserved/private IP range. // hostname and rejects any address that falls within a reserved/private IP range.
func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) { func safeDialContext(ctx context.Context, network, addr string) (net.Conn, error) {
...@@ -82,7 +99,7 @@ func Post(requestPayload *WebhookRequestPayload) error { ...@@ -82,7 +99,7 @@ func Post(requestPayload *WebhookRequestPayload) error {
} }
defer resp.Body.Close() defer resp.Body.Close()
b, err := io.ReadAll(resp.Body) b, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to read webhook response from %s", requestPayload.URL) return errors.Wrapf(err, "failed to read webhook response from %s", requestPayload.URL)
} }
...@@ -107,14 +124,17 @@ func Post(requestPayload *WebhookRequestPayload) error { ...@@ -107,14 +124,17 @@ func Post(requestPayload *WebhookRequestPayload) error {
} }
// PostAsync posts the message to webhook endpoint asynchronously. // PostAsync posts the message to webhook endpoint asynchronously.
// It spawns a new goroutine to handle the request and does not wait for the response. // It enqueues the request for bounded asynchronous dispatch and does not wait for the response.
func PostAsync(requestPayload *WebhookRequestPayload) { func PostAsync(requestPayload *WebhookRequestPayload) {
go func() { if requestPayload == nil {
if err := Post(requestPayload); err != nil { slog.Warn("Dropped webhook dispatch because payload is nil")
slog.Warn("Failed to dispatch webhook asynchronously", return
}
select {
case asyncPostQueue <- requestPayload:
default:
slog.Warn("Dropped webhook dispatch because the async queue is full",
slog.String("url", requestPayload.URL), slog.String("url", requestPayload.URL),
slog.String("activityType", requestPayload.ActivityType), slog.String("activityType", requestPayload.ActivityType))
slog.Any("err", err))
} }
}()
} }
package webhook package webhook
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestPostAsyncNilPayloadDoesNotPanic(t *testing.T) {
require.NotPanics(t, func() {
PostAsync(nil)
})
}
...@@ -2,6 +2,8 @@ package v1 ...@@ -2,6 +2,8 @@ package v1
import ( import (
"bytes" "bytes"
"encoding/binary"
"hash/crc32"
"image" "image"
"image/color" "image/color"
"image/jpeg" "image/jpeg"
...@@ -189,3 +191,42 @@ func TestStripImageExif(t *testing.T) { ...@@ -189,3 +191,42 @@ func TestStripImageExif(t *testing.T) {
assert.Error(t, err) assert.Error(t, err)
}) })
} }
func TestValidateImagePixelCountRejectsOversizedDimensions(t *testing.T) {
t.Parallel()
err := validateImagePixelCount(testPNGHeaderWithDimensions(100_000, 100_000))
require.Error(t, err)
require.Contains(t, err.Error(), "image dimensions exceed maximum")
}
func TestStripImageExifRejectsOversizedDimensionsBeforeDecode(t *testing.T) {
t.Parallel()
_, err := stripImageExif(testPNGHeaderWithDimensions(100_000, 100_000), "image/png")
require.Error(t, err)
require.Contains(t, err.Error(), "image dimensions exceed maximum")
}
func testPNGHeaderWithDimensions(width, height uint32) []byte {
var buf bytes.Buffer
buf.Write([]byte{0x89, 'P', 'N', 'G', '\r', '\n', 0x1a, '\n'})
ihdr := make([]byte, 13)
binary.BigEndian.PutUint32(ihdr[0:4], width)
binary.BigEndian.PutUint32(ihdr[4:8], height)
ihdr[8] = 8
ihdr[9] = 2
writePNGChunk(&buf, "IHDR", ihdr)
writePNGChunk(&buf, "IEND", nil)
return buf.Bytes()
}
func writePNGChunk(buf *bytes.Buffer, chunkType string, data []byte) {
_ = binary.Write(buf, binary.BigEndian, uint32(len(data)))
buf.WriteString(chunkType)
buf.Write(data)
crc := crc32.ChecksumIEEE(append([]byte(chunkType), data...))
_ = binary.Write(buf, binary.BigEndian, crc)
}
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"context" "context"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"image"
"io" "io"
"log/slog" "log/slog"
"mime" "mime"
...@@ -45,6 +46,7 @@ const ( ...@@ -45,6 +46,7 @@ const (
// Quality 95 maintains visual quality while ensuring metadata is removed. // Quality 95 maintains visual quality while ensuring metadata is removed.
defaultJPEGQuality = 95 defaultJPEGQuality = 95
maxBatchDeleteAttachments = 100 maxBatchDeleteAttachments = 100
maxImagePixels = 50_000_000
) )
var SupportedThumbnailMimeTypes = []string{ var SupportedThumbnailMimeTypes = []string{
...@@ -148,12 +150,18 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat ...@@ -148,12 +150,18 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
// Strip EXIF metadata from images for privacy protection. // Strip EXIF metadata from images for privacy protection.
// This removes sensitive information like GPS location, device details, etc. // This removes sensitive information like GPS location, device details, etc.
if shouldStripExif(create.Type) && !isAndroidMotionContainer(create.Payload.GetMotionMedia()) { if shouldStripExif(create.Type) && !isAndroidMotionContainer(create.Payload.GetMotionMedia()) {
if strippedBlob, err := stripImageExif(create.Blob, create.Type); err != nil { release, err := s.acquireImageProcessingSlot(ctx)
if err != nil {
return nil, status.Errorf(codes.ResourceExhausted, "too many image processing requests")
}
strippedBlob, stripErr := stripImageExif(create.Blob, create.Type)
release()
if stripErr != nil {
// Log warning but continue with original image to ensure uploads don't fail. // Log warning but continue with original image to ensure uploads don't fail.
slog.Warn("failed to strip EXIF metadata from image", slog.Warn("failed to strip EXIF metadata from image",
slog.String("type", create.Type), slog.String("type", create.Type),
slog.String("filename", create.Filename), slog.String("filename", create.Filename),
slog.String("error", err.Error())) slog.String("error", stripErr.Error()))
} else { } else {
create.Blob = strippedBlob create.Blob = strippedBlob
create.Size = int64(len(strippedBlob)) create.Size = int64(len(strippedBlob))
...@@ -745,6 +753,32 @@ func shouldStripExif(mimeType string) bool { ...@@ -745,6 +753,32 @@ func shouldStripExif(mimeType string) bool {
return exifCapableImageTypes[mimeType] return exifCapableImageTypes[mimeType]
} }
func (s *APIV1Service) acquireImageProcessingSlot(ctx context.Context) (func(), error) {
if s.imageProcessingSemaphore == nil {
return func() {}, nil
}
if err := s.imageProcessingSemaphore.Acquire(ctx, 1); err != nil {
return nil, err
}
return func() {
s.imageProcessingSemaphore.Release(1)
}, nil
}
func validateImagePixelCount(imageData []byte) error {
config, _, err := image.DecodeConfig(bytes.NewReader(imageData))
if err != nil {
return nil
}
if config.Width <= 0 || config.Height <= 0 {
return errors.New("invalid image dimensions")
}
if config.Width > maxImagePixels/config.Height {
return errors.Errorf("image dimensions exceed maximum of %d pixels", maxImagePixels)
}
return nil
}
// stripImageExif removes EXIF metadata from image files by decoding and re-encoding them. // stripImageExif removes EXIF metadata from image files by decoding and re-encoding them.
// This prevents exposure of sensitive metadata such as GPS location, camera details, and timestamps. // This prevents exposure of sensitive metadata such as GPS location, camera details, and timestamps.
// //
...@@ -759,6 +793,10 @@ func shouldStripExif(mimeType string) bool { ...@@ -759,6 +793,10 @@ func shouldStripExif(mimeType string) bool {
// //
// Returns the cleaned image data without any EXIF metadata, or an error if processing fails. // Returns the cleaned image data without any EXIF metadata, or an error if processing fails.
func stripImageExif(imageData []byte, mimeType string) ([]byte, error) { func stripImageExif(imageData []byte, mimeType string) ([]byte, error) {
if err := validateImagePixelCount(imageData); err != nil {
return nil, err
}
// Decode image with automatic EXIF orientation correction. // Decode image with automatic EXIF orientation correction.
// This ensures the image displays correctly after metadata removal. // This ensures the image displays correctly after metadata removal.
img, err := imaging.Decode(bytes.NewReader(imageData), imaging.AutoOrientation(true)) img, err := imaging.Decode(bytes.NewReader(imageData), imaging.AutoOrientation(true))
......
...@@ -44,6 +44,17 @@ func getPageToken(limit int, offset int) (string, error) { ...@@ -44,6 +44,17 @@ func getPageToken(limit int, offset int) (string, error) {
}) })
} }
func normalizePageSize(pageSize int32) int {
limit := int(pageSize)
if limit <= 0 {
return DefaultPageSize
}
if limit > MaxPageSize {
return MaxPageSize
}
return limit
}
func marshalPageToken(pageToken *v1pb.PageToken) (string, error) { func marshalPageToken(pageToken *v1pb.PageToken) (string, error) {
b, err := proto.Marshal(pageToken) b, err := proto.Marshal(pageToken)
if err != nil { if err != nil {
......
package v1
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestNormalizePageSize(t *testing.T) {
t.Parallel()
tests := []struct {
name string
pageSize int32
want int
}{
{
name: "default for zero",
pageSize: 0,
want: DefaultPageSize,
},
{
name: "default for negative",
pageSize: -1,
want: DefaultPageSize,
},
{
name: "preserves valid size",
pageSize: 42,
want: 42,
},
{
name: "clamps oversized size",
pageSize: int32(MaxPageSize + 1),
want: MaxPageSize,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tt.want, normalizePageSize(tt.pageSize))
})
}
}
...@@ -232,14 +232,15 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq ...@@ -232,14 +232,15 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil { if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err) return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
} }
limit = int(pageToken.Limit) limit = normalizePageSize(pageToken.Limit)
offset = int(pageToken.Offset) offset = int(pageToken.Offset)
} else { if offset < 0 {
limit = int(request.PageSize) offset = 0
} }
if limit <= 0 { } else {
limit = DefaultPageSize limit = normalizePageSize(request.PageSize)
} }
limit = min(limit, MaxPageSize)
limitPlusOne := limit + 1 limitPlusOne := limit + 1
memoFind.Limit = &limitPlusOne memoFind.Limit = &limitPlusOne
memoFind.Offset = &offset memoFind.Offset = &offset
...@@ -715,18 +716,45 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM ...@@ -715,18 +716,45 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID) memoFilter = fmt.Sprintf(`creator_id == %d || visibility in ["PUBLIC", "PROTECTED"]`, currentUser.ID)
} }
memoRelationComment := store.MemoRelationComment memoRelationComment := store.MemoRelationComment
var limit, offset int
if request.PageToken != "" {
var pageToken v1pb.PageToken
if err := unmarshalPageToken(request.PageToken, &pageToken); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid page token: %v", err)
}
limit = normalizePageSize(pageToken.Limit)
offset = int(pageToken.Offset)
if offset < 0 {
offset = 0
}
} else {
limit = normalizePageSize(request.PageSize)
}
limitPlusOne := limit + 1
memoRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{ memoRelations, err := s.Store.ListMemoRelations(ctx, &store.FindMemoRelation{
RelatedMemoID: &memo.ID, RelatedMemoID: &memo.ID,
Type: &memoRelationComment, Type: &memoRelationComment,
MemoFilter: &memoFilter, MemoFilter: &memoFilter,
Limit: &limitPlusOne,
Offset: &offset,
}) })
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memo relations") return nil, status.Errorf(codes.Internal, "failed to list memo relations")
} }
nextPageToken := ""
if len(memoRelations) == limitPlusOne {
memoRelations = memoRelations[:limit]
nextPageToken, err = getPageToken(limit, offset+limit)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get next page token, error: %v", err)
}
}
if len(memoRelations) == 0 { if len(memoRelations) == 0 {
response := &v1pb.ListMemoCommentsResponse{ response := &v1pb.ListMemoCommentsResponse{
Memos: []*v1pb.Memo{}, Memos: []*v1pb.Memo{},
NextPageToken: nextPageToken,
} }
return response, nil return response, nil
} }
...@@ -808,6 +836,7 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM ...@@ -808,6 +836,7 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
response := &v1pb.ListMemoCommentsResponse{ response := &v1pb.ListMemoCommentsResponse{
Memos: memosResponse, Memos: memosResponse,
NextPageToken: nextPageToken,
} }
return response, nil return response, nil
} }
......
...@@ -373,6 +373,46 @@ func TestListMemoCommentsSkipsCommentsWithMissingCreators(t *testing.T) { ...@@ -373,6 +373,46 @@ func TestListMemoCommentsSkipsCommentsWithMissingCreators(t *testing.T) {
require.Empty(t, resp.Memos) require.Empty(t, resp.Memos)
} }
func TestListMemoCommentsPaginates(t *testing.T) {
ctx := context.Background()
ts := NewTestService(t)
defer ts.Cleanup()
owner, err := ts.CreateRegularUser(ctx, "comment-page-owner")
require.NoError(t, err)
ownerCtx := ts.CreateUserContext(ctx, owner.ID)
memo, err := ts.Service.CreateMemo(ownerCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "memo with paged comments",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
for i := 0; i < 3; i++ {
_, err = ts.Service.CreateMemoComment(ownerCtx, &apiv1.CreateMemoCommentRequest{
Name: memo.Name,
Comment: &apiv1.Memo{
Content: fmt.Sprintf("comment %d", i),
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
}
firstPage, err := ts.Service.ListMemoComments(ownerCtx, &apiv1.ListMemoCommentsRequest{Name: memo.Name, PageSize: 2})
require.NoError(t, err)
require.Len(t, firstPage.Memos, 2)
require.NotEmpty(t, firstPage.NextPageToken)
secondPage, err := ts.Service.ListMemoComments(ownerCtx, &apiv1.ListMemoCommentsRequest{Name: memo.Name, PageToken: firstPage.NextPageToken})
require.NoError(t, err)
require.Len(t, secondPage.Memos, 1)
require.Empty(t, secondPage.NextPageToken)
}
// TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments. // TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments.
// This addresses issue #5483: https://github.com/usememos/memos/issues/5483 // This addresses issue #5483: https://github.com/usememos/memos/issues/5483
func TestCreateMemoWithCustomTimestamps(t *testing.T) { func TestCreateMemoWithCustomTimestamps(t *testing.T) {
......
...@@ -17,6 +17,8 @@ import ( ...@@ -17,6 +17,8 @@ import (
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
) )
const maxAPIRequestBytes = 256 << 20
type APIV1Service struct { type APIV1Service struct {
v1pb.UnimplementedInstanceServiceServer v1pb.UnimplementedInstanceServiceServer
v1pb.UnimplementedAuthServiceServer v1pb.UnimplementedAuthServiceServer
...@@ -35,6 +37,7 @@ type APIV1Service struct { ...@@ -35,6 +37,7 @@ type APIV1Service struct {
// thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion // thumbnailSemaphore limits concurrent thumbnail generation to prevent memory exhaustion
thumbnailSemaphore *semaphore.Weighted thumbnailSemaphore *semaphore.Weighted
imageProcessingSemaphore *semaphore.Weighted
} }
func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store) *APIV1Service { func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store) *APIV1Service {
...@@ -49,6 +52,7 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store ...@@ -49,6 +52,7 @@ func NewAPIV1Service(secret string, profile *profile.Profile, store *store.Store
MarkdownService: markdownService, MarkdownService: markdownService,
SSEHub: NewSSEHub(), SSEHub: NewSSEHub(),
thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations thumbnailSemaphore: semaphore.NewWeighted(3), // Limit to 3 concurrent thumbnail generations
imageProcessingSemaphore: semaphore.NewWeighted(2),
} }
} }
...@@ -120,7 +124,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech ...@@ -120,7 +124,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
})) }))
// Register SSE endpoint with same CORS as rest of /api/v1. // Register SSE endpoint with same CORS as rest of /api/v1.
RegisterSSERoutes(gwGroup, s.SSEHub, s.Store, s.Secret) RegisterSSERoutes(gwGroup, s.SSEHub, s.Store, s.Secret)
handler := echo.WrapHandler(gwMux) handler := echo.WrapHandler(http.MaxBytesHandler(gwMux, maxAPIRequestBytes))
gwGroup.Any("/api/v1/*", handler) gwGroup.Any("/api/v1/*", handler)
gwGroup.Any("/file/*", handler) gwGroup.Any("/file/*", handler)
...@@ -135,7 +139,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech ...@@ -135,7 +139,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
) )
connectMux := http.NewServeMux() connectMux := http.NewServeMux()
connectHandler := NewConnectServiceHandler(s) connectHandler := NewConnectServiceHandler(s)
connectHandler.RegisterConnectHandlers(connectMux, connectInterceptors) connectHandler.RegisterConnectHandlers(connectMux, connectInterceptors, connect.WithReadMaxBytes(maxAPIRequestBytes))
// Wrap with CORS for browser access // Wrap with CORS for browser access
corsHandler := middleware.CORSWithConfig(middleware.CORSConfig{ corsHandler := middleware.CORSWithConfig(middleware.CORSConfig{
...@@ -147,7 +151,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech ...@@ -147,7 +151,7 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
AllowCredentials: true, AllowCredentials: true,
}) })
connectGroup := echoServer.Group("", corsHandler) connectGroup := echoServer.Group("", corsHandler)
connectGroup.Any("/memos.api.v1.*", echo.WrapHandler(connectMux)) connectGroup.Any("/memos.api.v1.*", echo.WrapHandler(http.MaxBytesHandler(connectMux, maxAPIRequestBytes)))
return nil return nil
} }
...@@ -223,18 +223,14 @@ func (s *FileServerService) serveMediaStream(c *echo.Context, attachment *store. ...@@ -223,18 +223,14 @@ func (s *FileServerService) serveMediaStream(c *echo.Context, attachment *store.
// serveStaticFile serves non-streaming files (images, documents, etc.). // serveStaticFile serves non-streaming files (images, documents, etc.).
func (s *FileServerService) serveStaticFile(c *echo.Context, attachment *store.Attachment, contentType string, wantThumbnail bool) error { func (s *FileServerService) serveStaticFile(c *echo.Context, attachment *store.Attachment, contentType string, wantThumbnail bool) error {
blob, err := s.getAttachmentBlob(attachment)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment blob").Wrap(err)
}
// Generate thumbnail for supported image types. // Generate thumbnail for supported image types.
if wantThumbnail && thumbnailSupportedTypes[attachment.Type] { if wantThumbnail && thumbnailSupportedTypes[attachment.Type] {
if thumbnailBlob, err := s.getOrGenerateThumbnail(c.Request().Context(), attachment); err != nil { if thumbnailBlob, err := s.getOrGenerateThumbnail(c.Request().Context(), attachment); err != nil {
slog.Warn("failed to get thumbnail", "error", err) slog.Warn("failed to get thumbnail", "error", err)
} else { } else {
blob = thumbnailBlob setSecurityHeaders(c)
contentType = "image/jpeg" setMediaHeaders(c, "image/jpeg", attachment.Type)
return c.Blob(http.StatusOK, "image/jpeg", thumbnailBlob)
} }
} }
...@@ -246,7 +242,24 @@ func (s *FileServerService) serveStaticFile(c *echo.Context, attachment *store.A ...@@ -246,7 +242,24 @@ func (s *FileServerService) serveStaticFile(c *echo.Context, attachment *store.A
c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("attachment; filename=%q", attachment.Filename)) c.Response().Header().Set(echo.HeaderContentDisposition, fmt.Sprintf("attachment; filename=%q", attachment.Filename))
} }
return c.Blob(http.StatusOK, contentType, blob) switch attachment.StorageType {
case storepb.AttachmentStorageType_LOCAL:
filePath, err := s.resolveLocalPath(attachment.Reference)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to resolve file path").Wrap(err)
}
http.ServeFile(c.Response(), c.Request(), filePath)
return nil
case storepb.AttachmentStorageType_S3:
reader, err := s.getAttachmentReader(c.Request().Context(), attachment)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "failed to get attachment reader").Wrap(err)
}
defer reader.Close()
return c.Stream(http.StatusOK, contentType, reader)
default:
return c.Blob(http.StatusOK, contentType, attachment.Blob)
}
} }
// ============================================================================= // =============================================================================
...@@ -260,7 +273,7 @@ func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]b ...@@ -260,7 +273,7 @@ func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]b
return s.readLocalFile(attachment.Reference) return s.readLocalFile(attachment.Reference)
case storepb.AttachmentStorageType_S3: case storepb.AttachmentStorageType_S3:
return s.downloadFromS3(attachment) return s.downloadFromS3(context.Background(), attachment)
default: default:
return attachment.Blob, nil return attachment.Blob, nil
...@@ -268,7 +281,7 @@ func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]b ...@@ -268,7 +281,7 @@ func (s *FileServerService) getAttachmentBlob(attachment *store.Attachment) ([]b
} }
// getAttachmentReader returns a reader for streaming attachment content. // getAttachmentReader returns a reader for streaming attachment content.
func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (io.ReadCloser, error) { func (s *FileServerService) getAttachmentReader(ctx context.Context, attachment *store.Attachment) (io.ReadCloser, error) {
switch attachment.StorageType { switch attachment.StorageType {
case storepb.AttachmentStorageType_LOCAL: case storepb.AttachmentStorageType_LOCAL:
filePath, err := s.resolveLocalPath(attachment.Reference) filePath, err := s.resolveLocalPath(attachment.Reference)
...@@ -289,7 +302,7 @@ func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (i ...@@ -289,7 +302,7 @@ func (s *FileServerService) getAttachmentReader(attachment *store.Attachment) (i
if err != nil { if err != nil {
return nil, err return nil, err
} }
reader, err := s3Client.GetObjectStream(context.Background(), s3Object.Key) reader, err := s3Client.GetObjectStream(ctx, s3Object.Key)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to stream from S3") return nil, errors.Wrap(err, "failed to stream from S3")
} }
...@@ -356,13 +369,13 @@ func (*FileServerService) createS3Client(attachment *store.Attachment) (*s3.Clie ...@@ -356,13 +369,13 @@ func (*FileServerService) createS3Client(attachment *store.Attachment) (*s3.Clie
} }
// downloadFromS3 downloads the entire object from S3. // downloadFromS3 downloads the entire object from S3.
func (s *FileServerService) downloadFromS3(attachment *store.Attachment) ([]byte, error) { func (s *FileServerService) downloadFromS3(ctx context.Context, attachment *store.Attachment) ([]byte, error) {
client, s3Object, err := s.createS3Client(attachment) client, s3Object, err := s.createS3Client(attachment)
if err != nil { if err != nil {
return nil, err return nil, err
} }
blob, err := client.GetObject(context.Background(), s3Object.Key) blob, err := client.GetObject(ctx, s3Object.Key)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to download from S3") return nil, errors.Wrap(err, "failed to download from S3")
} }
...@@ -411,7 +424,7 @@ func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachme ...@@ -411,7 +424,7 @@ func (s *FileServerService) getOrGenerateThumbnail(ctx context.Context, attachme
return blob, nil return blob, nil
} }
return s.generateThumbnail(attachment, thumbnailPath) return s.generateThumbnail(ctx, attachment, thumbnailPath)
} }
// getThumbnailPath returns the file path for a cached thumbnail. // getThumbnailPath returns the file path for a cached thumbnail.
...@@ -435,8 +448,8 @@ func (*FileServerService) readCachedThumbnail(path string) ([]byte, error) { ...@@ -435,8 +448,8 @@ func (*FileServerService) readCachedThumbnail(path string) ([]byte, error) {
} }
// generateThumbnail creates a new thumbnail and saves it to disk. // generateThumbnail creates a new thumbnail and saves it to disk.
func (s *FileServerService) generateThumbnail(attachment *store.Attachment, thumbnailPath string) ([]byte, error) { func (s *FileServerService) generateThumbnail(ctx context.Context, attachment *store.Attachment, thumbnailPath string) ([]byte, error) {
reader, err := s.getAttachmentReader(attachment) reader, err := s.getAttachmentReader(ctx, attachment)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to get attachment reader") return nil, errors.Wrap(err, "failed to get attachment reader")
} }
......
...@@ -73,6 +73,52 @@ func TestServeAttachmentFile_ShareTokenAllowsDirectMemoAttachment(t *testing.T) ...@@ -73,6 +73,52 @@ func TestServeAttachmentFile_ShareTokenAllowsDirectMemoAttachment(t *testing.T)
require.Equal(t, "memo attachment", rec.Body.String()) require.Equal(t, "memo attachment", rec.Body.String())
} }
func TestServeAttachmentFile_LocalStaticFileSupportsRangeRequests(t *testing.T) {
ctx := context.Background()
svc, fs, _, cleanup := newShareAttachmentTestServices(ctx, t)
defer cleanup()
creator, err := svc.Store.CreateUser(ctx, &store.User{
Username: "range-owner",
Role: store.RoleUser,
Email: "range-owner@example.com",
})
require.NoError(t, err)
creatorCtx := context.WithValue(ctx, auth.UserIDContextKey, creator.ID)
attachment, err := svc.CreateAttachment(creatorCtx, &apiv1.CreateAttachmentRequest{
Attachment: &apiv1.Attachment{
Filename: "range.txt",
Type: "text/plain",
Content: []byte("0123456789"),
},
})
require.NoError(t, err)
_, err = svc.CreateMemo(creatorCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "range memo",
Visibility: apiv1.Visibility_PUBLIC,
Attachments: []*apiv1.Attachment{
{Name: attachment.Name},
},
},
})
require.NoError(t, err)
e := echo.New()
fs.RegisterRoutes(e)
req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("/file/%s/%s", attachment.Name, attachment.Filename), nil)
req.Header.Set("Range", "bytes=2-5")
rec := httptest.NewRecorder()
e.ServeHTTP(rec, req)
require.Equal(t, http.StatusPartialContent, rec.Code)
require.Equal(t, "2345", rec.Body.String())
require.Equal(t, "bytes 2-5/10", rec.Header().Get("Content-Range"))
}
func TestServeAttachmentFile_ShareTokenRejectsCommentAttachment(t *testing.T) { func TestServeAttachmentFile_ShareTokenRejectsCommentAttachment(t *testing.T) {
ctx := context.Background() ctx := context.Background()
svc, fs, _, cleanup := newShareAttachmentTestServices(ctx, t) svc, fs, _, cleanup := newShareAttachmentTestServices(ctx, t)
......
...@@ -34,13 +34,13 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) ...@@ -34,13 +34,13 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation)
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) { func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
where, args := []string{"TRUE"}, []any{} where, args := []string{"TRUE"}, []any{}
if find.MemoID != nil { if find.MemoID != nil {
where, args = append(where, "`memo_id` = ?"), append(args, find.MemoID) where, args = append(where, "`memo_id` = ?"), append(args, *find.MemoID)
} }
if find.RelatedMemoID != nil { if find.RelatedMemoID != nil {
where, args = append(where, "`related_memo_id` = ?"), append(args, find.RelatedMemoID) where, args = append(where, "`related_memo_id` = ?"), append(args, *find.RelatedMemoID)
} }
if find.Type != nil { if find.Type != nil {
where, args = append(where, "`type` = ?"), append(args, find.Type) where, args = append(where, "`type` = ?"), append(args, *find.Type)
} }
if len(find.MemoIDList) > 0 { if len(find.MemoIDList) > 0 {
placeholders := make([]string, len(find.MemoIDList)) placeholders := make([]string, len(find.MemoIDList))
...@@ -73,7 +73,15 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation ...@@ -73,7 +73,15 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
} }
} }
rows, err := d.db.QueryContext(ctx, "SELECT `memo_id`, `related_memo_id`, `type` FROM `memo_relation` WHERE "+strings.Join(where, " AND "), args...) query := "SELECT `memo_id`, `related_memo_id`, `type` FROM `memo_relation` WHERE " + strings.Join(where, " AND ") + " ORDER BY `memo_id` DESC"
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
}
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -102,13 +110,13 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation ...@@ -102,13 +110,13 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error { func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
where, args := []string{"TRUE"}, []any{} where, args := []string{"TRUE"}, []any{}
if delete.MemoID != nil { if delete.MemoID != nil {
where, args = append(where, "`memo_id` = ?"), append(args, delete.MemoID) where, args = append(where, "`memo_id` = ?"), append(args, *delete.MemoID)
} }
if delete.RelatedMemoID != nil { if delete.RelatedMemoID != nil {
where, args = append(where, "`related_memo_id` = ?"), append(args, delete.RelatedMemoID) where, args = append(where, "`related_memo_id` = ?"), append(args, *delete.RelatedMemoID)
} }
if delete.Type != nil { if delete.Type != nil {
where, args = append(where, "`type` = ?"), append(args, delete.Type) where, args = append(where, "`type` = ?"), append(args, *delete.Type)
} }
stmt := "DELETE FROM `memo_relation` WHERE " + strings.Join(where, " AND ") stmt := "DELETE FROM `memo_relation` WHERE " + strings.Join(where, " AND ")
result, err := d.db.ExecContext(ctx, stmt, args...) result, err := d.db.ExecContext(ctx, stmt, args...)
......
...@@ -41,13 +41,13 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) ...@@ -41,13 +41,13 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation)
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) { func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if find.MemoID != nil { if find.MemoID != nil {
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, find.MemoID) where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *find.MemoID)
} }
if find.RelatedMemoID != nil { if find.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, find.RelatedMemoID) where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, *find.RelatedMemoID)
} }
if find.Type != nil { if find.Type != nil {
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type) where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, *find.Type)
} }
if len(find.MemoIDList) > 0 { if len(find.MemoIDList) > 0 {
memoPlaceholders := make([]string, len(find.MemoIDList)) memoPlaceholders := make([]string, len(find.MemoIDList))
...@@ -93,13 +93,22 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation ...@@ -93,13 +93,22 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
} }
} }
rows, err := d.db.QueryContext(ctx, ` query := `
SELECT SELECT
memo_id, memo_id,
related_memo_id, related_memo_id,
type type
FROM memo_relation FROM memo_relation
WHERE `+strings.Join(where, " AND "), args...) WHERE ` + strings.Join(where, " AND ") + `
ORDER BY memo_id DESC`
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
}
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -128,13 +137,13 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation ...@@ -128,13 +137,13 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error { func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if delete.MemoID != nil { if delete.MemoID != nil {
where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, delete.MemoID) where, args = append(where, "memo_id = "+placeholder(len(args)+1)), append(args, *delete.MemoID)
} }
if delete.RelatedMemoID != nil { if delete.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, delete.RelatedMemoID) where, args = append(where, "related_memo_id = "+placeholder(len(args)+1)), append(args, *delete.RelatedMemoID)
} }
if delete.Type != nil { if delete.Type != nil {
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, delete.Type) where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, *delete.Type)
} }
stmt := `DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ") stmt := `DELETE FROM memo_relation WHERE ` + strings.Join(where, " AND ")
result, err := d.db.ExecContext(ctx, stmt, args...) result, err := d.db.ExecContext(ctx, stmt, args...)
......
...@@ -41,13 +41,13 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) ...@@ -41,13 +41,13 @@ func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation)
func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) { func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) {
where, args := []string{"TRUE"}, []any{} where, args := []string{"TRUE"}, []any{}
if find.MemoID != nil { if find.MemoID != nil {
where, args = append(where, "memo_id = ?"), append(args, find.MemoID) where, args = append(where, "memo_id = ?"), append(args, *find.MemoID)
} }
if find.RelatedMemoID != nil { if find.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = ?"), append(args, find.RelatedMemoID) where, args = append(where, "related_memo_id = ?"), append(args, *find.RelatedMemoID)
} }
if find.Type != nil { if find.Type != nil {
where, args = append(where, "type = ?"), append(args, find.Type) where, args = append(where, "type = ?"), append(args, *find.Type)
} }
if len(find.MemoIDList) > 0 { if len(find.MemoIDList) > 0 {
placeholders := make([]string, len(find.MemoIDList)) placeholders := make([]string, len(find.MemoIDList))
...@@ -78,13 +78,22 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation ...@@ -78,13 +78,22 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
} }
} }
rows, err := d.db.QueryContext(ctx, ` query := `
SELECT SELECT
memo_id, memo_id,
related_memo_id, related_memo_id,
type type
FROM memo_relation FROM memo_relation
WHERE `+strings.Join(where, " AND "), args...) WHERE ` + strings.Join(where, " AND ") + `
ORDER BY memo_id DESC`
if find.Limit != nil {
query = fmt.Sprintf("%s LIMIT %d", query, *find.Limit)
if find.Offset != nil {
query = fmt.Sprintf("%s OFFSET %d", query, *find.Offset)
}
}
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -113,13 +122,13 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation ...@@ -113,13 +122,13 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error { func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error {
where, args := []string{"TRUE"}, []any{} where, args := []string{"TRUE"}, []any{}
if delete.MemoID != nil { if delete.MemoID != nil {
where, args = append(where, "memo_id = ?"), append(args, delete.MemoID) where, args = append(where, "memo_id = ?"), append(args, *delete.MemoID)
} }
if delete.RelatedMemoID != nil { if delete.RelatedMemoID != nil {
where, args = append(where, "related_memo_id = ?"), append(args, delete.RelatedMemoID) where, args = append(where, "related_memo_id = ?"), append(args, *delete.RelatedMemoID)
} }
if delete.Type != nil { if delete.Type != nil {
where, args = append(where, "type = ?"), append(args, delete.Type) where, args = append(where, "type = ?"), append(args, *delete.Type)
} }
stmt := ` stmt := `
DELETE FROM memo_relation DELETE FROM memo_relation
......
...@@ -26,6 +26,8 @@ type FindMemoRelation struct { ...@@ -26,6 +26,8 @@ type FindMemoRelation struct {
MemoFilter *string MemoFilter *string
// MemoIDList matches relations where memo_id OR related_memo_id is in the list. // MemoIDList matches relations where memo_id OR related_memo_id is in the list.
MemoIDList []int32 MemoIDList []int32
Limit *int
Offset *int
} }
type DeleteMemoRelation struct { type DeleteMemoRelation struct {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment