diff --git a/cmd/cloudflared/common_service.go b/cmd/cloudflared/common_service.go index db7338c0a86..c7337f606b9 100644 --- a/cmd/cloudflared/common_service.go +++ b/cmd/cloudflared/common_service.go @@ -1,10 +1,13 @@ package main import ( + "strconv" + "github.com/rs/zerolog" "github.com/urfave/cli/v2" "github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil" + cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags" "github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel" ) @@ -28,3 +31,81 @@ func getServiceExtraArgsFromCliArgs(c *cli.Context, log *zerolog.Logger) ([]stri return make([]string, 0), nil } } + +type serviceFlagSerializer struct { + name string + serialize func(*cli.Context) (string, bool) +} + +var serviceFlagSerializers = []serviceFlagSerializer{ + {name: cfdflags.Region, serialize: serializeStringServiceFlag(cfdflags.Region)}, + {name: cfdflags.EdgeIpVersion, serialize: serializeStringServiceFlag(cfdflags.EdgeIpVersion)}, + {name: cfdflags.EdgeBindAddress, serialize: serializeStringServiceFlag(cfdflags.EdgeBindAddress)}, + {name: cfdflags.Protocol, serialize: serializeStringServiceFlag(cfdflags.Protocol)}, + {name: cfdflags.Retries, serialize: serializeIntServiceFlag(cfdflags.Retries)}, + {name: cfdflags.LogLevel, serialize: serializeStringServiceFlag(cfdflags.LogLevel)}, + {name: cfdflags.TransportLogLevel, serialize: serializeStringServiceFlag(cfdflags.TransportLogLevel)}, + {name: cfdflags.LogFile, serialize: serializeStringServiceFlag(cfdflags.LogFile)}, + {name: cfdflags.LogDirectory, serialize: serializeStringServiceFlag(cfdflags.LogDirectory)}, + {name: cfdflags.TraceOutput, serialize: serializeStringServiceFlag(cfdflags.TraceOutput)}, + {name: cfdflags.Metrics, serialize: serializeStringServiceFlag(cfdflags.Metrics)}, + {name: cfdflags.MetricsUpdateFreq, serialize: serializeDurationServiceFlag(cfdflags.MetricsUpdateFreq)}, + {name: cfdflags.GracePeriod, serialize: serializeDurationServiceFlag(cfdflags.GracePeriod)}, + {name: cfdflags.MaxActiveFlows, serialize: serializeIntServiceFlag(cfdflags.MaxActiveFlows)}, + {name: cfdflags.PostQuantum, serialize: serializeBoolServiceFlag(cfdflags.PostQuantum)}, +} + +func buildServiceFlagArgs(c *cli.Context) []string { + args := make([]string, 0, len(serviceFlagSerializers)) + for _, flag := range serviceFlagSerializers { + if arg, ok := flag.serialize(c); ok { + args = append(args, arg) + } + } + return args +} + +func buildServiceRunArgs(c *cli.Context, runArgs []string) []string { + args := buildServiceFlagArgs(c) + return append(args, runArgs...) +} + +func serializeStringServiceFlag(name string) func(*cli.Context) (string, bool) { + return func(c *cli.Context) (string, bool) { + if !c.IsSet(name) { + return "", false + } + value := c.String(name) + if value == "" { + return "", false + } + return "--" + name + "=" + value, true + } +} + +func serializeIntServiceFlag(name string) func(*cli.Context) (string, bool) { + return func(c *cli.Context) (string, bool) { + if !c.IsSet(name) { + return "", false + } + return "--" + name + "=" + strconv.Itoa(c.Int(name)), true + } +} + +func serializeDurationServiceFlag(name string) func(*cli.Context) (string, bool) { + return func(c *cli.Context) (string, bool) { + if !c.IsSet(name) { + return "", false + } + return "--" + name + "=" + c.Duration(name).String(), true + } +} + +func serializeBoolServiceFlag(name string) func(*cli.Context) (string, bool) { + return func(c *cli.Context) (string, bool) { + if !c.IsSet(name) { + return "", false + } + return "--" + name + "=" + strconv.FormatBool(c.Bool(name)), true + } +} diff --git a/cmd/cloudflared/common_service_test.go b/cmd/cloudflared/common_service_test.go new file mode 100644 index 00000000000..220679bda2d --- /dev/null +++ b/cmd/cloudflared/common_service_test.go @@ -0,0 +1,72 @@ +package main + +import ( + "flag" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/urfave/cli/v2" + + cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags" +) + +func TestBuildServiceFlagArgs(t *testing.T) { + cliCtx := newServiceTestContext(t) + + require.NoError(t, cliCtx.Set(cfdflags.Region, "us")) + require.NoError(t, cliCtx.Set(cfdflags.EdgeIpVersion, "6")) + require.NoError(t, cliCtx.Set(cfdflags.Retries, "3")) + require.NoError(t, cliCtx.Set(cfdflags.GracePeriod, "10s")) + require.NoError(t, cliCtx.Set(cfdflags.PostQuantum, "true")) + + require.Equal(t, []string{ + "--region=us", + "--edge-ip-version=6", + "--retries=3", + "--grace-period=10s", + "--post-quantum=true", + }, buildServiceFlagArgs(cliCtx)) +} + +func TestBuildServiceRunArgsAppendsTunnelCommand(t *testing.T) { + cliCtx := newServiceTestContext(t) + + require.NoError(t, cliCtx.Set(cfdflags.EdgeIpVersion, "6")) + require.NoError(t, cliCtx.Set(cfdflags.GracePeriod, "15s")) + require.NoError(t, cliCtx.Set(cfdflags.NoAutoUpdate, "true")) + require.NoError(t, cliCtx.Set(cfdflags.AutoUpdateFreq, "24h")) + + got := buildServiceRunArgs(cliCtx, []string{"--config", "/etc/cloudflared/config.yml", "tunnel", "run"}) + + require.Equal(t, []string{ + "--edge-ip-version=6", + "--grace-period=15s", + "--config", "/etc/cloudflared/config.yml", "tunnel", "run", + }, got) +} + +func newServiceTestContext(t *testing.T) *cli.Context { + t.Helper() + + flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError) + flagSet.String(cfdflags.Region, "", "") + flagSet.String(cfdflags.EdgeIpVersion, "", "") + flagSet.String(cfdflags.EdgeBindAddress, "", "") + flagSet.String(cfdflags.Protocol, "", "") + flagSet.Int(cfdflags.Retries, 0, "") + flagSet.String(cfdflags.LogLevel, "", "") + flagSet.String(cfdflags.TransportLogLevel, "", "") + flagSet.String(cfdflags.LogFile, "", "") + flagSet.String(cfdflags.LogDirectory, "", "") + flagSet.String(cfdflags.TraceOutput, "", "") + flagSet.String(cfdflags.Metrics, "", "") + flagSet.Duration(cfdflags.MetricsUpdateFreq, 0, "") + flagSet.Duration(cfdflags.GracePeriod, 0, "") + flagSet.Int(cfdflags.MaxActiveFlows, 0, "") + flagSet.Bool(cfdflags.PostQuantum, false, "") + flagSet.Bool(cfdflags.NoAutoUpdate, false, "") + flagSet.Duration(cfdflags.AutoUpdateFreq, 24*time.Hour, "") + + return cli.NewContext(cli.NewApp(), flagSet, nil) +} diff --git a/cmd/cloudflared/linux_service.go b/cmd/cloudflared/linux_service.go index c544465aa66..3c9d34205a5 100644 --- a/cmd/cloudflared/linux_service.go +++ b/cmd/cloudflared/linux_service.go @@ -225,7 +225,7 @@ func installLinuxService(c *cli.Context) error { return err } - templateArgs.ExtraArgs = extraArgs + templateArgs.ExtraArgs = buildServiceRunArgs(c, extraArgs) switch { case isSystemd():