diff --git a/apps/api-backend/src/index.ts b/apps/api-backend/src/index.ts index 9e907ea..50320f3 100644 --- a/apps/api-backend/src/index.ts +++ b/apps/api-backend/src/index.ts @@ -3,115 +3,277 @@ import { prisma } from "db"; import { Elysia, t } from "elysia"; import { Conversation } from "./types"; import { Gemini } from "./llms/Gemini"; -import { OpenAi } from "./llms/Openai"; +import { OpenAi } from "./llms/OpenAi"; import { Claude } from "./llms/Claude"; -import { LlmResponse } from "./llms/Base"; +import { LlmResponse, StreamChunk } from "./llms/Base"; +import { generateCompletionId } from "./utils/generate-completion"; const app = new Elysia() -.use(bearer()) -.use(openapi()); -.post("/api/v1/chat/completions", async ({ status, bearer: apiKey, body }) => { - const model = body.model; - const [_companyName, providerModelName] = model.split("/"); - const apiKeyDb = await prisma.apiKey.findFirst({ - where: { - apiKey, - disabled: false, - deleted: false - }, - select: { - user: true - } - }) - - if (!apiKeyDb) { - return status(403, { - message: "Invalid api key" - }) - } - - if (apiKeyDb?.user.credits <= 0) { - return status(403, { - message: "You dont have enough credits in your db" - }) - } - - const modelDb = await prisma.model.findFirst({ - where: { - slug: model - } - }) - - if (!modelDb) { - return status(403, { - message: "This is an invalid model we dont support" - }) - } - - const providers = await prisma.modelProviderMapping.findMany({ - where: { - modelId: modelDb.id - }, - include: { - provider: true - } - }) - - const provider = providers[Math.floor(Math.random() * providers.length)]; - - let response: LlmResponse | null = null - if (provider.provider.name === "Google API") { - response = await Gemini.chat(providerModelName, body.messages) - } - - if (provider.provider.name === "Google Vertex") { - response = await Gemini.chat(providerModelName, body.messages) - } - - if (provider.provider.name === "OpenAI") { - response = await OpenAi.chat(providerModelName, body.messages) - } - - if (provider.provider.name === "Claude API") { - response = await Claude.chat(providerModelName, body.messages) - } - - if (!response) { - return status(403, { - message: "No provider found for this model" - }) - } - - const creditsUsed = (response.inputTokensConsumed * provider.inputTokenCost + response.outputTokensConsumed * provider.outputTokenCost) / 10; - console.log(creditsUsed); - const res = await prisma.user.update({ - where: { - id: apiKeyDb.user.id - }, - data: { - credits: { - decrement: creditsUsed + .use(bearer()) + .post( + "/api/v1/chat/completions", + async ({ status, bearer: apiKey, body }) => { + const model = body.model; + const [_companyName, providerModelName] = model.split("/"); + const apiKeyDb = await prisma.apiKey.findFirst({ + where: { + apiKey, + disabled: false, + deleted: false, + }, + select: { + id: true, + user: true, + }, + }); + + if (!apiKeyDb) { + return status(403, { + message: "Invalid api key", + }); + } + + if (apiKeyDb?.user.credits <= 0) { + return status(403, { + message: "You dont have enough credits in your db", + }); + } + + const modelDb = await prisma.model.findFirst({ + where: { + slug: model, + }, + }); + + if (!modelDb) { + return status(403, { + message: "This is an invalid model we dont support", + }); + } + + const providers = await prisma.modelProviderMapping.findMany({ + where: { + modelId: modelDb.id, + }, + include: { + provider: true, + }, + }); + + const provider = providers[Math.floor(Math.random() * providers.length)]; + + if (body.stream) { + const encoder = new TextEncoder(); + const stream = new TransformStream({ + transform(chunk: StreamChunk, controller) { + const data = `data: ${JSON.stringify(chunk)}\n\n`; + controller.enqueue(encoder.encode(data)); + }, + flush(controller) { + controller.enqueue(encoder.encode("data: [DONE]\n\n")); + }, + }); + const writer = stream.writable.getWriter(); + + (async () => { + let outputTokens = 0; + let inputTokens = 0; + let fullOutput = ""; + + try { + let streamGenerator: AsyncGenerator | null = null; + const completionId = generateCompletionId(); + if ( + provider.provider.name === "Google API" || + provider.provider.name === "Google Vertex" + ) { + streamGenerator = Gemini.streamChat( + completionId, + providerModelName, + body.messages, + ); + } else if (provider.provider.name === "OpenAI") { + streamGenerator = OpenAi.streamChat( + completionId, + providerModelName, + body.messages, + ); + } else if (provider.provider.name === "Claude API") { + streamGenerator = Claude.streamChat( + completionId, + providerModelName, + body.messages, + ); + } + + if (!streamGenerator) { + await writer.write({ + choices: [ + { + delta: { + content: "No provider found for this model", + }, + }, + ], + } as any); + await writer.close(); + return; + } + + inputTokens = body.messages.reduce( + (acc, msg) => acc + Math.ceil(msg.content.length / 4), + 0, + ); + + for await (const chunk of streamGenerator) { + await writer.write(chunk); + if (chunk.choices[0]?.delta?.content) { + outputTokens += Math.ceil( + chunk.choices[0].delta.content.length / 4, + ); + fullOutput += chunk.choices[0].delta.content; + } + } + const creditsUsed = + (inputTokens * provider.inputTokenCost + + outputTokens * provider.outputTokenCost) / + 10; + + await writer.write({ + id: completionId, + object: "chat.completion.chunk", + created: Math.floor(Date.now() / 1000), + model: providerModelName, + choices: [ + { + index: 0, + delta: {}, + finish_reason: "stop", + }, + ], + inputTokensConsumed: inputTokens, + outputTokensConsumed: outputTokens + }); + + + + + await prisma.conversation.create({ + data: { + userId: apiKeyDb.user.id, + apiKeyId: apiKeyDb.id, + modelProviderMappingId: provider.id, + input: JSON.stringify(body.messages), + output: fullOutput, + inputTokenCount: inputTokens, + outputTokenCount: outputTokens, + }, + }); + + await prisma.user.update({ + where: { id: apiKeyDb.user.id }, + data: { credits: { decrement: creditsUsed } }, + }); + await prisma.apiKey.update({ + where: { apiKey }, + data: { creditsConsumed: { increment: creditsUsed } }, + }); + } catch (error) { + console.error("Streaming error:", error); + } finally { + await writer.close(); + } + })(); + + return new Response(stream.readable, { + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }, + }); + } + + let response: LlmResponse | null = null; + if (provider.provider.name === "Google API") { + const completionId = generateCompletionId(); + response = await Gemini.chat(completionId, providerModelName, body.messages); + } + + if (provider.provider.name === "Google Vertex") { + const completionId = generateCompletionId(); + response = await Gemini.chat(completionId, providerModelName, body.messages); + } + + if (provider.provider.name === "OpenAI") { + const completionId = generateCompletionId(); + response = await OpenAi.chat(completionId, providerModelName, body.messages); } - } - }); - console.log(res) - const res2 = await prisma.apiKey.update({ - where: { - apiKey: apiKey - }, - data: { - creditsConsumed: { - increment: creditsUsed + + if (provider.provider.name === "Claude API") { + const completionId = generateCompletionId(); + response = await Claude.chat(completionId, providerModelName, body.messages); } - } - }) - console.log(res2) - return response; -}, { - body: Conversation -}).listen(4000); + if (!response) { + return status(403, { + message: "No provider found for this model", + }); + } + + const creditsUsed = + (response.inputTokensConsumed * provider.inputTokenCost + + response.outputTokensConsumed * provider.outputTokenCost) / + 10; + + const outputText = response.choices + .map((choice) => choice.message.content) + .join(""); + + await prisma.conversation.create({ + data: { + userId: apiKeyDb.user.id, + apiKeyId: apiKeyDb.id, + modelProviderMappingId: provider.id, + input: JSON.stringify(body.messages), + output: outputText, + inputTokenCount: response.inputTokensConsumed, + outputTokenCount: response.outputTokensConsumed, + }, + }); + + console.log(creditsUsed); + const res = await prisma.user.update({ + where: { + id: apiKeyDb.user.id, + }, + data: { + credits: { + decrement: creditsUsed, + }, + }, + }); + console.log(res); + const res2 = await prisma.apiKey.update({ + where: { + apiKey: apiKey, + }, + data: { + creditsConsumed: { + increment: creditsUsed, + }, + }, + }); + console.log(res2); + + return response; + }, + { + body: Conversation, + }, + ) + .listen(4000); console.log( - `🦊 Elysia is running at ${app.server?.hostname}:${app.server?.port}` + `🦊 Elysia is running at ${app.server?.hostname}:${app.server?.port}`, ); diff --git a/apps/api-backend/src/llms/Base.ts b/apps/api-backend/src/llms/Base.ts index 83aa463..a9cf480 100644 --- a/apps/api-backend/src/llms/Base.ts +++ b/apps/api-backend/src/llms/Base.ts @@ -1,19 +1,52 @@ import { Messages } from "../types"; export type LlmResponse = { - completions: { - choices: { - message: { - content: string - } - }[] - }, - inputTokensConsumed: number, - outputTokensConsumed: number -} + id: string; + object: "chat.completion"; + model: string; + choices: { + index: number; + message: { + content: string; + role: "assistant" + }; + finish_reason: "stop"; + }[]; + inputTokensConsumed: number; + outputTokensConsumed: number; + created: number; +}; + +export type StreamChunk = { + id: string; + object: "chat.completion.chunk"; + created: number; + model: string; + choices: { + index: number; + delta: { + content?: string; + }; + finish_reason: null | "stop", + }[]; + inputTokensConsumed?: number; + outputTokensConsumed?: number; +}; export class BaseLlm { - static async chat(model: string, messages: Messages): Promise { - throw new Error("Not implemented chat function") - } -} \ No newline at end of file + static async chat( + completionId: string, + model: string, + messages: Messages, + ): Promise { + throw new Error("Not implemented chat function"); + } + + static streamChat( + model: string, + completionId: string, + messages: Messages, + ): AsyncGenerator { + throw new Error("Not implemented streamChat function"); + } +} diff --git a/apps/api-backend/src/llms/Claude.ts b/apps/api-backend/src/llms/Claude.ts index cfcc604..305e4ef 100644 --- a/apps/api-backend/src/llms/Claude.ts +++ b/apps/api-backend/src/llms/Claude.ts @@ -1,36 +1,77 @@ import Anthropic from "@anthropic-ai/sdk"; import { Messages } from "../types"; -import { BaseLlm, LlmResponse } from "./Base"; +import { BaseLlm, LlmResponse, StreamChunk } from "./Base"; import { TextBlock } from "@anthropic-ai/sdk/resources"; const client = new Anthropic({ - apiKey: process.env.ANTHROPIC_API_KEY + apiKey: process.env.ANTHROPIC_API_KEY, }); - export class Claude extends BaseLlm { - static async chat(model: string, messages: Messages): Promise { - - const response = await client.messages.create({ - max_tokens: 2048, - messages: messages.map(message => ({ - role: message.role, - content: message.content - })), - model: model - }); + static async chat(completionId: string, model: string, messages: Messages): Promise { + const response = await client.messages.create({ + max_tokens: 2048, + messages: messages.map((message) => ({ + role: message.role, + content: message.content, + })), + model: model, + }); + + return { + outputTokensConsumed: response.usage.output_tokens, + inputTokensConsumed: response.usage.input_tokens, + object: "chat.completion", + id: completionId, + created: Math.floor(Date.now() / 1000), + model, + choices: response.content.map((content) => ({ + index: 0, + message: { + content: (content as TextBlock).text, + role: "assistant" + }, + finish_reason:"stop" + })), + }; + } + + static async *streamChat( + completionId: string, + model: string, + messages: Messages, + ): AsyncGenerator { + const stream = await client.messages.create({ + max_tokens: 2048, + messages: messages.map((message) => ({ + role: message.role, + content: message.content, + })), + model: model, + stream: true, + }); - return { - outputTokensConsumed: response.usage.output_tokens, - inputTokensConsumed: response.usage.input_tokens, - completions: { - choices: response.content.map(content => ({ - message: { - content: (content as TextBlock).text - } - })) - } + for await (const chunk of stream) { + if (chunk.type === "content_block_delta") { + const delta = chunk.delta as { text?: string }; + if (delta.text) { + yield { + id: completionId, + object: "chat.completion.chunk", + model, + created: Math.floor(Date.now() / 1000), + choices: [ + { + index: 0, + delta: { + content: delta.text, + }, + finish_reason: null + }, + ], + }; } - + } } -} \ No newline at end of file + } +} diff --git a/apps/api-backend/src/llms/Gemini.ts b/apps/api-backend/src/llms/Gemini.ts index 8e910ca..bdd01af 100644 --- a/apps/api-backend/src/llms/Gemini.ts +++ b/apps/api-backend/src/llms/Gemini.ts @@ -1,32 +1,73 @@ import { Messages } from "../types"; -import { BaseLlm, LlmResponse } from "./Base"; +import { BaseLlm, LlmResponse, StreamChunk } from "./Base"; import { GoogleGenAI } from "@google/genai"; const ai = new GoogleGenAI({ - apiKey: process.env.GOOGLE_API_KEY + apiKey: process.env.GOOGLE_API_KEY, }); - export class Gemini extends BaseLlm { - static async chat(model: string, messages: Messages): Promise { - const response = await ai.models.generateContent({ - model: model, - contents: messages.map(message => ({ - text: message.content, - role: message.role - })) - }); + static async chat(completionId: string, model: string, messages: Messages): Promise { + const response = await ai.models.generateContent({ + model: model, + contents: messages.map((message) => ({ + text: message.content, + role: message.role, + })), + }); + + return { + outputTokensConsumed: response.usageMetadata?.candidatesTokenCount ?? 0, + inputTokensConsumed: response.usageMetadata?.promptTokenCount ?? 0, + id: completionId, + model, + created: Math.floor(Date.now() / 1000), + object: "chat.completion", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: response.text ?? "", + }, + finish_reason: "stop" + }, + ], + }; + } + + static async *streamChat( + completionId: string, + model: string, + messages: Messages, + ): AsyncGenerator { + const response = await ai.models.generateContentStream({ + model: model, + contents: messages.map((message) => ({ + text: message.content, + role: message.role, + })), + }); - return { - outputTokensConsumed: response.usageMetadata?.candidatesTokenCount!, - inputTokensConsumed: response.usageMetadata?.promptTokenCount!, - completions: { - choices: [{ - message: { - content: response.text! - } - }] - } - } + for await (const chunk of response) { + const text = chunk.text; + if (text) { + yield { + id: completionId, + object: "chat.completion.chunk", + model, + created: Math.floor(Date.now() / 1000), + choices: [ + { + index: 0, + delta: { + content: text, + }, + finish_reason: null + }, + ], + }; + } } -} \ No newline at end of file + } +} diff --git a/apps/api-backend/src/llms/OpenAi.ts b/apps/api-backend/src/llms/OpenAi.ts index 86ae53c..fed5e45 100644 --- a/apps/api-backend/src/llms/OpenAi.ts +++ b/apps/api-backend/src/llms/OpenAi.ts @@ -1,30 +1,72 @@ import { Messages } from "../types"; -import { BaseLlm, LlmResponse } from "./Base"; +import { BaseLlm, LlmResponse, StreamChunk } from "./Base"; import OpenAI from "openai"; const client = new OpenAI({ - apiKey: process.env.OPENAI_API_KEY + apiKey: process.env.OPENAI_API_KEY, }); export class OpenAi extends BaseLlm { - static async chat(model: string, messages: Messages): Promise { - const response = await client.responses.create({ - model: model, - input: messages.map(message => ({ - role: message.role, - content: message.content - })) - }); + static async chat(completionId: string, model: string, messages: Messages): Promise { + const response = await client.responses.create({ + model: model, + input: messages.map((message) => ({ + role: message.role, + content: message.content, + })), + }); - return { - inputTokensConsumed: response.usage?.input_tokens!, - outputTokensConsumed: response.usage?.output_tokens!, - completions: { - choices: [{ - message: { - content: response.output_text - } - }] - } - } + return { + inputTokensConsumed: response.usage?.input_tokens ?? 0, + outputTokensConsumed: response.usage?.output_tokens ?? 0, + model, + created: Math.floor(Date.now() / 1000), + id: completionId, + object: "chat.completion", + choices: [ + { + index: 0, + message: { + role: "assistant", + content: response.output_text, + }, + finish_reason: "stop" + }, + ], + }; + } + + static async *streamChat( + completionId: string, + model: string, + messages: Messages, + ): AsyncGenerator { + const stream = await client.responses.create({ + model: model, + input: messages.map((message) => ({ + role: message.role, + content: message.content, + })), + stream: true, + }); + + for await (const chunk of stream) { + if (chunk.type === "response.output_text.delta") { + yield { + id: completionId, + object: "chat.completion.chunk", + model, + created: Math.floor(Date.now() / 1000), + choices: [ + { + index: 0, + delta: { + content: chunk.delta, + }, + finish_reason: null + }, + ], + }; + } } -} \ No newline at end of file + } +} diff --git a/apps/api-backend/src/types.ts b/apps/api-backend/src/types.ts index fc70dea..20e398a 100644 --- a/apps/api-backend/src/types.ts +++ b/apps/api-backend/src/types.ts @@ -1,18 +1,19 @@ import { t } from "elysia"; - -export const Messages = t.Array(t.Object({ +export const Messages = t.Array( + t.Object({ role: t.Enum({ user: "user", - assistant: "assistant" + assistant: "assistant", }), - content: t.String() -})) + content: t.String(), + }), +); export type Messages = typeof Messages.static; export const Conversation = t.Object({ - model: t.String(), - messages: Messages -}) - + model: t.String(), + messages: Messages, + stream: t.Optional(t.Boolean()), +}); diff --git a/apps/api-backend/src/utils/generate-completion.ts b/apps/api-backend/src/utils/generate-completion.ts new file mode 100644 index 0000000..d06820d --- /dev/null +++ b/apps/api-backend/src/utils/generate-completion.ts @@ -0,0 +1,3 @@ +export function generateCompletionId(): string { + return `chatcmpl-${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 10)}`; +} diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..b8844c8 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,32 @@ +services: + postgres: + image: postgres:15-alpine + ports: + - "5439:5432" + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: postgres + volumes: + - ./packages/db/prisma/data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres"] + interval: 5s + timeout: 5s + retries: 5 + + seed: + build: + context: . + dockerfile: packages/db/Dockerfile.seed + environment: + DATABASE_URL: postgresql://postgres:postgres@postgres:5432/postgres + depends_on: + postgres: + condition: service_healthy + profiles: + - seed + +volumes: + prisma: + driver: local diff --git a/packages/db/.gitignore b/packages/db/.gitignore index 1a386ec..e582a76 100644 --- a/packages/db/.gitignore +++ b/packages/db/.gitignore @@ -34,3 +34,5 @@ report.[0-9]_.[0-9]_.[0-9]_.[0-9]_.json .DS_Store /generated/prisma + +prisma/data diff --git a/packages/db/Dockerfile.seed b/packages/db/Dockerfile.seed new file mode 100644 index 0000000..b20c834 --- /dev/null +++ b/packages/db/Dockerfile.seed @@ -0,0 +1,25 @@ +FROM oven/bun:latest + +WORKDIR /app + +# Install dependencies for Prisma +RUN apt-get update && apt-get install -y openssl + +# Copy root package files +COPY package.json bun.lock ./ + +# Copy the db package +COPY packages/db ./packages/db + +# Install dependencies +RUN bun install + +# Generate Prisma client +WORKDIR /app/packages/db +RUN bun prisma generate + +# Set environment variable for database URL (will be overridden at runtime) +ENV DATABASE_URL="" + +# Command to run the seed script +CMD ["bun", "prisma/seed.ts"] diff --git a/packages/db/package.json b/packages/db/package.json index 35a8810..4db145d 100644 --- a/packages/db/package.json +++ b/packages/db/package.json @@ -6,6 +6,9 @@ ".": "./index.ts" }, "private": true, + "prisma": { + "seed": "bun prisma/seed.ts" + }, "devDependencies": { "@types/bun": "latest", "@types/pg": "^8.16.0", diff --git a/packages/db/prisma.config.ts b/packages/db/prisma.config.ts index 831a20f..d0279d3 100644 --- a/packages/db/prisma.config.ts +++ b/packages/db/prisma.config.ts @@ -7,6 +7,7 @@ export default defineConfig({ schema: "prisma/schema.prisma", migrations: { path: "prisma/migrations", + seed: "bun ./prisma/seed.ts", }, datasource: { url: process.env["DATABASE_URL"], diff --git a/packages/db/prisma/seed.ts b/packages/db/prisma/seed.ts new file mode 100644 index 0000000..c1957a1 --- /dev/null +++ b/packages/db/prisma/seed.ts @@ -0,0 +1,162 @@ +import { PrismaPg } from "@prisma/adapter-pg"; +import { PrismaClient } from "../generated/prisma/client"; + +const adapter = new PrismaPg({ + connectionString: process.env.DATABASE_URL!, +}); + +const prisma = new PrismaClient({ + adapter, +}); + +async function main() { + console.log("Starting database seed..."); + + // Seed Companies + const companies = await Promise.all([ + prisma.company.upsert({ + where: { id: 1 }, + update: {}, + create: { + id: 1, + name: "OpenAI", + website: "https://chat.com", + }, + }), + prisma.company.upsert({ + where: { id: 2 }, + update: {}, + create: { + id: 2, + name: "anthropic", + website: "https://claud.ai", + }, + }), + prisma.company.upsert({ + where: { id: 3 }, + update: {}, + create: { + id: 3, + name: "google", + website: "https://gemini.google.com", + }, + }), + ]); + console.log(`Created ${companies.length} companies`); + + // Seed Providers + const providers = await Promise.all([ + prisma.provider.upsert({ + where: { id: 1 }, + update: {}, + create: { + id: 1, + name: "Google API", + website: "https://aistudio.google.com", + }, + }), + prisma.provider.upsert({ + where: { id: 2 }, + update: {}, + create: { + id: 2, + name: "Claude API", + website: "https://claude.ai", + }, + }), + prisma.provider.upsert({ + where: { id: 3 }, + update: {}, + create: { + id: 3, + name: "OpenAI", + website: "https://api.chat.com", + }, + }), + ]); + console.log(`Created ${providers.length} providers`); + + // Seed Models + const models = await Promise.all([ + prisma.model.upsert({ + where: { id: 1 }, + update: {}, + create: { + id: 1, + name: "Google: Gemini 3 Flash Preview", + slug: "google/gemini-3-flash-preview", + companyId: 3, + }, + }), + prisma.model.upsert({ + where: { id: 2 }, + update: {}, + create: { + id: 2, + name: "Google: Gemini 2.5 Pro", + slug: "google/gemini-2.5-pro", + companyId: 3, + }, + }), + prisma.model.upsert({ + where: { id: 3 }, + update: {}, + create: { + id: 3, + name: "Google: Gemini 2.5 Flash", + slug: "google/gemini-2.5-flash", + companyId: 3, + }, + }), + ]); + console.log(`Created ${models.length} models`); + + // Seed Model Provider Mappings + const mappings = await Promise.all([ + prisma.modelProviderMapping.upsert({ + where: { id: 1 }, + update: {}, + create: { + id: 1, + modelId: 1, + providerId: 1, + inputTokenCost: 0, + outputTokenCost: 0, + }, + }), + prisma.modelProviderMapping.upsert({ + where: { id: 2 }, + update: {}, + create: { + id: 2, + modelId: 2, + providerId: 1, + inputTokenCost: 0, + outputTokenCost: 0, + }, + }), + prisma.modelProviderMapping.upsert({ + where: { id: 3 }, + update: {}, + create: { + id: 3, + modelId: 3, + providerId: 1, + inputTokenCost: 0, + outputTokenCost: 0, + }, + }), + ]); + console.log(`Created ${mappings.length} model-provider mappings`); + + console.log("Database seed completed successfully!"); +} + +main() + .catch((e) => { + console.error("Error seeding database:", e); + process.exit(1); + }) + .finally(async () => { + await prisma.$disconnect(); + });