Skip to content

Commit ea7ed47

Browse files
Merge pull request #757 from samanthajayasinghe/sesion-policy-prio
[SREP-1313] feat: Move the inline session policy to the last hop of the assume role sequence
2 parents bbcffa9 + f9b81db commit ea7ed47

4 files changed

Lines changed: 217 additions & 39 deletions

File tree

cmd/ocm-backplane/cloud/common.go

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ var GetCallerIdentity = func(client *sts.Client) error {
4646
return err
4747
}
4848

49+
// CheckEgressIP checks the egress IP of the client
50+
// This is a wrapper around checkEgressIPImpl to allow for easy mocking
51+
var CheckEgressIP = checkEgressIPImpl
52+
4953
// QueryConfig Wrapper for the configuration needed for cloud requests
5054
type QueryConfig struct {
5155
config.BackplaneConfiguration
@@ -313,6 +317,11 @@ func (cfg *QueryConfig) getIsolatedCredentials(ocmToken string) (aws.Credentials
313317
return aws.Credentials{}, fmt.Errorf("failed to unmarshal response: %w", err)
314318
}
315319

320+
inlinePolicy, err := verifyTrustedIPAndGetPolicy(cfg)
321+
if err != nil {
322+
return aws.Credentials{}, err
323+
}
324+
316325
assumeRoleArnSessionSequence := make([]awsutil.RoleArnSession, 0, len(roleChainResponse.AssumptionSequence))
317326
for _, namedRoleArnEntry := range roleChainResponse.AssumptionSequence {
318327
roleArnSession := awsutil.RoleArnSession{RoleArn: namedRoleArnEntry.Arn}
@@ -325,6 +334,7 @@ func (cfg *QueryConfig) getIsolatedCredentials(ocmToken string) (aws.Credentials
325334
roleArnSession.PolicyARNs = []types.PolicyDescriptorType{}
326335
if namedRoleArnEntry.Name == CustomerRoleArnName {
327336
roleArnSession.IsCustomerRole = true
337+
328338
// Add the session policy ARN for selected roles
329339
if roleChainResponse.SessionPolicyArn != "" {
330340
logger.Debugf("Adding session policy ARN for role %s: %s", namedRoleArnEntry.Name, roleChainResponse.SessionPolicyArn)
@@ -333,7 +343,10 @@ func (cfg *QueryConfig) getIsolatedCredentials(ocmToken string) (aws.Credentials
333343
Arn: aws.String(roleChainResponse.SessionPolicyArn),
334344
},
335345
}
346+
} else {
347+
roleArnSession.Policy = &inlinePolicy
336348
}
349+
337350
} else {
338351
roleArnSession.IsCustomerRole = false
339352
}
@@ -347,53 +360,60 @@ func (cfg *QueryConfig) getIsolatedCredentials(ocmToken string) (aws.Credentials
347360
Credentials: NewStaticCredentialsProvider(seedCredentials.AccessKeyID, seedCredentials.SecretAccessKey, seedCredentials.SessionToken),
348361
})
349362

363+
targetCredentials, err := AssumeRoleSequence(
364+
seedClient,
365+
assumeRoleArnSessionSequence,
366+
cfg.BackplaneConfiguration.ProxyURL,
367+
awsutil.DefaultSTSClientProviderFunc,
368+
)
369+
if err != nil {
370+
return aws.Credentials{}, fmt.Errorf("failed to assume role sequence: %w", err)
371+
}
372+
return targetCredentials, nil
373+
}
374+
375+
// verifyTrustedIPAndGetPolicy verifies that the client IP is in the trusted IP range
376+
// and returns the inline policy for the trusted IPs
377+
func verifyTrustedIPAndGetPolicy(cfg *QueryConfig) (awsutil.PolicyDocument, error) {
350378
var proxyURL *url.URL
351379

352380
if cfg.BackplaneConfiguration.ProxyURL != nil {
381+
var err error
353382
proxyURL, err = url.Parse(*cfg.BackplaneConfiguration.ProxyURL)
354383
if err != nil {
355-
return aws.Credentials{}, fmt.Errorf("failed to parse proxy URL: %w", err)
384+
return awsutil.PolicyDocument{}, fmt.Errorf("failed to parse proxy URL: %w", err)
356385
}
357386
}
358387
httpClient := &http.Client{
359388
Transport: &http.Transport{
360389
Proxy: http.ProxyURL(proxyURL),
361390
},
362391
}
363-
clientIP, err := checkEgressIP(httpClient, "https://checkip.amazonaws.com/")
392+
clientIP, err := CheckEgressIP(httpClient, "https://checkip.amazonaws.com/")
364393
if err != nil {
365-
return aws.Credentials{}, fmt.Errorf("failed to determine client IP: %w", err)
394+
return awsutil.PolicyDocument{}, fmt.Errorf("failed to determine client IP: %w", err)
366395
}
367396

368397
trustedRange, err := getTrustedIPList(cfg.OcmConnection)
369398
if err != nil {
370-
return aws.Credentials{}, err
399+
return awsutil.PolicyDocument{}, err
371400
}
372401

373402
err = verifyIPTrusted(clientIP, trustedRange)
374403
if err != nil {
375-
return aws.Credentials{}, err
404+
return awsutil.PolicyDocument{}, err
376405
}
377406

378407
inlinePolicy, err := getTrustedIPInlinePolicy(trustedRange)
379408
if err != nil {
380-
return aws.Credentials{}, fmt.Errorf("failed to build inline policy: %w", err)
409+
return awsutil.PolicyDocument{}, fmt.Errorf("failed to build inline policy: %w", err)
381410
}
382411

383-
targetCredentials, err := AssumeRoleSequence(
384-
seedClient,
385-
assumeRoleArnSessionSequence,
386-
cfg.BackplaneConfiguration.ProxyURL,
387-
awsutil.DefaultSTSClientProviderFunc,
388-
&inlinePolicy,
389-
)
390-
if err != nil {
391-
return aws.Credentials{}, fmt.Errorf("failed to assume role sequence: %w", err)
392-
}
393-
return targetCredentials, nil
412+
return inlinePolicy, nil
394413
}
395414

