Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 79 additions & 57 deletions src/routes/graphql/ws.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,72 +3,94 @@ import { Duplex } from 'stream';
import WebSocket from 'ws';

import { logger } from '~helpers';
import { validateSubscription } from '../../subscriptionPermissions';

// The host of the AppSync API (non-realtime)
const WS_HOST = new URL('/', process.env.APPSYNC_API).host;
// In production the Amplify WebSocket API is secure and requires a different endpoint
const WS_ENDPOINT = process.env.NODE_ENV === 'dev' ? `ws://${WS_HOST}` : process.env.APPSYNC_WSS_API;
const WS_ENDPOINT =
process.env.NODE_ENV === 'dev'
? `ws://${WS_HOST}`
: process.env.APPSYNC_WSS_API;

const wss = new WebSocket.Server({ noServer: true });

// Custom websocker upgrade handler
// This proxies websocket requests and adds the necessary headers for Amplify authorization if applicable
export const handleWsUpgrade = (req: InstanceType<typeof IncomingMessage>, socket: Duplex, head: Buffer) => {
// localhost is fine here, we're just using the path and the query string
const url = new URL(req.url || '', 'http://localhost');
const authHeaders = {
host: WS_HOST,
// Creates a date in the format YYYYMMDDTHHMMSSZ
'x-amz-date': new Date().toISOString().replace(/[\-:]/g,'').replace(/\.\d{3}/,''),
'x-api-key': process.env.APPSYNC_API_KEY,
}
// Add "header" query string parameter (default is {})
url.searchParams.set('header', btoa((JSON.stringify(authHeaders))));
const proxyPath = `${url.pathname}?${url.searchParams.toString()}`
// Establish a websocket connection to Amplify
const targetWs = new WebSocket(
`${WS_ENDPOINT}${proxyPath}`,
req.headers['sec-websocket-protocol'] || 'graphql-ws',
);
targetWs.on('open', () => {
wss.handleUpgrade(req, socket, head, (ws) => {
// Add authorization headers to incoming client messages
ws.on('message', (data) => {
let parsed;
try {
parsed = JSON.parse(data.toString());
} catch (e) {
logger('Failed to parse websocket message', e);
return;
}
if (parsed.payload?.extensions?.authorization) {
parsed.payload.extensions.authorization = {
...parsed.payload.extensions.authorization,
...authHeaders,
};
return targetWs.send(JSON.stringify(parsed));
}
targetWs.send(data.toString());
});
ws.on('close', () => {
targetWs.close();
});
ws.on('error', (err) => {
logger('WebSocket error from client: ', err);
targetWs.close();
});
export const handleWsUpgrade = (
req: InstanceType<typeof IncomingMessage>,
socket: Duplex,
head: Buffer,
) => {
// localhost is fine here, we're just using the path and the query string
const url = new URL(req.url || '', 'http://localhost');
const authHeaders = {
host: WS_HOST,
// Creates a date in the format YYYYMMDDTHHMMSSZ
'x-amz-date': new Date()
.toISOString()
.replace(/[\-:]/g, '')
.replace(/\.\d{3}/, ''),
'x-api-key': process.env.APPSYNC_API_KEY,
};
// Add "header" query string parameter (default is {})
url.searchParams.set('header', btoa(JSON.stringify(authHeaders)));
const proxyPath = `${url.pathname}?${url.searchParams.toString()}`;
// Establish a websocket connection to Amplify
const targetWs = new WebSocket(
`${WS_ENDPOINT}${proxyPath}`,
req.headers['sec-websocket-protocol'] || 'graphql-ws',
);
targetWs.on('open', () => {
wss.handleUpgrade(req, socket, head, (ws) => {
// Add authorization headers to incoming client messages
ws.on('message', (data) => {
let parsed;
try {
parsed = JSON.parse(data.toString());
} catch (e) {
logger('Failed to parse websocket message', e);
return;
}

// Pass through messages from Amplify to the client
targetWs.on('message', (data) => {
ws.send(data.toString());
});
targetWs.on('close', () => {
ws.close();
});
targetWs.on('error', (err) => {
logger('WebSocket error from Amplify: ', err);
ws.close();
});
});
// Validate subscription queries on start
if (parsed.type === 'start' && parsed.payload?.data) {
const parsedData = JSON.parse(parsed.payload.data);
if (!validateSubscription(parsedData.query)) {
ws.close();
targetWs.close();
return;
}
}

if (parsed.payload?.extensions?.authorization) {
parsed.payload.extensions.authorization = {
...parsed.payload.extensions.authorization,
...authHeaders,
};
return targetWs.send(JSON.stringify(parsed));
}
targetWs.send(data.toString());
});
ws.on('close', () => {
targetWs.close();
});
ws.on('error', (err) => {
logger('WebSocket error from client: ', err);
targetWs.close();
});

// Pass through messages from Amplify to the client
targetWs.on('message', (data) => {
ws.send(data.toString());
});
targetWs.on('close', () => {
ws.close();
});
targetWs.on('error', (err) => {
logger('WebSocket error from Amplify: ', err);
ws.close();
});
});
});
};
81 changes: 81 additions & 0 deletions src/subscriptionPermissions.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import { parse, visit, visitWithTypeInfo, TypeInfo } from 'graphql';

import { getSchema } from './schema';
import { logger } from '~helpers';

const allowedSubscriptions = [
'onCreateColonyActionMetadata',
'onUpdateColony',
'onCreateColonyContributor',
'onUpdateColonyContributor',
];

const blockedFields: Record<string, string[]> = {
Profile: ['email'],
User: [
'bridgeCustomerId',
'privateBetaInviteCode',
'userPrivateBetaInviteCodeId',
],
Colony: ['colonyMemberInvite', 'colonyMemberInviteCode'],
};

export const validateSubscription = (query: string): boolean => {
let document;
try {
document = parse(query);
} catch {
logger('Subscription rejected: Invalid query');
return false;
}

// Check if the subscription is allowed
for (const def of document.definitions) {
if (def.kind !== 'OperationDefinition') continue;

if (def.operation !== 'subscription') {
logger('Subscription rejected: Non-subscription operation in document');
return false;
}

const firstField = def.selectionSet.selections[0];
if (!firstField || firstField.kind !== 'Field') {
logger('Subscription rejected: No field selected');
return false;
}

const subscriptionName = firstField.name.value;
if (!allowedSubscriptions.includes(subscriptionName)) {
logger('Subscription rejected:', subscriptionName);
return false;
}
}

// Check for blocked fields
const typeInfo = new TypeInfo(getSchema());
let blocked = false;

visit(
document,
visitWithTypeInfo(typeInfo, {
Field: {
enter(node) {
const parentType = typeInfo.getParentType();
if (!parentType) return;

const blockedList = blockedFields[parentType.name];
if (blockedList?.includes(node.name.value)) {
logger(
'Subscription rejected due to blocked field:',
`${parentType.name}.${node.name.value}`,
);
blocked = true;
return false;
}
},
},
}),
);

return !blocked;
};
Loading