Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions compiler/cpp/src/thrift/generate/t_py_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,9 @@ void t_py_generator::generate_py_struct_reader(ostream& out, t_struct* tstruct)
}
indent_down();

indent(out) << "iprot.increment_recursion_depth()" << '\n';
indent(out) << "try:" << '\n';
indent_up();
indent(out) << "iprot.readStructBegin()" << '\n';

if (is_immutable(tstruct)) {
Expand Down Expand Up @@ -1193,6 +1196,12 @@ void t_py_generator::generate_py_struct_reader(ostream& out, t_struct* tstruct)
indent(out) << ")" << '\n';
}

indent_down();
indent(out) << "finally:" << '\n';
indent_up();
indent(out) << "iprot.decrement_recursion_depth()" << '\n';
indent_down();

indent_down();
out << '\n';
}
Expand All @@ -1214,6 +1223,9 @@ void t_py_generator::generate_py_struct_writer(ostream& out, t_struct* tstruct)
indent(out) << "return" << '\n';
indent_down();

indent(out) << "oprot.increment_recursion_depth()" << '\n';
indent(out) << "try:" << '\n';
indent_up();
indent(out) << "oprot.writeStructBegin('" << name << "')" << '\n';

for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
Expand All @@ -1237,6 +1249,12 @@ void t_py_generator::generate_py_struct_writer(ostream& out, t_struct* tstruct)
out << indent() << "oprot.writeFieldStop()" << '\n' << indent() << "oprot.writeStructEnd()"
<< '\n';

indent_down();
indent(out) << "finally:" << '\n';
indent_up();
indent(out) << "oprot.decrement_recursion_depth()" << '\n';
indent_down();

out << '\n';

indent_down();
Expand Down
3 changes: 3 additions & 0 deletions lib/py/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ py3-test: py3-build
$(PYTHON3) test/thrift_TCompactProtocol.py
$(PYTHON3) test/thrift_TNonblockingServer.py
$(PYTHON3) test/thrift_TSerializer.py
$(PYTHON3) test/test_recursion_depth.py
else
py3-build:
py3-test:
Expand All @@ -40,6 +41,7 @@ all-local: py3-build
$(PYTHON) setup.py build
${THRIFT} --gen py test/test_thrift_file/TestServer.thrift
${THRIFT} --gen py ../../test/v0.16/FuzzTestNoUuid.thrift
${THRIFT} --gen py ../../test/Recursive.thrift

# We're ignoring prefix here because site-packages seems to be
# the equivalent of /usr/local/lib in Python land.
Expand All @@ -59,6 +61,7 @@ check-local: all py3-test
$(PYTHON) test/thrift_TNonblockingServer.py
$(PYTHON) test/thrift_TSerializer.py
$(PYTHON) test/test_compiler/test_keyword_escape.py
$(PYTHON) test/test_recursion_depth.py


clean-local:
Expand Down
12 changes: 9 additions & 3 deletions lib/py/src/ext/protocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class ProtocolBase {
ProtocolBase()
: stringLimit_((std::numeric_limits<int32_t>::max)()),
containerLimit_((std::numeric_limits<int32_t>::max)()),
recursionDepth_(0),
output_(nullptr) {}
inline virtual ~ProtocolBase();

Expand All @@ -54,6 +55,10 @@ class ProtocolBase {
long containerLimit() const { return containerLimit_; }
void setContainerLengthLimit(long limit) { containerLimit_ = limit; }

static const int32_t kDefaultRecursionDepth = 64;
bool checkDepthLimit();
void decrementDepth() { recursionDepth_--; }

protected:
bool readBytes(char** output, int len);

Expand Down Expand Up @@ -84,12 +89,13 @@ class ProtocolBase {

long stringLimit_;
long containerLimit_;
int32_t recursionDepth_;
EncodeBuffer* output_;
DecodeBuffer input_;
};
}
}
}
} // namespace py
} // namespace thrift
} // namespace apache

#include "ext/protocol.tcc"

Expand Down
77 changes: 62 additions & 15 deletions lib/py/src/ext/protocol.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ inline int read_buffer(PyObject* buf, char** output, int len) {
}
return PycStringIO->cread(buf, output, len);
}
}
} // namespace detail

template <typename Impl>
inline ProtocolBase<Impl>::~ProtocolBase() {
Expand Down Expand Up @@ -147,7 +147,7 @@ inline int read_buffer(PyObject* buf, char** output, int len) {
buf2->pos = (std::min)(buf2->pos + static_cast<Py_ssize_t>(len), buf2->string_size);
return static_cast<int>(buf2->pos - pos0);
}
}
} // namespace detail

template <typename Impl>
inline ProtocolBase<Impl>::~ProtocolBase() {
Expand Down Expand Up @@ -207,6 +207,18 @@ DECLARE_OP_SCOPE(WriteStruct, writeStruct)
DECLARE_OP_SCOPE(ReadStruct, readStruct)
#undef DECLARE_OP_SCOPE

template <typename Impl>
struct RecursionGuard {
ProtocolBase<Impl>* proto;
bool valid;
explicit RecursionGuard(ProtocolBase<Impl>* p) : proto(p), valid(p->checkDepthLimit()) {}
~RecursionGuard() {
if (valid)
proto->decrementDepth();
}
operator bool() const { return valid; }
};

inline bool check_ssize_t_32(Py_ssize_t len) {
// error from getting the int
if (INT_CONV_ERROR_OCCURRED(len)) {
Expand All @@ -218,7 +230,7 @@ inline bool check_ssize_t_32(Py_ssize_t len) {
}
return true;
}
}
} // namespace detail

