Commit c3e7e2c3 authored by boojack's avatar boojack

fix: normalize attachment MIME types before validation

parent aafcc21a
...@@ -82,24 +82,25 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat ...@@ -82,24 +82,25 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat
if !validateFilename(request.Attachment.Filename) { if !validateFilename(request.Attachment.Filename) {
return nil, status.Errorf(codes.InvalidArgument, "filename contains invalid characters or format") return nil, status.Errorf(codes.InvalidArgument, "filename contains invalid characters or format")
} }
if request.Attachment.Type == "" { normalizedMimeType := request.Attachment.Type
if normalizedMimeType == "" {
ext := filepath.Ext(request.Attachment.Filename) ext := filepath.Ext(request.Attachment.Filename)
mimeType := mime.TypeByExtension(ext) mimeType := mime.TypeByExtension(ext)
if mimeType == "" { if mimeType == "" {
mimeType = http.DetectContentType(request.Attachment.Content) mimeType = http.DetectContentType(request.Attachment.Content)
} }
// ParseMediaType to strip parameters if normalizedType, ok := normalizeMimeType(mimeType); ok {
mediaType, _, err := mime.ParseMediaType(mimeType) normalizedMimeType = normalizedType
if err == nil {
request.Attachment.Type = mediaType
} }
} }
if request.Attachment.Type == "" { if normalizedMimeType == "" {
request.Attachment.Type = "application/octet-stream" normalizedMimeType = "application/octet-stream"
} }
if !isValidMimeType(request.Attachment.Type) { normalizedType, ok := normalizeMimeType(normalizedMimeType)
if !ok {
return nil, status.Errorf(codes.InvalidArgument, "invalid MIME type format") return nil, status.Errorf(codes.InvalidArgument, "invalid MIME type format")
} }
request.Attachment.Type = normalizedType
attachmentUID, err := ValidateAndGenerateUID(request.AttachmentId) attachmentUID, err := ValidateAndGenerateUID(request.AttachmentId)
if err != nil { if err != nil {
...@@ -617,16 +618,18 @@ func validateFilename(filename string) bool { ...@@ -617,16 +618,18 @@ func validateFilename(filename string) bool {
return true return true
} }
func isValidMimeType(mimeType string) bool { func normalizeMimeType(mimeType string) (string, bool) {
// Reject empty or excessively long MIME types mimeType = strings.TrimSpace(mimeType)
if mimeType == "" || len(mimeType) > 255 { if mimeType == "" || len(mimeType) > 255 {
return false return "", false
}
mediaType, _, err := mime.ParseMediaType(mimeType)
if err != nil || mediaType == "" || len(mediaType) > 255 {
return "", false
} }
// MIME type must match the pattern: type/subtype return mediaType, true
// Allow common characters in MIME types per RFC 2045
matched, _ := regexp.MatchString(`^[a-zA-Z0-9][a-zA-Z0-9!#$&^_.+-]{0,126}/[a-zA-Z0-9][a-zA-Z0-9!#$&^_.+-]{0,126}$`, mimeType)
return matched
} }
func (s *APIV1Service) validateAttachmentFilter(ctx context.Context, filterStr string) error { func (s *APIV1Service) validateAttachmentFilter(ctx context.Context, filterStr string) error {
......
...@@ -61,6 +61,30 @@ func TestCreateAttachment(t *testing.T) { ...@@ -61,6 +61,30 @@ func TestCreateAttachment(t *testing.T) {
require.Equal(t, "application/octet-stream", attachment.Type) require.Equal(t, "application/octet-stream", attachment.Type)
}) })
t.Run("Type_WithParameters_NormalizedBeforeValidation", func(t *testing.T) {
attachment, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
Attachment: &v1pb.Attachment{
Filename: "voice-note.webm",
Type: "audio/webm;codecs=opus",
Content: []byte("fake webm content"),
},
})
require.NoError(t, err)
require.Equal(t, "audio/webm", attachment.Type)
})
t.Run("Type_InvalidFormat_Rejected", func(t *testing.T) {
_, err := ts.Service.CreateAttachment(userCtx, &v1pb.CreateAttachmentRequest{
Attachment: &v1pb.Attachment{
Filename: "broken.webm",
Type: `audio/webm;codecs="unterminated`,
Content: []byte("fake webm content"),
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid MIME type format")
})
t.Run("LocalStorage_PathCollisionUsesUniqueReference", func(t *testing.T) { t.Run("LocalStorage_PathCollisionUsesUniqueReference", func(t *testing.T) {
_, err := ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ _, err := ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_STORAGE, Key: storepb.InstanceSettingKey_STORAGE,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment