Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEXT_CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New Features and Improvements

### Bug Fixes
* Fixed Databricks M2M OAuth to correctly use Databricks OIDC endpoints instead of incorrectly using Azure endpoints when `ARM_CLIENT_ID` is set. Added new `getDatabricksOidcEndpoints` method that returns only Databricks OIDC endpoints, and updated all Databricks OAuth flows to use it. The old `getOidcEndpoints` property is deprecated but maintained for backward compatibility.

### Security Vulnerabilities

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,7 +758,46 @@ public ClientType getClientType() {
}
}

/**
* @deprecated Use {@link #getDatabricksOidcEndpoints()} instead. This method incorrectly returns
* Azure OIDC endpoints when azure_client_id is set, even for Databricks OAuth flows that
* don't use Azure authentication. This caused bugs where Databricks M2M OAuth would fail when
* ARM_CLIENT_ID was set for other purposes. Use instead: - getDatabricksOidcEndpoints(): For
* Databricks OAuth (oauth-m2m, external-browser, etc.). -
* getAzureEntraIdWorkspaceEndpoints(): For Azure Entra ID OIDC endpoints.
* @return The OIDC endpoints. This method dinamically returns the OIDC endpoints based on the
* config.
*/
@Deprecated
public OpenIDConnectEndpoints getOidcEndpoints() throws IOException {
if (isAzure() && getAzureClientId() != null) {
return getAzureEntraIdWorkspaceEndpoints();
}
return getDatabricksOidcEndpoints();
}

/**
* @return The Azure Entra ID OIDC endpoints.
*/
public OpenIDConnectEndpoints getAzureEntraIdWorkspaceEndpoints() throws IOException {
if (isAzure() && getAzureClientId() != null) {
Request request = new Request("GET", getHost() + "/oidc/oauth2/v2.0/authorize");
request.setRedirectionBehavior(false);
Response resp = getHttpClient().execute(request);
String realAuthUrl = resp.getFirstHeader("location");
if (realAuthUrl == null) {
return null;
}
return new OpenIDConnectEndpoints(
realAuthUrl.replaceAll("/authorize", "/token"), realAuthUrl);
}
return null;
}

