From 81b18c83480ca8f9e14bdbb17cd7306f9e8dea89 Mon Sep 17 00:00:00 2001 From: Lars Ershammar Date: Mon, 9 Mar 2026 10:33:24 +0000 Subject: [PATCH] Add WebSocket signaling for WebRTC connections Add WebSocket-based signaling as a faster alternative to HTTP long-polling for WebRTC connections between the browser client and the Cuttlefish operator. Key changes: - Add WsClient type implementing the Client interface with goroutine-based read/write pumps and ping/pong keepalive - Add GET /devices/{deviceId}/connect WebSocket upgrade endpoint - Add WebSocketConnector JavaScript class that tries WebSocket first and falls back to HTTP polling if unavailable - Full backward compatibility: existing polling connections continue to work unchanged Uses the existing gorilla/websocket dependency (already used for the ADB WebSocket proxy). --- frontend/src/liboperator/operator/clients.go | 164 ++++++++++++++++++ .../src/liboperator/operator/clients_test.go | 78 +++++++++ frontend/src/liboperator/operator/operator.go | 49 ++++++ .../operator/intercept/js/server_connector.js | 122 +++++++++++++ 4 files changed, 413 insertions(+) diff --git a/frontend/src/liboperator/operator/clients.go b/frontend/src/liboperator/operator/clients.go index 2a243454772..8ec2c81b11f 100644 --- a/frontend/src/liboperator/operator/clients.go +++ b/frontend/src/liboperator/operator/clients.go @@ -15,10 +15,26 @@ package operator import ( + "encoding/json" + "log" "math/rand" "strings" "sync" "time" + + "github.com/gorilla/websocket" +) + +// WebSocket timing constants +const ( + // Time allowed to write a message to the peer + wsWriteWait = 10 * time.Second + // Time allowed to read the next pong message from the peer + wsPongWait = 60 * time.Second + // Send pings to peer with this period (must be less than wsPongWait) + wsPingPeriod = 30 * time.Second + // Maximum message size allowed from peer + wsMaxMessageSize = 64 * 1024 ) type Client interface { @@ -139,3 +155,151 @@ func randStr(l int, r *rand.Rand) string { } return s.String() } + +// WsClient implements Client for WebSocket connections. +type WsClient struct { + // The id given to this client by the device + clientId int + conn *websocket.Conn + device *Device + // Buffered channel of outbound messages + send chan interface{} + // Signals that the connection is closed + done chan struct{} + // Ensures Close() only runs once + closeOnce sync.Once +} + +// NewWsClient creates a new WebSocket client and starts its read/write pumps. +func NewWsClient(conn *websocket.Conn, device *Device) *WsClient { + c := &WsClient{ + conn: conn, + device: device, + send: make(chan interface{}, 256), + done: make(chan struct{}), + } + go c.writePump() + go c.readPump() + return c +} + +// Send queues a message for delivery to the WebSocket client. +func (c *WsClient) Send(msg interface{}) error { + select { + case c.send <- msg: + return nil + case <-c.done: + return nil + } +} + +// OnDeviceDisconnected notifies the client that the device disconnected. +func (c *WsClient) OnDeviceDisconnected() { + // Send an error message to the client before closing + c.Send(map[string]interface{}{ + "error": "Device disconnected", + }) + c.Close() +} + +// Close terminates the WebSocket connection. +func (c *WsClient) Close() { + c.closeOnce.Do(func() { + close(c.done) + if c.conn != nil { + c.conn.Close() + } + }) +} + +// Done returns a channel that is closed when the client disconnects. +func (c *WsClient) Done() <-chan struct{} { + return c.done +} + +// ClientId returns the id assigned by the device. +func (c *WsClient) ClientId() int { + return c.clientId +} + +// readPump reads messages from the WebSocket and forwards them to the device. +func (c *WsClient) readPump() { + defer func() { + c.device.Unregister(c.clientId) + c.Close() + }() + + c.conn.SetReadLimit(wsMaxMessageSize) + c.conn.SetReadDeadline(time.Now().Add(wsPongWait)) + c.conn.SetPongHandler(func(string) error { + c.conn.SetReadDeadline(time.Now().Add(wsPongWait)) + return nil + }) + + for { + _, message, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, + websocket.CloseGoingAway, + websocket.CloseNormalClosure) { + log.Printf("WebSocket read error: %v", err) + } + return + } + + var msg map[string]interface{} + if err := json.Unmarshal(message, &msg); err != nil { + log.Printf("Invalid JSON from WebSocket client: %v", err) + continue + } + + // Forward to device as client_msg + clientMsg := map[string]interface{}{ + "message_type": "client_msg", + "client_id": c.clientId, + "payload": msg, + } + if err := c.device.Send(clientMsg); err != nil { + log.Printf("Failed to forward message to device: %v", err) + return + } + } +} + +// writePump writes messages from the send channel to the WebSocket. +func (c *WsClient) writePump() { + ticker := time.NewTicker(wsPingPeriod) + defer func() { + ticker.Stop() + c.conn.Close() + }() + + for { + select { + case msg, ok := <-c.send: + c.conn.SetWriteDeadline(time.Now().Add(wsWriteWait)) + if !ok { + // Channel closed + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + if err := c.conn.WriteJSON(msg); err != nil { + log.Printf("WebSocket write error: %v", err) + return + } + + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(wsWriteWait)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + + case <-c.done: + return + } + } +} + +// Compile-time check that WsClient implements Client +var _ Client = (*WsClient)(nil) diff --git a/frontend/src/liboperator/operator/clients_test.go b/frontend/src/liboperator/operator/clients_test.go index a3b62b1b9d9..3580c5e1b25 100644 --- a/frontend/src/liboperator/operator/clients_test.go +++ b/frontend/src/liboperator/operator/clients_test.go @@ -100,3 +100,81 @@ func TestRandStr(t *testing.T) { t.Error("randStr returned two equal strings of len 100") } } + +// TestWsClientImplementsClient verifies WsClient implements the Client interface. +// This is a compile-time check, but we include it as a test for documentation. +func TestWsClientImplementsClient(t *testing.T) { + // This is checked at compile time via: var _ Client = (*WsClient)(nil) + // If WsClient doesn't implement Client, the code won't compile. +} + +// TestWsClientClose verifies that Close() can be called multiple times safely. +func TestWsClientClose(t *testing.T) { + // Create a minimal WsClient without a real connection + // to test the Close() sync.Once behavior + c := &WsClient{ + send: make(chan interface{}, 1), + done: make(chan struct{}), + } + + // First close should work + c.Close() + + // Verify done channel is closed + select { + case <-c.done: + // Expected + default: + t.Error("done channel not closed after Close()") + } + + // Second close should not panic + c.Close() +} + +// TestWsClientDone verifies the Done() channel behavior. +func TestWsClientDone(t *testing.T) { + c := &WsClient{ + send: make(chan interface{}, 1), + done: make(chan struct{}), + } + + // Done() should return the done channel + done := c.Done() + if done == nil { + t.Error("Done() returned nil") + } + + // Channel should be open initially + select { + case <-done: + t.Error("done channel closed before Close() called") + default: + // Expected + } + + // After Close(), channel should be closed + c.Close() + select { + case <-done: + // Expected + default: + t.Error("done channel not closed after Close()") + } +} + +// TestWsClientSendAfterClose verifies Send() returns without blocking after Close(). +func TestWsClientSendAfterClose(t *testing.T) { + c := &WsClient{ + send: make(chan interface{}, 1), + done: make(chan struct{}), + } + + c.Close() + + // Send should return immediately without blocking + err := c.Send("test message") + if err != nil { + t.Error("Send() returned error after Close()") + } +} diff --git a/frontend/src/liboperator/operator/operator.go b/frontend/src/liboperator/operator/operator.go index a43dfdc9b5b..441d3ee34b9 100644 --- a/frontend/src/liboperator/operator/operator.go +++ b/frontend/src/liboperator/operator/operator.go @@ -114,6 +114,8 @@ func SetupControlEndpoint(pool *DevicePool, path string) (func() error, error) { // GET /devices/{deviceId}/services/{serviceName}/{typeName}/type // GET /devices/{deviceId}/openwrt{path:/.*} // POST /devices/{deviceId}/openwrt{path:/.*} +// GET /devices/{deviceId}/adb (WebSocket) +// GET /devices/{deviceId}/connect (WebSocket signaling) // GET /polled_connections // GET /polled_connections/{connId}/messages // POST /polled_connections/{connId}/:forward @@ -157,6 +159,9 @@ func CreateHttpHandlers( router.HandleFunc("/devices/{deviceId}/adb", func(w http.ResponseWriter, r *http.Request) { adbProxy(w, r, pool) }).Methods("GET") + router.HandleFunc("/devices/{deviceId}/connect", func(w http.ResponseWriter, r *http.Request) { + connectWebSocket(w, r, pool) + }).Methods("GET") router.HandleFunc("/polled_connections/{connId}/:forward", func(w http.ResponseWriter, r *http.Request) { forward(w, r, polledSet) }).Methods("POST") @@ -550,6 +555,50 @@ func adbProxy(w http.ResponseWriter, r *http.Request, pool *DevicePool) { } } +// WebSocket upgrader for signaling connections +var wsSignalingUpgrader = websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + // Allow all origins; restrict in production via reverse proxy + return true + }, +} + +// WebSocket endpoint for signaling +func connectWebSocket(w http.ResponseWriter, r *http.Request, pool *DevicePool) { + vars := mux.Vars(r) + deviceId := vars["deviceId"] + + device := pool.GetDevice(deviceId) + if device == nil { + http.Error(w, "Device not found", http.StatusNotFound) + return + } + + conn, err := wsSignalingUpgrader.Upgrade(w, r, nil) + if err != nil { + log.Printf("WebSocket upgrade failed for device %s: %v", deviceId, err) + return + } + + client := NewWsClient(conn, device) + clientId := device.Register(client) + client.clientId = clientId + + log.Printf("WebSocket client %d connected to device %s", clientId, deviceId) + + // Send device info to the client + client.Send(map[string]interface{}{ + "device_info": device.privateData, + }) + + // Block until the client disconnects + <-client.Done() + + log.Printf("WebSocket client %d disconnected from device %s", clientId, deviceId) +} + // Wrapper for implementing io.ReadWriteCloser of websocket.Conn type wsIoWrapper struct { wsConn *websocket.Conn diff --git a/frontend/src/operator/intercept/js/server_connector.js b/frontend/src/operator/intercept/js/server_connector.js index 6dbc64f7622..e116f57872d 100644 --- a/frontend/src/operator/intercept/js/server_connector.js +++ b/frontend/src/operator/intercept/js/server_connector.js @@ -42,7 +42,19 @@ export function deviceId() { } // Creates a connector capable of communicating with the signaling server. +// Tries WebSocket first, falls back to HTTP polling if WebSocket is unavailable. export async function createConnector() { + // Check if WebSocket is supported + if (typeof WebSocket !== 'undefined') { + try { + const connector = new WebSocketConnector(); + console.log('Using WebSocket signaling'); + return connector; + } catch (e) { + console.warn('WebSocket connector creation failed, falling back to polling:', e); + } + } + console.log('Using HTTP polling signaling'); return new PollingConnector(); } @@ -274,6 +286,116 @@ class PollingConnector extends Connector { } } +// WebSocket connection timeout in milliseconds +const WS_CONNECT_TIMEOUT = 3000; + +// Implementation of the Connector interface using WebSocket +class WebSocketConnector extends Connector { + #configUrl = httpUrl('infra_config'); + #ws; + #config = undefined; + #deviceInfo = undefined; + #onDeviceMsgCb = msg => + console.error('Received device message without registered listener'); + + onDeviceMsg(cb) { + this.#onDeviceMsgCb = cb; + } + + constructor() { + super(); + } + + async requestDevice(deviceId) { + let config = await this.#getConfig(); + + return new Promise((resolve, reject) => { + const protocol = location.protocol === 'https:' ? 'wss:' : 'ws:'; + const wsUrl = `${protocol}//${location.host}/devices/${deviceId}/connect`; + + const timeout = setTimeout(() => { + this.#ws.close(); + reject(new Error('WebSocket connection timeout')); + }, WS_CONNECT_TIMEOUT); + + this.#ws = new WebSocket(wsUrl); + + this.#ws.onopen = () => { + console.debug('WebSocket signaling connection opened'); + }; + + this.#ws.onmessage = (event) => { + const message = JSON.parse(event.data); + + // First message should be device_info + if (this.#deviceInfo === undefined && message.device_info !== undefined) { + clearTimeout(timeout); + this.#deviceInfo = message.device_info; + resolve({ + deviceInfo: this.#deviceInfo, + infraConfig: config, + }); + return; + } + + // Handle error messages + if (message.error !== undefined) { + if (this.#deviceInfo === undefined) { + clearTimeout(timeout); + reject(new Error(message.error)); + } else { + console.error('Device error:', message.error); + } + return; + } + + // Normal message - forward to callback + // Messages from device come with payload wrapper + if (message.payload !== undefined) { + this.#onDeviceMsgCb(message.payload); + } else { + this.#onDeviceMsgCb(message); + } + }; + + this.#ws.onerror = (error) => { + clearTimeout(timeout); + reject(new Error('WebSocket connection failed')); + }; + + this.#ws.onclose = (event) => { + clearTimeout(timeout); + if (this.#deviceInfo === undefined) { + reject(new Error('WebSocket closed before receiving device info')); + } + }; + }); + } + + async sendToDevice(msg) { + if (this.#ws && this.#ws.readyState === WebSocket.OPEN) { + this.#ws.send(JSON.stringify(msg)); + } else { + throw new Error('WebSocket not connected'); + } + } + + async #getConfig() { + if (this.#config === undefined) { + this.#config = await (await fetch(this.#configUrl, { + method: 'GET', + redirect: 'follow', + })).json(); + } + return this.#config; + } + + // WebSocket is always "fast", so this is a no-op + expectMessagesSoon(durationMilliseconds) { + // No action needed for WebSocket - messages are delivered immediately + } +} + export class DisplayInfo { display_id = ''; width = 0;