diff --git a/lib/typeprof/core/ast/call.rb b/lib/typeprof/core/ast/call.rb index a2e97728..19153289 100644 --- a/lib/typeprof/core/ast/call.rb +++ b/lib/typeprof/core/ast/call.rb @@ -164,6 +164,15 @@ def install0(genv) @changes.add_edge(genv, allow_nil, ret) end + if @mid == :[]= && @recv.is_a?(LocalVariableReadNode) + key_node = @positional_args[0] + if key_node.is_a?(SymbolNode) + recv_vtx = @lenv.get_var(@recv.var) + nvtx = @lenv.new_var(@recv.var, self) + @changes.add_hash_aset_box(genv, recv_vtx, key_node.lit, ret, nvtx) + end + end + ret end @@ -188,6 +197,10 @@ def retrieve_at(pos, &blk) end def modified_vars(tbl, vars) + if @mid == :[]= && @recv.is_a?(LocalVariableReadNode) && tbl.include?(@recv.var) + key_node = @positional_args[0] + vars << @recv.var if key_node.is_a?(SymbolNode) + end subnodes.each do |key, subnode| next unless subnode if subnode.is_a?(AST::Node) diff --git a/lib/typeprof/core/builtin.rb b/lib/typeprof/core/builtin.rb index cf6e270e..8a3f45f6 100644 --- a/lib/typeprof/core/builtin.rb +++ b/lib/typeprof/core/builtin.rb @@ -111,9 +111,16 @@ def hash_aref(changes, node, ty, a_args, ret) def hash_aset(changes, node, ty, a_args, ret) if a_args.positionals.size == 2 + val = a_args.positionals[1] + + # Skip backflow for local variable receivers (handled by HashAsetBox) + if node.recv.is_a?(AST::LocalVariableReadNode) + changes.add_edge(@genv, val, ret) + return true + end + case ty when Type::Hash - val = a_args.positionals[1] idx = node.positional_args[0] if idx.is_a?(AST::SymbolNode) && ty.get_value(idx.lit) # TODO: how to handle new key? diff --git a/lib/typeprof/core/graph/box.rb b/lib/typeprof/core/graph/box.rb index 9c62aeb6..03c9bf6e 100644 --- a/lib/typeprof/core/graph/box.rb +++ b/lib/typeprof/core/graph/box.rb @@ -1105,4 +1105,79 @@ def run0(genv, changes) changes.add_edge(genv, source_vtx, @ret) end end + + class HashAsetBox < Box + def initialize(node, genv, recv, key_sym, val_vtx, out_vtx) + super(node) + @recv = recv + @key_sym = key_sym + @val_vtx = val_vtx + @out_vtx = out_vtx + @recv.add_edge(genv, self) + @val_vtx.add_edge(genv, self) + # Cache vertices to ensure convergence in loops. + # Without caching, run0 creates new Vertex objects each time, + # producing new Type objects that prevent the fixed-point from being reached. + @field_cache = {} + @unified_key = Vertex.new(node) + @unified_val = Vertex.new(node) + @merged_key = Vertex.new(node) + @merged_val = Vertex.new(node) + end + + attr_reader :recv, :key_sym, :val_vtx, :out_vtx + + def ret = @out_vtx + + def destroy(genv) + @recv.remove_edge(genv, self) + @val_vtx.remove_edge(genv, self) + super(genv) + end + + def run0(genv, changes) + @recv.each_type do |ty| + case ty + when Type::Record + new_fields = {} + ty.fields.each do |key, field_vtx| + @field_cache[key] ||= Vertex.new(@node) + changes.add_edge(genv, field_vtx, @field_cache[key]) unless field_vtx.equal?(@field_cache[key]) + new_fields[key] = @field_cache[key] + end + @field_cache[@key_sym] ||= Vertex.new(@node) + new_fields[@key_sym] = @field_cache[@key_sym] + changes.add_edge(genv, @val_vtx, @field_cache[@key_sym]) + new_fields.each do |key, vtx| + changes.add_edge(genv, Source.new(Type::Symbol.new(genv, key)), @unified_key) + changes.add_edge(genv, vtx, @unified_val) + end + base_type = genv.gen_hash_type(@unified_key, @unified_val) + new_record = Type::Record.new(genv, new_fields, base_type) + changes.add_edge(genv, Source.new(new_record), @out_vtx) + when Type::Hash + build_merged_hash_type(genv, changes, ty.get_key, ty.get_value) + when Type::Instance + if ty.mod == genv.mod_hash + build_merged_hash_type(genv, changes, ty.args[0], ty.args[1]) + else + changes.add_edge(genv, Source.new(ty), @out_vtx) + end + else + changes.add_edge(genv, Source.new(ty), @out_vtx) + end + end + end + + private + + def build_merged_hash_type(genv, changes, old_key_vtx, old_val_vtx) + changes.add_edge(genv, old_key_vtx, @merged_key) unless old_key_vtx.equal?(@merged_key) + changes.add_edge(genv, Source.new(Type::Symbol.new(genv, @key_sym)), @merged_key) + changes.add_edge(genv, old_val_vtx, @merged_val) unless old_val_vtx.equal?(@merged_val) + changes.add_edge(genv, @val_vtx, @merged_val) + new_hash_type = genv.gen_hash_type(@merged_key, @merged_val) + changes.add_edge(genv, Source.new(new_hash_type), @out_vtx) + end + end end diff --git a/lib/typeprof/core/graph/change_set.rb b/lib/typeprof/core/graph/change_set.rb index 8e7d6008..08a2415e 100644 --- a/lib/typeprof/core/graph/change_set.rb +++ b/lib/typeprof/core/graph/change_set.rb @@ -137,6 +137,11 @@ def add_instance_type_box(genv, singleton_ty_vtx) @new_boxes[key] ||= InstanceTypeBox.new(@node, genv, singleton_ty_vtx) end + def add_hash_aset_box(genv, recv, key_sym, val_vtx, out_vtx) + key = [:hash_aset, recv, key_sym, val_vtx, out_vtx] + @new_boxes[key] ||= HashAsetBox.new(@node, genv, recv, key_sym, val_vtx, out_vtx) + end + def add_diagnostic(meth, msg, node = @node) @new_diagnostics << TypeProf::Diagnostic.new(node, meth, msg) end diff --git a/scenario/hash/hash_aset.rb b/scenario/hash/hash_aset.rb new file mode 100644 index 00000000..cf397413 --- /dev/null +++ b/scenario/hash/hash_aset.rb @@ -0,0 +1,22 @@ +## update +def foo(options) + return if options[:skip] + + options[:name] = "str" + bar(options) + nil +end + +def bar(options) + options[:age] = 10 + nil +end + +args = Hash.new +foo(args) + +## assert +class Object + def foo: (Hash[:skip, untyped]) -> nil + def bar: (Hash[:name | :skip, String]) -> nil +end diff --git a/scenario/hash/hash_aset_loop.rb b/scenario/hash/hash_aset_loop.rb new file mode 100644 index 00000000..6c3f28af --- /dev/null +++ b/scenario/hash/hash_aset_loop.rb @@ -0,0 +1,14 @@ +## update +def foo(options) + while options[:flag] + options[:name] = "str" + end + nil +end + +foo(Hash.new) + +## assert +class Object + def foo: (Hash[:flag, untyped]) -> nil +end