diff --git a/compiler/cpp/src/thrift/generate/t_py_generator.cc b/compiler/cpp/src/thrift/generate/t_py_generator.cc index 02b5c4e99f4..ee0f4c51917 100644 --- a/compiler/cpp/src/thrift/generate/t_py_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_py_generator.cc @@ -1119,6 +1119,9 @@ void t_py_generator::generate_py_struct_reader(ostream& out, t_struct* tstruct) } indent_down(); + indent(out) << "iprot.increment_recursion_depth()" << '\n'; + indent(out) << "try:" << '\n'; + indent_up(); indent(out) << "iprot.readStructBegin()" << '\n'; if (is_immutable(tstruct)) { @@ -1193,6 +1196,12 @@ void t_py_generator::generate_py_struct_reader(ostream& out, t_struct* tstruct) indent(out) << ")" << '\n'; } + indent_down(); + indent(out) << "finally:" << '\n'; + indent_up(); + indent(out) << "iprot.decrement_recursion_depth()" << '\n'; + indent_down(); + indent_down(); out << '\n'; } @@ -1214,6 +1223,9 @@ void t_py_generator::generate_py_struct_writer(ostream& out, t_struct* tstruct) indent(out) << "return" << '\n'; indent_down(); + indent(out) << "oprot.increment_recursion_depth()" << '\n'; + indent(out) << "try:" << '\n'; + indent_up(); indent(out) << "oprot.writeStructBegin('" << name << "')" << '\n'; for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) { @@ -1237,6 +1249,12 @@ void t_py_generator::generate_py_struct_writer(ostream& out, t_struct* tstruct) out << indent() << "oprot.writeFieldStop()" << '\n' << indent() << "oprot.writeStructEnd()" << '\n'; + indent_down(); + indent(out) << "finally:" << '\n'; + indent_up(); + indent(out) << "oprot.decrement_recursion_depth()" << '\n'; + indent_down(); + out << '\n'; indent_down(); diff --git a/lib/py/Makefile.am b/lib/py/Makefile.am index 81170e2165e..9932d89ffd5 100644 --- a/lib/py/Makefile.am +++ b/lib/py/Makefile.am @@ -31,6 +31,7 @@ py3-test: py3-build $(PYTHON3) test/thrift_TCompactProtocol.py $(PYTHON3) test/thrift_TNonblockingServer.py $(PYTHON3) test/thrift_TSerializer.py + $(PYTHON3) test/test_recursion_depth.py else py3-build: py3-test: @@ -40,6 +41,7 @@ all-local: py3-build $(PYTHON) setup.py build ${THRIFT} --gen py test/test_thrift_file/TestServer.thrift ${THRIFT} --gen py ../../test/v0.16/FuzzTestNoUuid.thrift + ${THRIFT} --gen py ../../test/Recursive.thrift # We're ignoring prefix here because site-packages seems to be # the equivalent of /usr/local/lib in Python land. @@ -59,6 +61,7 @@ check-local: all py3-test $(PYTHON) test/thrift_TNonblockingServer.py $(PYTHON) test/thrift_TSerializer.py $(PYTHON) test/test_compiler/test_keyword_escape.py + $(PYTHON) test/test_recursion_depth.py clean-local: diff --git a/lib/py/src/ext/protocol.h b/lib/py/src/ext/protocol.h index c0cd43724ac..20911c89724 100644 --- a/lib/py/src/ext/protocol.h +++ b/lib/py/src/ext/protocol.h @@ -35,6 +35,7 @@ class ProtocolBase { ProtocolBase() : stringLimit_((std::numeric_limits::max)()), containerLimit_((std::numeric_limits::max)()), + recursionDepth_(0), output_(nullptr) {} inline virtual ~ProtocolBase(); @@ -54,6 +55,10 @@ class ProtocolBase { long containerLimit() const { return containerLimit_; } void setContainerLengthLimit(long limit) { containerLimit_ = limit; } + static const int32_t kDefaultRecursionDepth = 64; + bool checkDepthLimit(); + void decrementDepth() { recursionDepth_--; } + protected: bool readBytes(char** output, int len); @@ -84,12 +89,13 @@ class ProtocolBase { long stringLimit_; long containerLimit_; + int32_t recursionDepth_; EncodeBuffer* output_; DecodeBuffer input_; }; -} -} -} +} // namespace py +} // namespace thrift +} // namespace apache #include "ext/protocol.tcc" diff --git a/lib/py/src/ext/protocol.tcc b/lib/py/src/ext/protocol.tcc index 123ca69ea83..448fc6f105e 100644 --- a/lib/py/src/ext/protocol.tcc +++ b/lib/py/src/ext/protocol.tcc @@ -63,7 +63,7 @@ inline int read_buffer(PyObject* buf, char** output, int len) { } return PycStringIO->cread(buf, output, len); } -} +} // namespace detail template inline ProtocolBase::~ProtocolBase() { @@ -147,7 +147,7 @@ inline int read_buffer(PyObject* buf, char** output, int len) { buf2->pos = (std::min)(buf2->pos + static_cast(len), buf2->string_size); return static_cast(buf2->pos - pos0); } -} +} // namespace detail template inline ProtocolBase::~ProtocolBase() { @@ -207,6 +207,18 @@ DECLARE_OP_SCOPE(WriteStruct, writeStruct) DECLARE_OP_SCOPE(ReadStruct, readStruct) #undef DECLARE_OP_SCOPE +template +struct RecursionGuard { + ProtocolBase* proto; + bool valid; + explicit RecursionGuard(ProtocolBase* p) : proto(p), valid(p->checkDepthLimit()) {} + ~RecursionGuard() { + if (valid) + proto->decrementDepth(); + } + operator bool() const { return valid; } +}; + inline bool check_ssize_t_32(Py_ssize_t len) { // error from getting the int if (INT_CONV_ERROR_OCCURRED(len)) { @@ -218,7 +230,7 @@ inline bool check_ssize_t_32(Py_ssize_t len) { } return true; } -} +} // namespace detail template bool parse_pyint(PyObject* o, T* ret, int32_t min, int32_t max) { @@ -258,6 +270,31 @@ bool ProtocolBase::checkLengthLimit(int32_t len, long limit) { return true; } +template +bool ProtocolBase::checkDepthLimit() { + recursionDepth_++; + if (recursionDepth_ > kDefaultRecursionDepth) { + recursionDepth_--; + static PyObject* TProtocolExceptionCls = nullptr; + if (!TProtocolExceptionCls) { + PyObject* mod = PyImport_ImportModule("thrift.protocol.TProtocol"); + if (!mod) + return false; + TProtocolExceptionCls = PyObject_GetAttrString(mod, "TProtocolException"); + Py_DECREF(mod); + if (!TProtocolExceptionCls) + return false; + } + ScopedPyObject exc( + PyObject_CallFunction(TProtocolExceptionCls, "is", 6, "Maximum recursion depth exceeded")); + if (!exc) + return false; + PyErr_SetObject(TProtocolExceptionCls, exc.get()); + return false; + } + return true; +} + template bool ProtocolBase::readBytes(char** output, int len) { if (len < 0) { @@ -502,6 +539,11 @@ bool ProtocolBase::encodeValue(PyObject* value, TType type, PyObject* type return false; } + detail::RecursionGuard rec(this); + if (!rec) { + return false; + } + Py_ssize_t nspec = PyTuple_Size(parsedargs.spec); if (nspec == -1) { PyErr_SetString(PyExc_TypeError, "spec is not a tuple"); @@ -545,17 +587,17 @@ bool ProtocolBase::encodeValue(PyObject* value, TType type, PyObject* type case T_UUID: { ScopedPyObject instval(PyObject_GetAttr(value, INTERN_STRING(bytes))); if (!instval) { - return false; + return false; } Py_ssize_t size; char* buffer; if (PyBytes_AsStringAndSize(instval.get(), &buffer, &size) < 0) { - return false; + return false; } if (size != 16) { - PyErr_SetString(PyExc_TypeError, "uuid.bytes must be exactly 16 bytes long"); - return false; + PyErr_SetString(PyExc_TypeError, "uuid.bytes must be exactly 16 bytes long"); + return false; } impl()->writeUuid(buffer); return true; @@ -836,11 +878,11 @@ PyObject* ProtocolBase::decodeValue(TType type, PyObject* typeargs) { case T_UUID: { char* buf = nullptr; - if(!impl()->readUuid(&buf)) { + if (!impl()->readUuid(&buf)) { return nullptr; } - if(!UuidModule) { + if (!UuidModule) { UuidModule = PyImport_ImportModule("uuid"); if (!UuidModule) return nullptr; @@ -848,12 +890,12 @@ PyObject* ProtocolBase::decodeValue(TType type, PyObject* typeargs) { ScopedPyObject cls(PyObject_GetAttr(UuidModule, INTERN_STRING(UUID))); if (!cls) { - return nullptr; + return nullptr; } ScopedPyObject pyBytes(PyBytes_FromStringAndSize(buf, 16)); if (!pyBytes) { - return nullptr; + return nullptr; } ScopedPyObject args(PyTuple_New(0)); @@ -900,7 +942,7 @@ PyObject* ProtocolBase::readStruct(PyObject* output, PyObject* klass, PyOb // 1. "frozen2" mode: classes inherit from TFrozenBase // 2. "python.immutable" annotation: classes get a __setattr__ that raises TypeError immutable = PyObject_IsSubclass(klass, TFrozenBase) - || reinterpret_cast(klass)->tp_setattro != PyObject_GenericSetAttr; + || reinterpret_cast(klass)->tp_setattro != PyObject_GenericSetAttr; if (immutable) { kwargs.reset(PyDict_New()); @@ -917,6 +959,11 @@ PyObject* ProtocolBase::readStruct(PyObject* output, PyObject* klass, PyOb } } + detail::RecursionGuard rec(this); + if (!rec) { + return nullptr; + } + detail::ReadStructScope scope = detail::readStructScope(this); if (!scope) { return nullptr; @@ -980,7 +1027,7 @@ PyObject* ProtocolBase::readStruct(PyObject* output, PyObject* klass, PyOb Py_INCREF(output); return output; } -} -} -} +} // namespace py +} // namespace thrift +} // namespace apache #endif // THRIFT_PY_PROTOCOL_H diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py index a32e7778721..c0e5d7cb76b 100644 --- a/lib/py/src/protocol/TProtocol.py +++ b/lib/py/src/protocol/TProtocol.py @@ -43,10 +43,22 @@ def __init__(self, type=UNKNOWN, message=None): class TProtocolBase(object): """Base class for Thrift protocol driver.""" + DEFAULT_RECURSION_DEPTH = 64 + def __init__(self, trans): self.trans = trans self._fast_decode = None self._fast_encode = None + self._recursion_depth = 0 + + def increment_recursion_depth(self): + self._recursion_depth += 1 + if self._recursion_depth > self.DEFAULT_RECURSION_DEPTH: + raise TProtocolException(TProtocolException.DEPTH_LIMIT, + "Maximum recursion depth exceeded") + + def decrement_recursion_depth(self): + self._recursion_depth -= 1 @staticmethod def _check_length(limit, length): diff --git a/lib/py/test/test_recursion_depth.py b/lib/py/test/test_recursion_depth.py new file mode 100644 index 00000000000..fd54b1ed194 --- /dev/null +++ b/lib/py/test/test_recursion_depth.py @@ -0,0 +1,177 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# Round-trip tests for the struct/exception read/write recursion depth limit. +# Covers the pure-Python path and, when fastbinary is available, the C extension. +# + +import os +import sys +import unittest + +gen_path = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "gen-py" +) +sys.path.insert(0, gen_path) + +import _import_local_thrift # noqa +from thrift.protocol.TBinaryProtocol import TBinaryProtocol +from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated +from thrift.protocol.TCompactProtocol import TCompactProtocol +from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated +from thrift.protocol.TJSONProtocol import TJSONProtocol +from thrift.protocol.TProtocol import TProtocolException +from thrift.transport import TTransport + +from Recursive.ttypes import CoError, CoError2, RecTree + +LIMIT = 64 # TProtocolBase.DEFAULT_RECURSION_DEPTH + + +def make_chain(depth): + """RecTree chain where writing increments the counter 'depth' times.""" + node = RecTree() + for _ in range(depth - 1): + node = RecTree(children=[node]) + return node + + +def make_error_chain(depth): + """CoError/CoError2 alternating chain of total 'depth' struct levels.""" + leaf = CoError() + node = leaf + for _ in range(depth - 1): + if isinstance(node, CoError): + node = CoError2(other=node) + else: + node = CoError(other=node) + return node + + +def make_binary_payload(depth): + """Raw TBinaryProtocol payload for a chain of 'depth' nested RecTree nodes.""" + payload = b"\x00" # leaf: STOP + for _ in range(depth - 1): + # field id=1 type=LIST, list elem=STRUCT count=1, then STOP for outer struct + payload = b"\x0f\x00\x01\x0c\x00\x00\x00\x01" + payload + b"\x00" + return payload + + +def roundtrip_binary(struct): + buf = TTransport.TMemoryBuffer() + struct.write(TBinaryProtocol(buf)) + result = RecTree() + result.read(TBinaryProtocol(TTransport.TMemoryBuffer(buf.getvalue()))) + return result + + +def roundtrip_accel(struct): + buf = TTransport.TMemoryBuffer() + struct.write(TBinaryProtocolAccelerated(buf)) + result = RecTree() + result.read(TBinaryProtocolAccelerated(TTransport.TMemoryBuffer(buf.getvalue()))) + return result + + +class RecursionDepthBinaryTest(unittest.TestCase): + + def test_roundtrip_at_limit(self): + self.assertIsNotNone(roundtrip_binary(make_chain(LIMIT))) + + def test_write_over_limit(self): + with self.assertRaises(TProtocolException) as ctx: + make_chain(LIMIT + 1).write(TBinaryProtocol(TTransport.TMemoryBuffer())) + self.assertEqual(ctx.exception.type, TProtocolException.DEPTH_LIMIT) + + def test_read_over_limit(self): + with self.assertRaises(TProtocolException) as ctx: + result = RecTree() + result.read(TBinaryProtocol(TTransport.TMemoryBuffer(make_binary_payload(LIMIT + 1)))) + self.assertEqual(ctx.exception.type, TProtocolException.DEPTH_LIMIT) + + def test_exception_type_over_limit(self): + with self.assertRaises(TProtocolException) as ctx: + make_error_chain(LIMIT + 1).write(TBinaryProtocol(TTransport.TMemoryBuffer())) + self.assertEqual(ctx.exception.type, TProtocolException.DEPTH_LIMIT) + + +class RecursionDepthAcceleratedTest(unittest.TestCase): + + def setUp(self): + try: + import thrift.protocol.fastbinary # noqa + self.has_fastbinary = True + except ImportError: + self.has_fastbinary = False + + def test_roundtrip_at_limit(self): + self.assertIsNotNone(roundtrip_accel(make_chain(LIMIT))) + + def test_write_over_limit(self): + with self.assertRaises(TProtocolException) as ctx: + make_chain(LIMIT + 1).write(TBinaryProtocolAccelerated(TTransport.TMemoryBuffer())) + self.assertEqual(ctx.exception.type, TProtocolException.DEPTH_LIMIT) + + def test_read_over_limit(self): + with self.assertRaises(TProtocolException) as ctx: + result = RecTree() + result.read( + TBinaryProtocolAccelerated( + TTransport.TMemoryBuffer(make_binary_payload(LIMIT + 1)) + ) + ) + self.assertEqual(ctx.exception.type, TProtocolException.DEPTH_LIMIT) + + +class RecursionDepthCompactTest(unittest.TestCase): + + def test_roundtrip_at_limit(self): + buf = TTransport.TMemoryBuffer() + make_chain(LIMIT).write(TCompactProtocol(buf)) + result = RecTree() + result.read(TCompactProtocol(TTransport.TMemoryBuffer(buf.getvalue()))) + self.assertIsNotNone(result) + + def test_write_over_limit(self): + with self.assertRaises(TProtocolException) as ctx: + make_chain(LIMIT + 1).write(TCompactProtocol(TTransport.TMemoryBuffer())) + self.assertEqual(ctx.exception.type, TProtocolException.DEPTH_LIMIT) + + def test_write_over_limit_accelerated(self): + with self.assertRaises(TProtocolException) as ctx: + make_chain(LIMIT + 1).write(TCompactProtocolAccelerated(TTransport.TMemoryBuffer())) + self.assertEqual(ctx.exception.type, TProtocolException.DEPTH_LIMIT) + + +class RecursionDepthJSONTest(unittest.TestCase): + + def test_roundtrip_at_limit(self): + buf = TTransport.TMemoryBuffer() + make_chain(LIMIT).write(TJSONProtocol(buf)) + result = RecTree() + result.read(TJSONProtocol(TTransport.TMemoryBuffer(buf.getvalue()))) + self.assertIsNotNone(result) + + def test_write_over_limit(self): + with self.assertRaises(TProtocolException) as ctx: + make_chain(LIMIT + 1).write(TJSONProtocol(TTransport.TMemoryBuffer())) + self.assertEqual(ctx.exception.type, TProtocolException.DEPTH_LIMIT) + + +if __name__ == "__main__": + unittest.main(verbosity=2)