diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 13bd4acbfef..2849a1550a3 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -159,6 +159,8 @@ def update_features_impl(op: OpKey): torch.ops.aten.sym_size.int, operator.add, operator.sub, + operator.floordiv, + operator.mul, operator.lt, operator.gt, operator.ge, @@ -279,6 +281,26 @@ def register_bitwise_and(): ) +@update_features(exir_ops.edge.aten.bitwise_not.default) +def register_bitwise_not(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + inputs_dtypes=utils.BOOL_T, + supports_resize=True, + supports_highdim=True, + ) + + +@update_features(exir_ops.edge.aten.logical_and.default) +def register_logical_and(): + return OpFeatures( + inputs_storage=utils.ANY_STORAGE, + inputs_dtypes=utils.BOOL_T, + supports_resize=True, + supports_highdim=True, + ) + + # ============================================================================= # BinaryScalarOp.cpp # ============================================================================= @@ -301,16 +323,22 @@ def register_pow_tensor_scalar(): @update_features(exir_ops.edge.aten._to_copy.default) def register_to_copy(): - def check_to_copy_node(node: torch.fx.Node) -> bool: - # Only single-arg _to_copy is supported - return len(node.args) == 1 + def pick_to_copy_storage( + node: torch.fx.Node, + ) -> Tuple[utils.TensorRepSet, utils.TensorRepSet]: + in_dtype = node.args[0].meta["val"].dtype # type: ignore[union-attr] + out_dtype = node.meta["val"].dtype + fp_types = {torch.float16, torch.float32} + if in_dtype in fp_types and out_dtype in fp_types: + return utils.ANY_STORAGE, utils.ANY_STORAGE + return utils.CONTIGUOUS_BUFFER, utils.CONTIGUOUS_BUFFER return OpFeatures( inputs_storage=utils.ANY_STORAGE, - inputs_dtypes=utils.FP_INT_T, - outputs_dtypes=utils.FP_INT_T, + inputs_dtypes=utils.FP_INT_BOOL_T, + outputs_dtypes=utils.FP_INT_BOOL_T, supports_resize=True, - are_node_inputs_supported_fn=check_to_copy_node, + pick_io_storage_fn=pick_to_copy_storage, ) @@ -1301,6 +1329,7 @@ def register_scalar_tensor(): return OpFeatures( inputs_storage=utils.CHANNELS_PACKED_TEXTURE, inputs_dtypes=utils.FP_INT_T, + supports_resize=True, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml index 47f538aee6c..1763f975058 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml @@ -46,3 +46,6 @@ unary_op: OPERATOR: leaky_relu(X, A) - NAME: round OPERATOR: round(X) + - NAME: bitwise_not_uint8 + OPERATOR: 1 - X + DTYPE: uint8 diff --git a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp index 92c2fa218ec..fa4c75463b7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp @@ -214,6 +214,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.gt.Tensor, gt); VK_REGISTER_OP(aten.ge.Tensor, ge); VK_REGISTER_OP(aten.bitwise_and.Tensor, bitwise_and); + VK_REGISTER_OP(aten.logical_and.default, bitwise_and); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Split.cpp b/backends/vulkan/runtime/graph/ops/impl/Split.cpp index 4e62ae8806d..3931606751c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Split.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Split.cpp @@ -35,7 +35,7 @@ void resize_split_node( const ValueListPtr out_list = graph->get_value_list(out_list_ref); const std::vector split_sizes = - *(graph->get_int_list(split_sizes_ref)); + graph->extract_int_or_symint_list(split_sizes_ref); const int64_t dim = graph->extract_scalar(dim_ref); const int64_t input_ndim = graph->dim_of(input); @@ -125,7 +125,8 @@ void split_with_sizes_copy_default( ValueRef out_list_ref = args[3]; int64_t dim = graph.extract_scalar(dim_ref); - std::vector split_sizes = *(graph.get_int_list(split_sizes_ref)); + std::vector split_sizes = + graph.extract_int_or_symint_list(split_sizes_ref); add_split_with_sizes_node(graph, input, split_sizes, dim, out_list_ref); } diff --git a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp index eb03639abf1..3aef6bc988d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp @@ -81,6 +81,90 @@ void sym_add(ComputeGraph& graph, const std::vector& args) { new ExecuteNode(resize_sym_add_node, args)); } +void sym_sub_impl(ComputeGraph* graph, const std::vector& args) { + const ValueRef a = args.at(0); + const ValueRef b = args.at(1); + const ValueRef out = args.at(2); + + const int32_t a_val = graph->read_symint(a); + const int32_t b_val = graph->read_symint(b); + const int32_t result = a_val - b_val; + + graph->set_symint(out, result); +} + +void resize_sym_sub_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + sym_sub_impl(graph, resize_args); +} + +void sym_sub(ComputeGraph& graph, const std::vector& args) { + sym_sub_impl(&graph, args); + + graph.execute_nodes().emplace_back( + new ExecuteNode(resize_sym_sub_node, args)); +} + +void sym_floordiv_impl(ComputeGraph* graph, const std::vector& args) { + const ValueRef a = args.at(0); + const ValueRef b = args.at(1); + const ValueRef out = args.at(2); + + const int32_t a_val = graph->read_symint(a); + const int32_t b_val = graph->read_symint(b); + // Floor division: round towards negative infinity + const int32_t result = (a_val ^ b_val) < 0 && a_val % b_val != 0 + ? a_val / b_val - 1 + : a_val / b_val; + + graph->set_symint(out, result); +} + +void resize_sym_floordiv_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + sym_floordiv_impl(graph, resize_args); +} + +void sym_floordiv(ComputeGraph& graph, const std::vector& args) { + sym_floordiv_impl(&graph, args); + + graph.execute_nodes().emplace_back( + new ExecuteNode(resize_sym_floordiv_node, args)); +} + +void sym_mul_impl(ComputeGraph* graph, const std::vector& args) { + const ValueRef a = args.at(0); + const ValueRef b = args.at(1); + const ValueRef out = args.at(2); + + const int32_t a_val = graph->read_symint(a); + const int32_t b_val = graph->read_symint(b); + const int32_t result = a_val * b_val; + + graph->set_symint(out, result); +} + +void resize_sym_mul_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& resize_args) { + (void)args; + sym_mul_impl(graph, resize_args); +} + +void sym_mul(ComputeGraph& graph, const std::vector& args) { + sym_mul_impl(&graph, args); + + graph.execute_nodes().emplace_back( + new ExecuteNode(resize_sym_mul_node, args)); +} + void select_as_symint_impl( ComputeGraph* graph, const std::vector& unused, @@ -132,6 +216,9 @@ void select_as_symint(ComputeGraph& graph, const std::vector& args) { REGISTER_OPERATORS { VK_REGISTER_OP(sym_size.int, sym_size_int); VK_REGISTER_OP(add, sym_add); + VK_REGISTER_OP(sub, sym_sub); + VK_REGISTER_OP(floordiv, sym_floordiv); + VK_REGISTER_OP(mul, sym_mul); VK_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint); } diff --git a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp index 275023faa59..2de4a555860 100644 --- a/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp @@ -30,8 +30,8 @@ bool is_float_type(vkapi::ScalarType dtype) { } void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) { - vkapi::ScalarType in_dtype = graph.dtype_of(in); - vkapi::ScalarType out_dtype = graph.dtype_of(out); + const vkapi::ScalarType in_dtype = graph.dtype_of(in); + const vkapi::ScalarType out_dtype = graph.dtype_of(out); // Same-dtype or float<->half conversions can use BlitNode if (in_dtype == out_dtype || diff --git a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp index 9830a8e8784..de6172da2b9 100644 --- a/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp @@ -158,6 +158,7 @@ DEFINE_ACTIVATION_FN(hardswish); DEFINE_ACTIVATION_FN(hardsigmoid); DEFINE_LEAKY_RELU_FN(leaky_relu); DEFINE_ACTIVATION_FN(round); +DEFINE_ACTIVATION_FN(bitwise_not); REGISTER_OPERATORS { VK_REGISTER_OP(aten.abs.default, abs); @@ -179,6 +180,7 @@ REGISTER_OPERATORS { VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid); VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu); VK_REGISTER_OP(aten.round.default, round); + VK_REGISTER_OP(aten.bitwise_not.default, bitwise_not); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Where.cpp b/backends/vulkan/runtime/graph/ops/impl/Where.cpp index adb7fb1beca..c52a0c277cd 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Where.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Where.cpp @@ -21,10 +21,22 @@ void resize_where_node( const std::vector& extra_args) { (void)extra_args; const ValueRef out = args.at(0).refs.at(0); - const ValueRef self = args.at(1).refs.at(1); - const std::vector self_sizes = graph->sizes_of(self); - graph->virtual_resize(out, self_sizes); + std::vector out_sizes; + for (const ValueRef ref : args.at(1).refs) { + if (!graph->val_is_tensor(ref)) { + continue; + } + const std::vector s = graph->sizes_of(ref); + if (s.size() > out_sizes.size()) { + out_sizes.resize(s.size(), 1); + } + const size_t offset = out_sizes.size() - s.size(); + for (size_t i = 0; i < s.size(); i++) { + out_sizes[offset + i] = std::max(out_sizes[offset + i], s[i]); + } + } + graph->virtual_resize(out, out_sizes); } void add_where_node(