396-
func checkEgressIP(client *http.Client, url string) (net.IP, error) {
415+
// checkEgressIPImpl checks the egress IP of the client
416+
func checkEgressIPImpl(client *http.Client, url string) (net.IP, error) {
397417
resp, err := client.Get(url)
398418
if err != nil {
399419
return nil, fmt.Errorf("failed to fetch IP: %w", err)

cmd/ocm-backplane/cloud/common_test.go

Lines changed: 166 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,15 +285,15 @@ var _ = Describe("getIsolatedCredentials", func() {
285285
defer server.Close()
286286

287287
client := &http.Client{}
288-
ip, err := checkEgressIP(client, server.URL)
288+
ip, err := CheckEgressIP(client, server.URL)
289289

290290
Expect(err).NotTo(HaveOccurred())
291291
Expect(ip).To(Equal(net.ParseIP(mockIP)))
292292
})
293293
It("should return an error when the HTTP GET fails", func() {
294294
client = &http.Client{}
295295
// Invalid URL to force error
296-
ip, err := checkEgressIP(client, "http://invalid_url")
296+
ip, err := CheckEgressIP(client, "http://invalid_url")
297297
Expect(err).To(HaveOccurred())
298298
Expect(ip).To(BeNil())
299299
})
@@ -303,7 +303,7 @@ var _ = Describe("getIsolatedCredentials", func() {
303303
}))
304304
client = server.Client()
305305

306-
ip, err := checkEgressIP(client, server.URL)
306+
ip, err := CheckEgressIP(client, server.URL)
307307
Expect(err).To(MatchError(ContainSubstring("failed to parse IP")))
308308
Expect(ip).To(BeNil())
309309
})
@@ -359,6 +359,85 @@ var _ = Describe("getIsolatedCredentials", func() {
359359
Expect(err).To(BeNil())
360360
})
361361
})
362+
363+
Context("Execute verifyTrustedIPAndGetPolicy", func() {
364+
It("should successfully verify IP and return policy when IP is in trusted range", func() {
365+
// Mock the IP check to return a valid IP
366+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
367+
fmt.Fprint(w, "209.10.10.10") // IP that matches our test trusted range
368+
}))
369+
defer server.Close()
370+
371+
// Override the checkEgressIP function to use our test server
372+
originalCheckEgressIP := CheckEgressIP
373+
CheckEgressIP = func(client *http.Client, url string) (net.IP, error) {
374+
return originalCheckEgressIP(client, server.URL)
375+
}
376+
defer func() {
377+
CheckEgressIP = originalCheckEgressIP
378+
}()
379+
380+
// Set up expected trusted IP list
381+
ip1 := cmv1.NewTrustedIp().ID("209.10.10.10").Enabled(true)
382+
expectedIPList, err := cmv1.NewTrustedIpList().Items(ip1).Build()
383+
Expect(err).To(BeNil())
384+
mockOcmInterface.EXPECT().GetTrustedIPList(gomock.Any()).Return(expectedIPList, nil)
385+
386+
// Call the function
387+
policy, err := verifyTrustedIPAndGetPolicy(&testQueryConfig)
388+
389+
// Verify success
390+
Expect(err).To(BeNil())
391+
Expect(policy.Version).To(Equal("2012-10-17"))
392+
Expect(len(policy.Statement)).To(BeNumerically(">", 0))
393+
})
394+
395+
It("should fail when client IP is not in trusted range", func() {
396+
// Mock the IP check to return an untrusted IP
397+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
398+
fmt.Fprint(w, "192.168.1.1") // IP that doesn't match our test trusted range
399+
}))
400+
defer server.Close()
401+
402+
// Override the checkEgressIP function to use our test server
403+
originalCheckEgressIP := CheckEgressIP
404+
CheckEgressIP = func(client *http.Client, url string) (net.IP, error) {
405+
return originalCheckEgressIP(client, server.URL)
406+
}
407+
defer func() {
408+
CheckEgressIP = originalCheckEgressIP
409+
}()
410+
411+
// Set up expected trusted IP list (only has 209.x.x.x IPs)
412+
ip1 := cmv1.NewTrustedIp().ID("209.10.10.10").Enabled(true)
413+
expectedIPList, err := cmv1.NewTrustedIpList().Items(ip1).Build()
414+
Expect(err).To(BeNil())
415+
mockOcmInterface.EXPECT().GetTrustedIPList(gomock.Any()).Return(expectedIPList, nil)
416+
417+
// Call the function
418+
_, err = verifyTrustedIPAndGetPolicy(&testQueryConfig)
419+
420+
// Verify failure
421+
Expect(err).To(HaveOccurred())
422+
Expect(err.Error()).To(ContainSubstring("client IP 192.168.1.1 is not in the trusted IP range"))
423+
})
424+
425+
It("should fail when proxy URL is invalid", func() {
426+
// Set an invalid proxy URL
427+
invalidProxyURL := "://invalid-url"
428+
testQueryConfig.BackplaneConfiguration.ProxyURL = &invalidProxyURL
429+
430+
// Call the function
431+
_, err := verifyTrustedIPAndGetPolicy(&testQueryConfig)
432+
433+
// Verify failure
434+
Expect(err).To(HaveOccurred())
435+
Expect(err.Error()).To(ContainSubstring("failed to parse proxy URL"))
436+
437+
// Reset proxy URL
438+
testQueryConfig.BackplaneConfiguration.ProxyURL = nil
439+
})
440+
})
362441
})
363442

