From eb895beda19d7a7675780149d56134ae6d0c1d2e Mon Sep 17 00:00:00 2001 From: BeceHQ Date: Sat, 17 Jan 2026 17:19:47 +0100 Subject: [PATCH 1/3] Update --- .../test/component/tensor/PermuteTest.java | 560 ++++++++++++++++++ .../tensor/TransposeLinDataTest.java | 197 ++++++ 2 files changed, 757 insertions(+) create mode 100644 src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java create mode 100644 src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java diff --git a/src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java b/src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java new file mode 100644 index 00000000000..69c4be1b4cf --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java @@ -0,0 +1,560 @@ +package org.apache.sysds.test.component.tensor; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.data.DenseBlock; +//import org.apache.sysds.runtime.data.DenseBlockFactory; +import org.mockito.Mockito; +import java.util.Arrays; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.List; +import java.util.ArrayList; + + +public class PermuteTest { + + @Test + public void TestMatrixBlockPermute() { + + int[] shape = {2, 3, 4}; + + MatrixBlock tensor = TensorUtils.createArangeMatrixBlock(shape); + Assert.assertEquals(24, tensor.getNumRows() * tensor.getNumColumns()); + + double[] data = tensor.getDenseBlockValues(); + Assert.assertEquals(23.0, data[1 * 4 * 3 + 2 * 4 + 3], 0.001); + Assert.assertEquals( 0.0, data[0 * 4 * 3 + 0 * 4 + 0], 0.001); + + TensorUtils.printMatrixTensor(tensor, shape); + + int[] permutation = {1, 0, 2}; + + MatrixBlock outTensor = PermuteIt.permute(tensor, shape, permutation); + int[] outShape = {3, 2, 4}; + + TensorUtils.printMatrixTensor(outTensor, outShape); + + double[] outData = outTensor.getDenseBlockValues(); + Assert.assertEquals(24, 1 * outTensor.getNumColumns()); + Assert.assertEquals(24, outData.length); + Assert.assertEquals(4.0, outData[8], 0.001); + Assert.assertEquals(15.0, outData[7], 0.001); + } + + @Test + public void testPermute2D_Transpose() { + int[] shape = {10, 5}; + int[] perm = {1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = PermuteIt.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute3D_Simple() { + int[] shape = {2, 3, 4}; + int[] perm = {1, 0, 2}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = PermuteIt.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute3D_Identity() { + int[] shape = {5, 5, 5}; + int[] perm = {0, 1, 2}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = PermuteIt.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute4D_Reverse() { + int[] shape = {2, 3, 4, 5}; + int[] perm = {3, 2, 1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = PermuteIt.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermuteHighRank() { + int[] shape = {2, 2, 2, 2, 2, 2}; + int[] perm = {5, 0, 4, 1, 3, 2}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = PermuteIt.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + + @Test + public void testLargeBlockLogic_Mocked() { + int[] shape = {10, 10, 10}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + + DenseBlock originalDB = in.getDenseBlock(); + DenseBlock spyDB = Mockito.spy(originalDB); + Mockito.when(spyDB.numBlocks()).thenReturn(2); + + in.setDenseBlock(spyDB); + + MatrixBlock out = PermuteIt.permute(in, shape, perm); + + MatrixBlock originalIn = generateMatrixBlock(shape); + verifyPermutation(originalIn, out, shape, perm); + } + + @Test + public void testLargeBlockLogic_Mocked_InputAndOutput() { + int[] shape = {4, 4, 4}; + int[] perm = {2, 1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + DenseBlock spyIn = Mockito.spy(in.getDenseBlock()); + Mockito.when(spyIn.numBlocks()).thenReturn(5); + in.setDenseBlock(spyIn); + + MatrixBlock out = PermuteIt.permute(in, shape, perm); + + MatrixBlock originalIn = generateMatrixBlock(shape); + verifyPermutation(originalIn, out, shape, perm); + } + + @Test + public void testPermute3D_Parallel() { + int[] shape = {100, 100, 100}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = PermuteIt.permute(in, shape, perm, -1); + + verifyPermutation(in, out, shape, perm); + } + + + private MatrixBlock generateMatrixBlock(int[] shape) { + long len = 1; + for (int d : shape) len *= d; + + MatrixBlock mb = new MatrixBlock(1, (int)len, false); + mb.allocateDenseBlock(); + double[] data = mb.getDenseBlockValues(); + for(int i = 0; i < data.length; i++) { + data[i] = (double)i; + } + return mb; + } + + + private void verifyPermutation(MatrixBlock in, MatrixBlock out, int[] inShape, int[] perm) { + + double[] inData = new double[(int)(in.getNumRows() * in.getNumColumns())]; + double[] outData = new double[(int)(out.getNumRows() * out.getNumColumns())]; + + DenseBlock inDB = in.getDenseBlock(); + DenseBlock outDB = out.getDenseBlock(); + + if (inDB != null) { + int inBlockSize = inDB.blockSize(); + for (int i = 0; i < inDB.numBlocks(); i++) { + double[] block = inDB.valuesAt(i); + int offset = i * inBlockSize; + int len = Math.min(inBlockSize, inData.length - offset); + System.arraycopy(block, 0, inData, offset, len); + } + } + + if (outDB != null) { + int outBlockSize = outDB.blockSize(); + for (int i = 0; i < outDB.numBlocks(); i++) { + double[] block = outDB.valuesAt(i); + int offset = i * outBlockSize; + int len = Math.min(outBlockSize, outData.length - offset); + System.arraycopy(block, 0, outData, offset, len); + } + } + + int rank = inShape.length; + int[] outShape = new int[rank]; + for(int i=0; i 0.0001) { + Assert.fail("Mismatch at linear output index " + i + + ". Output coords " + Arrays.toString(currentCoords) + + ". Input coords " + Arrays.toString(inCoords) + + ". Expected " + expectedValue + " but got " + actualValue); + } + } + } + + private long[] getStrides(int[] dims) { + long[] strides = new long[dims.length]; + long stride = 1; + for (int i = dims.length - 1; i >= 0; i--) { + strides[i] = stride; + stride *= dims[i]; + } + return strides; + } + + + public static class TensorUtils { + + public static MatrixBlock createArangeMatrixBlock(int[] shape) { + long length = 1; + for (int d : shape) length *= d; + + MatrixBlock mb = new MatrixBlock(1, (int)length, false); + mb.allocateDenseBlock(); + + double[] data = mb.getDenseBlockValues(); + for (int i = 0; i < data.length; i++) { + data[i] = (double) i; + } + return mb; + } + + public static void printMatrixTensor(MatrixBlock mb, int[] shape) { + double[] data = mb.getDenseBlockValues(); + StringBuilder sb = new StringBuilder(); + sb.append("MatrixBlock-Tensor(").append(Arrays.toString(shape)).append("):\n"); + printRecursive(data, shape, 0, 0, sb, 0); + System.out.println(sb.toString()); + } + + private static void printRecursive(double[] data, int[] shape, int dim, int offset, StringBuilder sb, int indent) { + int stride = 1; + for (int i = dim + 1; i < shape.length; i++) stride *= shape[i]; + + for (int k = 0; k < indent; k++) sb.append(" "); + sb.append("["); + + if (dim == shape.length - 1) { + for (int i = 0; i < shape[dim]; i++) { + sb.append(String.format("%.1f", data[offset + i])); + if (i < shape[dim] - 1) sb.append(", "); + } + sb.append("]"); + } else { + sb.append("\n"); + for (int i = 0; i < shape[dim]; i++) { + printRecursive(data, shape, dim + 1, offset + i * stride, sb, indent + 2); + if (i < shape[dim] - 1) { + sb.append(","); + sb.append("\n"); + if (shape.length - dim > 2) sb.append("\n"); + } + } + sb.append("\n"); + for (int k = 0; k < indent; k++) sb.append(" "); + sb.append("]"); + } + } + } + + + public static class PermuteIt { + + // blocking according to typical L2 cache sizes + private static final int BLOCK_SIZE = 128; + private static final int PAR_NUMCELL_THRESHOLD = 1024; //1024*1024 + + //Aus LibMatrixReorg + static void transposeRow(double[] a, double[] c, int aix, int cix, int n2, int len) { + final int bn = len % 8; + for (int j = 0; j < bn; j++, aix++, cix += n2) + c[cix] = a[aix]; + for (int j = bn; j < len; j += 8, aix += 8, cix += 8 * n2) { + c[cix + 0 * n2] = a[aix + 0]; + c[cix + 1 * n2] = a[aix + 1]; + c[cix + 2 * n2] = a[aix + 2]; + c[cix + 3 * n2] = a[aix + 3]; + c[cix + 4 * n2] = a[aix + 4]; + c[cix + 5 * n2] = a[aix + 5]; + c[cix + 6 * n2] = a[aix + 6]; + c[cix + 7 * n2] = a[aix + 7]; + } + } + + private static long[] getStrides(int[] dims) { + long[] strides = new long[dims.length]; + long stride = 1; + for (int i = dims.length - 1; i >= 0; i--) { + strides[i] = stride; + stride *= dims[i]; + } + return strides; + } + + public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm) { + return permute(in, inDims, perm, 1); + } + + public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm, int k) { + int rank = inDims.length; + + //Early opt out + boolean isIdentity = true; + for (int i = 0; i < rank; i++) { + if (perm[i] != i) { + isIdentity = false; + break; + } + } + if (isIdentity) { + return new MatrixBlock(in); + } + + int[] outDims = new int[rank]; + for (int i = 0; i < rank; i++) + outDims[i] = inDims[perm[i]]; + + long length = 1; + for (int d : outDims) length *= d; + + MatrixBlock out = new MatrixBlock(1, (int)length, false); + out.allocateDenseBlock(); + + DenseBlock inDB = in.getDenseBlock(); + DenseBlock outDB = out.getDenseBlock(); + + long[] inStrides = getStrides(inDims); + long[] outStrides = getStrides(outDims); + + long[] permutedStrides = new long[rank]; + for (int i = 0; i < rank; i++) { + permutedStrides[i] = outStrides[perm[i]]; + } + + boolean useParallel = (k > 1 || k == -1) && length >= PAR_NUMCELL_THRESHOLD; + int numThreads = k == -1 ? Runtime.getRuntime().availableProcessors() : k; + + if (inDB.numBlocks() == 1 && outDB.numBlocks() == 1) { + double[] inData = inDB.valuesAt(0); + double[] outData = outDB.valuesAt(0); + + if (useParallel && rank > 0) { + parallelPermuteSingleBlock(inData, outData, inDims, inStrides, + permutedStrides, numThreads); + } else { + recursivePermuteSingleBlock(inData, outData, inDims, inStrides, + permutedStrides, 0, 0, 0); + } + } + else { + if (useParallel && rank > 0) { + parallelPermuteMultiBlock(inDB, outDB, inDims, inStrides, + permutedStrides, numThreads); + } else { + recursivePermuteMultiBlock(inDB, outDB, inDims, inStrides, + permutedStrides, 0, 0L, 0L); + } + } + return out; + } + + private static void recursivePermuteSingleBlock( + double[] inData, double[] outData, + int[] inDims, long[] inStrides, long[] permutedStrides, + int dim, int inOffset, int outOffset) { + + if (dim == inDims.length - 1) { + int len = inDims[dim]; + int outStride = (int) permutedStrides[dim]; + + if (outStride == 1) + System.arraycopy(inData, inOffset, outData, outOffset, len); + else + transposeRow(inData, outData, inOffset, outOffset, outStride, len); + return; + } + + int dimSize = inDims[dim]; + long inStep = inStrides[dim]; + long outStep = permutedStrides[dim]; + + for (int bi = 0; bi < dimSize; bi += BLOCK_SIZE) { + int bimin = Math.min(bi + BLOCK_SIZE, dimSize); + for (int i = bi; i < bimin; i++) { + recursivePermuteSingleBlock( + inData, outData, inDims, inStrides, permutedStrides, + dim + 1, + inOffset + (int)(i * inStep), + outOffset + (int)(i * outStep) + ); + } + } + } + + private static void parallelPermuteSingleBlock( + double[] inData, double[] outData, + int[] inDims, long[] inStrides, long[] permutedStrides, + int numThreads) { + + int dimSize = inDims[0]; + int tasksPerThread = Math.max(1, dimSize / numThreads); + + ExecutorService pool = Executors.newFixedThreadPool(numThreads); + List> futures = new ArrayList<>(); + + for (int t = 0; t < numThreads; t++) { + final int start = t * tasksPerThread; + final int end = (t == numThreads - 1) ? dimSize : (t + 1) * tasksPerThread; + + if (start >= dimSize) break; + + futures.add(pool.submit(() -> { + for (int i = start; i < end; i++) { + recursivePermuteSingleBlock( + inData, outData, inDims, inStrides, permutedStrides, + 1, + (int)(i * inStrides[0]), + (int)(i * permutedStrides[0]) + ); + } + })); + } + + for (Future f : futures) { + try { + f.get(); + } catch (Exception e) { + throw new RuntimeException("Parallel permute failed", e); + } + } + pool.shutdown(); + } + + private static void recursivePermuteMultiBlock( + DenseBlock inDB, DenseBlock outDB, + int[] inDims, long[] inStrides, long[] permutedStrides, + int dim, long inOffset, long outOffset) { + + if (dim == inDims.length - 1) { + int len = inDims[dim]; + long outStride = permutedStrides[dim]; + + int inBlockSize = inDB.blockSize(); + int outBlockSize = outDB.blockSize(); + + for (int i = 0; i < len; i++) { + long currentInAbs = inOffset + i * inStrides[dim]; + long currentOutAbs = outOffset + i * outStride; + + int inBlockIdx = (int) (currentInAbs / inBlockSize); + int inRelIdx = (int) (currentInAbs % inBlockSize); + + int outBlockIdx = (int) (currentOutAbs / outBlockSize); + int outRelIdx = (int) (currentOutAbs % outBlockSize); + + double[] inArr = inDB.valuesAt(inBlockIdx); + double[] outArr = outDB.valuesAt(outBlockIdx); + + if (inArr != null && outArr != null && + inRelIdx < inArr.length && outRelIdx < outArr.length) { + outArr[outRelIdx] = inArr[inRelIdx]; + } + } + return; + } + + int dimSize = inDims[dim]; + long inStep = inStrides[dim]; + long outStep = permutedStrides[dim]; + + for (int bi = 0; bi < dimSize; bi += BLOCK_SIZE) { + int bimin = Math.min(bi + BLOCK_SIZE, dimSize); + for (int i = bi; i < bimin; i++) { + recursivePermuteMultiBlock( + inDB, outDB, inDims, inStrides, permutedStrides, + dim + 1, + inOffset + i * inStep, + outOffset + i * outStep + ); + } + } + } + + private static void parallelPermuteMultiBlock( + DenseBlock inDB, DenseBlock outDB, + int[] inDims, long[] inStrides, long[] permutedStrides, + int numThreads) { + + int dimSize = inDims[0]; + int tasksPerThread = Math.max(1, dimSize / numThreads); + + ExecutorService pool = Executors.newFixedThreadPool(numThreads); + List> futures = new ArrayList<>(); + + for (int t = 0; t < numThreads; t++) { + final int start = t * tasksPerThread; + final int end = (t == numThreads - 1) ? dimSize : (t + 1) * tasksPerThread; + + if (start >= dimSize) break; + + futures.add(pool.submit(() -> { + for (int i = start; i < end; i++) { + recursivePermuteMultiBlock( + inDB, outDB, inDims, inStrides, permutedStrides, + 1, + i * inStrides[0], + i * permutedStrides[0] + ); + } + })); + } + + for (Future f : futures) { + try { + f.get(); + } catch (Exception e) { + throw new RuntimeException("Parallel permute failed", e); + } + } + pool.shutdown(); + } + } +} \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java b/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java new file mode 100644 index 00000000000..d7e13a8b562 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/tensor/TransposeLinDataTest.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.tensor; + +import org.junit.Assert; +import org.junit.Test; + +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.data.TensorBlock; + import java.util.Arrays; + +public class TransposeLinDataTest { + + @Test + public void Testrightelem(){ + int[] shape = {2, 3, 4}; + TensorBlock tensor = TensorUtils.createArangeTensor(shape); + + Assert.assertArrayEquals(new int[]{2, 3, 4}, tensor.getDims()); + Assert.assertEquals(0.0, tensor.get(new int[]{0, 0, 0})); + Assert.assertEquals(23.0, tensor.get(new int[]{1, 2, 3})); + Assert.assertEquals(6.0, tensor.get(new int[]{0, 1, 2})); + Assert.assertEquals(12.0, tensor.get(new int[]{1, 0, 0})); + printTensor(tensor); + + + int[] permutation = {1, 0, 2}; + TensorBlock outTensor = PermuteIt.permute(tensor, permutation); + printTensor(outTensor); + + Assert.assertArrayEquals(new int[]{3, 2, 4}, outTensor.getDims()); + Assert.assertEquals(0.0, outTensor.get(new int[]{0,0,0})); + Assert.assertEquals(23.0, outTensor.get(new int[]{2, 1, 3})); + Assert.assertEquals(12.0, outTensor.get(new int[]{0, 1, 0})); + Assert.assertEquals(17.0, outTensor.get(new int[]{1, 1, 1})); + + + int[] second_permutation = {2, 1, 0}; + TensorBlock perm2Block = PermuteIt.permute(tensor, second_permutation); + printTensor(perm2Block); + + Assert.assertArrayEquals(new int[]{4, 3, 2}, perm2Block.getDims()); + Assert.assertEquals(0.0, perm2Block.get(new int[]{0, 0, 0})); + Assert.assertEquals(12.0, perm2Block.get(new int[]{0, 0, 1})); + Assert.assertEquals(11.0, perm2Block.get(new int[]{3, 2, 0})); + Assert.assertEquals(23.0, perm2Block.get(new int[]{3, 2, 1})); + + } + + + + + public class TensorUtils { + + public static TensorBlock createArangeTensor(int[] shape) { + TensorBlock tb = new TensorBlock(ValueType.FP64, shape); + tb.allocateBlock(); + double[] counter = { 0.0 }; + int[] currentIndices = new int[shape.length]; + + fillRecursively(tb, shape, 0, currentIndices, counter); + + return tb; + } + + private static void fillRecursively(TensorBlock tb, int[] shape, int dim, int[] currentIndices, double[] counter) { + if (dim == shape.length) { + tb.set(currentIndices, counter[0]); + counter[0]++; + return; + } + + for (int i = 0; i < shape[dim]; i++) { + currentIndices[dim] = i; + + fillRecursively(tb, shape, dim + 1, currentIndices, counter); + } + } + } + + + + public class PermuteIt { + + + public static TensorBlock permute(TensorBlock tensor, int[] permute_dims) { + + int anz_dims = tensor.getNumDims(); + int[] dims = tensor.getDims(); + ValueType tensorType = tensor.getValueType(); + + int[] out_shape = new int[anz_dims]; + + for (int idx = 0; idx < anz_dims; idx++){ + out_shape[idx] = dims[permute_dims[idx]]; + } + + TensorBlock outTensor = new TensorBlock(tensorType, out_shape); + outTensor.allocateBlock(); + + int[] inIndex = new int[anz_dims]; + int[] outIndex = new int[anz_dims]; + + rekursion(tensor, outTensor, permute_dims, dims, 0, inIndex, outIndex); + return outTensor; + } + + public static void rekursion(TensorBlock inTensor, + TensorBlock outTensor, + int[] permutation, + int[] inShape, + int dim, + int[] inIndex, + int[]outIndex + ){ + + if (dim == inShape.length) { + for(int idx = 0; idx < permutation.length; idx++){ + outIndex[idx] = inIndex[permutation[idx]]; + } + double val = (double) inTensor.get(inIndex); + outTensor.set(outIndex, val); + return; + } + + for(int idx = 0; idx < inShape[dim]; idx++){ + inIndex[dim] = idx; + rekursion(inTensor, outTensor, permutation, inShape, dim+1, inIndex, outIndex); + } + + } + + } + + + public static void printTensor(TensorBlock tb) { + StringBuilder sb = new StringBuilder(); + int[] shape = tb.getDims(); + int[] currentIndices = new int[shape.length]; + + sb.append("Tensor(").append(Arrays.toString(shape)).append("):\n"); + printRecursive(tb, shape, 0, currentIndices, sb, 0); + + System.out.println(sb.toString()); + } + + private static void printRecursive(TensorBlock tb, int[] shape, int dim, int[] indices, StringBuilder sb, int indent) { + for (int k = 0; k < indent; k++) sb.append(" "); + + sb.append("["); + + if (dim == shape.length - 1) { + for (int i = 0; i < shape[dim]; i++) { + indices[dim] = i; + double val = (double) tb.get(indices); + sb.append(String.format("%.1f", val)); + if (i < shape[dim] - 1) sb.append(", "); + } + sb.append("]"); + } + + else { + sb.append("\n"); + for (int i = 0; i < shape[dim]; i++) { + indices[dim] = i; + printRecursive(tb, shape, dim + 1, indices, sb, indent + 2); + + if (i < shape[dim] - 1) { + sb.append(","); + sb.append("\n"); + if (shape.length - dim > 2) sb.append("\n"); + } + } + sb.append("\n"); + for (int k = 0; k < indent; k++) sb.append(" "); + sb.append("]"); + } + } + +} \ No newline at end of file From adffba1f43a2a9e9a8c2e3d8602720032ec3687b Mon Sep 17 00:00:00 2001 From: BeceHQ Date: Thu, 22 Jan 2026 17:51:39 +0100 Subject: [PATCH 2/3] Update --- .../test/component/tensor/PermuteTest.java | 201 ++++++++++-------- 1 file changed, 109 insertions(+), 92 deletions(-) diff --git a/src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java b/src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java index 69c4be1b4cf..384c6e57c0e 100644 --- a/src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java +++ b/src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java @@ -8,9 +8,8 @@ import org.mockito.Mockito; import java.util.Arrays; import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; +import org.apache.sysds.runtime.util.CommonThreadPool; import java.util.concurrent.Future; -import java.util.List; import java.util.ArrayList; @@ -322,7 +321,7 @@ static void transposeRow(double[] a, double[] c, int aix, int cix, int n2, int l private static long[] getStrides(int[] dims) { long[] strides = new long[dims.length]; long stride = 1; - for (int i = dims.length - 1; i >= 0; i--) { + for( int i = dims.length - 1; i >= 0; i-- ) { strides[i] = stride; stride *= dims[i]; } @@ -336,24 +335,25 @@ public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm) { public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm, int k) { int rank = inDims.length; - //Early opt out + // Early opt out boolean isIdentity = true; - for (int i = 0; i < rank; i++) { - if (perm[i] != i) { + for( int i = 0; i < rank; i++ ) { + if( perm[i] != i ) { isIdentity = false; break; } } - if (isIdentity) { + + if( isIdentity ) { return new MatrixBlock(in); } int[] outDims = new int[rank]; - for (int i = 0; i < rank; i++) + for( int i = 0; i < rank; i++ ) outDims[i] = inDims[perm[i]]; long length = 1; - for (int d : outDims) length *= d; + for( int d : outDims ) length *= d; MatrixBlock out = new MatrixBlock(1, (int)length, false); out.allocateDenseBlock(); @@ -365,27 +365,26 @@ public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm, int long[] outStrides = getStrides(outDims); long[] permutedStrides = new long[rank]; - for (int i = 0; i < rank; i++) { + for( int i = 0; i < rank; i++ ) { permutedStrides[i] = outStrides[perm[i]]; } boolean useParallel = (k > 1 || k == -1) && length >= PAR_NUMCELL_THRESHOLD; int numThreads = k == -1 ? Runtime.getRuntime().availableProcessors() : k; - if (inDB.numBlocks() == 1 && outDB.numBlocks() == 1) { + if( inDB.numBlocks() == 1 && outDB.numBlocks() == 1 ) { double[] inData = inDB.valuesAt(0); double[] outData = outDB.valuesAt(0); - if (useParallel && rank > 0) { + if( useParallel && rank > 0 ) { parallelPermuteSingleBlock(inData, outData, inDims, inStrides, permutedStrides, numThreads); } else { recursivePermuteSingleBlock(inData, outData, inDims, inStrides, permutedStrides, 0, 0, 0); } - } - else { - if (useParallel && rank > 0) { + } else { + if( useParallel && rank > 0 ) { parallelPermuteMultiBlock(inDB, outDB, inDims, inStrides, permutedStrides, numThreads); } else { @@ -401,14 +400,15 @@ private static void recursivePermuteSingleBlock( int[] inDims, long[] inStrides, long[] permutedStrides, int dim, int inOffset, int outOffset) { - if (dim == inDims.length - 1) { + if( dim == inDims.length - 1 ) { int len = inDims[dim]; int outStride = (int) permutedStrides[dim]; - if (outStride == 1) + if( outStride == 1 ) { System.arraycopy(inData, inOffset, outData, outOffset, len); - else + } else { transposeRow(inData, outData, inOffset, outOffset, outStride, len); + } return; } @@ -416,9 +416,10 @@ private static void recursivePermuteSingleBlock( long inStep = inStrides[dim]; long outStep = permutedStrides[dim]; - for (int bi = 0; bi < dimSize; bi += BLOCK_SIZE) { + //blocked execution + for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) { int bimin = Math.min(bi + BLOCK_SIZE, dimSize); - for (int i = bi; i < bimin; i++) { + for( int i = bi; i < bimin; i++ ) { recursivePermuteSingleBlock( inData, outData, inDims, inStrides, permutedStrides, dim + 1, @@ -432,40 +433,43 @@ private static void recursivePermuteSingleBlock( private static void parallelPermuteSingleBlock( double[] inData, double[] outData, int[] inDims, long[] inStrides, long[] permutedStrides, - int numThreads) { - - int dimSize = inDims[0]; - int tasksPerThread = Math.max(1, dimSize / numThreads); + int k) { - ExecutorService pool = Executors.newFixedThreadPool(numThreads); - List> futures = new ArrayList<>(); + final int dimSize = inDims[0]; + final int tasksPerThread = Math.max(1, dimSize / k); - for (int t = 0; t < numThreads; t++) { - final int start = t * tasksPerThread; - final int end = (t == numThreads - 1) ? dimSize : (t + 1) * tasksPerThread; - - if (start >= dimSize) break; - - futures.add(pool.submit(() -> { - for (int i = start; i < end; i++) { - recursivePermuteSingleBlock( - inData, outData, inDims, inStrides, permutedStrides, - 1, - (int)(i * inStrides[0]), - (int)(i * permutedStrides[0]) - ); - } - })); - } - - for (Future f : futures) { - try { - f.get(); - } catch (Exception e) { - throw new RuntimeException("Parallel permute failed", e); + // Set up thread pool + final ExecutorService pool = CommonThreadPool.get(k); + try { + final ArrayList> tasks = new ArrayList<>(); + + for( int t = 0; t < k; t++ ) { + final int start = t * tasksPerThread; + final int end = (t == k - 1) ? dimSize : (t + 1) * tasksPerThread; + + if( start >= dimSize ) break; + + tasks.add(pool.submit(() -> { + for( int i = start; i < end; i++ ) { + recursivePermuteSingleBlock( + inData, outData, inDims, inStrides, permutedStrides, + 1, + (int)(i * inStrides[0]), + (int)(i * permutedStrides[0]) + ); + } + })); } + + // Wait for all threads + for (Future task : tasks){ //pool.invokeAll(tasks)) { + task.get(); + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } finally { + pool.shutdown(); } - pool.shutdown(); } private static void recursivePermuteMultiBlock( @@ -473,14 +477,14 @@ private static void recursivePermuteMultiBlock( int[] inDims, long[] inStrides, long[] permutedStrides, int dim, long inOffset, long outOffset) { - if (dim == inDims.length - 1) { + if (dim == inDims.length - 1 ) { int len = inDims[dim]; long outStride = permutedStrides[dim]; - + int inBlockSize = inDB.blockSize(); int outBlockSize = outDB.blockSize(); - for (int i = 0; i < len; i++) { + for( int i = 0; i < len; i++ ) { long currentInAbs = inOffset + i * inStrides[dim]; long currentOutAbs = outOffset + i * outStride; @@ -493,8 +497,8 @@ private static void recursivePermuteMultiBlock( double[] inArr = inDB.valuesAt(inBlockIdx); double[] outArr = outDB.valuesAt(outBlockIdx); - if (inArr != null && outArr != null && - inRelIdx < inArr.length && outRelIdx < outArr.length) { + if( inArr != null && outArr != null && + inRelIdx < inArr.length && outRelIdx < outArr.length ) { outArr[outRelIdx] = inArr[inRelIdx]; } } @@ -505,14 +509,15 @@ private static void recursivePermuteMultiBlock( long inStep = inStrides[dim]; long outStep = permutedStrides[dim]; - for (int bi = 0; bi < dimSize; bi += BLOCK_SIZE) { + //blocked execution + for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) { int bimin = Math.min(bi + BLOCK_SIZE, dimSize); - for (int i = bi; i < bimin; i++) { + for( int i = bi; i < bimin; i++ ) { recursivePermuteMultiBlock( - inDB, outDB, inDims, inStrides, permutedStrides, - dim + 1, - inOffset + i * inStep, - outOffset + i * outStep + inDB, outDB, inDims, inStrides, permutedStrides, + dim + 1, + inOffset + i * inStep, + outOffset + i * outStep ); } } @@ -521,40 +526,52 @@ private static void recursivePermuteMultiBlock( private static void parallelPermuteMultiBlock( DenseBlock inDB, DenseBlock outDB, int[] inDims, long[] inStrides, long[] permutedStrides, - int numThreads) { - - int dimSize = inDims[0]; - int tasksPerThread = Math.max(1, dimSize / numThreads); + int k) { - ExecutorService pool = Executors.newFixedThreadPool(numThreads); - List> futures = new ArrayList<>(); + final int dimSize = inDims[0]; + final int tasksPerThread = Math.max(1, dimSize / k); - for (int t = 0; t < numThreads; t++) { - final int start = t * tasksPerThread; - final int end = (t == numThreads - 1) ? dimSize : (t + 1) * tasksPerThread; + // Set up thread pool + final ExecutorService pool = CommonThreadPool.get(k); + try { + final ArrayList> tasks = new ArrayList<>(); - if (start >= dimSize) break; - - futures.add(pool.submit(() -> { - for (int i = start; i < end; i++) { - recursivePermuteMultiBlock( - inDB, outDB, inDims, inStrides, permutedStrides, - 1, - i * inStrides[0], - i * permutedStrides[0] - ); - } - })); - } - - for (Future f : futures) { - try { - f.get(); - } catch (Exception e) { - throw new RuntimeException("Parallel permute failed", e); + for (int t = 0; t < k; t++) { + final int start = t * tasksPerThread; + final int end = (t == k - 1) ? dimSize : (t + 1) * tasksPerThread; + + if (start >= dimSize) break; + + tasks.add(pool.submit(() -> { + for (int i = start; i < end; i++) { + recursivePermuteMultiBlock( + inDB, outDB, inDims, inStrides, permutedStrides, + 1, + i * inStrides[0], + i * permutedStrides[0] + ); + } + })); } - } + + // Wait for all threads + for (Future task : tasks) { + task.get(); + } + } catch (Exception ex) { + throw new RuntimeException(ex); + } finally { pool.shutdown(); + } } - } -} \ No newline at end of file + } +} + + +//callable +//naming +//class vs func +//Wo reinfügen +//Genauer wie abgeben +//Bezüglich phd + From da4473dd2efbe210f62df89a3d0ea51d096f5272 Mon Sep 17 00:00:00 2001 From: BeceHQ Date: Fri, 30 Jan 2026 21:20:26 +0100 Subject: [PATCH 3/3] Final version --- .../runtime/matrix/data/LibMatrixReorg.java | 338 ++++++++++ .../matrix/libMatrixReorg/PermuteTest.java | 418 +++++++++++++ .../test/component/tensor/PermuteTest.java | 577 ------------------ 3 files changed, 756 insertions(+), 577 deletions(-) create mode 100644 src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java delete mode 100644 src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index 90ea445be8d..2e2a874741d 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -4515,4 +4515,342 @@ private static int eulerTotient(int[] primes, int[] exponents, int[] iExponents, } return count; } + + + + + + + + private static long[] getStridesForPermutation(int[] dims) { + long[] strides = new long[dims.length]; + long stride = 1; + for( int i = dims.length - 1; i >= 0; i-- ) { + strides[i] = stride; + stride *= dims[i]; + } + return strides; + } + + public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm) { + return permute(in, inDims, perm, 1); + } + + public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm, int k) { + int rank = inDims.length; + + boolean isIdentity = true; + for( int i = 0; i < rank; i++ ) { + if( perm[i] != i ) { + isIdentity = false; + break; + } + } + + if( isIdentity ) { + return new MatrixBlock(in); + } + + int[] outDims = new int[rank]; + for( int i = 0; i < rank; i++ ) { + outDims[i] = inDims[perm[i]]; + } + + long length = 1; + for( int d : outDims ) { + length *= d; + } + + MatrixBlock out = new MatrixBlock(1, (int)length, false); + out.allocateDenseBlock(); + + DenseBlock inDB = in.getDenseBlock(); + DenseBlock outDB = out.getDenseBlock(); + + long[] inStrides = getStridesForPermutation(inDims); + long[] outStrides = getStridesForPermutation(outDims); + + long[] permutedStrides = new long[rank]; + for( int i = 0; i < rank; i++ ) { + permutedStrides[i] = outStrides[perm[i]]; + } + + boolean useParallel = (k > 1 || k == -1) && length >= PAR_NUMCELL_THRESHOLD; + int numThreads = k == -1 ? Runtime.getRuntime().availableProcessors() : k; + + if( inDB.numBlocks() == 1 && outDB.numBlocks() == 1 ) { + double[] inData = inDB.valuesAt(0); + double[] outData = outDB.valuesAt(0); + + if( useParallel && rank > 0 ) { + permuteSingleBlockParallel(inData, outData, inDims, inStrides, + permutedStrides, numThreads, length); + } else { + permuteSingleBlock(inData, outData, inDims, inStrides, + permutedStrides, 0, 0, 0); + } + } else { + if( useParallel && rank > 0 ) { + permuteMultiBlockParallel(inDB, outDB, inDims, inStrides, + permutedStrides, numThreads, length); + } else { + permuteMultiBlock(inDB, outDB, inDims, inStrides, + permutedStrides, 0, 0L, 0L); + } + } + return out; + } + + private static void permuteSingleBlock( + double[] inData, double[] outData, + int[] inDims, long[] inStrides, long[] permutedStrides, + int dim, int inOffset, int outOffset) { + + if( dim == inDims.length - 1 ) { + int len = inDims[dim]; + int outStride = (int) permutedStrides[dim]; + + if( outStride == 1 ) { + System.arraycopy(inData, inOffset, outData, outOffset, len); + } else { + transposeRow(inData, outData, inOffset, outOffset, outStride, len); + } + return; + } + + int dimSize = inDims[dim]; + long inStep = inStrides[dim]; + long outStep = permutedStrides[dim]; + + final int BLOCK_SIZE = 128; + for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) { + int bimin = Math.min(bi + BLOCK_SIZE, dimSize); + for( int i = bi; i < bimin; i++ ) { + permuteSingleBlock( + inData, outData, inDims, inStrides, permutedStrides, + dim + 1, + inOffset + (int)(i * inStep), + outOffset + (int)(i * outStep) + ); + } + } + } + + private static void permuteSingleBlockParallel( + double[] inData, double[] outData, + int[] inDims, long[] inStrides, long[] permutedStrides, + int k, long totalElements) { + + final long elementsPerThread = Math.max(1024, (totalElements + k - 1) / k); + final int actualThreads = (int) Math.min(k, (totalElements + elementsPerThread - 1) / elementsPerThread); + + final ExecutorService pool = CommonThreadPool.get(actualThreads); + try { + final ArrayList tasks = new ArrayList<>(); + + for( int t = 0; t < actualThreads; t++ ) { + final long start = t * elementsPerThread; + final long end = Math.min(start + elementsPerThread, totalElements); + + if( start >= totalElements ) { + break; + } + + tasks.add(new PermuteSingleBlockTask(inData, outData, inDims, + inStrides, permutedStrides, start, end)); + } + + for( Future task : pool.invokeAll(tasks) ) { + task.get(); + } + } catch (Exception ex) { + throw new DMLRuntimeException(ex); + } finally { + pool.shutdown(); + } + } + + private static void permuteMultiBlock( + DenseBlock inDB, DenseBlock outDB, + int[] inDims, long[] inStrides, long[] permutedStrides, + int dim, long inOffset, long outOffset) { + + if( dim == inDims.length - 1 ) { + int len = inDims[dim]; + long outStride = permutedStrides[dim]; + + int inBlockSize = inDB.blockSize(); + int outBlockSize = outDB.blockSize(); + + for( int i = 0; i < len; i++ ) { + long currentInAbs = inOffset + i * inStrides[dim]; + long currentOutAbs = outOffset + i * outStride; + + int inBlockIdx = (int) (currentInAbs / inBlockSize); + int inRelIdx = (int) (currentInAbs % inBlockSize); + + int outBlockIdx = (int) (currentOutAbs / outBlockSize); + int outRelIdx = (int) (currentOutAbs % outBlockSize); + + double[] inArr = inDB.valuesAt(inBlockIdx); + double[] outArr = outDB.valuesAt(outBlockIdx); + + if( inArr != null && outArr != null && + inRelIdx < inArr.length && outRelIdx < outArr.length ) { + outArr[outRelIdx] = inArr[inRelIdx]; + } + } + return; + } + + int dimSize = inDims[dim]; + long inStep = inStrides[dim]; + long outStep = permutedStrides[dim]; + + final int BLOCK_SIZE = 128; + for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) { + int bimin = Math.min(bi + BLOCK_SIZE, dimSize); + for( int i = bi; i < bimin; i++ ) { + permuteMultiBlock( + inDB, outDB, inDims, inStrides, permutedStrides, + dim + 1, + inOffset + i * inStep, + outOffset + i * outStep + ); + } + } + } + + private static void permuteMultiBlockParallel( + DenseBlock inDB, DenseBlock outDB, + int[] inDims, long[] inStrides, long[] permutedStrides, + int k, long totalElements) { + + final long elementsPerThread = Math.max(1024, (totalElements + k - 1) / k); + final int actualThreads = (int) Math.min(k, (totalElements + elementsPerThread - 1) / elementsPerThread); + + final ExecutorService pool = CommonThreadPool.get(actualThreads); + try { + final ArrayList tasks = new ArrayList<>(); + + for( int t = 0; t < actualThreads; t++ ) { + final long start = t * elementsPerThread; + final long end = Math.min(start + elementsPerThread, totalElements); + + if( start >= totalElements ) { + break; + } + + tasks.add(new PermuteMultiBlockTask(inDB, outDB, inDims, + inStrides, permutedStrides, start, end)); + } + + for( Future task : pool.invokeAll(tasks) ) { + task.get(); + } + + } catch (Exception ex) { + throw new DMLRuntimeException(ex); + } finally { + pool.shutdown(); + } + } + + private static class PermuteSingleBlockTask implements Callable { + private final double[] inData; + private final double[] outData; + private final int[] inDims; + private final long[] inStrides; + private final long[] permutedStrides; + private final long start; + private final long end; + + protected PermuteSingleBlockTask(double[] inData, double[] outData, + int[] inDims, long[] inStrides, long[] permutedStrides, + long start, long end) { + this.inData = inData; + this.outData = outData; + this.inDims = inDims; + this.inStrides = inStrides; + this.permutedStrides = permutedStrides; + this.start = start; + this.end = end; + } + + @Override + public Object call() { + for( long idx = start; idx < end; idx++ ) { + long inIdx = 0; + long outIdx = 0; + long remaining = idx; + + for( int d = 0; d < inDims.length; d++ ) { + long coord = remaining / inStrides[d]; + remaining = remaining % inStrides[d]; + inIdx += coord * inStrides[d]; + outIdx += coord * permutedStrides[d]; + } + + outData[(int)outIdx] = inData[(int)inIdx]; + } + return null; + } + } + + private static class PermuteMultiBlockTask implements Callable { + private final DenseBlock inDB; + private final DenseBlock outDB; + private final int[] inDims; + private final long[] inStrides; + private final long[] permutedStrides; + private final long start; + private final long end; + + protected PermuteMultiBlockTask(DenseBlock inDB, DenseBlock outDB, + int[] inDims, long[] inStrides, long[] permutedStrides, + long start, long end) { + this.inDB = inDB; + this.outDB = outDB; + this.inDims = inDims; + this.inStrides = inStrides; + this.permutedStrides = permutedStrides; + this.start = start; + this.end = end; + } + + @Override + public Object call() { + int inBlockSize = inDB.blockSize(); + int outBlockSize = outDB.blockSize(); + + for( long idx = start; idx < end; idx++ ) { + long inIdx = 0; + long outIdx = 0; + long remaining = idx; + + for( int d = 0; d < inDims.length; d++ ) { + long coord = remaining / inStrides[d]; + remaining = remaining % inStrides[d]; + inIdx += coord * inStrides[d]; + outIdx += coord * permutedStrides[d]; + } + + int inBlockIdx = (int) (inIdx / inBlockSize); + int inRelIdx = (int) (inIdx % inBlockSize); + + int outBlockIdx = (int) (outIdx / outBlockSize); + int outRelIdx = (int) (outIdx % outBlockSize); + + double[] inArr = inDB.valuesAt(inBlockIdx); + double[] outArr = outDB.valuesAt(outBlockIdx); + + if( inArr != null && outArr != null && + inRelIdx < inArr.length && outRelIdx < outArr.length ) { + outArr[outRelIdx] = inArr[inRelIdx]; + } + } + return null; + } + } } + diff --git a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java new file mode 100644 index 00000000000..2b09c0ca0ae --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/PermuteTest.java @@ -0,0 +1,418 @@ +package org.apache.sysds.test.component.matrix.libMatrixReorg; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.data.LibMatrixReorg; +import org.apache.sysds.runtime.data.DenseBlock; +import org.mockito.Mockito; +import java.util.Arrays; + +public class PermuteTest { + + @Test + public void testBasicPermute() { + int[] shape = {2, 3, 4}; + MatrixBlock tensor = generateMatrixBlock(shape); + + Assert.assertEquals(24, tensor.getNumRows() * tensor.getNumColumns()); + + double[] data = tensor.getDenseBlockValues(); + Assert.assertEquals(23.0, data[1 * 4 * 3 + 2 * 4 + 3], 0.001); + Assert.assertEquals(0.0, data[0 * 4 * 3 + 0 * 4 + 0], 0.001); + + int[] permutation = {1, 0, 2}; + MatrixBlock outTensor = LibMatrixReorg.permute(tensor, shape, permutation); + + double[] outData = outTensor.getDenseBlockValues(); + Assert.assertEquals(24, outData.length); + Assert.assertEquals(4.0, outData[8], 0.001); + Assert.assertEquals(15.0, outData[7], 0.001); + } + + @Test + public void testPermute2D_Transpose() { + int[] shape = {10, 5}; + int[] perm = {1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute3D_Simple() { + int[] shape = {2, 3, 4}; + int[] perm = {1, 0, 2}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute3D_Identity() { + int[] shape = {5, 5, 5}; + int[] perm = {0, 1, 2}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute4D_Reverse() { + int[] shape = {2, 3, 4, 5}; + int[] perm = {3, 2, 1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermuteHighRank() { + int[] shape = {2, 2, 2, 2, 2, 2}; + int[] perm = {5, 0, 4, 1, 3, 2}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testLargeBlockLogic_Mocked() { + int[] shape = {10, 10, 10}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + DenseBlock originalDB = in.getDenseBlock(); + DenseBlock spyDB = Mockito.spy(originalDB); + Mockito.when(spyDB.numBlocks()).thenReturn(2); + in.setDenseBlock(spyDB); + + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + MatrixBlock originalIn = generateMatrixBlock(shape); + verifyPermutation(originalIn, out, shape, perm); + } + + @Test + public void testLargeBlockLogic_Mocked_InputAndOutput() { + int[] shape = {4, 4, 4}; + int[] perm = {2, 1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + DenseBlock spyIn = Mockito.spy(in.getDenseBlock()); + Mockito.when(spyIn.numBlocks()).thenReturn(5); + in.setDenseBlock(spyIn); + + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + MatrixBlock originalIn = generateMatrixBlock(shape); + verifyPermutation(originalIn, out, shape, perm); + } + + @Test + public void testPermute3D_Parallel() { + int[] shape = {100, 100, 100}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm, -1); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPerformance_SingleVsMultiThreaded() { + int size = 100; + int[] shape = {size, size, size}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + + long startSingle = System.nanoTime(); + MatrixBlock outSingle = LibMatrixReorg.permute(in, shape, perm, 1); + long timeSingle = System.nanoTime() - startSingle; + + long startMulti = System.nanoTime(); + MatrixBlock outMulti = LibMatrixReorg.permute(in, shape, perm, -1); + long timeMulti = System.nanoTime() - startMulti; + + verifyPermutation(in, outSingle, shape, perm); + verifyPermutation(in, outMulti, shape, perm); + + System.out.println("Large Matrix (" + size + "x" + size + "x" + size + "):"); + System.out.println("Single-threaded: " + timeSingle / 1_000_000 + " ms"); + System.out.println("Multi-threaded: " + timeMulti / 1_000_000 + " ms"); + System.out.println("Speedup: " + String.format("%.2fx", (double)timeSingle / timeMulti)); + + Assert.assertTrue("Multi-threaded should be faster for large matrices", timeMulti < timeSingle); + } + + @Test + public void testPerformance_LargeMatrix_SingleVsMulti() { + int[] shape = {1, 10000, 10000}; + int[] perm = {0, 2, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + + long startSingle = System.nanoTime(); + MatrixBlock outSingle = LibMatrixReorg.permute(in, shape, perm, 1); + long timeSingle = System.nanoTime() - startSingle; + + long startMulti = System.nanoTime(); + MatrixBlock outMulti = LibMatrixReorg.permute(in, shape, perm, -1); + long timeMulti = System.nanoTime() - startMulti; + + System.out.println("Large Matrix (" + 1 + "x" + 10000 + "x" + 100000 + "):"); + System.out.println("Single-threaded: " + timeSingle / 1_000_000 + " ms"); + System.out.println("Multi-threaded: " + timeMulti / 1_000_000 + " ms"); + System.out.println("Speedup: " + String.format("%.2fx", (double)timeSingle / timeMulti)); + + Assert.assertTrue("Multi-threaded should be faster for large matrices", timeMulti < timeSingle); + } + + @Test + public void testPerformance_PermuteVsNativeTranspose() { + int size = 1000; + MatrixBlock in = new MatrixBlock(size, size, false); + in.allocateDenseBlock(); + double[] data = in.getDenseBlockValues(); + for (int i = 0; i < size; i++) { + for (int j = 0; j < size; j++) { + data[i * size + j] = i * size + j; + } + } + + int[] shape = {size, size}; + int[] perm = {1, 0}; + + long startPermute = System.nanoTime(); + MatrixBlock outPermute = LibMatrixReorg.permute(in, shape, perm, -1); + long timePermute = System.nanoTime() - startPermute; + + long startTranspose = System.nanoTime(); + MatrixBlock outTranspose = LibMatrixReorg.transpose(in); + long timeTranspose = System.nanoTime() - startTranspose; + + System.out.println("Transpose Performance (" + size + "x" + size + "):"); + System.out.println("Permute function: " + timePermute / 1_000_000 + " ms"); + System.out.println("Native transpose: " + timeTranspose / 1_000_000 + " ms"); + System.out.println("Ratio: " + String.format("%.2fx", (double)timePermute / timeTranspose)); + + double[] permuteData = outPermute.getDenseBlockValues(); + + for (int i = 0; i < size; i++) { + for (int j = 0; j < size; j++) { + double expected = in.get(j, i); + double actual = permuteData[i * size + j]; + Assert.assertEquals("Mismatch at (" + i + "," + j + ")", expected, actual, 0.0001); + } + } + } + + @Test + public void testEdgeCase_SingleElement() { + int[] shape = {1, 1, 1}; + int[] perm = {2, 1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testEdgeCase_OneDimensionOne() { + int[] shape = {5, 1, 10}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testEdgeCase_TwoDimensionsOne() { + int[] shape = {1, 1, 100}; + int[] perm = {2, 1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testConsecutivePermutations() { + int[] shape = {3, 4, 5}; + int[] perm1 = {1, 0, 2}; + int[] perm2 = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock temp = LibMatrixReorg.permute(in, shape, perm1); + + int[] tempShape = {shape[perm1[0]], shape[perm1[1]], shape[perm1[2]]}; + MatrixBlock out = LibMatrixReorg.permute(temp, tempShape, perm2); + + int[] finalShape = {tempShape[perm2[0]], tempShape[perm2[1]], tempShape[perm2[2]]}; + + verifyPermutation(temp, out, tempShape, perm2); + } + + @Test + public void testDifferentThreadCounts() { + int[] shape = {50, 50, 50}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + + MatrixBlock out1 = LibMatrixReorg.permute(in, shape, perm, 1); + MatrixBlock out2 = LibMatrixReorg.permute(in, shape, perm, 2); + MatrixBlock out4 = LibMatrixReorg.permute(in, shape, perm, 4); + MatrixBlock out8 = LibMatrixReorg.permute(in, shape, perm, 8); + + double[] data1 = out1.getDenseBlockValues(); + double[] data2 = out2.getDenseBlockValues(); + double[] data4 = out4.getDenseBlockValues(); + double[] data8 = out8.getDenseBlockValues(); + + for (int i = 0; i < data1.length; i++) { + Assert.assertEquals(data1[i], data2[i], 0.0001); + Assert.assertEquals(data1[i], data4[i], 0.0001); + Assert.assertEquals(data1[i], data8[i], 0.0001); + } + } + + @Test + public void testPermute_AllDimensionsCyclic() { + int[] shape = {3, 4, 5, 2}; + int[] perm = {1, 2, 3, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute_NonContiguousStrides() { + int[] shape = {7, 11, 13}; + int[] perm = {2, 0, 1}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + @Test + public void testPermute_LargePrimeStrides() { + int[] shape = {17, 19}; + int[] perm = {1, 0}; + + MatrixBlock in = generateMatrixBlock(shape); + MatrixBlock out = LibMatrixReorg.permute(in, shape, perm); + + verifyPermutation(in, out, shape, perm); + } + + private MatrixBlock generateMatrixBlock(int[] shape) { + long len = 1; + for (int d : shape) len *= d; + + MatrixBlock mb = new MatrixBlock(1, (int)len, false); + mb.allocateDenseBlock(); + double[] data = mb.getDenseBlockValues(); + for (int i = 0; i < data.length; i++) { + data[i] = (double) i; + } + return mb; + } + + private void verifyPermutation(MatrixBlock in, MatrixBlock out, int[] inShape, int[] perm) { + double[] inData = new double[(int)(in.getNumRows() * in.getNumColumns())]; + double[] outData = new double[(int)(out.getNumRows() * out.getNumColumns())]; + + DenseBlock inDB = in.getDenseBlock(); + DenseBlock outDB = out.getDenseBlock(); + + if (inDB != null) { + int inBlockSize = inDB.blockSize(); + for (int i = 0; i < inDB.numBlocks(); i++) { + double[] block = inDB.valuesAt(i); + int offset = i * inBlockSize; + int len = Math.min(inBlockSize, inData.length - offset); + System.arraycopy(block, 0, inData, offset, len); + } + } + + if (outDB != null) { + int outBlockSize = outDB.blockSize(); + for (int i = 0; i < outDB.numBlocks(); i++) { + double[] block = outDB.valuesAt(i); + int offset = i * outBlockSize; + int len = Math.min(outBlockSize, outData.length - offset); + System.arraycopy(block, 0, outData, offset, len); + } + } + + int rank = inShape.length; + int[] outShape = new int[rank]; + for (int i = 0; i < rank; i++) + outShape[i] = inShape[perm[i]]; + + long[] outStrides = getStrides(outShape); + long[] inStrides = getStrides(inShape); + + long len = 1; + for (int d : outShape) len *= d; + + for (long i = 0; i < len; i++) { + int[] outCoords = new int[rank]; + long temp = i; + for (int d = 0; d < rank; d++) { + outCoords[d] = (int)(temp / outStrides[d]); + temp = temp % outStrides[d]; + } + + int[] inCoords = new int[rank]; + for (int d = 0; d < rank; d++) { + inCoords[perm[d]] = outCoords[d]; + } + + long inIndex = 0; + for (int d = 0; d < rank; d++) { + inIndex += inCoords[d] * inStrides[d]; + } + + double expectedValue = inData[(int)inIndex]; + double actualValue = outData[(int)i]; + + if (Math.abs(expectedValue - actualValue) > 0.0001) { + Assert.fail("Mismatch at linear output index " + i + + ". Output coords " + Arrays.toString(outCoords) + + ". Input coords " + Arrays.toString(inCoords) + + ". Expected " + expectedValue + " but got " + actualValue); + } + } + } + + private long[] getStrides(int[] dims) { + long[] strides = new long[dims.length]; + long stride = 1; + for (int i = dims.length - 1; i >= 0; i--) { + strides[i] = stride; + stride *= dims[i]; + } + return strides; + } +} \ No newline at end of file diff --git a/src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java b/src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java deleted file mode 100644 index 384c6e57c0e..00000000000 --- a/src/test/java/org/apache/sysds/test/component/tensor/PermuteTest.java +++ /dev/null @@ -1,577 +0,0 @@ -package org.apache.sysds.test.component.tensor; - -import org.junit.Assert; -import org.junit.Test; -import org.apache.sysds.runtime.matrix.data.MatrixBlock; -import org.apache.sysds.runtime.data.DenseBlock; -//import org.apache.sysds.runtime.data.DenseBlockFactory; -import org.mockito.Mockito; -import java.util.Arrays; -import java.util.concurrent.ExecutorService; -import org.apache.sysds.runtime.util.CommonThreadPool; -import java.util.concurrent.Future; -import java.util.ArrayList; - - -public class PermuteTest { - - @Test - public void TestMatrixBlockPermute() { - - int[] shape = {2, 3, 4}; - - MatrixBlock tensor = TensorUtils.createArangeMatrixBlock(shape); - Assert.assertEquals(24, tensor.getNumRows() * tensor.getNumColumns()); - - double[] data = tensor.getDenseBlockValues(); - Assert.assertEquals(23.0, data[1 * 4 * 3 + 2 * 4 + 3], 0.001); - Assert.assertEquals( 0.0, data[0 * 4 * 3 + 0 * 4 + 0], 0.001); - - TensorUtils.printMatrixTensor(tensor, shape); - - int[] permutation = {1, 0, 2}; - - MatrixBlock outTensor = PermuteIt.permute(tensor, shape, permutation); - int[] outShape = {3, 2, 4}; - - TensorUtils.printMatrixTensor(outTensor, outShape); - - double[] outData = outTensor.getDenseBlockValues(); - Assert.assertEquals(24, 1 * outTensor.getNumColumns()); - Assert.assertEquals(24, outData.length); - Assert.assertEquals(4.0, outData[8], 0.001); - Assert.assertEquals(15.0, outData[7], 0.001); - } - - @Test - public void testPermute2D_Transpose() { - int[] shape = {10, 5}; - int[] perm = {1, 0}; - - MatrixBlock in = generateMatrixBlock(shape); - MatrixBlock out = PermuteIt.permute(in, shape, perm); - - verifyPermutation(in, out, shape, perm); - } - - @Test - public void testPermute3D_Simple() { - int[] shape = {2, 3, 4}; - int[] perm = {1, 0, 2}; - - MatrixBlock in = generateMatrixBlock(shape); - MatrixBlock out = PermuteIt.permute(in, shape, perm); - - verifyPermutation(in, out, shape, perm); - } - - @Test - public void testPermute3D_Identity() { - int[] shape = {5, 5, 5}; - int[] perm = {0, 1, 2}; - - MatrixBlock in = generateMatrixBlock(shape); - MatrixBlock out = PermuteIt.permute(in, shape, perm); - - verifyPermutation(in, out, shape, perm); - } - - @Test - public void testPermute4D_Reverse() { - int[] shape = {2, 3, 4, 5}; - int[] perm = {3, 2, 1, 0}; - - MatrixBlock in = generateMatrixBlock(shape); - MatrixBlock out = PermuteIt.permute(in, shape, perm); - - verifyPermutation(in, out, shape, perm); - } - - @Test - public void testPermuteHighRank() { - int[] shape = {2, 2, 2, 2, 2, 2}; - int[] perm = {5, 0, 4, 1, 3, 2}; - - MatrixBlock in = generateMatrixBlock(shape); - MatrixBlock out = PermuteIt.permute(in, shape, perm); - - verifyPermutation(in, out, shape, perm); - } - - - @Test - public void testLargeBlockLogic_Mocked() { - int[] shape = {10, 10, 10}; - int[] perm = {2, 0, 1}; - - MatrixBlock in = generateMatrixBlock(shape); - - DenseBlock originalDB = in.getDenseBlock(); - DenseBlock spyDB = Mockito.spy(originalDB); - Mockito.when(spyDB.numBlocks()).thenReturn(2); - - in.setDenseBlock(spyDB); - - MatrixBlock out = PermuteIt.permute(in, shape, perm); - - MatrixBlock originalIn = generateMatrixBlock(shape); - verifyPermutation(originalIn, out, shape, perm); - } - - @Test - public void testLargeBlockLogic_Mocked_InputAndOutput() { - int[] shape = {4, 4, 4}; - int[] perm = {2, 1, 0}; - - MatrixBlock in = generateMatrixBlock(shape); - DenseBlock spyIn = Mockito.spy(in.getDenseBlock()); - Mockito.when(spyIn.numBlocks()).thenReturn(5); - in.setDenseBlock(spyIn); - - MatrixBlock out = PermuteIt.permute(in, shape, perm); - - MatrixBlock originalIn = generateMatrixBlock(shape); - verifyPermutation(originalIn, out, shape, perm); - } - - @Test - public void testPermute3D_Parallel() { - int[] shape = {100, 100, 100}; - int[] perm = {2, 0, 1}; - - MatrixBlock in = generateMatrixBlock(shape); - MatrixBlock out = PermuteIt.permute(in, shape, perm, -1); - - verifyPermutation(in, out, shape, perm); - } - - - private MatrixBlock generateMatrixBlock(int[] shape) { - long len = 1; - for (int d : shape) len *= d; - - MatrixBlock mb = new MatrixBlock(1, (int)len, false); - mb.allocateDenseBlock(); - double[] data = mb.getDenseBlockValues(); - for(int i = 0; i < data.length; i++) { - data[i] = (double)i; - } - return mb; - } - - - private void verifyPermutation(MatrixBlock in, MatrixBlock out, int[] inShape, int[] perm) { - - double[] inData = new double[(int)(in.getNumRows() * in.getNumColumns())]; - double[] outData = new double[(int)(out.getNumRows() * out.getNumColumns())]; - - DenseBlock inDB = in.getDenseBlock(); - DenseBlock outDB = out.getDenseBlock(); - - if (inDB != null) { - int inBlockSize = inDB.blockSize(); - for (int i = 0; i < inDB.numBlocks(); i++) { - double[] block = inDB.valuesAt(i); - int offset = i * inBlockSize; - int len = Math.min(inBlockSize, inData.length - offset); - System.arraycopy(block, 0, inData, offset, len); - } - } - - if (outDB != null) { - int outBlockSize = outDB.blockSize(); - for (int i = 0; i < outDB.numBlocks(); i++) { - double[] block = outDB.valuesAt(i); - int offset = i * outBlockSize; - int len = Math.min(outBlockSize, outData.length - offset); - System.arraycopy(block, 0, outData, offset, len); - } - } - - int rank = inShape.length; - int[] outShape = new int[rank]; - for(int i=0; i 0.0001) { - Assert.fail("Mismatch at linear output index " + i + - ". Output coords " + Arrays.toString(currentCoords) + - ". Input coords " + Arrays.toString(inCoords) + - ". Expected " + expectedValue + " but got " + actualValue); - } - } - } - - private long[] getStrides(int[] dims) { - long[] strides = new long[dims.length]; - long stride = 1; - for (int i = dims.length - 1; i >= 0; i--) { - strides[i] = stride; - stride *= dims[i]; - } - return strides; - } - - - public static class TensorUtils { - - public static MatrixBlock createArangeMatrixBlock(int[] shape) { - long length = 1; - for (int d : shape) length *= d; - - MatrixBlock mb = new MatrixBlock(1, (int)length, false); - mb.allocateDenseBlock(); - - double[] data = mb.getDenseBlockValues(); - for (int i = 0; i < data.length; i++) { - data[i] = (double) i; - } - return mb; - } - - public static void printMatrixTensor(MatrixBlock mb, int[] shape) { - double[] data = mb.getDenseBlockValues(); - StringBuilder sb = new StringBuilder(); - sb.append("MatrixBlock-Tensor(").append(Arrays.toString(shape)).append("):\n"); - printRecursive(data, shape, 0, 0, sb, 0); - System.out.println(sb.toString()); - } - - private static void printRecursive(double[] data, int[] shape, int dim, int offset, StringBuilder sb, int indent) { - int stride = 1; - for (int i = dim + 1; i < shape.length; i++) stride *= shape[i]; - - for (int k = 0; k < indent; k++) sb.append(" "); - sb.append("["); - - if (dim == shape.length - 1) { - for (int i = 0; i < shape[dim]; i++) { - sb.append(String.format("%.1f", data[offset + i])); - if (i < shape[dim] - 1) sb.append(", "); - } - sb.append("]"); - } else { - sb.append("\n"); - for (int i = 0; i < shape[dim]; i++) { - printRecursive(data, shape, dim + 1, offset + i * stride, sb, indent + 2); - if (i < shape[dim] - 1) { - sb.append(","); - sb.append("\n"); - if (shape.length - dim > 2) sb.append("\n"); - } - } - sb.append("\n"); - for (int k = 0; k < indent; k++) sb.append(" "); - sb.append("]"); - } - } - } - - - public static class PermuteIt { - - // blocking according to typical L2 cache sizes - private static final int BLOCK_SIZE = 128; - private static final int PAR_NUMCELL_THRESHOLD = 1024; //1024*1024 - - //Aus LibMatrixReorg - static void transposeRow(double[] a, double[] c, int aix, int cix, int n2, int len) { - final int bn = len % 8; - for (int j = 0; j < bn; j++, aix++, cix += n2) - c[cix] = a[aix]; - for (int j = bn; j < len; j += 8, aix += 8, cix += 8 * n2) { - c[cix + 0 * n2] = a[aix + 0]; - c[cix + 1 * n2] = a[aix + 1]; - c[cix + 2 * n2] = a[aix + 2]; - c[cix + 3 * n2] = a[aix + 3]; - c[cix + 4 * n2] = a[aix + 4]; - c[cix + 5 * n2] = a[aix + 5]; - c[cix + 6 * n2] = a[aix + 6]; - c[cix + 7 * n2] = a[aix + 7]; - } - } - - private static long[] getStrides(int[] dims) { - long[] strides = new long[dims.length]; - long stride = 1; - for( int i = dims.length - 1; i >= 0; i-- ) { - strides[i] = stride; - stride *= dims[i]; - } - return strides; - } - - public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm) { - return permute(in, inDims, perm, 1); - } - - public static MatrixBlock permute(MatrixBlock in, int[] inDims, int[] perm, int k) { - int rank = inDims.length; - - // Early opt out - boolean isIdentity = true; - for( int i = 0; i < rank; i++ ) { - if( perm[i] != i ) { - isIdentity = false; - break; - } - } - - if( isIdentity ) { - return new MatrixBlock(in); - } - - int[] outDims = new int[rank]; - for( int i = 0; i < rank; i++ ) - outDims[i] = inDims[perm[i]]; - - long length = 1; - for( int d : outDims ) length *= d; - - MatrixBlock out = new MatrixBlock(1, (int)length, false); - out.allocateDenseBlock(); - - DenseBlock inDB = in.getDenseBlock(); - DenseBlock outDB = out.getDenseBlock(); - - long[] inStrides = getStrides(inDims); - long[] outStrides = getStrides(outDims); - - long[] permutedStrides = new long[rank]; - for( int i = 0; i < rank; i++ ) { - permutedStrides[i] = outStrides[perm[i]]; - } - - boolean useParallel = (k > 1 || k == -1) && length >= PAR_NUMCELL_THRESHOLD; - int numThreads = k == -1 ? Runtime.getRuntime().availableProcessors() : k; - - if( inDB.numBlocks() == 1 && outDB.numBlocks() == 1 ) { - double[] inData = inDB.valuesAt(0); - double[] outData = outDB.valuesAt(0); - - if( useParallel && rank > 0 ) { - parallelPermuteSingleBlock(inData, outData, inDims, inStrides, - permutedStrides, numThreads); - } else { - recursivePermuteSingleBlock(inData, outData, inDims, inStrides, - permutedStrides, 0, 0, 0); - } - } else { - if( useParallel && rank > 0 ) { - parallelPermuteMultiBlock(inDB, outDB, inDims, inStrides, - permutedStrides, numThreads); - } else { - recursivePermuteMultiBlock(inDB, outDB, inDims, inStrides, - permutedStrides, 0, 0L, 0L); - } - } - return out; - } - - private static void recursivePermuteSingleBlock( - double[] inData, double[] outData, - int[] inDims, long[] inStrides, long[] permutedStrides, - int dim, int inOffset, int outOffset) { - - if( dim == inDims.length - 1 ) { - int len = inDims[dim]; - int outStride = (int) permutedStrides[dim]; - - if( outStride == 1 ) { - System.arraycopy(inData, inOffset, outData, outOffset, len); - } else { - transposeRow(inData, outData, inOffset, outOffset, outStride, len); - } - return; - } - - int dimSize = inDims[dim]; - long inStep = inStrides[dim]; - long outStep = permutedStrides[dim]; - - //blocked execution - for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) { - int bimin = Math.min(bi + BLOCK_SIZE, dimSize); - for( int i = bi; i < bimin; i++ ) { - recursivePermuteSingleBlock( - inData, outData, inDims, inStrides, permutedStrides, - dim + 1, - inOffset + (int)(i * inStep), - outOffset + (int)(i * outStep) - ); - } - } - } - - private static void parallelPermuteSingleBlock( - double[] inData, double[] outData, - int[] inDims, long[] inStrides, long[] permutedStrides, - int k) { - - final int dimSize = inDims[0]; - final int tasksPerThread = Math.max(1, dimSize / k); - - // Set up thread pool - final ExecutorService pool = CommonThreadPool.get(k); - try { - final ArrayList> tasks = new ArrayList<>(); - - for( int t = 0; t < k; t++ ) { - final int start = t * tasksPerThread; - final int end = (t == k - 1) ? dimSize : (t + 1) * tasksPerThread; - - if( start >= dimSize ) break; - - tasks.add(pool.submit(() -> { - for( int i = start; i < end; i++ ) { - recursivePermuteSingleBlock( - inData, outData, inDims, inStrides, permutedStrides, - 1, - (int)(i * inStrides[0]), - (int)(i * permutedStrides[0]) - ); - } - })); - } - - // Wait for all threads - for (Future task : tasks){ //pool.invokeAll(tasks)) { - task.get(); - } - } catch (Exception ex) { - throw new RuntimeException(ex); - } finally { - pool.shutdown(); - } - } - - private static void recursivePermuteMultiBlock( - DenseBlock inDB, DenseBlock outDB, - int[] inDims, long[] inStrides, long[] permutedStrides, - int dim, long inOffset, long outOffset) { - - if (dim == inDims.length - 1 ) { - int len = inDims[dim]; - long outStride = permutedStrides[dim]; - - int inBlockSize = inDB.blockSize(); - int outBlockSize = outDB.blockSize(); - - for( int i = 0; i < len; i++ ) { - long currentInAbs = inOffset + i * inStrides[dim]; - long currentOutAbs = outOffset + i * outStride; - - int inBlockIdx = (int) (currentInAbs / inBlockSize); - int inRelIdx = (int) (currentInAbs % inBlockSize); - - int outBlockIdx = (int) (currentOutAbs / outBlockSize); - int outRelIdx = (int) (currentOutAbs % outBlockSize); - - double[] inArr = inDB.valuesAt(inBlockIdx); - double[] outArr = outDB.valuesAt(outBlockIdx); - - if( inArr != null && outArr != null && - inRelIdx < inArr.length && outRelIdx < outArr.length ) { - outArr[outRelIdx] = inArr[inRelIdx]; - } - } - return; - } - - int dimSize = inDims[dim]; - long inStep = inStrides[dim]; - long outStep = permutedStrides[dim]; - - //blocked execution - for( int bi = 0; bi < dimSize; bi += BLOCK_SIZE ) { - int bimin = Math.min(bi + BLOCK_SIZE, dimSize); - for( int i = bi; i < bimin; i++ ) { - recursivePermuteMultiBlock( - inDB, outDB, inDims, inStrides, permutedStrides, - dim + 1, - inOffset + i * inStep, - outOffset + i * outStep - ); - } - } - } - - private static void parallelPermuteMultiBlock( - DenseBlock inDB, DenseBlock outDB, - int[] inDims, long[] inStrides, long[] permutedStrides, - int k) { - - final int dimSize = inDims[0]; - final int tasksPerThread = Math.max(1, dimSize / k); - - // Set up thread pool - final ExecutorService pool = CommonThreadPool.get(k); - try { - final ArrayList> tasks = new ArrayList<>(); - - for (int t = 0; t < k; t++) { - final int start = t * tasksPerThread; - final int end = (t == k - 1) ? dimSize : (t + 1) * tasksPerThread; - - if (start >= dimSize) break; - - tasks.add(pool.submit(() -> { - for (int i = start; i < end; i++) { - recursivePermuteMultiBlock( - inDB, outDB, inDims, inStrides, permutedStrides, - 1, - i * inStrides[0], - i * permutedStrides[0] - ); - } - })); - } - - // Wait for all threads - for (Future task : tasks) { - task.get(); - } - } catch (Exception ex) { - throw new RuntimeException(ex); - } finally { - pool.shutdown(); - } - } - } -} - - -//callable -//naming -//class vs func -//Wo reinfügen -//Genauer wie abgeben -//Bezüglich phd -