Skip to content

Commit 31bbc40

Browse files
committed
fix: fix decimal.Context.status race in free-threading
1 parent 26696a6 commit 31bbc40

3 files changed

Lines changed: 81 additions & 4 deletions

File tree

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import unittest
2+
3+
from test.support import threading_helper
4+
5+
import decimal
6+
7+
N_THREADS = 8
8+
ITERATIONS = 100_000
9+
10+
@threading_helper.requires_working_threading()
11+
class TestDecimal(unittest.TestCase):
12+
def test_add_status(self):
13+
# prec=4 makes "1.23456" Inexact|Rounded
14+
shared_ctx = decimal.Context(prec=4)
15+
def worker():
16+
for _ in range(ITERATIONS):
17+
shared_ctx.create_decimal("1.23456")
18+
threading_helper.run_concurrently(worker, N_THREADS)
19+
20+
21+
def test_clear_flags(self):
22+
shared_ctx = decimal.Context(prec=4)
23+
24+
def producer():
25+
for _ in range(ITERATIONS):
26+
shared_ctx.create_decimal("1.23456")
27+
28+
def clearer():
29+
for _ in range(ITERATIONS):
30+
shared_ctx.clear_flags()
31+
32+
threading_helper.run_concurrently([producer]*N_THREADS + [clearer]*N_THREADS)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix a data race in :class:`decimal.Context` status flag updates in the
2+
:term:`free-threaded build`.

