Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions cmd/cloudflared/common_service.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package main

import (
"strconv"

"github.com/rs/zerolog"
"github.com/urfave/cli/v2"

"github.com/cloudflare/cloudflared/cmd/cloudflared/cliutil"
cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
"github.com/cloudflare/cloudflared/cmd/cloudflared/tunnel"
)

Expand All @@ -28,3 +31,81 @@ func getServiceExtraArgsFromCliArgs(c *cli.Context, log *zerolog.Logger) ([]stri
return make([]string, 0), nil
}
}

type serviceFlagSerializer struct {
name string
serialize func(*cli.Context) (string, bool)
}

var serviceFlagSerializers = []serviceFlagSerializer{
{name: cfdflags.Region, serialize: serializeStringServiceFlag(cfdflags.Region)},
{name: cfdflags.EdgeIpVersion, serialize: serializeStringServiceFlag(cfdflags.EdgeIpVersion)},
{name: cfdflags.EdgeBindAddress, serialize: serializeStringServiceFlag(cfdflags.EdgeBindAddress)},
{name: cfdflags.Protocol, serialize: serializeStringServiceFlag(cfdflags.Protocol)},
{name: cfdflags.Retries, serialize: serializeIntServiceFlag(cfdflags.Retries)},
{name: cfdflags.LogLevel, serialize: serializeStringServiceFlag(cfdflags.LogLevel)},
{name: cfdflags.TransportLogLevel, serialize: serializeStringServiceFlag(cfdflags.TransportLogLevel)},
{name: cfdflags.LogFile, serialize: serializeStringServiceFlag(cfdflags.LogFile)},
{name: cfdflags.LogDirectory, serialize: serializeStringServiceFlag(cfdflags.LogDirectory)},
{name: cfdflags.TraceOutput, serialize: serializeStringServiceFlag(cfdflags.TraceOutput)},
{name: cfdflags.Metrics, serialize: serializeStringServiceFlag(cfdflags.Metrics)},
{name: cfdflags.MetricsUpdateFreq, serialize: serializeDurationServiceFlag(cfdflags.MetricsUpdateFreq)},
{name: cfdflags.GracePeriod, serialize: serializeDurationServiceFlag(cfdflags.GracePeriod)},
{name: cfdflags.MaxActiveFlows, serialize: serializeIntServiceFlag(cfdflags.MaxActiveFlows)},
{name: cfdflags.PostQuantum, serialize: serializeBoolServiceFlag(cfdflags.PostQuantum)},
}

func buildServiceFlagArgs(c *cli.Context) []string {
args := make([]string, 0, len(serviceFlagSerializers))
for _, flag := range serviceFlagSerializers {
if arg, ok := flag.serialize(c); ok {
args = append(args, arg)
}
}
return args
}

func buildServiceRunArgs(c *cli.Context, runArgs []string) []string {
args := buildServiceFlagArgs(c)
return append(args, runArgs...)
}

func serializeStringServiceFlag(name string) func(*cli.Context) (string, bool) {
return func(c *cli.Context) (string, bool) {
if !c.IsSet(name) {
return "", false
}
value := c.String(name)
if value == "" {
return "", false
}
return "--" + name + "=" + value, true
}
}

func serializeIntServiceFlag(name string) func(*cli.Context) (string, bool) {
return func(c *cli.Context) (string, bool) {
if !c.IsSet(name) {
return "", false
}
return "--" + name + "=" + strconv.Itoa(c.Int(name)), true
}
}

func serializeDurationServiceFlag(name string) func(*cli.Context) (string, bool) {
return func(c *cli.Context) (string, bool) {
if !c.IsSet(name) {
return "", false
}
return "--" + name + "=" + c.Duration(name).String(), true
}
}

func serializeBoolServiceFlag(name string) func(*cli.Context) (string, bool) {
return func(c *cli.Context) (string, bool) {
if !c.IsSet(name) {
return "", false
}
return "--" + name + "=" + strconv.FormatBool(c.Bool(name)), true
}
}
72 changes: 72 additions & 0 deletions cmd/cloudflared/common_service_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package main

import (
"flag"
"testing"
"time"

"github.com/stretchr/testify/require"
"github.com/urfave/cli/v2"

cfdflags "github.com/cloudflare/cloudflared/cmd/cloudflared/flags"
)

func TestBuildServiceFlagArgs(t *testing.T) {
cliCtx := newServiceTestContext(t)

require.NoError(t, cliCtx.Set(cfdflags.Region, "us"))
require.NoError(t, cliCtx.Set(cfdflags.EdgeIpVersion, "6"))
require.NoError(t, cliCtx.Set(cfdflags.Retries, "3"))
require.NoError(t, cliCtx.Set(cfdflags.GracePeriod, "10s"))
require.NoError(t, cliCtx.Set(cfdflags.PostQuantum, "true"))

require.Equal(t, []string{
"--region=us",
"--edge-ip-version=6",
"--retries=3",
"--grace-period=10s",
"--post-quantum=true",
}, buildServiceFlagArgs(cliCtx))
}

func TestBuildServiceRunArgsAppendsTunnelCommand(t *testing.T) {
cliCtx := newServiceTestContext(t)

require.NoError(t, cliCtx.Set(cfdflags.EdgeIpVersion, "6"))
require.NoError(t, cliCtx.Set(cfdflags.GracePeriod, "15s"))
require.NoError(t, cliCtx.Set(cfdflags.NoAutoUpdate, "true"))
require.NoError(t, cliCtx.Set(cfdflags.AutoUpdateFreq, "24h"))

got := buildServiceRunArgs(cliCtx, []string{"--config", "/etc/cloudflared/config.yml", "tunnel", "run"})

require.Equal(t, []string{
"--edge-ip-version=6",
"--grace-period=15s",
"--config", "/etc/cloudflared/config.yml", "tunnel", "run",
}, got)
}

func newServiceTestContext(t *testing.T) *cli.Context {
t.Helper()

flagSet := flag.NewFlagSet(t.Name(), flag.PanicOnError)
flagSet.String(cfdflags.Region, "", "")
flagSet.String(cfdflags.EdgeIpVersion, "", "")
flagSet.String(cfdflags.EdgeBindAddress, "", "")
flagSet.String(cfdflags.Protocol, "", "")
flagSet.Int(cfdflags.Retries, 0, "")
flagSet.String(cfdflags.LogLevel, "", "")
flagSet.String(cfdflags.TransportLogLevel, "", "")
flagSet.String(cfdflags.LogFile, "", "")
flagSet.String(cfdflags.LogDirectory, "", "")
flagSet.String(cfdflags.TraceOutput, "", "")
flagSet.String(cfdflags.Metrics, "", "")
flagSet.Duration(cfdflags.MetricsUpdateFreq, 0, "")
flagSet.Duration(cfdflags.GracePeriod, 0, "")
flagSet.Int(cfdflags.MaxActiveFlows, 0, "")
flagSet.Bool(cfdflags.PostQuantum, false, "")
flagSet.Bool(cfdflags.NoAutoUpdate, false, "")
flagSet.Duration(cfdflags.AutoUpdateFreq, 24*time.Hour, "")

return cli.NewContext(cli.NewApp(), flagSet, nil)
}
2 changes: 1 addition & 1 deletion cmd/cloudflared/linux_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func installLinuxService(c *cli.Context) error {
return err
}

templateArgs.ExtraArgs = extraArgs
templateArgs.ExtraArgs = buildServiceRunArgs(c, extraArgs)

switch {
case isSystemd():
Expand Down