Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions app.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package main

import (
"errors"
"flag"
"fmt"
"io"
"os"
"runtime/debug"
"time"

"github.com/gotify/server/v2/config"
"github.com/gotify/server/v2/config/migrate"
"github.com/gotify/server/v2/database"
"github.com/gotify/server/v2/mode"
"github.com/gotify/server/v2/model"
Expand All @@ -27,7 +33,53 @@ var (
)

func main() {
os.Exit(run(os.Args[1:], os.Stdout, os.Stderr))
}

func run(args []string, stdout, stderr io.Writer) int {
vInfo := &model.VersionInfo{Version: Version, Commit: Commit, BuildDate: BuildDate}
fs := flag.NewFlagSet("gotify", flag.ContinueOnError)
fs.SetOutput(stderr)
fs.Usage = func() { printUsage(stderr) }
if err := fs.Parse(args); err != nil {
if errors.Is(err, flag.ErrHelp) {
return 0
}
return 2
}

command := fs.Arg(0)
switch command {
case "serve", "":
return serve(vInfo)
case "version":
fmt.Fprintln(stdout, "Version:", vInfo.Version)
fmt.Fprintln(stdout, "Commit:", vInfo.Commit)
fmt.Fprintln(stdout, "Build Date:", vInfo.BuildDate)
fmt.Fprintln(stdout, "Go Build Info:")
b, ok := debug.ReadBuildInfo()
if ok {
fmt.Fprintln(stdout, b)
}
return 0
case "migrate-config":
content, err := migrate.Config(fs.Arg(1))
if err != nil {
fmt.Fprintln(stderr, err)
return 1
}
fmt.Fprintln(stdout, content)
return 0
default:
if command != "" {
fmt.Fprintf(stderr, "gotify: unknown command %q\n\n", command)
}
printUsage(stderr)
return 2
}
}

func serve(vInfo *model.VersionInfo) int {
mode.Set(Mode)

conf, futureLogs := config.Get()
Expand All @@ -40,21 +92,24 @@ func main() {
exit = exit || futureLog.Level == zerolog.FatalLevel || futureLog.Level == zerolog.PanicLevel
}
if exit {
os.Exit(1)
return 1
}

if conf.PluginsDir != "" {
if err := os.MkdirAll(conf.PluginsDir, 0o755); err != nil {
panic(err)
log.Error().Err(err).Str("dir", conf.PluginsDir).Msg("Cannot create plugins directory")
return 1
}
}
if err := os.MkdirAll(conf.UploadedImagesDir, 0o755); err != nil {
panic(err)
log.Error().Err(err).Str("dir", conf.UploadedImagesDir).Msg("Cannot create uploaded images directory")
return 1
}

db, err := database.New(conf.Database.Dialect, conf.Database.Connection, conf.DefaultUser.Name, conf.DefaultUser.Pass, conf.PassStrength, true, time.Now)
if err != nil {
panic(err)
log.Error().Err(err).Msg("Cannot initialize database")
return 1
}
defer db.Close()

Expand All @@ -63,8 +118,20 @@ func main() {

if err := runner.Run(engine, conf); err != nil {
log.Error().Err(err).Msg("Server error")
os.Exit(1)
return 1
}
return 0
}

func printUsage(w io.Writer) {
fmt.Fprint(w, `Usage: gotify [flags] <command> [arguments]

Commands:
serve Start the Gotify server.
migrate-config <file.yml> Convert an old YAML config file to the new env
format and print it to stdout.
version Show version information
`)
}

func noColor(noColorEnv string) bool {
Expand Down
31 changes: 31 additions & 0 deletions app_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package main

import (
"bytes"
"testing"

"github.com/stretchr/testify/assert"
)

func TestRun(t *testing.T) {
cases := []struct {
name string
args []string
wantCode int
stdout string // substring expected on stdout
stderr string // substring expected on stderr
}{
{"version", []string{"version"}, 0, "Version: ", ""},
{"unknown command", []string{"bogus"}, 2, "", "unknown command"},
{"unknown flag", []string{"--nope"}, 2, "", "not defined"},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
var stdout, stderr bytes.Buffer
code := run(c.args, &stdout, &stderr)
assert.Equal(t, c.wantCode, code)
assert.Contains(t, stdout.String(), c.stdout)
assert.Contains(t, stderr.String(), c.stderr)
})
}
}
185 changes: 185 additions & 0 deletions config/migrate/migrate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package migrate

import (
"encoding/csv"
"encoding/json"
"errors"
"fmt"
"os"
"strconv"
"strings"

"github.com/gotify/server/v2/config"
"github.com/joho/godotenv"
"gopkg.in/yaml.v3"
)

type oldConfig struct {
Server struct {
KeepAlivePeriodSeconds *int
ListenAddr *string
Port *int
SSL struct {
Enabled *bool
RedirectToHTTPS *bool
ListenAddr *string
Port *int
CertFile *string
CertKey *string
LetsEncrypt struct {
Enabled *bool
AcceptTOS *bool
Cache *string
DirectoryURL *string
Hosts []string
}
}
ResponseHeaders map[string]string
Stream struct {
PingPeriodSeconds *int
AllowedOrigins []string
}
Cors struct {
AllowOrigins []string
AllowMethods []string
AllowHeaders []string
}
TrustedProxies []string
SecureCookie *bool
}
Database struct {
Dialect *string
Connection *string
}
DefaultUser struct {
Name *string
Pass *string
}
PassStrength *int
UploadedImagesDir *string
PluginsDir *string
Registration *bool
OIDC struct {
Enabled *bool
Issuer *string
ClientID *string
ClientSecret *string
UsernameClaim *string
RedirectURL *string
AutoRegister *bool
Scopes []string
}
}

func Config(file string) (string, error) {
if file == "" {
return "", errors.New("migrate-config requires one argument: the path to the old config.yml")
}
data, err := os.ReadFile(file)
if err != nil {
return "", fmt.Errorf("cannot read config file %s: %w", file, err)
}

var migrated oldConfig
if err := yaml.Unmarshal(data, &migrated); err != nil {
return "", fmt.Errorf("cannot parse config file %s: %w", file, err)
}

content, err := godotenv.Marshal(buildEnv(migrated))
if err != nil {
return "", fmt.Errorf("cannot render config: %w", err)
}

return content, nil
}

func buildEnv(c oldConfig) map[string]string {
out := map[string]string{}
str := func(key string, value *string) {
if value != nil {
out[key] = *value
}
}
num := func(key string, value *int) {
if value != nil {
out[key] = strconv.Itoa(*value)
}
}
boolean := func(key string, value *bool) {
if value != nil {
out[key] = strconv.FormatBool(*value)
}
}
list := func(key string, value []string) {
if value != nil {
out[key] = marshalList(value)
}
}
headers := func(key string, value map[string]string) {
if value != nil {
out[key] = marshalMap(value)
}
}

num(config.EnvServerKeepAlivePeriodSeconds, c.Server.KeepAlivePeriodSeconds)
str(config.EnvServerListenAddr, c.Server.ListenAddr)
num(config.EnvServerPort, c.Server.Port)
boolean(config.EnvServerSSLEnabled, c.Server.SSL.Enabled)
boolean(config.EnvServerSSLRedirectToHTTPS, c.Server.SSL.RedirectToHTTPS)
str(config.EnvServerSSLListenAddr, c.Server.SSL.ListenAddr)
num(config.EnvServerSSLPort, c.Server.SSL.Port)
str(config.EnvServerSSLCertFile, c.Server.SSL.CertFile)
str(config.EnvServerSSLCertKey, c.Server.SSL.CertKey)
boolean(config.EnvServerSSLLetsEncryptEnabled, c.Server.SSL.LetsEncrypt.Enabled)
boolean(config.EnvServerSSLLetsEncryptAcceptTOS, c.Server.SSL.LetsEncrypt.AcceptTOS)
str(config.EnvServerSSLLetsEncryptCache, c.Server.SSL.LetsEncrypt.Cache)
str(config.EnvServerSSLLetsEncryptDirectoryURL, c.Server.SSL.LetsEncrypt.DirectoryURL)
list(config.EnvServerSSLLetsEncryptHosts, c.Server.SSL.LetsEncrypt.Hosts)
headers(config.EnvServerResponseHeaders, c.Server.ResponseHeaders)
num(config.EnvServerStreamPingPeriodSeconds, c.Server.Stream.PingPeriodSeconds)
list(config.EnvServerStreamAllowedOrigins, c.Server.Stream.AllowedOrigins)
list(config.EnvServerCorsAllowOrigins, c.Server.Cors.AllowOrigins)
list(config.EnvServerCorsAllowMethods, c.Server.Cors.AllowMethods)
list(config.EnvServerCorsAllowHeaders, c.Server.Cors.AllowHeaders)
list(config.EnvServerTrustedProxies, c.Server.TrustedProxies)
boolean(config.EnvServerSecureCookie, c.Server.SecureCookie)
str(config.EnvDatabaseDialect, c.Database.Dialect)
str(config.EnvDatabaseConnection, c.Database.Connection)
str(config.EnvDefaultUserName, c.DefaultUser.Name)
str(config.EnvDefaultUserPass, c.DefaultUser.Pass)
num(config.EnvPassStrength, c.PassStrength)
str(config.EnvUploadedImagesDir, c.UploadedImagesDir)
str(config.EnvPluginsDir, c.PluginsDir)
boolean(config.EnvRegistration, c.Registration)
boolean(config.EnvOIDCEnabled, c.OIDC.Enabled)
str(config.EnvOIDCIssuer, c.OIDC.Issuer)
str(config.EnvOIDCClientID, c.OIDC.ClientID)
str(config.EnvOIDCClientSecret, c.OIDC.ClientSecret)
str(config.EnvOIDCUsernameClaim, c.OIDC.UsernameClaim)
str(config.EnvOIDCRedirectURL, c.OIDC.RedirectURL)
boolean(config.EnvOIDCAutoRegister, c.OIDC.AutoRegister)
list(config.EnvOIDCScopes, c.OIDC.Scopes)
return out
}

func marshalMap(m map[string]string) string {
if len(m) == 0 {
return ""
}
data, err := json.Marshal(m)
if err != nil {
return ""
}
return string(data)
}

func marshalList(values []string) string {
Comment thread
jmattheis marked this conversation as resolved.
var sb strings.Builder
writer := csv.NewWriter(&sb)
writer.UseCRLF = false
if err := writer.Write(values); err != nil {
return ""
}
writer.Flush()
return strings.TrimRight(sb.String(), "\n")
}
Loading
Loading