Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 56 additions & 3 deletions services/account/backend.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package account

import (
"encoding/base64"
"errors"
"fmt"
"slices"
"strings"
"sync"
)

var (
errNoAlternateContact = errors.New("ResourceNotFoundException: no alternate contact found")
errNoContactInfo = errors.New("ResourceNotFoundException: no contact information set")
// errInvalidNextToken is returned when ListRegions receives an undecodable cursor.
errInvalidNextToken = errors.New("ValidationException: invalid nextToken")
)

// RegionOptStatus represents the opt-in status of an AWS region.
Expand Down Expand Up @@ -146,8 +150,14 @@ func (b *InMemoryBackend) DescribeAccount() (*Details, error) {
}, nil
}

// ListRegions returns regions filtered by opt-in status.
func (b *InMemoryBackend) ListRegions(statusFilter []RegionOptStatus, _ int, _ string) ([]*Region, string, error) {
// ListRegions returns regions filtered by opt-in status, honouring AWS's
// maxResults/nextToken pagination. nextToken is an opaque cursor (base64 of the
// exclusive-start RegionName); when maxResults <= 0 the full filtered list is returned.
func (b *InMemoryBackend) ListRegions(
statusFilter []RegionOptStatus,
maxResults int,
nextToken string,
) ([]*Region, string, error) {
b.mu.RLock()
defer b.mu.RUnlock()

Expand All @@ -159,7 +169,50 @@ func (b *InMemoryBackend) ListRegions(statusFilter []RegionOptStatus, _ int, _ s
}
}

return filtered, "", nil
// Order deterministically by RegionName so the name-based pagination cursor is
// stable across pages (AWS returns regions in alphabetical order).
slices.SortFunc(filtered, func(a, b *Region) int {
return strings.Compare(a.RegionName, b.RegionName)
})

// Apply the exclusive-start cursor: skip everything up to and including the
// region named by the decoded token.
if nextToken != "" {
start, decErr := decodeRegionToken(nextToken)
if decErr != nil {
return nil, "", errInvalidNextToken
}

idx := 0
for idx < len(filtered) && filtered[idx].RegionName <= start {
idx++
}

filtered = filtered[idx:]
}

if maxResults <= 0 || maxResults >= len(filtered) {
return filtered, "", nil
}

page := filtered[:maxResults]

return page, encodeRegionToken(page[len(page)-1].RegionName), nil
}

// encodeRegionToken produces an opaque pagination cursor for the given RegionName.
func encodeRegionToken(regionName string) string {
return base64.StdEncoding.EncodeToString([]byte(regionName))
}

// decodeRegionToken reverses encodeRegionToken.
func decodeRegionToken(token string) (string, error) {
raw, err := base64.StdEncoding.DecodeString(token)
if err != nil {
return "", err
}

return string(raw), nil
}

// GetAlternateContact retrieves an alternate contact by type.
Expand Down
6 changes: 6 additions & 0 deletions services/account/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"

"github.com/labstack/echo/v5"
Expand Down Expand Up @@ -203,6 +204,11 @@ func (h *Handler) handleListRegions(c *echo.Context, q interface{ Get(string) st
}

maxResults := 0
if raw := q.Get(queryMaxResults); raw != "" {
if n, convErr := strconv.Atoi(raw); convErr == nil && n > 0 {
maxResults = n
}
}
nextToken := q.Get(queryNextToken)

regions, next, err := h.Backend.ListRegions(statusFilter, maxResults, nextToken)
Expand Down
64 changes: 64 additions & 0 deletions services/account/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,70 @@ func TestHandler_ListRegions(t *testing.T) {
}
}

