diff --git a/dataframe.cabal b/dataframe.cabal index c99a60a..f7f2a2f 100644 --- a/dataframe.cabal +++ b/dataframe.cabal @@ -106,7 +106,6 @@ library process ^>= 1.6, snappy-hs ^>= 0.1, random >= 1.2 && < 1.3, - random-shuffle >= 0.0.4 && < 1, regex-tdfa >= 1.3.0 && < 2, scientific >=0.3.1 && <0.4, template-haskell >= 2.0 && < 3, diff --git a/src/DataFrame/Operations/Permutation.hs b/src/DataFrame/Operations/Permutation.hs index a381f98..e799bfd 100644 --- a/src/DataFrame/Operations/Permutation.hs +++ b/src/DataFrame/Operations/Permutation.hs @@ -9,16 +9,18 @@ import qualified Data.List as L import qualified Data.Text as T import qualified Data.Vector as V import qualified Data.Vector.Unboxed as VU +import qualified Data.Vector.Unboxed.Mutable as VUM import Control.Exception (throw) +import Control.Monad.ST (runST) +import Data.Vector.Internal.Check (HasCallStack) import DataFrame.Errors (DataFrameException (..)) -import DataFrame.Internal.Column +import DataFrame.Internal.Column (Columnable, atIndicesStable) import DataFrame.Internal.DataFrame (DataFrame (..)) -import DataFrame.Internal.Expression -import DataFrame.Internal.Row -import DataFrame.Operations.Core -import System.Random -import System.Random.Shuffle (shuffle') +import DataFrame.Internal.Expression (Expr (Col)) +import DataFrame.Internal.Row (sortedIndexes', toRowVector) +import DataFrame.Operations.Core (columnNames, dimensions) +import System.Random (Random (randomR), RandomGen) -- | Sort order taken as a parameter by the 'sortBy' function. data SortOrder where @@ -75,5 +77,23 @@ shuffle pureGen df = in df{columns = V.map (atIndicesStable indexes) (columns df)} -shuffledIndices :: (RandomGen g) => g -> Int -> VU.Vector Int -shuffledIndices pureGen k = VU.fromList (shuffle' [0 .. (k - 1)] k pureGen) +shuffledIndices :: (HasCallStack, RandomGen g) => g -> Int -> VU.Vector Int +shuffledIndices pureGen k + | k < 0 = error $ "Vector index may not be a neative number: " <> show k + | k == 0 = VU.empty + | otherwise = shuffleVec pureGen + where + shuffleVec :: (RandomGen g) => g -> VU.Vector Int + shuffleVec g = runST $ do + vm <- VUM.generate k id + let (n, nGen) = randomR (1, k - 1) g + go vm n nGen + VU.unsafeFreeze vm + + go v (-1) _ = pure () + go v 0 _ = pure () + go v maxInd gen = + let + (n, nextGen) = randomR (1, maxInd) gen + in + VUM.swap v 0 n *> go (VUM.tail v) (maxInd - 1) nextGen diff --git a/tests/Operations/Shuffle.hs b/tests/Operations/Shuffle.hs index f1c52f4..9b609c3 100644 --- a/tests/Operations/Shuffle.hs +++ b/tests/Operations/Shuffle.hs @@ -5,7 +5,9 @@ module Operations.Shuffle where import qualified DataFrame as D -import DataFrame.Operations.Permutation (shuffle) +import qualified Data.Set as Set +import qualified Data.Vector.Unboxed as VU +import DataFrame.Operations.Permutation (shuffle, shuffledIndices) import System.Random (mkStdGen) import Test.HUnit (Test (..), assertEqual) @@ -74,6 +76,21 @@ shuffleDifferentSeedIsDifferent = (shuffled1 == shuffled2) ) +-- Test that ShuffleIndeces does not dorp, add, or repeat any index +shuffleDoesNotAddOrDropIndices :: Test +shuffleDoesNotAddOrDropIndices = + let + gen = mkStdGen 42 + actual = (Set.fromList [0 .. 10]) + computedVector = shuffledIndices gen 11 + computed = (Set.fromList $ VU.toList $ shuffledIndices gen 11) + in + TestList + [ TestCase + (assertEqual "Indecis are not dropped or added" (VU.length computedVector) 11) + , TestCase (assertEqual "There are no repeated indecis" computed actual) + ] + tests :: [Test] tests = [ TestLabel "shuffleShuffles" shuffleShuffles @@ -81,4 +98,5 @@ tests = , TestLabel "shufflePreservesColumnNames" shufflePreservesColumnNames , TestLabel "shuffleSameSeedIsSameShuffle" shuffleSameSeedIsSameShuffle , TestLabel "shuffleDifferentSeedIsDifferent" shuffleDifferentSeedIsDifferent + , TestLabel "shuffleDoesNotAddOrDropIndices" shuffleDoesNotAddOrDropIndices ]