Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/websocket/routing/message-router.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,28 @@ describe('MessageRouter', () => {
});
});

describe('circular reference protection', () => {
it('should throw when registering a handler with a circular pattern', () => {
const pattern: Record<string, unknown> = { cmd: 'test' };
pattern.self = pattern;

expect(() => {
router.registerHandlers([createHandler(pattern, () => 'result')]);
}).toThrow('Circular reference detected');
});

it('should throw when routing a circular event object', () => {
router.registerHandlers([createHandler('safe', () => 'ok')]);

const circular: Record<string, unknown> = { cmd: 'test' };
circular.self = circular;

expect(() => {
router.hasHandler(circular);
}).toThrow('Circular reference detected');
});
});

describe('edge cases with JSON serialization', () => {
it('should treat undefined values as omitted (JSON behavior)', async () => {
const pattern = { cmd: 'test', value: 1 };
Expand Down
29 changes: 2 additions & 27 deletions src/websocket/routing/message-router.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Logger } from '@nestjs/common';
import { MessageHandler } from './metadata-scanner';
import { sortObjectKeys } from './pattern-key';

/**
* Represents an incoming WebSocket message
Expand Down Expand Up @@ -157,33 +158,7 @@ export class MessageRouter {
return pattern;
}
// For object patterns, create a stable JSON string key with sorted keys (recursively)
return JSON.stringify(this.sortObjectKeys(pattern));
}

/**
* Recursively sorts object keys for stable serialization
* @private
*/
private sortObjectKeys(obj: Record<string, unknown>): Record<string, unknown> {
const sorted: Record<string, unknown> = {};
for (const key of Object.keys(obj).sort()) {
sorted[key] = this.sortValue(obj[key]);
}
return sorted;
}

/**
* Recursively sorts values (handles objects, arrays, and primitives)
* @private
*/
private sortValue(value: unknown): unknown {
if (value === null || typeof value !== 'object') {
return value;
}
if (Array.isArray(value)) {
return value.map((item) => this.sortValue(item));
}
return this.sortObjectKeys(value as Record<string, unknown>);
return JSON.stringify(sortObjectKeys(pattern));
}

/**
Expand Down
24 changes: 24 additions & 0 deletions src/websocket/routing/metadata-scanner.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,30 @@ describe('MetadataScanner', () => {
).toBe('handleMessage');
});

it('should match object patterns with arrays containing out-of-order keys', () => {
const pattern = {
cmd: 'batch',
items: [
{ b: 1, a: 2 },
{ d: 4, c: 3 },
],
};
addMetadata('handleMessage', pattern);

scanner.scanForMessageHandlers(gateway);

// Same pattern but object keys in different order inside arrays
expect(
scanner.getMethodNameForEvent(gateway, {
cmd: 'batch',
items: [
{ a: 2, b: 1 },
{ c: 3, d: 4 },
],
})
).toBe('handleMessage');
});

it('should return null for non-existent event', () => {
addMetadata('handleMessage', 'message');

Expand Down
30 changes: 3 additions & 27 deletions src/websocket/routing/metadata-scanner.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { Logger } from '@nestjs/common';
import 'reflect-metadata';
import { sortObjectKeys } from './pattern-key';

/**
* Metadata keys used by @SubscribeMessage decorator
Expand Down Expand Up @@ -103,7 +104,7 @@ export class MetadataScanner {
message: messagePattern,
messageKey:
typeof messagePattern === 'object'
? JSON.stringify(this.sortObjectKeys(messagePattern))
? JSON.stringify(sortObjectKeys(messagePattern))
: undefined,
methodName,
callback: method.bind(instance),
Expand Down Expand Up @@ -156,7 +157,7 @@ export class MetadataScanner {
// Pre-compute eventKey for object patterns to avoid repeated serialization
const eventKey =
typeof event === 'object' && event !== null
? JSON.stringify(this.sortObjectKeys(event))
? JSON.stringify(sortObjectKeys(event))
: undefined;

const handler = handlers.find((h) => {
Expand All @@ -178,29 +179,4 @@ export class MetadataScanner {

return handler ? handler.methodName : null;
}

/**
* Recursively sorts object keys for stable serialization
* @private
*/
private sortObjectKeys(
obj: Record<string, unknown>,
seen = new WeakSet<object>()
): Record<string, unknown> {
if (seen.has(obj)) {
throw new Error('Circular reference detected in message pattern');
}
seen.add(obj);

const sorted: Record<string, unknown> = {};
for (const key of Object.keys(obj).sort()) {
const value = obj[key];
sorted[key] =
value !== null && typeof value === 'object' && !Array.isArray(value)
? this.sortObjectKeys(value as Record<string, unknown>, seen)
: value;
}
seen.delete(obj);
return sorted;
}
}
72 changes: 72 additions & 0 deletions src/websocket/routing/pattern-key.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import { sortObjectKeys } from './pattern-key';

describe('sortObjectKeys', () => {
it('should sort top-level keys alphabetically', () => {
const result = sortObjectKeys({ b: 2, a: 1, c: 3 });
expect(Object.keys(result)).toEqual(['a', 'b', 'c']);
expect(result).toEqual({ a: 1, b: 2, c: 3 });
});

it('should return empty object for empty input', () => {
expect(sortObjectKeys({})).toEqual({});
});

it('should sort nested object keys recursively', () => {
const result = sortObjectKeys({ z: { y: 2, x: 1 }, a: 0 });
expect(Object.keys(result)).toEqual(['a', 'z']);
expect(Object.keys(result.z as object)).toEqual(['x', 'y']);
});

it('should handle arrays without modifying order', () => {
const result = sortObjectKeys({ items: [3, 1, 2] });
expect(result.items).toEqual([3, 1, 2]);
});

it('should sort keys of objects nested inside arrays', () => {
const result = sortObjectKeys({
items: [
{ b: 1, a: 2 },
{ d: 3, c: 4 },
],
});
const items = result.items as Record<string, unknown>[];
expect(Object.keys(items[0])).toEqual(['a', 'b']);
expect(Object.keys(items[1])).toEqual(['c', 'd']);
});

it('should handle nested arrays', () => {
const result = sortObjectKeys({ matrix: [[{ z: 1, a: 2 }]] });
const inner = (result.matrix as unknown[][])[0][0] as Record<string, unknown>;
expect(Object.keys(inner)).toEqual(['a', 'z']);
});

it('should pass through primitives as-is', () => {
const result = sortObjectKeys({ str: 'hello', num: 42, bool: true, nil: null });
expect(result).toEqual({ bool: true, nil: null, num: 42, str: 'hello' });
});

it('should throw on circular references', () => {
const obj: Record<string, unknown> = { a: 1 };
obj.self = obj;
expect(() => sortObjectKeys(obj)).toThrow('Circular reference detected');
});

it('should throw on deeply nested circular references', () => {
const inner: Record<string, unknown> = { value: 1 };
const outer: Record<string, unknown> = { nested: inner };
inner.parent = outer;
expect(() => sortObjectKeys(outer)).toThrow('Circular reference detected');
});

it('should allow the same object in multiple non-circular branches', () => {
const shared = { x: 1, y: 2 };
const result = sortObjectKeys({ a: shared, b: shared });
expect(Object.keys(result.a as object)).toEqual(['x', 'y']);
expect(Object.keys(result.b as object)).toEqual(['x', 'y']);
});

it('should handle undefined and function values like JSON.stringify', () => {
const result = sortObjectKeys({ a: 1, b: undefined, c: () => {} });
expect(result).toEqual({ a: 1, b: undefined, c: expect.any(Function) });
});
});
31 changes: 31 additions & 0 deletions src/websocket/routing/pattern-key.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
* Recursively sorts object keys for stable JSON serialization.
* Handles nested objects, arrays, and primitives.
* Throws on circular references.
*/
export function sortObjectKeys(
obj: Record<string, unknown>,
seen = new WeakSet<object>()
): Record<string, unknown> {
if (seen.has(obj)) {
throw new Error('Circular reference detected in message pattern');
}
seen.add(obj);

const sorted: Record<string, unknown> = {};
for (const key of Object.keys(obj).sort()) {
sorted[key] = sortValue(obj[key], seen);
}
seen.delete(obj);
return sorted;
}

function sortValue(value: unknown, seen: WeakSet<object>): unknown {
if (value === null || typeof value !== 'object') {
return value;
}
if (Array.isArray(value)) {
return value.map((item) => sortValue(item, seen));
}
return sortObjectKeys(value as Record<string, unknown>, seen);
}