From b846a0b8350a5287357b8f9afe9ab5af4ebe0ae8 Mon Sep 17 00:00:00 2001 From: Paul Bottinelli Date: Thu, 4 Jun 2026 11:32:09 -0400 Subject: [PATCH] Validate SNTRUP761 KEM input lengths --- .../Messaging/Crypto/SNTRUP761/Bindings.hs | 50 ++++++++++++++--- tests/CoreTests/CryptoTests.hs | 54 +++++++++++++++++-- 2 files changed, 93 insertions(+), 11 deletions(-) diff --git a/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs b/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs index 73d4c68851..c70cd41b28 100644 --- a/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs +++ b/src/Simplex/Messaging/Crypto/SNTRUP761/Bindings.hs @@ -13,6 +13,7 @@ module Simplex.Messaging.Crypto.SNTRUP761.Bindings ) where import Control.Concurrent.STM +import Control.Exception (throwIO) import Crypto.Random (ChaChaDRG) import Data.Aeson (FromJSON (..), ToJSON (..)) import Data.Bifunctor (bimap) @@ -55,7 +56,8 @@ sntrup761Keypair drg = ) sntrup761Enc :: TVar ChaChaDRG -> KEMPublicKey -> IO (KEMCiphertext, KEMSharedKey) -sntrup761Enc drg (KEMPublicKey pk) = +sntrup761Enc drg (KEMPublicKey pk) = do + requireByteArrayLength "SNTRUP761 public key" c_SNTRUP761_PUBLICKEY_SIZE pk BA.withByteArray pk $ \pkPtr -> bimap KEMCiphertext KEMSharedKey <$> BA.allocRet @@ -66,31 +68,63 @@ sntrup761Enc drg (KEMPublicKey pk) = ) sntrup761Dec :: KEMCiphertext -> KEMSecretKey -> IO KEMSharedKey -sntrup761Dec (KEMCiphertext c) (KEMSecretKey sk) = +sntrup761Dec (KEMCiphertext c) (KEMSecretKey sk) = do + requireByteArrayLength "SNTRUP761 ciphertext" c_SNTRUP761_CIPHERTEXT_SIZE c + requireByteArrayLength "SNTRUP761 secret key" c_SNTRUP761_SECRETKEY_SIZE sk BA.withByteArray sk $ \skPtr -> BA.withByteArray c $ \cPtr -> KEMSharedKey <$> BA.alloc c_SNTRUP761_SIZE (\kPtr -> c_sntrup761_dec kPtr cPtr skPtr) +requireByteArrayLength :: BA.ByteArrayAccess bytes => String -> Int -> bytes -> IO () +requireByteArrayLength valueName expected bytes = + either (throwIO . userError) (const $ pure ()) $ + validateByteArrayLength valueName expected bytes + +validateByteArrayLength :: BA.ByteArrayAccess bytes => String -> Int -> bytes -> Either String bytes +validateByteArrayLength valueName expected bytes + | actual == expected = Right bytes + | otherwise = Left $ valueName <> " must be " <> show expected <> " bytes, got " <> show actual + where + actual = BA.length bytes + +parseKEMPublicKey :: ByteString -> Either String KEMPublicKey +parseKEMPublicKey = + fmap KEMPublicKey . validateByteArrayLength "SNTRUP761 public key" c_SNTRUP761_PUBLICKEY_SIZE + +parseKEMSecretKey :: ScrubbedBytes -> Either String KEMSecretKey +parseKEMSecretKey = + fmap KEMSecretKey . validateByteArrayLength "SNTRUP761 secret key" c_SNTRUP761_SECRETKEY_SIZE + +parseKEMCiphertext :: ByteString -> Either String KEMCiphertext +parseKEMCiphertext = + fmap KEMCiphertext . validateByteArrayLength "SNTRUP761 ciphertext" c_SNTRUP761_CIPHERTEXT_SIZE + instance Encoding KEMSecretKey where smpEncode (KEMSecretKey c) = smpEncode . Large $ BA.convert c - smpP = KEMSecretKey . BA.convert . unLarge <$> smpP + smpP = do + Large bytes <- smpP + either fail pure $ parseKEMSecretKey (BA.convert bytes) instance StrEncoding KEMSecretKey where strEncode (KEMSecretKey pk) = strEncode (BA.convert pk :: ByteString) - strP = KEMSecretKey . BA.convert <$> strP @ByteString + strP = either fail pure . parseKEMSecretKey . BA.convert =<< strP @ByteString instance Encoding KEMPublicKey where smpEncode (KEMPublicKey pk) = smpEncode . Large $ BA.convert pk - smpP = KEMPublicKey . BA.convert . unLarge <$> smpP + smpP = do + Large bytes <- smpP + either fail pure $ parseKEMPublicKey bytes instance StrEncoding KEMPublicKey where strEncode (KEMPublicKey pk) = strEncode (BA.convert pk :: ByteString) - strP = KEMPublicKey . BA.convert <$> strP @ByteString + strP = either fail pure . parseKEMPublicKey =<< strP @ByteString instance Encoding KEMCiphertext where smpEncode (KEMCiphertext c) = smpEncode . Large $ BA.convert c - smpP = KEMCiphertext . BA.convert . unLarge <$> smpP + smpP = do + Large bytes <- smpP + either fail pure $ parseKEMCiphertext bytes instance Encoding KEMSharedKey where smpEncode (KEMSharedKey c) = smpEncode (BA.convert c :: ByteString) @@ -98,7 +132,7 @@ instance Encoding KEMSharedKey where instance StrEncoding KEMCiphertext where strEncode (KEMCiphertext pk) = strEncode (BA.convert pk :: ByteString) - strP = KEMCiphertext . BA.convert <$> strP @ByteString + strP = either fail pure . parseKEMCiphertext =<< strP @ByteString instance StrEncoding KEMSharedKey where strEncode (KEMSharedKey pk) = strEncode (BA.convert pk :: ByteString) diff --git a/tests/CoreTests/CryptoTests.hs b/tests/CoreTests/CryptoTests.hs index 8e4d9a2582..39868219f6 100644 --- a/tests/CoreTests/CryptoTests.hs +++ b/tests/CoreTests/CryptoTests.hs @@ -1,16 +1,19 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -Wno-orphans #-} module CoreTests.CryptoTests (cryptoTests) where import Control.Concurrent.STM +import Control.Exception (SomeException) import Control.Monad.Except import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Lazy.Char8 as LB -import Data.Either (isRight) +import Data.Either (isLeft, isRight) import Data.Int (Int64) +import Data.List (isInfixOf) import qualified Data.Text as T import Data.Text.Encoding (encodeUtf8) import qualified Data.Text.Lazy as LT @@ -23,10 +26,13 @@ import qualified SMPClient import qualified Simplex.Messaging.Crypto as C import qualified Simplex.Messaging.Crypto.Lazy as LC import Simplex.Messaging.Crypto.SNTRUP761.Bindings +import Simplex.Messaging.Crypto.SNTRUP761.Bindings.Defines +import Simplex.Messaging.Encoding (Large (..), smpDecode, smpEncode) +import Simplex.Messaging.Encoding.String (strDecode, strEncode) import Simplex.Messaging.Transport.Client import Test.Hspec hiding (fit, it) import Test.Hspec.QuickCheck (modifyMaxSuccess) -import Test.QuickCheck +import Test.QuickCheck hiding (Large) import Util cryptoTests :: Spec @@ -99,8 +105,10 @@ cryptoTests = do describe "X448" $ testEncoding C.SX448 describe "X509 chains" $ do it "should validate certificates" testValidateX509 - describe "sntrup761" $ + describe "sntrup761" $ do it "should enc/dec key" testSNTRUP761 + it "should reject malformed KEM encodings" testSNTRUP761RejectsMalformedEncodings + it "should reject malformed KEM FFI inputs" testSNTRUP761RejectsMalformedFFIInputs instance Eq C.APublicKey where C.APublicKey a k == C.APublicKey a' k' = case testEquality a a' of @@ -271,3 +279,43 @@ testSNTRUP761 = do (c, KEMSharedKey k) <- sntrup761Enc drg pk KEMSharedKey k' <- sntrup761Dec c sk k' `shouldBe` k + +testSNTRUP761RejectsMalformedEncodings :: IO () +testSNTRUP761RejectsMalformedEncodings = do + smpDecode @KEMPublicKey (smpEncode $ Large shortPublicKey) `shouldSatisfy` isLeft + strDecode @KEMPublicKey (strEncode shortPublicKey) `shouldSatisfy` isLeft + smpDecode @KEMPublicKey (smpEncode $ Large validPublicKey) `shouldSatisfy` isRight + strDecode @KEMPublicKey (strEncode validPublicKey) `shouldSatisfy` isRight + smpDecode @KEMCiphertext (smpEncode $ Large shortCiphertext) `shouldSatisfy` isLeft + strDecode @KEMCiphertext (strEncode shortCiphertext) `shouldSatisfy` isLeft + smpDecode @KEMCiphertext (smpEncode $ Large validCiphertext) `shouldSatisfy` isRight + strDecode @KEMCiphertext (strEncode validCiphertext) `shouldSatisfy` isRight + smpDecode @KEMSecretKey (smpEncode $ Large shortSecretKey) `shouldSatisfy` isLeft + strDecode @KEMSecretKey (strEncode shortSecretKey) `shouldSatisfy` isLeft + +testSNTRUP761RejectsMalformedFFIInputs :: IO () +testSNTRUP761RejectsMalformedFFIInputs = do + drg <- C.newRandom + (_, sk) <- sntrup761Keypair drg + sntrup761Enc drg (KEMPublicKey shortPublicKey) + `shouldThrow` kemLengthException "public key" + sntrup761Dec (KEMCiphertext shortCiphertext) sk + `shouldThrow` kemLengthException "ciphertext" + +kemLengthException :: String -> SomeException -> Bool +kemLengthException valueName e = valueName `isInfixOf` show e + +shortPublicKey :: B.ByteString +shortPublicKey = B.replicate (c_SNTRUP761_PUBLICKEY_SIZE - 1) 'p' + +validPublicKey :: B.ByteString +validPublicKey = B.replicate c_SNTRUP761_PUBLICKEY_SIZE 'p' + +shortCiphertext :: B.ByteString +shortCiphertext = B.replicate (c_SNTRUP761_CIPHERTEXT_SIZE - 1) 'c' + +validCiphertext :: B.ByteString +validCiphertext = B.replicate c_SNTRUP761_CIPHERTEXT_SIZE 'c' + +shortSecretKey :: B.ByteString +shortSecretKey = B.replicate (c_SNTRUP761_SECRETKEY_SIZE - 1) 's'