Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module Simplex.Messaging.Crypto.SNTRUP761.Bindings
) where

import Control.Concurrent.STM
import Control.Exception (throwIO)
import Crypto.Random (ChaChaDRG)
import Data.Aeson (FromJSON (..), ToJSON (..))
import Data.Bifunctor (bimap)
Expand Down Expand Up @@ -55,7 +56,8 @@ sntrup761Keypair drg =
)

sntrup761Enc :: TVar ChaChaDRG -> KEMPublicKey -> IO (KEMCiphertext, KEMSharedKey)
sntrup761Enc drg (KEMPublicKey pk) =
sntrup761Enc drg (KEMPublicKey pk) = do
requireByteArrayLength "SNTRUP761 public key" c_SNTRUP761_PUBLICKEY_SIZE pk
BA.withByteArray pk $ \pkPtr ->
bimap KEMCiphertext KEMSharedKey
<$> BA.allocRet
Expand All @@ -66,39 +68,71 @@ sntrup761Enc drg (KEMPublicKey pk) =
)

sntrup761Dec :: KEMCiphertext -> KEMSecretKey -> IO KEMSharedKey
sntrup761Dec (KEMCiphertext c) (KEMSecretKey sk) =
sntrup761Dec (KEMCiphertext c) (KEMSecretKey sk) = do
requireByteArrayLength "SNTRUP761 ciphertext" c_SNTRUP761_CIPHERTEXT_SIZE c
requireByteArrayLength "SNTRUP761 secret key" c_SNTRUP761_SECRETKEY_SIZE sk
BA.withByteArray sk $ \skPtr ->
BA.withByteArray c $ \cPtr ->
KEMSharedKey
<$> BA.alloc c_SNTRUP761_SIZE (\kPtr -> c_sntrup761_dec kPtr cPtr skPtr)

requireByteArrayLength :: BA.ByteArrayAccess bytes => String -> Int -> bytes -> IO ()
requireByteArrayLength valueName expected bytes =
either (throwIO . userError) (const $ pure ()) $
validateByteArrayLength valueName expected bytes

validateByteArrayLength :: BA.ByteArrayAccess bytes => String -> Int -> bytes -> Either String bytes
validateByteArrayLength valueName expected bytes
| actual == expected = Right bytes
| otherwise = Left $ valueName <> " must be " <> show expected <> " bytes, got " <> show actual
where
actual = BA.length bytes

parseKEMPublicKey :: ByteString -> Either String KEMPublicKey
parseKEMPublicKey =
fmap KEMPublicKey . validateByteArrayLength "SNTRUP761 public key" c_SNTRUP761_PUBLICKEY_SIZE

parseKEMSecretKey :: ScrubbedBytes -> Either String KEMSecretKey
parseKEMSecretKey =
fmap KEMSecretKey . validateByteArrayLength "SNTRUP761 secret key" c_SNTRUP761_SECRETKEY_SIZE

parseKEMCiphertext :: ByteString -> Either String KEMCiphertext
parseKEMCiphertext =
fmap KEMCiphertext . validateByteArrayLength "SNTRUP761 ciphertext" c_SNTRUP761_CIPHERTEXT_SIZE

instance Encoding KEMSecretKey where
smpEncode (KEMSecretKey c) = smpEncode . Large $ BA.convert c
smpP = KEMSecretKey . BA.convert . unLarge <$> smpP
smpP = do
Large bytes <- smpP
either fail pure $ parseKEMSecretKey (BA.convert bytes)

instance StrEncoding KEMSecretKey where
strEncode (KEMSecretKey pk) = strEncode (BA.convert pk :: ByteString)
strP = KEMSecretKey . BA.convert <$> strP @ByteString
strP = either fail pure . parseKEMSecretKey . BA.convert =<< strP @ByteString

instance Encoding KEMPublicKey where
smpEncode (KEMPublicKey pk) = smpEncode . Large $ BA.convert pk
smpP = KEMPublicKey . BA.convert . unLarge <$> smpP
smpP = do
Large bytes <- smpP
either fail pure $ parseKEMPublicKey bytes

instance StrEncoding KEMPublicKey where
strEncode (KEMPublicKey pk) = strEncode (BA.convert pk :: ByteString)
strP = KEMPublicKey . BA.convert <$> strP @ByteString
strP = either fail pure . parseKEMPublicKey =<< strP @ByteString

instance Encoding KEMCiphertext where
smpEncode (KEMCiphertext c) = smpEncode . Large $ BA.convert c
smpP = KEMCiphertext . BA.convert . unLarge <$> smpP
smpP = do
Large bytes <- smpP
either fail pure $ parseKEMCiphertext bytes

instance Encoding KEMSharedKey where
smpEncode (KEMSharedKey c) = smpEncode (BA.convert c :: ByteString)
smpP = KEMSharedKey . BA.convert <$> smpP @ByteString

