Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions packages/client/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,15 +165,38 @@ export class WebSocketTransport implements BAPTransport {
/** Called when reconnection succeeds */
onReconnected: (() => void) | null = null;

private baseUrl: string;
private token: string | undefined;

constructor(
private readonly url: string,
url: string,
options: WebSocketTransportOptions = {}
) {
this.baseUrl = url;
this.maxReconnectAttempts = options.maxReconnectAttempts ?? 5;
this.reconnectDelay = options.reconnectDelay ?? 1000;
this.autoReconnect = options.autoReconnect ?? false;
}

/**
* Get the current connection URL, including token if set
*/
private getConnectionUrl(): string {
if (!this.token) {
return this.baseUrl;
}
const urlObj = new URL(this.baseUrl);
urlObj.searchParams.set("token", this.token);
return urlObj.toString();
}

/**
* Update the authentication token. Takes effect on the next connection/reconnection.
*/
updateToken(newToken: string): void {
this.token = newToken;
}

/**
* Connect to the WebSocket server
*/
Expand All @@ -189,7 +212,7 @@ export class WebSocketTransport implements BAPTransport {
this.ws = null;
}

this.ws = new WebSocket(this.url);
this.ws = new WebSocket(this.getConnectionUrl());

this.ws.on("open", () => {
this.reconnectAttempts = 0;
Expand Down Expand Up @@ -382,13 +405,11 @@ export class BAPClient extends EventEmitter {
};

if (typeof urlOrTransport === "string") {
let url = urlOrTransport;
const wsTransport = new WebSocketTransport(urlOrTransport);
if (options.token) {
const urlObj = new URL(url);
urlObj.searchParams.set("token", options.token);
url = urlObj.toString();
wsTransport.updateToken(options.token);
}
this.transport = new WebSocketTransport(url);
this.transport = wsTransport;
} else {
this.transport = urlOrTransport;
}
Expand All @@ -398,6 +419,18 @@ export class BAPClient extends EventEmitter {
this.transport.onError = (error) => this.emit("error", error);
}

/**
* Update the authentication token.
* Takes effect on the next connection or reconnection attempt.
* Only works when the transport is a WebSocketTransport.
*/
updateToken(newToken: string): void {
this.options.token = newToken;
if (this.transport instanceof WebSocketTransport) {
this.transport.updateToken(newToken);
}
}

// ===========================================================================
// Connection Management
// ===========================================================================
Expand Down
169 changes: 133 additions & 36 deletions packages/mcp/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* Exposes Browser Agent Protocol as an MCP (Model Context Protocol) server.
* Allows AI agents to control browsers through standardized MCP tools.
*
* TODO (MEDIUM): Add input validation on tool arguments before passing to BAP client
* TODO (MEDIUM): Add input validation on tool arguments before passing to BAP client — DONE (Zod schemas for top 10 tools)
* TODO (MEDIUM): Enforce session timeout (maxSessionDuration) - currently unused
* TODO (MEDIUM): Add resource cleanup on partial failure in ensureClient() — DONE (v0.2.0)
* TODO (LOW): parseSelector should validate empty/whitespace-only strings
Expand Down Expand Up @@ -41,6 +41,99 @@ import {
type AriaRole,
type AgentObserveResult,
} from "@browseragentprotocol/protocol";
import { z } from "zod";

// =============================================================================
// Input Validation Schemas
// =============================================================================

/** Validates a non-empty string argument */
const nonEmptyString = z.string().min(1, "must be a non-empty string");

/** Validates a URL string (must include protocol) */
const urlString = nonEmptyString.refine(
(val) => {
try {
new URL(val);
return true;
} catch {
return false;
}
},
{ message: "must be a valid URL (include protocol, e.g. https://)" }
);

/** Validates a positive integer */
const positiveInt = z.number().int().positive();

/** Validates a non-negative number */
const nonNegativeNumber = z.number().nonnegative();

/** Validation schemas for tool arguments that take user-facing inputs */
const ToolArgSchemas = {
navigate: z.object({
url: urlString,
waitUntil: z.enum(["load", "domcontentloaded", "networkidle"]).optional(),
observe: z.boolean().optional(),
observeMaxElements: positiveInt.optional(),
}),
click: z.object({
selector: nonEmptyString,
clickCount: positiveInt.optional(),
}),
type: z.object({
selector: nonEmptyString,
text: z.string(),
delay: nonNegativeNumber.optional(),
}),
fill: z.object({
selector: nonEmptyString,
value: z.string(),
}),
press: z.object({
key: nonEmptyString,
selector: nonEmptyString.optional(),
}),
select: z.object({
selector: nonEmptyString,
value: nonEmptyString,
}),
hover: z.object({
selector: nonEmptyString,
}),
element: z.object({
selector: nonEmptyString,
properties: z.array(z.string()).optional(),
}),
activate_page: z.object({
pageId: nonEmptyString,
}),
extract: z.object({
instruction: nonEmptyString,
schema: z.object({ type: z.string() }).passthrough(),
mode: z.enum(["single", "list", "table"]).optional(),
selector: nonEmptyString.optional(),
}),
} as const;

/**
* Validate tool arguments against a Zod schema.
* Returns the parsed (and typed) arguments, or throws a descriptive error.
*/
function validateArgs<T extends z.ZodType>(
toolName: string,
schema: T,
args: Record<string, unknown>
): z.infer<T> {
const result = schema.safeParse(args);
if (!result.success) {
const issues = result.error.issues
.map((i: z.ZodIssue) => `${i.path.join(".")}: ${i.message}`)
.join("; ");
throw new Error(`Invalid arguments for '${toolName}': ${issues}`);
}
return result.data;
}

// =============================================================================
// Types
Expand Down Expand Up @@ -939,7 +1032,8 @@ export class BAPMCPServer {
switch (name) {
// Navigation
case "navigate": {
const url = args.url as string;
const validated = validateArgs("navigate", ToolArgSchemas.navigate, args);
const url = validated.url;

// Security check
if (!this.isAllowedDomain(url)) {
Expand All @@ -963,17 +1057,17 @@ export class BAPMCPServer {
};
}

const waitUntil = (args.waitUntil as WaitUntilState) ?? "load";
const waitUntil = (validated.waitUntil as WaitUntilState) ?? "load";

// Fusion: navigate-observe kernel
const observeFlag = args.observe as boolean | undefined;
const observeFlag = validated.observe;
const result = await client.navigate(url, {
waitUntil,
...(observeFlag ? {
observe: {
includeMetadata: true,
includeInteractiveElements: true,
maxElements: (args.observeMaxElements as number) ?? 50,
maxElements: validated.observeMaxElements ?? 50,
},
} : {}),
});
Expand Down Expand Up @@ -1005,48 +1099,48 @@ export class BAPMCPServer {

// Element Interaction
case "click": {
const selector = parseSelector(args.selector as string);
const options = args.clickCount ? { clickCount: args.clickCount as number } : undefined;
const validated = validateArgs("click", ToolArgSchemas.click, args);
const selector = parseSelector(validated.selector);
const options = validated.clickCount ? { clickCount: validated.clickCount } : undefined;
await client.click(selector, options);
return {
content: [{ type: "text", text: `Clicked: ${args.selector}` }],
content: [{ type: "text", text: `Clicked: ${validated.selector}` }],
};
}

case "type": {
const selector = parseSelector(args.selector as string);
const text = args.text as string;
const delay = args.delay as number | undefined;
await client.type(selector, text, { delay });
const validated = validateArgs("type", ToolArgSchemas.type, args);
const selector = parseSelector(validated.selector);
await client.type(selector, validated.text, { delay: validated.delay });
return {
content: [{ type: "text", text: `Typed "${text}" into: ${args.selector}` }],
content: [{ type: "text", text: `Typed "${validated.text}" into: ${validated.selector}` }],
};
}

case "fill": {
const selector = parseSelector(args.selector as string);
const value = args.value as string;
await client.fill(selector, value);
const validated = validateArgs("fill", ToolArgSchemas.fill, args);
const selector = parseSelector(validated.selector);
await client.fill(selector, validated.value);
return {
content: [{ type: "text", text: `Filled "${value}" into: ${args.selector}` }],
content: [{ type: "text", text: `Filled "${validated.value}" into: ${validated.selector}` }],
};
}

case "press": {
const key = args.key as string;
const selector = args.selector ? parseSelector(args.selector as string) : undefined;
await client.press(key, selector);
const validated = validateArgs("press", ToolArgSchemas.press, args);
const selector = validated.selector ? parseSelector(validated.selector) : undefined;
await client.press(validated.key, selector);
return {
content: [{ type: "text", text: `Pressed: ${key}` }],
content: [{ type: "text", text: `Pressed: ${validated.key}` }],
};
}

case "select": {
const selector = parseSelector(args.selector as string);
const value = args.value as string;
await client.select(selector, value);
const validated = validateArgs("select", ToolArgSchemas.select, args);
const selector = parseSelector(validated.selector);
await client.select(selector, validated.value);
return {
content: [{ type: "text", text: `Selected "${value}" in: ${args.selector}` }],
content: [{ type: "text", text: `Selected "${validated.value}" in: ${validated.selector}` }],
};
}

Expand All @@ -1061,10 +1155,11 @@ export class BAPMCPServer {
}

case "hover": {
const selector = parseSelector(args.selector as string);
const validated = validateArgs("hover", ToolArgSchemas.hover, args);
const selector = parseSelector(validated.selector);
await client.hover(selector);
return {
content: [{ type: "text", text: `Hovered over: ${args.selector}` }],
content: [{ type: "text", text: `Hovered over: ${validated.selector}` }],
};
}

Expand Down Expand Up @@ -1120,8 +1215,9 @@ export class BAPMCPServer {
}

case "element": {
const selector = parseSelector(args.selector as string);
const properties = (args.properties as ElementProperty[]) ?? ["visible", "enabled"];
const validated = validateArgs("element", ToolArgSchemas.element, args);
const selector = parseSelector(validated.selector);
const properties = (validated.properties as ElementProperty[]) ?? ["visible", "enabled"];
const result = await client.element(selector, properties);
return {
content: [
Expand All @@ -1145,10 +1241,10 @@ export class BAPMCPServer {
}

case "activate_page": {
const pageId = args.pageId as string;
await client.activatePage(pageId);
const validated = validateArgs("activate_page", ToolArgSchemas.activate_page, args);
await client.activatePage(validated.pageId);
return {
content: [{ type: "text", text: `Activated page: ${pageId}` }],
content: [{ type: "text", text: `Activated page: ${validated.pageId}` }],
};
}

Expand Down Expand Up @@ -1373,11 +1469,12 @@ export class BAPMCPServer {
}

case "extract": {
const validated = validateArgs("extract", ToolArgSchemas.extract, args);
const result = await client.extract({
instruction: args.instruction as string,
schema: args.schema as ExtractionSchema,
mode: args.mode as "single" | "list" | "table" | undefined,
selector: args.selector ? parseSelector(args.selector as string) : undefined,
instruction: validated.instruction,
schema: validated.schema as ExtractionSchema,
mode: validated.mode,
selector: validated.selector ? parseSelector(validated.selector) : undefined,
});

if (result.success) {
Expand Down
Loading
Loading