Unverified Commit 83ed32f1 authored by memoclaw's avatar memoclaw Committed by GitHub

feat(ai): add instance AI providers and transcription (#5829)

Co-authored-by: 's avatarmemoclaw <265580040+memoclaw@users.noreply.github.com>
parent 40fd700f
package ai
// ProviderType identifies an AI provider implementation.
type ProviderType string
const (
// ProviderOpenAI is OpenAI's hosted API.
ProviderOpenAI ProviderType = "OPENAI"
// ProviderOpenAICompatible is an OpenAI-compatible API endpoint.
ProviderOpenAICompatible ProviderType = "OPENAI_COMPATIBLE"
// ProviderAnthropic is Anthropic's API.
ProviderAnthropic ProviderType = "ANTHROPIC"
// ProviderGemini is Google's Gemini API.
ProviderGemini ProviderType = "GEMINI"
)
// ProviderConfig configures a callable AI provider connection.
type ProviderConfig struct {
ID string
Title string
Type ProviderType
Endpoint string
APIKey string
Models []string
DefaultModel string
}
package ai
import "github.com/pkg/errors"
var (
// ErrProviderNotFound indicates that a requested provider ID does not exist.
ErrProviderNotFound = errors.New("AI provider not found")
// ErrCapabilityUnsupported indicates that the provider does not support the requested capability.
ErrCapabilityUnsupported = errors.New("AI provider capability unsupported")
)
package openai
import (
"net/http"
"net/url"
"strings"
"time"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/ai"
)
const defaultEndpoint = "https://api.openai.com/v1"
// Transcriber transcribes audio with OpenAI-compatible transcription APIs.
type Transcriber struct {
endpoint string
apiKey string
httpClient *http.Client
}
// NewTranscriber creates a new OpenAI-compatible transcriber.
func NewTranscriber(config ai.ProviderConfig, options ...Option) (*Transcriber, error) {
endpoint := strings.TrimSpace(config.Endpoint)
if endpoint == "" {
endpoint = defaultEndpoint
}
if _, err := url.ParseRequestURI(endpoint); err != nil {
return nil, errors.Wrap(err, "invalid OpenAI endpoint")
}
if config.APIKey == "" {
return nil, errors.New("OpenAI API key is required")
}
transcriber := &Transcriber{
endpoint: endpoint,
apiKey: config.APIKey,
httpClient: &http.Client{
Timeout: 2 * time.Minute,
},
}
for _, option := range options {
option(transcriber)
}
return transcriber, nil
}
// Option configures a Transcriber.
type Option func(*Transcriber)
// WithHTTPClient sets the HTTP client used by the transcriber.
func WithHTTPClient(client *http.Client) Option {
return func(t *Transcriber) {
if client != nil {
t.httpClient = client
}
}
}
package openai
import (
"bytes"
"context"
"encoding/json"
"io"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/ai"
)
type transcriptionResponse struct {
Text string `json:"text"`
Language string `json:"language"`
Duration float64 `json:"duration"`
}
type errorResponse struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code"`
} `json:"error"`
}
// Transcribe transcribes audio with the /audio/transcriptions endpoint.
func (t *Transcriber) Transcribe(ctx context.Context, request ai.TranscribeRequest) (*ai.TranscribeResponse, error) {
if strings.TrimSpace(request.Model) == "" {
return nil, errors.New("model is required")
}
if request.Audio == nil {
return nil, errors.New("audio is required")
}
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
if err := writeAudioFilePart(writer, request); err != nil {
return nil, err
}
if err := writer.WriteField("model", request.Model); err != nil {
return nil, errors.Wrap(err, "failed to write model field")
}
if err := writer.WriteField("response_format", "json"); err != nil {
return nil, errors.Wrap(err, "failed to write response format field")
}
if request.Prompt != "" {
if err := writer.WriteField("prompt", request.Prompt); err != nil {
return nil, errors.Wrap(err, "failed to write prompt field")
}
}
if request.Language != "" {
if err := writer.WriteField("language", request.Language); err != nil {
return nil, errors.Wrap(err, "failed to write language field")
}
}
if err := writer.Close(); err != nil {
return nil, errors.Wrap(err, "failed to close multipart writer")
}
httpRequest, err := http.NewRequestWithContext(ctx, http.MethodPost, strings.TrimRight(t.endpoint, "/")+"/audio/transcriptions", body)
if err != nil {
return nil, errors.Wrap(err, "failed to create transcription request")
}
httpRequest.Header.Set("Authorization", "Bearer "+t.apiKey)
httpRequest.Header.Set("Content-Type", writer.FormDataContentType())
httpResponse, err := t.httpClient.Do(httpRequest)
if err != nil {
return nil, errors.Wrap(err, "failed to send transcription request")
}
defer httpResponse.Body.Close()
responseBody, err := io.ReadAll(httpResponse.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to read transcription response")
}
if httpResponse.StatusCode < http.StatusOK || httpResponse.StatusCode >= http.StatusMultipleChoices {
return nil, errors.Errorf("transcription request failed with status %d: %s", httpResponse.StatusCode, extractErrorMessage(responseBody))
}
var response transcriptionResponse
if err := json.Unmarshal(responseBody, &response); err != nil {
return nil, errors.Wrap(err, "failed to unmarshal transcription response")
}
return &ai.TranscribeResponse{
Text: response.Text,
Language: response.Language,
Duration: response.Duration,
}, nil
}
func writeAudioFilePart(writer *multipart.Writer, request ai.TranscribeRequest) error {
filename := strings.TrimSpace(request.Filename)
if filename == "" {
filename = "audio"
}
contentType := strings.TrimSpace(request.ContentType)
if contentType == "" {
contentType = "application/octet-stream"
} else {
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return errors.Wrap(err, "invalid audio content type")
}
contentType = mediaType
}
header := make(textproto.MIMEHeader)
header.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{
"name": "file",
"filename": sanitizeFilename(filename),
}))
header.Set("Content-Type", contentType)
part, err := writer.CreatePart(header)
if err != nil {
return errors.Wrap(err, "failed to create audio file part")
}
if _, err := io.Copy(part, request.Audio); err != nil {
return errors.Wrap(err, "failed to write audio file part")
}
return nil
}
func extractErrorMessage(responseBody []byte) string {
var response errorResponse
if err := json.Unmarshal(responseBody, &response); err == nil && response.Error.Message != "" {
return response.Error.Message
}
return string(responseBody)
}
func sanitizeFilename(filename string) string {
filename = strings.NewReplacer("\r", "_", "\n", "_").Replace(filename)
if strings.TrimSpace(filename) == "" {
return "audio"
}
return filename
}
package openai
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/internal/ai"
)
func TestTranscribe(t *testing.T) {
t.Parallel()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/audio/transcriptions", r.URL.Path)
require.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
require.NoError(t, r.ParseMultipartForm(10<<20))
require.Equal(t, "gpt-4o-transcribe", r.FormValue("model"))
require.Equal(t, "json", r.FormValue("response_format"))
require.Equal(t, "domain words", r.FormValue("prompt"))
require.Equal(t, "en", r.FormValue("language"))
file, header, err := r.FormFile("file")
require.NoError(t, err)
defer file.Close()
require.Equal(t, "voice.wav", header.Filename)
require.Equal(t, "audio/wav", header.Header.Get("Content-Type"))
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]any{
"text": "hello world",
"language": "en",
"duration": 1.5,
}))
}))
defer server.Close()
transcriber, err := NewTranscriber(ai.ProviderConfig{
Endpoint: server.URL,
APIKey: "test-key",
})
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
response, err := transcriber.Transcribe(ctx, ai.TranscribeRequest{
Model: "gpt-4o-transcribe",
Filename: "voice.wav",
ContentType: "audio/wav",
Audio: strings.NewReader("RIFF"),
Prompt: "domain words",
Language: "en",
})
require.NoError(t, err)
require.Equal(t, "hello world", response.Text)
require.Equal(t, "en", response.Language)
require.Equal(t, 1.5, response.Duration)
}
package ai
import "github.com/pkg/errors"
// FindProvider returns the provider with the given ID.
func FindProvider(providers []ProviderConfig, providerID string) (*ProviderConfig, error) {
if providerID == "" {
return nil, errors.Wrap(ErrProviderNotFound, "provider ID is required")
}
for _, provider := range providers {
if provider.ID == providerID {
return &provider, nil
}
}
return nil, errors.Wrapf(ErrProviderNotFound, "provider ID %q", providerID)
}
package ai
import (
"context"
"io"
)
// Transcriber transcribes audio into text.
type Transcriber interface {
Transcribe(ctx context.Context, request TranscribeRequest) (*TranscribeResponse, error)
}
// TranscribeRequest contains an audio transcription request.
type TranscribeRequest struct {
Model string
Filename string
ContentType string
Audio io.Reader
Size int64
Prompt string
Language string
}
// TranscribeResponse contains an audio transcription response.
type TranscribeResponse struct {
Text string
Language string
Duration float64
}
syntax = "proto3";
package memos.api.v1;
import "google/api/annotations.proto";
import "google/api/client.proto";
import "google/api/field_behavior.proto";
option go_package = "gen/api/v1";
service AIService {
// Transcribe transcribes an audio file using an instance AI provider.
rpc Transcribe(TranscribeRequest) returns (TranscribeResponse) {
option (google.api.http) = {
post: "/api/v1/ai:transcribe"
body: "*"
};
option (google.api.method_signature) = "provider_id,config,audio";
}
}
message TranscribeRequest {
// Required. The instance AI provider ID to use.
string provider_id = 1 [(google.api.field_behavior) = REQUIRED];
// Required. Transcription options.
TranscriptionConfig config = 2 [(google.api.field_behavior) = REQUIRED];
// Required. Audio input.
TranscriptionAudio audio = 3 [(google.api.field_behavior) = REQUIRED];
}
message TranscriptionConfig {
// Optional. The model to use. If empty, the provider's default model is used.
string model = 1 [(google.api.field_behavior) = OPTIONAL];
// Optional. A prompt to improve transcription quality.
string prompt = 2 [(google.api.field_behavior) = OPTIONAL];
// Optional. The language of the input audio.
string language = 3 [(google.api.field_behavior) = OPTIONAL];
}
message TranscriptionAudio {
oneof source {
// Inline audio bytes.
bytes content = 1 [(google.api.field_behavior) = INPUT_ONLY];
// URI for audio content. Reserved for future use.
string uri = 2;
}
// Optional. The uploaded filename.
string filename = 3 [(google.api.field_behavior) = OPTIONAL];
// Optional. The MIME type of the input audio.
string content_type = 4 [(google.api.field_behavior) = OPTIONAL];
}
message TranscribeResponse {
// The transcribed text.
string text = 1;
}
......@@ -72,6 +72,7 @@ message InstanceSetting {
MemoRelatedSetting memo_related_setting = 4;
TagsSetting tags_setting = 5;
NotificationSetting notification_setting = 6;
AISetting ai_setting = 7;
}
// Enumeration of instance setting keys.
......@@ -87,6 +88,8 @@ message InstanceSetting {
TAGS = 4;
// NOTIFICATION is the key for notification transport settings.
NOTIFICATION = 5;
// AI is the key for AI provider settings.
AI = 6;
}
// General instance settings configuration.
......@@ -201,6 +204,37 @@ message InstanceSetting {
bool use_ssl = 10;
}
}
// AI provider configuration settings.
message AISetting {
// providers is the list of AI provider configurations available instance-wide.
repeated AIProviderConfig providers = 1;
}
// AIProviderConfig represents one callable AI provider connection.
message AIProviderConfig {
string id = 1;
string title = 2;
AIProviderType type = 3;
string endpoint = 4;
// api_key is write-only and is never returned by GetInstanceSetting.
string api_key = 5 [(google.api.field_behavior) = INPUT_ONLY];
repeated string models = 6;
string default_model = 7;
// api_key_set indicates whether an API key is stored for this provider.
bool api_key_set = 8 [(google.api.field_behavior) = OUTPUT_ONLY];
// api_key_hint is a masked hint for the stored API key.
string api_key_hint = 9 [(google.api.field_behavior) = OUTPUT_ONLY];
}
// AIProviderType is the provider implementation type.
enum AIProviderType {
AI_PROVIDER_TYPE_UNSPECIFIED = 0;
OPENAI = 1;
OPENAI_COMPATIBLE = 2;
ANTHROPIC = 3;
GEMINI = 4;
}
}
// Request message for GetInstanceSetting method.
......
This diff is collapsed.
// Code generated by protoc-gen-grpc-gateway. DO NOT EDIT.
// source: api/v1/ai_service.proto
/*
Package apiv1 is a reverse proxy.
It translates gRPC into RESTful JSON APIs.
*/
package apiv1
import (
"context"
"errors"
"io"
"net/http"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/grpc-ecosystem/grpc-gateway/v2/utilities"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
// Suppress "imported and not used" errors
var (
_ codes.Code
_ io.Reader
_ status.Status
_ = errors.New
_ = runtime.String
_ = utilities.NewDoubleArray
_ = metadata.Join
)
func request_AIService_Transcribe_0(ctx context.Context, marshaler runtime.Marshaler, client AIServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq TranscribeRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
if req.Body != nil {
_, _ = io.Copy(io.Discard, req.Body)
}
msg, err := client.Transcribe(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD))
return msg, metadata, err
}
func local_request_AIService_Transcribe_0(ctx context.Context, marshaler runtime.Marshaler, server AIServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) {
var (
protoReq TranscribeRequest
metadata runtime.ServerMetadata
)
if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) {
return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err)
}
msg, err := server.Transcribe(ctx, &protoReq)
return msg, metadata, err
}
// RegisterAIServiceHandlerServer registers the http handlers for service AIService to "mux".
// UnaryRPC :call AIServiceServer directly.
// StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906.
// Note that using this registration option will cause many gRPC library features to stop working. Consider using RegisterAIServiceHandlerFromEndpoint instead.
// GRPC interceptors will not work for this type of registration. To use interceptors, you must use the "runtime.WithMiddlewares" option in the "runtime.NewServeMux" call.
func RegisterAIServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server AIServiceServer) error {
mux.Handle(http.MethodPost, pattern_AIService_Transcribe_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
var stream runtime.ServerTransportStream
ctx = grpc.NewContextWithServerTransportStream(ctx, &stream)
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/memos.api.v1.AIService/Transcribe", runtime.WithHTTPPathPattern("/api/v1/ai:transcribe"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := local_request_AIService_Transcribe_0(annotatedContext, inboundMarshaler, server, req, pathParams)
md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer())
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_AIService_Transcribe_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
return nil
}
// RegisterAIServiceHandlerFromEndpoint is same as RegisterAIServiceHandler but
// automatically dials to "endpoint" and closes the connection when "ctx" gets done.
func RegisterAIServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) {
conn, err := grpc.NewClient(endpoint, opts...)
if err != nil {
return err
}
defer func() {
if err != nil {
if cerr := conn.Close(); cerr != nil {
grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr)
}
return
}
go func() {
<-ctx.Done()
if cerr := conn.Close(); cerr != nil {
grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr)
}
}()
}()
return RegisterAIServiceHandler(ctx, mux, conn)
}
// RegisterAIServiceHandler registers the http handlers for service AIService to "mux".
// The handlers forward requests to the grpc endpoint over "conn".
func RegisterAIServiceHandler(ctx context.Context, mux *runtime.ServeMux, conn *grpc.ClientConn) error {
return RegisterAIServiceHandlerClient(ctx, mux, NewAIServiceClient(conn))
}
// RegisterAIServiceHandlerClient registers the http handlers for service AIService
// to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "AIServiceClient".
// Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "AIServiceClient"
// doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in
// "AIServiceClient" to call the correct interceptors. This client ignores the HTTP middlewares.
func RegisterAIServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux, client AIServiceClient) error {
mux.Handle(http.MethodPost, pattern_AIService_Transcribe_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) {
ctx, cancel := context.WithCancel(req.Context())
defer cancel()
inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req)
annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/memos.api.v1.AIService/Transcribe", runtime.WithHTTPPathPattern("/api/v1/ai:transcribe"))
if err != nil {
runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err)
return
}
resp, md, err := request_AIService_Transcribe_0(annotatedContext, inboundMarshaler, client, req, pathParams)
annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md)
if err != nil {
runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err)
return
}
forward_AIService_Transcribe_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
})
return nil
}
var (
pattern_AIService_Transcribe_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "ai"}, "transcribe"))
)
var (
forward_AIService_Transcribe_0 = runtime.ForwardResponseMessage
)
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.6.1
// - protoc (unknown)
// source: api/v1/ai_service.proto
package apiv1
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.64.0 or later.
const _ = grpc.SupportPackageIsVersion9
const (
AIService_Transcribe_FullMethodName = "/memos.api.v1.AIService/Transcribe"
)
// AIServiceClient is the client API for AIService service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type AIServiceClient interface {
// Transcribe transcribes an audio file using an instance AI provider.
Transcribe(ctx context.Context, in *TranscribeRequest, opts ...grpc.CallOption) (*TranscribeResponse, error)
}
type aIServiceClient struct {
cc grpc.ClientConnInterface
}
func NewAIServiceClient(cc grpc.ClientConnInterface) AIServiceClient {
return &aIServiceClient{cc}
}
func (c *aIServiceClient) Transcribe(ctx context.Context, in *TranscribeRequest, opts ...grpc.CallOption) (*TranscribeResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(TranscribeResponse)
err := c.cc.Invoke(ctx, AIService_Transcribe_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
// AIServiceServer is the server API for AIService service.
// All implementations must embed UnimplementedAIServiceServer
// for forward compatibility.
type AIServiceServer interface {
// Transcribe transcribes an audio file using an instance AI provider.
Transcribe(context.Context, *TranscribeRequest) (*TranscribeResponse, error)
mustEmbedUnimplementedAIServiceServer()
}
// UnimplementedAIServiceServer must be embedded to have
// forward compatible implementations.
//
// NOTE: this should be embedded by value instead of pointer to avoid a nil
// pointer dereference when methods are called.
type UnimplementedAIServiceServer struct{}
func (UnimplementedAIServiceServer) Transcribe(context.Context, *TranscribeRequest) (*TranscribeResponse, error) {
return nil, status.Error(codes.Unimplemented, "method Transcribe not implemented")
}
func (UnimplementedAIServiceServer) mustEmbedUnimplementedAIServiceServer() {}
func (UnimplementedAIServiceServer) testEmbeddedByValue() {}
// UnsafeAIServiceServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to AIServiceServer will
// result in compilation errors.
type UnsafeAIServiceServer interface {
mustEmbedUnimplementedAIServiceServer()
}
func RegisterAIServiceServer(s grpc.ServiceRegistrar, srv AIServiceServer) {
// If the following call panics, it indicates UnimplementedAIServiceServer was
// embedded by pointer and is nil. This will cause panics if an
// unimplemented method is ever invoked, so we test this at initialization
// time to prevent it from happening at runtime later due to I/O.
if t, ok := srv.(interface{ testEmbeddedByValue() }); ok {
t.testEmbeddedByValue()
}
s.RegisterService(&AIService_ServiceDesc, srv)
}
func _AIService_Transcribe_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(TranscribeRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(AIServiceServer).Transcribe(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: AIService_Transcribe_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(AIServiceServer).Transcribe(ctx, req.(*TranscribeRequest))
}
return interceptor(ctx, in, info, handler)
}
// AIService_ServiceDesc is the grpc.ServiceDesc for AIService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var AIService_ServiceDesc = grpc.ServiceDesc{
ServiceName: "memos.api.v1.AIService",
HandlerType: (*AIServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "Transcribe",
Handler: _AIService_Transcribe_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "api/v1/ai_service.proto",
}
// Code generated by protoc-gen-connect-go. DO NOT EDIT.
//
// Source: api/v1/ai_service.proto
package apiv1connect
import (
connect "connectrpc.com/connect"
context "context"
errors "errors"
v1 "github.com/usememos/memos/proto/gen/api/v1"
http "net/http"
strings "strings"
)
// This is a compile-time assertion to ensure that this generated file and the connect package are
// compatible. If you get a compiler error that this constant is not defined, this code was
// generated with a version of connect newer than the one compiled into your binary. You can fix the
// problem by either regenerating this code with an older version of connect or updating the connect
// version compiled into your binary.
const _ = connect.IsAtLeastVersion1_13_0
const (
// AIServiceName is the fully-qualified name of the AIService service.
AIServiceName = "memos.api.v1.AIService"
)
// These constants are the fully-qualified names of the RPCs defined in this package. They're
// exposed at runtime as Spec.Procedure and as the final two segments of the HTTP route.
//
// Note that these are different from the fully-qualified method names used by
// google.golang.org/protobuf/reflect/protoreflect. To convert from these constants to
// reflection-formatted method names, remove the leading slash and convert the remaining slash to a
// period.
const (
// AIServiceTranscribeProcedure is the fully-qualified name of the AIService's Transcribe RPC.
AIServiceTranscribeProcedure = "/memos.api.v1.AIService/Transcribe"
)
// AIServiceClient is a client for the memos.api.v1.AIService service.
type AIServiceClient interface {
// Transcribe transcribes an audio file using an instance AI provider.
Transcribe(context.Context, *connect.Request[v1.TranscribeRequest]) (*connect.Response[v1.TranscribeResponse], error)
}
// NewAIServiceClient constructs a client for the memos.api.v1.AIService service. By default, it
// uses the Connect protocol with the binary Protobuf Codec, asks for gzipped responses, and sends
// uncompressed requests. To use the gRPC or gRPC-Web protocols, supply the connect.WithGRPC() or
// connect.WithGRPCWeb() options.
//
// The URL supplied here should be the base URL for the Connect or gRPC server (for example,
// http://api.acme.com or https://acme.com/grpc).
func NewAIServiceClient(httpClient connect.HTTPClient, baseURL string, opts ...connect.ClientOption) AIServiceClient {
baseURL = strings.TrimRight(baseURL, "/")
aIServiceMethods := v1.File_api_v1_ai_service_proto.Services().ByName("AIService").Methods()
return &aIServiceClient{
transcribe: connect.NewClient[v1.TranscribeRequest, v1.TranscribeResponse](
httpClient,
baseURL+AIServiceTranscribeProcedure,
connect.WithSchema(aIServiceMethods.ByName("Transcribe")),
connect.WithClientOptions(opts...),
),
}
}
// aIServiceClient implements AIServiceClient.
type aIServiceClient struct {
transcribe *connect.Client[v1.TranscribeRequest, v1.TranscribeResponse]
}
// Transcribe calls memos.api.v1.AIService.Transcribe.
func (c *aIServiceClient) Transcribe(ctx context.Context, req *connect.Request[v1.TranscribeRequest]) (*connect.Response[v1.TranscribeResponse], error) {
return c.transcribe.CallUnary(ctx, req)
}
// AIServiceHandler is an implementation of the memos.api.v1.AIService service.
type AIServiceHandler interface {
// Transcribe transcribes an audio file using an instance AI provider.
Transcribe(context.Context, *connect.Request[v1.TranscribeRequest]) (*connect.Response[v1.TranscribeResponse], error)
}
// NewAIServiceHandler builds an HTTP handler from the service implementation. It returns the path
// on which to mount the handler and the handler itself.
//
// By default, handlers support the Connect, gRPC, and gRPC-Web protocols with the binary Protobuf
// and JSON codecs. They also support gzip compression.
func NewAIServiceHandler(svc AIServiceHandler, opts ...connect.HandlerOption) (string, http.Handler) {
aIServiceMethods := v1.File_api_v1_ai_service_proto.Services().ByName("AIService").Methods()
aIServiceTranscribeHandler := connect.NewUnaryHandler(
AIServiceTranscribeProcedure,
svc.Transcribe,
connect.WithSchema(aIServiceMethods.ByName("Transcribe")),
connect.WithHandlerOptions(opts...),
)
return "/memos.api.v1.AIService/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case AIServiceTranscribeProcedure:
aIServiceTranscribeHandler.ServeHTTP(w, r)
default:
http.NotFound(w, r)
}
})
}
// UnimplementedAIServiceHandler returns CodeUnimplemented from all methods.
type UnimplementedAIServiceHandler struct{}
func (UnimplementedAIServiceHandler) Transcribe(context.Context, *connect.Request[v1.TranscribeRequest]) (*connect.Response[v1.TranscribeResponse], error) {
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("memos.api.v1.AIService.Transcribe is not implemented"))
}
This diff is collapsed.
......@@ -6,6 +6,31 @@ info:
title: ""
version: 0.0.1
paths:
/api/v1/ai:transcribe:
post:
tags:
- AIService
description: Transcribe transcribes an audio file using an instance AI provider.
operationId: AIService_Transcribe
requestBody:
content:
application/json:
schema:
$ref: '#/components/schemas/TranscribeRequest'
required: true
responses:
"200":
description: OK
content:
application/json:
schema:
$ref: '#/components/schemas/TranscribeResponse'
default:
description: Default error response
content:
application/json:
schema:
$ref: '#/components/schemas/Status'
/api/v1/attachments:
get:
tags:
......@@ -2380,7 +2405,55 @@ components:
$ref: '#/components/schemas/InstanceSetting_TagsSetting'
notificationSetting:
$ref: '#/components/schemas/InstanceSetting_NotificationSetting'
aiSetting:
$ref: '#/components/schemas/InstanceSetting_AISetting'
description: An instance setting resource.
InstanceSetting_AIProviderConfig:
type: object
properties:
id:
type: string
title:
type: string
type:
enum:
- AI_PROVIDER_TYPE_UNSPECIFIED
- OPENAI
- OPENAI_COMPATIBLE
- ANTHROPIC
- GEMINI
type: string
format: enum
endpoint:
type: string
apiKey:
writeOnly: true
type: string
description: api_key is write-only and is never returned by GetInstanceSetting.
models:
type: array
items:
type: string
defaultModel:
type: string
apiKeySet:
readOnly: true
type: boolean
description: api_key_set indicates whether an API key is stored for this provider.
apiKeyHint:
readOnly: true
type: string
description: api_key_hint is a masked hint for the stored API key.
description: AIProviderConfig represents one callable AI provider connection.
InstanceSetting_AISetting:
type: object
properties:
providers:
type: array
items:
$ref: '#/components/schemas/InstanceSetting_AIProviderConfig'
description: providers is the list of AI provider configurations available instance-wide.
description: AI provider configuration settings.
InstanceSetting_GeneralSetting:
type: object
properties:
......@@ -3144,6 +3217,59 @@ components:
description: |-
S3 configuration for cloud storage backend.
Reference: https://developers.cloudflare.com/r2/examples/aws/aws-sdk-go/
TranscribeRequest:
required:
- providerId
- config
- audio
type: object
properties:
providerId:
type: string
description: Required. The instance AI provider ID to use.
config:
allOf:
- $ref: '#/components/schemas/TranscriptionConfig'
description: Required. Transcription options.
audio:
allOf:
- $ref: '#/components/schemas/TranscriptionAudio'
description: Required. Audio input.
TranscribeResponse:
type: object
properties:
text:
type: string
description: The transcribed text.
TranscriptionAudio:
type: object
properties:
content:
writeOnly: true
type: string
description: Inline audio bytes.
format: bytes
uri:
type: string
description: URI for audio content. Reserved for future use.
filename:
type: string
description: Optional. The uploaded filename.
contentType:
type: string
description: Optional. The MIME type of the input audio.
TranscriptionConfig:
type: object
properties:
model:
type: string
description: Optional. The model to use. If empty, the provider's default model is used.
prompt:
type: string
description: Optional. A prompt to improve transcription quality.
language:
type: string
description: Optional. The language of the input audio.
UpsertMemoReactionRequest:
required:
- name
......@@ -3419,6 +3545,7 @@ components:
format: date-time
description: UserWebhook represents a webhook owned by a user.
tags:
- name: AIService
- name: AttachmentService
- name: AuthService
- name: IdentityProviderService
......
This diff is collapsed.
......@@ -20,6 +20,8 @@ enum InstanceSettingKey {
TAGS = 5;
// NOTIFICATION is the key for notification transport settings.
NOTIFICATION = 6;
// AI is the key for AI provider settings.
AI = 7;
}
message InstanceSetting {
......@@ -31,6 +33,7 @@ message InstanceSetting {
InstanceMemoRelatedSetting memo_related_setting = 5;
InstanceTagsSetting tags_setting = 6;
InstanceNotificationSetting notification_setting = 7;
InstanceAISetting ai_setting = 8;
}
}
......@@ -142,3 +145,27 @@ message InstanceNotificationSetting {
bool use_ssl = 10;
}
}
message InstanceAISetting {
// providers is the list of AI provider configurations available instance-wide.
repeated AIProviderConfig providers = 1;
}
message AIProviderConfig {
string id = 1;
string title = 2;
AIProviderType type = 3;
string endpoint = 4;
// api_key is write-only at the API layer and is required by the server to call providers.
string api_key = 5;
repeated string models = 6;
string default_model = 7;
}
enum AIProviderType {
AI_PROVIDER_TYPE_UNSPECIFIED = 0;
OPENAI = 1;
OPENAI_COMPATIBLE = 2;
ANTHROPIC = 3;
GEMINI = 4;
}
package v1
import (
"bytes"
"context"
"mime"
"net/http"
"strings"
"github.com/pkg/errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/usememos/memos/internal/ai"
"github.com/usememos/memos/internal/ai/openai"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
)
const (
maxTranscriptionAudioSizeBytes = 25 * MebiByte
maxTranscriptionPromptLength = 4096
maxTranscriptionLanguageLength = 32
maxTranscriptionFilenameLength = 255
)
var supportedTranscriptionContentTypes = map[string]bool{
"audio/mpeg": true,
"audio/mp4": true,
"audio/mpga": true,
"audio/wav": true,
"audio/x-wav": true,
"audio/webm": true,
"audio/x-m4a": true,
"video/mp4": true,
"video/mpeg": true,
"video/webm": true,
}
// Transcribe transcribes an audio file using an instance AI provider.
func (s *APIV1Service) Transcribe(ctx context.Context, request *v1pb.TranscribeRequest) (*v1pb.TranscribeResponse, error) {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.Unauthenticated, "user not authenticated")
}
if strings.TrimSpace(request.ProviderId) == "" {
return nil, status.Errorf(codes.InvalidArgument, "provider_id is required")
}
if request.Config == nil {
return nil, status.Errorf(codes.InvalidArgument, "config is required")
}
prompt := strings.TrimSpace(request.Config.GetPrompt())
if len(prompt) > maxTranscriptionPromptLength {
return nil, status.Errorf(codes.InvalidArgument, "prompt is too long; maximum length is %d characters", maxTranscriptionPromptLength)
}
language := strings.TrimSpace(request.Config.GetLanguage())
if len(language) > maxTranscriptionLanguageLength {
return nil, status.Errorf(codes.InvalidArgument, "language is too long; maximum length is %d characters", maxTranscriptionLanguageLength)
}
if request.Audio == nil {
return nil, status.Errorf(codes.InvalidArgument, "audio is required")
}
if request.Audio.GetUri() != "" {
return nil, status.Errorf(codes.InvalidArgument, "audio uri is not supported")
}
content := request.Audio.GetContent()
if len(content) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "audio content is required")
}
if len(content) > maxTranscriptionAudioSizeBytes {
return nil, status.Errorf(codes.InvalidArgument, "audio file is too large; maximum size is 25 MiB")
}
filename := strings.TrimSpace(request.Audio.GetFilename())
if len(filename) > maxTranscriptionFilenameLength {
return nil, status.Errorf(codes.InvalidArgument, "filename is too long; maximum length is %d characters", maxTranscriptionFilenameLength)
}
contentType := strings.TrimSpace(request.Audio.GetContentType())
if contentType == "" {
contentType = http.DetectContentType(content)
}
if !isSupportedTranscriptionContentType(contentType) {
return nil, status.Errorf(codes.InvalidArgument, "audio content type %q is not supported", contentType)
}
provider, model, err := s.resolveAIProviderForTranscription(ctx, request.ProviderId, request.Config.GetModel())
if err != nil {
return nil, err
}
transcriber, err := newAITranscriber(provider)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "failed to create AI transcriber: %v", err)
}
transcription, err := transcriber.Transcribe(ctx, ai.TranscribeRequest{
Model: model,
Filename: filename,
ContentType: contentType,
Audio: bytes.NewReader(content),
Size: int64(len(content)),
Prompt: prompt,
Language: language,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to transcribe audio: %v", err)
}
return &v1pb.TranscribeResponse{
Text: transcription.Text,
}, nil
}
func (s *APIV1Service) resolveAIProviderForTranscription(ctx context.Context, providerID string, model string) (ai.ProviderConfig, string, error) {
setting, err := s.Store.GetInstanceAISetting(ctx)
if err != nil {
return ai.ProviderConfig{}, "", status.Errorf(codes.Internal, "failed to get AI setting: %v", err)
}
providers := make([]ai.ProviderConfig, 0, len(setting.GetProviders()))
for _, provider := range setting.GetProviders() {
if provider == nil {
continue
}
providers = append(providers, convertAIProviderConfigFromStore(provider))
}
provider, err := ai.FindProvider(providers, providerID)
if err != nil {
return ai.ProviderConfig{}, "", status.Errorf(codes.NotFound, "AI provider not found")
}
selectedModel := strings.TrimSpace(model)
if selectedModel == "" {
selectedModel = provider.DefaultModel
}
if selectedModel == "" {
return ai.ProviderConfig{}, "", status.Errorf(codes.InvalidArgument, "model is required")
}
if !containsString(provider.Models, selectedModel) {
return ai.ProviderConfig{}, "", status.Errorf(codes.InvalidArgument, "model %q is not configured for provider %q", selectedModel, provider.ID)
}
return *provider, selectedModel, nil
}
func convertAIProviderConfigFromStore(provider *storepb.AIProviderConfig) ai.ProviderConfig {
return ai.ProviderConfig{
ID: provider.GetId(),
Title: provider.GetTitle(),
Type: convertAIProviderTypeFromStore(provider.GetType()),
Endpoint: provider.GetEndpoint(),
APIKey: provider.GetApiKey(),
Models: provider.GetModels(),
DefaultModel: provider.GetDefaultModel(),
}
}
func convertAIProviderTypeFromStore(providerType storepb.AIProviderType) ai.ProviderType {
switch providerType {
case storepb.AIProviderType_OPENAI:
return ai.ProviderOpenAI
case storepb.AIProviderType_OPENAI_COMPATIBLE:
return ai.ProviderOpenAICompatible
case storepb.AIProviderType_ANTHROPIC:
return ai.ProviderAnthropic
case storepb.AIProviderType_GEMINI:
return ai.ProviderGemini
default:
return ""
}
}
func newAITranscriber(provider ai.ProviderConfig) (ai.Transcriber, error) {
switch provider.Type {
case ai.ProviderOpenAI, ai.ProviderOpenAICompatible:
return openai.NewTranscriber(provider)
default:
return nil, errors.Wrapf(ai.ErrCapabilityUnsupported, "provider type %q", provider.Type)
}
}
func containsString(values []string, target string) bool {
for _, value := range values {
if value == target {
return true
}
}
return false
}
func isSupportedTranscriptionContentType(contentType string) bool {
mediaType, _, err := mime.ParseMediaType(strings.TrimSpace(contentType))
if err != nil {
return false
}
mediaType = strings.ToLower(mediaType)
return supportedTranscriptionContentTypes[mediaType]
}
......@@ -39,6 +39,7 @@ func (s *ConnectServiceHandler) RegisterConnectHandlers(mux *http.ServeMux, opts
wrap(apiv1connect.NewUserServiceHandler(s, opts...)),
wrap(apiv1connect.NewMemoServiceHandler(s, opts...)),
wrap(apiv1connect.NewAttachmentServiceHandler(s, opts...)),
wrap(apiv1connect.NewAIServiceHandler(s, opts...)),
wrap(apiv1connect.NewShortcutServiceHandler(s, opts...)),
wrap(apiv1connect.NewIdentityProviderServiceHandler(s, opts...)),
}
......
......@@ -435,6 +435,16 @@ func (s *ConnectServiceHandler) BatchDeleteAttachments(ctx context.Context, req
return connect.NewResponse(resp), nil
}
// AIService
func (s *ConnectServiceHandler) Transcribe(ctx context.Context, req *connect.Request[v1pb.TranscribeRequest]) (*connect.Response[v1pb.TranscribeResponse], error) {
resp, err := s.APIV1Service.Transcribe(ctx, req.Msg)
if err != nil {
return nil, convertGRPCError(err)
}
return connect.NewResponse(resp), nil
}
// ShortcutService
func (s *ConnectServiceHandler) ListShortcuts(ctx context.Context, req *connect.Request[v1pb.ListShortcutsRequest]) (*connect.Response[v1pb.ListShortcutsResponse], error) {
......
......@@ -5,8 +5,10 @@ import (
"fmt"
"math"
"regexp"
"slices"
"strings"
"github.com/lithammer/shortuuid/v4"
"github.com/pkg/errors"
colorpb "google.golang.org/genproto/googleapis/type/color"
"google.golang.org/grpc/codes"
......@@ -54,6 +56,8 @@ func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.Get
_, err = s.Store.GetInstanceTagsSetting(ctx)
case storepb.InstanceSettingKey_NOTIFICATION:
_, err = s.Store.GetInstanceNotificationSetting(ctx)
case storepb.InstanceSettingKey_AI:
_, err = s.Store.GetInstanceAISetting(ctx)
default:
return nil, status.Errorf(codes.InvalidArgument, "unsupported instance setting key: %v", instanceSettingKey)
}
......@@ -71,9 +75,10 @@ func (s *APIV1Service) GetInstanceSetting(ctx context.Context, request *v1pb.Get
return nil, status.Errorf(codes.NotFound, "instance setting not found")
}
// Storage and notification settings contain credentials; restrict to admins only.
// Storage, notification, and AI settings contain credentials; restrict to admins only.
if instanceSetting.Key == storepb.InstanceSettingKey_STORAGE ||
instanceSetting.Key == storepb.InstanceSettingKey_NOTIFICATION {
instanceSetting.Key == storepb.InstanceSettingKey_NOTIFICATION ||
instanceSetting.Key == storepb.InstanceSettingKey_AI {
user, err := s.fetchCurrentUser(ctx)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err)
......@@ -127,6 +132,10 @@ func (s *APIV1Service) UpdateInstanceSetting(ctx context.Context, request *v1pb.
storage.S3Config.AccessKeySecret = existing.S3Config.AccessKeySecret
}
}
case storepb.InstanceSettingKey_AI:
if err := s.prepareInstanceAISettingForUpdate(ctx, updateSetting.GetAiSetting()); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid AI setting: %v", err)
}
default:
// No credential preservation needed for other setting types.
}
......@@ -164,6 +173,10 @@ func convertInstanceSettingFromStore(setting *storepb.InstanceSetting) *v1pb.Ins
instanceSetting.Value = &v1pb.InstanceSetting_NotificationSetting_{
NotificationSetting: convertInstanceNotificationSettingFromStore(setting.GetNotificationSetting()),
}
case *storepb.InstanceSetting_AiSetting:
instanceSetting.Value = &v1pb.InstanceSetting_AiSetting{
AiSetting: convertInstanceAISettingFromStore(setting.GetAiSetting()),
}
default:
// Leave Value unset for unsupported setting variants.
}
......@@ -199,6 +212,10 @@ func convertInstanceSettingToStore(setting *v1pb.InstanceSetting) *storepb.Insta
instanceSetting.Value = &storepb.InstanceSetting_NotificationSetting{
NotificationSetting: convertInstanceNotificationSettingToStore(setting.GetNotificationSetting()),
}
case storepb.InstanceSettingKey_AI:
instanceSetting.Value = &storepb.InstanceSetting_AiSetting{
AiSetting: convertInstanceAISettingToStore(setting.GetAiSetting()),
}
default:
// Keep the default GeneralSetting value
}
......@@ -398,6 +415,58 @@ func convertInstanceNotificationSettingToStore(setting *v1pb.InstanceSetting_Not
return notificationSetting
}
func convertInstanceAISettingFromStore(setting *storepb.InstanceAISetting) *v1pb.InstanceSetting_AISetting {
if setting == nil {
return nil
}
aiSetting := &v1pb.InstanceSetting_AISetting{
Providers: make([]*v1pb.InstanceSetting_AIProviderConfig, 0, len(setting.Providers)),
}
for _, provider := range setting.Providers {
if provider == nil {
continue
}
apiKey := provider.GetApiKey()
aiSetting.Providers = append(aiSetting.Providers, &v1pb.InstanceSetting_AIProviderConfig{
Id: provider.GetId(),
Title: provider.GetTitle(),
Type: v1pb.InstanceSetting_AIProviderType(provider.GetType()),
Endpoint: provider.GetEndpoint(),
Models: provider.GetModels(),
DefaultModel: provider.GetDefaultModel(),
ApiKeySet: apiKey != "",
ApiKeyHint: maskAPIKey(apiKey),
})
}
return aiSetting
}
func convertInstanceAISettingToStore(setting *v1pb.InstanceSetting_AISetting) *storepb.InstanceAISetting {
if setting == nil {
return nil
}
aiSetting := &storepb.InstanceAISetting{
Providers: make([]*storepb.AIProviderConfig, 0, len(setting.Providers)),
}
for _, provider := range setting.Providers {
if provider == nil {
continue
}
aiSetting.Providers = append(aiSetting.Providers, &storepb.AIProviderConfig{
Id: provider.GetId(),
Title: provider.GetTitle(),
Type: storepb.AIProviderType(provider.GetType()),
Endpoint: provider.GetEndpoint(),
ApiKey: provider.GetApiKey(),
Models: provider.GetModels(),
DefaultModel: provider.GetDefaultModel(),
})
}
return aiSetting
}
func validateInstanceSetting(setting *v1pb.InstanceSetting) error {
key, err := ExtractInstanceSettingKeyFromName(setting.Name)
if err != nil {
......@@ -409,6 +478,104 @@ func validateInstanceSetting(setting *v1pb.InstanceSetting) error {
return validateInstanceTagsSetting(setting.GetTagsSetting())
}
func (s *APIV1Service) prepareInstanceAISettingForUpdate(ctx context.Context, setting *storepb.InstanceAISetting) error {
if setting == nil {
return errors.New("AI setting is required")
}
existing, err := s.Store.GetInstanceAISetting(ctx)
if err != nil {
return errors.Wrap(err, "failed to get existing AI setting")
}
existingProviders := map[string]*storepb.AIProviderConfig{}
if existing != nil {
for _, provider := range existing.Providers {
if provider != nil && provider.Id != "" {
existingProviders[provider.Id] = provider
}
}
}
seenIDs := map[string]bool{}
for _, provider := range setting.Providers {
if provider == nil {
return errors.New("provider cannot be nil")
}
provider.Id = strings.TrimSpace(provider.Id)
if provider.Id == "" {
provider.Id = shortuuid.New()
}
if seenIDs[provider.Id] {
return errors.Errorf("duplicate provider ID %q", provider.Id)
}
seenIDs[provider.Id] = true
provider.Title = strings.TrimSpace(provider.Title)
if provider.Title == "" {
return errors.New("provider title is required")
}
if provider.Type == storepb.AIProviderType_AI_PROVIDER_TYPE_UNSPECIFIED {
return errors.Errorf("provider %q type is required", provider.Id)
}
provider.Endpoint = strings.TrimSpace(provider.Endpoint)
if provider.Type == storepb.AIProviderType_OPENAI && provider.Endpoint == "" {
provider.Endpoint = "https://api.openai.com/v1"
}
if provider.Type == storepb.AIProviderType_OPENAI_COMPATIBLE && provider.Endpoint == "" {
return errors.Errorf("provider %q endpoint is required", provider.Id)
}
provider.Models = normalizeAIModels(provider.Models)
if len(provider.Models) == 0 {
return errors.Errorf("provider %q must define at least one model", provider.Id)
}
provider.DefaultModel = strings.TrimSpace(provider.DefaultModel)
if provider.DefaultModel == "" {
provider.DefaultModel = provider.Models[0]
}
if !slices.Contains(provider.Models, provider.DefaultModel) {
return errors.Errorf("provider %q default model %q must be included in models", provider.Id, provider.DefaultModel)
}
if provider.ApiKey == "" {
if existingProvider, ok := existingProviders[provider.Id]; ok {
provider.ApiKey = existingProvider.ApiKey
}
}
if provider.ApiKey == "" {
return errors.Errorf("provider %q API key is required", provider.Id)
}
}
return nil
}
func normalizeAIModels(models []string) []string {
normalized := []string{}
seen := map[string]bool{}
for _, model := range models {
model = strings.TrimSpace(model)
if model == "" || seen[model] {
continue
}
seen[model] = true
normalized = append(normalized, model)
}
return normalized
}
func maskAPIKey(apiKey string) string {
if apiKey == "" {
return ""
}
if len(apiKey) <= 8 {
return "..."
}
prefixLength := min(4, len(apiKey))
return apiKey[:prefixLength] + "..." + apiKey[len(apiKey)-4:]
}
func validateInstanceTagsSetting(setting *v1pb.InstanceSetting_TagsSetting) error {
if setting == nil {
return errors.New("tags setting is required")
......
package test
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestTranscribe(t *testing.T) {
ctx := context.Background()
t.Run("requires authentication", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
_, err := ts.Service.Transcribe(ctx, &v1pb.TranscribeRequest{
ProviderId: "openai-main",
Config: &v1pb.TranscriptionConfig{
Model: "gpt-4o-transcribe",
},
Audio: &v1pb.TranscriptionAudio{
Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")},
Filename: "voice.wav",
ContentType: "audio/wav",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "user not authenticated")
})
t.Run("transcribes audio file with configured provider", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "alice")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
openAIServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, "/audio/transcriptions", r.URL.Path)
require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
require.NoError(t, r.ParseMultipartForm(10<<20))
require.Equal(t, "gpt-4o-transcribe", r.FormValue("model"))
require.Equal(t, "names: Alice", r.FormValue("prompt"))
file, header, err := r.FormFile("file")
require.NoError(t, err)
defer file.Close()
require.Equal(t, "voice.wav", header.Filename)
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(map[string]string{
"text": "transcribed text",
}))
}))
defer openAIServer.Close()
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_AI,
Value: &storepb.InstanceSetting_AiSetting{
AiSetting: &storepb.InstanceAISetting{
Providers: []*storepb.AIProviderConfig{
{
Id: "openai-main",
Title: "OpenAI",
Type: storepb.AIProviderType_OPENAI_COMPATIBLE,
Endpoint: openAIServer.URL,
ApiKey: "sk-test",
Models: []string{"gpt-4o-transcribe"},
DefaultModel: "gpt-4o-transcribe",
},
},
},
},
})
require.NoError(t, err)
resp, err := ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{
ProviderId: "openai-main",
Config: &v1pb.TranscriptionConfig{
Prompt: "names: Alice",
},
Audio: &v1pb.TranscriptionAudio{
Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")},
Filename: "voice.wav",
ContentType: "audio/wav",
},
})
require.NoError(t, err)
require.Equal(t, "transcribed text", resp.Text)
})
t.Run("rejects unconfigured model", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "bob")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_AI,
Value: &storepb.InstanceSetting_AiSetting{
AiSetting: &storepb.InstanceAISetting{
Providers: []*storepb.AIProviderConfig{
{
Id: "openai-main",
Title: "OpenAI",
Type: storepb.AIProviderType_OPENAI_COMPATIBLE,
Endpoint: "https://example.com/v1",
ApiKey: "sk-test",
Models: []string{"gpt-4o-transcribe"},
DefaultModel: "gpt-4o-transcribe",
},
},
},
},
})
require.NoError(t, err)
_, err = ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{
ProviderId: "openai-main",
Config: &v1pb.TranscriptionConfig{
Model: "other-model",
},
Audio: &v1pb.TranscriptionAudio{
Source: &v1pb.TranscriptionAudio_Content{Content: []byte("RIFF")},
Filename: "voice.wav",
ContentType: "audio/wav",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not configured")
})
t.Run("rejects non-audio content before provider call", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
user, err := ts.CreateRegularUser(ctx, "charlie")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_AI,
Value: &storepb.InstanceSetting_AiSetting{
AiSetting: &storepb.InstanceAISetting{
Providers: []*storepb.AIProviderConfig{
{
Id: "openai-main",
Title: "OpenAI",
Type: storepb.AIProviderType_OPENAI_COMPATIBLE,
Endpoint: "https://example.com/v1",
ApiKey: "sk-test",
Models: []string{"gpt-4o-transcribe"},
DefaultModel: "gpt-4o-transcribe",
},
},
},
},
})
require.NoError(t, err)
_, err = ts.Service.Transcribe(userCtx, &v1pb.TranscribeRequest{
ProviderId: "openai-main",
Config: &v1pb.TranscriptionConfig{
Model: "gpt-4o-transcribe",
},
Audio: &v1pb.TranscriptionAudio{
Source: &v1pb.TranscriptionAudio_Content{Content: []byte("not audio")},
Filename: "notes.txt",
ContentType: "text/plain",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not supported")
})
}
......@@ -238,6 +238,34 @@ func TestGetInstanceSetting(t *testing.T) {
"SmtpPassword must never be returned in responses")
})
t.Run("GetInstanceSetting - AI setting requires admin", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
admin, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
adminCtx := ts.CreateUserContext(ctx, admin.ID)
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
req := &v1pb.GetInstanceSettingRequest{Name: "instance/settings/AI"}
_, err = ts.Service.GetInstanceSetting(ctx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "not authenticated")
_, err = ts.Service.GetInstanceSetting(userCtx, req)
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
resp, err := ts.Service.GetInstanceSetting(adminCtx, req)
require.NoError(t, err)
require.NotNil(t, resp.GetAiSetting())
require.Empty(t, resp.GetAiSetting().GetProviders())
})
t.Run("GetInstanceSetting - invalid setting name", func(t *testing.T) {
// Create test service for this specific test
ts := NewTestService(t)
......@@ -258,6 +286,41 @@ func TestGetInstanceSetting(t *testing.T) {
func TestUpdateInstanceSetting(t *testing.T) {
ctx := context.Background()
t.Run("UpdateInstanceSetting - AI setting requires admin", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, regularUser.ID)
setting := &v1pb.InstanceSetting{
Name: "instance/settings/AI",
Value: &v1pb.InstanceSetting_AiSetting{
AiSetting: &v1pb.InstanceSetting_AISetting{
Providers: []*v1pb.InstanceSetting_AIProviderConfig{
{
Id: "openai-main",
Title: "OpenAI",
Type: v1pb.InstanceSetting_OPENAI,
ApiKey: "sk-test",
Models: []string{"gpt-4o-transcribe"},
DefaultModel: "gpt-4o-transcribe",
},
},
},
},
}
_, err = ts.Service.UpdateInstanceSetting(ctx, &v1pb.UpdateInstanceSettingRequest{Setting: setting})
require.Error(t, err)
require.Contains(t, err.Error(), "not authenticated")
_, err = ts.Service.UpdateInstanceSetting(userCtx, &v1pb.UpdateInstanceSettingRequest{Setting: setting})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("UpdateInstanceSetting - tags setting", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
......@@ -490,4 +553,75 @@ func TestUpdateInstanceSetting(t *testing.T) {
"existing AccessKeySecret must be preserved when an empty value is sent")
require.Equal(t, "s3-v2.example.com", stored.GetS3Config().GetEndpoint())
})
t.Run("UpdateInstanceSetting - AI provider keys are write-only and preserved on empty", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
adminCtx := ts.CreateUserContext(ctx, hostUser.ID)
_, err = ts.Service.UpdateInstanceSetting(adminCtx, &v1pb.UpdateInstanceSettingRequest{
Setting: &v1pb.InstanceSetting{
Name: "instance/settings/AI",
Value: &v1pb.InstanceSetting_AiSetting{
AiSetting: &v1pb.InstanceSetting_AISetting{
Providers: []*v1pb.InstanceSetting_AIProviderConfig{
{
Id: "openai-main",
Title: "OpenAI",
Type: v1pb.InstanceSetting_OPENAI,
ApiKey: "sk-original",
Models: []string{"gpt-5.4", "gpt-5.4-mini"},
DefaultModel: "gpt-5.4",
},
},
},
},
},
})
require.NoError(t, err)
resp, err := ts.Service.GetInstanceSetting(adminCtx, &v1pb.GetInstanceSettingRequest{
Name: "instance/settings/AI",
})
require.NoError(t, err)
require.Len(t, resp.GetAiSetting().GetProviders(), 1)
provider := resp.GetAiSetting().GetProviders()[0]
require.Empty(t, provider.GetApiKey(), "AI provider API key must never be returned in responses")
require.True(t, provider.GetApiKeySet())
require.Equal(t, "sk-o...inal", provider.GetApiKeyHint())
require.Equal(t, "https://api.openai.com/v1", provider.GetEndpoint())
_, err = ts.Service.UpdateInstanceSetting(adminCtx, &v1pb.UpdateInstanceSettingRequest{
Setting: &v1pb.InstanceSetting{
Name: "instance/settings/AI",
Value: &v1pb.InstanceSetting_AiSetting{
AiSetting: &v1pb.InstanceSetting_AISetting{
Providers: []*v1pb.InstanceSetting_AIProviderConfig{
{
Id: "openai-main",
Title: "OpenAI primary",
Type: v1pb.InstanceSetting_OPENAI,
ApiKey: "",
Models: []string{"gpt-5.4-mini", "gpt-5.4-mini", "gpt-5.4"},
DefaultModel: "",
},
},
},
},
},
})
require.NoError(t, err)
stored, err := ts.Store.GetInstanceAISetting(ctx)
require.NoError(t, err)
require.Len(t, stored.GetProviders(), 1)
require.Equal(t, "sk-original", stored.GetProviders()[0].GetApiKey(),
"existing AI provider API key must be preserved when an empty value is sent")
require.Equal(t, "OpenAI primary", stored.GetProviders()[0].GetTitle())
require.Equal(t, []string{"gpt-5.4-mini", "gpt-5.4"}, stored.GetProviders()[0].GetModels())
require.Equal(t, "gpt-5.4-mini", stored.GetProviders()[0].GetDefaultModel())
})
}
......@@ -23,6 +23,7 @@ type APIV1Service struct {
v1pb.UnimplementedUserServiceServer
v1pb.UnimplementedMemoServiceServer
v1pb.UnimplementedAttachmentServiceServer
v1pb.UnimplementedAIServiceServer
v1pb.UnimplementedShortcutServiceServer
v1pb.UnimplementedIdentityProviderServiceServer
......@@ -104,6 +105,9 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech
if err := v1pb.RegisterAttachmentServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterAIServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
if err := v1pb.RegisterShortcutServiceHandlerServer(ctx, gwMux, s); err != nil {
return err
}
......
......@@ -41,6 +41,8 @@ func (s *Store) UpsertInstanceSetting(ctx context.Context, upsert *storepb.Insta
valueBytes, err = protojson.Marshal(upsert.GetTagsSetting())
} else if upsert.Key == storepb.InstanceSettingKey_NOTIFICATION {
valueBytes, err = protojson.Marshal(upsert.GetNotificationSetting())
} else if upsert.Key == storepb.InstanceSettingKey_AI {
valueBytes, err = protojson.Marshal(upsert.GetAiSetting())
} else {
return nil, errors.Errorf("unsupported instance setting key: %v", upsert.Key)
}
......@@ -216,6 +218,26 @@ func (s *Store) GetInstanceNotificationSetting(ctx context.Context) (*storepb.In
return instanceNotificationSetting, nil
}
// GetInstanceAISetting gets the AI provider settings for the instance.
func (s *Store) GetInstanceAISetting(ctx context.Context) (*storepb.InstanceAISetting, error) {
instanceSetting, err := s.GetInstanceSetting(ctx, &FindInstanceSetting{
Name: storepb.InstanceSettingKey_AI.String(),
})
if err != nil {
return nil, errors.Wrap(err, "failed to get instance AI setting")
}
instanceAISetting := &storepb.InstanceAISetting{}
if instanceSetting != nil {
instanceAISetting = instanceSetting.GetAiSetting()
}
s.instanceSettingCache.Set(ctx, storepb.InstanceSettingKey_AI.String(), &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_AI,
Value: &storepb.InstanceSetting_AiSetting{AiSetting: instanceAISetting},
})
return instanceAISetting, nil
}
const (
defaultInstanceStorageType = storepb.InstanceStorageSetting_LOCAL
defaultInstanceUploadSizeLimitMb = 30
......@@ -291,6 +313,12 @@ func convertInstanceSettingFromRaw(instanceSettingRaw *InstanceSetting) (*storep
return nil, err
}
instanceSetting.Value = &storepb.InstanceSetting_NotificationSetting{NotificationSetting: notificationSetting}
case storepb.InstanceSettingKey_AI.String():
aiSetting := &storepb.InstanceAISetting{}
if err := protojsonUnmarshaler.Unmarshal([]byte(instanceSettingRaw.Value), aiSetting); err != nil {
return nil, err
}
instanceSetting.Value = &storepb.InstanceSetting_AiSetting{AiSetting: aiSetting}
default:
// Skip unsupported instance setting key.
return nil, nil
......
......@@ -326,6 +326,55 @@ func TestInstanceSettingNotificationSetting(t *testing.T) {
ts.Close()
}
func TestInstanceSettingAISetting(t *testing.T) {
t.Parallel()
ctx := context.Background()
ts := NewTestingStore(ctx, t)
aiSetting, err := ts.GetInstanceAISetting(ctx)
require.NoError(t, err)
require.NotNil(t, aiSetting)
require.Empty(t, aiSetting.Providers)
_, err = ts.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_AI,
Value: &storepb.InstanceSetting_AiSetting{
AiSetting: &storepb.InstanceAISetting{
Providers: []*storepb.AIProviderConfig{
{
Id: "openai-main",
Title: "OpenAI",
Type: storepb.AIProviderType_OPENAI,
Endpoint: "https://api.openai.com/v1",
ApiKey: "sk-test",
Models: []string{"gpt-5.4", "gpt-5.4-mini"},
DefaultModel: "gpt-5.4",
},
{
Id: "company-gateway",
Title: "Company Gateway",
Type: storepb.AIProviderType_OPENAI_COMPATIBLE,
Endpoint: "https://llm.example.com/v1",
ApiKey: "gw-test",
Models: []string{"qwen-plus"},
DefaultModel: "qwen-plus",
},
},
},
},
})
require.NoError(t, err)
aiSetting, err = ts.GetInstanceAISetting(ctx)
require.NoError(t, err)
require.Len(t, aiSetting.Providers, 2)
require.Equal(t, "openai-main", aiSetting.Providers[0].Id)
require.Equal(t, "sk-test", aiSetting.Providers[0].ApiKey)
require.Equal(t, "company-gateway", aiSetting.Providers[1].Id)
ts.Close()
}
func TestInstanceSettingListAll(t *testing.T) {
t.Parallel()
ctx := context.Background()
......
This diff is collapsed.
......@@ -2,6 +2,7 @@ import { timestampDate } from "@bufbuild/protobuf/wkt";
import { Code, ConnectError, createClient, type Interceptor } from "@connectrpc/connect";
import { createConnectTransport } from "@connectrpc/connect-web";
import { getAccessToken, hasStoredToken, isTokenExpired, REQUEST_TOKEN_EXPIRY_BUFFER_MS, setAccessToken } from "./auth-state";
import { AIService } from "./types/proto/api/v1/ai_service_pb";
import { AttachmentService } from "./types/proto/api/v1/attachment_service_pb";
import { AuthService } from "./types/proto/api/v1/auth_service_pb";
import { IdentityProviderService } from "./types/proto/api/v1/idp_service_pb";
......@@ -195,6 +196,7 @@ export const userServiceClient = createClient(UserService, transport);
// Content service clients
export const memoServiceClient = createClient(MemoService, transport);
export const attachmentServiceClient = createClient(AttachmentService, transport);
export const aiServiceClient = createClient(AIService, transport);
export const shortcutServiceClient = createClient(ShortcutService, transport);
// Configuration service clients
......
......@@ -5,6 +5,8 @@ import {
InstanceProfile,
InstanceProfileSchema,
InstanceSetting,
InstanceSetting_AISetting,
InstanceSetting_AISettingSchema,
InstanceSetting_GeneralSetting,
InstanceSetting_GeneralSettingSchema,
InstanceSetting_Key,
......@@ -39,6 +41,7 @@ interface InstanceContextValue extends InstanceState {
memoRelatedSetting: InstanceSetting_MemoRelatedSetting;
storageSetting: InstanceSetting_StorageSetting;
tagsSetting: InstanceSetting_TagsSetting;
aiSetting: InstanceSetting_AISetting;
initialize: () => Promise<void>;
fetchSetting: (key: InstanceSetting_Key) => Promise<void>;
updateSetting: (setting: InstanceSetting) => Promise<void>;
......@@ -88,6 +91,14 @@ export function InstanceProvider({ children }: { children: ReactNode }) {
return create(InstanceSetting_TagsSettingSchema, {});
}, [state.settings]);
const aiSetting = useMemo((): InstanceSetting_AISetting => {
const setting = state.settings.find((s) => s.name === `${instanceSettingNamePrefix}AI`);
if (setting?.value.case === "aiSetting") {
return setting.value.value;
}
return create(InstanceSetting_AISettingSchema, {});
}, [state.settings]);
const initialize = useCallback(async () => {
setState((prev) => ({ ...prev, isLoading: true }));
try {
......@@ -142,11 +153,12 @@ export function InstanceProvider({ children }: { children: ReactNode }) {
memoRelatedSetting,
storageSetting,
tagsSetting,
aiSetting,
initialize,
fetchSetting,
updateSetting,
}),
[state, generalSetting, memoRelatedSetting, storageSetting, tagsSetting, initialize, fetchSetting, updateSetting],
[state, generalSetting, memoRelatedSetting, storageSetting, tagsSetting, aiSetting, initialize, fetchSetting, updateSetting],
);
return <InstanceContext.Provider value={value}>{children}</InstanceContext.Provider>;
......
......@@ -386,6 +386,32 @@
"update-information": "Update Information",
"username-note": "Used to sign in"
},
"ai": {
"add-provider": "Add provider",
"api-key": "API key",
"api-key-required": "API key is required.",
"configured": "Configured",
"current-key": "Current key: {{key}}",
"default-model": "Default model",
"default-model-required": "Default model must be listed in models.",
"delete-provider": "Delete AI provider `{{title}}`?",
"description": "Configure instance-wide AI providers available to server-side AI features.",
"dialog-description": "Models are entered manually. Leave the API key blank while editing to keep the stored key.",
"edit-provider": "Edit provider",
"endpoint": "Endpoint",
"endpoint-required": "Endpoint is required for OpenAI-compatible providers.",
"keep-api-key": "Leave blank to keep the existing key",
"label": "AI",
"model-count": "{{count}} models",
"models": "Models",
"models-hint": "Enter one model per line.",
"models-required": "At least one model is required.",
"no-providers": "No AI providers configured.",
"provider-title": "Provider name",
"provider-title-required": "Provider name is required.",
"provider-type": "Provider type",
"providers": "Providers"
},
"instance": {
"disallow-change-nickname": "Disallow changing nickname",
"disallow-change-username": "Disallow changing username",
......
import {
BotIcon,
CogIcon,
DatabaseIcon,
KeyIcon,
......@@ -13,6 +14,7 @@ import {
import { useEffect, useMemo, useState } from "react";
import { useLocation } from "react-router-dom";
import MobileHeader from "@/components/MobileHeader";
import AISection from "@/components/Settings/AISection";
import InstanceSection from "@/components/Settings/InstanceSection";
import MemberSection from "@/components/Settings/MemberSection";
import MemoRelatedSettings from "@/components/Settings/MemoRelatedSettings";
......@@ -31,10 +33,10 @@ import { InstanceSetting_Key } from "@/types/proto/api/v1/instance_service_pb";
import { User_Role } from "@/types/proto/api/v1/user_service_pb";
import { useTranslate } from "@/utils/i18n";
type SettingSection = "my-account" | "preference" | "webhook" | "member" | "system" | "memo" | "storage" | "sso" | "tags";
type SettingSection = "my-account" | "preference" | "webhook" | "member" | "system" | "memo" | "storage" | "sso" | "tags" | "ai";
const BASIC_SECTIONS: SettingSection[] = ["my-account", "preference", "webhook"];
const ADMIN_SECTIONS: SettingSection[] = ["member", "system", "memo", "tags", "storage", "sso"];
const ADMIN_SECTIONS: SettingSection[] = ["member", "system", "memo", "tags", "storage", "sso", "ai"];
const SECTION_ICON_MAP: Record<SettingSection, LucideIcon> = {
"my-account": UserIcon,
......@@ -46,6 +48,7 @@ const SECTION_ICON_MAP: Record<SettingSection, LucideIcon> = {
storage: DatabaseIcon,
tags: TagsIcon,
sso: KeyIcon,
ai: BotIcon,
};
const SECTION_COMPONENT_MAP: Record<SettingSection, React.ComponentType> = {
......@@ -58,6 +61,7 @@ const SECTION_COMPONENT_MAP: Record<SettingSection, React.ComponentType> = {
storage: StorageSection,
tags: TagsSection,
sso: SSOSection,
ai: AISection,
};
const Setting = () => {
......@@ -86,6 +90,7 @@ const Setting = () => {
// Fetch admin-only settings that are not eagerly loaded by InstanceContext.
fetchSetting(InstanceSetting_Key.STORAGE);
fetchSetting(InstanceSetting_Key.TAGS);
fetchSetting(InstanceSetting_Key.AI);
}, [isHost, fetchSetting]);
const handleSectionSelectorItemClick = (section: SettingSection) => {
......
// @generated by protoc-gen-es v2.11.0 with parameter "target=ts"
// @generated from file api/v1/ai_service.proto (package memos.api.v1, syntax proto3)
/* eslint-disable */
import type { GenFile, GenMessage, GenService } from "@bufbuild/protobuf/codegenv2";
import { fileDesc, messageDesc, serviceDesc } from "@bufbuild/protobuf/codegenv2";
import { file_google_api_annotations } from "../../google/api/annotations_pb";
import { file_google_api_client } from "../../google/api/client_pb";
import { file_google_api_field_behavior } from "../../google/api/field_behavior_pb";
import type { Message } from "@bufbuild/protobuf";
/**
* Describes the file api/v1/ai_service.proto.
*/
export const file_api_v1_ai_service: GenFile = /*@__PURE__*/
fileDesc("ChdhcGkvdjEvYWlfc2VydmljZS5wcm90bxIMbWVtb3MuYXBpLnYxIpsBChFUcmFuc2NyaWJlUmVxdWVzdBIYCgtwcm92aWRlcl9pZBgBIAEoCUID4EECEjYKBmNvbmZpZxgCIAEoCzIhLm1lbW9zLmFwaS52MS5UcmFuc2NyaXB0aW9uQ29uZmlnQgPgQQISNAoFYXVkaW8YAyABKAsyIC5tZW1vcy5hcGkudjEuVHJhbnNjcmlwdGlvbkF1ZGlvQgPgQQIiVQoTVHJhbnNjcmlwdGlvbkNvbmZpZxISCgVtb2RlbBgBIAEoCUID4EEBEhMKBnByb21wdBgCIAEoCUID4EEBEhUKCGxhbmd1YWdlGAMgASgJQgPgQQEidwoSVHJhbnNjcmlwdGlvbkF1ZGlvEhYKB2NvbnRlbnQYASABKAxCA+BBBEgAEg0KA3VyaRgCIAEoCUgAEhUKCGZpbGVuYW1lGAMgASgJQgPgQQESGQoMY29udGVudF90eXBlGAQgASgJQgPgQQFCCAoGc291cmNlIiIKElRyYW5zY3JpYmVSZXNwb25zZRIMCgR0ZXh0GAEgASgJMpoBCglBSVNlcnZpY2USjAEKClRyYW5zY3JpYmUSHy5tZW1vcy5hcGkudjEuVHJhbnNjcmliZVJlcXVlc3QaIC5tZW1vcy5hcGkudjEuVHJhbnNjcmliZVJlc3BvbnNlIjvaQRhwcm92aWRlcl9pZCxjb25maWcsYXVkaW+C0+STAho6ASoiFS9hcGkvdjEvYWk6dHJhbnNjcmliZUKmAQoQY29tLm1lbW9zLmFwaS52MUIOQWlTZXJ2aWNlUHJvdG9QAVowZ2l0aHViLmNvbS91c2VtZW1vcy9tZW1vcy9wcm90by9nZW4vYXBpL3YxO2FwaXYxogIDTUFYqgIMTWVtb3MuQXBpLlYxygIMTWVtb3NcQXBpXFYx4gIYTWVtb3NcQXBpXFYxXEdQQk1ldGFkYXRh6gIOTWVtb3M6OkFwaTo6VjFiBnByb3RvMw", [file_google_api_annotations, file_google_api_client, file_google_api_field_behavior]);
/**
* @generated from message memos.api.v1.TranscribeRequest
*/
export type TranscribeRequest = Message<"memos.api.v1.TranscribeRequest"> & {
/**
* Required. The instance AI provider ID to use.
*
* @generated from field: string provider_id = 1;
*/
providerId: string;
/**
* Required. Transcription options.
*
* @generated from field: memos.api.v1.TranscriptionConfig config = 2;
*/
config?: TranscriptionConfig;
/**
* Required. Audio input.
*
* @generated from field: memos.api.v1.TranscriptionAudio audio = 3;
*/
audio?: TranscriptionAudio;
};
/**
* Describes the message memos.api.v1.TranscribeRequest.
* Use `create(TranscribeRequestSchema)` to create a new message.
*/
export const TranscribeRequestSchema: GenMessage<TranscribeRequest> = /*@__PURE__*/
messageDesc(file_api_v1_ai_service, 0);
/**
* @generated from message memos.api.v1.TranscriptionConfig
*/
export type TranscriptionConfig = Message<"memos.api.v1.TranscriptionConfig"> & {
/**
* Optional. The model to use. If empty, the provider's default model is used.
*
* @generated from field: string model = 1;
*/
model: string;
/**
* Optional. A prompt to improve transcription quality.
*
* @generated from field: string prompt = 2;
*/
prompt: string;
/**
* Optional. The language of the input audio.
*
* @generated from field: string language = 3;
*/
language: string;
};
/**
* Describes the message memos.api.v1.TranscriptionConfig.
* Use `create(TranscriptionConfigSchema)` to create a new message.
*/
export const TranscriptionConfigSchema: GenMessage<TranscriptionConfig> = /*@__PURE__*/
messageDesc(file_api_v1_ai_service, 1);
/**
* @generated from message memos.api.v1.TranscriptionAudio
*/
export type TranscriptionAudio = Message<"memos.api.v1.TranscriptionAudio"> & {
/**
* @generated from oneof memos.api.v1.TranscriptionAudio.source
*/
source: {
/**
* Inline audio bytes.
*
* @generated from field: bytes content = 1;
*/
value: Uint8Array;
case: "content";
} | {
/**
* URI for audio content. Reserved for future use.
*
* @generated from field: string uri = 2;
*/
value: string;
case: "uri";
} | { case: undefined; value?: undefined };
/**
* Optional. The uploaded filename.
*
* @generated from field: string filename = 3;
*/
filename: string;
/**
* Optional. The MIME type of the input audio.
*
* @generated from field: string content_type = 4;
*/
contentType: string;
};
/**
* Describes the message memos.api.v1.TranscriptionAudio.
* Use `create(TranscriptionAudioSchema)` to create a new message.
*/
export const TranscriptionAudioSchema: GenMessage<TranscriptionAudio> = /*@__PURE__*/
messageDesc(file_api_v1_ai_service, 2);
/**
* @generated from message memos.api.v1.TranscribeResponse
*/
export type TranscribeResponse = Message<"memos.api.v1.TranscribeResponse"> & {
/**
* The transcribed text.
*
* @generated from field: string text = 1;
*/
text: string;
};
/**
* Describes the message memos.api.v1.TranscribeResponse.
* Use `create(TranscribeResponseSchema)` to create a new message.
*/
export const TranscribeResponseSchema: GenMessage<TranscribeResponse> = /*@__PURE__*/
messageDesc(file_api_v1_ai_service, 3);
/**
* @generated from service memos.api.v1.AIService
*/
export const AIService: GenService<{
/**
* Transcribe transcribes an audio file using an instance AI provider.
*
* @generated from rpc memos.api.v1.AIService.Transcribe
*/
transcribe: {
methodKind: "unary";
input: typeof TranscribeRequestSchema;
output: typeof TranscribeResponseSchema;
},
}> = /*@__PURE__*/
serviceDesc(file_api_v1_ai_service, 0);
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment