Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions Lib/test/test_free_threading/test_decimal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import unittest

from test.support import threading_helper

import decimal

N_THREADS = 8
ITERATIONS = 100_000

@threading_helper.requires_working_threading()
class TestDecimal(unittest.TestCase):
def test_add_status(self):
# prec=4 makes "1.23456" Inexact|Rounded
shared_ctx = decimal.Context(prec=4)
def worker():
for _ in range(ITERATIONS):
shared_ctx.create_decimal("1.23456")
threading_helper.run_concurrently(worker, N_THREADS)


def test_clear_flags(self):
shared_ctx = decimal.Context(prec=4)

def producer():
for _ in range(ITERATIONS):
shared_ctx.create_decimal("1.23456")

def clearer():
for _ in range(ITERATIONS):
shared_ctx.clear_flags()

threading_helper.run_concurrently([producer]*N_THREADS + [clearer]*N_THREADS)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Fix data races in :meth:`decimal.Context.create_decimal` and
:meth:`decimal.Context.clear_flags` when updating
:class:`decimal.Context` status flags in the :term:`free-threaded build`.
36 changes: 34 additions & 2 deletions Modules/_decimal/_decimal.c
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ typedef struct {

typedef struct PyDecContextObject {
PyObject_HEAD
#ifdef Py_GIL_DISABLED
PyMutex ctx_lock;
#endif
mpd_context_t ctx;
PyObject *traps;
PyObject *flags;
Expand Down Expand Up @@ -248,6 +251,16 @@ typedef struct {
#define CTX(v) (&_PyDecContextObject_CAST(v)->ctx)
#define CtxCaps(v) (_PyDecContextObject_CAST(v)->capitals)

#ifdef Py_GIL_DISABLED
#define CTX_LOCK_INIT(v) _PyDecContextObject_CAST(v)->ctx_lock = (PyMutex){0}
#define CTX_LOCK(v) PyMutex_Lock(&_PyDecContextObject_CAST(v)->ctx_lock)
#define CTX_UNLOCK(v) PyMutex_Unlock(&_PyDecContextObject_CAST(v)->ctx_lock)
#else
#define CTX_LOCK_INIT(v) ((void)0)
#define CTX_LOCK(v) ((void)0)
#define CTX_UNLOCK(v) ((void)0)
#endif

static inline decimal_state *
get_module_state_from_ctx(PyObject *v)
{
Expand Down Expand Up @@ -613,7 +626,9 @@ dec_addstatus(PyObject *context, uint32_t status)
mpd_context_t *ctx = CTX(context);
decimal_state *state = get_module_state_from_ctx(context);

CTX_LOCK(context);
ctx->status |= status;
CTX_UNLOCK(context);
if (status & (ctx->traps|MPD_Malloc_error)) {
PyObject *ex, *siglist;

Expand Down Expand Up @@ -1418,7 +1433,9 @@ static PyObject *
_decimal_Context_clear_flags_impl(PyObject *self)
/*[clinic end generated code: output=c86719a70177d0b6 input=a06055e2f3e7edb1]*/
{
CTX_LOCK(self);
CTX(self)->status = 0;
CTX_UNLOCK(self);
Py_RETURN_NONE;
}

Expand Down Expand Up @@ -1449,6 +1466,7 @@ context_new(PyTypeObject *type,
if (self == NULL) {
return NULL;
}
CTX_LOCK_INIT(self);

self->traps = PyObject_CallObject((PyObject *)state->PyDecSignalDict_Type, NULL);
if (self->traps == NULL) {
Expand Down Expand Up @@ -1566,8 +1584,12 @@ context_repr(PyObject *self)
#endif
ctx = CTX(self);

CTX_LOCK(self);
uint32_t ctx_status = ctx->status;
CTX_UNLOCK(self);

mem = MPD_MAX_SIGNAL_LIST;
n = mpd_lsnprint_signals(flags, mem, ctx->status, dec_signal_string);
n = mpd_lsnprint_signals(flags, mem, ctx_status, dec_signal_string);
if (n < 0 || n >= mem) {
INTERNAL_ERROR_PTR("context_repr");
}
Expand All @@ -1594,6 +1616,7 @@ init_basic_context(PyObject *v)
ctx.round = MPD_ROUND_HALF_UP;

*CTX(v) = ctx;
CTX_LOCK_INIT(v);
CtxCaps(v) = 1;
}

Expand All @@ -1606,6 +1629,7 @@ init_extended_context(PyObject *v)
ctx.traps = 0;

*CTX(v) = ctx;
CTX_LOCK_INIT(v);
CtxCaps(v) = 1;
}

Expand Down Expand Up @@ -1719,7 +1743,11 @@ _decimal_Context___reduce___impl(PyObject *self, PyTypeObject *cls)

ctx = CTX(self);

flags = signals_as_list(state, ctx->status);
CTX_LOCK(self);
uint32_t ctx_status = ctx->status;
CTX_UNLOCK(self);

flags = signals_as_list(state, ctx_status);
if (flags == NULL) {
return NULL;
}
Expand Down Expand Up @@ -3495,7 +3523,9 @@ convert_op_cmp(PyObject **vcmp, PyObject **wcmp, PyObject *v, PyObject *w,
*wcmp = NULL;
}
else {
CTX_LOCK(context);
ctx->status |= MPD_Float_operation;
CTX_UNLOCK(context);
*wcmp = PyDec_FromFloatExact(state, w, context);
}
}
Expand All @@ -3510,7 +3540,9 @@ convert_op_cmp(PyObject **vcmp, PyObject **wcmp, PyObject *v, PyObject *w,
*wcmp = NULL;
}
else {
CTX_LOCK(context);
ctx->status |= MPD_Float_operation;
CTX_UNLOCK(context);
*wcmp = PyDec_FromFloatExact(state, tmp, context);
Py_DECREF(tmp);
}
Expand Down
Loading