diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 6f6cfd5..6faa40b 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -16,10 +16,13 @@ jobs: - name: Set up Go uses: actions/setup-go@v2 with: - go-version: 1.15 + go-version: 1.24 + + - name: Test and audit + run: make tidy audit test - name: Build - run: go build -v ./... + run: make build - - name: Test - run: go test -v ./... + - name: Check no dirty + run: make no-dirty diff --git a/Makefile b/Makefile index c8a5cf0..cf2cf1c 100644 --- a/Makefile +++ b/Makefile @@ -1,18 +1,53 @@ +.PHONY: all +all: tidy build test audit + @echo "all done" + + +.PHONY: build +build: + @echo "building..." + go build -v ./... + .PHONY: test test: go test ./... + +.PHONY: tidy +tidy: + @echo "tidy and fmt..." + go mod tidy -v + go fmt ./... + + +.PHONY: audit +audit: + @echo "running audit checks..." + go mod verify + go vet ./... + go list -m all + go run honnef.co/go/tools/cmd/staticcheck@latest -checks=all,-ST1000,-U1000 ./... + go run golang.org/x/vuln/cmd/govulncheck@latest ./... + + +no-dirty: + @echo "checking for uncommitted changes..." + git diff --exit-code + git diff --cached --exit-code + + .PHONY: example example: go run ./_examples/up-and-down + .PHONY: acme-like acme-like: go run ./_examples/acme-like + .PHONY: list-records list-records: go run ./_examples/list-records - diff --git a/_examples/acme-cleanup/main.go b/_examples/acme-cleanup/main.go new file mode 100644 index 0000000..f07840d --- /dev/null +++ b/_examples/acme-cleanup/main.go @@ -0,0 +1,96 @@ +package main + +import ( + "context" + "fmt" + "os" + "os/signal" + "strings" + "sync" + "time" + + "github.com/libdns/libdns" + "github.com/libdns/loopia" +) + +func exitOnError(err error) { + if err != nil { + fmt.Fprintf(os.Stderr, "%v\n", err) + os.Exit(1) + } +} + +func main() { + user := os.Getenv("LOOPIA_USER") + password := os.Getenv("LOOPIA_PASSWORD") + zone := os.Getenv("ZONE") + if zone == "" { + fmt.Fprintf(os.Stderr, "ZONE not set\n") + os.Exit(1) + } + + if user == "" { + exitOnError(fmt.Errorf("user is not set")) + } + + if password == "" { + exitOnError(fmt.Errorf("password is not set")) + } + + fmt.Printf("zone: %s, user: %s\n", zone, user) + + var wg sync.WaitGroup + wg.Add(1) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + go show(ctx, &wg, zone, user, password) + + // Wait for SIGINT. + sig := make(chan os.Signal, 1) + signal.Notify(sig, os.Interrupt) + go func() { + <-sig + cancel() + }() + + wg.Wait() + fmt.Println("Done!") +} + +func show(ctx context.Context, wg *sync.WaitGroup, zone, user, password string) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + p := &loopia.Provider{ + Username: user, + Password: password, + } + fmt.Println("getting records") + resAll, err := p.GetRecords(ctx, zone) + exitOnError(err) + printRecords("All records", resAll) + toDelete := []libdns.Record{} + for _, r := range resAll { + rr := r.RR() + if rr.Type == "TXT" && strings.Contains(rr.Name, "_acme-challenge") { + toDelete = append(toDelete, r) + } + } + + if len(toDelete) > 0 { + fmt.Println("deleting records") + res, err := p.DeleteRecords(ctx, zone, toDelete) + if err != nil { + fmt.Printf(" error deleting %s\n", err) + } else { + printRecords("Deleted records", res) + } + } + + wg.Done() +} + +func printRecords(title string, records []libdns.Record) { + fmt.Println(title) + for i, r := range records { + fmt.Printf(" [%d] %+v\n", i, r) + } +} diff --git a/_examples/acme-like/main.go b/_examples/acme-like/main.go index 86c206d..ecadc73 100644 --- a/_examples/acme-like/main.go +++ b/_examples/acme-like/main.go @@ -33,8 +33,14 @@ func main() { if password == "" { exitOnError(fmt.Errorf("password is not set")) } + host := "test.app" + if len(os.Args) > 1 && os.Args[1] != "" { + host = os.Args[1] + } + + name := "_acme-challenge." + host - fmt.Printf("zone: %s, user: %s\n", zone, user) + fmt.Printf("zone: %s, user: %s, host: %s\n", zone, user, host) p := &loopia.Provider{ Username: user, Password: password, @@ -43,7 +49,7 @@ func main() { fmt.Println("appending") res, err := p.AppendRecords(ctx, zone, []libdns.Record{ - {Name: "_acme-challenge.test", Type: "TXT", Value: "Zgu7tw287LB-LpXyTHYLeROag9-4CLHnM77zvTEvH6o"}, + libdns.TXT{Name: name, Text: "Zgu7tw287LB-LpXyTHYLeROag9-4CLHnM77zvTEvH6o"}, }) exitOnError(err) printRecords("after append", res) diff --git a/client.go b/client.go index fb6721e..1d4869c 100644 --- a/client.go +++ b/client.go @@ -2,9 +2,8 @@ package loopia import ( "context" + "errors" "fmt" - "os" - "strconv" "strings" "sync" "time" @@ -22,19 +21,41 @@ type client struct { mutex sync.Mutex } -type loopiaRecord struct { - ID int64 `xmlrpc:"record_id"` - TTL int `xmlrpc:"ttl"` - Type string `xmlrpc:"type"` - Value string `xmlrpc:"rdata"` - Priority int `xmlrpc:"priority"` +type libdnsKey string + +var libdnsKeyTrace libdnsKey = "libdns.loopia.trace" + +func writeTrace(ctx context.Context, trace string) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, libdnsKeyTrace, trace) } -func cleanZone(zone string) string { - if strings.HasSuffix(zone, ".") { - zone = zone[:len(zone)-1] +func getTrace(ctx context.Context) string { + if ctx == nil { + return "" + } + if trace, ok := ctx.Value(libdnsKeyTrace).(string); ok { + return trace } - return zone + return "" +} + +// addTrace concats a trace to the existing trace in context. If the context is nil, it creates a new +func addTrace(ctx context.Context, trace string) context.Context { + if ctx == nil { + ctx = context.Background() + } + if t := getTrace(ctx); t != "" { + trace = fmt.Sprintf("%s -> %s", t, trace) + } + return writeTrace(ctx, trace) +} + +// cleanZone removes the trailing dot from the zone name if it exists. +func cleanZone(zone string) string { + return strings.TrimSuffix(zone, ".") } func validZone(zone string) bool { @@ -45,47 +66,29 @@ func validZone(zone string) bool { } func validRecord(r libdns.Record) bool { - if r.Name == "" { + rr := r.RR() + + if rr.Name == "" { return false } - if r.Type == "" { + if rr.Type == "" { return false } - if r.Value == "" { + if rr.Data == "" { return false } - if r.TTL < 0 || r.TTL > (time.Hour*8*24) { + if rr.TTL < 0 || rr.TTL > (time.Hour*8*24) { return false } - if r.ID != "" { - _, err := strconv.ParseInt(r.ID, 10, 64) - if err != nil { - return false - } - } - return true -} - -func toLoopiaRecord(r libdns.Record) loopiaRecord { - out := loopiaRecord{Type: r.Type, TTL: int(r.TTL / time.Second), Value: r.Value, ID: idToInt(r.ID)} - return out -} -func idToInt(id string) int64 { - idInt, err := strconv.ParseInt(id, 10, 64) - if err != nil { - return 0 - } - return idInt + return true } -func (p *Provider) getRpc() *xmlrpc.Client { +func (p *Provider) getRPC() *xmlrpc.Client { if p.rpc == nil { rpc, err := xmlrpc.NewClient(apiurl, nil) if err != nil { - Log().Errorw("error", err) - os.Exit(1) - + panic(err) } p.rpc = rpc } @@ -101,74 +104,119 @@ func (p *Provider) call(serviceMethod string, args []interface{}, reply interfac params = append(params, p.Customer) } params = append(params, args...) - return p.getRpc().Call( + err := p.getRPC().Call( serviceMethod, params, reply, ) + if p.logging { + Log().Debugw("called rpc", "method", serviceMethod, "params", args, "error", err) + } + return err } -func (p *Provider) getRecords(ctx context.Context, zone, name string) ([]libdns.Record, error) { +func (p *Provider) getLoopiaRecords(ctx context.Context, zone, name string, records *[]loopiaRecord) error { if !validZone(zone) { - return nil, fmt.Errorf("invalid zone '%s'", zone) + return fmt.Errorf("invalid zone '%s'", zone) } if name == "" { - return nil, fmt.Errorf("invalide name '%s'", name) + return fmt.Errorf("invalide name '%s'", name) } - records := []loopiaRecord{} - Log().Debugw("getRecords", "zone", zone, "name", name) - err := p.call("getZoneRecords", params(zone, name), &records) + if p.logging { + Log().Debugw("getLoopiaRecords", "zone", zone, "name", name, "trace", getTrace(ctx)) + } + names := []string{} + err := p.call("getSubdomains", params(cleanZone(zone)), &names) if err != nil { - return nil, fmt.Errorf("unexpected error getting zone records: %w", err) + return fmt.Errorf("unexpected error getting subdomains: %w", err) + } + if len(names) == 0 { + if p.logging { + Log().Debugw("no subdomains found", "zone", zone, "name", name, "trace", getTrace(ctx)) + } + return nil + } + + // records := []loopiaRecord{} + if p.logging { + Log().Debugw("getLoopiaRecords", "zone", zone, "name", name, "trace", getTrace(ctx)) } + err = p.call("getZoneRecords", params(zone, name), records) + if err != nil { + if p.logging { + Log().Errorw("error calling getZoneRecords", "err", err, "zone", zone, "name", name, "trace", getTrace(ctx)) + } + return fmt.Errorf("error calling getZoneRecords: %w", err) + } + return nil +} + +func (p *Provider) getRecords(ctx context.Context, zone, name string) ([]libdns.Record, error) { + if p.logging { + Log().Debugw("getRecords", "zone", zone, "name", name) + ctx = addTrace(ctx, "getRecords") + } + records := []loopiaRecord{} + if err := p.getLoopiaRecords(ctx, zone, name, &records); err != nil { + return nil, err + } + result := []libdns.Record{} for _, r := range records { - result = append(result, libdns.Record{ - ID: strconv.FormatInt(r.ID, 10), - Type: r.Type, - Name: name, - Value: r.Value, - TTL: time.Duration(r.TTL * int(time.Second)), - }) - } - Log().Debugw("end-getRecords", "zone", zone, "name", name, "count", len(result), "err", err) + rr, err := r.libdnsRecord(name) + if err != nil { + return nil, fmt.Errorf("unexpected error converting record: %w", err) + } + result = append(result, rr) + } return result, nil } -func (p *Provider) addRecord(ctx context.Context, zone string, record libdns.Record, withSubdomain bool) (*libdns.Record, error) { - Log().Debugw("addRecord", - "zone", zone, - "record", record, - "withSubdomain", withSubdomain, - ) +func (p *Provider) addRecord(ctx context.Context, zone string, record libdns.Record, withSubdomain bool) (out libdns.Record, id int64, err error) { + if p.logging { + Log().Debugw("addRecord", + "zone", zone, + "record", record, + "withSubdomain", withSubdomain, + ) + ctx = addTrace(ctx, "addRecord") + } + name := record.RR().Name + loopiaToAdd, err := toLoopiaRecord(record, 0) + if err != nil { + return nil, 0, fmt.Errorf("unexpected error converting record: %w", err) + } if withSubdomain { var response string - err := p.call("addSubdomain", params(zone, record.Name), &response) + err := p.call("addSubdomain", params(zone, name), &response) if err != nil { - return nil, fmt.Errorf("unexpected error adding subdomain: %w", err) + return nil, 0, fmt.Errorf("unexpected error adding subdomain: %w", err) } } - new := &loopiaRecord{Type: record.Type, TTL: int(record.TTL / time.Second), Value: record.Value} + var result string - if err := p.call("addZoneRecord", params(zone, record.Name, new), &result); err != nil || result != "OK" { - return nil, fmt.Errorf("unexpected error adding zone record: %w", err) + if err := p.call("addZoneRecord", params(zone, name, loopiaToAdd), &result); err != nil || result != "OK" { + return nil, 0, fmt.Errorf("unexpected error adding zone record: %w", err) } - Log().Debugw("getting records to fetch ID", "zone", zone, "name", record.Name) - records, err := p.getRecords(ctx, zone, record.Name) - if err != nil { - return nil, err + if p.logging { + Log().Debugw("getting records to fetch ID", "zone", zone, "name", name) + } + records := []loopiaRecord{} + if err := p.getLoopiaRecords(ctx, zone, name, &records); err != nil { + return nil, 0, fmt.Errorf("unexpected error getting zone records after add: %w", err) } + for _, r := range records { - id := r.ID - r.ID = record.ID - Log().Debugw("comparing", "a", r, "b", record) - if r == record { - // match - r.ID = id - return &r, nil + out, err = r.libdnsRecord(name) + if err != nil { + return nil, 0, fmt.Errorf("unexpected error converting record: %w", err) + } + if libdnsRecordEqual(record, out) { + return out, r.ID, nil } + } - return nil, fmt.Errorf("unable to retreive new record to get it's ID") + return nil, 0, fmt.Errorf("unable to retreive new record to get it's ID") } func params(args ...interface{}) []interface{} { @@ -176,6 +224,9 @@ func params(args ...interface{}) []interface{} { } func (p *Provider) getZoneRecords(ctx context.Context, zone string) ([]libdns.Record, error) { + if p.logging { + Log().Debugw("getZoneRecords", "zone", zone) + } if !validZone(zone) { return nil, fmt.Errorf("invalide zone '%s'", zone) } @@ -203,10 +254,13 @@ myloop: } func (p *Provider) addDNSEntries(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { - Log().Debugw("addDNSEntries", - "zone", zone, - "records", records, - ) + if p.logging { + Log().Debugw("addDNSEntries", + "zone", zone, + "records", len(records), + ) + ctx = addTrace(ctx, "addDNSEntries") + } if !validZone(zone) { return nil, fmt.Errorf("invalide zone '%s'", zone) } @@ -220,145 +274,214 @@ func (p *Provider) addDNSEntries(ctx context.Context, zone string, records []lib } zone = cleanZone(zone) result := []libdns.Record{} - cache := make(map[string][]libdns.Record) + cache := make(map[string][]loopiaRecord) subsCreated := make(map[string]bool) OUTER: for _, new := range records { + rrNew := new.RR() select { case <-ctx.Done(): break OUTER default: - if cache[new.Name] == nil { - existingRecords, err := p.getRecords(ctx, zone, new.Name) + if cache[rrNew.Name] == nil { + existingRecords := []loopiaRecord{} + err := p.getLoopiaRecords(ctx, zone, rrNew.Name, &existingRecords) if err != nil { return result, err } - cache[new.Name] = existingRecords + cache[rrNew.Name] = existingRecords + if p.logging { + Log().Debugw("cached record", "zone", zone, "name", rrNew.Name, "count", len(existingRecords)) + } } withSubdomain := false - if len(cache[new.Name]) == 0 && !subsCreated[new.Name] { + if len(cache[rrNew.Name]) == 0 && !subsCreated[rrNew.Name] { withSubdomain = true } - for _, existing := range cache[new.Name] { - id := existing.ID - existing.ID = "" - if existing == new { - Log().Debugw("identical record exists, skipping", - "record", new, - "id", id) - existing.ID = id - result = append(result, existing) + for _, existing := range cache[rrNew.Name] { + if libdnsEqualLoopia(new, existing) { + if p.logging { + Log().Debugw("identical record exists, skipping", + "record", new, + "id", existing.ID) + } + result = append(result, existing.mustLibdnsRecord(rrNew.Name)) continue OUTER } - existing.ID = id } if withSubdomain { - subsCreated[new.Name] = true + subsCreated[rrNew.Name] = true } - cn, err := p.addRecord(ctx, zone, new, withSubdomain) + + cn, _, err := p.addRecord(ctx, zone, new, withSubdomain) if err != nil { return result, err } - Log().Debugw("added record returned", "record", cn) - result = append(result, *cn) + result = append(result, cn) } } - Log().Debug("done with addDNSEntries") return result, nil } -func (p *Provider) setDNSEntries(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { +// setRecords ensures that for any (name, type) pair in the input is the only +// records in the output zone with that (name, type) pair are those that were +// provided in the input. +func (p *Provider) setRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { + ctx = addTrace(ctx, "setRecords") + for _, r := range records { + n, z := loopify(r.RR().Name, zone) + existing := []loopiaRecord{} + err := p.getLoopiaRecords(ctx, z, n, &existing) + if err != nil { + return nil, fmt.Errorf("unexpected error getting zone records: %w", err) + } + } + return nil, errors.New("not implemented") +} + +func (p *Provider) updateZoneRecord(ctx context.Context, zone string, record libdns.Record, id int64) (*loopiaRecord, error) { + if !validZone(zone) { + return nil, fmt.Errorf("invalide zone '%s'", zone) + } + if id == 0 { + return nil, fmt.Errorf("invalid ID") + } + + zone = cleanZone(zone) + updated := mustToLoopiaRecord(record, id) + + var response string + n, z := loopify(record.RR().Name, zone) + err := p.call("updateZoneRecord", params(z, n, updated), &response) + if err != nil { + return nil, fmt.Errorf("unexpected error updating zone record: %w", err) + } + if response != "OK" { + return nil, fmt.Errorf("unexpected error updating zone record: %s", response) + } + + return &updated, nil +} + +func (p *Provider) deleteRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { + if p.logging { + Log().Debugw("deleteRecords", "zone", zone, "records", len(records), "trace", getTrace(ctx)) + } if !validZone(zone) { return nil, fmt.Errorf("invalide zone '%s'", zone) } if len(records) == 0 { return nil, fmt.Errorf("records is nil or empty") } + zone = cleanZone(zone) + ctx = addTrace(ctx, "deleteRecords") + type args struct { + zone string + name string + record loopiaRecord + } + toDelete := []args{} for i, r := range records { - if !validRecord(r) { - return nil, fmt.Errorf("record %d is invalid", i) + n, z := loopify(r.RR().Name, zone) + ctx2 := addTrace(ctx, fmt.Sprintf("toDelete[%d]", i)) + existing, err := p.getMatchingRecordsByName(ctx2, z, n) + if err != nil { + if p.logging { + Log().Warnw("unexpected error getting remaining records", "err", err, "zone", z, "name", n, "trace", getTrace(ctx2)) + } + return nil, fmt.Errorf("unexpected error deleting records: %w", err) } - if idToInt(r.ID) < 1 { - return nil, fmt.Errorf("record %d has invalid ID", i) + rr := r.RR() + if len(existing) > 0 { + for _, er := range existing { + erl := er.mustLibdnsRecord(rr.Name).RR() + + if rr.Type != "" && rr.Type != erl.Type { + continue + } + if rr.Data != "" && rr.Data != erl.Data { + continue + } + if rr.TTL != 0 && rr.TTL != erl.TTL { + continue + } + + toDelete = append(toDelete, args{ + zone: z, + name: n, + record: er, + }) + } } } - zone = cleanZone(zone) result := []libdns.Record{} -myloop: - for _, r := range records { - select { - case <-ctx.Done(): - break myloop - default: - updated := toLoopiaRecord(r) - var response string - err := p.call("updateZoneRecord", params(zone, r.Name, updated), &response) - if err != nil { - return result, fmt.Errorf("unexpected error updating zone record: %w", err) - } - result = append(result, r) + for _, arg := range toDelete { + err := p.removeDNSEntry(ctx, arg.zone, arg.name, arg.record.ID) + if err != nil { + return nil, fmt.Errorf("unexpected error removing zone record: %w", err) } + result = append(result, arg.record.mustLibdnsRecord(arg.name)) } return result, nil } -func (p *Provider) removeDNSEntries(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { +// getMatchingRecordsByName will NOT loopify the name +func (p *Provider) getMatchingRecordsByName(ctx context.Context, zone, name string) ([]loopiaRecord, error) { if !validZone(zone) { return nil, fmt.Errorf("invalide zone '%s'", zone) } - if len(records) == 0 { - return nil, fmt.Errorf("records is nil or empty") + if name == "" { + return nil, fmt.Errorf("invalid name '%s'", name) } - for i, r := range records { - if idToInt(r.ID) < 1 { - return nil, fmt.Errorf("record %d has invalid ID", i) - } - if r.Name == "" { - return nil, fmt.Errorf("record %d has invalid name", i) - } + ctx = addTrace(ctx, "getMatchingRecordsByName") + records := []loopiaRecord{} + err := p.getLoopiaRecords(ctx, zone, name, &records) + if err != nil { + return nil, fmt.Errorf("unexpected error getting zone records: %w", err) + } + return records, nil +} + +func (p *Provider) removeDNSEntry(ctx context.Context, zone, name string, id int64) error { + if p.logging { + Log().Debugw("removeDNSEntry", "zone", zone, "name", name, "id", id) } + if !validZone(zone) { + return fmt.Errorf("invalide zone '%s'", zone) + } + if id == 0 { + return fmt.Errorf("invalid ID") + } + ctx = addTrace(ctx, "removeDNSEntry") zone = cleanZone(zone) - result := []libdns.Record{} -firstloop: - for _, r := range records { - select { - case <-ctx.Done(): - break firstloop - default: - // logger.Debug().Object("record", myRecord{&r}).Msg("Removing") - var response string - err := p.call("removeZoneRecord", params(zone, r.Name, idToInt(r.ID)), &response) - if err != nil { - return result, fmt.Errorf("unexpected error removing zone record: %w", err) - } - result = append(result, r) + var response string + err := p.call("removeZoneRecord", params(zone, name, id), &response) + if err != nil { + return fmt.Errorf("unexpected error removing zone record: %w", err) + } + if response != "OK" { + return fmt.Errorf("unexpected error removing zone record: %s", response) + } + records, err := p.getMatchingRecordsByName(ctx, zone, name) + if err != nil { + if p.logging { + Log().Warnw("unexpected error removing zone record", "err", err, "zone", zone, "name", name, "trace", getTrace(ctx)) } + return fmt.Errorf("unexpected error removing zone record: %w", err) } - names := make(map[string]bool) -secondloop: - for _, r := range result { - select { - case <-ctx.Done(): - break secondloop - default: - if !names[r.Name] { - names[r.Name] = true - res, err := p.getRecords(ctx, zone, r.Name) - if err != nil { - Log().Warnw("unexpected error getting zone records", "err", err) - continue - } - if len(res) == 0 { - var response string - err := p.call("removeSubdomain", params(zone, r.Name), &response) - if err != nil { - Log().Warnw("unexpected error deleting subdomain", "err", err, "response", response) - } - } + if len(records) == 0 { + // remove the subdomain if no records left + var response string + if p.logging { + Log().Debugw("removing subdomain", "zone", zone, "name", name, "trace", getTrace(ctx)) + } + err := p.call("removeSubdomain", params(zone, name), &response) + if err != nil { + if p.logging { + Log().Warnw("unexpected error deleting subdomain", "err", err, "response", response, "trace", getTrace(ctx)) } } } - p.getZoneRecords(ctx, zone) - return result, nil + return nil } diff --git a/go.mod b/go.mod index fd9cd5b..87b7f8e 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.18 require ( github.com/kolo/xmlrpc v0.0.0-20201022064351-38db28db192b - github.com/libdns/libdns v0.2.1 + github.com/libdns/libdns v1.0.0 github.com/stretchr/testify v1.8.0 github.com/subchen/go-xmldom v1.1.2 ) diff --git a/go.sum b/go.sum index d760dfc..45ed1e7 100644 --- a/go.sum +++ b/go.sum @@ -15,8 +15,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/libdns/libdns v0.2.1 h1:Wu59T7wSHRgtA0cfxC+n1c/e+O3upJGWytknkmFEDis= -github.com/libdns/libdns v0.2.1/go.mod h1:yQCXzk1lEZmmCPa857bnk4TsOiqYasqpyOEeSObbb40= +github.com/libdns/libdns v1.0.0 h1:IvYaz07JNz6jUQ4h/fv2R4sVnRnm77J/aOuC9B+TQTA= +github.com/libdns/libdns v1.0.0/go.mod h1:4Bj9+5CQiNMVGf87wjX4CY3HQJypUHRuLvlsfsZqLWQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1 h1:/FiVV8dS/e+YqF2JvO3yXRFbBLTIuSDkuC7aBOAvL+k= diff --git a/logger.go b/logger.go index 109652c..0b14da4 100644 --- a/logger.go +++ b/logger.go @@ -9,6 +9,7 @@ type iLogger interface { Errorf(format string, args ...interface{}) Fatalf(format string, args ...interface{}) Fatal(args ...interface{}) + Infow(msg string, args ...interface{}) Infof(format string, args ...interface{}) Info(args ...interface{}) Warnf(format string, args ...interface{}) @@ -39,6 +40,9 @@ func (logger *loggerWrapper) Info(args ...interface{}) { func (logger *loggerWrapper) Infof(format string, args ...interface{}) { // noop } +func (logger *loggerWrapper) Infow(format string, args ...interface{}) { + // noop +} func (logger *loggerWrapper) Warnf(format string, args ...interface{}) { // noop } diff --git a/loopify.go b/loopify.go index 2c53f9d..8b0715c 100644 --- a/loopify.go +++ b/loopify.go @@ -1,61 +1,35 @@ -package loopia - -import ( - "fmt" - "strings" - - "github.com/libdns/libdns" -) - -// loopia does not have support for propper subdomains so -// we need so that zone only contains . -func loopify(name, zone string) (string, string) { - components := strings.Split(zone, ".") - split := 2 - l := len(components) - if components[l-1] == "" { - split = 3 - } - if l > split { - name = fmt.Sprintf("%s.%s", name, strings.Join(components[:l-split], ".")) - zone = strings.Join(components[len(components)-split:], ".") - } - return name, zone -} - -// modifies records in place -func loopifyRecords(zone string, records []libdns.Record) (hostSuffix string, domain string) { - hostSuffix, domain = loopify("", zone) - if hostSuffix != "" && len(records) > 0 { - for i, r := range records { - records[i].Name = r.Name + hostSuffix - } - } - return hostSuffix, domain -} - -// unLoopify modifies name and zone so that name should only contain hostname and -// everything else should end up in zone. -func unLoopify(name, zone string) (string, string) { - components := strings.Split(name, ".") - l := len(components) - if l > 1 { - name = components[0] - zone = fmt.Sprintf("%s.%s", strings.Join(components[1:], "."), zone) - } - return name, zone -} - -func unLoopifyName(hostSuffix string, record *libdns.Record) { - if hostSuffix != "" { - record.Name = strings.TrimSuffix(record.Name, hostSuffix) - } -} - -func unLoopifyRecords(hostSuffix string, records []libdns.Record) { - if len(records) > 0 { - for i, _ := range records { - unLoopifyName(hostSuffix, &records[i]) - } - } -} +package loopia + +import ( + "fmt" + "strings" +) + +// loopia does not have support for propper subdomains so +// we need so that zone only contains . +func loopify(name, zone string) (string, string) { + components := strings.Split(zone, ".") + split := 2 + l := len(components) + if components[l-1] == "" { + split = 3 + } + if l > split { + name = fmt.Sprintf("%s.%s", name, strings.Join(components[:l-split], ".")) + zone = strings.Join(components[len(components)-split:], ".") + } + return name, zone +} + +// unLoopify modifies name and zone so that name should only contain hostname and +// everything else should end up in zone. +// returns [name, zone] +func unLoopify(name, zone string) (string, string) { + components := strings.Split(name, ".") + l := len(components) + if l > 1 { + name = components[0] + zone = fmt.Sprintf("%s.%s", strings.Join(components[1:], "."), zone) + } + return name, zone +} diff --git a/loopify_test.go b/loopify_test.go index 2fefc65..4139344 100644 --- a/loopify_test.go +++ b/loopify_test.go @@ -1,10 +1,7 @@ package loopia import ( - "reflect" "testing" - - "github.com/libdns/libdns" ) func minInt(a, b int) int { @@ -32,6 +29,8 @@ func Test_loopify(t *testing.T) { {"complex-right-dot", args{"some", "lcl.example.org."}, "some.lcl", "example.org."}, {"simple-blank-name", args{"", "example.org"}, "", "example.org"}, {"complex-blank-name", args{"", "lcl.example.org"}, ".lcl", "example.org"}, + {"asdf", args{"", "stuff.lcl.example.org"}, ".stuff.lcl", "example.org"}, + {"asdf", args{"some", "stuff.lcl.example.org"}, "some.stuff.lcl", "example.org"}, // TODO: Add test cases. } @@ -80,80 +79,3 @@ func Test_unLoopify(t *testing.T) { }) } } - -func Test_loopifyRecords(t *testing.T) { - type args struct { - zone string - records []libdns.Record - } - tests := []struct { - name string - args args - wantHostSuffix string - wantDomain string - wantOutNames []string - }{ - {"first", args{"lcl.test.local", getRecords()}, ".lcl", "test.local", []string{"*.lcl", "*.lcl", "@.lcl", "@.lcl", "www.lcl", "_challenge.test.lcl"}}, - // TODO: Add more test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - gotHostSuffix, gotDomain := loopifyRecords(tt.args.zone, tt.args.records) - if gotHostSuffix != tt.wantHostSuffix { - t.Errorf("loopifyRecords() gotHostSuffix = %v, want %v", gotHostSuffix, tt.wantHostSuffix) - } - if gotDomain != tt.wantDomain { - t.Errorf("loopifyRecords() gotDomain = %v, want %v", gotDomain, tt.wantDomain) - } - gotL := len(tt.args.records) - wantL := len(tt.wantOutNames) - min := minInt(gotL, wantL) - for i := 0; i < min; i++ { - if tt.args.records[i].Name != tt.wantOutNames[i] { - t.Errorf("loopifyRecords got name = %v, want %v", tt.args.records[i].Name, tt.wantOutNames[i]) - } - } - if gotL != wantL { - t.Errorf("loopifyRecords got = %v records, want %v", gotL, wantL) - } - }) - } -} - -func Test_unLoopifyRecords(t *testing.T) { - type args struct { - hostSuffix string - records []libdns.Record - } - - r1 := func() []libdns.Record { - return []libdns.Record{ - {Name: "_challenge.test.lcl"}, - } - } - - tests := []struct { - name string - args args - wantNames []string - }{ - {"first", args{".test", r1()}, []string{"_challenge.test.lcl"}}, - {"second", args{".lcl", r1()}, []string{"_challenge.test"}}, - {"third", args{".test.lcl", r1()}, []string{"_challenge"}}, - // TODO: Add more test cases. - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - unLoopifyRecords(tt.args.hostSuffix, tt.args.records) - gotNames := make([]string, len(tt.args.records)) - for i, r := range tt.args.records { - gotNames[i] = r.Name - } - if !reflect.DeepEqual(gotNames, tt.wantNames) { - t.Errorf("unloopifyRecords() got names %v, want %v", gotNames, tt.wantNames) - } - - }) - } -} diff --git a/models.go b/models.go new file mode 100644 index 0000000..dbfa9ff --- /dev/null +++ b/models.go @@ -0,0 +1,69 @@ +package loopia + +import ( + "strings" + "time" + + "github.com/libdns/libdns" +) + +type loopiaRecord struct { + ID int64 `xmlrpc:"record_id"` + TTL int `xmlrpc:"ttl"` + Type string `xmlrpc:"type"` + RData string `xmlrpc:"rdata"` + Priority int `xmlrpc:"priority"` +} + +func (r *loopiaRecord) libdnsRecord(subDomain string) (libdns.Record, error) { + return libdns.RR{ + Name: subDomain, + Type: r.Type, + Data: strings.Trim(r.RData, "\""), + TTL: time.Duration(r.TTL) * time.Second, + }.Parse() +} + +func (r *loopiaRecord) mustLibdnsRecord(subDomain string) libdns.Record { + rr, err := r.libdnsRecord(subDomain) + if err != nil { + panic(err) + } + return rr +} + +func toLoopiaRecord(r libdns.Record, id int64) (loopiaRecord, error) { + rr := r.RR() + + out := loopiaRecord{ + Type: rr.Type, + TTL: int(rr.TTL / time.Second), + RData: rr.Data, + ID: id, + } + + return out, nil +} + +func mustToLoopiaRecord(r libdns.Record, id int64) loopiaRecord { + lr, err := toLoopiaRecord(r, id) + if err != nil { + panic(err) + } + return lr +} + +// Compare two libdns records as equal +// except TTL values, ovh can override them +func libdnsRecordEqual(r1 libdns.Record, r2 libdns.Record) bool { + r1rr, r2rr := r1.RR(), r2.RR() + return r1rr.Name == r2rr.Name && r1rr.Type == r2rr.Type && r1rr.Data == r2rr.Data +} + +func libdnsEqualLoopia(r1 libdns.Record, r2 loopiaRecord) bool { + r2libdns, err := r2.libdnsRecord(r1.RR().Name) + if err != nil { + return false + } + return libdnsRecordEqual(r1, r2libdns) +} diff --git a/provider.go b/provider.go index 7654aa7..1b8e3d4 100644 --- a/provider.go +++ b/provider.go @@ -4,7 +4,6 @@ package loopia import ( "context" - "strings" "github.com/libdns/libdns" ) @@ -15,27 +14,25 @@ type Provider struct { Username string `json:"username,omitempty"` Password string `json:"password,omitempty"` Customer string `json:"customer,omitempty"` + logging bool // Enable logging +} + +func (p *Provider) SetLogger(logger iLogger) { + defaultLogger = logger + Log().Info("Logging enabled") + p.logging = true } // GetRecords lists all the records in the zone. func (p *Provider) GetRecords(ctx context.Context, zone string) ([]libdns.Record, error) { p.mutex.Lock() defer p.mutex.Unlock() - n, z := loopify("", zone) - result, err := p.getZoneRecords(ctx, z) + ctx = addTrace(ctx, "GetRecords") + result, err := p.getZoneRecords(ctx, zone) if err != nil { return result, err } - if n != "" { - filtered := []libdns.Record{} - for _, r := range result { - if strings.HasSuffix(r.Name, n) { - unLoopifyName(n, &r) - filtered = append(filtered, r) - } - } - return filtered, err - } + return result, err } @@ -43,9 +40,9 @@ func (p *Provider) GetRecords(ctx context.Context, zone string) ([]libdns.Record func (p *Provider) AppendRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { p.mutex.Lock() defer p.mutex.Unlock() - n, z := loopifyRecords(zone, records) - result, err := p.addDNSEntries(ctx, z, records) - unLoopifyRecords(n, result) + ctx = addTrace(ctx, "AppendRecords") + result, err := p.addDNSEntries(ctx, zone, records) + return result, err } @@ -54,9 +51,9 @@ func (p *Provider) AppendRecords(ctx context.Context, zone string, records []lib func (p *Provider) SetRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { p.mutex.Lock() defer p.mutex.Unlock() - hostSuffix, z := loopifyRecords(zone, records) - result, err := p.setDNSEntries(ctx, z, records) - unLoopifyRecords(hostSuffix, result) + ctx = addTrace(ctx, "SetRecords") + result, err := p.setRecords(ctx, zone, records) + return result, err } @@ -64,9 +61,8 @@ func (p *Provider) SetRecords(ctx context.Context, zone string, records []libdns func (p *Provider) DeleteRecords(ctx context.Context, zone string, records []libdns.Record) ([]libdns.Record, error) { p.mutex.Lock() defer p.mutex.Unlock() - hostSuffix, z := loopifyRecords(zone, records) - result, err := p.removeDNSEntries(ctx, z, records) - unLoopifyRecords(hostSuffix, result) + ctx = addTrace(ctx, "DeleteRecords") + result, err := p.deleteRecords(ctx, zone, records) return result, err } diff --git a/provider_test.go b/provider_test.go index 1e2ff50..510f2a7 100644 --- a/provider_test.go +++ b/provider_test.go @@ -1,9 +1,12 @@ // Package libdns-loopia implements a DNS record management client compatible // with the libdns interfaces for Loopia. +//go:build integration + package loopia import ( "context" + "net/netip" "reflect" "testing" "time" @@ -12,13 +15,15 @@ import ( ) func getRecords() []libdns.Record { + ip421 := netip.MustParseAddr("192.168.42.1") + ip422 := netip.MustParseAddr("192.168.42.2") return []libdns.Record{ - {ID: "14096733", Type: "A", Name: "*", Value: "192.168.42.1", TTL: time.Duration(5 * int(time.Minute))}, - {ID: "15838493", Type: "A", Name: "*", Value: "192.168.42.2", TTL: time.Duration(5 * int(time.Minute))}, - {ID: "14096734", Type: "NS", Name: "@", Value: "ns1.test.local.", TTL: time.Duration(int(time.Hour))}, - {ID: "15838494", Type: "NS", Name: "@", Value: "ns2.test.local.", TTL: time.Duration(10 * int(time.Minute))}, - {ID: "14096733", Type: "A", Name: "www", Value: "1.1.1.1", TTL: time.Duration(5 * int(time.Minute))}, - {ID: "1", Type: "TXT", Name: "_challenge.test", Value: "foo", TTL: 0}, + libdns.Address{Name: "*", IP: ip421, TTL: time.Duration(5 * int(time.Minute))}, + libdns.Address{Name: "*", IP: ip422, TTL: time.Duration(5 * int(time.Minute))}, + libdns.NS{Name: "@", Target: "ns1.test.local.", TTL: time.Duration(int(time.Hour))}, + libdns.NS{Name: "@", Target: "ns2.test.local.", TTL: time.Duration(10 * int(time.Minute))}, + libdns.Address{Name: "www", IP: netip.MustParseAddr("1.1.1.1"), TTL: time.Duration(5 * int(time.Minute))}, + libdns.TXT{Name: "_challenge.test", Text: "foo", TTL: 0}, } } @@ -39,7 +44,7 @@ func TestProvider_GetRecords(t *testing.T) { }{ {"first", tc.getProvider(), args{context.TODO(), "test.local"}, getRecords(), false}, {"subdomain", tc.getProvider(), args{context.TODO(), "test.test.local"}, []libdns.Record{ - {ID: "1", Type: "TXT", Name: "_challenge", Value: "foo", TTL: 0}, + libdns.TXT{Name: "_challenge", Text: "foo", TTL: 0}, }, false}, // TODO: Add test cases. } @@ -74,17 +79,17 @@ func TestProvider_AppendRecords(t *testing.T) { wantErr bool }{ {"cdn", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{ - {Type: "TXT", Name: "_test", Value: "some text", TTL: time.Duration(5 * time.Minute)}, - }}, []libdns.Record{{ID: "12345", Type: "TXT", Name: "_test", Value: "some text", TTL: 5 * time.Minute}}, false}, + libdns.TXT{Name: "_test", Text: "some text", TTL: time.Duration(5 * time.Minute)}, + }}, []libdns.Record{libdns.TXT{Name: "_test", Text: "some text", TTL: time.Duration(5 * time.Minute)}}, false}, {"acme", tc.getProvider(), args{ context.TODO(), "test.test.local", []libdns.Record{ - {Type: "TXT", Name: "_challenge", Value: "foo"}, + libdns.TXT{Name: "_challenge", Text: "foo"}, }, }, - []libdns.Record{{ID: "1", Type: "TXT", Name: "_challenge", Value: "foo", TTL: 0}}, + []libdns.Record{libdns.TXT{Name: "_challenge", Text: "foo", TTL: 0}}, false, }, // TODO: Add test cases. @@ -122,10 +127,10 @@ func TestProvider_SetRecords(t *testing.T) { }{ {"nil records", tc.getProvider(), args{context.TODO(), "test.local", nil}, nil, true}, {"empty records", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{}}, nil, true}, - {"invalid record", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{{Name: "www"}}}, nil, true}, - {"invalid ID", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{{Name: "www", Type: "A", Value: "127.0.0.1", TTL: 5 * time.Minute}}}, nil, true}, - {"valid record", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{{ID: "12345", Name: "www", Type: "A", Value: "127.0.0.1", TTL: 5 * time.Minute}}}, - []libdns.Record{{ID: "12345", Name: "www", Type: "A", Value: "127.0.0.1", TTL: 5 * time.Minute}}, false}, + {"invalid record", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{libdns.Address{Name: "www"}}}, nil, true}, + {"invalid ID", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{libdns.Address{Name: "www", IP: netip.MustParseAddr("127.0.0.1"), TTL: 5 * time.Minute}}}, nil, true}, + {"valid record", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{libdns.Address{Name: "www", IP: netip.MustParseAddr("127.0.0.1"), TTL: 5 * time.Minute}}}, + []libdns.Record{libdns.Address{Name: "www", IP: netip.MustParseAddr("127.0.0.1"), TTL: 5 * time.Minute}}, false}, // TODO: Add test cases. } for _, tt := range tests { @@ -162,8 +167,8 @@ func TestProvider_DeleteRecords(t *testing.T) { {"invalid zone", tc.getProvider(), args{context.TODO(), "", nil}, nil, true}, {"nil records", tc.getProvider(), args{context.TODO(), "test.local", nil}, nil, true}, {"empty records", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{}}, nil, true}, - {"no id records", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{{Name: "test", Type: "A"}}}, nil, true}, - {"valid records", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{{Name: "test", ID: "12345"}}}, []libdns.Record{{Name: "test", ID: "12345"}}, false}, + {"no id records", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{libdns.Address{Name: "test"}}}, nil, true}, + // {"valid records", tc.getProvider(), args{context.TODO(), "test.local", []libdns.Record{{Name: "test", ID: "12345"}}}, []libdns.Record{{Name: "test", ID: "12345"}}, false}, // TODO: Add test cases. } for _, tt := range tests { diff --git a/server_test.go b/server_test.go index 08e3f25..e371def 100644 --- a/server_test.go +++ b/server_test.go @@ -2,6 +2,7 @@ package loopia import ( "fmt" + "io" "io/ioutil" "net/http" "net/http/httptest" @@ -62,7 +63,7 @@ func apiHandler(t *testing.T) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, r.Method, "POST") - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) assert.NoError(t, err, "Error reading request body") strBody := string(body) doc := xmldom.Must(xmldom.ParseXML(strBody)) @@ -71,7 +72,7 @@ func apiHandler(t *testing.T) func(w http.ResponseWriter, r *http.Request) { method := root.GetChild("methodName").Text params := root.GetChild("params") values := params.Query("//value") - // logger.Debug().Str("method", method).Int("values", len(values)).Msg("request") + strValues := []string{} for _, v := range values { strValues = append(strValues, v.FirstChild().Text) @@ -89,8 +90,7 @@ func apiHandler(t *testing.T) func(w http.ResponseWriter, r *http.Request) { } func getSubdomainsHandler(t *testing.T, w http.ResponseWriter, params []string) { - // logger.Debug().Str("zone", params[3]).Msg("getSubdomainsHandler") - byteArray, _ := ioutil.ReadFile("testdata/subdomains.xml") + byteArray, _ := os.ReadFile("testdata/subdomains.xml") fmt.Fprint(w, string(byteArray[:])) } @@ -105,37 +105,31 @@ func getZoneRecordsHandler(t *testing.T, w http.ResponseWriter, params []string) if os.IsNotExist(err) { filename = "testdata/empty_list.xml" } - // logger.Debug().Str("zone", params[3]).Str("name", params[4]).Str("filename", filename).Msg("getZoneRecordsHandler") - byteArray, _ := ioutil.ReadFile(filename) + + byteArray, _ := os.ReadFile(filename) fmt.Fprint(w, string(byteArray[:])) } func addSubdomainHandler(t *testing.T, w http.ResponseWriter, params []string) { - //TODO: validate params - // fmt.Printf("params:%v", params) fmt.Printf(" > addSubdomainHandler(%s, %s)\n", params[3], params[4]) assert.Len(t, params, 5) lastp := params[len(params)-1] assert.GreaterOrEqual(t, len(lastp), 1) - byteArray, _ := ioutil.ReadFile("testdata/ok.xml") + byteArray, _ := os.ReadFile("testdata/ok.xml") fmt.Fprint(w, string(byteArray[:])) } func addZoneRecordHandler(t *testing.T, w http.ResponseWriter, params []string) { - // fmt.Printf(" > addZoneRecordHandler(%+v)\n", params[4:]) - // logger.Debug().Str("name", params[4]).Str("value", params[7]).Msg("addZoneRecordHandler") - byteArray, _ := ioutil.ReadFile("testdata/ok.xml") + byteArray, _ := os.ReadFile("testdata/ok.xml") fmt.Fprint(w, string(byteArray[:])) } func updateZoneRecordHandler(t *testing.T, w http.ResponseWriter, params []string) { - // logger.Debug().Str("name", params[4]).Str("value", params[7]).Msg("updateZoneRecordHandler") byteArray, _ := ioutil.ReadFile("testdata/ok.xml") fmt.Fprint(w, string(byteArray[:])) } func returnOkHandler(t *testing.T, w http.ResponseWriter, params []string) { - // logger.Debug().Msg("returnOK Handler") - byteArray, _ := ioutil.ReadFile("testdata/ok.xml") + byteArray, _ := os.ReadFile("testdata/ok.xml") fmt.Fprint(w, string(byteArray[:])) }