Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,8 @@ def update_features_impl(op: OpKey):
torch.ops.aten.sym_size.int,
operator.add,
operator.sub,
operator.floordiv,
operator.mul,
operator.lt,
operator.gt,
operator.ge,
Expand Down Expand Up @@ -279,6 +281,26 @@ def register_bitwise_and():
)


@update_features(exir_ops.edge.aten.bitwise_not.default)
def register_bitwise_not():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.BOOL_T,
supports_resize=True,
supports_highdim=True,
)


@update_features(exir_ops.edge.aten.logical_and.default)
def register_logical_and():
return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.BOOL_T,
supports_resize=True,
supports_highdim=True,
)


# =============================================================================
# BinaryScalarOp.cpp
# =============================================================================
Expand All @@ -301,16 +323,22 @@ def register_pow_tensor_scalar():

@update_features(exir_ops.edge.aten._to_copy.default)
def register_to_copy():
def check_to_copy_node(node: torch.fx.Node) -> bool:
# Only single-arg _to_copy is supported
return len(node.args) == 1
def pick_to_copy_storage(
node: torch.fx.Node,
) -> Tuple[utils.TensorRepSet, utils.TensorRepSet]:
in_dtype = node.args[0].meta["val"].dtype # type: ignore[union-attr]
out_dtype = node.meta["val"].dtype
fp_types = {torch.float16, torch.float32}
if in_dtype in fp_types and out_dtype in fp_types:
return utils.ANY_STORAGE, utils.ANY_STORAGE
return utils.CONTIGUOUS_BUFFER, utils.CONTIGUOUS_BUFFER

return OpFeatures(
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_INT_T,
outputs_dtypes=utils.FP_INT_T,
inputs_dtypes=utils.FP_INT_BOOL_T,
outputs_dtypes=utils.FP_INT_BOOL_T,
supports_resize=True,
are_node_inputs_supported_fn=check_to_copy_node,
pick_io_storage_fn=pick_to_copy_storage,
)


Expand Down Expand Up @@ -1301,6 +1329,7 @@ def register_scalar_tensor():
return OpFeatures(
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
inputs_dtypes=utils.FP_INT_T,
supports_resize=True,
)


Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/unary_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,6 @@ unary_op:
OPERATOR: leaky_relu(X, A)
- NAME: round
OPERATOR: round(X)
- NAME: bitwise_not_uint8
OPERATOR: 1 - X
DTYPE: uint8
1 change: 1 addition & 0 deletions backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(aten.gt.Tensor, gt);
VK_REGISTER_OP(aten.ge.Tensor, ge);
VK_REGISTER_OP(aten.bitwise_and.Tensor, bitwise_and);
VK_REGISTER_OP(aten.logical_and.default, bitwise_and);
}

} // namespace vkcompute
5 changes: 3 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void resize_split_node(

const ValueListPtr out_list = graph->get_value_list(out_list_ref);
const std::vector<int64_t> split_sizes =
*(graph->get_int_list(split_sizes_ref));
graph->extract_int_or_symint_list(split_sizes_ref);
const int64_t dim = graph->extract_scalar<int64_t>(dim_ref);

const int64_t input_ndim = graph->dim_of(input);
Expand Down Expand Up @@ -125,7 +125,8 @@ void split_with_sizes_copy_default(
ValueRef out_list_ref = args[3];

int64_t dim = graph.extract_scalar<int64_t>(dim_ref);
std::vector<int64_t> split_sizes = *(graph.get_int_list(split_sizes_ref));
std::vector<int64_t> split_sizes =
graph.extract_int_or_symint_list(split_sizes_ref);

add_split_with_sizes_node(graph, input, split_sizes, dim, out_list_ref);
}
Expand Down
87 changes: 87 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/SymIntOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,90 @@ void sym_add(ComputeGraph& graph, const std::vector<ValueRef>& args) {
new ExecuteNode(resize_sym_add_node, args));
}

void sym_sub_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
const ValueRef a = args.at(0);
const ValueRef b = args.at(1);
const ValueRef out = args.at(2);

const int32_t a_val = graph->read_symint(a);
const int32_t b_val = graph->read_symint(b);
const int32_t result = a_val - b_val;

graph->set_symint(out, result);
}

void resize_sym_sub_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)args;
sym_sub_impl(graph, resize_args);
}

void sym_sub(ComputeGraph& graph, const std::vector<ValueRef>& args) {
sym_sub_impl(&graph, args);

graph.execute_nodes().emplace_back(
new ExecuteNode(resize_sym_sub_node, args));
}

