Unverified Commit 1a80b0a1 authored by Moaaz Siddiqui's avatar Moaaz Siddiqui Committed by GitHub

feat: tool-calling support across providers (#3)

parent 19beb5cb
......@@ -69,6 +69,7 @@ The problem is that stacking them by hand is painful: fourteen different SDKs, f
- **OpenAI-compatible**`POST /v1/chat/completions` and `GET /v1/models` work with the official OpenAI SDKs and any OpenAI-compatible client (LangChain, LlamaIndex, Continue, Hermes, etc.). Just change `base_url`.
- **Streaming and non-streaming** — Server-Sent Events for `stream: true`, JSON response otherwise. Every provider adapter implements both.
- **Tool calling** — OpenAI-style `tools` / `tool_choice` requests are passed through, and assistant `tool_calls` + `tool` role follow-up messages round-trip across providers.
- **Automatic fallover** — If the chosen provider returns a 429, 5xx, or times out, the router skips it, puts the key on a short cooldown, and retries on the next model in your fallback chain (up to 20 attempts).
- **Per-key rate tracking** — RPM, RPD, TPM, and TPD counters per `(platform, model, key)` so the router always picks a key that's under its caps.
- **Sticky sessions** — Multi-turn conversations keep talking to the same model for 30 minutes to avoid the hallucination spike that comes from mid-conversation model switches.
......@@ -86,7 +87,6 @@ The scope is deliberately narrow. If a feature isn't on this list and isn't belo
- **Embeddings** (`/v1/embeddings`)
- **Image generation** (`/v1/images/*`)
- **Audio / speech** (`/v1/audio/*`)
- **Function / tool calling** — the request schema doesn't pass `tools` through yet
- **Vision / multimodal inputs** — message content is text-only
- **Legacy completions** (`/v1/completions`) — only the chat endpoint is implemented
- **Moderation** (`/v1/moderations`)
......
......@@ -16,13 +16,22 @@ describe('CloudflareProvider', () => {
it('should parse account_id:token key format', async () => {
let capturedUrl = '';
let capturedHeaders: Record<string, string> = {};
let capturedBody: any = null;
vi.spyOn(global, 'fetch').mockImplementation(async (url, init) => {
capturedUrl = url as string;
capturedHeaders = (init as any).headers;
capturedBody = JSON.parse((init as any).body);
return {
ok: true,
json: () => Promise.resolve({ result: { response: 'Hello from CF!' } }),
json: () => Promise.resolve({
id: 'chatcmpl-cf',
object: 'chat.completion',
created: 123,
model: '@cf/meta/llama-3.1-70b-instruct',
choices: [{ index: 0, message: { role: 'assistant', content: 'Hello from CF!' }, finish_reason: 'stop' }],
usage: { prompt_tokens: 1, completion_tokens: 2, total_tokens: 3 },
}),
} as any;
});
......@@ -33,8 +42,9 @@ describe('CloudflareProvider', () => {
);
expect(capturedUrl).toContain('abc123');
expect(capturedUrl).toContain('@cf/meta/llama-3.1-70b-instruct');
expect(capturedUrl).toContain('/ai/v1/chat/completions');
expect(capturedHeaders['Authorization']).toBe('Bearer my-token-here');
expect(capturedBody.model).toBe('@cf/meta/llama-3.1-70b-instruct');
expect(result.choices[0].message.content).toBe('Hello from CF!');
});
......
......@@ -13,23 +13,45 @@ describe('CohereProvider', () => {
expect(provider.name).toBe('Cohere');
});
it('should translate response to OpenAI format', async () => {
vi.spyOn(global, 'fetch').mockResolvedValueOnce({
it('should call compatibility API and return OpenAI response', async () => {
let capturedUrl = '';
let capturedBody: any = null;
vi.spyOn(global, 'fetch').mockImplementationOnce(async (url, init) => {
capturedUrl = String(url);
capturedBody = JSON.parse((init as any).body);
return {
ok: true,
json: () => Promise.resolve({
id: 'cohere-123',
message: { content: [{ type: 'text', text: 'Hello from Cohere!' }] },
finish_reason: 'COMPLETE',
usage: { tokens: { input_tokens: 10, output_tokens: 5 } },
object: 'chat.completion',
created: 123,
model: 'command-a-03-2025',
choices: [{ index: 0, message: { role: 'assistant', content: 'Hello from Cohere!' }, finish_reason: 'stop' }],
usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
}),
} as any);
} as any;
});
const result = await provider.chatCompletion(
'test-key',
[{ role: 'user', content: 'Hi' }],
'command-r-plus-08-2024',
{
tools: [{
type: 'function',
function: {
name: 'get_weather',
parameters: {
type: 'object',
properties: { city: { type: 'string' } },
},
},
}],
},
);
expect(capturedUrl).toContain('/compatibility/v1/chat/completions');
expect(capturedBody.tools).toHaveLength(1);
expect(result.object).toBe('chat.completion');
expect(result.choices[0].message.content).toBe('Hello from Cohere!');
expect(result.usage.prompt_tokens).toBe(10);
......
......@@ -92,4 +92,83 @@ describe('GoogleProvider', () => {
expect(capturedBody.contents).toHaveLength(1);
expect(capturedBody.contents[0].role).toBe('user');
});
it('should translate OpenAI tools/tool_choice to Gemini tools/toolConfig', async () => {
let capturedBody: any;
vi.spyOn(global, 'fetch').mockImplementation(async (_url, init) => {
capturedBody = JSON.parse((init as any).body);
return {
ok: true,
json: () => Promise.resolve({
candidates: [{ content: { parts: [{ text: 'ok' }] }, finishReason: 'STOP' }],
usageMetadata: { promptTokenCount: 1, candidatesTokenCount: 1, totalTokenCount: 2 },
}),
} as any;
});
await provider.chatCompletion(
'test-key',
[{ role: 'user', content: 'Weather in Karachi?' }],
'gemini-2.5-pro',
{
tools: [{
type: 'function',
function: {
name: 'get_weather',
description: 'Get weather for a city',
parameters: {
type: 'object',
properties: { city: { type: 'string' } },
required: ['city'],
},
},
}],
tool_choice: {
type: 'function',
function: { name: 'get_weather' },
},
},
);
expect(capturedBody.tools[0].functionDeclarations[0].name).toBe('get_weather');
expect(capturedBody.toolConfig.functionCallingConfig.mode).toBe('ANY');
expect(capturedBody.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(['get_weather']);
});
it('should translate Gemini functionCall response to OpenAI tool_calls', async () => {
vi.spyOn(global, 'fetch').mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
candidates: [{
content: {
parts: [{
functionCall: {
id: 'call_123',
name: 'get_weather',
args: { city: 'Lahore' },
},
}],
},
finishReason: 'STOP',
}],
usageMetadata: {
promptTokenCount: 12,
candidatesTokenCount: 3,
totalTokenCount: 15,
},
}),
} as any);
const result = await provider.chatCompletion(
'test-key',
[{ role: 'user', content: 'What is the weather?' }],
'gemini-2.5-pro',
);
expect(result.choices[0].finish_reason).toBe('tool_calls');
expect(result.choices[0].message.content).toBeNull();
expect(result.choices[0].message.tool_calls?.[0].id).toBe('call_123');
expect(result.choices[0].message.tool_calls?.[0].function.name).toBe('get_weather');
expect(result.choices[0].message.tool_calls?.[0].function.arguments).toBe('{"city":"Lahore"}');
});
});
......@@ -21,10 +21,12 @@ describe('OpenAICompatProvider', () => {
it('should call API with correct URL and headers', async () => {
let capturedUrl = '';
let capturedHeaders: Record<string, string> = {};
let capturedBody: any = null;
vi.spyOn(global, 'fetch').mockImplementation(async (url, init) => {
capturedUrl = url as string;
capturedHeaders = (init as any).headers;
capturedBody = JSON.parse((init as any).body);
return {
ok: true,
json: () => Promise.resolve({
......@@ -43,6 +45,51 @@ describe('OpenAICompatProvider', () => {
expect(capturedUrl).toBe('https://api.test.com/v1/chat/completions');
expect(capturedHeaders['Authorization']).toBe('Bearer my-key');
expect(capturedHeaders['X-Custom']).toBe('test');
expect(capturedBody.messages[0].role).toBe('user');
});
it('should pass tool-calling params through untouched', async () => {
let capturedBody: any = null;
vi.spyOn(global, 'fetch').mockImplementation(async (_url, init) => {
capturedBody = JSON.parse((init as any).body);
return {
ok: true,
json: () => Promise.resolve({
id: 'test-id',
object: 'chat.completion',
created: 123,
model: 'test-model',
choices: [{ index: 0, message: { role: 'assistant', content: null, tool_calls: [] }, finish_reason: 'stop' }],
usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 },
}),
} as any;
});
await provider.chatCompletion(
'my-key',
[{ role: 'user', content: 'what is weather?' }],
'test-model',
{
tools: [{
type: 'function',
function: {
name: 'get_weather',
description: 'Get weather',
parameters: {
type: 'object',
properties: { city: { type: 'string' } },
required: ['city'],
},
},
}],
tool_choice: 'required',
parallel_tool_calls: true,
},
);
expect(capturedBody.tools).toHaveLength(1);
expect(capturedBody.tool_choice).toBe('required');
expect(capturedBody.parallel_tool_calls).toBe(true);
});
it('should throw on error response', async () => {
......
import { describe, it, expect, beforeAll, beforeEach, afterEach, vi } from 'vitest';
import type { Express } from 'express';
import { createApp } from '../../app.js';
import { initDb, getDb } from '../../db/index.js';
async function request(app: Express, method: string, path: string, body?: any) {
const server = app.listen(0);
const addr = server.address() as any;
const url = `http://127.0.0.1:${addr.port}${path}`;
const res = await fetch(url, {
method,
headers: body ? { 'Content-Type': 'application/json' } : {},
body: body ? JSON.stringify(body) : undefined,
});
const data = await res.text();
server.close();
let json: any = null;
try { json = JSON.parse(data); } catch {}
return { status: res.status, body: json, headers: res.headers, raw: data };
}
describe('Proxy tool-calling support', () => {
let app: Express;
beforeAll(() => {
process.env.ENCRYPTION_KEY = '0'.repeat(64);
initDb(':memory:');
app = createApp();
});
beforeEach(async () => {
const db = getDb();
db.prepare('DELETE FROM api_keys').run();
db.prepare('DELETE FROM requests').run();
const addKey = await request(app, 'POST', '/api/keys', {
platform: 'groq',
key: 'gsk_proxy_tool_test',
label: 'proxy-tools',
});
expect(addKey.status).toBe(201);
});
afterEach(() => {
vi.restoreAllMocks();
});
it('passes tools/tool_choice to provider and returns tool_calls', async () => {
const origFetch = global.fetch;
let providerBody: any = null;
vi.spyOn(global, 'fetch').mockImplementation(async (url, init) => {
const urlStr = typeof url === 'string' ? url : url.toString();
if (urlStr.includes('api.groq.com/openai/v1/chat/completions')) {
providerBody = JSON.parse((init as any).body);
return {
ok: true,
json: () => Promise.resolve({
id: 'chatcmpl-tool',
object: 'chat.completion',
created: 123,
model: 'openai/gpt-oss-120b',
choices: [{
index: 0,
message: {
role: 'assistant',
content: null,
tool_calls: [{
id: 'call_weather',
type: 'function',
function: {
name: 'get_weather',
arguments: '{"city":"Karachi"}',
},
}],
},
finish_reason: 'tool_calls',
}],
usage: { prompt_tokens: 12, completion_tokens: 4, total_tokens: 16 },
}),
} as any;
}
return origFetch(url, init);
});
const { status, body } = await request(app, 'POST', '/v1/chat/completions', {
model: 'auto',
messages: [{ role: 'user', content: 'What is the weather in Karachi?' }],
tools: [{
type: 'function',
function: {
name: 'get_weather',
description: 'Get current weather',
parameters: {
type: 'object',
properties: { city: { type: 'string' } },
required: ['city'],
},
},
}],
tool_choice: 'required',
});
expect(status).toBe(200);
expect(providerBody.tools).toHaveLength(1);
expect(providerBody.tool_choice).toBe('required');
expect(body.choices[0].finish_reason).toBe('tool_calls');
expect(body.choices[0].message.tool_calls[0].function.name).toBe('get_weather');
});
it('accepts assistant tool_calls + tool messages in follow-up turns', async () => {
const origFetch = global.fetch;
let providerBody: any = null;
vi.spyOn(global, 'fetch').mockImplementation(async (url, init) => {
const urlStr = typeof url === 'string' ? url : url.toString();
if (urlStr.includes('api.groq.com/openai/v1/chat/completions')) {
providerBody = JSON.parse((init as any).body);
return {
ok: true,
json: () => Promise.resolve({
id: 'chatcmpl-final',
object: 'chat.completion',
created: 123,
model: 'openai/gpt-oss-120b',
choices: [{
index: 0,
message: {
role: 'assistant',
content: 'It is 30C in Karachi.',
},
finish_reason: 'stop',
}],
usage: { prompt_tokens: 18, completion_tokens: 6, total_tokens: 24 },
}),
} as any;
}
return origFetch(url, init);
});
const { status, body } = await request(app, 'POST', '/v1/chat/completions', {
messages: [
{ role: 'user', content: 'Weather in Karachi?' },
{
role: 'assistant',
content: null,
tool_calls: [{
id: 'call_weather_1',
type: 'function',
function: {
name: 'get_weather',
arguments: '{"city":"Karachi"}',
},
}],
},
{
role: 'tool',
tool_call_id: 'call_weather_1',
content: '{"temp_c":30}',
},
],
});
expect(status).toBe(200);
expect(providerBody.messages[1].role).toBe('assistant');
expect(providerBody.messages[1].content).toBeNull();
expect(providerBody.messages[1].tool_calls).toHaveLength(1);
expect(providerBody.messages[2].role).toBe('tool');
expect(providerBody.messages[2].tool_call_id).toBe('call_weather_1');
expect(body.choices[0].message.content).toContain('30C');
});
});
......@@ -2,6 +2,8 @@ import type {
ChatMessage,
ChatCompletionResponse,
ChatCompletionChunk,
ChatToolDefinition,
ChatToolChoice,
Platform,
} from '@freellmapi/shared/types.js';
......@@ -10,6 +12,9 @@ export interface CompletionOptions {
temperature?: number;
max_tokens?: number;
top_p?: number;
tools?: ChatToolDefinition[];
tool_choice?: ChatToolChoice;
parallel_tool_calls?: boolean;
}
export abstract class BaseProvider {
......
......@@ -27,7 +27,7 @@ export class CloudflareProvider extends BaseProvider {
options?: CompletionOptions,
): Promise<ChatCompletionResponse> {
const { accountId, token } = this.parseKey(apiKey);
const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${modelId}`;
const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/v1/chat/completions`;
const res = await this.fetchWithTimeout(url, {
method: 'POST',
......@@ -36,38 +36,25 @@ export class CloudflareProvider extends BaseProvider {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: modelId,
messages,
max_tokens: options?.max_tokens,
temperature: options?.temperature,
max_tokens: options?.max_tokens,
top_p: options?.top_p,
tools: options?.tools,
tool_choice: options?.tool_choice,
parallel_tool_calls: options?.parallel_tool_calls,
}),
});
if (!res.ok) {
const err = await res.json().catch(() => ({}));
const errors = (err as any).errors;
throw new Error(`Cloudflare API error ${res.status}: ${errors?.[0]?.message ?? res.statusText}`);
throw new Error(`Cloudflare API error ${res.status}: ${(err as any).error?.message ?? (err as any).errors?.[0]?.message ?? res.statusText}`);
}
const data = await res.json() as any;
const text = data.result?.response ?? '';
return {
id: this.makeId(),
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: modelId,
choices: [{
index: 0,
message: { role: 'assistant', content: text },
finish_reason: 'stop',
}],
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
_routed_via: { platform: 'cloudflare', model: modelId },
};
const data = await res.json() as ChatCompletionResponse;
data._routed_via = { platform: 'cloudflare', model: modelId };
return data;
}
async *streamChatCompletion(
......@@ -77,7 +64,7 @@ export class CloudflareProvider extends BaseProvider {
options?: CompletionOptions,
): AsyncGenerator<ChatCompletionChunk> {
const { accountId, token } = this.parseKey(apiKey);
const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/run/${modelId}`;
const url = `https://api.cloudflare.com/client/v4/accounts/${accountId}/ai/v1/chat/completions`;
const res = await this.fetchWithTimeout(url, {
method: 'POST',
......@@ -86,23 +73,27 @@ export class CloudflareProvider extends BaseProvider {
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: modelId,
messages,
max_tokens: options?.max_tokens,
temperature: options?.temperature,
max_tokens: options?.max_tokens,
top_p: options?.top_p,
tools: options?.tools,
tool_choice: options?.tool_choice,
parallel_tool_calls: options?.parallel_tool_calls,
stream: true,
}),
});
if (!res.ok) {
const err = await res.json().catch(() => ({}));
throw new Error(`Cloudflare API error ${res.status}: ${(err as any).errors?.[0]?.message ?? res.statusText}`);
throw new Error(`Cloudflare API error ${res.status}: ${(err as any).error?.message ?? (err as any).errors?.[0]?.message ?? res.statusText}`);
}
const reader = res.body?.getReader();
if (!reader) throw new Error('No response body');
const decoder = new TextDecoder();
const id = this.makeId();
let buffer = '';
while (true) {
......@@ -119,17 +110,10 @@ export class CloudflareProvider extends BaseProvider {
const data = trimmed.slice(6);
if (data === '[DONE]') return;
try {
const parsed = JSON.parse(data);
if (parsed.response) {
yield {
id,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: modelId,
choices: [{ index: 0, delta: { content: parsed.response }, finish_reason: null }],
};
yield JSON.parse(data) as ChatCompletionChunk;
} catch {
// Skip malformed chunks
}
} catch { /* skip */ }
}
}
}
......
......@@ -5,16 +5,7 @@ import type {
} from '@freellmapi/shared/types.js';
import { BaseProvider, type CompletionOptions } from './base.js';
const API_BASE = 'https://api.cohere.com/v2';
interface CohereResponse {
id: string;
message?: { content?: { type: string; text: string }[] };
finish_reason?: string;
usage?: {
tokens?: { input_tokens?: number; output_tokens?: number };
};
}
const API_BASE = 'https://api.cohere.ai/compatibility/v1';
export class CohereProvider extends BaseProvider {
readonly platform = 'cohere' as const;
......@@ -26,51 +17,33 @@ export class CohereProvider extends BaseProvider {
modelId: string,
options?: CompletionOptions,
): Promise<ChatCompletionResponse> {
const cohereMessages = messages.map(m => ({
role: m.role === 'system' ? 'system' as const : m.role === 'assistant' ? 'assistant' as const : 'user' as const,
content: m.content,
}));
const body: Record<string, unknown> = {
model: modelId,
messages,
temperature: options?.temperature,
max_tokens: options?.max_tokens,
top_p: options?.top_p,
tools: options?.tools,
tool_choice: options?.tool_choice,
};
const res = await this.fetchWithTimeout(`${API_BASE}/chat`, {
const res = await this.fetchWithTimeout(`${API_BASE}/chat/completions`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: modelId,
messages: cohereMessages,
temperature: options?.temperature,
max_tokens: options?.max_tokens,
p: options?.top_p,
}),
body: JSON.stringify(body),
});
if (!res.ok) {
const err = await res.json().catch(() => ({}));
throw new Error(`Cohere API error ${res.status}: ${(err as any).message ?? res.statusText}`);
throw new Error(`Cohere API error ${res.status}: ${(err as any).error?.message ?? res.statusText}`);
}
const data = await res.json() as CohereResponse;
const text = data.message?.content?.[0]?.text ?? '';
return {
id: data.id ?? this.makeId(),
object: 'chat.completion',
created: Math.floor(Date.now() / 1000),
model: modelId,
choices: [{
index: 0,
message: { role: 'assistant', content: text },
finish_reason: data.finish_reason ?? 'stop',
}],
usage: {
prompt_tokens: data.usage?.tokens?.input_tokens ?? 0,
completion_tokens: data.usage?.tokens?.output_tokens ?? 0,
total_tokens: (data.usage?.tokens?.input_tokens ?? 0) + (data.usage?.tokens?.output_tokens ?? 0),
},
_routed_via: { platform: 'cohere', model: modelId },
};
const data = await res.json() as ChatCompletionResponse;
data._routed_via = { platform: 'cohere', model: modelId };
return data;
}
async *streamChatCompletion(
......@@ -79,36 +52,35 @@ export class CohereProvider extends BaseProvider {
modelId: string,
options?: CompletionOptions,
): AsyncGenerator<ChatCompletionChunk> {
const cohereMessages = messages.map(m => ({
role: m.role === 'system' ? 'system' as const : m.role === 'assistant' ? 'assistant' as const : 'user' as const,
content: m.content,
}));
const body: Record<string, unknown> = {
model: modelId,
messages,
temperature: options?.temperature,
max_tokens: options?.max_tokens,
top_p: options?.top_p,
tools: options?.tools,
tool_choice: options?.tool_choice,
stream: true,
};
const res = await this.fetchWithTimeout(`${API_BASE}/chat`, {
const res = await this.fetchWithTimeout(`${API_BASE}/chat/completions`, {
method: 'POST',
headers: {
'Authorization': `Bearer ${apiKey}`,
'Content-Type': 'application/json',
},
body: JSON.stringify({
model: modelId,
messages: cohereMessages,
temperature: options?.temperature,
max_tokens: options?.max_tokens,
stream: true,
}),
body: JSON.stringify(body),
});
if (!res.ok) {
const err = await res.json().catch(() => ({}));
throw new Error(`Cohere API error ${res.status}: ${(err as any).message ?? res.statusText}`);
throw new Error(`Cohere API error ${res.status}: ${(err as any).error?.message ?? res.statusText}`);
}
const reader = res.body?.getReader();
if (!reader) throw new Error('No response body');
const decoder = new TextDecoder();
const id = this.makeId();
let buffer = '';
while (true) {
......@@ -121,31 +93,13 @@ export class CohereProvider extends BaseProvider {
for (const line of lines) {
const trimmed = line.trim();
if (!trimmed) continue;
if (!trimmed || !trimmed.startsWith('data: ')) continue;
const data = trimmed.slice(6);
if (data === '[DONE]') return;
try {
const event = JSON.parse(trimmed);
if (event.type === 'content-delta') {
const text = event.delta?.message?.content?.text ?? '';
if (text) {
yield {
id,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: modelId,
choices: [{ index: 0, delta: { content: text }, finish_reason: null }],
};
}
} else if (event.type === 'message-end') {
yield {
id,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: modelId,
choices: [{ index: 0, delta: {}, finish_reason: 'stop' }],
};
}
yield JSON.parse(data) as ChatCompletionChunk;
} catch {
// Skip malformed lines
// Skip malformed chunks
}
}
}
......
......@@ -2,40 +2,205 @@ import type {
ChatMessage,
ChatCompletionResponse,
ChatCompletionChunk,
ChatToolCall,
ChatToolChoice,
ChatToolDefinition,
TokenUsage,
} from '@freellmapi/shared/types.js';
import { BaseProvider, type CompletionOptions } from './base.js';
const API_BASE = 'https://generativelanguage.googleapis.com/v1beta';
interface GeminiPart {
text?: string;
functionCall?: {
id?: string;
name?: string;
args?: unknown;
};
functionResponse?: {
id?: string;
name?: string;
response?: unknown;
};
}
interface GeminiCandidate {
content?: { parts?: GeminiPart[] };
finishReason?: string;
}
interface GeminiResponse {
candidates?: GeminiCandidate[];
usageMetadata?: {
promptTokenCount?: number;
candidatesTokenCount?: number;
totalTokenCount?: number;
};
}
function safeParseObject(raw: string): Record<string, unknown> {
try {
const parsed = JSON.parse(raw) as unknown;
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed)) {
return parsed as Record<string, unknown>;
}
return { value: parsed };
} catch {
return { value: raw };
}
}
function normalizeGeminiArgs(args: unknown): string {
if (typeof args === 'string') return args;
return JSON.stringify(args ?? {});
}
function toGeminiFinishReason(finishReason?: string): string {
const r = (finishReason ?? '').toUpperCase();
if (!r) return 'stop';
if (r === 'MAX_TOKENS') return 'length';
if (r === 'SAFETY' || r === 'RECITATION' || r === 'BLOCKLIST' || r === 'PROHIBITED_CONTENT' || r === 'SPII') {
return 'content_filter';
}
return 'stop';
}
function toGeminiTools(tools?: ChatToolDefinition[]): Array<{ functionDeclarations: Array<Record<string, unknown>> }> | undefined {
if (!tools || tools.length === 0) return undefined;
return [{
functionDeclarations: tools.map(t => ({
name: t.function.name,
description: t.function.description,
parameters: t.function.parameters,
})),
}];
}
function toGeminiToolConfig(toolChoice?: ChatToolChoice): { functionCallingConfig: Record<string, unknown> } | undefined {
if (!toolChoice) return undefined;
if (typeof toolChoice === 'string') {
const mode =
toolChoice === 'none'
? 'NONE'
: toolChoice === 'required'
? 'ANY'
: 'AUTO';
return { functionCallingConfig: { mode } };
}
return {
functionCallingConfig: {
mode: 'ANY',
allowedFunctionNames: [toolChoice.function.name],
},
};
}
// Translate OpenAI messages to Gemini format
function toGeminiContents(messages: ChatMessage[]) {
const systemInstruction = messages.find(m => m.role === 'system');
const systemMessages = messages
.filter(m => m.role === 'system' && typeof m.content === 'string' && m.content.length > 0)
.map(m => m.content as string);
const toolNameByCallId = new Map<string, string>();
for (const m of messages) {
for (const tc of m.tool_calls ?? []) {
toolNameByCallId.set(tc.id, tc.function.name);
}
}
const contents = messages
.filter(m => m.role !== 'system')
.map(m => ({
role: m.role === 'assistant' ? 'model' : 'user',
parts: [{ text: m.content }],
}));
.map((m): { role: 'user' | 'model'; parts: GeminiPart[] } | null => {
if (m.role === 'assistant') {
const parts: GeminiPart[] = [];
if (typeof m.content === 'string' && m.content.length > 0) {
parts.push({ text: m.content });
}
for (const call of m.tool_calls ?? []) {
parts.push({
functionCall: {
id: call.id,
name: call.function.name,
args: safeParseObject(call.function.arguments),
},
});
}
if (parts.length === 0) return null;
return {
role: 'model',
parts,
};
}
if (m.role === 'tool') {
const toolCallId = m.tool_call_id;
if (!toolCallId) return null;
const toolName = m.name ?? toolNameByCallId.get(toolCallId) ?? 'tool';
const response = safeParseObject(typeof m.content === 'string' ? m.content : '');
return {
role: 'user',
parts: [{
functionResponse: {
id: toolCallId,
name: toolName,
response,
},
}],
};
}
return {
role: 'user',
parts: [{ text: typeof m.content === 'string' ? m.content : '' }],
};
})
.filter((entry): entry is { role: 'user' | 'model'; parts: GeminiPart[] } => entry !== null);
return {
contents,
systemInstruction: systemInstruction
? { parts: [{ text: systemInstruction.content }] }
systemInstruction: systemMessages.length > 0
? { parts: [{ text: systemMessages.join('\n\n') }] }
: undefined,
};
}
interface GeminiResponse {
candidates?: {
content?: { parts?: { text?: string }[] };
finishReason?: string;
}[];
usageMetadata?: {
promptTokenCount?: number;
candidatesTokenCount?: number;
totalTokenCount?: number;
};
function extractToolCalls(parts: GeminiPart[] | undefined): ChatToolCall[] {
const calls: ChatToolCall[] = [];
if (!parts) return calls;
let fallbackIndex = 0;
for (const part of parts) {
if (!part.functionCall?.name) continue;
const id = part.functionCall.id ?? `call_${Date.now()}_${fallbackIndex++}`;
calls.push({
id,
type: 'function',
function: {
name: part.functionCall.name,
arguments: normalizeGeminiArgs(part.functionCall.args),
},
});
}
return calls;
}
function extractText(parts: GeminiPart[] | undefined): string | null {
if (!parts) return null;
const text = parts
.map(p => p.text ?? '')
.join('');
return text.length > 0 ? text : null;
}
export class GoogleProvider extends BaseProvider {
......@@ -57,6 +222,8 @@ export class GoogleProvider extends BaseProvider {
maxOutputTokens: options?.max_tokens,
topP: options?.top_p,
},
tools: toGeminiTools(options?.tools),
toolConfig: toGeminiToolConfig(options?.tool_choice),
};
if (systemInstruction) body.systemInstruction = systemInstruction;
......@@ -73,8 +240,11 @@ export class GoogleProvider extends BaseProvider {
}
const data = await res.json() as GeminiResponse;
const candidate = data.candidates?.[0];
const parts = candidate?.content?.parts;
const toolCalls = extractToolCalls(parts);
const text = extractText(parts);
const text = data.candidates?.[0]?.content?.parts?.[0]?.text ?? '';
const usage: TokenUsage = {
prompt_tokens: data.usageMetadata?.promptTokenCount ?? 0,
completion_tokens: data.usageMetadata?.candidatesTokenCount ?? 0,
......@@ -88,8 +258,12 @@ export class GoogleProvider extends BaseProvider {
model: modelId,
choices: [{
index: 0,
message: { role: 'assistant', content: text },
finish_reason: data.candidates?.[0]?.finishReason?.toLowerCase() === 'stop' ? 'stop' : 'stop',
message: {
role: 'assistant',
content: text,
...(toolCalls.length > 0 ? { tool_calls: toolCalls } : {}),
},
finish_reason: toolCalls.length > 0 ? 'tool_calls' : toGeminiFinishReason(candidate?.finishReason),
}],
usage,
_routed_via: { platform: 'google', model: modelId },
......@@ -111,6 +285,8 @@ export class GoogleProvider extends BaseProvider {
maxOutputTokens: options?.max_tokens,
topP: options?.top_p,
},
tools: toGeminiTools(options?.tools),
toolConfig: toGeminiToolConfig(options?.tool_choice),
};
if (systemInstruction) body.systemInstruction = systemInstruction;
......@@ -132,6 +308,10 @@ export class GoogleProvider extends BaseProvider {
const decoder = new TextDecoder();
const id = this.makeId();
let buffer = '';
let emittedFinish = false;
let sawToolCalls = false;
const seenToolCallKeys = new Set<string>();
while (true) {
const { done, value } = await reader.read();
......@@ -145,12 +325,38 @@ export class GoogleProvider extends BaseProvider {
const trimmed = line.trim();
if (!trimmed || !trimmed.startsWith('data: ')) continue;
const raw = trimmed.slice(6);
if (raw === '[DONE]') return;
if (raw === '[DONE]') {
if (!emittedFinish) {
emittedFinish = true;
yield {
id,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: modelId,
choices: [{
index: 0,
delta: {},
finish_reason: sawToolCalls ? 'tool_calls' : 'stop',
}],
};
}
return;
}
const chunk = JSON.parse(raw) as GeminiResponse;
const text = chunk.candidates?.[0]?.content?.parts?.[0]?.text ?? '';
if (!text) continue;
const candidate = chunk.candidates?.[0];
const parts = candidate?.content?.parts ?? [];
const text = extractText(parts);
const toolCalls = extractToolCalls(parts).filter(call => {
const key = `${call.id}:${call.function.name}:${call.function.arguments}`;
if (seenToolCallKeys.has(key)) return false;
seenToolCallKeys.add(key);
return true;
});
if ((text && text.length > 0) || toolCalls.length > 0) {
sawToolCalls = sawToolCalls || toolCalls.length > 0;
yield {
id,
object: 'chat.completion.chunk',
......@@ -158,14 +364,34 @@ export class GoogleProvider extends BaseProvider {
model: modelId,
choices: [{
index: 0,
delta: { content: text },
delta: {
...(text ? { content: text } : {}),
...(toolCalls.length > 0 ? { tool_calls: toolCalls } : {}),
},
finish_reason: null,
}],
};
}
if (candidate?.finishReason && !emittedFinish) {
emittedFinish = true;
yield {
id,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model: modelId,
choices: [{
index: 0,
delta: {},
finish_reason: sawToolCalls ? 'tool_calls' : toGeminiFinishReason(candidate.finishReason),
}],
};
return;
}
}
}
// Final chunk
if (!emittedFinish) {
yield {
id,
object: 'chat.completion.chunk',
......@@ -174,10 +400,11 @@ export class GoogleProvider extends BaseProvider {
choices: [{
index: 0,
delta: {},
finish_reason: 'stop',
finish_reason: sawToolCalls ? 'tool_calls' : 'stop',
}],
};
}
}
async validateKey(apiKey: string): Promise<boolean> {
try {
......
......@@ -30,6 +30,9 @@ export class HuggingFaceProvider extends BaseProvider {
temperature: options?.temperature,
max_tokens: options?.max_tokens,
top_p: options?.top_p,
tools: options?.tools,
tool_choice: options?.tool_choice,
parallel_tool_calls: options?.parallel_tool_calls,
}),
});
......@@ -60,6 +63,10 @@ export class HuggingFaceProvider extends BaseProvider {
messages,
temperature: options?.temperature,
max_tokens: options?.max_tokens,
top_p: options?.top_p,
tools: options?.tools,
tool_choice: options?.tool_choice,
parallel_tool_calls: options?.parallel_tool_calls,
stream: true,
}),
});
......
......@@ -68,10 +68,10 @@ register(new OpenAICompatProvider({
baseUrl: 'https://models.inference.ai.azure.com',
}));
// Cohere - unique API format
// Cohere - OpenAI-compatible via Cohere compatibility endpoint
register(new CohereProvider());
// Cloudflare Workers AI - unique API format (key = "account_id:token")
// Cloudflare Workers AI - OpenAI-compatible endpoint (key = "account_id:token")
register(new CloudflareProvider());
// Hugging Face - OpenAI-compatible per-model endpoint
......
......@@ -52,6 +52,9 @@ export class OpenAICompatProvider extends BaseProvider {
temperature: options?.temperature,
max_tokens: options?.max_tokens,
top_p: options?.top_p,
tools: options?.tools,
tool_choice: options?.tool_choice,
parallel_tool_calls: options?.parallel_tool_calls,
}),
});
......@@ -84,6 +87,9 @@ export class OpenAICompatProvider extends BaseProvider {
temperature: options?.temperature,
max_tokens: options?.max_tokens,
top_p: options?.top_p,
tools: options?.tools,
tool_choice: options?.tool_choice,
parallel_tool_calls: options?.parallel_tool_calls,
stream: true,
}),
});
......
import { Router } from 'express';
import type { Request, Response } from 'express';
import { z } from 'zod';
import type { ChatMessage } from '@freellmapi/shared/types.js';
import { routeRequest, recordRateLimitHit, recordSuccess, type RouteResult } from '../services/router.js';
import { recordRequest, recordTokens, setCooldown } from '../services/ratelimit.js';
import { getDb, getUnifiedApiKey } from '../db/index.js';
......@@ -13,16 +14,16 @@ export const proxyRouter = Router();
const stickySessionMap = new Map<string, { modelDbId: number; lastUsed: number }>();
const STICKY_TTL_MS = 30 * 60 * 1000; // 30 min session TTL
function getSessionKey(messages: { role: string; content: string }[]): string {
function getSessionKey(messages: ChatMessage[]): string {
// Use the first user message as session identifier
// Hermes sends the full conversation each time, so first user msg is stable
const firstUser = messages.find(m => m.role === 'user');
if (!firstUser) return '';
if (!firstUser || typeof firstUser.content !== 'string') return '';
// Hash: first 100 chars of first user message + message count
return `${firstUser.content.slice(0, 100)}:${messages.length > 2 ? 'multi' : 'single'}`;
}
function getStickyModel(messages: { role: string; content: string }[]): number | undefined {
function getStickyModel(messages: ChatMessage[]): number | undefined {
// Only apply sticky for multi-turn (has assistant messages = continuation)
const hasAssistant = messages.some(m => m.role === 'assistant');
if (!hasAssistant) return undefined;
......@@ -40,7 +41,7 @@ function getStickyModel(messages: { role: string; content: string }[]): number |
return entry.modelDbId;
}
function setStickyModel(messages: { role: string; content: string }[], modelDbId: number) {
function setStickyModel(messages: ChatMessage[], modelDbId: number) {
const key = getSessionKey(messages);
if (!key) return;
stickySessionMap.set(key, { modelDbId, lastUsed: Date.now() });
......@@ -73,16 +74,82 @@ proxyRouter.get('/models', (_req: Request, res: Response) => {
const MAX_RETRIES = 20;
const chatCompletionSchema = z.object({
messages: z.array(z.object({
role: z.enum(['system', 'user', 'assistant']),
const toolCallSchema = z.object({
id: z.string().min(1),
type: z.literal('function'),
function: z.object({
name: z.string().min(1),
arguments: z.string(),
}),
});
const systemMessageSchema = z.object({
role: z.literal('system'),
content: z.string(),
name: z.string().optional(),
});
const userMessageSchema = z.object({
role: z.literal('user'),
content: z.string(),
})).min(1),
name: z.string().optional(),
});
const assistantMessageSchema = z.object({
role: z.literal('assistant'),
content: z.string().nullable().optional(),
name: z.string().optional(),
tool_calls: z.array(toolCallSchema).optional(),
}).refine((msg) => {
const hasContent = typeof msg.content === 'string' && msg.content.length > 0;
const hasToolCalls = (msg.tool_calls?.length ?? 0) > 0;
return hasContent || hasToolCalls;
}, {
message: 'assistant messages must include non-empty content or tool_calls',
});
const toolMessageSchema = z.object({
role: z.literal('tool'),
content: z.string(),
tool_call_id: z.string().min(1),
name: z.string().optional(),
});
const toolDefinitionSchema = z.object({
type: z.literal('function'),
function: z.object({
name: z.string().min(1),
description: z.string().optional(),
parameters: z.record(z.string(), z.unknown()).optional(),
strict: z.boolean().optional(),
}),
});
const toolChoiceSchema = z.union([
z.enum(['none', 'auto', 'required']),
z.object({
type: z.literal('function'),
function: z.object({
name: z.string().min(1),
}),
}),
]);
const chatCompletionSchema = z.object({
messages: z.array(z.union([
systemMessageSchema,
userMessageSchema,
assistantMessageSchema,
toolMessageSchema,
])).min(1),
model: z.string().optional(),
temperature: z.number().min(0).max(2).optional(),
max_tokens: z.number().int().positive().optional(),
top_p: z.number().min(0).max(1).optional(),
stream: z.boolean().optional(),
tools: z.array(toolDefinitionSchema).optional(),
tool_choice: toolChoiceSchema.optional(),
parallel_tool_calls: z.boolean().optional(),
});
function isRetryableError(err: any): boolean {
......@@ -124,8 +191,37 @@ proxyRouter.post('/chat/completions', async (req: Request, res: Response) => {
return;
}
const { messages, temperature, max_tokens, top_p, stream } = parsed.data;
const estimatedInputTokens = messages.reduce((sum, m) => sum + Math.ceil(m.content.length / 4), 0);
const { temperature, max_tokens, top_p, stream, tools, tool_choice, parallel_tool_calls } = parsed.data;
const messages: ChatMessage[] = parsed.data.messages.map((m): ChatMessage => {
if (m.role === 'assistant') {
return {
role: 'assistant',
content: m.content ?? null,
...(m.name ? { name: m.name } : {}),
...(m.tool_calls ? { tool_calls: m.tool_calls } : {}),
};
}
if (m.role === 'tool') {
return {
role: 'tool',
content: m.content,
tool_call_id: m.tool_call_id,
...(m.name ? { name: m.name } : {}),
};
}
return {
role: m.role,
content: m.content,
...(m.name ? { name: m.name } : {}),
};
});
const estimatedInputTokens = messages.reduce((sum, m) => {
if (typeof m.content !== 'string') return sum;
return sum + Math.ceil(m.content.length / 4);
}, 0);
const estimatedTotal = estimatedInputTokens + (max_tokens ?? 1000);
// Sticky session: prefer the same model for multi-turn conversations
......@@ -170,7 +266,7 @@ proxyRouter.post('/chat/completions', async (req: Request, res: Response) => {
let totalOutputTokens = 0;
const gen = route.provider.streamChatCompletion(
route.apiKey, messages, route.modelId,
{ temperature, max_tokens, top_p },
{ temperature, max_tokens, top_p, tools, tool_choice, parallel_tool_calls },
);
for await (const chunk of gen) {
......@@ -190,7 +286,7 @@ proxyRouter.post('/chat/completions', async (req: Request, res: Response) => {
} else {
const result = await route.provider.chatCompletion(
route.apiKey, messages, route.modelId,
{ temperature, max_tokens, top_p },
{ temperature, max_tokens, top_p, tools, tool_choice, parallel_tool_calls },
);
const totalTokens = result.usage?.total_tokens ?? 0;
......
......@@ -66,9 +66,46 @@ export interface FallbackEntry {
// ---- OpenAI-Compatible Types ----
export interface ChatToolCallFunction {
name: string;
arguments: string;
}
export interface ChatToolCall {
id: string;
type: 'function';
function: ChatToolCallFunction;
}
export interface ChatToolFunctionDefinition {
name: string;
description?: string;
parameters?: Record<string, unknown>;
strict?: boolean;
}
export interface ChatToolDefinition {
type: 'function';
function: ChatToolFunctionDefinition;
}
export type ChatToolChoice =
| 'none'
| 'auto'
| 'required'
| {
type: 'function';
function: {
name: string;
};
};
export interface ChatMessage {
role: 'system' | 'user' | 'assistant';
content: string;
role: 'system' | 'user' | 'assistant' | 'tool';
content: string | null;
name?: string;
tool_call_id?: string;
tool_calls?: ChatToolCall[];
}
export interface ChatCompletionRequest {
......@@ -78,6 +115,9 @@ export interface ChatCompletionRequest {
max_tokens?: number;
stream?: boolean;
top_p?: number;
tools?: ChatToolDefinition[];
tool_choice?: ChatToolChoice;
parallel_tool_calls?: boolean;
}
export interface ChatCompletionChoice {
......@@ -112,7 +152,11 @@ export interface ChatCompletionChunk {
model: string;
choices: {
index: number;
delta: Partial<ChatMessage>;
delta: {
role?: 'assistant';
content?: string;
tool_calls?: ChatToolCall[];
};
finish_reason: string | null;
}[];
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment