Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions convgen.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,14 @@ func Module(opts ...moduleOption) module {
panic("convgen: not generated")
}

// ImportModule imports configurations and registered functions from another module.
//
// var core = convgen.Module(...)
// var ext = convgen.Module(convgen.ImportModule(core), ...)
func ImportModule(mod module) moduleOption {
panic("convgen: not generated")
}

// Struct directive generates a converter function between two struct types
// without error:
//
Expand Down
53 changes: 47 additions & 6 deletions internal/convgen/parse/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,29 @@ func (cfg Config) ForkForEnum() Config {
return c
}

func (cfg *Config) UpdateImport(other Config) {
cfg.Update(other)

if other.ForStruct != nil {
if cfg.ForStruct == nil {
cfg.ForStruct = &Config{}
}
cfg.ForStruct.UpdateImport(*other.ForStruct)
}
if other.ForUnion != nil {
if cfg.ForUnion == nil {
cfg.ForUnion = &Config{}
}
cfg.ForUnion.UpdateImport(*other.ForUnion)
}
if other.ForEnum != nil {
if cfg.ForEnum == nil {
cfg.ForEnum = &Config{}
}
cfg.ForEnum.UpdateImport(*other.ForEnum)
}
}

type parsers interface {
ParsePathX(p *Parser, expr ast.Expr) (*Path, error)
ParsePathY(p *Parser, expr ast.Expr) (*Path, error)
Expand All @@ -169,7 +192,7 @@ type parsers interface {
ParsePkgY(p *Parser, expr ast.Expr) (*types.Package, error)
}

func (p *Parser) ParseConfig(cfg *Config, args []ast.Expr, parsers parsers) error {
func (p *Parser) ParseConfig(cfg *Config, args []ast.Expr, parsers parsers, fetchMod func(token.Pos) (*Module, error)) error {
var errs error
for _, arg := range args {
if _, ok := arg.(*ast.Ident); ok {
Expand All @@ -195,43 +218,45 @@ func (p *Parser) ParseConfig(cfg *Config, args []ast.Expr, parsers parsers) erro
if cfg.ForStruct == nil {
cfg.ForStruct = &Config{}
}
if err := p.ParseConfig(cfg.ForStruct, call.Args, parsers); err != nil {
if err := p.ParseConfig(cfg.ForStruct, call.Args, parsers, fetchMod); err != nil {
errs = errors.Join(errs, err)
}
continue
case p.IsDirective(call, "ForUnion"):
if cfg.ForUnion == nil {
cfg.ForUnion = &Config{}
}
if err := p.ParseConfig(cfg.ForUnion, call.Args, parsers); err != nil {
if err := p.ParseConfig(cfg.ForUnion, call.Args, parsers, fetchMod); err != nil {
errs = errors.Join(errs, err)
}
continue
case p.IsDirective(call, "ForEnum"):
if cfg.ForEnum == nil {
cfg.ForEnum = &Config{}
}
if err := p.ParseConfig(cfg.ForEnum, call.Args, parsers); err != nil {
if err := p.ParseConfig(cfg.ForEnum, call.Args, parsers, fetchMod); err != nil {
errs = errors.Join(errs, err)
}
continue
}

if err := p.ParseOption(cfg, call, parsers); err != nil {
if err := p.ParseOption(cfg, call, parsers, fetchMod); err != nil {
errs = errors.Join(errs, err)
}
}
return errs
}

func (p *Parser) ParseOption(cfg *Config, call *ast.CallExpr, ps parsers) error { // nolint: gocyclo
func (p *Parser) ParseOption(cfg *Config, call *ast.CallExpr, ps parsers, fetchMod func(token.Pos) (*Module, error)) error { // nolint: gocyclo
callee := typeutil.Callee(p.Pkg().TypesInfo, call)
if callee == nil || !IsConvgenImport(callee.Pkg().Path()) {
return codefmt.Errorf(p, call, "option must be convgen directive")
}

name := callee.Name()
switch name {
case "ImportModule":
return p.ParseOptionImportModule(cfg, call, fetchMod)
case "ImportFunc":
return p.ParseOptionImportFunc(cfg, call, false)
case "ImportFuncErr":
Expand Down Expand Up @@ -290,6 +315,22 @@ func (p *Parser) ParseOption(cfg *Config, call *ast.CallExpr, ps parsers) error
return codefmt.Errorf(p, call.Fun, "%s is not supported option", name)
}

func (p *Parser) ParseOptionImportModule(c *Config, call *ast.CallExpr, fetchMod func(token.Pos) (*Module, error)) error {
expr, err := needArgs1(p, call)
if err != nil {
return err
}

mod, err := p.ParseModuleArg(expr, fetchMod)
if err != nil {
return err
}
if mod != nil {
c.UpdateImport(mod.Config)
}
return nil
}

func (p *Parser) ParseOptionImportFunc(c *Config, call *ast.CallExpr, hasErr bool) error {
expr, err := needArgs1(p, call)
if err != nil {
Expand Down
11 changes: 9 additions & 2 deletions internal/convgen/parse/injector.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,14 @@ func (p *Parser) parseInjector(id *ast.Ident, call *ast.CallExpr, doc, comment *
inj.Func = fn
errs = errors.Join(errs, err)

mod, err := p.ParseModuleArg(call.Args[0], mods)
fetchMod := func(pos token.Pos) (*Module, error) {
if mod, ok := mods[pos]; ok {
return mod, nil
}
return nil, nil // not found
}

mod, err := p.ParseModuleArg(call.Args[0], fetchMod)
if err != nil {
mod = NilModule() // Prevent nil panic to collect as many errors as possible
}
Expand Down Expand Up @@ -266,7 +273,7 @@ func (p *Parser) parseInjector(id *ast.Ident, call *ast.CallExpr, doc, comment *
// Parse config
cfg.DiscoverBySamplePkgX = inj.X().Pkg()
cfg.DiscoverBySamplePkgY = inj.Y().Pkg()
errs = errors.Join(errs, p.ParseConfig(&cfg, opts, parsers))
errs = errors.Join(errs, p.ParseConfig(&cfg, opts, parsers, fetchMod))
inj.Config = cfg

// Register into the module
Expand Down
61 changes: 49 additions & 12 deletions internal/convgen/parse/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ type Module struct {
// ParseModules finds and parses all convgen.Module calls in the parsed files.
func (p *Parser) ParseModules() (map[token.Pos]*Module, error) {
var errs error
mods := make(map[token.Pos]*Module)

type modDecl struct {
Name string
Call *ast.CallExpr
}
decls := make(map[token.Pos]modDecl)

for _, file := range p.ConvgenGoFiles() {
for id, call := range p.FindModules(file) {
Expand All @@ -46,10 +51,39 @@ func (p *Parser) ParseModules() (map[token.Pos]*Module, error) {
name = ""
}

mod, err := p.ParseModule(call, name)
mods[id.Pos()] = mod
decls[id.Pos()] = modDecl{Name: name, Call: call}
}
}

mods := make(map[token.Pos]*Module)
visiting := make(map[token.Pos]bool)

var fetchMod func(token.Pos) (*Module, error)
fetchMod = func(pos token.Pos) (*Module, error) {
if mod, ok := mods[pos]; ok {
return mod, nil
}
if visiting[pos] {
return nil, errors.New("import cycle detected")
}
decl, ok := decls[pos]
if !ok {
return nil, nil
}

visiting[pos] = true
mod, err := p.ParseModule(decl.Call, decl.Name, fetchMod)
if err != nil {
errs = errors.Join(errs, err)
}
// Mark as not visiting even if there is an error to avoid blocking other modules that depend on this module.
visiting[pos] = false
mods[pos] = mod
return mod, err
}

for pos := range decls {
_, _ = fetchMod(pos)
}

return mods, errs
Expand Down Expand Up @@ -92,7 +126,7 @@ func (p *Parser) FindModules(file *ast.File) iter.Seq2[*ast.Ident, *ast.CallExpr

// ParseModule parses a [convgen.Module] call expression and returns a new
// module.
func (p *Parser) ParseModule(call *ast.CallExpr, name string) (*Module, error) {
func (p *Parser) ParseModule(call *ast.CallExpr, name string, fetchMod func(token.Pos) (*Module, error)) (*Module, error) {
// Chain of For* after NewModule
calls := []*ast.CallExpr{call}
for {
Expand All @@ -113,23 +147,23 @@ func (p *Parser) ParseModule(call *ast.CallExpr, name string) (*Module, error) {

var cfg Config
var errs error
if err := p.ParseConfig(&cfg, calls[0].Args, nil); err != nil {
if err := p.ParseConfig(&cfg, calls[0].Args, nil, fetchMod); err != nil {
errs = errors.Join(errs, err)
}

for _, call := range calls[1:] {
switch call.Fun.(*ast.SelectorExpr).Sel.Name {
case "ForStruct":
cfg.ForStruct = &Config{}
err := p.ParseConfig(cfg.ForStruct, call.Args, nil)
err := p.ParseConfig(cfg.ForStruct, call.Args, nil, fetchMod)
errs = errors.Join(errs, err)
case "ForUnion":
cfg.ForUnion = &Config{}
err := p.ParseConfig(cfg.ForUnion, call.Args, nil)
err := p.ParseConfig(cfg.ForUnion, call.Args, nil, fetchMod)
errs = errors.Join(errs, err)
case "ForEnum":
cfg.ForEnum = &Config{}
err := p.ParseConfig(cfg.ForEnum, call.Args, nil)
err := p.ParseConfig(cfg.ForEnum, call.Args, nil, fetchMod)
errs = errors.Join(errs, err)
default:
panic("unexpected module chain")
Expand Down Expand Up @@ -181,7 +215,7 @@ func (p *Parser) newModuleLookup(cfg Config, old *typeinfo.Lookup[typeinfo.Func]

// ParseModuleArg parses a Convgen module type argument from the given
// expression.
func (p *Parser) ParseModuleArg(expr ast.Expr, mods map[token.Pos]*Module) (*Module, error) {
func (p *Parser) ParseModuleArg(expr ast.Expr, fetchMod func(token.Pos) (*Module, error)) (*Module, error) {
expr = ast.Unparen(expr)

// Inline Module Declaration
Expand All @@ -194,7 +228,7 @@ func (p *Parser) ParseModuleArg(expr ast.Expr, mods map[token.Pos]*Module) (*Mod
// implicit converters. The implicit converters will inherit the module's
// configuration.
if call, ok := expr.(*ast.CallExpr); ok && p.IsDirective(call, "Module") {
return p.ParseModule(call, "")
return p.ParseModule(call, "", fetchMod)
}

// Validate identifier
Expand Down Expand Up @@ -230,8 +264,11 @@ func (p *Parser) ParseModuleArg(expr ast.Expr, mods map[token.Pos]*Module) (*Mod
// This is the most common way to declare and use a module. Multiple
// converters can belong to the same package-level module.
modPos := p.Pkg().TypesInfo.ObjectOf(id).Pos()
mod, ok := mods[modPos]
if !ok {
mod, err := fetchMod(modPos)
if err != nil {
return nil, codefmt.Errorf(p, expr, "cannot import module %q: %v", id.Name, err)
}
if mod == nil {
return nil, codefmt.Errorf(p, expr, "cannot find %q module declared by convgen.Module", id.Name)
}
return mod, nil
Expand Down
43 changes: 43 additions & 0 deletions testdata/program/ModuleImport/main/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
//go:build convgen

package main

import (
"fmt"
"strconv"

"github.com/sublee/convgen"
)

type User struct {
Id int
Name string
}

type UserDTO struct {
ID string
NAME string
}

func IntToString(i int) string {
return strconv.Itoa(i)
}

var mod1 = convgen.Module(
convgen.ImportFunc(IntToString),
)

var mod2 = convgen.Module(
convgen.ImportModule(mod1),
convgen.RenameToLower(true, true),
)

var UserToDTO = convgen.Struct[User, UserDTO](mod2)

func main() {
dto := UserToDTO(User{
Id: 42,
Name: "Alice",
})
fmt.Println(dto.ID, dto.NAME)
}
1 change: 1 addition & 0 deletions testdata/program/ModuleImport/want/program_output.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
42 Alice