void sym_floordiv_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
const ValueRef a = args.at(0);
const ValueRef b = args.at(1);
const ValueRef out = args.at(2);

const int32_t a_val = graph->read_symint(a);
const int32_t b_val = graph->read_symint(b);
// Floor division: round towards negative infinity
const int32_t result = (a_val ^ b_val) < 0 && a_val % b_val != 0
? a_val / b_val - 1
: a_val / b_val;

graph->set_symint(out, result);
}

void resize_sym_floordiv_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)args;
sym_floordiv_impl(graph, resize_args);
}

void sym_floordiv(ComputeGraph& graph, const std::vector<ValueRef>& args) {
sym_floordiv_impl(&graph, args);

graph.execute_nodes().emplace_back(
new ExecuteNode(resize_sym_floordiv_node, args));
}

void sym_mul_impl(ComputeGraph* graph, const std::vector<ValueRef>& args) {
const ValueRef a = args.at(0);
const ValueRef b = args.at(1);
const ValueRef out = args.at(2);

const int32_t a_val = graph->read_symint(a);
const int32_t b_val = graph->read_symint(b);
const int32_t result = a_val * b_val;

graph->set_symint(out, result);
}

void resize_sym_mul_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)args;
sym_mul_impl(graph, resize_args);
}

void sym_mul(ComputeGraph& graph, const std::vector<ValueRef>& args) {
sym_mul_impl(&graph, args);

graph.execute_nodes().emplace_back(
new ExecuteNode(resize_sym_mul_node, args));
}

void select_as_symint_impl(
ComputeGraph* graph,
const std::vector<ArgGroup>& unused,
Expand Down Expand Up @@ -132,6 +216,9 @@ void select_as_symint(ComputeGraph& graph, const std::vector<ValueRef>& args) {
REGISTER_OPERATORS {
VK_REGISTER_OP(sym_size.int, sym_size_int);
VK_REGISTER_OP(add, sym_add);
VK_REGISTER_OP(sub, sym_sub);
VK_REGISTER_OP(floordiv, sym_floordiv);
VK_REGISTER_OP(mul, sym_mul);
VK_REGISTER_OP(et_vk.select_as_symint.default, select_as_symint);
}

Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/ToCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ bool is_float_type(vkapi::ScalarType dtype) {
}

void add_to_copy_node(ComputeGraph& graph, ValueRef in, ValueRef out) {
vkapi::ScalarType in_dtype = graph.dtype_of(in);
vkapi::ScalarType out_dtype = graph.dtype_of(out);
const vkapi::ScalarType in_dtype = graph.dtype_of(in);
const vkapi::ScalarType out_dtype = graph.dtype_of(out);

// Same-dtype or float<->half conversions can use BlitNode
if (in_dtype == out_dtype ||
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/impl/UnaryOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ DEFINE_ACTIVATION_FN(hardswish);
DEFINE_ACTIVATION_FN(hardsigmoid);
DEFINE_LEAKY_RELU_FN(leaky_relu);
DEFINE_ACTIVATION_FN(round);
DEFINE_ACTIVATION_FN(bitwise_not);

REGISTER_OPERATORS {
VK_REGISTER_OP(aten.abs.default, abs);
Expand All @@ -179,6 +180,7 @@ REGISTER_OPERATORS {
VK_REGISTER_OP(aten.hardsigmoid.default, hardsigmoid);
VK_REGISTER_OP(aten.leaky_relu.default, leaky_relu);
VK_REGISTER_OP(aten.round.default, round);
VK_REGISTER_OP(aten.bitwise_not.default, bitwise_not);
}

} // namespace vkcompute
18 changes: 15 additions & 3 deletions backends/vulkan/runtime/graph/ops/impl/Where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,22 @@ void resize_where_node(
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
const ValueRef out = args.at(0).refs.at(0);
const ValueRef self = args.at(1).refs.at(1);

const std::vector<int64_t> self_sizes = graph->sizes_of(self);
graph->virtual_resize(out, self_sizes);
std::vector<int64_t> out_sizes;
for (const ValueRef ref : args.at(1).refs) {
if (!graph->val_is_tensor(ref)) {
continue;
}
const std::vector<int64_t> s = graph->sizes_of(ref);
if (s.size() > out_sizes.size()) {
out_sizes.resize(s.size(), 1);
}
const size_t offset = out_sizes.size() - s.size();
for (size_t i = 0; i < s.size(); i++) {
out_sizes[offset + i] = std::max(out_sizes[offset + i], s[i]);
}
}
graph->virtual_resize(out, out_sizes);
}

void add_where_node(
Expand Down
Loading