diff --git a/src/rust/src/backend/dsa.rs b/src/rust/src/backend/dsa.rs index c398c7faad6b..5eb47d106731 100644 --- a/src/rust/src/backend/dsa.rs +++ b/src/rust/src/backend/dsa.rs @@ -38,18 +38,21 @@ struct DsaParameters { pub(crate) fn private_key_from_pkey( pkey: &openssl::pkey::PKeyRef, -) -> DsaPrivateKey { - DsaPrivateKey { +) -> CryptographyResult { + Ok(DsaPrivateKey { pkey: pkey.to_owned(), - } + }) } pub(crate) fn public_key_from_pkey( + py: pyo3::Python<'_>, pkey: &openssl::pkey::PKeyRef, -) -> DsaPublicKey { - DsaPublicKey { +) -> CryptographyResult { + let key = DsaPublicKey { pkey: pkey.to_owned(), - } + }; + check_dsa_public_numbers(py, &key.public_numbers(py)?)?; + Ok(key) } #[pyo3::pyfunction] @@ -305,12 +308,40 @@ fn check_dsa_parameters( Ok(()) } +fn check_dsa_public_numbers( + py: pyo3::Python<'_>, + numbers: &DsaPublicNumbers, +) -> CryptographyResult<()> { + let params = numbers.parameter_numbers.get(); + check_dsa_parameters(py, params)?; + + if numbers.y.bind(py).le(1)? || numbers.y.bind(py).ge(params.p.bind(py))? { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("y must be > 1 and < p."), + )); + } + + if numbers + .y + .bind(py) + .pow(params.q.bind(py), Some(params.p.bind(py)))? + .ne(1)? + { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("y ** q mod p must be 1."), + )); + } + + Ok(()) +} + fn check_dsa_private_numbers( py: pyo3::Python<'_>, numbers: &DsaPrivateNumbers, ) -> CryptographyResult<()> { - let params = numbers.public_numbers.get().parameter_numbers.get(); - check_dsa_parameters(py, params)?; + let public_numbers = numbers.public_numbers.get(); + let params = public_numbers.parameter_numbers.get(); + check_dsa_public_numbers(py, public_numbers)?; if numbers.x.bind(py).le(0)? || numbers.x.bind(py).ge(params.q.bind(py))? { return Err(CryptographyError::from( @@ -318,7 +349,7 @@ fn check_dsa_private_numbers( )); } - if (**numbers.public_numbers.get().y.bind(py)).ne(params + if (**public_numbers.y.bind(py)).ne(params .g .bind(py) .pow(numbers.x.bind(py), Some(params.p.bind(py)))?)? @@ -440,7 +471,7 @@ impl DsaPublicNumbers { let parameter_numbers = self.parameter_numbers.get(); - check_dsa_parameters(py, parameter_numbers)?; + check_dsa_public_numbers(py, self)?; let dsa = openssl::dsa::Dsa::from_public_components( utils::py_int_to_bn(py, parameter_numbers.p.bind(py))?, diff --git a/src/rust/src/backend/keys.rs b/src/rust/src/backend/keys.rs index 5acebea690b1..7015a7adb58d 100644 --- a/src/rust/src/backend/keys.rs +++ b/src/rust/src/backend/keys.rs @@ -165,7 +165,7 @@ fn private_key_from_pkey<'p>( openssl::pkey::Id::ED448 => Ok(crate::backend::ed448::private_key_from_pkey(pkey) .into_pyobject(py)? .into_any()), - openssl::pkey::Id::DSA => Ok(crate::backend::dsa::private_key_from_pkey(pkey) + openssl::pkey::Id::DSA => Ok(crate::backend::dsa::private_key_from_pkey(pkey)? .into_pyobject(py)? .into_any()), openssl::pkey::Id::DH => Ok(crate::backend::dh::private_key_from_pkey(pkey) @@ -363,7 +363,7 @@ fn public_key_from_pkey<'p>( .into_pyobject(py)? .into_any()), - openssl::pkey::Id::DSA => Ok(crate::backend::dsa::public_key_from_pkey(pkey) + openssl::pkey::Id::DSA => Ok(crate::backend::dsa::public_key_from_pkey(py, pkey)? .into_pyobject(py)? .into_any()), openssl::pkey::Id::DH => Ok(crate::backend::dh::public_key_from_pkey(pkey) diff --git a/tests/hazmat/primitives/test_dsa.py b/tests/hazmat/primitives/test_dsa.py index 94e25eef8cd4..c3ecee63b138 100644 --- a/tests/hazmat/primitives/test_dsa.py +++ b/tests/hazmat/primitives/test_dsa.py @@ -352,6 +352,24 @@ def test_invalid_dsa_private_key_arguments(self, p, q, g, y, x, backend): 2**1200, DSA_KEY_1024.public_numbers.y, ), + ( + DSA_KEY_1024.public_numbers.parameter_numbers.p, + DSA_KEY_1024.public_numbers.parameter_numbers.q, + DSA_KEY_1024.public_numbers.parameter_numbers.g, + 1, + ), + ( + DSA_KEY_1024.public_numbers.parameter_numbers.p, + DSA_KEY_1024.public_numbers.parameter_numbers.q, + DSA_KEY_1024.public_numbers.parameter_numbers.g, + DSA_KEY_1024.public_numbers.parameter_numbers.p + 1, + ), + ( + DSA_KEY_1024.public_numbers.parameter_numbers.p, + DSA_KEY_1024.public_numbers.parameter_numbers.q, + DSA_KEY_1024.public_numbers.parameter_numbers.g, + DSA_KEY_1024.public_numbers.parameter_numbers.p - 1, + ), ], ) def test_invalid_dsa_public_key_arguments(self, p, q, g, y, backend): @@ -478,12 +496,17 @@ def test_dsa_verification(self, backend, subtests): backend, algorithm, vector["p"], vector["q"], vector["g"] ) - public_key = dsa.DSAPublicNumbers( - parameter_numbers=dsa.DSAParameterNumbers( - vector["p"], vector["q"], vector["g"] - ), - y=vector["y"], - ).public_key(backend) + try: + public_key = dsa.DSAPublicNumbers( + parameter_numbers=dsa.DSAParameterNumbers( + vector["p"], vector["q"], vector["g"] + ), + y=vector["y"], + ).public_key(backend) + except ValueError: + assert vector["result"] == "F" + continue + sig = encode_dss_signature(vector["r"], vector["s"]) if vector["result"] == "F":