Skip to content

Commit 3401540

Browse files
majosmCopilot
andcommitted
omit dtype specification for floating point CSR matmul args
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent d156f00 commit 3401540

File tree

1 file changed

+0
-3
lines changed

1 file changed

+0
-3
lines changed

arraycontext/context.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,8 +485,6 @@ def make_csr_matrix(
485485

486486
@memoize_method
487487
def _get_csr_matmul_prg(self, out_ndim: int) -> loopy.TranslationUnit:
488-
import numpy as np
489-
490488
import loopy as lp
491489

492490
out_extra_inames = tuple(f"i{n}" for n in range(1, out_ndim))
@@ -595,7 +593,6 @@ def _get_csr_matmul_prg(self, out_ndim: int) -> loopy.TranslationUnit:
595593
",".join([
596594
"ncols", "nrows", "nels",
597595
*out_extra_shape_comp_names]): idx_dtype,
598-
"elem_values,array,out": np.float64,
599596
"elem_col_indices,row_starts": idx_dtype})
600597

601598
def sparse_matmul(

0 commit comments

Comments
 (0)