From 05dfcbe5af163ba135499fd168269ba9eb7c756e Mon Sep 17 00:00:00 2001 From: Robin Karlsson Date: Wed, 20 May 2026 14:49:08 +0300 Subject: [PATCH] Add support for Unix sockets in the --host arg Add the possibility of passing a path to a Unix socket as unix:// in the --host parameter to allow running TensorBoard securely on multi-user hosts. --- tensorboard/plugins/core/core_plugin.py | 14 ++++++-- tensorboard/plugins/core/core_plugin_test.py | 18 ++++++++++ tensorboard/program.py | 35 ++++++++++++++++---- tensorboard/program_test.py | 21 ++++++++++++ 4 files changed, 80 insertions(+), 8 deletions(-) diff --git a/tensorboard/plugins/core/core_plugin.py b/tensorboard/plugins/core/core_plugin.py index 3ef86419f8f..aeddb9f9ca7 100644 --- a/tensorboard/plugins/core/core_plugin.py +++ b/tensorboard/plugins/core/core_plugin.py @@ -351,7 +351,8 @@ def define_flags(self, parser): help="""\ What host to listen to (default: localhost). To serve to the entire local network on both IPv4 and IPv6, see `--bind_all`, with which this option is -mutually exclusive. +mutually exclusive. May also be set to `unix://` (e.g. +`unix:///tmp/tb.sock`) to listen on a Unix domain socket. """, ) @@ -390,7 +391,8 @@ def define_flags(self, parser): Enables the SO_REUSEPORT option on the socket opened by TensorBoard's HTTP server, for platforms that support it. This is useful in cases when a parent process has obtained the port already and wants to delegate access to the -port to TensorBoard as a subprocess.(default: %(default)s).\ +port to TensorBoard as a subprocess. Ignored when `--host` is set to a +`unix://` socket. (default: %(default)s).\ """, ) @@ -707,6 +709,14 @@ def fix_flags(self, flags): ) elif flags.host is not None and flags.bind_all: raise FlagsError("Must not specify both --host and --bind_all.") + elif ( + flags.host is not None + and flags.host.startswith("unix://") + and flags.port is not None + ): + raise FlagsError( + "--host=unix://... must not be combined with --port." + ) elif ( flags.load_fast == "true" and flags.detect_file_replacement is True ): diff --git a/tensorboard/plugins/core/core_plugin_test.py b/tensorboard/plugins/core/core_plugin_test.py index a81393741bb..30fab97999b 100644 --- a/tensorboard/plugins/core/core_plugin_test.py +++ b/tensorboard/plugins/core/core_plugin_test.py @@ -58,6 +58,7 @@ def __init__( logdir="", logdir_spec="", path_prefix="", + port=None, reuse_port=False, version_tb=False, ): @@ -72,6 +73,7 @@ def __init__( self.logdir = logdir self.logdir_spec = logdir_spec self.path_prefix = path_prefix + self.port = port self.reuse_port = reuse_port self.version_tb = version_tb @@ -132,6 +134,22 @@ def testPathPrefix_mustStartWithSlash(self): self.assertIn("must start with slash", msg) self.assertIn(repr("noslash"), msg) + def testHostUnixSocket_alone_isAccepted(self): + loader = core_plugin.CorePluginLoader() + for value in ("unix:///tmp/tb.sock", "unix://tb.sock"): + loader.fix_flags(FakeFlags(logdir="/tmp", host=value)) + + def testHostUnixSocket_conflictsWithPort(self): + loader = core_plugin.CorePluginLoader() + flag = FakeFlags( + logdir="/tmp", + host="unix:///tmp/tb.sock", + port=6006, + ) + with self.assertRaises(base_plugin.FlagsError) as cm: + loader.fix_flags(flag) + self.assertIn("unix://", str(cm.exception)) + class CorePluginTest(tf.test.TestCase): def setUp(self): diff --git a/tensorboard/program.py b/tensorboard/program.py index e945c309549..e4b0b23692f 100644 --- a/tensorboard/program.py +++ b/tensorboard/program.py @@ -139,7 +139,7 @@ def __init__( ) from e assets_zip_provider = assets.get_default_assets_zip_provider() if server_class is None: - server_class = create_port_scanning_werkzeug_server + server_class = _default_server_class if subcommands is None: subcommands = [] self.plugin_loaders = [ @@ -333,7 +333,8 @@ def _register_info(self, server): info = manager.TensorBoardInfo( version=version.VERSION, start_time=int(time.time()), - port=server_url.port, + # For Unix sockets, server_url.port is None. + port=server_url.port or 0, pid=os.getpid(), path_prefix=self.flags.path_prefix, logdir=self.flags.logdir or self.flags.logdir_spec, @@ -476,6 +477,15 @@ def _make_server(self): return self.server_class(app, self.flags) +def _default_server_class(wsgi_app, flags): + """Default server factory.""" + + # Skip port scanning for Unix-socket servers. + if flags.host is not None and flags.host.startswith("unix://"): + return WerkzeugServer(wsgi_app, flags) + return create_port_scanning_werkzeug_server(wsgi_app, flags) + + def _should_use_data_server(flags): if flags.logdir_spec and not flags.logdir: logger.info( @@ -696,9 +706,13 @@ def __init__(self, wsgi_app, flags): self._flags = flags host = flags.host port = flags.port + self._unix_socket = host is not None and host.startswith("unix://") - self._auto_wildcard = flags.bind_all - if self._auto_wildcard: + self._auto_wildcard = flags.bind_all and not self._unix_socket + if self._unix_socket: + # Werkzeug accepts host="unix://" directly, port is ignored. + port = 0 + elif self._auto_wildcard: # Serve on all interfaces, and attempt to serve both IPv4 and IPv6 # traffic through one socket. host = self._get_wildcard_address(port) @@ -715,12 +729,14 @@ def is_port_in_use(port): return s.connect_ex(("localhost", port)) == 0 try: - if is_port_in_use(port): + if not self._unix_socket and is_port_in_use(port): raise TensorBoardPortInUseError( "TensorBoard could not bind to port %d, it was already in use" % port ) super().__init__(host, port, wsgi_app, _WSGIRequestHandler) + if self._unix_socket: + os.chmod(self.server_address, 0o700) except socket.error as e: if hasattr(errno, "EACCES") and e.errno == errno.EACCES: raise TensorBoardServerException( @@ -801,7 +817,7 @@ def _get_wildcard_address(self, port): def server_bind(self): """Override to set custom options on the socket.""" - if self._flags.reuse_port: + if self._flags.reuse_port and not self._unix_socket: try: socket.SO_REUSEPORT except AttributeError: @@ -856,6 +872,13 @@ def handle_error(self, request, client_address): def get_url(self): if not self._url: + if self._unix_socket: + self._url = "%s%s/" % ( + self._host, + self._flags.path_prefix.rstrip("/"), + ) + return self._url + if self._auto_wildcard: display_host = socket.getfqdn() # Confirm that the connection is open, otherwise change to `localhost` diff --git a/tensorboard/program_test.py b/tensorboard/program_test.py index 82ace265bef..27b89149dc2 100644 --- a/tensorboard/program_test.py +++ b/tensorboard/program_test.py @@ -16,7 +16,10 @@ import argparse import io +import os +import socket import sys +import tempfile from unittest import mock from tensorboard import program @@ -113,6 +116,7 @@ def make_flags(self, **kwargs): flags = argparse.Namespace() kwargs.setdefault("host", None) kwargs.setdefault("bind_all", kwargs["host"] is None) + kwargs.setdefault("port", None) kwargs.setdefault("reuse_port", False) for k, v in kwargs.items(): setattr(flags, k, v) @@ -168,6 +172,23 @@ def testSpecifiedHost(self): "Neither IPv4 (127.0.0.1) nor IPv6 (::1) could be bound.", ) + def testUnixSocketHost(self): + if not hasattr(socket, "AF_UNIX"): + self.skipTest("AF_UNIX not supported on this platform") + with tempfile.TemporaryDirectory() as tmpdir: + sock_path = os.path.join(tmpdir, "tb.sock") + server = program.WerkzeugServer( + self._StubApplication(), + self.make_flags(host="unix://" + sock_path, path_prefix=""), + ) + try: + self.assertTrue(os.path.exists(sock_path)) + url = server.get_url() + self.assertStartsWith(url, "unix://") + self.assertIn("tb.sock", url) + finally: + server.server_close() + class SubcommandTest(tb_test.TestCase): def setUp(self):