Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions src/rust/src/backend/dh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,20 @@ fn generate_parameters(

pub(crate) fn private_key_from_pkey(
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Private>,
) -> DHPrivateKey {
DHPrivateKey {
) -> CryptographyResult<DHPrivateKey> {
check_dh_parameters(&pkey.dh()?)?;
Ok(DHPrivateKey {
pkey: pkey.to_owned(),
}
})
}

pub(crate) fn public_key_from_pkey(
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Public>,
) -> DHPublicKey {
DHPublicKey {
) -> CryptographyResult<DHPublicKey> {
check_dh_parameters(&pkey.dh()?)?;
Ok(DHPublicKey {
pkey: pkey.to_owned(),
}
})
}

#[pyo3::pyfunction]
Expand All @@ -85,9 +87,9 @@ fn from_der_parameters(
.transpose()?;
let g = openssl::bn::BigNum::from_slice(asn1_params.g.as_bytes())?;

Ok(DHParameters {
dh: openssl::dh::Dh::from_pqg(p, q, g)?,
})
let dh = openssl::dh::Dh::from_pqg(p, q, g)?;
check_dh_parameters(&dh)?;
Ok(DHParameters { dh })
}

#[pyo3::pyfunction]
Expand Down Expand Up @@ -119,13 +121,19 @@ fn dh_parameters_from_numbers(
let g = utils::py_int_to_bn(py, numbers.g.bind(py))?;

let dh = openssl::dh::Dh::from_pqg(p, q, g)?;
check_dh_parameters(&dh)?;
Ok(dh)
}

fn check_dh_parameters<T: openssl::pkey::HasParams>(
dh: &openssl::dh::Dh<T>,
) -> CryptographyResult<()> {
if !dh.check_key()? {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("Invalid DH parameters"),
));
}
Ok(dh)
Ok(())
}

fn clone_dh<T: openssl::pkey::HasParams>(
Expand Down
51 changes: 41 additions & 10 deletions src/rust/src/backend/dsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,21 @@ struct DsaParameters {

pub(crate) fn private_key_from_pkey(
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Private>,
) -> DsaPrivateKey {
DsaPrivateKey {
) -> CryptographyResult<DsaPrivateKey> {
Ok(DsaPrivateKey {
pkey: pkey.to_owned(),
}
})
}

pub(crate) fn public_key_from_pkey(
py: pyo3::Python<'_>,
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Public>,
) -> DsaPublicKey {
DsaPublicKey {
) -> CryptographyResult<DsaPublicKey> {
let key = DsaPublicKey {
pkey: pkey.to_owned(),
}
};
check_dsa_public_numbers(py, &key.public_numbers(py)?)?;
Ok(key)
}

#[pyo3::pyfunction]
Expand Down Expand Up @@ -305,20 +308,48 @@ fn check_dsa_parameters(
Ok(())
}

fn check_dsa_public_numbers(
py: pyo3::Python<'_>,
numbers: &DsaPublicNumbers,
) -> CryptographyResult<()> {
let params = numbers.parameter_numbers.get();
check_dsa_parameters(py, params)?;

if numbers.y.bind(py).le(1)? || numbers.y.bind(py).ge(params.p.bind(py))? {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("y must be > 1 and < p."),
));
}

if numbers
.y
.bind(py)
.pow(params.q.bind(py), Some(params.p.bind(py)))?
.ne(1)?
{
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("y ** q mod p must be 1."),
));
}

Ok(())
}

fn check_dsa_private_numbers(
py: pyo3::Python<'_>,
numbers: &DsaPrivateNumbers,
) -> CryptographyResult<()> {
let params = numbers.public_numbers.get().parameter_numbers.get();
check_dsa_parameters(py, params)?;
let public_numbers = numbers.public_numbers.get();
let params = public_numbers.parameter_numbers.get();
check_dsa_public_numbers(py, public_numbers)?;

if numbers.x.bind(py).le(0)? || numbers.x.bind(py).ge(params.q.bind(py))? {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("x must be > 0 and < q."),
));
}

