diff --git a/awscli/customizations/sso/login.py b/awscli/customizations/sso/login.py index a16a946ed260..f2ef0aeb7558 100644 --- a/awscli/customizations/sso/login.py +++ b/awscli/customizations/sso/login.py @@ -57,6 +57,10 @@ def _run_main(self, parsed_args, parsed_globals): session_name=sso_config.get('session_name'), registration_scopes=sso_config.get('registration_scopes'), use_device_code=parsed_args.use_device_code, + redirect_address=( + parsed_args.redirect_host, + parsed_args.redirect_port, + ), ) success_msg = 'Successfully logged into Start URL: %s\n' uni_print(success_msg % sso_config['sso_start_url']) diff --git a/awscli/customizations/sso/utils.py b/awscli/customizations/sso/utils.py index 94062460f55d..199bc9eb34d1 100644 --- a/awscli/customizations/sso/utils.py +++ b/awscli/customizations/sso/utils.py @@ -61,6 +61,23 @@ 'instead of the Authorization Code flow.' ), }, + { + 'name': 'redirect-host', + 'action': 'store', + 'default': '0.0.0.0', + 'help_text': ( + 'Overrides OAuth callback address host instead of binding on all available local interfaces' + ), + }, + { + 'name': 'redirect-port', + 'action': 'store', + 'cli_type_name': 'integer', + 'default': 0, + 'help_text': ( + 'Overrides OAuth callback address port instead of arbitrary unused port' + ), + }, ] @@ -85,6 +102,7 @@ def do_sso_login( registration_scopes=None, session_name=None, use_device_code=False, + redirect_address=None, ): if token_cache is None: token_cache = JSONFileCache(SSO_TOKEN_DIR, dumps_func=_sso_json_dumps) @@ -100,7 +118,7 @@ def do_sso_login( sso_region=sso_region, client_creator=session.create_client, parsed_globals=parsed_globals, - auth_code_fetcher=AuthCodeFetcher(), + auth_code_fetcher=AuthCodeFetcher(server_address=redirect_address), cache=token_cache, on_pending_authorization=on_pending_authorization, ) @@ -217,7 +235,7 @@ class AuthCodeFetcher: # How long we wait overall for the callback _OVERALL_TIMEOUT = 60 * 10 - def __init__(self): + def __init__(self, server_address=('0.0.0.0', 0)): self._auth_code = None self._state = None self._is_done = False @@ -226,7 +244,7 @@ def __init__(self): # AuthCodeFetcher so that it can pass back the state and auth code try: handler = partial(OAuthCallbackHandler, self) - self.http_server = HTTPServer(('', 0), handler) + self.http_server = HTTPServer(server_address, handler) self.http_server.timeout = self._REQUEST_TIMEOUT except OSError as e: raise AuthCodeFetcherError(error_msg=e) diff --git a/tests/functional/sso/__init__.py b/tests/functional/sso/__init__.py index fb6d9d2b0f61..238902fb3cfe 100644 --- a/tests/functional/sso/__init__.py +++ b/tests/functional/sso/__init__.py @@ -159,6 +159,11 @@ def assert_device_browser_handler_called_with( verificationUriComplete, kwargs['verificationUriComplete'] ) + def assert_auth_code_fetcher_called_with(self, server_address): + self.fetcher_mock.assert_called_once() + _, kwargs = self.fetcher_mock.call_args + self.assertEqual(server_address, kwargs["server_address"]) + def assert_auth_browser_handler_called_with(self, expected_scopes): # The endpoint is subject to the endpoint rules, and the # code_challenge is not fixed so assert against the rest of the url diff --git a/tests/functional/sso/test_login.py b/tests/functional/sso/test_login.py index 5cad3a314651..8d8ca368478f 100644 --- a/tests/functional/sso/test_login.py +++ b/tests/functional/sso/test_login.py @@ -332,6 +332,17 @@ def test_login_device_sso_with_explicit_sso_session_arg(self): expected_token=self.access_token, ) + def test_login_auth_sso_with_explicit_redirect_port(self): + content = self.get_sso_session_config('test-session') + self.set_config_file_content(content=content) + self.add_oidc_auth_code_responses(self.access_token) + self.run_cmd( + 'sso login --redirect-port 5050 --redirect-host 50.50.50.50' + ) + self.assert_auth_code_fetcher_called_with( + server_address=('50.50.50.50', 5050) + ) + def test_login_auth_sso_with_explicit_sso_session_arg(self): content = self.get_sso_session_config( 'test-session', include_profile=False