instance StrEncoding KEMCiphertext where
strEncode (KEMCiphertext pk) = strEncode (BA.convert pk :: ByteString)
strP = KEMCiphertext . BA.convert <$> strP @ByteString
strP = either fail pure . parseKEMCiphertext =<< strP @ByteString

instance StrEncoding KEMSharedKey where
strEncode (KEMSharedKey pk) = strEncode (BA.convert pk :: ByteString)
Expand Down
54 changes: 51 additions & 3 deletions tests/CoreTests/CryptoTests.hs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module CoreTests.CryptoTests (cryptoTests) where

import Control.Concurrent.STM
import Control.Exception (SomeException)
import Control.Monad.Except
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Either (isRight)
import Data.Either (isLeft, isRight)
import Data.Int (Int64)
import Data.List (isInfixOf)
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import qualified Data.Text.Lazy as LT
Expand All @@ -23,10 +26,13 @@ import qualified SMPClient
import qualified Simplex.Messaging.Crypto as C
import qualified Simplex.Messaging.Crypto.Lazy as LC
import Simplex.Messaging.Crypto.SNTRUP761.Bindings
import Simplex.Messaging.Crypto.SNTRUP761.Bindings.Defines
import Simplex.Messaging.Encoding (Large (..), smpDecode, smpEncode)
import Simplex.Messaging.Encoding.String (strDecode, strEncode)
import Simplex.Messaging.Transport.Client
import Test.Hspec hiding (fit, it)
import Test.Hspec.QuickCheck (modifyMaxSuccess)
import Test.QuickCheck
import Test.QuickCheck hiding (Large)
import Util

cryptoTests :: Spec
Expand Down Expand Up @@ -99,8 +105,10 @@ cryptoTests = do
describe "X448" $ testEncoding C.SX448
describe "X509 chains" $ do
it "should validate certificates" testValidateX509
describe "sntrup761" $
describe "sntrup761" $ do
it "should enc/dec key" testSNTRUP761
it "should reject malformed KEM encodings" testSNTRUP761RejectsMalformedEncodings
it "should reject malformed KEM FFI inputs" testSNTRUP761RejectsMalformedFFIInputs

instance Eq C.APublicKey where
C.APublicKey a k == C.APublicKey a' k' = case testEquality a a' of
Expand Down Expand Up @@ -271,3 +279,43 @@ testSNTRUP761 = do
(c, KEMSharedKey k) <- sntrup761Enc drg pk
KEMSharedKey k' <- sntrup761Dec c sk
k' `shouldBe` k

testSNTRUP761RejectsMalformedEncodings :: IO ()
testSNTRUP761RejectsMalformedEncodings = do
smpDecode @KEMPublicKey (smpEncode $ Large shortPublicKey) `shouldSatisfy` isLeft
strDecode @KEMPublicKey (strEncode shortPublicKey) `shouldSatisfy` isLeft
smpDecode @KEMPublicKey (smpEncode $ Large validPublicKey) `shouldSatisfy` isRight
strDecode @KEMPublicKey (strEncode validPublicKey) `shouldSatisfy` isRight
smpDecode @KEMCiphertext (smpEncode $ Large shortCiphertext) `shouldSatisfy` isLeft
strDecode @KEMCiphertext (strEncode shortCiphertext) `shouldSatisfy` isLeft
smpDecode @KEMCiphertext (smpEncode $ Large validCiphertext) `shouldSatisfy` isRight
strDecode @KEMCiphertext (strEncode validCiphertext) `shouldSatisfy` isRight
smpDecode @KEMSecretKey (smpEncode $ Large shortSecretKey) `shouldSatisfy` isLeft
strDecode @KEMSecretKey (strEncode shortSecretKey) `shouldSatisfy` isLeft

testSNTRUP761RejectsMalformedFFIInputs :: IO ()
testSNTRUP761RejectsMalformedFFIInputs = do
drg <- C.newRandom
(_, sk) <- sntrup761Keypair drg
sntrup761Enc drg (KEMPublicKey shortPublicKey)
`shouldThrow` kemLengthException "public key"
sntrup761Dec (KEMCiphertext shortCiphertext) sk
`shouldThrow` kemLengthException "ciphertext"

kemLengthException :: String -> SomeException -> Bool
kemLengthException valueName e = valueName `isInfixOf` show e

shortPublicKey :: B.ByteString
shortPublicKey = B.replicate (c_SNTRUP761_PUBLICKEY_SIZE - 1) 'p'

validPublicKey :: B.ByteString
validPublicKey = B.replicate c_SNTRUP761_PUBLICKEY_SIZE 'p'

shortCiphertext :: B.ByteString
shortCiphertext = B.replicate (c_SNTRUP761_CIPHERTEXT_SIZE - 1) 'c'

validCiphertext :: B.ByteString
validCiphertext = B.replicate c_SNTRUP761_CIPHERTEXT_SIZE 'c'

shortSecretKey :: B.ByteString
shortSecretKey = B.replicate (c_SNTRUP761_SECRETKEY_SIZE - 1) 's'