Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions engine/internal/rdsrefresh/dblab.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ type SourceConfigUpdate struct {
Password string
// RDSIAMDBInstance is the RDS DB instance identifier for IAM auth. When empty, this field is omitted from the config update.
RDSIAMDBInstance string
// DumpParallelJobs sets the -j flag for pg_dump. When zero, the existing value is preserved.
DumpParallelJobs int
// RestoreParallelJobs sets the -j flag for pg_restore. When zero, the existing value is preserved.
RestoreParallelJobs int
}

// UpdateSourceConfig updates the source database connection in DBLab config.
Expand All @@ -198,6 +202,16 @@ func (c *DBLabClient) UpdateSourceConfig(ctx context.Context, update SourceConfi
proj.RDSIAMDBInstance = &update.RDSIAMDBInstance
}

if update.DumpParallelJobs > 0 {
dumpJobs := int64(update.DumpParallelJobs)
proj.DumpParallelJobs = &dumpJobs
}

if update.RestoreParallelJobs > 0 {
restoreJobs := int64(update.RestoreParallelJobs)
proj.RestoreParallelJobs = &restoreJobs
}

nested := map[string]interface{}{}

// defensive error check: StoreJSON only fails if target is not an addressable struct,
Expand Down
64 changes: 64 additions & 0 deletions engine/internal/rdsrefresh/dblab_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,70 @@ func TestDBLabClientUpdateSourceConfig(t *testing.T) {
assert.Nil(t, receivedConfig.RDSIAMDBInstance)
})

t.Run("successful with parallelism settings", func(t *testing.T) {
var receivedConfig models.ConfigProjection

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var nested map[string]interface{}
err := json.NewDecoder(r.Body).Decode(&nested)
require.NoError(t, err)

err = projection.LoadJSON(&receivedConfig, nested, projection.LoadOptions{
Groups: []string{"default", "sensitive"},
})
require.NoError(t, err)

w.WriteHeader(http.StatusOK)
}))
defer server.Close()

client, err := NewDBLabClient(&DBLabConfig{APIEndpoint: server.URL, Token: "test-token"})
require.NoError(t, err)

err = client.UpdateSourceConfig(context.Background(), SourceConfigUpdate{
Host: "clone-host.rds.amazonaws.com", Port: 5432, DBName: "postgres",
Username: "dbuser", Password: "dbpass",
DumpParallelJobs: 4, RestoreParallelJobs: 8,
})
require.NoError(t, err)

require.NotNil(t, receivedConfig.DumpParallelJobs)
assert.Equal(t, int64(4), *receivedConfig.DumpParallelJobs)
require.NotNil(t, receivedConfig.RestoreParallelJobs)
assert.Equal(t, int64(8), *receivedConfig.RestoreParallelJobs)
})

t.Run("omits parallelism when zero", func(t *testing.T) {
var receivedConfig models.ConfigProjection

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var nested map[string]interface{}
err := json.NewDecoder(r.Body).Decode(&nested)
require.NoError(t, err)

err = projection.LoadJSON(&receivedConfig, nested, projection.LoadOptions{
Groups: []string{"default", "sensitive"},
})
require.NoError(t, err)

w.WriteHeader(http.StatusOK)
}))
defer server.Close()

client, err := NewDBLabClient(&DBLabConfig{APIEndpoint: server.URL, Token: "test-token"})
require.NoError(t, err)

err = client.UpdateSourceConfig(context.Background(), SourceConfigUpdate{
Host: "host.rds.amazonaws.com", Port: 5432, DBName: "postgres",
Username: "dbuser", Password: "dbpass",
DumpParallelJobs: 0, RestoreParallelJobs: 0,
})
require.NoError(t, err)

assert.Nil(t, receivedConfig.DumpParallelJobs)
assert.Nil(t, receivedConfig.RestoreParallelJobs)
})

