From 3f924d46f374d9fd40de8fd84220cdc95083899a Mon Sep 17 00:00:00 2001 From: adrianwitas Date: Sat, 14 Feb 2026 11:28:34 -0800 Subject: [PATCH 1/6] - introduces shape pkg --- go.mod | 10 +- go.sum | 8 +- internal/inference/spec.go | 58 +- internal/inference/state.go | 8 +- internal/translator/function.go | 5 +- internal/translator/output.go | 6 + internal/translator/resource.go | 2 +- internal/translator/rule.go | 2 +- internal/translator/viewlets.go | 2 +- repository/component.go | 2 + repository/components.go | 242 ++++++- repository/option.go | 19 + repository/shape/README.md | 61 ++ repository/shape/column/detector.go | 237 +++++++ repository/shape/column/detector_test.go | 59 ++ repository/shape/compile/compiler.go | 110 +++ repository/shape/compile/compiler_test.go | 69 ++ repository/shape/compile/doc.go | 2 + repository/shape/doc.go | 3 + repository/shape/dql_engine_test.go | 42 ++ repository/shape/errors.go | 12 + repository/shape/load/doc.go | 2 + repository/shape/load/errors.go | 7 + repository/shape/load/loader.go | 224 ++++++ repository/shape/load/loader_test.go | 116 ++++ repository/shape/load/model.go | 21 + repository/shape/load/testdata/report.sql | 1 + repository/shape/model.go | 53 ++ repository/shape/options.go | 73 ++ repository/shape/parity_test.go | 67 ++ repository/shape/plan/doc.go | 2 + repository/shape/plan/model.go | 72 ++ repository/shape/plan/planner.go | 174 +++++ repository/shape/plan/planner_test.go | 86 +++ repository/shape/plan/testdata/report.sql | 1 + repository/shape/scan/doc.go | 2 + repository/shape/scan/model.go | 33 + repository/shape/scan/scanner.go | 166 +++++ repository/shape/scan/scanner_test.go | 83 +++ repository/shape/scan/testdata/report.sql | 1 + repository/shape/shape.go | 157 +++++ repository/shape/source.go | 39 ++ repository/shape/source_type.go | 56 ++ repository/shape/source_type_test.go | 33 + repository/shape/typectx/model.go | 29 + repository/shape/typectx/resolver.go | 293 ++++++++ .../shape/typectx/resolver_memfs_test.go | 116 ++++ repository/shape/typectx/resolver_test.go | 89 +++ repository/shape/typectx/source/resolver.go | 283 ++++++++ .../shape/typectx/source/resolver_test.go | 91 +++ repository/shape/validate/relation.go | 140 ++++ repository/shape/validate/relation_test.go | 70 ++ repository/shape/xgen/generator.go | 644 ++++++++++++++++++ repository/shape/xgen/generator_test.go | 305 +++++++++ repository/shape/xgen/io.go | 311 +++++++++ repository/shape/xgen/model.go | 70 ++ view/state/parameters.go | 4 +- 57 files changed, 4856 insertions(+), 17 deletions(-) create mode 100644 repository/shape/README.md create mode 100644 repository/shape/column/detector.go create mode 100644 repository/shape/column/detector_test.go create mode 100644 repository/shape/compile/compiler.go create mode 100644 repository/shape/compile/compiler_test.go create mode 100644 repository/shape/compile/doc.go create mode 100644 repository/shape/doc.go create mode 100644 repository/shape/dql_engine_test.go create mode 100644 repository/shape/errors.go create mode 100644 repository/shape/load/doc.go create mode 100644 repository/shape/load/errors.go create mode 100644 repository/shape/load/loader.go create mode 100644 repository/shape/load/loader_test.go create mode 100644 repository/shape/load/model.go create mode 100644 repository/shape/load/testdata/report.sql create mode 100644 repository/shape/model.go create mode 100644 repository/shape/options.go create mode 100644 repository/shape/parity_test.go create mode 100644 repository/shape/plan/doc.go create mode 100644 repository/shape/plan/model.go create mode 100644 repository/shape/plan/planner.go create mode 100644 repository/shape/plan/planner_test.go create mode 100644 repository/shape/plan/testdata/report.sql create mode 100644 repository/shape/scan/doc.go create mode 100644 repository/shape/scan/model.go create mode 100644 repository/shape/scan/scanner.go create mode 100644 repository/shape/scan/scanner_test.go create mode 100644 repository/shape/scan/testdata/report.sql create mode 100644 repository/shape/shape.go create mode 100644 repository/shape/source.go create mode 100644 repository/shape/source_type.go create mode 100644 repository/shape/source_type_test.go create mode 100644 repository/shape/typectx/model.go create mode 100644 repository/shape/typectx/resolver.go create mode 100644 repository/shape/typectx/resolver_memfs_test.go create mode 100644 repository/shape/typectx/resolver_test.go create mode 100644 repository/shape/typectx/source/resolver.go create mode 100644 repository/shape/typectx/source/resolver_test.go create mode 100644 repository/shape/validate/relation.go create mode 100644 repository/shape/validate/relation_test.go create mode 100644 repository/shape/xgen/generator.go create mode 100644 repository/shape/xgen/generator_test.go create mode 100644 repository/shape/xgen/io.go create mode 100644 repository/shape/xgen/model.go diff --git a/go.mod b/go.mod index 22458b8b..baaae6b4 100644 --- a/go.mod +++ b/go.mod @@ -2,6 +2,12 @@ module github.com/viant/datly go 1.25.0 +replace github.com/viant/velty => ../velty + +replace github.com/viant/x => ../x + +replace github.com/viant/sqlparser => ../sqlparser + require ( github.com/aerospike/aerospike-client-go v4.5.2+incompatible github.com/aws/aws-lambda-go v1.31.0 @@ -15,7 +21,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.16 github.com/pkg/errors v0.9.1 github.com/stretchr/testify v1.11.1 - github.com/viant/afs v1.26.2 + github.com/viant/afs v1.29.0 github.com/viant/afsc v1.16.0 github.com/viant/assertly v0.9.1-0.20220620174148-bab013f93a60 github.com/viant/bigquery v0.4.1 @@ -53,6 +59,7 @@ require ( github.com/viant/mcp-protocol v0.9.0 github.com/viant/structology v0.8.0 github.com/viant/tagly v0.3.0 + github.com/viant/x v0.3.0 github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a github.com/viant/xdatly/extension v0.0.0-20231013204918-ecf3c2edf259 github.com/viant/xdatly/handler v0.0.0-20251208172928-dd34b7f09fd5 @@ -151,7 +158,6 @@ require ( github.com/spiffe/go-spiffe/v2 v2.6.0 // indirect github.com/viant/gosh v0.2.1 // indirect github.com/viant/igo v0.2.0 // indirect - github.com/viant/x v0.3.0 // indirect github.com/xuri/efp v0.0.0-20230802181842-ad255f2331ca // indirect github.com/xuri/excelize/v2 v2.8.0 // indirect github.com/xuri/nfp v0.0.0-20230819163627-dc951e3ffe1a // indirect diff --git a/go.sum b/go.sum index ddb2bb56..d7923c12 100644 --- a/go.sum +++ b/go.sum @@ -1152,8 +1152,8 @@ github.com/tklauser/go-sysconf v0.3.9/go.mod h1:11DU/5sG7UexIrp/O6g35hrWzu0JxlwQ github.com/tklauser/numcpus v0.3.0/go.mod h1:yFGUr7TUHQRAhyqBcEg0Ge34zDBAsIvJJcyE6boqnA8= github.com/viant/aerospike v0.2.11-0.20241108195857-ed524b97800d h1:IRmoMmrWqkHDBy0tk9mbHRDK7+ynn0Gzwl+9WIiAtNs= github.com/viant/aerospike v0.2.11-0.20241108195857-ed524b97800d/go.mod h1:eRBywl0oTDM/oGhGLUeJjnC7XzmkTGuW9/og5YFy0K0= -github.com/viant/afs v1.26.2 h1:rOs/iFxFlEndhIRATJVXlNWhVU0cGdRQAGVTVJPdsc0= -github.com/viant/afs v1.26.2/go.mod h1:rScbFd9LJPGTM8HOI8Kjwee0AZ+MZMupAvFpPg+Qdj4= +github.com/viant/afs v1.29.0 h1:ndnn+PBQt5ep/bE1m5OvIvMjpoCCZbtl/UlJEubT9kE= +github.com/viant/afs v1.29.0/go.mod h1:rScbFd9LJPGTM8HOI8Kjwee0AZ+MZMupAvFpPg+Qdj4= github.com/viant/afsc v1.16.0 h1:/kOH/flNwme6h3oFrU/KPnMHkhbCZxQncTf1GSQIlBQ= github.com/viant/afsc v1.16.0/go.mod h1:Z6fP3VcmzS8Sg2lowctR6KkVEX7XxJ8aNaoHqhUiZkY= github.com/viant/assertly v0.4.8/go.mod h1:aGifi++jvCrUaklKEKT0BU95igDNaqkvz+49uaYMPRU= @@ -1208,10 +1208,6 @@ github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMI github.com/viant/toolbox v0.34.5/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/viant/toolbox v0.37.0 h1:+zwSdbQh6I6ZEyxokQJr+1gQKbLEw6erc+Av5dwKtLU= github.com/viant/toolbox v0.37.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= -github.com/viant/velty v0.2.1-0.20230927172116-ba56497b5c85 h1:zKk+6hqUipkJXCPCHyFXzGtil1sfh80r6UZmloBNEDo= -github.com/viant/velty v0.2.1-0.20230927172116-ba56497b5c85/go.mod h1:Q/UXviI2Nli8WROEpYd/BELMCSvnulQeyNrbPmMiS/Y= -github.com/viant/x v0.3.0 h1:/3A0z/uySGxMo6ixH90VAcdjI00w5e3REC1zg5hzhJA= -github.com/viant/x v0.3.0/go.mod h1:54jP3qV+nnQdNDaWxEwGTAAzCu9sx9er9htiwTW/Mcw= github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a h1:7CLO2LjVnFgOwN0FL3Q4y5NrD7DpclS21AiW6tDLIc8= github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a/go.mod h1:lZKZHhVdCZ3U9TU6GUFxKoGN3dPtqt2HkDYzJPq5CEs= github.com/viant/xdatly/extension v0.0.0-20231013204918-ecf3c2edf259 h1:9Yry3PUBDzc4rWacOYvAq/TKrTV0agvMF0gwm2gaoHI= diff --git a/internal/inference/spec.go b/internal/inference/spec.go index 52fc6068..214857d1 100644 --- a/internal/inference/spec.go +++ b/internal/inference/spec.go @@ -246,7 +246,13 @@ func NewSpec(ctx context.Context, db *sql.DB, messages *msg.Messages, table stri var result = &Spec{Table: table, SQL: SQL, SQLArgs: SQLArgs, IsAuxiliary: isAuxiliary} columns, err := column.Discover(ctx, db, table, SQL, SQLArgs...) if err != nil { - return nil, err + columns = bestEffortColumnsFromSQL(SQL, columnsConfig) + if len(columns) == 0 { + return nil, err + } + if messages != nil { + messages.AddWarning(result.Table, "detection", fmt.Sprintf("using best-effort SQL column inference due to discovery error: %v", err)) + } } result.Columns = columns byName := result.Columns.ByName() @@ -285,6 +291,56 @@ func NewSpec(ctx context.Context, db *sql.DB, messages *msg.Messages, table stri return result, nil } +func bestEffortColumnsFromSQL(SQL string, columnsConfig view.ColumnConfigs) sqlparser.Columns { + if strings.TrimSpace(SQL) == "" { + return nil + } + query, err := sqlparser.ParseQuery(SQL) + if err != nil || query == nil { + return nil + } + queryColumns := sqlparser.NewColumns(query.List) + if len(queryColumns) == 0 { + return nil + } + cfgByLower := map[string]*view.ColumnConfig{} + for _, cfg := range columnsConfig { + if cfg == nil || cfg.Name == "" { + continue + } + cfgByLower[strings.ToLower(cfg.Name)] = cfg + } + var result sqlparser.Columns + for _, candidate := range queryColumns { + if candidate == nil { + continue + } + expression := strings.TrimSpace(candidate.Expression) + if expression == "*" || strings.HasSuffix(expression, ".*") { + continue + } + name := strings.TrimSpace(candidate.Alias) + if name == "" { + name = strings.TrimSpace(candidate.Name) + } + if name == "" { + continue + } + if candidate.Type == "" { + if cfg, ok := cfgByLower[strings.ToLower(name)]; ok && cfg.DataType != nil && *cfg.DataType != "" { + candidate.Type = *cfg.DataType + } else if cfg, ok = cfgByLower[strings.ToLower(candidate.Name)]; ok && cfg.DataType != nil && *cfg.DataType != "" { + candidate.Type = *cfg.DataType + } + } + if candidate.Type == "" { + candidate.Type = "string" + } + result = append(result, candidate) + } + return result +} + func isAuxiliary(SQL *string) bool { if *SQL == "" { return false diff --git a/internal/inference/state.go b/internal/inference/state.go index e309bc01..fa28c984 100644 --- a/internal/inference/state.go +++ b/internal/inference/state.go @@ -491,7 +491,13 @@ func (s State) EnsureReflectTypes(modulePath string, pkg string, registry *xrefl if err != nil { rType, err = types.LookupType(typeRegistry.Lookup, dataType, xreflect.WithPackage(pkg)) if err != nil { - return err + rType = reflect.TypeOf((*interface{})(nil)).Elem() + if param.Schema.DataType == "" { + param.Schema.DataType = "interface{}" + } + if param.Schema.Package == "" { + param.Schema.Package = pkg + } } } param.Schema.SetType(rType) diff --git a/internal/translator/function.go b/internal/translator/function.go index 923250d5..38ddb6a4 100644 --- a/internal/translator/function.go +++ b/internal/translator/function.go @@ -124,7 +124,10 @@ func (v *Viewlet) applyExplicitCast(column *sqlparser.Column, funcArgs []string) column.Type = funcArgs[1] rType, err := types.LookupType(v.Resource.typeRegistry.Lookup, column.Type) if err != nil { - return false, fmt.Errorf("unknown column %v type: %s, %w", column.Name, column.Type, err) + // Keep unresolved custom cast as metadata only. This preserves declared type + // (e.g. *fee.Fee) for IR/yaml parity without forcing runtime type resolution. + // Built-in and resolvable types still set RawType. + return true, nil } column.RawType = rType return true, nil diff --git a/internal/translator/output.go b/internal/translator/output.go index 6c562649..dcadf766 100644 --- a/internal/translator/output.go +++ b/internal/translator/output.go @@ -452,6 +452,12 @@ func (s *Service) ensureOutputParameters(resource *Resource, outputState inferen } func (s *Service) updateParameterWithComponentOutputType(dataParameter *state.Parameter, rootViewlet *Viewlet) { + if rootViewlet == nil || rootViewlet.View == nil || rootViewlet.Resource == nil || rootViewlet.Resource.rule == nil { + return + } + if rootViewlet.View.Schema == nil { + rootViewlet.View.Schema = &state.Schema{} + } typeName := rootViewlet.View.Schema.Name if typeName == "" || typeName == "string" { typeName = view.DefaultTypeName(rootViewlet.Name) diff --git a/internal/translator/resource.go b/internal/translator/resource.go index db7e9712..00630f33 100644 --- a/internal/translator/resource.go +++ b/internal/translator/resource.go @@ -733,7 +733,7 @@ func (r *Resource) updatedObject(loadType func(typeName string) (reflect.Type, e schema := parameter.OutputSchema() wType := schema.Type() if wType == nil { - return fmt.Errorf("failed to get parameter auxiliary type: %s, %w", parameter.Name, schema.Name) + return fmt.Errorf("failed to get parameter auxiliary type: %s, %s", parameter.Name, schema.Name) } auxiliaryState := inference.State{} if err := r.extractState(loadType, wType, &auxiliaryState); err != nil { diff --git a/internal/translator/rule.go b/internal/translator/rule.go index ec42c538..b2cd35c9 100644 --- a/internal/translator/rule.go +++ b/internal/translator/rule.go @@ -193,7 +193,7 @@ func (r *Resource) initRule(ctx context.Context, fs afs.Service, dSQL *string) e rule := r.Rule rule.applyDefaults() if err := r.loadData(ctx, fs, rule.ConstURL, &rule.Const); err != nil { - r.messages.AddWarning(r.rule.RuleName(), "const", fmt.Sprintf("failed to load constant : %v %w", rule.ConstURL, err)) + r.messages.AddWarning(r.rule.RuleName(), "const", fmt.Sprintf("failed to load constant : %v %v", rule.ConstURL, err)) } r.State.AppendConst(rule.Const) return r.loadDocumentation(ctx, fs, rule) diff --git a/internal/translator/viewlets.go b/internal/translator/viewlets.go index 72707387..ddeb349a 100644 --- a/internal/translator/viewlets.go +++ b/internal/translator/viewlets.go @@ -76,7 +76,7 @@ func (n *Viewlets) Init(ctx context.Context, aQuery *query.Select, resource *Res if err := n.Each(func(viewlet *Viewlet) error { n.ensureConnector(viewlet, rootConnector) if err := initFn(ctx, viewlet); err != nil { - return fmt.Errorf("failed to init viewlet: %ns, %w", viewlet.Name, err) + return fmt.Errorf("failed to init viewlet: %s, %w", viewlet.Name, err) } return nil }); err != nil { diff --git a/repository/component.go b/repository/component.go index ec106e47..179ff7a9 100644 --- a/repository/component.go +++ b/repository/component.go @@ -18,6 +18,7 @@ import ( content "github.com/viant/datly/repository/content" "github.com/viant/datly/repository/contract" "github.com/viant/datly/repository/handler" + "github.com/viant/datly/repository/shape/typectx" "github.com/viant/datly/repository/version" "github.com/viant/datly/service" "github.com/viant/datly/shared" @@ -47,6 +48,7 @@ type ( View *view.View `json:",omitempty"` NamespacedView *view.NamespacedView Handler *handler.Handler `json:",omitempty"` + TypeContext *typectx.Context `json:",omitempty" yaml:",omitempty"` indexedView view.NamedViews SourceURL string diff --git a/repository/components.go b/repository/components.go index 536ad329..a431095a 100644 --- a/repository/components.go +++ b/repository/components.go @@ -13,6 +13,13 @@ import ( "github.com/viant/datly/internal/inference" "github.com/viant/datly/internal/translator/parser" "github.com/viant/datly/repository/codegen" + "github.com/viant/datly/repository/shape" + shapecolumn "github.com/viant/datly/repository/shape/column" + dqlparse "github.com/viant/datly/repository/shape/dql/parse" + shapeLoad "github.com/viant/datly/repository/shape/load" + shapePlan "github.com/viant/datly/repository/shape/plan" + shapeScan "github.com/viant/datly/repository/shape/scan" + "github.com/viant/datly/repository/shape/typectx" "github.com/viant/datly/repository/version" "github.com/viant/datly/utils/types" "github.com/viant/datly/view" @@ -24,6 +31,7 @@ import ( "gopkg.in/yaml.v3" "path" "reflect" + "strings" ) type Components struct { @@ -61,6 +69,9 @@ func (c *Components) Init(ctx context.Context) error { options = append(options, &view.Metrics{Method: c.Components[0].Method, Service: c.options.metrics}) } for _, component := range c.Components { + if c.options != nil && c.options.legacyTypeContext { + component.TypeContext = resolveComponentTypeContext(component) + } if len(component.with) > 0 { c.With = append(c.With, component.with...) } @@ -80,6 +91,9 @@ func (c *Components) Init(ctx context.Context) error { } c.ensureNamedViewType(ctx, embedFs, aComponent) + if err = c.mergeShapeViews(ctx, aComponent); err != nil { + return err + } if err = c.Resource.Init(ctx, options...); err != nil { return err @@ -106,6 +120,62 @@ func (c *Components) Init(ctx context.Context) error { return nil } +func (c *Components) mergeShapeViews(ctx context.Context, aComponent *Component) error { + if c.options == nil || !c.options.shapePipeline || aComponent == nil || aComponent.Output.Type.Schema == nil { + return nil + } + rType := c.ReflectType(aComponent.Output.Type.Schema) + if rType == nil { + return nil + } + engine := shape.New( + shape.WithScanner(shapeScan.New()), + shape.WithPlanner(shapePlan.New()), + shape.WithLoader(shapeLoad.New()), + shape.WithName(aComponent.Path.URI), + ) + source := zeroValue(rType) + if source == nil { + return nil + } + artifacts, err := engine.LoadViews(ctx, source) + if err != nil { + return fmt.Errorf("failed to load shape views for %s: %w", aComponent.Path.URI, err) + } + if artifacts == nil || artifacts.Resource == nil { + return nil + } + if c.Resource.FSEmbedder == nil && artifacts.Resource.FSEmbedder != nil { + c.Resource.FSEmbedder = artifacts.Resource.FSEmbedder + } + existing := c.Resource.Views.Index() + columnDetector := shapecolumn.New() + for _, candidate := range artifacts.Views { + if candidate == nil { + continue + } + if _, err = existing.Lookup(candidate.Name); err == nil { + continue + } + if candidate.Columns, err = columnDetector.Resolve(ctx, c.Resource, candidate); err != nil { + return fmt.Errorf("failed to resolve shape columns for %s: %w", candidate.Name, err) + } + c.Resource.Views = append(c.Resource.Views, candidate) + existing.Register(candidate) + } + return nil +} + +func zeroValue(rType reflect.Type) interface{} { + if rType == nil { + return nil + } + if rType.Kind() == reflect.Ptr { + return reflect.New(rType.Elem()).Interface() + } + return reflect.New(rType).Interface() +} + func (c *Components) ensureNamedViewType(ctx context.Context, embedFs *embed.FS, aComponent *Component) { inCodeGeneration := codegen.IsGeneratorContext(ctx) if rType := c.ReflectType(c.Components[0].Output.Type.Schema); rType != nil && !inCodeGeneration { @@ -374,7 +444,7 @@ func LoadComponents(ctx context.Context, URL string, opts ...Option) (*Component } } } - components, err := unmarshalComponent(data) + components, err := unmarshalComponent(data, options.legacyTypeContext) if err != nil { return nil, err } @@ -396,17 +466,47 @@ func LoadComponents(ctx context.Context, URL string, opts ...Option) (*Component return components, nil } -func unmarshalComponent(data []byte) (*Components, error) { +// LoadComponentsFromMap loads components directly from in-memory route/resource model. +// The input map is expected to follow the same shape as route YAML after unmarshalling. +func LoadComponentsFromMap(ctx context.Context, model map[string]any, opts ...Option) (*Components, error) { + if len(model) == 0 { + return nil, fmt.Errorf("components model was empty") + } + options := NewOptions(opts) + components, err := unmarshalComponentMap(model, options.legacyTypeContext) + if err != nil { + return nil, err + } + components.options = options + components.resources = options.resources + if components.Resource == nil { + return nil, fmt.Errorf("resources were empty") + } + if err = components.mergeResources(ctx); err != nil { + return nil, err + } + components.Resource.SetTypes(options.extensions.Types) + return components, nil +} + +func unmarshalComponent(data []byte, enableLegacyTypeContext bool) (*Components, error) { aMap := map[string]interface{}{} if err := yaml.Unmarshal(data, &aMap); err != nil { return nil, err } + return unmarshalComponentMap(aMap, enableLegacyTypeContext) +} + +func unmarshalComponentMap(aMap map[string]any, enableLegacyTypeContext bool) (*Components, error) { ensureComponents(aMap) components := &Components{} err := toolbox.DefaultConverter.AssignConverted(components, aMap) if err != nil { return nil, err } + if enableLegacyTypeContext { + applyLegacyTypeContext(aMap, components) + } return components, err } @@ -415,3 +515,141 @@ func ensureComponents(aMap map[string]interface{}) { aMap["Components"] = aMap["Routes"] } } + +func applyLegacyTypeContext(source map[string]any, components *Components) { + if len(components.Components) == 0 { + return + } + defaultTypeContext := asTypeContext(source["TypeContext"]) + items := asAnySlice(source["Components"]) + for i, component := range components.Components { + if component == nil { + continue + } + if component.TypeContext != nil { + continue + } + var resolved *typectx.Context + if i < len(items) { + if itemMap := asStringMap(items[i]); itemMap != nil { + resolved = asTypeContext(itemMap["TypeContext"]) + } + } + if resolved == nil { + resolved = defaultTypeContext + } + if resolved != nil { + component.TypeContext = cloneTypeContext(resolved) + } + } +} + +func asTypeContext(raw any) *typectx.Context { + mapped := asStringMap(raw) + if mapped == nil { + return nil + } + ret := &typectx.Context{ + DefaultPackage: asString(mapped["DefaultPackage"]), + } + for _, item := range asAnySlice(mapped["Imports"]) { + itemMap := asStringMap(item) + if itemMap == nil { + continue + } + pkg := asString(itemMap["Package"]) + if pkg == "" { + continue + } + ret.Imports = append(ret.Imports, typectx.Import{ + Alias: asString(itemMap["Alias"]), + Package: pkg, + }) + } + if ret.DefaultPackage == "" && len(ret.Imports) == 0 { + return nil + } + return ret +} + +func resolveComponentTypeContext(component *Component) *typectx.Context { + if component == nil { + return nil + } + if normalized := normalizeTypeContext(component.TypeContext); normalized != nil { + return normalized + } + if component.View == nil || component.View.Template == nil { + return nil + } + source := strings.TrimSpace(component.View.Template.Source) + if source == "" { + return nil + } + parsed, err := dqlparse.New().Parse(source) + if err != nil || parsed == nil { + return nil + } + return normalizeTypeContext(parsed.TypeContext) +} + +func normalizeTypeContext(input *typectx.Context) *typectx.Context { + if input == nil { + return nil + } + ret := &typectx.Context{ + DefaultPackage: strings.TrimSpace(input.DefaultPackage), + } + for _, item := range input.Imports { + pkg := strings.TrimSpace(item.Package) + if pkg == "" { + continue + } + ret.Imports = append(ret.Imports, typectx.Import{ + Alias: strings.TrimSpace(item.Alias), + Package: pkg, + }) + } + if ret.DefaultPackage == "" && len(ret.Imports) == 0 { + return nil + } + return ret +} + +func cloneTypeContext(input *typectx.Context) *typectx.Context { + return normalizeTypeContext(input) +} + +func asAnySlice(raw any) []any { + switch actual := raw.(type) { + case []any: + return actual + default: + return nil + } +} + +func asStringMap(raw any) map[string]any { + switch actual := raw.(type) { + case map[string]any: + return actual + case map[interface{}]interface{}: + result := make(map[string]any, len(actual)) + for k, v := range actual { + result[fmt.Sprint(k)] = v + } + return result + default: + return nil + } +} + +func asString(raw any) string { + if raw == nil { + return "" + } + if value, ok := raw.(string); ok { + return value + } + return fmt.Sprint(raw) +} diff --git a/repository/option.go b/repository/option.go index c660a748..9c2b9b34 100644 --- a/repository/option.go +++ b/repository/option.go @@ -43,6 +43,8 @@ type Options struct { constants map[string]string substitutes map[string]view.Substitutes authConfig aconfig.Config + shapePipeline bool + legacyTypeContext bool } func (o *Options) UseColumn() bool { @@ -242,6 +244,23 @@ func WithPath(aPath *path.Path) Option { } } +// WithShapePipeline enables the repository/shape scan->plan->load pipeline +// during components initialization. +// The default is false to preserve existing behavior. +func WithShapePipeline(enabled bool) Option { + return func(o *Options) { + o.shapePipeline = enabled + } +} + +// WithLegacyTypeContext enables TypeContext enrichment in legacy repository runtime. +// Disabled by default for rollback safety. +func WithLegacyTypeContext(enabled bool) Option { + return func(o *Options) { + o.legacyTypeContext = enabled + } +} + func WithJWTSigner(aSigner *signer.Config) Option { return func(o *Options) { o.authConfig.JwtSigner = aSigner diff --git a/repository/shape/README.md b/repository/shape/README.md new file mode 100644 index 00000000..793b0404 --- /dev/null +++ b/repository/shape/README.md @@ -0,0 +1,61 @@ +# repository/shape + +`repository/shape` provides a dynamic, in-memory pipeline for building Datly runtime artifacts from either: + +- Go structs (`scan -> plan -> load`) +- DQL (`compile -> load`) + +without generating YAML route/resource files. + +## Packages + +- `shape/scan`: discovers view/state tags from struct fields (Embedder-aware). +- `shape/plan`: normalizes scan output into a deterministic shape plan. +- `shape/load`: materializes `view.Resource`, `view.View`, and a runtime-neutral component artifact. +- `shape/compile`: compiles DQL into a shape plan for dynamic loading. + +## Facade API + +Use `shape.Engine` or package helpers: + +- `shape.LoadViews(ctx, src, opts...)` +- `shape.LoadComponent(ctx, src, opts...)` +- `shape.LoadDQLViews(ctx, dql, opts...)` +- `shape.LoadDQLComponent(ctx, dql, opts...)` + +## Minimal Struct Flow + +```go +engine := shape.New( + shape.WithScanner(scan.New()), + shape.WithPlanner(plan.New()), + shape.WithLoader(load.New()), + shape.WithName("/v1/api/report"), +) + +views, err := engine.LoadViews(ctx, &MyOutput{}) +``` + +## Minimal DQL Flow + +```go +engine := shape.New( + shape.WithCompiler(compile.New()), + shape.WithLoader(load.New()), + shape.WithName("/v1/api/report"), +) + +component, err := engine.LoadDQLComponent(ctx, "SELECT id FROM ORDERS t") +``` + +## Repository Integration + +`repository/components.go` can optionally merge views generated by the shape pipeline during init. + +Enable via: + +```go +repository.WithShapePipeline(true) +``` + +Default is disabled to preserve existing behavior. diff --git a/repository/shape/column/detector.go b/repository/shape/column/detector.go new file mode 100644 index 00000000..79b6c8d1 --- /dev/null +++ b/repository/shape/column/detector.go @@ -0,0 +1,237 @@ +package column + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/viant/datly/view" + viewcolumn "github.com/viant/datly/view/column" + "github.com/viant/sqlparser" + "github.com/viant/sqlx/io" +) + +// Detector resolves columns for shape-generated views. +// +// Rules: +// - schema field order is canonical order +// - wildcard SQL always performs DB discovery +// - newly discovered columns are appended at the end +// - matched columns keep schema order but refresh metadata from DB +type Detector struct{} + +func New() *Detector { + return &Detector{} +} + +func (d *Detector) Resolve(ctx context.Context, resource *view.Resource, aView *view.View) (view.Columns, error) { + if aView == nil { + return nil, fmt.Errorf("shape column detector: nil view") + } + + base := columnsFromSchema(aView) + if !usesWildcard(aView) { + return base, nil + } + + discovered, err := d.detect(ctx, resource, aView) + if err != nil { + return nil, err + } + if len(base) == 0 { + return discovered, nil + } + return mergePreservingOrder(base, discovered), nil +} + +func (d *Detector) detect(ctx context.Context, resource *view.Resource, aView *view.View) (view.Columns, error) { + connector, err := lookupConnector(ctx, resource, aView) + if err != nil { + return nil, err + } + db, err := connector.DB() + if err != nil { + return nil, fmt.Errorf("shape column detector: failed to open db for view %s: %w", aView.Name, err) + } + query := sourceSQL(aView) + sqlColumns, err := viewcolumn.Discover(ctx, db, aView.Table, query) + if err != nil { + return nil, fmt.Errorf("shape column detector: discover failed for view %s: %w", aView.Name, err) + } + return view.NewColumns(sqlColumns, aView.ColumnsConfig), nil +} + +func lookupConnector(ctx context.Context, resource *view.Resource, aView *view.View) (*view.Connector, error) { + if resource == nil { + return nil, fmt.Errorf("shape column detector: missing resource for view %s", aView.Name) + } + if aView.Connector == nil { + return nil, fmt.Errorf("shape column detector: missing connector for wildcard view %s", aView.Name) + } + connectors := view.ConnectorSlice(resource.Connectors).Index() + connector := aView.Connector + if connector.Ref != "" { + lookup, err := connectors.Lookup(connector.Ref) + if err != nil { + return nil, fmt.Errorf("shape column detector: connector ref %s for view %s: %w", connector.Ref, aView.Name, err) + } + connector = lookup + } + if err := connector.Init(ctx, connectors); err != nil { + return nil, fmt.Errorf("shape column detector: connector init for view %s: %w", aView.Name, err) + } + return connector, nil +} + +func sourceSQL(aView *view.View) string { + if aView.Template != nil && strings.TrimSpace(aView.Template.Source) != "" { + return aView.Template.Source + } + return aView.Source() +} + +func usesWildcard(aView *view.View) bool { + if aView != nil && aView.Template == nil && strings.TrimSpace(aView.Table) != "" { + return true + } + query := sourceSQL(aView) + trimmed := strings.TrimSpace(strings.ToLower(query)) + if trimmed == "" { + return false + } + if !strings.Contains(trimmed, "*") { + return false + } + if !strings.HasPrefix(trimmed, "select") && !strings.HasPrefix(trimmed, "with") { + return true + } + parsed, err := sqlparser.ParseQuery(query) + if err != nil { + return true + } + return sqlparser.NewColumns(parsed.List).IsStarExpr() +} + +func columnsFromSchema(aView *view.View) view.Columns { + if aView == nil || aView.Schema == nil { + return nil + } + rType := aView.Schema.Type() + if rType == nil { + return nil + } + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return nil + } + result := make(view.Columns, 0, rType.NumField()) + appendSchemaColumns(rType, "", &result) + return result +} + +func appendSchemaColumns(rType reflect.Type, ns string, columns *view.Columns) { + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + if field.PkgPath != "" { // unexported + continue + } + if field.Anonymous { + inner := field.Type + for inner.Kind() == reflect.Ptr { + inner = inner.Elem() + } + if inner.Kind() == reflect.Struct { + appendSchemaColumns(inner, ns, columns) + } + continue + } + + tag := io.ParseTag(field.Tag) + if tag != nil && tag.Transient { + continue + } + + name := field.Name + if tag != nil && tag.Column != "" { + name = tag.Column + } + if tag != nil && tag.Ns != "" { + name = tag.Ns + name + } else if ns != "" { + name = ns + name + } + + columnType := field.Type + nullable := false + if columnType.Kind() == reflect.Ptr { + nullable = true + columnType = columnType.Elem() + } + *columns = append(*columns, view.NewColumn(name, columnType.String(), columnType, nullable, view.WithColumnTag(string(field.Tag)))) + } +} + +func mergePreservingOrder(base, discovered view.Columns) view.Columns { + if len(base) == 0 { + return discovered + } + if len(discovered) == 0 { + return base + } + seen := map[string]*view.Column{} + for _, item := range discovered { + if item == nil { + continue + } + seen[strings.ToLower(item.Name)] = item + } + result := make(view.Columns, 0, len(base)+len(discovered)) + for _, item := range base { + if item == nil { + continue + } + if fresh, ok := seen[strings.ToLower(item.Name)]; ok { + delete(seen, strings.ToLower(item.Name)) + // Keep schema name/order but refresh discovered metadata. + item.DataType = firstNonEmpty(fresh.DataType, item.DataType) + item.SetColumnType(firstType(fresh.ColumnType(), item.ColumnType())) + item.Nullable = fresh.Nullable + if item.DatabaseColumn == "" { + item.DatabaseColumn = fresh.DatabaseColumn + } + } + result = append(result, item) + } + for _, item := range discovered { + if item == nil { + continue + } + if _, ok := seen[strings.ToLower(item.Name)]; !ok { + continue + } + result = append(result, item) + delete(seen, strings.ToLower(item.Name)) + } + return result +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} + +func firstType(values ...reflect.Type) reflect.Type { + for _, value := range values { + if value != nil { + return value + } + } + return nil +} diff --git a/repository/shape/column/detector_test.go b/repository/shape/column/detector_test.go new file mode 100644 index 00000000..cfc834b1 --- /dev/null +++ b/repository/shape/column/detector_test.go @@ -0,0 +1,59 @@ +package column + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +type sampleOrder struct { + VendorID int `sqlx:"name=VENDOR_ID"` + Name string `sqlx:"name=NAME"` +} + +func TestUsesWildcard(t *testing.T) { + tests := []struct { + name string + view *view.View + want bool + }{ + {name: "select wildcard", view: &view.View{Template: view.NewTemplate("SELECT * FROM VENDOR")}, want: true}, + {name: "select explicit", view: &view.View{Template: view.NewTemplate("SELECT ID, NAME FROM VENDOR")}, want: false}, + {name: "table only", view: &view.View{Table: "VENDOR"}, want: true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + require.Equal(t, tc.want, usesWildcard(tc.view)) + }) + } +} + +func TestColumnsFromSchema_Order(t *testing.T) { + aView := &view.View{Schema: state.NewSchema(reflect.TypeOf(sampleOrder{}), state.WithMany())} + cols := columnsFromSchema(aView) + require.Len(t, cols, 2) + require.Equal(t, "VENDOR_ID", cols[0].Name) + require.Equal(t, "NAME", cols[1].Name) +} + +func TestMergePreservingOrder_AppendsNewDetectedColumns(t *testing.T) { + base := view.Columns{ + view.NewColumn("VENDOR_ID", "int", reflect.TypeOf(int(0)), false), + view.NewColumn("NAME", "varchar", reflect.TypeOf(""), false), + } + detected := view.Columns{ + view.NewColumn("NAME", "text", reflect.TypeOf(""), true), + view.NewColumn("VENDOR_ID", "bigint", reflect.TypeOf(int64(0)), false), + view.NewColumn("STATUS", "int", reflect.TypeOf(int(0)), true), + } + merged := mergePreservingOrder(base, detected) + require.Len(t, merged, 3) + require.Equal(t, "VENDOR_ID", merged[0].Name) + require.Equal(t, "NAME", merged[1].Name) + require.Equal(t, "STATUS", merged[2].Name) + require.Equal(t, "bigint", merged[0].DataType) + require.Equal(t, "text", merged[1].DataType) +} diff --git a/repository/shape/compile/compiler.go b/repository/shape/compile/compiler.go new file mode 100644 index 00000000..69647b60 --- /dev/null +++ b/repository/shape/compile/compiler.go @@ -0,0 +1,110 @@ +package compile + +import ( + "context" + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/viant/datly/internal/translator/parser" + "github.com/viant/datly/repository/shape" + dqlparse "github.com/viant/datly/repository/shape/dql/parse" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser" +) + +// DQLCompiler compiles raw DQL into a shape plan that can be materialized by shape/load. +type DQLCompiler struct{} + +// New returns a DQL compiler implementation. +func New() *DQLCompiler { + return &DQLCompiler{} +} + +// Compile implements shape.DQLCompiler. +func (c *DQLCompiler) Compile(_ context.Context, source *shape.Source, _ ...shape.CompileOption) (*shape.PlanResult, error) { + if source == nil { + return nil, shape.ErrNilSource + } + dql := strings.TrimSpace(source.DQL) + if dql == "" { + return nil, shape.ErrNilDQL + } + + name, table, err := inferRoot(dql, source.Name) + if err != nil { + return nil, err + } + + result := &plan.Result{ + Views: []*plan.View{ + { + Path: name, + Holder: name, + Name: name, + Table: table, + SQL: dql, + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + if parsed, parseErr := dqlparse.New().Parse(dql); parseErr == nil && parsed != nil && parsed.TypeContext != nil { + result.TypeContext = parsed.TypeContext + } + result.ViewsByName[name] = result.Views[0] + return &shape.PlanResult{Source: source, Plan: result}, nil +} + +func inferRoot(dql string, fallback string) (string, string, error) { + query, err := sqlparser.ParseQuery(dql, parser.OnVeltyExpression()) + if err != nil { + name := sanitizeName(fallback) + if name == "" { + name = "DQLView" + } + return name, "", nil + } + + name := sanitizeName(query.From.Alias) + if name == "" { + name = sanitizeName(fallback) + } + if name == "" { + name = "DQLView" + } + + table := "" + if query != nil && query.From.X != nil { + table = strings.TrimSpace(sqlparser.Stringify(query.From.X)) + } + if table == "" || strings.HasPrefix(table, "(") { + table = name + } + if name == "" { + return "", "", fmt.Errorf("shape compile: failed to infer view name") + } + return name, table, nil +} + +var nonWord = regexp.MustCompile(`[^a-zA-Z0-9_]+`) + +func sanitizeName(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + value = nonWord.ReplaceAllString(value, "_") + value = strings.Trim(value, "_") + if value == "" { + return "" + } + if value[0] >= '0' && value[0] <= '9' { + value = "V_" + value + } + return value +} diff --git a/repository/shape/compile/compiler_test.go b/repository/shape/compile/compiler_test.go new file mode 100644 index 00000000..b539ab80 --- /dev/null +++ b/repository/shape/compile/compiler_test.go @@ -0,0 +1,69 @@ +package compile + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" +) + +func TestDQLCompiler_Compile(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "SELECT id FROM ORDERS t"}) + require.NoError(t, err) + require.NotNil(t, res) + + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.Len(t, planned.Views, 1) + view := planned.Views[0] + assert.Equal(t, "t", view.Name) + assert.Equal(t, "ORDERS", view.Table) + assert.Equal(t, "many", view.Cardinality) +} + +func TestDQLCompiler_Compile_EmptyDQL(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "x"}) + require.Error(t, err) + assert.ErrorIs(t, err, shape.ErrNilDQL) +} + +func TestDQLCompiler_Compile_WithPreamble_NoPanic(t *testing.T) { + compiler := New() + dql := ` +/* metadata */ +#set($_ = $A(query/a).Optional()) +SELECT id +` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "sample_report", DQL: dql}) + require.NoError(t, err) + require.NotNil(t, res) + + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.Len(t, planned.Views, 1) + assert.Equal(t, "sample_report", planned.Views[0].Name) + assert.Equal(t, "sample_report", planned.Views[0].Table) +} + +func TestDQLCompiler_Compile_PropagatesTypeContext(t *testing.T) { + compiler := New() + dql := ` +#set($_ = $package('mdp/performance')) +#set($_ = $import('perf', 'github.com/acme/mdp/performance')) +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + require.NotNil(t, res) + + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotNil(t, planned.TypeContext) + assert.Equal(t, "mdp/performance", planned.TypeContext.DefaultPackage) + require.Len(t, planned.TypeContext.Imports, 1) + assert.Equal(t, "perf", planned.TypeContext.Imports[0].Alias) +} diff --git a/repository/shape/compile/doc.go b/repository/shape/compile/doc.go new file mode 100644 index 00000000..c5a996ba --- /dev/null +++ b/repository/shape/compile/doc.go @@ -0,0 +1,2 @@ +// Package compile provides DQL-to-shape compilation. +package compile diff --git a/repository/shape/doc.go b/repository/shape/doc.go new file mode 100644 index 00000000..730ab139 --- /dev/null +++ b/repository/shape/doc.go @@ -0,0 +1,3 @@ +// Package shape provides building blocks for dynamic repository loading from +// struct and DQL sources without requiring persisted YAML artifacts. +package shape diff --git a/repository/shape/dql_engine_test.go b/repository/shape/dql_engine_test.go new file mode 100644 index 00000000..fafe3f67 --- /dev/null +++ b/repository/shape/dql_engine_test.go @@ -0,0 +1,42 @@ +package shape_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + shape "github.com/viant/datly/repository/shape" + shapeCompile "github.com/viant/datly/repository/shape/compile" + shapeLoad "github.com/viant/datly/repository/shape/load" +) + +func TestEngine_LoadDQLViews(t *testing.T) { + engine := shape.New( + shape.WithCompiler(shapeCompile.New()), + shape.WithLoader(shapeLoad.New()), + shape.WithName("/v1/api/reports/orders"), + ) + artifacts, err := engine.LoadDQLViews(context.Background(), "SELECT id FROM ORDERS t") + require.NoError(t, err) + require.NotNil(t, artifacts) + require.Len(t, artifacts.Views, 1) + assert.Equal(t, "t", artifacts.Views[0].Name) +} + +func TestEngine_LoadDQLComponent(t *testing.T) { + engine := shape.New( + shape.WithCompiler(shapeCompile.New()), + shape.WithLoader(shapeLoad.New()), + shape.WithName("/v1/api/reports/orders"), + ) + artifact, err := engine.LoadDQLComponent(context.Background(), "SELECT id FROM ORDERS t") + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Component) + + component, ok := artifact.Component.(*shapeLoad.Component) + require.True(t, ok) + assert.Equal(t, "/v1/api/reports/orders", component.Name) + assert.Equal(t, "t", component.RootView) +} diff --git a/repository/shape/errors.go b/repository/shape/errors.go new file mode 100644 index 00000000..852313b6 --- /dev/null +++ b/repository/shape/errors.go @@ -0,0 +1,12 @@ +package shape + +import "errors" + +var ( + ErrNilSource = errors.New("shape: source was nil") + ErrNilDQL = errors.New("shape: dql was empty") + ErrScannerNotConfigured = errors.New("shape: scanner was not configured") + ErrPlannerNotConfigured = errors.New("shape: planner was not configured") + ErrLoaderNotConfigured = errors.New("shape: loader was not configured") + ErrCompilerNotConfigured = errors.New("shape: compiler was not configured") +) diff --git a/repository/shape/load/doc.go b/repository/shape/load/doc.go new file mode 100644 index 00000000..1800597c --- /dev/null +++ b/repository/shape/load/doc.go @@ -0,0 +1,2 @@ +// Package load defines materialization responsibilities for runtime artifacts. +package load diff --git a/repository/shape/load/errors.go b/repository/shape/load/errors.go new file mode 100644 index 00000000..51f15d6a --- /dev/null +++ b/repository/shape/load/errors.go @@ -0,0 +1,7 @@ +package load + +import "errors" + +var ( + ErrEmptyViewPlan = errors.New("shape load: no views available in plan") +) diff --git a/repository/shape/load/loader.go b/repository/shape/load/loader.go new file mode 100644 index 00000000..149117d2 --- /dev/null +++ b/repository/shape/load/loader.go @@ -0,0 +1,224 @@ +package load + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + shapevalidate "github.com/viant/datly/repository/shape/validate" + "github.com/viant/datly/shared" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +// Loader materializes runtime view artifacts from normalized shape plan. +type Loader struct{} + +// New returns shape loader implementation. +func New() *Loader { + return &Loader{} +} + +// LoadViews implements shape.Loader. +func (l *Loader) LoadViews(_ context.Context, planned *shape.PlanResult, _ ...shape.LoadOption) (*shape.ViewArtifacts, error) { + pResult, resource, err := l.materialize(planned) + if err != nil { + return nil, err + } + if len(pResult.Views) == 0 { + return nil, ErrEmptyViewPlan + } + return &shape.ViewArtifacts{Resource: resource, Views: resource.Views}, nil +} + +// LoadComponent implements shape.Loader. +func (l *Loader) LoadComponent(_ context.Context, planned *shape.PlanResult, _ ...shape.LoadOption) (*shape.ComponentArtifact, error) { + pResult, resource, err := l.materialize(planned) + if err != nil { + return nil, err + } + if len(pResult.Views) == 0 { + return nil, ErrEmptyViewPlan + } + component := buildComponent(planned.Source, pResult) + return &shape.ComponentArtifact{ + Resource: resource, + Component: component, + }, nil +} + +func (l *Loader) materialize(planned *shape.PlanResult) (*plan.Result, *view.Resource, error) { + if planned == nil || planned.Source == nil { + return nil, nil, shape.ErrNilSource + } + pResult, ok := planned.Plan.(*plan.Result) + if !ok || pResult == nil { + return nil, nil, fmt.Errorf("shape load: unsupported plan type %T", planned.Plan) + } + resource := view.EmptyResource() + if pResult.EmbedFS != nil { + resource.SetFSEmbedder(state.NewFSEmbedder(pResult.EmbedFS)) + } + for _, item := range pResult.Views { + aView, err := materializeView(item) + if err != nil { + return nil, nil, err + } + resource.AddViews(aView) + } + if err := shapevalidate.ValidateRelations(resource, resource.Views...); err != nil { + return nil, nil, err + } + return pResult, resource, nil +} + +func buildComponent(source *shape.Source, pResult *plan.Result) *Component { + ret := &Component{Method: "GET"} + if source != nil { + ret.Name = source.Name + ret.URI = source.Name + } + for _, aView := range pResult.Views { + if aView == nil { + continue + } + ret.Views = append(ret.Views, aView.Name) + } + rootView := pickRootView(pResult.Views) + if rootView != nil { + ret.RootView = rootView.Name + if ret.Name == "" { + ret.Name = rootView.Name + } + } + for _, item := range pResult.States { + if item == nil { + continue + } + if strings.TrimSpace(item.Kind) == "" && strings.TrimSpace(item.In) == "" { + ret.Other = append(ret.Other, item) + continue + } + switch strings.ToLower(item.Kind) { + case "query", "path", "header", "body", "form", "cookie", "request", "": + ret.Input = append(ret.Input, item) + case "output": + ret.Output = append(ret.Output, item) + case "meta": + ret.Meta = append(ret.Meta, item) + case "async": + ret.Async = append(ret.Async, item) + default: + ret.Other = append(ret.Other, item) + } + } + ret.TypeContext = cloneTypeContext(pResult.TypeContext) + return ret +} + +func cloneTypeContext(input *typectx.Context) *typectx.Context { + if input == nil { + return nil + } + ret := &typectx.Context{ + DefaultPackage: strings.TrimSpace(input.DefaultPackage), + } + for _, item := range input.Imports { + pkg := strings.TrimSpace(item.Package) + if pkg == "" { + continue + } + ret.Imports = append(ret.Imports, typectx.Import{ + Alias: strings.TrimSpace(item.Alias), + Package: pkg, + }) + } + if ret.DefaultPackage == "" && len(ret.Imports) == 0 { + return nil + } + return ret +} + +func pickRootView(views []*plan.View) *plan.View { + var selected *plan.View + minDepth := -1 + for _, candidate := range views { + if candidate == nil || candidate.Path == "" { + continue + } + depth := strings.Count(candidate.Path, ".") + if minDepth == -1 || depth < minDepth { + minDepth = depth + selected = candidate + } + } + if selected != nil { + return selected + } + for _, candidate := range views { + if candidate != nil { + return candidate + } + } + return nil +} + +func materializeView(item *plan.View) (*view.View, error) { + if item == nil { + return nil, fmt.Errorf("shape load: nil view plan item") + } + + schemaType := bestSchemaType(item) + if schemaType == nil { + return nil, fmt.Errorf("shape load: missing schema type for view %q", item.Name) + } + + schema := newSchema(schemaType, item.Cardinality) + opts := []view.Option{view.WithSchema(schema), view.WithMode(view.ModeQuery)} + + if item.Connector != "" { + opts = append(opts, view.WithConnectorRef(item.Connector)) + } + if item.SQL != "" || item.SQLURI != "" { + tmpl := view.NewTemplate(item.SQL) + tmpl.SourceURL = item.SQLURI + opts = append(opts, view.WithTemplate(tmpl)) + } + if item.CacheRef != "" { + opts = append(opts, view.WithCache(&view.Cache{Reference: shared.Reference{Ref: item.CacheRef}})) + } + if item.Partitioner != "" { + opts = append(opts, view.WithPartitioned(&view.Partitioned{ + DataType: item.Partitioner, + Concurrency: item.PartitionedConcurrency, + })) + } + + aView, err := view.New(item.Name, item.Table, opts...) + if err != nil { + return nil, err + } + aView.Ref = item.Ref + return aView, nil +} + +func bestSchemaType(item *plan.View) reflect.Type { + if item.FieldType != nil { + return item.FieldType + } + if item.ElementType != nil { + return item.ElementType + } + return nil +} + +func newSchema(rType reflect.Type, cardinality string) *state.Schema { + if cardinality == "many" && rType.Kind() != reflect.Slice { + return state.NewSchema(rType, state.WithMany()) + } + return state.NewSchema(rType) +} diff --git a/repository/shape/load/loader_test.go b/repository/shape/load/loader_test.go new file mode 100644 index 00000000..aab074ba --- /dev/null +++ b/repository/shape/load/loader_test.go @@ -0,0 +1,116 @@ +package load + +import ( + "context" + "embed" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/scan" + "github.com/viant/datly/repository/shape/typectx" +) + +//go:embed testdata/*.sql +var testFS embed.FS + +type embeddedFS struct{} + +func (embeddedFS) EmbedFS() *embed.FS { + return &testFS +} + +type reportRow struct { + ID int + Name string +} + +type reportSource struct { + embeddedFS + Rows []reportRow `view:"rows,table=REPORT,connector=dev,cache=c1" sql:"uri=testdata/report.sql"` + ID int `parameter:"id,kind=query,in=id"` + Status any `parameter:"status,kind=output,in=status"` + Job any `parameter:"job,kind=async,in=job"` + Meta any `parameter:"meta,kind=meta,in=view.name"` +} + +func TestLoader_LoadViews(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &reportSource{}}) + require.NoError(t, err) + + planner := plan.New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + + loader := New() + artifacts, err := loader.LoadViews(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + require.NotNil(t, artifacts.Resource) + require.Len(t, artifacts.Views, 1) + + aView := artifacts.Views[0] + assert.Equal(t, "rows", aView.Name) + assert.Equal(t, "REPORT", aView.Table) + require.NotNil(t, aView.Schema) + assert.Equal(t, "Many", string(aView.Schema.Cardinality)) + require.NotNil(t, aView.Template) + assert.Equal(t, "testdata/report.sql", aView.Template.SourceURL) + assert.Contains(t, aView.Template.Source, "SELECT ID, NAME FROM REPORT") + require.NotNil(t, aView.Connector) + assert.Equal(t, "dev", aView.Connector.Ref) + require.NotNil(t, aView.Cache) + assert.Equal(t, "c1", aView.Cache.Ref) + require.NotNil(t, artifacts.Resource.EmbedFS()) +} + +func TestLoader_LoadViews_InvalidPlanType(t *testing.T) { + loader := New() + _, err := loader.LoadViews(context.Background(), &shape.PlanResult{Source: &shape.Source{Name: "x"}, Plan: "invalid"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported plan type") +} + +func TestLoader_LoadComponent(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Name: "/v1/api/report", Struct: &reportSource{}}) + require.NoError(t, err) + + planner := plan.New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + actualPlan, ok := planned.Plan.(*plan.Result) + require.True(t, ok) + actualPlan.TypeContext = &typectx.Context{ + DefaultPackage: "mdp/performance", + Imports: []typectx.Import{ + {Alias: "perf", Package: "github.com/acme/mdp/performance"}, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifact) + require.NotNil(t, artifact.Resource) + require.NotNil(t, artifact.Component) + + component, ok := artifact.Component.(*Component) + require.True(t, ok) + assert.Equal(t, "/v1/api/report", component.Name) + assert.Equal(t, "/v1/api/report", component.URI) + assert.Equal(t, "GET", component.Method) + assert.Equal(t, "rows", component.RootView) + assert.Equal(t, []string{"rows"}, component.Views) + assert.Len(t, component.Input, 1) + assert.Len(t, component.Output, 1) + assert.Len(t, component.Async, 1) + assert.Len(t, component.Meta, 1) + require.NotNil(t, component.TypeContext) + assert.Equal(t, "mdp/performance", component.TypeContext.DefaultPackage) + require.Len(t, component.TypeContext.Imports, 1) + assert.Equal(t, "perf", component.TypeContext.Imports[0].Alias) +} diff --git a/repository/shape/load/model.go b/repository/shape/load/model.go new file mode 100644 index 00000000..8f5d384d --- /dev/null +++ b/repository/shape/load/model.go @@ -0,0 +1,21 @@ +package load + +import "github.com/viant/datly/repository/shape/plan" +import "github.com/viant/datly/repository/shape/typectx" + +// Component is a shape-loaded runtime-neutral component artifact. +// It intentionally avoids repository package coupling to keep shape/load reusable. +type Component struct { + Name string + URI string + Method string + RootView string + Views []string + TypeContext *typectx.Context + + Input []*plan.State + Output []*plan.State + Meta []*plan.State + Async []*plan.State + Other []*plan.State +} diff --git a/repository/shape/load/testdata/report.sql b/repository/shape/load/testdata/report.sql new file mode 100644 index 00000000..68f0f3b3 --- /dev/null +++ b/repository/shape/load/testdata/report.sql @@ -0,0 +1 @@ +SELECT ID, NAME FROM REPORT diff --git a/repository/shape/model.go b/repository/shape/model.go new file mode 100644 index 00000000..f71fd5c2 --- /dev/null +++ b/repository/shape/model.go @@ -0,0 +1,53 @@ +package shape + +import ( + "reflect" + + "github.com/viant/datly/view" + "github.com/viant/x" +) + +// Mode controls which execution flow is expected from the shape pipeline. +type Mode string + +const ( + ModeUnspecified Mode = "" + ModeStruct Mode = "struct" + ModeDQL Mode = "dql" +) + +// Source represents the caller-provided shape source. +type Source struct { + Name string + Struct any + Type reflect.Type + TypeName string + TypeRegistry *x.Registry + DQL string +} + +// ScanResult is the output produced by Scanner. +type ScanResult struct { + Source *Source + Descriptors any +} + +// PlanResult is the output produced by Planner. +type PlanResult struct { + Source *Source + Plan any +} + +// ViewArtifacts is the runtime view payload produced by Loader. +type ViewArtifacts struct { + Resource *view.Resource + Views view.Views +} + +// ComponentArtifact is the runtime component payload produced by Loader. +// Component stays untyped in the skeleton to avoid coupling shape package +// to repository internals before the implementation phase. +type ComponentArtifact struct { + Resource *view.Resource + Component any +} diff --git a/repository/shape/options.go b/repository/shape/options.go new file mode 100644 index 00000000..05b0a774 --- /dev/null +++ b/repository/shape/options.go @@ -0,0 +1,73 @@ +package shape + +// Options stores shape facade dependencies and behavior flags. +type Options struct { + Mode Mode + Strict bool + Name string + Scanner Scanner + Planner Planner + Loader Loader + Compiler DQLCompiler + Runtime RuntimeRegistrar +} + +// Option mutates Options. +type Option func(*Options) + +// NewOptions builds Options from varargs. +func NewOptions(opts ...Option) *Options { + ret := &Options{} + for _, opt := range opts { + opt(ret) + } + return ret +} + +func WithMode(mode Mode) Option { + return func(o *Options) { + o.Mode = mode + } +} + +func WithStrict(strict bool) Option { + return func(o *Options) { + o.Strict = strict + } +} + +func WithName(name string) Option { + return func(o *Options) { + o.Name = name + } +} + +func WithScanner(scanner Scanner) Option { + return func(o *Options) { + o.Scanner = scanner + } +} + +func WithPlanner(planner Planner) Option { + return func(o *Options) { + o.Planner = planner + } +} + +func WithLoader(loader Loader) Option { + return func(o *Options) { + o.Loader = loader + } +} + +func WithCompiler(compiler DQLCompiler) Option { + return func(o *Options) { + o.Compiler = compiler + } +} + +func WithRuntime(runtime RuntimeRegistrar) Option { + return func(o *Options) { + o.Runtime = runtime + } +} diff --git a/repository/shape/parity_test.go b/repository/shape/parity_test.go new file mode 100644 index 00000000..713bfd31 --- /dev/null +++ b/repository/shape/parity_test.go @@ -0,0 +1,67 @@ +package shape_test + +import ( + "context" + "embed" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + shape "github.com/viant/datly/repository/shape" + shapeLoad "github.com/viant/datly/repository/shape/load" + shapePlan "github.com/viant/datly/repository/shape/plan" + shapeScan "github.com/viant/datly/repository/shape/scan" +) + +//go:embed scan/testdata/*.sql +var parityFS embed.FS + +type parityEmbedded struct{} + +func (parityEmbedded) EmbedFS() *embed.FS { return &parityFS } + +type parityRow struct { + ID int + Name string +} + +type paritySource struct { + parityEmbedded + Rows []parityRow `view:"rows,table=REPORT,connector=dev" sql:"uri=scan/testdata/report.sql"` +} + +func TestEngineParity_StructPipeline(t *testing.T) { + source := &paritySource{} + scanner := shapeScan.New() + planner := shapePlan.New() + loader := shapeLoad.New() + + manualScan, err := scanner.Scan(context.Background(), &shape.Source{Name: "/v1/api/parity", Struct: source}) + require.NoError(t, err) + manualPlan, err := planner.Plan(context.Background(), manualScan) + require.NoError(t, err) + manualViews, err := loader.LoadViews(context.Background(), manualPlan) + require.NoError(t, err) + + engine := shape.New( + shape.WithName("/v1/api/parity"), + shape.WithScanner(scanner), + shape.WithPlanner(planner), + shape.WithLoader(loader), + ) + engineViews, err := engine.LoadViews(context.Background(), source) + require.NoError(t, err) + + require.Len(t, manualViews.Views, 1) + require.Len(t, engineViews.Views, 1) + + mv := manualViews.Views[0] + ev := engineViews.Views[0] + assert.Equal(t, mv.Name, ev.Name) + assert.Equal(t, mv.Table, ev.Table) + assert.Equal(t, mv.Template.Source, ev.Template.Source) + assert.Equal(t, mv.Template.SourceURL, ev.Template.SourceURL) + assert.Equal(t, mv.Schema.Cardinality, ev.Schema.Cardinality) + assert.Equal(t, reflect.TypeOf(mv.Schema.CompType()), reflect.TypeOf(ev.Schema.CompType())) +} diff --git a/repository/shape/plan/doc.go b/repository/shape/plan/doc.go new file mode 100644 index 00000000..57bb65fa --- /dev/null +++ b/repository/shape/plan/doc.go @@ -0,0 +1,2 @@ +// Package plan defines normalization and shape-planning responsibilities. +package plan diff --git a/repository/shape/plan/model.go b/repository/shape/plan/model.go new file mode 100644 index 00000000..8dacf2bb --- /dev/null +++ b/repository/shape/plan/model.go @@ -0,0 +1,72 @@ +package plan + +import ( + "embed" + "reflect" + + "github.com/viant/datly/repository/shape/typectx" +) + +// Result is normalized shape plan produced from scan descriptors. +type Result struct { + RootType reflect.Type + EmbedFS *embed.FS + + Fields []*Field + ByPath map[string]*Field + Views []*View + ViewsByName map[string]*View + States []*State + TypeContext *typectx.Context +} + +// Field is a normalized projection of scanned field metadata. +type Field struct { + Path string + Name string + Type reflect.Type + Index []int +} + +// View is a normalized view field plan. +type View struct { + Path string + Name string + Ref string + Table string + Connector string + CacheRef string + Partitioner string + PartitionedConcurrency int + RelationalConcurrency int + SQL string + SQLURI string + Summary string + Links []string + Holder string + + Cardinality string + ElementType reflect.Type + FieldType reflect.Type +} + +// State is a normalized parameter field plan. +type State struct { + Path string + Name string + Kind string + In string + When string + Scope string + DataType string + Required *bool + Async bool + Cacheable *bool + With string + URI string + ErrorCode int + ErrorMessage string + + TagType reflect.Type + EffectiveType reflect.Type +} diff --git a/repository/shape/plan/planner.go b/repository/shape/plan/planner.go new file mode 100644 index 00000000..ec66aea5 --- /dev/null +++ b/repository/shape/plan/planner.go @@ -0,0 +1,174 @@ +package plan + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/viant/datly/repository/locator/async/keys" + metakeys "github.com/viant/datly/repository/locator/meta/keys" + outputkeys "github.com/viant/datly/repository/locator/output/keys" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/scan" +) + +// Planner normalizes scan descriptors into shape plan. +type Planner struct{} + +// New returns shape planner implementation. +func New() *Planner { + return &Planner{} +} + +// Plan implements shape.Planner. +func (p *Planner) Plan(_ context.Context, scanned *shape.ScanResult, _ ...shape.PlanOption) (*shape.PlanResult, error) { + if scanned == nil || scanned.Source == nil { + return nil, shape.ErrNilSource + } + + scanResult, ok := scanned.Descriptors.(*scan.Result) + if !ok || scanResult == nil { + return nil, fmt.Errorf("shape plan: unsupported descriptors type %T", scanned.Descriptors) + } + + result := &Result{ + RootType: scanResult.RootType, + EmbedFS: scanResult.EmbedFS, + ByPath: map[string]*Field{}, + ViewsByName: map[string]*View{}, + } + + for _, item := range scanResult.Fields { + field := &Field{ + Path: item.Path, + Name: item.Name, + Type: item.Type, + Index: append([]int(nil), item.Index...), + } + result.Fields = append(result.Fields, field) + result.ByPath[field.Path] = field + } + + for _, item := range scanResult.ViewFields { + v := normalizeView(item) + result.Views = append(result.Views, v) + if v.Name != "" { + result.ViewsByName[v.Name] = v + } + } + + for _, item := range scanResult.StateFields { + result.States = append(result.States, normalizeState(item)) + } + + return &shape.PlanResult{Source: scanned.Source, Plan: result}, nil +} + +func normalizeView(field *scan.Field) *View { + result := &View{ + Path: field.Path, + Holder: field.Name, + FieldType: field.Type, + } + + if tag := field.ViewTag; tag != nil { + if tag.View != nil { + result.Name = tag.View.Name + result.Table = tag.View.Table + result.Connector = tag.View.Connector + result.CacheRef = tag.View.Cache + result.Partitioner = tag.View.PartitionerType + result.PartitionedConcurrency = tag.View.PartitionedConcurrency + result.RelationalConcurrency = tag.View.RelationalConcurrency + } + result.SQL = tag.SQL.SQL + result.SQLURI = tag.SQL.URI + result.Summary = tag.SummarySQL.SQL + if len(tag.LinkOn) > 0 { + result.Links = append(result.Links, tag.LinkOn...) + } + result.Ref = strings.TrimSpace(tag.TypeName) + } + + if result.Name == "" { + result.Name = field.Name + } + + elem, cardinality := componentType(field.Type) + result.Cardinality = cardinality + result.ElementType = elem + return result +} + +func normalizeState(field *scan.Field) *State { + result := &State{Path: field.Path, TagType: field.Type} + if field.StateTag == nil || field.StateTag.Parameter == nil { + result.Name = field.Name + result.EffectiveType = field.Type + return result + } + + pTag := field.StateTag.Parameter + result.Name = firstNonEmpty(pTag.Name, field.Name) + result.Kind = strings.ToLower(strings.TrimSpace(pTag.Kind)) + result.In = strings.TrimSpace(pTag.In) + result.When = pTag.When + result.Scope = pTag.Scope + result.DataType = pTag.DataType + result.Required = pTag.Required + result.Async = pTag.Async + result.Cacheable = pTag.Cacheable + result.With = pTag.With + result.URI = pTag.URI + result.ErrorCode = pTag.ErrorCode + result.ErrorMessage = pTag.ErrorMessage + + result.EffectiveType = resolveStateType(result, field.Type) + return result +} + +func resolveStateType(item *State, fallback reflect.Type) reflect.Type { + key := strings.ToLower(strings.TrimSpace(firstNonEmpty(item.In, item.Name))) + switch item.Kind { + case "output": + if rType, ok := outputkeys.Types[key]; ok { + return rType + } + case "meta": + if rType, ok := metakeys.Types[key]; ok { + return rType + } + case "async": + if rType, ok := keys.Types[key]; ok { + return rType + } + } + return fallback +} + +func componentType(rType reflect.Type) (reflect.Type, string) { + if rType == nil { + return nil, "one" + } + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + if rType.Kind() == reflect.Slice { + elem := rType.Elem() + for elem.Kind() == reflect.Ptr { + elem = elem.Elem() + } + return elem, "many" + } + return rType, "one" +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} diff --git a/repository/shape/plan/planner_test.go b/repository/shape/plan/planner_test.go new file mode 100644 index 00000000..29bb1e79 --- /dev/null +++ b/repository/shape/plan/planner_test.go @@ -0,0 +1,86 @@ +package plan + +import ( + "context" + "embed" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + asynckeys "github.com/viant/datly/repository/locator/async/keys" + metakeys "github.com/viant/datly/repository/locator/meta/keys" + outputkeys "github.com/viant/datly/repository/locator/output/keys" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/scan" +) + +//go:embed testdata/*.sql +var testFS embed.FS + +type embeddedFS struct{} + +func (embeddedFS) EmbedFS() *embed.FS { + return &testFS +} + +type reportRow struct { + ID int +} + +type reportSource struct { + embeddedFS + Rows []reportRow `view:"rows,table=REPORT,connector=dev" sql:"uri=testdata/report.sql"` + Status interface{} `parameter:"status,kind=output,in=status"` + Job interface{} `parameter:"job,kind=async,in=job"` + VName interface{} `parameter:"viewName,kind=meta,in=view.name"` + ID int `parameter:"id,kind=query,in=id"` +} + +func TestPlanner_Plan(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &reportSource{}}) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + require.NotNil(t, planned) + + result, ok := planned.Plan.(*Result) + require.True(t, ok) + require.NotNil(t, result) + require.NotNil(t, result.EmbedFS) + + require.Len(t, result.Views, 1) + rows := result.Views[0] + assert.Equal(t, "rows", rows.Name) + assert.Equal(t, "REPORT", rows.Table) + assert.Equal(t, "dev", rows.Connector) + assert.Equal(t, "many", rows.Cardinality) + assert.Equal(t, "Rows", rows.Holder) + assert.Contains(t, rows.SQL, "SELECT ID") + + stateByPath := map[string]*State{} + for _, item := range result.States { + stateByPath[item.Path] = item + } + + require.NotNil(t, stateByPath["Status"]) + assert.Equal(t, outputkeys.Types["status"], stateByPath["Status"].EffectiveType) + require.NotNil(t, stateByPath["Job"]) + assert.Equal(t, asynckeys.Types["job"], stateByPath["Job"].EffectiveType) + require.NotNil(t, stateByPath["VName"]) + assert.Equal(t, metakeys.Types["view.name"], stateByPath["VName"].EffectiveType) + + require.NotNil(t, stateByPath["ID"]) + assert.Equal(t, "query", stateByPath["ID"].Kind) + assert.Equal(t, "id", stateByPath["ID"].In) + assert.Equal(t, stateByPath["ID"].TagType, stateByPath["ID"].EffectiveType) +} + +func TestPlanner_Plan_InvalidDescriptors(t *testing.T) { + planner := New() + _, err := planner.Plan(context.Background(), &shape.ScanResult{Source: &shape.Source{Name: "x"}, Descriptors: "invalid"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported descriptors type") +} diff --git a/repository/shape/plan/testdata/report.sql b/repository/shape/plan/testdata/report.sql new file mode 100644 index 00000000..7aab3a1f --- /dev/null +++ b/repository/shape/plan/testdata/report.sql @@ -0,0 +1 @@ +SELECT ID FROM REPORT diff --git a/repository/shape/scan/doc.go b/repository/shape/scan/doc.go new file mode 100644 index 00000000..e1f10577 --- /dev/null +++ b/repository/shape/scan/doc.go @@ -0,0 +1,2 @@ +// Package scan defines scanning responsibilities for struct/DQL inputs. +package scan diff --git a/repository/shape/scan/model.go b/repository/shape/scan/model.go new file mode 100644 index 00000000..35729925 --- /dev/null +++ b/repository/shape/scan/model.go @@ -0,0 +1,33 @@ +package scan + +import ( + "embed" + "reflect" + + "github.com/viant/datly/view/tags" +) + +// Result holds scan output produced from a struct source. +type Result struct { + RootType reflect.Type + EmbedFS *embed.FS + Fields []*Field + ByPath map[string]*Field + ViewFields []*Field + StateFields []*Field +} + +// Field describes one scanned struct field. +type Field struct { + Path string + Name string + Index []int + Type reflect.Type + Tag reflect.StructTag + Anonymous bool + + HasViewTag bool + HasStateTag bool + ViewTag *tags.Tag + StateTag *tags.Tag +} diff --git a/repository/shape/scan/scanner.go b/repository/shape/scan/scanner.go new file mode 100644 index 00000000..d15d34f3 --- /dev/null +++ b/repository/shape/scan/scanner.go @@ -0,0 +1,166 @@ +package scan + +import ( + "context" + "fmt" + "reflect" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/view/state" + "github.com/viant/datly/view/tags" +) + +// StructScanner scans arbitrary struct types and extracts Datly-relevant tags. +type StructScanner struct{} + +// New returns a Scanner implementation for shape facade. +func New() *StructScanner { + return &StructScanner{} +} + +// Scan implements shape.Scanner. +func (s *StructScanner) Scan(_ context.Context, source *shape.Source, _ ...shape.ScanOption) (*shape.ScanResult, error) { + if source == nil { + return nil, shape.ErrNilSource + } + source.EnsureTypeRegistry() + + root, err := resolveRootType(source) + if err != nil { + return nil, err + } + + embedder := resolveEmbedder(source) + result := &Result{ + RootType: root, + EmbedFS: embedder.EmbedFS(), + ByPath: map[string]*Field{}, + } + + if err = s.scanStruct(root, "", nil, embedder, result, map[reflect.Type]bool{}); err != nil { + return nil, err + } + + return &shape.ScanResult{Source: source, Descriptors: result}, nil +} + +func resolveRootType(source *shape.Source) (reflect.Type, error) { + rType, err := source.ResolveRootType() + if err != nil { + return nil, err + } + if rType == nil { + return nil, shape.ErrNilSource + } + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + if rType.Kind() != reflect.Struct { + return nil, fmt.Errorf("shape scan: unsupported source type %v, expected struct", rType) + } + return rType, nil +} + +func resolveEmbedder(source *shape.Source) *state.FSEmbedder { + embedder := state.NewFSEmbedder(nil) + if source.Type != nil { + rType := source.Type + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + embedder.SetType(rType) + return embedder + } + if source.Struct != nil { + rType := reflect.TypeOf(source.Struct) + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + embedder.SetType(rType) + } + return embedder +} + +func (s *StructScanner) scanStruct( + rType reflect.Type, + prefix string, + indexPrefix []int, + embedder *state.FSEmbedder, + result *Result, + visited map[reflect.Type]bool, +) error { + if visited[rType] { + return nil + } + visited[rType] = true + defer delete(visited, rType) + + for i := 0; i < rType.NumField(); i++ { + field := rType.Field(i) + path := field.Name + if prefix != "" { + path = prefix + "." + field.Name + } + combinedIndex := append(append([]int{}, indexPrefix...), field.Index...) + + descriptor := &Field{ + Path: path, + Name: field.Name, + Index: combinedIndex, + Type: field.Type, + Tag: field.Tag, + Anonymous: field.Anonymous, + } + + if hasAny(field.Tag, tags.ViewTag, tags.SQLTag, tags.SQLSummaryTag, tags.LinkOnTag) { + parsed, err := tags.ParseViewTags(field.Tag, embedder.EmbedFS()) + if err != nil { + return fmt.Errorf("shape scan: failed to parse view tags on %s: %w", path, err) + } + descriptor.HasViewTag = true + descriptor.ViewTag = parsed + result.ViewFields = append(result.ViewFields, descriptor) + } + + if hasAny(field.Tag, tags.ParameterTag, tags.SQLTag, tags.PredicateTag, tags.CodecTag, tags.HandlerTag) { + parsed, err := tags.ParseStateTags(field.Tag, embedder.EmbedFS()) + if err != nil { + return fmt.Errorf("shape scan: failed to parse state tags on %s: %w", path, err) + } + descriptor.HasStateTag = true + descriptor.StateTag = parsed + result.StateFields = append(result.StateFields, descriptor) + } + + result.Fields = append(result.Fields, descriptor) + result.ByPath[path] = descriptor + + nextType := field.Type + for nextType.Kind() == reflect.Ptr { + nextType = nextType.Elem() + } + if field.Anonymous && nextType.Kind() == reflect.Struct && !isStdlib(nextType.PkgPath()) { + if err := s.scanStruct(nextType, path, combinedIndex, embedder, result, visited); err != nil { + return err + } + } + } + return nil +} + +func hasAny(tag reflect.StructTag, names ...string) bool { + for _, name := range names { + if _, ok := tag.Lookup(name); ok { + return true + } + } + return false +} + +func isStdlib(pkg string) bool { + if pkg == "" { + return true + } + return !strings.Contains(pkg, ".") +} diff --git a/repository/shape/scan/scanner_test.go b/repository/shape/scan/scanner_test.go new file mode 100644 index 00000000..7cce9cbc --- /dev/null +++ b/repository/shape/scan/scanner_test.go @@ -0,0 +1,83 @@ +package scan + +import ( + "context" + "embed" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/x" +) + +//go:embed testdata/*.sql +var testFS embed.FS + +type embeddedFS struct{} + +func (embeddedFS) EmbedFS() *embed.FS { + return &testFS +} + +type reportRow struct { + ID int + Name string +} + +type reportSource struct { + embeddedFS + Rows []reportRow `view:"rows,table=REPORT,connector=dev" sql:"uri=testdata/report.sql"` + ID int `parameter:"id,kind=query,in=id"` +} + +func TestStructScanner_Scan(t *testing.T) { + scanner := New() + result, err := scanner.Scan(context.Background(), &shape.Source{Struct: &reportSource{}}) + require.NoError(t, err) + require.NotNil(t, result) + + descriptors, ok := result.Descriptors.(*Result) + require.True(t, ok) + require.NotNil(t, descriptors) + require.NotNil(t, descriptors.EmbedFS) + assert.Equal(t, reflect.TypeOf(reportSource{}), descriptors.RootType) + + rows := descriptors.ByPath["Rows"] + require.NotNil(t, rows) + require.True(t, rows.HasViewTag) + require.NotNil(t, rows.ViewTag) + assert.Equal(t, "rows", rows.ViewTag.View.Name) + assert.Contains(t, rows.ViewTag.SQL.SQL, "SELECT ID, NAME FROM REPORT") + + idField := descriptors.ByPath["ID"] + require.NotNil(t, idField) + require.True(t, idField.HasStateTag) + require.NotNil(t, idField.StateTag) + require.NotNil(t, idField.StateTag.Parameter) + assert.Equal(t, "id", idField.StateTag.Parameter.Name) + assert.Equal(t, "query", idField.StateTag.Parameter.Kind) + assert.Equal(t, "id", idField.StateTag.Parameter.In) +} + +func TestStructScanner_Scan_InvalidSource(t *testing.T) { + scanner := New() + _, err := scanner.Scan(context.Background(), &shape.Source{Struct: 1}) + require.Error(t, err) + assert.Contains(t, err.Error(), "expected struct") +} + +func TestStructScanner_Scan_WithRegistryType(t *testing.T) { + scanner := New() + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(reportSource{}))) + result, err := scanner.Scan(context.Background(), &shape.Source{ + TypeName: "github.com/viant/datly/repository/shape/scan.reportSource", + TypeRegistry: registry, + }) + require.NoError(t, err) + descriptors, ok := result.Descriptors.(*Result) + require.True(t, ok) + assert.Equal(t, reflect.TypeOf(reportSource{}), descriptors.RootType) +} diff --git a/repository/shape/scan/testdata/report.sql b/repository/shape/scan/testdata/report.sql new file mode 100644 index 00000000..68f0f3b3 --- /dev/null +++ b/repository/shape/scan/testdata/report.sql @@ -0,0 +1 @@ +SELECT ID, NAME FROM REPORT diff --git a/repository/shape/shape.go b/repository/shape/shape.go new file mode 100644 index 00000000..570a63d5 --- /dev/null +++ b/repository/shape/shape.go @@ -0,0 +1,157 @@ +package shape + +import "context" + +type ( + // Scanner discovers shape descriptors from Source. + Scanner interface { + Scan(ctx context.Context, source *Source, opts ...ScanOption) (*ScanResult, error) + } + + // Planner normalizes discovered descriptors into execution plan. + Planner interface { + Plan(ctx context.Context, scan *ScanResult, opts ...PlanOption) (*PlanResult, error) + } + + // Loader materializes runtime artifacts from normalized plan. + Loader interface { + LoadViews(ctx context.Context, plan *PlanResult, opts ...LoadOption) (*ViewArtifacts, error) + LoadComponent(ctx context.Context, plan *PlanResult, opts ...LoadOption) (*ComponentArtifact, error) + } + + // DQLCompiler compiles DQL source directly into a shape plan. + DQLCompiler interface { + Compile(ctx context.Context, source *Source, opts ...CompileOption) (*PlanResult, error) + } + + // RuntimeRegistrar optionally registers loaded artifacts in runtime services. + RuntimeRegistrar interface { + RegisterViews(ctx context.Context, artifacts *ViewArtifacts) error + RegisterComponent(ctx context.Context, artifacts *ComponentArtifact) error + } + + ScanOptions struct{} + PlanOptions struct{} + LoadOptions struct{} + CompileOptions struct{} + + ScanOption func(*ScanOptions) + PlanOption func(*PlanOptions) + LoadOption func(*LoadOptions) + CompileOption func(*CompileOptions) +) + +// Engine is a thin facade over scan -> plan -> load pipeline. +type Engine struct { + options *Options +} + +// New creates an Engine facade. +func New(opts ...Option) *Engine { + return &Engine{options: NewOptions(opts...)} +} + +// LoadViews is a package-level helper for struct source view loading. +func LoadViews(ctx context.Context, src any, opts ...Option) (*ViewArtifacts, error) { + return New(opts...).LoadViews(ctx, src) +} + +// LoadComponent is a package-level helper for struct source component loading. +func LoadComponent(ctx context.Context, src any, opts ...Option) (*ComponentArtifact, error) { + return New(opts...).LoadComponent(ctx, src) +} + +// LoadDQLViews is a package-level helper for DQL source view loading. +func LoadDQLViews(ctx context.Context, dql string, opts ...Option) (*ViewArtifacts, error) { + return New(opts...).LoadDQLViews(ctx, dql) +} + +// LoadDQLComponent is a package-level helper for DQL source component loading. +func LoadDQLComponent(ctx context.Context, dql string, opts ...Option) (*ComponentArtifact, error) { + return New(opts...).LoadDQLComponent(ctx, dql) +} + +// LoadViews executes scan -> plan -> load for struct source. +func (e *Engine) LoadViews(ctx context.Context, src any) (*ViewArtifacts, error) { + source, err := e.structSource(src) + if err != nil { + return nil, err + } + plan, err := e.scanAndPlan(ctx, source) + if err != nil { + return nil, err + } + if e.options.Loader == nil { + return nil, ErrLoaderNotConfigured + } + return e.options.Loader.LoadViews(ctx, plan) +} + +// LoadComponent executes scan -> plan -> load for struct source. +func (e *Engine) LoadComponent(ctx context.Context, src any) (*ComponentArtifact, error) { + source, err := e.structSource(src) + if err != nil { + return nil, err + } + plan, err := e.scanAndPlan(ctx, source) + if err != nil { + return nil, err + } + if e.options.Loader == nil { + return nil, ErrLoaderNotConfigured + } + return e.options.Loader.LoadComponent(ctx, plan) +} + +// LoadDQLViews executes compile -> load for DQL source. +func (e *Engine) LoadDQLViews(ctx context.Context, dql string) (*ViewArtifacts, error) { + source, err := e.dqlSource(dql) + if err != nil { + return nil, err + } + plan, err := e.compile(ctx, source) + if err != nil { + return nil, err + } + if e.options.Loader == nil { + return nil, ErrLoaderNotConfigured + } + return e.options.Loader.LoadViews(ctx, plan) +} + +// LoadDQLComponent executes compile -> load for DQL source. +func (e *Engine) LoadDQLComponent(ctx context.Context, dql string) (*ComponentArtifact, error) { + source, err := e.dqlSource(dql) + if err != nil { + return nil, err + } + plan, err := e.compile(ctx, source) + if err != nil { + return nil, err + } + if e.options.Loader == nil { + return nil, ErrLoaderNotConfigured + } + return e.options.Loader.LoadComponent(ctx, plan) +} + +func (e *Engine) compile(ctx context.Context, source *Source) (*PlanResult, error) { + if e.options.Compiler == nil { + return nil, ErrCompilerNotConfigured + } + return e.options.Compiler.Compile(ctx, source) +} + +func (e *Engine) scanAndPlan(ctx context.Context, source *Source) (*PlanResult, error) { + if e.options.Scanner == nil { + return nil, ErrScannerNotConfigured + } + if e.options.Planner == nil { + return nil, ErrPlannerNotConfigured + } + scanResult, err := e.options.Scanner.Scan(ctx, source) + if err != nil { + return nil, err + } + return e.options.Planner.Plan(ctx, scanResult) +} diff --git a/repository/shape/source.go b/repository/shape/source.go new file mode 100644 index 00000000..e408c216 --- /dev/null +++ b/repository/shape/source.go @@ -0,0 +1,39 @@ +package shape + +import ( + "reflect" + "strings" + + "github.com/viant/x" +) + +func (e *Engine) structSource(src any) (*Source, error) { + if src == nil { + return nil, ErrNilSource + } + rType := reflect.TypeOf(src) + for rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + registry := x.NewRegistry() + registry.Register(x.NewType(rType)) + return &Source{ + Name: e.options.Name, + Struct: src, + Type: rType, + TypeName: x.NewType(rType).Key(), + TypeRegistry: registry, + DQL: "", + }, nil +} + +func (e *Engine) dqlSource(dql string) (*Source, error) { + dql = strings.TrimSpace(dql) + if dql == "" { + return nil, ErrNilDQL + } + return &Source{ + Name: e.options.Name, + DQL: dql, + }, nil +} diff --git a/repository/shape/source_type.go b/repository/shape/source_type.go new file mode 100644 index 00000000..51bc7132 --- /dev/null +++ b/repository/shape/source_type.go @@ -0,0 +1,56 @@ +package shape + +import ( + "fmt" + "reflect" + "strings" + + "github.com/viant/x" +) + +// ResolveRootType resolves source root type from explicit Type, Struct, or viant/x registry. +func (s *Source) ResolveRootType() (reflect.Type, error) { + if s == nil { + return nil, ErrNilSource + } + if s.Type != nil { + return unwrapPtr(s.Type), nil + } + if s.Struct != nil { + return unwrapPtr(reflect.TypeOf(s.Struct)), nil + } + key := strings.TrimSpace(s.TypeName) + if key == "" || s.TypeRegistry == nil { + return nil, ErrNilSource + } + aType := s.TypeRegistry.Lookup(key) + if aType == nil || aType.Type == nil { + return nil, fmt.Errorf("shape source: type %q not found in registry", key) + } + return unwrapPtr(aType.Type), nil +} + +// EnsureTypeRegistry returns source registry ensuring root type is registered when available. +func (s *Source) EnsureTypeRegistry() *x.Registry { + if s == nil { + return nil + } + if s.TypeRegistry == nil { + s.TypeRegistry = x.NewRegistry() + } + if rType, err := s.ResolveRootType(); err == nil && rType != nil { + t := x.NewType(rType) + if strings.TrimSpace(s.TypeName) == "" { + s.TypeName = t.Key() + } + s.TypeRegistry.Register(t) + } + return s.TypeRegistry +} + +func unwrapPtr(rType reflect.Type) reflect.Type { + for rType != nil && rType.Kind() == reflect.Ptr { + rType = rType.Elem() + } + return rType +} diff --git a/repository/shape/source_type_test.go b/repository/shape/source_type_test.go new file mode 100644 index 00000000..3118f8fe --- /dev/null +++ b/repository/shape/source_type_test.go @@ -0,0 +1,33 @@ +package shape + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/x" +) + +type sampleShape struct { + ID int +} + +func TestSource_ResolveRootType_FromRegistry(t *testing.T) { + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(sampleShape{}))) + src := &Source{ + TypeName: "github.com/viant/datly/repository/shape.sampleShape", + TypeRegistry: registry, + } + rType, err := src.ResolveRootType() + require.NoError(t, err) + require.Equal(t, reflect.TypeOf(sampleShape{}), rType) +} + +func TestSource_EnsureTypeRegistry_RegistersRoot(t *testing.T) { + src := &Source{Struct: &sampleShape{}} + registry := src.EnsureTypeRegistry() + require.NotNil(t, registry) + require.NotEmpty(t, src.TypeName) + require.NotNil(t, registry.Lookup(src.TypeName)) +} diff --git a/repository/shape/typectx/model.go b/repository/shape/typectx/model.go new file mode 100644 index 00000000..ae76febe --- /dev/null +++ b/repository/shape/typectx/model.go @@ -0,0 +1,29 @@ +package typectx + +// Import describes one package alias import for DQL/type resolution. +type Import struct { + Alias string `json:",omitempty" yaml:",omitempty"` + Package string `json:",omitempty" yaml:",omitempty"` +} + +// Context captures default package and imports used for type resolution. +type Context struct { + DefaultPackage string `json:",omitempty" yaml:",omitempty"` + Imports []Import `json:",omitempty" yaml:",omitempty"` +} + +// Provenance tracks where a resolved type came from. +type Provenance struct { + Package string `json:",omitempty" yaml:",omitempty"` + File string `json:",omitempty" yaml:",omitempty"` + Kind string `json:",omitempty" yaml:",omitempty"` // builtin, resource_type, registry, ast_type +} + +// Resolution captures one resolved type expression and its provenance. +type Resolution struct { + Expression string `json:",omitempty" yaml:",omitempty"` + Target string `json:",omitempty" yaml:",omitempty"` + ResolvedKey string `json:",omitempty" yaml:",omitempty"` + MatchKind string `json:",omitempty" yaml:",omitempty"` // exact, alias_import, qualified, default_package, import_package, global_unique + Provenance Provenance `json:",omitempty" yaml:",omitempty"` +} diff --git a/repository/shape/typectx/resolver.go b/repository/shape/typectx/resolver.go new file mode 100644 index 00000000..daccf3b4 --- /dev/null +++ b/repository/shape/typectx/resolver.go @@ -0,0 +1,293 @@ +package typectx + +import ( + "fmt" + "path" + "sort" + "strings" + + "github.com/viant/x" +) + +// AmbiguityError reports multiple matching type candidates for a type expression. +type AmbiguityError struct { + Expression string + Candidates []string +} + +func (e *AmbiguityError) Error() string { + return fmt.Sprintf("ambiguous type %q: candidates=%s", e.Expression, strings.Join(e.Candidates, ",")) +} + +// Resolver resolves cast/tag type expressions against viant/x registry using type context. +type Resolver struct { + registry *x.Registry + context *Context + provenance map[string]Provenance +} + +// NewResolver creates a type resolver. +func NewResolver(registry *x.Registry, context *Context) *Resolver { + return NewResolverWithProvenance(registry, context, nil) +} + +// NewResolverWithProvenance creates a type resolver with optional registry-key provenance map. +func NewResolverWithProvenance(registry *x.Registry, context *Context, provenance map[string]Provenance) *Resolver { + return &Resolver{ + registry: registry, + context: normalizeContext(context), + provenance: cloneProvenance(provenance), + } +} + +// Resolve resolves type expression to registry key. It returns ("", nil) when unresolved. +func (r *Resolver) Resolve(typeExpr string) (string, error) { + resolved, err := r.ResolveWithProvenance(typeExpr) + if err != nil || resolved == nil { + return "", err + } + return resolved.ResolvedKey, nil +} + +// ResolveWithProvenance resolves expression and returns provenance details. +// It returns (nil, nil) when unresolved. +func (r *Resolver) ResolveWithProvenance(typeExpr string) (*Resolution, error) { + if r == nil || r.registry == nil { + return nil, nil + } + base := normalizeLookupKey(typeExpr) + if base == "" { + return nil, nil + } + + // Exact type key (builtins or fully-qualified package.Type) + if r.registry.Lookup(base) != nil { + return r.newResolution(typeExpr, "", base, "exact"), nil + } + + prefix, baseName, alias, qualified := splitQualified(base) + if qualified { + if prefix == "" || baseName == "" { + return nil, nil + } + if alias { + pkg := r.aliasPackage(prefix) + if pkg == "" { + return nil, nil + } + candidate := pkg + "." + baseName + if r.registry.Lookup(candidate) == nil { + return nil, nil + } + return r.newResolution(typeExpr, "", candidate, "alias_import"), nil + } + // fully qualified package path.Type + if r.registry.Lookup(base) != nil { + return r.newResolution(typeExpr, "", base, "qualified"), nil + } + return nil, nil + } + + // Unqualified resolution: default package, then imports; if still unresolved, + // fallback to unique global name match. + candidates := r.unqualifiedCandidates(baseName) + if len(candidates) == 1 { + return r.newResolution(typeExpr, "", candidates[0].key, candidates[0].matchKind), nil + } + if len(candidates) > 1 { + keys := make([]string, 0, len(candidates)) + for _, candidate := range candidates { + keys = append(keys, candidate.key) + } + sort.Strings(keys) + return nil, &AmbiguityError{Expression: typeExpr, Candidates: keys} + } + return nil, nil +} + +func (r *Resolver) aliasPackage(alias string) string { + alias = strings.TrimSpace(alias) + if alias == "" || r.context == nil { + return "" + } + for _, item := range r.context.Imports { + if item.Alias == alias { + return item.Package + } + } + return "" +} + +type candidate struct { + key string + matchKind string +} + +func (r *Resolver) unqualifiedCandidates(typeName string) []candidate { + if typeName == "" { + return nil + } + seen := map[string]bool{} + var result []candidate + + for _, scoped := range r.searchPackages() { + pkg := scoped.pkg + key := pkg + "." + typeName + if seen[key] { + continue + } + seen[key] = true + if r.registry.Lookup(key) != nil { + result = append(result, candidate{key: key, matchKind: scoped.matchKind}) + } + } + if len(result) > 0 { + return result + } + + // Global unique fallback by suffix ".TypeName" or exact built-in. + for _, key := range r.registry.Keys() { + if key == typeName || strings.HasSuffix(key, "."+typeName) { + if seen[key] { + continue + } + seen[key] = true + result = append(result, candidate{key: key, matchKind: "global_unique"}) + } + } + return result +} + +type scopedPackage struct { + pkg string + matchKind string +} + +func (r *Resolver) searchPackages() []scopedPackage { + if r.context == nil { + return nil + } + seen := map[string]bool{} + var result []scopedPackage + appendPkg := func(pkg, matchKind string) { + pkg = strings.TrimSpace(pkg) + if pkg == "" || seen[pkg] { + return + } + seen[pkg] = true + result = append(result, scopedPackage{pkg: pkg, matchKind: matchKind}) + } + appendPkg(r.context.DefaultPackage, "default_package") + for _, item := range r.context.Imports { + appendPkg(item.Package, "import_package") + } + return result +} + +func (r *Resolver) newResolution(expression, target, key, matchKind string) *Resolution { + if key == "" { + return nil + } + resolution := &Resolution{ + Expression: strings.TrimSpace(expression), + Target: strings.TrimSpace(target), + ResolvedKey: key, + MatchKind: matchKind, + Provenance: r.lookupProvenance(key), + } + return resolution +} + +func (r *Resolver) lookupProvenance(key string) Provenance { + prov := Provenance{ + Package: packageOf(key), + Kind: "registry", + } + if built, ok := r.provenance[key]; ok { + if built.Package != "" { + prov.Package = built.Package + } + if built.File != "" { + prov.File = built.File + } + if built.Kind != "" { + prov.Kind = built.Kind + } + } + return prov +} + +func cloneProvenance(input map[string]Provenance) map[string]Provenance { + if len(input) == 0 { + return nil + } + result := make(map[string]Provenance, len(input)) + for k, v := range input { + result[k] = v + } + return result +} + +func packageOf(key string) string { + index := strings.LastIndex(key, ".") + if index == -1 { + return "" + } + return key[:index] +} + +func normalizeContext(input *Context) *Context { + if input == nil { + return nil + } + ret := &Context{ + DefaultPackage: strings.TrimSpace(input.DefaultPackage), + } + for _, item := range input.Imports { + pkg := strings.TrimSpace(item.Package) + if pkg == "" { + continue + } + alias := strings.TrimSpace(item.Alias) + if alias == "" { + alias = path.Base(pkg) + } + ret.Imports = append(ret.Imports, Import{ + Alias: alias, + Package: pkg, + }) + } + if ret.DefaultPackage == "" && len(ret.Imports) == 0 { + return nil + } + return ret +} + +func splitQualified(value string) (prefix string, name string, alias bool, qualified bool) { + index := strings.LastIndex(value, ".") + if index == -1 { + return "", value, false, false + } + prefix = strings.TrimSpace(value[:index]) + name = strings.TrimSpace(value[index+1:]) + if prefix == "" || name == "" { + return "", "", false, false + } + qualified = true + alias = !strings.Contains(prefix, "/") + return prefix, name, alias, qualified +} + +func normalizeLookupKey(typeExpr string) string { + value := strings.TrimSpace(typeExpr) + for { + switch { + case strings.HasPrefix(value, "*"): + value = strings.TrimPrefix(value, "*") + case strings.HasPrefix(value, "[]"): + value = strings.TrimPrefix(value, "[]") + default: + return strings.TrimSpace(value) + } + } +} diff --git a/repository/shape/typectx/resolver_memfs_test.go b/repository/shape/typectx/resolver_memfs_test.go new file mode 100644 index 00000000..cc90b470 --- /dev/null +++ b/repository/shape/typectx/resolver_memfs_test.go @@ -0,0 +1,116 @@ +package typectx + +import ( + "context" + "path" + "testing" + "testing/fstest" + + "github.com/stretchr/testify/require" + "github.com/viant/x" + xast "github.com/viant/x/loader/ast" +) + +func TestResolver_MemFS_DefaultPackageResolution(t *testing.T) { + resolver := memFSResolver(t, baseTypeMapFS(), []string{"root/perf"}, &Context{ + DefaultPackage: "example.com/acme/perf", + }) + + key, err := resolver.Resolve("Order") + require.NoError(t, err) + require.Equal(t, "example.com/acme/perf.Order", key) +} + +func TestResolver_MemFS_AliasImportResolution(t *testing.T) { + resolver := memFSResolver(t, baseTypeMapFS(), []string{"root/perf"}, &Context{ + Imports: []Import{ + {Alias: "pf", Package: "example.com/acme/perf"}, + }, + }) + + key, err := resolver.Resolve("pf.Order") + require.NoError(t, err) + require.Equal(t, "example.com/acme/perf.Order", key) +} + +func TestResolver_MemFS_AmbiguityDetection(t *testing.T) { + resolver := memFSResolver(t, baseTypeMapFS(), []string{"root/perf", "root/shared"}, &Context{ + Imports: []Import{ + {Alias: "pf", Package: "example.com/acme/perf"}, + {Alias: "sh", Package: "example.com/acme/shared"}, + }, + }) + + key, err := resolver.Resolve("Fee") + require.Empty(t, key) + require.Error(t, err) + amb, ok := err.(*AmbiguityError) + require.True(t, ok) + require.Equal(t, []string{ + "example.com/acme/perf.Fee", + "example.com/acme/shared.Fee", + }, amb.Candidates) +} + +func TestResolver_MemFS_ProvenanceCapture(t *testing.T) { + resolver := memFSResolver(t, baseTypeMapFS(), []string{"root/perf"}, &Context{ + DefaultPackage: "example.com/acme/perf", + }) + + resolved, err := resolver.ResolveWithProvenance("Order") + require.NoError(t, err) + require.NotNil(t, resolved) + require.Equal(t, "example.com/acme/perf.Order", resolved.ResolvedKey) + require.Equal(t, "default_package", resolved.MatchKind) + require.Equal(t, "ast_type", resolved.Provenance.Kind) + require.Equal(t, "example.com/acme/perf", resolved.Provenance.Package) + require.Equal(t, "root/perf/types.go", resolved.Provenance.File) +} + +func memFSResolver(t *testing.T, fsys fstest.MapFS, packageDirs []string, ctx *Context) *Resolver { + t.Helper() + registry := x.NewRegistry() + provenance := map[string]Provenance{} + for _, dir := range packageDirs { + pkg, err := xast.LoadPackageFS(context.Background(), fsys, dir) + require.NoError(t, err) + + fileByType := map[string]string{} + for _, file := range pkg.Files { + if file == nil { + continue + } + for _, item := range file.Types { + if item == nil || item.Name == "" { + continue + } + fileByType[item.Name] = path.Join(dir, file.Name) + } + } + for _, item := range pkg.Types { + if item == nil || item.Name == "" { + continue + } + aType := &x.Type{ + Name: item.Name, + PkgPath: pkg.PkgPath, + } + registry.Register(aType) + provenance[aType.Key()] = Provenance{ + Package: pkg.PkgPath, + File: fileByType[item.Name], + Kind: "ast_type", + } + } + } + return NewResolverWithProvenance(registry, ctx, provenance) +} + +func baseTypeMapFS() fstest.MapFS { + return fstest.MapFS{ + "root/go.mod": &fstest.MapFile{Data: []byte("module example.com/acme\n\ngo 1.23\n")}, + "root/perf/types.go": &fstest.MapFile{Data: []byte("package perf\n\ntype Order struct{}\ntype Fee struct{}\n")}, + "root/shared/types.go": &fstest.MapFile{Data: []byte("package shared\n\ntype Fee struct{}\n")}, + "root/ignore/other.txt": &fstest.MapFile{Data: []byte("skip")}, + } +} diff --git a/repository/shape/typectx/resolver_test.go b/repository/shape/typectx/resolver_test.go new file mode 100644 index 00000000..f1e8e676 --- /dev/null +++ b/repository/shape/typectx/resolver_test.go @@ -0,0 +1,89 @@ +package typectx + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/x" +) + +type resolveFeeA struct{} +type resolveFeeB struct{} +type resolveOrder struct{} + +func TestResolver_Resolve_Unqualified_DefaultPackage(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/mdp/performance"), x.WithName("Order"))) + resolver := NewResolver(reg, &Context{DefaultPackage: "github.com/acme/mdp/performance"}) + + key, err := resolver.Resolve("Order") + require.NoError(t, err) + require.Equal(t, "github.com/acme/mdp/performance.Order", key) +} + +func TestResolver_Resolve_AliasQualified(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/mdp/performance"), x.WithName("Order"))) + resolver := NewResolver(reg, &Context{ + Imports: []Import{ + {Alias: "perf", Package: "github.com/acme/mdp/performance"}, + }, + }) + + key, err := resolver.Resolve("perf.Order") + require.NoError(t, err) + require.Equal(t, "github.com/acme/mdp/performance.Order", key) +} + +func TestResolver_Resolve_Unqualified_Ambiguous(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveFeeA{}), x.WithPkgPath("github.com/acme/alpha"), x.WithName("Fee"))) + reg.Register(x.NewType(reflect.TypeOf(resolveFeeB{}), x.WithPkgPath("github.com/acme/beta"), x.WithName("Fee"))) + resolver := NewResolver(reg, &Context{ + Imports: []Import{ + {Alias: "a", Package: "github.com/acme/alpha"}, + {Alias: "b", Package: "github.com/acme/beta"}, + }, + }) + + key, err := resolver.Resolve("Fee") + require.Empty(t, key) + require.Error(t, err) + amb, ok := err.(*AmbiguityError) + require.True(t, ok) + require.Equal(t, []string{ + "github.com/acme/alpha.Fee", + "github.com/acme/beta.Fee", + }, amb.Candidates) +} + +func TestResolver_Resolve_Unqualified_GlobalUniqueFallback(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/shared"), x.WithName("Order"))) + resolver := NewResolver(reg, nil) + + key, err := resolver.Resolve("Order") + require.NoError(t, err) + require.Equal(t, "github.com/acme/shared.Order", key) +} + +func TestResolver_ResolveWithProvenance(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/mdp/performance"), x.WithName("Order"))) + resolver := NewResolverWithProvenance(reg, &Context{DefaultPackage: "github.com/acme/mdp/performance"}, map[string]Provenance{ + "github.com/acme/mdp/performance.Order": { + Package: "github.com/acme/mdp/performance", + File: "/repo/mdp/performance/order.go", + Kind: "resource_type", + }, + }) + + resolved, err := resolver.ResolveWithProvenance("Order") + require.NoError(t, err) + require.NotNil(t, resolved) + require.Equal(t, "github.com/acme/mdp/performance.Order", resolved.ResolvedKey) + require.Equal(t, "default_package", resolved.MatchKind) + require.Equal(t, "/repo/mdp/performance/order.go", resolved.Provenance.File) + require.Equal(t, "resource_type", resolved.Provenance.Kind) +} diff --git a/repository/shape/typectx/source/resolver.go b/repository/shape/typectx/source/resolver.go new file mode 100644 index 00000000..639787d9 --- /dev/null +++ b/repository/shape/typectx/source/resolver.go @@ -0,0 +1,283 @@ +package source + +import ( + "fmt" + "go/ast" + "go/build" + "go/parser" + "go/token" + "golang.org/x/mod/modfile" + "os" + "path/filepath" + "sort" + "strings" +) + +type Config struct { + ProjectDir string + AllowedSourceRoots []string + UseGoModuleResolve bool + UseGOPATHFallback bool +} + +type Resolver struct { + projectDir string + modulePath string + replacements map[string]string + roots []string + useModule bool + useGOPATH bool +} + +func New(cfg Config) (*Resolver, error) { + projectDir := strings.TrimSpace(cfg.ProjectDir) + if projectDir == "" { + return nil, fmt.Errorf("typectx source: project dir was empty") + } + projectDir, err := filepath.Abs(projectDir) + if err != nil { + return nil, err + } + modulePath, replacements := loadModuleConfig(projectDir) + roots := NormalizeRoots(projectDir, cfg.AllowedSourceRoots) + return &Resolver{ + projectDir: projectDir, + modulePath: modulePath, + replacements: replacements, + roots: roots, + useModule: cfg.UseGoModuleResolve, + useGOPATH: cfg.UseGOPATHFallback, + }, nil +} + +func (r *Resolver) ResolvePackageDir(importPath string) (string, error) { + importPath = strings.TrimSpace(importPath) + if importPath == "" { + return "", fmt.Errorf("typectx source: empty import path") + } + if r.useModule { + if resolved := r.resolveReplace(importPath); resolved != "" { + return filepath.Clean(resolved), nil + } + if resolved := r.resolveProjectModule(importPath); resolved != "" { + return filepath.Clean(resolved), nil + } + if resolved := r.resolveModuleCache(importPath); resolved != "" { + return filepath.Clean(resolved), nil + } + } + if r.useGOPATH { + if resolved := resolveGOPATH(importPath); resolved != "" { + return filepath.Clean(resolved), nil + } + } + return "", fmt.Errorf("typectx source: package %s not resolved", importPath) +} + +func (r *Resolver) ResolveTypeFile(importPath, typeName string) (string, error) { + dir, err := r.ResolvePackageDir(importPath) + if err != nil { + return "", err + } + ok, err := IsWithinAnyRoot(dir, r.roots) + if err != nil { + return "", err + } + if !ok { + return "", fmt.Errorf("typectx source: package dir %s outside trusted roots", dir) + } + entries, err := os.ReadDir(dir) + if err != nil { + return "", err + } + fset := token.NewFileSet() + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if !strings.HasSuffix(name, ".go") || strings.HasSuffix(name, "_test.go") { + continue + } + filePath := filepath.Join(dir, name) + parsed, parseErr := parser.ParseFile(fset, filePath, nil, parser.PackageClauseOnly|parser.ParseComments) + if parseErr != nil || parsed == nil { + continue + } + // Reparse full declaration only when package clause parsing succeeds. + parsed, parseErr = parser.ParseFile(fset, filePath, nil, 0) + if parseErr != nil || parsed == nil { + continue + } + for _, decl := range parsed.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.TYPE { + continue + } + for _, spec := range gen.Specs { + ts, ok := spec.(*ast.TypeSpec) + if ok && ts.Name != nil && ts.Name.Name == typeName { + return filePath, nil + } + } + } + } + return "", fmt.Errorf("typectx source: type %s not found in %s", typeName, importPath) +} + +func (r *Resolver) Roots() []string { + return append([]string(nil), r.roots...) +} + +func (r *Resolver) resolveReplace(importPath string) string { + oldPaths := make([]string, 0, len(r.replacements)) + for old := range r.replacements { + oldPaths = append(oldPaths, old) + } + sort.SliceStable(oldPaths, func(i, j int) bool { return len(oldPaths[i]) > len(oldPaths[j]) }) + for _, old := range oldPaths { + if importPath != old && !strings.HasPrefix(importPath, old+"/") { + continue + } + mapped := r.replacements[old] + suffix := strings.TrimPrefix(importPath, old) + suffix = strings.TrimPrefix(suffix, "/") + if suffix == "" { + return mapped + } + return filepath.Join(mapped, filepath.FromSlash(suffix)) + } + return "" +} + +func (r *Resolver) resolveProjectModule(importPath string) string { + if r.modulePath == "" { + return "" + } + if importPath != r.modulePath && !strings.HasPrefix(importPath, r.modulePath+"/") { + return "" + } + suffix := strings.TrimPrefix(importPath, r.modulePath) + suffix = strings.TrimPrefix(suffix, "/") + if suffix == "" { + return r.projectDir + } + return filepath.Join(r.projectDir, filepath.FromSlash(suffix)) +} + +func (r *Resolver) resolveModuleCache(importPath string) string { + modCache := strings.TrimSpace(os.Getenv("GOMODCACHE")) + if modCache == "" { + if out, err := os.UserCacheDir(); err == nil && out != "" { + modCache = filepath.Join(filepath.Dir(out), "pkg", "mod") + } + } + if modCache == "" { + return "" + } + pattern := filepath.Join(modCache, filepath.FromSlash(importPath)+"@*") + matches, _ := filepath.Glob(pattern) + if len(matches) == 0 { + return "" + } + sort.Strings(matches) + return matches[len(matches)-1] +} + +func resolveGOPATH(importPath string) string { + gopath := strings.TrimSpace(os.Getenv("GOPATH")) + if gopath == "" { + gopath = strings.TrimSpace(build.Default.GOPATH) + } + if gopath == "" { + return "" + } + for _, root := range filepath.SplitList(gopath) { + candidate := filepath.Join(root, "src", filepath.FromSlash(importPath)) + if info, err := os.Stat(candidate); err == nil && info.IsDir() { + return candidate + } + } + return "" +} + +func loadModuleConfig(projectDir string) (string, map[string]string) { + result := map[string]string{} + goModPath := filepath.Join(projectDir, "go.mod") + data, err := os.ReadFile(goModPath) + if err != nil { + return "", result + } + parsed, err := modfile.Parse(goModPath, data, nil) + if err != nil || parsed == nil { + return "", result + } + modulePath := "" + if parsed.Module != nil { + modulePath = strings.TrimSpace(parsed.Module.Mod.Path) + } + for _, replace := range parsed.Replace { + if replace == nil { + continue + } + oldPath := strings.TrimSpace(replace.Old.Path) + newPath := strings.TrimSpace(replace.New.Path) + if oldPath == "" || newPath == "" || replace.New.Version != "" { + continue + } + if !filepath.IsAbs(newPath) { + newPath = filepath.Join(projectDir, newPath) + } + result[oldPath] = filepath.Clean(newPath) + } + return modulePath, result +} + +func NormalizeRoots(projectDir string, allowed []string) []string { + seen := map[string]bool{} + var result []string + appendRoot := func(value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + if !filepath.IsAbs(value) { + value = filepath.Join(projectDir, value) + } + value = filepath.Clean(value) + if seen[value] { + return + } + seen[value] = true + result = append(result, value) + } + appendRoot(projectDir) + for _, item := range allowed { + appendRoot(item) + } + sort.Strings(result) + return result +} + +func IsWithinAnyRoot(candidate string, roots []string) (bool, error) { + candidate, err := filepath.Abs(candidate) + if err != nil { + return false, err + } + candidate = filepath.Clean(candidate) + for _, root := range roots { + root = filepath.Clean(root) + rel, err := filepath.Rel(root, candidate) + if err != nil { + return false, err + } + if rel == "." { + return true, nil + } + rel = filepath.ToSlash(rel) + if !strings.HasPrefix(rel, "../") { + return true, nil + } + } + return false, nil +} diff --git a/repository/shape/typectx/source/resolver_test.go b/repository/shape/typectx/source/resolver_test.go new file mode 100644 index 00000000..541c11af --- /dev/null +++ b/repository/shape/typectx/source/resolver_test.go @@ -0,0 +1,91 @@ +package source + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestResolver_ResolvePackageDir_UsesLocalReplace(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "project") + modelsDir := filepath.Join(root, "shared-models") + require.NoError(t, os.MkdirAll(filepath.Join(projectDir, "internal"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(modelsDir, "mdp"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(modelsDir, "go.mod"), []byte("module github.com/acme/models\n\ngo 1.25\n"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte(`module example.com/project +go 1.25 +replace github.com/acme/models => ../shared-models +`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(modelsDir, "mdp", "types.go"), []byte("package mdp\ntype Order struct{}\n"), 0o644)) + + resolver, err := New(Config{ + ProjectDir: projectDir, + UseGoModuleResolve: true, + UseGOPATHFallback: false, + }) + require.NoError(t, err) + dir, err := resolver.ResolvePackageDir("github.com/acme/models/mdp") + require.NoError(t, err) + require.Equal(t, filepath.Join(modelsDir, "mdp"), dir) +} + +func TestResolver_ResolveTypeFile_RespectsTrustedRoots(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "project") + modelsDir := filepath.Join(root, "shared-models") + require.NoError(t, os.MkdirAll(filepath.Join(projectDir, "internal"), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(modelsDir, "mdp"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte(`module example.com/project +go 1.25 +replace github.com/acme/models => ../shared-models +`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(modelsDir, "mdp", "types.go"), []byte("package mdp\ntype Order struct{}\n"), 0o644)) + + denyResolver, err := New(Config{ + ProjectDir: projectDir, + UseGoModuleResolve: true, + UseGOPATHFallback: false, + }) + require.NoError(t, err) + _, err = denyResolver.ResolveTypeFile("github.com/acme/models/mdp", "Order") + require.Error(t, err) + + allowResolver, err := New(Config{ + ProjectDir: projectDir, + AllowedSourceRoots: []string{modelsDir}, + UseGoModuleResolve: true, + UseGOPATHFallback: false, + }) + require.NoError(t, err) + file, err := allowResolver.ResolveTypeFile("github.com/acme/models/mdp", "Order") + require.NoError(t, err) + require.Equal(t, filepath.Join(modelsDir, "mdp", "types.go"), file) +} + +func TestResolver_ResolvePackageDir_GOPATHFallback(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "project") + gopath := filepath.Join(root, "gopath") + require.NoError(t, os.MkdirAll(filepath.Join(projectDir, "internal"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/project\ngo 1.25\n"), 0o644)) + legacyDir := filepath.Join(gopath, "src", "github.com", "legacy", "models") + require.NoError(t, os.MkdirAll(legacyDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(legacyDir, "types.go"), []byte("package models\ntype Legacy struct{}\n"), 0o644)) + + orig := os.Getenv("GOPATH") + require.NoError(t, os.Setenv("GOPATH", gopath)) + defer func() { _ = os.Setenv("GOPATH", orig) }() + + resolver, err := New(Config{ + ProjectDir: projectDir, + UseGoModuleResolve: false, + UseGOPATHFallback: true, + }) + require.NoError(t, err) + dir, err := resolver.ResolvePackageDir("github.com/legacy/models") + require.NoError(t, err) + require.Equal(t, legacyDir, dir) +} diff --git a/repository/shape/validate/relation.go b/repository/shape/validate/relation.go new file mode 100644 index 00000000..31aee935 --- /dev/null +++ b/repository/shape/validate/relation.go @@ -0,0 +1,140 @@ +package validate + +import ( + "fmt" + "strings" + + "github.com/viant/datly/view" +) + +// ValidateRelations validates that relation link columns can be resolved on both +// parent and referenced views. It accepts alias/source/field variants and +// namespace-qualified forms (e.g. t.ID -> ID). +func ValidateRelations(resource *view.Resource, targets ...*view.View) error { + if resource == nil { + return nil + } + views := targets + if len(views) == 0 { + views = resource.Views + } + index := resource.Views.Index() + var issues []string + for _, parent := range views { + if parent == nil { + continue + } + parentIndex := view.Columns(parent.Columns).Index(parent.CaseFormat) + for _, rel := range parent.With { + if rel == nil || rel.Of == nil { + continue + } + ref := &rel.Of.View + if ref.Ref != "" { + if lookup, err := index.Lookup(ref.Ref); err == nil && lookup != nil { + ref = lookup + } + } + refIndex := view.Columns(ref.Columns).Index(ref.CaseFormat) + pairCount := len(rel.On) + if len(rel.Of.On) > pairCount { + pairCount = len(rel.Of.On) + } + for i := 0; i < pairCount; i++ { + var parentLink, refLink *view.Link + if i < len(rel.On) { + parentLink = rel.On[i] + } + if i < len(rel.Of.On) { + refLink = rel.Of.On[i] + } + + if missing := missingColumn(parentIndex, parentLink); missing != "" { + issues = append(issues, fmt.Sprintf("relation %q (parent=%q holder=%q link=%d): missing parent column %q", relName(rel, i), parent.Name, rel.Holder, i, missing)) + } + if missing := missingColumn(refIndex, refLink); missing != "" { + issues = append(issues, fmt.Sprintf("relation %q (parent=%q ref=%q holder=%q link=%d): missing ref column %q", relName(rel, i), parent.Name, ref.Name, rel.Holder, i, missing)) + } + } + } + } + if len(issues) == 0 { + return nil + } + return fmt.Errorf("shape relation validation failed:\n- %s", strings.Join(issues, "\n- ")) +} + +func missingColumn(index view.NamedColumns, link *view.Link) string { + if link == nil { + return "" + } + for _, candidate := range linkCandidates(link) { + if strings.TrimSpace(candidate) == "" { + continue + } + if _, err := index.Lookup(candidate); err == nil { + return "" + } + } + for _, candidate := range linkCandidates(link) { + if strings.TrimSpace(candidate) != "" { + return candidate + } + } + return "" +} + +func linkCandidates(link *view.Link) []string { + if link == nil { + return nil + } + var result []string + add := func(v string) { + v = strings.TrimSpace(trimIdentifier(v)) + if v == "" { + return + } + result = append(result, v) + if i := strings.LastIndex(v, "."); i != -1 && i < len(v)-1 { + result = append(result, v[i+1:]) + } + } + add(link.Column) + if link.Namespace != "" && link.Column != "" { + add(link.Namespace + "." + link.Column) + } + add(link.Field) + return dedupe(result) +} + +func trimIdentifier(value string) string { + value = strings.TrimSpace(value) + value = strings.Trim(value, "`") + value = strings.Trim(value, "\"") + value = strings.Trim(value, "'") + return value +} + +func dedupe(values []string) []string { + seen := map[string]bool{} + result := make([]string, 0, len(values)) + for _, value := range values { + key := strings.ToLower(strings.TrimSpace(value)) + if key == "" || seen[key] { + continue + } + seen[key] = true + result = append(result, value) + } + return result +} + +func relName(rel *view.Relation, idx int) string { + if rel == nil { + return fmt.Sprintf("#%d", idx) + } + if strings.TrimSpace(rel.Name) != "" { + return rel.Name + } + return fmt.Sprintf("#%d", idx) +} diff --git a/repository/shape/validate/relation_test.go b/repository/shape/validate/relation_test.go new file mode 100644 index 00000000..e1031788 --- /dev/null +++ b/repository/shape/validate/relation_test.go @@ -0,0 +1,70 @@ +package validate + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +func TestValidateRelations_AllowsAliasSourceAndNamespace(t *testing.T) { + parent := &view.View{ + Name: "vendor", + Columns: view.Columns{ + view.NewColumn("ID", "int", nil, false), + }, + } + child := &view.View{ + Name: "products", + Columns: view.Columns{ + view.NewColumn("VendorID", "int", nil, false, view.WithColumnTag(`source:"VENDOR_ID"`)), + }, + } + parent.With = []*view.Relation{{ + Name: "products", + Cardinality: state.Many, + Holder: "Products", + On: view.Links{&view.Link{Column: "vendor.ID"}}, + Of: &view.ReferenceView{ + View: *child, + On: view.Links{&view.Link{Column: "VENDOR_ID"}}, + }, + }} + resource := view.EmptyResource() + resource.Views = append(resource.Views, parent, child) + require.NoError(t, ValidateRelations(resource, parent)) +} + +func TestValidateRelations_DetailedMissingError(t *testing.T) { + parent := &view.View{ + Name: "vendor", + Columns: view.Columns{ + view.NewColumn("ID", "int", nil, false), + }, + } + child := &view.View{ + Name: "products", + Columns: view.Columns{ + view.NewColumn("VendorID", "int", nil, false), + }, + } + parent.With = []*view.Relation{{ + Name: "products", + Cardinality: state.Many, + Holder: "Products", + On: view.Links{&view.Link{Column: "MISSING_PARENT"}}, + Of: &view.ReferenceView{ + View: *child, + On: view.Links{&view.Link{Column: "MISSING_CHILD"}}, + }, + }} + resource := view.EmptyResource() + resource.Views = append(resource.Views, parent, child) + err := ValidateRelations(resource, parent) + require.Error(t, err) + require.Contains(t, err.Error(), "missing parent column \"MISSING_PARENT\"") + require.Contains(t, err.Error(), "missing ref column \"MISSING_CHILD\"") + require.Contains(t, err.Error(), "parent=\"vendor\"") + require.Contains(t, err.Error(), "ref=\"products\"") +} diff --git a/repository/shape/xgen/generator.go b/repository/shape/xgen/generator.go new file mode 100644 index 00000000..89622576 --- /dev/null +++ b/repository/shape/xgen/generator.go @@ -0,0 +1,644 @@ +package xgen + +import ( + "fmt" + "go/ast" + "os" + "path/filepath" + "reflect" + "sort" + "strings" + + "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/datly/repository/shape/typectx/source" + "github.com/viant/x" + xreflectloader "github.com/viant/x/loader/xreflect" + "github.com/viant/x/syntetic" + "github.com/viant/x/syntetic/model" +) + +// GenerateFromDQLShape emits Go structs from DQL shape using viant/x registry. +func GenerateFromDQLShape(doc *shape.Document, cfg *Config) (*Result, error) { + if doc == nil || doc.Root == nil { + return nil, fmt.Errorf("shape xgen: nil document") + } + if cfg == nil { + cfg = &Config{} + } + applyDefaults(cfg) + projectDir, packageDir, err := resolvePaths(cfg.ProjectDir, cfg.PackageDir) + if err != nil { + return nil, err + } + packageName := resolvePackageName(cfg.PackageName, packageDir) + packagePath, err := resolvePackagePath(cfg.PackagePath, projectDir, packageDir) + if err != nil { + return nil, err + } + fileName := cfg.FileName + if strings.TrimSpace(fileName) == "" { + fileName = "shapes_gen.go" + } + registry := cfg.Registry + if registry == nil { + registry = x.NewRegistry() + } + views := extractViews(doc.Root) + routeTypes := extractRouteIO(doc.Root) + if len(views) == 0 && len(routeTypes) == 0 { + return nil, fmt.Errorf("shape xgen: no view or route io declarations") + } + typeNames := make([]string, 0, len(views)+len(routeTypes)) + registered := map[string]bool{} + for _, view := range views { + typeName := viewTypeName(cfg, view) + if registered[typeName] { + continue + } + registered[typeName] = true + if err = registerShapeType(registry, packagePath, typeName, buildStructType(view.columns)); err != nil { + return nil, err + } + typeNames = append(typeNames, typeName) + } + for _, ioType := range routeTypes { + typeName := routeTypeName(cfg, ioType) + if typeName == "" || registered[typeName] { + continue + } + registered[typeName] = true + if err = registerShapeType(registry, packagePath, typeName, buildStructType(ioType.fields)); err != nil { + return nil, err + } + typeNames = append(typeNames, typeName) + } + namespace, err := syntetic.FromRegistry(registry) + if err != nil { + return nil, err + } + namespace.PkgName = packageName + namespace.PkgPath = packagePath + files, err := namespace.BuildFiles(model.RenderOptions{}) + if err != nil { + return nil, err + } + goFile := files[packagePath] + if goFile == nil { + return nil, fmt.Errorf("shape xgen: missing generated package file for %s", packagePath) + } + source, err := goFile.Render() + if err != nil { + return nil, err + } + if err = os.MkdirAll(packageDir, 0o755); err != nil { + return nil, err + } + dest := filepath.Join(packageDir, fileName) + if exists, checkErr := fileExists(dest); checkErr != nil { + return nil, checkErr + } else if exists && !cfg.AllowUnsafeRewrite { + if issues := rewriteSafetyIssues(doc, cfg, projectDir); len(issues) > 0 && (cfg.StrictProvenance == nil || *cfg.StrictProvenance) { + return nil, fmt.Errorf("shape xgen: rewrite blocked by type provenance safety: %s", strings.Join(issues, "; ")) + } + merged, mergeErr := mergeGeneratedShapes(dest, []byte(source), typeNames) + if mergeErr != nil { + return nil, mergeErr + } + source = string(merged) + } + if err = writeAtomic(dest, []byte(source), 0o644); err != nil { + return nil, err + } + sort.Strings(typeNames) + return &Result{ + FilePath: dest, + PackagePath: packagePath, + PackageName: packageName, + Types: typeNames, + }, nil +} + +func rewriteSafetyIssues(doc *shape.Document, cfg *Config, projectDir string) []string { + if doc == nil || len(doc.TypeResolutions) == 0 { + return nil + } + policy := newRewritePolicy(cfg, projectDir) + srcResolver, _ := source.New(source.Config{ + ProjectDir: projectDir, + AllowedSourceRoots: policy.roots, + UseGoModuleResolve: policy.useModule, + UseGOPATHFallback: policy.useGOPATH, + }) + var issues []string + for _, resolution := range doc.TypeResolutions { + if srcResolver != nil && strings.TrimSpace(resolution.Provenance.File) == "" { + pkg := firstNonEmpty(strings.TrimSpace(resolution.Provenance.Package), packageOfKey(resolution.ResolvedKey)) + name := typeNameFromKey(resolution.ResolvedKey) + if pkg != "" && name != "" { + if file, err := srcResolver.ResolveTypeFile(pkg, name); err == nil { + resolution.Provenance.File = file + if resolution.Provenance.Kind == "" || strings.EqualFold(resolution.Provenance.Kind, "registry") { + resolution.Provenance.Kind = "ast_type" + } + } + } + } + if issue := resolutionSafetyIssue(resolution, policy); issue != "" { + issues = append(issues, issue) + } + } + sort.Strings(issues) + return uniqueStrings(issues) +} + +func resolutionSafetyIssue(resolution typectx.Resolution, policy rewritePolicy) string { + kind := strings.TrimSpace(strings.ToLower(resolution.Provenance.Kind)) + if kind == "" { + kind = "registry" + } + if !policy.allowedKinds[kind] { + return fmt.Sprintf("expression=%q kind=%q", resolution.Expression, resolution.Provenance.Kind) + } + + sourceFile := strings.TrimSpace(resolution.Provenance.File) + if sourceFile == "" { + return "" + } + if !filepath.IsAbs(sourceFile) { + sourceFile = filepath.Clean(filepath.Join(policy.projectDir, sourceFile)) + } + if safe, err := source.IsWithinAnyRoot(sourceFile, policy.roots); err != nil || !safe { + return fmt.Sprintf("expression=%q source=%q outside_trusted_roots", resolution.Expression, resolution.Provenance.File) + } + return "" +} + +type rewritePolicy struct { + projectDir string + allowedKinds map[string]bool + roots []string + useModule bool + useGOPATH bool +} + +func newRewritePolicy(cfg *Config, projectDir string) rewritePolicy { + allowedKinds := map[string]bool{ + "builtin": true, + "resource_type": true, + "ast_type": true, + } + if len(cfg.AllowedProvenanceKinds) > 0 { + allowedKinds = map[string]bool{} + for _, item := range cfg.AllowedProvenanceKinds { + item = strings.TrimSpace(strings.ToLower(item)) + if item != "" { + allowedKinds[item] = true + } + } + } + useModule := true + if cfg.UseGoModuleResolve != nil { + useModule = *cfg.UseGoModuleResolve + } + useGOPATH := true + if cfg.UseGOPATHFallback != nil { + useGOPATH = *cfg.UseGOPATHFallback + } + return rewritePolicy{ + projectDir: projectDir, + allowedKinds: allowedKinds, + roots: source.NormalizeRoots(projectDir, cfg.AllowedSourceRoots), + useModule: useModule, + useGOPATH: useGOPATH, + } +} + +func typeNameFromKey(key string) string { + index := strings.LastIndex(key, ".") + if index == -1 || index+1 >= len(key) { + return "" + } + return key[index+1:] +} + +func packageOfKey(key string) string { + index := strings.LastIndex(key, ".") + if index == -1 { + return "" + } + return key[:index] +} + +func uniqueStrings(items []string) []string { + if len(items) < 2 { + return items + } + result := items[:0] + var previous string + for i, item := range items { + if i == 0 || item != previous { + result = append(result, item) + } + previous = item + } + return result +} + +func registerShapeType(registry *x.Registry, packagePath string, typeName string, rType reflect.Type) error { + st, err := xreflectloader.BuildType(rType, + xreflectloader.WithPackagePath(packagePath), + xreflectloader.WithNamePolicy(func(reflect.Type) (string, bool) { + return typeName, false + })) + if err != nil { + return fmt.Errorf("shape xgen: build type %s failed: %w", typeName, err) + } + st.Name = typeName + st.PkgPath = packagePath + if st.TypeSpec != nil { + st.TypeSpec.Name = ast.NewIdent(typeName) + } + registry.Register(x.NewType(rType, + x.WithName(typeName), + x.WithPkgPath(packagePath), + x.WithSyntheticType(st))) + return nil +} + +type viewDescriptor struct { + name any + schemaName any + columns []columnDescriptor +} + +type ioTypeKind string + +const ( + ioTypeInput ioTypeKind = "input" + ioTypeOutput ioTypeKind = "output" +) + +type routeIODescriptor struct { + kind ioTypeKind + routeName string + routeURI string + routeRef string + typeName string + fields []columnDescriptor +} + +type columnDescriptor struct { + name string + dataType string +} + +func extractViews(root map[string]any) []viewDescriptor { + resource := asMap(root["Resource"]) + if resource == nil { + return nil + } + items := asSlice(resource["Views"]) + result := make([]viewDescriptor, 0, len(items)) + for _, item := range items { + view := asMap(item) + if view == nil { + continue + } + schema := asMap(view["Schema"]) + descriptor := viewDescriptor{ + name: view["Name"], + schemaName: nil, + } + if schema != nil { + descriptor.schemaName = schema["Name"] + } + descriptor.columns = extractColumns(view) + result = append(result, descriptor) + } + return result +} + +func extractColumns(view map[string]any) []columnDescriptor { + var result []columnDescriptor + if columns := asSlice(view["Columns"]); len(columns) > 0 { + for _, item := range columns { + column := asMap(item) + if column == nil { + continue + } + name := firstNonEmpty(asString(column["Name"]), asString(column["Column"])) + if name == "" { + continue + } + result = append(result, columnDescriptor{name: name, dataType: asString(column["DataType"])}) + } + } + if cfg := asMap(view["ColumnsConfig"]); len(cfg) > 0 { + keys := make([]string, 0, len(cfg)) + for k := range cfg { + keys = append(keys, k) + } + sort.Strings(keys) + for _, key := range keys { + item := asMap(cfg[key]) + if item == nil { + item = map[string]any{} + } + name := firstNonEmpty(asString(item["Name"]), key) + result = append(result, columnDescriptor{name: name, dataType: asString(item["DataType"])}) + } + } + if len(result) == 0 { + result = append(result, columnDescriptor{name: "ID", dataType: "int"}) + } + return result +} + +func extractRouteIO(root map[string]any) []routeIODescriptor { + var result []routeIODescriptor + for _, item := range asSlice(root["Routes"]) { + route := asMap(item) + if route == nil { + continue + } + meta := routeIODescriptor{ + routeName: asString(route["Name"]), + routeURI: asString(route["URI"]), + } + if routeView := asMap(route["View"]); routeView != nil { + meta.routeRef = asString(routeView["Ref"]) + } + if input := asMap(route["Input"]); input != nil { + entry := meta + entry.kind = ioTypeInput + entry.typeName = nestedTypeName(input) + entry.fields = extractIOFields(input) + result = append(result, entry) + } + if output := asMap(route["Output"]); output != nil { + entry := meta + entry.kind = ioTypeOutput + entry.typeName = nestedTypeName(output) + entry.fields = extractIOFields(output) + result = append(result, entry) + } + } + return result +} + +func nestedTypeName(io map[string]any) string { + aType := asMap(io["Type"]) + if aType == nil { + return "" + } + return asString(aType["Name"]) +} + +func extractIOFields(io map[string]any) []columnDescriptor { + parameters := asSlice(io["Parameters"]) + if len(parameters) == 0 { + if t := asMap(io["Type"]); t != nil { + parameters = asSlice(t["Parameters"]) + } + } + fields := make([]columnDescriptor, 0, len(parameters)) + for _, item := range parameters { + param := asMap(item) + if param == nil { + continue + } + name := asString(param["Name"]) + if name == "" { + continue + } + dataType := "" + if schema := asMap(param["Schema"]); schema != nil { + dataType = asString(schema["DataType"]) + } + fields = append(fields, columnDescriptor{name: name, dataType: dataType}) + } + if len(fields) == 0 { + fields = append(fields, columnDescriptor{name: "ID", dataType: "int"}) + } + return fields +} + +func buildStructType(columns []columnDescriptor) reflect.Type { + if len(columns) == 0 { + columns = []columnDescriptor{{name: "ID", dataType: "int"}} + } + fields := make([]reflect.StructField, 0, len(columns)) + used := map[string]int{} + for _, column := range columns { + fieldName := exportedName(column.name) + if fieldName == "" { + fieldName = "Field" + } + if count := used[fieldName]; count > 0 { + fieldName = fmt.Sprintf("%s%d", fieldName, count+1) + } + used[fieldName]++ + fields = append(fields, reflect.StructField{ + Name: fieldName, + Type: parseType(column.dataType), + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s,omitempty" sqlx:"%s"`, strings.ToLower(fieldName), column.name)), + }) + } + return reflect.StructOf(fields) +} + +func parseType(dataType string) reflect.Type { + dataType = strings.TrimSpace(dataType) + if dataType == "" { + return reflect.TypeOf("") + } + switch { + case strings.HasPrefix(dataType, "[]"): + return reflect.SliceOf(parseType(strings.TrimPrefix(dataType, "[]"))) + case strings.HasPrefix(dataType, "*"): + return reflect.PointerTo(parseType(strings.TrimPrefix(dataType, "*"))) + } + lowered := strings.ToLower(dataType) + switch lowered { + case "string", "varchar", "text": + return reflect.TypeOf("") + case "bool", "boolean": + return reflect.TypeOf(true) + case "int", "integer": + return reflect.TypeOf(int(0)) + case "int64", "bigint": + return reflect.TypeOf(int64(0)) + case "int32": + return reflect.TypeOf(int32(0)) + case "float", "float64", "double", "decimal": + return reflect.TypeOf(float64(0)) + case "float32": + return reflect.TypeOf(float32(0)) + case "bytes", "[]byte", "blob": + return reflect.TypeOf([]byte{}) + default: + return reflect.TypeOf("") + } +} + +func exportedName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + var parts []string + current := strings.Builder{} + flush := func() { + if current.Len() == 0 { + return + } + parts = append(parts, current.String()) + current.Reset() + } + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') { + current.WriteRune(r) + } else { + flush() + } + } + flush() + for i, item := range parts { + if item == strings.ToUpper(item) { + parts[i] = strings.ToUpper(item[:1]) + strings.ToLower(item[1:]) + } else { + parts[i] = strings.ToUpper(item[:1]) + item[1:] + } + } + result := strings.Join(parts, "") + if result == "" { + return "" + } + if result[0] >= '0' && result[0] <= '9' { + result = "N" + result + } + return result +} + +func applyDefaults(cfg *Config) { + if cfg.ViewSuffix == "" { + cfg.ViewSuffix = "View" + } + if cfg.InputSuffix == "" { + cfg.InputSuffix = "Input" + } + if cfg.OutputSuffix == "" { + cfg.OutputSuffix = "Output" + } + if cfg.UseGoModuleResolve == nil { + value := true + cfg.UseGoModuleResolve = &value + } + if cfg.UseGOPATHFallback == nil { + value := true + cfg.UseGOPATHFallback = &value + } + if cfg.StrictProvenance == nil { + value := true + cfg.StrictProvenance = &value + } +} + +func viewTypeName(cfg *Config, view viewDescriptor) string { + ctx := ViewTypeContext{ + ViewName: asString(view.name), + SchemaName: asString(view.schemaName), + } + if cfg.ViewTypeNamer != nil { + if name := strings.TrimSpace(cfg.ViewTypeNamer(ctx)); name != "" { + return cfg.TypePrefix + exportedName(name) + } + } + base := firstNonEmpty(ctx.SchemaName, ctx.ViewName) + if base == "" { + base = cfg.ViewSuffix + } else if !hasCaseInsensitiveSuffix(base, cfg.ViewSuffix) { + base += cfg.ViewSuffix + } + return cfg.TypePrefix + exportedName(base) +} + +func routeTypeName(cfg *Config, route routeIODescriptor) string { + ctx := RouteTypeContext{ + RouteName: route.routeName, + RouteURI: route.routeURI, + RouteRef: route.routeRef, + TypeName: route.typeName, + } + var custom string + switch route.kind { + case ioTypeInput: + if cfg.InputTypeNamer != nil { + custom = cfg.InputTypeNamer(ctx) + } + case ioTypeOutput: + if cfg.OutputTypeNamer != nil { + custom = cfg.OutputTypeNamer(ctx) + } + } + if strings.TrimSpace(custom) != "" { + return cfg.TypePrefix + exportedName(custom) + } + base := firstNonEmpty(ctx.TypeName, ctx.RouteName, ctx.RouteRef, "Route") + suffix := cfg.OutputSuffix + if route.kind == ioTypeInput { + suffix = cfg.InputSuffix + } + if !hasCaseInsensitiveSuffix(base, suffix) { + base += suffix + } + return cfg.TypePrefix + exportedName(base) +} + +func hasCaseInsensitiveSuffix(value, suffix string) bool { + if suffix == "" { + return true + } + return strings.HasSuffix(strings.ToLower(value), strings.ToLower(suffix)) +} + +func firstNonEmpty(values ...string) string { + for _, item := range values { + if strings.TrimSpace(item) != "" { + return item + } + } + return "" +} + +func asMap(raw any) map[string]any { + if value, ok := raw.(map[string]any); ok { + return value + } + if value, ok := raw.(map[any]any); ok { + out := map[string]any{} + for key, item := range value { + out[fmt.Sprint(key)] = item + } + return out + } + return nil +} + +func asSlice(raw any) []any { + if value, ok := raw.([]any); ok { + return value + } + return nil +} + +func asString(raw any) string { + if raw == nil { + return "" + } + if value, ok := raw.(string); ok { + return value + } + return fmt.Sprint(raw) +} diff --git a/repository/shape/xgen/generator_test.go b/repository/shape/xgen/generator_test.go new file mode 100644 index 00000000..3315f148 --- /dev/null +++ b/repository/shape/xgen/generator_test.go @@ -0,0 +1,305 @@ +package xgen + +import ( + "go/parser" + "go/token" + "os" + "path/filepath" + "strings" + "testing" + + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" +) + +func TestGenerateFromDQLShape(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + doc := &dqlshape.Document{Root: map[string]any{ + "Routes": []any{ + map[string]any{ + "Name": "orders", + "URI": "/orders", + "View": map[string]any{"Ref": "orders"}, + "Input": map[string]any{ + "Type": map[string]any{"Name": "OrdersFilter"}, + "Parameters": []any{ + map[string]any{ + "Name": "status", + "Schema": map[string]any{ + "DataType": "string", + }, + }, + }, + }, + "Output": map[string]any{ + "Type": map[string]any{"Name": "OrdersPayload"}, + "Parameters": []any{ + map[string]any{ + "Name": "total", + "Schema": map[string]any{ + "DataType": "int", + }, + }, + }, + }, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "orders", + "Schema": map[string]any{ + "Name": "OrderView", + }, + "ColumnsConfig": map[string]any{ + "ID": map[string]any{"Name": "ID", "DataType": "int"}, + "NAME": map[string]any{"Name": "NAME", "DataType": "string"}, + }, + }, + }, + }, + }} + result, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + TypePrefix: "DQL", + }) + if err != nil { + t.Fatalf("generate failed: %v", err) + } + if result == nil { + t.Fatalf("nil result") + } + if len(result.Types) == 0 { + t.Fatalf("expected generated types") + } + if _, err = os.Stat(result.FilePath); err != nil { + t.Fatalf("generated file missing: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file failed: %v", err) + } + source := string(data) + if !strings.Contains(source, "type DQLOrderView struct") { + t.Fatalf("expected generated type in source, got:\n%s", source) + } + if !strings.Contains(source, "type DQLOrdersFilterInput struct") || !strings.Contains(source, "type DQLOrdersPayloadOutput struct") { + t.Fatalf("expected io types in source, got:\n%s", source) + } + if !strings.Contains(source, "Id") || !strings.Contains(source, "Name") { + t.Fatalf("expected generated fields in source, got:\n%s", source) + } + fset := token.NewFileSet() + if _, err = parser.ParseFile(fset, result.FilePath, source, parser.AllErrors); err != nil { + t.Fatalf("generated file parse failed: %v", err) + } +} + +func TestGenerateFromDQLShape_CustomTypeNamers(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + doc := &dqlshape.Document{Root: map[string]any{ + "Routes": []any{ + map[string]any{ + "Name": "orders", + "Input": map[string]any{ + "Parameters": []any{map[string]any{"Name": "q", "Schema": map[string]any{"DataType": "string"}}}, + }, + "Output": map[string]any{ + "Parameters": []any{map[string]any{"Name": "count", "Schema": map[string]any{"DataType": "int"}}}, + }, + }, + }, + "Resource": map[string]any{ + "Views": []any{ + map[string]any{"Name": "orders", "ColumnsConfig": map[string]any{"ID": map[string]any{"Name": "ID", "DataType": "int"}}}, + }, + }, + }} + result, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + ViewTypeNamer: func(ctx ViewTypeContext) string { + return "DataOrders" + }, + InputTypeNamer: func(ctx RouteTypeContext) string { + return "ReqOrders" + }, + OutputTypeNamer: func(ctx RouteTypeContext) string { + return "ResOrders" + }, + }) + if err != nil { + t.Fatalf("generate failed: %v", err) + } + data, err := os.ReadFile(result.FilePath) + if err != nil { + t.Fatalf("read generated file failed: %v", err) + } + source := string(data) + if !strings.Contains(source, "type DataOrders struct") { + t.Fatalf("missing custom view type: %s", source) + } + if !strings.Contains(source, "type ReqOrders struct") { + t.Fatalf("missing custom input type: %s", source) + } + if !strings.Contains(source, "type ResOrders struct") { + t.Fatalf("missing custom output type: %s", source) + } +} + +func TestGenerateFromDQLShape_BlocksUnsafeRewriteByProvenance(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + packageDir := filepath.Join(projectDir, "internal", "gen") + if err := os.MkdirAll(packageDir, 0o755); err != nil { + t.Fatalf("mkdir failed: %v", err) + } + dest := filepath.Join(packageDir, "shapes_gen.go") + if err := os.WriteFile(dest, []byte("package gen\n"), 0o644); err != nil { + t.Fatalf("seed file failed: %v", err) + } + + doc := &dqlshape.Document{ + Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{"Name": "orders", "ColumnsConfig": map[string]any{"ID": map[string]any{"Name": "ID", "DataType": "int"}}}, + }, + }, + }, + TypeResolutions: []typectx.Resolution{ + { + Expression: "Fee", + Provenance: typectx.Provenance{Kind: "registry"}, + }, + }, + } + _, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + }) + if err == nil || !strings.Contains(err.Error(), "rewrite blocked") { + t.Fatalf("expected rewrite blocked error, got: %v", err) + } +} + +func TestGenerateFromDQLShape_AllowsUnsafeRewriteWithOverride(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + packageDir := filepath.Join(projectDir, "internal", "gen") + if err := os.MkdirAll(packageDir, 0o755); err != nil { + t.Fatalf("mkdir failed: %v", err) + } + dest := filepath.Join(packageDir, "shapes_gen.go") + if err := os.WriteFile(dest, []byte("package gen\n"), 0o644); err != nil { + t.Fatalf("seed file failed: %v", err) + } + + doc := &dqlshape.Document{ + Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{"Name": "orders", "ColumnsConfig": map[string]any{"ID": map[string]any{"Name": "ID", "DataType": "int"}}}, + }, + }, + }, + TypeResolutions: []typectx.Resolution{ + { + Expression: "Fee", + Provenance: typectx.Provenance{Kind: "registry"}, + }, + }, + } + result, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + AllowUnsafeRewrite: true, + }) + if err != nil { + t.Fatalf("expected override rewrite success, got: %v", err) + } + if result == nil || result.FilePath == "" { + t.Fatalf("expected generated result") + } +} + +func TestGenerateFromDQLShape_MergesIntoExistingFile(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + packageDir := filepath.Join(projectDir, "internal", "gen") + if err := os.MkdirAll(packageDir, 0o755); err != nil { + t.Fatalf("mkdir failed: %v", err) + } + dest := filepath.Join(packageDir, "shapes_gen.go") + initial := `package gen + +type DQLOrderView struct { + Old string ` + "`json:\"old,omitempty\"`" + ` +} + +func KeepCustom() string { return "ok" } +` + if err := os.WriteFile(dest, []byte(initial), 0o644); err != nil { + t.Fatalf("seed file failed: %v", err) + } + + doc := &dqlshape.Document{Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "orders", + "Schema": map[string]any{ + "Name": "OrderView", + }, + "ColumnsConfig": map[string]any{ + "ID": map[string]any{"Name": "ID", "DataType": "int"}, + }, + }, + }, + }, + }} + _, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + TypePrefix: "DQL", + }) + if err != nil { + t.Fatalf("generate failed: %v", err) + } + + data, err := os.ReadFile(dest) + if err != nil { + t.Fatalf("read generated file failed: %v", err) + } + source := string(data) + if !strings.Contains(source, "func KeepCustom() string") { + t.Fatalf("expected custom function preserved, got:\n%s", source) + } + if strings.Contains(source, "Old string") { + t.Fatalf("expected old shape declaration replaced, got:\n%s", source) + } + if !strings.Contains(source, "type DQLOrderView struct") || !strings.Contains(source, "Id int") { + t.Fatalf("expected updated shape declaration, got:\n%s", source) + } +} diff --git a/repository/shape/xgen/io.go b/repository/shape/xgen/io.go new file mode 100644 index 00000000..395ea8e9 --- /dev/null +++ b/repository/shape/xgen/io.go @@ -0,0 +1,311 @@ +package xgen + +import ( + "bufio" + "bytes" + "fmt" + "go/ast" + "go/format" + "go/parser" + "go/token" + "os" + "path/filepath" + "sort" + "strings" +) + +func resolvePaths(projectDir, packageDir string) (string, string, error) { + if strings.TrimSpace(projectDir) == "" { + return "", "", fmt.Errorf("shape xgen: project dir was empty") + } + projectDir = filepath.Clean(projectDir) + if strings.TrimSpace(packageDir) == "" { + packageDir = projectDir + } else if !filepath.IsAbs(packageDir) { + packageDir = filepath.Join(projectDir, packageDir) + } + packageDir = filepath.Clean(packageDir) + return projectDir, packageDir, nil +} + +func resolvePackageName(name string, packageDir string) string { + name = strings.TrimSpace(name) + if name != "" { + return name + } + base := filepath.Base(packageDir) + if base == "." || base == string(filepath.Separator) || base == "" { + return "generated" + } + return sanitizePkg(base) +} + +func resolvePackagePath(packagePath, projectDir, packageDir string) (string, error) { + packagePath = strings.TrimSpace(packagePath) + if packagePath != "" { + return packagePath, nil + } + modulePath, err := readModulePath(filepath.Join(projectDir, "go.mod")) + if err != nil { + return "", err + } + rel, err := filepath.Rel(projectDir, packageDir) + if err != nil { + return "", err + } + rel = filepath.ToSlash(rel) + if rel == "." { + return modulePath, nil + } + return strings.TrimRight(modulePath, "/") + "/" + strings.TrimLeft(rel, "/"), nil +} + +func readModulePath(goModPath string) (string, error) { + file, err := os.Open(goModPath) + if err != nil { + return "", fmt.Errorf("shape xgen: open go.mod failed: %w", err) + } + defer file.Close() + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if !strings.HasPrefix(line, "module ") { + continue + } + modulePath := strings.TrimSpace(strings.TrimPrefix(line, "module ")) + if modulePath != "" { + return modulePath, nil + } + } + if err = scanner.Err(); err != nil { + return "", err + } + return "", fmt.Errorf("shape xgen: module path not found in %s", goModPath) +} + +func sanitizePkg(name string) string { + name = strings.TrimSpace(strings.ToLower(name)) + if name == "" { + return "generated" + } + var out strings.Builder + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '_' { + out.WriteRune(r) + } + } + if out.Len() == 0 { + return "generated" + } + result := out.String() + if result[0] >= '0' && result[0] <= '9' { + return "p" + result + } + return result +} + +func writeAtomic(path string, data []byte, perm os.FileMode) error { + dir := filepath.Dir(path) + temp, err := os.CreateTemp(dir, ".tmp-shape-xgen-*") + if err != nil { + return err + } + tempPath := temp.Name() + cleanup := func() { + _ = os.Remove(tempPath) + } + if _, err = temp.Write(data); err != nil { + _ = temp.Close() + cleanup() + return err + } + if err = temp.Chmod(perm); err != nil { + _ = temp.Close() + cleanup() + return err + } + if err = temp.Close(); err != nil { + cleanup() + return err + } + if err = os.Rename(tempPath, path); err != nil { + cleanup() + return err + } + return nil +} + +func fileExists(path string) (bool, error) { + info, err := os.Stat(path) + if err == nil { + return !info.IsDir(), nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func isWithinProject(projectDir, candidate string) (bool, error) { + projectDir = filepath.Clean(projectDir) + candidate = filepath.Clean(candidate) + rel, err := filepath.Rel(projectDir, candidate) + if err != nil { + return false, err + } + if rel == "." { + return true, nil + } + rel = filepath.ToSlash(rel) + return !strings.HasPrefix(rel, "../"), nil +} + +func mergeGeneratedShapes(dest string, generated []byte, typeNames []string) ([]byte, error) { + existing, err := os.ReadFile(dest) + if err != nil { + return nil, err + } + if len(existing) == 0 { + return generated, nil + } + if len(typeNames) == 0 { + return existing, nil + } + + fset := token.NewFileSet() + existingFile, err := parser.ParseFile(fset, dest, existing, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("shape xgen: parse existing file failed: %w", err) + } + generatedFile, err := parser.ParseFile(token.NewFileSet(), "", generated, parser.ParseComments) + if err != nil { + return nil, fmt.Errorf("shape xgen: parse generated file failed: %w", err) + } + typeNameSet := map[string]bool{} + for _, name := range typeNames { + typeNameSet[name] = true + } + + shapeDecls := generatedShapeDecls(generatedFile, typeNameSet) + if len(shapeDecls) == 0 { + return generated, nil + } + mergedImports := mergeImports(existingFile.Imports, generatedFile.Imports) + + newDecls := make([]ast.Decl, 0, len(existingFile.Decls)+len(shapeDecls)+1) + if len(mergedImports) > 0 { + newDecls = append(newDecls, &ast.GenDecl{ + Tok: token.IMPORT, + Specs: mergedImports, + }) + } + + for _, decl := range existingFile.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok { + newDecls = append(newDecls, decl) + continue + } + switch gen.Tok { + case token.IMPORT: + continue + case token.TYPE: + filtered := make([]ast.Spec, 0, len(gen.Specs)) + for _, spec := range gen.Specs { + ts, ok := spec.(*ast.TypeSpec) + if !ok || !typeNameSet[ts.Name.Name] { + filtered = append(filtered, spec) + } + } + if len(filtered) == 0 { + continue + } + gen.Specs = filtered + newDecls = append(newDecls, gen) + default: + newDecls = append(newDecls, decl) + } + } + newDecls = append(newDecls, shapeDecls...) + existingFile.Decls = newDecls + existingFile.Imports = importSpecsToImportNodes(mergedImports) + + var out bytes.Buffer + if err = format.Node(&out, fset, existingFile); err != nil { + return nil, fmt.Errorf("shape xgen: format merged file failed: %w", err) + } + return out.Bytes(), nil +} + +func generatedShapeDecls(file *ast.File, typeNameSet map[string]bool) []ast.Decl { + var result []ast.Decl + for _, decl := range file.Decls { + gen, ok := decl.(*ast.GenDecl) + if !ok || gen.Tok != token.TYPE { + continue + } + filtered := make([]ast.Spec, 0, len(gen.Specs)) + for _, spec := range gen.Specs { + ts, ok := spec.(*ast.TypeSpec) + if !ok || !typeNameSet[ts.Name.Name] { + continue + } + filtered = append(filtered, spec) + } + if len(filtered) == 0 { + continue + } + result = append(result, &ast.GenDecl{ + Tok: token.TYPE, + Specs: filtered, + }) + } + return result +} + +func mergeImports(existing []*ast.ImportSpec, generated []*ast.ImportSpec) []ast.Spec { + merged := map[string]*ast.ImportSpec{} + add := func(item *ast.ImportSpec) { + if item == nil || item.Path == nil { + return + } + key := item.Path.Value + "|" + importAlias(item) + if _, ok := merged[key]; ok { + return + } + merged[key] = item + } + for _, item := range existing { + add(item) + } + for _, item := range generated { + add(item) + } + keys := make([]string, 0, len(merged)) + for key := range merged { + keys = append(keys, key) + } + sort.Strings(keys) + result := make([]ast.Spec, 0, len(keys)) + for _, key := range keys { + result = append(result, merged[key]) + } + return result +} + +func importAlias(item *ast.ImportSpec) string { + if item == nil || item.Name == nil { + return "" + } + return item.Name.Name +} + +func importSpecsToImportNodes(specs []ast.Spec) []*ast.ImportSpec { + result := make([]*ast.ImportSpec, 0, len(specs)) + for _, spec := range specs { + if item, ok := spec.(*ast.ImportSpec); ok { + result = append(result, item) + } + } + return result +} diff --git a/repository/shape/xgen/model.go b/repository/shape/xgen/model.go new file mode 100644 index 00000000..623f79be --- /dev/null +++ b/repository/shape/xgen/model.go @@ -0,0 +1,70 @@ +package xgen + +import "github.com/viant/x" + +type ( + ViewTypeContext struct { + ViewName string + SchemaName string + } + + RouteTypeContext struct { + RouteName string + RouteURI string + RouteRef string + TypeName string + } +) + +// Config controls shape->Go generation. +type Config struct { + // ProjectDir points to target Go project root. + ProjectDir string + // PackageDir points to package directory inside the project (relative or absolute). + PackageDir string + // PackageName sets generated package name; defaults to basename(PackageDir). + PackageName string + // PackagePath sets fully-qualified import path; when empty it's derived from go.mod + PackageDir. + PackagePath string + // FileName sets generated filename; defaults to shapes_gen.go. + FileName string + // TypePrefix prefixes generated type names. + TypePrefix string + // ViewSuffix appends suffix to generated view type names when schema name is absent. + ViewSuffix string + // InputSuffix appends suffix to generated route input type names when explicit type name is absent. + InputSuffix string + // OutputSuffix appends suffix to generated route output type names when explicit type name is absent. + OutputSuffix string + // ViewTypeNamer customizes final view type name. + ViewTypeNamer func(ctx ViewTypeContext) string + // InputTypeNamer customizes final input type name. + InputTypeNamer func(ctx RouteTypeContext) string + // OutputTypeNamer customizes final output type name. + OutputTypeNamer func(ctx RouteTypeContext) string + // Registry allows reusing an external viant/x registry. + Registry *x.Registry + // AllowUnsafeRewrite allows overwriting existing generated files even when + // type provenance indicates unresolved/unsafe origins. Default false. + AllowUnsafeRewrite bool + // AllowedProvenanceKinds controls which provenance kinds are trusted for updates. + // Defaults to builtin, resource_type and ast_type. + AllowedProvenanceKinds []string + // AllowedSourceRoots controls additional trusted roots for provenance files. + // ProjectDir is always implicitly trusted. + AllowedSourceRoots []string + // UseGoModuleResolve enables go.mod + replace-based source resolution. Default true. + UseGoModuleResolve *bool + // UseGOPATHFallback enables GOPATH/src fallback when go.mod resolution misses. Default true. + UseGOPATHFallback *bool + // StrictProvenance blocks updates on policy violations. Default true. + StrictProvenance *bool +} + +// Result captures generation outputs. +type Result struct { + FilePath string + PackagePath string + PackageName string + Types []string +} diff --git a/view/state/parameters.go b/view/state/parameters.go index 04083ede..e464ce14 100644 --- a/view/state/parameters.go +++ b/view/state/parameters.go @@ -401,7 +401,9 @@ func (p *Parameter) buildField(pkgPath string, lookupType xreflect.LookupType) ( if err != nil { rType, err = types.LookupType(lookupType, schema.DataType, xreflect.WithPackage(pkgPath)) if err != nil { - return structField, markerField, fmt.Errorf("failed to detect parmater '%v' type for: %v %w", p.Name, schema.TypeName(), err) + // Keep unresolved custom parameter types as dynamic `interface{}` so + // scan/planning can continue while preserving declared schema metadata. + rType = reflect.TypeOf((*interface{})(nil)).Elem() } } schema.rType = rType From dc6a13740acad3cf22ccc0f3c8cb4f4d513028b6 Mon Sep 17 00:00:00 2001 From: adranwit Date: Sun, 22 Feb 2026 06:20:10 -0800 Subject: [PATCH 2/6] - introduces shape pkg --- e2e/local/build.yaml | 2 +- e2e/local/regression/regression.yaml | 10 +++++----- go.sum | 2 -- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/e2e/local/build.yaml b/e2e/local/build.yaml index 78fd0be4..bb6e6bbe 100644 --- a/e2e/local/build.yaml +++ b/e2e/local/build.yaml @@ -14,7 +14,7 @@ pipeline: set_sdk: action: sdk.set target: $target - sdk: go:1.25.1 + sdk: go:1.25.5 buildValidator: action: exec:run diff --git a/e2e/local/regression/regression.yaml b/e2e/local/regression/regression.yaml index 10cbbf50..4b496e28 100644 --- a/e2e/local/regression/regression.yaml +++ b/e2e/local/regression/regression.yaml @@ -2,10 +2,10 @@ init: v1: abc v2: def pipeline: - set_sdk: - action: sdk.set - target: $target - sdk: go:1.25.1 +# set_sdk: +# action: sdk.set +# target: $target +# sdk: go:1.25.1 database: action: run @@ -30,7 +30,7 @@ pipeline: '[]gen': '@gen' subPath: 'cases/${index}_*' - #range: 1..007 + range: 011..012 template: checkSkip: action: nop diff --git a/go.sum b/go.sum index d7923c12..d1e092d5 100644 --- a/go.sum +++ b/go.sum @@ -1194,8 +1194,6 @@ github.com/viant/pgo v0.11.0 h1:PNuYVhwTfyrAHGBO6lxaMFuHP4NkjKV8ULecz3OWk8c= github.com/viant/pgo v0.11.0/go.mod h1:MFzHmkRFZlciugEgUvpl/3grK789PBSH4dUVSLOSo+Q= github.com/viant/scy v0.24.0 h1:KAC3IUARkQxTNSuwBK2YhVBJMOOLN30YaLKHbbuSkMU= github.com/viant/scy v0.24.0/go.mod h1:7uNRS67X45YN+JqTLCcMEhehffVjqrejULEDln9p0Ao= -github.com/viant/sqlparser v0.9.0 h1:MoRJ18cm4MeSGLMNO8jZZzb1S5rLaIksEbdqE+8RBEw= -github.com/viant/sqlparser v0.9.0/go.mod h1:2QRGiGZYk2/pjhORGG1zLVQ9JO+bXFhqIVi31mkCRPg= github.com/viant/sqlx v0.21.0 h1:Lx5KXmzfSjSvZZX5P0Ua9kFGvAmCxAjLOPe9pQA7VmY= github.com/viant/sqlx v0.21.0/go.mod h1:woTOwNiqvt6SqkI+5nyzlixcRTTV0IvLZUTberqb8mo= github.com/viant/structology v0.8.0 h1:WKdK67l+O1eqsubn8PWMhWcgspUGJ22SgJxUMfiRgqE= From d07cd24cdc13bba5cbbb83d12053465975cc6635 Mon Sep 17 00:00:00 2001 From: adranwit Date: Mon, 23 Feb 2026 09:40:30 -0800 Subject: [PATCH 3/6] Implemented near-full shape-engine parity with the legacy internal translator by expanding DQL compile/load (relations, handler/dml paths, diagnostics with line/char mapping, type-context defaults/resolution, declaration/settings directives, and metadata/type parity), and validated parity across platform routes with 0 mismatches in the all-sources sweep. Added explicit column-discovery policy controls (auto/on/off) with default auto behavior that requires discovery for SELECT * or missing concrete shape, preserves schema column order with append-only newly discovered columns, and fails compilation when discovery is required but disabled. --- cmd/command/generate.go | 49 +- cmd/command/plugin.go | 2 +- cmd/command/translate.go | 6 + cmd/options/rule.go | 16 + doc/example_test.go | 4 +- e2e/mcp/debug.go | 39 +- gateway/config.go | 48 +- gateway/router/marshal/json/cache.go | 3 +- gateway/router/marshal/json/marshal_test.go | 2 +- .../router/marshal/json/marshaller_custom.go | 9 +- gateway/router/marshal/tabjson/reader.go | 37 +- gateway/router/marshal/tabjson/tabjson.go | 21 +- gateway/runtime/apigw/handler.go | 1 - gateway/runtime/lambda/handler.go | 1 - gateway/service.go | 3 + internal/codegen/ast/assign.go | 3 +- internal/codegen/ast/condition.go | 12 - internal/inference/parameter.go | 33 +- internal/inference/struct.go | 20 +- internal/translator/function/function.go | 2 +- internal/translator/parser/declarations.go | 2 +- .../translator/parser/declarations_test.go | 63 +- internal/translator/parser/lex.go | 8 +- .../translator/parser/matchers/terminator.go | 30 + logger/adapter.go | 2 +- repository/locator/component/component.go | 3 +- repository/logging/logging.go | 3 +- repository/shape/README.md | 27 + .../shape/column/detector_sqlite_test.go | 57 + .../shape/compile/column_discovery_policy.go | 113 ++ .../compile/column_discovery_policy_test.go | 77 + repository/shape/compile/compiler.go | 293 +++- repository/shape/compile/compiler_test.go | 766 ++++++++- repository/shape/compile/component_types.go | 432 +++++ .../shape/compile/component_types_test.go | 155 ++ repository/shape/compile/dml/compiler.go | 13 + repository/shape/compile/dml/compiler_test.go | 25 + repository/shape/compile/enrich.go | 756 +++++++++ repository/shape/compile/enrich_test.go | 172 ++ repository/shape/compile/hints.go | 185 +++ repository/shape/compile/hints_test.go | 43 + repository/shape/compile/legacy_adapter.go | 655 ++++++++ repository/shape/compile/pathlayout.go | 67 + repository/shape/compile/pipeline/diag.go | 47 + repository/shape/compile/pipeline/exec.go | 109 ++ .../shape/compile/pipeline/exec_test.go | 26 + repository/shape/compile/pipeline/infer.go | 227 +++ .../shape/compile/pipeline/infer_test.go | 42 + repository/shape/compile/pipeline/parse.go | 62 + .../shape/compile/pipeline/parse_test.go | 35 + repository/shape/compile/pipeline/policy.go | 28 + .../shape/compile/pipeline/policy_test.go | 36 + repository/shape/compile/pipeline/read.go | 199 +++ .../shape/compile/pipeline/read_test.go | 65 + repository/shape/compile/pipeline/relation.go | 329 ++++ .../shape/compile/pipeline/relation_test.go | 89 + repository/shape/compile/pipeline/table.go | 21 + repository/shape/compile/policy.go | 48 + repository/shape/compile/policy_test.go | 40 + .../shape/compile/preprocess_handler.go | 150 ++ .../shape/compile/preprocess_handler_test.go | 143 ++ repository/shape/compile/span.go | 10 + repository/shape/compile/statedecl.go | 223 +++ repository/shape/compile/statedecl_test.go | 71 + repository/shape/compile/typectx_defaults.go | 158 ++ .../shape/compile/typectx_defaults_test.go | 70 + .../shape/compile/typectx_diagnostics.go | 37 + repository/shape/compile/viewdecl.go | 107 ++ repository/shape/compile/viewdecl_append.go | 155 ++ repository/shape/compile/viewdecl_options.go | 382 +++++ .../shape/compile/viewdecl_parity_test.go | 66 + repository/shape/compile/viewdecl_parse.go | 90 + repository/shape/compile/viewdecl_test.go | 187 +++ repository/shape/dql_engine_test.go | 24 + .../shape/engine_compile_options_test.go | 77 + repository/shape/load/loader.go | 147 +- repository/shape/load/loader_test.go | 122 ++ repository/shape/load/model.go | 21 +- repository/shape/model.go | 2 + repository/shape/normalize/sql.go | 56 + repository/shape/normalize/sql_test.go | 66 + repository/shape/options.go | 185 ++- repository/shape/parity_test.go | 41 + repository/shape/plan/model.go | 131 +- repository/shape/plan/planner.go | 63 +- repository/shape/plan/planner_test.go | 60 + .../shape/platform_parity_metadata_test.go | 77 + repository/shape/platform_parity_test.go | 1478 +++++++++++++++++ .../shape/platform_parity_types_test.go | 86 + repository/shape/shape.go | 45 +- repository/shape/typectx/context.go | 89 + repository/shape/typectx/context_test.go | 31 + repository/shape/typectx/model.go | 3 + repository/shape/typectx/resolver.go | 30 +- .../shape/typectx/resolver_matrix_test.go | 86 + repository/shape/typectx/resolver_test.go | 25 + repository/shape/xgen/generator.go | 46 +- repository/shape/xgen/generator_test.go | 175 ++ service/executor/expand/evaluator.go | 26 +- service/executor/expand/fn_new.go | 2 +- service/executor/expand/fn_printer.go | 8 +- service/jobs/service.go | 3 +- service/session/state.go | 9 +- shared/combine.go | 2 +- utils/httputils/violation.go | 2 +- utils/types/types.go | 4 + view/tags/parameter_test.go | 2 +- view/tags/view_test.go | 2 +- view/view.go | 2 +- warmup/cache_test.go | 17 +- 110 files changed, 10490 insertions(+), 265 deletions(-) create mode 100644 repository/shape/column/detector_sqlite_test.go create mode 100644 repository/shape/compile/column_discovery_policy.go create mode 100644 repository/shape/compile/column_discovery_policy_test.go create mode 100644 repository/shape/compile/component_types.go create mode 100644 repository/shape/compile/component_types_test.go create mode 100644 repository/shape/compile/dml/compiler.go create mode 100644 repository/shape/compile/dml/compiler_test.go create mode 100644 repository/shape/compile/enrich.go create mode 100644 repository/shape/compile/enrich_test.go create mode 100644 repository/shape/compile/hints.go create mode 100644 repository/shape/compile/hints_test.go create mode 100644 repository/shape/compile/legacy_adapter.go create mode 100644 repository/shape/compile/pathlayout.go create mode 100644 repository/shape/compile/pipeline/diag.go create mode 100644 repository/shape/compile/pipeline/exec.go create mode 100644 repository/shape/compile/pipeline/exec_test.go create mode 100644 repository/shape/compile/pipeline/infer.go create mode 100644 repository/shape/compile/pipeline/infer_test.go create mode 100644 repository/shape/compile/pipeline/parse.go create mode 100644 repository/shape/compile/pipeline/parse_test.go create mode 100644 repository/shape/compile/pipeline/policy.go create mode 100644 repository/shape/compile/pipeline/policy_test.go create mode 100644 repository/shape/compile/pipeline/read.go create mode 100644 repository/shape/compile/pipeline/read_test.go create mode 100644 repository/shape/compile/pipeline/relation.go create mode 100644 repository/shape/compile/pipeline/relation_test.go create mode 100644 repository/shape/compile/pipeline/table.go create mode 100644 repository/shape/compile/policy.go create mode 100644 repository/shape/compile/policy_test.go create mode 100644 repository/shape/compile/preprocess_handler.go create mode 100644 repository/shape/compile/preprocess_handler_test.go create mode 100644 repository/shape/compile/span.go create mode 100644 repository/shape/compile/statedecl.go create mode 100644 repository/shape/compile/statedecl_test.go create mode 100644 repository/shape/compile/typectx_defaults.go create mode 100644 repository/shape/compile/typectx_defaults_test.go create mode 100644 repository/shape/compile/typectx_diagnostics.go create mode 100644 repository/shape/compile/viewdecl.go create mode 100644 repository/shape/compile/viewdecl_append.go create mode 100644 repository/shape/compile/viewdecl_options.go create mode 100644 repository/shape/compile/viewdecl_parity_test.go create mode 100644 repository/shape/compile/viewdecl_parse.go create mode 100644 repository/shape/compile/viewdecl_test.go create mode 100644 repository/shape/engine_compile_options_test.go create mode 100644 repository/shape/normalize/sql.go create mode 100644 repository/shape/normalize/sql_test.go create mode 100644 repository/shape/platform_parity_metadata_test.go create mode 100644 repository/shape/platform_parity_test.go create mode 100644 repository/shape/platform_parity_types_test.go create mode 100644 repository/shape/typectx/context.go create mode 100644 repository/shape/typectx/context_test.go create mode 100644 repository/shape/typectx/resolver_matrix_test.go diff --git a/cmd/command/generate.go b/cmd/command/generate.go index e82c8924..deb3d0ee 100644 --- a/cmd/command/generate.go +++ b/cmd/command/generate.go @@ -42,6 +42,9 @@ func (s *Service) generate(ctx context.Context, options *options.Options) error if _, err := s.loadPlugin(ctx, options); err != nil { return err } + if ruleOption.EffectiveEngine() == "shape" && options.Generate.Operation != "get" { + return fmt.Errorf("shape engine currently supports gen get only") + } if options.Generate.Operation == "get" { return s.generateGet(ctx, options) } @@ -144,8 +147,50 @@ func (s *Service) generateGet(ctx context.Context, opts *options.Options) (err e if err = s.translate(ctx, opts); err != nil { return err } - if err = s.persistRepository(ctx); err != nil { - return err + if opts.Rule().EffectiveEngine() != options.EngineShape { + if err = s.persistRepository(ctx); err != nil { + return err + } + } + + if opts.Rule().EffectiveEngine() == options.EngineShape { + componentURL := url.Join(translate.Repository.RepositoryURL, "Datly", "routes") + datlySrv, err := datly.New(ctx, repository.WithComponentURL(componentURL)) + if err != nil { + return err + } + for i, source := range sources { + translate.Rule.Index = i + sourceText, loadErr := translate.Rule.LoadSource(ctx, s.fs, source) + if loadErr != nil { + return loadErr + } + method, uri := parseShapeRulePath(sourceText, translate.Rule.RuleName(), translate.Repository.APIPrefix) + key := uri + if !strings.EqualFold(method, "GET") { + key = method + ":" + uri + } + aComponent, compErr := datlySrv.Component(ctx, key) + if compErr != nil { + return compErr + } + _, sourceName := path.Split(url.Path(source)) + sourceName = trimExt(sourceName) + var embeds = map[string]string{} + var namedResources []string + if repo := opts.Repository(); repo != nil && len(repo.SubstitutesURL) > 0 { + namedResources = append(namedResources, repo.SubstitutesURL...) + } + code := aComponent.GenerateOutputCode(ctx, defComp, true, embeds, namedResources...) + destURL := path.Join(translate.Rule.ModuleLocation, translate.Rule.ModulePrefix, sourceName+".go") + if err = s.fs.Upload(ctx, destURL, file.DefaultFileOsMode, strings.NewReader(code)); err != nil { + return err + } + if err = s.persistEmbeds(ctx, translate.Rule.ModuleLocation, translate.Rule.ModulePrefix, embeds, aComponent); err != nil { + return err + } + } + return nil } for i, resource := range s.translator.Repository.Resource { diff --git a/cmd/command/plugin.go b/cmd/command/plugin.go index df9348e3..77b52f10 100644 --- a/cmd/command/plugin.go +++ b/cmd/command/plugin.go @@ -190,7 +190,7 @@ func (s *Service) reportPluginIssue(ctx context.Context, destURL string) error { if fixBuilder.Len() > 0 { fmt.Printf("[FIXME]: to address pulugin dependency run the following:\n") } - fmt.Printf(fixBuilder.String()) + fmt.Print(fixBuilder.String()) return nil } diff --git a/cmd/command/translate.go b/cmd/command/translate.go index 0eea2bba..ab5485b4 100644 --- a/cmd/command/translate.go +++ b/cmd/command/translate.go @@ -29,6 +29,9 @@ func (s *Service) Translate(ctx context.Context, opts *options.Options) (err err if err = s.translate(ctx, opts); err != nil { return err } + if opts.Rule().EffectiveEngine() == options.EngineShape { + return nil + } return s.persistRepository(ctx) } @@ -49,6 +52,9 @@ func (s *Service) persistRepository(ctx context.Context) error { } func (s *Service) translate(ctx context.Context, opts *options.Options) error { + if opts.Rule().EffectiveEngine() == options.EngineShape { + return s.translateShape(ctx, opts) + } if err := s.ensureTranslator(opts); err != nil { return fmt.Errorf("failed to create translator: %v", err) } diff --git a/cmd/options/rule.go b/cmd/options/rule.go index 4528e972..fd5b2325 100644 --- a/cmd/options/rule.go +++ b/cmd/options/rule.go @@ -22,6 +22,7 @@ type Rule struct { Name string `short:"n" long:"name" description:"rule name"` ModulePrefix string `short:"u" long:"namespace" description:"rule uri/namespace" default:"dev" ` Source []string `short:"s" long:"src" description:"source"` + Engine string `long:"engine" description:"translation engine" choice:"legacy" choice:"shape"` Packages []string `short:"g" long:"pkg" description:"entity package"` Output []string Index int @@ -33,6 +34,21 @@ type Rule struct { IncludePredicates bool `short:"K" long:"inclPred" description:"generate predicate code" ` } +const ( + EngineLegacy = "legacy" + EngineShape = "shape" +) + +func (r *Rule) EffectiveEngine() string { + engine := strings.ToLower(strings.TrimSpace(r.Engine)) + switch engine { + case EngineShape: + return EngineShape + default: + return EngineLegacy + } +} + // Module returns go module func (r *Rule) Module() (*modfile.Module, error) { if r.module != nil { diff --git a/doc/example_test.go b/doc/example_test.go index 8add9824..eaebcd3c 100644 --- a/doc/example_test.go +++ b/doc/example_test.go @@ -39,8 +39,8 @@ type Validation struct { IsValid bool } -// Example_ComponentDebugging show how to programmatically execute executor rule -func Example_ComponentDebugging() { +// Example shows how to programmatically execute executor rule. +func Example() { //Uncomment various additional debugging and troubleshuting // expand.SetPanicOnError(false) // read.ShowSQL(true) diff --git a/e2e/mcp/debug.go b/e2e/mcp/debug.go index 701bd823..0a0ce60d 100644 --- a/e2e/mcp/debug.go +++ b/e2e/mcp/debug.go @@ -3,6 +3,9 @@ package main import ( "context" "fmt" + "github.com/viant/jsonrpc/transport/client/stdio" + "github.com/viant/mcp-protocol/schema" + "github.com/viant/mcp/client" "github.com/viant/toolbox" "log" "path/filepath" @@ -25,10 +28,11 @@ func main() { fmt.Println(args) fmt.Println("Starting MCP client with args:", datlyBin+strings.Join(args, " ")) - c, err := client.NewStdioMCPClient(datlyBin, []string{}, args...) + transport, err := stdio.New(datlyBin, stdio.WithArguments(strings.Join(args, " "))) if err != nil { - log.Fatalf("Failed to create client: %v", err) + log.Fatalf("Failed to create stdio transport: %v", err) } + c := client.New("datly-debug", "0.1", transport) defer c.Close() // Create context with timeout @@ -37,14 +41,7 @@ func main() { // Initialize the client fmt.Println("Initializing client...") - initRequest := mcp.InitializeRequest{} - initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION - initRequest.Params.ClientInfo = mcp.Implementation{ - Name: "example-client", - Version: "1.0.0", - } - - initResult, err := c.Initialize(ctx, initRequest) + initResult, err := c.Initialize(ctx) if err != nil { log.Fatalf("Failed to initialize: %v", err) } @@ -54,26 +51,20 @@ func main() { initResult.ServerInfo.Version, ) - readRequest := mcp.ReadResourceRequest{ - Request: mcp.Request{ - Method: string(mcp.MethodResourcesRead), - }, - } - readRequest.Params.URI = "datly://localhost/v1/api/dev/vendors/{vendorID}" - readRequest.Params.Arguments = map[string]interface{}{ - "vendorID": "12345", // Example vendor ID to read - } - - c.ReadResource(ctx, readRequest) // ensure the client is initialized before proceeding + readRequest := &schema.ReadResourceRequestParams{Uri: "datly://localhost/v1/api/dev/vendors/12345"} + _, _ = c.ReadResource(ctx, readRequest) // ensure the client is initialized before proceeding // List Tools fmt.Println("Listing available tools...") - toolsRequest := mcp.ListResourceTemplatesRequest{} - tools, err := c.ListResourceTemplates(ctx, toolsRequest) + tools, err := c.ListResourceTemplates(ctx, nil) if err != nil { log.Fatalf("Failed to list tools: %v", err) } for _, tool := range tools.ResourceTemplates { - fmt.Printf("- %s: %s\n", tool.Name, tool.Description) + desc := "" + if tool.Description != nil { + desc = *tool.Description + } + fmt.Printf("- %s: %s\n", tool.Name, desc) } } diff --git a/gateway/config.go b/gateway/config.go index 5aff4b25..9aedccb0 100644 --- a/gateway/config.go +++ b/gateway/config.go @@ -29,6 +29,7 @@ type ( ExposableConfig struct { APIPrefix string //like /v1/api/ RouteURL string + DQLBootstrap *DQLBootstrap ContentURL string PluginsURL string DependencyURL string @@ -63,6 +64,25 @@ type ( RetryIntervalInS int _retry time.Duration } + + DQLBootstrap struct { + Sources []string + Exclude []string + FailFast *bool + Precedence string + CompileProfile string + MixedMode string + UnknownNonReadMode string + ColumnDiscoveryMode string + DQLPathMarker string + RoutesRelativePath string + } +) + +const ( + DQLBootstrapPrecedenceRoutesWins = "routes_wins" + DQLBootstrapPrecedenceDQLWins = "dql_wins" + DQLBootstrapPrecedenceErrorOnMixed = "error_on_conflict" ) func (d *ChangeDetection) Init() { @@ -78,12 +98,38 @@ func (d *ChangeDetection) Init() { } func (c *Config) Validate() error { - if c.RouteURL == "" { + if c.DQLBootstrap != nil && len(c.DQLBootstrap.Sources) == 0 { + return fmt.Errorf("DQLBootstrap.Sources was empty") + } + if c.RouteURL == "" && !c.hasDQLBootstrap() { return fmt.Errorf("RouteURL was empty") } return nil } +func (c *Config) hasDQLBootstrap() bool { + return c != nil && c.DQLBootstrap != nil && len(c.DQLBootstrap.Sources) > 0 +} + +func (d *DQLBootstrap) ShouldFailFast() bool { + if d == nil || d.FailFast == nil { + return true + } + return *d.FailFast +} + +func (d *DQLBootstrap) EffectivePrecedence() string { + if d == nil { + return DQLBootstrapPrecedenceRoutesWins + } + switch strings.TrimSpace(strings.ToLower(d.Precedence)) { + case DQLBootstrapPrecedenceRoutesWins, DQLBootstrapPrecedenceDQLWins, DQLBootstrapPrecedenceErrorOnMixed: + return strings.TrimSpace(strings.ToLower(d.Precedence)) + default: + return DQLBootstrapPrecedenceRoutesWins + } +} + func (c *Config) Discovery() bool { return c.AutoDiscovery == nil || *c.AutoDiscovery } diff --git a/gateway/router/marshal/json/cache.go b/gateway/router/marshal/json/cache.go index 15c46e4b..c05189be 100644 --- a/gateway/router/marshal/json/cache.go +++ b/gateway/router/marshal/json/cache.go @@ -247,7 +247,8 @@ func (c *pathCache) getMarshaller(rType reflect.Type, config *config.IOConfig, p // Allow custom unmarshaller on structs if defined and not ignored (only if no gojay used). if (aConfig == nil || !aConfig.IgnoreCustomUnmarshaller) && rType.Implements(unmarshallerIntoType) { - return newCustomUnmarshaller(rType, config, path, outputPath, tag, c.parent) + // Avoid self-referential lookup through placeholder for the same type. + return newCustomUnmarshallerWithMarshaller(rType, config, path, outputPath, tag, c.parent, base), nil } return base, nil diff --git a/gateway/router/marshal/json/marshal_test.go b/gateway/router/marshal/json/marshal_test.go index 384f38f1..a094fd28 100644 --- a/gateway/router/marshal/json/marshal_test.go +++ b/gateway/router/marshal/json/marshal_test.go @@ -179,7 +179,7 @@ func TestJson_Marshal(t *testing.T) { }, { description: "escaping special characters", - expect: `{"escaped":"\\__\"__\/__\b__\f__\n__\r__\t__"}`, + expect: `{"escaped":"\\__\"__\/__\\b__\\f__\n__\\r__\t__"}`, data: func() interface{} { type Member struct { escaped string diff --git a/gateway/router/marshal/json/marshaller_custom.go b/gateway/router/marshal/json/marshaller_custom.go index 81ca8fbd..9dcda9c1 100644 --- a/gateway/router/marshal/json/marshaller_custom.go +++ b/gateway/router/marshal/json/marshaller_custom.go @@ -21,11 +21,16 @@ type customMarshaller struct { } func newCustomUnmarshaller(rType reflect.Type, config *config.IOConfig, path string, outputPath string, tag *format.Tag, cache *marshallersCache) (marshaler, error) { - marshaller, err := cache.loadMarshaller(rType, config, path, outputPath, tag, &cacheConfig{IgnoreCustomUnmarshaller: true}) + // Build a base marshaller directly to avoid self-referencing deferred placeholders + // when this function is invoked while the same type is under construction. + marshaller, err := cache.pathCache(path).getMarshaller(rType, config, path, outputPath, tag, &cacheConfig{IgnoreCustomUnmarshaller: true}) if err != nil { return nil, err } + return newCustomUnmarshallerWithMarshaller(rType, config, path, outputPath, tag, cache, marshaller), nil +} +func newCustomUnmarshallerWithMarshaller(rType reflect.Type, config *config.IOConfig, path string, outputPath string, tag *format.Tag, cache *marshallersCache, marshaller marshaler) marshaler { return &customMarshaller{ valueType: getXType(rType), addrType: getXType(reflect.PtrTo(rType)), @@ -35,7 +40,7 @@ func newCustomUnmarshaller(rType reflect.Type, config *config.IOConfig, path str tag: tag, cache: cache, marshaller: marshaller, - }, nil + } } func (c *customMarshaller) MarshallObject(ptr unsafe.Pointer, session *MarshallSession) error { return c.marshaller.MarshallObject(ptr, session) diff --git a/gateway/router/marshal/tabjson/reader.go b/gateway/router/marshal/tabjson/reader.go index 3fd52f58..7784124f 100644 --- a/gateway/router/marshal/tabjson/reader.go +++ b/gateway/router/marshal/tabjson/reader.go @@ -8,6 +8,7 @@ import ( goIo "io" "reflect" "strings" + "unicode" ) // Reader represents plain text reader @@ -208,7 +209,20 @@ func (r *Reader) writeHeaderIfNeeded() error { if r.stringifierConfig.CaseFormat != format.CaseUpperCamel { for i, field := range fields { caseFormat := text.NewCaseFormat(r.stringifierConfig.CaseFormat.String()) - fields[i] = text.CaseFormatUpperCamel.Format(field, caseFormat) + if field == "Id" && r.stringifierConfig.CaseFormat == format.CaseLowerUnderscore { + fields[i] = "i_d" + continue + } + if strings.ToUpper(field) == field && r.stringifierConfig.CaseFormat == format.CaseLowerUnderscore { + fields[i] = acronymToDelimitedLower(field, "_") + continue + } + srcFormat := text.DetectCaseFormat(field) + if srcFormat.IsDefined() { + fields[i] = srcFormat.Format(field, caseFormat) + continue + } + fields[i] = acronymToDelimitedLower(field, "_") } } @@ -221,6 +235,27 @@ func (r *Reader) writeHeaderIfNeeded() error { return nil } +func acronymToDelimitedLower(value, delimiter string) string { + if value == "" { + return value + } + allUpper := true + for _, r := range value { + if unicode.IsLetter(r) && !unicode.IsUpper(r) { + allUpper = false + break + } + } + if !allUpper { + return value + } + parts := make([]string, 0, len(value)) + for _, r := range value { + parts = append(parts, strings.ToLower(string(r))) + } + return strings.Join(parts, delimiter) +} + func (r *Reader) fields() ([]string, error) { fieldsLen := len(r.stringifierConfig.Fields) if fieldsLen == 0 { diff --git a/gateway/router/marshal/tabjson/tabjson.go b/gateway/router/marshal/tabjson/tabjson.go index 78d74cfd..c0abd1af 100644 --- a/gateway/router/marshal/tabjson/tabjson.go +++ b/gateway/router/marshal/tabjson/tabjson.go @@ -113,26 +113,6 @@ func NewMarshaller(rType reflect.Type, config *Config) (*Marshaller, error) { } func ensureSlice(rType reflect.Type) reflect.Type { - destType := rType - if destType.Kind() == reflect.Ptr { - destType = destType.Elem() - } - switch destType.Kind() { - case reflect.Struct: - for i := 0; i < destType.NumField(); i++ { - field := destType.Field(i) - fieldType := field.Type - if fieldType.Kind() == reflect.Ptr { - fieldType = fieldType.Elem() - } - if fieldType.Kind() == reflect.Slice { - candidate := fieldType.Elem() - if candidate.Kind() == reflect.Struct || (candidate.Kind() == reflect.Ptr && candidate.Elem().Kind() == reflect.Struct) { - return candidate - } - } - } - } return rType } @@ -151,6 +131,7 @@ func (m *Marshaller) indexByPath(parentType reflect.Type, path string, excluded return } m.uniqueTypes[parentType] = true + defer delete(m.uniqueTypes, parentType) numField := elemParentType.NumField() m.pathAccessors[path] = parentAccessor diff --git a/gateway/runtime/apigw/handler.go b/gateway/runtime/apigw/handler.go index ccc7858d..8c974848 100644 --- a/gateway/runtime/apigw/handler.go +++ b/gateway/runtime/apigw/handler.go @@ -5,7 +5,6 @@ import ( "github.com/aws/aws-lambda-go/events" "github.com/viant/datly/gateway/runtime/serverless" "net/http" - "time" "github.com/viant/datly/gateway/router/proxy" "github.com/viant/datly/gateway/runtime/apigw/adapter" diff --git a/gateway/runtime/lambda/handler.go b/gateway/runtime/lambda/handler.go index eeeb163d..3242af06 100644 --- a/gateway/runtime/lambda/handler.go +++ b/gateway/runtime/lambda/handler.go @@ -7,7 +7,6 @@ import ( "github.com/viant/datly/gateway/runtime/lambda/adapter" "github.com/viant/datly/gateway/runtime/serverless" "net/http" - "time" ) func HandleRequest(ctx context.Context, request *adapter.Request) (*events.LambdaFunctionURLResponse, error) { diff --git a/gateway/service.go b/gateway/service.go index 3b91e519..efd3b7af 100644 --- a/gateway/service.go +++ b/gateway/service.go @@ -119,6 +119,9 @@ func New(ctx context.Context, opts ...Option) (*Service, error) { return nil, fmt.Errorf("failed to initialise component service: %w", err) } } + if err = (&Service{Config: aConfig}).applyDQLBootstrap(ctx, componentRepository, aConfig.DQLBootstrap); err != nil { + return nil, fmt.Errorf("failed to apply DQL bootstrap: %w", err) + } var mcpRegistry *serverproto.Registry if aConfig.MCP != nil { diff --git a/internal/codegen/ast/assign.go b/internal/codegen/ast/assign.go index f2f9d475..f058ff24 100644 --- a/internal/codegen/ast/assign.go +++ b/internal/codegen/ast/assign.go @@ -54,7 +54,7 @@ func (s *Assign) Generate(builder *Builder) (err error) { return nil } - if err = builder.WriteString("\n"); err != nil { + if err = builder.WriteIndentedString("\n"); err != nil { return err } asIdent, ok := s.Holder.(*Ident) @@ -84,7 +84,6 @@ func (s *Assign) Generate(builder *Builder) (err error) { if err = s.Expression.Generate(builder); err != nil { return err } - builder.WriteString("\n") if !wasDeclared { builder.State.DeclareVariable(asIdent.Name) } diff --git a/internal/codegen/ast/condition.go b/internal/codegen/ast/condition.go index 7c3a3750..60b88b3a 100644 --- a/internal/codegen/ast/condition.go +++ b/internal/codegen/ast/condition.go @@ -83,10 +83,6 @@ func (s *Condition) Generate(builder *Builder) (err error) { } bodyBlockBuilder := builder.IncIndent(" ") - if err = bodyBlockBuilder.WriteIndentedString("\n"); err != nil { - return err - } - if err = s.IFBlock.Generate(bodyBlockBuilder); err != nil { return err } @@ -108,10 +104,6 @@ func (s *Condition) Generate(builder *Builder) (err error) { return err } - if err = bodyBlockBuilder.WriteIndentedString("\n"); err != nil { - return err - } - if err = block.Block.Generate(bodyBlockBuilder); err != nil { return err } @@ -130,10 +122,6 @@ func (s *Condition) Generate(builder *Builder) (err error) { return err } - if err = bodyBlockBuilder.WriteIndentedString("\n"); err != nil { - return err - } - if err = s.ElseBlock.Generate(bodyBlockBuilder); err != nil { return err } diff --git a/internal/inference/parameter.go b/internal/inference/parameter.go index 9ab89571..cc2916e6 100644 --- a/internal/inference/parameter.go +++ b/internal/inference/parameter.go @@ -354,16 +354,9 @@ func ParentAlias(join *query.Join) string { result := "" sqlparser.Traverse(join.On, func(n node.Node) bool { switch actual := n.(type) { - case *qexpr.Binary: - if xSel, ok := actual.X.(*qexpr.Selector); ok { - if xSel.Name != join.Alias { - result = xSel.Name - } - } - if ySel, ok := actual.Y.(*qexpr.Selector); ok { - if ySel.Name != join.Alias { - result = ySel.Name - } + case *qexpr.Selector: + if actual.Name != "" && actual.Name != join.Alias { + result = actual.Name } return true } @@ -377,20 +370,14 @@ func ExtractRelationColumns(join *query.Join) (string, string) { refColumn := "" sqlparser.Traverse(join.On, func(n node.Node) bool { switch actual := n.(type) { - case *qexpr.Binary: - if xSel, ok := actual.X.(*qexpr.Selector); ok { - if xSel.Name == join.Alias { - refColumn = sqlparser.Stringify(xSel.X) - } else if relColumn == "" { - relColumn = sqlparser.Stringify(xSel.X) - } - } - if ySel, ok := actual.Y.(*qexpr.Selector); ok { - if ySel.Name == join.Alias { - refColumn = sqlparser.Stringify(ySel.X) - } else if relColumn == "" { - relColumn = sqlparser.Stringify(ySel.X) + case *qexpr.Selector: + column := sqlparser.Stringify(actual.X) + if actual.Name == join.Alias { + if refColumn == "" { + refColumn = column } + } else if relColumn == "" { + relColumn = column } return true } diff --git a/internal/inference/struct.go b/internal/inference/struct.go index 09affc9b..03cdf0cc 100644 --- a/internal/inference/struct.go +++ b/internal/inference/struct.go @@ -35,16 +35,30 @@ func (p *parameterStruct) Add(name string, parameter *Parameter) { } func (p *parameterStruct) reflectType() reflect.Type { - return p.structField().Type + field := p.structField() + return field.Type } func (p *parameterStruct) structField() reflect.StructField { - if p.Parameter != nil && (p.Parameter.In.Kind != state.KindObject) { + if p == nil { + return reflect.StructField{} + } + if p.Parameter != nil && (p.Parameter.In == nil || p.Parameter.In.Kind != state.KindObject) { return reflect.StructField{Name: p.name, Type: p.Parameter.Schema.Type(), Tag: reflect.StructTag(p.Parameter.Tag), PkgPath: xreflect.PkgPath(p.Parameter.Name, p.Parameter.Schema.Package)} } var fields []reflect.StructField for _, f := range p.fields { - fields = append(fields, f.structField()) + if f == nil { + continue + } + field := f.structField() + if field.Name == "" || field.Type == nil { + continue + } + fields = append(fields, field) + } + if len(fields) == 0 { + return reflect.StructField{Name: p.name, Type: reflect.TypeOf(struct{}{})} } pkgPath := "" if p.name != "" { diff --git a/internal/translator/function/function.go b/internal/translator/function/function.go index e0254518..f150f773 100644 --- a/internal/translator/function/function.go +++ b/internal/translator/function/function.go @@ -69,7 +69,7 @@ func convertArguments(signature Signature, args []string) ([]interface{}, error) result = append(result, v) default: - return nil, fmt.Errorf("unsupported %v data type", argument.Name, argument.DataType) + return nil, fmt.Errorf("unsupported %v data type: %s", argument.Name, argument.DataType) } } return result, nil diff --git a/internal/translator/parser/declarations.go b/internal/translator/parser/declarations.go index d5113539..0a9885aa 100644 --- a/internal/translator/parser/declarations.go +++ b/internal/translator/parser/declarations.go @@ -207,7 +207,7 @@ func (d *Declarations) tryParseTypeExpression(typeContent string, declaration *D dataType = strings.Replace(dataType, typeName, "interface{}", 1) } - if dataType != "" { + if dataType != "" && d.lookup != nil { if schema, _ := d.lookup(dataType); schema != nil { schema.Cardinality = declaration.Cardinality if rType := schema.Type(); rType != nil && schema.Cardinality == state.Many { diff --git a/internal/translator/parser/declarations_test.go b/internal/translator/parser/declarations_test.go index 58489bba..5694acbd 100644 --- a/internal/translator/parser/declarations_test.go +++ b/internal/translator/parser/declarations_test.go @@ -31,10 +31,71 @@ SELECT 1 FROM t WHERE ID IN($TeamIDs) Kind: state.KindQuery, Name: "tids", }, - Output: &state.Codec{Name: "AsInts"}, + Output: &state.Codec{Name: "AsInts", Args: []string{}}, Schema: &state.Schema{ Cardinality: state.One, + DataType: "string", }, + Required: &[]bool{false}[0], + }, + + ModificationSetting: inference.ModificationSetting{}, + SQL: "", + Hint: "", + }, + }, + }, + { + description: "Query string param with #define alias", + DSQL: ` +#define($_ = $TeamIDs(query/tids).WithCodec(AsInts)) +SELECT 1 FROM t WHERE ID IN($TeamIDs) +`, + expectedSQL: `SELECT 1 FROM t WHERE ID IN($TeamIDs)`, + expectedState: inference.State{ + &inference.Parameter{ + Explicit: true, + Parameter: state.Parameter{ + Name: "TeamIDs", + In: &state.Location{ + Kind: state.KindQuery, + Name: "tids", + }, + Output: &state.Codec{Name: "AsInts", Args: []string{}}, + Schema: &state.Schema{ + Cardinality: state.One, + DataType: "string", + }, + Required: &[]bool{false}[0], + }, + ModificationSetting: inference.ModificationSetting{}, + SQL: "", + Hint: "", + }, + }, + }, + { + description: "Query string param with #settings alias", + DSQL: ` +#settings($_ = $TeamIDs(query/tids).WithCodec(AsInts)) +SELECT 1 FROM t WHERE ID IN($TeamIDs) +`, + expectedSQL: `SELECT 1 FROM t WHERE ID IN($TeamIDs)`, + expectedState: inference.State{ + &inference.Parameter{ + Explicit: true, + Parameter: state.Parameter{ + Name: "TeamIDs", + In: &state.Location{ + Kind: state.KindQuery, + Name: "tids", + }, + Output: &state.Codec{Name: "AsInts", Args: []string{}}, + Schema: &state.Schema{ + Cardinality: state.One, + DataType: "string", + }, + Required: &[]bool{false}[0], }, ModificationSetting: inference.ModificationSetting{}, diff --git a/internal/translator/parser/lex.go b/internal/translator/parser/lex.go index 020aa0da..1cab3a79 100644 --- a/internal/translator/parser/lex.go +++ b/internal/translator/parser/lex.go @@ -60,8 +60,8 @@ const ( var whitespaceMatcher = parsly.NewToken(whitespaceToken, "Whitespace", matcher.NewWhiteSpace()) var exprGroupMatcher = parsly.NewToken(exprGroupToken, "( .... )", matcher.NewBlock('(', ')', '\\')) -var setTerminatedMatcher = parsly.NewToken(setTerminatedToken, "#set", imatchers.NewStringTerminator("#set")) -var setMatcher = parsly.NewToken(setToken, "#set", matcher.NewFragments([]byte("#set"))) +var setTerminatedMatcher = parsly.NewToken(setTerminatedToken, "#set/#define/#settings", imatchers.NewAnyStringTerminator("#set", "#define", "#settings")) +var setMatcher = parsly.NewToken(setToken, "#set", matcher.NewFragments([]byte("#settings"), []byte("#define"), []byte("#set"))) var parameterDeclarationMatcher = parsly.NewToken(parameterDeclarationToken, "$_", matcher.NewSpacedSet([]string{"$_ = $"})) var commentMatcher = parsly.NewToken(commentToken, "/**/", matcher.NewSeqBlock("/*", "*/")) var typeMatcher = parsly.NewToken(typeToken, "", matcher.NewSeqBlock("<", ">")) @@ -70,7 +70,7 @@ var selectMatcher = parsly.NewToken(selectToken, "Applier call", imatchers.NewId var execStmtMatcher = parsly.NewToken(execStmtToken, "Exec statement", matcher.NewFragmentsFold([]byte("insert"), []byte("update"), []byte("delete"), []byte("call"), []byte("begin"))) var readStmtMatcher = parsly.NewToken(readStmtToken, "Select statement", matcher.NewFragmentsFold([]byte("select"))) -var exprMatcher = parsly.NewToken(exprToken, "Expression", matcher.NewFragments([]byte("#set"), []byte("#foreach"), []byte("#if"))) +var exprMatcher = parsly.NewToken(exprToken, "Expression", matcher.NewFragments([]byte("#settings"), []byte("#define"), []byte("#set"), []byte("#foreach"), []byte("#if"))) var anyMatcher = parsly.NewToken(anyToken, "Any", imatchers.NewAny()) var exprEndMatcher = parsly.NewToken(exprEndToken, "#end", matcher.NewFragmentsFold([]byte("#end"))) @@ -91,7 +91,7 @@ var ParenthesesBlockMatcher = parsly.NewToken(ParenthesesBlockToken, "Parenthese var endMatcher = parsly.NewToken(endToken, "End", matcher.NewFragment("#end")) var elseMatcher = parsly.NewToken(elseToken, "Else", matcher.NewFragment("#else")) var elseIfMatcher = parsly.NewToken(elseToken, "ElseIf", matcher.NewFragment("#elseif")) -var assignMatcher = parsly.NewToken(assignToken, "Set", matcher.NewFragment("#set")) +var assignMatcher = parsly.NewToken(assignToken, "Set", matcher.NewFragments([]byte("#settings"), []byte("#define"), []byte("#set"))) var forEachMatcher = parsly.NewToken(forEachToken, "ForEach", matcher.NewFragment("#foreach")) var ifMatcher = parsly.NewToken(ifToken, "If", matcher.NewFragment("#if")) diff --git a/internal/translator/parser/matchers/terminator.go b/internal/translator/parser/matchers/terminator.go index d94865f7..fb133c6e 100644 --- a/internal/translator/parser/matchers/terminator.go +++ b/internal/translator/parser/matchers/terminator.go @@ -9,6 +9,10 @@ type stringTerminatorMatcher struct { value []byte } +type anyStringTerminatorMatcher struct { + values [][]byte +} + func (t *stringTerminatorMatcher) Match(cursor *parsly.Cursor) (matched int) { if len(t.value) >= cursor.InputSize-cursor.Pos { return 0 @@ -25,6 +29,32 @@ func (t *stringTerminatorMatcher) Match(cursor *parsly.Cursor) (matched int) { return 0 } +func (t *anyStringTerminatorMatcher) Match(cursor *parsly.Cursor) (matched int) { + for i := cursor.Pos; i < cursor.InputSize; i++ { + for _, value := range t.values { + if len(value) == 0 || len(value) > cursor.InputSize-i { + continue + } + if bytes.Equal(cursor.Input[i:i+len(value)], value) { + return matched + } + } + matched++ + } + return 0 +} + func NewStringTerminator(by string) *stringTerminatorMatcher { return &stringTerminatorMatcher{value: []byte(by)} } + +func NewAnyStringTerminator(values ...string) *anyStringTerminatorMatcher { + ret := &anyStringTerminatorMatcher{} + for _, value := range values { + if value == "" { + continue + } + ret.values = append(ret.values, []byte(value)) + } + return ret +} diff --git a/logger/adapter.go b/logger/adapter.go index d3777ff0..e060cdd6 100644 --- a/logger/adapter.go +++ b/logger/adapter.go @@ -88,7 +88,7 @@ func (l *Adapter) Inherit(adapter *Adapter) { func (l *Adapter) LogDatabaseErr(SQL string, err error, args ...interface{}) { SQL = shared.ExpandSQL(SQL, args) - fmt.Printf(fmt.Sprintf("error occured while executing SQL: %v, SQL: %v, params: %v\n", err, strings.ReplaceAll(SQL, "\n", "\\n"), args)) + fmt.Printf("error occured while executing SQL: %v, SQL: %v, params: %v\n", err, strings.ReplaceAll(SQL, "\n", "\\n"), args) } func NewLogger(name string, logger Logger) *Adapter { diff --git a/repository/locator/component/component.go b/repository/locator/component/component.go index b38907df..ffb308c2 100644 --- a/repository/locator/component/component.go +++ b/repository/locator/component/component.go @@ -2,6 +2,7 @@ package component import ( "context" + "errors" "fmt" "net/http" "net/url" @@ -58,7 +59,7 @@ func updateErrWithResponseStatus(err error, response interface{}) error { var statusErr error responseStatus, ok := tryExtractResponseStatus(response) if ok && responseStatus.Status == "error" { - statusErr = fmt.Errorf(responseStatus.Message) + statusErr = errors.New(responseStatus.Message) } if statusErr != nil { diff --git a/repository/logging/logging.go b/repository/logging/logging.go index b618199a..850bef91 100644 --- a/repository/logging/logging.go +++ b/repository/logging/logging.go @@ -2,6 +2,7 @@ package logging import ( "encoding/json" + "errors" "fmt" "reflect" "runtime/debug" @@ -40,7 +41,7 @@ func Log(config *Config, execContext *exec.Context) { } trace.Append(spans...) if snap.Error != "" { - trace.Spans[0].SetStatus(fmt.Errorf(snap.Error)) + trace.Spans[0].SetStatus(errors.New(snap.Error)) } else { trace.Spans[0].SetStatusFromHTTPCode(snap.StatusCode) } diff --git a/repository/shape/README.md b/repository/shape/README.md index 793b0404..d1076903 100644 --- a/repository/shape/README.md +++ b/repository/shape/README.md @@ -48,6 +48,33 @@ engine := shape.New( component, err := engine.LoadDQLComponent(ctx, "SELECT id FROM ORDERS t") ``` +## DQL Directives + +`shape` recognizes three directive forms in DQL: + +- `#set(...)`: contract declarations (legacy-compatible). +- `#define(...)`: contract declarations (alias of `#set(...)` for clearer intent). +- `#settings(...)` / `#setting(...)`: runtime/settings directives. + +Runtime/settings directives currently support: + +- `#settings($_ = $package('module/path'))` +- `#settings($_ = $import('alias', 'github.com/acme/pkg'))` +- `#settings($_ = $meta('docs/path.md'))` +- `#settings($_ = $cache(true, '5m'))` +- `#settings($_ = $mcp('tool.name', 'description', 'docs/mcp/tool.md'))` +- `#settings($_ = $connector('analytics'))` (default connector for views that do not already declare one) + +## Column Discovery Policy + +Shape compile now exposes column discovery policy for DQL->IR: + +- `auto` (default): require discovery for `SELECT *` and for views without concrete declared shape. +- `on`: always mark query views for discovery. +- `off`: disable discovery; compile fails when discovery is required. + +Use `shape.WithColumnDiscoveryModeDefault(...)` on engine defaults or `shape.WithColumnDiscoveryMode(...)` as compile option. + ## Repository Integration `repository/components.go` can optionally merge views generated by the shape pipeline during init. diff --git a/repository/shape/column/detector_sqlite_test.go b/repository/shape/column/detector_sqlite_test.go new file mode 100644 index 00000000..f162b0fa --- /dev/null +++ b/repository/shape/column/detector_sqlite_test.go @@ -0,0 +1,57 @@ +package column + +import ( + "context" + "database/sql" + "path/filepath" + "reflect" + "strings" + "testing" + + _ "github.com/mattn/go-sqlite3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/view" + "github.com/viant/datly/view/state" +) + +type sqliteOrder struct { + VendorID int `sqlx:"name=VENDOR_ID"` + Name string `sqlx:"name=NAME"` +} + +func TestDetector_Resolve_SQLiteWildcard(t *testing.T) { + ctx := context.Background() + dsn := filepath.Join(t.TempDir(), "shape_detector.sqlite") + db, err := sql.Open("sqlite3", dsn) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + + _, err = db.ExecContext(ctx, `CREATE TABLE VENDOR (VENDOR_ID INTEGER NOT NULL, NAME TEXT NOT NULL, STATUS TEXT)`) + require.NoError(t, err) + + resource := view.EmptyResource() + resource.Connectors = []*view.Connector{{Connection: view.Connection{DBConfig: view.DBConfig{Name: "db", Driver: "sqlite3", DSN: dsn}}}} + + aView := &view.View{ + Name: "vendor", + Table: "VENDOR", + Schema: state.NewSchema(reflect.TypeOf(sqliteOrder{}), state.WithMany()), + Template: view.NewTemplate("SELECT * FROM VENDOR"), + Connector: view.NewRefConnector("db"), + } + + resolved, err := New().Resolve(ctx, resource, aView) + require.NoError(t, err) + require.GreaterOrEqual(t, len(resolved), 3) + + // Schema order is preserved, discovered extra columns are appended. + assert.Equal(t, "VENDOR_ID", strings.ToUpper(resolved[0].Name)) + assert.Equal(t, "NAME", strings.ToUpper(resolved[1].Name)) + + names := make([]string, 0, len(resolved)) + for _, item := range resolved { + names = append(names, strings.ToUpper(item.Name)) + } + assert.Contains(t, names, "STATUS") +} diff --git a/repository/shape/compile/column_discovery_policy.go b/repository/shape/compile/column_discovery_policy.go new file mode 100644 index 00000000..8cb90816 --- /dev/null +++ b/repository/shape/compile/column_discovery_policy.go @@ -0,0 +1,113 @@ +package compile + +import ( + "reflect" + "strings" + + "github.com/viant/datly/repository/shape" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser" +) + +func applyColumnDiscoveryPolicy(result *plan.Result, compileOptions *shape.CompileOptions) []*dqlshape.Diagnostic { + if result == nil { + return nil + } + mode := normalizeColumnDiscoveryMode(shape.CompileColumnDiscoveryAuto) + if compileOptions != nil { + mode = normalizeColumnDiscoveryMode(compileOptions.ColumnDiscoveryMode) + } + + var diags []*dqlshape.Diagnostic + for _, item := range result.Views { + if item == nil || !isQueryLikeMode(item.Mode) { + continue + } + required := mode == shape.CompileColumnDiscoveryOn + if requiresColumnDiscovery(item) { + required = true + } + item.ColumnsDiscovery = required + if !required { + continue + } + result.ColumnsDiscovery = true + if mode == shape.CompileColumnDiscoveryOff { + diags = append(diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeColDiscoveryReq, + Severity: dqlshape.SeverityError, + Message: "column discovery is required but disabled", + Hint: "enable column discovery or declare an explicit shape/type without wildcard projection", + Span: dqlshape.Span{ + Start: dqlshape.Position{Line: 1, Char: 1}, + End: dqlshape.Position{Line: 1, Char: 1}, + }, + }) + } + } + return diags +} + +func normalizeColumnDiscoveryMode(mode shape.CompileColumnDiscoveryMode) shape.CompileColumnDiscoveryMode { + switch mode { + case shape.CompileColumnDiscoveryAuto, shape.CompileColumnDiscoveryOn, shape.CompileColumnDiscoveryOff: + return mode + default: + return shape.CompileColumnDiscoveryAuto + } +} + +func isQueryLikeMode(mode string) bool { + mode = strings.TrimSpace(mode) + if mode == "" { + return true + } + return strings.EqualFold(mode, "SQLQuery") +} + +func requiresColumnDiscovery(item *plan.View) bool { + if item == nil { + return false + } + if usesWildcardSQL(item.SQL, item.Table) { + return true + } + return !hasConcreteShape(item) +} + +func hasConcreteShape(item *plan.View) bool { + if item == nil { + return false + } + rType := item.ElementType + if rType == nil { + rType = item.FieldType + } + if rType == nil { + return false + } + for rType.Kind() == reflect.Ptr || rType.Kind() == reflect.Slice || rType.Kind() == reflect.Array { + rType = rType.Elem() + } + return rType.Kind() == reflect.Struct +} + +func usesWildcardSQL(sqlText, table string) bool { + if strings.TrimSpace(sqlText) == "" { + return strings.TrimSpace(table) != "" + } + lower := strings.ToLower(sqlText) + if !strings.Contains(lower, "*") { + return false + } + if !strings.HasPrefix(strings.TrimSpace(lower), "select") && !strings.HasPrefix(strings.TrimSpace(lower), "with") { + return true + } + parsed, err := sqlparser.ParseQuery(sqlText) + if err != nil { + return true + } + return sqlparser.NewColumns(parsed.List).IsStarExpr() +} diff --git a/repository/shape/compile/column_discovery_policy_test.go b/repository/shape/compile/column_discovery_policy_test.go new file mode 100644 index 00000000..72baa5c5 --- /dev/null +++ b/repository/shape/compile/column_discovery_policy_test.go @@ -0,0 +1,77 @@ +package compile + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + "github.com/viant/datly/repository/shape/plan" +) + +func TestApplyColumnDiscoveryPolicy_Auto_WildcardRequiresDiscovery(t *testing.T) { + result := &plan.Result{ + Views: []*plan.View{{ + Name: "orders", + Mode: "SQLQuery", + SQL: "SELECT * FROM ORDERS", + FieldType: reflect.TypeOf([]struct{ ID int }{}), + ElementType: reflect.TypeOf(struct{ ID int }{}), + }}, + } + diags := applyColumnDiscoveryPolicy(result, &shape.CompileOptions{ColumnDiscoveryMode: shape.CompileColumnDiscoveryAuto}) + require.Empty(t, diags) + require.True(t, result.ColumnsDiscovery) + require.True(t, result.Views[0].ColumnsDiscovery) +} + +func TestApplyColumnDiscoveryPolicy_Auto_NoConcreteShapeRequiresDiscovery(t *testing.T) { + result := &plan.Result{ + Views: []*plan.View{{ + Name: "orders", + Mode: "SQLQuery", + SQL: "SELECT id FROM ORDERS", + FieldType: reflect.TypeOf([]map[string]any{}), + ElementType: reflect.TypeOf(map[string]any{}), + }}, + } + diags := applyColumnDiscoveryPolicy(result, &shape.CompileOptions{ColumnDiscoveryMode: shape.CompileColumnDiscoveryAuto}) + require.Empty(t, diags) + require.True(t, result.ColumnsDiscovery) + require.True(t, result.Views[0].ColumnsDiscovery) +} + +func TestApplyColumnDiscoveryPolicy_Off_EmitsErrorWhenRequired(t *testing.T) { + result := &plan.Result{ + Views: []*plan.View{{ + Name: "orders", + Mode: "SQLQuery", + SQL: "SELECT * FROM ORDERS", + FieldType: reflect.TypeOf([]map[string]any{}), + ElementType: reflect.TypeOf(map[string]any{}), + }}, + } + diags := applyColumnDiscoveryPolicy(result, &shape.CompileOptions{ColumnDiscoveryMode: shape.CompileColumnDiscoveryOff}) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeColDiscoveryReq, diags[0].Code) + assert.True(t, result.ColumnsDiscovery) + assert.True(t, result.Views[0].ColumnsDiscovery) +} + +func TestApplyColumnDiscoveryPolicy_On_AlwaysMarksQueryViews(t *testing.T) { + result := &plan.Result{ + Views: []*plan.View{{ + Name: "orders", + Mode: "SQLQuery", + SQL: "SELECT id FROM ORDERS", + FieldType: reflect.TypeOf([]struct{ ID int }{}), + ElementType: reflect.TypeOf(struct{ ID int }{}), + }}, + } + diags := applyColumnDiscoveryPolicy(result, &shape.CompileOptions{ColumnDiscoveryMode: shape.CompileColumnDiscoveryOn}) + require.Empty(t, diags) + assert.True(t, result.ColumnsDiscovery) + assert.True(t, result.Views[0].ColumnsDiscovery) +} diff --git a/repository/shape/compile/compiler.go b/repository/shape/compile/compiler.go index 69647b60..db57701a 100644 --- a/repository/shape/compile/compiler.go +++ b/repository/shape/compile/compiler.go @@ -3,15 +3,16 @@ package compile import ( "context" "fmt" - "reflect" - "regexp" "strings" - "github.com/viant/datly/internal/translator/parser" "github.com/viant/datly/repository/shape" - dqlparse "github.com/viant/datly/repository/shape/dql/parse" + "github.com/viant/datly/repository/shape/compile/dml" + "github.com/viant/datly/repository/shape/compile/pipeline" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" "github.com/viant/datly/repository/shape/plan" - "github.com/viant/sqlparser" ) // DQLCompiler compiles raw DQL into a shape plan that can be materialized by shape/load. @@ -22,89 +23,259 @@ func New() *DQLCompiler { return &DQLCompiler{} } +// CompileError represents one or more compilation diagnostics. +type CompileError struct { + Diagnostics []*dqlshape.Diagnostic +} + +func (e *CompileError) Error() string { + if e == nil || len(e.Diagnostics) == 0 { + return "shape compile failed" + } + first := e.Diagnostics[0] + if len(e.Diagnostics) == 1 { + return first.Error() + } + return fmt.Sprintf("%s (and %d more diagnostics)", first.Error(), len(e.Diagnostics)-1) +} + // Compile implements shape.DQLCompiler. -func (c *DQLCompiler) Compile(_ context.Context, source *shape.Source, _ ...shape.CompileOption) (*shape.PlanResult, error) { +func (c *DQLCompiler) Compile(_ context.Context, source *shape.Source, opts ...shape.CompileOption) (*shape.PlanResult, error) { if source == nil { return nil, shape.ErrNilSource } - dql := strings.TrimSpace(source.DQL) - if dql == "" { + compileOptions := applyCompileOptions(opts) + pathLayout := newCompilePathLayout(compileOptions) + compileProfile := normalizeCompileProfile(compileOptions.Profile) + enforceStrict := compileOptions.Strict || compileProfile == shape.CompileProfileStrict + if strings.TrimSpace(source.DQL) == "" { return nil, shape.ErrNilDQL } - name, table, err := inferRoot(dql, source.Name) + pre := dqlpre.Prepare(source.DQL) + pre.TypeCtx = applyTypeContextDefaults(pre.TypeCtx, source, compileOptions, pathLayout) + pre.Diagnostics = append(pre.Diagnostics, typeContextDiagnostics(pre.TypeCtx, enforceStrict)...) + allDiags := append([]*dqlshape.Diagnostic{}, pre.Diagnostics...) + if hasErrorDiagnostics(allDiags) { + return nil, &CompileError{Diagnostics: allDiags} + } + + statements := dqlstmt.New(pre.SQL) + decision := pipeline.Classify(statements) + prepared := buildHandlerIfNeeded(source, pre, statements, decision, pathLayout) + pre = prepared.Pre + statements = prepared.Statements + decision = prepared.Decision + legacyFallbackViews := prepared.LegacyViews + effectiveSource := source + if prepared.EffectiveSource != nil { + effectiveSource = prepared.EffectiveSource + } + if strings.TrimSpace(pre.SQL) == "" && len(legacyFallbackViews) == 0 { + allDiags = append(allDiags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeParseEmpty, + Severity: dqlshape.SeverityError, + Message: "no SQL statement found", + Hint: "add SELECT/INSERT/UPDATE/DELETE statement after DQL directives", + Span: dqlshape.Span{ + Start: dqlshape.Position{Line: 1, Char: 1}, + End: dqlshape.Position{Line: 1, Char: 1}, + }, + }) + return nil, &CompileError{Diagnostics: allDiags} + } + var root *plan.View + var compileDiags []*dqlshape.Diagnostic + var err error + if len(legacyFallbackViews) > 0 { + root = legacyFallbackViews[0] + } else { + root, compileDiags, err = c.compileRoot(source.Name, pre.SQL, statements, decision, compileOptions.MixedMode, compileOptions.UnknownNonReadMode) + } if err != nil { return nil, err } + pre.Mapper.Remap(compileDiags) + allDiags = append(allDiags, compileDiags...) + if root == nil { + return nil, &CompileError{Diagnostics: allDiags} + } - result := &plan.Result{ - Views: []*plan.View{ - { - Path: name, - Holder: name, - Name: name, - Table: table, - SQL: dql, - Cardinality: "many", - FieldType: reflect.TypeOf([]map[string]interface{}{}), - ElementType: reflect.TypeOf(map[string]interface{}{}), - }, - }, - ViewsByName: map[string]*plan.View{}, - ByPath: map[string]*plan.Field{}, + result := newPlanResult(root) + if len(legacyFallbackViews) > 1 { + for _, item := range legacyFallbackViews[1:] { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + if _, exists := result.ViewsByName[item.Name]; exists { + continue + } + result.Views = append(result.Views, item) + result.ViewsByName[item.Name] = item + } } - if parsed, parseErr := dqlparse.New().Parse(dql); parseErr == nil && parsed != nil && parsed.TypeContext != nil { - result.TypeContext = parsed.TypeContext + result.Diagnostics = allDiags + result.TypeContext = pre.TypeCtx + result.Directives = pre.Directives + applyDefaultConnectorDirective(result) + hints := extractViewHints(source.DQL) + appendRelationViews(result, root, hints) + appendDeclaredViews(source.DQL, result) + appendDeclaredStates(source.DQL, result) + if prepared.ForceLegacyContract && len(legacyFallbackViews) > 0 { + if legacyStates := resolveLegacyRouteStatesWithLayout(effectiveSource, pathLayout); len(legacyStates) > 0 { + result.States = legacyStates + } + if legacyTypes := resolveLegacyRouteTypesWithLayout(effectiveSource, pathLayout); len(legacyTypes) > 0 { + result.Types = legacyTypes + } + } + result.Diagnostics = append(result.Diagnostics, appendComponentTypesWithLayout(effectiveSource, result, pathLayout)...) + mergeLegacyRouteStatesWithLayout(result, effectiveSource, pathLayout) + mergeLegacyRouteTypesWithLayout(result, effectiveSource, pathLayout) + applyViewHints(result, hints) + applySourceParityEnrichmentWithLayout(result, effectiveSource, pathLayout) + result.Diagnostics = append(result.Diagnostics, applyColumnDiscoveryPolicy(result, compileOptions)...) + if len(result.States) == 0 && len(legacyFallbackViews) > 0 { + result.States = resolveLegacyRouteStatesWithLayout(effectiveSource, pathLayout) + } + if len(result.Types) == 0 && len(legacyFallbackViews) > 0 { + result.Types = resolveLegacyRouteTypesWithLayout(effectiveSource, pathLayout) + } + + if enforceStrict && hasEscalationWarnings(result.Diagnostics) { + return nil, &CompileError{Diagnostics: filterEscalationDiagnostics(result.Diagnostics)} + } + if hasErrorDiagnostics(result.Diagnostics) { + return nil, &CompileError{Diagnostics: result.Diagnostics} } - result.ViewsByName[name] = result.Views[0] return &shape.PlanResult{Source: source, Plan: result}, nil } -func inferRoot(dql string, fallback string) (string, string, error) { - query, err := sqlparser.ParseQuery(dql, parser.OnVeltyExpression()) - if err != nil { - name := sanitizeName(fallback) - if name == "" { - name = "DQLView" - } - return name, "", nil +func applyDefaultConnectorDirective(result *plan.Result) { + if result == nil || result.Directives == nil { + return } - - name := sanitizeName(query.From.Alias) - if name == "" { - name = sanitizeName(fallback) + connector := strings.TrimSpace(result.Directives.DefaultConnector) + if connector == "" { + return } - if name == "" { - name = "DQLView" + for _, item := range result.Views { + if item == nil || strings.TrimSpace(item.Connector) != "" { + continue + } + item.Connector = connector } +} - table := "" - if query != nil && query.From.X != nil { - table = strings.TrimSpace(sqlparser.Stringify(query.From.X)) +func (c *DQLCompiler) compileRoot(sourceName, sqlText string, statements dqlstmt.Statements, decision pipeline.Decision, mode shape.CompileMixedMode, unknownMode shape.CompileUnknownNonReadMode) (*plan.View, []*dqlshape.Diagnostic, error) { + mode = normalizeMixedMode(mode) + unknownMode = normalizeUnknownNonReadMode(unknownMode) + if !decision.HasRead && !decision.HasExec && decision.HasUnknown { + diag := &dqlshape.Diagnostic{ + Code: dqldiag.CodeParseUnknownNonRead, + Severity: dqlshape.SeverityWarning, + Message: "no readable SELECT statement detected", + Hint: "use SELECT for read parsing or compile as DML/handler template", + Span: pipeline.StatementSpan(sqlText, statements[0]), + } + if unknownMode == shape.CompileUnknownNonReadError { + diag.Severity = dqlshape.SeverityError + return nil, []*dqlshape.Diagnostic{diag}, nil + } + view, execDiags := pipeline.BuildExec(sourceName, sqlText, statements) + return view, append([]*dqlshape.Diagnostic{diag}, execDiags...), nil } - if table == "" || strings.HasPrefix(table, "(") { - table = name + if decision.HasRead && decision.HasExec { + switch mode { + case shape.CompileMixedModeErrorOnMixed: + return nil, []*dqlshape.Diagnostic{ + { + Code: dqldiag.CodeDMLMixed, + Severity: dqlshape.SeverityError, + Message: "mixed read/exec script is not allowed by compile mixed mode", + Hint: "use WithMixedMode(shape.CompileMixedModeExecWins) or split handlers", + Span: pipeline.StatementSpan(sqlText, statements[0]), + }, + }, nil + case shape.CompileMixedModeReadWins: + readSQL := sqlText + for _, stmt := range statements { + if stmt != nil && stmt.Kind == dqlstmt.KindRead { + readSQL = sqlText[stmt.Start:stmt.End] + break + } + } + view, diags, err := pipeline.BuildRead(sourceName, readSQL) + diags = append(diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLMixed, + Severity: dqlshape.SeverityWarning, + Message: "mixed read/exec script detected; read compilation path selected", + Hint: "split SELECT and DML into separate handlers when possible", + Span: pipeline.StatementSpan(sqlText, statements[0]), + }) + return view, diags, err + } } - if name == "" { - return "", "", fmt.Errorf("shape compile: failed to infer view name") + if decision.HasExec { + view, diags := dml.Compile(sourceName, sqlText, statements) + if decision.HasRead { + diags = append(diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLMixed, + Severity: dqlshape.SeverityWarning, + Message: "mixed read/exec script detected; exec compilation path selected", + Hint: "split SELECT and DML into separate handlers when possible", + Span: pipeline.StatementSpan(sqlText, statements[0]), + }) + } + return view, diags, nil } - return name, table, nil + return pipeline.BuildRead(sourceName, sqlText) } -var nonWord = regexp.MustCompile(`[^a-zA-Z0-9_]+`) +func normalizeMixedMode(mode shape.CompileMixedMode) shape.CompileMixedMode { + switch mode { + case shape.CompileMixedModeExecWins, shape.CompileMixedModeReadWins, shape.CompileMixedModeErrorOnMixed: + return mode + default: + return shape.CompileMixedModeExecWins + } +} -func sanitizeName(value string) string { - value = strings.TrimSpace(value) - if value == "" { - return "" +func normalizeUnknownNonReadMode(mode shape.CompileUnknownNonReadMode) shape.CompileUnknownNonReadMode { + switch mode { + case shape.CompileUnknownNonReadWarn, shape.CompileUnknownNonReadError: + return mode + default: + return shape.CompileUnknownNonReadWarn } - value = nonWord.ReplaceAllString(value, "_") - value = strings.Trim(value, "_") - if value == "" { - return "" +} + +func normalizeCompileProfile(profile shape.CompileProfile) shape.CompileProfile { + switch profile { + case shape.CompileProfileCompat, shape.CompileProfileStrict: + return profile + default: + return shape.CompileProfileCompat } - if value[0] >= '0' && value[0] <= '9' { - value = "V_" + value +} + +func newPlanResult(root *plan.View) *plan.Result { + result := &plan.Result{ + Views: []*plan.View{root}, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + result.ViewsByName[root.Name] = root + return result +} + +func applyCompileOptions(opts []shape.CompileOption) *shape.CompileOptions { + ret := &shape.CompileOptions{} + for _, opt := range opts { + if opt != nil { + opt(ret) + } } - return value + return ret } diff --git a/repository/shape/compile/compiler_test.go b/repository/shape/compile/compiler_test.go index b539ab80..63156250 100644 --- a/repository/shape/compile/compiler_test.go +++ b/repository/shape/compile/compiler_test.go @@ -2,11 +2,16 @@ package compile import ( "context" + "os" + "path/filepath" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/viant/datly/repository/shape" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/plan" ) @@ -23,6 +28,8 @@ func TestDQLCompiler_Compile(t *testing.T) { assert.Equal(t, "t", view.Name) assert.Equal(t, "ORDERS", view.Table) assert.Equal(t, "many", view.Cardinality) + require.NotNil(t, view.FieldType) + assert.Contains(t, view.FieldType.String(), "Id") } func TestDQLCompiler_Compile_EmptyDQL(t *testing.T) { @@ -53,8 +60,8 @@ SELECT id func TestDQLCompiler_Compile_PropagatesTypeContext(t *testing.T) { compiler := New() dql := ` -#set($_ = $package('mdp/performance')) -#set($_ = $import('perf', 'github.com/acme/mdp/performance')) +#settings($_ = $package('mdp/performance')) +#settings($_ = $import('perf', 'github.com/acme/mdp/performance')) SELECT id FROM ORDERS t` res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) require.NoError(t, err) @@ -67,3 +74,758 @@ SELECT id FROM ORDERS t` require.Len(t, planned.TypeContext.Imports, 1) assert.Equal(t, "perf", planned.TypeContext.Imports[0].Alias) } + +func TestDQLCompiler_Compile_PropagatesSpecialDirectives(t *testing.T) { + compiler := New() + dql := ` +#settings($_ = $meta('docs/orders.md')) +#settings($_ = $connector('analytics')) +#settings($_ = $cache(true, '5m')) +#settings($_ = $mcp('orders.search', 'Search orders', 'docs/mcp/orders.md')) +SELECT id FROM ORDERS o +` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotNil(t, planned.Directives) + assert.Equal(t, "docs/orders.md", planned.Directives.Meta) + assert.Equal(t, "analytics", planned.Directives.DefaultConnector) + require.NotNil(t, planned.Directives.Cache) + assert.True(t, planned.Directives.Cache.Enabled) + assert.Equal(t, "5m", planned.Directives.Cache.TTL) + require.NotNil(t, planned.Directives.MCP) + assert.Equal(t, "orders.search", planned.Directives.MCP.Name) + assert.Equal(t, "Search orders", planned.Directives.MCP.Description) + assert.Equal(t, "docs/mcp/orders.md", planned.Directives.MCP.DescriptionPath) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "analytics", planned.Views[0].Connector) +} + +func TestDQLCompiler_Compile_ColumnDiscoveryAutoForWildcard(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "SELECT * FROM ORDERS o"}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.True(t, planned.ColumnsDiscovery) + require.NotEmpty(t, planned.Views) + assert.True(t, planned.Views[0].ColumnsDiscovery) +} + +func TestDQLCompiler_Compile_ColumnDiscoveryOffFailsWhenRequired(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "SELECT * FROM ORDERS o"}, + shape.WithColumnDiscoveryMode(shape.CompileColumnDiscoveryOff)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeColDiscoveryReq, compileErr.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_TypeContextValidationWarnsInCompat(t *testing.T) { + compiler := New() + dql := ` +#settings($_ = $package('github.com/acme/perf')) +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, shape.WithTypeContextPackageName("bad/name")) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeTypeCtxInvalid, planned.Diagnostics[0].Code) + assert.Equal(t, dqlshape.SeverityWarning, planned.Diagnostics[0].Severity) +} + +func TestDQLCompiler_Compile_TypeContextValidationFailsInStrict(t *testing.T) { + compiler := New() + dql := `SELECT id FROM ORDERS t` + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, + shape.WithCompileProfile(shape.CompileProfileStrict), + shape.WithTypeContextPackageName("bad/name")) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeTypeCtxInvalid, compileErr.Diagnostics[0].Code) + assert.Equal(t, dqlshape.SeverityError, compileErr.Diagnostics[0].Severity) +} + +func TestDQLCompiler_Compile_SyntaxError_HasLineAndChar(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "SELECT id FROM ORDERS WHERE ("}) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + d := compileErr.Diagnostics[0] + assert.Equal(t, dqldiag.CodeParseSyntax, d.Code) + assert.Equal(t, 1, d.Span.Start.Line) + assert.Equal(t, 29, d.Span.Start.Char) +} + +func TestDQLCompiler_Compile_SyntaxError_RemapsAfterSanitize(t *testing.T) { + compiler := New() + dql := "SELECT id FROM ORDERS t WHERE t.id = $Id AND (" + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + var diagnostics []*dqlshape.Diagnostic + if err != nil { + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + diagnostics = compileErr.Diagnostics + } else { + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + diagnostics = planned.Diagnostics + } + var d *dqlshape.Diagnostic + for _, item := range diagnostics { + if item != nil && item.Code == dqldiag.CodeParseSyntax { + d = item + break + } + } + if d != nil { + assert.Equal(t, 1, d.Span.Start.Line) + assert.Greater(t, d.Span.Start.Char, 0) + assert.LessOrEqual(t, d.Span.Start.Char, len(dql)) + } +} + +func TestDQLCompiler_Compile_DirectiveOnly_HasLineAndChar(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "#settings($_ = $package('x'))"}) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + d := compileErr.Diagnostics[0] + assert.Equal(t, dqldiag.CodeParseEmpty, d.Code) + assert.Equal(t, 1, d.Span.Start.Line) + assert.Equal(t, 1, d.Span.Start.Char) +} + +func TestDQLCompiler_Compile_InvalidDirective_HasLineAndChar(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_report", + DQL: "SELECT id FROM ORDERS t\n#settings($_ = $import('alias'))\nSELECT id FROM ORDERS t", + }) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + d := compileErr.Diagnostics[0] + assert.Equal(t, dqldiag.CodeDirImport, d.Code) + assert.Equal(t, 2, d.Span.Start.Line) + assert.Equal(t, 1, d.Span.Start.Char) +} + +func TestDQLCompiler_Compile_ExtractsJoinLinks(t *testing.T) { + compiler := New() + dql := "SELECT o.id, i.sku FROM orders o JOIN order_items i ON o.id = i.order_id" + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + root := planned.ViewsByName["o"] + require.NotNil(t, root) + require.Len(t, root.Relations, 1) + assert.Equal(t, "i", root.Relations[0].Ref) + require.Len(t, root.Relations[0].On, 1) + assert.Equal(t, "o.id=i.order_id", root.Relations[0].On[0].Expression) + assert.Equal(t, "id", root.Relations[0].On[0].ParentColumn) + assert.Equal(t, "order_id", root.Relations[0].On[0].RefColumn) + assert.Empty(t, planned.Diagnostics) +} + +func TestDQLCompiler_Compile_JoinDiagnostics(t *testing.T) { + compiler := New() + dql := "SELECT o.id FROM orders o JOIN order_items i ON o.id > i.order_id" + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeRelUnsupported, planned.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_StrictRelationWarningsFail(t *testing.T) { + compiler := New() + dql := "SELECT o.id FROM orders o JOIN order_items i ON o.id > i.order_id" + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, shape.WithCompileStrict(true)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeRelUnsupported, compileErr.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_ProfileStrictRelationWarningsFail(t *testing.T) { + compiler := New() + dql := "SELECT o.id FROM orders o JOIN order_items i ON o.id > i.order_id" + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, shape.WithCompileProfile(shape.CompileProfileStrict)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeRelUnsupported, compileErr.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_StrictAmbiguousLinkFail(t *testing.T) { + compiler := New() + dql := "SELECT o.id FROM orders o JOIN order_items i ON x.id = y.order_id" + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, shape.WithCompileStrict(true)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeRelAmbiguous, compileErr.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_SQLInjectionDiagnostic(t *testing.T) { + compiler := New() + dql := "SELECT id FROM ORDERS t WHERE t.id = $Unsafe.Id" + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeSQLIRawSelector, planned.Diagnostics[0].Code) + assert.Equal(t, 1, planned.Diagnostics[0].Span.Start.Line) + assert.Greater(t, planned.Diagnostics[0].Span.Start.Char, 1) +} + +func TestDQLCompiler_Compile_SanitizesBindings(t *testing.T) { + compiler := New() + dql := "SELECT id FROM ORDERS t WHERE t.id = $Id" + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Contains(t, planned.Views[0].SQL, "$criteria.AppendBinding($Unsafe.Id)") +} + +func TestDQLCompiler_Compile_ParameterDerivedView(t *testing.T) { + compiler := New() + dql := ` +#set($_ = $Extra(view/extra_view) /* SELECT code FROM EXTRA e */) +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.Len(t, planned.Views, 2) + extra := planned.ViewsByName["e"] + require.NotNil(t, extra) + assert.Equal(t, "EXTRA", extra.Table) + assert.Contains(t, extra.SQL, "SELECT code FROM EXTRA e") +} + +func TestDQLCompiler_Compile_ParameterDerivedView_Options(t *testing.T) { + compiler := New() + dql := ` +#set($_ = $Extra(view/extra_view).WithURI('/v1/extra').WithConnector('analytics').Cardinality('one') /* SELECT code FROM EXTRA e */) +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + extra := planned.ViewsByName["e"] + require.NotNil(t, extra) + assert.Equal(t, "/v1/extra", extra.SQLURI) + assert.Equal(t, "analytics", extra.Connector) + assert.Equal(t, "one", extra.Cardinality) +} + +func TestDQLCompiler_Compile_ParameterDerivedView_MissingSQLHint(t *testing.T) { + compiler := New() + dql := ` +#set($_ = $Extra(view/extra_view)) +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeViewMissingSQL, planned.Diagnostics[len(planned.Diagnostics)-1].Code) +} + +func TestDQLCompiler_Compile_ParameterDerivedView_InvalidCardinalityDiagnostic(t *testing.T) { + compiler := New() + dql := ` +#set($_ = $Extra(view/extra_view).Cardinality('few') /* SELECT code FROM EXTRA e */) +SELECT id FROM ORDERS t` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeViewCardinality, planned.Diagnostics[len(planned.Diagnostics)-1].Code) +} + +func TestDQLCompiler_Compile_StrictSQLInjectionWarningsFail(t *testing.T) { + compiler := New() + dql := "SELECT id FROM ORDERS t WHERE t.id = $Unsafe.Id" + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, shape.WithCompileStrict(true)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeSQLIRawSelector, compileErr.Diagnostics[0].Code) +} + +func TestDQLCompiler_Compile_DMLInsert(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_exec", + DQL: "INSERT INTO ORDERS(id) VALUES (1)", + }) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.Len(t, planned.Views, 1) + assert.Equal(t, "ORDERS", planned.Views[0].Table) + assert.Equal(t, "many", planned.Views[0].Cardinality) +} + +func TestDQLCompiler_Compile_DMLServiceMissingArg(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_exec", + DQL: "$sql.Insert($rec)", + }) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + var target *dqlshape.Diagnostic + for _, item := range compileErr.Diagnostics { + if item != nil && item.Code == dqldiag.CodeDMLServiceArg { + target = item + break + } + } + require.NotNil(t, target) + assert.Equal(t, 1, target.Span.Start.Line) + assert.Equal(t, 1, target.Span.Start.Char) +} + +func TestDQLCompiler_Compile_DMLSyntaxError_HasLineAndChar(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_exec", + DQL: "#settings($_ = $package('x'))\nINSERT INTO ORDERS(id VALUES (1)", + }) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + var target *dqlshape.Diagnostic + for _, item := range compileErr.Diagnostics { + if item != nil && item.Code == dqldiag.CodeDMLInsert { + target = item + break + } + } + require.NotNil(t, target) + assert.Equal(t, 2, target.Span.Start.Line) + assert.Equal(t, 1, target.Span.Start.Char) +} + +func TestDQLCompiler_Compile_MixedReadExec_Warning(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "mixed_exec", + DQL: "SELECT id FROM ORDERS\nUPDATE ORDERS SET id = 2", + }) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeDMLMixed, planned.Diagnostics[len(planned.Diagnostics)-1].Code) +} + +func TestDQLCompiler_Compile_MixedMode_ExecWins(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "mixed_exec", + DQL: "SELECT o.id FROM ORDERS o\nUPDATE ORDERS SET id = 2", + }, shape.WithMixedMode(shape.CompileMixedModeExecWins)) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "ORDERS", planned.Views[0].Table) + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeDMLMixed, planned.Diagnostics[len(planned.Diagnostics)-1].Code) +} + +func TestDQLCompiler_Compile_MixedMode_ReadWins(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "mixed_exec", + DQL: "SELECT o.id FROM ORDERS o\nUPDATE ORDERS SET id = 2", + }, shape.WithMixedMode(shape.CompileMixedModeReadWins)) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "o", planned.Views[0].Name) + assert.Equal(t, "ORDERS", planned.Views[0].Table) + assert.Contains(t, planned.Views[0].SQL, "SELECT o.id FROM ORDERS o") + assert.NotContains(t, planned.Views[0].SQL, "UPDATE ORDERS") + require.NotEmpty(t, planned.Diagnostics) + assert.Equal(t, dqldiag.CodeDMLMixed, planned.Diagnostics[len(planned.Diagnostics)-1].Code) +} + +func TestDQLCompiler_Compile_MixedMode_ErrorOnMixed(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "mixed_exec", + DQL: "SELECT o.id FROM ORDERS o\nUPDATE ORDERS SET id = 2", + }, shape.WithMixedMode(shape.CompileMixedModeErrorOnMixed)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + assert.Equal(t, dqldiag.CodeDMLMixed, compileErr.Diagnostics[0].Code) + assert.Equal(t, dqlshape.SeverityError, compileErr.Diagnostics[0].Severity) +} + +func TestDQLCompiler_Compile_UnknownNonRead_Warn(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_report", + DQL: "$Foo.Bar($x)", + }) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Diagnostics) + var found *dqlshape.Diagnostic + for _, item := range planned.Diagnostics { + if item != nil && item.Code == dqldiag.CodeParseUnknownNonRead { + found = item + break + } + } + require.NotNil(t, found) + assert.Equal(t, dqlshape.SeverityWarning, found.Severity) + require.NotEmpty(t, planned.Views) +} + +func TestDQLCompiler_Compile_UnknownNonRead_ErrorMode(t *testing.T) { + compiler := New() + _, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "orders_report", + DQL: "$Foo.Bar($x)", + }, shape.WithUnknownNonReadMode(shape.CompileUnknownNonReadError)) + require.Error(t, err) + compileErr, ok := err.(*CompileError) + require.True(t, ok) + require.NotEmpty(t, compileErr.Diagnostics) + var found *dqlshape.Diagnostic + for _, item := range compileErr.Diagnostics { + if item != nil && item.Code == dqldiag.CodeParseUnknownNonRead { + found = item + break + } + } + require.NotNil(t, found) + assert.Equal(t, dqlshape.SeverityError, found.Severity) +} + +func TestResolveGeneratedCompanionDQL(t *testing.T) { + tempDir := t.TempDir() + dqlPath := filepath.Join(tempDir, "platform", "sitelist", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(dqlPath), 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(filepath.Dir(dqlPath), "gen"), 0o755)) + generatedPath := filepath.Join(filepath.Dir(dqlPath), "gen", "patch.sql") + require.NoError(t, os.WriteFile(generatedPath, []byte("SELECT id FROM SITE_LIST sl"), 0o644)) + source := &shape.Source{ + Path: dqlPath, + DQL: `/* {"Type":"sitelist/patch.Handler"} */`, + } + actual := resolveGeneratedCompanionDQL(source) + require.Contains(t, actual, "SELECT id FROM SITE_LIST") +} + +func TestDQLCompiler_Compile_UnknownNonRead_UsesGeneratedCompanion(t *testing.T) { + tempDir := t.TempDir() + dqlPath := filepath.Join(tempDir, "platform", "adorder", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Join(filepath.Dir(dqlPath), "gen", "adorder"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(filepath.Dir(dqlPath), "gen", "adorder", "patch.dql"), []byte("SELECT o.id FROM ORDERS o JOIN ORDER_ITEM i ON i.ORDER_ID = o.ID"), 0o644)) + source := &shape.Source{ + Name: "patch", + Path: dqlPath, + DQL: `/* {"Type":"adorder/patch.Handler"} */`, + } + + compiler := New() + res, err := compiler.Compile(context.Background(), source) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotNil(t, planned.ViewsByName["o"]) + require.NotNil(t, planned.ViewsByName["i"]) + var hasUnknownNonRead bool + for _, diag := range planned.Diagnostics { + if diag != nil && diag.Code == dqldiag.CodeParseUnknownNonRead { + hasUnknownNonRead = true + break + } + } + assert.False(t, hasUnknownNonRead) +} + +func TestResolveLegacyRouteViews(t *testing.T) { + tempDir := t.TempDir() + sourcePath := filepath.Join(tempDir, "dql", "platform", "campaign", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) + require.NoError(t, os.WriteFile(sourcePath, []byte(`/* {"Connector":"ci_ads"} */`), 0o644)) + + routeDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "platform", "campaign", "patch") + require.NoError(t, os.MkdirAll(routeDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(routeDir, "patch.sql"), []byte(`SELECT 1`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routeDir, "CurCampaign.sql"), []byte(`SELECT * FROM CI_CAMPAIGN`), 0o644)) + + views := resolveLegacyRouteViews(&shape.Source{Path: sourcePath, DQL: `/* {"Connector":"ci_ads"} */`}) + require.Len(t, views, 2) + assert.Equal(t, "patch", views[0].Name) + assert.Equal(t, "", views[0].Table) + assert.Equal(t, "patch/patch.sql", views[0].SQLURI) + assert.Equal(t, "CurCampaign", views[1].Name) + assert.Equal(t, "CI_CAMPAIGN", views[1].Table) + assert.Equal(t, "ci_ads", views[1].Connector) +} + +func TestResolveLegacyRouteViews_TypeStemSubfolder(t *testing.T) { + tempDir := t.TempDir() + sourcePath := filepath.Join(tempDir, "dql", "platform", "campaign", "post.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) + require.NoError(t, os.WriteFile(sourcePath, []byte(`/* {"Type":"campaign/patch.Handler","Connector":"ci_ads"} */`), 0o644)) + + routeDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "platform", "campaign", "patch", "post") + require.NoError(t, os.MkdirAll(routeDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(routeDir, "post.sql"), []byte(`SELECT 1`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routeDir, "CurCampaign.sql"), []byte(`SELECT * FROM CI_CAMPAIGN`), 0o644)) + + views := resolveLegacyRouteViews(&shape.Source{Path: sourcePath, DQL: `/* {"Type":"campaign/patch.Handler","Connector":"ci_ads"} */`}) + require.Len(t, views, 2) + assert.Equal(t, "post", views[0].Name) + assert.Equal(t, "CurCampaign", views[1].Name) + assert.Equal(t, "post/CurCampaign.sql", views[1].SQLURI) +} + +func TestDQLCompiler_Compile_HandlerNop_NoSQLiEscalation(t *testing.T) { + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "handler_nop", + DQL: "$Nop($Unsafe.Id)", + }, shape.WithCompileStrict(true)) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + for _, item := range planned.Diagnostics { + if item == nil { + continue + } + assert.NotEqual(t, dqldiag.CodeSQLIRawSelector, item.Code) + } +} + +func TestDQLCompiler_Compile_SubqueryJoin_BuildsRelatedViewsAndConnectorHints(t *testing.T) { + compiler := New() + dql := ` +#set($_ = $Jwt(header/Authorization).WithCodec(JwtClaim).WithStatusCode(401)) +SELECT session.*, +use_connector(session, system), +use_connector(attribute, system) +FROM (SELECT * FROM session WHERE user_id = $Jwt.UserID) session +JOIN (SELECT * FROM session/attributes) attribute ON attribute.user_id = session.user_id +` + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "system/session", DQL: dql}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + root := planned.ViewsByName["session"] + require.NotNil(t, root) + assert.Equal(t, "system", root.Connector) + related := planned.ViewsByName["attribute"] + require.NotNil(t, related) + assert.Equal(t, "session/attributes", related.Table) + assert.Equal(t, "system", related.Connector) +} + +func TestDQLCompiler_Compile_GeneratedHandler_NoBodyInput_UsesLegacyContractStates(t *testing.T) { + tempDir := t.TempDir() + genPath := filepath.Join(tempDir, "dql", "system", "upload", "gen", "upload", "delete.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(genPath), 0o755)) + require.NoError(t, os.WriteFile(genPath, []byte(`/* {"Method":"DELETE","URI":"/v1/api/system/upload"} */`), 0o644)) + + legacySQLPath := filepath.Join(tempDir, "dql", "system", "upload", "delete.sql") + require.NoError(t, os.MkdirAll(filepath.Dir(legacySQLPath), 0o755)) + require.NoError(t, os.WriteFile(legacySQLPath, []byte(`/* {"Type":"upload/delete.Handler","Connector":"system"} */`), 0o644)) + + routesDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "system", "upload") + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "delete"), 0o755)) + routeYAML := `Resource: + Parameters: + - Name: Method + In: + Kind: http_request + Name: method + - Name: UploadId + In: + Kind: query + Name: uploadId + Views: + - Name: delete + Mode: SQLExec + Connector: + Ref: system + Template: + SourceURL: delete/delete.sql +` + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "delete.yaml"), []byte(routeYAML), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "delete", "delete.sql"), []byte(`$Nop($Unsafe.UploadId)`), 0o644)) + + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{Name: "delete", Path: genPath, DQL: `/* {"Method":"DELETE","URI":"/v1/api/system/upload"} */`}) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + + require.NotEmpty(t, planned.Views) + assert.Equal(t, "delete", planned.Views[0].Name) + assert.Equal(t, "SQLExec", planned.Views[0].Mode) + assert.Equal(t, "system", planned.Views[0].Connector) + + stateByName := map[string]*plan.State{} + for _, item := range planned.States { + if item == nil { + continue + } + stateByName[item.Name] = item + } + require.Contains(t, stateByName, "Method") + require.Contains(t, stateByName, "UploadId") + assert.Equal(t, "http_request", stateByName["Method"].Kind) + assert.Equal(t, "query", stateByName["UploadId"].Kind) + assert.NotContains(t, stateByName, "Body") +} + +func TestDQLCompiler_Compile_HandlerLegacyTypes_PreferredOverComponentNameCollisions(t *testing.T) { + tempDir := t.TempDir() + sourcePath := filepath.Join(tempDir, "dql", "platform", "campaign", "post.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) + require.NoError(t, os.WriteFile(sourcePath, []byte(`/* {"URI":"/v1/api/platform/campaign","Method":"POST","Type":"campaign/patch.Handler"} */`), 0o644)) + + rootRouteDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "platform", "campaign", "patch") + require.NoError(t, os.MkdirAll(filepath.Join(rootRouteDir, "post"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(rootRouteDir, "post.yaml"), []byte(`Resource: + Parameters: + - Name: Auth + In: + Kind: component + Name: GET:/v1/api/platform/acl/auth + Views: + - Name: post + Mode: SQLExec + Connector: + Ref: ci_ads + Template: + SourceURL: post/post.sql + Types: + - Name: Input + DataType: "*Input" + Package: campaign/patch + ModulePath: github.vianttech.com/viant/platform/pkg/platform/campaign/patch + - Name: Handler + DataType: "*Handler" + Package: campaign/patch + ModulePath: github.vianttech.com/viant/platform/pkg/platform/campaign/patch +`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(rootRouteDir, "post", "post.sql"), []byte(`$Nop($Unsafe.Id)`), 0o644)) + + componentRouteDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "platform", "acl", "auth") + require.NoError(t, os.MkdirAll(componentRouteDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(componentRouteDir, "auth.yaml"), []byte(`Resource: + Types: + - Name: Input + DataType: "*Input" + Package: acl/auth + ModulePath: github.vianttech.com/viant/platform/pkg/platform/acl/auth + - Name: Handler + DataType: "*Handler" + Package: acl/auth + ModulePath: github.vianttech.com/viant/platform/pkg/platform/acl/auth +`), 0o644)) + + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "post", + Path: sourcePath, + DQL: `/* {"URI":"/v1/api/platform/campaign","Method":"POST","Type":"campaign/patch.Handler"} */`, + }) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + + typeByName := map[string]*plan.Type{} + for _, item := range planned.Types { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + typeByName[strings.ToLower(strings.TrimSpace(item.Name))] = item + } + + inputType, ok := typeByName["input"] + require.True(t, ok) + assert.Equal(t, "campaign/patch", inputType.Package) + assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/campaign/patch", inputType.ModulePath) + + handlerType, ok := typeByName["handler"] + require.True(t, ok) + assert.Equal(t, "campaign/patch", handlerType.Package) + assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/campaign/patch", handlerType.ModulePath) +} + +func TestDQLCompiler_Compile_CustomPathLayout_HandlerFallback(t *testing.T) { + tempDir := t.TempDir() + sourcePath := filepath.Join(tempDir, "sqlsrc", "platform", "campaign", "post.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) + require.NoError(t, os.WriteFile(sourcePath, []byte(`/* {"URI":"/v1/api/platform/campaign","Method":"POST","Type":"campaign/patch.Handler","Connector":"ci_ads"} */`), 0o644)) + + routesDir := filepath.Join(tempDir, "config", "routes", "platform", "campaign", "patch") + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "post"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "post.yaml"), []byte(`Resource: + Views: + - Name: post + Mode: SQLExec + Connector: + Ref: ci_ads + Template: + SourceURL: post/post.sql +`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "post", "post.sql"), []byte(`$Nop($Unsafe.Id)`), 0o644)) + + compiler := New() + res, err := compiler.Compile(context.Background(), &shape.Source{ + Name: "post", + Path: sourcePath, + DQL: `/* {"URI":"/v1/api/platform/campaign","Method":"POST","Type":"campaign/patch.Handler","Connector":"ci_ads"} */`, + }, shape.WithDQLPathMarker("sqlsrc"), shape.WithRoutesRelativePath("config/routes")) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "post", planned.Views[0].Name) + assert.Equal(t, "SQLExec", planned.Views[0].Mode) + assert.Equal(t, "ci_ads", planned.Views[0].Connector) + assert.Contains(t, planned.Views[0].SQL, "$Nop(") +} diff --git a/repository/shape/compile/component_types.go b/repository/shape/compile/component_types.go new file mode 100644 index 00000000..5c553bc9 --- /dev/null +++ b/repository/shape/compile/component_types.go @@ -0,0 +1,432 @@ +package compile + +import ( + "os" + "path/filepath" + "sort" + "strings" + + "github.com/viant/datly/repository/shape" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" + "gopkg.in/yaml.v3" +) + +type componentVisitState int + +const ( + componentVisitIdle componentVisitState = iota + componentVisitActive + componentVisitDone +) + +func appendComponentTypes(source *shape.Source, result *plan.Result) []*dqlshape.Diagnostic { + return appendComponentTypesWithLayout(source, result, defaultCompilePathLayout()) +} + +func appendComponentTypesWithLayout(source *shape.Source, result *plan.Result, layout compilePathLayout) []*dqlshape.Diagnostic { + if source == nil || result == nil { + return nil + } + _, routesRoot, dqlRoot, ok := sourceRootsWithLayout(source.Path, layout) + if !ok { + return nil + } + sourceNamespace, _ := dqlToRouteNamespaceWithLayout(source.Path, layout) + collector := &componentCollector{ + routesRoot: routesRoot, + visited: map[string]componentVisitState{}, + outputByRoute: map[string]string{}, + typesByName: map[string]*plan.Type{}, + } + if strings.TrimSpace(sourceNamespace) != "" { + collector.collect(sourceNamespace, relationSpan(source.DQL, 0), false) + } + + for _, stateItem := range result.States { + if stateItem == nil || !strings.EqualFold(strings.TrimSpace(stateItem.Kind), "component") { + continue + } + ref := strings.TrimSpace(stateItem.In) + if ref == "" { + continue + } + namespace := resolveComponentNamespaceWithNamespace(ref, source.Path, dqlRoot, sourceNamespace) + if namespace == "" { + collector.diags = append(collector.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompRefInvalid, + Severity: dqlshape.SeverityWarning, + Message: "invalid component reference: " + ref, + Hint: "use ../component/ref or GET:/v1/api/... route reference", + Span: componentRefSpan(source.DQL, ref), + }) + continue + } + outputType, ok := collector.collect(namespace, componentRefSpan(source.DQL, ref), true) + if ok && strings.TrimSpace(stateItem.DataType) == "" { + stateItem.DataType = strings.TrimSpace(outputType) + } + } + + names := make([]string, 0, len(collector.typesByName)) + for name := range collector.typesByName { + names = append(names, name) + } + sort.Strings(names) + existing := map[string]bool{} + reportedCollision := map[string]bool{} + for _, item := range result.Types { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + existing[strings.ToLower(strings.TrimSpace(item.Name))] = true + } + for _, name := range names { + keyName := strings.ToLower(strings.TrimSpace(name)) + if existing[keyName] { + if !reportedCollision[keyName] { + collector.diags = append(collector.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompTypeCollision, + Severity: dqlshape.SeverityWarning, + Message: "component type skipped due to existing type name: " + strings.TrimSpace(name), + Hint: "rename colliding type or keep route type as canonical source", + Span: relationSpan(source.DQL, 0), + }) + reportedCollision[keyName] = true + } + continue + } + item := collector.typesByName[name] + result.Types = append(result.Types, item) + existing[keyName] = true + } + return collector.diags +} + +type componentCollector struct { + routesRoot string + visited map[string]componentVisitState + outputByRoute map[string]string + typesByName map[string]*plan.Type + diags []*dqlshape.Diagnostic +} + +func (c *componentCollector) collect(namespace string, span dqlshape.Span, required bool) (string, bool) { + key := strings.ToLower(strings.TrimSpace(namespace)) + if key == "" { + return "", false + } + switch c.visited[key] { + case componentVisitDone: + return c.outputByRoute[key], true + case componentVisitActive: + c.diags = append(c.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompCycle, + Severity: dqlshape.SeverityWarning, + Message: "component reference cycle detected at " + namespace, + Hint: "break cyclic component references", + Span: span, + }) + return "", false + } + c.visited[key] = componentVisitActive + + payload, ok := loadRoutePayload(c.routesRoot, namespace) + if !ok { + c.visited[key] = componentVisitDone + if required { + c.diags = append(c.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompRouteMissing, + Severity: dqlshape.SeverityWarning, + Message: "component route YAML not found: " + namespace, + Hint: "ensure matching route exists under repo/dev/Datly/routes", + Span: span, + }) + } + return "", false + } + + for _, item := range payload.Resource.Types { + name := strings.TrimSpace(item.Name) + if name == "" { + continue + } + keyName := strings.ToLower(name) + if _, exists := c.typesByName[keyName]; exists { + continue + } + c.typesByName[keyName] = &plan.Type{ + Name: name, + Alias: strings.TrimSpace(item.Alias), + DataType: strings.TrimSpace(item.DataType), + Cardinality: strings.TrimSpace(item.Cardinality), + Package: strings.TrimSpace(item.Package), + ModulePath: strings.TrimSpace(item.ModulePath), + } + } + + outputType := routeOutputType(payload) + c.outputByRoute[key] = outputType + + for _, param := range payload.Resource.Parameters { + if !strings.EqualFold(strings.TrimSpace(param.In.Kind), "component") { + continue + } + nextNS := resolveComponentNamespaceFromRoute(strings.TrimSpace(param.In.Name), namespace) + if nextNS == "" { + c.diags = append(c.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompRefInvalid, + Severity: dqlshape.SeverityWarning, + Message: "invalid nested component reference: " + strings.TrimSpace(param.In.Name), + Hint: "use ../component/ref or GET:/v1/api/... route reference", + Span: span, + }) + continue + } + c.collect(nextNS, span, true) + } + + c.visited[key] = componentVisitDone + return outputType, true +} + +func sourceRoots(sourcePath string) (platformRoot, routesRoot, dqlRoot string, ok bool) { + return sourceRootsWithLayout(sourcePath, defaultCompilePathLayout()) +} + +func sourceRootsWithLayout(sourcePath string, layout compilePathLayout) (platformRoot, routesRoot, dqlRoot string, ok bool) { + path := filepath.Clean(strings.TrimSpace(sourcePath)) + if path == "" { + return "", "", "", false + } + normalized := filepath.ToSlash(path) + marker := layout.dqlMarker + if marker == "" { + marker = defaultCompilePathLayout().dqlMarker + } + idx := strings.Index(normalized, marker) + if idx == -1 { + return "", "", "", false + } + platformRoot = path[:idx] + dqlRoot = filepath.Join(platformRoot, filepath.FromSlash(strings.Trim(marker, "/"))) + routesRoot = joinRelativePath(platformRoot, layout.routesRelative) + return platformRoot, routesRoot, dqlRoot, true +} + +func dqlToRouteNamespace(sourcePath string) (string, bool) { + return dqlToRouteNamespaceWithLayout(sourcePath, defaultCompilePathLayout()) +} + +func dqlToRouteNamespaceWithLayout(sourcePath string, layout compilePathLayout) (string, bool) { + path := filepath.Clean(strings.TrimSpace(sourcePath)) + if path == "" { + return "", false + } + normalized := filepath.ToSlash(path) + marker := layout.dqlMarker + if marker == "" { + marker = defaultCompilePathLayout().dqlMarker + } + idx := strings.Index(normalized, marker) + if idx == -1 { + return "", false + } + relative := strings.TrimPrefix(normalized[idx+len(marker):], "/") + if relative == "" { + return "", false + } + return strings.Trim(strings.TrimSuffix(relative, filepath.Ext(relative)), "/"), true +} + +func resolveComponentNamespace(ref, sourcePath, dqlRoot string) string { + ref = strings.TrimSpace(ref) + ref = strings.TrimPrefix(ref, "GET:") + ref = strings.TrimPrefix(ref, "POST:") + ref = strings.TrimPrefix(ref, "PUT:") + ref = strings.TrimPrefix(ref, "PATCH:") + ref = strings.TrimPrefix(ref, "DELETE:") + ref = strings.TrimPrefix(ref, "OPTIONS:") + ref = strings.TrimSpace(ref) + if strings.HasPrefix(ref, "/v1/api/") { + return strings.Trim(strings.TrimPrefix(ref, "/v1/api/"), "/") + } + if strings.HasPrefix(ref, "v1/api/") { + return strings.Trim(strings.TrimPrefix(ref, "v1/api/"), "/") + } + if strings.HasPrefix(ref, "/") { + return strings.Trim(ref, "/") + } + if dqlRoot == "" || strings.TrimSpace(sourcePath) == "" { + return "" + } + base := filepath.Dir(filepath.Clean(sourcePath)) + target := filepath.Clean(filepath.Join(base, ref)) + rel, err := filepath.Rel(dqlRoot, target) + if err != nil { + return "" + } + rel = filepath.ToSlash(rel) + rel = strings.TrimSuffix(rel, filepath.Ext(rel)) + return strings.Trim(rel, "/") +} + +func resolveComponentNamespaceWithNamespace(ref, sourcePath, dqlRoot, sourceNamespace string) string { + if namespace := resolveComponentNamespace(ref, sourcePath, dqlRoot); namespace != "" { + return namespace + } + return resolveComponentNamespaceFromRoute(ref, sourceNamespace) +} + +func resolveComponentNamespaceFromRoute(ref, sourceNamespace string) string { + ref = strings.TrimSpace(ref) + if ref == "" { + return "" + } + if namespace := resolveComponentNamespace(ref, "", ""); namespace != "" { + return namespace + } + normalizedBase := strings.Trim(strings.TrimSpace(sourceNamespace), "/") + if normalizedBase == "" { + return "" + } + baseDir := pathDir(normalizedBase) + target := filepath.ToSlash(filepath.Clean(filepath.Join(baseDir, ref))) + target = strings.TrimSuffix(target, filepath.Ext(target)) + return strings.Trim(target, "/") +} + +func pathDir(path string) string { + if path == "" { + return "" + } + parts := strings.Split(strings.Trim(path, "/"), "/") + if len(parts) <= 1 { + return "" + } + return strings.Join(parts[:len(parts)-1], "/") +} + +type routePayload struct { + Resource struct { + Types []struct { + Name string `yaml:"Name"` + Alias string `yaml:"Alias"` + DataType string `yaml:"DataType"` + Cardinality string `yaml:"Cardinality"` + Package string `yaml:"Package"` + ModulePath string `yaml:"ModulePath"` + } `yaml:"Types"` + Parameters []struct { + Name string `yaml:"Name"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` + Schema struct { + DataType string `yaml:"DataType"` + Name string `yaml:"Name"` + Package string `yaml:"Package"` + Cardinality string `yaml:"Cardinality"` + } `yaml:"Schema"` + } `yaml:"Parameters"` + } `yaml:"Resource"` + Routes []struct { + Handler struct { + OutputType string `yaml:"OutputType"` + } `yaml:"Handler"` + Output struct { + Cardinality string `yaml:"Cardinality"` + Type struct { + Name string `yaml:"Name"` + Package string `yaml:"Package"` + } `yaml:"Type"` + } `yaml:"Output"` + } `yaml:"Routes"` +} + +func loadRoutePayload(routesRoot, namespace string) (*routePayload, bool) { + candidates := routeYAMLCandidates(routesRoot, namespace) + for _, candidate := range candidates { + data, err := os.ReadFile(candidate) + if err != nil { + continue + } + payload := &routePayload{} + if err = yaml.Unmarshal(data, payload); err != nil { + continue + } + return payload, true + } + return nil, false +} + +func routeOutputType(payload *routePayload) string { + if payload == nil { + return "" + } + for _, route := range payload.Routes { + if outputType := strings.TrimSpace(route.Handler.OutputType); outputType != "" { + leaf := outputType + if idx := strings.LastIndex(leaf, "."); idx >= 0 && idx+1 < len(leaf) { + leaf = leaf[idx+1:] + } + leaf = strings.Trim(strings.TrimSpace(leaf), "*") + if leaf != "" { + return "*" + leaf + } + } + if name := strings.TrimSpace(route.Output.Type.Name); name != "" { + name = strings.Trim(name, "*") + if name != "" { + return "*" + name + } + } + } + for _, param := range payload.Resource.Parameters { + if strings.EqualFold(strings.TrimSpace(param.In.Kind), "output") { + if dataType := strings.TrimSpace(param.Schema.DataType); dataType != "" { + return dataType + } + if name := strings.TrimSpace(param.Schema.Name); name != "" { + name = strings.Trim(name, "*") + if name != "" { + return "*" + name + } + } + } + } + for _, item := range payload.Resource.Types { + if strings.EqualFold(strings.TrimSpace(item.Name), "output") { + if dataType := strings.TrimSpace(item.DataType); dataType != "" { + return dataType + } + return "*Output" + } + } + return "" +} + +func componentRefSpan(raw, ref string) dqlshape.Span { + offset := 0 + ref = strings.TrimSpace(ref) + if ref != "" { + if idx := strings.Index(raw, ref); idx >= 0 { + offset = idx + } + } + return relationSpan(raw, offset) +} + +func routeYAMLCandidates(routesRoot, namespace string) []string { + namespace = strings.Trim(namespace, "/") + if namespace == "" { + return nil + } + leaf := filepath.Base(namespace) + return []string{ + filepath.Join(routesRoot, filepath.FromSlash(namespace)+".yaml"), + filepath.Join(routesRoot, filepath.FromSlash(namespace), leaf+".yaml"), + } +} diff --git a/repository/shape/compile/component_types_test.go b/repository/shape/compile/component_types_test.go new file mode 100644 index 00000000..0e93ec71 --- /dev/null +++ b/repository/shape/compile/component_types_test.go @@ -0,0 +1,155 @@ +package compile + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + "github.com/viant/datly/repository/shape/plan" +) + +func TestResolveComponentNamespace(t *testing.T) { + dqlRoot := "/repo/dql" + source := "/repo/dql/platform/tvaffiliatestation/tvaffiliatestation.dql" + assert.Equal(t, "platform/acl/auth", resolveComponentNamespace("../acl/auth", source, dqlRoot)) + assert.Equal(t, "platform/acl/auth", resolveComponentNamespace("GET:/v1/api/platform/acl/auth", source, dqlRoot)) + assert.Equal(t, "platform/acl/auth", resolveComponentNamespace("v1/api/platform/acl/auth", source, dqlRoot)) +} + +func TestDQLToRouteNamespace(t *testing.T) { + ns, ok := dqlToRouteNamespace("/repo/dql/platform/tvaffiliatestation/tvaffiliatestation.dql") + require.True(t, ok) + assert.Equal(t, "platform/tvaffiliatestation/tvaffiliatestation", ns) +} + +func TestSourceRoots_CustomLayout(t *testing.T) { + layout := compilePathLayout{ + dqlMarker: "/sqlsrc/", + routesRelative: "config/routes", + } + platformRoot, routesRoot, dqlRoot, ok := sourceRootsWithLayout("/repo/sqlsrc/platform/agency/agency.dql", layout) + require.True(t, ok) + assert.Equal(t, "/repo", filepath.ToSlash(platformRoot)) + assert.Equal(t, "/repo/config/routes", filepath.ToSlash(routesRoot)) + assert.Equal(t, "/repo/sqlsrc", filepath.ToSlash(dqlRoot)) + + ns, ok := dqlToRouteNamespaceWithLayout("/repo/sqlsrc/platform/agency/agency.dql", layout) + require.True(t, ok) + assert.Equal(t, "platform/agency/agency", ns) +} + +func TestAppendComponentTypes(t *testing.T) { + temp := t.TempDir() + dqlDir := filepath.Join(temp, "dql", "platform", "tvaffiliatestation") + routesDir := filepath.Join(temp, "repo", "dev", "Datly", "routes", "platform", "acl") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "auth"), 0o755)) + require.NoError(t, os.MkdirAll(routesDir, 0o755)) + sourcePath := filepath.Join(dqlDir, "tvaffiliatestation.dql") + require.NoError(t, os.WriteFile(sourcePath, []byte("SELECT 1"), 0o644)) + + authYAML := `Resource: + Types: + - Name: Input + DataType: "*Input" + Package: acl/auth + ModulePath: github.vianttech.com/viant/platform/pkg/platform/acl/auth + Parameters: + - In: + Kind: component + Name: GET:/v1/api/platform/acl/user +Routes: + - Handler: + OutputType: acl/auth.Output +` + userYAML := `Resource: + Types: + - Name: UserView + DataType: "struct{Id int;}" + Package: acl + ModulePath: github.vianttech.com/viant/platform/pkg/platform/acl +` + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "auth", "auth.yaml"), []byte(authYAML), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "user.yaml"), []byte(userYAML), 0o644)) + + result := &plan.Result{ + States: []*plan.State{ + {Name: "Auth", Kind: "component", In: "../acl/auth"}, + }, + } + appendComponentTypes(&shape.Source{Path: sourcePath, DQL: "#set($Auth = $component<../acl/auth>())"}, result) + require.Len(t, result.Types, 2) + names := map[string]bool{} + for _, item := range result.Types { + names[item.Name] = true + } + assert.True(t, names["Input"]) + assert.True(t, names["UserView"]) + assert.Equal(t, "*Output", result.States[0].DataType) +} + +func TestAppendComponentTypes_MissingComponentRoute(t *testing.T) { + temp := t.TempDir() + dqlDir := filepath.Join(temp, "dql", "platform", "sample") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + sourcePath := filepath.Join(dqlDir, "sample.dql") + dql := "#set($Auth = $component<../acl/missing>())\nSELECT 1" + require.NoError(t, os.WriteFile(sourcePath, []byte(dql), 0o644)) + result := &plan.Result{ + States: []*plan.State{{Name: "Auth", Kind: "component", In: "../acl/missing"}}, + } + diags := appendComponentTypes(&shape.Source{Path: sourcePath, DQL: dql}, result) + require.NotEmpty(t, diags) + assert.Equal(t, "DQL-COMP-ROUTE-MISSING", diags[0].Code) + assert.GreaterOrEqual(t, diags[0].Span.Start.Line, 1) + assert.GreaterOrEqual(t, diags[0].Span.Start.Char, 1) +} + +func TestAppendComponentTypes_TypeCollisionEmitsDiagnostic(t *testing.T) { + temp := t.TempDir() + dqlDir := filepath.Join(temp, "dql", "platform", "tvaffiliatestation") + routesDir := filepath.Join(temp, "repo", "dev", "Datly", "routes", "platform", "acl") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "auth"), 0o755)) + sourcePath := filepath.Join(dqlDir, "tvaffiliatestation.dql") + require.NoError(t, os.WriteFile(sourcePath, []byte("SELECT 1"), 0o644)) + + authYAML := `Resource: + Types: + - Name: Input + DataType: "*Input" + Package: acl/auth + ModulePath: github.vianttech.com/viant/platform/pkg/platform/acl/auth +` + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "auth", "auth.yaml"), []byte(authYAML), 0o644)) + + result := &plan.Result{ + States: []*plan.State{ + {Name: "Auth", Kind: "component", In: "../acl/auth"}, + }, + Types: []*plan.Type{ + { + Name: "Input", + DataType: "*Input", + Package: "campaign/patch", + ModulePath: "github.vianttech.com/viant/platform/pkg/platform/campaign/patch", + }, + }, + } + diags := appendComponentTypes(&shape.Source{Path: sourcePath, DQL: "#set($Auth = $component<../acl/auth>())"}, result) + require.NotEmpty(t, diags) + var found bool + for _, item := range diags { + if item != nil && item.Code == dqldiag.CodeCompTypeCollision { + found = true + break + } + } + assert.True(t, found) + require.Len(t, result.Types, 1) + assert.Equal(t, "campaign/patch", result.Types[0].Package) +} diff --git a/repository/shape/compile/dml/compiler.go b/repository/shape/compile/dml/compiler.go new file mode 100644 index 00000000..8b6838fd --- /dev/null +++ b/repository/shape/compile/dml/compiler.go @@ -0,0 +1,13 @@ +package dml + +import ( + "github.com/viant/datly/repository/shape/compile/pipeline" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" + "github.com/viant/datly/repository/shape/plan" +) + +// Compile builds an exec-oriented view and validates DML statements. +func Compile(sourceName, sqlText string, statements dqlstmt.Statements) (*plan.View, []*dqlshape.Diagnostic) { + return pipeline.BuildExec(sourceName, sqlText, statements) +} diff --git a/repository/shape/compile/dml/compiler_test.go b/repository/shape/compile/dml/compiler_test.go new file mode 100644 index 00000000..7c1d907f --- /dev/null +++ b/repository/shape/compile/dml/compiler_test.go @@ -0,0 +1,25 @@ +package dml + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" +) + +func TestCompile_Insert(t *testing.T) { + sqlText := "INSERT INTO ORDERS(id) VALUES (1)" + view, diags := Compile("orders_exec", sqlText, dqlstmt.New(sqlText)) + require.NotNil(t, view) + assert.Equal(t, "ORDERS", view.Table) + assert.Empty(t, diags) +} + +func TestCompile_ServiceMissingArg(t *testing.T) { + sqlText := "$sql.Insert($rec)" + _, diags := Compile("orders_exec", sqlText, dqlstmt.New(sqlText)) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeDMLServiceArg, diags[0].Code) +} diff --git a/repository/shape/compile/enrich.go b/repository/shape/compile/enrich.go new file mode 100644 index 00000000..08b80570 --- /dev/null +++ b/repository/shape/compile/enrich.go @@ -0,0 +1,756 @@ +package compile + +import ( + "encoding/json" + "os" + "path/filepath" + "regexp" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/compile/pipeline" + "github.com/viant/datly/repository/shape/plan" + "gopkg.in/yaml.v3" +) + +var ( + ruleHeaderExpr = regexp.MustCompile(`(?s)^\s*/\*\s*(\{.*?\})\s*\*/`) + embedExpr = regexp.MustCompile(`(?is)\$\{\s*embed:\s*([^}]+)\}`) + fromTableExpr = regexp.MustCompile(`(?is)\bfrom\s+([a-zA-Z_$][a-zA-Z0-9_$.{}/]*)`) + summaryJoinExpr = regexp.MustCompile(`(?is)\bjoin\s*\((.*?)\)\s*summary\s+on\s+1\s*=\s*1`) + joinEmbedExpr = regexp.MustCompile(`(?is)\bjoin\s*\(\s*\$\{\s*embed:\s*([^}]+)\}\s*\)\s*(?:as\s+)?([a-zA-Z_][a-zA-Z0-9_]*)`) + joinBodyExpr = regexp.MustCompile(`(?is)\bjoin\s*\((.*?)\)\s*(?:as\s+)?([a-zA-Z_][a-zA-Z0-9_]*)\s+on\b`) +) + +type ruleSettings struct { + Connector string `json:"Connector"` + Name string `json:"Name"` + Type string `json:"Type"` + Method string `json:"Method"` + URI string `json:"URI"` +} + +func applySourceParityEnrichment(result *plan.Result, source *shape.Source) { + applySourceParityEnrichmentWithLayout(result, source, defaultCompilePathLayout()) +} + +func applySourceParityEnrichmentWithLayout(result *plan.Result, source *shape.Source, layout compilePathLayout) { + if result == nil || len(result.Views) == 0 { + return + } + settings := extractRuleSettings(source) + legacyViews := loadLegacyRouteViewAttrsWithLayout(source, settings, layout) + baseDir := sourceSQLBaseDir(source) + module := sourceModuleWithLayout(source, layout) + sourceName := pipeline.SanitizeName(source.Name) + joinEmbedRefs := map[string]string{} + joinSubqueryBodies := map[string]string{} + if len(result.Views) > 0 && result.Views[0] != nil { + sqlForJoinExtract := result.Views[0].SQL + if source != nil && strings.TrimSpace(source.DQL) != "" { + sqlForJoinExtract = source.DQL + } + joinEmbedRefs = extractJoinEmbedRefs(sqlForJoinExtract) + joinSubqueryBodies = extractJoinSubqueryBodies(sqlForJoinExtract) + } + for idx, item := range result.Views { + if item == nil { + continue + } + if legacy, ok := lookupLegacyRouteViewAttr(legacyViews, item.Name); ok { + if legacy.Mode != "" { + item.Mode = legacy.Mode + } + if legacy.Module != "" { + item.Module = legacy.Module + } + if legacy.AllowNulls != nil { + value := *legacy.AllowNulls + item.AllowNulls = &value + } + if legacy.SelectorNamespace != "" { + item.SelectorNamespace = legacy.SelectorNamespace + } + if legacy.SelectorNoLimit != nil { + value := *legacy.SelectorNoLimit + item.SelectorNoLimit = &value + } + if legacy.SchemaType != "" { + item.SchemaType = legacy.SchemaType + } + if legacy.Cardinality != "" { + item.Cardinality = legacy.Cardinality + } + if legacy.HasSummary != nil && *legacy.HasSummary && strings.TrimSpace(item.Summary) == "" { + item.Summary = "legacy-summary" + } + } + if item.SQLURI == "" && baseDir != "" { + item.SQLURI = baseDir + "/" + item.Name + ".sql" + } + if item.Module == "" { + item.Module = module + } + if item.SelectorNamespace == "" { + item.SelectorNamespace = defaultSelectorNamespace(item.Name) + } + if item.SchemaType == "" { + item.SchemaType = defaultSchemaType(item.Name, settings, idx == 0) + } + if shouldInferTable(item) { + candidateSQL := item.SQL + if strings.TrimSpace(candidateSQL) == "" { + candidateSQL = item.Table + } + if table := inferTableFromSQL(candidateSQL, source); table != "" { + item.Table = table + } + } + if strings.HasPrefix(strings.TrimSpace(item.Table), "(") || normalizedTemplatePlaceholderTable(strings.TrimSpace(item.Table)) { + if ref, ok := joinEmbedRefs[item.Name]; ok { + if table := inferTableFromEmbedRef(source, ref); table != "" { + item.Table = table + } + } + if body, ok := joinSubqueryBodies[item.Name]; ok { + if table := inferTableFromSQL(body, source); table != "" { + item.Table = table + } + } + if table := inferTableFromSiblingSQL(item.Name, source); table != "" { + item.Table = table + } + } + if item.Connector == "" && settings.Connector != "" { + item.Connector = settings.Connector + } + if item.Connector == "" && source != nil && strings.TrimSpace(source.Connector) != "" { + item.Connector = strings.TrimSpace(source.Connector) + } + if item.Connector == "" { + item.Connector = inferConnector(item, source) + } + if item.Summary == "" { + item.Summary = extractSummarySQL(item.SQL) + if item.Summary == "" && source != nil { + item.Summary = extractSummarySQL(source.DQL) + } + } + } + if source != nil && strings.TrimSpace(source.Path) != "" { + normalizeRootViewName(result, sourceName, settings) + } +} + +type legacyRouteViewAttr struct { + Name string + Mode string + Module string + AllowNulls *bool + SelectorNamespace string + SelectorNoLimit *bool + SchemaType string + Cardinality string + HasSummary *bool +} + +func loadLegacyRouteViewAttrs(source *shape.Source, settings *ruleSettings) []legacyRouteViewAttr { + return loadLegacyRouteViewAttrsWithLayout(source, settings, defaultCompilePathLayout()) +} + +func loadLegacyRouteViewAttrsWithLayout(source *shape.Source, settings *ruleSettings, layout compilePathLayout) []legacyRouteViewAttr { + if source == nil || strings.TrimSpace(source.Path) == "" { + return nil + } + platformRoot, relativeDir, stem, ok := platformPathParts(source.Path, layout) + if !ok { + return nil + } + typeExpr := "" + if settings != nil { + typeExpr = strings.TrimSpace(settings.Type) + } + typeExpr = strings.Trim(typeExpr, `"'`) + typeExpr = strings.TrimSuffix(typeExpr, ".Handler") + typeStem := "" + if typeExpr != "" { + typeStem = filepath.Base(filepath.FromSlash(typeExpr)) + } + routesRoot := joinRelativePath(platformRoot, layout.routesRelative) + routesBase := filepath.Join(routesRoot, filepath.FromSlash(relativeDir)) + candidates := legacyRouteYAMLCandidates(routesBase, stem, typeStem) + for _, candidate := range candidates { + if attrs := parseLegacyRouteViewAttrs(candidate); len(attrs) > 0 { + return attrs + } + } + return nil +} + +func parseLegacyRouteViewAttrs(path string) []legacyRouteViewAttr { + data, err := os.ReadFile(path) + if err != nil { + return nil + } + var payload struct { + Resource struct { + Views []struct { + Name string `yaml:"Name"` + Mode string `yaml:"Mode"` + Module string `yaml:"Module"` + AllowNulls *bool `yaml:"AllowNulls"` + Selector struct { + Namespace string `yaml:"Namespace"` + NoLimit *bool `yaml:"NoLimit"` + } `yaml:"Selector"` + Template struct { + Summary *struct{} `yaml:"Summary"` + } `yaml:"Template"` + Schema struct { + Cardinality string `yaml:"Cardinality"` + DataType string `yaml:"DataType"` + Name string `yaml:"Name"` + } `yaml:"Schema"` + } `yaml:"Views"` + } `yaml:"Resource"` + } + if err = yaml.Unmarshal(data, &payload); err != nil { + return nil + } + result := make([]legacyRouteViewAttr, 0, len(payload.Resource.Views)) + for _, item := range payload.Resource.Views { + cardinality := strings.TrimSpace(item.Schema.Cardinality) + if cardinality != "" { + cardinality = strings.ToLower(cardinality) + } + result = append(result, legacyRouteViewAttr{ + Name: strings.TrimSpace(item.Name), + Mode: strings.TrimSpace(item.Mode), + Module: strings.TrimSpace(item.Module), + AllowNulls: item.AllowNulls, + SelectorNamespace: strings.TrimSpace(item.Selector.Namespace), + SelectorNoLimit: item.Selector.NoLimit, + SchemaType: firstNonEmptyString(strings.TrimSpace(item.Schema.DataType), strings.TrimSpace(item.Schema.Name)), + Cardinality: cardinality, + HasSummary: func() *bool { + if item.Template.Summary == nil { + return nil + } + value := true + return &value + }(), + }) + } + return result +} + +func lookupLegacyRouteViewAttr(items []legacyRouteViewAttr, name string) (legacyRouteViewAttr, bool) { + name = strings.TrimSpace(name) + if name == "" { + return legacyRouteViewAttr{}, false + } + for _, item := range items { + if strings.EqualFold(strings.TrimSpace(item.Name), name) { + return item, true + } + } + return legacyRouteViewAttr{}, false +} + +func firstNonEmptyString(values ...string) string { + for _, value := range values { + value = strings.TrimSpace(value) + if value != "" { + return value + } + } + return "" +} + +func extractSummarySQL(sqlText string) string { + sqlText = strings.TrimSpace(sqlText) + if sqlText == "" || !strings.Contains(sqlText, "$View.") { + return "" + } + matches := summaryJoinExpr.FindStringSubmatch(sqlText) + if len(matches) < 2 { + return "" + } + return strings.TrimSpace(matches[1]) +} + +func extractRuleSettings(source *shape.Source) *ruleSettings { + if source == nil || strings.TrimSpace(source.DQL) == "" { + return &ruleSettings{} + } + matches := ruleHeaderExpr.FindStringSubmatch(source.DQL) + if len(matches) < 2 { + return &ruleSettings{} + } + rawJSON := strings.TrimSpace(matches[1]) + ret := &ruleSettings{} + _ = json.Unmarshal([]byte(rawJSON), ret) + return ret +} + +func sourceSQLBaseDir(source *shape.Source) string { + if source == nil { + return "" + } + path := strings.TrimSpace(source.Path) + if path == "" { + return "" + } + base := strings.TrimSpace(filepath.Base(path)) + if base == "" { + return "" + } + stem := strings.TrimSpace(strings.TrimSuffix(base, filepath.Ext(base))) + if stem == "" || stem == "." || stem == string(filepath.Separator) { + return "" + } + return stem +} + +func sourceModule(source *shape.Source) string { + return sourceModuleWithLayout(source, defaultCompilePathLayout()) +} + +func sourceModuleWithLayout(source *shape.Source, layout compilePathLayout) string { + if source == nil || strings.TrimSpace(source.Path) == "" { + return "" + } + normalized := filepath.ToSlash(source.Path) + marker := layout.dqlMarker + if marker == "" { + marker = defaultCompilePathLayout().dqlMarker + } + idx := strings.Index(normalized, marker) + if idx == -1 { + return "" + } + relative := strings.TrimPrefix(normalized[idx+len(marker):], "/") + dir := strings.TrimSpace(filepath.ToSlash(filepath.Dir(relative))) + if dir == "." || dir == "/" { + return "" + } + return dir +} + +func defaultSelectorNamespace(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + var b strings.Builder + for i := 0; i < len(name); i++ { + ch := name[i] + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') { + b.WriteByte(byte(strings.ToLower(string(ch))[0])) + } + } + value := b.String() + switch { + case len(value) >= 2: + return value[:2] + case len(value) == 1: + return value + default: + return "" + } +} + +func defaultSchemaType(name string, settings *ruleSettings, root bool) string { + if root && settings != nil && strings.TrimSpace(settings.Name) != "" { + return "*" + strings.TrimSpace(settings.Name) + "View" + } + name = strings.TrimSpace(name) + if name == "" { + return "" + } + return "*" + toExportedTypeName(name) + "View" +} + +func toExportedTypeName(name string) string { + name = strings.TrimSpace(name) + if name == "" { + return "" + } + parts := strings.FieldsFunc(name, func(r rune) bool { + return r == '_' || r == '-' || r == ' ' || r == '.' + }) + if len(parts) == 0 { + return "" + } + var b strings.Builder + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + b.WriteString(strings.ToUpper(part[:1])) + if len(part) > 1 { + b.WriteString(part[1:]) + } + } + return b.String() +} + +func shouldInferTable(item *plan.View) bool { + if item == nil { + return false + } + name := strings.TrimSpace(item.Name) + table := strings.TrimSpace(item.Table) + if table == "" { + return true + } + if strings.HasPrefix(table, "(") { + return true + } + if normalizedTemplatePlaceholderTable(table) { + return true + } + return strings.EqualFold(name, table) +} + +func normalizedTemplatePlaceholderTable(table string) bool { + if table == "" { + return false + } + parts := strings.Split(table, ".") + if len(parts) < 3 { + return false + } + for i := 0; i < len(parts)-1; i++ { + part := strings.TrimSpace(parts[i]) + if part == "" { + return false + } + for _, ch := range part { + if ch < '0' || ch > '9' { + return false + } + } + } + return true +} + +func inferTableFromSQL(sqlText string, source *shape.Source) string { + sqlText = strings.TrimSpace(sqlText) + if sqlText == "" { + return "" + } + if expr := topLevelFromExpr(sqlText); expr != "" { + if table := tableFromFromExpr(expr, source); table != "" { + return table + } + } + if table := pipeline.InferTableFromSQL(sqlText); table != "" { + if !strings.EqualFold(table, "DQLView") { + return table + } + } + cleaned := embedExpr.ReplaceAllString(sqlText, " ") + match := fromTableExpr.FindStringSubmatch(cleaned) + if len(match) >= 2 { + return strings.Trim(match[1], "`\"") + } + if table := inferFromEmbeddedSQL(sqlText, source); table != "" { + return table + } + return "" +} + +func inferFromEmbeddedSQL(sqlText string, source *shape.Source) string { + matches := embedExpr.FindStringSubmatch(sqlText) + if len(matches) < 2 { + return "" + } + ref := strings.TrimSpace(matches[1]) + ref = strings.Trim(ref, `"'`) + if ref == "" { + return "" + } + resolved := resolveEmbedPath(source, ref) + if resolved == "" { + return "" + } + embedded, err := os.ReadFile(resolved) + if err != nil { + return "" + } + queryNode, _, err := pipeline.ParseSelectWithDiagnostic(string(embedded)) + if err != nil || queryNode == nil { + fallback := fromTableExpr.FindStringSubmatch(string(embedded)) + if len(fallback) < 2 { + return "" + } + return strings.Trim(fallback[1], "`\"") + } + _, table, err := pipeline.InferRoot(queryNode, "") + if err != nil || strings.TrimSpace(table) == "" { + return "" + } + if strings.EqualFold(strings.TrimSpace(table), "DQLView") { + return "" + } + return strings.Trim(table, "`\"") +} + +func resolveEmbedPath(source *shape.Source, ref string) string { + if filepath.IsAbs(ref) { + return ref + } + if source == nil || strings.TrimSpace(source.Path) == "" { + return "" + } + base := source.Path + if fi, err := os.Stat(base); err == nil && fi.IsDir() { + return filepath.Clean(filepath.Join(base, ref)) + } + return filepath.Clean(filepath.Join(filepath.Dir(base), ref)) +} + +func inferTableFromSiblingSQL(viewName string, source *shape.Source) string { + viewName = strings.TrimSpace(viewName) + if viewName == "" || source == nil || strings.TrimSpace(source.Path) == "" { + return "" + } + sibling := filepath.Join(filepath.Dir(source.Path), viewName+".sql") + data, err := os.ReadFile(sibling) + if err != nil { + sibling = filepath.Join(filepath.Dir(source.Path), strings.ToLower(viewName)+".sql") + data, err = os.ReadFile(sibling) + } + if err != nil { + return "" + } + return inferTableFromSQL(string(data), source) +} + +func inferTableFromEmbedRef(source *shape.Source, ref string) string { + ref = strings.Trim(strings.TrimSpace(ref), `"'`) + if ref == "" { + return "" + } + resolved := resolveEmbedPath(source, ref) + if resolved == "" { + return "" + } + data, err := os.ReadFile(resolved) + if err != nil { + return "" + } + return pipeline.InferTableFromSQL(string(data)) +} + +func topLevelFromExpr(sqlText string) string { + lower := strings.ToLower(sqlText) + depth := 0 + inSingle := false + inDouble := false + inBacktick := false + for i := 0; i < len(sqlText); i++ { + ch := sqlText[i] + switch ch { + case '\'': + if !inDouble && !inBacktick { + inSingle = !inSingle + } + case '"': + if !inSingle && !inBacktick { + inDouble = !inDouble + } + case '`': + if !inSingle && !inDouble { + inBacktick = !inBacktick + } + case '(': + if !inSingle && !inDouble && !inBacktick { + depth++ + } + case ')': + if !inSingle && !inDouble && !inBacktick && depth > 0 { + depth-- + } + } + if inSingle || inDouble || inBacktick || depth != 0 { + continue + } + if i+6 > len(sqlText) { + break + } + if lower[i:i+4] != "from" { + continue + } + if i > 0 { + prev := lower[i-1] + if (prev >= 'a' && prev <= 'z') || (prev >= '0' && prev <= '9') || prev == '_' { + continue + } + } + j := i + 4 + for j < len(sqlText) && (sqlText[j] == ' ' || sqlText[j] == '\n' || sqlText[j] == '\t' || sqlText[j] == '\r') { + j++ + } + if j >= len(sqlText) { + return "" + } + if sqlText[j] == '(' { + start := j + d := 0 + for ; j < len(sqlText); j++ { + if sqlText[j] == '(' { + d++ + } else if sqlText[j] == ')' { + d-- + if d == 0 { + j++ + break + } + } + } + for j < len(sqlText) && (sqlText[j] == ' ' || sqlText[j] == '\n' || sqlText[j] == '\t' || sqlText[j] == '\r') { + j++ + } + for j < len(sqlText) { + c := sqlText[j] + if !(c == '_' || c == '.' || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9')) { + break + } + j++ + } + return strings.TrimSpace(sqlText[start:j]) + } + start := j + for j < len(sqlText) { + c := sqlText[j] + if !(c == '_' || c == '.' || c == '/' || c == '{' || c == '}' || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '$') { + break + } + j++ + } + return strings.TrimSpace(sqlText[start:j]) + } + return "" +} + +func tableFromFromExpr(fromExpr string, source *shape.Source) string { + fromExpr = strings.TrimSpace(fromExpr) + if fromExpr == "" { + return "" + } + if strings.HasPrefix(fromExpr, "(") { + if table := inferFromEmbeddedSQL(fromExpr, source); table != "" { + return table + } + inner := fromExpr + if idx := strings.LastIndex(inner, ")"); idx > 0 { + inner = strings.TrimSpace(inner[1:idx]) + } + return inferTableFromSQL(inner, source) + } + return strings.Trim(fromExpr, "`\"") +} + +func inferConnector(item *plan.View, source *shape.Source) string { + if item == nil { + return "" + } + path := "" + if source != nil { + path = strings.ToLower(strings.ReplaceAll(source.Path, "\\", "/")) + } + table := strings.ToUpper(strings.TrimSpace(item.Table)) + switch { + case strings.Contains(path, "/dql/system/"): + return "system" + case strings.HasPrefix(table, "CI_") || strings.Contains(table, ".CI_"): + return "ci_ads" + case strings.Contains(path, "/dql/ui/"): + return "sitemgmt" + case strings.Contains(table, "SITE"): + return "sitemgmt" + default: + return "" + } +} + +func normalizeRootViewName(result *plan.Result, sourceName string, settings *ruleSettings) { + if result == nil || len(result.Views) == 0 { + return + } + root := result.Views[0] + if root == nil { + return + } + desired := sourceName + if desired == "" { + return + } + _ = settings + current := strings.TrimSpace(root.Name) + if current == "" { + root.Name = desired + root.Path = desired + root.Holder = desired + return + } + if strings.EqualFold(current, desired) { + return + } + suspicious := map[string]bool{ + "and": true, "or": true, "status": true, "value": true, "watching": true, + } + if !suspicious[strings.ToLower(current)] { + return + } + if result.ViewsByName != nil { + delete(result.ViewsByName, root.Name) + } else { + result.ViewsByName = map[string]*plan.View{} + } + root.Name = desired + root.Path = desired + root.Holder = desired + result.ViewsByName[root.Name] = root +} + +func extractJoinEmbedRefs(sqlText string) map[string]string { + result := map[string]string{} + if strings.TrimSpace(sqlText) == "" { + return result + } + for _, m := range joinEmbedExpr.FindAllStringSubmatch(sqlText, -1) { + if len(m) < 3 { + continue + } + ref := strings.TrimSpace(m[1]) + alias := strings.TrimSpace(m[2]) + if ref == "" || alias == "" { + continue + } + result[alias] = ref + } + return result +} + +func extractJoinSubqueryBodies(sqlText string) map[string]string { + result := map[string]string{} + if strings.TrimSpace(sqlText) == "" { + return result + } + for _, m := range joinBodyExpr.FindAllStringSubmatch(sqlText, -1) { + if len(m) < 3 { + continue + } + body := strings.TrimSpace(m[1]) + alias := strings.TrimSpace(m[2]) + if body == "" || alias == "" { + continue + } + result[alias] = body + } + return result +} diff --git a/repository/shape/compile/enrich_test.go b/repository/shape/compile/enrich_test.go new file mode 100644 index 00000000..1c532d39 --- /dev/null +++ b/repository/shape/compile/enrich_test.go @@ -0,0 +1,172 @@ +package compile + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" +) + +func TestApplySourceParityEnrichment_RuleConnectorAndSQLURI(t *testing.T) { + source := &shape.Source{ + Path: "/repo/dql/platform/timezone/timezone.dql", + DQL: `/* {"Connector":"ci_ads"} */ SELECT * FROM CI_TIME_ZONE t`, + } + result := &plan.Result{ + Views: []*plan.View{ + {Name: "timezone", Table: "timezone", SQL: "SELECT * FROM CI_TIME_ZONE t"}, + }, + } + + applySourceParityEnrichment(result, source) + + require.Equal(t, "ci_ads", result.Views[0].Connector) + require.Equal(t, "timezone/timezone.sql", result.Views[0].SQLURI) + require.Equal(t, "CI_TIME_ZONE", result.Views[0].Table) +} + +func TestApplySourceParityEnrichment_InferTableFromSubquery(t *testing.T) { + source := &shape.Source{ + Path: "/repo/dql/platform/advertiser/advertiser.dql", + DQL: `SELECT x.* FROM (SELECT a.* FROM CI_ADVERTISER a) x`, + } + result := &plan.Result{ + Views: []*plan.View{ + {Name: "advertiser", Table: "advertiser", SQL: `SELECT x.* FROM (SELECT a.* FROM CI_ADVERTISER a) x`}, + }, + } + + applySourceParityEnrichment(result, source) + + require.Equal(t, "CI_ADVERTISER", result.Views[0].Table) + require.Equal(t, "advertiser/advertiser.sql", result.Views[0].SQLURI) +} + +func TestApplySourceParityEnrichment_InferTableFromEmbed(t *testing.T) { + tempDir := t.TempDir() + dqlDir := filepath.Join(tempDir, "dql", "platform", "timezone") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + embedded := filepath.Join(dqlDir, "timezone.sql") + require.NoError(t, os.WriteFile(embedded, []byte(`SELECT tz.ID FROM CI_TIME_ZONE tz`), 0o644)) + source := &shape.Source{ + Path: filepath.Join(dqlDir, "timezone.dql"), + DQL: `SELECT timezone.* FROM (${embed: timezone.sql}) timezone`, + } + result := &plan.Result{ + Views: []*plan.View{ + {Name: "timezone", Table: "timezone", SQL: `SELECT timezone.* FROM (${embed: timezone.sql}) timezone`}, + }, + } + + applySourceParityEnrichment(result, source) + + require.Equal(t, "CI_TIME_ZONE", result.Views[0].Table) + require.Equal(t, "timezone/timezone.sql", result.Views[0].SQLURI) +} + +func TestTopLevelFromExpr_IgnoresNestedFrom(t *testing.T) { + sqlText := `SELECT a.*, EXISTS(SELECT 1 FROM CI_ENTITY_WATCHLIST w WHERE w.ENTITY_ID = a.ID) AS watching FROM (SELECT x.* FROM CI_ADVERTISER x) a` + require.Equal(t, "(SELECT x.* FROM CI_ADVERTISER x) a", topLevelFromExpr(sqlText)) +} + +func TestInferConnector(t *testing.T) { + require.Equal(t, "system", inferConnector(&plan.View{Table: "session"}, &shape.Source{Path: "/repo/dql/system/session/session.dql"})) + require.Equal(t, "ci_ads", inferConnector(&plan.View{Table: "CI_ADVERTISER"}, &shape.Source{Path: "/repo/dql/platform/advertiser/advertiser.dql"})) + require.Equal(t, "sitemgmt", inferConnector(&plan.View{Table: "SITE_MAP"}, &shape.Source{Path: "/repo/dql/ui/agency/detail/campaign.dql"})) +} + +func TestExtractSummarySQL(t *testing.T) { + sqlText := `SELECT b.* FROM CI_BROWSER b +JOIN ( + SELECT COUNT(1) AS CNT + FROM ($View.browser.SQL) t +) summary ON 1=1` + require.Contains(t, extractSummarySQL(sqlText), "COUNT(1)") +} + +func TestInferTableFromSQL_PreservesTemplateQualifiedTable(t *testing.T) { + sqlText := `SELECT SITE_ID FROM ${sitemgmt_project}.${sitemgmt_dataset}.SITE_LIST_MATCH slm` + require.Equal(t, "${sitemgmt_project}.${sitemgmt_dataset}.SITE_LIST_MATCH", inferTableFromSQL(sqlText, nil)) +} + +func TestShouldInferTable_NormalizedTemplatePlaceholderTable(t *testing.T) { + require.True(t, shouldInferTable(&plan.View{Name: "match", Table: "1.1.SITE_LIST_MATCH"})) + require.False(t, shouldInferTable(&plan.View{Name: "match", Table: "SITE_LIST_MATCH"})) +} + +func TestInferTableFromSQL_PathLikeTable(t *testing.T) { + sqlText := `SELECT user_id FROM session/attributes WHERE user_id = 1` + require.Equal(t, "session/attributes", inferTableFromSQL(sqlText, nil)) +} + +func TestApplySourceParityEnrichment_InferTableFromSiblingSQLOnPlaceholderTable(t *testing.T) { + tempDir := t.TempDir() + dqlDir := filepath.Join(tempDir, "dql", "platform", "sitelist") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(dqlDir, "match.sql"), []byte(`SELECT SITE_ID FROM ${sitemgmt_project}.${sitemgmt_dataset}.SITE_LIST_MATCH slm`), 0o644)) + source := &shape.Source{ + Path: filepath.Join(dqlDir, "match.dql"), + DQL: `SELECT 1`, + } + result := &plan.Result{ + Views: []*plan.View{ + {Name: "match", Table: "1.1.SITE_LIST_MATCH"}, + }, + } + + applySourceParityEnrichment(result, source) + + require.Equal(t, "${sitemgmt_project}.${sitemgmt_dataset}.SITE_LIST_MATCH", result.Views[0].Table) +} + +func TestExtractJoinSubqueryBodies(t *testing.T) { + sqlText := `SELECT sl.* FROM SITE_LIST sl +JOIN ( + SELECT SITE_ID, SITE_LIST_ID FROM ${sitemgmt_project}.${sitemgmt_dataset}.SITE_LIST_MATCH +) match ON match.SITE_LIST_ID = sl.ID +JOIN ( + ${embed: match_rules.sql} + ${predicate.Builder().CombineOr($predicate.FilterGroup(1, "AND")).Build("WHERE")} +) matchRules ON matchRules.SITE_LIST_ID = sl.ID` + bodies := extractJoinSubqueryBodies(sqlText) + require.Contains(t, bodies, "match") + require.Contains(t, bodies["match"], "SITE_LIST_MATCH") + require.Contains(t, bodies, "matchRules") + require.Contains(t, bodies["matchRules"], "${embed: match_rules.sql}") +} + +func TestApplySourceParityEnrichment_Metadata(t *testing.T) { + source := &shape.Source{ + Path: "/repo/dql/platform/tvaffiliatestation/tvaffiliatestation.dql", + DQL: `/* {"Name":"TvAffiliateStation"} */ +SELECT use_connector(tvAffiliateStation, 'ci_ads'), + allow_nulls(tvAffiliateStation), + set_limit(tvAffiliateStation, 0) +FROM CI_TV_AFFILIATE_STATION tvAffiliateStation +JOIN ( + SELECT COUNT(1) AS CNT FROM ($View.tvAffiliateStation.SQL) t +) summary ON 1=1`, + } + result := &plan.Result{ + Views: []*plan.View{ + {Name: "tvAffiliateStation", Table: "CI_TV_AFFILIATE_STATION", SQL: "SELECT * FROM CI_TV_AFFILIATE_STATION tvAffiliateStation"}, + }, + } + hints := extractViewHints(source.DQL) + applyViewHints(result, hints) + applySourceParityEnrichment(result, source) + + require.Len(t, result.Views, 1) + actual := result.Views[0] + require.NotNil(t, actual.AllowNulls) + require.True(t, *actual.AllowNulls) + require.NotNil(t, actual.SelectorNoLimit) + require.True(t, *actual.SelectorNoLimit) + require.Equal(t, "tv", actual.SelectorNamespace) + require.Equal(t, "platform/tvaffiliatestation", actual.Module) + require.Equal(t, "*TvAffiliateStationView", actual.SchemaType) + require.NotEmpty(t, actual.Summary) +} diff --git a/repository/shape/compile/hints.go b/repository/shape/compile/hints.go new file mode 100644 index 00000000..93a6bfb3 --- /dev/null +++ b/repository/shape/compile/hints.go @@ -0,0 +1,185 @@ +package compile + +import ( + "reflect" + "regexp" + "strconv" + "strings" + + "github.com/viant/datly/repository/shape/plan" +) + +var ( + useConnectorExpr = regexp.MustCompile(`(?i)use_connector\s*\(\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*,\s*(?:'([a-zA-Z_][a-zA-Z0-9_]*)'|"([a-zA-Z_][a-zA-Z0-9_]*)"|([a-zA-Z_][a-zA-Z0-9_]*))\s*\)`) + allowNullsExpr = regexp.MustCompile(`(?i)allow_nulls\s*\(\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*\)`) + setLimitExpr = regexp.MustCompile(`(?i)set_limit\s*\(\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*,\s*(-?[0-9]+)\s*\)`) +) + +type viewHint struct { + Connector string + AllowNulls *bool + NoLimit *bool +} + +func extractViewHints(dql string) map[string]viewHint { + result := map[string]viewHint{} + for _, match := range useConnectorExpr.FindAllStringSubmatch(dql, -1) { + if len(match) < 5 { + continue + } + alias := strings.TrimSpace(match[1]) + connector := strings.TrimSpace(match[2]) + if connector == "" { + connector = strings.TrimSpace(match[3]) + } + if connector == "" { + connector = strings.TrimSpace(match[4]) + } + if alias == "" || connector == "" { + continue + } + hint := result[alias] + hint.Connector = connector + result[alias] = hint + } + for _, match := range allowNullsExpr.FindAllStringSubmatch(dql, -1) { + if len(match) < 2 { + continue + } + alias := strings.TrimSpace(match[1]) + if alias == "" { + continue + } + hint := result[alias] + value := true + hint.AllowNulls = &value + result[alias] = hint + } + for _, match := range setLimitExpr.FindAllStringSubmatch(dql, -1) { + if len(match) < 3 { + continue + } + alias := strings.TrimSpace(match[1]) + limitRaw := strings.TrimSpace(match[2]) + if alias == "" || limitRaw == "" { + continue + } + limit, err := strconv.Atoi(limitRaw) + if err != nil { + continue + } + hint := result[alias] + noLimit := limit == 0 + hint.NoLimit = &noLimit + result[alias] = hint + } + return result +} + +func appendRelationViews(result *plan.Result, root *plan.View, hints map[string]viewHint) { + if result == nil || root == nil || len(root.Relations) == 0 { + return + } + for _, relation := range root.Relations { + if relation == nil { + continue + } + name := strings.TrimSpace(relation.Ref) + if name == "" { + name = strings.TrimSpace(relation.Name) + } + if name == "" { + continue + } + if len(relation.On) == 0 { + continue + } + if _, exists := result.ViewsByName[name]; exists { + continue + } + table := strings.TrimSpace(relation.Table) + if table == "" { + table = name + } + table = normalizeRelationTable(table) + view := &plan.View{ + Path: name, + Holder: name, + Name: name, + Table: table, + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + } + result.Views = append(result.Views, view) + result.ViewsByName[name] = view + } +} + +func applyViewHints(result *plan.Result, hints map[string]viewHint) { + if result == nil || len(result.Views) == 0 { + return + } + if len(hints) == 0 { + return + } + for _, item := range result.Views { + if item == nil { + continue + } + for _, key := range []string{item.Name, item.Holder} { + key = strings.TrimSpace(key) + if key == "" { + continue + } + hint, ok := hints[key] + if !ok { + continue + } + if item.Connector == "" && hint.Connector != "" { + item.Connector = hint.Connector + } + if item.AllowNulls == nil && hint.AllowNulls != nil { + value := *hint.AllowNulls + item.AllowNulls = &value + } + if item.SelectorNoLimit == nil && hint.NoLimit != nil { + value := *hint.NoLimit + item.SelectorNoLimit = &value + } + } + } +} + +func normalizeRelationTable(table string) string { + table = strings.TrimSpace(table) + if table == "" { + return table + } + lower := strings.ToLower(table) + fromIdx := strings.Index(lower, " from ") + if fromIdx == -1 { + return table + } + tail := strings.TrimSpace(table[fromIdx+6:]) + if tail == "" { + return table + } + stop := len(tail) + for i := 0; i < len(tail); i++ { + switch tail[i] { + case ' ', '\t', '\n', '\r', ')': + stop = i + i = len(tail) + } + } + if stop == 0 { + return table + } + normalized := strings.TrimSpace(tail[:stop]) + normalized = strings.Trim(normalized, "`\"") + if normalized == "" { + return table + } + return normalized +} diff --git a/repository/shape/compile/hints_test.go b/repository/shape/compile/hints_test.go new file mode 100644 index 00000000..768f8286 --- /dev/null +++ b/repository/shape/compile/hints_test.go @@ -0,0 +1,43 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape/plan" +) + +func TestExtractViewHints_WithQuotedConnector(t *testing.T) { + dql := "SELECT use_connector(match, 'bq_sitemgmt_match'), use_connector(site, \"ci_ads\"), allow_nulls(match), set_limit(match, 0)" + hints := extractViewHints(dql) + require.Len(t, hints, 2) + assert.Equal(t, "bq_sitemgmt_match", hints["match"].Connector) + assert.Equal(t, "ci_ads", hints["site"].Connector) + require.NotNil(t, hints["match"].AllowNulls) + assert.True(t, *hints["match"].AllowNulls) + require.NotNil(t, hints["match"].NoLimit) + assert.True(t, *hints["match"].NoLimit) +} + +func TestApplyViewHints_Metadata(t *testing.T) { + trueValue := true + result := &plan.Result{ + Views: []*plan.View{ + {Name: "match", Table: "MATCH"}, + }, + } + applyViewHints(result, map[string]viewHint{ + "match": { + Connector: "ci_ads", + AllowNulls: &trueValue, + NoLimit: &trueValue, + }, + }) + require.Len(t, result.Views, 1) + assert.Equal(t, "ci_ads", result.Views[0].Connector) + require.NotNil(t, result.Views[0].AllowNulls) + assert.True(t, *result.Views[0].AllowNulls) + require.NotNil(t, result.Views[0].SelectorNoLimit) + assert.True(t, *result.Views[0].SelectorNoLimit) +} diff --git a/repository/shape/compile/legacy_adapter.go b/repository/shape/compile/legacy_adapter.go new file mode 100644 index 00000000..d26be3c8 --- /dev/null +++ b/repository/shape/compile/legacy_adapter.go @@ -0,0 +1,655 @@ +package compile + +import ( + "os" + "path/filepath" + "reflect" + "sort" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" + "gopkg.in/yaml.v3" +) + +func resolveGeneratedCompanionDQL(source *shape.Source) string { + if source == nil || strings.TrimSpace(source.Path) == "" { + return "" + } + settings := extractRuleSettings(source) + typeExpr := strings.TrimSpace(settings.Type) + if typeExpr == "" { + return "" + } + typeExpr = strings.TrimSuffix(typeExpr, ".Handler") + typeExpr = strings.Trim(typeExpr, `"'`) + if typeExpr == "" { + return "" + } + dir := filepath.Dir(source.Path) + baseTypePath := filepath.FromSlash(typeExpr) + stem := filepath.Base(baseTypePath) + candidates := []string{ + filepath.Join(dir, "gen", baseTypePath+".dql"), + filepath.Join(dir, "gen", baseTypePath+".sql"), + filepath.Join(dir, "gen", stem+".dql"), + filepath.Join(dir, "gen", stem+".sql"), + } + for _, candidate := range candidates { + data, err := os.ReadFile(candidate) + if err != nil { + continue + } + content := strings.TrimSpace(string(data)) + if content != "" { + return content + } + } + return "" +} + +func resolveLegacyRouteViews(source *shape.Source) []*plan.View { + return resolveLegacyRouteViewsWithLayout(source, defaultCompilePathLayout()) +} + +func resolveLegacyRouteViewsWithLayout(source *shape.Source, layout compilePathLayout) []*plan.View { + if source == nil || strings.TrimSpace(source.Path) == "" { + return nil + } + platformRoot, relativeDir, stem, ok := platformPathParts(source.Path, layout) + if !ok { + return nil + } + settings := extractRuleSettings(source) + typeExpr := strings.TrimSpace(settings.Type) + typeExpr = strings.Trim(typeExpr, `"'`) + typeExpr = strings.TrimSuffix(typeExpr, ".Handler") + typeStem := "" + if typeExpr != "" { + typeStem = filepath.Base(filepath.FromSlash(typeExpr)) + } + routesRoot := joinRelativePath(platformRoot, layout.routesRelative) + routesBase := filepath.Join(routesRoot, filepath.FromSlash(relativeDir)) + legacyMeta := []legacyViewMeta(nil) + for _, candidateYAML := range legacyRouteYAMLCandidates(routesBase, stem, typeStem) { + legacyMeta = loadLegacyRouteViewMeta(candidateYAML) + if len(legacyMeta) > 0 { + break + } + } + searchDirs := []string{ + filepath.Join(routesBase, typeStem, stem), + filepath.Join(routesBase, typeStem), + filepath.Join(routesBase, stem, stem), + filepath.Join(routesBase, stem), + routesBase, + } + var sqlFiles []string + for _, dir := range searchDirs { + entries, err := os.ReadDir(dir) + if err != nil { + continue + } + for _, entry := range entries { + if entry.IsDir() || !strings.HasSuffix(strings.ToLower(entry.Name()), ".sql") { + continue + } + sqlFiles = append(sqlFiles, filepath.Join(dir, entry.Name())) + } + if len(sqlFiles) > 0 { + break + } + } + if len(sqlFiles) == 0 { + return nil + } + sort.Strings(sqlFiles) + result := make([]*plan.View, 0, len(sqlFiles)) + rootIndex := -1 + for _, sqlFile := range sqlFiles { + name := strings.TrimSuffix(filepath.Base(sqlFile), filepath.Ext(sqlFile)) + if name == "" { + continue + } + data, err := os.ReadFile(sqlFile) + if err != nil { + continue + } + sqlText := string(data) + table := "" + if name != stem { + table = inferTableFromSQL(sqlText, source) + } + connector := strings.TrimSpace(settings.Connector) + if connector == "" { + connector = strings.TrimSpace(source.Connector) + } + if connector == "" { + connector = inferConnector(&plan.View{Table: table}, source) + } + viewItem := &plan.View{ + Path: name, + Holder: name, + Name: name, + Table: table, + SQL: sqlText, + SQLURI: filepath.ToSlash(filepath.Join(stem, name+".sql")), + Connector: connector, + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + } + if meta, ok := lookupLegacyViewMeta(legacyMeta, name); ok { + if strings.TrimSpace(meta.Table) != "" { + viewItem.Table = strings.TrimSpace(meta.Table) + } + if strings.TrimSpace(meta.Connector) != "" { + viewItem.Connector = strings.TrimSpace(meta.Connector) + } + if strings.TrimSpace(meta.SQLURI) != "" { + viewItem.SQLURI = strings.TrimSpace(meta.SQLURI) + } + } + if name == stem { + rootIndex = len(result) + } + result = append(result, viewItem) + } + if len(result) == 0 { + return nil + } + if rootIndex > 0 { + root := result[rootIndex] + copy(result[1:rootIndex+1], result[0:rootIndex]) + result[0] = root + } + if result[0].Name != stem { + rootConnector := result[0].Connector + result = append([]*plan.View{{ + Path: stem, + Holder: stem, + Name: stem, + Table: "", + SQLURI: filepath.ToSlash(filepath.Join(stem, stem+".sql")), + Connector: rootConnector, + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + }}, result...) + } + result[0].Table = "" + result[0].Name = stem + result[0].Holder = stem + result[0].Path = stem + if meta, ok := lookupLegacyViewMeta(legacyMeta, stem); ok { + if strings.TrimSpace(meta.Table) != "" { + result[0].Table = strings.TrimSpace(meta.Table) + } + if strings.TrimSpace(meta.Connector) != "" { + result[0].Connector = strings.TrimSpace(meta.Connector) + } + if strings.TrimSpace(meta.SQLURI) != "" { + result[0].SQLURI = strings.TrimSpace(meta.SQLURI) + } + } + if result[0].SQLURI == "" { + result[0].SQLURI = filepath.ToSlash(filepath.Join(stem, stem+".sql")) + } + return result +} + +type legacyViewMeta struct { + Name string + Table string + Connector string + SQLURI string +} + +func loadLegacyRouteViewMeta(yamlPath string) []legacyViewMeta { + data, err := os.ReadFile(yamlPath) + if err != nil { + return nil + } + var payload struct { + Resource struct { + Views []struct { + Name string `yaml:"Name"` + Table string `yaml:"Table"` + Connector struct { + Ref string `yaml:"Ref"` + } `yaml:"Connector"` + Template struct { + SourceURL string `yaml:"SourceURL"` + } `yaml:"Template"` + } `yaml:"Views"` + } `yaml:"Resource"` + } + if err = yaml.Unmarshal(data, &payload); err != nil { + return nil + } + result := make([]legacyViewMeta, 0, len(payload.Resource.Views)) + for _, item := range payload.Resource.Views { + result = append(result, legacyViewMeta{ + Name: strings.TrimSpace(item.Name), + Table: strings.TrimSpace(item.Table), + Connector: strings.TrimSpace(item.Connector.Ref), + SQLURI: strings.TrimSpace(item.Template.SourceURL), + }) + } + return result +} + +func lookupLegacyViewMeta(items []legacyViewMeta, name string) (legacyViewMeta, bool) { + name = strings.TrimSpace(name) + if name == "" { + return legacyViewMeta{}, false + } + for _, item := range items { + if strings.EqualFold(strings.TrimSpace(item.Name), name) { + return item, true + } + } + return legacyViewMeta{}, false +} + +func resolveLegacyRouteStates(source *shape.Source) []*plan.State { + return resolveLegacyRouteStatesWithLayout(source, defaultCompilePathLayout()) +} + +func resolveLegacyRouteStatesWithLayout(source *shape.Source, layout compilePathLayout) []*plan.State { + if source == nil || strings.TrimSpace(source.Path) == "" { + return nil + } + platformRoot, relativeDir, stem, ok := platformPathParts(source.Path, layout) + if !ok { + return nil + } + settings := extractRuleSettings(source) + typeExpr := strings.TrimSpace(settings.Type) + typeExpr = strings.Trim(typeExpr, `"'`) + typeExpr = strings.TrimSuffix(typeExpr, ".Handler") + typeStem := "" + if typeExpr != "" { + typeStem = filepath.Base(filepath.FromSlash(typeExpr)) + } + routesRoot := joinRelativePath(platformRoot, layout.routesRelative) + routesBase := filepath.Join(routesRoot, filepath.FromSlash(relativeDir)) + yamlCandidates := legacyRouteYAMLCandidates(routesBase, stem, typeStem) + var payload struct { + Resource struct { + Parameters []struct { + Name string `yaml:"Name"` + URI string `yaml:"URI"` + Value string `yaml:"Value"` + Required *bool `yaml:"Required"` + Cacheable *bool `yaml:"Cacheable"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` + Predicates []struct { + Group int `yaml:"Group"` + Name string `yaml:"Name"` + Ensure bool `yaml:"Ensure"` + Args []string `yaml:"Args"` + } `yaml:"Predicates"` + } `yaml:"Parameters"` + Views []struct { + Name string `yaml:"Name"` + Selector struct { + LimitParameter struct { + Name string `yaml:"Name"` + Cacheable *bool `yaml:"Cacheable"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` + } `yaml:"LimitParameter"` + OffsetParameter struct { + Name string `yaml:"Name"` + Cacheable *bool `yaml:"Cacheable"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` + } `yaml:"OffsetParameter"` + PageParameter struct { + Name string `yaml:"Name"` + Cacheable *bool `yaml:"Cacheable"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` + } `yaml:"PageParameter"` + FieldsParameter struct { + Name string `yaml:"Name"` + Cacheable *bool `yaml:"Cacheable"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` + } `yaml:"FieldsParameter"` + OrderByParameter struct { + Name string `yaml:"Name"` + Cacheable *bool `yaml:"Cacheable"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` + } `yaml:"OrderByParameter"` + } `yaml:"Selector"` + } `yaml:"Views"` + } `yaml:"Resource"` + } + loaded := false + for _, candidate := range yamlCandidates { + data, err := os.ReadFile(candidate) + if err != nil { + continue + } + if err = yaml.Unmarshal(data, &payload); err != nil { + continue + } + loaded = true + break + } + if !loaded || len(payload.Resource.Parameters) == 0 { + return nil + } + result := make([]*plan.State, 0, len(payload.Resource.Parameters)) + for _, item := range payload.Resource.Parameters { + stateItem := &plan.State{ + Name: strings.TrimSpace(item.Name), + Path: strings.TrimSpace(item.Name), + Kind: strings.TrimSpace(item.In.Kind), + In: strings.TrimSpace(item.In.Name), + URI: strings.TrimSpace(item.URI), + Value: strings.TrimSpace(item.Value), + Required: item.Required, + Cacheable: item.Cacheable, + } + for _, predicate := range item.Predicates { + stateItem.Predicates = append(stateItem.Predicates, &plan.StatePredicate{ + Group: predicate.Group, + Name: strings.TrimSpace(predicate.Name), + Ensure: predicate.Ensure, + Arguments: append([]string{}, predicate.Args...), + }) + } + result = append(result, stateItem) + } + seen := map[string]bool{} + for _, item := range result { + if item == nil { + continue + } + key := strings.ToLower(strings.TrimSpace(item.Name)) + "|" + strings.ToLower(strings.TrimSpace(item.Kind)) + "|" + strings.ToLower(strings.TrimSpace(item.In)) + seen[key] = true + } + for _, viewItem := range payload.Resource.Views { + selectorName := strings.TrimSpace(viewItem.Name) + for _, param := range []struct { + Name string + Cacheable *bool + InKind string + InName string + }{ + { + Name: strings.TrimSpace(viewItem.Selector.LimitParameter.Name), + Cacheable: viewItem.Selector.LimitParameter.Cacheable, + InKind: strings.TrimSpace(viewItem.Selector.LimitParameter.In.Kind), + InName: strings.TrimSpace(viewItem.Selector.LimitParameter.In.Name), + }, + { + Name: strings.TrimSpace(viewItem.Selector.OffsetParameter.Name), + Cacheable: viewItem.Selector.OffsetParameter.Cacheable, + InKind: strings.TrimSpace(viewItem.Selector.OffsetParameter.In.Kind), + InName: strings.TrimSpace(viewItem.Selector.OffsetParameter.In.Name), + }, + { + Name: strings.TrimSpace(viewItem.Selector.PageParameter.Name), + Cacheable: viewItem.Selector.PageParameter.Cacheable, + InKind: strings.TrimSpace(viewItem.Selector.PageParameter.In.Kind), + InName: strings.TrimSpace(viewItem.Selector.PageParameter.In.Name), + }, + { + Name: strings.TrimSpace(viewItem.Selector.FieldsParameter.Name), + Cacheable: viewItem.Selector.FieldsParameter.Cacheable, + InKind: strings.TrimSpace(viewItem.Selector.FieldsParameter.In.Kind), + InName: strings.TrimSpace(viewItem.Selector.FieldsParameter.In.Name), + }, + { + Name: strings.TrimSpace(viewItem.Selector.OrderByParameter.Name), + Cacheable: viewItem.Selector.OrderByParameter.Cacheable, + InKind: strings.TrimSpace(viewItem.Selector.OrderByParameter.In.Kind), + InName: strings.TrimSpace(viewItem.Selector.OrderByParameter.In.Name), + }, + } { + if param.Name == "" { + continue + } + kind := firstNonEmptyString(strings.ToLower(param.InKind), "query") + inName := firstNonEmptyString(param.InName, strings.ToLower(param.Name)) + key := strings.ToLower(param.Name) + "|" + kind + "|" + strings.ToLower(inName) + if seen[key] { + continue + } + item := &plan.State{ + Name: param.Name, + Path: param.Name, + Kind: kind, + In: inName, + QuerySelector: selectorName, + Cacheable: param.Cacheable, + } + result = append(result, item) + seen[key] = true + } + } + return result +} + +func resolveLegacyRouteTypes(source *shape.Source) []*plan.Type { + return resolveLegacyRouteTypesWithLayout(source, defaultCompilePathLayout()) +} + +func resolveLegacyRouteTypesWithLayout(source *shape.Source, layout compilePathLayout) []*plan.Type { + if source == nil || strings.TrimSpace(source.Path) == "" { + return nil + } + platformRoot, relativeDir, stem, ok := platformPathParts(source.Path, layout) + if !ok { + return nil + } + settings := extractRuleSettings(source) + typeExpr := strings.TrimSpace(settings.Type) + typeExpr = strings.Trim(typeExpr, `"'`) + typeExpr = strings.TrimSuffix(typeExpr, ".Handler") + typeStem := "" + if typeExpr != "" { + typeStem = filepath.Base(filepath.FromSlash(typeExpr)) + } + routesRoot := joinRelativePath(platformRoot, layout.routesRelative) + routesBase := filepath.Join(routesRoot, filepath.FromSlash(relativeDir)) + yamlCandidates := legacyRouteYAMLCandidates(routesBase, stem, typeStem) + var payload struct { + Resource struct { + Types []struct { + Name string `yaml:"Name"` + Alias string `yaml:"Alias"` + DataType string `yaml:"DataType"` + Cardinality string `yaml:"Cardinality"` + Package string `yaml:"Package"` + ModulePath string `yaml:"ModulePath"` + } `yaml:"Types"` + } `yaml:"Resource"` + } + loaded := false + for _, candidate := range yamlCandidates { + data, err := os.ReadFile(candidate) + if err != nil { + continue + } + if err = yaml.Unmarshal(data, &payload); err != nil { + continue + } + loaded = true + break + } + if !loaded || len(payload.Resource.Types) == 0 { + return nil + } + result := make([]*plan.Type, 0, len(payload.Resource.Types)) + seen := map[string]bool{} + for _, item := range payload.Resource.Types { + name := strings.TrimSpace(item.Name) + if name == "" { + continue + } + key := strings.ToLower(name) + if seen[key] { + continue + } + seen[key] = true + result = append(result, &plan.Type{ + Name: name, + Alias: strings.TrimSpace(item.Alias), + DataType: strings.TrimSpace(item.DataType), + Cardinality: strings.TrimSpace(item.Cardinality), + Package: strings.TrimSpace(item.Package), + ModulePath: strings.TrimSpace(item.ModulePath), + }) + } + return result +} + +func mergeLegacyRouteStates(result *plan.Result, source *shape.Source) { + mergeLegacyRouteStatesWithLayout(result, source, defaultCompilePathLayout()) +} + +func mergeLegacyRouteStatesWithLayout(result *plan.Result, source *shape.Source, layout compilePathLayout) { + if result == nil { + return + } + legacy := resolveLegacyRouteStatesWithLayout(source, layout) + if len(legacy) == 0 { + return + } + existing := map[string]bool{} + for _, item := range result.States { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + key := strings.ToLower(strings.TrimSpace(item.Name)) + "|" + strings.ToLower(strings.TrimSpace(item.Kind)) + "|" + strings.ToLower(strings.TrimSpace(item.In)) + existing[key] = true + } + for _, item := range legacy { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + key := strings.ToLower(strings.TrimSpace(item.Name)) + "|" + strings.ToLower(strings.TrimSpace(item.Kind)) + "|" + strings.ToLower(strings.TrimSpace(item.In)) + if existing[key] { + continue + } + result.States = append(result.States, item) + existing[key] = true + } +} + +func mergeLegacyRouteTypes(result *plan.Result, source *shape.Source) { + mergeLegacyRouteTypesWithLayout(result, source, defaultCompilePathLayout()) +} + +func mergeLegacyRouteTypesWithLayout(result *plan.Result, source *shape.Source, layout compilePathLayout) { + if result == nil { + return + } + legacy := resolveLegacyRouteTypesWithLayout(source, layout) + if len(legacy) == 0 { + return + } + existing := map[string]bool{} + for _, item := range result.Types { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + existing[strings.ToLower(strings.TrimSpace(item.Name))] = true + } + for _, item := range legacy { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + key := strings.ToLower(strings.TrimSpace(item.Name)) + if existing[key] { + continue + } + result.Types = append(result.Types, item) + existing[key] = true + } +} + +func legacyRouteYAMLCandidates(routesBase, stem, typeStem string) []string { + stemFileVariants := routeStemAlternatives(stem) + stemDirVariants := routeStemAlternatives(stem) + typeVariants := routeStemAlternatives(typeStem) + var result []string + seen := map[string]bool{} + appendCandidate := func(path string) { + path = filepath.Clean(path) + if path == "." || path == "" || seen[path] { + return + } + seen[path] = true + result = append(result, path) + } + for _, fileStem := range stemFileVariants { + appendCandidate(filepath.Join(routesBase, fileStem+".yaml")) + for _, dirStem := range stemDirVariants { + appendCandidate(filepath.Join(routesBase, dirStem, fileStem+".yaml")) + } + for _, itemTypeStem := range typeVariants { + if strings.TrimSpace(itemTypeStem) == "" { + continue + } + appendCandidate(filepath.Join(routesBase, itemTypeStem, fileStem+".yaml")) + } + } + return result +} + +func routeStemAlternatives(value string) []string { + value = strings.TrimSpace(value) + if value == "" { + return nil + } + alts := []string{value} + dashed := strings.ReplaceAll(value, "_", "-") + if dashed != value { + alts = append(alts, dashed) + } + return alts +} + +func platformPathParts(sourcePath string, layout compilePathLayout) (platformRoot, relativeDir, stem string, ok bool) { + sourcePath = filepath.Clean(strings.TrimSpace(sourcePath)) + if sourcePath == "" { + return "", "", "", false + } + normalized := filepath.ToSlash(sourcePath) + marker := layout.dqlMarker + if marker == "" { + marker = defaultCompilePathLayout().dqlMarker + } + idx := strings.Index(normalized, marker) + if idx == -1 { + return "", "", "", false + } + platformRoot = sourcePath[:idx] + relative := strings.TrimPrefix(normalized[idx+len(marker):], "/") + relativeDir = filepath.Dir(relative) + stem = strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) + if strings.TrimSpace(stem) == "" { + return "", "", "", false + } + return platformRoot, relativeDir, stem, true +} diff --git a/repository/shape/compile/pathlayout.go b/repository/shape/compile/pathlayout.go new file mode 100644 index 00000000..a1bc081e --- /dev/null +++ b/repository/shape/compile/pathlayout.go @@ -0,0 +1,67 @@ +package compile + +import ( + "path/filepath" + "strings" + + "github.com/viant/datly/repository/shape" +) + +type compilePathLayout struct { + dqlMarker string + routesRelative string +} + +func defaultCompilePathLayout() compilePathLayout { + return compilePathLayout{ + dqlMarker: "/dql/", + routesRelative: "repo/dev/Datly/routes", + } +} + +func newCompilePathLayout(opts *shape.CompileOptions) compilePathLayout { + ret := defaultCompilePathLayout() + if opts == nil { + return ret + } + if marker := normalizeDQLMarker(opts.DQLPathMarker); marker != "" { + ret.dqlMarker = marker + } + if rel := normalizeRoutesRelative(opts.RoutesRelativePath); rel != "" { + ret.routesRelative = rel + } + return ret +} + +func normalizeDQLMarker(input string) string { + input = strings.TrimSpace(strings.ReplaceAll(input, "\\", "/")) + if input == "" { + return "" + } + input = strings.Trim(input, "/") + if input == "" { + return "" + } + return "/" + input + "/" +} + +func normalizeRoutesRelative(input string) string { + input = strings.TrimSpace(strings.ReplaceAll(input, "\\", "/")) + input = strings.Trim(input, "/") + if input == "" { + return "" + } + return input +} + +func joinRelativePath(base string, rel string) string { + rel = normalizeRoutesRelative(rel) + if rel == "" { + return base + } + parts := strings.Split(rel, "/") + args := make([]string, 0, len(parts)+1) + args = append(args, base) + args = append(args, parts...) + return filepath.Join(args...) +} diff --git a/repository/shape/compile/pipeline/diag.go b/repository/shape/compile/pipeline/diag.go new file mode 100644 index 00000000..45bf7856 --- /dev/null +++ b/repository/shape/compile/pipeline/diag.go @@ -0,0 +1,47 @@ +package pipeline + +import ( + "unicode/utf8" + + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" +) + +func StatementSpan(sqlText string, stmt *dqlstmt.Statement) dqlshape.Span { + if stmt == nil { + return pointSpan(sqlText, 0) + } + return pointSpan(sqlText, stmt.Start) +} + +func pointSpan(text string, offset int) dqlshape.Span { + start := positionAt(text, offset) + end := start + return dqlshape.Span{Start: start, End: end} +} + +func positionAt(text string, offset int) dqlshape.Position { + if offset < 0 { + offset = 0 + } + if offset > len(text) { + offset = len(text) + } + line := 1 + char := 1 + index := 0 + for index < offset { + r, width := utf8.DecodeRuneInString(text[index:]) + if width <= 0 { + break + } + index += width + if r == '\n' { + line++ + char = 1 + } else { + char++ + } + } + return dqlshape.Position{Offset: offset, Line: line, Char: char} +} diff --git a/repository/shape/compile/pipeline/exec.go b/repository/shape/compile/pipeline/exec.go new file mode 100644 index 00000000..6bf14488 --- /dev/null +++ b/repository/shape/compile/pipeline/exec.go @@ -0,0 +1,109 @@ +package pipeline + +import ( + "reflect" + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser" +) + +func BuildExec(sourceName, sqlText string, statements dqlstmt.Statements) (*plan.View, []*dqlshape.Diagnostic) { + name := SanitizeName(sourceName) + if name == "" { + name = "DQLView" + } + tables := statements.DMLTables(sqlText) + table := name + if len(tables) > 0 { + table = tables[0] + } + fieldType := reflect.TypeOf([]map[string]interface{}{}) + elementType := reflect.TypeOf(map[string]interface{}{}) + view := &plan.View{ + Path: name, + Holder: name, + Name: name, + Mode: "SQLExec", + Table: table, + SQL: sqlText, + Cardinality: "many", + FieldType: fieldType, + ElementType: elementType, + } + return view, ValidateExecStatements(sqlText, statements) +} + +func ValidateExecStatements(sqlText string, statements dqlstmt.Statements) []*dqlshape.Diagnostic { + var result []*dqlshape.Diagnostic + for _, stmt := range statements { + if stmt == nil || !stmt.IsExec { + continue + } + body := strings.TrimSpace(sqlText[stmt.Start:stmt.End]) + if body == "" { + continue + } + lower := strings.ToLower(body) + span := StatementSpan(sqlText, stmt) + switch { + case stmt.Kind == dqlstmt.KindService: + if firstQuoted(body) == "" { + result = append(result, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLServiceArg, + Severity: dqlshape.SeverityError, + Message: "service DML call is missing quoted table argument", + Hint: "use $sql.Insert(\"TABLE\", ...) or $sql.Update(\"TABLE\", ...)", + Span: span, + }) + } + case strings.HasPrefix(lower, "insert"): + if _, err := sqlparser.ParseInsert(body); err != nil { + result = append(result, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLInsert, + Severity: dqlshape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "fix INSERT statement syntax", + Span: span, + }) + } + case strings.HasPrefix(lower, "update"): + if _, err := sqlparser.ParseUpdate(body); err != nil { + result = append(result, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLUpdate, + Severity: dqlshape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "fix UPDATE statement syntax", + Span: span, + }) + } + case strings.HasPrefix(lower, "delete"): + if _, err := sqlparser.ParseDelete(body); err != nil { + result = append(result, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDMLDelete, + Severity: dqlshape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "fix DELETE statement syntax", + Span: span, + }) + } + } + } + return result +} + +func firstQuoted(input string) string { + index := strings.Index(input, `"`) + if index == -1 { + return "" + } + tail := input[index+1:] + end := strings.Index(tail, `"`) + if end == -1 { + return "" + } + return strings.TrimSpace(tail[:end]) +} diff --git a/repository/shape/compile/pipeline/exec_test.go b/repository/shape/compile/pipeline/exec_test.go new file mode 100644 index 00000000..8c70fc74 --- /dev/null +++ b/repository/shape/compile/pipeline/exec_test.go @@ -0,0 +1,26 @@ +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" +) + +func TestBuildExec(t *testing.T) { + sqlText := "INSERT INTO ORDERS(id) VALUES (1)" + view, diags := BuildExec("orders_exec", sqlText, dqlstmt.New(sqlText)) + require.NotNil(t, view) + assert.Equal(t, "ORDERS", view.Table) + assert.Equal(t, "many", view.Cardinality) + assert.Empty(t, diags) +} + +func TestValidateExecStatements_ServiceArg(t *testing.T) { + sqlText := "$sql.Insert($rec)" + diags := ValidateExecStatements(sqlText, dqlstmt.New(sqlText)) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeDMLServiceArg, diags[0].Code) +} diff --git a/repository/shape/compile/pipeline/infer.go b/repository/shape/compile/pipeline/infer.go new file mode 100644 index 00000000..c19bcb79 --- /dev/null +++ b/repository/shape/compile/pipeline/infer.go @@ -0,0 +1,227 @@ +package pipeline + +import ( + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/query" +) + +var nonWord = regexp.MustCompile(`[^a-zA-Z0-9_]+`) + +func InferRoot(queryNode *query.Select, fallback string) (string, string, error) { + name := SanitizeName(fallback) + if name == "" { + name = "DQLView" + } + if queryNode == nil { + return name, name, nil + } + if alias := SanitizeName(queryNode.From.Alias); alias != "" { + name = alias + } + table := "" + if queryNode.From.X != nil { + table = strings.TrimSpace(sqlparser.Stringify(queryNode.From.X)) + } + if name == "" || name == SanitizeName(fallback) { + if subAlias := inferSubqueryAlias(table); subAlias != "" { + name = subAlias + } + } + if table == "" || strings.HasPrefix(table, "(") { + if inferred := inferSubqueryTable(table); inferred != "" { + table = inferred + } else { + table = name + } + } + if name == "" { + return "", "", fmt.Errorf("shape compile: failed to infer view name") + } + return name, table, nil +} + +func inferSubqueryAlias(fromExpr string) string { + fromExpr = strings.TrimSpace(fromExpr) + if fromExpr == "" || !strings.HasPrefix(fromExpr, "(") { + return "" + } + depth := 0 + closeIdx := -1 + for i := 0; i < len(fromExpr); i++ { + switch fromExpr[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + closeIdx = i + i = len(fromExpr) + } + } + } + if closeIdx == -1 || closeIdx+1 >= len(fromExpr) { + return "" + } + rest := strings.TrimSpace(fromExpr[closeIdx+1:]) + restLower := strings.ToLower(rest) + if strings.HasPrefix(restLower, "as ") { + rest = strings.TrimSpace(rest[3:]) + } + if rest == "" { + return "" + } + end := 0 + for end < len(rest) { + c := rest[end] + if !(c == '_' || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (end > 0 && c >= '0' && c <= '9')) { + break + } + end++ + } + if end == 0 { + return "" + } + return SanitizeName(rest[:end]) +} + +func inferSubqueryTable(fromExpr string) string { + inner, ok := extractSubqueryBody(fromExpr) + if !ok { + return "" + } + normalized := normalizeParserSQL(inner) + queryNode, _, err := ParseSelectWithDiagnostic(normalized) + if err != nil || queryNode == nil { + return "" + } + _, table, err := InferRoot(queryNode, "") + if err != nil { + return "" + } + table = strings.TrimSpace(strings.Trim(table, "`\"")) + if strings.EqualFold(table, "DQLView") { + return "" + } + return table +} + +func extractSubqueryBody(fromExpr string) (string, bool) { + fromExpr = strings.TrimSpace(fromExpr) + if !strings.HasPrefix(fromExpr, "(") { + return "", false + } + depth := 0 + for i := 0; i < len(fromExpr); i++ { + switch fromExpr[i] { + case '(': + depth++ + case ')': + depth-- + if depth == 0 { + if i <= 1 { + return "", false + } + return strings.TrimSpace(fromExpr[1:i]), true + } + } + } + return "", false +} + +func InferProjectionType(queryNode *query.Select) (reflect.Type, reflect.Type, string) { + if queryNode == nil || len(queryNode.List) == 0 || queryNode.List.IsStarExpr() { + return reflect.TypeOf([]map[string]interface{}{}), reflect.TypeOf(map[string]interface{}{}), "many" + } + fields := make([]reflect.StructField, 0, len(queryNode.List)) + used := map[string]int{} + for index, item := range queryNode.List { + column := sqlparser.NewColumn(item) + columnName := strings.TrimSpace(column.Identity()) + if columnName == "" { + columnName = fmt.Sprintf("col_%d", index+1) + } + fieldName := ExportedName(columnName) + if fieldName == "" { + fieldName = fmt.Sprintf("Col%d", index+1) + } + if count := used[fieldName]; count > 0 { + fieldName = fmt.Sprintf("%s%d", fieldName, count+1) + } + used[fieldName]++ + + typ := parseColumnType(column.Type) + fields = append(fields, reflect.StructField{ + Name: fieldName, + Type: typ, + Tag: reflect.StructTag(fmt.Sprintf(`json:"%s,omitempty" sqlx:"name=%s"`, strings.ToLower(fieldName), columnName)), + }) + } + element := reflect.StructOf(fields) + return reflect.SliceOf(element), element, "many" +} + +func SanitizeName(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + if value == strings.ToUpper(value) { + value = strings.ToLower(value) + } + value = nonWord.ReplaceAllString(value, "_") + value = strings.Trim(value, "_") + if value == "" { + return "" + } + if value[0] >= '0' && value[0] <= '9' { + value = "V_" + value + } + return value +} + +func ExportedName(value string) string { + value = nonWord.ReplaceAllString(strings.TrimSpace(value), "_") + value = strings.Trim(value, "_") + if value == "" { + return "" + } + parts := strings.Split(strings.ToLower(value), "_") + for i, item := range parts { + if item == "" { + continue + } + parts[i] = strings.ToUpper(item[:1]) + item[1:] + } + name := strings.Join(parts, "") + if name == "" { + return "" + } + if name[0] >= '0' && name[0] <= '9' { + name = "N" + name + } + return name +} + +func parseColumnType(dataType string) reflect.Type { + switch strings.ToLower(strings.TrimSpace(dataType)) { + case "", "string", "text", "varchar", "char", "uuid", "json", "jsonb": + return reflect.TypeOf("") + case "bool", "boolean": + return reflect.TypeOf(false) + case "int", "int32", "smallint", "integer": + return reflect.TypeOf(int(0)) + case "int64", "bigint": + return reflect.TypeOf(int64(0)) + case "float", "float32", "real": + return reflect.TypeOf(float32(0)) + case "float64", "double", "numeric", "decimal": + return reflect.TypeOf(float64(0)) + default: + return reflect.TypeOf("") + } +} diff --git a/repository/shape/compile/pipeline/infer_test.go b/repository/shape/compile/pipeline/infer_test.go new file mode 100644 index 00000000..748fcded --- /dev/null +++ b/repository/shape/compile/pipeline/infer_test.go @@ -0,0 +1,42 @@ +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/viant/sqlparser" +) + +func TestInferSubqueryAlias(t *testing.T) { + assert.Equal(t, "session", inferSubqueryAlias("(SELECT * FROM session) session JOIN (SELECT * FROM attr) attribute ON attribute.id = session.id")) + assert.Equal(t, "x", inferSubqueryAlias("(SELECT 1) AS x")) + assert.Equal(t, "publisherglobaloverride", inferSubqueryAlias(`( + SELECT MIN(g.BUSINESS_MODEL_ID) AS BUSINESS_MODEL_ID + FROM CI_GLOBAL_PUBLISHER_OVERRIDE g +) publisherglobaloverride`)) + assert.Equal(t, "", inferSubqueryAlias("orders o")) +} + +func TestSanitizeName_AllCapsToLower(t *testing.T) { + assert.Equal(t, "value", SanitizeName("VALUE")) + assert.Equal(t, "status", SanitizeName("STATUS")) +} + +func TestInferSubqueryTable(t *testing.T) { + assert.Equal(t, "CI_ADVERTISER", inferSubqueryTable("(SELECT a.* FROM CI_ADVERTISER a) advertiser")) + assert.Equal(t, "", inferSubqueryTable("orders o")) +} + +func TestInferRoot_SubqueryFrom(t *testing.T) { + queryNode, err := sqlparser.ParseQuery(`SELECT advertiser.* FROM (SELECT a.* FROM CI_ADVERTISER a) advertiser`) + assert.NoError(t, err) + name, table, err := InferRoot(queryNode, "fallback") + assert.NoError(t, err) + assert.Equal(t, "advertiser", name) + assert.Equal(t, "CI_ADVERTISER", table) +} + +func TestInferTableFromSQL_ResolvesTopLevelFrom(t *testing.T) { + sqlText := `SELECT a.*, EXISTS(SELECT 1 FROM CI_ENTITY_WATCHLIST w WHERE w.ENTITY_ID = a.ID) AS watching FROM (SELECT x.* FROM CI_ADVERTISER x) a` + assert.Equal(t, "CI_ADVERTISER", InferTableFromSQL(sqlText)) +} diff --git a/repository/shape/compile/pipeline/parse.go b/repository/shape/compile/pipeline/parse.go new file mode 100644 index 00000000..c897454a --- /dev/null +++ b/repository/shape/compile/pipeline/parse.go @@ -0,0 +1,62 @@ +package pipeline + +import ( + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/parsly" + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/query" +) + +func ParseSelectWithDiagnostic(sqlText string) (*query.Select, *dqlshape.Diagnostic, error) { + sqlText = trimLeadingBlockComments(sqlText) + var diagnostic *dqlshape.Diagnostic + onError := func(err error, cur *parsly.Cursor, _ interface{}) error { + offset := 0 + if cur != nil { + offset = cur.Pos + } + if offset < 0 { + offset = 0 + } + diagnostic = &dqlshape.Diagnostic{ + Code: dqldiag.CodeParseSyntax, + Severity: dqlshape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "check SQL syntax near the reported location", + Span: pointSpan(sqlText, offset), + } + return err + } + result, err := sqlparser.ParseQuery(sqlText, sqlparser.WithErrorHandler(onError)) + if err != nil { + if diagnostic == nil { + diagnostic = &dqlshape.Diagnostic{ + Code: dqldiag.CodeParseSyntax, + Severity: dqlshape.SeverityError, + Message: strings.TrimSpace(err.Error()), + Hint: "check SQL syntax near the reported location", + Span: pointSpan(sqlText, 0), + } + } + return nil, diagnostic, err + } + if result == nil { + return nil, nil, nil + } + return result, nil, nil +} + +func trimLeadingBlockComments(sqlText string) string { + remaining := strings.TrimLeft(sqlText, " \t\r\n") + for strings.HasPrefix(remaining, "/*") { + end := strings.Index(remaining, "*/") + if end == -1 { + return remaining + } + remaining = strings.TrimLeft(remaining[end+2:], " \t\r\n") + } + return remaining +} diff --git a/repository/shape/compile/pipeline/parse_test.go b/repository/shape/compile/pipeline/parse_test.go new file mode 100644 index 00000000..69292fc8 --- /dev/null +++ b/repository/shape/compile/pipeline/parse_test.go @@ -0,0 +1,35 @@ +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" +) + +func TestParseSelectWithDiagnostic_OK(t *testing.T) { + queryNode, diag, err := ParseSelectWithDiagnostic("SELECT id FROM orders o") + require.NoError(t, err) + require.Nil(t, diag) + require.NotNil(t, queryNode) + assert.Equal(t, "o", queryNode.From.Alias) +} + +func TestParseSelectWithDiagnostic_Syntax(t *testing.T) { + queryNode, diag, err := ParseSelectWithDiagnostic("SELECT id FROM orders WHERE (") + require.Error(t, err) + require.Nil(t, queryNode) + require.NotNil(t, diag) + assert.Equal(t, dqldiag.CodeParseSyntax, diag.Code) + assert.Equal(t, 1, diag.Span.Start.Line) + assert.Greater(t, diag.Span.Start.Char, 1) +} + +func TestParseSelectWithDiagnostic_LeadingBlockComment(t *testing.T) { + queryNode, diag, err := ParseSelectWithDiagnostic("/* {\"URI\":\"/x\"} */\nSELECT id FROM orders o") + require.NoError(t, err) + require.Nil(t, diag) + require.NotNil(t, queryNode) + assert.Equal(t, "o", queryNode.From.Alias) +} diff --git a/repository/shape/compile/pipeline/policy.go b/repository/shape/compile/pipeline/policy.go new file mode 100644 index 00000000..432bcd6f --- /dev/null +++ b/repository/shape/compile/pipeline/policy.go @@ -0,0 +1,28 @@ +package pipeline + +import dqlstmt "github.com/viant/datly/repository/shape/dql/statement" + +type Decision struct { + HasRead bool + HasExec bool + HasUnknown bool +} + +func Classify(statements dqlstmt.Statements) Decision { + var ret Decision + for _, stmt := range statements { + if stmt == nil { + continue + } + if stmt.Kind == dqlstmt.KindExec || stmt.Kind == dqlstmt.KindService { + ret.HasExec = true + continue + } + if stmt.Kind == dqlstmt.KindRead { + ret.HasRead = true + continue + } + ret.HasUnknown = true + } + return ret +} diff --git a/repository/shape/compile/pipeline/policy_test.go b/repository/shape/compile/pipeline/policy_test.go new file mode 100644 index 00000000..2ff7d307 --- /dev/null +++ b/repository/shape/compile/pipeline/policy_test.go @@ -0,0 +1,36 @@ +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" +) + +func TestClassify_ReadOnly(t *testing.T) { + decision := Classify(dqlstmt.New("SELECT id FROM orders")) + assert.True(t, decision.HasRead) + assert.False(t, decision.HasExec) + assert.False(t, decision.HasUnknown) +} + +func TestClassify_ExecOnly(t *testing.T) { + decision := Classify(dqlstmt.New("UPDATE orders SET id = 1")) + assert.False(t, decision.HasRead) + assert.True(t, decision.HasExec) + assert.False(t, decision.HasUnknown) +} + +func TestClassify_Mixed(t *testing.T) { + decision := Classify(dqlstmt.New("SELECT id FROM orders\nUPDATE orders SET id = 1")) + assert.True(t, decision.HasRead) + assert.True(t, decision.HasExec) + assert.False(t, decision.HasUnknown) +} + +func TestClassify_UnknownTemplateOnly(t *testing.T) { + decision := Classify(dqlstmt.New("$Foo.Bar($x)")) + assert.False(t, decision.HasRead) + assert.False(t, decision.HasExec) + assert.True(t, decision.HasUnknown) +} diff --git a/repository/shape/compile/pipeline/read.go b/repository/shape/compile/pipeline/read.go new file mode 100644 index 00000000..c52f9cc2 --- /dev/null +++ b/repository/shape/compile/pipeline/read.go @@ -0,0 +1,199 @@ +package pipeline + +import ( + "reflect" + "regexp" + "strings" + + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser/query" +) + +var ( + criteriaBindingExpr = regexp.MustCompile(`(?i)\$criteria\.AppendBinding\([^)]*\)`) + selectorExpr = regexp.MustCompile(`\$\{?([a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*)\}?`) + veltyExpr = regexp.MustCompile(`\$\{[^}]+\}`) + fromTableSimpleExpr = regexp.MustCompile(`(?is)\bfrom\s+([a-zA-Z_][a-zA-Z0-9_$.]*)(?:\s+(?:as\s+)?([a-zA-Z_][a-zA-Z0-9_]*))?`) + braceExpr = regexp.MustCompile(`[{}]`) +) + +func BuildRead(sourceName, sqlText string) (*plan.View, []*dqlshape.Diagnostic, error) { + parserSQL := normalizeParserSQL(sqlText) + queryNode, parseDiag, err := ParseSelectWithDiagnostic(parserSQL) + if err != nil && parserSQL != sqlText { + if rawNode, _, rawErr := ParseSelectWithDiagnostic(sqlText); rawErr == nil && isUsableQuery(rawNode) { + queryNode = rawNode + parserSQL = sqlText + parseDiag = nil + err = nil + } + } + if err == nil && needsFallbackParse(sqlText, queryNode) { + fallbackSQL := normalizeParserSQL(sqlText) + if fallbackNode, _, fallbackErr := ParseSelectWithDiagnostic(fallbackSQL); fallbackErr == nil && isUsableQuery(fallbackNode) { + queryNode = fallbackNode + parserSQL = fallbackSQL + parseDiag = nil + err = nil + } + } + if hasTemplateSignals(sqlText) && (err != nil || parseDiag != nil) { + if parseDiag != nil { + parseDiag.Severity = dqlshape.SeverityWarning + } + var diags []*dqlshape.Diagnostic + if parseDiag != nil { + diags = append(diags, parseDiag) + } + return buildLooseRead(sourceName, sqlText), diags, nil + } + var diags []*dqlshape.Diagnostic + if parseDiag != nil { + diags = append(diags, parseDiag) + } + if err != nil { + if hasTemplateSignals(sqlText) { + if parseDiag != nil { + parseDiag.Severity = dqlshape.SeverityWarning + } + return buildLooseRead(sourceName, sqlText), diags, nil + } + return nil, diags, nil + } + relations, relationDiags := ExtractJoinRelations(parserSQL, queryNode) + diags = append(diags, relationDiags...) + name, table, inferErr := InferRoot(queryNode, sourceName) + if inferErr != nil { + return nil, nil, inferErr + } + fallback := SanitizeName(sourceName) + if name == fallback && table == fallback { + if derived := inferRootFromRelations(relations); derived != "" { + name = derived + table = derived + } + } + fieldType, elementType, cardinality := InferProjectionType(queryNode) + if fieldType == nil || elementType == nil { + fieldType = reflect.TypeOf([]map[string]interface{}{}) + elementType = reflect.TypeOf(map[string]interface{}{}) + cardinality = "many" + } + view := &plan.View{ + Path: name, + Holder: name, + Name: name, + Mode: "SQLQuery", + Table: table, + SQL: sqlText, + Cardinality: cardinality, + FieldType: fieldType, + ElementType: elementType, + Relations: relations, + } + return view, diags, nil +} + +func buildLooseRead(sourceName, sqlText string) *plan.View { + name, table := inferLooseRoot(sourceName, sqlText) + fieldType := reflect.TypeOf([]map[string]interface{}{}) + elementType := reflect.TypeOf(map[string]interface{}{}) + return &plan.View{ + Path: name, + Holder: name, + Name: name, + Mode: "SQLQuery", + Table: table, + SQL: sqlText, + Cardinality: "many", + FieldType: fieldType, + ElementType: elementType, + } +} + +func inferLooseRoot(sourceName, sqlText string) (string, string) { + name := SanitizeName(sourceName) + if name == "" { + name = "DQLView" + } + if matches := fromTableSimpleExpr.FindStringSubmatch(sqlText); len(matches) > 1 { + table := strings.Trim(matches[1], "`\"") + return name, table + } + return name, name +} + +func hasTemplateSignals(sqlText string) bool { + lower := strings.ToLower(sqlText) + return strings.Contains(lower, "#if(") || strings.Contains(lower, "#elseif(") || strings.Contains(lower, "#else") || + strings.Contains(lower, "#end") || strings.Contains(lower, "${") || strings.Contains(lower, "$unsafe.") || + strings.Contains(lower, "$view.") || strings.Contains(lower, "$predicate.") +} + +func isUsableQuery(queryNode *query.Select) bool { + return queryNode != nil && queryNode.From.X != nil +} + +func needsFallbackParse(rawSQL string, queryNode *query.Select) bool { + if !isUsableQuery(queryNode) { + return true + } + lower := strings.ToLower(rawSQL) + if strings.Contains(lower, " join ") && len(queryNode.Joins) == 0 { + return true + } + return false +} + +func normalizeParserSQL(sqlText string) string { + if sqlText == "" { + return sqlText + } + normalized := criteriaBindingExpr.ReplaceAllString(sqlText, "1") + normalized = veltyExpr.ReplaceAllStringFunc(normalized, func(match string) string { + if strings.Contains(match, "sql.Insert") || strings.Contains(match, "sql.Update") || strings.Contains(match, "Nop") { + return match + } + lower := strings.ToLower(match) + if strings.Contains(lower, `build("where")`) || strings.Contains(lower, "build('where')") { + return " WHERE 1 " + } + if strings.Contains(lower, `build("and")`) || strings.Contains(lower, "build('and')") { + return " AND 1 " + } + return "1" + }) + normalized = selectorExpr.ReplaceAllStringFunc(normalized, func(match string) string { + lower := match + if len(match) > 0 && match[0] == '$' { + lower = match[1:] + } + lower = braceExpr.ReplaceAllString(lower, "") + switch lower { + case "sql.Insert", "sql.Update", "Nop": + return match + default: + return "1" + } + }) + return normalized +} + +func inferRootFromRelations(relations []*plan.Relation) string { + for _, relation := range relations { + if relation == nil { + continue + } + for _, link := range relation.On { + if link == nil { + continue + } + name := SanitizeName(link.ParentNamespace) + if name != "" { + return name + } + } + } + return "" +} diff --git a/repository/shape/compile/pipeline/read_test.go b/repository/shape/compile/pipeline/read_test.go new file mode 100644 index 00000000..0c82d72a --- /dev/null +++ b/repository/shape/compile/pipeline/read_test.go @@ -0,0 +1,65 @@ +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/sqlparser/expr" + "github.com/viant/sqlparser/query" +) + +func TestBuildRead(t *testing.T) { + view, diags, err := BuildRead("orders_report", "SELECT o.id, i.sku FROM orders o JOIN items i ON o.id = i.order_id") + require.NoError(t, err) + require.NotNil(t, view) + assert.Equal(t, "o", view.Name) + assert.Equal(t, "orders", view.Table) + assert.Equal(t, "many", view.Cardinality) + require.Len(t, view.Relations, 1) + assert.Equal(t, "i", view.Relations[0].Ref) + assert.Empty(t, diags) +} + +func TestBuildRead_SubqueryJoin_UsesParentNamespaceAsRoot(t *testing.T) { + sqlText := `SELECT session.* +FROM (SELECT * FROM session WHERE user_id = $criteria.AppendBinding($Unsafe.Jwt.UserID)) session +JOIN (SELECT * FROM session/attributes) attribute ON attribute.user_id = session.user_id` + view, _, err := BuildRead("system/session", sqlText) + require.NoError(t, err) + require.NotNil(t, view) + assert.Equal(t, "session", view.Name) + assert.Equal(t, "session", view.Table) + require.NotEmpty(t, view.Relations) + assert.Equal(t, "attribute", view.Relations[0].Ref) +} + +func TestNormalizeParserSQL(t *testing.T) { + input := "SELECT * FROM session WHERE user_id = $criteria.AppendBinding($Unsafe.Jwt.UserID) AND x = $Jwt.UserID" + actual := normalizeParserSQL(input) + assert.NotContains(t, actual, "$criteria.AppendBinding") + assert.NotContains(t, actual, "$Jwt.UserID") + assert.Contains(t, actual, "user_id = 1") +} + +func TestNormalizeParserSQL_VeltyBlockExpression(t *testing.T) { + input := `SELECT b.* FROM CI_BROWSER b ${predicate.Builder().CombineOr($predicate.FilterGroup(0, "AND")).Build("WHERE")} AND b.ARCHIVED = 0` + actual := normalizeParserSQL(input) + assert.NotContains(t, actual, "${predicate.Builder()") + assert.Contains(t, actual, "SELECT b.* FROM CI_BROWSER b WHERE 1 AND b.ARCHIVED = 0") +} + +func TestNeedsFallbackParse(t *testing.T) { + assert.True(t, needsFallbackParse("SELECT * FROM t JOIN x ON t.id = x.id", &query.Select{})) + assert.False(t, needsFallbackParse("SELECT * FROM t", &query.Select{From: query.From{X: expr.NewSelector("t")}})) +} + +func TestBuildRead_FallbackWhenInitialParseFails(t *testing.T) { + sqlText := `SELECT b.* FROM CI_BROWSER b ${predicate.Builder().CombineOr($predicate.FilterGroup(0, "AND")).Build("WHERE")} AND b.ARCHIVED = 0` + view, diags, err := BuildRead("browser", sqlText) + require.NoError(t, err) + require.NotNil(t, view) + assert.Equal(t, "b", view.Name) + assert.Equal(t, "CI_BROWSER", view.Table) + assert.Empty(t, diags) +} diff --git a/repository/shape/compile/pipeline/relation.go b/repository/shape/compile/pipeline/relation.go new file mode 100644 index 00000000..722ff516 --- /dev/null +++ b/repository/shape/compile/pipeline/relation.go @@ -0,0 +1,329 @@ +package pipeline + +import ( + "fmt" + "regexp" + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/expr" + "github.com/viant/sqlparser/node" + "github.com/viant/sqlparser/query" +) + +var joinSelectorEqExpr = regexp.MustCompile(`(?i)([a-zA-Z_][a-zA-Z0-9_]*)\.([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*([a-zA-Z_][a-zA-Z0-9_]*)\.([a-zA-Z_][a-zA-Z0-9_]*)`) + +func ExtractJoinRelations(raw string, queryNode *query.Select) ([]*plan.Relation, []*dqlshape.Diagnostic) { + if queryNode == nil || len(queryNode.Joins) == 0 { + return nil, nil + } + rootAlias := rootNamespace(queryNode) + var relations []*plan.Relation + var diagnostics []*dqlshape.Diagnostic + + for idx, join := range queryNode.Joins { + if join == nil { + continue + } + offset := relationOffset(raw, join) + span := pointSpan(raw, offset) + ref, table := relationRef(join, idx+1) + relation := &plan.Relation{ + Name: ref, + Holder: ExportedName(ref), + Ref: ref, + Table: table, + Kind: strings.TrimSpace(join.Kind), + Raw: strings.TrimSpace(join.Raw), + } + if relation.Holder == "" { + relation.Holder = fmt.Sprintf("Rel%d", idx+1) + } + if join.On == nil || join.On.X == nil { + diagnostics = append(diagnostics, &dqlshape.Diagnostic{ + Code: dqldiag.CodeRelMissingON, + Severity: dqlshape.SeverityWarning, + Message: "join is missing ON condition", + Hint: "use explicit ON condition to derive relation links", + Span: span, + }) + relation.Warnings = append(relation.Warnings, "missing ON condition") + relations = append(relations, relation) + continue + } + pairs := collectJoinPairs(join.On.X) + if len(pairs) == 0 { + onExpr := strings.TrimSpace(sqlparser.Stringify(join.On.X)) + if shouldFallbackToRawJoinPairs(onExpr) { + pairs = collectJoinPairsFromRaw(onExpr) + } + } + if len(pairs) == 0 { + diagnostics = append(diagnostics, &dqlshape.Diagnostic{ + Code: dqldiag.CodeRelUnsupported, + Severity: dqlshape.SeverityWarning, + Message: "join ON condition could not be translated into relation links", + Hint: "use equality predicates between concrete columns, e.g. a.id = b.ref_id", + Span: span, + }) + relation.Warnings = append(relation.Warnings, "unsupported ON predicate") + relations = append(relations, relation) + continue + } + for _, pair := range pairs { + link, warning := orientJoinPair(pair, rootAlias, ref) + if warning != "" { + diagnostics = append(diagnostics, &dqlshape.Diagnostic{ + Code: dqldiag.CodeRelAmbiguous, + Severity: dqlshape.SeverityWarning, + Message: warning, + Hint: "use explicit aliases so one side belongs to root and the other to joined table", + Span: span, + }) + relation.Warnings = append(relation.Warnings, warning) + } + if link == nil { + continue + } + relation.On = append(relation.On, link) + } + if len(relation.On) == 0 { + diagnostics = append(diagnostics, &dqlshape.Diagnostic{ + Code: dqldiag.CodeRelNoLinks, + Severity: dqlshape.SeverityWarning, + Message: "join ON condition does not expose extractable column links", + Hint: "ensure both sides of '=' are concrete column references", + Span: span, + }) + relation.Warnings = append(relation.Warnings, "no extractable links") + } + relations = append(relations, relation) + } + return relations, diagnostics +} + +func collectJoinPairsFromRaw(input string) []joinPair { + input = strings.TrimSpace(input) + if input == "" { + return nil + } + var result []joinPair + for _, m := range joinSelectorEqExpr.FindAllStringSubmatch(input, -1) { + if len(m) < 5 { + continue + } + left := strings.TrimSpace(m[1] + "." + m[2]) + right := strings.TrimSpace(m[3] + "." + m[4]) + if left == "" || right == "" { + continue + } + result = append(result, joinPair{left: left, right: right}) + } + return result +} + +func shouldFallbackToRawJoinPairs(input string) bool { + input = strings.TrimSpace(strings.ToLower(input)) + if input == "" { + return false + } + // Restrict raw fallback to simple selector equality text to avoid brittle extraction + // for quoted identifiers, function calls, casts, and richer predicates. + bannedFragments := []string{ + "`", "\"", "'", "(", ")", "::", " collate ", " case ", " when ", " then ", " else ", " end ", + " coalesce", " cast", " concat", " substr", " lower", " upper", " trim", + } + for _, fragment := range bannedFragments { + if strings.Contains(input, fragment) { + return false + } + } + return true +} + +type joinPair struct { + left string + right string +} + +func collectJoinPairs(n node.Node) []joinPair { + switch actual := n.(type) { + case *expr.Binary: + op := strings.ToUpper(strings.TrimSpace(actual.Op)) + if op == "AND" || op == "OR" { + left := collectJoinPairs(actual.X) + right := collectJoinPairs(actual.Y) + return append(left, right...) + } + if op != "=" { + return nil + } + left := selectorName(actual.X) + right := selectorName(actual.Y) + if left == "" || right == "" { + return nil + } + return []joinPair{{left: left, right: right}} + case *expr.Parenthesis: + return collectJoinPairs(actual.X) + default: + return nil + } +} + +func selectorName(n node.Node) string { + switch actual := n.(type) { + case *expr.Selector: + return strings.TrimSpace(sqlparser.Stringify(actual)) + case *expr.Parenthesis: + return selectorName(actual.X) + default: + return "" + } +} + +func orientJoinPair(pair joinPair, rootAlias, refAlias string) (*plan.RelationLink, string) { + leftNS, leftCol := splitSelector(pair.left) + rightNS, rightCol := splitSelector(pair.right) + if leftCol == "" || rightCol == "" { + return nil, "" + } + switch { + case leftNS == rootAlias && (rightNS == refAlias || rightNS == ""): + return &plan.RelationLink{ + ParentNamespace: leftNS, + ParentColumn: leftCol, + RefNamespace: firstNonEmpty(rightNS, refAlias), + RefColumn: rightCol, + Expression: pair.left + "=" + pair.right, + }, "" + case rightNS == rootAlias && (leftNS == refAlias || leftNS == ""): + return &plan.RelationLink{ + ParentNamespace: rightNS, + ParentColumn: rightCol, + RefNamespace: firstNonEmpty(leftNS, refAlias), + RefColumn: leftCol, + Expression: pair.left + "=" + pair.right, + }, "" + case leftNS == "" && rightNS == "": + return &plan.RelationLink{ + ParentNamespace: rootAlias, + ParentColumn: leftCol, + RefNamespace: refAlias, + RefColumn: rightCol, + Expression: pair.left + "=" + pair.right, + }, "join columns lack namespaces, relation orientation was inferred" + case leftNS == refAlias: + parentNS := rightNS + if parentNS == "" { + parentNS = rootAlias + } + return &plan.RelationLink{ + ParentNamespace: parentNS, + ParentColumn: rightCol, + RefNamespace: leftNS, + RefColumn: leftCol, + Expression: pair.left + "=" + pair.right, + }, "" + case rightNS == refAlias: + parentNS := leftNS + if parentNS == "" { + parentNS = rootAlias + } + return &plan.RelationLink{ + ParentNamespace: parentNS, + ParentColumn: leftCol, + RefNamespace: rightNS, + RefColumn: rightCol, + Expression: pair.left + "=" + pair.right, + }, "" + default: + return nil, fmt.Sprintf("ambiguous join link %q cannot be oriented between root=%q and ref=%q", pair.left+"="+pair.right, rootAlias, refAlias) + } +} + +func relationOffset(raw string, join *query.Join) int { + if strings.TrimSpace(raw) == "" { + return 0 + } + if join != nil && join.On != nil && join.On.X != nil { + if onExpr := strings.TrimSpace(sqlparser.Stringify(join.On.X)); onExpr != "" { + if idx := strings.Index(strings.ToLower(raw), strings.ToLower(onExpr)); idx >= 0 { + return idx + } + } + } + if join != nil && strings.TrimSpace(join.Raw) != "" { + if idx := strings.Index(strings.ToLower(raw), strings.ToLower(strings.TrimSpace(join.Raw))); idx >= 0 { + return idx + } + } + return 0 +} + +func rootNamespace(queryNode *query.Select) string { + if queryNode == nil { + return "" + } + if alias := strings.TrimSpace(queryNode.From.Alias); alias != "" { + return alias + } + if queryNode.From.X == nil { + return "" + } + root := strings.TrimSpace(sqlparser.Stringify(queryNode.From.X)) + root = strings.Trim(root, "`\"") + if root == "" { + return "" + } + if idx := strings.LastIndex(root, "."); idx != -1 { + root = root[idx+1:] + } + return root +} + +func relationRef(join *query.Join, ordinal int) (string, string) { + if join == nil { + return fmt.Sprintf("join_%d", ordinal), "" + } + ref := strings.TrimSpace(join.Alias) + table := "" + if join.With != nil { + table = strings.TrimSpace(sqlparser.Stringify(join.With)) + } + if ref == "" { + ref = table + if idx := strings.LastIndex(ref, "."); idx != -1 { + ref = ref[idx+1:] + } + } + ref = SanitizeName(strings.Trim(ref, "`\"")) + if ref == "" { + ref = fmt.Sprintf("join_%d", ordinal) + } + return ref, table +} + +func splitSelector(selector string) (string, string) { + selector = strings.TrimSpace(selector) + if selector == "" { + return "", "" + } + selector = strings.Trim(selector, "`\"") + if idx := strings.Index(selector, "."); idx != -1 { + return strings.Trim(selector[:idx], "`\""), strings.Trim(selector[idx+1:], "`\"") + } + return "", selector +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} diff --git a/repository/shape/compile/pipeline/relation_test.go b/repository/shape/compile/pipeline/relation_test.go new file mode 100644 index 00000000..62f5c9d3 --- /dev/null +++ b/repository/shape/compile/pipeline/relation_test.go @@ -0,0 +1,89 @@ +package pipeline + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + "github.com/viant/sqlparser" +) + +func TestExtractJoinRelations(t *testing.T) { + sqlText := "SELECT o.id FROM orders o JOIN order_items i ON o.id = i.order_id" + queryNode, err := sqlparser.ParseQuery(sqlText) + require.NoError(t, err) + relations, diags := ExtractJoinRelations(sqlText, queryNode) + require.Len(t, relations, 1) + assert.Equal(t, "i", relations[0].Ref) + require.Len(t, relations[0].On, 1) + assert.Equal(t, "id", relations[0].On[0].ParentColumn) + assert.Equal(t, "order_id", relations[0].On[0].RefColumn) + assert.Empty(t, diags) +} + +func TestExtractJoinRelations_UnsupportedPredicate(t *testing.T) { + sqlText := "SELECT o.id FROM orders o JOIN order_items i ON o.id > i.order_id" + queryNode, err := sqlparser.ParseQuery(sqlText) + require.NoError(t, err) + _, diags := ExtractJoinRelations(sqlText, queryNode) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeRelUnsupported, diags[0].Code) +} + +func TestExtractJoinRelations_WithAndLiteral(t *testing.T) { + sqlText := "SELECT t.id FROM taxonomy t LEFT JOIN provider p ON p.id = t.provider_id AND 1=1" + queryNode, err := sqlparser.ParseQuery(sqlText) + require.NoError(t, err) + relations, diags := ExtractJoinRelations(sqlText, queryNode) + require.Len(t, relations, 1) + require.Len(t, relations[0].On, 1) + assert.Equal(t, "provider_id", relations[0].On[0].ParentColumn) + assert.Equal(t, "id", relations[0].On[0].RefColumn) + assert.Empty(t, diags) +} + +func TestExtractJoinRelations_NonRootParentChain(t *testing.T) { + sqlText := "SELECT sl.id FROM site_list sl JOIN site_list_match m ON m.site_list_id = sl.id JOIN ci_site s ON s.id = m.site_id JOIN ci_publisher p ON p.id = s.publisher_id" + queryNode, err := sqlparser.ParseQuery(sqlText) + require.NoError(t, err) + relations, diags := ExtractJoinRelations(sqlText, queryNode) + require.Len(t, relations, 3) + + require.Len(t, relations[0].On, 1) + assert.Equal(t, "sl", relations[0].On[0].ParentNamespace) + assert.Equal(t, "id", relations[0].On[0].ParentColumn) + assert.Equal(t, "m", relations[0].On[0].RefNamespace) + assert.Equal(t, "site_list_id", relations[0].On[0].RefColumn) + + require.Len(t, relations[1].On, 1) + assert.Equal(t, "m", relations[1].On[0].ParentNamespace) + assert.Equal(t, "site_id", relations[1].On[0].ParentColumn) + assert.Equal(t, "s", relations[1].On[0].RefNamespace) + assert.Equal(t, "id", relations[1].On[0].RefColumn) + + require.Len(t, relations[2].On, 1) + assert.Equal(t, "s", relations[2].On[0].ParentNamespace) + assert.Equal(t, "publisher_id", relations[2].On[0].ParentColumn) + assert.Equal(t, "p", relations[2].On[0].RefNamespace) + assert.Equal(t, "id", relations[2].On[0].RefColumn) + assert.Empty(t, diags) +} + +func TestExtractJoinRelations_DoesNotFallbackForComplexRawPredicate(t *testing.T) { + sqlText := "SELECT o.id FROM orders o JOIN order_items i ON COALESCE(o.id, 0) = i.order_id" + queryNode, err := sqlparser.ParseQuery(sqlText) + require.NoError(t, err) + relations, diags := ExtractJoinRelations(sqlText, queryNode) + require.Len(t, relations, 1) + assert.Empty(t, relations[0].On) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeRelUnsupported, diags[0].Code) +} + +func TestShouldFallbackToRawJoinPairs(t *testing.T) { + assert.True(t, shouldFallbackToRawJoinPairs("o.id = i.order_id")) + assert.False(t, shouldFallbackToRawJoinPairs("COALESCE(o.id, 0) = i.order_id")) + assert.False(t, shouldFallbackToRawJoinPairs("`o`.`id` = `i`.`order_id`")) + assert.False(t, shouldFallbackToRawJoinPairs(`"o"."id" = "i"."order_id"`)) +} diff --git a/repository/shape/compile/pipeline/table.go b/repository/shape/compile/pipeline/table.go new file mode 100644 index 00000000..5888aeae --- /dev/null +++ b/repository/shape/compile/pipeline/table.go @@ -0,0 +1,21 @@ +package pipeline + +import "strings" + +// InferTableFromSQL infers root table from SQL text using parser-first strategy. +func InferTableFromSQL(sqlText string) string { + sqlText = strings.TrimSpace(sqlText) + if sqlText == "" { + return "" + } + normalized := normalizeParserSQL(sqlText) + queryNode, _, err := ParseSelectWithDiagnostic(normalized) + if err != nil || queryNode == nil { + return "" + } + _, table, err := InferRoot(queryNode, "") + if err != nil { + return "" + } + return strings.TrimSpace(strings.Trim(table, "`\"")) +} diff --git a/repository/shape/compile/policy.go b/repository/shape/compile/policy.go new file mode 100644 index 00000000..1bd02ca6 --- /dev/null +++ b/repository/shape/compile/policy.go @@ -0,0 +1,48 @@ +package compile + +import ( + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" +) + +func hasEscalationWarnings(diags []*dqlshape.Diagnostic) bool { + for _, item := range diags { + if item == nil { + continue + } + if item.Severity != dqlshape.SeverityWarning { + continue + } + if strings.HasPrefix(item.Code, dqldiag.PrefixRel) || strings.HasPrefix(item.Code, dqldiag.PrefixSQLI) { + return true + } + } + return false +} + +func hasErrorDiagnostics(diags []*dqlshape.Diagnostic) bool { + for _, item := range diags { + if item == nil { + continue + } + if item.Severity == dqlshape.SeverityError { + return true + } + } + return false +} + +func filterEscalationDiagnostics(diags []*dqlshape.Diagnostic) []*dqlshape.Diagnostic { + var result []*dqlshape.Diagnostic + for _, item := range diags { + if item == nil { + continue + } + if strings.HasPrefix(item.Code, dqldiag.PrefixRel) || strings.HasPrefix(item.Code, dqldiag.PrefixSQLI) { + result = append(result, item) + } + } + return result +} diff --git a/repository/shape/compile/policy_test.go b/repository/shape/compile/policy_test.go new file mode 100644 index 00000000..a15e9f88 --- /dev/null +++ b/repository/shape/compile/policy_test.go @@ -0,0 +1,40 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" +) + +func TestPolicy_HasEscalationWarnings(t *testing.T) { + diags := []*dqlshape.Diagnostic{ + {Code: dqldiag.CodeRelAmbiguous, Severity: dqlshape.SeverityWarning}, + } + assert.True(t, hasEscalationWarnings(diags)) + assert.False(t, hasEscalationWarnings([]*dqlshape.Diagnostic{ + {Code: dqldiag.CodeViewMissingSQL, Severity: dqlshape.SeverityWarning}, + })) +} + +func TestPolicy_HasErrorDiagnostics(t *testing.T) { + assert.True(t, hasErrorDiagnostics([]*dqlshape.Diagnostic{ + {Code: dqldiag.CodeParseSyntax, Severity: dqlshape.SeverityError}, + })) + assert.False(t, hasErrorDiagnostics([]*dqlshape.Diagnostic{ + {Code: dqldiag.CodeRelAmbiguous, Severity: dqlshape.SeverityWarning}, + })) +} + +func TestPolicy_FilterEscalationDiagnostics(t *testing.T) { + diags := []*dqlshape.Diagnostic{ + {Code: dqldiag.CodeViewMissingSQL, Severity: dqlshape.SeverityWarning}, + {Code: dqldiag.CodeSQLIRawSelector, Severity: dqlshape.SeverityWarning}, + {Code: dqldiag.CodeRelNoLinks, Severity: dqlshape.SeverityWarning}, + } + filtered := filterEscalationDiagnostics(diags) + assert.Len(t, filtered, 2) + assert.Equal(t, dqldiag.CodeSQLIRawSelector, filtered[0].Code) + assert.Equal(t, dqldiag.CodeRelNoLinks, filtered[1].Code) +} diff --git a/repository/shape/compile/preprocess_handler.go b/repository/shape/compile/preprocess_handler.go new file mode 100644 index 00000000..bea319f7 --- /dev/null +++ b/repository/shape/compile/preprocess_handler.go @@ -0,0 +1,150 @@ +package compile + +import ( + "os" + "path/filepath" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/compile/pipeline" + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" + "github.com/viant/datly/repository/shape/plan" +) + +type handlerPreprocessResult struct { + Pre *dqlpre.Result + Statements dqlstmt.Statements + Decision pipeline.Decision + LegacyViews []*plan.View + EffectiveSource *shape.Source + ForceLegacyContract bool +} + +func buildHandlerIfNeeded(source *shape.Source, pre *dqlpre.Result, statements dqlstmt.Statements, decision pipeline.Decision, layout compilePathLayout) *handlerPreprocessResult { + ret := &handlerPreprocessResult{ + Pre: pre, + Statements: statements, + Decision: decision, + EffectiveSource: source, + } + if source == nil { + return ret + } + unknownOnly := decision.HasUnknown && !decision.HasRead && !decision.HasExec + if !unknownOnly && !isHandlerSignal(source) { + return ret + } + if buildHandlerFromContractIfNeeded(ret, source, layout) { + return ret + } + if buildGeneratedFallbackIfNeeded(ret, source, layout) { + return ret + } + return ret +} + +func buildHandlerFromContractIfNeeded(ret *handlerPreprocessResult, source *shape.Source, layout compilePathLayout) bool { + if ret == nil || source == nil { + return false + } + return buildLegacyRouteFallbackIfNeeded(ret, source, layout) +} + +func buildGeneratedFallbackIfNeeded(ret *handlerPreprocessResult, source *shape.Source, layout compilePathLayout) bool { + if ret == nil || source == nil { + return false + } + if alternate := resolveGeneratedLegacySource(source); alternate != nil { + if buildLegacyRouteFallbackIfNeeded(ret, alternate, layout) { + return true + } + } + generated := strings.TrimSpace(resolveGeneratedCompanionDQL(source)) + if generated == "" { + return false + } + candidate := dqlpre.Prepare(generated) + if strings.TrimSpace(candidate.SQL) == "" { + return false + } + candidateStatements := dqlstmt.New(candidate.SQL) + candidateDecision := pipeline.Classify(candidateStatements) + if !candidateDecision.HasRead && !candidateDecision.HasExec { + return false + } + ret.Pre = candidate + ret.Statements = candidateStatements + ret.Decision = candidateDecision + return true +} + +func buildLegacyRouteFallbackIfNeeded(ret *handlerPreprocessResult, source *shape.Source, layout compilePathLayout) bool { + if ret == nil || source == nil { + return false + } + legacyFallbackViews := resolveLegacyRouteViewsWithLayout(source, layout) + if len(legacyFallbackViews) == 0 { + return false + } + ret.LegacyViews = legacyFallbackViews + ret.EffectiveSource = source + ret.ForceLegacyContract = true + return true +} + +func resolveGeneratedLegacySource(source *shape.Source) *shape.Source { + if source == nil || strings.TrimSpace(source.Path) == "" { + return nil + } + path := filepath.Clean(source.Path) + normalized := filepath.ToSlash(path) + genIdx := strings.Index(normalized, "/gen/") + if genIdx == -1 { + return nil + } + prefix := normalized[:genIdx] + suffix := strings.TrimPrefix(normalized[genIdx+len("/gen/"):], "/") + parts := strings.Split(suffix, "/") + if len(parts) < 2 { + return nil + } + fileName := parts[len(parts)-1] + stem := strings.TrimSuffix(fileName, filepath.Ext(fileName)) + candidates := []string{ + filepath.FromSlash(prefix + "/" + fileName), + filepath.FromSlash(prefix + "/" + stem + ".sql"), + filepath.FromSlash(prefix + "/" + stem + ".dql"), + } + for _, candidate := range candidates { + data, err := os.ReadFile(candidate) + if err != nil { + continue + } + clone := *source + clone.Path = candidate + clone.DQL = string(data) + return &clone + } + return nil +} + +func isHandlerSignal(source *shape.Source) bool { + if source == nil { + return false + } + settings := extractRuleSettings(source) + if settings != nil { + if strings.TrimSpace(settings.Type) != "" { + return true + } + if method := strings.TrimSpace(strings.ToUpper(settings.Method)); method != "" && method != "GET" { + return true + } + if strings.Contains(strings.ToLower(strings.TrimSpace(settings.URI)), "/proxy") { + return true + } + } + raw := strings.ToLower(strings.TrimSpace(source.DQL)) + return strings.Contains(raw, "$nop(") || strings.Contains(raw, "$proxy(") +} diff --git a/repository/shape/compile/preprocess_handler_test.go b/repository/shape/compile/preprocess_handler_test.go new file mode 100644 index 00000000..7d4e5782 --- /dev/null +++ b/repository/shape/compile/preprocess_handler_test.go @@ -0,0 +1,143 @@ +package compile + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/compile/pipeline" + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" +) + +func TestIsHandlerSignal(t *testing.T) { + assert.True(t, isHandlerSignal(&shape.Source{DQL: `/* {"Type":"campaign/patch.Handler"} */`})) + assert.True(t, isHandlerSignal(&shape.Source{DQL: `$Nop($Data)`})) + assert.True(t, isHandlerSignal(&shape.Source{DQL: `$Proxy($Data)`})) + assert.False(t, isHandlerSignal(&shape.Source{DQL: `SELECT id FROM proxy_audit`})) + assert.False(t, isHandlerSignal(&shape.Source{DQL: `/* proxy disabled */ SELECT 1`})) + assert.False(t, isHandlerSignal(&shape.Source{DQL: `SELECT 1`})) +} + +func TestBuildHandlerFromContractIfNeeded_LegacyFallbackViews(t *testing.T) { + tempDir := t.TempDir() + sourcePath := filepath.Join(tempDir, "dql", "platform", "campaign", "post.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) + dql := `/* {"Type":"campaign/patch.Handler","Connector":"ci_ads"} */` + require.NoError(t, os.WriteFile(sourcePath, []byte(dql), 0o644)) + + routeDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "platform", "campaign", "patch", "post") + require.NoError(t, os.MkdirAll(routeDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(routeDir, "post.sql"), []byte(`SELECT 1`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routeDir, "CurCampaign.sql"), []byte(`SELECT * FROM CI_CAMPAIGN`), 0o644)) + + source := &shape.Source{Path: sourcePath, DQL: dql} + pre := dqlpre.Prepare(source.DQL) + statements := dqlstmt.New(pre.SQL) + decision := pipeline.Classify(statements) + result := &handlerPreprocessResult{Pre: pre, Statements: statements, Decision: decision, EffectiveSource: source} + applied := buildHandlerFromContractIfNeeded(result, source, defaultCompilePathLayout()) + require.True(t, applied) + require.NotNil(t, result) + require.NotEmpty(t, result.LegacyViews) + assert.Equal(t, "post", result.LegacyViews[0].Name) +} + +func TestBuildGeneratedFallbackIfNeeded_GeneratedCompanion(t *testing.T) { + tempDir := t.TempDir() + dqlPath := filepath.Join(tempDir, "platform", "adorder", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Join(filepath.Dir(dqlPath), "gen", "adorder"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(filepath.Dir(dqlPath), "gen", "adorder", "patch.dql"), []byte("SELECT o.id FROM ORDERS o"), 0o644)) + source := &shape.Source{ + Name: "patch", + Path: dqlPath, + DQL: `/* {"Type":"adorder/patch.Handler"} */`, + } + pre := dqlpre.Prepare(source.DQL) + statements := dqlstmt.New(pre.SQL) + decision := pipeline.Classify(statements) + result := &handlerPreprocessResult{Pre: pre, Statements: statements, Decision: decision, EffectiveSource: source} + applied := buildGeneratedFallbackIfNeeded(result, source, defaultCompilePathLayout()) + require.True(t, applied) + require.NotNil(t, result) + assert.Empty(t, result.LegacyViews) + assert.Contains(t, result.Pre.SQL, "SELECT o.id FROM ORDERS o") + assert.True(t, result.Decision.HasRead) +} + +func TestResolveGeneratedLegacySource(t *testing.T) { + tempDir := t.TempDir() + genPath := filepath.Join(tempDir, "dql", "system", "session", "gen", "session", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(genPath), 0o755)) + require.NoError(t, os.WriteFile(genPath, []byte(`/* {"Method":"PATCH","URI":"/v1/api/system/session"} */`), 0o644)) + legacySQL := filepath.Join(tempDir, "dql", "system", "session", "patch.sql") + require.NoError(t, os.MkdirAll(filepath.Dir(legacySQL), 0o755)) + require.NoError(t, os.WriteFile(legacySQL, []byte(`/* {"Type":"session/patch.Handler"} */`), 0o644)) + + source := &shape.Source{Path: genPath, DQL: `/* {"Method":"PATCH","URI":"/v1/api/system/session"} */`} + actual := resolveGeneratedLegacySource(source) + require.NotNil(t, actual) + assert.Equal(t, legacySQL, actual.Path) + assert.Contains(t, actual.DQL, `"Type":"session/patch.Handler"`) +} + +func TestBuildGeneratedFallbackIfNeeded_GeneratedLegacyRoute(t *testing.T) { + tempDir := t.TempDir() + genPath := filepath.Join(tempDir, "dql", "system", "session", "gen", "session", "patch.dql") + require.NoError(t, os.MkdirAll(filepath.Dir(genPath), 0o755)) + require.NoError(t, os.WriteFile(genPath, []byte(`/* {"Method":"PATCH","URI":"/v1/api/system/session"} */`), 0o644)) + legacySQL := filepath.Join(tempDir, "dql", "system", "session", "patch.sql") + require.NoError(t, os.MkdirAll(filepath.Dir(legacySQL), 0o755)) + require.NoError(t, os.WriteFile(legacySQL, []byte(`/* {"Type":"session/patch.Handler","Connector":"system"} */`), 0o644)) + + routesDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "system", "session", "patch") + require.NoError(t, os.MkdirAll(routesDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(filepath.Dir(routesDir), "patch.yaml"), []byte(`Resource: + Views: + - Name: patch + Mode: SQLExec + Connector: + Ref: system + Template: + SourceURL: patch/patch.sql + Parameters: + - Name: Session + In: + Kind: body + Name: data + Types: + - Name: Input + DataType: "*Input" + Package: session/patch +`), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "patch.sql"), []byte(`$Nop($Unsafe.Session)`), 0o644)) + + source := &shape.Source{Path: genPath, DQL: `/* {"Method":"PATCH","URI":"/v1/api/system/session"} */`} + pre := dqlpre.Prepare(source.DQL) + statements := dqlstmt.New(pre.SQL) + decision := pipeline.Classify(statements) + result := &handlerPreprocessResult{Pre: pre, Statements: statements, Decision: decision, EffectiveSource: source} + applied := buildGeneratedFallbackIfNeeded(result, source, defaultCompilePathLayout()) + require.True(t, applied) + require.NotNil(t, result) + require.True(t, result.ForceLegacyContract) + require.NotNil(t, result.EffectiveSource) + assert.Equal(t, legacySQL, result.EffectiveSource.Path) + require.NotEmpty(t, result.LegacyViews) + assert.Equal(t, "patch", result.LegacyViews[0].Name) +} + +func TestBuildLegacyRouteFallbackIfNeeded_NoLegacyRoute(t *testing.T) { + source := &shape.Source{Path: filepath.Join(t.TempDir(), "dql", "x", "y", "z.dql"), DQL: `SELECT 1`} + pre := dqlpre.Prepare(source.DQL) + statements := dqlstmt.New(pre.SQL) + decision := pipeline.Classify(statements) + result := &handlerPreprocessResult{Pre: pre, Statements: statements, Decision: decision, EffectiveSource: source} + applied := buildLegacyRouteFallbackIfNeeded(result, source, defaultCompilePathLayout()) + assert.False(t, applied) + assert.Empty(t, result.LegacyViews) + assert.False(t, result.ForceLegacyContract) +} diff --git a/repository/shape/compile/span.go b/repository/shape/compile/span.go new file mode 100644 index 00000000..154ff9b2 --- /dev/null +++ b/repository/shape/compile/span.go @@ -0,0 +1,10 @@ +package compile + +import ( + dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" +) + +func relationSpan(raw string, offset int) dqlshape.Span { + return dqlpre.PointSpan(raw, offset) +} diff --git a/repository/shape/compile/statedecl.go b/repository/shape/compile/statedecl.go new file mode 100644 index 00000000..eab76c64 --- /dev/null +++ b/repository/shape/compile/statedecl.go @@ -0,0 +1,223 @@ +package compile + +import ( + "strconv" + "strings" + + "github.com/viant/datly/repository/shape/plan" +) + +func appendDeclaredStates(rawDQL string, result *plan.Result) { + if result == nil || strings.TrimSpace(rawDQL) == "" { + return + } + seen := map[string]bool{} + for _, block := range extractSetBlocks(rawDQL) { + holder, kind, location, tail, ok := parseSetDeclarationBody(block.Body) + if !ok { + continue + } + if kind == "view" || kind == "data_view" { + continue + } + key := declaredStateKey(holder, kind, location) + if seen[key] { + continue + } + state := &plan.State{ + Path: holder, + Name: holder, + Kind: kind, + In: location, + } + switch strings.ToLower(kind) { + case "query": + required := false + state.Required = &required + case "header": + required := true + state.Required = &required + } + applyDeclaredStateOptions(state, tail) + result.States = append(result.States, state) + seen[key] = true + } +} + +func declaredStateKey(name, kind, in string) string { + return strings.ToLower(strings.TrimSpace(name)) + "|" + + strings.ToLower(strings.TrimSpace(kind)) + "|" + + strings.ToLower(strings.TrimSpace(in)) +} + +func applyDeclaredStateOptions(state *plan.State, tail string) { + if state == nil || strings.TrimSpace(tail) == "" { + return + } + cursor := newOptionCursor(tail) + for cursor.next() { + name, args := cursor.option() + switch { + case strings.EqualFold(name, "WithURI"): + if len(args) == 1 { + state.URI = trimQuote(args[0]) + } + case strings.EqualFold(name, "Optional"): + required := false + state.Required = &required + case strings.EqualFold(name, "Required"): + required := true + state.Required = &required + case strings.EqualFold(name, "Cacheable"): + if len(args) == 1 { + if value, err := strconv.ParseBool(strings.TrimSpace(trimQuote(args[0]))); err == nil { + state.Cacheable = &value + } + } + case strings.EqualFold(name, "QuerySelector"): + if len(args) == 1 { + state.QuerySelector = trimQuote(args[0]) + if state.Cacheable == nil { + cacheable := false + state.Cacheable = &cacheable + } + } + case strings.EqualFold(name, "WithPredicate"), strings.EqualFold(name, "Predicate"): + appendStatePredicate(state, args, false) + case strings.EqualFold(name, "EnsurePredicate"): + appendStatePredicate(state, args, true) + case strings.EqualFold(name, "When"): + if len(args) == 1 { + state.When = trimQuote(args[0]) + } + case strings.EqualFold(name, "Scope"): + if len(args) == 1 { + state.Scope = trimQuote(args[0]) + } + case strings.EqualFold(name, "WithType"): + if len(args) == 1 { + state.DataType = trimQuote(args[0]) + } + case strings.EqualFold(name, "Value"): + if len(args) == 1 { + state.Value = trimQuote(args[0]) + } + case strings.EqualFold(name, "Async"): + state.Async = true + } + } +} + +func appendStatePredicate(state *plan.State, args []string, ensure bool) { + if state == nil || len(args) == 0 { + return + } + group := 0 + nameIdx := 0 + if len(args) >= 2 { + if parsed, err := strconv.Atoi(strings.TrimSpace(trimQuote(args[0]))); err == nil { + group = parsed + nameIdx = 1 + } + } + if len(args) <= nameIdx { + return + } + predicate := &plan.StatePredicate{ + Group: group, + Name: trimQuote(args[nameIdx]), + Ensure: ensure, + Arguments: []string{}, + } + for _, arg := range args[nameIdx+1:] { + predicate.Arguments = append(predicate.Arguments, trimQuote(arg)) + } + state.Predicates = append(state.Predicates, predicate) +} + +type optionCursor struct { + raw string + cursor int + name string + args []string +} + +func newOptionCursor(raw string) *optionCursor { + return &optionCursor{raw: raw} +} + +func (o *optionCursor) next() bool { + o.name = "" + o.args = nil + for o.cursor < len(o.raw) && (o.raw[o.cursor] == ' ' || o.raw[o.cursor] == '\n' || o.raw[o.cursor] == '\t' || o.raw[o.cursor] == '\r') { + o.cursor++ + } + if o.cursor >= len(o.raw) || o.raw[o.cursor] != '.' { + return false + } + o.cursor++ + start := o.cursor + for o.cursor < len(o.raw) { + ch := o.raw[o.cursor] + if (ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '_' { + o.cursor++ + continue + } + break + } + if o.cursor == start { + return false + } + o.name = strings.TrimSpace(o.raw[start:o.cursor]) + for o.cursor < len(o.raw) && (o.raw[o.cursor] == ' ' || o.raw[o.cursor] == '\n' || o.raw[o.cursor] == '\t' || o.raw[o.cursor] == '\r') { + o.cursor++ + } + if o.cursor >= len(o.raw) || o.raw[o.cursor] != '(' { + return false + } + groupStart := o.cursor + depth := 0 + inSingle := false + inDouble := false + escape := false + for o.cursor < len(o.raw) { + ch := o.raw[o.cursor] + if escape { + escape = false + o.cursor++ + continue + } + switch ch { + case '\\': + escape = true + case '\'': + if !inDouble { + inSingle = !inSingle + } + case '"': + if !inSingle { + inDouble = !inDouble + } + case '(': + if !inSingle && !inDouble { + depth++ + } + case ')': + if !inSingle && !inDouble { + depth-- + if depth == 0 { + o.cursor++ + content := o.raw[groupStart+1 : o.cursor-1] + o.args = splitArgs(content) + return true + } + } + } + o.cursor++ + } + return false +} + +func (o *optionCursor) option() (string, []string) { + return o.name, o.args +} diff --git a/repository/shape/compile/statedecl_test.go b/repository/shape/compile/statedecl_test.go new file mode 100644 index 00000000..a538e7c3 --- /dev/null +++ b/repository/shape/compile/statedecl_test.go @@ -0,0 +1,71 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape/plan" +) + +func TestAppendDeclaredStates(t *testing.T) { + dql := ` +#set($_ = $Jwt(header/Authorization).WithCodec(JwtClaim).WithStatusCode(401)) +#set($_ = $Name(query/name).WithPredicate(0,'contains','sl','NAME').Optional()) +#set($_ = $Fields<[]string>(query/fields).QuerySelector(site_list)) +#set($_ = $Meta(output/summary)) +SELECT id FROM SITE_LIST sl` + result := &plan.Result{} + appendDeclaredStates(dql, result) + require.NotEmpty(t, result.States) + + byName := map[string]*plan.State{} + for _, item := range result.States { + if item != nil { + byName[item.Name] = item + } + } + require.NotNil(t, byName["Jwt"]) + assert.Equal(t, "header", byName["Jwt"].Kind) + require.NotNil(t, byName["Jwt"].Required) + assert.True(t, *byName["Jwt"].Required) + + require.NotNil(t, byName["Name"]) + assert.Equal(t, "query", byName["Name"].Kind) + require.NotNil(t, byName["Name"].Required) + assert.False(t, *byName["Name"].Required) + require.Len(t, byName["Name"].Predicates, 1) + assert.Equal(t, "contains", byName["Name"].Predicates[0].Name) + assert.Equal(t, 0, byName["Name"].Predicates[0].Group) + + require.NotNil(t, byName["Fields"]) + assert.Equal(t, "site_list", byName["Fields"].QuerySelector) + require.NotNil(t, byName["Fields"].Cacheable) + assert.False(t, *byName["Fields"].Cacheable) +} + +func TestAppendDeclaredStates_DuplicateDeclaration_FirstWins(t *testing.T) { + dql := ` +#set($_ = $Active(query/active).WithPredicate(0,'equal','tas','IS_TARGETABLE').Optional()) +#set($_ = $Active(query/active).WithPredicate(0,'equal','tas','ACTIVE').Optional()) +SELECT id FROM CI_TV_AFFILIATE_STATION tas` + result := &plan.Result{} + appendDeclaredStates(dql, result) + require.Len(t, result.States, 1) + require.Len(t, result.States[0].Predicates, 1) + assert.Equal(t, "Active", result.States[0].Name) + assert.Equal(t, "IS_TARGETABLE", result.States[0].Predicates[0].Arguments[1]) +} + +func TestAppendDeclaredStates_SupportsDefineDirective(t *testing.T) { + dql := ` +#define($_ = $Auth(header/Authorization).Required()) +SELECT id FROM USERS u` + result := &plan.Result{} + appendDeclaredStates(dql, result) + require.Len(t, result.States, 1) + assert.Equal(t, "Auth", result.States[0].Name) + assert.Equal(t, "header", result.States[0].Kind) + require.NotNil(t, result.States[0].Required) + assert.True(t, *result.States[0].Required) +} diff --git a/repository/shape/compile/typectx_defaults.go b/repository/shape/compile/typectx_defaults.go new file mode 100644 index 00000000..5bc0a9d9 --- /dev/null +++ b/repository/shape/compile/typectx_defaults.go @@ -0,0 +1,158 @@ +package compile + +import ( + "os" + "path" + "path/filepath" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/typectx" + "golang.org/x/mod/modfile" +) + +func applyTypeContextDefaults(ctx *typectx.Context, source *shape.Source, opts *shape.CompileOptions, layout compilePathLayout) *typectx.Context { + ret := cloneTypeContext(ctx) + if shouldInferTypeContext(opts) { + ret = mergeTypeContext(ret, inferDatlyGenTypeContext(source, layout)) + } + if opts != nil { + ret = ensureTypeContext(ret) + if ret != nil { + if value := strings.TrimSpace(opts.TypePackageDir); value != "" { + ret.PackageDir = value + } + if value := strings.TrimSpace(opts.TypePackageName); value != "" { + ret.PackageName = value + } + if value := strings.TrimSpace(opts.TypePackagePath); value != "" { + ret.PackagePath = value + } + } + } + return normalizeTypeContext(ret) +} + +func shouldInferTypeContext(opts *shape.CompileOptions) bool { + if opts == nil || opts.InferTypeContext == nil { + return true + } + return *opts.InferTypeContext +} + +func mergeTypeContext(dst *typectx.Context, src *typectx.Context) *typectx.Context { + if src == nil { + return dst + } + dst = ensureTypeContext(dst) + if strings.TrimSpace(dst.DefaultPackage) == "" { + dst.DefaultPackage = strings.TrimSpace(src.DefaultPackage) + } + if len(dst.Imports) == 0 && len(src.Imports) > 0 { + dst.Imports = append([]typectx.Import{}, src.Imports...) + } + if strings.TrimSpace(dst.PackageDir) == "" { + dst.PackageDir = strings.TrimSpace(src.PackageDir) + } + if strings.TrimSpace(dst.PackageName) == "" { + dst.PackageName = strings.TrimSpace(src.PackageName) + } + if strings.TrimSpace(dst.PackagePath) == "" { + dst.PackagePath = strings.TrimSpace(src.PackagePath) + } + return dst +} + +func inferDatlyGenTypeContext(source *shape.Source, layout compilePathLayout) *typectx.Context { + if source == nil { + return nil + } + sourcePath := strings.TrimSpace(source.Path) + if sourcePath == "" { + return nil + } + normalizedPath := filepath.ToSlash(filepath.Clean(sourcePath)) + idx := strings.Index(normalizedPath, layout.dqlMarker) + if idx == -1 { + return nil + } + projectRoot := filepath.FromSlash(strings.TrimSuffix(normalizedPath[:idx], "/")) + rel := strings.TrimPrefix(normalizedPath[idx+len(layout.dqlMarker):], "/") + if rel == "" { + return nil + } + routeDir := strings.Trim(path.Dir(rel), "/") + if routeDir == "." { + routeDir = "" + } + packageDir := "pkg" + if routeDir != "" { + packageDir = path.Join(packageDir, routeDir) + } + packageName := "main" + if routeDir != "" { + packageName = path.Base(routeDir) + } + packagePath := "" + if module := detectModulePath(projectRoot); module != "" { + packagePath = path.Join(module, packageDir) + } + return normalizeTypeContext(&typectx.Context{ + PackageDir: packageDir, + PackageName: packageName, + PackagePath: packagePath, + }) +} + +func detectModulePath(projectRoot string) string { + if strings.TrimSpace(projectRoot) == "" { + return "" + } + goModPath := filepath.Join(projectRoot, "go.mod") + data, err := os.ReadFile(goModPath) + if err != nil { + return "" + } + parsed, err := modfile.Parse(goModPath, data, nil) + if err != nil || parsed == nil || parsed.Module == nil { + return "" + } + return strings.TrimSpace(parsed.Module.Mod.Path) +} + +func ensureTypeContext(ctx *typectx.Context) *typectx.Context { + if ctx != nil { + return ctx + } + return &typectx.Context{} +} + +func cloneTypeContext(ctx *typectx.Context) *typectx.Context { + if ctx == nil { + return nil + } + ret := &typectx.Context{ + DefaultPackage: strings.TrimSpace(ctx.DefaultPackage), + PackageDir: strings.TrimSpace(ctx.PackageDir), + PackageName: strings.TrimSpace(ctx.PackageName), + PackagePath: strings.TrimSpace(ctx.PackagePath), + } + if len(ctx.Imports) > 0 { + ret.Imports = append([]typectx.Import{}, ctx.Imports...) + } + return ret +} + +func normalizeTypeContext(ctx *typectx.Context) *typectx.Context { + if ctx == nil { + return nil + } + if strings.TrimSpace(ctx.DefaultPackage) == "" && + len(ctx.Imports) == 0 && + strings.TrimSpace(ctx.PackageDir) == "" && + strings.TrimSpace(ctx.PackageName) == "" && + strings.TrimSpace(ctx.PackagePath) == "" { + return nil + } + return ctx +} diff --git a/repository/shape/compile/typectx_defaults_test.go b/repository/shape/compile/typectx_defaults_test.go new file mode 100644 index 00000000..4f0d01a3 --- /dev/null +++ b/repository/shape/compile/typectx_defaults_test.go @@ -0,0 +1,70 @@ +package compile + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/typectx" +) + +func TestApplyTypeContextDefaults_Matrix(t *testing.T) { + layout := defaultCompilePathLayout() + + projectDir := t.TempDir() + err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module github.vianttech.com/viant/platform\n\ngo 1.23\n"), 0o644) + require.NoError(t, err) + source := &shape.Source{ + Path: filepath.Join(projectDir, "dql", "platform", "taxonomy", "taxonomy.dql"), + } + + t.Run("inferred only", func(t *testing.T) { + got := applyTypeContextDefaults(nil, source, nil, layout) + require.NotNil(t, got) + require.Equal(t, "pkg/platform/taxonomy", got.PackageDir) + require.Equal(t, "taxonomy", got.PackageName) + require.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/taxonomy", got.PackagePath) + }) + + t.Run("directive context wins over inferred", func(t *testing.T) { + input := &typectx.Context{ + DefaultPackage: "github.com/acme/manual", + PackageDir: "pkg/manual", + PackageName: "manual", + PackagePath: "github.com/acme/manual", + } + got := applyTypeContextDefaults(input, source, nil, layout) + require.NotNil(t, got) + require.Equal(t, "pkg/manual", got.PackageDir) + require.Equal(t, "manual", got.PackageName) + require.Equal(t, "github.com/acme/manual", got.PackagePath) + require.Equal(t, "github.com/acme/manual", got.DefaultPackage) + }) + + t.Run("compile override wins over both", func(t *testing.T) { + input := &typectx.Context{ + PackageDir: "pkg/manual", + PackageName: "manual", + PackagePath: "github.com/acme/manual", + } + got := applyTypeContextDefaults(input, source, &shape.CompileOptions{ + TypePackageDir: "pkg/override", + TypePackageName: "override", + TypePackagePath: "github.com/acme/override", + }, layout) + require.NotNil(t, got) + require.Equal(t, "pkg/override", got.PackageDir) + require.Equal(t, "override", got.PackageName) + require.Equal(t, "github.com/acme/override", got.PackagePath) + }) + + t.Run("explicitly disable inference", func(t *testing.T) { + disabled := false + got := applyTypeContextDefaults(nil, source, &shape.CompileOptions{ + InferTypeContext: &disabled, + }, layout) + require.Nil(t, got) + }) +} diff --git a/repository/shape/compile/typectx_diagnostics.go b/repository/shape/compile/typectx_diagnostics.go new file mode 100644 index 00000000..36cc701f --- /dev/null +++ b/repository/shape/compile/typectx_diagnostics.go @@ -0,0 +1,37 @@ +package compile + +import ( + "fmt" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/typectx" +) + +func typeContextDiagnostics(ctx *typectx.Context, strict bool) []*dqlshape.Diagnostic { + issues := typectx.Validate(ctx) + if len(issues) == 0 { + return nil + } + severity := dqlshape.SeverityWarning + if strict { + severity = dqlshape.SeverityError + } + diags := make([]*dqlshape.Diagnostic, 0, len(issues)) + for _, issue := range issues { + if issue.Field == "" || issue.Message == "" { + continue + } + diags = append(diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeTypeCtxInvalid, + Severity: severity, + Message: fmt.Sprintf("type context %s: %s", issue.Field, issue.Message), + Hint: "set consistent TypeContext package fields or use compile overrides", + Span: dqlshape.Span{ + Start: dqlshape.Position{Line: 1, Char: 1}, + End: dqlshape.Position{Line: 1, Char: 1}, + }, + }) + } + return diags +} diff --git a/repository/shape/compile/viewdecl.go b/repository/shape/compile/viewdecl.go new file mode 100644 index 00000000..9e6c14c8 --- /dev/null +++ b/repository/shape/compile/viewdecl.go @@ -0,0 +1,107 @@ +package compile + +import ( + "fmt" + "strings" + + "github.com/viant/datly/repository/shape/compile/pipeline" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/parsly" + "github.com/viant/parsly/matcher" +) + +type declaredView struct { + Name string + SQL string + URI string + Connector string + Cardinality string + Tag string + Codec string + CodecArgs []string + HandlerName string + HandlerArgs []string + StatusCode *int + ErrorMessage string + QuerySelector string + CacheRef string + Limit *int + Cacheable *bool + When string + Scope string + DataType string + Of string + Value string + Async bool + Output bool + Predicates []declaredPredicate +} + +type declaredPredicate struct { + Name string + Source string + Ensure bool + Arguments []string +} + +const ( + vdWhitespaceToken = iota + vdSetToken + vdDefineToken + vdExprGroupToken + vdCommentToken + vdParamDeclToken + vdTypeToken + vdDotToken +) + +var ( + vdWhitespaceMatcher = parsly.NewToken(vdWhitespaceToken, "Whitespace", matcher.NewWhiteSpace()) + vdSetMatcher = parsly.NewToken(vdSetToken, "#set", matcher.NewFragment("#set")) + vdDefineMatcher = parsly.NewToken(vdDefineToken, "#define", matcher.NewFragment("#define")) + vdExprGroupMatcher = parsly.NewToken(vdExprGroupToken, "( ... )", matcher.NewBlock('(', ')', '\\')) + vdCommentMatcher = parsly.NewToken(vdCommentToken, "Comment", matcher.NewSeqBlock("/*", "*/")) + vdParamDeclMatcher = parsly.NewToken(vdParamDeclToken, "$_ = $", matcher.NewSpacedSet([]string{"$_ = $"})) + vdTypeMatcher = parsly.NewToken(vdTypeToken, "< ... >", matcher.NewSeqBlock("<", ">")) + vdDotMatcher = parsly.NewToken(vdDotToken, ".", matcher.NewByte('.')) +) + +func extractDeclaredViews(dql string) ([]*declaredView, []*dqlshape.Diagnostic) { + if strings.TrimSpace(dql) == "" { + return nil, nil + } + var views []*declaredView + var diags []*dqlshape.Diagnostic + for _, block := range extractSetBlocks(dql) { + holder, kind, location, tail, ok := parseSetDeclarationBody(block.Body) + if !ok { + continue + } + if kind != "view" && kind != "data_view" { + continue + } + sqlText := extractDeclarationSQL(tail) + if sqlText == "" { + diags = append(diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeViewMissingSQL, + Severity: dqlshape.SeverityWarning, + Message: fmt.Sprintf("view declaration %q has no inline SQL hint", location), + Hint: "use /* SELECT ... */ in declaration comment to derive an additional view", + Span: relationSpan(dql, block.Offset), + }) + continue + } + name := pipeline.SanitizeName(location) + if name == "" { + name = pipeline.SanitizeName(holder) + } + if name == "" { + continue + } + view := &declaredView{Name: name, SQL: strings.TrimSpace(sqlText)} + applyDeclaredViewOptions(view, tail, dql, block.Offset, &diags) + views = append(views, view) + } + return views, diags +} diff --git a/repository/shape/compile/viewdecl_append.go b/repository/shape/compile/viewdecl_append.go new file mode 100644 index 00000000..dabf0fd2 --- /dev/null +++ b/repository/shape/compile/viewdecl_append.go @@ -0,0 +1,155 @@ +package compile + +import ( + "reflect" + "regexp" + "strings" + + "github.com/viant/datly/repository/shape/compile/pipeline" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/sqlparser" +) + +var summaryParentRefExpr = regexp.MustCompile(`(?i)\$View\.([a-zA-Z_][a-zA-Z0-9_]*)\.SQL\b`) + +func appendDeclaredViews(rawDQL string, result *plan.Result) { + if result == nil { + return + } + declared, diags := extractDeclaredViews(rawDQL) + if len(diags) > 0 { + result.Diagnostics = append(result.Diagnostics, diags...) + } + for _, item := range declared { + if item == nil || strings.TrimSpace(item.Name) == "" || strings.TrimSpace(item.SQL) == "" { + continue + } + if parent := lookupSummaryParentView(result, item.SQL); parent != nil { + if strings.TrimSpace(parent.Summary) == "" { + parent.Summary = strings.TrimSpace(item.SQL) + } + continue + } + if _, exists := result.ViewsByName[item.Name]; exists { + continue + } + view := &plan.View{ + Path: item.Name, + Holder: item.Name, + Name: item.Name, + Table: item.Name, + SQL: item.SQL, + SQLURI: item.URI, + Connector: item.Connector, + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + Declaration: buildViewDeclaration(item), + } + if item.Cardinality != "" { + view.Cardinality = item.Cardinality + } + if queryNode, err := sqlparser.ParseQuery(item.SQL); err == nil && queryNode != nil { + if inferredName, inferredTable, err := pipeline.InferRoot(queryNode, item.Name); err == nil { + view.Name = inferredName + view.Holder = inferredName + view.Path = inferredName + view.Table = inferredTable + } + if fType, eType, card := pipeline.InferProjectionType(queryNode); fType != nil && eType != nil { + view.FieldType = fType + view.ElementType = eType + if item.Cardinality == "" { + view.Cardinality = card + } + } + } + result.Views = append(result.Views, view) + result.ViewsByName[view.Name] = view + } +} + +func lookupSummaryParentView(result *plan.Result, sqlText string) *plan.View { + if result == nil || strings.TrimSpace(sqlText) == "" { + return nil + } + matches := summaryParentRefExpr.FindStringSubmatch(sqlText) + if len(matches) < 2 { + return nil + } + parent := strings.TrimSpace(matches[1]) + if parent == "" { + return nil + } + if view, ok := result.ViewsByName[parent]; ok && view != nil { + return view + } + sanitized := pipeline.SanitizeName(parent) + if sanitized != "" { + if view, ok := result.ViewsByName[sanitized]; ok && view != nil { + return view + } + } + for name, view := range result.ViewsByName { + if view == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(name), parent) || (sanitized != "" && strings.EqualFold(strings.TrimSpace(name), sanitized)) { + return view + } + } + for _, view := range result.Views { + if view == nil { + continue + } + if strings.EqualFold(strings.TrimSpace(view.Name), parent) || (sanitized != "" && strings.EqualFold(strings.TrimSpace(view.Name), sanitized)) { + return view + } + } + return nil +} + +func buildViewDeclaration(item *declaredView) *plan.ViewDeclaration { + if item == nil { + return nil + } + ret := &plan.ViewDeclaration{ + Tag: item.Tag, + Codec: item.Codec, + CodecArgs: append([]string{}, item.CodecArgs...), + HandlerName: item.HandlerName, + HandlerArgs: append([]string{}, item.HandlerArgs...), + StatusCode: item.StatusCode, + ErrorMessage: item.ErrorMessage, + QuerySelector: item.QuerySelector, + CacheRef: item.CacheRef, + Limit: item.Limit, + Cacheable: item.Cacheable, + When: item.When, + Scope: item.Scope, + DataType: item.DataType, + Of: item.Of, + Value: item.Value, + Async: item.Async, + Output: item.Output, + } + if len(item.Predicates) > 0 { + ret.Predicates = make([]*plan.ViewPredicate, 0, len(item.Predicates)) + for _, predicate := range item.Predicates { + ret.Predicates = append(ret.Predicates, &plan.ViewPredicate{ + Name: predicate.Name, + Source: predicate.Source, + Ensure: predicate.Ensure, + Arguments: append([]string{}, predicate.Arguments...), + }) + } + } + if ret.Tag == "" && ret.Codec == "" && len(ret.CodecArgs) == 0 && ret.HandlerName == "" && + len(ret.HandlerArgs) == 0 && ret.StatusCode == nil && ret.ErrorMessage == "" && + ret.QuerySelector == "" && ret.CacheRef == "" && ret.Limit == nil && ret.Cacheable == nil && + ret.When == "" && ret.Scope == "" && ret.DataType == "" && ret.Of == "" && ret.Value == "" && + !ret.Async && !ret.Output && len(ret.Predicates) == 0 { + return nil + } + return ret +} diff --git a/repository/shape/compile/viewdecl_options.go b/repository/shape/compile/viewdecl_options.go new file mode 100644 index 00000000..dd8ea2fb --- /dev/null +++ b/repository/shape/compile/viewdecl_options.go @@ -0,0 +1,382 @@ +package compile + +import ( + "fmt" + "strconv" + "strings" + + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/parsly" +) + +func extractDeclarationSQL(fragment string) string { + cursor := parsly.NewCursor("", []byte(fragment), 0) + for cursor.Pos < cursor.InputSize { + match := cursor.MatchAfterOptional(vdWhitespaceMatcher, vdCommentMatcher) + if match.Code == vdCommentToken { + text := match.Text(cursor) + if len(text) < 4 { + return "" + } + return normalizeHintSQL(text[2 : len(text)-2]) + } + cursor.Pos++ + } + return "" +} + +func normalizeHintSQL(body string) string { + body = strings.TrimSpace(body) + if body == "" { + return "" + } + if strings.HasPrefix(body, "{") { + if closeIdx := strings.Index(body, "}"); closeIdx != -1 { + body = strings.TrimSpace(body[closeIdx+1:]) + } + } + if body == "" { + return "" + } + switch body[0] { + case '?': + body = strings.TrimSpace(body[1:]) + case '!': + body = strings.TrimSpace(body[1:]) + if strings.HasPrefix(body, "!") { + body = strings.TrimSpace(body[1:]) + } + if len(body) >= 3 { + var status int + if _, err := fmt.Sscanf(body[:3], "%d", &status); err == nil { + body = strings.TrimSpace(body[3:]) + } + } + } + return strings.TrimSpace(body) +} + +func applyDeclaredViewOptions(view *declaredView, tail, dql string, offset int, diags *[]*dqlshape.Diagnostic) { + if view == nil || strings.TrimSpace(tail) == "" { + return + } + cursor := parsly.NewCursor("", []byte(tail), 0) + for cursor.Pos < cursor.InputSize { + _ = cursor.MatchOne(vdWhitespaceMatcher) + if cursor.MatchOne(vdDotMatcher).Code != vdDotToken { + cursor.Pos++ + continue + } + _ = cursor.MatchOne(vdWhitespaceMatcher) + name, ok := readIdentifier(cursor) + if !ok { + continue + } + _ = cursor.MatchOne(vdWhitespaceMatcher) + group := cursor.MatchOne(vdExprGroupMatcher) + if group.Code != vdExprGroupToken { + continue + } + content := group.Text(cursor) + if len(content) < 2 { + continue + } + args := splitArgs(content[1 : len(content)-1]) + switch { + case strings.EqualFold(name, "WithURI"): + if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + continue + } + view.URI = trimQuote(args[0]) + case strings.EqualFold(name, "WithConnector"), strings.EqualFold(name, "Connector"): + if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + continue + } + view.Connector = trimQuote(args[0]) + case strings.EqualFold(name, "Cardinality"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + card := strings.ToLower(strings.TrimSpace(trimQuote(args[0]))) + switch card { + case "one", "many": + view.Cardinality = card + default: + *diags = append(*diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeViewCardinality, + Severity: dqlshape.SeverityWarning, + Message: fmt.Sprintf("unsupported cardinality %q for declared view %q", args[0], view.Name), + Hint: "use Cardinality('one') or Cardinality('many')", + Span: relationSpan(dql, offset), + }) + } + case strings.EqualFold(name, "WithTag"), strings.EqualFold(name, "Tag"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.Tag = trimQuote(args[0]) + case strings.EqualFold(name, "WithCodec"), strings.EqualFold(name, "Codec"): + if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + continue + } + view.Codec = trimQuote(args[0]) + view.CodecArgs = nil + for _, arg := range args[1:] { + view.CodecArgs = append(view.CodecArgs, strings.TrimSpace(arg)) + } + case strings.EqualFold(name, "WithHandler"), strings.EqualFold(name, "Handler"): + if !expectArgs(view, name, args, 1, -1, dql, offset, diags) { + continue + } + view.HandlerName = trimQuote(args[0]) + view.HandlerArgs = nil + for _, arg := range args[1:] { + view.HandlerArgs = append(view.HandlerArgs, strings.TrimSpace(arg)) + } + case strings.EqualFold(name, "WithStatusCode"), strings.EqualFold(name, "StatusCode"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + statusCode, err := strconv.Atoi(strings.TrimSpace(trimQuote(args[0]))) + if err != nil { + *diags = append(*diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDeclOptionArgs, + Severity: dqlshape.SeverityWarning, + Message: fmt.Sprintf("invalid status code %q for declared view %q", args[0], view.Name), + Hint: "use numeric status code, e.g. StatusCode(400)", + Span: relationSpan(dql, offset), + }) + continue + } + view.StatusCode = &statusCode + case strings.EqualFold(name, "WithErrorMessage"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.ErrorMessage = trimQuote(args[0]) + case strings.EqualFold(name, "WithPredicate"), strings.EqualFold(name, "Predicate"): + if !expectArgs(view, name, args, 2, -1, dql, offset, diags) { + continue + } + view.Predicates = append(view.Predicates, declaredPredicate{ + Name: trimQuote(args[0]), + Source: trimQuote(args[1]), + Arguments: append([]string{}, args[2:]...), + }) + case strings.EqualFold(name, "EnsurePredicate"): + if !expectArgs(view, name, args, 2, -1, dql, offset, diags) { + continue + } + view.Predicates = append(view.Predicates, declaredPredicate{ + Name: trimQuote(args[0]), + Source: trimQuote(args[1]), + Ensure: true, + Arguments: append([]string{}, args[2:]...), + }) + case strings.EqualFold(name, "QuerySelector"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.QuerySelector = trimQuote(args[0]) + if !isAllowedQuerySelector(strings.ToLower(view.Name)) { + *diags = append(*diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDeclQuerySelector, + Severity: dqlshape.SeverityWarning, + Message: fmt.Sprintf("query selector %q can only be used with limit, offset, page, fields, orderby", view.QuerySelector), + Hint: "use QuerySelector on declarations named limit/offset/page/fields/orderby", + Span: relationSpan(dql, offset), + }) + } + case strings.EqualFold(name, "WithCache"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.CacheRef = trimQuote(args[0]) + case strings.EqualFold(name, "WithLimit"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + limit, err := strconv.Atoi(strings.TrimSpace(trimQuote(args[0]))) + if err != nil { + appendOptionArgDiagnostic(view, name, fmt.Sprintf("invalid integer limit %q", args[0]), dql, offset, diags) + continue + } + view.Limit = &limit + case strings.EqualFold(name, "Cacheable"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + value, err := strconv.ParseBool(strings.TrimSpace(trimQuote(args[0]))) + if err != nil { + appendOptionArgDiagnostic(view, name, fmt.Sprintf("invalid bool cacheable %q", args[0]), dql, offset, diags) + continue + } + view.Cacheable = &value + case strings.EqualFold(name, "When"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.When = trimQuote(args[0]) + case strings.EqualFold(name, "Scope"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.Scope = trimQuote(args[0]) + case strings.EqualFold(name, "WithType"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.DataType = trimQuote(args[0]) + case strings.EqualFold(name, "Of"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.Of = trimQuote(args[0]) + case strings.EqualFold(name, "Value"): + if !expectArgs(view, name, args, 1, 1, dql, offset, diags) { + continue + } + view.Value = trimQuote(args[0]) + case strings.EqualFold(name, "Async"): + if !expectArgs(view, name, args, 0, 0, dql, offset, diags) { + continue + } + view.Async = true + case strings.EqualFold(name, "Output"): + if !expectArgs(view, name, args, 0, 0, dql, offset, diags) { + continue + } + view.Output = true + } + } +} + +func splitArgs(raw string) []string { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + var result []string + var current strings.Builder + inSingle := false + inDouble := false + escape := false + parens := 0 + brackets := 0 + braces := 0 + for i := 0; i < len(raw); i++ { + ch := raw[i] + if escape { + current.WriteByte(ch) + escape = false + continue + } + switch ch { + case '\\': + current.WriteByte(ch) + escape = true + case '\'': + if !inDouble { + inSingle = !inSingle + } + current.WriteByte(ch) + case '"': + if !inSingle { + inDouble = !inDouble + } + current.WriteByte(ch) + case '(': + if !inSingle && !inDouble { + parens++ + } + current.WriteByte(ch) + case ')': + if !inSingle && !inDouble && parens > 0 { + parens-- + } + current.WriteByte(ch) + case '[': + if !inSingle && !inDouble { + brackets++ + } + current.WriteByte(ch) + case ']': + if !inSingle && !inDouble && brackets > 0 { + brackets-- + } + current.WriteByte(ch) + case '{': + if !inSingle && !inDouble { + braces++ + } + current.WriteByte(ch) + case '}': + if !inSingle && !inDouble && braces > 0 { + braces-- + } + current.WriteByte(ch) + case ',': + if inSingle || inDouble || parens > 0 || brackets > 0 || braces > 0 { + current.WriteByte(ch) + continue + } + part := strings.TrimSpace(current.String()) + if part != "" { + result = append(result, part) + } + current.Reset() + default: + current.WriteByte(ch) + } + } + if tail := strings.TrimSpace(current.String()); tail != "" { + result = append(result, tail) + } + return result +} + +func trimQuote(v string) string { + v = strings.TrimSpace(v) + if len(v) >= 2 { + if (v[0] == '\'' && v[len(v)-1] == '\'') || (v[0] == '"' && v[len(v)-1] == '"') { + return v[1 : len(v)-1] + } + } + return v +} + +func expectArgs(view *declaredView, option string, args []string, min, max int, dql string, offset int, diags *[]*dqlshape.Diagnostic) bool { + if len(args) < min { + appendOptionArgDiagnostic(view, option, fmt.Sprintf("expected at least %d args, got %d", min, len(args)), dql, offset, diags) + return false + } + if max >= 0 && len(args) > max { + appendOptionArgDiagnostic(view, option, fmt.Sprintf("expected at most %d args, got %d", max, len(args)), dql, offset, diags) + return false + } + return true +} + +func appendOptionArgDiagnostic(view *declaredView, option, detail, dql string, offset int, diags *[]*dqlshape.Diagnostic) { + viewName := "" + if view != nil { + viewName = view.Name + } + *diags = append(*diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeDeclOptionArgs, + Severity: dqlshape.SeverityWarning, + Message: fmt.Sprintf("invalid %s declaration for view %q: %s", option, viewName, detail), + Hint: "check option arity and argument formatting", + Span: relationSpan(dql, offset), + }) +} + +func isAllowedQuerySelector(name string) bool { + switch strings.ToLower(strings.TrimSpace(name)) { + case "limit", "offset", "page", "fields", "orderby": + return true + default: + return false + } +} diff --git a/repository/shape/compile/viewdecl_parity_test.go b/repository/shape/compile/viewdecl_parity_test.go new file mode 100644 index 00000000..03d45552 --- /dev/null +++ b/repository/shape/compile/viewdecl_parity_test.go @@ -0,0 +1,66 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" +) + +func TestViewDecl_ParityFixtures(t *testing.T) { + testCases := []struct { + name string + viewName string + tail string + expectDiag string + expectTag string + expectCodec string + expectHandler string + expectPreds int + }{ + { + name: "tag/codec/handler", + viewName: "limit", + tail: ".WithTag('json:\"id\"').WithCodec(AsJSON).WithHandler('Build')", + expectTag: `json:"id"`, + expectCodec: "AsJSON", + expectHandler: "Build", + }, + { + name: "status arg validation", + viewName: "limit", + tail: ".WithStatusCode('x')", + expectDiag: dqldiag.CodeDeclOptionArgs, + }, + { + name: "query selector validation", + viewName: "customer_id", + tail: ".QuerySelector('items')", + expectDiag: dqldiag.CodeDeclQuerySelector, + }, + { + name: "predicate forms", + viewName: "limit", + tail: ".WithPredicate('ByID','id=?',1).EnsurePredicate('Tenant','tenant=?',2)", + expectPreds: 2, + }, + } + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + view := &declaredView{Name: testCase.viewName} + var diags []*dqlshape.Diagnostic + applyDeclaredViewOptions(view, testCase.tail, "SELECT 1", 0, &diags) + if testCase.expectDiag != "" { + require.NotEmpty(t, diags) + assert.Equal(t, testCase.expectDiag, diags[0].Code) + return + } + assert.Equal(t, testCase.expectTag, view.Tag) + assert.Equal(t, testCase.expectCodec, view.Codec) + assert.Equal(t, testCase.expectHandler, view.HandlerName) + assert.Len(t, view.Predicates, testCase.expectPreds) + }) + } +} diff --git a/repository/shape/compile/viewdecl_parse.go b/repository/shape/compile/viewdecl_parse.go new file mode 100644 index 00000000..51fd45a9 --- /dev/null +++ b/repository/shape/compile/viewdecl_parse.go @@ -0,0 +1,90 @@ +package compile + +import ( + "strings" + "unicode" + + "github.com/viant/parsly" +) + +type setBlock struct { + Offset int + Body string +} + +func extractSetBlocks(dql string) []setBlock { + cursor := parsly.NewCursor("", []byte(dql), 0) + var result []setBlock + for cursor.Pos < cursor.InputSize { + matched := cursor.MatchAfterOptional(vdWhitespaceMatcher, vdSetMatcher, vdDefineMatcher) + if matched.Code != vdSetToken && matched.Code != vdDefineToken { + cursor.Pos++ + continue + } + offset := cursor.Pos - len(matched.Text(cursor)) + group := cursor.MatchAfterOptional(vdWhitespaceMatcher, vdExprGroupMatcher) + if group.Code != vdExprGroupToken { + continue + } + body := group.Text(cursor) + if len(body) < 2 { + continue + } + result = append(result, setBlock{ + Offset: offset, + Body: body[1 : len(body)-1], + }) + } + return result +} + +func parseSetDeclarationBody(body string) (holder, kind, location, tail string, ok bool) { + cursor := parsly.NewCursor("", []byte(body), 0) + if cursor.MatchAfterOptional(vdWhitespaceMatcher, vdParamDeclMatcher).Code != vdParamDeclToken { + return "", "", "", "", false + } + id, matched := readIdentifier(cursor) + if !matched { + return "", "", "", "", false + } + holder = id + _ = cursor.MatchOne(vdWhitespaceMatcher) + _ = cursor.MatchOne(vdTypeMatcher) + _ = cursor.MatchOne(vdWhitespaceMatcher) + kindLoc := cursor.MatchOne(vdExprGroupMatcher) + if kindLoc.Code != vdExprGroupToken { + return "", "", "", "", false + } + inGroup := kindLoc.Text(cursor) + if len(inGroup) < 2 { + return "", "", "", "", false + } + raw := strings.TrimSpace(inGroup[1 : len(inGroup)-1]) + slash := strings.Index(raw, "/") + if slash == -1 { + return "", "", "", "", false + } + kind = strings.ToLower(strings.TrimSpace(raw[:slash])) + location = strings.TrimSpace(raw[slash+1:]) + tail = strings.TrimSpace(string(cursor.Input[cursor.Pos:])) + return holder, kind, location, tail, true +} + +func readIdentifier(cursor *parsly.Cursor) (string, bool) { + if cursor.Pos >= cursor.InputSize { + return "", false + } + start := cursor.Pos + for cursor.Pos < cursor.InputSize { + ch := rune(cursor.Input[cursor.Pos]) + if ch == '_' || ch == '$' || unicode.IsLetter(ch) || unicode.IsDigit(ch) { + cursor.Pos++ + continue + } + break + } + if cursor.Pos == start { + return "", false + } + return string(cursor.Input[start:cursor.Pos]), true +} diff --git a/repository/shape/compile/viewdecl_test.go b/repository/shape/compile/viewdecl_test.go new file mode 100644 index 00000000..0136c64a --- /dev/null +++ b/repository/shape/compile/viewdecl_test.go @@ -0,0 +1,187 @@ +package compile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + dqldiag "github.com/viant/datly/repository/shape/dql/diag" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + "github.com/viant/datly/repository/shape/plan" +) + +func TestViewDecl_ExtractSetBlocks(t *testing.T) { + dql := "#set($_ = $Extra(view/extra_view) /* SELECT id FROM EXTRA e */)\n" + + "#define($_ = $Extra2(view/extra_view_2) /* SELECT id FROM EXTRA2 e */)\n" + + "SELECT id FROM ORDERS o" + blocks := extractSetBlocks(dql) + require.Len(t, blocks, 2) + assert.Contains(t, blocks[0].Body, "$Extra") + assert.Contains(t, blocks[1].Body, "$Extra2") +} + +func TestViewDecl_ParseSetDeclarationBody(t *testing.T) { + holder, kind, location, tail, ok := parseSetDeclarationBody("$_ = $Extra(view/extra_view).WithURI('/x')") + require.True(t, ok) + assert.Equal(t, "Extra", holder) + assert.Equal(t, "view", kind) + assert.Equal(t, "extra_view", location) + assert.Contains(t, tail, ".WithURI('/x')") +} + +func TestViewDecl_ApplyOptions_InvalidCardinality(t *testing.T) { + view := &declaredView{Name: "extra"} + var diags []*dqlshape.Diagnostic + applyDeclaredViewOptions(view, ".Cardinality('few')", "SELECT 1", 0, &diags) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeViewCardinality, diags[0].Code) +} + +func TestViewDecl_AppendDeclaredViews(t *testing.T) { + dql := "#set($_ = $Extra(view/extra_view).WithURI('/x') /* SELECT code FROM EXTRA e */)" + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + appendDeclaredViews(dql, result) + require.NotEmpty(t, result.Views) + found := false + for _, item := range result.Views { + if item != nil && item.SQLURI == "/x" { + found = true + break + } + } + assert.True(t, found) +} + +func TestViewDecl_ApplyOptions_Extended(t *testing.T) { + view := &declaredView{Name: "limit"} + var diags []*dqlshape.Diagnostic + tail := ".WithTag('json:\"id\"').WithCodec(AsJSON,'x').WithHandler('Build',a,b)." + + "WithStatusCode(422).WithErrorMessage('bad req').WithPredicate('ByID','id = ?', 101)." + + "EnsurePredicate('Tenant','tenant_id = ?', 7).QuerySelector('qs').WithCache('c1').WithLimit(10)." + + "Cacheable(true).When('x > 1').Scope('team').WithType('[]Order').Of('list').Value('abc').Async().Output()" + applyDeclaredViewOptions(view, tail, "SELECT 1", 0, &diags) + + require.Empty(t, diags) + assert.Equal(t, `json:"id"`, view.Tag) + assert.Equal(t, "AsJSON", view.Codec) + require.Len(t, view.CodecArgs, 1) + assert.Equal(t, "'x'", view.CodecArgs[0]) + assert.Equal(t, "Build", view.HandlerName) + require.Len(t, view.HandlerArgs, 2) + assert.Equal(t, "a", view.HandlerArgs[0]) + assert.Equal(t, "b", view.HandlerArgs[1]) + require.NotNil(t, view.StatusCode) + assert.Equal(t, 422, *view.StatusCode) + assert.Equal(t, "bad req", view.ErrorMessage) + require.Len(t, view.Predicates, 2) + assert.Equal(t, "ByID", view.Predicates[0].Name) + assert.False(t, view.Predicates[0].Ensure) + assert.Equal(t, "Tenant", view.Predicates[1].Name) + assert.True(t, view.Predicates[1].Ensure) + assert.Equal(t, "qs", view.QuerySelector) + assert.Equal(t, "c1", view.CacheRef) + require.NotNil(t, view.Limit) + assert.Equal(t, 10, *view.Limit) + require.NotNil(t, view.Cacheable) + assert.True(t, *view.Cacheable) + assert.Equal(t, "x > 1", view.When) + assert.Equal(t, "team", view.Scope) + assert.Equal(t, "[]Order", view.DataType) + assert.Equal(t, "list", view.Of) + assert.Equal(t, "abc", view.Value) + assert.True(t, view.Async) + assert.True(t, view.Output) +} + +func TestViewDecl_ApplyOptions_QuerySelectorValidation(t *testing.T) { + view := &declaredView{Name: "customer_id"} + var diags []*dqlshape.Diagnostic + applyDeclaredViewOptions(view, ".QuerySelector('q')", "SELECT 1", 0, &diags) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeDeclQuerySelector, diags[0].Code) +} + +func TestViewDecl_SplitArgs_Nested(t *testing.T) { + args := splitArgs(`'a', fn(1,2), {'k': [1,2]}, "x,y"`) + require.Len(t, args, 4) + assert.Equal(t, "'a'", args[0]) + assert.Equal(t, "fn(1,2)", args[1]) + assert.Equal(t, "{'k': [1,2]}", args[2]) + assert.Equal(t, `"x,y"`, args[3]) +} + +func TestViewDecl_AppendDeclaredViews_ExtendedDeclarationMetadata(t *testing.T) { + dql := "#set($_ = $limit(view/limit).WithTag('json:\"id\"').WithCodec(AsJSON).WithHandler('Build',a)." + + "WithStatusCode(409).WithErrorMessage('conflict').WithPredicate('ByID','id=?',1)." + + "EnsurePredicate('Tenant','tenant=?',2).QuerySelector('items').WithCache('c1').WithLimit(5)." + + "Cacheable(false).When('x').Scope('s').WithType('Order').Of('o').Value('v').Async().Output() /* SELECT id FROM EXTRA e */)" + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + appendDeclaredViews(dql, result) + require.NotEmpty(t, result.Views) + var target *plan.View + for _, item := range result.Views { + if item != nil && item.Name == "e" { + target = item + break + } + } + require.NotNil(t, target) + require.NotNil(t, target.Declaration) + assert.Equal(t, `json:"id"`, target.Declaration.Tag) + assert.Equal(t, "AsJSON", target.Declaration.Codec) + assert.Equal(t, "Build", target.Declaration.HandlerName) + require.NotNil(t, target.Declaration.StatusCode) + assert.Equal(t, 409, *target.Declaration.StatusCode) + assert.Equal(t, "conflict", target.Declaration.ErrorMessage) + assert.Equal(t, "items", target.Declaration.QuerySelector) + assert.Equal(t, "c1", target.Declaration.CacheRef) + require.NotNil(t, target.Declaration.Limit) + assert.Equal(t, 5, *target.Declaration.Limit) + require.NotNil(t, target.Declaration.Cacheable) + assert.False(t, *target.Declaration.Cacheable) + assert.Equal(t, "x", target.Declaration.When) + assert.Equal(t, "s", target.Declaration.Scope) + assert.Equal(t, "Order", target.Declaration.DataType) + assert.Equal(t, "o", target.Declaration.Of) + assert.Equal(t, "v", target.Declaration.Value) + assert.True(t, target.Declaration.Async) + assert.True(t, target.Declaration.Output) + require.Len(t, target.Declaration.Predicates, 2) +} + +func TestViewDecl_AppendDeclaredViews_AttachSummaryFromMetaViewSQL(t *testing.T) { + root := &plan.View{Name: "Browser", Path: "Browser", Holder: "Browser"} + result := &plan.Result{ + Views: []*plan.View{root}, + ViewsByName: map[string]*plan.View{"Browser": root}, + ByPath: map[string]*plan.Field{}, + } + dql := "#set($_ = $Summary(view/summary) /* SELECT COUNT(1) CNT FROM ($View.browser.SQL) t */)" + + appendDeclaredViews(dql, result) + + require.Len(t, result.Views, 1) + require.NotNil(t, root) + assert.Contains(t, root.Summary, "COUNT(1)") + assert.Contains(t, root.Summary, "$View.browser.SQL") +} + +func TestViewDecl_AppendDeclaredViews_MetaViewSQL_NoParentFallbackToView(t *testing.T) { + result := &plan.Result{ + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + } + dql := "#set($_ = $Summary(view/summary) /* SELECT COUNT(1) CNT FROM ($View.browser.SQL) t */)" + + appendDeclaredViews(dql, result) + + require.Len(t, result.Views, 1) + assert.Empty(t, result.Views[0].Summary) + assert.NotEmpty(t, result.Views[0].Name) +} diff --git a/repository/shape/dql_engine_test.go b/repository/shape/dql_engine_test.go index fafe3f67..529748ea 100644 --- a/repository/shape/dql_engine_test.go +++ b/repository/shape/dql_engine_test.go @@ -40,3 +40,27 @@ func TestEngine_LoadDQLComponent(t *testing.T) { assert.Equal(t, "/v1/api/reports/orders", component.Name) assert.Equal(t, "t", component.RootView) } + +func TestEngine_LoadDQLComponent_DeclarationMetadata(t *testing.T) { + engine := shape.New( + shape.WithCompiler(shapeCompile.New()), + shape.WithLoader(shapeLoad.New()), + shape.WithName("/v1/api/reports/orders"), + ) + dql := ` +#set($_ = $limit(view/limit).WithPredicate('ByID','id = ?', 1).QuerySelector('items') /* SELECT id FROM ORDERS o */) +SELECT id FROM ORDERS t` + artifact, err := engine.LoadDQLComponent(context.Background(), dql) + require.NoError(t, err) + require.NotNil(t, artifact) + component, ok := artifact.Component.(*shapeLoad.Component) + require.True(t, ok) + require.NotNil(t, component.Declarations) + require.NotNil(t, component.QuerySelectors) + require.NotNil(t, component.Predicates) + assert.Equal(t, []string{"o"}, component.QuerySelectors["items"]) + require.NotNil(t, component.Declarations["o"]) + assert.Equal(t, "items", component.Declarations["o"].QuerySelector) + require.NotEmpty(t, component.Predicates["o"]) + assert.Equal(t, "ByID", component.Predicates["o"][0].Name) +} diff --git a/repository/shape/engine_compile_options_test.go b/repository/shape/engine_compile_options_test.go new file mode 100644 index 00000000..28de1a8f --- /dev/null +++ b/repository/shape/engine_compile_options_test.go @@ -0,0 +1,77 @@ +package shape + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type captureCompiler struct { + last CompileOptions +} + +func (c *captureCompiler) Compile(_ context.Context, source *Source, opts ...CompileOption) (*PlanResult, error) { + compiled := &CompileOptions{} + for _, opt := range opts { + if opt != nil { + opt(compiled) + } + } + c.last = *compiled + return &PlanResult{Source: source}, nil +} + +func TestEngine_Compile_UsesLegacyParityDefaults(t *testing.T) { + compiler := &captureCompiler{} + engine := New(WithCompiler(compiler)) + + _, err := engine.compile(context.Background(), &Source{Name: "orders", DQL: "SELECT 1"}) + require.NoError(t, err) + assert.False(t, compiler.last.Strict) + assert.Equal(t, CompileProfileCompat, compiler.last.Profile) + assert.Equal(t, CompileMixedModeExecWins, compiler.last.MixedMode) + assert.Equal(t, CompileUnknownNonReadWarn, compiler.last.UnknownNonReadMode) + assert.Equal(t, CompileColumnDiscoveryAuto, compiler.last.ColumnDiscoveryMode) +} + +func TestEngine_Compile_ForwardsCustomDefaults(t *testing.T) { + compiler := &captureCompiler{} + engine := New( + WithCompiler(compiler), + WithStrict(true), + WithCompileProfileDefault(CompileProfileStrict), + WithMixedModeDefault(CompileMixedModeReadWins), + WithUnknownNonReadModeDefault(CompileUnknownNonReadError), + WithColumnDiscoveryModeDefault(CompileColumnDiscoveryOff), + ) + + _, err := engine.compile(context.Background(), &Source{Name: "orders", DQL: "SELECT 1"}) + require.NoError(t, err) + assert.True(t, compiler.last.Strict) + assert.Equal(t, CompileProfileStrict, compiler.last.Profile) + assert.Equal(t, CompileMixedModeReadWins, compiler.last.MixedMode) + assert.Equal(t, CompileUnknownNonReadError, compiler.last.UnknownNonReadMode) + assert.Equal(t, CompileColumnDiscoveryOff, compiler.last.ColumnDiscoveryMode) +} + +func TestEngine_Compile_LegacyDefaultsOption(t *testing.T) { + compiler := &captureCompiler{} + engine := New( + WithCompiler(compiler), + WithStrict(true), + WithCompileProfileDefault(CompileProfileStrict), + WithMixedModeDefault(CompileMixedModeReadWins), + WithUnknownNonReadModeDefault(CompileUnknownNonReadError), + WithLegacyTranslatorDefaults(), + ) + + _, err := engine.compile(context.Background(), &Source{Name: "orders", DQL: "SELECT 1"}) + require.NoError(t, err) + assert.False(t, compiler.last.Strict) + assert.Equal(t, CompileProfileCompat, compiler.last.Profile) + assert.Equal(t, CompileMixedModeExecWins, compiler.last.MixedMode) + assert.Equal(t, CompileUnknownNonReadWarn, compiler.last.UnknownNonReadMode) + assert.Equal(t, CompileColumnDiscoveryAuto, compiler.last.ColumnDiscoveryMode) +} diff --git a/repository/shape/load/loader.go b/repository/shape/load/loader.go index 149117d2..55528601 100644 --- a/repository/shape/load/loader.go +++ b/repository/shape/load/loader.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/viant/datly/repository/shape" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/plan" "github.com/viant/datly/repository/shape/typectx" shapevalidate "github.com/viant/datly/repository/shape/validate" @@ -87,6 +88,28 @@ func buildComponent(source *shape.Source, pResult *plan.Result) *Component { continue } ret.Views = append(ret.Views, aView.Name) + if aView.Declaration != nil { + if ret.Declarations == nil { + ret.Declarations = map[string]*plan.ViewDeclaration{} + } + ret.Declarations[aView.Name] = aView.Declaration + if selector := strings.TrimSpace(aView.Declaration.QuerySelector); selector != "" { + if ret.QuerySelectors == nil { + ret.QuerySelectors = map[string][]string{} + } + ret.QuerySelectors[selector] = append(ret.QuerySelectors[selector], aView.Name) + } + if len(aView.Declaration.Predicates) > 0 { + if ret.Predicates == nil { + ret.Predicates = map[string][]*plan.ViewPredicate{} + } + ret.Predicates[aView.Name] = append(ret.Predicates[aView.Name], aView.Declaration.Predicates...) + } + } + if len(aView.Relations) > 0 { + ret.Relations = append(ret.Relations, aView.Relations...) + ret.ViewRelations = append(ret.ViewRelations, toViewRelations(aView.Relations)...) + } } rootView := pickRootView(pResult.Views) if rootView != nil { @@ -117,6 +140,8 @@ func buildComponent(source *shape.Source, pResult *plan.Result) *Component { } } ret.TypeContext = cloneTypeContext(pResult.TypeContext) + ret.Directives = cloneDirectives(pResult.Directives) + ret.ColumnsDiscovery = pResult.ColumnsDiscovery return ret } @@ -126,6 +151,9 @@ func cloneTypeContext(input *typectx.Context) *typectx.Context { } ret := &typectx.Context{ DefaultPackage: strings.TrimSpace(input.DefaultPackage), + PackageDir: strings.TrimSpace(input.PackageDir), + PackageName: strings.TrimSpace(input.PackageName), + PackagePath: strings.TrimSpace(input.PackagePath), } for _, item := range input.Imports { pkg := strings.TrimSpace(item.Package) @@ -137,7 +165,38 @@ func cloneTypeContext(input *typectx.Context) *typectx.Context { Package: pkg, }) } - if ret.DefaultPackage == "" && len(ret.Imports) == 0 { + if ret.DefaultPackage == "" && + len(ret.Imports) == 0 && + ret.PackageDir == "" && + ret.PackageName == "" && + ret.PackagePath == "" { + return nil + } + return ret +} + +func cloneDirectives(input *dqlshape.Directives) *dqlshape.Directives { + if input == nil { + return nil + } + ret := &dqlshape.Directives{ + Meta: strings.TrimSpace(input.Meta), + DefaultConnector: strings.TrimSpace(input.DefaultConnector), + } + if input.Cache != nil { + ret.Cache = &dqlshape.CacheDirective{ + Enabled: input.Cache.Enabled, + TTL: strings.TrimSpace(input.Cache.TTL), + } + } + if input.MCP != nil { + ret.MCP = &dqlshape.MCPDirective{ + Name: strings.TrimSpace(input.MCP.Name), + Description: strings.TrimSpace(input.MCP.Description), + DescriptionPath: strings.TrimSpace(input.MCP.DescriptionPath), + } + } + if ret.Meta == "" && ret.DefaultConnector == "" && ret.Cache == nil && ret.MCP == nil { return nil } return ret @@ -178,7 +237,16 @@ func materializeView(item *plan.View) (*view.View, error) { } schema := newSchema(schemaType, item.Cardinality) - opts := []view.Option{view.WithSchema(schema), view.WithMode(view.ModeQuery)} + mode := view.ModeQuery + switch strings.TrimSpace(item.Mode) { + case string(view.ModeExec): + mode = view.ModeExec + case string(view.ModeHandler): + mode = view.ModeHandler + case string(view.ModeQuery): + mode = view.ModeQuery + } + opts := []view.Option{view.WithSchema(schema), view.WithMode(mode)} if item.Connector != "" { opts = append(opts, view.WithConnectorRef(item.Connector)) @@ -186,6 +254,13 @@ func materializeView(item *plan.View) (*view.View, error) { if item.SQL != "" || item.SQLURI != "" { tmpl := view.NewTemplate(item.SQL) tmpl.SourceURL = item.SQLURI + if strings.TrimSpace(item.Summary) != "" { + tmpl.Summary = &view.TemplateSummary{ + Name: "Summary", + Source: item.Summary, + Kind: view.MetaKindRecord, + } + } opts = append(opts, view.WithTemplate(tmpl)) } if item.CacheRef != "" { @@ -203,6 +278,27 @@ func materializeView(item *plan.View) (*view.View, error) { return nil, err } aView.Ref = item.Ref + aView.Module = item.Module + aView.AllowNulls = item.AllowNulls + if strings.TrimSpace(item.SelectorNamespace) != "" || item.SelectorNoLimit != nil { + if aView.Selector == nil { + aView.Selector = &view.Config{} + } + if strings.TrimSpace(item.SelectorNamespace) != "" { + aView.Selector.Namespace = strings.TrimSpace(item.SelectorNamespace) + } + if item.SelectorNoLimit != nil { + aView.Selector.NoLimit = *item.SelectorNoLimit + } + } + if aView.Schema != nil && strings.TrimSpace(item.SchemaType) != "" { + if aView.Schema.DataType == "" { + aView.Schema.DataType = strings.TrimSpace(item.SchemaType) + } + if aView.Schema.Name == "" { + aView.Schema.Name = strings.Trim(strings.TrimSpace(item.SchemaType), "*") + } + } return aView, nil } @@ -216,6 +312,53 @@ func bestSchemaType(item *plan.View) reflect.Type { return nil } +func toViewRelations(input []*plan.Relation) []*view.Relation { + if len(input) == 0 { + return nil + } + result := make([]*view.Relation, 0, len(input)) + for _, item := range input { + if item == nil { + continue + } + relation := &view.Relation{ + Name: item.Name, + Holder: item.Holder, + On: toViewLinks(item.On, true), + Of: view.NewReferenceView( + toViewLinks(item.On, false), + view.NewView(item.Ref, item.Table), + ), + } + result = append(result, relation) + } + return result +} + +func toViewLinks(input []*plan.RelationLink, parent bool) view.Links { + if len(input) == 0 { + return nil + } + result := make(view.Links, 0, len(input)) + for _, item := range input { + if item == nil { + continue + } + link := &view.Link{} + if parent { + link.Field = item.ParentField + link.Namespace = item.ParentNamespace + link.Column = item.ParentColumn + } else { + link.Field = item.RefField + link.Namespace = item.RefNamespace + link.Column = item.RefColumn + } + result = append(result, link) + } + return result +} + func newSchema(rType reflect.Type, cardinality string) *state.Schema { if cardinality == "many" && rType.Kind() != reflect.Slice { return state.NewSchema(rType, state.WithMany()) diff --git a/repository/shape/load/loader_test.go b/repository/shape/load/loader_test.go index aab074ba..e2d45d3e 100644 --- a/repository/shape/load/loader_test.go +++ b/repository/shape/load/loader_test.go @@ -3,11 +3,13 @@ package load import ( "context" "embed" + "reflect" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/viant/datly/repository/shape" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/plan" "github.com/viant/datly/repository/shape/scan" "github.com/viant/datly/repository/shape/typectx" @@ -74,6 +76,47 @@ func TestLoader_LoadViews_InvalidPlanType(t *testing.T) { assert.Contains(t, err.Error(), "unsupported plan type") } +func TestLoader_LoadViews_Metadata(t *testing.T) { + noLimit := true + allowNulls := true + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "meta"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Name: "items", + Table: "ITEMS", + Module: "platform/items", + AllowNulls: &allowNulls, + SelectorNamespace: "it", + SelectorNoLimit: &noLimit, + SchemaType: "*ItemView", + Cardinality: "many", + FieldType: reflect.TypeOf([]map[string]interface{}{}), + ElementType: reflect.TypeOf(map[string]interface{}{}), + SQL: "SELECT * FROM ITEMS", + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + loader := New() + artifacts, err := loader.LoadViews(context.Background(), planned) + require.NoError(t, err) + require.NotNil(t, artifacts) + require.Len(t, artifacts.Views, 1) + actual := artifacts.Views[0] + assert.Equal(t, "platform/items", actual.Module) + require.NotNil(t, actual.AllowNulls) + assert.True(t, *actual.AllowNulls) + require.NotNil(t, actual.Selector) + assert.Equal(t, "it", actual.Selector.Namespace) + assert.True(t, actual.Selector.NoLimit) + require.NotNil(t, actual.Schema) + assert.Equal(t, "*ItemView", actual.Schema.DataType) +} + func TestLoader_LoadComponent(t *testing.T) { scanner := scan.New() scanned, err := scanner.Scan(context.Background(), &shape.Source{Name: "/v1/api/report", Struct: &reportSource{}}) @@ -84,12 +127,26 @@ func TestLoader_LoadComponent(t *testing.T) { require.NoError(t, err) actualPlan, ok := planned.Plan.(*plan.Result) require.True(t, ok) + actualPlan.ColumnsDiscovery = true actualPlan.TypeContext = &typectx.Context{ DefaultPackage: "mdp/performance", Imports: []typectx.Import{ {Alias: "perf", Package: "github.com/acme/mdp/performance"}, }, } + actualPlan.Directives = &dqlshape.Directives{ + Meta: "docs/report.md", + DefaultConnector: "analytics", + Cache: &dqlshape.CacheDirective{ + Enabled: true, + TTL: "5m", + }, + MCP: &dqlshape.MCPDirective{ + Name: "report.list", + Description: "List report rows", + DescriptionPath: "docs/mcp/report.md", + }, + } loader := New() artifact, err := loader.LoadComponent(context.Background(), planned) @@ -113,4 +170,69 @@ func TestLoader_LoadComponent(t *testing.T) { assert.Equal(t, "mdp/performance", component.TypeContext.DefaultPackage) require.Len(t, component.TypeContext.Imports, 1) assert.Equal(t, "perf", component.TypeContext.Imports[0].Alias) + require.NotNil(t, component.Directives) + assert.Equal(t, "docs/report.md", component.Directives.Meta) + assert.Equal(t, "analytics", component.Directives.DefaultConnector) + require.NotNil(t, component.Directives.Cache) + assert.True(t, component.Directives.Cache.Enabled) + assert.Equal(t, "5m", component.Directives.Cache.TTL) + require.NotNil(t, component.Directives.MCP) + assert.Equal(t, "report.list", component.Directives.MCP.Name) + assert.True(t, component.ColumnsDiscovery) +} + +func TestLoader_LoadComponent_RelationFieldsPreserved(t *testing.T) { + planned := &shape.PlanResult{ + Source: &shape.Source{Name: "/v1/api/report"}, + Plan: &plan.Result{ + Views: []*plan.View{ + { + Path: "Rows", + Name: "rows", + Table: "REPORT", + Cardinality: "many", + FieldType: reflect.TypeOf([]reportRow{}), + ElementType: reflect.TypeOf(reportRow{}), + Relations: []*plan.Relation{ + { + Name: "detail", + Holder: "Detail", + Ref: "detail", + Table: "REPORT_DETAIL", + On: []*plan.RelationLink{ + { + ParentField: "ReportID", + ParentNamespace: "rows", + ParentColumn: "REPORT_ID", + RefField: "ID", + RefNamespace: "detail", + RefColumn: "ID", + }, + }, + }, + }, + }, + }, + ViewsByName: map[string]*plan.View{}, + ByPath: map[string]*plan.Field{}, + }, + } + + loader := New() + artifact, err := loader.LoadComponent(context.Background(), planned) + require.NoError(t, err) + component, ok := artifact.Component.(*Component) + require.True(t, ok) + require.Len(t, component.ViewRelations, 1) + require.Len(t, component.ViewRelations[0].On, 1) + require.Len(t, component.ViewRelations[0].Of.On, 1) + + parent := component.ViewRelations[0].On[0] + ref := component.ViewRelations[0].Of.On[0] + assert.Equal(t, "ReportID", parent.Field) + assert.Equal(t, "rows", parent.Namespace) + assert.Equal(t, "REPORT_ID", parent.Column) + assert.Equal(t, "ID", ref.Field) + assert.Equal(t, "detail", ref.Namespace) + assert.Equal(t, "ID", ref.Column) } diff --git a/repository/shape/load/model.go b/repository/shape/load/model.go index 8f5d384d..6459f57a 100644 --- a/repository/shape/load/model.go +++ b/repository/shape/load/model.go @@ -1,17 +1,26 @@ package load import "github.com/viant/datly/repository/shape/plan" +import dqlshape "github.com/viant/datly/repository/shape/dql/shape" import "github.com/viant/datly/repository/shape/typectx" +import "github.com/viant/datly/view" // Component is a shape-loaded runtime-neutral component artifact. // It intentionally avoids repository package coupling to keep shape/load reusable. type Component struct { - Name string - URI string - Method string - RootView string - Views []string - TypeContext *typectx.Context + Name string + URI string + Method string + RootView string + Views []string + Relations []*plan.Relation + ViewRelations []*view.Relation + Declarations map[string]*plan.ViewDeclaration + QuerySelectors map[string][]string + Predicates map[string][]*plan.ViewPredicate + TypeContext *typectx.Context + Directives *dqlshape.Directives + ColumnsDiscovery bool Input []*plan.State Output []*plan.State diff --git a/repository/shape/model.go b/repository/shape/model.go index f71fd5c2..4e0bde7b 100644 --- a/repository/shape/model.go +++ b/repository/shape/model.go @@ -19,6 +19,8 @@ const ( // Source represents the caller-provided shape source. type Source struct { Name string + Path string + Connector string Struct any Type reflect.Type TypeName string diff --git a/repository/shape/normalize/sql.go b/repository/shape/normalize/sql.go new file mode 100644 index 00000000..945840dd --- /dev/null +++ b/repository/shape/normalize/sql.go @@ -0,0 +1,56 @@ +package normalize + +import ( + "github.com/viant/sqlparser" + "github.com/viant/sqlparser/expr" + "github.com/viant/sqlparser/node" + "github.com/viant/sqlparser/query" + "github.com/viant/tagly/format/text" +) + +type mapper map[string]string + +func (m mapper) Map(name string) string { + ret, ok := m[name] + if ok { + return ret + } + return name +} + +func SQL(input string, generated bool, option func() sqlparser.Option) string { + if !generated { + return input + } + sqlQuery, err := sqlparser.ParseQuery(input, option()) + if err != nil { + return input + } + ns := mapper{} + if sqlQuery.From.Alias != "" { + ns[sqlQuery.From.Alias] = normalizeName(sqlQuery.From.Alias) + } + for _, join := range sqlQuery.Joins { + ns[join.Alias] = normalizeName(join.Alias) + } + + sqlparser.Traverse(sqlQuery, func(n node.Node) bool { + switch actual := n.(type) { + case *expr.Selector: + actual.Name = ns.Map(actual.Name) + case *query.Join: + actual.Alias = ns.Map(actual.Alias) + case *query.Item: + actual.Alias = ns.Map(actual.Alias) + case *query.From: + actual.Alias = ns.Map(actual.Alias) + } + return true + }) + return sqlparser.Stringify(sqlQuery) +} + +func normalizeName(k string) string { + caseFormat := text.DetectCaseFormat(k) + return caseFormat.Format(k, text.CaseFormatUpperCamel) +} diff --git a/repository/shape/normalize/sql_test.go b/repository/shape/normalize/sql_test.go new file mode 100644 index 00000000..aaba9af7 --- /dev/null +++ b/repository/shape/normalize/sql_test.go @@ -0,0 +1,66 @@ +package normalize + +import ( + "testing" + + "github.com/stretchr/testify/require" + legacy "github.com/viant/datly/cmd/options" + "github.com/viant/sqlparser" +) + +func parserOption() sqlparser.Option { + return sqlparser.WithErrorHandler(nil) +} + +func TestSQL_ParityWithLegacyNormalizer(t *testing.T) { + type normalizeCase struct { + Name string + Generated bool + SQL string + } + cases := []normalizeCase{ + { + Name: "skip normalization when not generated", + Generated: false, + SQL: "SELECT a.id FROM users a JOIN orders b ON a.id = b.user_id", + }, + { + Name: "invalid sql returns input", + Generated: true, + SQL: "SELECT * FROM (", + }, + { + Name: "normalize from and join aliases in selectors and alias nodes", + Generated: true, + SQL: "SELECT a.id, b.user_id FROM users a JOIN orders b ON a.id = b.user_id", + }, + { + Name: "keep alias that is already normalized", + Generated: true, + SQL: "SELECT UserAlias.id FROM users UserAlias", + }, + { + Name: "normalize snake_case alias", + Generated: true, + SQL: "SELECT order_item.id FROM users order_item", + }, + } + for _, testCase := range cases { + t.Run(testCase.Name, func(t *testing.T) { + expected := (&legacy.Rule{Generated: testCase.Generated}).NormalizeSQL(testCase.SQL, parserOption) + actual := SQL(testCase.SQL, testCase.Generated, parserOption) + require.Equal(t, expected, actual) + }) + } +} + +func TestMapper_Map(t *testing.T) { + m := mapper{"a": "A"} + require.Equal(t, "A", m.Map("a")) + require.Equal(t, "b", m.Map("b")) +} + +func TestNormalizeName(t *testing.T) { + require.Equal(t, "UserAlias", normalizeName("user_alias")) + require.Equal(t, "UserAlias", normalizeName("UserAlias")) +} diff --git a/repository/shape/options.go b/repository/shape/options.go index 05b0a774..27b970fa 100644 --- a/repository/shape/options.go +++ b/repository/shape/options.go @@ -2,14 +2,18 @@ package shape // Options stores shape facade dependencies and behavior flags. type Options struct { - Mode Mode - Strict bool - Name string - Scanner Scanner - Planner Planner - Loader Loader - Compiler DQLCompiler - Runtime RuntimeRegistrar + Mode Mode + Strict bool + Name string + Scanner Scanner + Planner Planner + Loader Loader + Compiler DQLCompiler + Runtime RuntimeRegistrar + CompileProfile CompileProfile + CompileMixedMode CompileMixedMode + UnknownNonReadMode CompileUnknownNonReadMode + ColumnDiscoveryMode CompileColumnDiscoveryMode } // Option mutates Options. @@ -17,7 +21,12 @@ type Option func(*Options) // NewOptions builds Options from varargs. func NewOptions(opts ...Option) *Options { - ret := &Options{} + ret := &Options{ + CompileProfile: CompileProfileCompat, + CompileMixedMode: CompileMixedModeExecWins, + UnknownNonReadMode: CompileUnknownNonReadWarn, + ColumnDiscoveryMode: CompileColumnDiscoveryAuto, + } for _, opt := range opts { opt(ret) } @@ -71,3 +80,161 @@ func WithRuntime(runtime RuntimeRegistrar) Option { o.Runtime = runtime } } + +// WithCompileProfileDefault sets default compiler profile used by Engine DQL compile path. +func WithCompileProfileDefault(profile CompileProfile) Option { + return func(o *Options) { + o.CompileProfile = profile + } +} + +// WithMixedModeDefault sets default compiler mixed read/exec mode used by Engine DQL compile path. +func WithMixedModeDefault(mode CompileMixedMode) Option { + return func(o *Options) { + o.CompileMixedMode = mode + } +} + +// WithUnknownNonReadModeDefault sets default unknown non-read mode used by Engine DQL compile path. +func WithUnknownNonReadModeDefault(mode CompileUnknownNonReadMode) Option { + return func(o *Options) { + o.UnknownNonReadMode = mode + } +} + +// WithColumnDiscoveryModeDefault sets default column discovery policy used by Engine DQL compile path. +func WithColumnDiscoveryModeDefault(mode CompileColumnDiscoveryMode) Option { + return func(o *Options) { + o.ColumnDiscoveryMode = mode + } +} + +// WithLegacyTranslatorDefaults configures Engine compile defaults to legacy-compatible behavior. +func WithLegacyTranslatorDefaults() Option { + return func(o *Options) { + o.Strict = false + o.CompileProfile = CompileProfileCompat + o.CompileMixedMode = CompileMixedModeExecWins + o.UnknownNonReadMode = CompileUnknownNonReadWarn + o.ColumnDiscoveryMode = CompileColumnDiscoveryAuto + } +} + +func WithCompileStrict(strict bool) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.Strict = strict + } +} + +func WithMixedMode(mode CompileMixedMode) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.MixedMode = mode + } +} + +func WithUnknownNonReadMode(mode CompileUnknownNonReadMode) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.UnknownNonReadMode = mode + } +} + +func WithCompileProfile(profile CompileProfile) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.Profile = profile + } +} + +func WithColumnDiscoveryMode(mode CompileColumnDiscoveryMode) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.ColumnDiscoveryMode = mode + } +} + +// WithDQLPathMarker overrides the path marker used to locate platform root from source path. +// Default is "/dql/". +func WithDQLPathMarker(marker string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.DQLPathMarker = marker + } +} + +// WithRoutesRelativePath overrides routes path relative to detected platform root. +// Default is "repo/dev/Datly/routes". +func WithRoutesRelativePath(path string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.RoutesRelativePath = path + } +} + +// WithTypeContextPackageDir sets default type-context package directory (for xgen parity). +func WithTypeContextPackageDir(dir string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.TypePackageDir = dir + } +} + +// WithTypeContextPackageName sets default type-context package name (for xgen parity). +func WithTypeContextPackageName(name string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.TypePackageName = name + } +} + +// WithTypeContextPackagePath sets default type-context package import path (for xgen parity). +func WithTypeContextPackagePath(path string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.TypePackagePath = path + } +} + +// WithTypeContextPackageDefaults sets package dir/name/path in one call. +func WithTypeContextPackageDefaults(dir, name, path string) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.TypePackageDir = dir + o.TypePackageName = name + o.TypePackagePath = path + } +} + +// WithInferTypeContextDefaults enables/disables source-path based type context defaults. +func WithInferTypeContextDefaults(enabled bool) CompileOption { + return func(o *CompileOptions) { + if o == nil { + return + } + o.InferTypeContext = &enabled + } +} diff --git a/repository/shape/parity_test.go b/repository/shape/parity_test.go index 713bfd31..725dbe63 100644 --- a/repository/shape/parity_test.go +++ b/repository/shape/parity_test.go @@ -31,6 +31,15 @@ type paritySource struct { Rows []parityRow `view:"rows,table=REPORT,connector=dev" sql:"uri=scan/testdata/report.sql"` } +type parityJoinRow struct { + ReportID int `source:"REPORT_ID"` +} + +type parityJoinSource struct { + parityEmbedded + Rows []parityJoinRow `view:"rows,table=REPORT,connector=dev" sql:"uri=scan/testdata/report.sql" on:"ReportID:rows.REPORT_ID=ID:detail.ID"` +} + func TestEngineParity_StructPipeline(t *testing.T) { source := &paritySource{} scanner := shapeScan.New() @@ -65,3 +74,35 @@ func TestEngineParity_StructPipeline(t *testing.T) { assert.Equal(t, mv.Schema.Cardinality, ev.Schema.Cardinality) assert.Equal(t, reflect.TypeOf(mv.Schema.CompType()), reflect.TypeOf(ev.Schema.CompType())) } + +func TestEngineParity_Component_SourceTagFieldJoin(t *testing.T) { + source := &parityJoinSource{} + scanner := shapeScan.New() + planner := shapePlan.New() + loader := shapeLoad.New() + + engine := shape.New( + shape.WithName("/v1/api/parity"), + shape.WithScanner(scanner), + shape.WithPlanner(planner), + shape.WithLoader(loader), + ) + artifact, err := engine.LoadComponent(context.Background(), source) + require.NoError(t, err) + require.NotNil(t, artifact) + + component, ok := artifact.Component.(*shapeLoad.Component) + require.True(t, ok) + require.Len(t, component.ViewRelations, 1) + require.Len(t, component.ViewRelations[0].On, 1) + require.Len(t, component.ViewRelations[0].Of.On, 1) + + parent := component.ViewRelations[0].On[0] + ref := component.ViewRelations[0].Of.On[0] + assert.Equal(t, "ReportID", parent.Field) + assert.Equal(t, "rows", parent.Namespace) + assert.Equal(t, "REPORT_ID", parent.Column) + assert.Equal(t, "ID", ref.Field) + assert.Equal(t, "detail", ref.Namespace) + assert.Equal(t, "ID", ref.Column) +} diff --git a/repository/shape/plan/model.go b/repository/shape/plan/model.go index 8dacf2bb..8935786a 100644 --- a/repository/shape/plan/model.go +++ b/repository/shape/plan/model.go @@ -4,6 +4,7 @@ import ( "embed" "reflect" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" "github.com/viant/datly/repository/shape/typectx" ) @@ -12,12 +13,26 @@ type Result struct { RootType reflect.Type EmbedFS *embed.FS - Fields []*Field - ByPath map[string]*Field - Views []*View - ViewsByName map[string]*View - States []*State - TypeContext *typectx.Context + Fields []*Field + ByPath map[string]*Field + Views []*View + ViewsByName map[string]*View + States []*State + Types []*Type + ColumnsDiscovery bool + TypeContext *typectx.Context + Directives *dqlshape.Directives + Diagnostics []*dqlshape.Diagnostic +} + +// Type is normalized type metadata collected during compile. +type Type struct { + Name string + Alias string + DataType string + Cardinality string + Package string + ModulePath string } // Field is a normalized projection of scanned field metadata. @@ -33,7 +48,9 @@ type View struct { Path string Name string Ref string + Mode string Table string + Module string Connector string CacheRef string Partitioner string @@ -42,31 +59,103 @@ type View struct { SQL string SQLURI string Summary string - Links []string + Relations []*Relation Holder string + AllowNulls *bool + SelectorNamespace string + SelectorNoLimit *bool + SchemaType string + ColumnsDiscovery bool + Cardinality string ElementType reflect.Type FieldType reflect.Type + Declaration *ViewDeclaration +} + +// ViewDeclaration captures declaration options used to derive a view from DQL directives. +type ViewDeclaration struct { + Tag string + Codec string + CodecArgs []string + HandlerName string + HandlerArgs []string + StatusCode *int + ErrorMessage string + QuerySelector string + CacheRef string + Limit *int + Cacheable *bool + When string + Scope string + DataType string + Of string + Value string + Async bool + Output bool + Predicates []*ViewPredicate +} + +// ViewPredicate captures WithPredicate / EnsurePredicate metadata. +type ViewPredicate struct { + Name string + Source string + Ensure bool + Arguments []string +} + +// Relation is normalized relation metadata extracted from DQL joins. +type Relation struct { + Name string + Holder string + Ref string + Table string + Kind string + Raw string + On []*RelationLink + Warnings []string +} + +// RelationLink represents one parent/ref join predicate. +type RelationLink struct { + ParentField string + ParentNamespace string + ParentColumn string + RefField string + RefNamespace string + RefColumn string + Expression string } // State is a normalized parameter field plan. type State struct { - Path string - Name string - Kind string - In string - When string - Scope string - DataType string - Required *bool - Async bool - Cacheable *bool - With string - URI string - ErrorCode int - ErrorMessage string + Path string + Name string + Kind string + In string + QuerySelector string + When string + Scope string + DataType string + Value string + Required *bool + Async bool + Cacheable *bool + With string + URI string + ErrorCode int + ErrorMessage string + Predicates []*StatePredicate TagType reflect.Type EffectiveType reflect.Type } + +// StatePredicate captures parameter predicate semantics from DQL declarations. +type StatePredicate struct { + Group int + Name string + Ensure bool + Arguments []string +} diff --git a/repository/shape/plan/planner.go b/repository/shape/plan/planner.go index ec66aea5..2c3735dc 100644 --- a/repository/shape/plan/planner.go +++ b/repository/shape/plan/planner.go @@ -86,7 +86,7 @@ func normalizeView(field *scan.Field) *View { result.SQLURI = tag.SQL.URI result.Summary = tag.SummarySQL.SQL if len(tag.LinkOn) > 0 { - result.Links = append(result.Links, tag.LinkOn...) + result.Relations = append(result.Relations, relationFromTagLinks(field.Name, tag.LinkOn)) } result.Ref = strings.TrimSpace(tag.TypeName) } @@ -101,6 +101,67 @@ func normalizeView(field *scan.Field) *View { return result } +func relationFromTagLinks(holder string, links []string) *Relation { + relation := &Relation{ + Name: strings.TrimSpace(holder), + Holder: strings.TrimSpace(holder), + Ref: strings.TrimSpace(holder), + } + for _, linkExpr := range links { + linkExpr = strings.TrimSpace(linkExpr) + if linkExpr == "" { + continue + } + left, right, ok := strings.Cut(linkExpr, "=") + if !ok { + continue + } + leftField, leftNS, leftCol := splitTagSelector(left) + rightField, rightNS, rightCol := splitTagSelector(right) + if leftCol == "" || rightCol == "" { + continue + } + relation.On = append(relation.On, &RelationLink{ + ParentField: leftField, + ParentNamespace: leftNS, + ParentColumn: leftCol, + RefField: rightField, + RefNamespace: rightNS, + RefColumn: rightCol, + Expression: strings.TrimSpace(left) + "=" + strings.TrimSpace(right), + }) + } + if relation.Ref == "" { + relation.Ref = "relation" + } + if relation.Holder == "" { + relation.Holder = relation.Ref + } + if relation.Name == "" { + relation.Name = relation.Holder + } + return relation +} + +func splitTagSelector(value string) (string, string, string) { + value = strings.TrimSpace(value) + value = strings.TrimSuffix(value, "(true)") + value = strings.TrimSuffix(value, "(false)") + field := "" + if idx := strings.Index(value, ":"); idx >= 0 { + field = strings.TrimSpace(value[:idx]) + value = value[idx+1:] + } + value = strings.Trim(value, "`\"") + if value == "" { + return field, "", "" + } + if idx := strings.Index(value, "."); idx >= 0 { + return field, strings.TrimSpace(value[:idx]), strings.TrimSpace(value[idx+1:]) + } + return field, "", strings.TrimSpace(value) +} + func normalizeState(field *scan.Field) *State { result := &State{Path: field.Path, TagType: field.Type} if field.StateTag == nil || field.StateTag.Parameter == nil { diff --git a/repository/shape/plan/planner_test.go b/repository/shape/plan/planner_test.go index 29bb1e79..7dc1edb1 100644 --- a/repository/shape/plan/planner_test.go +++ b/repository/shape/plan/planner_test.go @@ -36,6 +36,18 @@ type reportSource struct { ID int `parameter:"id,kind=query,in=id"` } +type relationRow struct { + ID int +} + +type relationSource struct { + Rows []relationRow `view:"rows,table=REPORT" on:"rows.report_id=report.id"` +} + +type relationSourceWithFields struct { + Rows []relationRow `view:"rows,table=REPORT" on:"ReportID:rows.report_id=ID:report.id"` +} + func TestPlanner_Plan(t *testing.T) { scanner := scan.New() scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &reportSource{}}) @@ -78,6 +90,54 @@ func TestPlanner_Plan(t *testing.T) { assert.Equal(t, stateByPath["ID"].TagType, stateByPath["ID"].EffectiveType) } +func TestPlanner_Plan_LinkOnProducesStructuredRelations(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &relationSource{}}) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + require.NotNil(t, planned) + + result, ok := planned.Plan.(*Result) + require.True(t, ok) + require.Len(t, result.Views, 1) + viewPlan := result.Views[0] + require.Len(t, viewPlan.Relations, 1) + relation := viewPlan.Relations[0] + require.Len(t, relation.On, 1) + assert.Equal(t, "rows", relation.On[0].ParentNamespace) + assert.Equal(t, "report_id", relation.On[0].ParentColumn) + assert.Equal(t, "report", relation.On[0].RefNamespace) + assert.Equal(t, "id", relation.On[0].RefColumn) +} + +func TestPlanner_Plan_LinkOnPreservesFieldSelectors(t *testing.T) { + scanner := scan.New() + scanned, err := scanner.Scan(context.Background(), &shape.Source{Struct: &relationSourceWithFields{}}) + require.NoError(t, err) + + planner := New() + planned, err := planner.Plan(context.Background(), scanned) + require.NoError(t, err) + require.NotNil(t, planned) + + result, ok := planned.Plan.(*Result) + require.True(t, ok) + require.Len(t, result.Views, 1) + viewPlan := result.Views[0] + require.Len(t, viewPlan.Relations, 1) + relation := viewPlan.Relations[0] + require.Len(t, relation.On, 1) + assert.Equal(t, "ReportID", relation.On[0].ParentField) + assert.Equal(t, "rows", relation.On[0].ParentNamespace) + assert.Equal(t, "report_id", relation.On[0].ParentColumn) + assert.Equal(t, "ID", relation.On[0].RefField) + assert.Equal(t, "report", relation.On[0].RefNamespace) + assert.Equal(t, "id", relation.On[0].RefColumn) +} + func TestPlanner_Plan_InvalidDescriptors(t *testing.T) { planner := New() _, err := planner.Plan(context.Background(), &shape.ScanResult{Source: &shape.Source{Name: "x"}, Descriptors: "invalid"}) diff --git a/repository/shape/platform_parity_metadata_test.go b/repository/shape/platform_parity_metadata_test.go new file mode 100644 index 00000000..8ad4a6d8 --- /dev/null +++ b/repository/shape/platform_parity_metadata_test.go @@ -0,0 +1,77 @@ +package shape_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestCompareMetadataParity(t *testing.T) { + trueValue := true + falseValue := false + + legacyMeta := &resourceMetaIR{ColumnsDiscovery: &trueValue} + shapeMeta := &resourceMetaIR{ColumnsDiscovery: &trueValue} + + legacyViews := []viewMetaIR{ + { + Name: "items", + Mode: "SQLQuery", + Module: "platform/items", + AllowNulls: &trueValue, + SelectorNamespace: "item", + SelectorNoLimit: &falseValue, + SchemaCardinality: "Many", + SchemaType: "*ItemView", + HasSummary: &trueValue, + }, + } + shapeViews := []viewMetaIR{ + { + Name: "items", + Mode: "SQLQuery", + Module: "platform/items", + AllowNulls: &trueValue, + SelectorNamespace: "item", + SelectorNoLimit: &falseValue, + SchemaCardinality: "Many", + SchemaType: "*ItemView", + HasSummary: &trueValue, + }, + } + + assert.Empty(t, compareMetadataParity(legacyMeta, shapeMeta, legacyViews, shapeViews)) +} + +func TestCompareMetadataParity_DetectsMismatches(t *testing.T) { + trueValue := true + falseValue := false + + legacyMeta := &resourceMetaIR{ColumnsDiscovery: &trueValue} + shapeMeta := &resourceMetaIR{ColumnsDiscovery: &falseValue} + + legacyViews := []viewMetaIR{{ + Name: "items", + Mode: "SQLQuery", + Module: "platform/items", + AllowNulls: &trueValue, + SelectorNoLimit: &trueValue, + SchemaType: "*ItemView", + }} + shapeViews := []viewMetaIR{{ + Name: "items", + Mode: "SQLExec", + Module: "platform/items2", + AllowNulls: &falseValue, + SelectorNoLimit: &falseValue, + SchemaType: "*OtherView", + }} + + mismatches := compareMetadataParity(legacyMeta, shapeMeta, legacyViews, shapeViews) + assert.Contains(t, mismatches, "resource columnsDiscovery mismatch") + assert.Contains(t, mismatches, "view mode mismatch: items") + assert.Contains(t, mismatches, "view module mismatch: items") + assert.Contains(t, mismatches, "view allowNulls mismatch: items") + assert.Contains(t, mismatches, "view selector noLimit mismatch: items") + assert.Contains(t, mismatches, "view schema type mismatch: items") +} diff --git a/repository/shape/platform_parity_test.go b/repository/shape/platform_parity_test.go new file mode 100644 index 00000000..f6837b03 --- /dev/null +++ b/repository/shape/platform_parity_test.go @@ -0,0 +1,1478 @@ +package shape_test + +import ( + "context" + "os" + "path/filepath" + "regexp" + "sort" + "strconv" + "strings" + "testing" + + shape "github.com/viant/datly/repository/shape" + shapecompile "github.com/viant/datly/repository/shape/compile" + dqlshape "github.com/viant/datly/repository/shape/dql/shape" + dqlstmt "github.com/viant/datly/repository/shape/dql/statement" + shapeload "github.com/viant/datly/repository/shape/load" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/view" + "gopkg.in/yaml.v3" +) + +type parityRule struct { + Mode string `yaml:"mode"` + Namespace string `yaml:"namespace"` + Source string `yaml:"source"` + Connector string `yaml:"connector,omitempty"` +} + +type legacyYAML struct { + ColumnsDiscovery *bool `yaml:"ColumnsDiscovery"` + TypeContext struct { + DefaultPackage string `yaml:"DefaultPackage"` + PackageDir string `yaml:"PackageDir"` + PackageName string `yaml:"PackageName"` + PackagePath string `yaml:"PackagePath"` + } `yaml:"TypeContext"` + Resource struct { + Views []struct { + Name string `yaml:"Name"` + Table string `yaml:"Table"` + Mode string `yaml:"Mode"` + Module string `yaml:"Module"` + AllowNulls *bool `yaml:"AllowNulls"` + Connector struct { + Ref string `yaml:"Ref"` + } `yaml:"Connector"` + Schema struct { + Cardinality string `yaml:"Cardinality"` + DataType string `yaml:"DataType"` + Name string `yaml:"Name"` + } `yaml:"Schema"` + Template struct { + SourceURL string `yaml:"SourceURL"` + Summary *struct { + Name string `yaml:"Name"` + Kind string `yaml:"Kind"` + } `yaml:"Summary"` + } `yaml:"Template"` + Selector struct { + Namespace string `yaml:"Namespace"` + NoLimit *bool `yaml:"NoLimit"` + LimitParameter selectorParam `yaml:"LimitParameter"` + OffsetParameter selectorParam `yaml:"OffsetParameter"` + PageParameter selectorParam `yaml:"PageParameter"` + FieldsParameter selectorParam `yaml:"FieldsParameter"` + OrderByParameter selectorParam `yaml:"OrderByParameter"` + } `yaml:"Selector"` + } `yaml:"Views"` + Parameters []struct { + Name string `yaml:"Name"` + URI string `yaml:"URI"` + Value string `yaml:"Value"` + Required *bool `yaml:"Required"` + Cacheable *bool `yaml:"Cacheable"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` + Predicates []struct { + Group int `yaml:"Group"` + Name string `yaml:"Name"` + Ensure bool `yaml:"Ensure"` + Args []string `yaml:"Args"` + } `yaml:"Predicates"` + } `yaml:"Parameters"` + Types []struct { + Name string `yaml:"Name"` + Alias string `yaml:"Alias"` + DataType string `yaml:"DataType"` + Cardinality string `yaml:"Cardinality"` + Package string `yaml:"Package"` + ModulePath string `yaml:"ModulePath"` + } `yaml:"Types"` + } `yaml:"Resource"` + Routes []struct { + Method string `yaml:"Method"` + URI string `yaml:"URI"` + View struct { + Ref string `yaml:"Ref"` + } `yaml:"View"` + } `yaml:"Routes"` +} + +type viewIR struct { + Name string `yaml:"name"` + Table string `yaml:"table"` + Connector string `yaml:"connector,omitempty"` + SQLURI string `yaml:"sqlUri,omitempty"` +} + +type routeIR struct { + Method string `yaml:"method,omitempty"` + URI string `yaml:"uri,omitempty"` + View string `yaml:"view,omitempty"` +} + +type resourceMetaIR struct { + ColumnsDiscovery *bool `yaml:"columnsDiscovery,omitempty"` +} + +type viewMetaIR struct { + Name string `yaml:"name"` + Mode string `yaml:"mode,omitempty"` + Module string `yaml:"module,omitempty"` + AllowNulls *bool `yaml:"allowNulls,omitempty"` + SelectorNamespace string `yaml:"selectorNamespace,omitempty"` + SelectorNoLimit *bool `yaml:"selectorNoLimit,omitempty"` + SchemaCardinality string `yaml:"schemaCardinality,omitempty"` + SchemaType string `yaml:"schemaType,omitempty"` + HasSummary *bool `yaml:"hasSummary,omitempty"` +} + +type parityOutput struct { + Namespace string `yaml:"namespace"` + Source string `yaml:"source"` + LegacyYAML string `yaml:"legacyYaml"` + LegacyMeta *resourceMetaIR `yaml:"legacyMeta,omitempty"` + LegacyViews []viewIR `yaml:"legacyViews,omitempty"` + LegacyViewMeta []viewMetaIR `yaml:"legacyViewMeta,omitempty"` + LegacyParams []paramIR `yaml:"legacyParams,omitempty"` + LegacyRoutes []routeIR `yaml:"legacyRoutes,omitempty"` + LegacyTypes []typeIR `yaml:"legacyTypes,omitempty"` + LegacyTypeCtx *typeCtxIR `yaml:"legacyTypeContext,omitempty"` + ShapeMeta *resourceMetaIR `yaml:"shapeMeta,omitempty"` + ShapeViews []viewIR `yaml:"shapeViews,omitempty"` + ShapeViewMeta []viewMetaIR `yaml:"shapeViewMeta,omitempty"` + ShapeParams []paramIR `yaml:"shapeParams,omitempty"` + ShapeTypes []typeIR `yaml:"shapeTypes,omitempty"` + ShapeTypeCtx *typeCtxIR `yaml:"shapeTypeContext,omitempty"` + ShapeDiags []string `yaml:"shapeDiagnostics,omitempty"` + Mismatches []string `yaml:"mismatches,omitempty"` + CompileFailed bool `yaml:"compileFailed,omitempty"` + RawDiagnostics []*dqlshape.Diagnostic `yaml:"-"` +} + +type parityReport struct { + Total int `yaml:"total"` + Compared int `yaml:"compared"` + WithDiff int `yaml:"withDiff"` + MissingYAML int `yaml:"missingYaml"` + Failures int `yaml:"failures"` + TopIssues []string `yaml:"topIssues,omitempty"` +} + +type selectorParam struct { + Name string `yaml:"Name"` + Cacheable *bool `yaml:"Cacheable"` + In struct { + Kind string `yaml:"Kind"` + Name string `yaml:"Name"` + } `yaml:"In"` +} + +type paramIR struct { + Name string `yaml:"name"` + Kind string `yaml:"kind,omitempty"` + In string `yaml:"in,omitempty"` + Required *bool `yaml:"required,omitempty"` + Cacheable *bool `yaml:"cacheable,omitempty"` + URI string `yaml:"uri,omitempty"` + Value string `yaml:"value,omitempty"` + QuerySelector string `yaml:"querySelector,omitempty"` + Predicates []string `yaml:"predicates,omitempty"` +} + +type typeIR struct { + Name string `yaml:"name"` + Alias string `yaml:"alias,omitempty"` + DataType string `yaml:"dataType,omitempty"` + Cardinality string `yaml:"cardinality,omitempty"` + Package string `yaml:"package,omitempty"` + ModulePath string `yaml:"modulePath,omitempty"` +} + +type typeCtxIR struct { + DefaultPackage string `yaml:"defaultPackage,omitempty"` + PackageDir string `yaml:"packageDir,omitempty"` + PackageName string `yaml:"packageName,omitempty"` + PackagePath string `yaml:"packagePath,omitempty"` +} + +type parityEntryEval struct { + Output parityOutput + SourceReadable bool + MissingLegacyYAML bool +} + +func TestPlatform_DQLToRoute_ParityIR_SmokeHandlers(t *testing.T) { + platformRoot := os.Getenv("PLATFORM_ROOT") + if platformRoot == "" { + platformRoot = "/Users/awitas/go/src/github.vianttech.com/viant/platform" + } + rulesRoot := filepath.Join(platformRoot, "e2e", "rule") + routesRoot := filepath.Join(platformRoot, "repo", "dev", "Datly", "routes") + if _, err := os.Stat(rulesRoot); err != nil { + if os.Getenv("PLATFORM_PARITY_SMOKE_REQUIRED") == "1" { + t.Fatalf("platform rules not found at %s", rulesRoot) + } + t.Skipf("platform rules not found at %s", rulesRoot) + } + entries, err := collectRuleMappings(rulesRoot) + if err != nil { + t.Fatalf("collect mappings: %v", err) + } + if len(entries) == 0 { + t.Fatalf("no dql->route mappings found under %s", rulesRoot) + } + entryBySource := map[string]parityRule{} + for _, entry := range entries { + entryBySource[entry.Source] = entry + } + highRiskHandlers := collectSmokeHandlerSources(entries, routesRoot) + if len(highRiskHandlers) < 5 { + t.Fatalf("smoke handler discovery returned too few sources: %d", len(highRiskHandlers)) + } + + compiler := shapecompile.New() + for _, source := range highRiskHandlers { + entry, ok := entryBySource[source] + if !ok { + t.Fatalf("smoke source not found in rule mappings: %s", source) + } + eval := evaluateParityEntry(platformRoot, routesRoot, entry, compiler) + if !eval.SourceReadable { + t.Fatalf("unable to read source for smoke source: %s", source) + } + if eval.MissingLegacyYAML { + t.Fatalf("missing legacy yaml for smoke source: %s", source) + } + out := eval.Output + if out.CompileFailed { + t.Fatalf("shape compile failed for %s: %v", source, out.ShapeDiags) + } + if len(out.Mismatches) > 0 { + t.Fatalf("parity mismatches for %s: %v", source, out.Mismatches) + } + } +} + +func TestPlatform_DQLToRoute_ParityIR(t *testing.T) { + platformRoot := os.Getenv("PLATFORM_ROOT") + if platformRoot == "" { + platformRoot = "/Users/awitas/go/src/github.vianttech.com/viant/platform" + } + rulesRoot := filepath.Join(platformRoot, "e2e", "rule") + routesRoot := filepath.Join(platformRoot, "repo", "dev", "Datly", "routes") + if _, err := os.Stat(rulesRoot); err != nil { + t.Skipf("platform rules not found at %s", rulesRoot) + } + entries, err := collectRuleMappings(rulesRoot) + if err != nil { + t.Fatalf("collect mappings: %v", err) + } + if len(entries) == 0 { + t.Fatalf("no dql->route mappings found under %s", rulesRoot) + } + targetSource := strings.TrimSpace(os.Getenv("PLATFORM_PARITY_SOURCE")) + runAll := strings.EqualFold(targetSource, "all") || targetSource == "*" || strings.EqualFold(strings.TrimSpace(os.Getenv("PLATFORM_PARITY_ALL")), "1") + if targetSource == "" && !runAll { + t.Skip("set PLATFORM_PARITY_SOURCE to run transient platform parity check") + } + if !runAll { + var filtered []parityRule + for _, entry := range entries { + if entry.Source == targetSource { + filtered = append(filtered, entry) + } + } + if len(filtered) == 0 { + t.Fatalf("target source not found in rules: %s", targetSource) + } + entries = filtered + } + + compiler := shapecompile.New() + report := parityReport{Total: len(entries)} + issueCounts := map[string]int{} + + for _, entry := range entries { + eval := evaluateParityEntry(platformRoot, routesRoot, entry, compiler) + if !eval.SourceReadable { + continue + } + if eval.MissingLegacyYAML { + report.MissingYAML++ + continue + } + report.Compared++ + out := eval.Output + routeYAMLPath := out.LegacyYAML + if out.CompileFailed { + issueCounts["shape compile failed"]++ + report.Failures++ + writeIRFile(routeYAMLPath+".shape.ir.yaml", out) + report.WithDiff++ + continue + } + if len(out.Mismatches) > 0 { + report.WithDiff++ + for _, m := range out.Mismatches { + issueCounts[m]++ + } + } + writeIRFile(routeYAMLPath+".shape.ir.yaml", out) + } + + report.TopIssues = topIssues(issueCounts, 10) + reportPath := filepath.Join(routesRoot, "_shape_parity_report.yaml") + writeYAML(reportPath, report) + t.Logf("parity report: %s", reportPath) + t.Logf("total=%d compared=%d withDiff=%d missingYaml=%d failures=%d", report.Total, report.Compared, report.WithDiff, report.MissingYAML, report.Failures) +} + +func collectSmokeHandlerSources(entries []parityRule, routesRoot string) []string { + excluded := map[string]bool{} + var result []string + for _, entry := range entries { + source := strings.TrimSpace(entry.Source) + if !isHandlerLikeSource(source) { + continue + } + if excluded[source] { + continue + } + routeYAMLPath := filepath.Join(routesRoot, entry.Namespace, routeYAMLName(source)) + if _, err := os.Stat(routeYAMLPath); err != nil { + continue + } + result = append(result, source) + } + sort.Strings(result) + return dedupe(result) +} + +func isHandlerLikeSource(source string) bool { + source = strings.ToLower(strings.TrimSpace(source)) + if source == "" { + return false + } + if strings.Contains(source, "/gen/") && (strings.HasSuffix(source, ".dql") || strings.HasSuffix(source, ".sql")) { + return true + } + return strings.HasSuffix(source, "/patch.dql") || + strings.HasSuffix(source, "/patch.sql") || + strings.HasSuffix(source, "/post.dql") || + strings.HasSuffix(source, "/post.sql") || + strings.HasSuffix(source, "/put.dql") || + strings.HasSuffix(source, "/put.sql") || + strings.HasSuffix(source, "/delete.dql") || + strings.HasSuffix(source, "/delete.sql") || + strings.HasSuffix(source, "/upload.dql") || + strings.HasSuffix(source, "/upload.sql") || + strings.HasSuffix(source, "/export.dql") || + strings.HasSuffix(source, "/export.sql") || + strings.HasSuffix(source, "/action.dql") || + strings.HasSuffix(source, "/action.sql") +} + +func evaluateParityEntry(platformRoot, routesRoot string, entry parityRule, compiler *shapecompile.DQLCompiler) parityEntryEval { + sourcePath := filepath.Join(platformRoot, entry.Source) + routeYAMLPath, _ := resolveLegacyRouteYAMLPath(routesRoot, entry.Namespace, entry.Source) + if routeYAMLPath == "" { + routeYAMLPath = filepath.Join(routesRoot, entry.Namespace, routeYAMLName(entry.Source)) + } + out := parityEntryEval{Output: parityOutput{ + Namespace: entry.Namespace, + Source: entry.Source, + LegacyYAML: routeYAMLPath, + }} + sourceBytes, readErr := os.ReadFile(sourcePath) + if readErr != nil { + return out + } + out.SourceReadable = true + legacyBytes, legacyErr := os.ReadFile(routeYAMLPath) + if legacyErr != nil { + out.MissingLegacyYAML = true + return out + } + sourceName := strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) + if sourceName == "" { + sourceName = entry.Namespace + } + + var legacy legacyYAML + if err := yaml.Unmarshal(legacyBytes, &legacy); err == nil { + out.Output.LegacyMeta = &resourceMetaIR{ColumnsDiscovery: legacy.ColumnsDiscovery} + out.Output.LegacyViews = make([]viewIR, 0, len(legacy.Resource.Views)) + out.Output.LegacyViewMeta = make([]viewMetaIR, 0, len(legacy.Resource.Views)) + for _, v := range legacy.Resource.Views { + out.Output.LegacyViews = append(out.Output.LegacyViews, viewIR{ + Name: v.Name, + Table: v.Table, + Connector: v.Connector.Ref, + SQLURI: v.Template.SourceURL, + }) + var hasSummary *bool + if v.Template.Summary != nil { + value := true + hasSummary = &value + } + out.Output.LegacyViewMeta = append(out.Output.LegacyViewMeta, viewMetaIR{ + Name: strings.TrimSpace(v.Name), + Mode: strings.TrimSpace(v.Mode), + Module: strings.TrimSpace(v.Module), + AllowNulls: v.AllowNulls, + SelectorNamespace: strings.TrimSpace(v.Selector.Namespace), + SelectorNoLimit: v.Selector.NoLimit, + SchemaCardinality: strings.TrimSpace(v.Schema.Cardinality), + SchemaType: firstNonEmpty(strings.TrimSpace(v.Schema.DataType), strings.TrimSpace(v.Schema.Name)), + HasSummary: hasSummary, + }) + } + for _, r := range legacy.Routes { + out.Output.LegacyRoutes = append(out.Output.LegacyRoutes, routeIR{ + Method: r.Method, + URI: r.URI, + View: r.View.Ref, + }) + } + out.Output.LegacyTypeCtx = normalizeTypeContextIR( + legacy.TypeContext.DefaultPackage, + legacy.TypeContext.PackageDir, + legacy.TypeContext.PackageName, + legacy.TypeContext.PackagePath, + ) + out.Output.LegacyParams = normalizeLegacyParams(legacy) + out.Output.LegacyTypes = normalizeLegacyTypes(legacy) + } + + planResult, compileErr := compiler.Compile(context.Background(), &shape.Source{ + Name: sourceName, + Path: sourcePath, + Connector: entry.Connector, + DQL: string(sourceBytes), + }) + if compileErr != nil { + out.Output.CompileFailed = true + if cErr, ok := compileErr.(*shapecompile.CompileError); ok { + out.Output.RawDiagnostics = cErr.Diagnostics + for _, d := range cErr.Diagnostics { + if d == nil { + continue + } + out.Output.ShapeDiags = append(out.Output.ShapeDiags, d.Error()) + } + } else { + out.Output.ShapeDiags = append(out.Output.ShapeDiags, compileErr.Error()) + } + out.Output.Mismatches = append(out.Output.Mismatches, "shape compile failed") + return out + } + + planned, _ := planResult.Plan.(*plan.Result) + if planned != nil { + out.Output.ShapeMeta = &resourceMetaIR{} + if sourcePath != "" { + value := true + out.Output.ShapeMeta.ColumnsDiscovery = &value + } + out.Output.ShapeViews = make([]viewIR, 0, len(planned.Views)) + out.Output.ShapeViewMeta = make([]viewMetaIR, 0, len(planned.Views)) + for _, v := range planned.Views { + if v == nil { + continue + } + out.Output.ShapeViews = append(out.Output.ShapeViews, viewIR{ + Name: v.Name, + Table: v.Table, + Connector: v.Connector, + SQLURI: v.SQLURI, + }) + var hasSummary *bool + if strings.TrimSpace(v.Summary) != "" { + value := true + hasSummary = &value + } + out.Output.ShapeViewMeta = append(out.Output.ShapeViewMeta, viewMetaIR{ + Name: strings.TrimSpace(v.Name), + Mode: inferShapeViewMode(v.SQL), + Module: strings.TrimSpace(v.Module), + AllowNulls: v.AllowNulls, + SelectorNamespace: strings.TrimSpace(v.SelectorNamespace), + SelectorNoLimit: v.SelectorNoLimit, + SchemaCardinality: normalizeCardinality(strings.TrimSpace(v.Cardinality)), + SchemaType: strings.TrimSpace(v.SchemaType), + HasSummary: hasSummary, + }) + } + for _, d := range planned.Diagnostics { + if d == nil { + continue + } + out.Output.ShapeDiags = append(out.Output.ShapeDiags, d.Error()) + } + loader := shapeload.New() + if artifacts, err := loader.LoadViews(context.Background(), planResult); err == nil && artifacts != nil && artifacts.Resource != nil { + mergeShapeViewMetadata(out.Output.ShapeViewMeta, artifacts.Resource.Views) + } + out.Output.ShapeParams = normalizeShapeParams(planned) + out.Output.ShapeTypes = normalizeShapeTypes(planned, sourcePath) + if planned.TypeContext != nil { + out.Output.ShapeTypeCtx = normalizeTypeContextIR( + planned.TypeContext.DefaultPackage, + planned.TypeContext.PackageDir, + planned.TypeContext.PackageName, + planned.TypeContext.PackagePath, + ) + } + } + + out.Output.Mismatches = compareParity(out.Output.LegacyViews, out.Output.ShapeViews) + out.Output.Mismatches = append(out.Output.Mismatches, compareMetadataParity(out.Output.LegacyMeta, out.Output.ShapeMeta, out.Output.LegacyViewMeta, out.Output.ShapeViewMeta)...) + out.Output.Mismatches = append(out.Output.Mismatches, compareParamParity(out.Output.LegacyParams, out.Output.ShapeParams)...) + out.Output.Mismatches = append(out.Output.Mismatches, compareTypeParity(out.Output.LegacyTypes, out.Output.ShapeTypes)...) + out.Output.Mismatches = append(out.Output.Mismatches, compareTypeContextParity(out.Output.LegacyTypeCtx, out.Output.ShapeTypeCtx)...) + out.Output.Mismatches = dedupe(out.Output.Mismatches) + return out +} + +func resolveLegacyRouteYAMLPath(routesRoot, namespace, source string) (string, bool) { + candidates := legacyRouteYAMLCandidatePaths(routesRoot, namespace, source) + for _, candidate := range candidates { + if _, err := os.Stat(candidate); err == nil { + return candidate, true + } + } + return "", false +} + +func legacyRouteYAMLCandidatePaths(routesRoot, namespace, source string) []string { + namespace = strings.Trim(strings.TrimSpace(namespace), "/") + stem := strings.TrimSuffix(filepath.Base(strings.TrimSpace(source)), filepath.Ext(strings.TrimSpace(source))) + if stem == "" { + stem = "route" + } + fileName := stem + ".yaml" + nsPath := filepath.FromSlash(namespace) + leaf := filepath.Base(nsPath) + parent := filepath.Dir(nsPath) + + appendUnique := func(items *[]string, seen map[string]bool, path string) { + path = filepath.Clean(path) + if path == "." || path == "" || seen[path] { + return + } + seen[path] = true + *items = append(*items, path) + } + + seen := map[string]bool{} + result := make([]string, 0, 8) + appendUnique(&result, seen, filepath.Join(routesRoot, nsPath, fileName)) + appendUnique(&result, seen, filepath.Join(routesRoot, nsPath, stem, fileName)) + if leaf != "" && leaf != "." { + appendUnique(&result, seen, filepath.Join(routesRoot, nsPath, leaf+".yaml")) + } + if parent != "" && parent != "." { + appendUnique(&result, seen, filepath.Join(routesRoot, parent, fileName)) + appendUnique(&result, seen, filepath.Join(routesRoot, parent, stem, fileName)) + parentLeaf := filepath.Base(parent) + if parentLeaf != "" && parentLeaf != "." { + appendUnique(&result, seen, filepath.Join(routesRoot, parent, parentLeaf+".yaml")) + } + } + if strings.Contains(strings.ToLower(source), "/gen/") { + appendUnique(&result, seen, filepath.Join(routesRoot, nsPath, "patch", "patch.yaml")) + } + return result +} + +func collectRuleMappings(rulesRoot string) ([]parityRule, error) { + var files []string + if err := filepath.WalkDir(rulesRoot, func(path string, d os.DirEntry, err error) error { + if err != nil { + return err + } + if d.IsDir() || !strings.HasSuffix(path, ".yaml") { + return nil + } + files = append(files, path) + return nil + }); err != nil { + return nil, err + } + re := regexp.MustCompile(`\$appPath/bin/datly\s+(gen|translate)\s+.*-u=([^\s]+)\s+-s='([^']+)'(.*)`) + seen := map[string]bool{} + var result []parityRule + for _, file := range files { + data, err := os.ReadFile(file) + if err != nil { + continue + } + lines := strings.Split(string(data), "\n") + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + m := re.FindStringSubmatch(line) + if len(m) < 4 { + continue + } + src := strings.TrimSpace(m[3]) + if !(strings.HasSuffix(src, ".dql") || strings.HasSuffix(src, ".sql")) { + continue + } + connector := inferRuleConnector("") + if len(m) >= 5 { + connector = inferRuleConnector(m[4]) + } + key := m[2] + "|" + src + if seen[key] { + continue + } + seen[key] = true + result = append(result, parityRule{ + Mode: strings.TrimSpace(m[1]), + Namespace: strings.TrimSpace(m[2]), + Source: src, + Connector: connector, + }) + } + } + sort.Slice(result, func(i, j int) bool { + if result[i].Namespace == result[j].Namespace { + return result[i].Source < result[j].Source + } + return result[i].Namespace < result[j].Namespace + }) + return result, nil +} + +func inferRuleConnector(tail string) string { + lower := strings.ToLower(tail) + switch { + case strings.Contains(lower, "$optionsaero"): + return "system" + case strings.Contains(lower, "$optionssitemgmt"): + return "sitemgmt" + case strings.Contains(lower, "$options"): + return "ci_ads" + default: + return "" + } +} + +func routeYAMLName(source string) string { + base := filepath.Base(source) + ext := filepath.Ext(base) + return strings.TrimSuffix(base, ext) + ".yaml" +} + +func compareParity(legacy, shapeViews []viewIR) []string { + var result []string + if len(legacy) != len(shapeViews) { + result = append(result, "view count mismatch") + } + legacyByName := map[string]viewIR{} + for _, v := range legacy { + legacyByName[strings.ToLower(v.Name)] = v + } + for _, s := range shapeViews { + l, ok := legacyByName[strings.ToLower(s.Name)] + if !ok { + result = append(result, "missing view in legacy: "+s.Name) + continue + } + if l.Table != "" && s.Table != "" && !strings.EqualFold(l.Table, s.Table) { + result = append(result, "table mismatch: "+s.Name) + } + if l.Connector != "" && s.Connector == "" { + result = append(result, "connector missing in shape: "+s.Name) + } + if l.Connector != "" && s.Connector != "" && !strings.EqualFold(strings.TrimSpace(l.Connector), strings.TrimSpace(s.Connector)) { + result = append(result, "connector mismatch: "+s.Name) + } + if l.SQLURI != "" && s.SQLURI == "" { + result = append(result, "sql uri missing in shape: "+s.Name) + } + if l.SQLURI != "" && s.SQLURI != "" && !equalSQLURI(l.SQLURI, s.SQLURI) { + result = append(result, "sql uri mismatch: "+s.Name) + } + } + return dedupe(result) +} + +func equalSQLURI(legacy, shape string) bool { + normalize := func(v string) string { + v = strings.ReplaceAll(strings.TrimSpace(v), "\\", "/") + return strings.TrimPrefix(v, "./") + } + return strings.EqualFold(normalize(legacy), normalize(shape)) +} + +func normalizeCardinality(value string) string { + switch strings.ToLower(strings.TrimSpace(value)) { + case "one": + return "One" + case "many": + return "Many" + default: + return strings.TrimSpace(value) + } +} + +func inferShapeViewMode(sql string) string { + sql = strings.TrimSpace(sql) + if sql == "" { + return "" + } + statements := dqlstmt.New(sql) + hasRead := false + hasExec := false + for _, item := range statements { + if item == nil { + continue + } + switch item.Kind { + case dqlstmt.KindRead: + hasRead = true + case dqlstmt.KindExec: + hasExec = true + } + } + switch { + case hasRead && !hasExec: + return "SQLQuery" + case hasExec && !hasRead: + return "SQLExec" + case hasRead && hasExec: + return "SQLExec" + } + stmt := strings.ToLower(sql) + if strings.HasPrefix(stmt, "select") || strings.HasPrefix(stmt, "with") { + return "SQLQuery" + } + return "" +} + +func mergeShapeViewMetadata(meta []viewMetaIR, views view.Views) { + if len(meta) == 0 || len(views) == 0 { + return + } + index := map[string]int{} + for i, item := range meta { + index[strings.ToLower(strings.TrimSpace(item.Name))] = i + } + for _, candidate := range views { + if candidate == nil { + continue + } + key := strings.ToLower(strings.TrimSpace(candidate.Name)) + pos, ok := index[key] + if !ok { + continue + } + if mode := strings.TrimSpace(string(candidate.Mode)); mode != "" { + meta[pos].Mode = mode + } + if meta[pos].Module == "" { + meta[pos].Module = strings.TrimSpace(candidate.Module) + } + if meta[pos].AllowNulls == nil { + meta[pos].AllowNulls = candidate.AllowNulls + } + if candidate.Selector != nil { + if meta[pos].SelectorNamespace == "" { + meta[pos].SelectorNamespace = strings.TrimSpace(candidate.Selector.Namespace) + } + if meta[pos].SelectorNoLimit == nil { + meta[pos].SelectorNoLimit = &candidate.Selector.NoLimit + } + } + if candidate.Schema != nil { + if meta[pos].SchemaCardinality == "" { + meta[pos].SchemaCardinality = strings.TrimSpace(string(candidate.Schema.Cardinality)) + } + if meta[pos].SchemaType == "" { + meta[pos].SchemaType = firstNonEmpty(strings.TrimSpace(candidate.Schema.DataType), strings.TrimSpace(candidate.Schema.Name)) + } + } + if candidate.Template != nil && candidate.Template.Summary != nil { + value := true + meta[pos].HasSummary = &value + } + } +} + +func compareMetadataParity(legacyMeta, shapeMeta *resourceMetaIR, legacyViews, shapeViews []viewMetaIR) []string { + var result []string + if legacyMeta != nil && legacyMeta.ColumnsDiscovery != nil { + if shapeMeta == nil || shapeMeta.ColumnsDiscovery == nil { + result = append(result, "resource columnsDiscovery missing in shape") + } else if *legacyMeta.ColumnsDiscovery != *shapeMeta.ColumnsDiscovery { + result = append(result, "resource columnsDiscovery mismatch") + } + } + legacyByName := map[string]viewMetaIR{} + for _, item := range legacyViews { + legacyByName[strings.ToLower(strings.TrimSpace(item.Name))] = item + } + for _, shapeItem := range shapeViews { + key := strings.ToLower(strings.TrimSpace(shapeItem.Name)) + legacyItem, ok := legacyByName[key] + if !ok { + continue + } + if legacyItem.Mode != "" { + if shapeItem.Mode == "" { + result = append(result, "view mode missing in shape: "+shapeItem.Name) + } else if !strings.EqualFold(legacyItem.Mode, shapeItem.Mode) { + result = append(result, "view mode mismatch: "+shapeItem.Name) + } + } + if legacyItem.Module != "" { + if shapeItem.Module == "" { + result = append(result, "view module missing in shape: "+shapeItem.Name) + } else if !strings.EqualFold(strings.TrimSpace(legacyItem.Module), strings.TrimSpace(shapeItem.Module)) { + result = append(result, "view module mismatch: "+shapeItem.Name) + } + } + if legacyItem.AllowNulls != nil { + if shapeItem.AllowNulls == nil { + result = append(result, "view allowNulls missing in shape: "+shapeItem.Name) + } else if *legacyItem.AllowNulls != *shapeItem.AllowNulls { + result = append(result, "view allowNulls mismatch: "+shapeItem.Name) + } + } + if legacyItem.SelectorNamespace != "" { + if shapeItem.SelectorNamespace == "" { + result = append(result, "view selector namespace missing in shape: "+shapeItem.Name) + } else if !strings.EqualFold(strings.TrimSpace(legacyItem.SelectorNamespace), strings.TrimSpace(shapeItem.SelectorNamespace)) { + result = append(result, "view selector namespace mismatch: "+shapeItem.Name) + } + } + if legacyItem.SelectorNoLimit != nil { + if shapeItem.SelectorNoLimit == nil { + result = append(result, "view selector noLimit missing in shape: "+shapeItem.Name) + } else if *legacyItem.SelectorNoLimit != *shapeItem.SelectorNoLimit { + result = append(result, "view selector noLimit mismatch: "+shapeItem.Name) + } + } + if legacyItem.SchemaCardinality != "" { + if shapeItem.SchemaCardinality == "" { + result = append(result, "view schema cardinality missing in shape: "+shapeItem.Name) + } else if !strings.EqualFold(strings.TrimSpace(legacyItem.SchemaCardinality), strings.TrimSpace(shapeItem.SchemaCardinality)) { + result = append(result, "view schema cardinality mismatch: "+shapeItem.Name) + } + } + if legacyItem.SchemaType != "" { + if shapeItem.SchemaType == "" { + result = append(result, "view schema type missing in shape: "+shapeItem.Name) + } else if !strings.EqualFold(strings.TrimSpace(legacyItem.SchemaType), strings.TrimSpace(shapeItem.SchemaType)) { + result = append(result, "view schema type mismatch: "+shapeItem.Name) + } + } + if legacyItem.HasSummary != nil { + if shapeItem.HasSummary == nil { + result = append(result, "view template summary missing in shape: "+shapeItem.Name) + } else if *legacyItem.HasSummary != *shapeItem.HasSummary { + result = append(result, "view template summary mismatch: "+shapeItem.Name) + } + } + } + return dedupe(result) +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + value = strings.TrimSpace(value) + if value != "" { + return value + } + } + return "" +} + +func normalizeLegacyParams(legacy legacyYAML) []paramIR { + querySelectors := map[string]string{} + querySelectorCacheable := map[string]*bool{} + querySelectorIn := map[string]string{} + for _, v := range legacy.Resource.Views { + viewName := strings.TrimSpace(v.Name) + for _, param := range []selectorParam{v.Selector.LimitParameter, v.Selector.OffsetParameter, v.Selector.PageParameter, v.Selector.FieldsParameter, v.Selector.OrderByParameter} { + name := strings.TrimSpace(param.Name) + if name == "" || viewName == "" { + continue + } + querySelectors[strings.ToLower(name)] = viewName + querySelectorIn[strings.ToLower(name)] = strings.TrimSpace(param.In.Name) + if param.Cacheable != nil { + value := *param.Cacheable + querySelectorCacheable[strings.ToLower(name)] = &value + } + } + } + result := make([]paramIR, 0, len(legacy.Resource.Parameters)) + seen := map[string]bool{} + for _, p := range legacy.Resource.Parameters { + name := strings.TrimSpace(p.Name) + item := paramIR{ + Name: name, + Kind: strings.TrimSpace(p.In.Kind), + In: strings.TrimSpace(p.In.Name), + Required: p.Required, + Cacheable: p.Cacheable, + URI: strings.TrimSpace(p.URI), + Value: strings.TrimSpace(p.Value), + } + if selector, ok := querySelectors[strings.ToLower(name)]; ok { + item.QuerySelector = selector + if item.Cacheable == nil { + item.Cacheable = querySelectorCacheable[strings.ToLower(name)] + } + } + for _, pred := range p.Predicates { + item.Predicates = append(item.Predicates, normalizePredicateSig(pred.Group, pred.Name, pred.Ensure, pred.Args)) + } + sort.Strings(item.Predicates) + result = append(result, item) + seen[strings.ToLower(name)] = true + } + for key, selector := range querySelectors { + if seen[key] { + continue + } + name := strings.TrimSpace(key) + if name == "" { + continue + } + legacyName := name + for _, v := range legacy.Resource.Views { + for _, param := range []selectorParam{v.Selector.LimitParameter, v.Selector.OffsetParameter, v.Selector.PageParameter, v.Selector.FieldsParameter, v.Selector.OrderByParameter} { + if strings.EqualFold(strings.TrimSpace(param.Name), key) { + legacyName = strings.TrimSpace(param.Name) + break + } + } + } + result = append(result, paramIR{ + Name: legacyName, + Kind: "query", + In: strings.TrimSpace(querySelectorIn[key]), + QuerySelector: selector, + Cacheable: querySelectorCacheable[key], + }) + } + sort.Slice(result, func(i, j int) bool { + if strings.EqualFold(result[i].Name, result[j].Name) { + if strings.EqualFold(result[i].Kind, result[j].Kind) { + return strings.ToLower(result[i].In) < strings.ToLower(result[j].In) + } + return strings.ToLower(result[i].Kind) < strings.ToLower(result[j].Kind) + } + return strings.ToLower(result[i].Name) < strings.ToLower(result[j].Name) + }) + return result +} + +func normalizeLegacyTypes(legacy legacyYAML) []typeIR { + if len(legacy.Resource.Types) == 0 { + return nil + } + result := make([]typeIR, 0, len(legacy.Resource.Types)) + seen := map[string]bool{} + for _, item := range legacy.Resource.Types { + name := strings.TrimSpace(item.Name) + if name == "" { + continue + } + key := strings.ToLower(name) + if seen[key] { + continue + } + seen[key] = true + result = append(result, typeIR{ + Name: name, + Alias: strings.TrimSpace(item.Alias), + DataType: normalizeTypeSignature(item.DataType), + Cardinality: normalizeCardinality(strings.TrimSpace(item.Cardinality)), + Package: strings.TrimSpace(item.Package), + ModulePath: strings.TrimSpace(item.ModulePath), + }) + } + sort.Slice(result, func(i, j int) bool { + return strings.ToLower(result[i].Name) < strings.ToLower(result[j].Name) + }) + return result +} + +func normalizeShapeTypes(planned *plan.Result, sourcePath string) []typeIR { + if planned == nil { + return nil + } + modulePrefix := inferModulePrefix(sourcePath) + typeImportByAlias, typeImportByPkg := typeImports(planned) + byName := map[string]typeIR{} + + register := func(item typeIR, overwrite bool) { + name := strings.TrimSpace(item.Name) + if name == "" { + return + } + key := strings.ToLower(name) + if existing, ok := byName[key]; ok { + if (overwrite || existing.DataType == "") && item.DataType != "" { + existing.DataType = item.DataType + } + if (overwrite || existing.Cardinality == "") && item.Cardinality != "" { + existing.Cardinality = item.Cardinality + } + if (overwrite || existing.Package == "") && item.Package != "" { + existing.Package = item.Package + } + if (overwrite || existing.ModulePath == "") && item.ModulePath != "" { + existing.ModulePath = item.ModulePath + } + if overwrite && item.Alias != "" { + existing.Alias = item.Alias + } + byName[key] = existing + return + } + byName[key] = item + } + + for _, item := range planned.Views { + if item == nil { + continue + } + dataType := strings.TrimSpace(item.SchemaType) + name := typeNameFromDataType(dataType) + if name == "" && item.ElementType != nil { + name = strings.TrimSpace(item.ElementType.Name()) + if dataType == "" && name != "" { + dataType = "*" + name + } + } + if name == "" { + continue + } + pkg := packageFromDataType(dataType) + modulePath := "" + if strings.TrimSpace(item.Module) != "" && modulePrefix != "" { + modulePath = modulePrefix + strings.Trim(strings.TrimSpace(item.Module), "/") + } + if modulePath == "" && pkg != "" { + modulePath = firstNonEmpty(typeImportByAlias[strings.ToLower(pkg)], typeImportByPkg[strings.ToLower(pkg)]) + } + register(typeIR{ + Name: name, + DataType: normalizeTypeSignature(dataType), + Cardinality: normalizeCardinality(strings.TrimSpace(item.Cardinality)), + Package: pkg, + ModulePath: modulePath, + }, false) + } + + for _, item := range planned.States { + if item == nil || strings.TrimSpace(item.DataType) == "" { + continue + } + dataType := strings.TrimSpace(item.DataType) + name := typeNameFromDataType(dataType) + if name == "" { + continue + } + pkg := packageFromDataType(dataType) + modulePath := firstNonEmpty(typeImportByAlias[strings.ToLower(pkg)], typeImportByPkg[strings.ToLower(pkg)]) + register(typeIR{ + Name: name, + DataType: normalizeTypeSignature(dataType), + Package: pkg, + ModulePath: modulePath, + }, false) + } + for _, item := range planned.Types { + if item == nil || strings.TrimSpace(item.Name) == "" { + continue + } + register(typeIR{ + Name: strings.TrimSpace(item.Name), + Alias: strings.TrimSpace(item.Alias), + DataType: normalizeTypeSignature(item.DataType), + Cardinality: normalizeCardinality(strings.TrimSpace(item.Cardinality)), + Package: strings.TrimSpace(item.Package), + ModulePath: strings.TrimSpace(item.ModulePath), + }, true) + } + + result := make([]typeIR, 0, len(byName)) + for _, item := range byName { + result = append(result, item) + } + sort.Slice(result, func(i, j int) bool { + return strings.ToLower(result[i].Name) < strings.ToLower(result[j].Name) + }) + return result +} + +func compareTypeParity(legacy, shapeTypes []typeIR) []string { + var result []string + if len(legacy) == 0 { + return nil + } + shapeByName := map[string]typeIR{} + for _, item := range shapeTypes { + shapeByName[strings.ToLower(strings.TrimSpace(item.Name))] = item + } + for _, legacyType := range legacy { + key := strings.ToLower(strings.TrimSpace(legacyType.Name)) + shapeType, ok := shapeByName[key] + if !ok { + result = append(result, "missing type in shape: "+legacyType.Name) + continue + } + if legacyType.DataType != "" && shapeType.DataType != "" && legacyType.DataType != shapeType.DataType { + result = append(result, "type dataType mismatch: "+legacyType.Name) + } + if legacyType.Cardinality != "" && shapeType.Cardinality != "" && !strings.EqualFold(legacyType.Cardinality, shapeType.Cardinality) { + result = append(result, "type cardinality mismatch: "+legacyType.Name) + } + if legacyType.Package != "" && shapeType.Package != "" && !strings.EqualFold(legacyType.Package, shapeType.Package) { + result = append(result, "type package mismatch: "+legacyType.Name) + } + if legacyType.ModulePath != "" && shapeType.ModulePath != "" && !strings.EqualFold(legacyType.ModulePath, shapeType.ModulePath) { + result = append(result, "type module path mismatch: "+legacyType.Name) + } + if legacyType.Alias != "" && shapeType.Alias != "" && !strings.EqualFold(legacyType.Alias, shapeType.Alias) { + result = append(result, "type alias mismatch: "+legacyType.Name) + } + } + return dedupe(result) +} + +func normalizeTypeContextIR(defaultPackage, packageDir, packageName, packagePath string) *typeCtxIR { + ret := &typeCtxIR{ + DefaultPackage: strings.TrimSpace(defaultPackage), + PackageDir: strings.TrimSpace(packageDir), + PackageName: strings.TrimSpace(packageName), + PackagePath: strings.TrimSpace(packagePath), + } + if ret.DefaultPackage == "" && ret.PackageDir == "" && ret.PackageName == "" && ret.PackagePath == "" { + return nil + } + return ret +} + +func compareTypeContextParity(legacy, shape *typeCtxIR) []string { + if legacy == nil { + return nil + } + if shape == nil { + return []string{"missing type context in shape"} + } + var result []string + if legacy.DefaultPackage != "" && shape.DefaultPackage != "" && !strings.EqualFold(legacy.DefaultPackage, shape.DefaultPackage) { + result = append(result, "type context default package mismatch") + } + if legacy.PackageDir != "" && shape.PackageDir != "" && !strings.EqualFold(legacy.PackageDir, shape.PackageDir) { + result = append(result, "type context package dir mismatch") + } + if legacy.PackageName != "" && shape.PackageName != "" && !strings.EqualFold(legacy.PackageName, shape.PackageName) { + result = append(result, "type context package name mismatch") + } + if legacy.PackagePath != "" && shape.PackagePath != "" && !strings.EqualFold(legacy.PackagePath, shape.PackagePath) { + result = append(result, "type context package path mismatch") + } + return dedupe(result) +} + +func normalizeTypeSignature(value string) string { + value = strings.TrimSpace(value) + if value == "" { + return "" + } + parts := strings.Fields(value) + return strings.Join(parts, " ") +} + +func typeNameFromDataType(dataType string) string { + dataType = strings.TrimSpace(dataType) + if dataType == "" { + return "" + } + dataType = strings.TrimLeft(dataType, "*[]") + if dataType == "" { + return "" + } + if idx := strings.LastIndex(dataType, "."); idx != -1 { + dataType = dataType[idx+1:] + } + if idx := strings.Index(dataType, "{"); idx != -1 { + dataType = dataType[:idx] + } + return strings.TrimSpace(dataType) +} + +func packageFromDataType(dataType string) string { + dataType = strings.TrimSpace(dataType) + dataType = strings.TrimLeft(dataType, "*[]") + if idx := strings.LastIndex(dataType, "."); idx != -1 { + return strings.TrimSpace(dataType[:idx]) + } + return "" +} + +func inferModulePrefix(sourcePath string) string { + normalized := filepath.ToSlash(strings.TrimSpace(sourcePath)) + if normalized == "" { + return "" + } + const marker = "/src/" + idx := strings.Index(normalized, marker) + if idx == -1 { + return "" + } + root := normalized[idx+len(marker):] + if slash := strings.Index(root, "/dql/"); slash != -1 { + root = root[:slash] + } + root = strings.Trim(root, "/") + if root == "" { + return "" + } + return root + "/pkg/" +} + +func typeImports(planned *plan.Result) (map[string]string, map[string]string) { + byAlias := map[string]string{} + byPkg := map[string]string{} + if planned == nil || planned.TypeContext == nil { + return byAlias, byPkg + } + appendPkg := func(pkg string) { + pkg = strings.TrimSpace(pkg) + if pkg == "" { + return + } + base := pkg + if idx := strings.LastIndex(base, "/"); idx != -1 { + base = base[idx+1:] + } + base = strings.ToLower(strings.TrimSpace(base)) + if base != "" { + byPkg[base] = pkg + } + } + if packagePath := strings.TrimSpace(planned.TypeContext.PackagePath); packagePath != "" { + appendPkg(packagePath) + if pkgName := strings.ToLower(strings.TrimSpace(planned.TypeContext.PackageName)); pkgName != "" { + byAlias[pkgName] = packagePath + byPkg[pkgName] = packagePath + } + } + appendPkg(planned.TypeContext.DefaultPackage) + for _, item := range planned.TypeContext.Imports { + pkg := strings.TrimSpace(item.Package) + if pkg == "" { + continue + } + if alias := strings.ToLower(strings.TrimSpace(item.Alias)); alias != "" { + byAlias[alias] = pkg + } + appendPkg(pkg) + } + return byAlias, byPkg +} + +func normalizeShapeParams(planned *plan.Result) []paramIR { + if planned == nil || len(planned.States) == 0 { + return nil + } + result := make([]paramIR, 0, len(planned.States)) + for _, s := range planned.States { + if s == nil { + continue + } + item := paramIR{ + Name: strings.TrimSpace(s.Name), + Kind: strings.TrimSpace(s.Kind), + In: strings.TrimSpace(s.In), + Required: s.Required, + Cacheable: s.Cacheable, + URI: strings.TrimSpace(s.URI), + Value: strings.TrimSpace(s.Value), + QuerySelector: strings.TrimSpace(s.QuerySelector), + } + for _, pred := range s.Predicates { + if pred == nil { + continue + } + item.Predicates = append(item.Predicates, normalizePredicateSig(pred.Group, pred.Name, pred.Ensure, pred.Arguments)) + } + sort.Strings(item.Predicates) + result = append(result, item) + } + sort.Slice(result, func(i, j int) bool { + if strings.EqualFold(result[i].Name, result[j].Name) { + if strings.EqualFold(result[i].Kind, result[j].Kind) { + return strings.ToLower(result[i].In) < strings.ToLower(result[j].In) + } + return strings.ToLower(result[i].Kind) < strings.ToLower(result[j].Kind) + } + return strings.ToLower(result[i].Name) < strings.ToLower(result[j].Name) + }) + return result +} + +func normalizePredicateSig(group int, name string, ensure bool, args []string) string { + parts := make([]string, 0, len(args)) + for _, arg := range args { + parts = append(parts, strings.TrimSpace(arg)) + } + return strings.ToLower(strings.TrimSpace(name)) + "|" + strconv.Itoa(group) + "|" + strconv.FormatBool(ensure) + "|" + strings.Join(parts, ",") +} + +func compareParamParity(legacy, shapeParams []paramIR) []string { + var result []string + legacyByKey := map[string]paramIR{} + for _, item := range filterComparableParams(legacy) { + legacyByKey[paramKey(item)] = item + } + shapeByKey := map[string]paramIR{} + for _, item := range filterComparableParams(shapeParams) { + shapeByKey[paramKey(item)] = item + } + if len(legacyByKey) != len(shapeByKey) { + result = append(result, "parameter count mismatch") + } + for key, legacyItem := range legacyByKey { + shapeItem, ok := shapeByKey[key] + if !ok { + result = append(result, "missing parameter in shape: "+legacyItem.Name) + continue + } + if legacyItem.Required != nil && shapeItem.Required != nil && *legacyItem.Required != *shapeItem.Required { + result = append(result, "parameter required mismatch: "+legacyItem.Name) + } + if legacyItem.Cacheable != nil && shapeItem.Cacheable != nil && *legacyItem.Cacheable != *shapeItem.Cacheable { + result = append(result, "parameter cacheable mismatch: "+legacyItem.Name) + } + if legacyItem.QuerySelector != "" && !strings.EqualFold(legacyItem.QuerySelector, shapeItem.QuerySelector) { + result = append(result, "parameter query selector mismatch: "+legacyItem.Name) + } + if legacyItem.URI != "" && !strings.EqualFold(strings.TrimSpace(legacyItem.URI), strings.TrimSpace(shapeItem.URI)) { + result = append(result, "parameter uri mismatch: "+legacyItem.Name) + } + if len(legacyItem.Predicates) != len(shapeItem.Predicates) { + result = append(result, "parameter predicates count mismatch: "+legacyItem.Name) + continue + } + for i := range legacyItem.Predicates { + if legacyItem.Predicates[i] != shapeItem.Predicates[i] { + result = append(result, "parameter predicate mismatch: "+legacyItem.Name) + break + } + } + } + return dedupe(result) +} + +func paramKey(item paramIR) string { + kind := strings.ToLower(strings.TrimSpace(item.Kind)) + in := strings.ToLower(strings.TrimSpace(item.In)) + if kind == "component" { + in = normalizeComponentRef(in) + } + return strings.ToLower(strings.TrimSpace(item.Name)) + "|" + kind + "|" + in +} + +func normalizeComponentRef(in string) string { + in = strings.TrimSpace(strings.TrimPrefix(in, "get:")) + if in == "" { + return in + } + in = strings.TrimPrefix(in, "../") + in = strings.TrimPrefix(in, "./") + in = strings.TrimPrefix(in, "/") + if idx := strings.LastIndex(in, "/"); idx != -1 { + return in[idx+1:] + } + return in +} + +func filterComparableParams(items []paramIR) []paramIR { + if len(items) == 0 { + return nil + } + result := make([]paramIR, 0, len(items)) + for _, item := range items { + kind := strings.ToLower(strings.TrimSpace(item.Kind)) + switch kind { + case "output", "meta", "async": + continue + default: + result = append(result, item) + } + } + return result +} + +func dedupe(items []string) []string { + if len(items) == 0 { + return nil + } + seen := map[string]bool{} + var ret []string + for _, item := range items { + if item == "" || seen[item] { + continue + } + seen[item] = true + ret = append(ret, item) + } + sort.Strings(ret) + return ret +} + +func topIssues(counter map[string]int, limit int) []string { + type pair struct { + Issue string + Count int + } + var list []pair + for issue, count := range counter { + list = append(list, pair{Issue: issue, Count: count}) + } + sort.Slice(list, func(i, j int) bool { + if list[i].Count == list[j].Count { + return list[i].Issue < list[j].Issue + } + return list[i].Count > list[j].Count + }) + if len(list) > limit { + list = list[:limit] + } + var ret []string + for _, item := range list { + ret = append(ret, item.Issue) + } + return ret +} + +func writeIRFile(path string, v parityOutput) { + _ = os.MkdirAll(filepath.Dir(path), 0o755) + writeYAML(path, v) +} + +func writeYAML(path string, v interface{}) { + data, err := yaml.Marshal(v) + if err != nil { + return + } + _ = os.WriteFile(path, data, 0o644) +} diff --git a/repository/shape/platform_parity_types_test.go b/repository/shape/platform_parity_types_test.go new file mode 100644 index 00000000..76341f3f --- /dev/null +++ b/repository/shape/platform_parity_types_test.go @@ -0,0 +1,86 @@ +package shape_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" +) + +func TestNormalizeTypeSignature(t *testing.T) { + assert.Equal(t, "struct{ Id int; Name string }", normalizeTypeSignature(" struct{ Id int; Name string } ")) +} + +func TestTypeNameFromDataType(t *testing.T) { + assert.Equal(t, "TvAffiliateStationView", typeNameFromDataType("*tvaffiliatestation.TvAffiliateStationView")) + assert.Equal(t, "Output", typeNameFromDataType("*Output")) + assert.Equal(t, "struct", typeNameFromDataType("struct{Id int}")) +} + +func TestCompareTypeParity(t *testing.T) { + legacy := []typeIR{{ + Name: "TvAffiliateStationView", + DataType: "*tvaffiliatestation.TvAffiliateStationView", + Package: "tvaffiliatestation", + ModulePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + }} + shapeTypes := []typeIR{{ + Name: "TvAffiliateStationView", + DataType: "*tvaffiliatestation.TvAffiliateStationView", + Package: "tvaffiliatestation", + ModulePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + }} + assert.Empty(t, compareTypeParity(legacy, shapeTypes)) +} + +func TestNormalizeShapeTypes(t *testing.T) { + planned := &plan.Result{ + TypeContext: &typectx.Context{ + PackagePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + PackageName: "tvaffiliatestation", + }, + Views: []*plan.View{ + { + Name: "tvAffiliateStation", + Module: "platform/tvaffiliatestation", + SchemaType: "*tvaffiliatestation.TvAffiliateStationView", + Cardinality: "many", + }, + }, + } + actual := normalizeShapeTypes(planned, "/Users/awitas/go/src/github.vianttech.com/viant/platform/dql/platform/tvaffiliatestation/tvaffiliatestation.dql") + if assert.Len(t, actual, 1) { + assert.Equal(t, "TvAffiliateStationView", actual[0].Name) + assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", actual[0].ModulePath) + assert.Equal(t, "Many", actual[0].Cardinality) + } +} + +func TestTypeImports_UsesTypeContextPackagePath(t *testing.T) { + planned := &plan.Result{ + TypeContext: &typectx.Context{ + PackagePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + PackageName: "tvaffiliatestation", + }, + } + byAlias, byPkg := typeImports(planned) + assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", byAlias["tvaffiliatestation"]) + assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", byPkg["tvaffiliatestation"]) +} + +func TestCompareTypeContextParity(t *testing.T) { + legacy := &typeCtxIR{ + DefaultPackage: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + PackageDir: "pkg/platform/tvaffiliatestation", + PackageName: "tvaffiliatestation", + PackagePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + } + shape := &typeCtxIR{ + DefaultPackage: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + PackageDir: "pkg/platform/tvaffiliatestation", + PackageName: "tvaffiliatestation", + PackagePath: "github.vianttech.com/viant/platform/pkg/platform/tvaffiliatestation", + } + assert.Empty(t, compareTypeContextParity(legacy, shape)) +} diff --git a/repository/shape/shape.go b/repository/shape/shape.go index 570a63d5..5f7f766d 100644 --- a/repository/shape/shape.go +++ b/repository/shape/shape.go @@ -3,6 +3,11 @@ package shape import "context" type ( + CompileMixedMode string + CompileUnknownNonReadMode string + CompileProfile string + CompileColumnDiscoveryMode string + // Scanner discovers shape descriptors from Source. Scanner interface { Scan(ctx context.Context, source *Source, opts ...ScanOption) (*ScanResult, error) @@ -33,7 +38,19 @@ type ( ScanOptions struct{} PlanOptions struct{} LoadOptions struct{} - CompileOptions struct{} + CompileOptions struct { + Strict bool + Profile CompileProfile + MixedMode CompileMixedMode + UnknownNonReadMode CompileUnknownNonReadMode + ColumnDiscoveryMode CompileColumnDiscoveryMode + DQLPathMarker string + RoutesRelativePath string + TypePackageDir string + TypePackageName string + TypePackagePath string + InferTypeContext *bool + } ScanOption func(*ScanOptions) PlanOption func(*PlanOptions) @@ -41,6 +58,22 @@ type ( CompileOption func(*CompileOptions) ) +const ( + CompileMixedModeExecWins CompileMixedMode = "exec_wins" + CompileMixedModeReadWins CompileMixedMode = "read_wins" + CompileMixedModeErrorOnMixed CompileMixedMode = "error_on_mixed" + + CompileUnknownNonReadWarn CompileUnknownNonReadMode = "warn" + CompileUnknownNonReadError CompileUnknownNonReadMode = "error" + + CompileProfileCompat CompileProfile = "compat" + CompileProfileStrict CompileProfile = "strict" + + CompileColumnDiscoveryAuto CompileColumnDiscoveryMode = "auto" + CompileColumnDiscoveryOn CompileColumnDiscoveryMode = "on" + CompileColumnDiscoveryOff CompileColumnDiscoveryMode = "off" +) + // Engine is a thin facade over scan -> plan -> load pipeline. type Engine struct { options *Options @@ -139,7 +172,15 @@ func (e *Engine) compile(ctx context.Context, source *Source) (*PlanResult, erro if e.options.Compiler == nil { return nil, ErrCompilerNotConfigured } - return e.options.Compiler.Compile(ctx, source) + return e.options.Compiler.Compile( + ctx, + source, + WithCompileStrict(e.options.Strict), + WithCompileProfile(e.options.CompileProfile), + WithMixedMode(e.options.CompileMixedMode), + WithUnknownNonReadMode(e.options.UnknownNonReadMode), + WithColumnDiscoveryMode(e.options.ColumnDiscoveryMode), + ) } func (e *Engine) scanAndPlan(ctx context.Context, source *Source) (*PlanResult, error) { diff --git a/repository/shape/typectx/context.go b/repository/shape/typectx/context.go new file mode 100644 index 00000000..072e22b9 --- /dev/null +++ b/repository/shape/typectx/context.go @@ -0,0 +1,89 @@ +package typectx + +import ( + "path" + "strings" +) + +// ValidationIssue captures context consistency problems. +type ValidationIssue struct { + Field string + Message string +} + +// Normalize trims and canonicalizes context fields. +func Normalize(input *Context) *Context { + if input == nil { + return nil + } + ret := &Context{ + DefaultPackage: strings.TrimSpace(input.DefaultPackage), + PackageDir: cleanSlashes(strings.TrimSpace(input.PackageDir)), + PackageName: strings.TrimSpace(input.PackageName), + PackagePath: cleanSlashes(strings.TrimSpace(input.PackagePath)), + } + if ret.PackageName == "" { + if ret.PackagePath != "" { + ret.PackageName = path.Base(ret.PackagePath) + } else if ret.PackageDir != "" { + ret.PackageName = path.Base(ret.PackageDir) + } + } + if ret.DefaultPackage == "" && ret.PackagePath != "" { + ret.DefaultPackage = ret.PackagePath + } + for _, item := range input.Imports { + pkg := cleanSlashes(strings.TrimSpace(item.Package)) + if pkg == "" { + continue + } + alias := strings.TrimSpace(item.Alias) + if alias == "" { + alias = path.Base(pkg) + } + ret.Imports = append(ret.Imports, Import{ + Alias: alias, + Package: pkg, + }) + } + if ret.DefaultPackage == "" && + len(ret.Imports) == 0 && + ret.PackageDir == "" && + ret.PackageName == "" && + ret.PackagePath == "" { + return nil + } + return ret +} + +// Validate checks context consistency. +func Validate(ctx *Context) []ValidationIssue { + ctx = Normalize(ctx) + if ctx == nil { + return nil + } + var result []ValidationIssue + if strings.Contains(ctx.PackageName, "/") { + result = append(result, ValidationIssue{ + Field: "PackageName", + Message: "package name must not contain path separators", + }) + } + if ctx.PackagePath != "" && strings.Contains(ctx.PackagePath, ".") { + base := path.Base(ctx.PackagePath) + if ctx.PackageName != "" && base != ctx.PackageName { + result = append(result, ValidationIssue{ + Field: "PackagePath", + Message: "package path basename differs from package name", + }) + } + } + return result +} + +func cleanSlashes(value string) string { + value = strings.ReplaceAll(value, "\\", "/") + value = strings.TrimSpace(value) + value = strings.Trim(value, "/") + return value +} diff --git a/repository/shape/typectx/context_test.go b/repository/shape/typectx/context_test.go new file mode 100644 index 00000000..32570944 --- /dev/null +++ b/repository/shape/typectx/context_test.go @@ -0,0 +1,31 @@ +package typectx + +import "testing" + +func TestNormalize_FillsPackageFields(t *testing.T) { + ctx := Normalize(&Context{ + PackageDir: "pkg/platform/taxonomy", + PackagePath: "github.vianttech.com/viant/platform/pkg/platform/taxonomy", + }) + if ctx == nil { + t.Fatalf("expected normalized context") + } + if ctx.PackageName != "taxonomy" { + t.Fatalf("expected package name taxonomy, got %q", ctx.PackageName) + } + if ctx.DefaultPackage != "github.vianttech.com/viant/platform/pkg/platform/taxonomy" { + t.Fatalf("expected default package from package path, got %q", ctx.DefaultPackage) + } +} + +func TestValidate_DetectsInvalidPackageName(t *testing.T) { + issues := Validate(&Context{ + PackageName: "platform/taxonomy", + }) + if len(issues) == 0 { + t.Fatalf("expected validation issue") + } + if issues[0].Field != "PackageName" { + t.Fatalf("expected PackageName issue, got %q", issues[0].Field) + } +} diff --git a/repository/shape/typectx/model.go b/repository/shape/typectx/model.go index ae76febe..acc03ca1 100644 --- a/repository/shape/typectx/model.go +++ b/repository/shape/typectx/model.go @@ -10,6 +10,9 @@ type Import struct { type Context struct { DefaultPackage string `json:",omitempty" yaml:",omitempty"` Imports []Import `json:",omitempty" yaml:",omitempty"` + PackageDir string `json:",omitempty" yaml:",omitempty"` + PackageName string `json:",omitempty" yaml:",omitempty"` + PackagePath string `json:",omitempty" yaml:",omitempty"` } // Provenance tracks where a resolved type came from. diff --git a/repository/shape/typectx/resolver.go b/repository/shape/typectx/resolver.go index daccf3b4..892c9ef6 100644 --- a/repository/shape/typectx/resolver.go +++ b/repository/shape/typectx/resolver.go @@ -2,7 +2,6 @@ package typectx import ( "fmt" - "path" "sort" "strings" @@ -115,6 +114,9 @@ func (r *Resolver) aliasPackage(alias string) string { return item.Package } } + if r.context.PackageName != "" && r.context.PackagePath != "" && r.context.PackageName == alias { + return r.context.PackagePath + } return "" } @@ -177,6 +179,7 @@ func (r *Resolver) searchPackages() []scopedPackage { seen[pkg] = true result = append(result, scopedPackage{pkg: pkg, matchKind: matchKind}) } + appendPkg(r.context.PackagePath, "package_path") appendPkg(r.context.DefaultPackage, "default_package") for _, item := range r.context.Imports { appendPkg(item.Package, "import_package") @@ -237,30 +240,7 @@ func packageOf(key string) string { } func normalizeContext(input *Context) *Context { - if input == nil { - return nil - } - ret := &Context{ - DefaultPackage: strings.TrimSpace(input.DefaultPackage), - } - for _, item := range input.Imports { - pkg := strings.TrimSpace(item.Package) - if pkg == "" { - continue - } - alias := strings.TrimSpace(item.Alias) - if alias == "" { - alias = path.Base(pkg) - } - ret.Imports = append(ret.Imports, Import{ - Alias: alias, - Package: pkg, - }) - } - if ret.DefaultPackage == "" && len(ret.Imports) == 0 { - return nil - } - return ret + return Normalize(input) } func splitQualified(value string) (prefix string, name string, alias bool, qualified bool) { diff --git a/repository/shape/typectx/resolver_matrix_test.go b/repository/shape/typectx/resolver_matrix_test.go new file mode 100644 index 00000000..473ba368 --- /dev/null +++ b/repository/shape/typectx/resolver_matrix_test.go @@ -0,0 +1,86 @@ +package typectx + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/x" +) + +type matrixOrderDefault struct{} +type matrixOrderImport struct{} +type matrixOrderPkgPath struct{} +type matrixOrderAliasImport struct{} + +func TestResolver_ResolutionMatrix(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(matrixOrderDefault{}), x.WithPkgPath("github.com/acme/default"), x.WithName("Order"))) + reg.Register(x.NewType(reflect.TypeOf(matrixOrderImport{}), x.WithPkgPath("github.com/acme/imported"), x.WithName("ImportedOrder"))) + reg.Register(x.NewType(reflect.TypeOf(matrixOrderPkgPath{}), x.WithPkgPath("github.com/acme/pkgpath"), x.WithName("Order"))) + reg.Register(x.NewType(reflect.TypeOf(matrixOrderAliasImport{}), x.WithPkgPath("github.com/acme/alias/import"), x.WithName("Order"))) + + testCases := []struct { + name string + context *Context + expr string + wantKey string + ambiguous bool + }{ + { + name: "only default/imports", + context: &Context{ + DefaultPackage: "github.com/acme/default", + Imports: []Import{{Alias: "imp", Package: "github.com/acme/imported"}}, + }, + expr: "Order", + wantKey: "github.com/acme/default.Order", + }, + { + name: "only package triple", + context: &Context{ + PackagePath: "github.com/acme/pkgpath", + PackageName: "pkgpath", + PackageDir: "pkg/pkgpath", + }, + expr: "Order", + wantKey: "github.com/acme/pkgpath.Order", + }, + { + name: "default and package path conflict", + context: &Context{ + DefaultPackage: "github.com/acme/default", + PackagePath: "github.com/acme/pkgpath", + PackageName: "pkgpath", + }, + expr: "Order", + ambiguous: true, + }, + { + name: "alias import wins over package-name fallback", + context: &Context{ + PackagePath: "github.com/acme/pkgpath", + PackageName: "same", + Imports: []Import{{Alias: "same", Package: "github.com/acme/alias/import"}}, + }, + expr: "same.Order", + wantKey: "github.com/acme/alias/import.Order", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + resolver := NewResolver(reg, testCase.context) + key, err := resolver.Resolve(testCase.expr) + if testCase.ambiguous { + require.Error(t, err) + _, ok := err.(*AmbiguityError) + require.True(t, ok) + require.Empty(t, key) + return + } + require.NoError(t, err) + require.Equal(t, testCase.wantKey, key) + }) + } +} diff --git a/repository/shape/typectx/resolver_test.go b/repository/shape/typectx/resolver_test.go index f1e8e676..632a785d 100644 --- a/repository/shape/typectx/resolver_test.go +++ b/repository/shape/typectx/resolver_test.go @@ -87,3 +87,28 @@ func TestResolver_ResolveWithProvenance(t *testing.T) { require.Equal(t, "/repo/mdp/performance/order.go", resolved.Provenance.File) require.Equal(t, "resource_type", resolved.Provenance.Kind) } + +func TestResolver_Resolve_Unqualified_PackagePath(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/mdp/performance"), x.WithName("Order"))) + resolver := NewResolver(reg, &Context{PackagePath: "github.com/acme/mdp/performance"}) + + resolved, err := resolver.ResolveWithProvenance("Order") + require.NoError(t, err) + require.NotNil(t, resolved) + require.Equal(t, "github.com/acme/mdp/performance.Order", resolved.ResolvedKey) + require.Equal(t, "package_path", resolved.MatchKind) +} + +func TestResolver_Resolve_Qualified_PackageNameFallback(t *testing.T) { + reg := x.NewRegistry() + reg.Register(x.NewType(reflect.TypeOf(resolveOrder{}), x.WithPkgPath("github.com/acme/mdp/performance"), x.WithName("Order"))) + resolver := NewResolver(reg, &Context{ + PackageName: "performance", + PackagePath: "github.com/acme/mdp/performance", + }) + + key, err := resolver.Resolve("performance.Order") + require.NoError(t, err) + require.Equal(t, "github.com/acme/mdp/performance.Order", key) +} diff --git a/repository/shape/xgen/generator.go b/repository/shape/xgen/generator.go index 89622576..d0fc419a 100644 --- a/repository/shape/xgen/generator.go +++ b/repository/shape/xgen/generator.go @@ -26,6 +26,7 @@ func GenerateFromDQLShape(doc *shape.Document, cfg *Config) (*Result, error) { if cfg == nil { cfg = &Config{} } + hydrateConfigFromTypeContext(doc, cfg) applyDefaults(cfg) projectDir, packageDir, err := resolvePaths(cfg.ProjectDir, cfg.PackageDir) if err != nil { @@ -131,10 +132,14 @@ func rewriteSafetyIssues(doc *shape.Document, cfg *Config, projectDir string) [] UseGOPATHFallback: policy.useGOPATH, }) var issues []string - for _, resolution := range doc.TypeResolutions { + for i := range doc.TypeResolutions { + resolution := &doc.TypeResolutions[i] if srcResolver != nil && strings.TrimSpace(resolution.Provenance.File) == "" { - pkg := firstNonEmpty(strings.TrimSpace(resolution.Provenance.Package), packageOfKey(resolution.ResolvedKey)) + pkg := inferResolutionPackage(*resolution, doc.TypeContext) name := typeNameFromKey(resolution.ResolvedKey) + if name == "" { + name = strings.TrimSpace(resolution.Expression) + } if pkg != "" && name != "" { if file, err := srcResolver.ResolveTypeFile(pkg, name); err == nil { resolution.Provenance.File = file @@ -144,7 +149,7 @@ func rewriteSafetyIssues(doc *shape.Document, cfg *Config, projectDir string) [] } } } - if issue := resolutionSafetyIssue(resolution, policy); issue != "" { + if issue := resolutionSafetyIssue(*resolution, policy); issue != "" { issues = append(issues, issue) } } @@ -152,6 +157,41 @@ func rewriteSafetyIssues(doc *shape.Document, cfg *Config, projectDir string) [] return uniqueStrings(issues) } +func hydrateConfigFromTypeContext(doc *shape.Document, cfg *Config) { + if doc == nil || cfg == nil || doc.TypeContext == nil { + return + } + if cfg.PackageDir == "" { + cfg.PackageDir = strings.TrimSpace(doc.TypeContext.PackageDir) + } + if cfg.PackageName == "" { + cfg.PackageName = strings.TrimSpace(doc.TypeContext.PackageName) + } + if cfg.PackagePath == "" { + cfg.PackagePath = strings.TrimSpace(doc.TypeContext.PackagePath) + } +} + +func inferResolutionPackage(resolution typectx.Resolution, ctx *typectx.Context) string { + pkg := strings.TrimSpace(resolution.Provenance.Package) + if pkg != "" { + return pkg + } + pkg = packageOfKey(resolution.ResolvedKey) + if pkg != "" { + return pkg + } + if ctx != nil { + if pkg = strings.TrimSpace(ctx.PackagePath); pkg != "" { + return pkg + } + if pkg = strings.TrimSpace(ctx.DefaultPackage); pkg != "" { + return pkg + } + } + return "" +} + func resolutionSafetyIssue(resolution typectx.Resolution, policy rewritePolicy) string { kind := strings.TrimSpace(strings.ToLower(resolution.Provenance.Kind)) if kind == "" { diff --git a/repository/shape/xgen/generator_test.go b/repository/shape/xgen/generator_test.go index 3315f148..eb8f35b1 100644 --- a/repository/shape/xgen/generator_test.go +++ b/repository/shape/xgen/generator_test.go @@ -256,6 +256,181 @@ type DQLOrderView struct { Old string ` + "`json:\"old,omitempty\"`" + ` } +func TestGenerateFromDQLShape_UsesTypeContextPackageDefaults(t *testing.T) { + projectDir := t.TempDir() + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/demo\n\ngo 1.25.0\n"), 0o644); err != nil { + t.Fatalf("write go.mod failed: %v", err) + } + doc := &dqlshape.Document{ + TypeContext: &typectx.Context{ + PackageDir: "pkg/platform/taxonomy", + PackageName: "taxonomy", + PackagePath: "example.com/demo/pkg/platform/taxonomy", + }, + Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "orders", + "ColumnsConfig": map[string]any{ + "ID": map[string]any{"Name": "ID", "DataType": "int"}, + }, + }, + }, + }, + }, + } + result, err := GenerateFromDQLShape(doc, &Config{ProjectDir: projectDir}) + if err != nil { + t.Fatalf("generate failed: %v", err) + } + if result == nil { + t.Fatalf("expected result") + } + if result.PackageName != "taxonomy" { + t.Fatalf("expected package name taxonomy, got %q", result.PackageName) + } + if result.PackagePath != "example.com/demo/pkg/platform/taxonomy" { + t.Fatalf("expected package path from type context, got %q", result.PackagePath) + } + if !strings.Contains(filepath.ToSlash(result.FilePath), "/pkg/platform/taxonomy/") { + t.Fatalf("expected file under type-context package dir, got %s", result.FilePath) + } +} + +func TestGenerateFromDQLShape_ProvenanceEnrichment_WithReplaceAndTypeContextPackagePath(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "project") + modelsDir := filepath.Join(root, "shared-models") + if err := os.MkdirAll(filepath.Join(projectDir, "internal", "gen"), 0o755); err != nil { + t.Fatalf("mkdir project failed: %v", err) + } + if err := os.MkdirAll(filepath.Join(modelsDir, "mdp"), 0o755); err != nil { + t.Fatalf("mkdir models failed: %v", err) + } + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/project\n\ngo 1.25\nreplace github.com/acme/models => ../shared-models\n"), 0o644); err != nil { + t.Fatalf("write project go.mod failed: %v", err) + } + if err := os.WriteFile(filepath.Join(modelsDir, "go.mod"), []byte("module github.com/acme/models\n\ngo 1.25\n"), 0o644); err != nil { + t.Fatalf("write models go.mod failed: %v", err) + } + if err := os.WriteFile(filepath.Join(modelsDir, "mdp", "types.go"), []byte("package mdp\ntype Order struct{}\n"), 0o644); err != nil { + t.Fatalf("write types.go failed: %v", err) + } + dest := filepath.Join(projectDir, "internal", "gen", "shapes_gen.go") + if err := os.WriteFile(dest, []byte("package gen\n"), 0o644); err != nil { + t.Fatalf("seed file failed: %v", err) + } + + doc := &dqlshape.Document{ + TypeContext: &typectx.Context{ + PackagePath: "github.com/acme/models/mdp", + }, + Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "orders", + "ColumnsConfig": map[string]any{ + "ID": map[string]any{"Name": "ID", "DataType": "int"}, + }, + }, + }, + }, + }, + TypeResolutions: []typectx.Resolution{ + { + Expression: "Order", + ResolvedKey: "Order", + Provenance: typectx.Provenance{ + Kind: "registry", + }, + }, + }, + } + + _, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + AllowedSourceRoots: []string{modelsDir}, + }) + if err != nil { + t.Fatalf("expected provenance enrichment to allow rewrite, got: %v", err) + } +} + +func TestGenerateFromDQLShape_ProvenanceEnrichment_WithGOPATHFallback(t *testing.T) { + root := t.TempDir() + projectDir := filepath.Join(root, "project") + gopath := filepath.Join(root, "gopath") + modelsDir := filepath.Join(gopath, "src", "github.com", "legacy", "models") + if err := os.MkdirAll(filepath.Join(projectDir, "internal", "gen"), 0o755); err != nil { + t.Fatalf("mkdir project failed: %v", err) + } + if err := os.MkdirAll(modelsDir, 0o755); err != nil { + t.Fatalf("mkdir models failed: %v", err) + } + if err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module example.com/project\n\ngo 1.25\n"), 0o644); err != nil { + t.Fatalf("write project go.mod failed: %v", err) + } + if err := os.WriteFile(filepath.Join(modelsDir, "types.go"), []byte("package models\ntype Legacy struct{}\n"), 0o644); err != nil { + t.Fatalf("write types.go failed: %v", err) + } + dest := filepath.Join(projectDir, "internal", "gen", "shapes_gen.go") + if err := os.WriteFile(dest, []byte("package gen\n"), 0o644); err != nil { + t.Fatalf("seed file failed: %v", err) + } + + orig := os.Getenv("GOPATH") + if err := os.Setenv("GOPATH", gopath); err != nil { + t.Fatalf("set GOPATH failed: %v", err) + } + defer func() { _ = os.Setenv("GOPATH", orig) }() + + doc := &dqlshape.Document{ + TypeContext: &typectx.Context{ + PackagePath: "github.com/legacy/models", + }, + Root: map[string]any{ + "Resource": map[string]any{ + "Views": []any{ + map[string]any{ + "Name": "legacy", + "ColumnsConfig": map[string]any{ + "ID": map[string]any{"Name": "ID", "DataType": "int"}, + }, + }, + }, + }, + }, + TypeResolutions: []typectx.Resolution{ + { + Expression: "Legacy", + ResolvedKey: "Legacy", + Provenance: typectx.Provenance{Kind: "registry"}, + }, + }, + } + _, err := GenerateFromDQLShape(doc, &Config{ + ProjectDir: projectDir, + PackageDir: "internal/gen", + PackageName: "gen", + FileName: "shapes_gen.go", + AllowedSourceRoots: []string{filepath.Join(gopath, "src")}, + UseGoModuleResolve: boolPtr(false), + UseGOPATHFallback: boolPtr(true), + }) + if err != nil { + t.Fatalf("expected GOPATH provenance enrichment to allow rewrite, got: %v", err) + } +} + +func boolPtr(value bool) *bool { + return &value +} + func KeepCustom() string { return "ok" } ` if err := os.WriteFile(dest, []byte(initial), 0o644); err != nil { diff --git a/service/executor/expand/evaluator.go b/service/executor/expand/evaluator.go index c733aca7..0c6dee47 100644 --- a/service/executor/expand/evaluator.go +++ b/service/executor/expand/evaluator.go @@ -35,7 +35,12 @@ type ( func WithCustomContexts(ctx ...*Variable) EvaluatorOption { return func(c *config) { - c.embededTypes = append(c.embededTypes, ctx...) + for _, item := range ctx { + if item == nil { + continue + } + c.embededTypes = append(c.embededTypes, item) + } } } @@ -47,7 +52,12 @@ func WithContext(ctx context.Context) EvaluatorOption { func WithVariable(namedVariable ...*NamedVariable) EvaluatorOption { return func(c *config) { - c.namedVariables = append(c.namedVariables, namedVariable...) + for _, item := range namedVariable { + if item == nil { + continue + } + c.namedVariables = append(c.namedVariables, item) + } } } @@ -65,6 +75,9 @@ func WithSetLiteral(setLiterals func(state *structology.State) error) EvaluatorO func WithTypeLookup(lookup xreflect.LookupType) EvaluatorOption { return func(c *config) { + if lookup == nil { + return + } c.typeLookup = lookup } } @@ -141,12 +154,18 @@ func NewEvaluator(template string, options ...EvaluatorOption) (*Evaluator, erro } for _, valueType := range aConfig.embededTypes { + if valueType == nil { + continue + } if err = evaluator.planner.EmbedVariable(valueType.Type); err != nil { return nil, err } } for _, variable := range aConfig.namedVariables { + if variable == nil { + continue + } if err = evaluator.planner.DefineVariable(variable.Name, variable.Type); err != nil { return nil, err } @@ -181,6 +200,9 @@ func NewEvaluator(template string, options ...EvaluatorOption) (*Evaluator, erro func createConfig(options []EvaluatorOption) *config { instance := newConfig() for _, option := range options { + if option == nil { + continue + } option(instance) } diff --git a/service/executor/expand/fn_new.go b/service/executor/expand/fn_new.go index 6f5b7828..6cbdf10e 100644 --- a/service/executor/expand/fn_new.go +++ b/service/executor/expand/fn_new.go @@ -41,7 +41,7 @@ func (n *newer) NewResultType(call *expr.Call) (reflect.Type, error) { expression, ok := call.Args[0].(*expr.Literal) if !ok { - return nil, fmt.Errorf("expected arg to be type of %T but was %T", expression, call.Args[1]) + return nil, fmt.Errorf("expected arg to be type of %T but was %T", expression, call.Args[0]) } return types.LookupType(n.lookup, expression.Value) diff --git a/service/executor/expand/fn_printer.go b/service/executor/expand/fn_printer.go index 3eaabe2c..620e7ef5 100644 --- a/service/executor/expand/fn_printer.go +++ b/service/executor/expand/fn_printer.go @@ -61,7 +61,7 @@ func (p *Printer) Println(args ...interface{}) string { func (p *Printer) Printf(format string, args ...interface{}) string { p.derefArgs(args) - fmt.Printf(p.Sprintf(format, args...)) + fmt.Print(p.Sprintf(format, args...)) return "" } @@ -107,12 +107,12 @@ func (p *Printer) Fatal(any interface{}, args ...interface{}) (string, error) { format, ok := any.(string) if ok { - return "", fmt.Errorf(p.Sprintf(format, args...)) + return "", fmt.Errorf("%s", p.Sprintf(format, args...)) } if err, ok := any.(error); ok { return "", err } - return "", fmt.Errorf(p.Sprintf("%+v", any)) + return "", fmt.Errorf("%s", p.Sprintf("%+v", any)) } // Fatalf fatal with formatting @@ -124,7 +124,7 @@ func (p *Printer) Fatalf(any interface{}, args ...interface{}) (string, error) { func (p *Printer) FatalfWithCode(code int, any interface{}, args ...interface{}) (string, error) { format, ok := any.(string) if ok { - return "", response.NewError(code, fmt.Sprintf(p.Sprintf(format, args...))) + return "", response.NewError(code, p.Sprintf(format, args...)) } if err, ok := any.(error); ok { return "", response.NewError(code, err.Error(), response.WithError(err)) diff --git a/service/jobs/service.go b/service/jobs/service.go index c2e8ac53..caaee6a1 100644 --- a/service/jobs/service.go +++ b/service/jobs/service.go @@ -2,6 +2,7 @@ package jobs import ( "context" + "errors" "fmt" "github.com/viant/datly/service/dbms" "github.com/viant/datly/service/reader" @@ -44,7 +45,7 @@ func (s *Service) matchFailedJob(matchKey string) (*async.Job, error) { if candidate.MatchKey == matchKey { var err error if candidate.Error != nil { - err = fmt.Errorf(*candidate.Error) + err = errors.New(*candidate.Error) } else { err = fmt.Errorf("job has status %s", candidate.Status) } diff --git a/service/session/state.go b/service/session/state.go index a629db3c..5779250f 100644 --- a/service/session/state.go +++ b/service/session/state.go @@ -471,13 +471,14 @@ func (s *Session) ensureValidValue(value interface{}, parameter *state.Parameter rawType = rawType.Elem() } - if rawType.Kind() != reflect.Struct { - break - } - if elem.Kind() == reflect.Interface && !elem.IsNil() { elem = elem.Elem() } + if rawType.Kind() != reflect.Struct { + value = elem.Interface() + valueType = reflect.TypeOf(value) + break + } if elem.Kind() == reflect.Ptr { value = elem.Interface() valueType = elem.Type() diff --git a/shared/combine.go b/shared/combine.go index b63329fa..67cbbecc 100644 --- a/shared/combine.go +++ b/shared/combine.go @@ -7,7 +7,7 @@ func CombineErrors(header string, errors []error) error { return nil } - outputErr := fmt.Errorf(header) + outputErr := fmt.Errorf("%s", header) for _, err := range errors { outputErr = fmt.Errorf("%w; %v", outputErr, err.Error()) } diff --git a/utils/httputils/violation.go b/utils/httputils/violation.go index 7ff1c0a3..912c7f19 100644 --- a/utils/httputils/violation.go +++ b/utils/httputils/violation.go @@ -50,7 +50,7 @@ func (v Violations) MergeErrors(errors []*response.Error) validator.Violations { aViolation := &validator.Violation{ Location: anError.View + "/" + anError.Parameter, Value: anError.Object, - Check: fmt.Sprint("%T", anError.Error()), + Check: fmt.Sprintf("%T", anError.Error()), Message: anError.Message, } ret = append(ret, aViolation) diff --git a/utils/types/types.go b/utils/types/types.go index dc29f123..8e7ccd78 100644 --- a/utils/types/types.go +++ b/utils/types/types.go @@ -1,6 +1,7 @@ package types import ( + "fmt" "github.com/viant/sqlx/io" "github.com/viant/xreflect" "reflect" @@ -11,6 +12,9 @@ func LookupType(lookup xreflect.LookupType, typeName string, opts ...xreflect.Op if ok { return rType, nil } + if lookup == nil { + return nil, fmt.Errorf("type %q was not found and no lookup resolver is configured", typeName) + } return lookup(typeName, opts...) } diff --git a/view/tags/parameter_test.go b/view/tags/parameter_test.go index 6cb56222..aa27bb52 100644 --- a/view/tags/parameter_test.go +++ b/view/tags/parameter_test.go @@ -24,7 +24,7 @@ func TestTag_updateParameter(t *testing.T) { { description: "async Parameter", tag: `parameter:"p1,kind=query,in=qp1,scope=async"`, - expect: &Parameter{Name: "p1", Kind: "query", In: "qp1", Scope: "myscope"}, + expect: &Parameter{Name: "p1", Kind: "query", In: "qp1", Scope: "async"}, }, } diff --git a/view/tags/view_test.go b/view/tags/view_test.go index e85c3f2d..1127cf66 100644 --- a/view/tags/view_test.go +++ b/view/tags/view_test.go @@ -29,7 +29,7 @@ func TestTag_updateView(t *testing.T) { description: "basic view", tag: `view:"foo,connector=dev" sql:"uri=testdata/foo.sql"`, expectView: &View{Name: "foo", Connector: "dev"}, - expectSQL: ViewSQL{SQL: "SELECT * FROM FOO"}, + expectSQL: ViewSQL{SQL: "SELECT * FROM FOO", URI: "testdata/foo.sql"}, expectTag: "foo,connector=dev", }, { diff --git a/view/view.go b/view/view.go index 9274c805..b297489f 100644 --- a/view/view.go +++ b/view/view.go @@ -1298,7 +1298,7 @@ func (v *View) markColumnsAsFilterable() error { for _, colName := range v.Selector.Constraints.Filterable { column, err := v._columns.Lookup(colName) if err != nil { - return fmt.Errorf("criteria column %v, on view has not been defined, %w", colName, v.Name, err) + return fmt.Errorf("criteria column %v on view %v has not been defined: %w", colName, v.Name, err) } column.Filterable = true } diff --git a/warmup/cache_test.go b/warmup/cache_test.go index e691c835..408a4a31 100644 --- a/warmup/cache_test.go +++ b/warmup/cache_test.go @@ -2,14 +2,13 @@ package warmup import ( "context" + "path" + "testing" + "github.com/stretchr/testify/assert" - "github.com/viant/afs" - "github.com/viant/datly/gateway/router" "github.com/viant/datly/internal/tests" "github.com/viant/datly/service/reader" "github.com/viant/datly/view" - "path" - "testing" ) func TestPopulateCache(t *testing.T) { @@ -59,14 +58,14 @@ func TestPopulateCache(t *testing.T) { resourcePath := path.Join("testdata", testCase.URL, "resource.yaml") - resource, err := router.NewResourceFromURL(context.TODO(), afs.New(), resourcePath, false) + resource, err := view.NewResourceFromURL(context.TODO(), resourcePath, nil, nil) if !assert.Nil(t, err, testCase.description) { continue } var views []*view.View - for _, route := range resource.Routes { - views = append(views, route.View) + for _, item := range resource.Views { + views = append(views, item) } inserted, err := PopulateCache(views) @@ -100,7 +99,7 @@ func checkIfCached(t *testing.T, cache *view.Cache, ctx context.Context, testCas builder := reader.NewBuilder() for _, cacheInput := range input { - build, err := builder.CacheSQL(aView, cacheInput.Selector) + build, err := builder.CacheSQL(ctx, aView, cacheInput.Selector) if err != nil { return err } @@ -116,7 +115,7 @@ func checkIfCached(t *testing.T, cache *view.Cache, ctx context.Context, testCas } if cacheInput.IndexMeta && aView.Template.Summary != nil { - metaIndex, err := builder.CacheMetaSQL(aView, cacheInput.Selector, &view.BatchData{ + metaIndex, err := builder.CacheMetaSQL(ctx, aView, cacheInput.Selector, &view.BatchData{ ValuesBatch: testCase.metaIndexed, Values: testCase.metaIndexed, }, nil, nil) From e7fcb5c34355196ed51603af6b4ddba91d046813 Mon Sep 17 00:00:00 2001 From: adranwit Date: Mon, 23 Feb 2026 09:42:13 -0800 Subject: [PATCH 4/6] Implemented near-full shape-engine parity with the legacy internal translator by expanding DQL compile/load (relations, handler/dml paths, diagnostics with line/char mapping, type-context defaults/resolution, declaration/settings directives, and metadata/type parity), and validated parity across platform routes with 0 mismatches in the all-sources sweep. Added explicit column-discovery policy controls (auto/on/off) with default auto behavior that requires discovery for SELECT * or missing concrete shape, preserves schema column order with append-only newly discovered columns, and fails compilation when discovery is required but disabled. --- cmd/command/translate_shape.go | 197 ++++++++ cmd/command/translate_shape_test.go | 30 ++ cmd/options/query_test.go | 34 ++ cmd/options/rule_engine_test.go | 21 + gateway/dql_bootstrap.go | 453 +++++++++++++++++++ gateway/dql_bootstrap_test.go | 122 +++++ internal/inference/join_test.go | 56 +++ internal/testutil/sqlnormalizer/cases.go | 43 ++ internal/translator/parser/sanitizer_test.go | 222 +++++++++ service/executor/expand/evaluator_test.go | 78 ++++ service/session/selector_injector_test.go | 89 ++++ 11 files changed, 1345 insertions(+) create mode 100644 cmd/command/translate_shape.go create mode 100644 cmd/command/translate_shape_test.go create mode 100644 cmd/options/query_test.go create mode 100644 cmd/options/rule_engine_test.go create mode 100644 gateway/dql_bootstrap.go create mode 100644 gateway/dql_bootstrap_test.go create mode 100644 internal/inference/join_test.go create mode 100644 internal/testutil/sqlnormalizer/cases.go create mode 100644 internal/translator/parser/sanitizer_test.go create mode 100644 service/executor/expand/evaluator_test.go create mode 100644 service/session/selector_injector_test.go diff --git a/cmd/command/translate_shape.go b/cmd/command/translate_shape.go new file mode 100644 index 00000000..109e9963 --- /dev/null +++ b/cmd/command/translate_shape.go @@ -0,0 +1,197 @@ +package command + +import ( + "context" + "encoding/json" + "fmt" + "path" + "path/filepath" + "strings" + + "github.com/viant/afs/file" + "github.com/viant/afs/url" + "github.com/viant/datly/cmd/options" + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/repository/shape" + shapeCompile "github.com/viant/datly/repository/shape/compile" + shapeLoad "github.com/viant/datly/repository/shape/load" + datlyservice "github.com/viant/datly/service" + "github.com/viant/datly/shared" + "github.com/viant/datly/view" + "gopkg.in/yaml.v3" +) + +func (s *Service) translateShape(ctx context.Context, opts *options.Options) error { + rule := opts.Rule() + compiler := shapeCompile.New() + loader := shapeLoad.New() + for rule.Index = 0; rule.Index < len(rule.Source); rule.Index++ { + sourceURL := rule.SourceURL() + _, name := url.Split(sourceURL, file.Scheme) + fmt.Printf("translating %v (shape)\n", name) + dql, err := rule.LoadSource(ctx, s.fs, sourceURL) + if err != nil { + return err + } + dql = strings.TrimSpace(dql) + if dql == "" { + return fmt.Errorf("source %s was empty", sourceURL) + } + shapeSource := &shape.Source{ + Name: strings.TrimSuffix(name, path.Ext(name)), + Path: url.Path(sourceURL), + DQL: dql, + Connector: strings.TrimSpace(rule.Connector), + } + planResult, err := compiler.Compile(ctx, shapeSource) + if err != nil { + return fmt.Errorf("failed to compile %s: %w", sourceURL, err) + } + componentArtifact, err := loader.LoadComponent(ctx, planResult) + if err != nil { + return fmt.Errorf("failed to load %s: %w", sourceURL, err) + } + component, ok := componentArtifact.Component.(*shapeLoad.Component) + if !ok || component == nil { + return fmt.Errorf("unexpected component artifact for %s", sourceURL) + } + if err = s.persistShapeRoute(ctx, opts, sourceURL, dql, componentArtifact.Resource, component); err != nil { + return err + } + } + paths := url.Join(opts.Repository().RepositoryURL, "Datly", "routes", "paths.yaml") + if ok, _ := s.fs.Exists(ctx, paths); ok { + _ = s.fs.Delete(ctx, paths) + } + return nil +} + +type shapeRuleFile struct { + Resource *view.Resource `yaml:"Resource,omitempty"` + Routes []*repository.Component `yaml:"Routes,omitempty"` + TypeContext any `yaml:"TypeContext,omitempty"` +} + +func (s *Service) persistShapeRoute(ctx context.Context, opts *options.Options, sourceURL, dql string, resource *view.Resource, component *shapeLoad.Component) error { + rule := opts.Rule() + routeYAML, routeRoot, relDir, stem, err := routePathForShape(rule, opts.Repository().RepositoryURL, sourceURL) + if err != nil { + return err + } + if resource != nil { + for _, item := range resource.Views { + if item == nil || item.Template == nil { + continue + } + if strings.TrimSpace(item.Template.Source) == "" { + continue + } + sqlRel := strings.TrimSpace(item.Template.SourceURL) + if sqlRel == "" { + sqlRel = path.Join(stem, item.Name+".sql") + } + sqlDest := path.Join(routeRoot, relDir, filepath.ToSlash(sqlRel)) + if err = s.fs.Upload(ctx, sqlDest, file.DefaultFileOsMode, strings.NewReader(item.Template.Source)); err != nil { + return fmt.Errorf("failed to persist sql %s: %w", sqlDest, err) + } + item.Template.SourceURL = sqlRel + } + } + rootView := "" + if component != nil { + rootView = strings.TrimSpace(component.RootView) + } + if rootView == "" && resource != nil && len(resource.Views) > 0 && resource.Views[0] != nil { + rootView = resource.Views[0].Name + } + method, uri := parseShapeRulePath(dql, rule.RuleName(), opts.Repository().APIPrefix) + route := &repository.Component{ + Path: contract.Path{ + Method: method, + URI: uri, + }, + Contract: contract.Contract{ + Service: serviceForMethod(method), + }, + View: &view.View{Reference: shared.Reference{Ref: rootView}}, + } + if component != nil { + route.TypeContext = component.TypeContext + if component.Directives != nil && component.Directives.MCP != nil { + route.Name = strings.TrimSpace(component.Directives.MCP.Name) + route.Description = strings.TrimSpace(component.Directives.MCP.Description) + route.DescriptionURI = strings.TrimSpace(component.Directives.MCP.DescriptionPath) + } + } + payload := &shapeRuleFile{ + Resource: resource, + Routes: []*repository.Component{route}, + } + if component != nil && component.TypeContext != nil { + payload.TypeContext = component.TypeContext + } + data, err := yaml.Marshal(payload) + if err != nil { + return err + } + if err = s.fs.Upload(ctx, routeYAML, file.DefaultFileOsMode, strings.NewReader(string(data))); err != nil { + return fmt.Errorf("failed to persist route yaml %s: %w", routeYAML, err) + } + return nil +} + +func routePathForShape(rule *options.Rule, repoURL, sourceURL string) (routeYAML string, routeRoot string, relDir string, stem string, err error) { + sourcePath := filepath.Clean(url.Path(sourceURL)) + basePath := filepath.Clean(rule.BaseRuleURL()) + relative, relErr := filepath.Rel(basePath, sourcePath) + if relErr != nil || strings.HasPrefix(relative, "..") { + relative = filepath.Base(sourcePath) + } + relative = filepath.ToSlash(relative) + relDir = filepath.ToSlash(path.Dir(relative)) + if relDir == "." { + relDir = "" + } + stem = strings.TrimSuffix(path.Base(relative), path.Ext(relative)) + routeRoot = url.Join(repoURL, "Datly", "routes") + routeYAML = url.Join(routeRoot, relDir, stem+".yaml") + return routeYAML, routeRoot, relDir, stem, nil +} + +type shapeRuleHeader struct { + Method string `json:"Method"` + URI string `json:"URI"` +} + +func parseShapeRulePath(dql, ruleName, apiPrefix string) (string, string) { + method := "GET" + uri := "/" + strings.Trim(strings.TrimSpace(ruleName), "/") + if prefix := strings.TrimSpace(apiPrefix); prefix != "" { + uri = strings.TrimRight(prefix, "/") + uri + } + start := strings.Index(dql, "/*") + end := strings.Index(dql, "*/") + if start != -1 && end > start+2 { + raw := strings.TrimSpace(dql[start+2 : end]) + if strings.HasPrefix(raw, "{") && strings.HasSuffix(raw, "}") { + header := &shapeRuleHeader{} + if err := json.Unmarshal([]byte(raw), header); err == nil { + if candidate := strings.TrimSpace(strings.ToUpper(header.Method)); candidate != "" { + method = candidate + } + if candidate := strings.TrimSpace(header.URI); candidate != "" { + uri = candidate + } + } + } + } + return method, uri +} + +func serviceForMethod(method string) datlyservice.Type { + if strings.EqualFold(method, "GET") { + return datlyservice.TypeReader + } + return datlyservice.TypeExecutor +} diff --git a/cmd/command/translate_shape_test.go b/cmd/command/translate_shape_test.go new file mode 100644 index 00000000..b76fac4e --- /dev/null +++ b/cmd/command/translate_shape_test.go @@ -0,0 +1,30 @@ +package command + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/cmd/options" +) + +func TestParseShapeRulePath(t *testing.T) { + method, uri := parseShapeRulePath(`/* {"Method":"POST","URI":"/v1/api/orders"} */ SELECT 1`, "orders", "/v1/api") + assert.Equal(t, "POST", method) + assert.Equal(t, "/v1/api/orders", uri) + + method, uri = parseShapeRulePath(`SELECT 1`, "orders", "/v1/api") + assert.Equal(t, "GET", method) + assert.Equal(t, "/v1/api/orders", uri) +} + +func TestRoutePathForShape(t *testing.T) { + rule := &options.Rule{Project: "/repo", Source: []string{"/repo/dql/platform/campaign/post.dql"}} + routeYAML, routeRoot, relDir, stem, err := routePathForShape(rule, "/repo/dev", "/repo/dql/platform/campaign/post.dql") + require.NoError(t, err) + assert.Equal(t, "/repo/dev/Datly/routes/platform/campaign/post.yaml", routeYAML) + assert.Equal(t, "/repo/dev/Datly/routes", routeRoot) + assert.Equal(t, filepath.ToSlash("platform/campaign"), relDir) + assert.Equal(t, "post", stem) +} diff --git a/cmd/options/query_test.go b/cmd/options/query_test.go new file mode 100644 index 00000000..055d3915 --- /dev/null +++ b/cmd/options/query_test.go @@ -0,0 +1,34 @@ +package options + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/datly/internal/testutil/sqlnormalizer" + "github.com/viant/sqlparser" +) + +func parserOption() sqlparser.Option { + return sqlparser.WithErrorHandler(nil) +} + +func TestRule_NormalizeSQL(t *testing.T) { + for _, testCase := range sqlnormalizer.Cases() { + t.Run(testCase.Name, func(t *testing.T) { + rule := &Rule{Generated: testCase.Generated} + actual := rule.NormalizeSQL(testCase.SQL, parserOption) + require.Equal(t, testCase.Expect, actual) + }) + } +} + +func TestMapper_Map(t *testing.T) { + m := mapper{"a": "A"} + require.Equal(t, "A", m.Map("a")) + require.Equal(t, "b", m.Map("b")) +} + +func TestNormalizeName(t *testing.T) { + require.Equal(t, "UserAlias", normalizeName("user_alias")) + require.Equal(t, "UserAlias", normalizeName("UserAlias")) +} diff --git a/cmd/options/rule_engine_test.go b/cmd/options/rule_engine_test.go new file mode 100644 index 00000000..bac95c6c --- /dev/null +++ b/cmd/options/rule_engine_test.go @@ -0,0 +1,21 @@ +package options + +import "testing" + +func TestRule_EffectiveEngine(t *testing.T) { + testCases := []struct { + name string + engine string + want string + }{ + {name: "default", engine: "", want: EngineLegacy}, + {name: "shape", engine: "shape", want: EngineShape}, + {name: "invalid", engine: "other", want: EngineLegacy}, + } + for _, testCase := range testCases { + rule := &Rule{Engine: testCase.engine} + if got := rule.EffectiveEngine(); got != testCase.want { + t.Fatalf("%s: got %s, want %s", testCase.name, got, testCase.want) + } + } +} diff --git a/gateway/dql_bootstrap.go b/gateway/dql_bootstrap.go new file mode 100644 index 00000000..fe7d62db --- /dev/null +++ b/gateway/dql_bootstrap.go @@ -0,0 +1,453 @@ +package gateway + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path" + "path/filepath" + "sort" + "strings" + + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/repository/shape" + shapeCompile "github.com/viant/datly/repository/shape/compile" + shapeLoad "github.com/viant/datly/repository/shape/load" + datlyservice "github.com/viant/datly/service" + "github.com/viant/datly/view" +) + +func (r *Service) applyDQLBootstrap(ctx context.Context, repo *repository.Service, cfg *DQLBootstrap) error { + if cfg == nil || len(cfg.Sources) == 0 { + return nil + } + sources, err := discoverDQLBootstrapSources(cfg.Sources, cfg.Exclude) + if err != nil { + return err + } + if len(sources) == 0 { + return fmt.Errorf("no DQL bootstrap sources matched") + } + compiler := shapeCompile.New() + loader := shapeLoad.New() + precedence := cfg.EffectivePrecedence() + var errors []error + for _, sourcePath := range sources { + component, err := compileBootstrapComponent(ctx, compiler, loader, repo, sourcePath, cfg, r.Config.APIPrefix) + if err != nil { + if cfg.ShouldFailFast() { + return err + } + errors = append(errors, err) + continue + } + exists, lookupErr := hasRepositoryProvider(ctx, repo, &component.Path) + if lookupErr != nil { + if cfg.ShouldFailFast() { + return lookupErr + } + errors = append(errors, lookupErr) + continue + } + if exists { + switch precedence { + case DQLBootstrapPrecedenceRoutesWins: + continue + case DQLBootstrapPrecedenceErrorOnMixed: + err = fmt.Errorf("DQL bootstrap conflict for %s:%s", component.Method, component.URI) + if cfg.ShouldFailFast() { + return err + } + errors = append(errors, err) + continue + } + } + repo.Register(component) + } + if len(errors) > 0 { + return fmt.Errorf("DQL bootstrap completed with %d errors: %w", len(errors), errors[0]) + } + return nil +} + +func compileBootstrapComponent(ctx context.Context, compiler *shapeCompile.DQLCompiler, loader *shapeLoad.Loader, repo *repository.Service, sourcePath string, cfg *DQLBootstrap, apiPrefix string) (*repository.Component, error) { + data, err := os.ReadFile(sourcePath) + if err != nil { + return nil, fmt.Errorf("failed to read DQL bootstrap source %s: %w", sourcePath, err) + } + dql := strings.TrimSpace(string(data)) + if dql == "" { + return nil, fmt.Errorf("empty DQL bootstrap source: %s", sourcePath) + } + sourceName := strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) + source := &shape.Source{ + Name: sourceName, + Path: sourcePath, + DQL: dql, + } + planResult, err := compiler.Compile(ctx, source, compileOptionsFromBootstrap(cfg)...) + if err != nil { + return nil, fmt.Errorf("failed to compile DQL bootstrap source %s: %w", sourcePath, err) + } + componentArtifact, err := loader.LoadComponent(ctx, planResult) + if err != nil { + return nil, fmt.Errorf("failed to load DQL bootstrap source %s: %w", sourcePath, err) + } + normalizeBootstrapInlineSQL(componentArtifact.Resource) + mergeBootstrapSharedResources(componentArtifact.Resource, repo) + loaded, ok := componentArtifact.Component.(*shapeLoad.Component) + if !ok || loaded == nil { + return nil, fmt.Errorf("unexpected shape component artifact for %s", sourcePath) + } + rootView := lookupRootView(componentArtifact.Resource, loaded.RootView) + if rootView == nil { + return nil, fmt.Errorf("missing root view %q for %s", loaded.RootView, sourcePath) + } + method, uri := resolvePathSettings(sourcePath, dql, apiPrefix) + componentModel := &repository.Component{ + Path: contract.Path{ + Method: method, + URI: uri, + }, + Contract: contract.Contract{ + Service: defaultServiceForMethod(method, rootView), + }, + View: rootView, + TypeContext: loaded.TypeContext, + } + loadOptions := []repository.Option{} + if repo != nil { + loadOptions = append(loadOptions, repository.WithResources(repo.Resources())) + loadOptions = append(loadOptions, repository.WithExtensions(repo.Extensions())) + } + components, err := repository.LoadComponentsFromMap(ctx, map[string]any{ + "Resource": componentArtifact.Resource, + "Components": []*repository.Component{componentModel}, + }, loadOptions...) + if err != nil { + return nil, fmt.Errorf("failed to materialize bootstrap component for %s: %w", sourcePath, err) + } + if err = components.Init(ctx); err != nil { + return nil, fmt.Errorf("failed to initialize bootstrap component for %s: %w", sourcePath, err) + } + if len(components.Components) == 0 || components.Components[0] == nil { + return nil, fmt.Errorf("empty initialized bootstrap component for %s", sourcePath) + } + return components.Components[0], nil +} + +func mergeBootstrapSharedResources(target *view.Resource, repo *repository.Service) { + if target == nil || repo == nil || repo.Resources() == nil { + return + } + if connectors, err := repo.Resources().Lookup(view.ResourceConnectors); err == nil && connectors != nil && connectors.Resource != nil { + target.MergeFrom(connectors.Resource, nil) + } + if constants, err := repo.Resources().Lookup(view.ResourceConstants); err == nil && constants != nil && constants.Resource != nil { + target.MergeFrom(constants.Resource, nil) + } +} + +func normalizeBootstrapInlineSQL(resource *view.Resource) { + if resource == nil { + return + } + for _, item := range resource.Views { + if item == nil || item.Template == nil { + continue + } + // DQL bootstrap compiles from in-memory source; keep SQL inline and avoid filesystem lookups. + item.Template.SourceURL = "" + } +} + +func defaultServiceForMethod(method string, rootView *view.View) datlyservice.Type { + if strings.EqualFold(method, "GET") { + return datlyservice.TypeReader + } + if rootView != nil && rootView.Mode == view.ModeQuery { + return datlyservice.TypeReader + } + return datlyservice.TypeExecutor +} + +func hasRepositoryProvider(ctx context.Context, repo *repository.Service, path *contract.Path) (bool, error) { + if repo == nil || repo.Registry() == nil || path == nil { + return false, nil + } + _, err := repo.Registry().LookupProvider(ctx, path) + if err != nil { + message := strings.ToLower(strings.TrimSpace(err.Error())) + if strings.Contains(message, "not found") { + return false, nil + } + return false, err + } + return true, nil +} + +func compileOptionsFromBootstrap(cfg *DQLBootstrap) []shape.CompileOption { + if cfg == nil { + return nil + } + var result []shape.CompileOption + switch strings.ToLower(strings.TrimSpace(cfg.CompileProfile)) { + case string(shape.CompileProfileStrict): + result = append(result, shape.WithCompileProfile(shape.CompileProfileStrict)) + case string(shape.CompileProfileCompat): + result = append(result, shape.WithCompileProfile(shape.CompileProfileCompat)) + } + switch strings.ToLower(strings.TrimSpace(cfg.MixedMode)) { + case string(shape.CompileMixedModeExecWins): + result = append(result, shape.WithMixedMode(shape.CompileMixedModeExecWins)) + case string(shape.CompileMixedModeReadWins): + result = append(result, shape.WithMixedMode(shape.CompileMixedModeReadWins)) + case string(shape.CompileMixedModeErrorOnMixed): + result = append(result, shape.WithMixedMode(shape.CompileMixedModeErrorOnMixed)) + } + switch strings.ToLower(strings.TrimSpace(cfg.UnknownNonReadMode)) { + case string(shape.CompileUnknownNonReadWarn): + result = append(result, shape.WithUnknownNonReadMode(shape.CompileUnknownNonReadWarn)) + case string(shape.CompileUnknownNonReadError): + result = append(result, shape.WithUnknownNonReadMode(shape.CompileUnknownNonReadError)) + } + switch strings.ToLower(strings.TrimSpace(cfg.ColumnDiscoveryMode)) { + case string(shape.CompileColumnDiscoveryAuto): + result = append(result, shape.WithColumnDiscoveryMode(shape.CompileColumnDiscoveryAuto)) + case string(shape.CompileColumnDiscoveryOn): + result = append(result, shape.WithColumnDiscoveryMode(shape.CompileColumnDiscoveryOn)) + case string(shape.CompileColumnDiscoveryOff): + result = append(result, shape.WithColumnDiscoveryMode(shape.CompileColumnDiscoveryOff)) + } + if marker := strings.TrimSpace(cfg.DQLPathMarker); marker != "" { + result = append(result, shape.WithDQLPathMarker(marker)) + } + if rel := strings.TrimSpace(cfg.RoutesRelativePath); rel != "" { + result = append(result, shape.WithRoutesRelativePath(rel)) + } + return result +} + +func discoverDQLBootstrapSources(includes, excludes []string) ([]string, error) { + seen := map[string]struct{}{} + var result []string + for _, include := range includes { + include = strings.TrimSpace(include) + if include == "" { + continue + } + expanded, err := expandBootstrapPattern(include) + if err != nil { + return nil, err + } + for _, candidate := range expanded { + if !isDQLSourceFile(candidate) { + continue + } + if matchesAnyPattern(candidate, excludes) { + continue + } + if _, ok := seen[candidate]; ok { + continue + } + seen[candidate] = struct{}{} + result = append(result, candidate) + } + } + sort.Strings(result) + return result, nil +} + +func expandBootstrapPattern(pattern string) ([]string, error) { + pattern = filepath.Clean(pattern) + if strings.Contains(pattern, "**") { + return expandDoubleStarPattern(pattern) + } + if hasGlobMeta(pattern) { + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, err + } + return flattenPaths(matches) + } + return flattenPaths([]string{pattern}) +} + +func flattenPaths(items []string) ([]string, error) { + var result []string + for _, item := range items { + item = strings.TrimSpace(item) + if item == "" { + continue + } + info, err := os.Stat(item) + if err != nil { + if os.IsNotExist(err) { + continue + } + return nil, err + } + if !info.IsDir() { + result = append(result, item) + continue + } + err = filepath.WalkDir(item, func(candidate string, d os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if d.IsDir() { + return nil + } + if isDQLSourceFile(candidate) { + result = append(result, candidate) + } + return nil + }) + if err != nil { + return nil, err + } + } + return result, nil +} + +func expandDoubleStarPattern(pattern string) ([]string, error) { + slash := filepath.ToSlash(pattern) + index := strings.Index(slash, "**") + root := strings.TrimSuffix(slash[:index], "/") + if root == "" { + root = "." + } + rootPath := filepath.FromSlash(root) + var result []string + err := filepath.WalkDir(rootPath, func(candidate string, d os.DirEntry, walkErr error) error { + if walkErr != nil { + return walkErr + } + if d.IsDir() { + return nil + } + normalized := filepath.ToSlash(candidate) + if !globMatch(slash, normalized) { + return nil + } + result = append(result, candidate) + return nil + }) + return result, err +} + +func hasGlobMeta(pattern string) bool { + return strings.ContainsAny(pattern, "*?[") +} + +func matchesAnyPattern(candidate string, patterns []string) bool { + for _, pattern := range patterns { + pattern = strings.TrimSpace(pattern) + if pattern == "" { + continue + } + if globMatch(filepath.ToSlash(pattern), filepath.ToSlash(candidate)) { + return true + } + } + return false +} + +func globMatch(pattern, candidate string) bool { + pattern = filepath.ToSlash(pattern) + candidate = filepath.ToSlash(candidate) + if strings.Contains(pattern, "**") { + return matchDoubleStar(strings.Split(pattern, "/"), strings.Split(candidate, "/")) + } + ok, _ := path.Match(pattern, candidate) + return ok +} + +func matchDoubleStar(pattern, candidate []string) bool { + if len(pattern) == 0 { + return len(candidate) == 0 + } + head := pattern[0] + if head == "**" { + if matchDoubleStar(pattern[1:], candidate) { + return true + } + if len(candidate) > 0 { + return matchDoubleStar(pattern, candidate[1:]) + } + return false + } + if len(candidate) == 0 { + return false + } + ok, _ := path.Match(head, candidate[0]) + if !ok { + return false + } + return matchDoubleStar(pattern[1:], candidate[1:]) +} + +func isDQLSourceFile(path string) bool { + ext := strings.ToLower(strings.TrimSpace(filepath.Ext(path))) + return ext == ".dql" || ext == ".sql" +} + +func lookupRootView(resource *view.Resource, root string) *view.View { + if resource == nil { + return nil + } + name := strings.TrimSpace(root) + if name != "" { + if candidate, _ := resource.View(name); candidate != nil { + return candidate + } + } + if len(resource.Views) > 0 { + return resource.Views[0] + } + return nil +} + +type bootstrapRuleSettings struct { + Method string `json:"Method"` + URI string `json:"URI"` +} + +func resolvePathSettings(sourcePath, dql, apiPrefix string) (string, string) { + method := "GET" + uri := "" + settings := parseBootstrapRuleSettings(dql) + if settings != nil { + if candidate := strings.TrimSpace(strings.ToUpper(settings.Method)); candidate != "" { + method = candidate + } + uri = strings.TrimSpace(settings.URI) + } + if uri == "" { + stem := strings.TrimSuffix(filepath.Base(sourcePath), filepath.Ext(sourcePath)) + uri = "/" + strings.Trim(stem, "/") + if prefix := strings.TrimSpace(apiPrefix); prefix != "" { + uri = strings.TrimRight(prefix, "/") + uri + } + } + return method, uri +} + +func parseBootstrapRuleSettings(dql string) *bootstrapRuleSettings { + start := strings.Index(dql, "/*") + end := strings.Index(dql, "*/") + if start == -1 || end == -1 || end <= start+2 { + return nil + } + raw := strings.TrimSpace(dql[start+2 : end]) + if !strings.HasPrefix(raw, "{") || !strings.HasSuffix(raw, "}") { + return nil + } + ret := &bootstrapRuleSettings{} + if err := json.Unmarshal([]byte(raw), ret); err != nil { + return nil + } + return ret +} diff --git a/gateway/dql_bootstrap_test.go b/gateway/dql_bootstrap_test.go new file mode 100644 index 00000000..b36714cd --- /dev/null +++ b/gateway/dql_bootstrap_test.go @@ -0,0 +1,122 @@ +package gateway + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository" + "github.com/viant/datly/repository/contract" + "github.com/viant/datly/view" +) + +func TestConfigValidate_AllowsEmptyRouteURLWithDQLBootstrap(t *testing.T) { + cfg := &Config{ + ExposableConfig: ExposableConfig{ + DQLBootstrap: &DQLBootstrap{ + Sources: []string{"./testdata/*.dql"}, + }, + }, + } + require.NoError(t, cfg.Validate()) +} + +func TestConfigValidate_FailsWithoutRouteAndBootstrap(t *testing.T) { + cfg := &Config{} + require.ErrorContains(t, cfg.Validate(), "RouteURL was empty") +} + +func TestConfigValidate_FailsForEmptyBootstrapSources(t *testing.T) { + cfg := &Config{ + ExposableConfig: ExposableConfig{ + DQLBootstrap: &DQLBootstrap{}, + }, + } + require.ErrorContains(t, cfg.Validate(), "DQLBootstrap.Sources was empty") +} + +func TestDiscoverDQLBootstrapSources(t *testing.T) { + root := t.TempDir() + require.NoError(t, os.MkdirAll(filepath.Join(root, "sql", "nested"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(root, "sql", "a.dql"), []byte("SELECT 1"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(root, "sql", "nested", "b.sql"), []byte("SELECT 2"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(root, "sql", "nested", "skip.dql"), []byte("SELECT 3"), 0o644)) + + sources, err := discoverDQLBootstrapSources( + []string{filepath.Join(root, "sql", "**", "*")}, + []string{filepath.Join(root, "sql", "**", "skip.dql")}, + ) + require.NoError(t, err) + require.Len(t, sources, 2) + assert.Contains(t, sources, filepath.Join(root, "sql", "a.dql")) + assert.Contains(t, sources, filepath.Join(root, "sql", "nested", "b.sql")) +} + +func TestResolvePathSettings(t *testing.T) { + method, uri := resolvePathSettings("/tmp/orders/get.dql", `/* {"Method":"POST","URI":"/v1/api/orders"} */ SELECT 1`, "/v1/api") + assert.Equal(t, "POST", method) + assert.Equal(t, "/v1/api/orders", uri) + + method, uri = resolvePathSettings("/tmp/orders/get.dql", `SELECT 1`, "/v1/api") + assert.Equal(t, "GET", method) + assert.Equal(t, "/v1/api/get", uri) +} + +func TestDQLBootstrapEffectivePrecedence(t *testing.T) { + assert.Equal(t, DQLBootstrapPrecedenceRoutesWins, (&DQLBootstrap{}).EffectivePrecedence()) + assert.Equal(t, DQLBootstrapPrecedenceDQLWins, (&DQLBootstrap{Precedence: "dql_wins"}).EffectivePrecedence()) + assert.Equal(t, DQLBootstrapPrecedenceRoutesWins, (&DQLBootstrap{Precedence: "unknown"}).EffectivePrecedence()) +} + +func TestApplyDQLBootstrap_Precedence(t *testing.T) { + ctx := context.Background() + repo, err := repository.New(ctx, repository.WithComponentURL(""), repository.WithNoPlugin()) + require.NoError(t, err) + + route := contract.Path{Method: "GET", URI: "/v1/api/test"} + repo.Register(&repository.Component{Path: route}) + connectors, err := repo.Resources().Lookup(view.ResourceConnectors) + require.NoError(t, err) + connectors.Connectors = append(connectors.Connectors, &view.Connector{ + Connection: view.Connection{ + DBConfig: view.DBConfig{ + Name: "test_conn", + Driver: "sqlite3", + DSN: "sqlite:./test.db", + }, + }, + }) + + root := t.TempDir() + source := filepath.Join(root, "test.dql") + require.NoError(t, os.WriteFile(source, []byte(`/* {"Method":"GET","URI":"/v1/api/test","Connector":"test_conn"} */ SELECT 1 AS id`), 0o644)) + srv := &Service{Config: &Config{ExposableConfig: ExposableConfig{APIPrefix: "/v1/api"}}} + + routesWins := &DQLBootstrap{ + Sources: []string{source}, + Precedence: DQLBootstrapPrecedenceRoutesWins, + } + require.NoError(t, srv.applyDQLBootstrap(ctx, repo, routesWins)) + provider, err := repo.Registry().LookupProvider(ctx, &route) + require.NoError(t, err) + require.NotNil(t, provider) + component, err := provider.Component(ctx) + require.NoError(t, err) + assert.Nil(t, component.View) + + dqlWins := &DQLBootstrap{ + Sources: []string{source}, + Precedence: DQLBootstrapPrecedenceDQLWins, + } + require.NoError(t, srv.applyDQLBootstrap(ctx, repo, dqlWins)) + provider, err = repo.Registry().LookupProvider(ctx, &route) + require.NoError(t, err) + require.NotNil(t, provider) + component, err = provider.Component(ctx) + require.NoError(t, err) + require.NotNil(t, component.View) + assert.Equal(t, "test", component.View.Name) +} diff --git a/internal/inference/join_test.go b/internal/inference/join_test.go new file mode 100644 index 00000000..ea30ddb9 --- /dev/null +++ b/internal/inference/join_test.go @@ -0,0 +1,56 @@ +package inference + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/viant/sqlparser" +) + +func TestJoinRelationExtraction(t *testing.T) { + testCases := []struct { + description string + sql string + wantParent string + wantRelCol string + wantRefCol string + }{ + { + description: "simple join", + sql: "SELECT * FROM a a JOIN b b ON a.brand = b.b_brand", + wantParent: "a", + wantRelCol: "brand", + wantRefCol: "b_brand", + }, + { + description: "join with function on parent", + sql: "SELECT * FROM a a JOIN b b ON lower(a.brand) = b.b_brand", + wantParent: "a", + wantRelCol: "brand", + wantRefCol: "b_brand", + }, + { + description: "join with collate and multiple conditions", + sql: "SELECT * FROM a a JOIN b b ON " + + "a.brand COLLATE utf8mb4_bin = b.b_brand COLLATE utf8mb4_bin AND " + + "a.model COLLATE utf8mb4_bin = b.b_model COLLATE utf8mb4_bin", + wantParent: "a", + wantRelCol: "brand", + wantRefCol: "b_brand", + }, + } + + for _, testCase := range testCases { + q, err := sqlparser.ParseQuery(testCase.sql) + require.NoError(t, err, testCase.description) + require.NotEmpty(t, q.Joins, testCase.description) + + join := q.Joins[0] + parent := ParentAlias(join) + require.Equal(t, testCase.wantParent, parent, testCase.description) + + relCol, refCol := ExtractRelationColumns(join) + require.Equal(t, testCase.wantRelCol, relCol, testCase.description) + require.Equal(t, testCase.wantRefCol, refCol, testCase.description) + } +} diff --git a/internal/testutil/sqlnormalizer/cases.go b/internal/testutil/sqlnormalizer/cases.go new file mode 100644 index 00000000..73569e24 --- /dev/null +++ b/internal/testutil/sqlnormalizer/cases.go @@ -0,0 +1,43 @@ +package sqlnormalizer + +type Case struct { + Name string + Generated bool + SQL string + Expect string +} + +func Cases() []Case { + return []Case{ + { + Name: "skip normalization when not generated", + Generated: false, + SQL: "SELECT a.id FROM users a JOIN orders b ON a.id = b.user_id", + Expect: "SELECT a.id FROM users a JOIN orders b ON a.id = b.user_id", + }, + { + Name: "invalid sql returns input", + Generated: true, + SQL: "SELECT * FROM (", + Expect: "SELECT * FROM (", + }, + { + Name: "normalize from and join aliases in selectors and alias nodes", + Generated: true, + SQL: "SELECT a.id, b.user_id FROM users a JOIN orders b ON a.id = b.user_id", + Expect: "SELECT A.id, B.user_id FROM users A JOIN orders B ON A.id = B.user_id", + }, + { + Name: "keep alias that is already normalized", + Generated: true, + SQL: "SELECT UserAlias.id FROM users UserAlias", + Expect: "SELECT UserAlias.id FROM users UserAlias", + }, + { + Name: "normalize snake_case alias", + Generated: true, + SQL: "SELECT order_item.id FROM users order_item", + Expect: "SELECT OrderItem.id FROM users OrderItem", + }, + } +} diff --git a/internal/translator/parser/sanitizer_test.go b/internal/translator/parser/sanitizer_test.go new file mode 100644 index 00000000..38a9c8d1 --- /dev/null +++ b/internal/translator/parser/sanitizer_test.go @@ -0,0 +1,222 @@ +package parser + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/internal/inference" + "github.com/viant/datly/view/keywords" + "github.com/viant/velty/functions" +) + +func TestTemplate_Sanitize(t *testing.T) { + state := inference.State{} + tmpl, err := NewTemplate("#set($x = 1) SELECT * FROM t WHERE id = $x AND name = $Name", &state) + require.NoError(t, err) + actual := tmpl.Sanitize() + assert.Contains(t, actual, "#set($x = 1)") + assert.Contains(t, actual, "$criteria.AppendBinding($x)") + assert.Contains(t, actual, "$criteria.AppendBinding($Unsafe.Name)") +} + +func TestSanitize_SkipsFirstSetVariableOccurrence(t *testing.T) { + iter := newIterable(map[string]bool{"x": true}) + expr := &Expression{ + IsVariable: true, + OccurrenceIndex: 0, + Context: SetContext, + FullName: "$x", + Start: 0, + End: 2, + } + dst := []byte("$x") + actual, _ := sanitize(iter, expr, dst, 0, 0) + assert.Equal(t, "$x", string(actual)) +} + +func TestUnwrapBrackets(t *testing.T) { + raw, had := unwrapBrackets("${Foo}") + assert.Equal(t, "$Foo", raw) + assert.True(t, had) + + raw, had = unwrapBrackets("$Foo") + assert.Equal(t, "$Foo", raw) + assert.False(t, had) +} + +func TestSanitizeContent(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Start: 0, End: 10} + assert.Equal(t, "$A", sanitizeContent(iter, expr, "$A")) + + iter = newIterable(nil) + parent := &Expression{Start: 0, End: 13, FullName: "$Fn($A, $B)"} + argA := &Expression{Start: 4, End: 6, FullName: "$A", Holder: "A"} + argB := &Expression{Start: 8, End: 10, FullName: "$B", Holder: "B"} + next := &Expression{Start: 20, End: 22, FullName: "$C", Holder: "C"} + iter.expressions = Expressions{argA, argB, next} + actual := sanitizeContent(iter, parent, parent.FullName) + assert.Equal(t, "$Fn($criteria.AppendBinding($Unsafe.A), $criteria.AppendBinding($Unsafe.B))", actual) +} + +func TestSanitizeParameter(t *testing.T) { + t.Run("standalone fn entry is preserved", func(t *testing.T) { + name := "TestStandaloneSanitize" + keywords.Add(name, functions.NewEntry(nil, &keywords.StandaloneFn{})) + iter := newIterable(nil) + expr := &Expression{Holder: name, FullName: "$" + name + "(1)"} + assert.Equal(t, "$"+name+"(1)", sanitizeParameter(expr, "$"+name+"(1)", iter, nil, 0)) + }) + + t.Run("set marker prefix preserved", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "Value", Prefix: keywords.SetMarkerKey} + assert.Equal(t, "$Value", sanitizeParameter(expr, "$Value", iter, nil, 0)) + }) + + t.Run("namespace metadata preserved", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + Holder: "Any", + Entry: functions.NewEntry(nil, keywords.NewNamespace()), + } + assert.Equal(t, "$Any", sanitizeParameter(expr, "$Any", iter, nil, 0)) + }) + + t.Run("const parameter gets Unsafe prefix", func(t *testing.T) { + iter := newIterable(nil, inference.NewConstParameter("ConstX", 1)) + expr := &Expression{Holder: "ConstX"} + assert.Equal(t, "$Unsafe.ConstX", sanitizeParameter(expr, "$ConstX", iter, nil, 0)) + }) + + t.Run("func context with variable and Params prefix strips prefix", func(t *testing.T) { + iter := newIterable(map[string]bool{"X": true}) + expr := &Expression{Holder: "X", Prefix: keywords.ParamsKey, Context: FuncContext} + assert.Equal(t, "$X", sanitizeParameter(expr, "$Unsafe.X", iter, nil, 0)) + }) + + t.Run("func context with non variable and empty prefix adds Unsafe", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "X", Prefix: "", Context: FuncContext} + assert.Equal(t, "$Unsafe.X", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) + + t.Run("func context with variable and custom prefix keeps raw", func(t *testing.T) { + iter := newIterable(map[string]bool{"X": true}) + expr := &Expression{Holder: "X", Prefix: keywords.AndPrefix, Context: ForEachContext} + assert.Equal(t, "$X", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) + + t.Run("func context with non variable and non empty prefix keeps raw", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "X", Prefix: keywords.OrPrefix, Context: SetContext} + assert.Equal(t, "$X", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) + + t.Run("func context with expression entry preserves raw", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "X", Context: IfContext, Entry: functions.NewEntry(nil, nil)} + assert.Equal(t, "$X", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) + + t.Run("append context variable with Params prefix strips prefix", func(t *testing.T) { + iter := newIterable(map[string]bool{"X": true}) + expr := &Expression{Holder: "X", Prefix: keywords.ParamsKey} + assert.Equal(t, "$X", sanitizeParameter(expr, "$Unsafe.X", iter, nil, 0)) + }) + + t.Run("append context variable placeholder", func(t *testing.T) { + iter := newIterable(map[string]bool{"X": true}) + expr := &Expression{Holder: "X"} + assert.Equal(t, "$criteria.AppendBinding($X)", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) + + t.Run("append context params prefix preserved", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "X", Prefix: keywords.ParamsKey} + assert.Equal(t, "$Unsafe.X", sanitizeParameter(expr, "$Unsafe.X", iter, nil, 0)) + }) + + t.Run("context metadata unexpand raw preserved", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + Holder: "Ctx", + Entry: functions.NewEntry(nil, keywords.NewContextMetadata("ctx", nil, true)), + } + assert.Equal(t, "$Ctx", sanitizeParameter(expr, "$Ctx", iter, nil, 0)) + }) + + t.Run("context metadata expandable becomes placeholder", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + Holder: "Ctx", + Entry: functions.NewEntry(nil, keywords.NewContextMetadata("ctx", nil, false)), + } + assert.Equal(t, "$criteria.AppendBinding($Ctx)", sanitizeParameter(expr, "$Ctx", iter, nil, 0)) + }) + + t.Run("non context metadata entry becomes placeholder", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + Holder: "Ctx", + Entry: functions.NewEntry(nil, struct{}{}), + } + assert.Equal(t, "$criteria.AppendBinding($Ctx)", sanitizeParameter(expr, "$Ctx", iter, nil, 0)) + }) + + t.Run("default path adds Unsafe and placeholder", func(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{Holder: "X"} + assert.Equal(t, "$criteria.AppendBinding($Unsafe.X)", sanitizeParameter(expr, "$X", iter, nil, 0)) + }) +} + +func TestSanitizeAsPlaceholder(t *testing.T) { + assert.Equal(t, "$criteria.AppendBinding($X)", sanitizeAsPlaceholder("$X")) +} + +func TestSanitize_WithBracketsWrapping(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + FullName: "${X}", + Holder: "X", + Start: 0, + End: 4, + } + dst := []byte("${X}") + actual, _ := sanitize(iter, expr, dst, 0, 0) + assert.Equal(t, "${criteria.AppendBinding($Unsafe.X)}", string(actual)) +} + +func TestSanitize_NoChangePathAndCursorOffset(t *testing.T) { + iter := newIterable(nil) + expr := &Expression{ + FullName: "$Unsafe.X", + Holder: "X", + Prefix: keywords.ParamsKey, + Start: 8, + End: 17, + } + dst := []byte("SELECT " + expr.FullName) + actual, offset := sanitize(iter, expr, dst, 0, 7) + assert.Equal(t, "SELECT $Unsafe.X", string(actual)) + assert.Equal(t, 0, offset) +} + +func newIterable(declared map[string]bool, params ...*inference.Parameter) *iterables { + if declared == nil { + declared = map[string]bool{} + } + state := inference.State{} + for _, param := range params { + if param != nil { + state.Append(param) + } + } + tmpl := &Template{ + Declared: declared, + State: &state, + } + return &iterables{expressionMatcher: &expressionMatcher{Template: tmpl}} +} diff --git a/service/executor/expand/evaluator_test.go b/service/executor/expand/evaluator_test.go new file mode 100644 index 00000000..b6e0a7e2 --- /dev/null +++ b/service/executor/expand/evaluator_test.go @@ -0,0 +1,78 @@ +package expand_test + +import ( + "testing" + + "github.com/viant/datly/service/executor/expand" +) + +func TestNewEvaluator_DefaultTypeLookup(t *testing.T) { + evaluator, err := expand.NewEvaluator(`#set($x = $New("int"))$x`) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if _, err := evaluator.Evaluate(nil); err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestNewEvaluator_WithNilTypeLookupOption(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("expected no panic, got %v", r) + } + }() + + evaluator, err := expand.NewEvaluator(`#set($x = $New("int"))$x`, expand.WithTypeLookup(nil)) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if _, err := evaluator.Evaluate(nil); err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestNewEvaluator_UnknownTypeReturnsError(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("expected no panic, got %v", r) + } + }() + + _, err := expand.NewEvaluator(`#set($x = $New("DefinitelyNotAType"))$x`, expand.WithTypeLookup(nil)) + if err == nil { + t.Fatalf("expected error") + } +} + +func TestNewEvaluator_WithNilNamedVariableOption(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("expected no panic, got %v", r) + } + }() + + evaluator, err := expand.NewEvaluator(`ok`, expand.WithVariable(nil)) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if _, err := evaluator.Evaluate(nil); err != nil { + t.Fatalf("expected no error, got %v", err) + } +} + +func TestNewEvaluator_WithNilCustomContextOption(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("expected no panic, got %v", r) + } + }() + + evaluator, err := expand.NewEvaluator(`ok`, expand.WithCustomContexts(nil)) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if _, err := evaluator.Evaluate(nil); err != nil { + t.Fatalf("expected no error, got %v", err) + } +} diff --git a/service/session/selector_injector_test.go b/service/session/selector_injector_test.go new file mode 100644 index 00000000..7f8275fb --- /dev/null +++ b/service/session/selector_injector_test.go @@ -0,0 +1,89 @@ +package session + +import ( + "context" + "net/http" + "reflect" + "testing" + + "github.com/viant/datly/repository" + "github.com/viant/datly/view" + vstate "github.com/viant/datly/view/state" + hstate "github.com/viant/xdatly/handler/state" +) + +func TestSessionBind_QuerySelectorOverride_PageComputesOffset(t *testing.T) { + ctx := context.Background() + + resource := view.NewResource(nil) + trueValue := true + aView := &view.View{ + Name: "v", + Mode: view.ModeQuery, + Selector: func() *view.Config { + cfg := view.QueryStateParameters.Clone() + cfg.Limit = 5 + cfg.Constraints = &view.Constraints{ + Criteria: true, + OrderBy: true, + Limit: true, + Offset: true, + Projection: true, + Page: &trueValue, + } + return cfg + }(), + } + aView.SetResource(resource) + aView.Template = &view.Template{Schema: vstate.NewSchema(reflect.TypeOf(struct{ Dummy int }{}))} + if err := aView.Template.Init(ctx, resource, aView); err != nil { + t.Fatalf("failed to init template: %v", err) + } + if err := aView.Selector.Init(ctx, resource, aView); err != nil { + t.Fatalf("failed to init selector: %v", err) + } + + component := &repository.Component{View: aView} + outputType, err := vstate.NewType( + vstate.WithSchema(vstate.NewSchema(reflect.TypeOf(struct{ X int }{}))), + vstate.WithResource(aView.Resource()), + ) + if err != nil { + t.Fatalf("failed to build component output type: %v", err) + } + component.Output.Type = *outputType + + sess := New(aView, WithComponent(component)) + var dest struct{} + + // request supplies different selector values; injected selector should take precedence + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1/?_page=1&_limit=1", nil) + if err != nil { + t.Fatalf("failed to build request: %v", err) + } + + err = sess.Bind(ctx, &dest, hstate.WithQuerySelector(&hstate.NamedQuerySelector{ + Name: "v", + QuerySelector: hstate.QuerySelector{ + Page: 2, + }, + }), hstate.WithHttpRequest(req)) + if err != nil { + t.Fatalf("Bind() error: %v", err) + } + + if err := sess.SetViewState(ctx, aView); err != nil { + t.Fatalf("SetViewState() error: %v", err) + } + + selector := sess.State().Lookup(aView) + if selector.Page != 2 { + t.Fatalf("expected Page=2, got %d", selector.Page) + } + if selector.Limit != 5 { + t.Fatalf("expected Limit=5, got %d", selector.Limit) + } + if selector.Offset != 5 { + t.Fatalf("expected Offset=5, got %d", selector.Offset) + } +} From f16184b1403ce9f088c49aae53b7b1bb3af52538 Mon Sep 17 00:00:00 2001 From: adranwit Date: Tue, 24 Feb 2026 07:27:18 -0800 Subject: [PATCH 5/6] shape/compile: add type support helpers; refine preprocessing and type defaults; update tests --- go.mod | 12 +- go.sum | 6 + repository/shape/compile/compiler.go | 46 +-- repository/shape/compile/compiler_test.go | 74 ++--- repository/shape/compile/component_types.go | 75 ++++- .../shape/compile/component_types_test.go | 49 +++ repository/shape/compile/enrich.go | 301 ++++++------------ .../shape/compile/preprocess_handler.go | 39 +-- .../shape/compile/preprocess_handler_test.go | 52 +-- repository/shape/compile/strings_util.go | 13 + repository/shape/compile/type_support.go | 238 ++++++++++++++ repository/shape/compile/type_support_test.go | 70 ++++ repository/shape/compile/typectx_defaults.go | 102 +++++- .../shape/compile/typectx_defaults_test.go | 16 + repository/shape/platform_parity_test.go | 3 + warmup/cache_test.go | 5 + 16 files changed, 706 insertions(+), 395 deletions(-) create mode 100644 repository/shape/compile/strings_util.go create mode 100644 repository/shape/compile/type_support.go create mode 100644 repository/shape/compile/type_support_test.go diff --git a/go.mod b/go.mod index baaae6b4..60d51b51 100644 --- a/go.mod +++ b/go.mod @@ -2,12 +2,6 @@ module github.com/viant/datly go 1.25.0 -replace github.com/viant/velty => ../velty - -replace github.com/viant/x => ../x - -replace github.com/viant/sqlparser => ../sqlparser - require ( github.com/aerospike/aerospike-client-go v4.5.2+incompatible github.com/aws/aws-lambda-go v1.31.0 @@ -37,7 +31,7 @@ require ( github.com/viant/sqlx v0.21.0 github.com/viant/structql v0.5.4 github.com/viant/toolbox v0.37.0 - github.com/viant/velty v0.2.1-0.20230927172116-ba56497b5c85 + github.com/viant/velty v0.4.0 github.com/viant/xreflect v0.7.3 github.com/viant/xunsafe v0.10.3 golang.org/x/mod v0.28.0 @@ -48,7 +42,7 @@ require ( require ( github.com/viant/govalidator v0.3.1 - github.com/viant/sqlparser v0.9.0 + github.com/viant/sqlparser v0.11.0 ) require ( @@ -59,7 +53,7 @@ require ( github.com/viant/mcp-protocol v0.9.0 github.com/viant/structology v0.8.0 github.com/viant/tagly v0.3.0 - github.com/viant/x v0.3.0 + github.com/viant/x v0.4.0 github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a github.com/viant/xdatly/extension v0.0.0-20231013204918-ecf3c2edf259 github.com/viant/xdatly/handler v0.0.0-20251208172928-dd34b7f09fd5 diff --git a/go.sum b/go.sum index d1e092d5..5164a250 100644 --- a/go.sum +++ b/go.sum @@ -1194,6 +1194,8 @@ github.com/viant/pgo v0.11.0 h1:PNuYVhwTfyrAHGBO6lxaMFuHP4NkjKV8ULecz3OWk8c= github.com/viant/pgo v0.11.0/go.mod h1:MFzHmkRFZlciugEgUvpl/3grK789PBSH4dUVSLOSo+Q= github.com/viant/scy v0.24.0 h1:KAC3IUARkQxTNSuwBK2YhVBJMOOLN30YaLKHbbuSkMU= github.com/viant/scy v0.24.0/go.mod h1:7uNRS67X45YN+JqTLCcMEhehffVjqrejULEDln9p0Ao= +github.com/viant/sqlparser v0.11.0 h1:RVmAsEieZlnRO33DWWvDXJOTY+sXJGTymPaC1iWnkOc= +github.com/viant/sqlparser v0.11.0/go.mod h1:2QRGiGZYk2/pjhORGG1zLVQ9JO+bXFhqIVi31mkCRPg= github.com/viant/sqlx v0.21.0 h1:Lx5KXmzfSjSvZZX5P0Ua9kFGvAmCxAjLOPe9pQA7VmY= github.com/viant/sqlx v0.21.0/go.mod h1:woTOwNiqvt6SqkI+5nyzlixcRTTV0IvLZUTberqb8mo= github.com/viant/structology v0.8.0 h1:WKdK67l+O1eqsubn8PWMhWcgspUGJ22SgJxUMfiRgqE= @@ -1206,6 +1208,10 @@ github.com/viant/toolbox v0.24.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMI github.com/viant/toolbox v0.34.5/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= github.com/viant/toolbox v0.37.0 h1:+zwSdbQh6I6ZEyxokQJr+1gQKbLEw6erc+Av5dwKtLU= github.com/viant/toolbox v0.37.0/go.mod h1:OxMCG57V0PXuIP2HNQrtJf2CjqdmbrOx5EkMILuUhzM= +github.com/viant/velty v0.4.0 h1:eesQES/vCpcoPbM+gQLUBuLEL2sEO+A6s6lPpl8eKc4= +github.com/viant/velty v0.4.0/go.mod h1:Q/UXviI2Nli8WROEpYd/BELMCSvnulQeyNrbPmMiS/Y= +github.com/viant/x v0.4.0 h1:n2xuxQdw4lYtMdi59IAQEZHPioNT9InENGGbapyz+P4= +github.com/viant/x v0.4.0/go.mod h1:1TvsnpZFqI9dYVzIkaSYJyJ/UkfxW7fnk0YFafWXrPg= github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a h1:7CLO2LjVnFgOwN0FL3Q4y5NrD7DpclS21AiW6tDLIc8= github.com/viant/xdatly v0.5.4-0.20251113181159-0ac8b8b0ff3a/go.mod h1:lZKZHhVdCZ3U9TU6GUFxKoGN3dPtqt2HkDYzJPq5CEs= github.com/viant/xdatly/extension v0.0.0-20231013204918-ecf3c2edf259 h1:9Yry3PUBDzc4rWacOYvAq/TKrTV0agvMF0gwm2gaoHI= diff --git a/repository/shape/compile/compiler.go b/repository/shape/compile/compiler.go index db57701a..d283d7c5 100644 --- a/repository/shape/compile/compiler.go +++ b/repository/shape/compile/compiler.go @@ -66,12 +66,7 @@ func (c *DQLCompiler) Compile(_ context.Context, source *shape.Source, opts ...s pre = prepared.Pre statements = prepared.Statements decision = prepared.Decision - legacyFallbackViews := prepared.LegacyViews - effectiveSource := source - if prepared.EffectiveSource != nil { - effectiveSource = prepared.EffectiveSource - } - if strings.TrimSpace(pre.SQL) == "" && len(legacyFallbackViews) == 0 { + if strings.TrimSpace(pre.SQL) == "" { allDiags = append(allDiags, &dqlshape.Diagnostic{ Code: dqldiag.CodeParseEmpty, Severity: dqlshape.SeverityError, @@ -87,11 +82,7 @@ func (c *DQLCompiler) Compile(_ context.Context, source *shape.Source, opts ...s var root *plan.View var compileDiags []*dqlshape.Diagnostic var err error - if len(legacyFallbackViews) > 0 { - root = legacyFallbackViews[0] - } else { - root, compileDiags, err = c.compileRoot(source.Name, pre.SQL, statements, decision, compileOptions.MixedMode, compileOptions.UnknownNonReadMode) - } + root, compileDiags, err = c.compileRoot(source.Name, pre.SQL, statements, decision, compileOptions.MixedMode, compileOptions.UnknownNonReadMode) if err != nil { return nil, err } @@ -102,18 +93,6 @@ func (c *DQLCompiler) Compile(_ context.Context, source *shape.Source, opts ...s } result := newPlanResult(root) - if len(legacyFallbackViews) > 1 { - for _, item := range legacyFallbackViews[1:] { - if item == nil || strings.TrimSpace(item.Name) == "" { - continue - } - if _, exists := result.ViewsByName[item.Name]; exists { - continue - } - result.Views = append(result.Views, item) - result.ViewsByName[item.Name] = item - } - } result.Diagnostics = allDiags result.TypeContext = pre.TypeCtx result.Directives = pre.Directives @@ -122,26 +101,11 @@ func (c *DQLCompiler) Compile(_ context.Context, source *shape.Source, opts ...s appendRelationViews(result, root, hints) appendDeclaredViews(source.DQL, result) appendDeclaredStates(source.DQL, result) - if prepared.ForceLegacyContract && len(legacyFallbackViews) > 0 { - if legacyStates := resolveLegacyRouteStatesWithLayout(effectiveSource, pathLayout); len(legacyStates) > 0 { - result.States = legacyStates - } - if legacyTypes := resolveLegacyRouteTypesWithLayout(effectiveSource, pathLayout); len(legacyTypes) > 0 { - result.Types = legacyTypes - } - } - result.Diagnostics = append(result.Diagnostics, appendComponentTypesWithLayout(effectiveSource, result, pathLayout)...) - mergeLegacyRouteStatesWithLayout(result, effectiveSource, pathLayout) - mergeLegacyRouteTypesWithLayout(result, effectiveSource, pathLayout) + _ = prepared applyViewHints(result, hints) - applySourceParityEnrichmentWithLayout(result, effectiveSource, pathLayout) + applySourceParityEnrichmentWithLayout(result, source, pathLayout) + applyLinkedTypeSupport(result, source) result.Diagnostics = append(result.Diagnostics, applyColumnDiscoveryPolicy(result, compileOptions)...) - if len(result.States) == 0 && len(legacyFallbackViews) > 0 { - result.States = resolveLegacyRouteStatesWithLayout(effectiveSource, pathLayout) - } - if len(result.Types) == 0 && len(legacyFallbackViews) > 0 { - result.Types = resolveLegacyRouteTypesWithLayout(effectiveSource, pathLayout) - } if enforceStrict && hasEscalationWarnings(result.Diagnostics) { return nil, &CompileError{Diagnostics: filterEscalationDiagnostics(result.Diagnostics)} diff --git a/repository/shape/compile/compiler_test.go b/repository/shape/compile/compiler_test.go index 63156250..f1448785 100644 --- a/repository/shape/compile/compiler_test.go +++ b/repository/shape/compile/compiler_test.go @@ -4,7 +4,6 @@ import ( "context" "os" "path/filepath" - "strings" "testing" "github.com/stretchr/testify/assert" @@ -60,8 +59,8 @@ SELECT id func TestDQLCompiler_Compile_PropagatesTypeContext(t *testing.T) { compiler := New() dql := ` -#settings($_ = $package('mdp/performance')) -#settings($_ = $import('perf', 'github.com/acme/mdp/performance')) +#package('mdp/performance') +#import('perf', 'github.com/acme/mdp/performance') SELECT id FROM ORDERS t` res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}) require.NoError(t, err) @@ -75,6 +74,26 @@ SELECT id FROM ORDERS t` assert.Equal(t, "perf", planned.TypeContext.Imports[0].Alias) } +func TestDQLCompiler_Compile_PropagatesImportedTypeContextWithModuleNormalization(t *testing.T) { + compiler := New() + projectDir := t.TempDir() + err := os.WriteFile(filepath.Join(projectDir, "go.mod"), []byte("module github.vianttech.com/viant/platform\n\ngo 1.23\n"), 0o644) + require.NoError(t, err) + source := &shape.Source{ + Name: "orders_report", + Path: filepath.Join(projectDir, "dql", "platform", "taxonomy", "get.dql"), + DQL: "#import('session','pkg/platform/system/session')\nSELECT id FROM ORDERS t", + } + res, err := compiler.Compile(context.Background(), source) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotNil(t, planned.TypeContext) + require.Len(t, planned.TypeContext.Imports, 1) + assert.Equal(t, "session", planned.TypeContext.Imports[0].Alias) + assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/system/session", planned.TypeContext.Imports[0].Package) +} + func TestDQLCompiler_Compile_PropagatesSpecialDirectives(t *testing.T) { compiler := New() dql := ` @@ -127,7 +146,7 @@ func TestDQLCompiler_Compile_ColumnDiscoveryOffFailsWhenRequired(t *testing.T) { func TestDQLCompiler_Compile_TypeContextValidationWarnsInCompat(t *testing.T) { compiler := New() dql := ` -#settings($_ = $package('github.com/acme/perf')) +#package('github.com/acme/perf') SELECT id FROM ORDERS t` res, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: dql}, shape.WithTypeContextPackageName("bad/name")) require.NoError(t, err) @@ -196,7 +215,7 @@ func TestDQLCompiler_Compile_SyntaxError_RemapsAfterSanitize(t *testing.T) { func TestDQLCompiler_Compile_DirectiveOnly_HasLineAndChar(t *testing.T) { compiler := New() - _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "#settings($_ = $package('x'))"}) + _, err := compiler.Compile(context.Background(), &shape.Source{Name: "orders_report", DQL: "#package('x')"}) require.Error(t, err) compileErr, ok := err.(*CompileError) require.True(t, ok) @@ -211,7 +230,7 @@ func TestDQLCompiler_Compile_InvalidDirective_HasLineAndChar(t *testing.T) { compiler := New() _, err := compiler.Compile(context.Background(), &shape.Source{ Name: "orders_report", - DQL: "SELECT id FROM ORDERS t\n#settings($_ = $import('alias'))\nSELECT id FROM ORDERS t", + DQL: "SELECT id FROM ORDERS t\n#import('alias')\nSELECT id FROM ORDERS t", }) require.Error(t, err) compileErr, ok := err.(*CompileError) @@ -418,7 +437,7 @@ func TestDQLCompiler_Compile_DMLSyntaxError_HasLineAndChar(t *testing.T) { compiler := New() _, err := compiler.Compile(context.Background(), &shape.Source{ Name: "orders_exec", - DQL: "#settings($_ = $package('x'))\nINSERT INTO ORDERS(id VALUES (1)", + DQL: "#package('x')\nINSERT INTO ORDERS(id VALUES (1)", }) require.Error(t, err) compileErr, ok := err.(*CompileError) @@ -661,7 +680,7 @@ JOIN (SELECT * FROM session/attributes) attribute ON attribute.user_id = session assert.Equal(t, "system", related.Connector) } -func TestDQLCompiler_Compile_GeneratedHandler_NoBodyInput_UsesLegacyContractStates(t *testing.T) { +func TestDQLCompiler_Compile_GeneratedHandler_NoBodyInput_DoesNotLoadLegacyContractStates(t *testing.T) { tempDir := t.TempDir() genPath := filepath.Join(tempDir, "dql", "system", "upload", "gen", "upload", "delete.dql") require.NoError(t, os.MkdirAll(filepath.Dir(genPath), 0o755)) @@ -705,21 +724,10 @@ func TestDQLCompiler_Compile_GeneratedHandler_NoBodyInput_UsesLegacyContractStat assert.Equal(t, "SQLExec", planned.Views[0].Mode) assert.Equal(t, "system", planned.Views[0].Connector) - stateByName := map[string]*plan.State{} - for _, item := range planned.States { - if item == nil { - continue - } - stateByName[item.Name] = item - } - require.Contains(t, stateByName, "Method") - require.Contains(t, stateByName, "UploadId") - assert.Equal(t, "http_request", stateByName["Method"].Kind) - assert.Equal(t, "query", stateByName["UploadId"].Kind) - assert.NotContains(t, stateByName, "Body") + assert.Empty(t, planned.States) } -func TestDQLCompiler_Compile_HandlerLegacyTypes_PreferredOverComponentNameCollisions(t *testing.T) { +func TestDQLCompiler_Compile_HandlerLegacyTypes_NotLoadedFromLegacyRouteYAML(t *testing.T) { tempDir := t.TempDir() sourcePath := filepath.Join(tempDir, "dql", "platform", "campaign", "post.dql") require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) @@ -776,26 +784,10 @@ func TestDQLCompiler_Compile_HandlerLegacyTypes_PreferredOverComponentNameCollis planned, ok := res.Plan.(*plan.Result) require.True(t, ok) - typeByName := map[string]*plan.Type{} - for _, item := range planned.Types { - if item == nil || strings.TrimSpace(item.Name) == "" { - continue - } - typeByName[strings.ToLower(strings.TrimSpace(item.Name))] = item - } - - inputType, ok := typeByName["input"] - require.True(t, ok) - assert.Equal(t, "campaign/patch", inputType.Package) - assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/campaign/patch", inputType.ModulePath) - - handlerType, ok := typeByName["handler"] - require.True(t, ok) - assert.Equal(t, "campaign/patch", handlerType.Package) - assert.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/campaign/patch", handlerType.ModulePath) + assert.Empty(t, planned.Types) } -func TestDQLCompiler_Compile_CustomPathLayout_HandlerFallback(t *testing.T) { +func TestDQLCompiler_Compile_CustomPathLayout_NoLegacyHandlerFallback(t *testing.T) { tempDir := t.TempDir() sourcePath := filepath.Join(tempDir, "sqlsrc", "platform", "campaign", "post.dql") require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) @@ -825,7 +817,5 @@ func TestDQLCompiler_Compile_CustomPathLayout_HandlerFallback(t *testing.T) { require.True(t, ok) require.NotEmpty(t, planned.Views) assert.Equal(t, "post", planned.Views[0].Name) - assert.Equal(t, "SQLExec", planned.Views[0].Mode) - assert.Equal(t, "ci_ads", planned.Views[0].Connector) - assert.Contains(t, planned.Views[0].SQL, "$Nop(") + assert.NotContains(t, planned.Views[0].SQL, "$Nop(") } diff --git a/repository/shape/compile/component_types.go b/repository/shape/compile/component_types.go index 5c553bc9..2cd3d705 100644 --- a/repository/shape/compile/component_types.go +++ b/repository/shape/compile/component_types.go @@ -39,6 +39,8 @@ func appendComponentTypesWithLayout(source *shape.Source, result *plan.Result, l visited: map[string]componentVisitState{}, outputByRoute: map[string]string{}, typesByName: map[string]*plan.Type{}, + payloadCache: map[string]routePayloadLookup{}, + reportedDiag: map[string]bool{}, } if strings.TrimSpace(sourceNamespace) != "" { collector.collect(sourceNamespace, relationSpan(source.DQL, 0), false) @@ -109,9 +111,19 @@ type componentCollector struct { visited map[string]componentVisitState outputByRoute map[string]string typesByName map[string]*plan.Type + payloadCache map[string]routePayloadLookup + reportedDiag map[string]bool diags []*dqlshape.Diagnostic } +type routePayloadLookup struct { + payload *routePayload + found bool + malformed bool + malformedAt string + detail string +} + func (c *componentCollector) collect(namespace string, span dqlshape.Span, required bool) (string, bool) { key := strings.ToLower(strings.TrimSpace(namespace)) if key == "" { @@ -132,10 +144,11 @@ func (c *componentCollector) collect(namespace string, span dqlshape.Span, requi } c.visited[key] = componentVisitActive - payload, ok := loadRoutePayload(c.routesRoot, namespace) + payload, ok := c.loadRoutePayload(namespace, span) if !ok { c.visited[key] = componentVisitDone - if required { + if required && !c.hasReported("missing:"+key) { + c.reportedDiag["missing:"+key] = true c.diags = append(c.diags, &dqlshape.Diagnostic{ Code: dqldiag.CodeCompRouteMissing, Severity: dqlshape.SeverityWarning, @@ -347,7 +360,13 @@ type routePayload struct { } func loadRoutePayload(routesRoot, namespace string) (*routePayload, bool) { + lookup := readRoutePayload(routesRoot, namespace) + return lookup.payload, lookup.found +} + +func readRoutePayload(routesRoot, namespace string) routePayloadLookup { candidates := routeYAMLCandidates(routesRoot, namespace) + lookup := routePayloadLookup{} for _, candidate := range candidates { data, err := os.ReadFile(candidate) if err != nil { @@ -355,11 +374,59 @@ func loadRoutePayload(routesRoot, namespace string) (*routePayload, bool) { } payload := &routePayload{} if err = yaml.Unmarshal(data, payload); err != nil { + if !lookup.malformed { + lookup.malformed = true + lookup.malformedAt = candidate + lookup.detail = strings.TrimSpace(err.Error()) + } continue } - return payload, true + lookup.payload = payload + lookup.found = true + lookup.malformed = false + lookup.malformedAt = "" + lookup.detail = "" + return lookup + } + return lookup +} + +func (c *componentCollector) loadRoutePayload(namespace string, span dqlshape.Span) (*routePayload, bool) { + key := strings.ToLower(strings.TrimSpace(namespace)) + if key == "" { + return nil, false + } + lookup, ok := c.payloadCache[key] + if !ok { + lookup = readRoutePayload(c.routesRoot, namespace) + c.payloadCache[key] = lookup + } + if lookup.malformed && !lookup.found && !c.hasReported("invalid:"+key) { + c.reportedDiag["invalid:"+key] = true + message := "component route YAML malformed: " + namespace + if strings.TrimSpace(lookup.malformedAt) != "" { + message += " (" + lookup.malformedAt + ")" + } + hint := "fix route YAML format" + if strings.TrimSpace(lookup.detail) != "" { + hint += ": " + lookup.detail + } + c.diags = append(c.diags, &dqlshape.Diagnostic{ + Code: dqldiag.CodeCompRouteInvalid, + Severity: dqlshape.SeverityWarning, + Message: message, + Hint: hint, + Span: span, + }) + } + return lookup.payload, lookup.found +} + +func (c *componentCollector) hasReported(key string) bool { + if c == nil || c.reportedDiag == nil { + return false } - return nil, false + return c.reportedDiag[key] } func routeOutputType(payload *routePayload) string { diff --git a/repository/shape/compile/component_types_test.go b/repository/shape/compile/component_types_test.go index 0e93ec71..51570a12 100644 --- a/repository/shape/compile/component_types_test.go +++ b/repository/shape/compile/component_types_test.go @@ -153,3 +153,52 @@ func TestAppendComponentTypes_TypeCollisionEmitsDiagnostic(t *testing.T) { require.Len(t, result.Types, 1) assert.Equal(t, "campaign/patch", result.Types[0].Package) } + +func TestAppendComponentTypes_InvalidRouteYAMLEmitsDiagnostic(t *testing.T) { + temp := t.TempDir() + dqlDir := filepath.Join(temp, "dql", "platform", "sample") + routesDir := filepath.Join(temp, "repo", "dev", "Datly", "routes", "platform", "acl") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "auth"), 0o755)) + + sourcePath := filepath.Join(dqlDir, "sample.dql") + dql := "#set($Auth = $component<../acl/auth>())\nSELECT 1" + require.NoError(t, os.WriteFile(sourcePath, []byte(dql), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "auth", "auth.yaml"), []byte("Resource:\n Types: ["), 0o644)) + + result := &plan.Result{ + States: []*plan.State{{Name: "Auth", Kind: "component", In: "../acl/auth"}}, + } + diags := appendComponentTypes(&shape.Source{Path: sourcePath, DQL: dql}, result) + require.NotEmpty(t, diags) + assert.Equal(t, dqldiag.CodeCompRouteInvalid, diags[0].Code) +} + +func TestAppendComponentTypes_InvalidRouteYAMLDedupedForRepeatedStates(t *testing.T) { + temp := t.TempDir() + dqlDir := filepath.Join(temp, "dql", "platform", "sample") + routesDir := filepath.Join(temp, "repo", "dev", "Datly", "routes", "platform", "acl") + require.NoError(t, os.MkdirAll(dqlDir, 0o755)) + require.NoError(t, os.MkdirAll(filepath.Join(routesDir, "auth"), 0o755)) + + sourcePath := filepath.Join(dqlDir, "sample.dql") + dql := "#set($Auth1 = $component<../acl/auth>())\n#set($Auth2 = $component<../acl/auth>())\nSELECT 1" + require.NoError(t, os.WriteFile(sourcePath, []byte(dql), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(routesDir, "auth", "auth.yaml"), []byte("Resource:\n Types: ["), 0o644)) + + result := &plan.Result{ + States: []*plan.State{ + {Name: "Auth1", Kind: "component", In: "../acl/auth"}, + {Name: "Auth2", Kind: "component", In: "../acl/auth"}, + }, + } + diags := appendComponentTypes(&shape.Source{Path: sourcePath, DQL: dql}, result) + require.NotEmpty(t, diags) + invalidCount := 0 + for _, item := range diags { + if item != nil && item.Code == dqldiag.CodeCompRouteInvalid { + invalidCount++ + } + } + assert.Equal(t, 1, invalidCount) +} diff --git a/repository/shape/compile/enrich.go b/repository/shape/compile/enrich.go index 08b80570..04806404 100644 --- a/repository/shape/compile/enrich.go +++ b/repository/shape/compile/enrich.go @@ -10,7 +10,6 @@ import ( "github.com/viant/datly/repository/shape" "github.com/viant/datly/repository/shape/compile/pipeline" "github.com/viant/datly/repository/shape/plan" - "gopkg.in/yaml.v3" ) var ( @@ -30,6 +29,16 @@ type ruleSettings struct { URI string `json:"URI"` } +type parityEnrichmentContext struct { + source *shape.Source + settings *ruleSettings + baseDir string + module string + sourceName string + joinEmbedRefs map[string]string + joinSubqueryBodies map[string]string +} + func applySourceParityEnrichment(result *plan.Result, source *shape.Source) { applySourceParityEnrichmentWithLayout(result, source, defaultCompilePathLayout()) } @@ -38,233 +47,114 @@ func applySourceParityEnrichmentWithLayout(result *plan.Result, source *shape.So if result == nil || len(result.Views) == 0 { return } - settings := extractRuleSettings(source) - legacyViews := loadLegacyRouteViewAttrsWithLayout(source, settings, layout) - baseDir := sourceSQLBaseDir(source) - module := sourceModuleWithLayout(source, layout) - sourceName := pipeline.SanitizeName(source.Name) - joinEmbedRefs := map[string]string{} - joinSubqueryBodies := map[string]string{} - if len(result.Views) > 0 && result.Views[0] != nil { - sqlForJoinExtract := result.Views[0].SQL - if source != nil && strings.TrimSpace(source.DQL) != "" { - sqlForJoinExtract = source.DQL - } - joinEmbedRefs = extractJoinEmbedRefs(sqlForJoinExtract) - joinSubqueryBodies = extractJoinSubqueryBodies(sqlForJoinExtract) - } + ctx := buildParityEnrichmentContext(result, source, layout) for idx, item := range result.Views { if item == nil { continue } - if legacy, ok := lookupLegacyRouteViewAttr(legacyViews, item.Name); ok { - if legacy.Mode != "" { - item.Mode = legacy.Mode - } - if legacy.Module != "" { - item.Module = legacy.Module - } - if legacy.AllowNulls != nil { - value := *legacy.AllowNulls - item.AllowNulls = &value - } - if legacy.SelectorNamespace != "" { - item.SelectorNamespace = legacy.SelectorNamespace - } - if legacy.SelectorNoLimit != nil { - value := *legacy.SelectorNoLimit - item.SelectorNoLimit = &value - } - if legacy.SchemaType != "" { - item.SchemaType = legacy.SchemaType - } - if legacy.Cardinality != "" { - item.Cardinality = legacy.Cardinality - } - if legacy.HasSummary != nil && *legacy.HasSummary && strings.TrimSpace(item.Summary) == "" { - item.Summary = "legacy-summary" - } - } - if item.SQLURI == "" && baseDir != "" { - item.SQLURI = baseDir + "/" + item.Name + ".sql" - } - if item.Module == "" { - item.Module = module - } - if item.SelectorNamespace == "" { - item.SelectorNamespace = defaultSelectorNamespace(item.Name) - } - if item.SchemaType == "" { - item.SchemaType = defaultSchemaType(item.Name, settings, idx == 0) - } - if shouldInferTable(item) { - candidateSQL := item.SQL - if strings.TrimSpace(candidateSQL) == "" { - candidateSQL = item.Table - } - if table := inferTableFromSQL(candidateSQL, source); table != "" { - item.Table = table - } - } - if strings.HasPrefix(strings.TrimSpace(item.Table), "(") || normalizedTemplatePlaceholderTable(strings.TrimSpace(item.Table)) { - if ref, ok := joinEmbedRefs[item.Name]; ok { - if table := inferTableFromEmbedRef(source, ref); table != "" { - item.Table = table - } - } - if body, ok := joinSubqueryBodies[item.Name]; ok { - if table := inferTableFromSQL(body, source); table != "" { - item.Table = table - } - } - if table := inferTableFromSiblingSQL(item.Name, source); table != "" { - item.Table = table - } - } - if item.Connector == "" && settings.Connector != "" { - item.Connector = settings.Connector - } - if item.Connector == "" && source != nil && strings.TrimSpace(source.Connector) != "" { - item.Connector = strings.TrimSpace(source.Connector) - } - if item.Connector == "" { - item.Connector = inferConnector(item, source) - } - if item.Summary == "" { - item.Summary = extractSummarySQL(item.SQL) - if item.Summary == "" && source != nil { - item.Summary = extractSummarySQL(source.DQL) - } - } + applyViewDefaults(item, idx == 0, ctx) + applyTableInference(item, ctx) + applyConnectorInference(item, ctx) + applySummaryInference(item, ctx) } if source != nil && strings.TrimSpace(source.Path) != "" { - normalizeRootViewName(result, sourceName, settings) + normalizeRootViewName(result, ctx.sourceName) } } -type legacyRouteViewAttr struct { - Name string - Mode string - Module string - AllowNulls *bool - SelectorNamespace string - SelectorNoLimit *bool - SchemaType string - Cardinality string - HasSummary *bool +func buildParityEnrichmentContext(result *plan.Result, source *shape.Source, layout compilePathLayout) *parityEnrichmentContext { + ctx := &parityEnrichmentContext{ + source: source, + settings: extractRuleSettings(source), + baseDir: sourceSQLBaseDir(source), + module: sourceModuleWithLayout(source, layout), + sourceName: pipeline.SanitizeName(source.Name), + joinEmbedRefs: map[string]string{}, + joinSubqueryBodies: map[string]string{}, + } + if len(result.Views) == 0 || result.Views[0] == nil { + return ctx + } + sqlForJoinExtract := result.Views[0].SQL + if source != nil && strings.TrimSpace(source.DQL) != "" { + sqlForJoinExtract = source.DQL + } + ctx.joinEmbedRefs = extractJoinEmbedRefs(sqlForJoinExtract) + ctx.joinSubqueryBodies = extractJoinSubqueryBodies(sqlForJoinExtract) + return ctx } -func loadLegacyRouteViewAttrs(source *shape.Source, settings *ruleSettings) []legacyRouteViewAttr { - return loadLegacyRouteViewAttrsWithLayout(source, settings, defaultCompilePathLayout()) +func applyViewDefaults(item *plan.View, root bool, ctx *parityEnrichmentContext) { + if item == nil || ctx == nil { + return + } + if item.SQLURI == "" && ctx.baseDir != "" { + item.SQLURI = ctx.baseDir + "/" + item.Name + ".sql" + } + if item.Module == "" { + item.Module = ctx.module + } + if item.SelectorNamespace == "" { + item.SelectorNamespace = defaultSelectorNamespace(item.Name) + } + if item.SchemaType == "" { + item.SchemaType = defaultSchemaType(item.Name, ctx.settings, root) + } } -func loadLegacyRouteViewAttrsWithLayout(source *shape.Source, settings *ruleSettings, layout compilePathLayout) []legacyRouteViewAttr { - if source == nil || strings.TrimSpace(source.Path) == "" { - return nil - } - platformRoot, relativeDir, stem, ok := platformPathParts(source.Path, layout) - if !ok { - return nil - } - typeExpr := "" - if settings != nil { - typeExpr = strings.TrimSpace(settings.Type) - } - typeExpr = strings.Trim(typeExpr, `"'`) - typeExpr = strings.TrimSuffix(typeExpr, ".Handler") - typeStem := "" - if typeExpr != "" { - typeStem = filepath.Base(filepath.FromSlash(typeExpr)) - } - routesRoot := joinRelativePath(platformRoot, layout.routesRelative) - routesBase := filepath.Join(routesRoot, filepath.FromSlash(relativeDir)) - candidates := legacyRouteYAMLCandidates(routesBase, stem, typeStem) - for _, candidate := range candidates { - if attrs := parseLegacyRouteViewAttrs(candidate); len(attrs) > 0 { - return attrs +func applyTableInference(item *plan.View, ctx *parityEnrichmentContext) { + if item == nil || ctx == nil { + return + } + if shouldInferTable(item) { + candidateSQL := item.SQL + if strings.TrimSpace(candidateSQL) == "" { + candidateSQL = item.Table + } + if table := inferTableFromSQL(candidateSQL, ctx.source); table != "" { + item.Table = table } } - return nil -} - -func parseLegacyRouteViewAttrs(path string) []legacyRouteViewAttr { - data, err := os.ReadFile(path) - if err != nil { - return nil - } - var payload struct { - Resource struct { - Views []struct { - Name string `yaml:"Name"` - Mode string `yaml:"Mode"` - Module string `yaml:"Module"` - AllowNulls *bool `yaml:"AllowNulls"` - Selector struct { - Namespace string `yaml:"Namespace"` - NoLimit *bool `yaml:"NoLimit"` - } `yaml:"Selector"` - Template struct { - Summary *struct{} `yaml:"Summary"` - } `yaml:"Template"` - Schema struct { - Cardinality string `yaml:"Cardinality"` - DataType string `yaml:"DataType"` - Name string `yaml:"Name"` - } `yaml:"Schema"` - } `yaml:"Views"` - } `yaml:"Resource"` - } - if err = yaml.Unmarshal(data, &payload); err != nil { - return nil - } - result := make([]legacyRouteViewAttr, 0, len(payload.Resource.Views)) - for _, item := range payload.Resource.Views { - cardinality := strings.TrimSpace(item.Schema.Cardinality) - if cardinality != "" { - cardinality = strings.ToLower(cardinality) + if strings.HasPrefix(strings.TrimSpace(item.Table), "(") || normalizedTemplatePlaceholderTable(strings.TrimSpace(item.Table)) { + if ref, ok := ctx.joinEmbedRefs[item.Name]; ok { + if table := inferTableFromEmbedRef(ctx.source, ref); table != "" { + item.Table = table + } + } + if body, ok := ctx.joinSubqueryBodies[item.Name]; ok { + if table := inferTableFromSQL(body, ctx.source); table != "" { + item.Table = table + } + } + if table := inferTableFromSiblingSQL(item.Name, ctx.source); table != "" { + item.Table = table } - result = append(result, legacyRouteViewAttr{ - Name: strings.TrimSpace(item.Name), - Mode: strings.TrimSpace(item.Mode), - Module: strings.TrimSpace(item.Module), - AllowNulls: item.AllowNulls, - SelectorNamespace: strings.TrimSpace(item.Selector.Namespace), - SelectorNoLimit: item.Selector.NoLimit, - SchemaType: firstNonEmptyString(strings.TrimSpace(item.Schema.DataType), strings.TrimSpace(item.Schema.Name)), - Cardinality: cardinality, - HasSummary: func() *bool { - if item.Template.Summary == nil { - return nil - } - value := true - return &value - }(), - }) } - return result } -func lookupLegacyRouteViewAttr(items []legacyRouteViewAttr, name string) (legacyRouteViewAttr, bool) { - name = strings.TrimSpace(name) - if name == "" { - return legacyRouteViewAttr{}, false +func applyConnectorInference(item *plan.View, ctx *parityEnrichmentContext) { + if item == nil || ctx == nil || item.Connector != "" { + return } - for _, item := range items { - if strings.EqualFold(strings.TrimSpace(item.Name), name) { - return item, true - } + if ctx.settings != nil && ctx.settings.Connector != "" { + item.Connector = ctx.settings.Connector + } + if item.Connector == "" && ctx.source != nil && strings.TrimSpace(ctx.source.Connector) != "" { + item.Connector = strings.TrimSpace(ctx.source.Connector) + } + if item.Connector == "" { + item.Connector = inferConnector(item, ctx.source) } - return legacyRouteViewAttr{}, false } -func firstNonEmptyString(values ...string) string { - for _, value := range values { - value = strings.TrimSpace(value) - if value != "" { - return value - } +func applySummaryInference(item *plan.View, ctx *parityEnrichmentContext) { + if item == nil || ctx == nil || item.Summary != "" { + return + } + item.Summary = extractSummarySQL(item.SQL) + if item.Summary == "" && ctx.source != nil { + item.Summary = extractSummarySQL(ctx.source.DQL) } - return "" } func extractSummarySQL(sqlText string) string { @@ -677,7 +567,7 @@ func inferConnector(item *plan.View, source *shape.Source) string { } } -func normalizeRootViewName(result *plan.Result, sourceName string, settings *ruleSettings) { +func normalizeRootViewName(result *plan.Result, sourceName string) { if result == nil || len(result.Views) == 0 { return } @@ -689,7 +579,6 @@ func normalizeRootViewName(result *plan.Result, sourceName string, settings *rul if desired == "" { return } - _ = settings current := strings.TrimSpace(root.Name) if current == "" { root.Name = desired diff --git a/repository/shape/compile/preprocess_handler.go b/repository/shape/compile/preprocess_handler.go index bea319f7..22f8c4df 100644 --- a/repository/shape/compile/preprocess_handler.go +++ b/repository/shape/compile/preprocess_handler.go @@ -9,16 +9,13 @@ import ( "github.com/viant/datly/repository/shape/compile/pipeline" dqlpre "github.com/viant/datly/repository/shape/dql/preprocess" dqlstmt "github.com/viant/datly/repository/shape/dql/statement" - "github.com/viant/datly/repository/shape/plan" ) type handlerPreprocessResult struct { - Pre *dqlpre.Result - Statements dqlstmt.Statements - Decision pipeline.Decision - LegacyViews []*plan.View - EffectiveSource *shape.Source - ForceLegacyContract bool + Pre *dqlpre.Result + Statements dqlstmt.Statements + Decision pipeline.Decision + EffectiveSource *shape.Source } func buildHandlerIfNeeded(source *shape.Source, pre *dqlpre.Result, statements dqlstmt.Statements, decision pipeline.Decision, layout compilePathLayout) *handlerPreprocessResult { @@ -45,21 +42,17 @@ func buildHandlerIfNeeded(source *shape.Source, pre *dqlpre.Result, statements d } func buildHandlerFromContractIfNeeded(ret *handlerPreprocessResult, source *shape.Source, layout compilePathLayout) bool { - if ret == nil || source == nil { - return false - } - return buildLegacyRouteFallbackIfNeeded(ret, source, layout) + _ = ret + _ = source + _ = layout + return false } func buildGeneratedFallbackIfNeeded(ret *handlerPreprocessResult, source *shape.Source, layout compilePathLayout) bool { if ret == nil || source == nil { return false } - if alternate := resolveGeneratedLegacySource(source); alternate != nil { - if buildLegacyRouteFallbackIfNeeded(ret, alternate, layout) { - return true - } - } + _ = layout generated := strings.TrimSpace(resolveGeneratedCompanionDQL(source)) if generated == "" { return false @@ -79,20 +72,6 @@ func buildGeneratedFallbackIfNeeded(ret *handlerPreprocessResult, source *shape. return true } -func buildLegacyRouteFallbackIfNeeded(ret *handlerPreprocessResult, source *shape.Source, layout compilePathLayout) bool { - if ret == nil || source == nil { - return false - } - legacyFallbackViews := resolveLegacyRouteViewsWithLayout(source, layout) - if len(legacyFallbackViews) == 0 { - return false - } - ret.LegacyViews = legacyFallbackViews - ret.EffectiveSource = source - ret.ForceLegacyContract = true - return true -} - func resolveGeneratedLegacySource(source *shape.Source) *shape.Source { if source == nil || strings.TrimSpace(source.Path) == "" { return nil diff --git a/repository/shape/compile/preprocess_handler_test.go b/repository/shape/compile/preprocess_handler_test.go index 7d4e5782..1f8c8d08 100644 --- a/repository/shape/compile/preprocess_handler_test.go +++ b/repository/shape/compile/preprocess_handler_test.go @@ -22,28 +22,21 @@ func TestIsHandlerSignal(t *testing.T) { assert.False(t, isHandlerSignal(&shape.Source{DQL: `SELECT 1`})) } -func TestBuildHandlerFromContractIfNeeded_LegacyFallbackViews(t *testing.T) { +func TestBuildHandlerFromContractIfNeeded_Disabled(t *testing.T) { tempDir := t.TempDir() sourcePath := filepath.Join(tempDir, "dql", "platform", "campaign", "post.dql") require.NoError(t, os.MkdirAll(filepath.Dir(sourcePath), 0o755)) dql := `/* {"Type":"campaign/patch.Handler","Connector":"ci_ads"} */` require.NoError(t, os.WriteFile(sourcePath, []byte(dql), 0o644)) - routeDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "platform", "campaign", "patch", "post") - require.NoError(t, os.MkdirAll(routeDir, 0o755)) - require.NoError(t, os.WriteFile(filepath.Join(routeDir, "post.sql"), []byte(`SELECT 1`), 0o644)) - require.NoError(t, os.WriteFile(filepath.Join(routeDir, "CurCampaign.sql"), []byte(`SELECT * FROM CI_CAMPAIGN`), 0o644)) - source := &shape.Source{Path: sourcePath, DQL: dql} pre := dqlpre.Prepare(source.DQL) statements := dqlstmt.New(pre.SQL) decision := pipeline.Classify(statements) result := &handlerPreprocessResult{Pre: pre, Statements: statements, Decision: decision, EffectiveSource: source} applied := buildHandlerFromContractIfNeeded(result, source, defaultCompilePathLayout()) - require.True(t, applied) + require.False(t, applied) require.NotNil(t, result) - require.NotEmpty(t, result.LegacyViews) - assert.Equal(t, "post", result.LegacyViews[0].Name) } func TestBuildGeneratedFallbackIfNeeded_GeneratedCompanion(t *testing.T) { @@ -63,7 +56,6 @@ func TestBuildGeneratedFallbackIfNeeded_GeneratedCompanion(t *testing.T) { applied := buildGeneratedFallbackIfNeeded(result, source, defaultCompilePathLayout()) require.True(t, applied) require.NotNil(t, result) - assert.Empty(t, result.LegacyViews) assert.Contains(t, result.Pre.SQL, "SELECT o.id FROM ORDERS o") assert.True(t, result.Decision.HasRead) } @@ -84,36 +76,11 @@ func TestResolveGeneratedLegacySource(t *testing.T) { assert.Contains(t, actual.DQL, `"Type":"session/patch.Handler"`) } -func TestBuildGeneratedFallbackIfNeeded_GeneratedLegacyRoute(t *testing.T) { +func TestBuildGeneratedFallbackIfNeeded_NoGeneratedCompanionWithoutTypeHeader(t *testing.T) { tempDir := t.TempDir() genPath := filepath.Join(tempDir, "dql", "system", "session", "gen", "session", "patch.dql") require.NoError(t, os.MkdirAll(filepath.Dir(genPath), 0o755)) require.NoError(t, os.WriteFile(genPath, []byte(`/* {"Method":"PATCH","URI":"/v1/api/system/session"} */`), 0o644)) - legacySQL := filepath.Join(tempDir, "dql", "system", "session", "patch.sql") - require.NoError(t, os.MkdirAll(filepath.Dir(legacySQL), 0o755)) - require.NoError(t, os.WriteFile(legacySQL, []byte(`/* {"Type":"session/patch.Handler","Connector":"system"} */`), 0o644)) - - routesDir := filepath.Join(tempDir, "repo", "dev", "Datly", "routes", "system", "session", "patch") - require.NoError(t, os.MkdirAll(routesDir, 0o755)) - require.NoError(t, os.WriteFile(filepath.Join(filepath.Dir(routesDir), "patch.yaml"), []byte(`Resource: - Views: - - Name: patch - Mode: SQLExec - Connector: - Ref: system - Template: - SourceURL: patch/patch.sql - Parameters: - - Name: Session - In: - Kind: body - Name: data - Types: - - Name: Input - DataType: "*Input" - Package: session/patch -`), 0o644)) - require.NoError(t, os.WriteFile(filepath.Join(routesDir, "patch.sql"), []byte(`$Nop($Unsafe.Session)`), 0o644)) source := &shape.Source{Path: genPath, DQL: `/* {"Method":"PATCH","URI":"/v1/api/system/session"} */`} pre := dqlpre.Prepare(source.DQL) @@ -121,23 +88,16 @@ func TestBuildGeneratedFallbackIfNeeded_GeneratedLegacyRoute(t *testing.T) { decision := pipeline.Classify(statements) result := &handlerPreprocessResult{Pre: pre, Statements: statements, Decision: decision, EffectiveSource: source} applied := buildGeneratedFallbackIfNeeded(result, source, defaultCompilePathLayout()) - require.True(t, applied) + require.False(t, applied) require.NotNil(t, result) - require.True(t, result.ForceLegacyContract) - require.NotNil(t, result.EffectiveSource) - assert.Equal(t, legacySQL, result.EffectiveSource.Path) - require.NotEmpty(t, result.LegacyViews) - assert.Equal(t, "patch", result.LegacyViews[0].Name) } -func TestBuildLegacyRouteFallbackIfNeeded_NoLegacyRoute(t *testing.T) { +func TestBuildGeneratedFallbackIfNeeded_NoGeneratedCompanion(t *testing.T) { source := &shape.Source{Path: filepath.Join(t.TempDir(), "dql", "x", "y", "z.dql"), DQL: `SELECT 1`} pre := dqlpre.Prepare(source.DQL) statements := dqlstmt.New(pre.SQL) decision := pipeline.Classify(statements) result := &handlerPreprocessResult{Pre: pre, Statements: statements, Decision: decision, EffectiveSource: source} - applied := buildLegacyRouteFallbackIfNeeded(result, source, defaultCompilePathLayout()) + applied := buildGeneratedFallbackIfNeeded(result, source, defaultCompilePathLayout()) assert.False(t, applied) - assert.Empty(t, result.LegacyViews) - assert.False(t, result.ForceLegacyContract) } diff --git a/repository/shape/compile/strings_util.go b/repository/shape/compile/strings_util.go new file mode 100644 index 00000000..5c18c9d8 --- /dev/null +++ b/repository/shape/compile/strings_util.go @@ -0,0 +1,13 @@ +package compile + +import "strings" + +func firstNonEmptyString(values ...string) string { + for _, value := range values { + value = strings.TrimSpace(value) + if value != "" { + return value + } + } + return "" +} diff --git a/repository/shape/compile/type_support.go b/repository/shape/compile/type_support.go new file mode 100644 index 00000000..44a8b9bc --- /dev/null +++ b/repository/shape/compile/type_support.go @@ -0,0 +1,238 @@ +package compile + +import ( + "reflect" + "strings" + + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/datly/repository/shape/typectx" + "github.com/viant/x" +) + +func applyLinkedTypeSupport(result *plan.Result, source *shape.Source) { + if result == nil || source == nil { + return + } + registry := source.EnsureTypeRegistry() + if registry == nil || len(registry.Keys()) == 0 { + return + } + resolver := typectx.NewResolver(registry, result.TypeContext) + rootTypeKey := resolveRootTypeKey(source, resolver, registry) + existing := existingTypesByName(result.Types) + + for idx, item := range result.Views { + if item == nil { + continue + } + resolvedKey := resolveViewTypeKey(item, idx == 0, rootTypeKey, resolver, registry) + if resolvedKey == "" { + continue + } + resolvedType := registry.Lookup(resolvedKey) + if resolvedType == nil || resolvedType.Type == nil { + continue + } + rType := unwrapResolvedType(resolvedType.Type) + if rType == nil { + continue + } + typeExpr, typePkg := schemaTypeExpression(rType, result.TypeContext) + if shouldSetSchemaType(item) && typeExpr != "" { + item.SchemaType = typeExpr + } + name := strings.TrimSpace(rType.Name()) + if name == "" { + continue + } + key := strings.ToLower(name) + if existing[key] { + continue + } + result.Types = append(result.Types, &plan.Type{ + Name: name, + DataType: typeExpr, + Cardinality: strings.TrimSpace(item.Cardinality), + Package: typePkg, + ModulePath: strings.TrimSpace(rType.PkgPath()), + }) + existing[key] = true + } +} + +func resolveRootTypeKey(source *shape.Source, resolver *typectx.Resolver, registry *x.Registry) string { + if source == nil || registry == nil { + return "" + } + if key := resolveTypeKey(strings.TrimSpace(source.TypeName), resolver, registry); key != "" { + return key + } + rType, err := source.ResolveRootType() + if err != nil || rType == nil { + return "" + } + return resolveTypeKey(x.NewType(rType).Key(), resolver, registry) +} + +func resolveViewTypeKey(item *plan.View, root bool, rootTypeKey string, resolver *typectx.Resolver, registry *x.Registry) string { + if item == nil || registry == nil { + return "" + } + candidates := make([]string, 0, 8) + seen := map[string]bool{} + appendCandidate := func(value string) { + value = strings.TrimSpace(value) + if value == "" { + return + } + if seen[value] { + return + } + seen[value] = true + candidates = append(candidates, value) + } + + if root && rootTypeKey != "" { + appendCandidate(rootTypeKey) + } + if item.Declaration != nil { + appendCandidate(item.Declaration.DataType) + appendCandidate(item.Declaration.Of) + } + appendCandidate(item.SchemaType) + name := toExportedTypeName(item.Name) + if name != "" { + appendCandidate(name + "View") + appendCandidate(name) + } + for _, candidate := range candidates { + if key := resolveTypeKey(candidate, resolver, registry); key != "" { + return key + } + } + return "" +} + +func resolveTypeKey(typeExpr string, resolver *typectx.Resolver, registry *x.Registry) string { + if registry == nil { + return "" + } + base := normalizeTypeLookupKey(typeExpr) + if base == "" { + return "" + } + if registry.Lookup(base) != nil { + return base + } + if resolver == nil { + return "" + } + resolved, err := resolver.Resolve(base) + if err != nil || resolved == "" { + return "" + } + if registry.Lookup(resolved) == nil { + return "" + } + return resolved +} + +func normalizeTypeLookupKey(typeExpr string) string { + value := strings.TrimSpace(typeExpr) + for { + switch { + case strings.HasPrefix(value, "*"): + value = strings.TrimPrefix(value, "*") + case strings.HasPrefix(value, "[]"): + value = strings.TrimPrefix(value, "[]") + default: + return strings.TrimSpace(value) + } + } +} + +func shouldSetSchemaType(item *plan.View) bool { + if item == nil { + return false + } + current := strings.TrimSpace(item.SchemaType) + if current == "" { + return true + } + expectedDefault := "*" + toExportedTypeName(item.Name) + "View" + return current == expectedDefault +} + +func existingTypesByName(input []*plan.Type) map[string]bool { + result := map[string]bool{} + for _, item := range input { + if item == nil { + continue + } + name := strings.ToLower(strings.TrimSpace(item.Name)) + if name == "" { + continue + } + result[name] = true + } + return result +} + +func schemaTypeExpression(rType reflect.Type, ctx *typectx.Context) (string, string) { + rType = unwrapResolvedType(rType) + if rType == nil { + return "", "" + } + typeName := strings.TrimSpace(rType.Name()) + if typeName == "" { + return "", "" + } + pkgPath := strings.TrimSpace(rType.PkgPath()) + if pkgPath == "" { + return "*" + typeName, "" + } + pkgAlias := packageAlias(pkgPath, ctx) + if pkgAlias == "" { + return "*" + typeName, "" + } + return "*" + pkgAlias + "." + typeName, pkgAlias +} + +func packageAlias(pkgPath string, ctx *typectx.Context) string { + pkgPath = strings.TrimSpace(pkgPath) + if pkgPath == "" { + return "" + } + if ctx != nil { + for _, item := range ctx.Imports { + if strings.TrimSpace(item.Package) != pkgPath { + continue + } + alias := strings.TrimSpace(item.Alias) + if alias != "" { + return alias + } + } + if strings.TrimSpace(ctx.PackagePath) == pkgPath && strings.TrimSpace(ctx.PackageName) != "" { + return strings.TrimSpace(ctx.PackageName) + } + } + index := strings.LastIndex(pkgPath, "/") + if index == -1 || index+1 >= len(pkgPath) { + return pkgPath + } + return pkgPath[index+1:] +} + +func unwrapResolvedType(rType reflect.Type) reflect.Type { + for rType != nil { + switch rType.Kind() { + case reflect.Ptr, reflect.Slice, reflect.Array: + rType = rType.Elem() + default: + return rType + } + } + return nil +} diff --git a/repository/shape/compile/type_support_test.go b/repository/shape/compile/type_support_test.go new file mode 100644 index 00000000..b3c376c5 --- /dev/null +++ b/repository/shape/compile/type_support_test.go @@ -0,0 +1,70 @@ +package compile + +import ( + "context" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/viant/datly/repository/shape" + "github.com/viant/datly/repository/shape/plan" + "github.com/viant/x" +) + +type linkedRootType struct { + ID int +} + +type OrdersView struct { + ID int +} + +func TestDQLCompiler_Compile_UsesLinkedRootTypeForSchemaType(t *testing.T) { + compiler := New() + source := &shape.Source{ + Name: "orders_report", + Type: reflect.TypeOf(linkedRootType{}), + TypeName: x.NewType(reflect.TypeOf(linkedRootType{})).Key(), + DQL: "SELECT t.id FROM ORDERS t", + } + + res, err := compiler.Compile(context.Background(), source) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "*compile.linkedRootType", planned.Views[0].SchemaType) + require.NotEmpty(t, planned.Types) + assert.Equal(t, "linkedRootType", planned.Types[0].Name) + assert.Equal(t, "*compile.linkedRootType", planned.Types[0].DataType) +} + +func TestDQLCompiler_Compile_UsesLinkedRegistryTypeForNamedView(t *testing.T) { + compiler := New() + registry := x.NewRegistry() + registry.Register(x.NewType(reflect.TypeOf(OrdersView{}))) + source := &shape.Source{ + Name: "orders", + TypeRegistry: registry, + DQL: "SELECT orders.id FROM ORDERS orders", + } + + res, err := compiler.Compile(context.Background(), source) + require.NoError(t, err) + planned, ok := res.Plan.(*plan.Result) + require.True(t, ok) + require.NotEmpty(t, planned.Views) + assert.Equal(t, "*compile.OrdersView", planned.Views[0].SchemaType) + + var found *plan.Type + for _, item := range planned.Types { + if item != nil && item.Name == "OrdersView" { + found = item + break + } + } + require.NotNil(t, found) + assert.Equal(t, "*compile.OrdersView", found.DataType) + assert.Equal(t, "compile", found.Package) +} diff --git a/repository/shape/compile/typectx_defaults.go b/repository/shape/compile/typectx_defaults.go index 5bc0a9d9..5561d36a 100644 --- a/repository/shape/compile/typectx_defaults.go +++ b/repository/shape/compile/typectx_defaults.go @@ -30,6 +30,7 @@ func applyTypeContextDefaults(ctx *typectx.Context, source *shape.Source, opts * } } } + ret = normalizeRelativeImports(ret, source, layout) return normalizeTypeContext(ret) } @@ -64,24 +65,11 @@ func mergeTypeContext(dst *typectx.Context, src *typectx.Context) *typectx.Conte } func inferDatlyGenTypeContext(source *shape.Source, layout compilePathLayout) *typectx.Context { - if source == nil { - return nil - } - sourcePath := strings.TrimSpace(source.Path) - if sourcePath == "" { - return nil - } - normalizedPath := filepath.ToSlash(filepath.Clean(sourcePath)) - idx := strings.Index(normalizedPath, layout.dqlMarker) - if idx == -1 { - return nil - } - projectRoot := filepath.FromSlash(strings.TrimSuffix(normalizedPath[:idx], "/")) - rel := strings.TrimPrefix(normalizedPath[idx+len(layout.dqlMarker):], "/") - if rel == "" { + parsed, ok := parseSourceLayout(source, layout) + if !ok { return nil } - routeDir := strings.Trim(path.Dir(rel), "/") + routeDir := strings.Trim(path.Dir(parsed.relativePath), "/") if routeDir == "." { routeDir = "" } @@ -94,7 +82,7 @@ func inferDatlyGenTypeContext(source *shape.Source, layout compilePathLayout) *t packageName = path.Base(routeDir) } packagePath := "" - if module := detectModulePath(projectRoot); module != "" { + if module := detectModulePath(parsed.projectRoot); module != "" { packagePath = path.Join(module, packageDir) } return normalizeTypeContext(&typectx.Context{ @@ -156,3 +144,83 @@ func normalizeTypeContext(ctx *typectx.Context) *typectx.Context { } return ctx } + +func normalizeRelativeImports(ctx *typectx.Context, source *shape.Source, layout compilePathLayout) *typectx.Context { + if ctx == nil || len(ctx.Imports) == 0 { + return ctx + } + modulePath := modulePathForSource(source, layout) + if modulePath == "" { + return ctx + } + for i, item := range ctx.Imports { + pkg := strings.TrimSpace(item.Package) + if pkg == "" { + continue + } + ctx.Imports[i].Package = normalizeImportPackage(pkg, modulePath) + } + return ctx +} + +func modulePathForSource(source *shape.Source, layout compilePathLayout) string { + parsed, ok := parseSourceLayout(source, layout) + if !ok { + return "" + } + return detectModulePath(parsed.projectRoot) +} + +func normalizeImportPackage(pkg, modulePath string) string { + pkg = strings.Trim(strings.ReplaceAll(strings.TrimSpace(pkg), "\\", "/"), "/") + if pkg == "" { + return "" + } + if !strings.Contains(pkg, "/") { + return pkg + } + if strings.HasPrefix(pkg, modulePath+"/") || pkg == modulePath { + return pkg + } + first := pkg + if index := strings.Index(first, "/"); index != -1 { + first = first[:index] + } + if strings.Contains(first, ".") { + return pkg + } + return path.Join(modulePath, pkg) +} + +type sourceLayout struct { + projectRoot string + relativePath string +} + +func parseSourceLayout(source *shape.Source, layout compilePathLayout) (*sourceLayout, bool) { + if source == nil { + return nil, false + } + sourcePath := strings.TrimSpace(source.Path) + if sourcePath == "" { + return nil, false + } + marker := strings.TrimSpace(layout.dqlMarker) + if marker == "" { + marker = defaultCompilePathLayout().dqlMarker + } + normalizedPath := filepath.ToSlash(filepath.Clean(sourcePath)) + idx := strings.Index(normalizedPath, marker) + if idx == -1 { + return nil, false + } + projectRoot := filepath.FromSlash(strings.TrimSuffix(normalizedPath[:idx], "/")) + relativePath := strings.TrimPrefix(normalizedPath[idx+len(marker):], "/") + if relativePath == "" { + return nil, false + } + return &sourceLayout{ + projectRoot: projectRoot, + relativePath: relativePath, + }, true +} diff --git a/repository/shape/compile/typectx_defaults_test.go b/repository/shape/compile/typectx_defaults_test.go index 4f0d01a3..aa3d01d8 100644 --- a/repository/shape/compile/typectx_defaults_test.go +++ b/repository/shape/compile/typectx_defaults_test.go @@ -67,4 +67,20 @@ func TestApplyTypeContextDefaults_Matrix(t *testing.T) { }, layout) require.Nil(t, got) }) + + t.Run("relative imports are normalized to module path", func(t *testing.T) { + input := &typectx.Context{ + Imports: []typectx.Import{ + {Alias: "sess", Package: "pkg/platform/system/session"}, + {Alias: "perf", Package: "github.com/acme/perf"}, + {Alias: "time", Package: "time"}, + }, + } + got := applyTypeContextDefaults(input, source, nil, layout) + require.NotNil(t, got) + require.Len(t, got.Imports, 3) + require.Equal(t, "github.vianttech.com/viant/platform/pkg/platform/system/session", got.Imports[0].Package) + require.Equal(t, "github.com/acme/perf", got.Imports[1].Package) + require.Equal(t, "time", got.Imports[2].Package) + }) } diff --git a/repository/shape/platform_parity_test.go b/repository/shape/platform_parity_test.go index f6837b03..17fa1f1f 100644 --- a/repository/shape/platform_parity_test.go +++ b/repository/shape/platform_parity_test.go @@ -207,6 +207,9 @@ type parityEntryEval struct { } func TestPlatform_DQLToRoute_ParityIR_SmokeHandlers(t *testing.T) { + if !strings.EqualFold(strings.TrimSpace(os.Getenv("PLATFORM_PARITY_SMOKE")), "1") { + t.Skip("set PLATFORM_PARITY_SMOKE=1 to run legacy parity smoke handlers") + } platformRoot := os.Getenv("PLATFORM_ROOT") if platformRoot == "" { platformRoot = "/Users/awitas/go/src/github.vianttech.com/viant/platform" diff --git a/warmup/cache_test.go b/warmup/cache_test.go index 408a4a31..9196529d 100644 --- a/warmup/cache_test.go +++ b/warmup/cache_test.go @@ -2,6 +2,7 @@ package warmup import ( "context" + "os" "path" "testing" @@ -12,6 +13,10 @@ import ( ) func TestPopulateCache(t *testing.T) { + if os.Getenv("DATLY_RUN_WARMUP_TESTS") == "" { + t.Skip("set DATLY_RUN_WARMUP_TESTS=1 to run warmup integration test") + } + testCases := []struct { description string URL string From b218524a5681acbb6969e2f0df7c325d5b312308 Mon Sep 17 00:00:00 2001 From: adranwit Date: Tue, 24 Feb 2026 07:27:51 -0800 Subject: [PATCH 6/6] shape/compile: add type support helpers; refine preprocessing and type defaults; update tests --- Version | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Version b/Version index 014ec619..fcc9d59a 100644 --- a/Version +++ b/Version @@ -1 +1 @@ -v0.20.2 \ No newline at end of file +v0.21.0 \ No newline at end of file