364443
// newTestCluster assembles a *cmv1.Cluster while handling the error to help out with inline test-case generation
@@ -537,6 +616,19 @@ var _ = Describe("PolicyARNs Integration", func() {
537616

538617
// Helper function to simulate the getIsolatedCredentials logic
539618
simulateGetIsolatedCredentialsLogic := func(roleChainResponse assumeChainResponse) []awsutil.RoleArnSession {
619+
// Create a mock inline policy for testing
620+
mockInlinePolicy := &awsutil.PolicyDocument{
621+
Version: "2012-10-17",
622+
Statement: []awsutil.PolicyStatement{
623+
{
624+
Sid: "TestPolicy",
625+
Effect: "Allow",
626+
Action: []string{"s3:GetObject"},
627+
Resource: aws.String("*"),
628+
},
629+
},
630+
}
631+
540632
assumeRoleArnSessionSequence := make([]awsutil.RoleArnSession, 0, len(roleChainResponse.AssumptionSequence))
541633
for _, namedRoleArnEntry := range roleChainResponse.AssumptionSequence {
542634
roleArnSession := awsutil.RoleArnSession{RoleArn: namedRoleArnEntry.Arn}
@@ -550,14 +642,18 @@ var _ = Describe("PolicyARNs Integration", func() {
550642
roleArnSession.PolicyARNs = []types.PolicyDescriptorType{}
551643
if namedRoleArnEntry.Name == CustomerRoleArnName {
552644
roleArnSession.IsCustomerRole = true
645+
553646
// Add the session policy ARN for selected roles
554647
if roleChainResponse.SessionPolicyArn != "" {
555648
roleArnSession.PolicyARNs = []types.PolicyDescriptorType{
556649
{
557650
Arn: aws.String(roleChainResponse.SessionPolicyArn),
558651
},
559652
}
653+
} else {
654+
roleArnSession.Policy = mockInlinePolicy
560655
}
656+
561657
} else {
562658
roleArnSession.IsCustomerRole = false
563659
}
@@ -591,6 +687,8 @@ var _ = Describe("PolicyARNs Integration", func() {
591687
Expect(customerRole.Name).To(Equal(CustomerRoleArnName))
592688
Expect(len(customerRole.PolicyARNs)).To(Equal(1))
593689
Expect(*customerRole.PolicyARNs[0].Arn).To(Equal(testSessionPolicyArn))
690+
// Verify that Policy is nil when SessionPolicyArn is used
691+
Expect(customerRole.Policy).To(BeNil())
594692
})
595693

596694
It("should not set PolicyARNs for non-customer roles", func() {
@@ -613,6 +711,8 @@ var _ = Describe("PolicyARNs Integration", func() {
613711
Expect(supportRole.IsCustomerRole).To(BeFalse())
614712
Expect(supportRole.Name).To(Equal("Support-Role-Arn"))
615713
Expect(len(supportRole.PolicyARNs)).To(Equal(0))
714+
// Verify that Policy is nil for non-customer roles
715+
Expect(supportRole.Policy).To(BeNil())
616716
})
617717

618718
// Generated by Cursor
@@ -636,6 +736,69 @@ var _ = Describe("PolicyARNs Integration", func() {
636736
Expect(customerRole.IsCustomerRole).To(BeTrue())
637737
Expect(customerRole.Name).To(Equal(CustomerRoleArnName))
638738
Expect(len(customerRole.PolicyARNs)).To(Equal(0))
739+
// Verify that Policy is set when SessionPolicyArn is empty
740+
Expect(customerRole.Policy).ToNot(BeNil())
741+
})
742+
})
743+
744+
Context("when verifying roleArnSession.Policy field behavior", func() {
745+
It("should set Policy only for customer roles without SessionPolicyArn", func() {
746+
// Test customer role with SessionPolicyArn - Policy should be nil
747+
roleChainResponseWithArn := assumeChainResponse{
748+
AssumptionSequence: []namedRoleArn{
749+
{
750+
Name: CustomerRoleArnName,
751+
Arn: "arn:aws:iam::123456789012:role/customer-role",
752+
},
753+
},
754+
CustomerRoleSessionName: "customer-session",
755+
SessionPolicyArn: testSessionPolicyArn,
756+
}
757+
758+
assumeRoleArnSessionSequence := simulateGetIsolatedCredentialsLogic(roleChainResponseWithArn)
759+
customerRoleWithArn := assumeRoleArnSessionSequence[0]
760+
761+
Expect(customerRoleWithArn.IsCustomerRole).To(BeTrue())
762+
Expect(customerRoleWithArn.Policy).To(BeNil()) // Policy should be nil when SessionPolicyArn is used
763+
Expect(len(customerRoleWithArn.PolicyARNs)).To(Equal(1))
764+
765+
// Test customer role without SessionPolicyArn - Policy should be set
766+
roleChainResponseWithoutArn := assumeChainResponse{
767+
AssumptionSequence: []namedRoleArn{
768+
{
769+
Name: CustomerRoleArnName,
770+
Arn: "arn:aws:iam::123456789012:role/customer-role",
771+
},
772+
},
773+
CustomerRoleSessionName: "customer-session",
774+
SessionPolicyArn: "", // Empty SessionPolicyArn
775+
}
776+
777+
assumeRoleArnSessionSequence = simulateGetIsolatedCredentialsLogic(roleChainResponseWithoutArn)
778+
customerRoleWithoutArn := assumeRoleArnSessionSequence[0]
779+
780+
Expect(customerRoleWithoutArn.IsCustomerRole).To(BeTrue())
781+
Expect(customerRoleWithoutArn.Policy).ToNot(BeNil()) // Policy should be set when SessionPolicyArn is empty
782+
Expect(len(customerRoleWithoutArn.PolicyARNs)).To(Equal(0))
783+
784+
// Test non-customer role - Policy should always be nil
785+
roleChainResponseNonCustomer := assumeChainResponse{
786+
AssumptionSequence: []namedRoleArn{
787+
{
788+
Name: "Support-Role-Arn",
789+
Arn: "arn:aws:iam::123456789012:role/support-role",
790+
},
791+
},
792+
CustomerRoleSessionName: "customer-session",
793+
SessionPolicyArn: "", // Empty SessionPolicyArn
794+
}
795+
796+
assumeRoleArnSessionSequence = simulateGetIsolatedCredentialsLogic(roleChainResponseNonCustomer)
797+
nonCustomerRole := assumeRoleArnSessionSequence[0]
798+
799+
Expect(nonCustomerRole.IsCustomerRole).To(BeFalse())
800+
Expect(nonCustomerRole.Policy).To(BeNil()) // Policy should always be nil for non-customer roles
801+
Expect(len(nonCustomerRole.PolicyARNs)).To(Equal(0))
639802
})
640803
})
641804

pkg/awsutil/sts.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ type RoleArnSession struct {
132132
RoleSessionName string
133133
RoleArn string
134134
IsCustomerRole bool
135+
Policy *PolicyDocument
135136
PolicyARNs []types.PolicyDescriptorType
136137
}
137138

@@ -140,7 +141,6 @@ func AssumeRoleSequence(
140141
roleArnSessionSequence []RoleArnSession,
141142
proxyURL *string,
142143
stsClientProviderFunc STSClientProviderFunc,
143-
inlinePolicy *PolicyDocument,
144144
) (aws.Credentials, error) {
145145
if len(roleArnSessionSequence) == 0 {
146146
return aws.Credentials{}, errors.New("role ARN sequence cannot be empty")
@@ -157,7 +157,7 @@ func AssumeRoleSequence(
157157
roleArnSession.RoleSessionName,
158158
roleArnSession.IsCustomerRole,
159159
)
160-
result, err := AssumeRole(nextClient, roleArnSession.RoleSessionName, roleArnSession.RoleArn, inlinePolicy, roleArnSession.PolicyARNs)
160+
result, err := AssumeRole(nextClient, roleArnSession.RoleSessionName, roleArnSession.RoleArn, roleArnSession.Policy, roleArnSession.PolicyARNs)
161161
retryCount := 0
162162
for err != nil {
163163
// IAM policy updates can take a few seconds to resolve, and the sts.Client in AWS' Go SDK doesn't refresh itself on retries.
@@ -170,7 +170,7 @@ func AssumeRoleSequence(
170170
return aws.Credentials{}, fmt.Errorf("failed to create client with credentials for role %v: %w", roleArnSession.RoleArn, err)
171171
}
172172

173-
result, err = AssumeRole(nextClient, roleArnSession.RoleSessionName, roleArnSession.RoleArn, inlinePolicy, roleArnSession.PolicyARNs)
173+
result, err = AssumeRole(nextClient, roleArnSession.RoleSessionName, roleArnSession.RoleArn, roleArnSession.Policy, roleArnSession.PolicyARNs)
174174
if err != nil {
175175
logger.Debugf("failed to create client with credentials for role %s: name:%s %v", roleArnSession.RoleArn, roleArnSession.Name, err)
176176
}

0 commit comments

Comments
 (0)