if (**numbers.public_numbers.get().y.bind(py)).ne(params
if (**public_numbers.y.bind(py)).ne(params
.g
.bind(py)
.pow(numbers.x.bind(py), Some(params.p.bind(py)))?)?
Expand Down Expand Up @@ -440,7 +471,7 @@ impl DsaPublicNumbers {

let parameter_numbers = self.parameter_numbers.get();

check_dsa_parameters(py, parameter_numbers)?;
check_dsa_public_numbers(py, self)?;

let dsa = openssl::dsa::Dsa::from_public_components(
utils::py_int_to_bn(py, parameter_numbers.p.bind(py))?,
Expand Down
14 changes: 7 additions & 7 deletions src/rust/src/backend/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,10 @@ fn private_key_from_pkey<'p>(
openssl::pkey::Id::ED448 => Ok(crate::backend::ed448::private_key_from_pkey(pkey)
.into_pyobject(py)?
.into_any()),
openssl::pkey::Id::DSA => Ok(crate::backend::dsa::private_key_from_pkey(pkey)
openssl::pkey::Id::DSA => Ok(crate::backend::dsa::private_key_from_pkey(pkey)?
.into_pyobject(py)?
.into_any()),
openssl::pkey::Id::DH => Ok(crate::backend::dh::private_key_from_pkey(pkey)
openssl::pkey::Id::DH => Ok(crate::backend::dh::private_key_from_pkey(pkey)?
.into_pyobject(py)?
.into_any()),

Expand All @@ -177,7 +177,7 @@ fn private_key_from_pkey<'p>(
CRYPTOGRAPHY_IS_BORINGSSL,
CRYPTOGRAPHY_IS_AWSLC
)))]
openssl::pkey::Id::DHX => Ok(crate::backend::dh::private_key_from_pkey(pkey)
openssl::pkey::Id::DHX => Ok(crate::backend::dh::private_key_from_pkey(pkey)?
.into_pyobject(py)?
.into_any()),
#[cfg(any(
Expand Down Expand Up @@ -333,7 +333,7 @@ fn public_key_from_pkey<'p>(
// `id` is a separate argument so we can test this while passing something
// unsupported.
match id {
openssl::pkey::Id::RSA => Ok(crate::backend::rsa::public_key_from_pkey(pkey)
openssl::pkey::Id::RSA => Ok(crate::backend::rsa::public_key_from_pkey(pkey)?
.into_pyobject(py)?
.into_any()),
openssl::pkey::Id::EC => Ok(crate::backend::ec::public_key_from_pkey(py, pkey)?
Expand Down Expand Up @@ -363,10 +363,10 @@ fn public_key_from_pkey<'p>(
.into_pyobject(py)?
.into_any()),

openssl::pkey::Id::DSA => Ok(crate::backend::dsa::public_key_from_pkey(pkey)
openssl::pkey::Id::DSA => Ok(crate::backend::dsa::public_key_from_pkey(py, pkey)?
.into_pyobject(py)?
.into_any()),
openssl::pkey::Id::DH => Ok(crate::backend::dh::public_key_from_pkey(pkey)
openssl::pkey::Id::DH => Ok(crate::backend::dh::public_key_from_pkey(pkey)?
.into_pyobject(py)?
.into_any()),

Expand All @@ -375,7 +375,7 @@ fn public_key_from_pkey<'p>(
CRYPTOGRAPHY_IS_BORINGSSL,
CRYPTOGRAPHY_IS_AWSLC
)))]
openssl::pkey::Id::DHX => Ok(crate::backend::dh::public_key_from_pkey(pkey)
openssl::pkey::Id::DHX => Ok(crate::backend::dh::public_key_from_pkey(pkey)?
.into_pyobject(py)?
.into_any()),
#[cfg(any(
Expand Down
42 changes: 34 additions & 8 deletions src/rust/src/backend/rsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ pub(crate) fn private_key_from_pkey(

pub(crate) fn public_key_from_pkey(
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Public>,
) -> RsaPublicKey {
RsaPublicKey {
) -> CryptographyResult<RsaPublicKey> {
let rsa = pkey.rsa()?;
check_public_key_components(rsa.e(), rsa.n())?;
Ok(RsaPublicKey {
pkey: pkey.to_owned(),
}
})
}

#[pyo3::pyfunction]
Expand Down Expand Up @@ -795,23 +797,47 @@ impl RsaPrivateNumbers {
}
}

fn check_public_key_components(
fn check_public_key_components_from_py(
py: pyo3::Python<'_>,
e: &pyo3::Bound<'_, pyo3::types::PyInt>,
n: &pyo3::Bound<'_, pyo3::types::PyInt>,
) -> CryptographyResult<()> {
if n.lt(3)? {
if n.lt(0)? {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("n must be >= 3."),
));
}

if e.lt(0)? {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("e must be >= 3 and < n."),
));
}

let e = utils::py_int_to_bn(py, e)?;
let n = utils::py_int_to_bn(py, n)?;
check_public_key_components(e.as_ref(), n.as_ref())
}

fn check_public_key_components(
e: &openssl::bn::BigNumRef,
n: &openssl::bn::BigNumRef,
) -> CryptographyResult<()> {
let three = openssl::bn::BigNum::from_u32(3)?;

if n.cmp(three.as_ref()).is_lt() {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("n must be >= 3."),
));
}

