Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"@zenstackhq/orm": "workspace:*",
"@zenstackhq/schema": "workspace:*",
"@zenstackhq/sdk": "workspace:*",
"@zenstackhq/plugin-policy": "workspace:*",
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"@zenstackhq/server": "workspace:*",
"chokidar": "^5.0.0",
"colors": "1.4.0",
Expand Down Expand Up @@ -98,4 +99,4 @@
"node": ">=20"
},
"funding": "https://github.com/sponsors/zenstackhq"
}
}
235 changes: 228 additions & 7 deletions packages/cli/src/actions/proxy.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import { PostgresDialect } from '@zenstackhq/orm/dialects/postgres';
import { SqliteDialect } from '@zenstackhq/orm/dialects/sqlite';
import type { SchemaDef } from '@zenstackhq/orm/schema';
import { PolicyPlugin } from '@zenstackhq/plugin-policy';
import { RPCApiHandler } from '@zenstackhq/server/api';
import { ZenStackMiddleware } from '@zenstackhq/server/express';
import type BetterSqlite3 from 'better-sqlite3';
Expand All @@ -20,21 +21,72 @@
import express from 'express';
import { createJiti } from 'jiti';
import type { createPool as MysqlCreatePool } from 'mysql2';
import { verify } from 'node:crypto';
import path from 'node:path';
import type { Pool as PgPoolType } from 'pg';
import { CliError } from '../cli-error';
import { getVersion } from '../utils/version-utils';
import { getOutputPath, getSchemaFile, loadSchemaDocument } from './action-utils';
import { z } from 'zod';

type Options = {
output?: string;
schema?: string;
port?: number;
port: number;
logLevel?: string[];
databaseUrl?: string;
studioAuthKey?: string;
signatureToleranceSecs: number;
};

export const ProxyAuthError = {
MISSING_SIGNATURE_HEADER: 'Missing x-zenstack-signature header',
INVALID_TIMESTAMP: 'Request timestamp is expired or invalid',
INVALID_SIGNATURE_FORMAT: 'Invalid x-zenstack-signature format',
} as const;

type ProxyAuthErrorCode = keyof typeof ProxyAuthError;

function rejectAuth(res: express.Response, code: ProxyAuthErrorCode) {
return res.status(401).json({ code, message: ProxyAuthError[code] });
}
/**
* Represents the identity claim embedded in the Authorization header.
* The bearer token is a plain base64-encoded JSON string.
*/
const UserClaimSchema = z.discriminatedUnion('type', [
z.object({ type: z.literal('superUser') }),
z.object({ type: z.literal('user'), data: z.record(z.string(), z.unknown()) }),
]);

type UserClaim = z.infer<typeof UserClaimSchema>;

/**
* Accepts a public key in either PEM format or as a raw base64 / base64url DER string
* (without the `-----BEGIN PUBLIC KEY-----` markers) and always returns a PEM string.
*/
function normalizePublicKey(key: string): string {
key = key.trim();
if (key.startsWith('-----BEGIN PUBLIC KEY-----')) {
return key;
}
// Convert base64url → standard base64, then wrap in PEM markers.
const b64 = key.replace(/-/g, '+').replace(/_/g, '/');
return `-----BEGIN PUBLIC KEY-----\n${b64}\n-----END PUBLIC KEY-----`;
}

export async function run(options: Options) {
// Resolve public key: CLI arg takes precedence, then ZENSTACK_STUDIO_AUTH_KEY env var.
options = { ...options, studioAuthKey: options.studioAuthKey ?? process.env['ZENSTACK_STUDIO_AUTH_KEY'] };
if (!options.studioAuthKey) {
console.warn(
colors.yellow(
'Warning: This proxy has no authentication. Do not expose it to the public network.\n' +
'To secure it, get an API key from ZenStack Studio and set it via the ZENSTACK_STUDIO_AUTH_KEY environment variable.',
),
);
}

const allowedLogLevels = ['error', 'query'] as const;
const log = options.logLevel?.filter((level): level is (typeof allowedLogLevels)[number] =>
allowedLogLevels.includes(level as any),
Expand Down Expand Up @@ -104,7 +156,14 @@
throw new CliError(`Failed to connect to the database: ${err instanceof Error ? err.message : String(err)}`);
}

startServer(db, schemaModule.schema, options);
// If a studioAuthKey is provided, create an authDb with the policy plugin
let authDb: ClientContract<SchemaDef> | undefined;
if (options.studioAuthKey) {
authDb = db.$use(new PolicyPlugin()) as ClientContract<SchemaDef>;
console.log(colors.gray('Access policy plugin enabled for authorization.'));
}

startServer(db, schemaModule.schema, options, authDb);
}

