Unverified Commit d87539a1 authored by memoclaw's avatar memoclaw Committed by GitHub

feat: add Gemini transcription provider (#5830)

Co-authored-by: 's avatarmemoclaw <265580040+memoclaw@users.noreply.github.com>
parent 83ed32f1
......@@ -8,8 +8,6 @@ const (
ProviderOpenAI ProviderType = "OPENAI"
// ProviderOpenAICompatible is an OpenAI-compatible API endpoint.
ProviderOpenAICompatible ProviderType = "OPENAI_COMPATIBLE"
// ProviderAnthropic is Anthropic's API.
ProviderAnthropic ProviderType = "ANTHROPIC"
// ProviderGemini is Google's Gemini API.
ProviderGemini ProviderType = "GEMINI"
)
......
package gemini
import (
"net/http"
"net/url"
"strings"
"time"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/ai"
)
const defaultEndpoint = "https://generativelanguage.googleapis.com/v1beta"
// Transcriber transcribes audio with Gemini audio understanding.
type Transcriber struct {
endpoint string
apiKey string
httpClient *http.Client
}
// NewTranscriber creates a new Gemini transcriber.
func NewTranscriber(config ai.ProviderConfig, options ...Option) (*Transcriber, error) {
endpoint := strings.TrimSpace(config.Endpoint)
if endpoint == "" {
endpoint = defaultEndpoint
}
if _, err := url.ParseRequestURI(endpoint); err != nil {
return nil, errors.Wrap(err, "invalid Gemini endpoint")
}
if config.APIKey == "" {
return nil, errors.New("Gemini API key is required")
}
transcriber := &Transcriber{
endpoint: strings.TrimRight(endpoint, "/"),
apiKey: config.APIKey,
httpClient: &http.Client{
Timeout: 2 * time.Minute,
},
}
for _, option := range options {
option(transcriber)
}
return transcriber, nil
}
// Option configures a Transcriber.
type Option func(*Transcriber)
// WithHTTPClient sets the HTTP client used by the transcriber.
func WithHTTPClient(client *http.Client) Option {
return func(t *Transcriber) {
if client != nil {
t.httpClient = client
}
}
}
package gemini
import (
"bytes"
"context"
"encoding/base64"
"encoding/json"
"io"
"mime"
"net/http"
"net/url"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/ai"
)
const (
transcriptionInstruction = `Transcribe the audio accurately. Return only the transcript text. Do not summarize, explain, or add content that is not spoken.`
maxInlineAudioSizeBytes = 14 * 1024 * 1024
)
var supportedContentTypes = map[string]string{
"audio/wav": "audio/wav",
"audio/x-wav": "audio/wav",
"audio/mp3": "audio/mp3",
"audio/mpeg": "audio/mp3",
"audio/aiff": "audio/aiff",
"audio/aac": "audio/aac",
"audio/ogg": "audio/ogg",
"audio/flac": "audio/flac",
"audio/x-flac": "audio/flac",
}
type generateContentRequest struct {
Contents []content `json:"contents"`
GenerationConfig map[string]json.Number `json:"generationConfig,omitempty"`
}
type content struct {
Role string `json:"role,omitempty"`
Parts []part `json:"parts"`
}
type part struct {
Text string `json:"text,omitempty"`
InlineData *inlineData `json:"inlineData,omitempty"`
}
type inlineData struct {
MIMEType string `json:"mimeType"`
Data string `json:"data"`
}
type generateContentResponse struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
}
type errorResponse struct {
Error struct {
Message string `json:"message"`
Status string `json:"status"`
} `json:"error"`
}
// Transcribe transcribes audio with Gemini generateContent.
func (t *Transcriber) Transcribe(ctx context.Context, request ai.TranscribeRequest) (*ai.TranscribeResponse, error) {
if strings.TrimSpace(request.Model) == "" {
return nil, errors.New("model is required")
}
if request.Audio == nil {
return nil, errors.New("audio is required")
}
audio, err := io.ReadAll(request.Audio)
if err != nil {
return nil, errors.Wrap(err, "failed to read audio")
}
if len(audio) == 0 {
return nil, errors.New("audio is required")
}
if len(audio) > maxInlineAudioSizeBytes {
return nil, errors.Errorf("audio is too large for Gemini inline transcription; maximum size is %d bytes", maxInlineAudioSizeBytes)
}
contentType, err := normalizeContentType(request.ContentType)
if err != nil {
return nil, err
}
prompt := buildTranscriptionPrompt(request.Prompt, request.Language)
body, err := json.Marshal(generateContentRequest{
Contents: []content{
{
Role: "user",
Parts: []part{
{InlineData: &inlineData{
MIMEType: contentType,
Data: base64.StdEncoding.EncodeToString(audio),
}},
{Text: prompt},
},
},
},
GenerationConfig: map[string]json.Number{
"temperature": json.Number("0"),
},
})
if err != nil {
return nil, errors.Wrap(err, "failed to marshal Gemini transcription request")
}
httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, t.endpoint+"/models/"+url.PathEscape(normalizeModelName(request.Model))+":generateContent", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "failed to create Gemini transcription request")
}
httpRequest.Header.Set("Content-Type", "application/json")
httpRequest.Header.Set("x-goog-api-key", t.apiKey)
httpResponse, err := t.httpClient.Do(httpRequest)
if err != nil {
return nil, errors.Wrap(err, "failed to send Gemini transcription request")
}
defer httpResponse.Body.Close()
responseBody, err := io.ReadAll(httpResponse.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read Gemini transcription response")
}
if httpResponse.StatusCode < http.StatusOK || httpResponse.StatusCode >= http.StatusMultipleChoices {
return nil, errors.Errorf("Gemini transcription request failed with status %d: %s", httpResponse.StatusCode, extractErrorMessage(responseBody))
}
var response generateContentResponse
if err := json.Unmarshal(responseBody, &response); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal Gemini transcription response")
}
text := extractText(response)
if text == "" {
return nil, errors.New("Gemini transcription response did not include text")
}
return &ai.TranscribeResponse{
Text: text,
}, nil
}
func normalizeContentType(contentType string) (string, error) {
mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType))
if err != nil {
return "", errors.Wrap(err, "invalid audio content type")
}
mediaType = strings.ToLower(mediaType)
normalized, ok := supportedContentTypes[mediaType]
if !ok {
return "", errors.Errorf("audio content type %q is not supported by Gemini", mediaType)
}
return normalized, nil
}
func buildTranscriptionPrompt(prompt string, language string) string {
parts := []string{transcriptionInstruction}
language = strings.TrimSpace(language)
if language != "" {
parts = append(parts, "The input language is "+language+".")
}
prompt = strings.TrimSpace(prompt)
if prompt != "" {
parts = append(parts, "Context and spelling hints:\n"+prompt)
}
return strings.Join(parts, "\n\n")
}
func normalizeModelName(model string) string {
return strings.TrimPrefix(strings.TrimSpace(model), "models/")
}
func extractText(response generateContentResponse) string {
var texts []string
for _, candidate := range response.Candidates {
for _, part := range candidate.Content.Parts {
text := strings.TrimSpace(part.Text)
if text != "" {
texts = append(texts, text)
}
}
}
return strings.Join(texts, "\n")
}
func extractErrorMessage(responseBody []byte) string {
var response errorResponse
if err := json.Unmarshal(responseBody, &response); err == nil && response.Error.Message != "" {
return response.Error.Message
}
return string(responseBody)
}
package gemini
import (
"context"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/internal/ai"
)
func TestTranscribe(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/v1beta/models/gemini-2.5-flash:generateContent", r.URL.Path)
require.Equal(t, "test-key", r.Header.Get("x-goog-api-key"))
require.Equal(t, "application/json", r.Header.Get("Content-Type"))
var request struct {
Contents []struct {
Parts []struct {
Text string `json:"text"`
InlineData *struct {
MIMEType string `json:"mimeType"`
Data string `json:"data"`
} `json:"inlineData"`
} `json:"parts"`
} `json:"contents"`
GenerationConfig map[string]json.Number `json:"generationConfig"`
}
require.NoError(t, json.NewDecoder(r.Body).Decode(&request))
require.Len(t, request.Contents, 1)
require.Len(t, request.Contents[0].Parts, 2)
require.NotNil(t, request.Contents[0].Parts[0].InlineData)
require.Equal(t, "audio/mp3", request.Contents[0].Parts[0].InlineData.MIMEType)
audio, err := base64.StdEncoding.DecodeString(request.Contents[0].Parts[0].InlineData.Data)
require.NoError(t, err)
require.Equal(t, "audio bytes", string(audio))
require.Contains(t, request.Contents[0].Parts[1].Text, "Return only the transcript text")
require.Contains(t, request.Contents[0].Parts[1].Text, "Context and spelling hints")
require.Equal(t, json.Number("0"), request.GenerationConfig["temperature"])
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"candidates": []map[string]any{
{
"content": map[string]any{
"parts": []map[string]string{{"text": "hello from gemini"}},
},
},
},
}))
}))
defer server.Close()
transcriber, err := NewTranscriber(ai.ProviderConfig{
Endpoint: server.URL + "/v1beta",
APIKey: "test-key",
})
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
response, err := transcriber.Transcribe(ctx, ai.TranscribeRequest{
Model: "models/gemini-2.5-flash",
ContentType: "audio/mpeg",
Audio: strings.NewReader("audio bytes"),
Prompt: "Memos, Steven",
Language: "en",
})
require.NoError(t, err)
require.Equal(t, "hello from gemini", response.Text)
}
func TestTranscribeRejectsUnsupportedContentType(t *testing.T) {
t.Parallel()
transcriber, err := NewTranscriber(ai.ProviderConfig{
Endpoint: "https://example.com/v1beta",
APIKey: "test-key",
})
require.NoError(t, err)
_, err = transcriber.Transcribe(context.Background(), ai.TranscribeRequest{
Model: "gemini-2.5-flash",
ContentType: "video/mp4",
Audio: strings.NewReader("video bytes"),
})
require.Error(t, err)
require.Contains(t, err.Error(), "not supported by Gemini")
}
......@@ -232,8 +232,7 @@ message InstanceSetting {
AI_PROVIDER_TYPE_UNSPECIFIED = 0;
OPENAI = 1;
OPENAI_COMPATIBLE = 2;
ANTHROPIC = 3;
GEMINI = 4;
GEMINI = 3;
}
}
......
......@@ -99,8 +99,7 @@ const (
InstanceSetting_AI_PROVIDER_TYPE_UNSPECIFIED InstanceSetting_AIProviderType = 0
InstanceSetting_OPENAI InstanceSetting_AIProviderType = 1
InstanceSetting_OPENAI_COMPATIBLE InstanceSetting_AIProviderType = 2
InstanceSetting_ANTHROPIC InstanceSetting_AIProviderType = 3
InstanceSetting_GEMINI InstanceSetting_AIProviderType = 4
InstanceSetting_GEMINI InstanceSetting_AIProviderType = 3
)
// Enum value maps for InstanceSetting_AIProviderType.
......@@ -109,15 +108,13 @@ var (
0: "AI_PROVIDER_TYPE_UNSPECIFIED",
1: "OPENAI",
2: "OPENAI_COMPATIBLE",
3: "ANTHROPIC",
4: "GEMINI",
3: "GEMINI",
}
InstanceSetting_AIProviderType_value = map[string]int32{
"AI_PROVIDER_TYPE_UNSPECIFIED": 0,
"OPENAI": 1,
"OPENAI_COMPATIBLE": 2,
"ANTHROPIC": 3,
"GEMINI": 4,
"GEMINI": 3,
}
)
......@@ -1414,7 +1411,7 @@ const file_api_v1_instance_service_proto_rawDesc = "" +
"\x04demo\x18\x03 \x01(\bR\x04demo\x12!\n" +
"\finstance_url\x18\x06 \x01(\tR\vinstanceUrl\x12(\n" +
"\x05admin\x18\a \x01(\v2\x12.memos.api.v1.UserR\x05admin\"\x1b\n" +
"\x19GetInstanceProfileRequest\"\xe2\x1a\n" +
"\x19GetInstanceProfileRequest\"\xd3\x1a\n" +
"\x0fInstanceSetting\x12\x17\n" +
"\x04name\x18\x01 \x01(\tB\x03\xe0A\bR\x04name\x12W\n" +
"\x0fgeneral_setting\x18\x02 \x01(\v2,.memos.api.v1.InstanceSetting.GeneralSettingH\x00R\x0egeneralSetting\x12W\n" +
......@@ -1502,15 +1499,14 @@ const file_api_v1_instance_service_proto_rawDesc = "" +
"\fMEMO_RELATED\x10\x03\x12\b\n" +
"\x04TAGS\x10\x04\x12\x10\n" +
"\fNOTIFICATION\x10\x05\x12\x06\n" +
"\x02AI\x10\x06\"p\n" +
"\x02AI\x10\x06\"a\n" +
"\x0eAIProviderType\x12 \n" +
"\x1cAI_PROVIDER_TYPE_UNSPECIFIED\x10\x00\x12\n" +
"\n" +
"\x06OPENAI\x10\x01\x12\x15\n" +
"\x11OPENAI_COMPATIBLE\x10\x02\x12\r\n" +
"\tANTHROPIC\x10\x03\x12\n" +
"\x11OPENAI_COMPATIBLE\x10\x02\x12\n" +
"\n" +
"\x06GEMINI\x10\x04:a\xeaA^\n" +
"\x06GEMINI\x10\x03:a\xeaA^\n" +
"\x1cmemos.api.v1/InstanceSetting\x12\x1binstance/settings/{setting}*\x10instanceSettings2\x0finstanceSettingB\a\n" +
"\x05value\"U\n" +
"\x19GetInstanceSettingRequest\x128\n" +
......
......@@ -2420,7 +2420,6 @@ components:
- AI_PROVIDER_TYPE_UNSPECIFIED
- OPENAI
- OPENAI_COMPATIBLE
- ANTHROPIC
- GEMINI
type: string
format: enum
......
......@@ -99,8 +99,7 @@ const (
AIProviderType_AI_PROVIDER_TYPE_UNSPECIFIED AIProviderType = 0
AIProviderType_OPENAI AIProviderType = 1
AIProviderType_OPENAI_COMPATIBLE AIProviderType = 2
AIProviderType_ANTHROPIC AIProviderType = 3
AIProviderType_GEMINI AIProviderType = 4
AIProviderType_GEMINI AIProviderType = 3
)
// Enum value maps for AIProviderType.
......@@ -109,15 +108,13 @@ var (
0: "AI_PROVIDER_TYPE_UNSPECIFIED",
1: "OPENAI",
2: "OPENAI_COMPATIBLE",
3: "ANTHROPIC",
4: "GEMINI",
3: "GEMINI",
}
AIProviderType_value = map[string]int32{
"AI_PROVIDER_TYPE_UNSPECIFIED": 0,
"OPENAI": 1,
"OPENAI_COMPATIBLE": 2,
"ANTHROPIC": 3,
"GEMINI": 4,
"GEMINI": 3,
}
)
......@@ -1324,15 +1321,14 @@ const file_store_instance_setting_proto_rawDesc = "" +
"\fMEMO_RELATED\x10\x04\x12\b\n" +
"\x04TAGS\x10\x05\x12\x10\n" +
"\fNOTIFICATION\x10\x06\x12\x06\n" +
"\x02AI\x10\a*p\n" +
"\x02AI\x10\a*a\n" +
"\x0eAIProviderType\x12 \n" +
"\x1cAI_PROVIDER_TYPE_UNSPECIFIED\x10\x00\x12\n" +
"\n" +
"\x06OPENAI\x10\x01\x12\x15\n" +
"\x11OPENAI_COMPATIBLE\x10\x02\x12\r\n" +
"\tANTHROPIC\x10\x03\x12\n" +
"\x11OPENAI_COMPATIBLE\x10\x02\x12\n" +
"\n" +
"\x06GEMINI\x10\x04B\x9f\x01\n" +
"\x06GEMINI\x10\x03B\x9f\x01\n" +
"\x0fcom.memos.storeB\x14InstanceSettingProtoP\x01Z)github.com/usememos/memos/proto/gen/store\xa2\x02\x03MSX\xaa\x02\vMemos.Store\xca\x02\vMemos\\Store\xe2\x02\x17Memos\\Store\\GPBMetadata\xea\x02\fMemos::Storeb\x06proto3"
var (
......
......@@ -166,6 +166,5 @@ enum AIProviderType {
AI_PROVIDER_TYPE_UNSPECIFIED = 0;
OPENAI = 1;
OPENAI_COMPATIBLE = 2;
ANTHROPIC = 3;
GEMINI = 4;
GEMINI = 3;
}
......@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc/status"
"github.com/usememos/memos/internal/ai"
"github.com/usememos/memos/internal/ai/gemini"
"github.com/usememos/memos/internal/ai/openai"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
......@@ -25,16 +26,22 @@ const (
)
var supportedTranscriptionContentTypes = map[string]bool{
"audio/mpeg": true,
"audio/mp4": true,
"audio/mpga": true,
"audio/wav": true,
"audio/x-wav": true,
"audio/webm": true,
"audio/x-m4a": true,
"video/mp4": true,
"video/mpeg": true,
"video/webm": true,
"audio/aac": true,
"audio/aiff": true,
"audio/flac": true,
"audio/mpeg": true,
"audio/mp3": true,
"audio/mp4": true,
"audio/mpga": true,
"audio/ogg": true,
"audio/wav": true,
"audio/x-wav": true,
"audio/x-flac": true,
"audio/x-m4a": true,
"audio/webm": true,
"video/mp4": true,
"video/mpeg": true,
"video/webm": true,
}
// Transcribe transcribes an audio file using an instance AI provider.
......@@ -161,8 +168,6 @@ func convertAIProviderTypeFromStore(providerType storepb.AIProviderType) ai.Prov
return ai.ProviderOpenAI
case storepb.AIProviderType_OPENAI_COMPATIBLE:
return ai.ProviderOpenAICompatible
case storepb.AIProviderType_ANTHROPIC:
return ai.ProviderAnthropic
case storepb.AIProviderType_GEMINI:
return ai.ProviderGemini
default:
......@@ -174,6 +179,8 @@ func newAITranscriber(provider ai.ProviderConfig) (ai.Transcriber, error) {
switch provider.Type {
case ai.ProviderOpenAI, ai.ProviderOpenAICompatible:
return openai.NewTranscriber(provider)
case ai.ProviderGemini:
return gemini.NewTranscriber(provider)
default:
return nil, errors.Wrapf(ai.ErrCapabilityUnsupported, "provider type %q", provider.Type)
}
......
......@@ -97,6 +97,63 @@ func TestTranscribe(t *testing.T) {
require.Equal(t, "transcribed text", resp.Text)
})
t.Run("transcribes audio file with Gemini provider", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "gemini-user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
geminiServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/v1beta/models/gemini-2.5-flash:generateContent", r.URL.Path)
require.Equal(t, "gemini-key", r.Header.Get("x-goog-api-key"))
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"candidates": []map[string]any{
{
"content": map[string]any{
"parts": []map[string]string{{"text": "gemini transcript"}},
},
},
},
}))
}))
defer geminiServer.Close()
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_AI,
Value: &storepb.InstanceSetting_AiSetting{
AiSetting: &storepb.InstanceAISetting{
Providers: []*storepb.AIProviderConfig{
{
Id: "gemini-main",
Title: "Gemini",
Type: storepb.AIProviderType_GEMINI,
Endpoint: geminiServer.URL + "/v1beta",
ApiKey: "gemini-key",
Models: []string{"gemini-2.5-flash"},
DefaultModel: "gemini-2.5-flash",
},
},
},
},
})
require.NoError(t, err)
resp, err := ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{
ProviderId: "gemini-main",
Config: &v1pb.TranscriptionConfig{},
Audio: &v1pb.TranscriptionAudio{
Source: &v1pb.TranscriptionAudio_Content{Content: []byte("mp3 bytes")},
Filename: "voice.mp3",
ContentType: "audio/mp3",
},
})
require.NoError(t, err)
require.Equal(t, "gemini transcript", resp.Text)
})
t.Run("rejects unconfigured model", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
......
......@@ -41,7 +41,6 @@ type LocalAIProvider = {
const providerTypeOptions = [
InstanceSetting_AIProviderType.OPENAI,
InstanceSetting_AIProviderType.OPENAI_COMPATIBLE,
InstanceSetting_AIProviderType.ANTHROPIC,
InstanceSetting_AIProviderType.GEMINI,
];
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment