Unverified Commit a7fd1dac authored by boojack's avatar boojack Committed by GitHub

refactor(ai): use official provider SDKs (#5845)

parent f394e946
...@@ -19,6 +19,7 @@ require ( ...@@ -19,6 +19,7 @@ require (
github.com/lib/pq v1.11.2 github.com/lib/pq v1.11.2
github.com/lithammer/shortuuid/v4 v4.2.0 github.com/lithammer/shortuuid/v4 v4.2.0
github.com/mark3labs/mcp-go v0.45.0 github.com/mark3labs/mcp-go v0.45.0
github.com/openai/openai-go/v3 v3.31.0
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/spf13/cobra v1.10.2 github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0
...@@ -32,6 +33,7 @@ require ( ...@@ -32,6 +33,7 @@ require (
golang.org/x/net v0.52.0 golang.org/x/net v0.52.0
golang.org/x/oauth2 v0.36.0 golang.org/x/oauth2 v0.36.0
golang.org/x/sync v0.20.0 golang.org/x/sync v0.20.0
google.golang.org/genai v1.54.0
google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5 google.golang.org/genproto v0.0.0-20260316180232-0b37fe3546d5
google.golang.org/genproto/googleapis/api v0.0.0-20260316172706-e463d84ca32d google.golang.org/genproto/googleapis/api v0.0.0-20260316172706-e463d84ca32d
google.golang.org/grpc v1.79.2 google.golang.org/grpc v1.79.2
...@@ -40,6 +42,9 @@ require ( ...@@ -40,6 +42,9 @@ require (
require ( require (
cel.dev/expr v0.25.1 // indirect cel.dev/expr v0.25.1 // indirect
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/auth v0.9.3 // indirect
cloud.google.com/go/compute/metadata v0.9.0 // indirect
dario.cat/mergo v1.0.2 // indirect dario.cat/mergo v1.0.2 // indirect
filippo.io/edwards25519 v1.1.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect
...@@ -66,6 +71,11 @@ require ( ...@@ -66,6 +71,11 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-ole/go-ole v1.2.6 // indirect
github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect
github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/s2a-go v0.1.8 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.4 // indirect
github.com/gorilla/websocket v1.5.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/invopop/jsonschema v0.13.0 // indirect github.com/invopop/jsonschema v0.13.0 // indirect
github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/compress v1.18.2 // indirect
...@@ -94,11 +104,16 @@ require ( ...@@ -94,11 +104,16 @@ require (
github.com/spf13/cast v1.10.0 // indirect github.com/spf13/cast v1.10.0 // indirect
github.com/spf13/pflag v1.0.10 // indirect github.com/spf13/pflag v1.0.10 // indirect
github.com/subosito/gotenv v1.6.0 // indirect github.com/subosito/gotenv v1.6.0 // indirect
github.com/tidwall/gjson v1.18.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/tklauser/go-sysconf v0.3.16 // indirect github.com/tklauser/go-sysconf v0.3.16 // indirect
github.com/tklauser/numcpus v0.11.0 // indirect github.com/tklauser/numcpus v0.11.0 // indirect
github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
go.opentelemetry.io/otel v1.41.0 // indirect go.opentelemetry.io/otel v1.41.0 // indirect
......
This diff is collapsed.
package ai
import (
"net/http"
"net/url"
"strings"
"time"
"github.com/pkg/errors"
)
const defaultHTTPTimeout = 2 * time.Minute
type transcriberOptions struct {
httpClient *http.Client
}
// TranscriberOption configures a transcriber.
type TranscriberOption func(*transcriberOptions)
// WithHTTPClient sets the HTTP client used by a transcriber.
func WithHTTPClient(client *http.Client) TranscriberOption {
return func(options *transcriberOptions) {
if client != nil {
options.httpClient = client
}
}
}
// NewTranscriber creates a transcriber for a provider.
func NewTranscriber(config ProviderConfig, options ...TranscriberOption) (Transcriber, error) {
transcriberOptions := transcriberOptions{
httpClient: &http.Client{Timeout: defaultHTTPTimeout},
}
for _, applyOption := range options {
applyOption(&transcriberOptions)
}
switch config.Type {
case ProviderOpenAI:
return newOpenAITranscriber(config, transcriberOptions)
case ProviderGemini:
return newGeminiTranscriber(config, transcriberOptions)
default:
return nil, errors.Wrapf(ErrCapabilityUnsupported, "provider type %q", config.Type)
}
}
func normalizeEndpoint(endpoint string, defaultEndpoint string, providerName string) (string, error) {
endpoint = strings.TrimSpace(endpoint)
if endpoint == "" {
endpoint = defaultEndpoint
}
if _, err := url.ParseRequestURI(endpoint); err != nil {
return "", errors.Wrapf(err, "invalid %s endpoint", providerName)
}
return strings.TrimRight(endpoint, "/"), nil
}
func requireAPIKey(apiKey string, providerName string) error {
if apiKey == "" {
return errors.Errorf("%s API key is required", providerName)
}
return nil
}
package gemini package ai
import ( import (
"bytes"
"context" "context"
"encoding/base64"
"encoding/json"
"io" "io"
"mime" "mime"
"net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
"google.golang.org/genai"
"github.com/usememos/memos/internal/ai"
) )
const ( const (
transcriptionInstruction = `Transcribe the audio accurately. Return only the transcript text. Do not summarize, explain, or add content that is not spoken.` defaultGeminiEndpoint = "https://generativelanguage.googleapis.com/v1beta"
maxInlineAudioSizeBytes = 14 * 1024 * 1024 geminiTranscriptionPrompt = `Transcribe the audio accurately. Return only the transcript text. Do not summarize, explain, or add content that is not spoken.`
maxGeminiInlineAudioSize = 14 * 1024 * 1024
defaultGeminiAPIVersion = "v1beta"
geminiProviderDisplayName = "Gemini"
geminiDefaultTemperature = float32(0)
) )
var supportedContentTypes = map[string]string{ var geminiSupportedContentTypes = map[string]string{
"audio/wav": "audio/wav", "audio/wav": "audio/wav",
"audio/x-wav": "audio/wav", "audio/x-wav": "audio/wav",
"audio/mp3": "audio/mp3", "audio/mp3": "audio/mp3",
...@@ -33,45 +32,45 @@ var supportedContentTypes = map[string]string{ ...@@ -33,45 +32,45 @@ var supportedContentTypes = map[string]string{
"audio/x-flac": "audio/flac", "audio/x-flac": "audio/flac",
} }
type generateContentRequest struct { type geminiTranscriber struct {
Contents []content `json:"contents"` client *genai.Client
GenerationConfig map[string]json.Number `json:"generationConfig,omitempty"`
}
type content struct {
Role string `json:"role,omitempty"`
Parts []part `json:"parts"`
}
type part struct {
Text string `json:"text,omitempty"`
InlineData *inlineData `json:"inlineData,omitempty"`
}
type inlineData struct {
MIMEType string `json:"mimeType"`
Data string `json:"data"`
} }
type generateContentResponse struct { func newGeminiTranscriber(config ProviderConfig, options transcriberOptions) (*geminiTranscriber, error) {
Candidates []struct { endpoint, err := normalizeEndpoint(config.Endpoint, defaultGeminiEndpoint, geminiProviderDisplayName)
Content struct { if err != nil {
Parts []struct { return nil, err
Text string `json:"text"` }
} `json:"parts"` if err := requireAPIKey(config.APIKey, geminiProviderDisplayName); err != nil {
} `json:"content"` return nil, err
} `json:"candidates"` }
} baseURL, apiVersion, err := normalizeGeminiEndpoint(endpoint)
if err != nil {
return nil, err
}
httpOptions := genai.HTTPOptions{
BaseURL: baseURL,
APIVersion: apiVersion,
}
if options.httpClient.Timeout > 0 {
timeout := options.httpClient.Timeout
httpOptions.Timeout = &timeout
}
type errorResponse struct { client, err := genai.NewClient(context.Background(), &genai.ClientConfig{
Error struct { APIKey: config.APIKey,
Message string `json:"message"` Backend: genai.BackendGeminiAPI,
Status string `json:"status"` HTTPClient: options.httpClient,
} `json:"error"` HTTPOptions: httpOptions,
})
if err != nil {
return nil, errors.Wrap(err, "failed to create Gemini client")
}
return &geminiTranscriber{client: client}, nil
} }
// Transcribe transcribes audio with Gemini generateContent. // Transcribe transcribes audio with Gemini generateContent.
func (t *Transcriber) Transcribe(ctx context.Context, request ai.TranscribeRequest) (*ai.TranscribeResponse, error) { func (t *geminiTranscriber) Transcribe(ctx context.Context, request TranscribeRequest) (*TranscribeResponse, error) {
if strings.TrimSpace(request.Model) == "" { if strings.TrimSpace(request.Model) == "" {
return nil, errors.New("model is required") return nil, errors.New("model is required")
} }
...@@ -85,85 +84,68 @@ func (t *Transcriber) Transcribe(ctx context.Context, request ai.TranscribeReque ...@@ -85,85 +84,68 @@ func (t *Transcriber) Transcribe(ctx context.Context, request ai.TranscribeReque
if len(audio) == 0 { if len(audio) == 0 {
return nil, errors.New("audio is required") return nil, errors.New("audio is required")
} }
if len(audio) > maxInlineAudioSizeBytes { if len(audio) > maxGeminiInlineAudioSize {
return nil, errors.Errorf("audio is too large for Gemini inline transcription; maximum size is %d bytes", maxInlineAudioSizeBytes) return nil, errors.Errorf("audio is too large for Gemini inline transcription; maximum size is %d bytes", maxGeminiInlineAudioSize)
} }
contentType, err := normalizeContentType(request.ContentType) contentType, err := normalizeGeminiContentType(request.ContentType)
if err != nil { if err != nil {
return nil, err return nil, err
} }
prompt := buildTranscriptionPrompt(request.Prompt, request.Language) prompt := buildGeminiTranscriptionPrompt(request.Prompt, request.Language)
body, err := json.Marshal(generateContentRequest{ temperature := geminiDefaultTemperature
Contents: []content{ response, err := t.client.Models.GenerateContent(ctx, normalizeGeminiModelName(request.Model), []*genai.Content{
{ genai.NewContentFromParts([]*genai.Part{
Role: "user", genai.NewPartFromBytes(audio, contentType),
Parts: []part{ genai.NewPartFromText(prompt),
{InlineData: &inlineData{ }, genai.RoleUser),
MIMEType: contentType, }, &genai.GenerateContentConfig{
Data: base64.StdEncoding.EncodeToString(audio), Temperature: &temperature,
}},
{Text: prompt},
},
},
},
GenerationConfig: map[string]json.Number{
"temperature": json.Number("0"),
},
}) })
if err != nil {
return nil, errors.Wrap(err, "failed to marshal Gemini transcription request")
}
httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, t.endpoint+"/models/"+url.PathEscape(normalizeModelName(request.Model))+":generateContent", bytes.NewReader(body))
if err != nil {
return nil, errors.Wrap(err, "failed to create Gemini transcription request")
}
httpRequest.Header.Set("Content-Type", "application/json")
httpRequest.Header.Set("x-goog-api-key", t.apiKey)
httpResponse, err := t.httpClient.Do(httpRequest)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to send Gemini transcription request") return nil, errors.Wrap(err, "failed to send Gemini transcription request")
} }
defer httpResponse.Body.Close() text := strings.TrimSpace(response.Text())
responseBody, err := io.ReadAll(httpResponse.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read Gemini transcription response")
}
if httpResponse.StatusCode < http.StatusOK || httpResponse.StatusCode >= http.StatusMultipleChoices {
return nil, errors.Errorf("Gemini transcription request failed with status %d: %s", httpResponse.StatusCode, extractErrorMessage(responseBody))
}
var response generateContentResponse
if err := json.Unmarshal(responseBody, &response); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal Gemini transcription response")
}
text := extractText(response)
if text == "" { if text == "" {
return nil, errors.New("Gemini transcription response did not include text") return nil, errors.New("Gemini transcription response did not include text")
} }
return &ai.TranscribeResponse{ return &TranscribeResponse{
Text: text, Text: text,
}, nil }, nil
} }
func normalizeContentType(contentType string) (string, error) { func normalizeGeminiEndpoint(endpoint string) (string, string, error) {
parsed, err := url.Parse(endpoint)
if err != nil {
return "", "", errors.Wrap(err, "invalid Gemini endpoint")
}
path := strings.TrimRight(parsed.Path, "/")
apiVersion := defaultGeminiAPIVersion
for _, supportedVersion := range []string{"v1alpha", "v1beta", "v1"} {
if path == "/"+supportedVersion || strings.HasSuffix(path, "/"+supportedVersion) {
apiVersion = supportedVersion
parsed.Path = strings.TrimSuffix(path, "/"+supportedVersion)
break
}
}
return strings.TrimRight(parsed.String(), "/"), apiVersion, nil
}
func normalizeGeminiContentType(contentType string) (string, error) {
mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType)) mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType))
if err != nil { if err != nil {
return "", errors.Wrap(err, "invalid audio content type") return "", errors.Wrap(err, "invalid audio content type")
} }
mediaType = strings.ToLower(mediaType) mediaType = strings.ToLower(mediaType)
normalized, ok := supportedContentTypes[mediaType] normalized, ok := geminiSupportedContentTypes[mediaType]
if !ok { if !ok {
return "", errors.Errorf("audio content type %q is not supported by Gemini", mediaType) return "", errors.Errorf("audio content type %q is not supported by Gemini", mediaType)
} }
return normalized, nil return normalized, nil
} }
func buildTranscriptionPrompt(prompt string, language string) string { func buildGeminiTranscriptionPrompt(prompt string, language string) string {
parts := []string{transcriptionInstruction} parts := []string{geminiTranscriptionPrompt}
language = strings.TrimSpace(language) language = strings.TrimSpace(language)
if language != "" { if language != "" {
parts = append(parts, "The input language is "+language+".") parts = append(parts, "The input language is "+language+".")
...@@ -175,27 +157,6 @@ func buildTranscriptionPrompt(prompt string, language string) string { ...@@ -175,27 +157,6 @@ func buildTranscriptionPrompt(prompt string, language string) string {
return strings.Join(parts, "\n\n") return strings.Join(parts, "\n\n")
} }
func normalizeModelName(model string) string { func normalizeGeminiModelName(model string) string {
return strings.TrimPrefix(strings.TrimSpace(model), "models/") return strings.TrimPrefix(strings.TrimSpace(model), "models/")
} }
func extractText(response generateContentResponse) string {
var texts []string
for _, candidate := range response.Candidates {
for _, part := range candidate.Content.Parts {
text := strings.TrimSpace(part.Text)
if text != "" {
texts = append(texts, text)
}
}
}
return strings.Join(texts, "\n")
}
func extractErrorMessage(responseBody []byte) string {
var response errorResponse
if err := json.Unmarshal(responseBody, &response); err == nil && response.Error.Message != "" {
return response.Error.Message
}
return string(responseBody)
}
package gemini
import (
"net/http"
"net/url"
"strings"
"time"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/ai"
)
const defaultEndpoint = "https://generativelanguage.googleapis.com/v1beta"
// Transcriber transcribes audio with Gemini audio understanding.
type Transcriber struct {
endpoint string
apiKey string
httpClient *http.Client
}
// NewTranscriber creates a new Gemini transcriber.
func NewTranscriber(config ai.ProviderConfig, options ...Option) (*Transcriber, error) {
endpoint := strings.TrimSpace(config.Endpoint)
if endpoint == "" {
endpoint = defaultEndpoint
}
if _, err := url.ParseRequestURI(endpoint); err != nil {
return nil, errors.Wrap(err, "invalid Gemini endpoint")
}
if config.APIKey == "" {
return nil, errors.New("Gemini API key is required")
}
transcriber := &Transcriber{
endpoint: strings.TrimRight(endpoint, "/"),
apiKey: config.APIKey,
httpClient: &http.Client{
Timeout: 2 * time.Minute,
},
}
for _, option := range options {
option(transcriber)
}
return transcriber, nil
}
// Option configures a Transcriber.
type Option func(*Transcriber)
// WithHTTPClient sets the HTTP client used by the transcriber.
func WithHTTPClient(client *http.Client) Option {
return func(t *Transcriber) {
if client != nil {
t.httpClient = client
}
}
}
package gemini package ai
import ( import (
"context" "context"
...@@ -11,11 +11,9 @@ import ( ...@@ -11,11 +11,9 @@ import (
"time" "time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/usememos/memos/internal/ai"
) )
func TestTranscribe(t *testing.T) { func TestGeminiTranscribe(t *testing.T) {
t.Parallel() t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
...@@ -61,7 +59,8 @@ func TestTranscribe(t *testing.T) { ...@@ -61,7 +59,8 @@ func TestTranscribe(t *testing.T) {
})) }))
defer server.Close() defer server.Close()
transcriber, err := NewTranscriber(ai.ProviderConfig{ transcriber, err := NewTranscriber(ProviderConfig{
Type: ProviderGemini,
Endpoint: server.URL + "/v1beta", Endpoint: server.URL + "/v1beta",
APIKey: "test-key", APIKey: "test-key",
}) })
...@@ -69,7 +68,7 @@ func TestTranscribe(t *testing.T) { ...@@ -69,7 +68,7 @@ func TestTranscribe(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
response, err := transcriber.Transcribe(ctx, ai.TranscribeRequest{ response, err := transcriber.Transcribe(ctx, TranscribeRequest{
Model: "models/gemini-2.5-flash", Model: "models/gemini-2.5-flash",
ContentType: "audio/mpeg", ContentType: "audio/mpeg",
Audio: strings.NewReader("audio bytes"), Audio: strings.NewReader("audio bytes"),
...@@ -80,16 +79,17 @@ func TestTranscribe(t *testing.T) { ...@@ -80,16 +79,17 @@ func TestTranscribe(t *testing.T) {
require.Equal(t, "hello from gemini", response.Text) require.Equal(t, "hello from gemini", response.Text)
} }
func TestTranscribeRejectsUnsupportedContentType(t *testing.T) { func TestGeminiTranscribeRejectsUnsupportedContentType(t *testing.T) {
t.Parallel() t.Parallel()
transcriber, err := NewTranscriber(ai.ProviderConfig{ transcriber, err := NewTranscriber(ProviderConfig{
Type: ProviderGemini,
Endpoint: "https://example.com/v1beta", Endpoint: "https://example.com/v1beta",
APIKey: "test-key", APIKey: "test-key",
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = transcriber.Transcribe(context.Background(), ai.TranscribeRequest{ _, err = transcriber.Transcribe(context.Background(), TranscribeRequest{
Model: "gemini-2.5-flash", Model: "gemini-2.5-flash",
ContentType: "video/mp4", ContentType: "video/mp4",
Audio: strings.NewReader("video bytes"), Audio: strings.NewReader("video bytes"),
......
package ai
import (
"context"
"mime"
"strings"
openaisdk "github.com/openai/openai-go/v3"
openaioption "github.com/openai/openai-go/v3/option"
"github.com/pkg/errors"
)
const defaultOpenAIEndpoint = "https://api.openai.com/v1"
type openAITranscriber struct {
client openaisdk.Client
}
func newOpenAITranscriber(config ProviderConfig, options transcriberOptions) (*openAITranscriber, error) {
endpoint, err := normalizeEndpoint(config.Endpoint, defaultOpenAIEndpoint, "OpenAI")
if err != nil {
return nil, err
}
if err := requireAPIKey(config.APIKey, "OpenAI"); err != nil {
return nil, err
}
return &openAITranscriber{
client: openaisdk.NewClient(
openaioption.WithAPIKey(config.APIKey),
openaioption.WithBaseURL(endpoint),
openaioption.WithHTTPClient(options.httpClient),
),
}, nil
}
// Transcribe transcribes audio with the OpenAI /audio/transcriptions endpoint.
func (t *openAITranscriber) Transcribe(ctx context.Context, request TranscribeRequest) (*TranscribeResponse, error) {
if strings.TrimSpace(request.Model) == "" {
return nil, errors.New("model is required")
}
if request.Audio == nil {
return nil, errors.New("audio is required")
}
filename, contentType, err := normalizeOpenAIAudioFileMetadata(request)
if err != nil {
return nil, err
}
params := openaisdk.AudioTranscriptionNewParams{
File: openaisdk.File(request.Audio, filename, contentType),
Model: openaisdk.AudioModel(request.Model),
ResponseFormat: openaisdk.AudioResponseFormatJSON,
}
if request.Prompt != "" {
params.Prompt = openaisdk.String(request.Prompt)
}
if request.Language != "" {
params.Language = openaisdk.String(request.Language)
}
response, err := t.client.Audio.Transcriptions.New(ctx, params)
if err != nil {
return nil, errors.Wrap(err, "failed to send OpenAI transcription request")
}
return &TranscribeResponse{
Text: response.Text,
Language: response.Language,
Duration: response.Duration,
}, nil
}
func normalizeOpenAIAudioFileMetadata(request TranscribeRequest) (string, string, error) {
filename := strings.TrimSpace(request.Filename)
if filename == "" {
filename = "audio"
}
contentType := strings.TrimSpace(request.ContentType)
if contentType == "" {
contentType = "application/octet-stream"
} else {
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return "", "", errors.Wrap(err, "invalid audio content type")
}
contentType = mediaType
}
return sanitizeFilename(filename), contentType, nil
}
func sanitizeFilename(filename string) string {
filename = strings.NewReplacer("\r", "_", "\n", "_").Replace(filename)
if strings.TrimSpace(filename) == "" {
return "audio"
}
return filename
}
package openai
import (
"net/http"
"net/url"
"strings"
"time"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/ai"
)
const defaultEndpoint = "https://api.openai.com/v1"
// Transcriber transcribes audio with OpenAI-compatible transcription APIs.
type Transcriber struct {
endpoint string
apiKey string
httpClient *http.Client
}
// NewTranscriber creates a new OpenAI-compatible transcriber.
func NewTranscriber(config ai.ProviderConfig, options ...Option) (*Transcriber, error) {
endpoint := strings.TrimSpace(config.Endpoint)
if endpoint == "" {
endpoint = defaultEndpoint
}
if _, err := url.ParseRequestURI(endpoint); err != nil {
return nil, errors.Wrap(err, "invalid OpenAI endpoint")
}
if config.APIKey == "" {
return nil, errors.New("OpenAI API key is required")
}
transcriber := &Transcriber{
endpoint: endpoint,
apiKey: config.APIKey,
httpClient: &http.Client{
Timeout: 2 * time.Minute,
},
}
for _, option := range options {
option(transcriber)
}
return transcriber, nil
}
// Option configures a Transcriber.
type Option func(*Transcriber)
// WithHTTPClient sets the HTTP client used by the transcriber.
func WithHTTPClient(client *http.Client) Option {
return func(t *Transcriber) {
if client != nil {
t.httpClient = client
}
}
}
package openai
import (
"bytes"
"context"
"encoding/json"
"io"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/ai"
)
type transcriptionResponse struct {
Text string `json:"text"`
Language string `json:"language"`
Duration float64 `json:"duration"`
}
type errorResponse struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code"`
} `json:"error"`
}
// Transcribe transcribes audio with the /audio/transcriptions endpoint.
func (t *Transcriber) Transcribe(ctx context.Context, request ai.TranscribeRequest) (*ai.TranscribeResponse, error) {
if strings.TrimSpace(request.Model) == "" {
return nil, errors.New("model is required")
}
if request.Audio == nil {
return nil, errors.New("audio is required")
}
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
if err := writeAudioFilePart(writer, request); err != nil {
return nil, err
}
if err := writer.WriteField("model", request.Model); err != nil {
return nil, errors.Wrap(err, "failed to write model field")
}
if err := writer.WriteField("response_format", "json"); err != nil {
return nil, errors.Wrap(err, "failed to write response format field")
}
if request.Prompt != "" {
if err := writer.WriteField("prompt", request.Prompt); err != nil {
return nil, errors.Wrap(err, "failed to write prompt field")
}
}
if request.Language != "" {
if err := writer.WriteField("language", request.Language); err != nil {
return nil, errors.Wrap(err, "failed to write language field")
}
}
if err := writer.Close(); err != nil {
return nil, errors.Wrap(err, "failed to close multipart writer")
}
httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(t.endpoint, "/")+"/audio/transcriptions", body)
if err != nil {
return nil, errors.Wrap(err, "failed to create transcription request")
}
httpRequest.Header.Set("Authorization", "Bearer "+t.apiKey)
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
httpResponse, err := t.httpClient.Do(httpRequest)
if err != nil {
return nil, errors.Wrap(err, "failed to send transcription request")
}
defer httpResponse.Body.Close()
responseBody, err := io.ReadAll(httpResponse.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read transcription response")
}
if httpResponse.StatusCode < http.StatusOK || httpResponse.StatusCode >= http.StatusMultipleChoices {
return nil, errors.Errorf("transcription request failed with status %d: %s", httpResponse.StatusCode, extractErrorMessage(responseBody))
}
var response transcriptionResponse
if err := json.Unmarshal(responseBody, &response); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal transcription response")
}
return &ai.TranscribeResponse{
Text: response.Text,
Language: response.Language,
Duration: response.Duration,
}, nil
}
func writeAudioFilePart(writer *multipart.Writer, request ai.TranscribeRequest) error {
filename := strings.TrimSpace(request.Filename)
if filename == "" {
filename = "audio"
}
contentType := strings.TrimSpace(request.ContentType)
if contentType == "" {
contentType = "application/octet-stream"
} else {
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return errors.Wrap(err, "invalid audio content type")
}
contentType = mediaType
}
header := make(textproto.MIMEHeader)
header.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{
"name": "file",
"filename": sanitizeFilename(filename),
}))
header.Set("Content-Type", contentType)
part, err := writer.CreatePart(header)
if err != nil {
return errors.Wrap(err, "failed to create audio file part")
}
if _, err := io.Copy(part, request.Audio); err != nil {
return errors.Wrap(err, "failed to write audio file part")
}
return nil
}
func extractErrorMessage(responseBody []byte) string {
var response errorResponse
if err := json.Unmarshal(responseBody, &response); err == nil && response.Error.Message != "" {
return response.Error.Message
}
return string(responseBody)
}
func sanitizeFilename(filename string) string {
filename = strings.NewReplacer("\r", "_", "\n", "_").Replace(filename)
if strings.TrimSpace(filename) == "" {
return "audio"
}
return filename
}
package openai package ai
import ( import (
"context" "context"
...@@ -10,11 +10,9 @@ import ( ...@@ -10,11 +10,9 @@ import (
"time" "time"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/usememos/memos/internal/ai"
) )
func TestTranscribe(t *testing.T) { func TestOpenAITranscribe(t *testing.T) {
t.Parallel() t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
...@@ -42,7 +40,8 @@ func TestTranscribe(t *testing.T) { ...@@ -42,7 +40,8 @@ func TestTranscribe(t *testing.T) {
})) }))
defer server.Close() defer server.Close()
transcriber, err := NewTranscriber(ai.ProviderConfig{ transcriber, err := NewTranscriber(ProviderConfig{
Type: ProviderOpenAI,
Endpoint: server.URL, Endpoint: server.URL,
APIKey: "test-key", APIKey: "test-key",
}) })
...@@ -50,7 +49,7 @@ func TestTranscribe(t *testing.T) { ...@@ -50,7 +49,7 @@ func TestTranscribe(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
response, err := transcriber.Transcribe(ctx, ai.TranscribeRequest{ response, err := transcriber.Transcribe(ctx, TranscribeRequest{
Model: "gpt-4o-transcribe", Model: "gpt-4o-transcribe",
Filename: "voice.wav", Filename: "voice.wav",
ContentType: "audio/wav", ContentType: "audio/wav",
......
...@@ -7,13 +7,10 @@ import ( ...@@ -7,13 +7,10 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/pkg/errors"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/usememos/memos/internal/ai" "github.com/usememos/memos/internal/ai"
"github.com/usememos/memos/internal/ai/gemini"
"github.com/usememos/memos/internal/ai/openai"
v1pb "github.com/usememos/memos/proto/gen/api/v1" v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store" storepb "github.com/usememos/memos/proto/gen/store"
) )
...@@ -97,7 +94,7 @@ func (s *APIV1Service) Transcribe(ctx context.Context, request *v1pb.TranscribeR ...@@ -97,7 +94,7 @@ func (s *APIV1Service) Transcribe(ctx context.Context, request *v1pb.TranscribeR
if err != nil { if err != nil {
return nil, err return nil, err
} }
transcriber, err := newAITranscriber(provider) transcriber, err := ai.NewTranscriber(provider)
if err != nil { if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to create AI transcriber: %v", err) return nil, status.Errorf(codes.InvalidArgument, "failed to create AI transcriber: %v", err)
} }
...@@ -165,17 +162,6 @@ func convertAIProviderTypeFromStore(providerType storepb.AIProviderType) ai.Prov ...@@ -165,17 +162,6 @@ func convertAIProviderTypeFromStore(providerType storepb.AIProviderType) ai.Prov
} }
} }
func newAITranscriber(provider ai.ProviderConfig) (ai.Transcriber, error) {
switch provider.Type {
case ai.ProviderOpenAI:
return openai.NewTranscriber(provider)
case ai.ProviderGemini:
return gemini.NewTranscriber(provider)
default:
return nil, errors.Wrapf(ai.ErrCapabilityUnsupported, "provider type %q", provider.Type)
}
}
func isSupportedTranscriptionContentType(contentType string) bool { func isSupportedTranscriptionContentType(contentType string) bool {
mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType)) mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType))
if err != nil { if err != nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment