Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 90 additions & 16 deletions Lib/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@
CAN_GET_SELECTED_OPENSSL_SIGALG = ssl.OPENSSL_VERSION_INFO >= (3, 5)
PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')

HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
requires_keylog = unittest.skipUnless(
HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')
CAN_SET_KEYLOG = HAS_KEYLOG and os.name != "nt"
requires_keylog_setter = unittest.skipUnless(
CAN_SET_KEYLOG,
"cannot set 'keylog_filename' on Windows"
)


PROTOCOL_TO_TLS_VERSION = {}
for proto, ver in (
("PROTOCOL_SSLv3", "SSLv3"),
Expand Down Expand Up @@ -265,34 +275,69 @@ def utc_offset(): #NOTE: ignore issues like #1647654
)


def test_wrap_socket(sock, *,
cert_reqs=ssl.CERT_NONE, ca_certs=None,
ciphers=None, ciphersuites=None,
min_version=None, max_version=None,
certfile=None, keyfile=None,
**kwargs):
if not kwargs.get("server_side"):
kwargs["server_hostname"] = SIGNED_CERTFILE_HOSTNAME
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
else:
def make_test_context(
*,
server_side=False,
check_hostname=None,
cert_reqs=ssl.CERT_NONE,
ca_certs=None, certfile=None, keyfile=None,
ciphers=None, ciphersuites=None,
min_version=None, max_version=None,
):
if server_side:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
if cert_reqs is not None:
else:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)

if check_hostname is None:
if cert_reqs == ssl.CERT_NONE:
context.check_hostname = False
else:
context.check_hostname = check_hostname

if cert_reqs is not None:
context.verify_mode = cert_reqs

if ca_certs is not None:
context.load_verify_locations(ca_certs)
if certfile is not None or keyfile is not None:
context.load_cert_chain(certfile, keyfile)

if ciphers is not None:
context.set_ciphers(ciphers)
if ciphersuites is not None:
context.set_ciphersuites(ciphersuites)

if min_version is not None:
context.minimum_version = min_version
if max_version is not None:
context.maximum_version = max_version
return context.wrap_socket(sock, **kwargs)

return context


def test_wrap_socket(
sock,
*,
server_side=False,
check_hostname=None,
cert_reqs=ssl.CERT_NONE,
ca_certs=None, certfile=None, keyfile=None,
ciphers=None, ciphersuites=None,
min_version=None, max_version=None,
**kwargs,
):
context = make_test_context(
server_side=server_side,
check_hostname=check_hostname,
cert_reqs=cert_reqs,
ca_certs=ca_certs, certfile=certfile, keyfile=keyfile,
ciphers=ciphers, ciphersuites=ciphersuites,
min_version=min_version, max_version=max_version,
)
if not server_side:
kwargs.setdefault("server_hostname", SIGNED_CERTFILE_HOSTNAME)
return context.wrap_socket(sock, server_side=server_side, **kwargs)


USE_SAME_TEST_CONTEXT = False
Expand Down Expand Up @@ -1730,6 +1775,39 @@ def test_num_tickest(self):
with self.assertRaises(ValueError):
ctx.num_tickets = 1

@support.cpython_only
def test_refcycle_msg_callback(self):
# See https://github.com/python/cpython/issues/142516.
ctx = make_test_context()
def msg_callback(*args, _=ctx, **kwargs): ...
ctx._msg_callback = msg_callback

@support.cpython_only
@requires_keylog_setter
def test_refcycle_keylog_filename(self):
# See https://github.com/python/cpython/issues/142516.
self.addCleanup(os_helper.unlink, os_helper.TESTFN)
ctx = make_test_context()
class KeylogFilename(str): ...
ctx.keylog_filename = KeylogFilename(os_helper.TESTFN)
ctx.keylog_filename._ = ctx

@support.cpython_only
@unittest.skipUnless(ssl.HAS_PSK, 'requires TLS-PSK')
def test_refcycle_psk_client_callback(self):
# See https://github.com/python/cpython/issues/142516.
ctx = make_test_context()
def psk_client_callback(*args, _=ctx, **kwargs): ...
ctx.set_psk_client_callback(psk_client_callback)

@support.cpython_only
@unittest.skipUnless(ssl.HAS_PSK, 'requires TLS-PSK')
def test_refcycle_psk_server_callback(self):
# See https://github.com/python/cpython/issues/142516.
ctx = make_test_context(server_side=True)
def psk_server_callback(*args, _=ctx, **kwargs): ...
ctx.set_psk_server_callback(psk_server_callback)


class SSLErrorTests(unittest.TestCase):

Expand Down Expand Up @@ -5163,10 +5241,6 @@ def test_internal_chain_server(self):
self.assertEqual(res, b'\x02\n')


HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
requires_keylog = unittest.skipUnless(
HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')

class TestSSLDebug(unittest.TestCase):

def keylog_lines(self, fname=os_helper.TESTFN):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
:mod:`ssl`: fix reference leaks in :class:`ssl.SSLContext` objects. Patch by
Bénédikt Tran.
18 changes: 12 additions & 6 deletions Modules/_ssl.c
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ typedef struct {
int post_handshake_auth;
#endif
PyObject *msg_cb;
PyObject *keylog_filename;
PyObject *keylog_filename; // can be anything accepted by Py_fopen()
BIO *keylog_bio;
/* Cached module state, also used in SSLSocket and SSLSession code. */
_sslmodulestate *state;
Expand Down Expand Up @@ -358,7 +358,7 @@ typedef struct {
PySSLContext *ctx; /* weakref to SSL context */
char shutdown_seen_zero;
enum py_ssl_server_or_client socket_type;
PyObject *owner; /* Python level "owner" passed to servername callback */
PyObject *owner; /* weakref to Python level "owner" passed to servername callback */
PyObject *server_hostname;
/* Some SSL callbacks don't have error reporting. Callback wrappers
* store exception information on the socket. The handshake, read, write,
Expand Down Expand Up @@ -2444,6 +2444,10 @@ static int
PySSL_clear(PyObject *op)
{
PySSLSocket *self = PySSLSocket_CAST(op);
Py_CLEAR(self->Socket);
Py_CLEAR(self->ctx);
Py_CLEAR(self->owner);
Py_CLEAR(self->server_hostname);
Py_CLEAR(self->exc);
return 0;
}
Expand All @@ -2468,10 +2472,7 @@ PySSL_dealloc(PyObject *op)
SSL_set_shutdown(self->ssl, SSL_SENT_SHUTDOWN | SSL_get_shutdown(self->ssl));
SSL_free(self->ssl);
}
Py_XDECREF(self->Socket);
Py_XDECREF(self->ctx);
Py_XDECREF(self->server_hostname);
Py_XDECREF(self->owner);
(void)PySSL_clear(op);
PyObject_GC_Del(self);
Py_DECREF(tp);
}
Expand Down Expand Up @@ -3594,6 +3595,11 @@ context_traverse(PyObject *op, visitproc visit, void *arg)
PySSLContext *self = PySSLContext_CAST(op);
Py_VISIT(self->set_sni_cb);
Py_VISIT(self->msg_cb);
Py_VISIT(self->keylog_filename);
#ifndef OPENSSL_NO_PSK
Py_VISIT(self->psk_client_callback);
Py_VISIT(self->psk_server_callback);
#endif
Py_VISIT(Py_TYPE(self));
return 0;
}
Expand Down
Loading