if e.lt(3)? || e.ge(n)? {
if e.cmp(three.as_ref()).is_lt() || e >= n {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("e must be >= 3 and < n."),
));
}

if e.bitand(1)?.eq(0)? {
if e.is_even() {
return Err(CryptographyError::from(
pyo3::exceptions::PyValueError::new_err("e must be odd."),
));
Expand All @@ -835,7 +861,7 @@ impl RsaPublicNumbers {
) -> CryptographyResult<RsaPublicKey> {
let _ = backend;

check_public_key_components(self.e.bind(py), self.n.bind(py))?;
check_public_key_components_from_py(py, self.e.bind(py), self.n.bind(py))?;

let rsa = openssl::rsa::Rsa::from_public_components(
utils::py_int_to_bn(py, self.n.bind(py))?,
Expand Down
35 changes: 29 additions & 6 deletions tests/hazmat/primitives/test_dsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,24 @@ def test_invalid_dsa_private_key_arguments(self, p, q, g, y, x, backend):
2**1200,
DSA_KEY_1024.public_numbers.y,
),
(
DSA_KEY_1024.public_numbers.parameter_numbers.p,
DSA_KEY_1024.public_numbers.parameter_numbers.q,
DSA_KEY_1024.public_numbers.parameter_numbers.g,
1,
),
(
DSA_KEY_1024.public_numbers.parameter_numbers.p,
DSA_KEY_1024.public_numbers.parameter_numbers.q,
DSA_KEY_1024.public_numbers.parameter_numbers.g,
DSA_KEY_1024.public_numbers.parameter_numbers.p + 1,
),
(
DSA_KEY_1024.public_numbers.parameter_numbers.p,
DSA_KEY_1024.public_numbers.parameter_numbers.q,
DSA_KEY_1024.public_numbers.parameter_numbers.g,
DSA_KEY_1024.public_numbers.parameter_numbers.p - 1,
),
],
)
def test_invalid_dsa_public_key_arguments(self, p, q, g, y, backend):
Expand Down Expand Up @@ -478,12 +496,17 @@ def test_dsa_verification(self, backend, subtests):
backend, algorithm, vector["p"], vector["q"], vector["g"]
)

public_key = dsa.DSAPublicNumbers(
parameter_numbers=dsa.DSAParameterNumbers(
vector["p"], vector["q"], vector["g"]
),
y=vector["y"],
).public_key(backend)
try:
public_key = dsa.DSAPublicNumbers(
parameter_numbers=dsa.DSAParameterNumbers(
vector["p"], vector["q"], vector["g"]
),
y=vector["y"],
).public_key(backend)
except ValueError:
assert vector["result"] == "F"
continue

sig = encode_dss_signature(vector["r"], vector["s"])

if vector["result"] == "F":
Expand Down
2 changes: 2 additions & 0 deletions tests/hazmat/primitives/test_rsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -2328,7 +2328,9 @@ def test_private_numbers_invalid_types(
@pytest.mark.parametrize(
("e", "n"),
[
(-1, 15), # public_exponent < 3
(7, 2), # modulus < 3
(7, -1), # modulus < 3
(1, 15), # public_exponent < 3
(17, 15), # public_exponent > modulus
(14, 15), # public_exponent not odd
Expand Down
Loading
Loading