diff --git a/src/ArrayFire/Data.hs b/src/ArrayFire/Data.hs index 9a183ce..263f310 100644 --- a/src/ArrayFire/Data.hs +++ b/src/ArrayFire/Data.hs @@ -41,7 +41,6 @@ import Foreign.Marshal hiding (void) import Foreign.Ptr (Ptr) import Foreign.Storable import System.IO.Unsafe -import Unsafe.Coerce import Data.Bits @@ -60,7 +59,7 @@ import ArrayFire.Arith -- [1 1 1 1] -- -1 bitNot - :: (AFType a, Bits a) + :: forall a. (AFType a, Bits a, Integral a) => Array a -> Array a bitNot arr = arr `bitXor` ones @@ -72,148 +71,78 @@ bitNot arr = arr `bitXor` ones , fromIntegral d2 , fromIntegral d3 ] - (complement zeroBits) + (fromIntegral (complement (zeroBits :: a))) --- | Creates an 'Array' from a scalar value from given dimensions --- --- >>> constant @Double [2,2] 2.0 --- ArrayFire Array --- [2 2 1 1] --- 2.0000 2.0000 --- 2.0000 2.0000 +-- | Creates a constant 'Array' filled with a 'Double' scalar. +-- ArrayFire converts the value to the element type internally. +-- Use 'constantComplex' for complex arrays, 'constantLong' / 'constantULong' +-- for 64-bit integer arrays where the value exceeds 2^53. constant - :: forall a . AFType a - => [Int] - -- ^ Dimensions - -> a - -- ^ Scalar value + :: forall a. AFType a + => [Int] -- ^ Dimensions + -> Double -- ^ Scalar value -> Array a {-# NOINLINE constant #-} constant dims val = - case dtyp of - x | x == c64 -> - cast $ constantComplex dims (unsafeCoerce val :: Complex Double) - | x == c32 -> - cast $ constantComplex dims (unsafeCoerce val :: Complex Float) - | x == s64 -> - cast $ constantLong dims (unsafeCoerce val :: Int) - | x == u64 -> - cast $ constantULong dims (unsafeCoerce val :: Word64) - | x == s32 -> - constant' dims (fromIntegral (unsafeCoerce val :: Int32) :: Double) - | x == s16 -> - constant' dims (fromIntegral (unsafeCoerce val :: Int16) :: Double) - | x == u32 -> - constant' dims (fromIntegral (unsafeCoerce val :: Word32) :: Double) - | x == u8 -> - constant' dims (fromIntegral (unsafeCoerce val :: Word8) :: Double) - | x == u16 -> - constant' dims (fromIntegral (unsafeCoerce val :: Word16) :: Double) - | x == f64 -> - constant' dims (unsafeCoerce val :: Double) - | x == b8 -> - constant' dims (fromIntegral (unsafeCoerce val :: CBool) :: Double) - | x == f32 -> - constant' dims (realToFrac (unsafeCoerce val :: Float)) - | otherwise -> error "constant: Invalid array fire type" + unsafePerformIO . mask_ $ do + ptr <- calloca $ \ptrPtr -> do + withArray (fromIntegral <$> dims) $ \dimArray -> do + throwAFError =<< af_constant ptrPtr val n dimArray dtyp + peek ptrPtr + Array <$> newForeignPtr af_release_array_finalizer ptr where + n = fromIntegral (length dims) dtyp = afType (Proxy @a) - -- Creates the array directly with the target dtype: @af_constant@ takes - -- the value as a C double for every non-complex, non-64-bit-integral - -- dtype. Routing through an f64 array and casting (as this used to do) - -- fails with AF_ERR_NO_DBL on OpenCL devices without fp64 support and - -- changes b8 semantics (the cast normalises non-zero values to 1). - constant' - :: [Int] - -- ^ Dimensions - -> Double - -- ^ Scalar value - -> Array a - constant' dims' val' = - unsafePerformIO . mask_ $ do - ptr <- calloca $ \ptrPtr -> do - withArray (fromIntegral <$> dims') $ \dimArray -> do - throwAFError =<< af_constant ptrPtr val' n dimArray dtyp - peek ptrPtr - Array <$> - newForeignPtr - af_release_array_finalizer - ptr - where - n = fromIntegral (length dims') - - -- | Creates an 'Array (Complex Double)' from a scalar val'ue - -- - -- @ - -- >>> constantComplex [2,2] (2.0 :+ 2.0) - -- @ - -- - constantComplex - :: forall arr . (Real arr, AFType (Complex arr)) - => [Int] - -- ^ Dimensions - -> Complex arr - -- ^ Scalar val'ue - -> Array (Complex arr) - constantComplex dims' ((realToFrac -> x) :+ (realToFrac -> y)) = unsafePerformIO . mask_ $ do - ptr <- calloca $ \ptrPtr -> do - withArray (fromIntegral <$> dims') $ \dimArray -> do - throwAFError =<< af_constant_complex ptrPtr x y n dimArray typ - peek ptrPtr - Array <$> - newForeignPtr - af_release_array_finalizer - ptr - where - n = fromIntegral (length dims') - typ = afType (Proxy @(Complex arr)) +-- | Creates a constant complex 'Array' from a 'Complex' scalar. +constantComplex + :: forall r. (Real r, AFType (Complex r)) + => [Int] -- ^ Dimensions + -> Complex r -- ^ Scalar value + -> Array (Complex r) +{-# NOINLINE constantComplex #-} +constantComplex dims ((realToFrac -> x) :+ (realToFrac -> y)) = + unsafePerformIO . mask_ $ do + ptr <- calloca $ \ptrPtr -> do + withArray (fromIntegral <$> dims) $ \dimArray -> do + throwAFError =<< af_constant_complex ptrPtr x y n dimArray typ + peek ptrPtr + Array <$> newForeignPtr af_release_array_finalizer ptr + where + n = fromIntegral (length dims) + typ = afType (Proxy @(Complex r)) - -- | Creates an 'Array Int64' from a scalar val'ue - -- - -- @ - -- >>> constantLong [2,2] 2.0 - -- @ - -- - constantLong - :: [Int] - -- ^ Dimensions - -> Int - -- ^ Scalar val'ue - -> Array Int - constantLong dims' val' = unsafePerformIO . mask_ $ do - ptr <- calloca $ \ptrPtr -> do - withArray (fromIntegral <$> dims') $ \dimArray -> do - throwAFError =<< af_constant_long ptrPtr (fromIntegral val') n dimArray - peek ptrPtr - Array <$> - newForeignPtr - af_release_array_finalizer - ptr - where - n = fromIntegral (length dims') +-- | Creates a constant 'Array' of 64-bit signed integers. +-- Preserves the full integer value without 'Double' rounding. +constantLong + :: [Int] -- ^ Dimensions + -> Int -- ^ Scalar value + -> Array Int +{-# NOINLINE constantLong #-} +constantLong dims val = + unsafePerformIO . mask_ $ do + ptr <- calloca $ \ptrPtr -> do + withArray (fromIntegral <$> dims) $ \dimArray -> do + throwAFError =<< af_constant_long ptrPtr (fromIntegral val) n dimArray + peek ptrPtr + Array <$> newForeignPtr af_release_array_finalizer ptr + where n = fromIntegral (length dims) - -- | Creates an 'Array Word64' from a scalar val'ue - -- - -- @ - -- >>> constantULong [2,2] 2.0 - -- @ - -- - constantULong - :: [Int] - -> Word64 - -> Array Word64 - constantULong dims' val' = unsafePerformIO . mask_ $ do - ptr <- calloca $ \ptrPtr -> do - withArray (fromIntegral <$> dims') $ \dimArray -> do - throwAFError =<< af_constant_ulong ptrPtr (fromIntegral val') n dimArray - peek ptrPtr - Array <$> - newForeignPtr - af_release_array_finalizer - ptr - where - n = fromIntegral (length dims') +-- | Creates a constant 'Array' of 64-bit unsigned integers. +-- Preserves the full integer value without 'Double' rounding. +constantULong + :: [Int] -- ^ Dimensions + -> Word64 -- ^ Scalar value + -> Array Word64 +{-# NOINLINE constantULong #-} +constantULong dims val = + unsafePerformIO . mask_ $ do + ptr <- calloca $ \ptrPtr -> do + withArray (fromIntegral <$> dims) $ \dimArray -> do + throwAFError =<< af_constant_ulong ptrPtr (fromIntegral val) n dimArray + peek ptrPtr + Array <$> newForeignPtr af_release_array_finalizer ptr + where n = fromIntegral (length dims) -- | Creates a range of values in an Array -- diff --git a/src/ArrayFire/Internal/Types.hsc b/src/ArrayFire/Internal/Types.hsc index 4c0082a..d361ed9 100644 --- a/src/ArrayFire/Internal/Types.hsc +++ b/src/ArrayFire/Internal/Types.hsc @@ -126,47 +126,20 @@ newtype Window = Window (ForeignPtr ()) class Storable a => AFType a where afType :: Proxy a -> AFDtype -instance AFType Double where - afType Proxy = f64 - -instance AFType Float where - afType Proxy = f32 - -instance AFType (Complex Double) where - afType Proxy = c64 - -instance AFType (Complex Float) where - afType Proxy = c32 - -instance AFType CBool where - afType Proxy = b8 - -instance AFType Int32 where - afType Proxy = s32 - -instance AFType Word32 where - afType Proxy = u32 - -instance AFType Word8 where - afType Proxy = u8 - -instance AFType Int64 where - afType Proxy = s64 - -instance AFType Int where - afType Proxy = s64 - -instance AFType Int16 where - afType Proxy = s16 - -instance AFType Word16 where - afType Proxy = u16 - -instance AFType Word64 where - afType Proxy = u64 - -instance AFType Word where - afType Proxy = u64 +instance AFType Double where afType Proxy = f64 +instance AFType Float where afType Proxy = f32 +instance AFType (Complex Double) where afType Proxy = c64 +instance AFType (Complex Float) where afType Proxy = c32 +instance AFType CBool where afType Proxy = b8 +instance AFType Int32 where afType Proxy = s32 +instance AFType Word32 where afType Proxy = u32 +instance AFType Word8 where afType Proxy = u8 +instance AFType Int64 where afType Proxy = s64 +instance AFType Int where afType Proxy = s64 +instance AFType Int16 where afType Proxy = s16 +instance AFType Word16 where afType Proxy = u16 +instance AFType Word64 where afType Proxy = u64 +instance AFType Word where afType Proxy = u64 -- | Maps an ArrayFire element type to the scalar type returned by whole-array -- reductions (e.g. 'meanAll', 'det'). Real and integral element types yield diff --git a/test/ArrayFire/DataSpec.hs b/test/ArrayFire/DataSpec.hs index 7d185a7..db4619b 100644 --- a/test/ArrayFire/DataSpec.hs +++ b/test/ArrayFire/DataSpec.hs @@ -18,24 +18,67 @@ import ArrayFire hiding (not) spec :: Spec spec = describe "Data tests" $ do - it "Should create constant Array" $ do - constant @Float [1] 1 `shouldBe` 1 - constant @Double [1] 1 `shouldBe` 1 - constant @Int16 [1] 1 `shouldBe` 1 - constant @Int32 [1] 1 `shouldBe` 1 - constant @Int64 [1] 1 `shouldBe` 1 - constant @Int [1] 1 `shouldBe` 1 - constant @Word16 [1] 1 `shouldBe` 1 - constant @Word32 [1] 1 `shouldBe` 1 - constant @Word64 [1] 1 `shouldBe` 1 - constant @Word [1] 1 `shouldBe` 1 - constant @CBool [1] 1 `shouldBe` 1 - constant @(Complex Double) [1] (1.0 :+ 1.0) - `shouldBe` - constant @(Complex Double) [1] (1.0 :+ 1.0) - constant @(Complex Float) [1] (1.0 :+ 1.0) - `shouldBe` - constant @(Complex Float) [1] (1.0 :+ 1.0) + describe "constant" $ do + it "creates a scalar Float array" $ + constant @Float [1] 1 `shouldBe` scalar @Float 1 + it "creates a scalar Double array" $ + constant @Double [1] 2.5 `shouldBe` scalar @Double 2.5 + it "creates a scalar Int16 array" $ + constant @Int16 [1] 42 `shouldBe` scalar @Int16 42 + it "creates a scalar Int32 array" $ + constant @Int32 [1] (-7) `shouldBe` scalar @Int32 (-7) + it "creates a scalar Word8 array" $ + constant @Word8 [1] 255 `shouldBe` scalar @Word8 255 + it "creates a scalar Word16 array" $ + constant @Word16 [1] 1000 `shouldBe` scalar @Word16 1000 + it "creates a scalar Word32 array" $ + constant @Word32 [1] 999 `shouldBe` scalar @Word32 999 + it "creates a CBool array" $ + constant @CBool [1] 1 `shouldBe` scalar @CBool 1 + it "creates a multi-element array with correct shape" $ do + let a = constant @Double [3,3] 0 + getDims a `shouldBe` (3,3,1,1) + it "all elements equal the scalar value" $ + constant @Float [4] 3.14 `shouldBe` vector @Float 4 [3.14, 3.14, 3.14, 3.14] + + describe "constantComplex" $ do + it "creates a Complex Double array preserving imaginary part" $ + constantComplex [1] (1.0 :+ 2.0) + `shouldBe` scalar @(Complex Double) (1.0 :+ 2.0) + it "creates a Complex Float array preserving imaginary part" $ + constantComplex [1] (3.0 :+ 4.0 :: Complex Float) + `shouldBe` scalar @(Complex Float) (3.0 :+ 4.0) + it "creates a zero complex array" $ + constantComplex [2] (0 :+ 0 :: Complex Double) + `shouldBe` vector @(Complex Double) 2 [0, 0] + it "handles purely real complex values" $ + constantComplex [1] (5.0 :+ 0.0 :: Complex Double) + `shouldBe` scalar @(Complex Double) (5.0 :+ 0.0) + it "handles purely imaginary complex values" $ + constantComplex [1] (0.0 :+ 7.0 :: Complex Double) + `shouldBe` scalar @(Complex Double) (0.0 :+ 7.0) + + describe "constantLong" $ do + it "creates an Int array with value 1" $ + constantLong [1] 1 `shouldBe` scalar @Int 1 + it "creates an Int array with a negative value" $ + constantLong [1] (-42) `shouldBe` scalar @Int (-42) + it "preserves maxBound :: Int without rounding" $ + constantLong [1] maxBound `shouldBe` scalar @Int maxBound + it "preserves minBound :: Int without rounding" $ + constantLong [1] minBound `shouldBe` scalar @Int minBound + it "creates a multi-element array" $ + constantLong [3] 7 `shouldBe` vector @Int 3 [7, 7, 7] + + describe "constantULong" $ do + it "creates a Word64 array with value 1" $ + constantULong [1] 1 `shouldBe` scalar @Word64 1 + it "creates a Word64 array with value 0" $ + constantULong [1] 0 `shouldBe` scalar @Word64 0 + it "preserves maxBound :: Word64 without rounding" $ + constantULong [1] maxBound `shouldBe` scalar @Word64 maxBound + it "creates a multi-element array" $ + constantULong [3] 100 `shouldBe` vector @Word64 3 [100, 100, 100] describe "arange" $ do it "generates a sequence along dim 0 for a 1D array" $ do