Modules/_decimal/_decimal.c

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ typedef struct {
218218

219219
typedef struct PyDecContextObject {
220220
PyObject_HEAD
221+
PyMutex ctx_lock;
221222
mpd_context_t ctx;
222223
PyObject *traps;
223224
PyObject *flags;
@@ -227,6 +228,7 @@ typedef struct PyDecContextObject {
227228
} PyDecContextObject;
228229

229230
#define _PyDecContextObject_CAST(op) ((PyDecContextObject *)(op))
231+
#define _PyDecContextObject_LOCKED_CAST(op) ((PyDecContextObject *)(op))
230232

231233
typedef struct {
232234
PyObject_HEAD
@@ -246,6 +248,7 @@ typedef struct {
246248
#define SdFlagAddr(v) (_PyDecSignalDictObject_CAST(v)->flags)
247249
#define SdFlags(v) (*_PyDecSignalDictObject_CAST(v)->flags)
248250
#define CTX(v) (&_PyDecContextObject_CAST(v)->ctx)
251+
#define CTX_LOCK(v) (&_PyDecContextObject_CAST(v)->ctx_lock)
249252
#define CtxCaps(v) (_PyDecContextObject_CAST(v)->capitals)
250253

251254
static inline decimal_state *
@@ -611,9 +614,12 @@ static int
611614
dec_addstatus(PyObject *context, uint32_t status)
612615
{
613616
mpd_context_t *ctx = CTX(context);
617+
PyMutex* ctx_lock = CTX_LOCK(context);
614618
decimal_state *state = get_module_state_from_ctx(context);
615619

620+
PyMutex_Lock(ctx_lock);
616621
ctx->status |= status;
622+
PyMutex_Unlock(ctx_lock);
617623
if (status & (ctx->traps|MPD_Malloc_error)) {
618624
PyObject *ex, *siglist;
619625

@@ -1418,7 +1424,10 @@ static PyObject *
14181424
_decimal_Context_clear_flags_impl(PyObject *self)
14191425
/*[clinic end generated code: output=c86719a70177d0b6 input=a06055e2f3e7edb1]*/
14201426
{
1427+
PyMutex* ctx_lock = CTX_LOCK(self);
1428+
PyMutex_Lock(ctx_lock);
14211429
CTX(self)->status = 0;
1430+
PyMutex_Unlock(ctx_lock);
14221431
Py_RETURN_NONE;
14231432
}
14241433

@@ -1437,6 +1446,7 @@ context_new(PyTypeObject *type,
14371446
{
14381447
PyDecContextObject *self = NULL;
14391448
mpd_context_t *ctx;
1449+
PyMutex* ctx_lock;
14401450

14411451
decimal_state *state = get_module_state_by_def(type);
14421452
if (type == state->PyDecContext_Type) {
@@ -1449,6 +1459,7 @@ context_new(PyTypeObject *type,
14491459
if (self == NULL) {
14501460
return NULL;
14511461
}
1462+
self->ctx_lock = (PyMutex){0};
14521463

14531464
self->traps = PyObject_CallObject((PyObject *)state->PyDecSignalDict_Type, NULL);
14541465
if (self->traps == NULL) {
@@ -1471,8 +1482,12 @@ context_new(PyTypeObject *type,
14711482
*ctx = dflt_ctx;
14721483
}
14731484

1485+
ctx_lock = CTX_LOCK(self);
1486+
14741487
SdFlagAddr(self->traps) = &ctx->traps;
1488+
PyMutex_Lock(ctx_lock);
14751489
SdFlagAddr(self->flags) = &ctx->status;
1490+
PyMutex_Unlock(ctx_lock);
14761491

14771492
CtxCaps(self) = 1;
14781493
self->tstate = NULL;
@@ -1556,6 +1571,7 @@ static PyObject *
15561571
context_repr(PyObject *self)
15571572
{
15581573
mpd_context_t *ctx;
1574+
PyMutex* ctx_lock;
15591575
char flags[MPD_MAX_SIGNAL_LIST];
15601576
char traps[MPD_MAX_SIGNAL_LIST];
15611577
int n, mem;
@@ -1566,8 +1582,13 @@ context_repr(PyObject *self)
15661582
#endif
15671583
ctx = CTX(self);
15681584

1585+
ctx_lock = CTX_LOCK(self);
1586+
PyMutex_Lock(ctx_lock);
1587+
uint32_t ctx_status = ctx->status;
1588+
PyMutex_Unlock(ctx_lock);
1589+
15691590
mem = MPD_MAX_SIGNAL_LIST;
1570-
n = mpd_lsnprint_signals(flags, mem, ctx->status, dec_signal_string);
1591+
n = mpd_lsnprint_signals(flags, mem, ctx_status, dec_signal_string);
15711592
if (n < 0 || n >= mem) {
15721593
INTERNAL_ERROR_PTR("context_repr");
15731594
}
@@ -1594,6 +1615,7 @@ init_basic_context(PyObject *v)
15941615
ctx.round = MPD_ROUND_HALF_UP;
15951616

15961617
*CTX(v) = ctx;
1618+
*CTX_LOCK(v) = (PyMutex){0};
15971619
CtxCaps(v) = 1;
15981620
}
15991621

@@ -1606,6 +1628,7 @@ init_extended_context(PyObject *v)
16061628
ctx.traps = 0;
16071629

16081630
*CTX(v) = ctx;
1631+
*CTX_LOCK(v) = (PyMutex){0};
16091632
CtxCaps(v) = 1;
16101633
}
16111634

@@ -1715,11 +1738,16 @@ _decimal_Context___reduce___impl(PyObject *self, PyTypeObject *cls)
17151738
PyObject *traps;
17161739
PyObject *ret;
17171740
mpd_context_t *ctx;
1741+
PyMutex* ctx_lock;
17181742
decimal_state *state = PyType_GetModuleState(cls);
17191743

17201744
ctx = CTX(self);
1745+
ctx_lock = CTX_LOCK(self);
17211746

1722-
flags = signals_as_list(state, ctx->status);
1747+
PyMutex_Lock(ctx_lock);
1748+
uint32_t ctx_status = ctx->status;
1749+
PyMutex_Unlock(ctx_lock);
1750+
flags = signals_as_list(state, ctx_status);
17231751
if (flags == NULL) {
17241752
return NULL;
17251753
}
@@ -1917,11 +1945,17 @@ PyDec_SetCurrentContext(PyObject *self, PyObject *v)
19171945
static PyObject *
19181946
init_current_context(decimal_state *state)
19191947
{
1948+
mpd_context_t* ctx;
1949+
PyMutex* ctx_lock;
19201950
PyObject *tl_context = context_copy(state, state->default_context_template);
19211951
if (tl_context == NULL) {
19221952
return NULL;
19231953
}
1924-
CTX(tl_context)->status = 0;
1954+
ctx = CTX(tl_context);
1955+
ctx_lock = CTX_LOCK(tl_context);
1956+
PyMutex_Lock(ctx_lock);
1957+
ctx->status = 0;
1958+
PyMutex_Unlock(ctx_lock);
19251959

19261960
PyObject *tok = PyContextVar_Set(state->current_context_var, tl_context);
19271961
if (tok == NULL) {
@@ -1982,7 +2016,11 @@ PyDec_SetCurrentContext(PyObject *self, PyObject *v)
19822016
if (v == NULL) {
19832017
return NULL;
19842018
}
1985-
CTX(v)->status = 0;
2019+
mpd_context_t* ctx = CTX(v);
2020+
PyMutex* ctx_lock = CTX_LOCK(v);
2021+
PyMutex_Lock(ctx_lock);
2022+
ctx->status = 0;
2023+
PyMutex_Unlock(ctx_lock);
19862024
}
19872025
else {
19882026
Py_INCREF(v);
@@ -3479,6 +3517,7 @@ convert_op_cmp(PyObject **vcmp, PyObject **wcmp, PyObject *v, PyObject *w,
34793517
int op, PyObject *context)
34803518
{
34813519
mpd_context_t *ctx = CTX(context);
3520+
PyMutex* ctx_lock = CTX_LOCK(context);
34823521

34833522
*vcmp = v;
34843523

@@ -3495,7 +3534,9 @@ convert_op_cmp(PyObject **vcmp, PyObject **wcmp, PyObject *v, PyObject *w,
34953534
*wcmp = NULL;
34963535
}
34973536
else {
3537+
PyMutex_Lock(ctx_lock);
34983538
ctx->status |= MPD_Float_operation;
3539+
PyMutex_Unlock(ctx_lock);
34993540
*wcmp = PyDec_FromFloatExact(state, w, context);
35003541
}
35013542
}
@@ -3510,7 +3551,9 @@ convert_op_cmp(PyObject **vcmp, PyObject **wcmp, PyObject *v, PyObject *w,
35103551
*wcmp = NULL;
35113552
}
35123553
else {
3554+
PyMutex_Lock(ctx_lock);
35133555
ctx->status |= MPD_Float_operation;
3556+
PyMutex_Unlock(ctx_lock);
35143557
*wcmp = PyDec_FromFloatExact(state, tmp, context);
35153558
Py_DECREF(tmp);
35163559
}

0 commit comments

Comments
 (0)