diff --git a/checker/checker.go b/checker/checker.go index 3620f207..63425af1 100644 --- a/checker/checker.go +++ b/checker/checker.go @@ -578,6 +578,15 @@ func (v *Checker) memberNode(node *ast.MemberNode) Nature { } return base.Elem(&v.config.NtCache) + case reflect.Interface: + // For non-any interface types, we don't know the concrete type at + // compile time. Allow field (non-method) access and defer resolution + // to runtime, where the concrete type can be inspected. + if name, ok := node.Property.(*ast.StringNode); ok && node.Method { + return v.error(node, "type %v has no method %v", base.String(), name.Value) + } + return Nature{} + case reflect.Struct: if name, ok := node.Property.(*ast.StringNode); ok { propertyName := name.Value diff --git a/test/issues/951/issue_test.go b/test/issues/951/issue_test.go new file mode 100644 index 00000000..7c735210 --- /dev/null +++ b/test/issues/951/issue_test.go @@ -0,0 +1,118 @@ +package issue951 + +import ( + "testing" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/internal/testify/require" +) + +type Node interface { + ID() string +} + +type Base struct { + Name string +} + +func (b Base) ID() string { return b.Name } + +type Container struct { + Base + Items []*Item +} + +type Item struct { + Kind string + Value string +} + +type Wrapper struct { + Node // embedded interface +} + +type Proxy struct { + *Wrapper +} + +type Nodes []Node + +func (ns Nodes) GetByID(id string) Node { + for _, n := range ns { + if n.ID() == id { + return n + } + } + return nil +} + +func TestFieldAccessThroughEmbeddedInterface(t *testing.T) { + container := &Container{ + Base: Base{Name: "test"}, + Items: []*Item{ + {Kind: "card", Value: "some_value"}, + }, + } + proxy := &Proxy{ + Wrapper: &Wrapper{ + Node: container, + }, + } + + tests := []struct { + name string + expr string + env any + expect any + }{ + { + name: "field through GetByID returning interface", + expr: `data.GetByID("test").Items[0].Value`, + env: map[string]any{"data": Nodes{proxy}}, + expect: "some_value", + }, + { + name: "optional chaining with embedded interface", + expr: `data.GetByID("test")?.Items[0].Value`, + env: map[string]any{"data": Nodes{proxy}}, + expect: "some_value", + }, + { + name: "optional chaining nil result", + expr: `data.GetByID("missing")?.Items`, + env: map[string]any{"data": Nodes{proxy}}, + expect: nil, + }, + { + name: "promoted field through interface", + expr: `data.GetByID("test").Name`, + env: map[string]any{"data": Nodes{proxy}}, + expect: "test", + }, + { + name: "method on interface still works", + expr: `data.GetByID("test").ID()`, + env: map[string]any{"data": Nodes{proxy}}, + expect: "test", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := expr.Eval(tt.expr, tt.env) + require.NoError(t, err) + require.Equal(t, tt.expect, result) + }) + } +} + +func TestFieldAccessEmbeddedInterfaceNil(t *testing.T) { + proxy := &Proxy{ + Wrapper: &Wrapper{ + Node: nil, + }, + } + + _, err := expr.Eval(`Items[0].Value`, proxy) + require.Error(t, err) +} diff --git a/vm/runtime/runtime.go b/vm/runtime/runtime.go index bc6f2b4d..169d66d6 100644 --- a/vm/runtime/runtime.go +++ b/vm/runtime/runtime.go @@ -79,23 +79,15 @@ func Fetch(from, i any) any { if cv, ok := fieldCache.Load(key); ok { return v.FieldByIndex(cv.([]int)).Interface() } - field, ok := t.FieldByNameFunc(func(name string) bool { - field, _ := t.FieldByName(name) - switch field.Tag.Get("expr") { - case "-": - return false - case fieldName: - return true - default: - return name == fieldName - } - }) - if ok && field.IsExported() { - value := v.FieldByIndex(field.Index) - if value.IsValid() { - fieldCache.Store(key, field.Index) - return value.Interface() - } + if value, field, ok := findStructField(v, fieldName); ok { + fieldCache.Store(key, field.Index) + return value.Interface() + } + // Field isn't found via standard Go promotion. Try to find it + // by traversing embedded interface values whose concrete types + // may contain the requested field. + if result, found := fetchFromEmbeddedInterfaces(v, fieldName); found { + return result } } panic(fmt.Sprintf("cannot fetch %v from %T", i, from)) @@ -143,6 +135,82 @@ func fieldByIndex(v reflect.Value, field *Field) reflect.Value { return v } +func findStructField(v reflect.Value, fieldName string) (reflect.Value, reflect.StructField, bool) { + t := v.Type() + field, ok := t.FieldByNameFunc(func(name string) bool { + sf, _ := t.FieldByName(name) + switch sf.Tag.Get("expr") { + case "-": + return false + case fieldName: + return true + default: + return name == fieldName + } + }) + if ok && field.IsExported() { + value := v.FieldByIndex(field.Index) + if value.IsValid() { + return value, field, true + } + } + return reflect.Value{}, reflect.StructField{}, false +} + +func fetchFromEmbeddedInterfaces(v reflect.Value, fieldName string) (any, bool) { + t := v.Type() + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.Anonymous { + continue + } + fv := v.Field(i) + fk := f.Type.Kind() + + // Dereference pointers to get to the underlying type. + for fk == reflect.Ptr { + if fv.IsNil() { + break + } + fv = fv.Elem() + fk = fv.Kind() + } + + switch fk { + case reflect.Interface: + if fv.IsNil() { + continue + } + // Unwrap interface and dereference pointers to reach the + // concrete struct value. + concrete := fv.Elem() + for concrete.Kind() == reflect.Ptr { + if concrete.IsNil() { + break + } + concrete = concrete.Elem() + } + if concrete.Kind() != reflect.Struct { + continue + } + if value, _, ok := findStructField(concrete, fieldName); ok { + return value.Interface(), true + } + // The concrete type itself may have embedded interfaces. + if result, found := fetchFromEmbeddedInterfaces(concrete, fieldName); found { + return result, found + } + + case reflect.Struct: + // Recurse into embedded structs to find embedded interfaces. + if result, found := fetchFromEmbeddedInterfaces(fv, fieldName); found { + return result, found + } + } + } + return nil, false +} + type Method struct { Index int Name string