t.Run("error on non-2xx status", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
Expand Down
148 changes: 148 additions & 0 deletions engine/internal/rdsrefresh/parallelism.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
2026 © PostgresAI
*/

package rdsrefresh

import (
"fmt"
"runtime"
"strconv"
"strings"

"gitlab.com/postgres-ai/database-lab/v3/pkg/log"
)

const (
// rdsInstanceClassPrefix is stripped to derive the instance size.
rdsInstanceClassPrefix = "db."

// minParallelJobs is the minimum parallelism level.
minParallelJobs = 1
)

// instanceSizeVCPUs maps AWS instance size suffixes to their typical vCPU count.
// this mapping is consistent across most instance families (m5, m6g, r5, r6g, c5, etc.).
// graviton and intel/amd variants of the same size have the same vCPU count.
var instanceSizeVCPUs = map[string]int{
"micro": 1,
"small": 1,
"medium": 2,
"large": 2,
"xlarge": 4,
"2xlarge": 8,
"3xlarge": 12,
"4xlarge": 16,
"6xlarge": 24,
"8xlarge": 32,
"9xlarge": 36,
"10xlarge": 40,
"12xlarge": 48,
"16xlarge": 64,
"18xlarge": 72,
"24xlarge": 96,
"32xlarge": 128,
"48xlarge": 192,
"metal": 96,
}

// ParallelismConfig holds the computed parallelism levels for dump and restore.
type ParallelismConfig struct {
DumpJobs int
RestoreJobs int
}

// ResolveParallelism determines the optimal parallelism levels for pg_dump and pg_restore.
// dump parallelism is based on the vCPU count of the RDS clone instance class.
// restore parallelism is based on the vCPU count of the local machine.
// local vCPU detection uses runtime.NumCPU(), which works on Linux
// (the target platform for DBLab Engine).
func ResolveParallelism(cfg *Config) (*ParallelismConfig, error) {
dumpJobs, err := resolveRDSInstanceVCPUs(cfg.RDSClone.InstanceClass)
if err != nil {
return nil, fmt.Errorf("failed to resolve RDS instance vCPUs: %w", err)
}

restoreJobs := resolveLocalVCPUs()

log.Msg("auto-parallelism: dump jobs =", dumpJobs, "(RDS clone vCPUs), restore jobs =", restoreJobs, "(local vCPUs)")

return &ParallelismConfig{
DumpJobs: dumpJobs,
RestoreJobs: restoreJobs,
}, nil
}

// resolveRDSInstanceVCPUs estimates the vCPU count for the given RDS instance class
// by parsing the instance size suffix (e.g. "xlarge" from "db.m5.xlarge").
// the mapping covers standard AWS size naming used across RDS instance families.
// if the size is not recognized, it attempts to parse a numeric multiplier prefix
// (e.g. "2xlarge" → 8 vCPUs).
func resolveRDSInstanceVCPUs(instanceClass string) (int, error) {
size, err := extractInstanceSize(instanceClass)
if err != nil {
return 0, err
}

if vcpus, ok := instanceSizeVCPUs[size]; ok {
return vcpus, nil
}

// handle unlisted NUMxlarge sizes by parsing the multiplier
vcpus, err := parseXlargeMultiplier(size)
if err != nil {
return 0, fmt.Errorf("unknown instance size %q in class %q", size, instanceClass)
}

return vcpus, nil
}

// extractInstanceSize extracts the size component from an RDS instance class.
// for example, "db.m5.xlarge" → "xlarge", "db.r6g.2xlarge" → "2xlarge".
func extractInstanceSize(instanceClass string) (string, error) {
if !strings.HasPrefix(instanceClass, rdsInstanceClassPrefix) {
return "", fmt.Errorf("invalid RDS instance class %q: expected %q prefix", instanceClass, rdsInstanceClassPrefix)
}

withoutPrefix := strings.TrimPrefix(instanceClass, rdsInstanceClassPrefix)

// format is "family.size", e.g. "m5.xlarge" or "r6g.2xlarge"
parts := strings.SplitN(withoutPrefix, ".", 2)

const expectedParts = 2
if len(parts) != expectedParts || parts[1] == "" {
return "", fmt.Errorf("invalid RDS instance class %q: expected format db.<family>.<size>", instanceClass)
}

return parts[1], nil
}

// parseXlargeMultiplier handles NUMxlarge patterns not in the static map.
// for example, "5xlarge" → 5 * 4 = 20 vCPUs.
func parseXlargeMultiplier(size string) (int, error) {
idx := strings.Index(size, "xlarge")
if idx <= 0 {
return 0, fmt.Errorf("not an xlarge variant: %q", size)
}

multiplier, err := strconv.Atoi(size[:idx])
if err != nil {
return 0, fmt.Errorf("invalid multiplier in %q: %w", size, err)
}

const vcpusPerXlarge = 4

return multiplier * vcpusPerXlarge, nil
}

// resolveLocalVCPUs returns the number of logical CPUs available on the local machine.
// uses runtime.NumCPU() which reads from /proc/cpuinfo on Linux
// (the target platform for DBLab Engine).
func resolveLocalVCPUs() int {
cpus := runtime.NumCPU()
if cpus < minParallelJobs {
return minParallelJobs
}

return cpus
}
140 changes: 140 additions & 0 deletions engine/internal/rdsrefresh/parallelism_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
2026 © PostgresAI
*/

package rdsrefresh

import (
"runtime"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestExtractInstanceSize(t *testing.T) {
testCases := []struct {
instanceClass string
expectedSize string
expectErr bool
}{
{instanceClass: "db.m5.xlarge", expectedSize: "xlarge"},
{instanceClass: "db.t3.medium", expectedSize: "medium"},
{instanceClass: "db.r6g.2xlarge", expectedSize: "2xlarge"},
{instanceClass: "db.m5.metal", expectedSize: "metal"},
{instanceClass: "db.t3.micro", expectedSize: "micro"},
{instanceClass: "db.r6g.16xlarge", expectedSize: "16xlarge"},
{instanceClass: "m5.xlarge", expectErr: true},
{instanceClass: "db.m5", expectErr: true},
{instanceClass: "db.", expectErr: true},
{instanceClass: "", expectErr: true},
}

for _, tc := range testCases {
t.Run(tc.instanceClass, func(t *testing.T) {
size, err := extractInstanceSize(tc.instanceClass)

if tc.expectErr {
require.Error(t, err)
return
}

require.NoError(t, err)
assert.Equal(t, tc.expectedSize, size)
})
}
}

func TestResolveRDSInstanceVCPUs(t *testing.T) {
testCases := []struct {
instanceClass string
expectedVCPUs int
expectErr bool
}{
{instanceClass: "db.t3.micro", expectedVCPUs: 1},
{instanceClass: "db.t3.small", expectedVCPUs: 1},
{instanceClass: "db.t3.medium", expectedVCPUs: 2},
{instanceClass: "db.m5.large", expectedVCPUs: 2},
{instanceClass: "db.m5.xlarge", expectedVCPUs: 4},
{instanceClass: "db.r6g.2xlarge", expectedVCPUs: 8},
{instanceClass: "db.r6g.4xlarge", expectedVCPUs: 16},
{instanceClass: "db.r6g.8xlarge", expectedVCPUs: 32},
{instanceClass: "db.r6g.16xlarge", expectedVCPUs: 64},
{instanceClass: "db.m5.24xlarge", expectedVCPUs: 96},
{instanceClass: "db.m5.metal", expectedVCPUs: 96},
{instanceClass: "db.m5.5xlarge", expectedVCPUs: 20},
{instanceClass: "invalid", expectErr: true},
{instanceClass: "db.m5", expectErr: true},
{instanceClass: "db.m5.unknown", expectErr: true},
}

for _, tc := range testCases {
t.Run(tc.instanceClass, func(t *testing.T) {
vcpus, err := resolveRDSInstanceVCPUs(tc.instanceClass)

if tc.expectErr {
require.Error(t, err)
return
}

require.NoError(t, err)
assert.Equal(t, tc.expectedVCPUs, vcpus)
})
}
}

func TestParseXlargeMultiplier(t *testing.T) {
testCases := []struct {
size string
expectedVCPUs int
expectErr bool
}{
{size: "2xlarge", expectedVCPUs: 8},
{size: "4xlarge", expectedVCPUs: 16},
{size: "5xlarge", expectedVCPUs: 20},
{size: "xlarge", expectErr: true},
{size: "large", expectErr: true},
{size: "abcxlarge", expectErr: true},
}

for _, tc := range testCases {
t.Run(tc.size, func(t *testing.T) {
vcpus, err := parseXlargeMultiplier(tc.size)

if tc.expectErr {
require.Error(t, err)
return
}

require.NoError(t, err)
assert.Equal(t, tc.expectedVCPUs, vcpus)
})
}
}

func TestResolveLocalVCPUs(t *testing.T) {
vcpus := resolveLocalVCPUs()

assert.Equal(t, runtime.NumCPU(), vcpus)
assert.GreaterOrEqual(t, vcpus, minParallelJobs)
}

func TestResolveParallelism(t *testing.T) {
t.Run("resolves both dump and restore jobs", func(t *testing.T) {
cfg := &Config{RDSClone: RDSCloneConfig{InstanceClass: "db.m5.xlarge"}}

result, err := ResolveParallelism(cfg)

require.NoError(t, err)
assert.Equal(t, 4, result.DumpJobs)
assert.Equal(t, runtime.NumCPU(), result.RestoreJobs)
})

t.Run("returns error for invalid instance class", func(t *testing.T) {
cfg := &Config{RDSClone: RDSCloneConfig{InstanceClass: "invalid"}}

_, err := ResolveParallelism(cfg)

require.Error(t, err)
})
}
Loading
Loading