function evaluateUrl(schemaUrl: ConfigExpr) {
Expand Down Expand Up @@ -198,17 +257,41 @@
}
}

export function createProxyApp(client: ClientContract<SchemaDef>, schema: SchemaDef): express.Application {
export function createProxyApp(
client: ClientContract<SchemaDef>,
schema: SchemaDef,
auth?: {
studioAuthKey: string;
authDb: ClientContract<SchemaDef>;
/** Seconds within which a signed request is considered valid. Defaults to 60. */
signatureToleranceSecs: number;
},
): express.Application {
const app = express();
app.use(cors());
app.use(express.json({ limit: '5mb' }));
app.use(
express.json({
limit: '5mb',
verify: (req, _res, buf) => {
// Capture the raw body string for use in signature verification.
(req as express.Request & { rawBody?: string }).rawBody = buf.toString('utf8');
},
}),
);
app.use(express.urlencoded({ extended: true, limit: '5mb' }));

if (auth?.studioAuthKey) {
// Apply signature-verification middleware to all authenticated endpoints.
const toleranceSecs = auth.signatureToleranceSecs;
const normalizedKey = normalizePublicKey(auth.studioAuthKey);
app.use(['/api/model', '/api/schema'], createSignatureMiddleware(normalizedKey, toleranceSecs));

Check failure

Code scanning / CodeQL

Missing rate limiting High

This route handler performs
authorization
, but is not rate-limited.
This route handler performs
authorization
, but is not rate-limited.
This route handler performs
authorization
, but is not rate-limited.
This route handler performs
authorization
, but is not rate-limited.
This route handler performs
authorization
, but is not rate-limited.
This route handler performs
authorization
, but is not rate-limited.
}

app.use(
'/api/model',
ZenStackMiddleware({
apiHandler: new RPCApiHandler({ schema }),
getClient: () => client,
getClient: (req) => resolveClient(client, auth?.authDb, req),
}),
);

Expand All @@ -219,8 +302,146 @@
return app;
}

function startServer(client: ClientContract<SchemaDef>, schema: any, options: Options) {
const app = createProxyApp(client, schema);
/**
* Creates an Express middleware that verifies the ed25519 signature on every request.
*
* The signature header format is: `t=<unix-timestamp>,v1=<base64url-signature>`
*
* The signed message is constructed as:
* - GET requests: `<raw-query-string><timestamp>[<authorizationToken>]`
* - Other methods: `<raw-body><timestamp>[<authorizationToken>]`
*
* `authorizationToken` is the bearer token value from the `Authorization` header (if present).
*/
function createSignatureMiddleware(publicKey: string, toleranceSeconds: number) {
// Throttle invalid-signature warnings to at most once per 60 seconds.
let lastInvalidSigWarnAt = 0;
const WARN_THROTTLE_SECS = 60;

function warnInvalidSignature() {
const now = Math.floor(Date.now() / 1000);
if (now - lastInvalidSigWarnAt >= WARN_THROTTLE_SECS) {
lastInvalidSigWarnAt = now;
console.warn(
colors.yellow(
'Warning: Received a request with an invalid signature. ' +
'Please double-check whether you have the correct public API key configured.',
),
);
}
}

return (req: express.Request, res: express.Response, next: express.NextFunction) => {
const signatureHeader = req.headers['x-zenstack-signature'];
if (!signatureHeader || typeof signatureHeader !== 'string') {
return rejectAuth(res, 'MISSING_SIGNATURE_HEADER');
}

const parts = signatureHeader.split(',');
const timestampPart = parts.find((p) => p.startsWith('t='));
const sigPart = parts.find((p) => p.startsWith('v1='));
if (!timestampPart || !sigPart) {
return rejectAuth(res, 'INVALID_SIGNATURE_FORMAT');
}
const timestamp = timestampPart.substring(2);
const sig = sigPart.substring(3);

// Replay-attack prevention: reject requests whose timestamp deviates
// from server time by more than the configured tolerance.
const requestTime = parseInt(timestamp, 10);
const now = Math.floor(Date.now() / 1000);
if (isNaN(requestTime) || Math.abs(now - requestTime) > toleranceSeconds) {
return rejectAuth(res, 'INVALID_TIMESTAMP');
}

// Payload: raw query string for GET/DELETE, raw body for other methods.
let payload: string;
if (req.method === 'GET' || req.method === 'DELETE') {
const qMark = req.originalUrl.indexOf('?');
payload = qMark >= 0 ? req.originalUrl.substring(qMark + 1) : '';
} else {
payload = (req as express.Request & { rawBody?: string }).rawBody ?? '';
}

// authorizationToken is the bearer token value (if present).
const authHeader = req.headers['authorization'];
const authorizationToken = authHeader && authHeader.startsWith('Bearer ') ? authHeader.substring(7) : undefined;

Comment thread
jiashengguo marked this conversation as resolved.
const message = authorizationToken ? `${payload}${timestamp}${authorizationToken}` : `${payload}${timestamp}`;

try {
const isValid = verify(null, Buffer.from(message, 'utf8'), publicKey, Buffer.from(sig, 'base64url'));
if (!isValid) {
warnInvalidSignature();
return rejectAuth(res, 'INVALID_SIGNATURE_FORMAT');
}
} catch {
warnInvalidSignature();
return rejectAuth(res, 'INVALID_SIGNATURE_FORMAT');
}

return next();
};
}

/**
* Resolves the appropriate client for a request based on the Authorization header.
*
* - No studioAuthKey configured (authDb is undefined): always return the base client.
* - SuperUser claim: return the base client (full access, no policy enforcement).
* - Regular user claim: return authDb with the user identity set via $setAuth.
* - No / invalid token: return the base client.
*/
function resolveClient(
client: ClientContract<SchemaDef>,
authDb: ClientContract<SchemaDef> | undefined,
req: express.Request,
): ClientContract<SchemaDef> {
if (!authDb) {
return client;
}

const authHeader = req.headers['authorization'];
if (!authHeader?.startsWith('Bearer ')) {
return authDb;
}
Comment thread
jiashengguo marked this conversation as resolved.

const token = authHeader.substring(7);
let claim: UserClaim;
try {
claim = UserClaimSchema.parse(JSON.parse(Buffer.from(token, 'base64').toString('utf8')));
} catch (err) {
console.error(
colors.red(`Failed to parse user claim from token: ${err instanceof Error ? err.message : String(err)}`),
);
return authDb;
}
Comment thread
jiashengguo marked this conversation as resolved.

if (claim.type === 'superUser') {
// SuperUser has full access without policy enforcement, so we return the base client directly.
return client;
} else {
return authDb.$setAuth(claim.data) as ClientContract<SchemaDef>;
}
}

function startServer(
client: ClientContract<SchemaDef>,
schema: any,
options: Options,
authDb?: ClientContract<SchemaDef>,
) {
const app = createProxyApp(
client,
schema,
options.studioAuthKey
? {
studioAuthKey: options.studioAuthKey,
authDb: authDb!,
signatureToleranceSecs: options.signatureToleranceSecs,
}
: undefined,
);

const server = app.listen(options.port, () => {
console.log(`ZenStack proxy server is running on port: ${options.port}`);
Expand Down
20 changes: 20 additions & 0 deletions packages/cli/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,26 @@ Arguments following -- are passed to the seed script. E.g.: "zen db seed -- --us
.addOption(new Option('-o, --output <path>', 'output directory for `zen generate` command'))
.addOption(new Option('-d, --databaseUrl <url>', 'database connection URL'))
.addOption(new Option('-l, --logLevel <level...>', 'Query log levels (e.g., query, error)'))
.addOption(
new Option(
'--studioAuthKey <key>',
'Authentication key from ZenStack Studio. When set, the proxy will only accept requests signed by your Studio project.\nCan also be set via the ZENSTACK_STUDIO_AUTH_KEY environment variable. ',
),
)
.addOption(
new Option(
'--signatureToleranceSecs <seconds>',
'Maximum age (in seconds) of a signed request before it is rejected as a replay. Defaults to 60.',
)
.default(60)
.argParser((v) => {
const parsed = parseInt(v, 10);
if (isNaN(parsed) || parsed < 0) {
throw new CliError(`--signatureToleranceSecs must be a positive integer, got: ${v}`);
}
return parsed;
}),
)
Comment thread
jiashengguo marked this conversation as resolved.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
.addOption(noVersionCheckOption)
.action(proxyAction);

Expand Down
Loading
Loading