template <typename T>
bool parse_pyint(PyObject* o, T* ret, int32_t min, int32_t max) {
Expand Down Expand Up @@ -258,6 +270,31 @@ bool ProtocolBase<Impl>::checkLengthLimit(int32_t len, long limit) {
return true;
}

template <typename Impl>
bool ProtocolBase<Impl>::checkDepthLimit() {
recursionDepth_++;
if (recursionDepth_ > kDefaultRecursionDepth) {
recursionDepth_--;
static PyObject* TProtocolExceptionCls = nullptr;
if (!TProtocolExceptionCls) {
PyObject* mod = PyImport_ImportModule("thrift.protocol.TProtocol");
if (!mod)
return false;
TProtocolExceptionCls = PyObject_GetAttrString(mod, "TProtocolException");
Py_DECREF(mod);
if (!TProtocolExceptionCls)
return false;
}
ScopedPyObject exc(
PyObject_CallFunction(TProtocolExceptionCls, "is", 6, "Maximum recursion depth exceeded"));
if (!exc)
return false;
PyErr_SetObject(TProtocolExceptionCls, exc.get());
return false;
}
return true;
}

template <typename Impl>
bool ProtocolBase<Impl>::readBytes(char** output, int len) {
if (len < 0) {
Expand Down Expand Up @@ -502,6 +539,11 @@ bool ProtocolBase<Impl>::encodeValue(PyObject* value, TType type, PyObject* type
return false;
}

detail::RecursionGuard<Impl> rec(this);
if (!rec) {
return false;
}

Py_ssize_t nspec = PyTuple_Size(parsedargs.spec);
if (nspec == -1) {
PyErr_SetString(PyExc_TypeError, "spec is not a tuple");
Expand Down Expand Up @@ -545,17 +587,17 @@ bool ProtocolBase<Impl>::encodeValue(PyObject* value, TType type, PyObject* type
case T_UUID: {
ScopedPyObject instval(PyObject_GetAttr(value, INTERN_STRING(bytes)));
if (!instval) {
return false;
return false;
}

Py_ssize_t size;
char* buffer;
if (PyBytes_AsStringAndSize(instval.get(), &buffer, &size) < 0) {
return false;
return false;
}
if (size != 16) {
PyErr_SetString(PyExc_TypeError, "uuid.bytes must be exactly 16 bytes long");
return false;
PyErr_SetString(PyExc_TypeError, "uuid.bytes must be exactly 16 bytes long");
return false;
}
impl()->writeUuid(buffer);
return true;
Expand Down Expand Up @@ -836,24 +878,24 @@ PyObject* ProtocolBase<Impl>::decodeValue(TType type, PyObject* typeargs) {

case T_UUID: {
char* buf = nullptr;
if(!impl()->readUuid(&buf)) {
if (!impl()->readUuid(&buf)) {
return nullptr;
}

if(!UuidModule) {
if (!UuidModule) {
UuidModule = PyImport_ImportModule("uuid");
if (!UuidModule)
return nullptr;
}

ScopedPyObject cls(PyObject_GetAttr(UuidModule, INTERN_STRING(UUID)));
if (!cls) {
return nullptr;
return nullptr;
}

ScopedPyObject pyBytes(PyBytes_FromStringAndSize(buf, 16));
if (!pyBytes) {
return nullptr;
return nullptr;
}

ScopedPyObject args(PyTuple_New(0));
Expand Down Expand Up @@ -900,7 +942,7 @@ PyObject* ProtocolBase<Impl>::readStruct(PyObject* output, PyObject* klass, PyOb
// 1. "frozen2" mode: classes inherit from TFrozenBase
// 2. "python.immutable" annotation: classes get a __setattr__ that raises TypeError
immutable = PyObject_IsSubclass(klass, TFrozenBase)
|| reinterpret_cast<PyTypeObject*>(klass)->tp_setattro != PyObject_GenericSetAttr;
|| reinterpret_cast<PyTypeObject*>(klass)->tp_setattro != PyObject_GenericSetAttr;

if (immutable) {
kwargs.reset(PyDict_New());
Expand All @@ -917,6 +959,11 @@ PyObject* ProtocolBase<Impl>::readStruct(PyObject* output, PyObject* klass, PyOb
}
}

detail::RecursionGuard<Impl> rec(this);
if (!rec) {
return nullptr;
}

detail::ReadStructScope<Impl> scope = detail::readStructScope(this);
if (!scope) {
return nullptr;
Expand Down Expand Up @@ -980,7 +1027,7 @@ PyObject* ProtocolBase<Impl>::readStruct(PyObject* output, PyObject* klass, PyOb
Py_INCREF(output);
return output;
}
}
}
}
} // namespace py
} // namespace thrift
} // namespace apache
#endif // THRIFT_PY_PROTOCOL_H
12 changes: 12 additions & 0 deletions lib/py/src/protocol/TProtocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,22 @@ def __init__(self, type=UNKNOWN, message=None):
class TProtocolBase(object):
"""Base class for Thrift protocol driver."""

DEFAULT_RECURSION_DEPTH = 64

def __init__(self, trans):
self.trans = trans
self._fast_decode = None
self._fast_encode = None
self._recursion_depth = 0

def increment_recursion_depth(self):
self._recursion_depth += 1
if self._recursion_depth > self.DEFAULT_RECURSION_DEPTH:
raise TProtocolException(TProtocolException.DEPTH_LIMIT,
"Maximum recursion depth exceeded")

def decrement_recursion_depth(self):
self._recursion_depth -= 1

@staticmethod
def _check_length(limit, length):
Expand Down
Loading
Loading