/**
* @return The Databricks OIDC endpoints.
*/
public OpenIDConnectEndpoints getDatabricksOidcEndpoints() throws IOException {
if (discoveryUrl == null) {
return fetchDefaultOidcEndpoints();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ private void addOIDCCredentialsProviders(DatabricksConfig config) {
// This would also need to be updated to support unified hosts.
OpenIDConnectEndpoints endpoints = null;
try {
endpoints = config.getOidcEndpoints();
endpoints = config.getDatabricksOidcEndpoints();
} catch (Exception e) {
LOG.warn("Failed to get OpenID Connect endpoints", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ public static String resolveClientSecret(DatabricksConfig config) {
public static OpenIDConnectEndpoints resolveOidcEndpoints(DatabricksConfig config)
throws IOException {
if (config.getClientId() != null && config.getClientSecret() != null) {
return config.getOidcEndpoints();
return config.getDatabricksOidcEndpoints();
} else if (config.getAzureClientId() != null && config.getAzureClientSecret() != null) {
Request request = new Request("GET", config.getHost() + "/oidc/oauth2/v2.0/authorize");
request.setRedirectionBehavior(false);
Expand All @@ -69,6 +69,6 @@ public static OpenIDConnectEndpoints resolveOidcEndpoints(DatabricksConfig confi
return new OpenIDConnectEndpoints(
realAuthUrl.replaceAll("/authorize", "/token"), realAuthUrl);
}
return config.getOidcEndpoints();
return config.getDatabricksOidcEndpoints();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public OAuthHeaderFactory configure(DatabricksConfig config) {
// TODO: Azure returns 404 for UC workspace after redirecting to
// https://login.microsoftonline.com/{cfg.azure_tenant_id}/.well-known/oauth-authorization-server
try {
OpenIDConnectEndpoints jsonResponse = config.getOidcEndpoints();
OpenIDConnectEndpoints jsonResponse = config.getDatabricksOidcEndpoints();
ClientCredentials clientCredentials =
new ClientCredentials.Builder()
.withHttpClient(config.getHttpClient())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public void testWorkspaceLevelOidcEndpointsWithAccountId() throws IOException {
c.resolve(
new Environment(new HashMap<>(), new ArrayList<String>(), System.getProperty("os.name")));
assertEquals(
c.getOidcEndpoints().getAuthorizationEndpoint(),
c.getDatabricksOidcEndpoints().getAuthorizationEndpoint(),
"https://test-workspace.cloud.databricks.com/oidc/v1/authorize");
}
}
Expand All @@ -128,7 +128,7 @@ public void testWorkspaceLevelOidcEndpointsRetries() throws IOException {
c.resolve(
new Environment(new HashMap<>(), new ArrayList<String>(), System.getProperty("os.name")));
assertEquals(
c.getOidcEndpoints().getAuthorizationEndpoint(),
c.getDatabricksOidcEndpoints().getAuthorizationEndpoint(),
"https://test-workspace.cloud.databricks.com/oidc/v1/authorize");
}
}
Expand All @@ -139,7 +139,7 @@ public void testAccountLevelOidcEndpoints() throws IOException {
new DatabricksConfig()
.setHost("https://accounts.cloud.databricks.com")
.setAccountId("1234567890")
.getOidcEndpoints()
.getDatabricksOidcEndpoints()
.getAuthorizationEndpoint(),
"https://accounts.cloud.databricks.com/oidc/accounts/1234567890/v1/authorize");
}
Expand All @@ -163,7 +163,7 @@ public void testDiscoveryEndpoint() throws IOException {
.setHost(server.getUrl())
.setDiscoveryUrl(discoveryUrl)
.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build())
.getOidcEndpoints();
.getDatabricksOidcEndpoints();

assertEquals(
oidcEndpoints.getAuthorizationEndpoint(), "https://test.auth.endpoint/oidc/v1/authorize");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public void testOidcEndpointsForUnifiedHost() throws IOException {
.setExperimentalIsUnifiedHost(true)
.setAccountId("test-account-123");

OpenIDConnectEndpoints endpoints = config.getOidcEndpoints();
OpenIDConnectEndpoints endpoints = config.getDatabricksOidcEndpoints();

assertEquals(
"https://unified.databricks.com/oidc/accounts/test-account-123/v1/authorize",
Expand All @@ -138,7 +138,7 @@ public void testOidcEndpointsForUnifiedHostMissingAccountId() {
// No account ID set

DatabricksException exception =
assertThrows(DatabricksException.class, () -> config.getOidcEndpoints());
assertThrows(DatabricksException.class, () -> config.getDatabricksOidcEndpoints());
assertTrue(exception.getMessage().contains("account_id is required"));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@ void clientAndConsentTest() throws IOException {
.setHttpClient(new CommonsHttpClient.Builder().withTimeoutSeconds(30).build());
config.resolve();

assertEquals("tokenEndPointFromServer", config.getOidcEndpoints().getTokenEndpoint());
assertEquals(
"tokenEndPointFromServer", config.getDatabricksOidcEndpoints().getTokenEndpoint());

OAuthClient testClient =
new OAuthClient.Builder()
.withHttpClient(config.getHttpClient())
.withClientId(config.getClientId())
.withClientSecret(config.getClientSecret())
.withHost(config.getHost())
.withOpenIDConnectEndpoints(config.getOidcEndpoints())
.withOpenIDConnectEndpoints(config.getDatabricksOidcEndpoints())
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
.withScopes(config.getScopes())
.build();
Expand Down Expand Up @@ -92,15 +93,16 @@ void clientAndConsentTestWithCustomRedirectUrl() throws IOException {
.setScopes(Arrays.asList("sql"));
config.resolve();

assertEquals("tokenEndPointFromServer", config.getOidcEndpoints().getTokenEndpoint());
assertEquals(
"tokenEndPointFromServer", config.getDatabricksOidcEndpoints().getTokenEndpoint());

OAuthClient testClient =
new OAuthClient.Builder()
.withHttpClient(config.getHttpClient())
.withClientId(config.getClientId())
.withClientSecret(config.getClientSecret())
.withHost(config.getHost())
.withOpenIDConnectEndpoints(config.getOidcEndpoints())
.withOpenIDConnectEndpoints(config.getDatabricksOidcEndpoints())
.withRedirectUrl(config.getEffectiveOAuthRedirectUrl())
.withScopes(config.getScopes())
.build();
Expand Down Expand Up @@ -129,8 +131,9 @@ void openIDConnectEndPointsTestAccounts() throws IOException {
config.resolve();

String prefix = "https://accounts.cloud.databricks.com/oidc/accounts/" + config.getAccountId();
assertEquals(prefix + "/v1/token", config.getOidcEndpoints().getTokenEndpoint());
assertEquals(prefix + "/v1/authorize", config.getOidcEndpoints().getAuthorizationEndpoint());
assertEquals(prefix + "/v1/token", config.getDatabricksOidcEndpoints().getTokenEndpoint());
assertEquals(
prefix + "/v1/authorize", config.getDatabricksOidcEndpoints().getAuthorizationEndpoint());
}

@Test
Expand Down Expand Up @@ -278,7 +281,7 @@ void cacheWithValidRefreshableTokenTest() throws IOException {

// Spy on the config to inject the endpoints.
DatabricksConfig spyConfig = Mockito.spy(config);
Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints();
Mockito.doReturn(endpoints).when(spyConfig).getDatabricksOidcEndpoints();

// Configure provider.
HeaderFactory headerFactory = provider.configure(spyConfig);
Expand Down Expand Up @@ -343,7 +346,7 @@ void cacheWithValidNonRefreshableTokenTest() throws IOException {

// Spy on the config to inject the endpoints.
DatabricksConfig spyConfig = Mockito.spy(config);
Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints();
Mockito.doReturn(endpoints).when(spyConfig).getDatabricksOidcEndpoints();

// Configure provider.
HeaderFactory headerFactory = provider.configure(spyConfig);
Expand Down Expand Up @@ -415,7 +418,7 @@ void cacheWithInvalidAccessTokenValidRefreshTest() throws IOException {

// Spy on the config to inject the endpoints
DatabricksConfig spyConfig = Mockito.spy(config);
Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints();
Mockito.doReturn(endpoints).when(spyConfig).getDatabricksOidcEndpoints();

// Configure provider
HeaderFactory headerFactory = provider.configure(spyConfig);
Expand Down Expand Up @@ -524,7 +527,7 @@ void cacheWithInvalidAccessTokenRefreshFailingTest() throws IOException {

// Spy on the config to inject the endpoints
DatabricksConfig spyConfig = Mockito.spy(config);
Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints();
Mockito.doReturn(endpoints).when(spyConfig).getDatabricksOidcEndpoints();

// Configure provider
HeaderFactory headerFactory = provider.configure(spyConfig);
Expand Down Expand Up @@ -610,7 +613,7 @@ void cacheWithInvalidTokensTest() throws IOException {
"https://test.databricks.com/oidc/v1/token",
"https://test.databricks.com/oidc/v1/authorize");
DatabricksConfig spyConfig = Mockito.spy(config);
Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints();
Mockito.doReturn(endpoints).when(spyConfig).getDatabricksOidcEndpoints();

// Configure provider
HeaderFactory headerFactory = provider.configure(spyConfig);
Expand Down Expand Up @@ -738,7 +741,7 @@ void externalBrowserAuthWithAzureClientIdTest() throws IOException {
"https://test.azuredatabricks.net/oidc/v1/token",
"https://test.azuredatabricks.net/oidc/v1/authorize");
DatabricksConfig spyConfig = Mockito.spy(config);
Mockito.doReturn(endpoints).when(spyConfig).getOidcEndpoints();
Mockito.doReturn(endpoints).when(spyConfig).getDatabricksOidcEndpoints();

// Configure provider
HeaderFactory headerFactory = provider.configure(spyConfig);
Expand Down
Loading