// TestHandler_ListRegions_Pagination verifies that ListRegions honours maxResults and
// returns an opaque nextToken, and that paging through with the token yields every
// region exactly once with no overlap.
func TestHandler_ListRegions_Pagination(t *testing.T) {
t.Parallel()

h := newTestHandler(t)

// Discover the full unpaged region set first.
allRec := doRequest(t, h, http.MethodGet, "/regions", nil)
require.Equal(t, http.StatusOK, allRec.Code)

var all struct {
NextToken string `json:"NextToken"`
Regions []account.Region `json:"Regions"`
}
require.NoError(t, json.NewDecoder(allRec.Body).Decode(&all))
require.Empty(t, all.NextToken, "unpaged listing must not return a token")
require.Greater(t, len(all.Regions), 2, "need several regions to exercise paging")

const pageSize = 2
seen := make([]string, 0, len(all.Regions))
token := ""

for {
path := "/regions?maxResults=2"
if token != "" {
path += "&nextToken=" + token
}

rec := doRequest(t, h, http.MethodGet, path, nil)
require.Equal(t, http.StatusOK, rec.Code)

var page struct {
NextToken string `json:"NextToken"`
Regions []account.Region `json:"Regions"`
}
require.NoError(t, json.NewDecoder(rec.Body).Decode(&page))
require.LessOrEqual(t, len(page.Regions), pageSize, "page must not exceed maxResults")

for _, r := range page.Regions {
seen = append(seen, r.RegionName)
}

if page.NextToken == "" {
break
}
token = page.NextToken
}

// Every region appears exactly once across the pages — no overlap, no gaps.
assert.Len(t, seen, len(all.Regions))
assert.ElementsMatch(t, regionNames(all.Regions), seen)
}

func regionNames(regions []account.Region) []string {
names := make([]string, 0, len(regions))
for _, r := range regions {
names = append(names, r.RegionName)
}

return names
}

func TestHandler_AlternateContact_PutGetDelete(t *testing.T) {
t.Parallel()

Expand Down
12 changes: 10 additions & 2 deletions services/apigatewayv2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -1257,15 +1257,23 @@ func (h *Handler) handleExportAPI(c *echo.Context, apiID, specification string)
}

func (h *Handler) handleGetModelTemplate(c *echo.Context, apiID, modelID string) error {
if _, err := h.Backend.GetModel(apiID, modelID); err != nil {
model, err := h.Backend.GetModel(apiID, modelID)
if err != nil {
if errors.Is(err, ErrAPINotFound) || errors.Is(err, ErrModelNotFound) {
return c.JSON(http.StatusNotFound, notFoundResponse{Message: msgNotFound})
}

return c.JSON(http.StatusInternalServerError, notFoundResponse{Message: err.Error()})
}

return c.JSON(http.StatusOK, map[string]string{"value": emptyModelTemplate})
// AWS returns the model's schema as the template value; fall back to an empty
// object only when the model has no schema defined.
value := model.Schema
if value == "" {
value = emptyModelTemplate
}

return c.JSON(http.StatusOK, map[string]string{"value": value})
}

func (h *Handler) handleDeleteAccessLogSettings(c *echo.Context, apiID, stageName string) error {
Expand Down
55 changes: 55 additions & 0 deletions services/apigatewayv2/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2131,6 +2131,61 @@ func TestHandler_CreateModel(t *testing.T) {
}
}

// TestHandler_GetModelTemplate verifies that GetModelTemplate returns the model's
// schema as the template value, falling back to an empty object only when the model
// has no schema defined.
func TestHandler_GetModelTemplate(t *testing.T) {
t.Parallel()

tests := []struct {
name string
schema string
wantValue string
}{
{
name: "returns_schema",
schema: `{"type":"object","properties":{"id":{"type":"string"}}}`,
wantValue: `{"type":"object","properties":{"id":{"type":"string"}}}`,
},
{
name: "empty_schema_falls_back_to_object",
schema: "",
wantValue: "{}",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

h := newTestHandler()
apiID := createAPI(t, h, "tmpl-api")

body := map[string]any{"name": "TmplModel", "contentType": "application/json"}
if tt.schema != "" {
body["schema"] = tt.schema
}

createRec := doRequest(t, h, http.MethodPost,
fmt.Sprintf("/v2/apis/%s/models", apiID), body)
require.Equal(t, http.StatusCreated, createRec.Code)

var model apigatewayv2.Model
require.NoError(t, json.Unmarshal(createRec.Body.Bytes(), &model))

rec := doRequest(t, h, http.MethodGet,
fmt.Sprintf("/v2/apis/%s/models/%s/template", apiID, model.ModelID), nil)
require.Equal(t, http.StatusOK, rec.Code)

var out struct {
Value string `json:"value"`
}
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &out))
assert.Equal(t, tt.wantValue, out.Value)
})
}
}

func TestHandler_CreateRouteResponse(t *testing.T) {
t.Parallel()

Expand Down
Loading
Loading