diff --git a/.agents/languages/java.md b/.agents/languages/java.md index 163c19da8c..27594ae004 100644 --- a/.agents/languages/java.md +++ b/.agents/languages/java.md @@ -22,6 +22,10 @@ Load this file when changing anything under `java/` or when Java drives a cross- - Do not add normal-JVM process-global caches keyed by user classes, generated classes, serializer classes, classloaders, or class-bound method handles. Prefer per-runtime state, immutable shared metadata, or build-time-only template data. - Concrete serializers may opt into sharing only after auditing retained fields. Treat serializers retaining `TypeResolver`, `RefResolver`, mutable scratch buffers, runtime state, or classloader-sensitive state as non-shareable unless that state is externalized. - Resolver and serializer hot paths should keep the fast-path/null-slow-path shape obvious. Hoist repeated buffer or cache-state access into locals for multi-step operations and keep rebuild/restoration logic cold. +- Do not use `instanceof` in Java hot paths, including per-value, per-field, per-element, + read/write/copy, resolver, serializer, codec, and buffer paths. Choose concrete + implementations during cold setup or code generation, cache final/static-final shape decisions, + or move type checks behind cold one-time dispatch instead. - Hot-path feature gates that are runtime constants must be `static final` fields read directly in the branch. Do not hide them behind helper methods such as `jdkInternalFieldAccess()`, because that obscures branch folding and can leave avoidable call/inlining work in hot serializers. diff --git a/AGENTS.md b/AGENTS.md index 5c62f3f76d..8b4d2644f3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -9,6 +9,7 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - `.agents/docs-and-formatting.md`: documentation, specification, and markdown rules. - `.agents/ci-and-pr.md`: CI triage, PR expectations, and commit conventions. - `.agents/testing/integration-tests.md`: `integration_tests/` prerequisites, regeneration rules, and commands. +- `docs/security/deserialization.md`: security boundaries for untrusted deserialization classification. - `.agents/languages/java.md` - `.agents/languages/csharp.md` - `.agents/languages/cpp.md` @@ -27,10 +28,11 @@ This is the entry point for AI guidance in Apache Fory. Read this file first, th - Preserve architecture. Do not introduce new layers, parallel flows, or public APIs unless explicitly requested; prefer local repair in the existing owner over shared-infra expansion, and stop if a fix conflicts with an ADR, spec, or invariant. - Respect ownership. Keep logic, state, and helpers in their natural owner, and do not move serializer-local, context-local, runtime-type-local, or protocol-local problems into global utilities. - Check the spec before implementation. For wire behavior and xlang mapping, use the specs as the source of truth and never copy one runtime's bug into another runtime just to make tests pass. +- For untrusted deserialization, read `docs/security/deserialization.md` before changing allocation, stream filling, skip, reference, metadata, or policy validation behavior. Variable-length deserialization must not allocate or reserve from attacker-declared lengths or counts before the byte owner has proven proportional readable bytes with `checkReadableBytes` or the runtime equivalent. - Reject semantic hacks. Do not bypass broken semantics by deleting cases, simplifying callers, adding coercion hooks, or using workaround fallbacks; fix the underlying bug and prove it with focused tests. - Protect hot paths. Avoid per-call allocations, callback objects, result tuples or records, unnecessary runtime branches, and wrapper-class substitutions in hot codec/runtime paths; prefer conditional imports and allocation-free concrete implementations where they fit the language. - Keep public APIs minimal. Public APIs must match user ownership and mental model, not internal implementation details; generated flows stay type-owned, while manual serializer registration stays explicit. -- Use semantic naming only. Name things after protocol or domain concepts, not history, runtime origin, or workaround style; avoid vague names such as `Internal`, `java_style_*`, `Runtime`, `Session`, `Plan`, or `Binding` when they do not name the real concept. Keep class, method, function, and variable names concise; do not encode the whole scenario or implementation history into one identifier. Never name a class or method with a `Plan` suffix; use the real domain concept instead. +- Use semantic naming only. Name things after protocol or domain concepts, not history, runtime origin, or workaround style; avoid vague names such as `Internal`, `java_style_*`, `Runtime`, `Session`, `Plan`, `Payload`, or `Binding` when they do not name the real concept. Keep class, method, function, and variable names concise; do not encode the whole scenario or implementation history into one identifier. Never name a class or method with a `Plan` suffix; use the real domain concept instead. For Fory codec/read APIs, do not use generic `payload` naming; name the exact owner and data shape, such as bytes, body, frame, field, string, list, map, compressed bytes, or primitive-array encoding. - Keep one implementation path. Do not keep parallel helpers, serializers, harnesses, wrappers, or registration flows for the same concept; extend the existing owner path instead of inventing another one. - Follow current scope exactly. The latest explicit user instruction overrides earlier plans, and when scope narrows, remove leaked out-of-scope edits immediately. - Preserve user corrections. When a user corrects code behavior, ownership, invariants, or review feedback in a way that should prevent repeat mistakes, encode the corrected rule where future agents will see it: prefer the nearest source comment for non-obvious code invariants, or the owning docs/spec for user-visible or protocol behavior. If the correction changes API usage, defaults, generated output, tests, or cross-runtime behavior, update the matching docs, examples, or source comments in the same task so future agents do not repeat the violation. Keep the note concise, English-only, and avoid comments that merely restate obvious code. diff --git a/cpp/fory/meta/meta_string.cc b/cpp/fory/meta/meta_string.cc index a548558d47..c6a0b18c4e 100644 --- a/cpp/fory/meta/meta_string.cc +++ b/cpp/fory/meta/meta_string.cc @@ -273,8 +273,11 @@ MetaStringTable::read_string(Buffer &buffer, const MetaStringDecoder &decoder) { return Unexpected(std::move(error)); } (void)hash_code; // hash_code is only used for Java-side caching. - bytes.resize(len); if (len > 0) { + if (FORY_PREDICT_FALSE(!buffer.ensure_readable(len, error))) { + return Unexpected(std::move(error)); + } + bytes.resize(len); buffer.read_bytes(bytes.data(), len, error); if (FORY_PREDICT_FALSE(!error.ok())) { return Unexpected(std::move(error)); @@ -294,6 +297,9 @@ MetaStringTable::read_string(Buffer &buffer, const MetaStringDecoder &decoder) { uint8_t enc_byte = static_cast(enc_byte_res); FORY_TRY(enc, to_meta_encoding(enc_byte)); encoding = enc; + if (FORY_PREDICT_FALSE(!buffer.ensure_readable(len, error))) { + return Unexpected(std::move(error)); + } bytes.resize(len); buffer.read_bytes(bytes.data(), len, error); if (FORY_PREDICT_FALSE(!error.ok())) { diff --git a/cpp/fory/serialization/array_serializer.h b/cpp/fory/serialization/array_serializer.h index b86976306d..25580f18a7 100644 --- a/cpp/fory/serialization/array_serializer.h +++ b/cpp/fory/serialization/array_serializer.h @@ -145,11 +145,12 @@ struct Serializer< return std::array(); } - uint32_t length = size_bytes / sizeof(T); - if (length != N) { - ctx.set_error(Error::invalid_data("Array size mismatch: expected " + - std::to_string(N) + " but got " + - std::to_string(length))); + constexpr size_t expected_bytes = N * sizeof(T); + if (static_cast(size_bytes) != expected_bytes) { + ctx.set_error(Error::invalid_data("Array byte size mismatch: expected " + + std::to_string(expected_bytes) + + " but got " + + std::to_string(size_bytes))); return std::array(); } @@ -368,11 +369,12 @@ template struct Serializer> { if (FORY_PREDICT_FALSE(ctx.has_error())) { return std::array(); } - uint32_t length = size_bytes / sizeof(float16_t); - if (length != N) { - ctx.set_error(Error::invalid_data("Array size mismatch: expected " + - std::to_string(N) + " but got " + - std::to_string(length))); + constexpr size_t expected_bytes = N * sizeof(float16_t); + if (static_cast(size_bytes) != expected_bytes) { + ctx.set_error(Error::invalid_data("Array byte size mismatch: expected " + + std::to_string(expected_bytes) + + " but got " + + std::to_string(size_bytes))); return std::array(); } std::array arr; @@ -480,11 +482,12 @@ template struct Serializer> { if (FORY_PREDICT_FALSE(ctx.has_error())) { return std::array(); } - uint32_t length = size_bytes / sizeof(bfloat16_t); - if (length != N) { - ctx.set_error(Error::invalid_data("Array size mismatch: expected " + - std::to_string(N) + " but got " + - std::to_string(length))); + constexpr size_t expected_bytes = N * sizeof(bfloat16_t); + if (static_cast(size_bytes) != expected_bytes) { + ctx.set_error(Error::invalid_data("Array byte size mismatch: expected " + + std::to_string(expected_bytes) + + " but got " + + std::to_string(size_bytes))); return std::array(); } std::array arr; diff --git a/cpp/fory/serialization/collection_serializer.h b/cpp/fory/serialization/collection_serializer.h index 4856b084b2..473f9d6950 100644 --- a/cpp/fory/serialization/collection_serializer.h +++ b/cpp/fory/serialization/collection_serializer.h @@ -23,6 +23,7 @@ #include "fory/serialization/serializer.h" #include #include +#include #include #include #include @@ -379,6 +380,23 @@ struct has_reserve inline constexpr bool has_reserve_v = has_reserve::value; +template +inline bool reserve_collection(Container &result, ReadContext &ctx, + uint32_t length) { + // Lazy error propagation may continue into later readers; do not let that + // path retain attacker-controlled capacity after an earlier read failure. + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { + return false; + } + if constexpr (has_reserve_v) { + result.reserve(length); + } + return true; +} + // Helper to insert element into container (vector or set) template inline void collection_insert(Container &result, T &&elem) { @@ -392,21 +410,13 @@ inline void collection_insert(Container &result, T &&elem) { /// Read collection data for polymorphic or shared-ref elements. template inline Container read_collection_data_slow(ReadContext &ctx, uint32_t length) { - // Guardrail: Enforce max_collection_size for collection reads - if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return Container{}; - } - Container result; - if constexpr (has_reserve_v) { - result.reserve(length); - } - if (length == 0) { return result; } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } constexpr bool elem_is_polymorphic = is_polymorphic_v; @@ -620,11 +630,6 @@ struct Serializer< if (FORY_PREDICT_FALSE(ctx.has_error())) { return std::vector(); } - // Guardrail: Enforce max_binary_size for binary byte-length reads - if (FORY_PREDICT_FALSE(total_bytes_u32 > ctx.config().max_binary_size)) { - ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); - return std::vector(); - } if (sizeof(T) == 0) { return std::vector(); } @@ -637,10 +642,16 @@ struct Serializer< " not aligned with element size " + std::to_string(sizeof(T)))); return std::vector(); } + if (FORY_PREDICT_FALSE( + !ctx.buffer().ensure_readable(total_bytes_u32, ctx.error()))) { + return std::vector(); + } std::vector result(elem_count); if (total_bytes_u32 > 0) { - ctx.read_bytes(result.data(), static_cast(total_bytes_u32), - ctx.error()); + Buffer &buffer = ctx.buffer(); + std::memcpy(result.data(), buffer.data() + buffer.reader_index(), + total_bytes_u32); + buffer.unsafe_increase_reader_index(total_bytes_u32); } return result; } @@ -734,20 +745,22 @@ template struct Serializer> { if (FORY_PREDICT_FALSE(ctx.has_error())) { return std::vector(); } - if (FORY_PREDICT_FALSE(total_bytes_u32 > ctx.config().max_binary_size)) { - ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); - return std::vector(); - } size_t elem_count = total_bytes_u32 / sizeof(float16_t); if (total_bytes_u32 % sizeof(float16_t) != 0) { ctx.set_error(Error::invalid_data( "Vector byte size not aligned with float16_t element size")); return std::vector(); } + if (FORY_PREDICT_FALSE( + !ctx.buffer().ensure_readable(total_bytes_u32, ctx.error()))) { + return std::vector(); + } std::vector result(elem_count); if (total_bytes_u32 > 0) { - ctx.read_bytes(result.data(), static_cast(total_bytes_u32), - ctx.error()); + Buffer &buffer = ctx.buffer(); + std::memcpy(result.data(), buffer.data() + buffer.reader_index(), + total_bytes_u32); + buffer.unsafe_increase_reader_index(total_bytes_u32); } return result; } @@ -839,20 +852,22 @@ template struct Serializer> { if (FORY_PREDICT_FALSE(ctx.has_error())) { return std::vector(); } - if (FORY_PREDICT_FALSE(total_bytes_u32 > ctx.config().max_binary_size)) { - ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); - return std::vector(); - } size_t elem_count = total_bytes_u32 / sizeof(bfloat16_t); if (total_bytes_u32 % sizeof(bfloat16_t) != 0) { ctx.set_error(Error::invalid_data( "Vector byte size not aligned with bfloat16_t element size")); return std::vector(); } + if (FORY_PREDICT_FALSE( + !ctx.buffer().ensure_readable(total_bytes_u32, ctx.error()))) { + return std::vector(); + } std::vector result(elem_count); if (total_bytes_u32 > 0) { - ctx.read_bytes(result.data(), static_cast(total_bytes_u32), - ctx.error()); + Buffer &buffer = ctx.buffer(); + std::memcpy(result.data(), buffer.data() + buffer.reader_index(), + total_bytes_u32); + buffer.unsafe_increase_reader_index(total_bytes_u32); } return result; } @@ -907,12 +922,6 @@ struct Serializer< return std::vector(); } - if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return std::vector(); - } - // Per xlang spec: header and type_info are omitted when length is 0 if (length == 0) { return std::vector(); @@ -953,7 +962,9 @@ struct Serializer< } std::vector result; - result.reserve(length); + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; + } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { @@ -1045,14 +1056,13 @@ struct Serializer< return std::vector(); } - if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return std::vector(); - } - std::vector result; - result.reserve(size); + if (size == 0) { + return result; + } + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -1141,32 +1151,18 @@ template struct Serializer> { return std::vector(); } - if (FORY_PREDICT_FALSE(size > ctx.config().max_binary_size)) { - ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); + if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(size, ctx.error()))) { return std::vector(); } std::vector result(size); - // Fast path: bulk read all bytes at once if we have enough buffer Buffer &buffer = ctx.buffer(); - if (size > 0 && buffer.reader_index() + size <= buffer.size()) { + if (size > 0) { const uint8_t *src = buffer.data() + buffer.reader_index(); for (uint32_t i = 0; i < size; ++i) { result[i] = (src[i] != 0); } - buffer.increase_reader_index(size, ctx.error()); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::vector(); - } - } else { - // Fallback: read byte-by-byte with bounds checking - for (uint32_t i = 0; i < size; ++i) { - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return result; - } - uint8_t byte = ctx.read_uint8(ctx.error()); - result[i] = (byte != 0); - } + buffer.unsafe_increase_reader_index(size); } return result; } @@ -1221,12 +1217,6 @@ template struct Serializer> { return std::list(); } - if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return std::list(); - } - // Per xlang spec: header and type_info are omitted when length is 0 if (length == 0) { return std::list(); @@ -1358,12 +1348,6 @@ template struct Serializer> { return std::list(); } - if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return std::list(); - } - std::list result; for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -1425,12 +1409,6 @@ template struct Serializer> { return std::deque(); } - if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return std::deque(); - } - // Per xlang spec: header and type_info are omitted when length is 0 if (length == 0) { return std::deque(); @@ -1562,12 +1540,6 @@ template struct Serializer> { return std::deque(); } - if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return std::deque(); - } - std::deque result; for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -1630,23 +1602,15 @@ struct Serializer> { return std::forward_list(); } - if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return std::forward_list(); - } - // Per xlang spec: header and type_info are omitted when length is 0 if (length == 0) { return std::forward_list(); } + // Dispatch to slow path for polymorphic/shared-ref elements // Read elements into a temporary vector then build forward_list // (forward_list doesn't have push_back, only push_front) std::vector temp; - temp.reserve(length); - - // Dispatch to slow path for polymorphic/shared-ref elements constexpr bool is_slow_path = is_polymorphic_v || is_shared_ref_v; if constexpr (is_slow_path) { temp = read_collection_data_slow>(ctx, length); @@ -1680,6 +1644,9 @@ struct Serializer> { } } + if (FORY_PREDICT_FALSE(!reserve_collection(temp, ctx, length))) { + return std::forward_list(); + } // Fast path: no tracking, no nulls, elements have declared type if (!track_ref && !has_null && is_same_type) { for (uint32_t i = 0; i < length; ++i) { @@ -2001,14 +1968,10 @@ struct Serializer> { return std::forward_list(); } - if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); + std::vector temp; + if (FORY_PREDICT_FALSE(!reserve_collection(temp, ctx, size))) { return std::forward_list(); } - - std::vector temp; - temp.reserve(size); for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { break; @@ -2106,12 +2069,6 @@ struct Serializer> { return std::set(); } - if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return std::set(); - } - // Per xlang spec: header and type_info are omitted when length is 0 if (size == 0) { return std::set(); @@ -2193,12 +2150,6 @@ struct Serializer> { return std::set(); } - if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return std::set(); - } - std::set result; for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -2293,12 +2244,6 @@ struct Serializer> { return std::unordered_set(); } - if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return std::unordered_set(); - } - // Per xlang spec: header and type_info are omitted when length is 0 if (size == 0) { return std::unordered_set(); @@ -2336,7 +2281,9 @@ struct Serializer> { } std::unordered_set result; - result.reserve(size); + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } if (!track_ref && !has_null && is_same_type) { for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -2382,14 +2329,10 @@ struct Serializer> { return std::unordered_set(); } - if (FORY_PREDICT_FALSE(size > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return std::unordered_set(); - } - std::unordered_set result; - result.reserve(size); + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, size))) { + return result; + } for (uint32_t i = 0; i < size; ++i) { if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; diff --git a/cpp/fory/serialization/collection_serializer_test.cc b/cpp/fory/serialization/collection_serializer_test.cc index 1eb6e692c8..c40f717e33 100644 --- a/cpp/fory/serialization/collection_serializer_test.cc +++ b/cpp/fory/serialization/collection_serializer_test.cc @@ -620,50 +620,6 @@ TEST(CollectionSerializerTest, ForwardListEmptyRoundTrip) { EXPECT_TRUE(deserialized.strings.empty()); } -// Test max_collection_size using objects (e.g., strings) -TEST(CollectionSerializerTest, MaxCollectionSizeNativeGuardrail) { - auto fory = Fory::builder() - .xlang(false) - .max_collection_size(2) - .compatible(false) - .build(); - fory.register_struct(200); - - VectorStringHolder original; - original.strings = {"A", "B", "C"}; - - auto bytes_result = fory.serialize(original); - ASSERT_TRUE(bytes_result.ok()); - - auto deserialize_result = fory.deserialize( - bytes_result->data(), bytes_result->size()); - - ASSERT_FALSE(deserialize_result.ok()); - EXPECT_TRUE(deserialize_result.error().message().find( - "exceeds max_collection_size") != std::string::npos); -} - -// Test max_binary_size using primitive numbers -TEST(CollectionSerializerTest, MaxBinarySizeNativeGuardrail) { - auto fory = Fory::builder() - .xlang(false) - .max_binary_size(10) - .compatible(false) - .build(); - - std::vector large_data = {1, 2, 3, 4, 5}; - - auto bytes_result = fory.serialize(large_data); - ASSERT_TRUE(bytes_result.ok()); - - auto deserialize_result = fory.deserialize>( - bytes_result->data(), bytes_result->size()); - - ASSERT_FALSE(deserialize_result.ok()); - EXPECT_TRUE(deserialize_result.error().message().find( - "exceeds max_binary_size") != std::string::npos); -} - } // namespace } // namespace serialization } // namespace fory diff --git a/cpp/fory/serialization/config.h b/cpp/fory/serialization/config.h index 63062c7ee5..d471c39074 100644 --- a/cpp/fory/serialization/config.h +++ b/cpp/fory/serialization/config.h @@ -52,12 +52,6 @@ struct Config { /// When enabled, avoids duplicating shared objects and handles cycles. bool track_ref = true; - /// Maximum allowed size for binary data in bytes. - uint32_t max_binary_size = 64 * 1024 * 1024; // 64MB default - - /// Maximum allowed number of elements in a collection or entries in a map. - uint32_t max_collection_size = 1024 * 1024; // 1M elements default - /// Default constructor with sensible defaults Config() = default; }; diff --git a/cpp/fory/serialization/decimal_serializers.h b/cpp/fory/serialization/decimal_serializers.h index 64d99f49d6..740e657391 100644 --- a/cpp/fory/serialization/decimal_serializers.h +++ b/cpp/fory/serialization/decimal_serializers.h @@ -22,6 +22,7 @@ #include "fory/serialization/serializer.h" #include +#include #include #include #include @@ -263,10 +264,6 @@ template <> struct Serializer { ctx.set_error(Error::invalid_data("Invalid decimal magnitude length 0")); return Decimal(); } - if (length64 > ctx.config().max_binary_size) { - ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); - return Decimal(); - } if (length64 > std::numeric_limits::max()) { ctx.set_error(Error::invalid_data("Invalid decimal magnitude length " + std::to_string(length64))); @@ -274,18 +271,22 @@ template <> struct Serializer { } uint32_t length = static_cast(length64); - std::vector payload(length); - ctx.buffer().read_bytes(payload.data(), length, ctx.error()); - if (FORY_PREDICT_FALSE(ctx.has_error())) { + if (FORY_PREDICT_FALSE( + !ctx.buffer().ensure_readable(length, ctx.error()))) { return Decimal(); } - if (payload.back() == 0) { + std::vector magnitude(length); + Buffer &buffer = ctx.buffer(); + std::memcpy(magnitude.data(), buffer.data() + buffer.reader_index(), + length); + buffer.unsafe_increase_reader_index(length); + if (magnitude.back() == 0) { ctx.set_error(Error::invalid_data( - "Non-canonical decimal payload: trailing zero byte")); + "Non-canonical decimal magnitude: trailing zero byte")); return Decimal(); } - return Decimal(scale, (meta & 1ULL) != 0, std::move(payload)); + return Decimal(scale, (meta & 1ULL) != 0, std::move(magnitude)); } static inline Decimal read_data_generic(ReadContext &ctx, bool has_generics) { diff --git a/cpp/fory/serialization/fory.h b/cpp/fory/serialization/fory.h index ca99c3f908..3361cb621d 100644 --- a/cpp/fory/serialization/fory.h +++ b/cpp/fory/serialization/fory.h @@ -121,19 +121,6 @@ class ForyBuilder { /// Build a thread-safe Fory instance (uses context pools). ThreadSafeFory build_thread_safe(); - /// Set the maximum allowed size for binary data in bytes. - inline ForyBuilder &max_binary_size(uint32_t size) { - config_.max_binary_size = size; - return *this; - } - - /// Set the maximum allowed number of elements in a collection or entries in a - /// map. - inline ForyBuilder &max_collection_size(uint32_t size) { - config_.max_collection_size = size; - return *this; - } - private: const Config &normalized_config() { if (!compatible_set_) { diff --git a/cpp/fory/serialization/map_serializer.h b/cpp/fory/serialization/map_serializer.h index 8554627350..830e5fbae5 100644 --- a/cpp/fory/serialization/map_serializer.h +++ b/cpp/fory/serialization/map_serializer.h @@ -81,6 +81,20 @@ struct MapReserver +inline bool reserve_map(MapType &map, ReadContext &ctx, uint32_t length) { + // Lazy error propagation may continue into later readers; do not let that + // path retain attacker-controlled capacity after an earlier read failure. + if (FORY_PREDICT_FALSE(ctx.has_error())) { + return false; + } + if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(length, ctx.error()))) { + return false; + } + MapReserver::reserve(map, length); + return true; +} + /// write chunk size at header offset inline void write_chunk_size(WriteContext &ctx, size_t header_offset, uint8_t size) { @@ -551,19 +565,13 @@ inline MapType read_map_data_fast(ReadContext &ctx, uint32_t length) { static_assert(!is_shared_ref_v && !is_shared_ref_v, "Fast path is for non-shared-ref types only"); - // Guardrail: Enforce max_collection_size for map reads (entry count) - if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Map entry count exceeds max_collection_size")); - return MapType{}; - } - MapType result; - MapReserver::reserve(result, length); - if (length == 0) { return result; } + if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { + return result; + } uint32_t len_counter = 0; @@ -689,19 +697,13 @@ inline MapType read_map_data_fast(ReadContext &ctx, uint32_t length) { /// Read map data for polymorphic or shared-ref maps template inline MapType read_map_data_slow(ReadContext &ctx, uint32_t length) { - // Guardrail: Enforce max_collection_size for map reads (entry count) - if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Map entry count exceeds max_collection_size")); - return MapType{}; - } - MapType result; - MapReserver::reserve(result, length); - if (length == 0) { return result; } + if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { + return result; + } constexpr bool key_is_polymorphic = is_polymorphic_v; constexpr bool val_is_polymorphic = is_polymorphic_v; diff --git a/cpp/fory/serialization/map_serializer_test.cc b/cpp/fory/serialization/map_serializer_test.cc index 2d16697f89..5f547abadd 100644 --- a/cpp/fory/serialization/map_serializer_test.cc +++ b/cpp/fory/serialization/map_serializer_test.cc @@ -824,26 +824,6 @@ TEST(MapSerializerTest, LargeMapWithPolymorphicValues) { EXPECT_EQ(deserialized[299]->name, "value_y_299"); } -TEST(MapSerializerTest, MaxMapSizeGuardrail) { - auto fory = Fory::builder() - .xlang(true) - .compatible(false) - .max_collection_size(2) - .build(); - - std::map large_map = {{"a", 1}, {"b", 2}, {"c", 3}}; - - auto serialize_result = fory.serialize(large_map); - ASSERT_TRUE(serialize_result.ok()); - - auto deserialize_result = fory.deserialize>( - serialize_result->data(), serialize_result->size()); - - ASSERT_FALSE(deserialize_result.ok()); - EXPECT_TRUE(deserialize_result.error().message().find( - "exceeds max_collection_size") != std::string::npos); -} - int main(int argc, char **argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); diff --git a/cpp/fory/serialization/serialization_test.cc b/cpp/fory/serialization/serialization_test.cc index c78618f1ea..840a522a66 100644 --- a/cpp/fory/serialization/serialization_test.cc +++ b/cpp/fory/serialization/serialization_test.cc @@ -22,10 +22,12 @@ #include "fory/serialization/skip.h" #include "fory/thirdparty/MurmurHash3.h" #include "gtest/gtest.h" +#include #include #include #include #include +#include #include #include #include @@ -77,6 +79,16 @@ FORY_ENUM(SparseStatus, UNKNOWN, OK); enum OldStatus : int32_t { OLD_NEG = -7, OLD_ZERO = 0, OLD_POS = 13 }; FORY_ENUM(::OldStatus, OLD_NEG, OLD_ZERO, OLD_POS); +static std::atomic g_ext_destructor_calls{0}; + +struct ExtWithDestructor { + int32_t value = 0; + + ~ExtWithDestructor() { g_ext_destructor_calls.fetch_add(1); } + + FORY_STRUCT(ExtWithDestructor, value); +}; + namespace fory { namespace serialization { namespace test { @@ -179,6 +191,21 @@ TEST(SerializationTest, BoolRoundtrip) { test_roundtrip(false); } +TEST(SerializationTest, HarnessDestroyRunsRegisteredDestructor) { + auto fory = + Fory::builder().xlang(true).compatible(false).track_ref(false).build(); + auto register_result = fory.register_extension_type<::ExtWithDestructor>(1); + ASSERT_TRUE(register_result.ok()) << register_result.error().message(); + auto type_info_result = + fory.type_resolver().get_type_info<::ExtWithDestructor>(); + ASSERT_TRUE(type_info_result.ok()) << type_info_result.error().message(); + + g_ext_destructor_calls.store(0); + void *ptr = new ExtWithDestructor(); + type_info_result.value()->harness.destroy_fn(ptr); + EXPECT_EQ(g_ext_destructor_calls.load(), 1); +} + TEST(SerializationTest, Int8Roundtrip) { test_roundtrip(0); test_roundtrip(127); @@ -228,6 +255,106 @@ TEST(SerializationTest, StringRoundtrip) { test_roundtrip(std::string("UTF-8: 你好世界")); } +TEST(SerializationTest, StringReadsCheckBodyBeforeAllocation) { + auto fory = + Fory::builder().xlang(true).compatible(false).track_ref(false).build(); + constexpr uint32_t huge_length = std::numeric_limits::max() - 1; + + for (StringEncoding encoding : + {StringEncoding::LATIN1, StringEncoding::UTF16, StringEncoding::UTF8}) { + Buffer buffer; + buffer.write_var_uint36_small((static_cast(huge_length) << 2) | + static_cast(encoding)); + ReadContext read_ctx(fory.config(), fory.type_resolver().clone()); + read_ctx.attach(buffer); + + auto result = detail::read_string_data(read_ctx); + + EXPECT_TRUE(result.empty()); + EXPECT_TRUE(read_ctx.has_error()); + } +} + +TEST(SerializationTest, U16StringReadsCheckBodyBeforeAllocation) { + auto fory = + Fory::builder().xlang(true).compatible(false).track_ref(false).build(); + constexpr uint32_t huge_length = std::numeric_limits::max() - 1; + + for (StringEncoding encoding : + {StringEncoding::LATIN1, StringEncoding::UTF16, StringEncoding::UTF8}) { + Buffer buffer; + buffer.write_var_uint36_small((static_cast(huge_length) << 2) | + static_cast(encoding)); + ReadContext read_ctx(fory.config(), fory.type_resolver().clone()); + read_ctx.attach(buffer); + + auto result = detail::read_u16string_data(read_ctx); + + EXPECT_TRUE(result.empty()); + EXPECT_TRUE(read_ctx.has_error()); + } +} + +TEST(SerializationTest, PrimitiveVectorReadsCheckBodyBeforeAllocation) { + auto fory = + Fory::builder().xlang(true).compatible(false).track_ref(false).build(); + Buffer buffer; + buffer.write_var_uint32(4096); + ReadContext read_ctx(fory.config(), fory.type_resolver().clone()); + read_ctx.attach(buffer); + + auto result = Serializer>::read_data(read_ctx); + + EXPECT_TRUE(result.empty()); + EXPECT_TRUE(read_ctx.has_error()); +} + +TEST(SerializationTest, BoolVectorReadsCheckBodyBeforeAllocation) { + auto fory = + Fory::builder().xlang(true).compatible(false).track_ref(false).build(); + Buffer buffer; + buffer.write_var_uint32(4096); + ReadContext read_ctx(fory.config(), fory.type_resolver().clone()); + read_ctx.attach(buffer); + + auto result = Serializer>::read_data(read_ctx); + + EXPECT_TRUE(result.empty()); + EXPECT_TRUE(read_ctx.has_error()); +} + +TEST(SerializationTest, FixedPrimitiveArrayRejectsWrongByteSize) { + auto fory = + Fory::builder().xlang(true).compatible(false).track_ref(false).build(); + Buffer buffer; + buffer.write_var_uint32(sizeof(int32_t) + 1); + buffer.write_uint32(1); + buffer.write_uint8(0); + ReadContext read_ctx(fory.config(), fory.type_resolver().clone()); + read_ctx.attach(buffer); + + auto result = Serializer>::read_data(read_ctx); + + EXPECT_EQ(result[0], 0); + ASSERT_TRUE(read_ctx.has_error()); + EXPECT_EQ(read_ctx.error().code(), ErrorCode::InvalidData); +} + +TEST(SerializationTest, DecimalReadsCheckBodyBeforeAllocation) { + auto fory = + Fory::builder().xlang(true).compatible(false).track_ref(false).build(); + Buffer buffer; + buffer.write_var_int32(0); + buffer.write_var_uint64((static_cast(4096) << 2) | 1ULL); + ReadContext read_ctx(fory.config(), fory.type_resolver().clone()); + read_ctx.attach(buffer); + + auto result = Serializer::read_data(read_ctx); + + EXPECT_TRUE(result.is_zero()); + EXPECT_TRUE(read_ctx.has_error()); +} + TEST(SerializationTest, DurationRoundtrip) { auto fory = Fory::builder().xlang(true).compatible(false).track_ref(false).build(); diff --git a/cpp/fory/serialization/skip.cc b/cpp/fory/serialization/skip.cc index e18f723bcd..ec75d88e0e 100644 --- a/cpp/fory/serialization/skip.cc +++ b/cpp/fory/serialization/skip.cc @@ -40,6 +40,12 @@ constexpr uint8_t MAP_TRACKING_VALUE_REF = 0b001000; constexpr uint8_t MAP_VALUE_NULL = 0b010000; constexpr uint8_t MAP_DECL_VALUE_TYPE = 0b100000; +void destroy_harness_value(const TypeInfo &type_info, void *ptr) { + if (ptr != nullptr) { + type_info.harness.destroy_fn(ptr); + } +} + bool consume_ref_flag(ReadContext &ctx, bool tracking_ref, bool null_only) { if (!tracking_ref && !null_only) { return true; @@ -98,9 +104,10 @@ void skip_ext_data(ReadContext &ctx, const TypeInfo &type_info) { DynDepthGuard dyn_depth_guard(ctx); void *ptr = type_info.harness.read_data_fn(ctx); if (FORY_PREDICT_FALSE(ctx.has_error())) { + destroy_harness_value(type_info, ptr); return; } - ::operator delete(ptr); + destroy_harness_value(type_info, ptr); } void skip_data_with_type_info(ReadContext &ctx, const TypeInfo *type_info) { @@ -533,20 +540,14 @@ void skip_ext(ReadContext &ctx, const FieldType &) { } DynDepthGuard dyn_depth_guard(ctx); - // Call the harness read_data_fn to skip the data - // The result is a pointer we need to delete + // The harness allocates with the registered concrete type, so skipped values + // must be destroyed through the paired harness hook. void *ptr = type_info->harness.read_data_fn(ctx); if (FORY_PREDICT_FALSE(ctx.has_error())) { + destroy_harness_value(*type_info, ptr); return; } - if (ptr) { - // We just wanted to skip, but harness allocates memory - need to clean up - // This is not ideal but works for now. A better approach would be to - // have a dedicated skip_data function in harness. - // For now, we use operator delete which works for POD types. - // TODO: Consider adding a harness.skip_data_fn or harness.delete_fn - ::operator delete(ptr); - } + destroy_harness_value(*type_info, ptr); } void skip_unknown(ReadContext &ctx) { @@ -776,10 +777,6 @@ void skip_field_value(ReadContext &ctx, const FieldType &field_type, ctx.set_error(Error::invalid_data("Invalid decimal magnitude length 0")); return; } - if (length64 > ctx.config().max_binary_size) { - ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); - return; - } if (length64 > std::numeric_limits::max()) { ctx.set_error(Error::invalid_data("Invalid decimal magnitude length " + std::to_string(length64))); @@ -818,7 +815,7 @@ void skip_field_value(ReadContext &ctx, const FieldType &field_type, return; } } - // Typed primitive arrays encode payload size in bytes, not element count. + // Typed primitive arrays encode byte size, not element count. uint32_t payload_size = ctx.read_var_uint32(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { return; diff --git a/cpp/fory/serialization/string_serializer.h b/cpp/fory/serialization/string_serializer.h index c570e0c25f..fe1363cab1 100644 --- a/cpp/fory/serialization/string_serializer.h +++ b/cpp/fory/serialization/string_serializer.h @@ -25,6 +25,7 @@ #include "fory/util/error.h" #include "fory/util/string_util.h" #include +#include #include #include #include @@ -101,9 +102,6 @@ inline void write_u16string_data(const char16_t *data, size_t size, inline std::string read_string_data(ReadContext &ctx) { // Read size with encoding using varuint36small uint64_t size_with_encoding = ctx.read_var_uint36_small(ctx.error()); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::string(); - } // Extract size and encoding from lower 2 bits uint64_t length = size_with_encoding >> 2; @@ -120,36 +118,34 @@ inline std::string read_string_data(ReadContext &ctx) { } const uint32_t length_u32 = static_cast(length); + if (FORY_PREDICT_FALSE( + !ctx.buffer().ensure_readable(length_u32, ctx.error()))) { + return std::string(); + } + + Buffer &buffer = ctx.buffer(); + const uint8_t *data = buffer.data() + buffer.reader_index(); + // Handle different encodings switch (encoding) { case StringEncoding::LATIN1: { - std::vector bytes(length_u32); - ctx.read_bytes(bytes.data(), length_u32, ctx.error()); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::string(); - } - return latin1_to_utf8(bytes.data(), length_u32); + std::string result = latin1_to_utf8(data, length_u32); + buffer.unsafe_increase_reader_index(length_u32); + return result; } case StringEncoding::UTF16: { - if ((length_u32 & 1) != 0) { + if (FORY_PREDICT_FALSE((length_u32 & 1) != 0)) { ctx.set_error(Error::invalid_data("UTF-16 length must be even")); return std::string(); } std::vector utf16_chars(length_u32 / 2); - ctx.read_bytes(reinterpret_cast(utf16_chars.data()), length_u32, - ctx.error()); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::string(); - } + std::memcpy(utf16_chars.data(), data, length_u32); + buffer.unsafe_increase_reader_index(length_u32); return utf16_to_utf8(utf16_chars.data(), utf16_chars.size()); } case StringEncoding::UTF8: { - // UTF-8: read bytes directly - std::string result(length_u32, '\0'); - ctx.read_bytes(&result[0], length_u32, ctx.error()); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::string(); - } + std::string result(reinterpret_cast(data), length_u32); + buffer.unsafe_increase_reader_index(length_u32); return result; } default: @@ -164,9 +160,6 @@ inline std::string read_string_data(ReadContext &ctx) { inline std::u16string read_u16string_data(ReadContext &ctx) { // Read size with encoding using varuint36small uint64_t size_with_encoding = ctx.read_var_uint36_small(ctx.error()); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::u16string(); - } // Extract size and encoding from lower 2 bits uint64_t length = size_with_encoding >> 2; @@ -183,39 +176,40 @@ inline std::u16string read_u16string_data(ReadContext &ctx) { } const uint32_t length_u32 = static_cast(length); + if (FORY_PREDICT_FALSE( + !ctx.buffer().ensure_readable(length_u32, ctx.error()))) { + return std::u16string(); + } + + Buffer &buffer = ctx.buffer(); + const uint8_t *data = buffer.data() + buffer.reader_index(); + // Handle different encodings switch (encoding) { case StringEncoding::LATIN1: { // Latin1 bytes map directly to char16_t (codepoints 0-255) std::u16string result(length_u32, u'\0'); for (size_t i = 0; i < length_u32; ++i) { - result[i] = static_cast(ctx.read_uint8(ctx.error())); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::u16string(); - } + result[i] = static_cast(data[i]); } + buffer.unsafe_increase_reader_index(length_u32); return result; } case StringEncoding::UTF16: { - if ((length_u32 & 1) != 0) { + if (FORY_PREDICT_FALSE((length_u32 & 1) != 0)) { ctx.set_error(Error::invalid_data("UTF-16 length must be even")); return std::u16string(); } std::u16string result(length_u32 / 2, u'\0'); - ctx.read_bytes(reinterpret_cast(&result[0]), length_u32, - ctx.error()); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::u16string(); - } + std::memcpy(&result[0], data, length_u32); + buffer.unsafe_increase_reader_index(length_u32); return result; } case StringEncoding::UTF8: { // Read UTF-8 bytes and convert to UTF-16 std::string utf8(length_u32, '\0'); - ctx.read_bytes(&utf8[0], length_u32, ctx.error()); - if (FORY_PREDICT_FALSE(ctx.has_error())) { - return std::u16string(); - } + std::memcpy(&utf8[0], data, length_u32); + buffer.unsafe_increase_reader_index(length_u32); return utf8_to_utf16(utf8, true /* little endian */); } default: diff --git a/cpp/fory/serialization/struct_serializer.h b/cpp/fory/serialization/struct_serializer.h index 967d5775d7..d8a6e2dd24 100644 --- a/cpp/fory/serialization/struct_serializer.h +++ b/cpp/fory/serialization/struct_serializer.h @@ -897,11 +897,11 @@ Container read_configured_list_data(ReadContext &ctx) { using Elem = element_type_t; uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; - if (FORY_PREDICT_FALSE(ctx.has_error()) || length == 0) { + if (length == 0) { return result; } - if constexpr (has_reserve_v) { - result.reserve(length); + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { @@ -943,11 +943,6 @@ FORY_NOINLINE Container read_configured_list_data_as_array_field( if (FORY_PREDICT_FALSE(ctx.has_error()) || length == 0) { return result; } - if (FORY_PREDICT_FALSE(length > ctx.config().max_collection_size)) { - ctx.set_error( - Error::invalid_data("Collection length exceeds max_collection_size")); - return result; - } uint8_t bitmap = ctx.read_uint8(ctx.error()); if (FORY_PREDICT_FALSE(ctx.has_error())) { return result; @@ -971,8 +966,8 @@ FORY_NOINLINE Container read_configured_list_data_as_array_field( "compatible list to array field requires declared elements")); return result; } - if constexpr (has_reserve_v) { - result.reserve(length); + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { + return result; } for (uint32_t i = 0; i < length; ++i) { if constexpr (is_raw_primitive_v) { @@ -1056,7 +1051,12 @@ MapType read_configured_map_data(ReadContext &ctx) { using Value = mapped_type_t; uint32_t length = ctx.read_var_uint32(ctx.error()); MapType result; - MapReserver::reserve(result, length); + if (length == 0) { + return result; + } + if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { + return result; + } uint32_t read_count = 0; while (read_count < length && !ctx.has_error()) { uint8_t header = ctx.read_uint8(ctx.error()); diff --git a/cpp/fory/serialization/type_info.h b/cpp/fory/serialization/type_info.h index f31b14cc0f..d2ee3d4bc1 100644 --- a/cpp/fory/serialization/type_info.h +++ b/cpp/fory/serialization/type_info.h @@ -64,6 +64,7 @@ struct Harness { using WriteDataFn = void (*)(const void *value, WriteContext &ctx, bool has_generics); using ReadDataFn = void *(*)(ReadContext &ctx); + using DestroyFn = void (*)(void *value); using ReadCompatibleFn = void *(*)(ReadContext &ctx, const struct TypeInfo *type_info); using SortedFieldInfosFn = @@ -73,22 +74,25 @@ struct Harness { Harness() = default; Harness(WriteFn write, ReadFn read, WriteDataFn write_data, - ReadDataFn read_data, SortedFieldInfosFn sorted_fields, + ReadDataFn read_data, DestroyFn destroy, + SortedFieldInfosFn sorted_fields, ReadCompatibleFn read_compatible = nullptr) : write_fn(write), read_fn(read), write_data_fn(write_data), - read_data_fn(read_data), sorted_field_infos_fn(sorted_fields), + read_data_fn(read_data), destroy_fn(destroy), + sorted_field_infos_fn(sorted_fields), read_compatible_fn(read_compatible) {} bool valid() const { return write_fn != nullptr && read_fn != nullptr && write_data_fn != nullptr && read_data_fn != nullptr && - sorted_field_infos_fn != nullptr; + destroy_fn != nullptr && sorted_field_infos_fn != nullptr; } WriteFn write_fn = nullptr; ReadFn read_fn = nullptr; WriteDataFn write_data_fn = nullptr; ReadDataFn read_data_fn = nullptr; + DestroyFn destroy_fn = nullptr; SortedFieldInfosFn sorted_field_infos_fn = nullptr; ReadCompatibleFn read_compatible_fn = nullptr; AnyWriteFn any_write_fn = nullptr; diff --git a/cpp/fory/serialization/type_resolver.cc b/cpp/fory/serialization/type_resolver.cc index 92971affd8..45db2d6541 100644 --- a/cpp/fory/serialization/type_resolver.cc +++ b/cpp/fory/serialization/type_resolver.cc @@ -244,6 +244,16 @@ Result FieldInfo::from_bytes(Buffer &buffer) { // special characters (same as Encoders.FIELD_NAME_DECODER). const size_t name_size = size_field + 1; + if (FORY_PREDICT_FALSE( + name_size > + static_cast(std::numeric_limits::max()))) { + return Unexpected( + Error::invalid_data("Field name size exceeds uint32 range")); + } + if (FORY_PREDICT_FALSE( + !buffer.ensure_readable(static_cast(name_size), error))) { + return Unexpected(std::move(error)); + } std::vector name_bytes(name_size); buffer.read_bytes(name_bytes.data(), static_cast(name_size), error); if (FORY_PREDICT_FALSE(!error.ok())) { @@ -486,6 +496,15 @@ read_meta_name(Buffer &buffer, const MetaStringDecoder &decoder, length = BIG_NAME_THRESHOLD + static_cast(extra); } + if (FORY_PREDICT_FALSE( + length > static_cast(std::numeric_limits::max()))) { + return Unexpected( + Error::invalid_data("Meta string size exceeds uint32 range")); + } + if (FORY_PREDICT_FALSE( + !buffer.ensure_readable(static_cast(length), error))) { + return Unexpected(std::move(error)); + } std::vector bytes(length); if (length > 0) { buffer.read_bytes(bytes.data(), static_cast(length), error); @@ -618,6 +637,9 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta *local_type_info) { FORY_TRY(meta_size, read_type_meta_size(buffer, header_bits, &header_size)); int64_t meta_hash = static_cast(header_bits >> TYPE_META_HASH_SHIFT); uint32_t body_start = static_cast(start_pos + header_size); + if (FORY_PREDICT_FALSE(!buffer.ensure_readable(meta_size, error))) { + return Unexpected(std::move(error)); + } // Read meta header uint8_t meta_header = buffer.read_uint8(error); if (FORY_PREDICT_FALSE(!error.ok())) { @@ -684,6 +706,10 @@ TypeMeta::from_bytes(Buffer &buffer, const TypeMeta *local_type_info) { } // Read field infos + if (FORY_PREDICT_FALSE(num_fields > buffer.remaining_size())) { + return Unexpected( + Error::invalid_data("TypeMeta field count exceeds remaining metadata")); + } std::vector field_infos; field_infos.reserve(num_fields); for (size_t i = 0; i < num_fields; ++i) { @@ -734,6 +760,9 @@ TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header) { uint32_t start_pos = buffer.reader_index(); Error error; + if (FORY_PREDICT_FALSE(!buffer.ensure_readable(meta_size, error))) { + return Unexpected(std::move(error)); + } // Read meta header uint8_t meta_header = buffer.read_uint8(error); @@ -801,6 +830,10 @@ TypeMeta::from_bytes_with_header(Buffer &buffer, int64_t header) { } // Read field infos + if (FORY_PREDICT_FALSE(num_fields > buffer.remaining_size())) { + return Unexpected( + Error::invalid_data("TypeMeta field count exceeds remaining metadata")); + } std::vector field_infos; field_infos.reserve(num_fields); for (size_t i = 0; i < num_fields; ++i) { diff --git a/cpp/fory/serialization/type_resolver.h b/cpp/fory/serialization/type_resolver.h index e51d535a05..85774d1689 100644 --- a/cpp/fory/serialization/type_resolver.h +++ b/cpp/fory/serialization/type_resolver.h @@ -1456,6 +1456,10 @@ class TypeResolver { template static void *harness_read_data_adapter_abstract(ReadContext &ctx); + template static void harness_destroy_adapter(void *ptr); + + static void harness_destroy_adapter_noop(void *ptr); + template static Result, Error> harness_struct_sorted_fields(TypeResolver &resolver); @@ -2044,6 +2048,7 @@ Harness TypeResolver::make_struct_harness_impl(std::true_type) { &TypeResolver::harness_read_adapter_abstract, &TypeResolver::harness_write_data_adapter, &TypeResolver::harness_read_data_adapter_abstract, + &TypeResolver::harness_destroy_adapter_noop, &TypeResolver::harness_struct_sorted_fields, &TypeResolver::harness_read_compatible_adapter_abstract); harness.any_write_fn = &detail::any_write_adapter; @@ -2057,6 +2062,7 @@ Harness TypeResolver::make_struct_harness_impl(std::false_type) { &TypeResolver::harness_read_adapter, &TypeResolver::harness_write_data_adapter, &TypeResolver::harness_read_data_adapter, + &TypeResolver::harness_destroy_adapter, &TypeResolver::harness_struct_sorted_fields, &TypeResolver::harness_read_compatible_adapter); harness.any_write_fn = &detail::any_write_adapter; @@ -2069,6 +2075,7 @@ template Harness TypeResolver::make_serializer_harness() { &TypeResolver::harness_read_adapter, &TypeResolver::harness_write_data_adapter, &TypeResolver::harness_read_data_adapter, + &TypeResolver::harness_destroy_adapter, &TypeResolver::harness_empty_sorted_fields); harness.any_write_fn = &detail::any_write_adapter; harness.any_read_fn = &detail::any_read_adapter; @@ -2124,6 +2131,12 @@ void *TypeResolver::harness_read_data_adapter_abstract(ReadContext &ctx) { return nullptr; } +template void TypeResolver::harness_destroy_adapter(void *ptr) { + delete static_cast(ptr); +} + +inline void TypeResolver::harness_destroy_adapter_noop(void *ptr) { (void)ptr; } + template void *TypeResolver::harness_read_compatible_adapter(ReadContext &ctx, const TypeInfo *ti) { diff --git a/cpp/fory/serialization/union_serializer.h b/cpp/fory/serialization/union_serializer.h index 2fda11fac6..d5247d431f 100644 --- a/cpp/fory/serialization/union_serializer.h +++ b/cpp/fory/serialization/union_serializer.h @@ -465,10 +465,10 @@ Container read_union_configured_list_data(ReadContext &ctx) { using Elem = element_type_t; uint32_t length = ctx.read_var_uint32(ctx.error()); Container result; - if constexpr (has_reserve_v) { - result.reserve(length); + if (length == 0) { + return result; } - if (FORY_PREDICT_FALSE(ctx.has_error()) || length == 0) { + if (FORY_PREDICT_FALSE(!reserve_collection(result, ctx, length))) { return result; } uint8_t bitmap = ctx.read_uint8(ctx.error()); @@ -552,7 +552,12 @@ MapType read_union_configured_map_data(ReadContext &ctx) { using Value = mapped_type_t; uint32_t length = ctx.read_var_uint32(ctx.error()); MapType result; - MapReserver::reserve(result, length); + if (length == 0) { + return result; + } + if (FORY_PREDICT_FALSE(!reserve_map(result, ctx, length))) { + return result; + } uint32_t read_count = 0; while (read_count < length && !ctx.has_error()) { uint8_t header = ctx.read_uint8(ctx.error()); diff --git a/cpp/fory/serialization/unsigned_serializer.h b/cpp/fory/serialization/unsigned_serializer.h index 29d99b330d..57c895b715 100644 --- a/cpp/fory/serialization/unsigned_serializer.h +++ b/cpp/fory/serialization/unsigned_serializer.h @@ -25,6 +25,7 @@ #include "fory/util/error.h" #include #include +#include #include #include @@ -362,13 +363,18 @@ template struct Serializer> { static inline std::array read_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); - if (FORY_PREDICT_FALSE(length != N || length * sizeof(uint8_t) > - ctx.buffer().remaining_size())) { + if (FORY_PREDICT_FALSE(length != N)) { ctx.set_error(Error::invalid_data("Array size mismatch: expected " + std::to_string(N) + " but got " + std::to_string(length))); return std::array(); } + if constexpr (N > 0) { + if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(N * sizeof(uint8_t), + ctx.error()))) { + return std::array(); + } + } std::array arr; if constexpr (N > 0) { ctx.read_bytes(arr.data(), N * sizeof(uint8_t), ctx.error()); @@ -446,13 +452,18 @@ template struct Serializer> { static inline std::array read_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); - if (FORY_PREDICT_FALSE(length != N || length * sizeof(uint16_t) > - ctx.buffer().remaining_size())) { + if (FORY_PREDICT_FALSE(length != N)) { ctx.set_error(Error::invalid_data("Array size mismatch: expected " + std::to_string(N) + " but got " + std::to_string(length))); return std::array(); } + if constexpr (N > 0) { + if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(N * sizeof(uint16_t), + ctx.error()))) { + return std::array(); + } + } std::array arr; if constexpr (N > 0) { ctx.read_bytes(arr.data(), N * sizeof(uint16_t), ctx.error()); @@ -530,13 +541,18 @@ template struct Serializer> { static inline std::array read_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); - if (FORY_PREDICT_FALSE(length != N || length * sizeof(uint32_t) > - ctx.buffer().remaining_size())) { + if (FORY_PREDICT_FALSE(length != N)) { ctx.set_error(Error::invalid_data("Array size mismatch: expected " + std::to_string(N) + " but got " + std::to_string(length))); return std::array(); } + if constexpr (N > 0) { + if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(N * sizeof(uint32_t), + ctx.error()))) { + return std::array(); + } + } std::array arr; if constexpr (N > 0) { ctx.read_bytes(arr.data(), N * sizeof(uint32_t), ctx.error()); @@ -614,13 +630,18 @@ template struct Serializer> { static inline std::array read_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); - if (FORY_PREDICT_FALSE(length != N || length * sizeof(uint64_t) > - ctx.buffer().remaining_size())) { + if (FORY_PREDICT_FALSE(length != N)) { ctx.set_error(Error::invalid_data("Array size mismatch: expected " + std::to_string(N) + " but got " + std::to_string(length))); return std::array(); } + if constexpr (N > 0) { + if (FORY_PREDICT_FALSE(!ctx.buffer().ensure_readable(N * sizeof(uint64_t), + ctx.error()))) { + return std::array(); + } + } std::array arr; if constexpr (N > 0) { ctx.read_bytes(arr.data(), N * sizeof(uint64_t), ctx.error()); @@ -704,19 +725,15 @@ template <> struct Serializer> { static inline std::vector read_data(ReadContext &ctx) { uint32_t length = ctx.read_var_uint32(ctx.error()); - if (FORY_PREDICT_FALSE(length > ctx.config().max_binary_size)) { - ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); - return std::vector(); - } - - if (FORY_PREDICT_FALSE(length > ctx.buffer().remaining_size())) { - ctx.set_error( - Error::invalid_data("Invalid length: " + std::to_string(length))); + if (FORY_PREDICT_FALSE( + !ctx.buffer().ensure_readable(length, ctx.error()))) { return std::vector(); } std::vector vec(length); if (length > 0) { - ctx.read_bytes(vec.data(), length, ctx.error()); + Buffer &buffer = ctx.buffer(); + std::memcpy(vec.data(), buffer.data() + buffer.reader_index(), length); + buffer.unsafe_increase_reader_index(length); } return vec; } @@ -805,25 +822,22 @@ template <> struct Serializer> { return std::vector(); } - if (FORY_PREDICT_FALSE(total_bytes > ctx.config().max_binary_size)) { - ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); - return std::vector(); - } - if (total_bytes % sizeof(uint16_t) != 0) { ctx.set_error(Error::invalid_data("Invalid length: " + std::to_string(total_bytes))); return std::vector(); } - if (FORY_PREDICT_FALSE(total_bytes > ctx.buffer().remaining_size())) { - ctx.set_error(Error::invalid_data("Invalid length: " + - std::to_string(total_bytes))); + if (FORY_PREDICT_FALSE( + !ctx.buffer().ensure_readable(total_bytes, ctx.error()))) { return std::vector(); } size_t length = total_bytes / sizeof(uint16_t); std::vector vec(length); if (total_bytes > 0) { - ctx.read_bytes(vec.data(), total_bytes, ctx.error()); + Buffer &buffer = ctx.buffer(); + std::memcpy(vec.data(), buffer.data() + buffer.reader_index(), + total_bytes); + buffer.unsafe_increase_reader_index(total_bytes); } return vec; } @@ -913,25 +927,22 @@ template <> struct Serializer> { return std::vector(); } - if (FORY_PREDICT_FALSE(total_bytes > ctx.config().max_binary_size)) { - ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); - return std::vector(); - } - if (total_bytes % sizeof(uint32_t) != 0) { ctx.set_error(Error::invalid_data("Invalid length: " + std::to_string(total_bytes))); return std::vector(); } - if (FORY_PREDICT_FALSE(total_bytes > ctx.buffer().remaining_size())) { - ctx.set_error(Error::invalid_data("Invalid length: " + - std::to_string(total_bytes))); + if (FORY_PREDICT_FALSE( + !ctx.buffer().ensure_readable(total_bytes, ctx.error()))) { return std::vector(); } size_t length = total_bytes / sizeof(uint32_t); std::vector vec(length); if (total_bytes > 0) { - ctx.read_bytes(vec.data(), total_bytes, ctx.error()); + Buffer &buffer = ctx.buffer(); + std::memcpy(vec.data(), buffer.data() + buffer.reader_index(), + total_bytes); + buffer.unsafe_increase_reader_index(total_bytes); } return vec; } @@ -1021,25 +1032,22 @@ template <> struct Serializer> { return std::vector(); } - if (FORY_PREDICT_FALSE(total_bytes > ctx.config().max_binary_size)) { - ctx.set_error(Error::invalid_data("Binary size exceeds max_binary_size")); - return std::vector(); - } - if (total_bytes % sizeof(uint64_t) != 0) { ctx.set_error(Error::invalid_data("Invalid length: " + std::to_string(total_bytes))); return std::vector(); } - if (FORY_PREDICT_FALSE(total_bytes > ctx.buffer().remaining_size())) { - ctx.set_error(Error::invalid_data("Invalid length: " + - std::to_string(total_bytes))); + if (FORY_PREDICT_FALSE( + !ctx.buffer().ensure_readable(total_bytes, ctx.error()))) { return std::vector(); } size_t length = total_bytes / sizeof(uint64_t); std::vector vec(length); if (total_bytes > 0) { - ctx.read_bytes(vec.data(), total_bytes, ctx.error()); + Buffer &buffer = ctx.buffer(); + std::memcpy(vec.data(), buffer.data() + buffer.reader_index(), + total_bytes); + buffer.unsafe_increase_reader_index(total_bytes); } return vec; } diff --git a/cpp/fory/serialization/unsigned_serializer_test.cc b/cpp/fory/serialization/unsigned_serializer_test.cc index 729973cb0e..9c2e728d86 100644 --- a/cpp/fory/serialization/unsigned_serializer_test.cc +++ b/cpp/fory/serialization/unsigned_serializer_test.cc @@ -274,28 +274,6 @@ TEST(UnsignedSerializerTest, UnsignedArrayTypeIdsAreDistinct) { static_cast(TypeId::BINARY)); } -TEST(UnsignedSerializerTest, MaxBinarySizeNativeGuardrail) { - // Set limit to 10 bytes - auto fory = Fory::builder() - .xlang(false) - .max_binary_size(10) - .compatible(false) - .build(); - - // 10 elements of uint32_t = 40 bytes > 10 byte limit - std::vector large_data(10, 42); - - auto bytes_result = fory.serialize(large_data); - ASSERT_TRUE(bytes_result.ok()); - - auto deserialize_result = fory.deserialize>( - bytes_result->data(), bytes_result->size()); - - ASSERT_FALSE(deserialize_result.ok()); - EXPECT_TRUE(deserialize_result.error().message().find( - "exceeds max_binary_size") != std::string::npos); -} - } // namespace test } // namespace serialization } // namespace fory diff --git a/cpp/fory/util/buffer.h b/cpp/fory/util/buffer.h index 68aa795332..c86a31ec8e 100644 --- a/cpp/fory/util/buffer.h +++ b/cpp/fory/util/buffer.h @@ -219,6 +219,10 @@ class Buffer { } // Unsafe methods don't check bound + FORY_ALWAYS_INLINE void unsafe_increase_reader_index(uint32_t diff) { + reader_index_ += diff; + } + template FORY_ALWAYS_INLINE void unsafe_put(uint32_t offset, T value) { store_unaligned(data_ + offset, value); diff --git a/cpp/fory/util/buffer_test.cc b/cpp/fory/util/buffer_test.cc index 97043a4fd7..5a2d5e4a1c 100644 --- a/cpp/fory/util/buffer_test.cc +++ b/cpp/fory/util/buffer_test.cc @@ -36,8 +36,11 @@ class OneByteStreamBuf : public std::streambuf { explicit OneByteStreamBuf(std::vector data) : data_(std::move(data)), pos_(0) {} + const std::vector &read_sizes() const { return read_sizes_; } + protected: std::streamsize xsgetn(char *s, std::streamsize count) override { + read_sizes_.push_back(count); if (pos_ >= data_.size() || count <= 0) { return 0; } @@ -59,6 +62,7 @@ class OneByteStreamBuf : public std::streambuf { std::vector data_; size_t pos_; char current_ = 0; + std::vector read_sizes_; }; class OneByteIStream : public std::istream { @@ -68,6 +72,10 @@ class OneByteIStream : public std::istream { rdbuf(&buf_); } + const std::vector &read_sizes() const { + return buf_.read_sizes(); + } + private: OneByteStreamBuf buf_; }; @@ -370,6 +378,19 @@ TEST(Buffer, StreamReadErrorWhenInsufficientData) { EXPECT_EQ(error.code(), ErrorCode::BufferOutOfBound); } +TEST(Buffer, StreamFillDoubleGrowsFromBufferedBytes) { + std::vector raw(17, 0x7); + OneByteIStream one_byte_stream(raw); + StdInputStream stream(one_byte_stream, 4); + + auto fill_result = stream.fill_buffer(100); + EXPECT_FALSE(fill_result.ok()); + ASSERT_FALSE(one_byte_stream.read_sizes().empty()); + EXPECT_EQ(one_byte_stream.read_sizes().front(), 4); + EXPECT_LT(stream.get_buffer().size(), 100U); + EXPECT_LE(stream.get_buffer().size(), 32U); +} + TEST(Buffer, OutputStreamThresholdFlushOnWriteBytes) { CountingOutputStream writer; Buffer *buffer = writer.get_buffer(); diff --git a/cpp/fory/util/stream.cc b/cpp/fory/util/stream.cc index d694c3f842..22aa3e0f9d 100644 --- a/cpp/fory/util/stream.cc +++ b/cpp/fory/util/stream.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include "fory/util/buffer.h" #include "fory/util/logging.h" @@ -127,21 +128,13 @@ Result StdInputStream::fill_buffer(uint32_t min_fill_size) { } const uint32_t read_pos = buffer_->reader_index_; - const uint32_t deficit = min_fill_size - remaining_size(); constexpr uint64_t k_max_u32 = std::numeric_limits::max(); - const uint64_t required = static_cast(buffer_->size_) + deficit; - if (required > k_max_u32) { + const uint64_t target = + static_cast(read_pos) + static_cast(min_fill_size); + if (target > k_max_u32) { return Unexpected( Error::out_of_bound("stream buffer size exceeds uint32 range")); } - if (required > data_.size()) { - uint64_t new_size = - std::max(required, static_cast(data_.size()) * 2); - if (new_size > k_max_u32) { - new_size = k_max_u32; - } - reserve(static_cast(new_size)); - } std::streambuf *source = stream_->rdbuf(); if (source == nullptr) { @@ -149,6 +142,23 @@ Result StdInputStream::fill_buffer(uint32_t min_fill_size) { } uint32_t write_pos = buffer_->size_; while (remaining_size() < min_fill_size) { + if (write_pos == data_.size()) { + // min_fill_size can come from attacker-controlled wire lengths. Do not + // query stream availability here: the virtual probe is not part of the + // correctness contract and would add an extra hot-path call for a rare + // fast path. Grow only from bytes already buffered so truncated streams + // fail before reserving the declared body size. + uint64_t new_size = + std::max(static_cast(data_.size()) * 2, + static_cast(initial_buffer_size_)); + if (new_size <= data_.size()) { + new_size = static_cast(data_.size()) + 1; + } + if (new_size > target) { + new_size = target; + } + reserve(static_cast(new_size)); + } uint32_t writable = static_cast(data_.size()) - write_pos; const std::streamsize read_bytes = source->sgetn(reinterpret_cast(data_.data() + write_pos), diff --git a/csharp/src/Fory.Generator/ForyModelGenerator.cs b/csharp/src/Fory.Generator/ForyModelGenerator.cs index 2e6ad7d4b7..fd1bfefba0 100644 --- a/csharp/src/Fory.Generator/ForyModelGenerator.cs +++ b/csharp/src/Fory.Generator/ForyModelGenerator.cs @@ -1010,7 +1010,7 @@ private static void EmitCompatibleFieldCodecMethod( sb.AppendLine(" if (remoteFieldType.TypeId == (uint)global::Apache.Fory.TypeId.Binary)"); sb.AppendLine(" {"); EmitReadNullOnlyPrefix(sb, member, 4); - EmitReadBinaryPayload(sb, codec, $"__{memberId}BinaryValue", 4); + EmitReadBinaryField(sb, codec, $"__{memberId}BinaryValue", 4); sb.AppendLine($" return __{memberId}BinaryValue;"); sb.AppendLine(" }"); } @@ -1037,7 +1037,7 @@ private static void EmitReadNullOnlyPrefix(StringBuilder sb, MemberModel member, sb.AppendLine($"{indent}}}"); } - private static void EmitReadBinaryPayload( + private static void EmitReadBinaryField( StringBuilder sb, FieldCodecModel codec, string targetVar, @@ -1053,6 +1053,7 @@ private static void EmitReadBinaryPayload( if (codec.CarrierKind == CarrierKind.List) { + sb.AppendLine($"{indent}context.Reader.CheckBound(__foryLength);"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new(__foryLength);"); sb.AppendLine($"{indent}for (int __foryIndex = 0; __foryIndex < __foryLength; __foryIndex++)"); sb.AppendLine($"{indent}{{"); @@ -1154,6 +1155,10 @@ private static void EmitReadCompatibleListArrayPayload( sb.AppendLine($"{innerIndent} }}"); sb.AppendLine($"{innerIndent}}}"); sb.AppendLine($"{indent}}}"); + sb.AppendLine($"{indent}if ({lengthVar} != 0)"); + sb.AppendLine($"{indent}{{"); + sb.AppendLine($"{indent} context.Reader.CheckBound({lengthVar});"); + sb.AppendLine($"{indent}}}"); string elementTypeName = codec.CarrierKind == CarrierKind.Array ? ElementTypeName(codec.TypeName) : PackedArrayElementTypeName(codec.TypeId); uint elementTypeId = PackedArrayElementTypeId(codec.TypeId); if (codec.CarrierKind == CarrierKind.Array) @@ -1488,19 +1493,20 @@ private static void EmitReadPackedArrayPayload( string indent = new(' ', indentLevel * 4); int width = PackedArrayElementWidth(codec.TypeId); uint elementTypeId = PackedArrayElementTypeId(codec.TypeId); - string payloadSizeVar = $"__foryPayloadSize{id++}"; + string byteSizeVar = $"__foryByteSize{id++}"; string countVar = $"__foryPackedCount{id++}"; - sb.AppendLine($"{indent}int {payloadSizeVar} = checked((int)context.Reader.ReadVarUInt32());"); + sb.AppendLine($"{indent}int {byteSizeVar} = checked((int)context.Reader.ReadVarUInt32());"); if (width > 1) { int mask = width - 1; - sb.AppendLine($"{indent}if (({payloadSizeVar} & {mask}) != 0)"); + sb.AppendLine($"{indent}if (({byteSizeVar} & {mask}) != 0)"); sb.AppendLine($"{indent}{{"); - sb.AppendLine($"{indent} throw new global::Apache.Fory.InvalidDataException(\"packed array payload size mismatch\");"); + sb.AppendLine($"{indent} throw new global::Apache.Fory.InvalidDataException(\"packed array byte size mismatch\");"); sb.AppendLine($"{indent}}}"); } - sb.AppendLine($"{indent}int {countVar} = {payloadSizeVar}{(width == 1 ? string.Empty : $" / {width}")};"); + sb.AppendLine($"{indent}context.Reader.CheckBound({byteSizeVar});"); + sb.AppendLine($"{indent}int {countVar} = {byteSizeVar}{(width == 1 ? string.Empty : $" / {width}")};"); if (codec.CarrierKind == CarrierKind.Array) { sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new {ElementTypeName(codec.TypeName)}[{countVar}];"); @@ -1546,6 +1552,10 @@ private static void EmitReadCollectionPayload( string sameTypeVar = $"__forySameType{id++}"; string declaredVar = $"__foryDeclared{id++}"; sb.AppendLine($"{indent}int {lengthVar} = checked((int)context.Reader.ReadVarUInt32());"); + sb.AppendLine($"{indent}if ({lengthVar} != 0)"); + sb.AppendLine($"{indent}{{"); + sb.AppendLine($"{indent} context.Reader.CheckBound({lengthVar});"); + sb.AppendLine($"{indent}}}"); if (isSet) { sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new();"); @@ -1647,6 +1657,10 @@ private static void EmitReadMapPayload( FieldCodecModel value = codec.Generics[1]; string totalVar = $"__foryTotal{id++}"; sb.AppendLine($"{indent}int {totalVar} = checked((int)context.Reader.ReadVarUInt32());"); + sb.AppendLine($"{indent}if ({totalVar} != 0)"); + sb.AppendLine($"{indent}{{"); + sb.AppendLine($"{indent} context.Reader.CheckBound({totalVar});"); + sb.AppendLine($"{indent}}}"); sb.AppendLine($"{indent}{codec.TypeName} {targetVar} = new({totalVar});"); sb.AppendLine($"{indent}int __foryRead = 0;"); sb.AppendLine($"{indent}while (__foryRead < {totalVar})"); @@ -1724,7 +1738,7 @@ private static void EmitReadInlineTypeInfo( if (!CanValidateInlineTypeInfo(codec.TypeId)) { sb.AppendLine( - $"{indent}throw new global::Apache.Fory.InvalidDataException(\"generated field payload requires declared nested user type metadata\");"); + $"{indent}throw new global::Apache.Fory.InvalidDataException(\"generated field value requires declared nested user type metadata\");"); return; } diff --git a/csharp/src/Fory/ByteBuffer.cs b/csharp/src/Fory/ByteBuffer.cs index eb3afc8276..84c7aab399 100644 --- a/csharp/src/Fory/ByteBuffer.cs +++ b/csharp/src/Fory/ByteBuffer.cs @@ -482,7 +482,7 @@ public void MoveBack(int amount) public void CheckBound(int need) { - if (_cursor + need > _length) + if (need < 0 || need > _length - _cursor) { throw new OutOfBoundsException(_cursor, need, _length); } diff --git a/csharp/src/Fory/CollectionSerializers.cs b/csharp/src/Fory/CollectionSerializers.cs index e48a57d349..c407153fd5 100644 --- a/csharp/src/Fory/CollectionSerializers.cs +++ b/csharp/src/Fory/CollectionSerializers.cs @@ -213,6 +213,7 @@ public static List ReadCollectionData(Serializer elementSerializer, Rea bool hasNull = (header & CollectionBits.HasNull) != 0; bool declared = (header & CollectionBits.DeclaredElementType) != 0; bool sameType = (header & CollectionBits.SameType) != 0; + context.Reader.CheckBound(length); List values = new(length); if (!sameType) { diff --git a/csharp/src/Fory/DecimalSerializer.cs b/csharp/src/Fory/DecimalSerializer.cs index 07208a53df..a300a213cf 100644 --- a/csharp/src/Fory/DecimalSerializer.cs +++ b/csharp/src/Fory/DecimalSerializer.cs @@ -115,11 +115,11 @@ public static void Write(ByteWriter buffer, int scale, BigInteger unscaled) throw new InvalidDataException("zero must use the small decimal encoding"); } - byte[] payload = magnitude.ToByteArray(isUnsigned: true, isBigEndian: false); - ulong meta = ((ulong)payload.Length << 1) | (unscaled.Sign < 0 ? 1UL : 0UL); + byte[] magnitudeBytes = magnitude.ToByteArray(isUnsigned: true, isBigEndian: false); + ulong meta = ((ulong)magnitudeBytes.Length << 1) | (unscaled.Sign < 0 ? 1UL : 0UL); ulong header = (meta << 1) | 1UL; buffer.WriteVarUInt64(header); - buffer.WriteBytes(payload); + buffer.WriteBytes(magnitudeBytes); } public static (int Scale, BigInteger Unscaled) Read(ByteReader buffer) @@ -139,13 +139,13 @@ public static (int Scale, BigInteger Unscaled) Read(ByteReader buffer) } int length = checked((int)lenLong); - byte[] payload = buffer.ReadBytes(length); - if (payload[^1] == 0) + byte[] magnitudeBytes = buffer.ReadBytes(length); + if (magnitudeBytes[^1] == 0) { - throw new InvalidDataException("non-canonical decimal payload: trailing zero byte"); + throw new InvalidDataException("non-canonical decimal magnitude bytes: trailing zero byte"); } - BigInteger magnitude = new(payload, isUnsigned: true, isBigEndian: false); + BigInteger magnitude = new(magnitudeBytes, isUnsigned: true, isBigEndian: false); if (magnitude.IsZero) { throw new InvalidDataException("big decimal encoding must not represent zero"); diff --git a/csharp/src/Fory/DictionarySerializers.cs b/csharp/src/Fory/DictionarySerializers.cs index 3d2e91e3ea..5aa49dfa75 100644 --- a/csharp/src/Fory/DictionarySerializers.cs +++ b/csharp/src/Fory/DictionarySerializers.cs @@ -217,6 +217,7 @@ public override TDictionary ReadData(ReadContext context) return CreateMap(0); } + context.Reader.CheckBound(totalLength); TDictionary map = CreateMap(totalLength); bool keyDynamicType = keyTypeInfo.IsDynamicType; bool valueDynamicType = valueTypeInfo.IsDynamicType; diff --git a/csharp/src/Fory/FieldSkipper.cs b/csharp/src/Fory/FieldSkipper.cs index 010fbefb73..aaca46ec23 100644 --- a/csharp/src/Fory/FieldSkipper.cs +++ b/csharp/src/Fory/FieldSkipper.cs @@ -319,8 +319,8 @@ private static void SkipPayload(ReadContext context, TypeMetaFieldType fieldType private static void SkipPackedArray(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - context.Reader.Skip(payloadSize); + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + context.Reader.Skip(byteSize); } private static void SkipListOrSet(ReadContext context, TypeMetaFieldType fieldType) diff --git a/csharp/src/Fory/NullableKeyDictionary.cs b/csharp/src/Fory/NullableKeyDictionary.cs index 2a297b438f..d6c8caab47 100644 --- a/csharp/src/Fory/NullableKeyDictionary.cs +++ b/csharp/src/Fory/NullableKeyDictionary.cs @@ -540,7 +540,8 @@ public override NullableKeyDictionary ReadData(ReadContext context return new NullableKeyDictionary(); } - NullableKeyDictionary map = new(); + context.Reader.CheckBound(totalLength); + NullableKeyDictionary map = new(totalLength); bool keyDynamicType = keyTypeInfo.IsDynamicType; bool valueDynamicType = valueTypeInfo.IsDynamicType; int readCount = 0; diff --git a/csharp/src/Fory/PrimitiveArraySerializers.cs b/csharp/src/Fory/PrimitiveArraySerializers.cs index 0f618da9a5..2e063db7c5 100644 --- a/csharp/src/Fory/PrimitiveArraySerializers.cs +++ b/csharp/src/Fory/PrimitiveArraySerializers.cs @@ -37,9 +37,10 @@ public override void WriteData(WriteContext context, in bool[] value, bool hasGe public override bool[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - bool[] values = new bool[payloadSize]; - for (int i = 0; i < payloadSize; i++) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + context.Reader.CheckBound(byteSize); + bool[] values = new bool[byteSize]; + for (int i = 0; i < byteSize; i++) { values[i] = context.Reader.ReadUInt8() != 0; } @@ -68,9 +69,10 @@ public override void WriteData(WriteContext context, in sbyte[] value, bool hasG public override sbyte[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - sbyte[] values = new sbyte[payloadSize]; - for (int i = 0; i < payloadSize; i++) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + context.Reader.CheckBound(byteSize); + sbyte[] values = new sbyte[byteSize]; + for (int i = 0; i < byteSize; i++) { values[i] = context.Reader.ReadInt8(); } @@ -99,13 +101,14 @@ public override void WriteData(WriteContext context, in short[] value, bool hasG public override short[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 1) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 1) != 0) { - throw new InvalidDataException("int16 array payload size mismatch"); + throw new InvalidDataException("int16 array byte size mismatch"); } - short[] values = new short[payloadSize / 2]; + context.Reader.CheckBound(byteSize); + short[] values = new short[byteSize / 2]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadInt16(); @@ -135,13 +138,14 @@ public override void WriteData(WriteContext context, in int[] value, bool hasGen public override int[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 3) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 3) != 0) { - throw new InvalidDataException("int32 array payload size mismatch"); + throw new InvalidDataException("int32 array byte size mismatch"); } - int[] values = new int[payloadSize / 4]; + context.Reader.CheckBound(byteSize); + int[] values = new int[byteSize / 4]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadInt32(); @@ -171,13 +175,14 @@ public override void WriteData(WriteContext context, in long[] value, bool hasGe public override long[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 7) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 7) != 0) { - throw new InvalidDataException("int64 array payload size mismatch"); + throw new InvalidDataException("int64 array byte size mismatch"); } - long[] values = new long[payloadSize / 8]; + context.Reader.CheckBound(byteSize); + long[] values = new long[byteSize / 8]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadInt64(); @@ -207,13 +212,14 @@ public override void WriteData(WriteContext context, in ushort[] value, bool has public override ushort[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 1) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 1) != 0) { - throw new InvalidDataException("uint16 array payload size mismatch"); + throw new InvalidDataException("uint16 array byte size mismatch"); } - ushort[] values = new ushort[payloadSize / 2]; + context.Reader.CheckBound(byteSize); + ushort[] values = new ushort[byteSize / 2]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadUInt16(); @@ -243,13 +249,14 @@ public override void WriteData(WriteContext context, in uint[] value, bool hasGe public override uint[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 3) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 3) != 0) { - throw new InvalidDataException("uint32 array payload size mismatch"); + throw new InvalidDataException("uint32 array byte size mismatch"); } - uint[] values = new uint[payloadSize / 4]; + context.Reader.CheckBound(byteSize); + uint[] values = new uint[byteSize / 4]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadUInt32(); @@ -279,13 +286,14 @@ public override void WriteData(WriteContext context, in ulong[] value, bool hasG public override ulong[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 7) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 7) != 0) { - throw new InvalidDataException("uint64 array payload size mismatch"); + throw new InvalidDataException("uint64 array byte size mismatch"); } - ulong[] values = new ulong[payloadSize / 8]; + context.Reader.CheckBound(byteSize); + ulong[] values = new ulong[byteSize / 8]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadUInt64(); @@ -315,13 +323,14 @@ public override void WriteData(WriteContext context, in Half[] value, bool hasGe public override Half[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 1) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 1) != 0) { - throw new InvalidDataException("float16 array payload size mismatch"); + throw new InvalidDataException("float16 array byte size mismatch"); } - Half[] values = new Half[payloadSize / 2]; + context.Reader.CheckBound(byteSize); + Half[] values = new Half[byteSize / 2]; for (int i = 0; i < values.Length; i++) { values[i] = BitConverter.UInt16BitsToHalf(context.Reader.ReadUInt16()); @@ -351,13 +360,14 @@ public override void WriteData(WriteContext context, in BFloat16[] value, bool h public override BFloat16[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 1) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 1) != 0) { - throw new InvalidDataException("bfloat16 array payload size mismatch"); + throw new InvalidDataException("bfloat16 array byte size mismatch"); } - BFloat16[] values = new BFloat16[payloadSize / 2]; + context.Reader.CheckBound(byteSize); + BFloat16[] values = new BFloat16[byteSize / 2]; for (int i = 0; i < values.Length; i++) { values[i] = BFloat16.FromBits(context.Reader.ReadUInt16()); @@ -387,13 +397,14 @@ public override void WriteData(WriteContext context, in float[] value, bool hasG public override float[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 3) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 3) != 0) { - throw new InvalidDataException("float32 array payload size mismatch"); + throw new InvalidDataException("float32 array byte size mismatch"); } - float[] values = new float[payloadSize / 4]; + context.Reader.CheckBound(byteSize); + float[] values = new float[byteSize / 4]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadFloat32(); @@ -423,13 +434,14 @@ public override void WriteData(WriteContext context, in double[] value, bool has public override double[] ReadData(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 7) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 7) != 0) { - throw new InvalidDataException("float64 array payload size mismatch"); + throw new InvalidDataException("float64 array byte size mismatch"); } - double[] values = new double[payloadSize / 8]; + context.Reader.CheckBound(byteSize); + double[] values = new double[byteSize / 8]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadFloat64(); diff --git a/csharp/src/Fory/PrimitiveDictionarySerializers.cs b/csharp/src/Fory/PrimitiveDictionarySerializers.cs index 05173961e9..a136bd57bd 100644 --- a/csharp/src/Fory/PrimitiveDictionarySerializers.cs +++ b/csharp/src/Fory/PrimitiveDictionarySerializers.cs @@ -670,12 +670,13 @@ public static TMap ReadMap( where TMapOps : struct, IPrimitiveMapReadOps { int totalLength = checked((int)context.Reader.ReadVarUInt32()); - TMap map = TMapOps.Create(totalLength); if (totalLength == 0) { - return map; + return TMapOps.Create(0); } + context.Reader.CheckBound(totalLength); + TMap map = TMapOps.Create(totalLength); TypeId keyTypeId = TKeyCodec.WireTypeId; TypeId valueTypeId = TValueCodec.WireTypeId; bool keyNullable = TKeyCodec.IsNullable; diff --git a/csharp/src/Fory/TypeMeta.cs b/csharp/src/Fory/TypeMeta.cs index 97d3de1177..31ea8e29aa 100644 --- a/csharp/src/Fory/TypeMeta.cs +++ b/csharp/src/Fory/TypeMeta.cs @@ -541,6 +541,7 @@ public static TypeMeta Decode(ByteReader reader) typeName = MetaString.Empty('$', '_'); } + bodyReader.CheckBound(numFields); List fields = new(numFields); for (int i = 0; i < numFields; i++) { diff --git a/csharp/src/Fory/TypeResolver.cs b/csharp/src/Fory/TypeResolver.cs index 000db8f9db..8313fed758 100644 --- a/csharp/src/Fory/TypeResolver.cs +++ b/csharp/src/Fory/TypeResolver.cs @@ -1125,6 +1125,7 @@ private static byte[] ReadBinary(ReadContext context) private static bool[] ReadBoolArray(ReadContext context) { int count = checked((int)context.Reader.ReadVarUInt32()); + context.Reader.CheckBound(count); bool[] values = new bool[count]; for (int i = 0; i < values.Length; i++) { @@ -1137,6 +1138,7 @@ private static bool[] ReadBoolArray(ReadContext context) private static sbyte[] ReadInt8Array(ReadContext context) { int count = checked((int)context.Reader.ReadVarUInt32()); + context.Reader.CheckBound(count); sbyte[] values = new sbyte[count]; for (int i = 0; i < values.Length; i++) { @@ -1148,13 +1150,14 @@ private static sbyte[] ReadInt8Array(ReadContext context) private static short[] ReadInt16Array(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 1) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 1) != 0) { - throw new InvalidDataException("int16 array payload size mismatch"); + throw new InvalidDataException("int16 array byte size mismatch"); } - short[] values = new short[payloadSize / 2]; + context.Reader.CheckBound(byteSize); + short[] values = new short[byteSize / 2]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadInt16(); @@ -1165,13 +1168,14 @@ private static short[] ReadInt16Array(ReadContext context) private static int[] ReadInt32Array(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 3) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 3) != 0) { - throw new InvalidDataException("int32 array payload size mismatch"); + throw new InvalidDataException("int32 array byte size mismatch"); } - int[] values = new int[payloadSize / 4]; + context.Reader.CheckBound(byteSize); + int[] values = new int[byteSize / 4]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadInt32(); @@ -1182,13 +1186,14 @@ private static int[] ReadInt32Array(ReadContext context) private static long[] ReadInt64Array(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 7) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 7) != 0) { - throw new InvalidDataException("int64 array payload size mismatch"); + throw new InvalidDataException("int64 array byte size mismatch"); } - long[] values = new long[payloadSize / 8]; + context.Reader.CheckBound(byteSize); + long[] values = new long[byteSize / 8]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadInt64(); @@ -1199,13 +1204,14 @@ private static long[] ReadInt64Array(ReadContext context) private static ushort[] ReadUInt16Array(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 1) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 1) != 0) { - throw new InvalidDataException("uint16 array payload size mismatch"); + throw new InvalidDataException("uint16 array byte size mismatch"); } - ushort[] values = new ushort[payloadSize / 2]; + context.Reader.CheckBound(byteSize); + ushort[] values = new ushort[byteSize / 2]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadUInt16(); @@ -1216,13 +1222,14 @@ private static ushort[] ReadUInt16Array(ReadContext context) private static uint[] ReadUInt32Array(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 3) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 3) != 0) { - throw new InvalidDataException("uint32 array payload size mismatch"); + throw new InvalidDataException("uint32 array byte size mismatch"); } - uint[] values = new uint[payloadSize / 4]; + context.Reader.CheckBound(byteSize); + uint[] values = new uint[byteSize / 4]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadUInt32(); @@ -1233,13 +1240,14 @@ private static uint[] ReadUInt32Array(ReadContext context) private static ulong[] ReadUInt64Array(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 7) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 7) != 0) { - throw new InvalidDataException("uint64 array payload size mismatch"); + throw new InvalidDataException("uint64 array byte size mismatch"); } - ulong[] values = new ulong[payloadSize / 8]; + context.Reader.CheckBound(byteSize); + ulong[] values = new ulong[byteSize / 8]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadUInt64(); @@ -1250,13 +1258,14 @@ private static ulong[] ReadUInt64Array(ReadContext context) private static float[] ReadFloat32Array(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 3) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 3) != 0) { - throw new InvalidDataException("float32 array payload size mismatch"); + throw new InvalidDataException("float32 array byte size mismatch"); } - float[] values = new float[payloadSize / 4]; + context.Reader.CheckBound(byteSize); + float[] values = new float[byteSize / 4]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadFloat32(); @@ -1267,13 +1276,14 @@ private static float[] ReadFloat32Array(ReadContext context) private static double[] ReadFloat64Array(ReadContext context) { - int payloadSize = checked((int)context.Reader.ReadVarUInt32()); - if ((payloadSize & 7) != 0) + int byteSize = checked((int)context.Reader.ReadVarUInt32()); + if ((byteSize & 7) != 0) { - throw new InvalidDataException("float64 array payload size mismatch"); + throw new InvalidDataException("float64 array byte size mismatch"); } - double[] values = new double[payloadSize / 8]; + context.Reader.CheckBound(byteSize); + double[] values = new double[byteSize / 8]; for (int i = 0; i < values.Length; i++) { values[i] = context.Reader.ReadFloat64(); diff --git a/csharp/tests/Fory.Tests/ByteBufferTests.cs b/csharp/tests/Fory.Tests/ByteBufferTests.cs index 9da9705aca..d2e0846cfe 100644 --- a/csharp/tests/Fory.Tests/ByteBufferTests.cs +++ b/csharp/tests/Fory.Tests/ByteBufferTests.cs @@ -116,6 +116,20 @@ public void PrimitiveReadWriteRoundTrip() Assert.Equal(0, reader.Remaining); } + [Fact] + public void CheckBoundRejectsNegativeAndOverflowingNeed() + { + ByteReader reader = new([0x01]); + + Assert.Throws(() => reader.CheckBound(-1)); + Assert.Throws(() => reader.ReadBytes(-1)); + Assert.Throws(() => reader.ReadSpan(-1)); + Assert.Throws(() => reader.Skip(-1)); + + reader.SetCursor(1); + Assert.Throws(() => reader.CheckBound(int.MaxValue)); + } + [Theory] [MemberData(nameof(VarUInt32Cases))] public void VarUInt32RoundTripAndSize(uint value, int expectedBytes) diff --git a/csharp/tests/Fory.Tests/ForyGeneratorTests.cs b/csharp/tests/Fory.Tests/ForyGeneratorTests.cs index 5b6bb4a665..0538463413 100644 --- a/csharp/tests/Fory.Tests/ForyGeneratorTests.cs +++ b/csharp/tests/Fory.Tests/ForyGeneratorTests.cs @@ -172,6 +172,43 @@ public sealed class Shape Assert.DoesNotContain("if (remoteField.FieldType.TypeId ==", generated, StringComparison.Ordinal); } + [Fact] + public void CompatibleBinaryListChecksBeforeCapacity() + { + const string source = """ + using System.Collections.Generic; + using Apache.Fory; + using S = Apache.Fory.Schema.Types; + + namespace GeneratedDiagnostics; + + [ForyStruct] + public sealed class BinaryListShape + { + [ForyField(1, Type = typeof(S.Array))] + public List Value { get; set; } = []; + } + """; + + string generated = GenerateSource(source); + + int lengthIndex = generated.IndexOf( + "int __foryLength = checked((int)context.Reader.ReadVarUInt32());", + StringComparison.Ordinal); + int checkIndex = generated.IndexOf( + "context.Reader.CheckBound(__foryLength);", + lengthIndex, + StringComparison.Ordinal); + int allocationIndex = generated.IndexOf( + "new(__foryLength);", + lengthIndex, + StringComparison.Ordinal); + + Assert.True(lengthIndex >= 0); + Assert.True(checkIndex > lengthIndex); + Assert.True(allocationIndex > checkIndex); + } + private static string GenerateSource(string source) { CSharpCompilation compilation = CreateCompilation(source); diff --git a/dart/packages/fory/README.md b/dart/packages/fory/README.md index d54fe1f98a..2a0e6d301c 100644 --- a/dart/packages/fory/README.md +++ b/dart/packages/fory/README.md @@ -123,18 +123,14 @@ Keep the same registration identity on every peer that exchanges the type. ```dart final fory = Fory( maxDepth: 256, - maxCollectionSize: 1 << 20, - maxBinarySize: 64 * 1024 * 1024, ); ``` -| Option | Default | Description | -| -------------------- | ---------- | ------------------------------------------------------- | -| `compatible` | `true` | Enables compatible struct encoding for schema evolution | -| `checkStructVersion` | `false` | Validates struct version for same-schema payloads | -| `maxDepth` | `256` | Maximum nesting depth per operation | -| `maxCollectionSize` | `1 << 20` | Maximum collection and map payload size | -| `maxBinarySize` | `64 << 20` | Maximum binary payload size | +| Option | Default | Description | +| -------------------- | ------- | ------------------------------------------------------- | +| `compatible` | `true` | Enables compatible struct encoding for schema evolution | +| `checkStructVersion` | `false` | Validates struct version for same-schema payloads | +| `maxDepth` | `256` | Maximum nesting depth per operation | ## Reference Tracking diff --git a/dart/packages/fory/lib/src/codegen/generated_support.dart b/dart/packages/fory/lib/src/codegen/generated_support.dart index 709f9cd3e1..b41c2fa1ca 100644 --- a/dart/packages/fory/lib/src/codegen/generated_support.dart +++ b/dart/packages/fory/lib/src/codegen/generated_support.dart @@ -267,6 +267,7 @@ void writeGeneratedBoolArrayValue(WriteContext context, BoolList value) { BoolList readGeneratedBoolArrayValue(ReadContext context) { final buffer = context.buffer; final size = buffer.readVarUint32(); + buffer.checkReadableBytes(size); return BoolList.arrayStorage(buffer.readInt8Bytes(size)); } diff --git a/dart/packages/fory/lib/src/config.dart b/dart/packages/fory/lib/src/config.dart index dbe6999f5f..2ab848e57f 100644 --- a/dart/packages/fory/lib/src/config.dart +++ b/dart/packages/fory/lib/src/config.dart @@ -19,19 +19,12 @@ /// Fory instance configuration for the Dart xlang implementation. /// -/// The defaults favor compatible mode with conservative safety limits. +/// The defaults favor compatible mode with conservative structural limits. final class Config { /// Default maximum nesting depth for a single serialization or /// deserialization operation. static const int defaultMaxDepth = 256; - /// Default maximum number of collection entries accepted in one collection or - /// map payload. - static const int defaultMaxCollectionSize = 1 << 20; - - /// Default maximum number of bytes accepted for a binary payload. - static const int defaultMaxBinarySize = 64 * 1024 * 1024; - /// Enables compatible struct encoding and decoding. /// /// In compatible mode Fory shares TypeDef metadata and disables @@ -46,12 +39,6 @@ final class Config { /// Maximum allowed read or write nesting depth. final int maxDepth; - /// Maximum allowed collection or map size. - final int maxCollectionSize; - - /// Maximum allowed binary payload size in bytes. - final int maxBinarySize; - /// Creates an immutable configuration object. /// /// Invalid numeric limits fail fast. When [compatible] is `true`, @@ -60,10 +47,6 @@ final class Config { this.compatible = true, bool checkStructVersion = true, this.maxDepth = defaultMaxDepth, - this.maxCollectionSize = defaultMaxCollectionSize, - this.maxBinarySize = defaultMaxBinarySize, }) : checkStructVersion = compatible ? false : checkStructVersion, - assert(maxDepth > 0, 'maxDepth must be positive'), - assert(maxCollectionSize > 0, 'maxCollectionSize must be positive'), - assert(maxBinarySize > 0, 'maxBinarySize must be positive'); + assert(maxDepth > 0, 'maxDepth must be positive'); } diff --git a/dart/packages/fory/lib/src/context/meta_string_reader.dart b/dart/packages/fory/lib/src/context/meta_string_reader.dart index 1b112155f5..d0b70b6c4d 100644 --- a/dart/packages/fory/lib/src/context/meta_string_reader.dart +++ b/dart/packages/fory/lib/src/context/meta_string_reader.dart @@ -24,13 +24,8 @@ import 'package:fory/src/meta/meta_string.dart'; import 'package:fory/src/resolver/type_resolver.dart'; import 'package:fory/src/types/int64.dart'; -typedef _MetaStringWords = ({ - int length, - int word0, - int word1, - int word2, - int word3 -}); +typedef _MetaStringWords = + ({int length, int word0, int word1, int word2, int word3}); /// Read-side state for meta-string references in one deserialization stream. final class MetaStringReader { @@ -61,9 +56,10 @@ final class MetaStringReader { if ((header & 1) == 1) { return _dynamicReadMetaStrings[length - 1]; } - final encoded = length > metaStringSmallThreshold - ? _readBigMetaString(buffer, length, expected) - : _readSmallMetaString(buffer, length, expected); + final encoded = + length > metaStringSmallThreshold + ? _readBigMetaString(buffer, length, expected) + : _readSmallMetaString(buffer, length, expected); _dynamicReadMetaStrings.add(encoded); return encoded; } @@ -74,6 +70,7 @@ final class MetaStringReader { EncodedMetaString? expected, ) { final hash = buffer.readInt64(); + buffer.checkReadableBytes(length); if (expected != null && expected.hash == hash) { buffer.skip(length); return expected; @@ -100,20 +97,14 @@ final class MetaStringReader { return EncodedMetaString.empty; } final encoding = buffer.readByte() & 0xff; + buffer.checkReadableBytes(length); final words = _readMetaStringWords(buffer, length); final word0 = words.word0; final word1 = words.word1; final word2 = words.word2; final word3 = words.word3; if (expected != null && - expected.matchesPacked( - encoding, - length, - word0, - word1, - word2, - word3, - )) { + expected.matchesPacked(encoding, length, word0, word1, word2, word3)) { return expected; } final hash = _smallMetaStringHash( @@ -128,7 +119,13 @@ final class MetaStringReader { if (bucket != null) { for (final cached in bucket) { if (cached.matchesPacked( - encoding, length, word0, word1, word2, word3)) { + encoding, + length, + word0, + word1, + word2, + word3, + )) { return cached; } } @@ -137,9 +134,7 @@ final class MetaStringReader { _materializeMetaStringWords(words), encoding: encoding, ); - (bucket ?? (_smallMetaStrings[hash] = [])).add( - encoded, - ); + (bucket ?? (_smallMetaStrings[hash] = [])).add(encoded); return encoded; } } diff --git a/dart/packages/fory/lib/src/fory.dart b/dart/packages/fory/lib/src/fory.dart index 741f647026..6daaa072e6 100644 --- a/dart/packages/fory/lib/src/fory.dart +++ b/dart/packages/fory/lib/src/fory.dart @@ -57,15 +57,11 @@ final class Fory { bool compatible = true, bool checkStructVersion = true, int maxDepth = Config.defaultMaxDepth, - int maxCollectionSize = Config.defaultMaxCollectionSize, - int maxBinarySize = Config.defaultMaxBinarySize, }) { final config = Config( compatible: compatible, checkStructVersion: checkStructVersion, maxDepth: maxDepth, - maxCollectionSize: maxCollectionSize, - maxBinarySize: maxBinarySize, ); _readBuffer = Buffer(); _writeBuffer = Buffer(); diff --git a/dart/packages/fory/lib/src/memory/buffer_mixin.dart b/dart/packages/fory/lib/src/memory/buffer_mixin.dart index 200ca38314..2c02064809 100644 --- a/dart/packages/fory/lib/src/memory/buffer_mixin.dart +++ b/dart/packages/fory/lib/src/memory/buffer_mixin.dart @@ -42,6 +42,13 @@ mixin _BufferMixin { /// Number of unread bytes between the reader and writer indices. int get readableBytes => _writerIndex - _readerIndex; + /// Fails if [length] bytes are not currently readable. + void checkReadableBytes(int length) { + if (length < 0 || length > readableBytes) { + throw StateError('Insufficient readable bytes: $length.'); + } + } + /// Returns the written portion of the underlying storage. /// /// The returned view shares memory with the buffer. diff --git a/dart/packages/fory/lib/src/resolver/type_resolver.dart b/dart/packages/fory/lib/src/resolver/type_resolver.dart index 4ec28c03cd..3ad6486316 100644 --- a/dart/packages/fory/lib/src/resolver/type_resolver.dart +++ b/dart/packages/fory/lib/src/resolver/type_resolver.dart @@ -1144,6 +1144,7 @@ final class TypeResolver { TypeInfo _readTypeDefWithHeader(Buffer buffer, TypeHeader header) { header.validateGlobal(); final metaSize = header.readMetaSize(buffer); + buffer.checkReadableBytes(metaSize); final metaBody = buffer.readBytes(metaSize); final metaBytes = Buffer.wrap(metaBody); final classHeader = metaBytes.readUint8(); @@ -1245,6 +1246,7 @@ final class TypeResolver { if (size == typeDefBigNameThreshold) { size += source.readVarUint32Small7(); } + source.checkReadableBytes(size); return internEncodedMetaString( Uint8List.fromList(source.readBytes(size)), encoding: decodeEncoding(compactEncoding), @@ -1269,10 +1271,13 @@ final class TypeResolver { nullable: fieldNullable, ref: fieldRef, ); - final identifier = - isTag - ? tagId.toString() - : decodeFieldName(source.readBytes(size), encoding); + final String identifier; + if (isTag) { + identifier = tagId.toString(); + } else { + source.checkReadableBytes(size); + identifier = decodeFieldName(source.readBytes(size), encoding); + } return FieldInfo( name: identifier, identifier: identifier, diff --git a/dart/packages/fory/lib/src/serializer/collection_serializers.dart b/dart/packages/fory/lib/src/serializer/collection_serializers.dart index 98d33ffc23..66ab8da9ed 100644 --- a/dart/packages/fory/lib/src/serializer/collection_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/collection_serializers.dart @@ -264,11 +264,6 @@ final class ListSerializer extends Serializer { required bool trackRef, }) { final size = values.length; - if (size > context.config.maxCollectionSize) { - throw StateError( - 'Collection size $size exceeds ${context.config.maxCollectionSize}.', - ); - } context.buffer.writeVarUint32(size); if (size == 0) { return; @@ -339,6 +334,7 @@ final class ListSerializer extends Serializer { bool hasPreservedRef = false, }) { final state = _prepareListRead(context, elementFieldType); + context.buffer.checkReadableBytes(state.size); final result = List.filled(state.size, null, growable: false); if (hasPreservedRef) { context.reference(result); @@ -512,11 +508,6 @@ Object _readCompatibleListAsArrayField( String fieldName, ) { final size = context.buffer.readVarUint32(); - if (size > context.config.maxCollectionSize) { - throw StateError( - 'Collection size $size exceeds ${context.config.maxCollectionSize}.', - ); - } if (size == 0) { return _newArrayValue(arrayTypeId, 0); } @@ -537,6 +528,7 @@ Object _readCompatibleListAsArrayField( ); } final elementResolved = context.typeResolver.resolveFieldType(elementType); + context.buffer.checkReadableBytes(size); final result = _newArrayValue(arrayTypeId, size); for (var index = 0; index < size; index += 1) { _setArrayValue( @@ -663,6 +655,7 @@ List readTypedListPayload( if (directTypeInfo.type == T && directTypeInfo.kind == RegistrationKind.struct) { final structSerializer = directTypeInfo.structSerializer!; + context.buffer.checkReadableBytes(state.size); final result = directTypeInfo.remoteTypeDef == null ? List.generate( @@ -686,6 +679,7 @@ List readTypedListPayload( return result; } if (directTypeInfo.type == T && directTypeInfo.typeId == TypeIds.string) { + context.buffer.checkReadableBytes(state.size); final result = List.generate( state.size, (_) => StringSerializer.readPayload(context) as T, @@ -696,6 +690,7 @@ List readTypedListPayload( } return result; } + context.buffer.checkReadableBytes(state.size); final result = List.generate( state.size, (_) => @@ -707,6 +702,7 @@ List readTypedListPayload( } return result; } + context.buffer.checkReadableBytes(state.size); final result = List.generate( state.size, (_) => convert(_readPreparedListItem(context, state)), @@ -732,11 +728,6 @@ void writeTypedListPayload( FieldType elementFieldType, ) { final size = values.length; - if (size > context.config.maxCollectionSize) { - throw StateError( - 'Collection size $size exceeds ${context.config.maxCollectionSize}.', - ); - } context.buffer.writeVarUint32(size); if (size == 0) return; final declaredTypeInfo = context.typeResolver.resolveFieldType( @@ -919,11 +910,6 @@ _PreparedListRead _prepareListRead( FieldType? elementFieldType, ) { final size = context.buffer.readVarUint32(); - if (size > context.config.maxCollectionSize) { - throw StateError( - 'Collection size $size exceeds ${context.config.maxCollectionSize}.', - ); - } if (size == 0) { return _PreparedListRead( size: 0, diff --git a/dart/packages/fory/lib/src/serializer/map_serializers.dart b/dart/packages/fory/lib/src/serializer/map_serializers.dart index a2b669c329..18d511412d 100644 --- a/dart/packages/fory/lib/src/serializer/map_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/map_serializers.dart @@ -40,13 +40,7 @@ final class MapSerializer extends Serializer { @override void write(WriteContext context, Map value) { - writePayload( - context, - value, - null, - null, - trackRef: context.rootTrackRef, - ); + writePayload(context, value, null, null, trackRef: context.rootTrackRef); } @override @@ -61,26 +55,24 @@ final class MapSerializer extends Serializer { FieldType? valueFieldType, { required bool trackRef, }) { - if (values.length > context.config.maxCollectionSize) { - throw StateError( - 'Map size ${values.length} exceeds ${context.config.maxCollectionSize}.', - ); - } context.buffer.writeVarUint32(values.length); - final declaredKeyTypeInfo = keyFieldType == null || keyFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(keyFieldType); + final declaredKeyTypeInfo = + keyFieldType == null || keyFieldType.isDynamic + ? null + : context.typeResolver.resolveFieldType(keyFieldType); final declaredValueTypeInfo = valueFieldType == null || valueFieldType.isDynamic ? null : context.typeResolver.resolveFieldType(valueFieldType); - final keyDeclared = declaredKeyTypeInfo != null && + final keyDeclared = + declaredKeyTypeInfo != null && usesDeclaredTypeInfo( context.config.compatible, keyFieldType!, declaredKeyTypeInfo, ); - final valueDeclared = declaredValueTypeInfo != null && + final valueDeclared = + declaredValueTypeInfo != null && usesDeclaredTypeInfo( context.config.compatible, valueFieldType!, @@ -109,14 +101,16 @@ final class MapSerializer extends Serializer { final key = entry.key; final value = entry.value; if (key == null || value == null) { - final keyTrackRef = keyRequestedRef && + final keyTrackRef = + keyRequestedRef && (keyDeclared ? declaredKeyTypeInfo.supportsRef : (key == null || context.typeResolver .resolveValue(key as Object) .supportsRef)); - final valueTrackRef = valueRequestedRef && + final valueTrackRef = + valueRequestedRef && (valueDeclared ? declaredValueTypeInfo.supportsRef : (value == null || @@ -138,12 +132,14 @@ final class MapSerializer extends Serializer { ); continue; } - final chunkKeyTypeInfo = keyDeclared - ? declaredKeyTypeInfo - : context.typeResolver.resolveValue(key as Object); - final chunkValueTypeInfo = valueDeclared - ? declaredValueTypeInfo - : context.typeResolver.resolveValue(value as Object); + final chunkKeyTypeInfo = + keyDeclared + ? declaredKeyTypeInfo + : context.typeResolver.resolveValue(key as Object); + final chunkValueTypeInfo = + valueDeclared + ? declaredValueTypeInfo + : context.typeResolver.resolveValue(value as Object); final chunkKeyTrackRef = keyRequestedRef && chunkKeyTypeInfo.supportsRef; final chunkValueTrackRef = valueRequestedRef && chunkValueTypeInfo.supportsRef; @@ -164,9 +160,8 @@ final class MapSerializer extends Serializer { context.writeTypeMetaValue(chunkValueTypeInfo, value); } var chunkLength = 1; - final tracksDepth = tracksNestedPayloadDepth( - chunkKeyTypeInfo, - ) || + final tracksDepth = + tracksNestedPayloadDepth(chunkKeyTypeInfo) || tracksNestedPayloadDepth(chunkValueTypeInfo); if (tracksDepth) { context.increaseDepth(); @@ -194,12 +189,14 @@ final class MapSerializer extends Serializer { pendingEntry = nextEntry; break; } - final nextKeyTypeInfo = keyDeclared - ? declaredKeyTypeInfo - : context.typeResolver.resolveValue(nextKey as Object); - final nextValueTypeInfo = valueDeclared - ? declaredValueTypeInfo - : context.typeResolver.resolveValue(nextValue as Object); + final nextKeyTypeInfo = + keyDeclared + ? declaredKeyTypeInfo + : context.typeResolver.resolveValue(nextKey as Object); + final nextValueTypeInfo = + valueDeclared + ? declaredValueTypeInfo + : context.typeResolver.resolveValue(nextValue as Object); final nextKeyTrackRef = keyRequestedRef && nextKeyTypeInfo.supportsRef; final nextValueTrackRef = valueRequestedRef && nextValueTypeInfo.supportsRef; @@ -208,10 +205,7 @@ final class MapSerializer extends Serializer { (!keyDeclared && !sameTypeInfo(chunkKeyTypeInfo, nextKeyTypeInfo)) || (!valueDeclared && - !sameTypeInfo( - chunkValueTypeInfo, - nextValueTypeInfo, - ))) { + !sameTypeInfo(chunkValueTypeInfo, nextValueTypeInfo))) { pendingEntry = nextEntry; break; } @@ -263,14 +257,10 @@ Map readTypedMapPayload( bool hasPreservedRef = false, }) { var remaining = context.buffer.readVarUint32(); - if (remaining > context.config.maxCollectionSize) { - throw StateError( - 'Map size $remaining exceeds ${context.config.maxCollectionSize}.', - ); - } - final declaredKeyTypeInfo = keyFieldType == null || keyFieldType.isDynamic - ? null - : context.typeResolver.resolveFieldType(keyFieldType); + final declaredKeyTypeInfo = + keyFieldType == null || keyFieldType.isDynamic + ? null + : context.typeResolver.resolveFieldType(keyFieldType); final declaredValueTypeInfo = valueFieldType == null || valueFieldType.isDynamic ? null @@ -285,12 +275,7 @@ Map readTypedMapPayload( final valueHasNull = (header & MapFlags.valueHasNull) != 0; if (keyHasNull || valueHasNull) { result[convertKey( - _readNullChunkKey( - context, - header, - keyFieldType, - declaredKeyTypeInfo, - ), + _readNullChunkKey(context, header, keyFieldType, declaredKeyTypeInfo), )] = convertValue( _readNullChunkValue( context, @@ -316,43 +301,45 @@ Map readTypedMapPayload( final valueTypeInfo = valueDeclared ? null : context.readTypeMetaValue(); final tracksDepth = ((keyDeclared ? declaredKeyTypeInfo : keyTypeInfo) != null && - tracksNestedPayloadDepth( - keyDeclared ? declaredKeyTypeInfo! : keyTypeInfo!, - )) || - ((valueDeclared ? declaredValueTypeInfo : valueTypeInfo) != null && - tracksNestedPayloadDepth( - valueDeclared ? declaredValueTypeInfo! : valueTypeInfo!, - )); + tracksNestedPayloadDepth( + keyDeclared ? declaredKeyTypeInfo! : keyTypeInfo!, + )) || + ((valueDeclared ? declaredValueTypeInfo : valueTypeInfo) != null && + tracksNestedPayloadDepth( + valueDeclared ? declaredValueTypeInfo! : valueTypeInfo!, + )); if (tracksDepth) { context.increaseDepth(); } for (var index = 0; index < chunkSize; index += 1) { - final key = keyDeclared - ? _readDeclaredMapValue( - context, - keyFieldType!, - declaredKeyTypeInfo!, - trackRef: keyTrackRef, - ) - : _readResolvedMapValue( - context, - keyTypeInfo!, - null, - trackRef: keyTrackRef, - ); - final value = valueDeclared - ? _readDeclaredMapValue( - context, - valueFieldType!, - declaredValueTypeInfo!, - trackRef: valueTrackRef, - ) - : _readResolvedMapValue( - context, - valueTypeInfo!, - null, - trackRef: valueTrackRef, - ); + final key = + keyDeclared + ? _readDeclaredMapValue( + context, + keyFieldType!, + declaredKeyTypeInfo!, + trackRef: keyTrackRef, + ) + : _readResolvedMapValue( + context, + keyTypeInfo!, + null, + trackRef: keyTrackRef, + ); + final value = + valueDeclared + ? _readDeclaredMapValue( + context, + valueFieldType!, + declaredValueTypeInfo!, + trackRef: valueTrackRef, + ) + : _readResolvedMapValue( + context, + valueTypeInfo!, + null, + trackRef: valueTrackRef, + ); result[convertKey(key)] = convertValue(value); } if (tracksDepth) { @@ -499,10 +486,7 @@ Object? _readNullChunkValue( return trackRef ? context.readRef() : context.readNonRef(); } -FieldType _declaredMapFieldType( - FieldType fieldType, { - required bool trackRef, -}) { +FieldType _declaredMapFieldType(FieldType fieldType, {required bool trackRef}) { return fieldType.withRootOverrides(nullable: false, ref: trackRef); } diff --git a/dart/packages/fory/lib/src/serializer/scalar_serializers.dart b/dart/packages/fory/lib/src/serializer/scalar_serializers.dart index 053319e40f..05e4604c04 100644 --- a/dart/packages/fory/lib/src/serializer/scalar_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/scalar_serializers.dart @@ -50,10 +50,10 @@ Uint8List _decimalMagnitudeToCanonicalLittleEndian(BigInt magnitude) { return Uint8List.fromList(bytes); } -BigInt _decimalMagnitudeFromCanonicalLittleEndian(Uint8List payload) { +BigInt _decimalMagnitudeFromCanonicalLittleEndian(Uint8List magnitudeBytes) { var magnitude = BigInt.zero; - for (var index = payload.length - 1; index >= 0; index -= 1) { - magnitude = (magnitude << 8) | BigInt.from(payload[index]); + for (var index = magnitudeBytes.length - 1; index >= 0; index -= 1) { + magnitude = (magnitude << 8) | BigInt.from(magnitudeBytes[index]); } return magnitude; } @@ -132,22 +132,13 @@ final class BinarySerializer extends Serializer { } static void writePayload(WriteContext context, Uint8List value) { - if (value.length > context.config.maxBinarySize) { - throw StateError( - 'Binary payload exceeds ${context.config.maxBinarySize} bytes.', - ); - } context.buffer.writeVarUint32(value.length); context.buffer.writeBytes(value); } static Uint8List readPayload(ReadContext context) { final size = context.buffer.readVarUint32(); - if (size > context.config.maxBinarySize) { - throw StateError( - 'Binary payload exceeds ${context.config.maxBinarySize} bytes.', - ); - } + context.buffer.checkReadableBytes(size); return context.buffer.copyBytes(size); } } @@ -178,11 +169,13 @@ final class DecimalSerializer extends Serializer { return; } - final payload = _decimalMagnitudeToCanonicalLittleEndian(unscaled.abs()); + final magnitudeBytes = _decimalMagnitudeToCanonicalLittleEndian( + unscaled.abs(), + ); final sign = unscaled.isNegative ? 1 : 0; - final meta = (payload.length << 1) | sign; + final meta = (magnitudeBytes.length << 1) | sign; buffer.writeVarUint64(Uint64((meta << 1) | 1)); - buffer.writeBytes(payload); + buffer.writeBytes(magnitudeBytes); } static Decimal readPayload(ReadContext context) { @@ -198,13 +191,16 @@ final class DecimalSerializer extends Serializer { if (length <= 0) { throw StateError('Invalid decimal magnitude length $length.'); } - final payload = context.buffer.copyBytes(length); - if (payload[length - 1] == 0) { + context.buffer.checkReadableBytes(length); + final magnitudeBytes = context.buffer.copyBytes(length); + if (magnitudeBytes[length - 1] == 0) { throw StateError( - 'Non-canonical decimal payload: trailing zero byte.', + 'Non-canonical decimal magnitude bytes: trailing zero byte.', ); } - final magnitude = _decimalMagnitudeFromCanonicalLittleEndian(payload); + final magnitude = _decimalMagnitudeFromCanonicalLittleEndian( + magnitudeBytes, + ); if (magnitude == BigInt.zero) { throw StateError('Big decimal encoding must not represent zero.'); } diff --git a/dart/packages/fory/lib/src/serializer/typed_array_serializers.dart b/dart/packages/fory/lib/src/serializer/typed_array_serializers.dart index 9f429399b8..360a61581c 100644 --- a/dart/packages/fory/lib/src/serializer/typed_array_serializers.dart +++ b/dart/packages/fory/lib/src/serializer/typed_array_serializers.dart @@ -29,32 +29,30 @@ import 'package:fory/src/types/float16.dart'; import 'package:fory/src/types/int64.dart'; import 'package:fory/src/types/uint64.dart'; -void writeTypedArrayBytes( - WriteContext context, - Object values, -) { +void writeTypedArrayBytes(WriteContext context, Object values) { final bytes = switch (values) { Int64List typed => typed.buffer.asUint8List( - typed.offsetInBytes, - typed.lengthInBytes, - ), + typed.offsetInBytes, + typed.lengthInBytes, + ), Uint64List typed => typed.buffer.asUint8List( - typed.offsetInBytes, - typed.lengthInBytes, - ), + typed.offsetInBytes, + typed.lengthInBytes, + ), Float16List typed => typed.buffer.asUint8List( - typed.offsetInBytes, - typed.lengthInBytes, - ), + typed.offsetInBytes, + typed.lengthInBytes, + ), Bfloat16List typed => typed.buffer.asUint8List( - typed.offsetInBytes, - typed.lengthInBytes, - ), + typed.offsetInBytes, + typed.lengthInBytes, + ), td.TypedData typed => typed.buffer.asUint8List( - typed.offsetInBytes, - typed.lengthInBytes, - ), - _ => throw ArgumentError.value( + typed.offsetInBytes, + typed.lengthInBytes, + ), + _ => + throw ArgumentError.value( values, 'values', 'Expected a supported contiguous typed array value.', @@ -75,6 +73,7 @@ T readTypedArrayBytes( 'Typed array byte size $byteSize is not aligned to element size $elementSize.', ); } + context.buffer.checkReadableBytes(byteSize); var bytes = context.buffer.readBytes(byteSize); if (bytes.offsetInBytes % elementSize != 0) { bytes = td.Uint8List.fromList(bytes); @@ -97,6 +96,7 @@ final class BoolArraySerializer extends Serializer { @override BoolList read(ReadContext context) { final size = context.buffer.readVarUint32(); + context.buffer.checkReadableBytes(size); return BoolList.arrayStorage(context.buffer.readInt8Bytes(size)); } } @@ -106,11 +106,7 @@ final class TypedArraySerializer extends Serializer { final int elementSize; final T Function(td.Uint8List bytes) viewBuilder; - const TypedArraySerializer( - this.typeId, - this.elementSize, - this.viewBuilder, - ); + const TypedArraySerializer(this.typeId, this.elementSize, this.viewBuilder); @override bool get supportsRef => false; @@ -130,6 +126,7 @@ final class TypedArraySerializer extends Serializer { T read(ReadContext context) { if (typeId == TypeIds.int8Array) { final size = context.buffer.readVarUint32(); + context.buffer.checkReadableBytes(size); return td.Int8List.fromList(context.buffer.readBytes(size)) as T; } return readTypedArrayBytes(context, elementSize, viewBuilder); @@ -139,115 +136,78 @@ final class TypedArraySerializer extends Serializer { const BoolArraySerializer boolArraySerializer = BoolArraySerializer(); const TypedArraySerializer int8ArraySerializer = TypedArraySerializer( - TypeIds.int8Array, - 1, - td.Int8List.fromList, -); + TypeIds.int8Array, + 1, + td.Int8List.fromList, + ); const TypedArraySerializer int16ArraySerializer = - TypedArraySerializer( - TypeIds.int16Array, - 2, - _asInt16List, -); + TypedArraySerializer(TypeIds.int16Array, 2, _asInt16List); const TypedArraySerializer int32ArraySerializer = - TypedArraySerializer( - TypeIds.int32Array, - 4, - _asInt32List, -); + TypedArraySerializer(TypeIds.int32Array, 4, _asInt32List); const TypedArraySerializer int64ArraySerializer = - TypedArraySerializer( - TypeIds.int64Array, - 8, - _asInt64List, -); + TypedArraySerializer(TypeIds.int64Array, 8, _asInt64List); const TypedArraySerializer uint16ArraySerializer = - TypedArraySerializer( - TypeIds.uint16Array, - 2, - _asUint16List, -); + TypedArraySerializer(TypeIds.uint16Array, 2, _asUint16List); const TypedArraySerializer uint32ArraySerializer = - TypedArraySerializer( - TypeIds.uint32Array, - 4, - _asUint32List, -); + TypedArraySerializer(TypeIds.uint32Array, 4, _asUint32List); const TypedArraySerializer uint64ArraySerializer = - TypedArraySerializer( - TypeIds.uint64Array, - 8, - _asUint64List, -); + TypedArraySerializer(TypeIds.uint64Array, 8, _asUint64List); const TypedArraySerializer float16ArraySerializer = - TypedArraySerializer( - TypeIds.float16Array, - 2, - _asFloat16List, -); + TypedArraySerializer(TypeIds.float16Array, 2, _asFloat16List); const TypedArraySerializer bfloat16ArraySerializer = TypedArraySerializer( - TypeIds.bfloat16Array, - 2, - _asBfloat16List, -); + TypeIds.bfloat16Array, + 2, + _asBfloat16List, + ); const TypedArraySerializer float32ArraySerializer = TypedArraySerializer( - TypeIds.float32Array, - 4, - _asFloat32List, -); + TypeIds.float32Array, + 4, + _asFloat32List, + ); const TypedArraySerializer float64ArraySerializer = TypedArraySerializer( - TypeIds.float64Array, - 8, - _asFloat64List, -); - -td.Int16List _asInt16List(td.Uint8List bytes) => bytes.buffer.asInt16List( - bytes.offsetInBytes, - bytes.lengthInBytes ~/ 2, + TypeIds.float64Array, + 8, + _asFloat64List, ); -td.Int32List _asInt32List(td.Uint8List bytes) => bytes.buffer.asInt32List( - bytes.offsetInBytes, - bytes.lengthInBytes ~/ 4, - ); +td.Int16List _asInt16List(td.Uint8List bytes) => + bytes.buffer.asInt16List(bytes.offsetInBytes, bytes.lengthInBytes ~/ 2); + +td.Int32List _asInt32List(td.Uint8List bytes) => + bytes.buffer.asInt32List(bytes.offsetInBytes, bytes.lengthInBytes ~/ 4); Int64List _asInt64List(td.Uint8List bytes) => Int64List.view(bytes.buffer, bytes.offsetInBytes, bytes.lengthInBytes ~/ 8); -td.Uint16List _asUint16List(td.Uint8List bytes) => bytes.buffer.asUint16List( - bytes.offsetInBytes, - bytes.lengthInBytes ~/ 2, - ); +td.Uint16List _asUint16List(td.Uint8List bytes) => + bytes.buffer.asUint16List(bytes.offsetInBytes, bytes.lengthInBytes ~/ 2); Float16List _asFloat16List(td.Uint8List bytes) => Float16List.view( - bytes.buffer, bytes.offsetInBytes, bytes.lengthInBytes ~/ 2); + bytes.buffer, + bytes.offsetInBytes, + bytes.lengthInBytes ~/ 2, +); Bfloat16List _asBfloat16List(td.Uint8List bytes) => Bfloat16List.view( - bytes.buffer, - bytes.offsetInBytes, - bytes.lengthInBytes ~/ 2, - ); + bytes.buffer, + bytes.offsetInBytes, + bytes.lengthInBytes ~/ 2, +); -td.Uint32List _asUint32List(td.Uint8List bytes) => bytes.buffer.asUint32List( - bytes.offsetInBytes, - bytes.lengthInBytes ~/ 4, - ); +td.Uint32List _asUint32List(td.Uint8List bytes) => + bytes.buffer.asUint32List(bytes.offsetInBytes, bytes.lengthInBytes ~/ 4); Uint64List _asUint64List(td.Uint8List bytes) => Uint64List.view( - bytes.buffer, - bytes.offsetInBytes, - bytes.lengthInBytes ~/ 8, - ); + bytes.buffer, + bytes.offsetInBytes, + bytes.lengthInBytes ~/ 8, +); -td.Float32List _asFloat32List(td.Uint8List bytes) => bytes.buffer.asFloat32List( - bytes.offsetInBytes, - bytes.lengthInBytes ~/ 4, - ); +td.Float32List _asFloat32List(td.Uint8List bytes) => + bytes.buffer.asFloat32List(bytes.offsetInBytes, bytes.lengthInBytes ~/ 4); -td.Float64List _asFloat64List(td.Uint8List bytes) => bytes.buffer.asFloat64List( - bytes.offsetInBytes, - bytes.lengthInBytes ~/ 8, - ); +td.Float64List _asFloat64List(td.Uint8List bytes) => + bytes.buffer.asFloat64List(bytes.offsetInBytes, bytes.lengthInBytes ~/ 8); diff --git a/dart/packages/fory/lib/src/util/string_util.dart b/dart/packages/fory/lib/src/util/string_util.dart index b274264b80..82b3484b71 100644 --- a/dart/packages/fory/lib/src/util/string_util.dart +++ b/dart/packages/fory/lib/src/util/string_util.dart @@ -68,14 +68,11 @@ String decodeString(Uint8List bytes, int encoding) { } } -String readStringFromBuffer( - Buffer buffer, - int byteLength, - int encoding, -) { +String readStringFromBuffer(Buffer buffer, int byteLength, int encoding) { if (byteLength == 0) { return ''; } + buffer.checkReadableBytes(byteLength); final start = bufferReaderIndex(buffer); buffer.skip(byteLength); final bytes = bufferBytes(buffer); @@ -84,9 +81,7 @@ String readStringFromBuffer( return String.fromCharCodes(bytes, start, start + byteLength); case stringUtf16Encoding: if (byteLength.isOdd) { - throw StateError( - 'Invalid UTF-16 string payload length $byteLength.', - ); + throw StateError('Invalid UTF-16 string payload length $byteLength.'); } final codeUnitCount = byteLength ~/ 2; if (Endian.host == Endian.little && start.isEven) { diff --git a/dart/packages/fory/test/codegen_conversion_expression_test.dart b/dart/packages/fory/test/codegen_conversion_expression_test.dart index 7982787cc7..c28fff6854 100644 --- a/dart/packages/fory/test/codegen_conversion_expression_test.dart +++ b/dart/packages/fory/test/codegen_conversion_expression_test.dart @@ -74,6 +74,6 @@ void main() { 'rawValue', nullExpression: 'fallback()', ); - expect(expression, 'rawValue == null ? fallback() : rawValue'); + expect(expression, 'rawValue ?? fallback()'); }); } diff --git a/dart/packages/fory/test/collection_serializer_test.dart b/dart/packages/fory/test/collection_serializer_test.dart index 08eac611d5..a42e5d701c 100644 --- a/dart/packages/fory/test/collection_serializer_test.dart +++ b/dart/packages/fory/test/collection_serializer_test.dart @@ -339,23 +339,6 @@ void main() { _expectNumericContainerEqual(roundTrip, value); } }); - - test('enforces maxCollectionSize for list set and map', () { - final fory = Fory(maxCollectionSize: 2); - - expect( - () => fory.serialize([1, 2, 3]), - throwsA(isA()), - ); - expect( - () => fory.serialize({1, 2, 3}), - throwsA(isA()), - ); - expect( - () => fory.serialize({'a': 1, 'b': 2, 'c': 3}), - throwsA(isA()), - ); - }); }); } diff --git a/dart/packages/fory/test/manual_registration_test.dart b/dart/packages/fory/test/manual_registration_test.dart index dccfca1ad6..af87e16d23 100644 --- a/dart/packages/fory/test/manual_registration_test.dart +++ b/dart/packages/fory/test/manual_registration_test.dart @@ -188,12 +188,7 @@ void main() { }); test('constructor forwards direct config parameters', () { - final fory = Fory( - compatible: true, - maxDepth: 64, - maxCollectionSize: 1024, - maxBinarySize: 4096, - ); + final fory = Fory(compatible: true, maxDepth: 64); final bytes = fory.serialize((42)); expect(fory.deserialize(bytes), equals((42))); }); diff --git a/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart b/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart index a1ad837677..8469cdd87f 100644 --- a/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart +++ b/dart/packages/fory/test/scalar_and_typed_array_serializer_test.dart @@ -1512,32 +1512,5 @@ void main() { isTrue, ); }); - - test('enforces maxBinarySize on write and read', () { - final oversized = Uint8List.fromList([1, 2, 3, 4]); - - expect( - () => Fory(maxBinarySize: 3).serialize(oversized), - throwsA( - isA().having( - (error) => error.toString(), - 'message', - contains('Binary payload exceeds 3 bytes.'), - ), - ), - ); - - final bytes = Fory().serialize(oversized); - expect( - () => Fory(maxBinarySize: 3).deserialize(bytes), - throwsA( - isA().having( - (error) => error.toString(), - 'message', - contains('Binary payload exceeds 3 bytes.'), - ), - ), - ); - }); }); } diff --git a/docs/README.md b/docs/README.md index 678c60ffb7..27500ed942 100644 --- a/docs/README.md +++ b/docs/README.md @@ -10,6 +10,7 @@ [Kotlin](guide/kotlin/index.md) guides. - For row format, see the [row format spec](specification/row_format_spec.md). - For using Apache Fory™ with GraalVM native image, see [graalvm support](guide/java/graalvm-support.md) doc. +- For deserialization security boundaries, see the [security model](security/deserialization.md). ## Fory IDL Schema diff --git a/docs/guide/dart/configuration.md b/docs/guide/dart/configuration.md index 6ed942a689..f84ff49c2a 100644 --- a/docs/guide/dart/configuration.md +++ b/docs/guide/dart/configuration.md @@ -82,31 +82,13 @@ Limits how deeply nested an object graph can be. Increase this if you have legit final fory = Fory(maxDepth: 128); ``` -### `maxCollectionSize` - -Maximum number of elements accepted in any single list, set, or map field. Prevents runaway memory allocation from malformed messages. - -```dart -final fory = Fory(maxCollectionSize: 100000); -``` - -### `maxBinarySize` - -Maximum number of bytes accepted for any single binary blob field. - -```dart -final fory = Fory(maxBinarySize: 8 * 1024 * 1024); -``` - ## Defaults -| Option | Default | -| -------------------- | --------- | -| `compatible` | `true` | -| `checkStructVersion` | `false` | -| `maxDepth` | 256 | -| `maxCollectionSize` | 1 048 576 | -| `maxBinarySize` | 64 MiB | +| Option | Default | +| -------------------- | ------- | +| `compatible` | `true` | +| `checkStructVersion` | `false` | +| `maxDepth` | 256 | ## Xlang Notes @@ -122,7 +104,7 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkStructVersion: true` with `compatible: false` for intentional same-schema payloads. -- Set `maxDepth`, `maxCollectionSize`, and `maxBinarySize` to reject unexpectedly large payloads. +- Set `maxDepth` to reject unexpectedly deep payload shapes. - Prefer generated schemas and explicit field metadata over broad dynamic fields for untrusted input. ## Related Topics diff --git a/docs/guide/javascript/configuration.md b/docs/guide/javascript/configuration.md index 896f599b90..e4a0fb3bd9 100644 --- a/docs/guide/javascript/configuration.md +++ b/docs/guide/javascript/configuration.md @@ -43,22 +43,18 @@ const fory = new Fory({ ref: true, compatible: true, maxDepth: 100, - maxBinarySize: 64 * 1024 * 1024, - maxCollectionSize: 1_000_000, hps, }); ``` -| Option | Default | Description | -| -------------------------- | ----------- | ------------------------------------------------------------------------------------- | -| `ref` | `false` | Enable reference tracking for shared or circular object graphs | -| `compatible` | `true` | Allow field additions/removals without breaking existing messages | -| `maxDepth` | `50` | Maximum nesting depth. Must be `>= 2`. Increase for deeply nested structures | -| `maxBinarySize` | 64 MiB | Maximum bytes accepted for any single binary field | -| `maxCollectionSize` | `1_000_000` | Maximum elements accepted in any list, set, or map | -| `useSliceString` | `false` | Optional string-reading optimization for Node.js. Leave at default unless benchmarked | -| `hps` | unset | Optional fast string helper from `@apache-fory/hps` (Node.js 20+) | -| `hooks.afterCodeGenerated` | unset | Callback to inspect the generated serializer code, useful for debugging | +| Option | Default | Description | +| -------------------------- | ------- | ------------------------------------------------------------------------------------- | +| `ref` | `false` | Enable reference tracking for shared or circular object graphs | +| `compatible` | `true` | Allow field additions/removals without breaking existing messages | +| `maxDepth` | `50` | Maximum nesting depth. Must be `>= 2`. Increase for deeply nested structures | +| `useSliceString` | `false` | Optional string-reading optimization for Node.js. Leave at default unless benchmarked | +| `hps` | unset | Optional fast string helper from `@apache-fory/hps` (Node.js 20+) | +| `hooks.afterCodeGenerated` | unset | Callback to inspect the generated serializer code, useful for debugging | ## Reference Tracking @@ -105,8 +101,7 @@ Leave this unset unless you run on Node.js 20+ and have benchmarked your workloa Security-related configuration: - Register only the expected schemas before deserializing untrusted payloads. -- Set `maxDepth`, `maxBinarySize`, and `maxCollectionSize` for the maximum payload shape your - service accepts. +- Set `maxDepth` for the maximum nesting depth your service accepts. - Prefer explicit `Type.struct(...)` schemas over `Type.any()` for untrusted input. - Pass `hps` only from the official package version you deploy with Fory. diff --git a/docs/guide/javascript/index.md b/docs/guide/javascript/index.md index 29615de78b..2423ef5690 100644 --- a/docs/guide/javascript/index.md +++ b/docs/guide/javascript/index.md @@ -31,7 +31,7 @@ Fory-supported languages. - **Fast**: serializer code is generated and cached the first time you register a schema, not on every call - **Reference-aware**: shared references and circular object graphs are supported when enabled - **Explicit schemas**: field types, nullability, and polymorphism are declared once with `Type.*` builders or TypeScript decorators -- **Safe defaults**: configurable depth, binary size, and collection size limits reject unexpectedly large or deep payloads +- **Safe defaults**: configurable depth checks reject unexpectedly deep payloads - **Modern types**: `bigint`, typed arrays, `Map`, `Set`, `Date`, `float16`, and `bfloat16` are supported ## Installation @@ -103,7 +103,7 @@ Create one `Fory` instance per application and reuse it — creating a new one f ## Configuration Fory JavaScript is xlang-only. `new Fory()` uses compatible schema evolution by default. Configure -reference tracking, size limits, and optional Node.js string acceleration through constructor +reference tracking, maximum read depth, and optional Node.js string acceleration through constructor options; see [Configuration](configuration.md). ## Documentation diff --git a/docs/guide/javascript/troubleshooting.md b/docs/guide/javascript/troubleshooting.md index e1ef6bd29f..bba32aeef7 100644 --- a/docs/guide/javascript/troubleshooting.md +++ b/docs/guide/javascript/troubleshooting.md @@ -38,22 +38,6 @@ new Fory({ maxDepth: 100 }); Increase this only if your data is legitimately deeply nested. -## `Binary size ... exceeds maxBinarySize` - -A binary field or the overall message exceeded the safety limit. If the size is expected and the source is trusted, increase the limit: - -```ts -new Fory({ maxBinarySize: 128 * 1024 * 1024 }); -``` - -## `Collection size ... exceeds maxCollectionSize` - -A list, set, or map has more elements than the configured limit. This often means the data is unexpectedly large. If it is legitimate, increase the limit: - -```ts -new Fory({ maxCollectionSize: 2_000_000 }); -``` - ## `Field "..." is not nullable` You are passing `null` to a field that was not declared nullable. Fix: add `.setNullable(true)` to the field schema: diff --git a/docs/guide/javascript/xlang-serialization.md b/docs/guide/javascript/xlang-serialization.md index d96c4df424..9442fe2647 100644 --- a/docs/guide/javascript/xlang-serialization.md +++ b/docs/guide/javascript/xlang-serialization.md @@ -146,7 +146,7 @@ Use the same type ID or type name in every peer. ## Safety Limits -The `maxDepth`, `maxBinarySize`, and `maxCollectionSize` options protect the JavaScript peer from overly large payloads. They do not change the binary format; they only control what the local `Fory` instance accepts. +The `maxDepth` option bounds nested payloads. It does not change the binary format; it only controls what the local `Fory` instance accepts. ## Related Topics diff --git a/docs/guide/python/type-registration.md b/docs/guide/python/type-registration.md index 302093369f..9e5589595b 100644 --- a/docs/guide/python/type-registration.md +++ b/docs/guide/python/type-registration.md @@ -81,5 +81,5 @@ same registration IDs or names on every peer that shares those payloads. ## Related Topics - [Configuration](configuration.md) - Fory parameters -- [Configuration](configuration.md#security) - Strict mode, deserialization policies, and size limits +- [Configuration](configuration.md#security) - Strict mode, deserialization policies, and maximum read depth - [Custom Serializers](custom-serializers.md) - Custom serialization diff --git a/docs/guide/swift/configuration.md b/docs/guide/swift/configuration.md index 450503f804..85c1e33257 100644 --- a/docs/guide/swift/configuration.md +++ b/docs/guide/swift/configuration.md @@ -30,8 +30,6 @@ public struct Config { public let trackRef: Bool public let compatible: Bool public let checkClassVersion: Bool - public let maxCollectionSize: Int - public let maxBinarySize: Int public let maxDepth: Int } ``` @@ -88,11 +86,10 @@ let fory = Fory(compatible: false, checkClassVersion: true) ### Size and Depth Limits -`maxCollectionSize`, `maxBinarySize`, and `maxDepth` bound decoded payload size -and nesting depth. +`maxDepth` bounds decoded payload nesting depth. ```swift -let fory = Fory(maxCollectionSize: 1_000_000, maxBinarySize: 64 * 1024 * 1024, maxDepth: 5) +let fory = Fory(maxDepth: 5) ``` ## Recommended Presets @@ -123,5 +120,4 @@ Security-related configuration: - Register only the expected generated models before deserializing untrusted payloads. - Use `checkClassVersion` with `compatible: false` for intentional same-schema payloads. -- Set `maxCollectionSize`, `maxBinarySize`, and `maxDepth` for the largest payload shape your - service accepts. +- Set `maxDepth` for the largest nesting depth your service accepts. diff --git a/docs/security/_category_.json b/docs/security/_category_.json new file mode 100644 index 0000000000..eeee0bd5e0 --- /dev/null +++ b/docs/security/_category_.json @@ -0,0 +1,6 @@ +{ + "position": 5, + "label": "Security", + "collapsible": true, + "collapsed": true +} diff --git a/docs/security/deserialization.md b/docs/security/deserialization.md new file mode 100644 index 0000000000..390b1855b4 --- /dev/null +++ b/docs/security/deserialization.md @@ -0,0 +1,281 @@ +--- +title: Deserialization Security Model +sidebar_position: 2 +--- + +This document defines the security model for Apache Fory deserialization. It is +a public security reference for classifying deserialization behavior and +deciding where validation is required. It is not a vulnerability disclosure, +does not describe exploit techniques, and does not document implementation +history. + +The model is intentionally narrow. Fory should prevent resource and policy +failures caused by untrusted input, but it should not add hot-path validation +that only enforces byte-form strictness when doing so does not protect a Fory +security boundary. + +## Scope + +This model applies to deserializing Fory binary data from untrusted or +partially trusted sources. + +It does not treat the semantic content of a successfully deserialized value as a +Fory security boundary. A sender can always construct protocol-valid data whose +value is chosen by that sender. Application authorization, object-level business +rules, and domain-specific validation remain application responsibilities. + +This model also does not cover trusted in-memory formats. Row format and other +memory-format paths are trusted-data paths unless a runtime explicitly exposes +them as untrusted deserialization APIs. + +## Trust Boundaries + +Fory deserialization should treat the encoded input as untrusted at API +boundaries that accept external bytes or streams. + +Fory security boundaries include: + +- Resource ownership, such as memory, CPU progress, stream buffering, file + handles, native allocations, callbacks, and retained read-side tables. +- Runtime safety, such as avoiding crashes, panics, undefined behavior, and + out-of-bounds reads or writes. +- Explicit Fory policy checks, such as type, function, method, class, or + registration policies that are intended to restrict what may be materialized. +- Cleanup boundaries, where state created during a failed read must be released + or reset before the next root operation. + +Fory security boundaries do not include: + +- The business meaning of a protocol-valid value. +- Which protocol-allowed byte form was used for a value. +- Whether a map, set, object, or metadata value uses one specific encoding + shape, unless rejecting other shapes is an explicit owner policy or protects + one of the boundaries above. + +## Security Invariants + +Deserialization code must prevent the following outcomes for untrusted input: + +- Crash, panic, undefined behavior, or out-of-bounds memory access. +- OOM or disproportionate allocation compared with bytes that are already + supplied or proven readable. +- No-progress loops, including loops where neither logical progress nor byte + progress is guaranteed after malformed input. +- Stream-buffer growth to an attacker-declared size before the corresponding + bytes have been read or skipped exactly. +- Resource leaks, including native allocations, handles, callbacks, or + registered cleanup work that cannot run. +- Retained attacker-controlled state after failure when that state can affect a + later root operation or grow across operations. +- Successful bypass of an explicit Fory policy boundary. + +When a path cannot produce one of these outcomes, earlier rejection of malformed +bytes is normally a correctness or interoperability choice, not a security +requirement. + +## Non-Security Semantics + +The following patterns are not vulnerabilities by default: + +- Protocol-allowed collection chunking, map chunking, and field ordering. +- Duplicate keys, set elements, or compatible fields that collapse according to + the target data structure or owning serializer semantics. +- Malformed ref, null, or type flags that eventually produce a read error. +- Malformed scalar bytes that are consumed linearly and eventually produce a + read error. +- Reading an encoded body before later shape validation when the operation + ultimately returns an error and does not create a security-invariant failure. + +Fory may still reject malformed forms for specification strictness or +interoperability. That validation should be added only when it is required by +the protocol owner, is effectively free on the relevant path, or protects a +security invariant listed above. Do not add protocol-layer validation solely to +reject scalar byte forms whose only effect is extra decode cost. + +### Value-bearing ref flags + +Some read paths intentionally share handling for multiple value-bearing flags. +For example, when both `NotNullValue` and `RefValue` mean that an encoded value +follows, a reader may merge their hot-path handling. This is not a malformed +flag bug by itself. Treat it as a bug only if the merged handling loses required +reference semantics, returns success across an explicit owner policy, or creates +a resource or runtime-safety failure. + +## Allocation And Byte Availability + +Fory should not make large allocations from attacker-declared lengths before +the required bytes are available or have been read exactly. + +For buffer-backed input: + +- Fixed-size binary values and primitive dense arrays should call the byte + owner's readability check for the required encoded byte size before allocating + the destination. For buffer-backed input this is normally a remaining-byte + comparison. +- Multi-byte element arrays should compute the required byte size with overflow + checks before allocation. +- Container readers that allocate, reserve, or size-hint from a declared + logical element count should first call the byte owner's readability check for + that count. This is not a full container-body validation; it is the allocation + proof that the sender has supplied at least proportional input bytes before + the reader preallocates from the count. + +For stream-backed input: + +- Reading or skipping a large byte region is the proof that the bytes exist. +- Byte-counted variable-length result allocation should use the byte owner's + readability check before allocation. Skip paths may use bounded skip without + materializing the skipped value. +- A stream-backed buffer may hold the full requested encoded body after that + body has been read from the stream. It must not reserve the attacker-declared + length before input bytes prove that length exists. +- Stream-backed fill buffers should grow from the current proven buffer size, + such as by doubling current capacity, and cap only to the immediate target + when the next bounded growth step reaches it. A byte owner may use an + owner-local availability signal as a one-shot growth hint when the stream + implementation itself is caller-owned trusted code; if that hint is absent or + insufficient, the reader must fall back to bounded growth from already + buffered bytes. Serializers should not add their own availability branches. +- A truncated stream should fail before allocating the final deserialized value + and should allocate only for bytes actually read plus bounded spare capacity. + +The byte owner should stay byte-oriented. Buffer, reader, or read-context APIs +may expose byte read and byte skip operations, but string decoding, decimal +parsing, primitive-array encoding, compression modes, and collection capacity +policy belong to the owning serializers. + +## Collection And Map Capacity + +Large valid collection inputs are allowed. If the input contains many encoded +elements, proportional deserialization is expected. + +The security requirement is to avoid disproportionate preallocation from a +declared logical count before enough input bytes justify that capacity. For a +non-empty container, a reader that will allocate or reserve from the declared +count should call `checkReadableBytes(logicalCount)` or the runtime equivalent +before that allocation. The check remains byte-owner-only: it does not decode +the whole container, validate element semantics, or replace chunk validation. +Readers that do not preallocate from the logical count may still grow +proportionally as elements are actually read. + +Map or collection chunk validation is security-relevant only when missing +validation can cause a no-progress loop, unbounded resource growth, retained +state, or success across a Fory policy boundary. Protocol-allowed chunk +segmentation is normal input and is not a security issue by itself. + +## Skip Semantics + +Skipping unknown or incompatible data is classified by concrete impact, not by +whether the runtime materializes a temporary value. + +Directly consuming encoded contents is useful when it is simple and owned by the +current runtime path. It is not a security requirement for complex fields such +as lists, sets, and maps. A runtime may materialize a value and discard it when +that preserves the existing serializer ownership model. + +For extension, dynamic, or user-owned types, the owning runtime may not always +have enough information to skip without invoking a registered serializer. In +that case, classify the behavior by concrete impact: + +- Resource leak, retained state, no-progress loop, or policy bypass is + security-relevant. +- Bounded materialization followed by an error or discard is allowed unless it + creates meaningful memory or CPU pressure. +- Pure strictness about whether a skipped value used one specific encoding shape + is not a security issue. + +## Metadata And Type Resolution + +Metadata parsing is security-sensitive when it affects retained read-side state, +type dispatch, or policy decisions. + +Metadata readers should: + +- Avoid unbounded recursion in nested metadata structures. +- Avoid unbounded table growth from attacker-controlled metadata streams. +- Validate metadata bodies before using them to bypass or replace existing + policy decisions. +- Reset or release metadata state at the correct root-operation boundary. + +Metadata byte-form strictness alone is not a security requirement. Rejecting a +metadata shape is useful only when the owner wants that strictness or when the +shape changes type identity, retained state, resource use, or policy behavior. + +## Reference Tracking + +Reference tracking is part of the wire protocol and is performance-sensitive. +Readers may use sentinel values and shared value-bearing branches to keep hot +paths compact. + +Reference tracking validation is security-relevant when malformed input can: + +- Access an out-of-range reference without reporting an error. +- Leave retained reference state after a failed root operation. +- Register unbounded callbacks or resolver state before the referenced value is + available. +- Cause a no-progress loop or crash. + +Reference tracking validation is not required merely because a malformed flag is +not rejected at the earliest possible byte. Lazy rejection is acceptable when +the root operation still returns an error and no security invariant is violated. + +## Error Propagation And Cleanup + +Fory runtimes may intentionally use lazy error propagation. After a read records +an error, later read steps may continue until the outer operation observes and +returns the error. + +This is acceptable when the continued work cannot: + +- Crash or panic. +- Allocate or retain attacker-controlled state. +- Leak resources. +- Bypass required cleanup. +- Return success across an explicit validation or policy boundary. + +Nested `try`/`finally` or equivalent cleanup should be added only when the +outer root-operation cleanup cannot cover the state or resource owned by the +nested path. + +## Performance Requirements + +Security validation must preserve Fory hot-path performance. Do not add +validation solely for strictness when it introduces: + +- Per-element object allocation. +- Dynamic dispatch or callbacks in hot loops. +- Wrapper objects or result carriers on success paths. +- Extra copying for buffer-backed string, binary, or primitive-array reads. +- Branches that do not protect a security invariant. + +Prefer owner-local checks that can be inlined and that already use information +available in the current serializer. Do not move serializer-owned semantics into +generic read-context helpers. + +## Classification Guide + +Use the following questions when reviewing deserialization behavior: + +1. Can this input crash, panic, or access memory out of bounds? +2. Can a small or unproven input length cause disproportionate allocation? +3. Can a stream-backed reader grow a buffer before exact read or skip proves the + bytes exist? +4. Can a loop continue without byte progress or logical progress? +5. Can the path retain attacker-controlled state after the root operation fails? +6. Can the path leak resources or skip required cleanup? +7. Can the path return success across an explicit Fory policy boundary? +8. Is the proposed validation effectively free in the relevant hot path? + +If the answer to the first seven questions is no, the issue is normally not a +security finding. If the validation is not effectively free, avoid adding it +unless the protocol owner explicitly requires it. + +## Documentation Boundaries + +Security model documents must not include exploit samples, CVE narratives, +line-level vulnerability candidates, branch history, migration timelines, or +cleanup plans. Keep those details in private reports, issues, or pull requests +as appropriate. + +Public security documentation should describe durable boundaries and invariants, +not the history of how the implementation reached them. diff --git a/docs/security/index.md b/docs/security/index.md new file mode 100644 index 0000000000..c75b583538 --- /dev/null +++ b/docs/security/index.md @@ -0,0 +1,16 @@ +--- +title: Security +sidebar_position: 1 +--- + +This directory documents Apache Fory security models and security invariants. +It is not a vulnerability disclosure area and does not list CVE details, +exploit samples, issue timelines, or implementation history. + +Security model documents describe how Fory should classify and prevent security +risks while preserving the performance characteristics expected from Fory +serialization runtimes. + +## Models + +- [Deserialization Security Model](deserialization.md) diff --git a/docs/specification/xlang_implementation_guide.md b/docs/specification/xlang_implementation_guide.md index 9ad0ec1189..da68fd5859 100644 --- a/docs/specification/xlang_implementation_guide.md +++ b/docs/specification/xlang_implementation_guide.md @@ -271,6 +271,142 @@ The current root read flow mirrors the write flow: Primitive and string-like hot paths should read directly from the buffer; complex payloads delegate to the resolved serializer. +### Stream And Buffer Byte Reads + +Implementations must keep byte availability in the byte owner layer while +keeping string, binary, primitive-array, compression, and collection semantics in +serializers. + +The required byte-owner primitive for allocation-before-read checks is a +readability check such as `checkReadableBytes(byteCount)`. Implementations do +not need additional generic read-context methods for this design. After the +readability check succeeds, serializers use their existing local buffer read, +copy, or decode paths. + +The readability check is a byte operation only. It must not decode strings, +primitive-array element counts, compression modes, or collection capacity +policy. + +For large byte-counted values, every implementation should call the byte-owner +readability check before allocating a variable-length result. This applies to +binary values, strings, decimal or metadata bodies, and primitive wire arrays +whose encoded body is measured in bytes. For multi-byte primitive wire arrays, +compare the encoded byte count, not only the logical element count, with the +readable bytes. + +1. Validate the encoded byte count in the serializer. For fixed-width primitive + arrays, check overflow and element alignment before allocation, such as + `wireByteCount % elementByteWidth == 0`, then derive the logical element + count from the encoded byte count. +2. Call `checkReadableBytes(wireByteCount)` unconditionally before allocating + the variable-length result. Buffer-backed inputs normally return from this + check with only a bounds comparison. Stream-backed inputs use the same call; + the byte owner handles the fast path when enough bytes are already buffered + and otherwise fills the read buffer until the requested encoded body is + readable or an input error is recorded. +3. After readability is proven, allocate the final value once and copy or decode + from the current readable buffer into the final result. + +`checkReadableBytes` is not an `ensureCapacity(wireByteCount)` operation. In +stream mode it may end with the byte owner holding the full encoded body in its +read buffer, but it must grow that buffer as bytes are successfully read from +the stream. It should grow from current proven buffer capacity, such as by +doubling current capacity, and cap only when that bounded growth step reaches +the immediate target. A byte owner may use an owner-local availability signal as +a one-shot growth hint when the stream implementation itself is caller-owned +trusted code; if that hint is absent or insufficient, it must fall back to +bounded growth from already buffered bytes. It must not reserve the +attacker-declared length before input bytes or an owner-local growth hint +justify that intermediate buffer capacity. The stream slow path may pay one +extra intermediate buffer copy; this is preferable to serializer-local chunk +accumulation and repeated final-container growth. + +For byte-counted values, the serializer should not duplicate the byte owner's +fast-path branch by testing `availableBytes()` before calling +`checkReadableBytes`. Keeping that branch in the byte owner gives every language +the same correctness rule and keeps serializer hot paths focused on their own +wire semantics. + +For primitive wire arrays: + +- Compare and prove the encoded wire byte count, not only the logical element + count. +- Keep compression, bit-packing, byte-order conversion, and other primitive + array encoding semantics in the serializer. `checkReadableBytes` only proves + that the encoded bytes are present. +- For compressed or transformed bodies, the serializer must still validate the + decoded length and encoding-specific metadata before allocating or returning + the final value. + +The common serializer shape is: + +```text +wireByteCount = readVarUint32() +elementWidth = primitiveWireElementWidth(kind) +validate wireByteCount and element alignment +elementCount = wireByteCount / elementWidth + +ctx.checkReadableBytes(wireByteCount) +result = allocatePrimitiveResult(elementCount) +copy or decode wireByteCount bytes from the current readable buffer into result +advance the reader index by wireByteCount +return result +``` + +Byte values are the `elementWidth == 1` specialization of the same policy. In +that case the serializer shape is: + +```text +byteCount = readVarUint32() + +ctx.checkReadableBytes(byteCount) +result = allocateBytes(byteCount) +copy byteCount bytes from the current readable buffer into result +advance the reader index by byteCount +return result +``` + +This policy avoids three inefficient implementation shapes: + +- allocating the complete final contiguous value before the encoded body is + readable +- growing or repeatedly copying the final result container on stream slow paths +- adding serializer-local chunk buffers when the byte owner can prove + readability once and expose a normal buffered read + +Scratch buffers remain appropriate when the target representation is not a +direct byte target, such as string transcoding, compression, byte-order +conversion that is not performed in place, bit-packed values, or runtimes whose +stream API cannot read into a caller-provided target. + +For fixed-width primitive arrays, the final result must not become visible to +callers until the exact encoded byte count has been read successfully. + +For list, set, map, and other container readers, the declared logical element +count is not an encoded byte count, so serializers must still own all element, +chunk, nullability, reference, and type-dispatch semantics. It is still the +right allocation proof for count-based preallocation: after validating a +non-empty count and reading any serializer-owned header or type metadata that +precedes allocation, call `checkReadableBytes(logicalCount)` before allocating, +reserving, or size-hinting from that count. The byte owner handles buffer versus +stream readiness; the container serializer then allocates with the declared +count and reads elements through its normal owner path. + +This check is not a full container-body validation. It only prevents a small or +truncated input from causing a large count-based preallocation. Chunk sizes, +duplicate keys, element value semantics, and protocol strictness remain owned by +the container/map serializer and should be validated only when they protect a +real owner invariant. + +For TypeDef or TypeMeta bodies, first prove that the encoded metadata body bytes +are readable through the byte owner. Field-list allocation should happen after +that body readability check and should not use a separate small initial-capacity +cap as a security rule. + +Skip paths do not need to materialize skipped values. Existing byte-skip +operations should consume any available buffered prefix first, then skip or drop +remaining stream bytes in bounded steps. + ### Nested reads use `ReadContext` Important rules: diff --git a/go/fory/array.go b/go/fory/array.go index 6cb1751f3f..f99f6ff39f 100644 --- a/go/fory/array.go +++ b/go/fory/array.go @@ -228,6 +228,13 @@ func (s *arrayConcreteValueSerializer) ReadData(ctx *ReadContext, value reflect. buf := ctx.Buffer() err := ctx.Err() length := int(buf.ReadVarUint32(err)) + if ctx.HasError() { + return + } + if length != value.Len() { + ctx.SetError(DeserializationErrorf("array length %d does not match serialized length %d", value.Len(), length)) + return + } var trackRefs bool if length > 0 { @@ -312,7 +319,7 @@ func (s arrayDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // Create a temp slice to read into, then copy back to array sliceType := reflect.SliceOf(value.Type().Elem()) tempSlice := reflect.MakeSlice(sliceType, value.Len(), value.Len()) - s.sliceSerializer.ReadData(ctx, tempSlice) + s.sliceSerializer.readData(ctx, tempSlice, value.Len()) if ctx.HasError() { return } @@ -365,16 +372,29 @@ func (s byteArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s byteArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() - length := ctx.ReadCollectionLength() + err := ctx.Err() + length := buf.ReadLength(err) if ctx.HasError() { return } + if length != value.Len() { + ctx.SetError(DeserializationErrorf("array length %d does not match serialized binary length %d", value.Len(), length)) + return + } + if !buf.CheckReadable(length, err) { + return + } + if length == 0 { + return + } + if value.CanAddr() { + buf.Read(value.Slice(0, length).Bytes()) + return + } data := make([]byte, length) buf.Read(data) - if value.CanSet() { - for i := 0; i < length && i < value.Len(); i++ { - value.Index(i).SetUint(uint64(data[i])) - } + for i := 0; i < length && i < value.Len(); i++ { + value.Index(i).SetUint(uint64(data[i])) } } diff --git a/go/fory/array_primitive.go b/go/fory/array_primitive.go index 06c76dc782..c8df38a894 100644 --- a/go/fory/array_primitive.go +++ b/go/fory/array_primitive.go @@ -25,6 +25,25 @@ import ( "github.com/apache/fory/go/fory/float16" ) +func checkWireArraySize(ctx *ReadContext, size, elemSize int, value reflect.Value) (int, bool) { + if ctx.HasError() { + return 0, false + } + if size%elemSize != 0 { + ctx.SetError(DeserializationErrorf("array byte length %d is not aligned to element size %d", size, elemSize)) + return 0, false + } + length := size / elemSize + if length != value.Type().Len() { + ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + return 0, false + } + if !ctx.Buffer().CheckReadable(size, ctx.Err()) { + return 0, false + } + return length, true +} + // ============================================================================ // boolArraySerializer - optimized [N]bool serialization // ============================================================================ @@ -65,20 +84,14 @@ func (s boolArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s boolArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() - err := ctx.Err() length := ctx.ReadBinaryLength() - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + if _, ok := checkWireArraySize(ctx, length, 1, value); !ok { return } if length > 0 { // Direct memory copy - bool is 1 byte in Go ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(length, err) - copy(unsafe.Slice((*byte)(ptr), length), raw) + buf.Read(unsafe.Slice((*byte)(ptr), length)) } } @@ -130,20 +143,14 @@ func (s int8ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeType func (s int8ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() - err := ctx.Err() length := ctx.ReadBinaryLength() - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + if _, ok := checkWireArraySize(ctx, length, 1, value); !ok { return } if length > 0 { // Direct memory copy - int8 is 1 byte ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(length, err) - copy(unsafe.Slice((*byte)(ptr), length), raw) + buf.Read(unsafe.Slice((*byte)(ptr), length)) } } @@ -198,19 +205,14 @@ func (s int16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 2 - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + length, ok := checkWireArraySize(ctx, size, 2, value) + if !ok { return } if length > 0 { if isLittleEndian { ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(ptr), size), raw) + buf.Read(unsafe.Slice((*byte)(ptr), size)) } else { for i := 0; i < length; i++ { value.Index(i).SetInt(int64(buf.ReadInt16(err))) @@ -270,19 +272,14 @@ func (s int32ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 4 - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + length, ok := checkWireArraySize(ctx, size, 4, value) + if !ok { return } if length > 0 { if isLittleEndian { ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(ptr), size), raw) + buf.Read(unsafe.Slice((*byte)(ptr), size)) } else { for i := 0; i < length; i++ { value.Index(i).SetInt(int64(buf.ReadInt32(err))) @@ -342,19 +339,14 @@ func (s int64ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 8 - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + length, ok := checkWireArraySize(ctx, size, 8, value) + if !ok { return } if length > 0 { if isLittleEndian { ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(ptr), size), raw) + buf.Read(unsafe.Slice((*byte)(ptr), size)) } else { for i := 0; i < length; i++ { value.Index(i).SetInt(buf.ReadInt64(err)) @@ -414,19 +406,14 @@ func (s float32ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) buf := ctx.Buffer() err := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 4 - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + length, ok := checkWireArraySize(ctx, size, 4, value) + if !ok { return } if length > 0 { if isLittleEndian { ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(ptr), size), raw) + buf.Read(unsafe.Slice((*byte)(ptr), size)) } else { for i := 0; i < length; i++ { value.Index(i).SetFloat(float64(buf.ReadFloat32(err))) @@ -486,19 +473,14 @@ func (s float64ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) buf := ctx.Buffer() err := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 8 - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + length, ok := checkWireArraySize(ctx, size, 8, value) + if !ok { return } if length > 0 { if isLittleEndian { ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(ptr), size), raw) + buf.Read(unsafe.Slice((*byte)(ptr), size)) } else { for i := 0; i < length; i++ { value.Index(i).SetFloat(buf.ReadFloat64(err)) @@ -555,20 +537,14 @@ func (s uint8ArraySerializer) Write(ctx *WriteContext, refMode RefMode, writeTyp func (s uint8ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() - err := ctx.Err() length := ctx.ReadBinaryLength() - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + if _, ok := checkWireArraySize(ctx, length, 1, value); !ok { return } if length > 0 { // Direct memory copy - uint8 is 1 byte ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(length, err) - copy(unsafe.Slice((*byte)(ptr), length), raw) + buf.Read(unsafe.Slice((*byte)(ptr), length)) } } @@ -624,19 +600,14 @@ func (s uint16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 2 - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + length, ok := checkWireArraySize(ctx, size, 2, value) + if !ok { return } if length > 0 { if isLittleEndian { ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(ptr), size), raw) + buf.Read(unsafe.Slice((*byte)(ptr), size)) } else { for i := 0; i < length; i++ { value.Index(i).SetUint(uint64(uint16(buf.ReadInt16(err)))) @@ -695,19 +666,14 @@ func (s uint32ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 4 - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + length, ok := checkWireArraySize(ctx, size, 4, value) + if !ok { return } if length > 0 { if isLittleEndian { ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(ptr), size), raw) + buf.Read(unsafe.Slice((*byte)(ptr), size)) } else { for i := 0; i < length; i++ { value.Index(i).SetUint(uint64(uint32(buf.ReadInt32(err)))) @@ -765,19 +731,14 @@ func (s uint64ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() err := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 8 - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + length, ok := checkWireArraySize(ctx, size, 8, value) + if !ok { return } if length > 0 { if isLittleEndian { ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(ptr), size), raw) + buf.Read(unsafe.Slice((*byte)(ptr), size)) } else { for i := 0; i < length; i++ { value.Index(i).SetUint(uint64(buf.ReadInt64(err))) @@ -839,20 +800,15 @@ func (s float16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) buf := ctx.Buffer() ctxErr := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 2 - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + length, ok := checkWireArraySize(ctx, size, 2, value) + if !ok { return } if length > 0 { if isLittleEndian { ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(size, ctxErr) - copy(unsafe.Slice((*byte)(ptr), size), raw) + buf.Read(unsafe.Slice((*byte)(ptr), size)) } else { for i := 0; i < length; i++ { value.Index(i).Set(reflect.ValueOf(float16.Float16FromBits(buf.ReadUint16(ctxErr)))) @@ -913,20 +869,15 @@ func (s bfloat16ArraySerializer) ReadData(ctx *ReadContext, value reflect.Value) buf := ctx.Buffer() ctxErr := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 2 - if ctx.HasError() { - return - } - if length != value.Type().Len() { - ctx.SetError(DeserializationErrorf("array length %d does not match type %v", length, value.Type())) + length, ok := checkWireArraySize(ctx, size, 2, value) + if !ok { return } if length > 0 { if isLittleEndian { ptr := value.Addr().UnsafePointer() - raw := buf.ReadBinary(size, ctxErr) - copy(unsafe.Slice((*byte)(ptr), size), raw) + buf.Read(unsafe.Slice((*byte)(ptr), size)) } else { for i := 0; i < length; i++ { value.Index(i).Set(reflect.ValueOf(bfloat16.BFloat16FromBits(buf.ReadUint16(ctxErr)))) diff --git a/go/fory/array_test.go b/go/fory/array_test.go index 9c18975ade..37cf3c41a5 100644 --- a/go/fory/array_test.go +++ b/go/fory/array_test.go @@ -97,3 +97,23 @@ func TestArrayDynSerializerRoundTrip(t *testing.T) { require.Equal(t, arr[3], resultSlice[3]) }) } + +func TestArrayRejectsLengthMismatch(t *testing.T) { + f := NewFory(WithXlang(false), WithCompatible(false)) + + t.Run("concrete", func(t *testing.T) { + bytes, err := f.Marshal([3]string{"a", "b", "c"}) + require.NoError(t, err) + + var out [2]string + require.Error(t, f.Unmarshal(bytes, &out)) + }) + + t.Run("dynamic", func(t *testing.T) { + bytes, err := f.Marshal([3]any{"a", "b", "c"}) + require.NoError(t, err) + + var out [2]any + require.Error(t, f.Unmarshal(bytes, &out)) + }) +} diff --git a/go/fory/buffer.go b/go/fory/buffer.go index 2e0f13655e..89e29f938d 100644 --- a/go/fory/buffer.go +++ b/go/fory/buffer.go @@ -50,6 +50,12 @@ func NewByteBufferFromReader(r io.Reader, bufferSize int) *ByteBuffer { //go:noinline func (b *ByteBuffer) fill(n int, errOut *Error) bool { + if n < 0 { + if errOut != nil { + *errOut = DeserializationErrorf("negative readable byte count: %d", n) + } + return false + } if b.reader == nil { if errOut != nil { *errOut = BufferOutOfBoundError(b.readerIndex, n, len(b.data)) @@ -69,24 +75,38 @@ func (b *ByteBuffer) fill(n int, errOut *Error) bool { b.data = b.data[:b.writerIndex] } - if cap(b.data) < n { - newCap := cap(b.data) * 2 - if newCap < n { - newCap = n - } - if newCap < b.bufferSize { - newCap = b.bufferSize - } - newData := make([]byte, len(b.data), newCap) - copy(newData, b.data) - b.data = newData - } - for len(b.data) < n { - spare := b.data[len(b.data):cap(b.data)] - if len(spare) == 0 { - return false + if len(b.data) == cap(b.data) { + // n can come from attacker-controlled wire lengths. Do not query + // reader availability here: interface/type-specific probes add + // hot-path cost for a rare fast path and are not the correctness + // source. Grow only from bytes already buffered so truncated streams + // fail before reserving the declared body size. + currentCap := cap(b.data) + newCap := currentCap * 2 + if currentCap > MaxInt/2 { + newCap = MaxInt + } + if newCap < b.bufferSize { + newCap = b.bufferSize + } + if newCap <= currentCap { + newCap = currentCap + 1 + } + if newCap > n { + newCap = n + } + if newCap <= currentCap { + if errOut != nil { + *errOut = DeserializationErrorf("stream buffer size exceeds supported range") + } + return false + } + newData := make([]byte, len(b.data), newCap) + copy(newData, b.data) + b.data = newData } + spare := b.data[len(b.data):cap(b.data)] readBytes, err := b.reader.Read(spare) if readBytes > 0 { b.data = b.data[:len(b.data)+readBytes] @@ -105,6 +125,35 @@ func (b *ByteBuffer) fill(n int, errOut *Error) bool { } return false } + if readBytes == 0 { + if errOut != nil { + *errOut = DeserializationError("stream read made no progress") + } + return false + } + } + return true +} + +func (b *ByteBuffer) discardFromReader(length int, errOut *Error) bool { + var scratch [8192]byte + for length > 0 { + n := length + if n > len(scratch) { + n = len(scratch) + } + readBytes, err := io.ReadFull(b.reader, scratch[:n]) + length -= readBytes + if err != nil { + if errOut != nil { + if err == io.EOF || err == io.ErrUnexpectedEOF { + *errOut = BufferOutOfBoundError(b.readerIndex, n, readBytes) + } else { + *errOut = DeserializationError(fmt.Sprintf("stream read error: %v", err)) + } + } + return false + } } return true } @@ -197,7 +246,15 @@ func (b *ByteBuffer) WriteLength(value int) { } func (b *ByteBuffer) ReadLength(err *Error) int { - return int(b.ReadVarUint32(err)) + length := b.ReadVarUint32(err) + const maxInt32 = uint32(1<<31 - 1) + if intSize == 32 && length > maxInt32 { + if err != nil { + *err = DeserializationErrorf("length %d exceeds supported int range", length) + } + return 0 + } + return int(length) } func (b *ByteBuffer) WriteUint64(value uint64) { @@ -371,7 +428,13 @@ func (b *ByteBuffer) Read(p []byte) (n int, err error) { // ReadBinary reads n bytes and sets error on bounds violation func (b *ByteBuffer) ReadBinary(length int, err *Error) []byte { - if b.readerIndex+length > len(b.data) { + if length < 0 { + if err != nil { + *err = DeserializationErrorf("negative byte count: %d", length) + } + return nil + } + if length > len(b.data)-b.readerIndex { if !b.fill(length, err) { return nil } @@ -1555,7 +1618,13 @@ func (b *ByteBuffer) IncreaseReaderIndex(n int) { // ReadBytes reads n bytes and sets error on bounds violation func (b *ByteBuffer) ReadBytes(n int, err *Error) []byte { - if b.readerIndex+n > len(b.data) { + if n < 0 { + if err != nil { + *err = DeserializationErrorf("negative byte count: %d", n) + } + return nil + } + if n > len(b.data)-b.readerIndex { if !b.fill(n, err) { return nil } @@ -1576,10 +1645,25 @@ func (b *ByteBuffer) ReadBytes(n int, err *Error) []byte { // Skip skips n bytes and sets error on bounds violation func (b *ByteBuffer) Skip(length int, err *Error) { - if b.readerIndex+length > len(b.data) { - if !b.fill(length, err) { + if length < 0 { + if err != nil { + *err = DeserializationErrorf("negative skip length: %d", length) + } + return + } + if length > len(b.data)-b.readerIndex { + if b.reader == nil { + if err != nil { + *err = BufferOutOfBoundError(b.readerIndex, length, len(b.data)) + } return } + available := len(b.data) - b.readerIndex + b.readerIndex = len(b.data) + if !b.discardFromReader(length-available, err) { + return + } + return } b.readerIndex += length } @@ -1587,7 +1671,13 @@ func (b *ByteBuffer) Skip(length int, err *Error) { // CheckReadable ensures that at least n bytes are available to read. // In stream mode, it will attempt to fill the buffer if necessary. func (b *ByteBuffer) CheckReadable(n int, err *Error) bool { - if b.readerIndex+n > len(b.data) { + if n < 0 { + if err != nil { + *err = DeserializationErrorf("negative readable byte count: %d", n) + } + return false + } + if n > len(b.data)-b.readerIndex { return b.fill(n, err) } return true diff --git a/go/fory/buffer_test.go b/go/fory/buffer_test.go index b4f2022389..78ac38f4f4 100644 --- a/go/fory/buffer_test.go +++ b/go/fory/buffer_test.go @@ -125,6 +125,48 @@ func TestReadVarUint32RejectsOverflowFifthByte(t *testing.T) { } } +func TestStreamSkipDoesNotGrowToSkippedLength(t *testing.T) { + err := &Error{} + data := make([]byte, 1<<20) + reader := bytes.NewReader(data) + buf := NewByteBufferFromReader(reader, 16) + + buf.Skip(len(data), err) + + require.True(t, err.Ok()) + require.Zero(t, reader.Len()) + require.Equal(t, len(buf.data), buf.readerIndex) + require.LessOrEqual(t, cap(buf.data), 16) +} + +func TestStreamFillDoubleGrowsFromBufferedBytes(t *testing.T) { + var err Error + data := make([]byte, 100) + buf := NewByteBufferFromReader(bytes.NewReader(data), 4) + require.True(t, buf.fill(len(data), &err)) + require.True(t, err.Ok()) + require.Equal(t, len(data), cap(buf.data)) + + err = Error{} + truncated := make([]byte, 17) + buf = NewByteBufferFromReader(bytes.NewReader(truncated), 4) + require.False(t, buf.fill(100, &err)) + require.True(t, err.HasError()) + require.Less(t, cap(buf.data), 100) + require.LessOrEqual(t, cap(buf.data), 32) +} + +func TestReadCollectionLengthDoesNotTreatElementsAsBytes(t *testing.T) { + writer := NewByteBuffer(nil) + writer.WriteLength(1024) + + ctx := NewReadContext(false) + ctx.SetData(writer.Bytes()) + + require.Equal(t, 1024, ctx.ReadCollectionLength()) + require.False(t, ctx.HasError()) +} + func TestReadVarUint32Small7RejectsOverflowFifthByte(t *testing.T) { buf := NewByteBuffer([]byte{0x80, 0x80, 0x80, 0x80, 0x10}) var err Error diff --git a/go/fory/codegen/decoder.go b/go/fory/codegen/decoder.go index 859f492d68..0e57021343 100644 --- a/go/fory/codegen/decoder.go +++ b/go/fory/codegen/decoder.go @@ -169,11 +169,20 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\tif isXlang {\n") fmt.Fprintf(buf, "\t\t\t// xlang mode: slices are not nullable, read directly without null flag\n") fmt.Fprintf(buf, "\t\t\tsliceLen := ctx.ReadCollectionLength()\n") + fmt.Fprintf(buf, "\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t}\n") fmt.Fprintf(buf, "\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") fmt.Fprintf(buf, "\t\t\t\t// ReadData collection flags (ignore for now)\n") fmt.Fprintf(buf, "\t\t\t\t_ = buf.ReadInt8(err)\n") + fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\tif !buf.CheckReadable(sliceLen, err) {\n") + fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\t// Create slice with proper capacity\n") fmt.Fprintf(buf, "\t\t\t\t%s = make([]any, sliceLen)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t\t// ReadData each element using ReadValue\n") @@ -188,11 +197,20 @@ func generateFieldReadTyped(buf *bytes.Buffer, field *FieldInfo) error { fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") fmt.Fprintf(buf, "\t\t\t\tsliceLen := ctx.ReadCollectionLength()\n") + fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make([]any, 0)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t\t} else {\n") fmt.Fprintf(buf, "\t\t\t\t\t// ReadData collection flags (ignore for now)\n") fmt.Fprintf(buf, "\t\t\t\t\t_ = buf.ReadInt8(err)\n") + fmt.Fprintf(buf, "\t\t\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t\t}\n") + fmt.Fprintf(buf, "\t\t\t\t\tif !buf.CheckReadable(sliceLen, err) {\n") + fmt.Fprintf(buf, "\t\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\t\t// Create slice with proper capacity\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make([]any, sliceLen)\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t\t\t// ReadData each element using ReadValue\n") @@ -371,48 +389,6 @@ func generateOptionValueRead(buf *bytes.Buffer, elemType types.Type, valueExpr s return nil } -// Note: generateSliceRead is no longer used since we use WriteReferencable/ReadValue for slice fields -// generateSliceRead generates code to deserialize a slice according to the list format -func generateSliceRead(buf *bytes.Buffer, sliceType *types.Slice, fieldAccess string) error { - elemType := sliceType.Elem() - - // Use block scope to avoid variable redeclaration across multiple slice fields - fmt.Fprintf(buf, "\t// ReadData slice %s\n", fieldAccess) - fmt.Fprintf(buf, "\t{\n") - fmt.Fprintf(buf, "\t\tsliceLen := int(buf.ReadVarUint32())\n") - fmt.Fprintf(buf, "\t\tif sliceLen == 0 {\n") - fmt.Fprintf(buf, "\t\t\t// Empty slice - matching reflection behavior where nil and empty are treated the same\n") - fmt.Fprintf(buf, "\t\t\t%s = nil\n", fieldAccess) - fmt.Fprintf(buf, "\t\t} else {\n") - - // ReadData collection flags for non-empty slice - fmt.Fprintf(buf, "\t\t\t// ReadData collection flags\n") - fmt.Fprintf(buf, "\t\t\tcollectFlag := buf.ReadInt8()\n") - fmt.Fprintf(buf, "\t\t\t// Check if CollectionIsDeclElementType flag is NOT set (meaning we need to read type ID)\n") - fmt.Fprintf(buf, "\t\t\tif (collectFlag & 4) == 0 {\n") - fmt.Fprintf(buf, "\t\t\t\t// ReadData element type ID (not declared, so we need to read it)\n") - fmt.Fprintf(buf, "\t\t\t\t_ = buf.ReadVarUint32()\n") - fmt.Fprintf(buf, "\t\t\t}\n") - - // Create slice - fmt.Fprintf(buf, "\t\t\t%s = make(%s, sliceLen)\n", fieldAccess, sliceType.String()) - - // ReadData elements - for declared type slices, use direct element reading without flags - fmt.Fprintf(buf, "\t\t\tfor i := 0; i < sliceLen; i++ {\n") - - // Generate element read code - for typed slices, read directly via serializer - elemAccess := fmt.Sprintf("%s[i]", fieldAccess) - if err := generateSliceElementReadDirect(buf, elemType, elemAccess); err != nil { - return err - } - - fmt.Fprintf(buf, "\t\t\t}\n") - fmt.Fprintf(buf, "\t\t}\n") - fmt.Fprintf(buf, "\t}\n") - - return nil -} - // generateSliceElementRead generates code to read a single slice element func generateSliceElementRead(buf *bytes.Buffer, elemType types.Type, elemAccess string) error { // Handle basic types @@ -540,6 +516,9 @@ func generateSliceReadInline(buf *bytes.Buffer, sliceType *types.Slice, fieldAcc fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") fmt.Fprintf(buf, "\t\t\t\tsliceLen := ctx.ReadCollectionLength()\n") + fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif sliceLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s, 0)\n", fieldAccess, sliceType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -563,10 +542,19 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi if iface, ok := unwrappedElem.(*types.Interface); ok && iface.Empty() { fmt.Fprintf(buf, "%s// Dynamic slice []any handling - no null flag\n", indent) fmt.Fprintf(buf, "%ssliceLen := ctx.ReadCollectionLength()\n", indent) + fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make([]any, 0)\n", indent, fieldAccess) fmt.Fprintf(buf, "%s} else {\n", indent) fmt.Fprintf(buf, "%s\t_ = buf.ReadInt8(err) // collection flags\n", indent) + fmt.Fprintf(buf, "%s\tif ctx.HasError() {\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s\tif !buf.CheckReadable(sliceLen, err) {\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) fmt.Fprintf(buf, "%s\t%s = make([]any, sliceLen)\n", indent, fieldAccess) fmt.Fprintf(buf, "%s\tfor i := range %s {\n", indent, fieldAccess) fmt.Fprintf(buf, "%s\t\tctx.ReadValue(reflect.ValueOf(&%s[i]).Elem(), fory.RefModeTracking, true)\n", indent, fieldAccess) @@ -577,6 +565,9 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi elemIsReferencable := isReferencableType(elemType) fmt.Fprintf(buf, "%ssliceLen := ctx.ReadCollectionLength()\n", indent) + fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif sliceLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s, 0)\n", indent, fieldAccess, sliceType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -591,6 +582,9 @@ func generateSliceReadInlineNoNull(buf *bytes.Buffer, sliceType *types.Slice, fi func writeSliceReadElements(buf *bytes.Buffer, sliceType *types.Slice, elemType types.Type, fieldAccess string, elemIsReferencable bool, indent string) error { // ReadData collection header fmt.Fprintf(buf, "%scollectFlag := buf.ReadInt8(err)\n", indent) + fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%s// Check if CollectionIsDeclElementType is set (bit 2, value 4)\n", indent) fmt.Fprintf(buf, "%shasDeclType := (collectFlag & 4) != 0\n", indent) if elemIsReferencable { @@ -599,6 +593,9 @@ func writeSliceReadElements(buf *bytes.Buffer, sliceType *types.Slice, elemType } // Create slice + fmt.Fprintf(buf, "%sif !buf.CheckReadable(sliceLen, err) {\n", indent) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%s%s = make(%s, sliceLen)\n", indent, fieldAccess, sliceType.String()) // ReadData elements based on whether CollectionIsDeclElementType is set @@ -849,6 +846,9 @@ func generateMapReadInline(buf *bytes.Buffer, mapType *types.Map, fieldAccess st fmt.Fprintf(buf, "\t\t\t\t%s = nil\n", fieldAccess) fmt.Fprintf(buf, "\t\t\t} else {\n") fmt.Fprintf(buf, "\t\t\t\tmapLen := ctx.ReadCollectionLength()\n") + fmt.Fprintf(buf, "\t\t\t\tif ctx.HasError() {\n") + fmt.Fprintf(buf, "\t\t\t\t\treturn ctx.TakeError()\n") + fmt.Fprintf(buf, "\t\t\t\t}\n") fmt.Fprintf(buf, "\t\t\t\tif mapLen == 0 {\n") fmt.Fprintf(buf, "\t\t\t\t\t%s = make(%s)\n", fieldAccess, mapType.String()) fmt.Fprintf(buf, "\t\t\t\t} else {\n") @@ -881,6 +881,9 @@ func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAcc indent := "\t\t\t" fmt.Fprintf(buf, "%smapLen := ctx.ReadCollectionLength()\n", indent) + fmt.Fprintf(buf, "%sif ctx.HasError() {\n", indent) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%sif mapLen == 0 {\n", indent) fmt.Fprintf(buf, "%s\t%s = make(%s)\n", indent, fieldAccess, mapType.String()) fmt.Fprintf(buf, "%s} else {\n", indent) @@ -893,6 +896,9 @@ func generateMapReadInlineNoNull(buf *bytes.Buffer, mapType *types.Map, fieldAcc // writeMapReadChunks generates the map chunk reading code with specified indentation func writeMapReadChunks(buf *bytes.Buffer, mapType *types.Map, fieldAccess string, keyType, valueType types.Type, keyIsInterface, valueIsInterface bool, indent string) error { + fmt.Fprintf(buf, "%sif !buf.CheckReadable(mapLen, err) {\n", indent) + fmt.Fprintf(buf, "%s\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s}\n", indent) fmt.Fprintf(buf, "%s%s = make(%s, mapLen)\n", indent, fieldAccess, mapType.String()) fmt.Fprintf(buf, "%smapSize := mapLen\n", indent) @@ -901,6 +907,13 @@ func writeMapReadChunks(buf *bytes.Buffer, mapType *types.Map, fieldAccess strin fmt.Fprintf(buf, "%s\t// ReadData KV header\n", indent) fmt.Fprintf(buf, "%s\tkvHeader := buf.ReadByte(err)\n", indent) fmt.Fprintf(buf, "%s\tchunkSize := int(buf.ReadByte(err))\n", indent) + fmt.Fprintf(buf, "%s\tif ctx.HasError() {\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) + fmt.Fprintf(buf, "%s\tif chunkSize == 0 || chunkSize > mapSize {\n", indent) + fmt.Fprintf(buf, "%s\t\tctx.SetError(fory.DeserializationErrorf(\"invalid map chunk size %%d for remaining length %%d\", chunkSize, mapSize))\n", indent) + fmt.Fprintf(buf, "%s\t\treturn ctx.TakeError()\n", indent) + fmt.Fprintf(buf, "%s\t}\n", indent) // Parse header flags fmt.Fprintf(buf, "%s\ttrackKeyRef := (kvHeader & 0x1) != 0\n", indent) diff --git a/go/fory/compatible_scalar.go b/go/fory/compatible_scalar.go index 542899ac65..a34092f7b5 100644 --- a/go/fory/compatible_scalar.go +++ b/go/fory/compatible_scalar.go @@ -326,7 +326,7 @@ func readCompatibleScalarValue(ctx *ReadContext, typeID TypeId) compatibleScalar case 1: return compatibleScalarValue{typeID: typeID, boolVal: true} default: - compatibleScalarFail(ctx, "", typeID, UNKNOWN, "bool payload is not 0 or 1") + compatibleScalarFail(ctx, "", typeID, UNKNOWN, "bool byte is not 0 or 1") return compatibleScalarValue{} } case STRING: diff --git a/go/fory/compatible_scalar_test.go b/go/fory/compatible_scalar_test.go index 4a043b1099..0bd31559d3 100644 --- a/go/fory/compatible_scalar_test.go +++ b/go/fory/compatible_scalar_test.go @@ -391,13 +391,13 @@ func TestCompatibleScalarConversions(t *testing.T) { } } -func TestCompatibleScalarRejectsInvalidBoolPayload(t *testing.T) { +func TestRejectsInvalidBoolByte(t *testing.T) { f := NewForyWithOptions(WithXlang(true), WithCompatible(true)) f.readCtx.SetData([]byte{2}) _ = readCompatibleScalarValue(f.readCtx, BOOL) err := f.readCtx.CheckError() require.Error(t, err) - assert.Contains(t, err.Error(), "bool payload is not 0 or 1") + assert.Contains(t, err.Error(), "bool byte is not 0 or 1") } func TestCompatibleScalarRejectsRefValueFlag(t *testing.T) { @@ -439,11 +439,11 @@ func TestCompatibleScalarSameTypeNullableStrictRead(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "invalid compatible scalar null flag") - badPayload := append([]byte(nil), data...) - badPayload[len(badPayload)-1] = 2 - err = reader.Unmarshal(badPayload, &out) + badBoolByte := append([]byte(nil), data...) + badBoolByte[len(badBoolByte)-1] = 2 + err = reader.Unmarshal(badBoolByte, &out) require.Error(t, err) - assert.Contains(t, err.Error(), "bool payload") + assert.Contains(t, err.Error(), "bool byte") } func TestCompatibleScalarTrackingRefMismatch(t *testing.T) { diff --git a/go/fory/decimal.go b/go/fory/decimal.go index 5bc3a0d540..83ea8fa8b9 100644 --- a/go/fory/decimal.go +++ b/go/fory/decimal.go @@ -111,14 +111,14 @@ func writeDecimalParts(buffer *ByteBuffer, scale int32, unscaled *big.Int) { } abs := new(big.Int).Abs(unscaled) - payload := abs.Bytes() - if len(payload) == 0 { + magnitudeBytes := abs.Bytes() + if len(magnitudeBytes) == 0 { panic("decimal zero must use the small encoding") } - reverseBytes(payload) - meta := (uint64(len(payload)) << 1) | uint64(signBit(unscaled.Sign())) + reverseBytes(magnitudeBytes) + meta := (uint64(len(magnitudeBytes)) << 1) | uint64(signBit(unscaled.Sign())) buffer.WriteVarUint64((meta << 1) | 1) - buffer.WriteBinary(payload) + buffer.WriteBinary(magnitudeBytes) } func readDecimalParts(ctx *ReadContext) (int32, *big.Int) { @@ -138,19 +138,19 @@ func readDecimalParts(ctx *ReadContext) (int32, *big.Int) { ctx.SetError(DeserializationErrorf("invalid decimal magnitude length %d", length)) return 0, nil } - if length > uint64(ctx.maxBinarySize) { - ctx.SetError(MaxBinarySizeExceededError(int(length), ctx.maxBinarySize)) + if length > uint64(MaxInt) { + ctx.SetError(DeserializationErrorf("invalid decimal magnitude length %d", length)) return 0, nil } - payload := ctx.buffer.ReadBytes(int(length), err) + magnitudeBytes := ctx.buffer.ReadBytes(int(length), err) if ctx.HasError() { return 0, nil } - if payload[len(payload)-1] == 0 { - ctx.SetError(DeserializationError("non-canonical decimal payload: trailing zero byte")) + if magnitudeBytes[len(magnitudeBytes)-1] == 0 { + ctx.SetError(DeserializationError("non-canonical decimal magnitude bytes: trailing zero byte")) return 0, nil } - bigEndian := append([]byte(nil), payload...) + bigEndian := append([]byte(nil), magnitudeBytes...) reverseBytes(bigEndian) magnitude := new(big.Int).SetBytes(bigEndian) if magnitude.Sign() == 0 { diff --git a/go/fory/decimal_test.go b/go/fory/decimal_test.go index 12fd6d090b..6e377e140a 100644 --- a/go/fory/decimal_test.go +++ b/go/fory/decimal_test.go @@ -148,10 +148,9 @@ func TestDecimalOOM(t *testing.T) { data := buffer.Bytes() - f := New(WithXlang(true), WithCompatible(false), WithMaxBinarySize(1024*1024)) + f := New(WithXlang(true), WithCompatible(false)) var decoded Decimal err := f.DeserializeFromReader(bytes.NewReader(data), &decoded) require.Error(t, err) - require.Contains(t, err.Error(), "max binary size exceeded") } diff --git a/go/fory/errors.go b/go/fory/errors.go index 6dc092bf2d..25f7dd087a 100644 --- a/go/fory/errors.go +++ b/go/fory/errors.go @@ -52,10 +52,6 @@ const ( ErrKindInvalidTag // ErrKindInvalidUTF16String indicates malformed UTF-16 string data ErrKindInvalidUTF16String - // ErrKindMaxCollectionSizeExceeded indicates max collection size exceeded - ErrKindMaxCollectionSizeExceeded - // ErrKindMaxBinarySizeExceeded indicates max binary size exceeded - ErrKindMaxBinarySizeExceeded ) // Error is a lightweight error type optimized for hot path performance. @@ -300,26 +296,6 @@ func InvalidUTF16StringError(byteCount int) Error { }) } -// MaxCollectionSizeExceededError creates a max collection size exceeded error -// -//go:noinline -func MaxCollectionSizeExceededError(size, limit int) Error { - return panicIfEnabled(Error{ - kind: ErrKindMaxCollectionSizeExceeded, - message: fmt.Sprintf("max collection size exceeded: size=%d, limit=%d", size, limit), - }) -} - -// MaxBinarySizeExceededError creates a max binary size exceeded error -// -//go:noinline -func MaxBinarySizeExceededError(size, limit int) Error { - return panicIfEnabled(Error{ - kind: ErrKindMaxBinarySizeExceeded, - message: fmt.Sprintf("max binary size exceeded: size=%d, limit=%d", size, limit), - }) -} - // WrapError wraps a standard error into a fory Error // //go:noinline diff --git a/go/fory/fory.go b/go/fory/fory.go index 0ae33d4605..964d72c85d 100644 --- a/go/fory/fory.go +++ b/go/fory/fory.go @@ -65,24 +65,20 @@ const ( // Config holds configuration options for Fory instances type Config struct { - TrackRef bool - MaxDepth int - IsXlang bool - Compatible bool // Schema evolution compatibility mode - MaxCollectionSize int - MaxBinarySize int - MaxTypeFields int + TrackRef bool + MaxDepth int + IsXlang bool + Compatible bool // Schema evolution compatibility mode + MaxTypeFields int } // defaultConfig returns the default configuration func defaultConfig() Config { return Config{ - TrackRef: false, // Match Java's default: reference tracking disabled - MaxDepth: 20, - IsXlang: true, - MaxCollectionSize: 1_000_000, - MaxBinarySize: 64 * 1024 * 1024, - MaxTypeFields: 10000, + TrackRef: false, // Match Java's default: reference tracking disabled + MaxDepth: 20, + IsXlang: true, + MaxTypeFields: 10000, } } @@ -123,20 +119,6 @@ func WithCompatible(enabled bool) Option { } } -// WithMaxCollectionSize sets the maximum collection size limit -func WithMaxCollectionSize(size int) Option { - return func(f *Fory) { - f.config.MaxCollectionSize = size - } -} - -// WithMaxBinarySize sets the maximum binary size limit -func WithMaxBinarySize(size int) Option { - return func(f *Fory) { - f.config.MaxBinarySize = size - } -} - // WithMaxTypeFields sets the maximum field count limit for schema definition deserialization func WithMaxTypeFields(size int) Option { return func(f *Fory) { @@ -197,8 +179,6 @@ func New(opts ...Option) *Fory { f.writeCtx.xlang = f.config.IsXlang f.readCtx = NewReadContext(f.config.TrackRef) - f.readCtx.maxCollectionSize = f.config.MaxCollectionSize - f.readCtx.maxBinarySize = f.config.MaxBinarySize f.readCtx.typeResolver = f.typeResolver f.readCtx.refResolver = f.refResolver f.readCtx.compatible = f.config.Compatible diff --git a/go/fory/fory_test.go b/go/fory/fory_test.go index cef56ab6ed..2888cdfe5c 100644 --- a/go/fory/fory_test.go +++ b/go/fory/fory_test.go @@ -18,6 +18,7 @@ package fory import ( + "bytes" "fmt" "reflect" "testing" @@ -512,6 +513,33 @@ func TestSerializeCommonReference(t *testing.T) { } } +func TestReadBufferObjectRejectsTruncatedInBandData(t *testing.T) { + buf := NewByteBuffer(nil) + buf.WriteBool(true) + buf.WriteVarUint32(4) + buf.WriteBinary([]byte{1, 2}) + + ctx := NewReadContext(false) + ctx.SetData(buf.Bytes()) + + require.Nil(t, ctx.ReadBufferObject()) + require.Error(t, ctx.CheckError()) +} + +func TestReadBufferObjectCopiesInBandDataFromStream(t *testing.T) { + data := []byte{1, 3, 10, 11, 12, 20, 21, 22, 23, 24, 25} + ctx := NewReadContext(false) + ctx.buffer.ResetWithReader(bytes.NewReader(data), 2) + + buf := ctx.ReadBufferObject() + require.NoError(t, ctx.CheckError()) + require.Equal(t, []byte{10, 11, 12}, buf.GetData()) + + ctx.buffer.Skip(6, ctx.Err()) + require.NoError(t, ctx.CheckError()) + require.Equal(t, []byte{10, 11, 12}, buf.GetData()) +} + // TestSerializeZeroCopy is temporarily disabled during API refactoring // TODO: Re-enable when zero-copy serialization API is updated /* diff --git a/go/fory/limit_test.go b/go/fory/limit_test.go deleted file mode 100644 index a0333cb12c..0000000000 --- a/go/fory/limit_test.go +++ /dev/null @@ -1,103 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package fory - -import ( - "github.com/stretchr/testify/require" - "testing" -) - -func TestMaxCollectionSizeGuardrail(t *testing.T) { - // 1. Test slice exceeding limit - t.Run("Slice exceeds MaxCollectionSize", func(t *testing.T) { - config := WithMaxCollectionSize(2) - f := NewFory(WithXlang(false), config, WithCompatible(false)) - - slice := []string{"a", "b", "c"} - fBase := NewFory(WithXlang(false), WithCompatible(false)) - bytes, _ := fBase.Serialize(slice) - - var decoded []string - err := f.Deserialize(bytes, &decoded) - require.Error(t, err) - require.Contains(t, err.Error(), "max collection size exceeded: size=3, limit=2") - }) - - // 2. Test map exceeding limit - t.Run("Map exceeds MaxCollectionSize", func(t *testing.T) { - config := WithMaxCollectionSize(2) - f := NewFory(WithXlang(false), config, WithCompatible(false)) - - m := map[int32]int32{1: 1, 2: 2, 3: 3} - fBase := NewFory(WithXlang(false), WithCompatible(false)) - bytes, _ := fBase.Serialize(m) - - var decoded map[int32]int32 - err := f.Deserialize(bytes, &decoded) - require.Error(t, err) - require.Contains(t, err.Error(), "max collection size exceeded: size=3, limit=2") - }) - - // 3. Test string is not affected by MaxCollectionSize - t.Run("String unaffected by MaxCollectionSize", func(t *testing.T) { - config := WithMaxCollectionSize(2) - f := NewFory(WithXlang(false), config, WithCompatible(false)) - - str := "hello world" // length 11 - bytes, err := f.Serialize(str) - require.NoError(t, err) - - var decoded string - err = f.Deserialize(bytes, &decoded) - require.NoError(t, err) - require.Equal(t, str, decoded) - }) -} - -func TestMaxBinarySizeGuardrail(t *testing.T) { - // 1. Test binary (byte slice) exceeding limit - t.Run("Byte slice exceeds MaxBinarySize", func(t *testing.T) { - config := WithMaxBinarySize(5) - f := NewFory(WithXlang(false), config, WithCompatible(false)) - - // We can serialize a byte slice using standard serializer, then decode with the f instance - slice := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - fBase := NewFory(WithXlang(false), WithCompatible(false)) - bytes, _ := fBase.Serialize(slice) - - var decoded []byte - err := f.Deserialize(bytes, &decoded) - require.Error(t, err) - require.Contains(t, err.Error(), "max binary size exceeded: size=10, limit=5") - }) - - // 2. Test string is not affected by MaxBinarySize - t.Run("String unaffected by MaxBinarySize", func(t *testing.T) { - config := WithMaxBinarySize(2) - f := NewFory(WithXlang(false), config, WithCompatible(false)) - - str := "hello world" // length 11 - bytes, err := f.Serialize(str) - require.NoError(t, err) - - var decoded string - err = f.Deserialize(bytes, &decoded) - require.NoError(t, err) - require.Equal(t, str, decoded) - }) -} diff --git a/go/fory/map.go b/go/fory/map.go index a59fd8e652..8b1d82cc95 100644 --- a/go/fory/map.go +++ b/go/fory/map.go @@ -293,20 +293,21 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { typeResolver := ctx.TypeResolver() type_ := value.Type() - // Initialize map - if value.IsNil() { - mapType := type_ - // For any maps without declared types, use map[any]any - if !s.hasGenerics && type_.Key().Kind() == reflect.Interface && type_.Elem().Kind() == reflect.Interface { - iface := reflect.TypeOf((*any)(nil)).Elem() - mapType = reflect.MapOf(iface, iface) - } - value.Set(reflect.MakeMap(mapType)) - } - refResolver.Reference(value) - size := ctx.ReadCollectionLength() - if size == 0 || ctx.HasError() { + if ctx.HasError() { + return + } + mapType := type_ + // For any maps without declared types, use map[any]any. + if !s.hasGenerics && type_.Key().Kind() == reflect.Interface && type_.Elem().Kind() == reflect.Interface { + iface := reflect.TypeOf((*any)(nil)).Elem() + mapType = reflect.MapOf(iface, iface) + } + if size == 0 { + if value.IsNil() { + value.Set(reflect.MakeMap(mapType)) + } + refResolver.Reference(value) return } @@ -314,6 +315,13 @@ func (s mapSerializer) ReadData(ctx *ReadContext, value reflect.Value) { if ctx.HasError() { return } + if !buf.CheckReadable(size, ctxErr) { + return + } + if value.IsNil() { + value.Set(reflect.MakeMapWithSize(mapType, size)) + } + refResolver.Reference(value) keyType := type_.Key() valueType := type_.Elem() diff --git a/go/fory/map_primitive.go b/go/fory/map_primitive.go index 287777eaea..d520e97bd5 100644 --- a/go/fory/map_primitive.go +++ b/go/fory/map_primitive.go @@ -68,15 +68,27 @@ func writeMapStringString(buf *ByteBuffer, m map[string]string, hasGenerics bool } } +func readTypedMapSize(ctx *ReadContext) (int, bool) { + size := ctx.ReadCollectionLength() + if size == 0 || ctx.HasError() { + return size, false + } + if !ctx.Buffer().CheckReadable(size, ctx.Err()) { + return 0, false + } + return size, true +} + // readMapStringString reads map[string]string using chunk protocol func readMapStringString(ctx *ReadContext) map[string]string { err := ctx.Err() buf := ctx.Buffer() - size := ctx.ReadCollectionLength() - result := make(map[string]string, size) - if size == 0 { + size, ok := readTypedMapSize(ctx) + result := make(map[string]string) + if !ok { return result } + result = make(map[string]string, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -159,11 +171,12 @@ func writeMapStringInt64(buf *ByteBuffer, m map[string]int64, hasGenerics bool) func readMapStringInt64(ctx *ReadContext) map[string]int64 { err := ctx.Err() buf := ctx.Buffer() - size := ctx.ReadCollectionLength() - result := make(map[string]int64, size) - if size == 0 { + size, ok := readTypedMapSize(ctx) + result := make(map[string]int64) + if !ok { return result } + result = make(map[string]int64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -243,11 +256,12 @@ func writeMapStringInt32(buf *ByteBuffer, m map[string]int32, hasGenerics bool) func readMapStringInt32(ctx *ReadContext) map[string]int32 { err := ctx.Err() buf := ctx.Buffer() - size := ctx.ReadCollectionLength() - result := make(map[string]int32, size) - if size == 0 { + size, ok := readTypedMapSize(ctx) + result := make(map[string]int32) + if !ok { return result } + result = make(map[string]int32, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -327,11 +341,12 @@ func writeMapStringInt(buf *ByteBuffer, m map[string]int, hasGenerics bool) { func readMapStringInt(ctx *ReadContext) map[string]int { err := ctx.Err() buf := ctx.Buffer() - size := ctx.ReadCollectionLength() - result := make(map[string]int, size) - if size == 0 { + size, ok := readTypedMapSize(ctx) + result := make(map[string]int) + if !ok { return result } + result = make(map[string]int, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -411,11 +426,12 @@ func writeMapStringFloat64(buf *ByteBuffer, m map[string]float64, hasGenerics bo func readMapStringFloat64(ctx *ReadContext) map[string]float64 { err := ctx.Err() buf := ctx.Buffer() - size := ctx.ReadCollectionLength() - result := make(map[string]float64, size) - if size == 0 { + size, ok := readTypedMapSize(ctx) + result := make(map[string]float64) + if !ok { return result } + result = make(map[string]float64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -495,11 +511,12 @@ func writeMapStringBool(buf *ByteBuffer, m map[string]bool, hasGenerics bool) { func readMapStringBool(ctx *ReadContext) map[string]bool { err := ctx.Err() buf := ctx.Buffer() - size := ctx.ReadCollectionLength() - result := make(map[string]bool, size) - if size == 0 { + size, ok := readTypedMapSize(ctx) + result := make(map[string]bool) + if !ok { return result } + result = make(map[string]bool, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -583,11 +600,12 @@ func writeMapInt32Int32(buf *ByteBuffer, m map[int32]int32, hasGenerics bool) { func readMapInt32Int32(ctx *ReadContext) map[int32]int32 { err := ctx.Err() buf := ctx.Buffer() - size := ctx.ReadCollectionLength() - result := make(map[int32]int32, size) - if size == 0 { + size, ok := readTypedMapSize(ctx) + result := make(map[int32]int32) + if !ok { return result } + result = make(map[int32]int32, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -667,11 +685,12 @@ func writeMapInt64Int64(buf *ByteBuffer, m map[int64]int64, hasGenerics bool) { func readMapInt64Int64(ctx *ReadContext) map[int64]int64 { err := ctx.Err() buf := ctx.Buffer() - size := ctx.ReadCollectionLength() - result := make(map[int64]int64, size) - if size == 0 { + size, ok := readTypedMapSize(ctx) + result := make(map[int64]int64) + if !ok { return result } + result = make(map[int64]int64, size) for size > 0 { chunkHeader := buf.ReadUint8(err) @@ -751,11 +770,12 @@ func writeMapIntInt(buf *ByteBuffer, m map[int]int, hasGenerics bool) { func readMapIntInt(ctx *ReadContext) map[int]int { err := ctx.Err() buf := ctx.Buffer() - size := ctx.ReadCollectionLength() - result := make(map[int]int, size) - if size == 0 { + size, ok := readTypedMapSize(ctx) + result := make(map[int]int) + if !ok { return result } + result = make(map[int]int, size) for size > 0 { chunkHeader := buf.ReadUint8(err) diff --git a/go/fory/meta_string_resolver.go b/go/fory/meta_string_resolver.go index 5a9e1cb8ff..d74d5aa2cd 100644 --- a/go/fory/meta_string_resolver.go +++ b/go/fory/meta_string_resolver.go @@ -163,6 +163,9 @@ func (r *MetaStringResolver) ReadMetaStringBytes(buf *ByteBuffer, ctxErr *Error) return nil, encErr } + if !buf.CheckReadable(length, ctxErr) { + return nil, *ctxErr + } data = make([]byte, length) _, err := buf.Read(data) if err != nil { @@ -187,6 +190,9 @@ func (r *MetaStringResolver) ReadMetaStringBytes(buf *ByteBuffer, ctxErr *Error) if encErr != nil { return nil, encErr } + if !buf.CheckReadable(length, ctxErr) { + return nil, *ctxErr + } data = make([]byte, length) _, err = buf.Read(data) if err != nil { diff --git a/go/fory/reader.go b/go/fory/reader.go index 1de67ec3fb..3985bb4e2b 100644 --- a/go/fory/reader.go +++ b/go/fory/reader.go @@ -29,23 +29,21 @@ import ( // ReadContext holds all state needed during deserialization. type ReadContext struct { - buffer *ByteBuffer - refReader *RefReader - trackRef bool // Cached flag to avoid indirection - xlang bool // Cross-language serialization mode - rootHeader byte - compatible bool // Schema evolution compatibility mode - typeResolver *TypeResolver // For complex type deserialization - refResolver *RefResolver // For reference tracking in native-mode paths - outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization - outOfBandIndex int // Current index into out-of-band buffers - depth int // Current nesting depth for cycle detection - maxDepth int // Maximum allowed nesting depth - err Error // Accumulated error state for deferred checking - lastTypePtr uintptr - lastTypeInfo *TypeInfo - maxCollectionSize int // Size guardrail for collection reads - maxBinarySize int // Size guardrail for binary reads + buffer *ByteBuffer + refReader *RefReader + trackRef bool // Cached flag to avoid indirection + xlang bool // Cross-language serialization mode + rootHeader byte + compatible bool // Schema evolution compatibility mode + typeResolver *TypeResolver // For complex type deserialization + refResolver *RefResolver // For reference tracking in native-mode paths + outOfBandBuffers []*ByteBuffer // Out-of-band buffers for deserialization + outOfBandIndex int // Current index into out-of-band buffers + depth int // Current nesting depth for cycle detection + maxDepth int // Maximum allowed nesting depth + err Error // Accumulated error state for deferred checking + lastTypePtr uintptr + lastTypeInfo *TypeInfo } // IsXlang returns whether cross-language serialization mode is enabled @@ -252,31 +250,23 @@ func (c *ReadContext) ReadAndValidateTypeId(expected TypeId) { } } -// ReadCollectionLength reads a length value for collections with size guardrails +// ReadCollectionLength reads a length value for collections. func (c *ReadContext) ReadCollectionLength() int { err := c.Err() length := c.buffer.ReadLength(err) if c.err.HasError() { return 0 } - if length > c.maxCollectionSize { - c.SetError(MaxCollectionSizeExceededError(length, c.maxCollectionSize)) - return 0 - } return length } -// ReadBinaryLength reads a length value for binary data with size guardrails +// ReadBinaryLength reads a byte length value for binary data. func (c *ReadContext) ReadBinaryLength() int { err := c.Err() length := c.buffer.ReadLength(err) if c.err.HasError() { return 0 } - if length > c.maxBinarySize { - c.SetError(MaxBinarySizeExceededError(length, c.maxBinarySize)) - return 0 - } return length } @@ -681,6 +671,19 @@ func (c *ReadContext) ReadBufferObject() *ByteBuffer { isInBand := c.buffer.ReadBool(err) if isInBand { size := c.ReadBinaryLength() + if c.HasError() { + return nil + } + if c.buffer.reader != nil { + bytes := c.buffer.ReadBytes(size, err) + if c.HasError() { + return nil + } + return NewByteBuffer(bytes) + } + if !c.buffer.CheckReadable(size, err) { + return nil + } buf := c.buffer.Slice(c.buffer.readerIndex, size) c.buffer.readerIndex += size return buf diff --git a/go/fory/set.go b/go/fory/set.go index ed3f9bac1b..1a42739547 100644 --- a/go/fory/set.go +++ b/go/fory/set.go @@ -314,6 +314,9 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { type_ := value.Type() // ReadData collection length from buffer length := ctx.ReadCollectionLength() + if ctx.HasError() { + return + } if length == 0 { // Initialize empty set if length is 0 value.Set(reflect.MakeMap(type_)) @@ -322,6 +325,9 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // ReadData collection flags that indicate special characteristics collectFlag := buf.ReadInt8(err) + if ctx.HasError() { + return + } var elemTypeInfo *TypeInfo // If all elements are same type, get element type info @@ -344,10 +350,16 @@ func (s setSerializer) ReadData(ctx *ReadContext, value reflect.Value) { elemTypeInfo = ctx.TypeResolver().ReadTypeInfo(buf, err) } } + if ctx.HasError() { + return + } + if !buf.CheckReadable(length, err) { + return + } // Initialize set if nil if value.IsNil() { - value.Set(reflect.MakeMap(type_)) + value.Set(reflect.MakeMapWithSize(type_, length)) } // Register reference for tracking (handles circular references) ctx.RefResolver().Reference(value) diff --git a/go/fory/skip.go b/go/fory/skip.go index abc7466449..82b03a5054 100644 --- a/go/fory/skip.go +++ b/go/fory/skip.go @@ -22,6 +22,14 @@ import ( "reflect" ) +func skipSizedBytes(ctx *ReadContext, size uint64) { + if size > uint64(MaxInt) { + ctx.SetError(DeserializationErrorf("skip byte length exceeds supported int range: %d", size)) + return + } + ctx.buffer.Skip(int(size), ctx.Err()) +} + // SkipFieldValue skips a field value in compatible mode when the field doesn't exist // or is incompatible with the local type. // Uses context error state for deferred error checking. @@ -638,42 +646,46 @@ func skipValue(ctx *ReadContext, fieldDef FieldDef, readRefFlag bool, isField bo encoding := header & 0b11 switch encoding { case 0: // Latin1 - 1 byte per char - _ = ctx.buffer.ReadBinary(int(size), err) + skipSizedBytes(ctx, size) case 1: // UTF-16LE - 2 bytes per char - _ = ctx.buffer.ReadBinary(int(size*2), err) + if size > uint64(MaxInt)/2 { + ctx.SetError(DeserializationErrorf("UTF-16 string byte length exceeds supported int range: %d", size)) + return + } + skipSizedBytes(ctx, size*2) case 2: // UTF-8 - variable, but size is byte count - _ = ctx.buffer.ReadBinary(int(size), err) + skipSizedBytes(ctx, size) } case BINARY: - length := uint32(ctx.ReadBinaryLength()) + length := ctx.ReadBinaryLength() if ctx.HasError() { return } - _ = ctx.buffer.ReadBinary(int(length), err) + ctx.buffer.Skip(length, err) case BOOL_ARRAY, INT8_ARRAY, UINT8_ARRAY: length := ctx.ReadBinaryLength() if ctx.HasError() { return } - _ = ctx.buffer.ReadBinary(length, err) + ctx.buffer.Skip(length, err) case INT16_ARRAY, UINT16_ARRAY, FLOAT16_ARRAY, BFLOAT16_ARRAY: size := ctx.ReadBinaryLength() if ctx.HasError() { return } - _ = ctx.buffer.ReadBinary(size, err) + ctx.buffer.Skip(size, err) case INT32_ARRAY, UINT32_ARRAY, FLOAT32_ARRAY: size := ctx.ReadBinaryLength() if ctx.HasError() { return } - _ = ctx.buffer.ReadBinary(size, err) + ctx.buffer.Skip(size, err) case INT64_ARRAY, UINT64_ARRAY, FLOAT64_ARRAY: size := ctx.ReadBinaryLength() if ctx.HasError() { return } - _ = ctx.buffer.ReadBinary(size, err) + ctx.buffer.Skip(size, err) // Date/Time types case DATE: diff --git a/go/fory/slice.go b/go/fory/slice.go index 7fd787b814..6d941b3bf6 100644 --- a/go/fory/slice.go +++ b/go/fory/slice.go @@ -303,6 +303,9 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() length := ctx.ReadCollectionLength() + if ctx.HasError() { + return + } isArrayType := value.Type().Kind() == reflect.Array if length == 0 { @@ -314,6 +317,9 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // ReadData collection flags collectFlag := buf.ReadInt8(ctxErr) + if ctx.HasError() { + return + } elemSerializer := s.elemSerializer @@ -335,6 +341,9 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { } } } + if ctx.HasError() { + return + } // IMPORTANT: collection readers must obey the TRACKING_REF bit written on the // wire, not whatever the local field annotation or inferred Go type would @@ -353,6 +362,9 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { return } } else { + if !buf.CheckReadable(length, ctxErr) { + return + } // For slices, allocate or resize as needed if value.Cap() < length { value.Set(reflect.MakeSlice(value.Type(), length, length)) @@ -370,14 +382,16 @@ func (s *sliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { if !trackRefs && !hasNull { if declaredGenericDispatch { for i := 0; i < length; i++ { - elemSerializer.Read(ctx, RefModeNone, false, true, value.Index(i)) + elem := value.Index(i) + elemSerializer.Read(ctx, RefModeNone, false, true, elem) if ctx.HasError() { return } } } else { for i := 0; i < length; i++ { - elemSerializer.ReadData(ctx, value.Index(i)) + elem := value.Index(i) + elemSerializer.ReadData(ctx, elem) if ctx.HasError() { return } diff --git a/go/fory/slice_dyn.go b/go/fory/slice_dyn.go index 90a7b8cad1..907fcddd4f 100644 --- a/go/fory/slice_dyn.go +++ b/go/fory/slice_dyn.go @@ -259,17 +259,30 @@ func (s sliceDynSerializer) Read(ctx *ReadContext, refMode RefMode, readType boo } func (s sliceDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { + s.readData(ctx, value, -1) +} + +func (s sliceDynSerializer) readData(ctx *ReadContext, value reflect.Value, expectedLength int) { buf := ctx.Buffer() ctxErr := ctx.Err() length := ctx.ReadCollectionLength() sliceType := value.Type() - value.Set(reflect.MakeSlice(sliceType, length, length)) + if ctx.HasError() { + return + } + if expectedLength >= 0 && length != expectedLength { + ctx.SetError(DeserializationErrorf("array length %d does not match serialized length %d", expectedLength, length)) + return + } if length == 0 { + value.Set(reflect.MakeSlice(sliceType, 0, 0)) return } collectFlag := buf.ReadInt8(ctxErr) - ctx.RefResolver().Reference(value) + if ctx.HasError() { + return + } var elemTypeInfo *TypeInfo var elemType reflect.Type @@ -286,10 +299,23 @@ func (s sliceDynSerializer) ReadData(ctx *ReadContext, value reflect.Value) { elemType = sliceType.Elem() elemSerializer, _ = ctx.TypeResolver().getSerializerByType(elemType, false) } - s.readSameType(ctx, buf, value, elemType, elemSerializer, collectFlag) + if ctx.HasError() { + return + } + if !buf.CheckReadable(length, ctxErr) { + return + } + value.Set(reflect.MakeSlice(sliceType, length, length)) + ctx.RefResolver().Reference(value) + s.readSameType(ctx, buf, value, elemType, elemSerializer, collectFlag, length) return } - s.readDifferentTypes(ctx, buf, value, collectFlag) + if !buf.CheckReadable(length, ctxErr) { + return + } + value.Set(reflect.MakeSlice(sliceType, length, length)) + ctx.RefResolver().Reference(value) + s.readDifferentTypes(ctx, buf, value, collectFlag, length) } func (s sliceDynSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, typeInfo *TypeInfo, value reflect.Value) { @@ -298,7 +324,7 @@ func (s sliceDynSerializer) ReadWithTypeInfo(ctx *ReadContext, refMode RefMode, } // readSameType handles deserialization of slices where all elements share the same type -func (s sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, value reflect.Value, elemType reflect.Type, serializer Serializer, flag int8) { +func (s sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, value reflect.Value, elemType reflect.Type, serializer Serializer, flag int8, length int) { trackRefs := (flag & CollectionTrackingRef) != 0 hasNull := (flag & CollectionHasNull) != 0 ctxErr := ctx.Err() @@ -316,7 +342,7 @@ func (s sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, valu isNamedStruct = true } - for i := 0; i < value.Len(); i++ { + for i := 0; i < length; i++ { if trackRefs { refID, refErr := ctx.RefResolver().TryPreserveRefId(buf) if refErr != nil { @@ -349,6 +375,9 @@ func (s sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, valu serializer.ReadData(ctx, elem) ctx.RefResolver().Reference(elem) } + if ctx.HasError() { + return + } value.Index(i).Set(elem) } else if hasNull { refFlag := buf.ReadInt8(ctxErr) @@ -357,26 +386,29 @@ func (s sliceDynSerializer) readSameType(ctx *ReadContext, buf *ByteBuffer, valu } elem := reflect.New(elemType).Elem() serializer.ReadData(ctx, elem) + if ctx.HasError() { + return + } value.Index(i).Set(elem) } else { elem := reflect.New(elemType).Elem() serializer.ReadData(ctx, elem) + if ctx.HasError() { + return + } value.Index(i).Set(elem) } - if ctx.HasError() { - return - } } } // readDifferentTypes handles deserialization of slices with mixed element types func (s sliceDynSerializer) readDifferentTypes( - ctx *ReadContext, buf *ByteBuffer, value reflect.Value, flag int8) { + ctx *ReadContext, buf *ByteBuffer, value reflect.Value, flag int8, length int) { trackRefs := (flag & CollectionTrackingRef) != 0 hasNull := (flag & CollectionHasNull) != 0 ctxErr := ctx.Err() - for i := 0; i < value.Len(); i++ { + for i := 0; i < length; i++ { if trackRefs { refID, refErr := ctx.RefResolver().TryPreserveRefId(buf) if refErr != nil { @@ -402,6 +434,9 @@ func (s sliceDynSerializer) readDifferentTypes( elem := reflect.New(elemType).Elem() serializer.ReadData(ctx, elem) ctx.RefResolver().SetReadObject(refID, elem) + if ctx.HasError() { + return + } value.Index(i).Set(elem) } else { if hasNull { @@ -417,11 +452,11 @@ func (s sliceDynSerializer) readDifferentTypes( elemType, serializer := s.wrapSerializerIfNeeded(typeInfo.Type, typeInfo.Serializer) elem := reflect.New(elemType).Elem() serializer.ReadData(ctx, elem) + if ctx.HasError() { + return + } value.Index(i).Set(elem) } - if ctx.HasError() { - return - } } } diff --git a/go/fory/slice_primitive.go b/go/fory/slice_primitive.go index 76df5a0dc8..9b92691ac8 100644 --- a/go/fory/slice_primitive.go +++ b/go/fory/slice_primitive.go @@ -75,14 +75,19 @@ func (s byteSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() length := ctx.ReadBinaryLength() + if ctx.HasError() { + return + } ptr := (*[]byte)(value.Addr().UnsafePointer()) if length == 0 { *ptr = make([]byte, 0) return } + if !buf.CheckReadable(length, ctxErr) { + return + } result := make([]byte, length) - raw := buf.ReadBinary(length, ctxErr) - copy(result, raw) + buf.Read(result) *ptr = result } @@ -643,6 +648,9 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { buf := ctx.Buffer() ctxErr := ctx.Err() length := ctx.ReadCollectionLength() + if ctx.HasError() { + return + } ptr := (*[]string)(value.Addr().UnsafePointer()) if length == 0 { *ptr = make([]string, 0) @@ -651,11 +659,20 @@ func (s stringSliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) { // Read collection flags collectFlag := buf.ReadInt8(ctxErr) + if ctx.HasError() { + return + } // Read element type info if present (when CollectionIsSameType but not CollectionIsDeclElementType) if (collectFlag&CollectionIsSameType) != 0 && (collectFlag&CollectionIsDeclElementType) == 0 { _ = buf.ReadUint8(ctxErr) // Read and discard type ID (we know it's STRING) } + if ctx.HasError() { + return + } + if !buf.CheckReadable(length, ctxErr) { + return + } result := make([]string, length) @@ -702,12 +719,11 @@ func ReadByteSlice(buf *ByteBuffer, err *Error) []byte { if size == 0 { return make([]byte, 0) } - raw := buf.ReadBinary(size, err) - if err.HasError() { + if !buf.CheckReadable(size, err) { return nil } result := make([]byte, size) - copy(result, raw) + buf.Read(result) return result } @@ -726,12 +742,11 @@ func ReadBoolSlice(buf *ByteBuffer, err *Error) []bool { if size == 0 { return make([]bool, 0) } - raw := buf.ReadBinary(size, err) - if err.HasError() { + if !buf.CheckReadable(size, err) { return nil } result := make([]bool, size) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size)) return result } @@ -750,12 +765,11 @@ func ReadInt8Slice(buf *ByteBuffer, err *Error) []int8 { if size == 0 { return make([]int8, 0) } - raw := buf.ReadBinary(size, err) - if err.HasError() { + if !buf.CheckReadable(size, err) { return nil } result := make([]int8, size) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size)) return result } @@ -776,18 +790,24 @@ func WriteInt16Slice(buf *ByteBuffer, value []int16) { // ReadInt16Slice reads []int16 from buffer using ARRAY protocol func ReadInt16Slice(buf *ByteBuffer, err *Error) []int16 { - size := buf.ReadLength(err) - length := size / 2 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%2 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 2) + return nil + } + if byteSize == 0 { return make([]int16, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 2 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]int16, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]int16, length) @@ -815,18 +835,24 @@ func WriteInt32Slice(buf *ByteBuffer, value []int32) { // ReadInt32Slice reads []int32 from buffer using ARRAY protocol func ReadInt32Slice(buf *ByteBuffer, err *Error) []int32 { - size := buf.ReadLength(err) - length := size / 4 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%4 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 4) + return nil + } + if byteSize == 0 { return make([]int32, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 4 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]int32, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]int32, length) @@ -854,18 +880,24 @@ func WriteInt64Slice(buf *ByteBuffer, value []int64) { // ReadInt64Slice reads []int64 from buffer using ARRAY protocol func ReadInt64Slice(buf *ByteBuffer, err *Error) []int64 { - size := buf.ReadLength(err) - length := size / 8 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%8 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 8) + return nil + } + if byteSize == 0 { return make([]int64, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 8 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]int64, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]int64, length) @@ -893,18 +925,24 @@ func WriteUint16Slice(buf *ByteBuffer, value []uint16) { // ReadUint16Slice reads []uint16 from buffer using ARRAY protocol func ReadUint16Slice(buf *ByteBuffer, err *Error) []uint16 { - size := buf.ReadLength(err) - length := size / 2 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%2 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 2) + return nil + } + if byteSize == 0 { return make([]uint16, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 2 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]uint16, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]uint16, length) @@ -932,18 +970,24 @@ func WriteUint32Slice(buf *ByteBuffer, value []uint32) { // ReadUint32Slice reads []uint32 from buffer using ARRAY protocol func ReadUint32Slice(buf *ByteBuffer, err *Error) []uint32 { - size := buf.ReadLength(err) - length := size / 4 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%4 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 4) + return nil + } + if byteSize == 0 { return make([]uint32, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 4 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]uint32, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]uint32, length) @@ -971,18 +1015,24 @@ func WriteUint64Slice(buf *ByteBuffer, value []uint64) { // ReadUint64Slice reads []uint64 from buffer using ARRAY protocol func ReadUint64Slice(buf *ByteBuffer, err *Error) []uint64 { - size := buf.ReadLength(err) - length := size / 8 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%8 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 8) + return nil + } + if byteSize == 0 { return make([]uint64, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 8 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]uint64, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]uint64, length) @@ -1010,18 +1060,24 @@ func WriteFloat32Slice(buf *ByteBuffer, value []float32) { // ReadFloat32Slice reads []float32 from buffer using ARRAY protocol func ReadFloat32Slice(buf *ByteBuffer, err *Error) []float32 { - size := buf.ReadLength(err) - length := size / 4 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%4 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 4) + return nil + } + if byteSize == 0 { return make([]float32, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 4 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]float32, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]float32, length) @@ -1049,18 +1105,24 @@ func WriteFloat64Slice(buf *ByteBuffer, value []float64) { // ReadFloat64Slice reads []float64 from buffer using ARRAY protocol func ReadFloat64Slice(buf *ByteBuffer, err *Error) []float64 { - size := buf.ReadLength(err) - length := size / 8 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%8 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 8) + return nil + } + if byteSize == 0 { return make([]float64, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 8 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]float64, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]float64, length) @@ -1128,10 +1190,14 @@ func (s float16SliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) buf := ctx.Buffer() ctxErr := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 2 if ctx.HasError() { return } + if size%2 != 0 { + ctx.SetError(DeserializationErrorf("float16 array byte length %d is not aligned to element size 2", size)) + return + } + length := size / 2 // Ensure capacity ptr := (*[]float16.Float16)(value.Addr().UnsafePointer()) @@ -1140,13 +1206,15 @@ func (s float16SliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) return } + if !buf.CheckReadable(size, ctxErr) { + return + } result := make([]float16.Float16, length) if isLittleEndian { - raw := buf.ReadBinary(size, ctxErr) // unsafe copy targetPtr := unsafe.Pointer(&result[0]) - copy(unsafe.Slice((*byte)(targetPtr), size), raw) + buf.Read(unsafe.Slice((*byte)(targetPtr), size)) } else { for i := 0; i < length; i++ { // ReadUint16 handles endianness @@ -1187,19 +1255,25 @@ func WriteIntSlice(buf *ByteBuffer, value []int) { // ReadIntSlice reads []int from buffer using ARRAY protocol func ReadIntSlice(buf *ByteBuffer, err *Error) []int { - size := buf.ReadLength(err) if strconv.IntSize == 64 { - length := size / 8 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%8 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 8) + return nil + } + if byteSize == 0 { return make([]int, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 8 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]int, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]int, length) @@ -1209,17 +1283,24 @@ func ReadIntSlice(buf *ByteBuffer, err *Error) []int { return result } } else { - length := size / 4 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%4 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 4) + return nil + } + if byteSize == 0 { return make([]int, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 4 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]int, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]int, length) @@ -1262,19 +1343,25 @@ func WriteUintSlice(buf *ByteBuffer, value []uint) { // ReadUintSlice reads []uint from buffer using ARRAY protocol func ReadUintSlice(buf *ByteBuffer, err *Error) []uint { - size := buf.ReadLength(err) if strconv.IntSize == 64 { - length := size / 8 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%8 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 8) + return nil + } + if byteSize == 0 { return make([]uint, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 8 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]uint, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]uint, length) @@ -1284,17 +1371,24 @@ func ReadUintSlice(buf *ByteBuffer, err *Error) []uint { return result } } else { - length := size / 4 - if length == 0 { + byteSize := buf.ReadLength(err) + if err.HasError() { + return nil + } + if byteSize%4 != 0 { + *err = DeserializationErrorf("array byte size %d is not aligned to element size %d", byteSize, 4) + return nil + } + if byteSize == 0 { return make([]uint, 0) } + if !buf.CheckReadable(byteSize, err) { + return nil + } + length := byteSize / 4 if isLittleEndian { - raw := buf.ReadBinary(size, err) - if err.HasError() { - return nil - } result := make([]uint, length) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), byteSize)) return result } else { result := make([]uint, length) @@ -1337,6 +1431,9 @@ func ReadStringSlice(buf *ByteBuffer, err *Error) []string { if (collectFlag&CollectionIsSameType) != 0 && (collectFlag&CollectionIsDeclElementType) == 0 { _ = buf.ReadUint8(err) // Read and discard element type ID } + if err.HasError() || !buf.CheckReadable(length, err) { + return nil + } result := make([]string, length) trackRefs := (collectFlag & CollectionTrackingRef) != 0 hasNull := (collectFlag & CollectionHasNull) != 0 @@ -1405,10 +1502,14 @@ func (s bfloat16SliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) buf := ctx.Buffer() ctxErr := ctx.Err() size := ctx.ReadBinaryLength() - length := size / 2 if ctx.HasError() { return } + if size%2 != 0 { + ctx.SetError(DeserializationErrorf("bfloat16 array byte length %d is not aligned to element size 2", size)) + return + } + length := size / 2 ptr := (*[]bfloat16.BFloat16)(value.Addr().UnsafePointer()) if length == 0 { @@ -1416,12 +1517,14 @@ func (s bfloat16SliceSerializer) ReadData(ctx *ReadContext, value reflect.Value) return } + if !buf.CheckReadable(size, ctxErr) { + return + } result := make([]bfloat16.BFloat16, length) if isLittleEndian { - raw := buf.ReadBinary(size, ctxErr) targetPtr := unsafe.Pointer(&result[0]) - copy(unsafe.Slice((*byte)(targetPtr), size), raw) + buf.Read(unsafe.Slice((*byte)(targetPtr), size)) } else { for i := 0; i < length; i++ { result[i] = bfloat16.BFloat16FromBits(buf.ReadUint16(ctxErr)) diff --git a/go/fory/slice_primitive_list.go b/go/fory/slice_primitive_list.go index aa965d1e45..0335b2a08e 100644 --- a/go/fory/slice_primitive_list.go +++ b/go/fory/slice_primitive_list.go @@ -112,31 +112,31 @@ func (s primitiveListSerializer) writeDataWithGenerics(ctx *WriteContext, value func (s primitiveListSerializer) writeValues(buf *ByteBuffer, value reflect.Value) { switch s.type_.Elem().Kind() { case reflect.Bool: - writeBoolListPayload(buf, primitiveListSliceView[bool](value)) + writeBoolListValues(buf, primitiveListSliceView[bool](value)) case reflect.Int8: - writeInt8ListPayload(buf, primitiveListSliceView[int8](value)) + writeInt8ListValues(buf, primitiveListSliceView[int8](value)) case reflect.Uint8: - writeUint8ListPayload(buf, primitiveListSliceView[byte](value)) + writeUint8ListValues(buf, primitiveListSliceView[byte](value)) case reflect.Int16: - writeInt16ListPayload(buf, primitiveListSliceView[int16](value)) + writeInt16ListValues(buf, primitiveListSliceView[int16](value)) case reflect.Uint16: - writeUint16ListPayload(buf, primitiveListSliceView[uint16](value)) + writeUint16ListValues(buf, primitiveListSliceView[uint16](value)) case reflect.Int32: - writeInt32ListPayload(buf, primitiveListSliceView[int32](value), s.elemTypeID) + writeInt32ListValues(buf, primitiveListSliceView[int32](value), s.elemTypeID) case reflect.Uint32: - writeUint32ListPayload(buf, primitiveListSliceView[uint32](value), s.elemTypeID) + writeUint32ListValues(buf, primitiveListSliceView[uint32](value), s.elemTypeID) case reflect.Int64: - writeInt64ListPayload(buf, primitiveListSliceView[int64](value), s.elemTypeID) + writeInt64ListValues(buf, primitiveListSliceView[int64](value), s.elemTypeID) case reflect.Uint64: - writeUint64ListPayload(buf, primitiveListSliceView[uint64](value), s.elemTypeID) + writeUint64ListValues(buf, primitiveListSliceView[uint64](value), s.elemTypeID) case reflect.Int: - writeIntListPayload(buf, primitiveListSliceView[int](value), s.elemTypeID) + writeIntListValues(buf, primitiveListSliceView[int](value), s.elemTypeID) case reflect.Uint: - writeUintListPayload(buf, primitiveListSliceView[uint](value), s.elemTypeID) + writeUintListValues(buf, primitiveListSliceView[uint](value), s.elemTypeID) case reflect.Float32: - writeFloat32ListPayload(buf, primitiveListSliceView[float32](value)) + writeFloat32ListValues(buf, primitiveListSliceView[float32](value)) case reflect.Float64: - writeFloat64ListPayload(buf, primitiveListSliceView[float64](value)) + writeFloat64ListValues(buf, primitiveListSliceView[float64](value)) } } @@ -164,21 +164,33 @@ func (s primitiveListSerializer) ReadData(ctx *ReadContext, value reflect.Value) buf := ctx.Buffer() err := ctx.Err() length := ctx.ReadCollectionLength() + if ctx.HasError() { + return + } if length == 0 { value.Set(reflect.MakeSlice(value.Type(), 0, 0)) return } collectFlag := buf.ReadInt8(err) + if ctx.HasError() { + return + } if (collectFlag & CollectionIsSameType) != 0 { if (collectFlag & CollectionIsDeclElementType) == 0 { ctx.TypeResolver().ReadTypeInfo(buf, err) } } + if ctx.HasError() { + return + } if (collectFlag & CollectionTrackingRef) != 0 { ctx.SetError(DeserializationErrorf("primitive list does not support reference-tracked elements")) return } hasNull := (collectFlag & CollectionHasNull) != 0 + if !s.checkBodyReadable(buf, err, length, hasNull) { + return + } s.readValues(buf, err, value, length, hasNull) } @@ -227,11 +239,17 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val return } collectFlag := buf.ReadInt8(err) + if ctx.HasError() { + return + } if (collectFlag & CollectionIsSameType) != 0 { if (collectFlag & CollectionIsDeclElementType) == 0 { ctx.TypeResolver().ReadTypeInfo(buf, err) } } + if ctx.HasError() { + return + } if (collectFlag & CollectionTrackingRef) != 0 { ctx.SetError(DeserializationErrorf("array-compatible list does not support reference-tracked elements")) return @@ -244,6 +262,9 @@ func (s compatiblePrimitiveListToArraySerializer) ReadData(ctx *ReadContext, val ctx.SetError(DeserializationErrorf("array-compatible list requires declared same-type elements")) return } + if !s.listReader.checkBodyReadable(buf, err, length, false) { + return + } if value.Kind() == reflect.Slice { temp := reflect.New(value.Type()).Elem() s.listReader.readValues(buf, err, temp, length, false) @@ -260,54 +281,79 @@ func (s compatiblePrimitiveListToArraySerializer) ReadWithTypeInfo(ctx *ReadCont s.Read(ctx, refMode, false, false, value) } +func (s primitiveListSerializer) checkBodyReadable(buf *ByteBuffer, err *Error, length int, hasNull bool) bool { + if !hasNull { + switch s.type_.Elem().Kind() { + case reflect.Int16, reflect.Uint16: + return checkPrimitiveListBytes(buf, err, length, 2) + case reflect.Int32, reflect.Uint32: + if s.elemTypeID == INT32 || s.elemTypeID == UINT32 { + return checkPrimitiveListBytes(buf, err, length, 4) + } + case reflect.Int64, reflect.Uint64: + if s.elemTypeID == INT64 || s.elemTypeID == UINT64 { + return checkPrimitiveListBytes(buf, err, length, 8) + } + case reflect.Int: + if s.elemTypeID == INT64 { + return checkPrimitiveListBytes(buf, err, length, 8) + } else if s.elemTypeID == INT32 { + return checkPrimitiveListBytes(buf, err, length, 4) + } + case reflect.Uint: + if s.elemTypeID == UINT64 { + return checkPrimitiveListBytes(buf, err, length, 8) + } else if s.elemTypeID == UINT32 { + return checkPrimitiveListBytes(buf, err, length, 4) + } + case reflect.Float32: + return checkPrimitiveListBytes(buf, err, length, 4) + case reflect.Float64: + return checkPrimitiveListBytes(buf, err, length, 8) + } + } + return buf.CheckReadable(length, err) +} + func (s primitiveListSerializer) readValues(buf *ByteBuffer, err *Error, value reflect.Value, length int, hasNull bool) { switch s.type_.Elem().Kind() { case reflect.Bool: - *(*[]bool)(value.Addr().UnsafePointer()) = readBoolListPayload(buf, err, length, hasNull) + *(*[]bool)(value.Addr().UnsafePointer()) = readBoolListValues(buf, err, length, hasNull) case reflect.Int8: - *(*[]int8)(value.Addr().UnsafePointer()) = readInt8ListPayload(buf, err, length, hasNull) + *(*[]int8)(value.Addr().UnsafePointer()) = readInt8ListValues(buf, err, length, hasNull) case reflect.Uint8: - *(*[]byte)(value.Addr().UnsafePointer()) = readUint8ListPayload(buf, err, length, hasNull) + *(*[]byte)(value.Addr().UnsafePointer()) = readUint8ListValues(buf, err, length, hasNull) case reflect.Int16: - *(*[]int16)(value.Addr().UnsafePointer()) = readInt16ListPayload(buf, err, length, hasNull) + *(*[]int16)(value.Addr().UnsafePointer()) = readInt16ListValues(buf, err, length, hasNull) case reflect.Uint16: - *(*[]uint16)(value.Addr().UnsafePointer()) = readUint16ListPayload(buf, err, length, hasNull) + *(*[]uint16)(value.Addr().UnsafePointer()) = readUint16ListValues(buf, err, length, hasNull) case reflect.Int32: - *(*[]int32)(value.Addr().UnsafePointer()) = readInt32ListPayload(buf, err, length, hasNull, s.elemTypeID) + *(*[]int32)(value.Addr().UnsafePointer()) = readInt32ListValues(buf, err, length, hasNull, s.elemTypeID) case reflect.Uint32: - *(*[]uint32)(value.Addr().UnsafePointer()) = readUint32ListPayload(buf, err, length, hasNull, s.elemTypeID) + *(*[]uint32)(value.Addr().UnsafePointer()) = readUint32ListValues(buf, err, length, hasNull, s.elemTypeID) case reflect.Int64: - *(*[]int64)(value.Addr().UnsafePointer()) = readInt64ListPayload(buf, err, length, hasNull, s.elemTypeID) + *(*[]int64)(value.Addr().UnsafePointer()) = readInt64ListValues(buf, err, length, hasNull, s.elemTypeID) case reflect.Uint64: - *(*[]uint64)(value.Addr().UnsafePointer()) = readUint64ListPayload(buf, err, length, hasNull, s.elemTypeID) + *(*[]uint64)(value.Addr().UnsafePointer()) = readUint64ListValues(buf, err, length, hasNull, s.elemTypeID) case reflect.Int: - *(*[]int)(value.Addr().UnsafePointer()) = readIntListPayload(buf, err, length, hasNull, s.elemTypeID) + *(*[]int)(value.Addr().UnsafePointer()) = readIntListValues(buf, err, length, hasNull, s.elemTypeID) case reflect.Uint: - *(*[]uint)(value.Addr().UnsafePointer()) = readUintListPayload(buf, err, length, hasNull, s.elemTypeID) + *(*[]uint)(value.Addr().UnsafePointer()) = readUintListValues(buf, err, length, hasNull, s.elemTypeID) case reflect.Float32: - *(*[]float32)(value.Addr().UnsafePointer()) = readFloat32ListPayload(buf, err, length, hasNull) + *(*[]float32)(value.Addr().UnsafePointer()) = readFloat32ListValues(buf, err, length, hasNull) case reflect.Float64: - *(*[]float64)(value.Addr().UnsafePointer()) = readFloat64ListPayload(buf, err, length, hasNull) + *(*[]float64)(value.Addr().UnsafePointer()) = readFloat64ListValues(buf, err, length, hasNull) } } func (s primitiveListSerializer) readArrayValues(buf *ByteBuffer, err *Error, value reflect.Value, length int) { switch s.type_.Elem().Kind() { case reflect.Bool: - raw := buf.ReadBinary(length, err) - for i := 0; i < length; i++ { - value.Index(i).SetBool(raw[i] != 0) - } + buf.Read(unsafe.Slice((*byte)(value.Addr().UnsafePointer()), length)) case reflect.Int8: - raw := buf.ReadBinary(length, err) - for i := 0; i < length; i++ { - value.Index(i).SetInt(int64(int8(raw[i]))) - } + buf.Read(unsafe.Slice((*byte)(value.Addr().UnsafePointer()), length)) case reflect.Uint8: - raw := buf.ReadBinary(length, err) - for i := 0; i < length; i++ { - value.Index(i).SetUint(uint64(raw[i])) - } + buf.Read(unsafe.Slice((*byte)(value.Addr().UnsafePointer()), length)) case reflect.Int16: for i := 0; i < length; i++ { value.Index(i).SetInt(int64(buf.ReadInt16(err))) @@ -389,17 +435,43 @@ func (s primitiveListSerializer) readArrayValues(buf *ByteBuffer, err *Error, va } } -func writeBoolListPayload(buf *ByteBuffer, value []bool) { +func writeBoolListValues(buf *ByteBuffer, value []bool) { if len(value) > 0 { buf.WriteBinary(unsafe.Slice((*byte)(unsafe.Pointer(&value[0])), len(value))) } } -func readBoolListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool) []bool { +func primitiveListByteSize(length int, elemSize int, err *Error) (int, bool) { + if length < 0 { + *err = DeserializationErrorf("negative primitive list length: %d", length) + return 0, false + } + if elemSize <= 0 { + *err = DeserializationErrorf("invalid primitive element size: %d", elemSize) + return 0, false + } + if length > int(^uint(0)>>1)/elemSize { + *err = DeserializationErrorf("primitive list byte size overflows: length %d element size %d", length, elemSize) + return 0, false + } + return length * elemSize, true +} + +func checkPrimitiveListBytes(buf *ByteBuffer, err *Error, length int, elemSize int) bool { + size, ok := primitiveListByteSize(length, elemSize, err) + if !ok { + return false + } + if !buf.CheckReadable(size, err) { + return false + } + return true +} + +func readBoolListValues(buf *ByteBuffer, err *Error, length int, hasNull bool) []bool { result := make([]bool, length) if !hasNull { - raw := buf.ReadBinary(length, err) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), length), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), length)) return result } for i := 0; i < length; i++ { @@ -410,17 +482,16 @@ func readBoolListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool) return result } -func writeInt8ListPayload(buf *ByteBuffer, value []int8) { +func writeInt8ListValues(buf *ByteBuffer, value []int8) { if len(value) > 0 { buf.WriteBinary(unsafe.Slice((*byte)(unsafe.Pointer(&value[0])), len(value))) } } -func readInt8ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool) []int8 { +func readInt8ListValues(buf *ByteBuffer, err *Error, length int, hasNull bool) []int8 { result := make([]int8, length) if !hasNull { - raw := buf.ReadBinary(length, err) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), length), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), length)) return result } for i := 0; i < length; i++ { @@ -431,17 +502,16 @@ func readInt8ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool) return result } -func writeUint8ListPayload(buf *ByteBuffer, value []byte) { +func writeUint8ListValues(buf *ByteBuffer, value []byte) { if len(value) > 0 { buf.WriteBinary(value) } } -func readUint8ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool) []byte { +func readUint8ListValues(buf *ByteBuffer, err *Error, length int, hasNull bool) []byte { result := make([]byte, length) if !hasNull { - raw := buf.ReadBinary(length, err) - copy(result, raw) + buf.Read(result) return result } for i := 0; i < length; i++ { @@ -452,7 +522,7 @@ func readUint8ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool) return result } -func writeInt16ListPayload(buf *ByteBuffer, value []int16) { +func writeInt16ListValues(buf *ByteBuffer, value []int16) { size := len(value) * 2 if len(value) > 0 { if isLittleEndian { @@ -465,13 +535,12 @@ func writeInt16ListPayload(buf *ByteBuffer, value []int16) { } } -func readInt16ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool) []int16 { +func readInt16ListValues(buf *ByteBuffer, err *Error, length int, hasNull bool) []int16 { result := make([]int16, length) if !hasNull { size := length * 2 if isLittleEndian { - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size)) } else { for i := 0; i < length; i++ { result[i] = buf.ReadInt16(err) @@ -487,7 +556,7 @@ func readInt16ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool) return result } -func writeUint16ListPayload(buf *ByteBuffer, value []uint16) { +func writeUint16ListValues(buf *ByteBuffer, value []uint16) { size := len(value) * 2 if len(value) > 0 { if isLittleEndian { @@ -500,13 +569,12 @@ func writeUint16ListPayload(buf *ByteBuffer, value []uint16) { } } -func readUint16ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool) []uint16 { +func readUint16ListValues(buf *ByteBuffer, err *Error, length int, hasNull bool) []uint16 { result := make([]uint16, length) if !hasNull { size := length * 2 if isLittleEndian { - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size)) } else { for i := 0; i < length; i++ { result[i] = uint16(buf.ReadInt16(err)) @@ -522,9 +590,9 @@ func readUint16ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool return result } -func writeInt32ListPayload(buf *ByteBuffer, value []int32, typeID TypeId) { +func writeInt32ListValues(buf *ByteBuffer, value []int32, typeID TypeId) { if typeID == INT32 { - writeInt32FixedListPayload(buf, value) + writeInt32FixedListValues(buf, value) return } for _, v := range value { @@ -532,7 +600,7 @@ func writeInt32ListPayload(buf *ByteBuffer, value []int32, typeID TypeId) { } } -func writeInt32FixedListPayload(buf *ByteBuffer, value []int32) { +func writeInt32FixedListValues(buf *ByteBuffer, value []int32) { size := len(value) * 4 if len(value) > 0 { if isLittleEndian { @@ -545,13 +613,12 @@ func writeInt32FixedListPayload(buf *ByteBuffer, value []int32) { } } -func readInt32ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []int32 { +func readInt32ListValues(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []int32 { result := make([]int32, length) if !hasNull && typeID == INT32 { size := length * 4 if isLittleEndian { - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size)) } else { for i := 0; i < length; i++ { result[i] = buf.ReadInt32(err) @@ -572,9 +639,9 @@ func readInt32ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool, return result } -func writeUint32ListPayload(buf *ByteBuffer, value []uint32, typeID TypeId) { +func writeUint32ListValues(buf *ByteBuffer, value []uint32, typeID TypeId) { if typeID == UINT32 { - writeUint32FixedListPayload(buf, value) + writeUint32FixedListValues(buf, value) return } for _, v := range value { @@ -582,7 +649,7 @@ func writeUint32ListPayload(buf *ByteBuffer, value []uint32, typeID TypeId) { } } -func writeUint32FixedListPayload(buf *ByteBuffer, value []uint32) { +func writeUint32FixedListValues(buf *ByteBuffer, value []uint32) { size := len(value) * 4 if len(value) > 0 { if isLittleEndian { @@ -595,13 +662,12 @@ func writeUint32FixedListPayload(buf *ByteBuffer, value []uint32) { } } -func readUint32ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []uint32 { +func readUint32ListValues(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []uint32 { result := make([]uint32, length) if !hasNull && typeID == UINT32 { size := length * 4 if isLittleEndian { - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size)) } else { for i := 0; i < length; i++ { result[i] = uint32(buf.ReadInt32(err)) @@ -622,10 +688,10 @@ func readUint32ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool return result } -func writeInt64ListPayload(buf *ByteBuffer, value []int64, typeID TypeId) { +func writeInt64ListValues(buf *ByteBuffer, value []int64, typeID TypeId) { switch typeID { case INT64: - writeInt64FixedListPayload(buf, value) + writeInt64FixedListValues(buf, value) case TAGGED_INT64: for _, v := range value { buf.WriteTaggedInt64(v) @@ -637,7 +703,7 @@ func writeInt64ListPayload(buf *ByteBuffer, value []int64, typeID TypeId) { } } -func writeInt64FixedListPayload(buf *ByteBuffer, value []int64) { +func writeInt64FixedListValues(buf *ByteBuffer, value []int64) { size := len(value) * 8 if len(value) > 0 { if isLittleEndian { @@ -650,13 +716,12 @@ func writeInt64FixedListPayload(buf *ByteBuffer, value []int64) { } } -func readInt64ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []int64 { +func readInt64ListValues(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []int64 { result := make([]int64, length) if !hasNull && typeID == INT64 { size := length * 8 if isLittleEndian { - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size)) } else { for i := 0; i < length; i++ { result[i] = buf.ReadInt64(err) @@ -680,10 +745,10 @@ func readInt64ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool, return result } -func writeUint64ListPayload(buf *ByteBuffer, value []uint64, typeID TypeId) { +func writeUint64ListValues(buf *ByteBuffer, value []uint64, typeID TypeId) { switch typeID { case UINT64: - writeUint64FixedListPayload(buf, value) + writeUint64FixedListValues(buf, value) case TAGGED_UINT64: for _, v := range value { buf.WriteTaggedUint64(v) @@ -695,7 +760,7 @@ func writeUint64ListPayload(buf *ByteBuffer, value []uint64, typeID TypeId) { } } -func writeUint64FixedListPayload(buf *ByteBuffer, value []uint64) { +func writeUint64FixedListValues(buf *ByteBuffer, value []uint64) { size := len(value) * 8 if len(value) > 0 { if isLittleEndian { @@ -708,13 +773,12 @@ func writeUint64FixedListPayload(buf *ByteBuffer, value []uint64) { } } -func readUint64ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []uint64 { +func readUint64ListValues(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []uint64 { result := make([]uint64, length) if !hasNull && typeID == UINT64 { size := length * 8 if isLittleEndian { - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size)) } else { for i := 0; i < length; i++ { result[i] = uint64(buf.ReadInt64(err)) @@ -738,10 +802,10 @@ func readUint64ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool return result } -func writeIntListPayload(buf *ByteBuffer, value []int, typeID TypeId) { +func writeIntListValues(buf *ByteBuffer, value []int, typeID TypeId) { if reflect.TypeOf(int(0)).Size() == 8 { asInt64 := unsafe.Slice((*int64)(unsafe.Pointer(&value[0])), len(value)) - writeInt64ListPayload(buf, asInt64, typeID) + writeInt64ListValues(buf, asInt64, typeID) return } for _, v := range value { @@ -753,10 +817,10 @@ func writeIntListPayload(buf *ByteBuffer, value []int, typeID TypeId) { } } -func readIntListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []int { +func readIntListValues(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []int { result := make([]int, length) if reflect.TypeOf(int(0)).Size() == 8 { - values := readInt64ListPayload(buf, err, length, hasNull, typeID) + values := readInt64ListValues(buf, err, length, hasNull, typeID) copy(unsafe.Slice((*int64)(unsafe.Pointer(&result[0])), length), values) return result } @@ -773,10 +837,10 @@ func readIntListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool, t return result } -func writeUintListPayload(buf *ByteBuffer, value []uint, typeID TypeId) { +func writeUintListValues(buf *ByteBuffer, value []uint, typeID TypeId) { if reflect.TypeOf(uint(0)).Size() == 8 { asUint64 := unsafe.Slice((*uint64)(unsafe.Pointer(&value[0])), len(value)) - writeUint64ListPayload(buf, asUint64, typeID) + writeUint64ListValues(buf, asUint64, typeID) return } for _, v := range value { @@ -788,10 +852,10 @@ func writeUintListPayload(buf *ByteBuffer, value []uint, typeID TypeId) { } } -func readUintListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []uint { +func readUintListValues(buf *ByteBuffer, err *Error, length int, hasNull bool, typeID TypeId) []uint { result := make([]uint, length) if reflect.TypeOf(uint(0)).Size() == 8 { - values := readUint64ListPayload(buf, err, length, hasNull, typeID) + values := readUint64ListValues(buf, err, length, hasNull, typeID) copy(unsafe.Slice((*uint64)(unsafe.Pointer(&result[0])), length), values) return result } @@ -808,7 +872,7 @@ func readUintListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool, return result } -func writeFloat32ListPayload(buf *ByteBuffer, value []float32) { +func writeFloat32ListValues(buf *ByteBuffer, value []float32) { size := len(value) * 4 if len(value) > 0 { if isLittleEndian { @@ -821,13 +885,12 @@ func writeFloat32ListPayload(buf *ByteBuffer, value []float32) { } } -func readFloat32ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool) []float32 { +func readFloat32ListValues(buf *ByteBuffer, err *Error, length int, hasNull bool) []float32 { result := make([]float32, length) if !hasNull { size := length * 4 if isLittleEndian { - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size)) } else { for i := 0; i < length; i++ { result[i] = buf.ReadFloat32(err) @@ -843,7 +906,7 @@ func readFloat32ListPayload(buf *ByteBuffer, err *Error, length int, hasNull boo return result } -func writeFloat64ListPayload(buf *ByteBuffer, value []float64) { +func writeFloat64ListValues(buf *ByteBuffer, value []float64) { size := len(value) * 8 if len(value) > 0 { if isLittleEndian { @@ -856,13 +919,12 @@ func writeFloat64ListPayload(buf *ByteBuffer, value []float64) { } } -func readFloat64ListPayload(buf *ByteBuffer, err *Error, length int, hasNull bool) []float64 { +func readFloat64ListValues(buf *ByteBuffer, err *Error, length int, hasNull bool) []float64 { result := make([]float64, length) if !hasNull { size := length * 8 if isLittleEndian { - raw := buf.ReadBinary(size, err) - copy(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size), raw) + buf.Read(unsafe.Slice((*byte)(unsafe.Pointer(&result[0])), size)) } else { for i := 0; i < length; i++ { result[i] = buf.ReadFloat64(err) diff --git a/go/fory/slice_primitive_test.go b/go/fory/slice_primitive_test.go index f566eb366c..284a2b6126 100644 --- a/go/fory/slice_primitive_test.go +++ b/go/fory/slice_primitive_test.go @@ -19,6 +19,7 @@ package fory import ( "math" + "reflect" "testing" "github.com/apache/fory/go/fory/bfloat16" @@ -26,6 +27,14 @@ import ( "github.com/stretchr/testify/assert" ) +func TestPrimitiveListReadableOverflow(t *testing.T) { + var err Error + serializer := primitiveListSerializer{type_: reflect.TypeOf([]int64{}), elemTypeID: INT64} + length := int(^uint(0)>>1)/8 + 1 + assert.False(t, serializer.checkBodyReadable(NewByteBuffer(nil), &err, length, false)) + assert.True(t, err.HasError()) +} + func TestFloat16Slice(t *testing.T) { f := NewFory(WithXlang(false), WithCompatible(false)) @@ -467,13 +476,40 @@ func TestReadInt32Slice_OOM_Bug(t *testing.T) { assert.Equal(t, 0, len(result), "Expected an empty slice due to missing data") } +func TestReadFixedWidthSliceBytes(t *testing.T) { + t.Run("unaligned_size", func(t *testing.T) { + buf := NewByteBuffer(nil) + buf.WriteLength(3) + buf.WriteBinary([]byte{1, 2, 3}) + buf.SetReaderIndex(0) + + err := &Error{} + result := ReadInt16Slice(buf, err) + + assert.True(t, err.HasError()) + assert.Nil(t, result) + }) + + t.Run("missing_body", func(t *testing.T) { + buf := NewByteBuffer(nil) + buf.WriteLength(40000) + buf.SetReaderIndex(0) + + err := &Error{} + result := ReadFloat64Slice(buf, err) + + assert.True(t, err.HasError()) + assert.Nil(t, result) + }) +} + func TestReadBoolSliceWrappedBuffer(t *testing.T) { - payload := NewByteBuffer(nil) - WriteBoolSlice(payload, []bool{true, false}) + arrayBytes := NewByteBuffer(nil) + WriteBoolSlice(arrayBytes, []bool{true, false}) err := &Error{} - result := ReadBoolSlice(NewByteBuffer(payload.Bytes()), err) + result := ReadBoolSlice(NewByteBuffer(arrayBytes.Bytes()), err) - assert.False(t, err.HasError(), "Expected wrapped buffer reads to use the serialized payload") + assert.False(t, err.HasError(), "Expected wrapped buffer reads to use serialized array bytes") assert.Equal(t, []bool{true, false}, result) } diff --git a/go/fory/string.go b/go/fory/string.go index 9165aaf66c..fafab3288a 100644 --- a/go/fory/string.go +++ b/go/fory/string.go @@ -53,6 +53,10 @@ func readString(buf *ByteBuffer, err *Error) string { header := buf.ReadVaruint36Small(err) size := header >> 2 // Extract byte count encoding := header & 0b11 // Extract encoding type + if intSize == 32 && size > uint64(MaxInt) { + err.SetError(fmt.Errorf("string byte count %d exceeds supported int range", size)) + return "" + } switch encoding { case encodingLatin1: diff --git a/go/fory/tests/structs_fory_gen.go b/go/fory/tests/structs_fory_gen.go index bca11252ac..742135a8ba 100644 --- a/go/fory/tests/structs_fory_gen.go +++ b/go/fory/tests/structs_fory_gen.go @@ -1,6 +1,6 @@ // Code generated by forygen. DO NOT EDIT. // source: structs.go -// generated at: 2026-05-04T02:27:29+08:00 +// generated at: 2026-06-12T06:41:26+08:00 package fory @@ -186,11 +186,20 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v if isXlang { // xlang mode: slices are not nullable, read directly without null flag sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } if sliceLen == 0 { v.DynamicSlice = make([]any, 0) } else { // ReadData collection flags (ignore for now) _ = buf.ReadInt8(err) + if ctx.HasError() { + return ctx.TakeError() + } + if !buf.CheckReadable(sliceLen, err) { + return ctx.TakeError() + } // Create slice with proper capacity v.DynamicSlice = make([]any, sliceLen) // ReadData each element using ReadValue @@ -205,11 +214,20 @@ func (g *DynamicSliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v v.DynamicSlice = nil } else { sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } if sliceLen == 0 { v.DynamicSlice = make([]any, 0) } else { // ReadData collection flags (ignore for now) _ = buf.ReadInt8(err) + if ctx.HasError() { + return ctx.TakeError() + } + if !buf.CheckReadable(sliceLen, err) { + return ctx.TakeError() + } // Create slice with proper capacity v.DynamicSlice = make([]any, sliceLen) // ReadData each element using ReadValue @@ -647,12 +665,22 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if mapLen == 0 { v.IntMap = make(map[int]int) } else { + if !buf.CheckReadable(mapLen, err) { + return ctx.TakeError() + } v.IntMap = make(map[int]int, mapLen) mapSize := mapLen for mapSize > 0 { // ReadData KV header kvHeader := buf.ReadByte(err) chunkSize := int(buf.ReadByte(err)) + if ctx.HasError() { + return ctx.TakeError() + } + if chunkSize == 0 || chunkSize > mapSize { + ctx.SetError(fory.DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, mapSize)) + return ctx.TakeError() + } trackKeyRef := (kvHeader & 0x1) != 0 keyNotDeclared := (kvHeader & 0x4) != 0 trackValueRef := (kvHeader & 0x8) != 0 @@ -678,15 +706,28 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) v.IntMap = nil } else { mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } if mapLen == 0 { v.IntMap = make(map[int]int) } else { + if !buf.CheckReadable(mapLen, err) { + return ctx.TakeError() + } v.IntMap = make(map[int]int, mapLen) mapSize := mapLen for mapSize > 0 { // ReadData KV header kvHeader := buf.ReadByte(err) chunkSize := int(buf.ReadByte(err)) + if ctx.HasError() { + return ctx.TakeError() + } + if chunkSize == 0 || chunkSize > mapSize { + ctx.SetError(fory.DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, mapSize)) + return ctx.TakeError() + } trackKeyRef := (kvHeader & 0x1) != 0 keyNotDeclared := (kvHeader & 0x4) != 0 trackValueRef := (kvHeader & 0x8) != 0 @@ -717,12 +758,22 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if mapLen == 0 { v.MixedMap = make(map[string]int) } else { + if !buf.CheckReadable(mapLen, err) { + return ctx.TakeError() + } v.MixedMap = make(map[string]int, mapLen) mapSize := mapLen for mapSize > 0 { // ReadData KV header kvHeader := buf.ReadByte(err) chunkSize := int(buf.ReadByte(err)) + if ctx.HasError() { + return ctx.TakeError() + } + if chunkSize == 0 || chunkSize > mapSize { + ctx.SetError(fory.DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, mapSize)) + return ctx.TakeError() + } trackKeyRef := (kvHeader & 0x1) != 0 keyNotDeclared := (kvHeader & 0x4) != 0 trackValueRef := (kvHeader & 0x8) != 0 @@ -748,15 +799,28 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) v.MixedMap = nil } else { mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } if mapLen == 0 { v.MixedMap = make(map[string]int) } else { + if !buf.CheckReadable(mapLen, err) { + return ctx.TakeError() + } v.MixedMap = make(map[string]int, mapLen) mapSize := mapLen for mapSize > 0 { // ReadData KV header kvHeader := buf.ReadByte(err) chunkSize := int(buf.ReadByte(err)) + if ctx.HasError() { + return ctx.TakeError() + } + if chunkSize == 0 || chunkSize > mapSize { + ctx.SetError(fory.DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, mapSize)) + return ctx.TakeError() + } trackKeyRef := (kvHeader & 0x1) != 0 keyNotDeclared := (kvHeader & 0x4) != 0 trackValueRef := (kvHeader & 0x8) != 0 @@ -787,12 +851,22 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) if mapLen == 0 { v.StringMap = make(map[string]string) } else { + if !buf.CheckReadable(mapLen, err) { + return ctx.TakeError() + } v.StringMap = make(map[string]string, mapLen) mapSize := mapLen for mapSize > 0 { // ReadData KV header kvHeader := buf.ReadByte(err) chunkSize := int(buf.ReadByte(err)) + if ctx.HasError() { + return ctx.TakeError() + } + if chunkSize == 0 || chunkSize > mapSize { + ctx.SetError(fory.DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, mapSize)) + return ctx.TakeError() + } trackKeyRef := (kvHeader & 0x1) != 0 keyNotDeclared := (kvHeader & 0x4) != 0 trackValueRef := (kvHeader & 0x8) != 0 @@ -818,15 +892,28 @@ func (g *MapDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *MapDemo) v.StringMap = nil } else { mapLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } if mapLen == 0 { v.StringMap = make(map[string]string) } else { + if !buf.CheckReadable(mapLen, err) { + return ctx.TakeError() + } v.StringMap = make(map[string]string, mapLen) mapSize := mapLen for mapSize > 0 { // ReadData KV header kvHeader := buf.ReadByte(err) chunkSize := int(buf.ReadByte(err)) + if ctx.HasError() { + return ctx.TakeError() + } + if chunkSize == 0 || chunkSize > mapSize { + ctx.SetError(fory.DeserializationErrorf("invalid map chunk size %d for remaining length %d", chunkSize, mapSize)) + return ctx.TakeError() + } trackKeyRef := (kvHeader & 0x1) != 0 keyNotDeclared := (kvHeader & 0x4) != 0 trackValueRef := (kvHeader & 0x8) != 0 @@ -1167,8 +1254,14 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD v.BoolSlice = make([]bool, 0) } else { collectFlag := buf.ReadInt8(err) + if ctx.HasError() { + return ctx.TakeError() + } // Check if CollectionIsDeclElementType is set (bit 2, value 4) hasDeclType := (collectFlag & 4) != 0 + if !buf.CheckReadable(sliceLen, err) { + return ctx.TakeError() + } v.BoolSlice = make([]bool, sliceLen) if hasDeclType { // Elements are written directly without type IDs @@ -1193,12 +1286,21 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD v.BoolSlice = nil } else { sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } if sliceLen == 0 { v.BoolSlice = make([]bool, 0) } else { collectFlag := buf.ReadInt8(err) + if ctx.HasError() { + return ctx.TakeError() + } // Check if CollectionIsDeclElementType is set (bit 2, value 4) hasDeclType := (collectFlag & 4) != 0 + if !buf.CheckReadable(sliceLen, err) { + return ctx.TakeError() + } v.BoolSlice = make([]bool, sliceLen) if hasDeclType { // Elements are written directly without type IDs @@ -1229,8 +1331,14 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD v.FloatSlice = make([]float64, 0) } else { collectFlag := buf.ReadInt8(err) + if ctx.HasError() { + return ctx.TakeError() + } // Check if CollectionIsDeclElementType is set (bit 2, value 4) hasDeclType := (collectFlag & 4) != 0 + if !buf.CheckReadable(sliceLen, err) { + return ctx.TakeError() + } v.FloatSlice = make([]float64, sliceLen) if hasDeclType { // Elements are written directly without type IDs @@ -1255,12 +1363,21 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD v.FloatSlice = nil } else { sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } if sliceLen == 0 { v.FloatSlice = make([]float64, 0) } else { collectFlag := buf.ReadInt8(err) + if ctx.HasError() { + return ctx.TakeError() + } // Check if CollectionIsDeclElementType is set (bit 2, value 4) hasDeclType := (collectFlag & 4) != 0 + if !buf.CheckReadable(sliceLen, err) { + return ctx.TakeError() + } v.FloatSlice = make([]float64, sliceLen) if hasDeclType { // Elements are written directly without type IDs @@ -1291,8 +1408,14 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD v.IntSlice = make([]int32, 0) } else { collectFlag := buf.ReadInt8(err) + if ctx.HasError() { + return ctx.TakeError() + } // Check if CollectionIsDeclElementType is set (bit 2, value 4) hasDeclType := (collectFlag & 4) != 0 + if !buf.CheckReadable(sliceLen, err) { + return ctx.TakeError() + } v.IntSlice = make([]int32, sliceLen) if hasDeclType { // Elements are written directly without type IDs @@ -1317,12 +1440,21 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD v.IntSlice = nil } else { sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } if sliceLen == 0 { v.IntSlice = make([]int32, 0) } else { collectFlag := buf.ReadInt8(err) + if ctx.HasError() { + return ctx.TakeError() + } // Check if CollectionIsDeclElementType is set (bit 2, value 4) hasDeclType := (collectFlag & 4) != 0 + if !buf.CheckReadable(sliceLen, err) { + return ctx.TakeError() + } v.IntSlice = make([]int32, sliceLen) if hasDeclType { // Elements are written directly without type IDs @@ -1353,10 +1485,16 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD v.StringSlice = make([]string, 0) } else { collectFlag := buf.ReadInt8(err) + if ctx.HasError() { + return ctx.TakeError() + } // Check if CollectionIsDeclElementType is set (bit 2, value 4) hasDeclType := (collectFlag & 4) != 0 // Check if CollectionTrackingRef is set (bit 0, value 1) trackRefs := (collectFlag & 1) != 0 + if !buf.CheckReadable(sliceLen, err) { + return ctx.TakeError() + } v.StringSlice = make([]string, sliceLen) if hasDeclType { // Elements are written directly without type IDs @@ -1387,14 +1525,23 @@ func (g *SliceDemo_ForyGenSerializer) ReadTyped(ctx *fory.ReadContext, v *SliceD v.StringSlice = nil } else { sliceLen := ctx.ReadCollectionLength() + if ctx.HasError() { + return ctx.TakeError() + } if sliceLen == 0 { v.StringSlice = make([]string, 0) } else { collectFlag := buf.ReadInt8(err) + if ctx.HasError() { + return ctx.TakeError() + } // Check if CollectionIsDeclElementType is set (bit 2, value 4) hasDeclType := (collectFlag & 4) != 0 // Check if CollectionTrackingRef is set (bit 0, value 1) trackRefs := (collectFlag & 1) != 0 + if !buf.CheckReadable(sliceLen, err) { + return ctx.TakeError() + } v.StringSlice = make([]string, sliceLen) if hasDeclType { // Elements are written directly without type IDs diff --git a/go/fory/type_def.go b/go/fory/type_def.go index 7520c42a6a..f654e7b8f4 100644 --- a/go/fory/type_def.go +++ b/go/fory/type_def.go @@ -1015,10 +1015,6 @@ func decodeTypeDef(fory *Fory, buffer *ByteBuffer, header int64) (*TypeDef, erro extraMetaSize = int(extra) metaSize += extraMetaSize } - if metaSize > fory.config.MaxBinarySize { - return nil, MaxBinarySizeExceededError(metaSize, fory.config.MaxBinarySize) - } - // Store the encoded bytes for the TypeDef (including meta header and metadata) encodedMeta := buffer.ReadBinary(metaSize, &bufErr) if bufErr.HasError() { diff --git a/go/fory/type_def_test.go b/go/fory/type_def_test.go index 71ccb7cf6a..a3f1bb0cc3 100644 --- a/go/fory/type_def_test.go +++ b/go/fory/type_def_test.go @@ -437,20 +437,10 @@ func TestTypeDefHeaderHashIncludesHeaderLowBits(t *testing.T) { require.Contains(t, err.Error(), "metadata hash") } -func TestTypeDefRejectsEncodedMetadataAboveMaxBinarySize(t *testing.T) { - fory := NewFory(WithXlang(false), WithMaxBinarySize(1), WithCompatible(false)) - body := typeDefTestBodyWithoutFields() - frame, header := typeDefTestFrame(t, body, false) - - _, err := decodeTypeDef(fory, frame, header) - require.Error(t, err) - require.Contains(t, err.Error(), "max binary size exceeded") -} - func TestTypeDefRejectsCompressedMetadata(t *testing.T) { decoded := typeDefTestBodyWithoutFields() compressed := deflateTypeDefTestBody(t, decoded) - fory := NewFory(WithXlang(false), WithMaxBinarySize(4096), WithCompatible(false)) + fory := NewFory(WithXlang(false), WithCompatible(false)) frame, header := typeDefTestFrame(t, compressed, true) _, err := decodeTypeDef(fory, frame, header) diff --git a/java/fory-core/src/main/java/org/apache/fory/config/Config.java b/java/fory-core/src/main/java/org/apache/fory/config/Config.java index 0b2a0b684f..8c9f18d821 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/Config.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/Config.java @@ -64,8 +64,6 @@ public class Config implements Serializable { private final boolean serializeEnumByName; private final int bufferSizeLimitBytes; private final int maxDepth; - private final int maxBinarySize; - private final int maxCollectionSize; private final float mapRefLoadFactor; private final boolean forVirtualThread; @@ -108,8 +106,6 @@ public Config(ForyBuilder builder) { serializeEnumByName = builder.serializeEnumByName; bufferSizeLimitBytes = builder.bufferSizeLimitBytes; maxDepth = builder.maxDepth; - maxBinarySize = builder.maxBinarySize; - maxCollectionSize = builder.maxCollectionSize; mapRefLoadFactor = builder.mapRefLoadFactor; forVirtualThread = builder.forVirtualThread; } @@ -296,16 +292,6 @@ public int maxDepth() { return maxDepth; } - /** Returns max binary payload size for attacker-controlled binary and primitive-array lengths. */ - public int maxBinarySize() { - return maxBinarySize; - } - - /** Returns max collection allocation size for attacker-controlled collection lengths. */ - public int maxCollectionSize() { - return maxCollectionSize; - } - /** Returns loadFactor of MacRef's writtenObjects. */ public float mapRefLoadFactor() { return mapRefLoadFactor; @@ -340,8 +326,6 @@ public boolean equals(Object o) { && compressIntArray == config.compressIntArray && compressLongArray == config.compressLongArray && bufferSizeLimitBytes == config.bufferSizeLimitBytes - && maxBinarySize == config.maxBinarySize - && maxCollectionSize == config.maxCollectionSize && requireClassRegistration == config.requireClassRegistration && suppressClassRegistrationWarnings == config.suppressClassRegistrationWarnings && registerGuavaTypes == config.registerGuavaTypes @@ -380,8 +364,6 @@ public int hashCode() { compressIntArray, compressLongArray, bufferSizeLimitBytes, - maxBinarySize, - maxCollectionSize, requireClassRegistration, suppressClassRegistrationWarnings, registerGuavaTypes, diff --git a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java index 10f25d9d22..93bd88aabb 100644 --- a/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java +++ b/java/fory-core/src/main/java/org/apache/fory/config/ForyBuilder.java @@ -99,8 +99,6 @@ public final class ForyBuilder { Integer bufferSizeLimitBytes = -1; MetaCompressor metaCompressor = new DeflaterMetaCompressor(); int maxDepth = 50; - int maxBinarySize = 64 * 1024 * 1024; - int maxCollectionSize = 1_000_000; float mapRefLoadFactor = 0.51f; boolean forVirtualThread = false; TypeChecker typeChecker; @@ -513,30 +511,6 @@ public ForyBuilder withMaxDepth(int maxDepth) { return this; } - /** - * Set max binary payload size for deserialization. Binary and primitive-array byte lengths above - * this limit are rejected before allocation. Default max binary size is 64 MiB. - */ - public ForyBuilder withMaxBinarySize(int maxBinarySize) { - Preconditions.checkArgument( - maxBinarySize >= 0, "maxBinarySize must >= 0 but got %s", maxBinarySize); - this.maxBinarySize = maxBinarySize; - recordAction(b -> b.withMaxBinarySize(maxBinarySize)); - return this; - } - - /** - * Set max collection size for deserialization. Collection lengths and collection capacity fields - * above this limit are rejected before allocation. Default max collection size is 1,000,000. - */ - public ForyBuilder withMaxCollectionSize(int maxCollectionSize) { - Preconditions.checkArgument( - maxCollectionSize >= 0, "maxCollectionSize must >= 0 but got %s", maxCollectionSize); - this.maxCollectionSize = maxCollectionSize; - recordAction(b -> b.withMaxCollectionSize(maxCollectionSize)); - return this; - } - /** Set loadFactor of MapRefResolver writtenObjects. Default value is 0.51 */ public ForyBuilder withMapRefLoadFactor(float loadFactor) { Preconditions.checkArgument( diff --git a/java/fory-core/src/main/java/org/apache/fory/io/ForyInputStream.java b/java/fory-core/src/main/java/org/apache/fory/io/ForyInputStream.java index 16f1d6d2ef..3b2bcd0f26 100644 --- a/java/fory-core/src/main/java/org/apache/fory/io/ForyInputStream.java +++ b/java/fory-core/src/main/java/org/apache/fory/io/ForyInputStream.java @@ -50,37 +50,61 @@ public ForyInputStream(InputStream stream, int bufferSize) { @Override public int fillBuffer(int minFillSize) { MemoryBuffer buffer = this.buffer; - byte[] heapMemory = buffer.getHeapMemory(); - int offset = buffer.size(); - if (offset + minFillSize > heapMemory.length) { - heapMemory = growBuffer(minFillSize, buffer); + if (minFillSize < 0) { + throw new IndexOutOfBoundsException("Negative minimum fill size " + minFillSize); + } + if (minFillSize == 0) { + return 0; } + int totalRead = 0; + boolean checkedAvailable = false; try { - int read; - int len = heapMemory.length - offset; - read = stream.read(heapMemory, offset, len); - while (read < minFillSize) { - int newRead = stream.read(heapMemory, offset + read, len - read); - if (newRead < 0) { + while (totalRead < minFillSize) { + byte[] heapMemory = buffer.getHeapMemory(); + int offset = buffer.size(); + int remainingNeeded = minFillSize - totalRead; + long targetSize = (long) offset + remainingNeeded; + if (targetSize > Integer.MAX_VALUE - 8L) { + throw new IndexOutOfBoundsException("Stream buffer size exceeds supported range"); + } + if (targetSize > heapMemory.length) { + int newSize = 0; + if (!checkedAvailable) { + checkedAvailable = true; + // Use available() only as a one-shot growth hint. It may be expensive or + // conservative, so failed hints fall back to bounded doubling. Final value + // allocation still waits for fillBuffer to complete successfully. + if (stream.available() >= remainingNeeded) { + newSize = (int) targetSize; + } + } + if (newSize == 0 && offset == heapMemory.length) { + newSize = nextBufferSize(heapMemory.length, (int) targetSize); + } + if (newSize != 0) { + heapMemory = growBuffer(buffer, newSize); + } + } + int read = stream.read(heapMemory, offset, heapMemory.length - offset); + if (read <= 0) { throw new IndexOutOfBoundsException("No enough data in the stream " + stream); } - read += newRead; + if (read > 0) { + buffer.increaseSize(read); + totalRead += read; + } } - buffer.increaseSize(read); - return read; + return totalRead; } catch (IOException e) { throw new RuntimeException(e); } } - private static byte[] growBuffer(int minFillSize, MemoryBuffer buffer) { - int newSize; + private byte[] growBuffer(MemoryBuffer buffer, int newSize) { int size = buffer.size(); - int targetSize = size + minFillSize; - newSize = - targetSize < MemoryBuffer.BUFFER_GROW_STEP_THRESHOLD - ? targetSize << 2 - : (int) Math.min(targetSize * 1.5d, Integer.MAX_VALUE - 8); + if (newSize <= size) { + throw new IndexOutOfBoundsException("Stream buffer size exceeds supported range"); + } byte[] newBuffer = new byte[newSize]; byte[] heapMemory = buffer.getHeapMemory(); System.arraycopy(heapMemory, 0, newBuffer, 0, size); @@ -89,6 +113,15 @@ private static byte[] growBuffer(int minFillSize, MemoryBuffer buffer) { return heapMemory; } + private static int nextBufferSize(int size, int targetSize) { + long grown = size == 0 ? 1L : (long) size << 1; + int maxSize = Integer.MAX_VALUE - 8; + if (grown > maxSize) { + grown = maxSize; + } + return (int) Math.min(grown, targetSize); + } + @Override public void readTo(byte[] dst, int dstIndex, int len) { MemoryBuffer buf = buffer; diff --git a/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java b/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java index 00b27886eb..b4b9b771be 100644 --- a/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java +++ b/java/fory-core/src/main/java/org/apache/fory/io/ForyReadableChannel.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ReadableByteChannel; +import java.nio.channels.SeekableByteChannel; import javax.annotation.concurrent.NotThreadSafe; import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; @@ -31,19 +32,38 @@ @NotThreadSafe public class ForyReadableChannel implements ForyStreamReader, ReadableByteChannel { private final ReadableByteChannel channel; + private final SeekableByteChannel seekableChannel; private final MemoryBuffer memoryBuffer; private ByteBuffer byteBuffer; public ForyReadableChannel(ReadableByteChannel channel) { this( channel, - AndroidSupport.IS_ANDROID ? ByteBuffer.allocate(4096) : ByteBuffer.allocateDirect(4096)); + AndroidSupport.IS_ANDROID ? ByteBuffer.allocate(4096) : ByteBuffer.allocateDirect(4096), + null); + } + + public ForyReadableChannel(SeekableByteChannel channel) { + this( + channel, + AndroidSupport.IS_ANDROID ? ByteBuffer.allocate(4096) : ByteBuffer.allocateDirect(4096), + channel); } public ForyReadableChannel(ReadableByteChannel channel, ByteBuffer buffer) { + this(channel, buffer, null); + } + + public ForyReadableChannel(SeekableByteChannel channel, ByteBuffer buffer) { + this(channel, buffer, channel); + } + + private ForyReadableChannel( + ReadableByteChannel channel, ByteBuffer buffer, SeekableByteChannel seekableChannel) { Preconditions.checkArgument( !buffer.isReadOnly(), "ForyReadableChannel requires writable ByteBuffer."); this.channel = channel; + this.seekableChannel = seekableChannel; if (AndroidSupport.IS_ANDROID && buffer.isDirect()) { buffer = ByteBuffer.allocate(buffer.capacity()); } @@ -62,32 +82,83 @@ public ForyReadableChannel(ReadableByteChannel channel, ByteBuffer buffer) { @Override public int fillBuffer(int minFillSize) { + if (minFillSize < 0) { + throw new DeserializationException("Negative minimum fill size " + minFillSize); + } + if (minFillSize == 0) { + return 0; + } try { - ByteBuffer byteBuf = byteBuffer; - MemoryBuffer memoryBuf = memoryBuffer; - int position = byteBuf.position(); - int newLimit = position + minFillSize; - if (newLimit > byteBuf.capacity()) { - int newSize = - newLimit < MemoryBuffer.BUFFER_GROW_STEP_THRESHOLD - ? newLimit << 2 - : (int) Math.min(newLimit * 1.5d, Integer.MAX_VALUE); - ByteBuffer newByteBuf = - byteBuf.isDirect() ? ByteBuffer.allocateDirect(newSize) : ByteBuffer.allocate(newSize); - byteBuf.position(0); - newByteBuf.put(byteBuf); - byteBuf = byteBuffer = newByteBuf; - memoryBuf.initByteBuffer(byteBuf, position); + int totalRead = 0; + SeekableByteChannel seekableChannel = this.seekableChannel; + boolean checkedSeekableRemaining = seekableChannel == null; + while (totalRead < minFillSize) { + ByteBuffer byteBuf = byteBuffer; + MemoryBuffer memoryBuf = memoryBuffer; + int position = byteBuf.position(); + int remainingNeeded = minFillSize - totalRead; + long targetSize = (long) position + remainingNeeded; + if (targetSize > Integer.MAX_VALUE) { + throw new DeserializationException("Stream buffer size exceeds supported range"); + } + if (targetSize > byteBuf.capacity()) { + int newCapacity = 0; + if (!checkedSeekableRemaining) { + checkedSeekableRemaining = true; + // Query exact channel remaining bytes only as a one-shot fast path. Otherwise grow + // from bytes already buffered so truncated channels fail before reserving the body. + if (seekableChannel.size() - seekableChannel.position() >= remainingNeeded) { + newCapacity = (int) targetSize; + } + } + if (newCapacity == 0 && position == byteBuf.capacity()) { + newCapacity = nextBufferSize(byteBuf.capacity(), (int) targetSize); + } + if (newCapacity != 0) { + byteBuf = growBuffer(byteBuf, memoryBuf, position, newCapacity); + } + } + byteBuf.limit(byteBuf.capacity()); + int read = channel.read(byteBuf); + if (read <= 0) { + throw new DeserializationException("Unexpected end of byte channel"); + } + totalRead += read; + memoryBuf.increaseSize(read); + byteBuf.limit(byteBuf.position()); } - byteBuf.limit(newLimit); - readFully(byteBuf, minFillSize); - memoryBuf.increaseSize(minFillSize); - return minFillSize; + return totalRead; } catch (IOException e) { throw new DeserializationException("Failed to read the provided byte channel", e); } } + private ByteBuffer growBuffer( + ByteBuffer byteBuf, MemoryBuffer memoryBuf, int position, int newCapacity) { + int oldCapacity = byteBuf.capacity(); + if (newCapacity <= oldCapacity) { + throw new DeserializationException("Stream buffer size exceeds supported range"); + } + ByteBuffer newByteBuf = + byteBuf.isDirect() + ? ByteBuffer.allocateDirect(newCapacity) + : ByteBuffer.allocate(newCapacity); + byteBuf.position(0); + byteBuf.limit(position); + newByteBuf.put(byteBuf); + byteBuffer = newByteBuf; + memoryBuf.initByteBuffer(newByteBuf, position); + return newByteBuf; + } + + private static int nextBufferSize(int oldCapacity, int targetSize) { + long grown = oldCapacity == 0 ? 1L : (long) oldCapacity << 1; + if (grown > Integer.MAX_VALUE) { + grown = Integer.MAX_VALUE; + } + return (int) Math.min(grown, targetSize); + } + @Override public int read(ByteBuffer dst) throws IOException { int length = dst.remaining(); diff --git a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java index 0b273e3976..c88cf66e1e 100644 --- a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java +++ b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryBuffer.java @@ -3113,6 +3113,7 @@ private long skipPadding(long pos, int b) { } public byte[] readBytes(int length) { + checkReadableBytes(length); int readerIdx = readerIndex; byte[] bytes = new byte[length]; // use subtract to avoid overflow @@ -3254,7 +3255,7 @@ public int readBinarySize() { } int diff = size - readIdx; if (diff < binarySize) { - streamReader.fillBuffer(diff); + streamReader.fillBuffer(binarySize - diff); } return binarySize; } @@ -3277,7 +3278,7 @@ private int continueReadBinarySize(int readIdx, int bulkRead, int binarySize) { } int diff = size - readIdx; if (diff < binarySize) { - streamReader.fillBuffer(diff); + streamReader.fillBuffer(binarySize - diff); } return binarySize; } @@ -3302,13 +3303,12 @@ public byte[] readBytesAndSize() { } /** - * Reads a size-validated primitive byte-array payload into {@code values}. The caller owns size - * validation and destination allocation; this method reads payload bytes only, not the size - * prefix. + * Reads a size-validated primitive byte-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size prefix. */ - public void readByteArrayPayload(byte[] values, int numBytes) { + public void readByteArrayBytes(byte[] values, int numBytes) { if (AndroidSupport.IS_ANDROID) { - MemoryOps.readByteArrayPayload(this, values, numBytes); + MemoryOps.readByteArrayBytes(this, values, numBytes); } else { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { @@ -3326,13 +3326,12 @@ public void readByteArrayPayload(byte[] values, int numBytes) { } /** - * Reads a size-validated primitive boolean-array payload into {@code values}. The caller owns - * size validation and destination allocation; this method reads payload bytes only, not the size - * prefix. + * Reads a size-validated primitive boolean-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size prefix. */ - public void readBooleanArrayPayload(boolean[] values, int numBytes) { + public void readBooleanArrayBytes(boolean[] values, int numBytes) { if (AndroidSupport.IS_ANDROID) { - MemoryOps.readBooleanArrayPayload(this, values, numBytes); + MemoryOps.readBooleanArrayBytes(this, values, numBytes); } else { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { @@ -3345,13 +3344,12 @@ public void readBooleanArrayPayload(boolean[] values, int numBytes) { } /** - * Reads a size-validated primitive char-array payload into {@code values}. The caller owns size - * validation and destination allocation; this method reads payload bytes only, not the size - * prefix. + * Reads a size-validated primitive char-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size prefix. */ - public void readCharArrayPayload(char[] values, int numBytes) { + public void readCharArrayBytes(char[] values, int numBytes) { if (AndroidSupport.IS_ANDROID) { - MemoryOps.readCharArrayPayload(this, values, numBytes); + MemoryOps.readCharArrayBytes(this, values, numBytes); } else { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { @@ -3364,13 +3362,12 @@ public void readCharArrayPayload(char[] values, int numBytes) { } /** - * Reads a size-validated primitive int16-array payload into {@code values}. The caller owns size - * validation and destination allocation; this method reads payload bytes only, not the size - * prefix. + * Reads a size-validated primitive int16-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size prefix. */ - public void readInt16ArrayPayload(short[] values, int numBytes) { + public void readInt16ArrayBytes(short[] values, int numBytes) { if (AndroidSupport.IS_ANDROID) { - MemoryOps.readInt16ArrayPayload(this, values, numBytes); + MemoryOps.readInt16ArrayBytes(this, values, numBytes); } else { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { @@ -3383,13 +3380,12 @@ public void readInt16ArrayPayload(short[] values, int numBytes) { } /** - * Reads a size-validated primitive int32-array payload into {@code values}. The caller owns size - * validation and destination allocation; this method reads payload bytes only, not the size - * prefix. + * Reads a size-validated primitive int32-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size prefix. */ - public void readInt32ArrayPayload(int[] values, int numBytes) { + public void readInt32ArrayBytes(int[] values, int numBytes) { if (AndroidSupport.IS_ANDROID) { - MemoryOps.readInt32ArrayPayload(this, values, numBytes); + MemoryOps.readInt32ArrayBytes(this, values, numBytes); } else { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { @@ -3402,13 +3398,12 @@ public void readInt32ArrayPayload(int[] values, int numBytes) { } /** - * Reads a size-validated primitive int64-array payload into {@code values}. The caller owns size - * validation and destination allocation; this method reads payload bytes only, not the size - * prefix. + * Reads a size-validated primitive int64-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size prefix. */ - public void readInt64ArrayPayload(long[] values, int numBytes) { + public void readInt64ArrayBytes(long[] values, int numBytes) { if (AndroidSupport.IS_ANDROID) { - MemoryOps.readInt64ArrayPayload(this, values, numBytes); + MemoryOps.readInt64ArrayBytes(this, values, numBytes); } else { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { @@ -3421,13 +3416,12 @@ public void readInt64ArrayPayload(long[] values, int numBytes) { } /** - * Reads a size-validated primitive float32-array payload into {@code values}. The caller owns - * size validation and destination allocation; this method reads payload bytes only, not the size - * prefix. + * Reads a size-validated primitive float32-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size prefix. */ - public void readFloat32ArrayPayload(float[] values, int numBytes) { + public void readFloat32ArrayBytes(float[] values, int numBytes) { if (AndroidSupport.IS_ANDROID) { - MemoryOps.readFloat32ArrayPayload(this, values, numBytes); + MemoryOps.readFloat32ArrayBytes(this, values, numBytes); } else { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { @@ -3440,13 +3434,12 @@ public void readFloat32ArrayPayload(float[] values, int numBytes) { } /** - * Reads a size-validated primitive float64-array payload into {@code values}. The caller owns - * size validation and destination allocation; this method reads payload bytes only, not the size - * prefix. + * Reads a size-validated primitive float64-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size prefix. */ - public void readFloat64ArrayPayload(double[] values, int numBytes) { + public void readFloat64ArrayBytes(double[] values, int numBytes) { if (AndroidSupport.IS_ANDROID) { - MemoryOps.readFloat64ArrayPayload(this, values, numBytes); + MemoryOps.readFloat64ArrayBytes(this, values, numBytes); } else { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { @@ -3508,6 +3501,11 @@ public void readChars(char[] chars, int offset, int numElements) { @CodegenInvoke public char[] readCharsAndSize() { final int numBytes = readBinarySize(); + if ((numBytes & 1) != 0) { + throw new IllegalArgumentException( + "Char array byte size " + numBytes + " is not aligned to element size 2"); + } + checkReadableBytes(numBytes); int numElements = numBytes >> 1; char[] values = new char[numElements]; readChars(values, 0, numElements); @@ -3638,10 +3636,14 @@ public void checkReadableBytes(int minimumReadableBytes) { // use subtract to avoid overflow int remaining = size - readerIndex; if (minimumReadableBytes > remaining) { - streamReader.fillBuffer(minimumReadableBytes - remaining); + fillReadableBytes(minimumReadableBytes, remaining); } } + private void fillReadableBytes(int minimumReadableBytes, int remaining) { + streamReader.fillBuffer(minimumReadableBytes - remaining); + } + /** * Returns internal byte array if data is on heap and remaining buffer size is equal to internal * byte array size, or create a new byte array which copy remaining data from off-heap. diff --git a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryOps.java b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryOps.java index 02722024ac..77984a3450 100644 --- a/java/fory-core/src/main/java/org/apache/fory/memory/MemoryOps.java +++ b/java/fory-core/src/main/java/org/apache/fory/memory/MemoryOps.java @@ -924,12 +924,12 @@ static int readBinarySize(MemoryBuffer buffer) { int binarySize = readVarUInt32(buffer); int diff = buffer.size - buffer.readerIndex; if (diff < binarySize) { - buffer.streamReader.fillBuffer(diff); + buffer.streamReader.fillBuffer(binarySize - diff); } return binarySize; } - static void readByteArrayPayload(MemoryBuffer buffer, byte[] values, int numBytes) { + static void readByteArrayBytes(MemoryBuffer buffer, byte[] values, int numBytes) { if (buffer.readerIndex > buffer.size - numBytes) { buffer.streamReader.readTo(values, 0, numBytes); return; @@ -939,7 +939,7 @@ static void readByteArrayPayload(MemoryBuffer buffer, byte[] values, int numByte buffer.readerIndex = readerIdx + numBytes; } - static void readBooleanArrayPayload(MemoryBuffer buffer, boolean[] values, int numBytes) { + static void readBooleanArrayBytes(MemoryBuffer buffer, boolean[] values, int numBytes) { int readerIdx = buffer.readerIndex; if (readerIdx > buffer.size - numBytes) { buffer.streamReader.readBooleans(values, 0, numBytes); @@ -953,7 +953,7 @@ static void readBooleanArrayPayload(MemoryBuffer buffer, boolean[] values, int n buffer.readerIndex = readerIdx + numBytes; } - static void readCharArrayPayload(MemoryBuffer buffer, char[] values, int numBytes) { + static void readCharArrayBytes(MemoryBuffer buffer, char[] values, int numBytes) { int readerIdx = buffer.readerIndex; if (readerIdx > buffer.size - numBytes) { buffer.streamReader.readChars(values, 0, numBytes >>> 1); @@ -968,7 +968,7 @@ static void readCharArrayPayload(MemoryBuffer buffer, char[] values, int numByte buffer.readerIndex = readerIdx + numBytes; } - static void readInt16ArrayPayload(MemoryBuffer buffer, short[] values, int numBytes) { + static void readInt16ArrayBytes(MemoryBuffer buffer, short[] values, int numBytes) { int readerIdx = buffer.readerIndex; if (readerIdx > buffer.size - numBytes) { buffer.streamReader.readShorts(values, 0, numBytes >>> 1); @@ -983,7 +983,7 @@ static void readInt16ArrayPayload(MemoryBuffer buffer, short[] values, int numBy buffer.readerIndex = readerIdx + numBytes; } - static void readInt32ArrayPayload(MemoryBuffer buffer, int[] values, int numBytes) { + static void readInt32ArrayBytes(MemoryBuffer buffer, int[] values, int numBytes) { int readerIdx = buffer.readerIndex; if (readerIdx > buffer.size - numBytes) { buffer.streamReader.readInts(values, 0, numBytes >>> 2); @@ -1002,7 +1002,7 @@ static void readInt32ArrayPayload(MemoryBuffer buffer, int[] values, int numByte buffer.readerIndex = readerIdx + numBytes; } - static void readInt64ArrayPayload(MemoryBuffer buffer, long[] values, int numBytes) { + static void readInt64ArrayBytes(MemoryBuffer buffer, long[] values, int numBytes) { int readerIdx = buffer.readerIndex; if (readerIdx > buffer.size - numBytes) { buffer.streamReader.readLongs(values, 0, numBytes >>> 3); @@ -1025,7 +1025,7 @@ static void readInt64ArrayPayload(MemoryBuffer buffer, long[] values, int numByt buffer.readerIndex = readerIdx + numBytes; } - static void readFloat32ArrayPayload(MemoryBuffer buffer, float[] values, int numBytes) { + static void readFloat32ArrayBytes(MemoryBuffer buffer, float[] values, int numBytes) { int readerIdx = buffer.readerIndex; if (readerIdx > buffer.size - numBytes) { buffer.streamReader.readFloats(values, 0, numBytes >>> 2); @@ -1045,7 +1045,7 @@ static void readFloat32ArrayPayload(MemoryBuffer buffer, float[] values, int num buffer.readerIndex = readerIdx + numBytes; } - static void readFloat64ArrayPayload(MemoryBuffer buffer, double[] values, int numBytes) { + static void readFloat64ArrayBytes(MemoryBuffer buffer, double[] values, int numBytes) { int readerIdx = buffer.readerIndex; if (readerIdx > buffer.size - numBytes) { buffer.streamReader.readDoubles(values, 0, numBytes >>> 3); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java index 24a6a251a9..9fe08fdfb5 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ArraySerializers.java @@ -47,17 +47,8 @@ public final class ArraySerializers { private ArraySerializers() {} - private static void throwObjectArraySizeLimitExceeded(int size, int maxCollectionSize) { - throw new DeserializationException( - "Object array size " + size + " exceeds max collection size " + maxCollectionSize); - } - - private static void throwInvalidObjectArraySize(int size, int maxCollectionSize) { - if (size < 0) { - throw new DeserializationException("Object array size must be non-negative: " + size); - } else { - throwObjectArraySizeLimitExceeded(size, maxCollectionSize); - } + private static void throwInvalidObjectArraySize(int size) { + throw new DeserializationException("Object array size must be non-negative: " + size); } /** @@ -99,7 +90,6 @@ public static Serializer newObjectArraySerializer(TypeResolver typeResolver, public static final class ObjectArraySerializer extends Serializer { private final TypeResolver typeResolver; private final TypeInfoHolder elementTypeInfoHolder; - private final int maxCollectionSize; public ObjectArraySerializer(TypeResolver typeResolver, Class cls) { super(typeResolver.getConfig(), (Class) cls); @@ -109,7 +99,6 @@ public ObjectArraySerializer(TypeResolver typeResolver, Class cls) { } Preconditions.checkArgument(cls.isArray() && !cls.getComponentType().isPrimitive()); elementTypeInfoHolder = typeResolver.nilTypeInfoHolder(); - maxCollectionSize = typeResolver.getConfig().maxCollectionSize(); } @Override @@ -143,9 +132,10 @@ public Object[] read(ReadContext readContext) { int numElements = buffer.readVarUInt32Small7(); // Keep this as direct primitive branches. Object-array reads allocate immediately; using // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. - if (numElements < 0 || numElements > maxCollectionSize) { - throwInvalidObjectArraySize(numElements, maxCollectionSize); + if (numElements < 0) { + throwInvalidObjectArraySize(numElements); } + buffer.checkReadableBytes(numElements); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { @@ -179,7 +169,6 @@ private abstract static class SameTypeObjectArraySerializer extends Serializer componentType; private final Serializer elementSerializer; private final TypeInfoHolder elementTypeInfoHolder; - private final int maxCollectionSize; SameTypeObjectArraySerializer( TypeResolver typeResolver, Class arrayType, Class componentType) { @@ -191,7 +180,6 @@ private abstract static class SameTypeObjectArraySerializer extends Serializer maxCollectionSize) { - throwInvalidObjectArraySize(numElements, maxCollectionSize); + if (numElements < 0) { + throwInvalidObjectArraySize(numElements); } + buffer.checkReadableBytes(numElements); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { @@ -644,7 +633,6 @@ public static final class UnknownArraySerializer extends Serializer { private final String className; private final TypeResolver typeResolver; private final TypeInfoHolder elementTypeInfoHolder; - private final int maxCollectionSize; public UnknownArraySerializer(TypeResolver typeResolver, Class cls) { this(typeResolver, "Unknown", cls); @@ -656,7 +644,6 @@ public UnknownArraySerializer(TypeResolver typeResolver, String className, Class this.className = className; this.typeResolver = typeResolver; elementTypeInfoHolder = typeResolver.nilTypeInfoHolder(); - maxCollectionSize = typeResolver.getConfig().maxCollectionSize(); } @Override @@ -671,9 +658,10 @@ public Object[] read(ReadContext readContext) { int numElements = buffer.readVarUInt32Small7(); // Keep this as direct primitive branches. Object-array reads allocate immediately; using // Preconditions.checkArgument here would add helper/varargs overhead on the valid path. - if (numElements < 0 || numElements > maxCollectionSize) { - throwInvalidObjectArraySize(numElements, maxCollectionSize); + if (numElements < 0) { + throwInvalidObjectArraySize(numElements); } + buffer.checkReadableBytes(numElements); Object[] value = newArray(numElements); readContext.reference(value); if (numElements != 0) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/BigIntegerSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/BigIntegerSerializer.java index 6378e5e72f..650746d288 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/BigIntegerSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/BigIntegerSerializer.java @@ -30,12 +30,10 @@ public final class BigIntegerSerializer extends ImmutableSerializer implements Shareable { private final boolean xlang; - private final int maxBinarySize; public BigIntegerSerializer(Config config) { super(config, BigInteger.class); xlang = config.isXlang(); - maxBinarySize = config.maxBinarySize(); } @Override @@ -65,7 +63,7 @@ private void writeNative(WriteContext writeContext, BigInteger value) { private BigInteger readNative(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int len = buffer.readVarUInt32Small7(); - checkBinaryPayloadLength(len, maxBinarySize); + checkBinaryBodyLength(len); buffer.checkReadableBytes(len); byte[] bytes = buffer.readBytes(len); return new BigInteger(bytes); @@ -76,16 +74,12 @@ private void writeXlang(WriteContext writeContext, BigInteger value) { } private BigInteger readXlang(ReadContext readContext) { - return DecimalSerializer.readXlangBigInteger(readContext.getBuffer(), maxBinarySize); + return DecimalSerializer.readXlangBigInteger(readContext.getBuffer()); } - private static void checkBinaryPayloadLength(int len, int maxBinarySize) { + private static void checkBinaryBodyLength(int len) { if (len <= 0) { - throw new DeserializationException("BigInteger payload length must be positive: " + len); - } - if (len > maxBinarySize) { - throw new DeserializationException( - "BigInteger payload length " + len + " exceeds max binary size " + maxBinarySize); + throw new DeserializationException("BigInteger body length must be positive: " + len); } } } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java index 7f4a8856e9..35eeca550a 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/CompatibleCollectionArrayReader.java @@ -36,7 +36,6 @@ import org.apache.fory.collection.UInt32List; import org.apache.fory.collection.UInt64List; import org.apache.fory.collection.UInt8List; -import org.apache.fory.config.Config; import org.apache.fory.context.ReadContext; import org.apache.fory.context.RefReader; import org.apache.fory.exception.DeserializationException; @@ -101,8 +100,8 @@ static ReadAction readAction(TypeResolver resolver, Descriptor descriptor) { int untrackedPeerListElementTypeId = untrackedListElementTypeId(descriptor); int localListElementTypeId = untrackedListElementTypeId(localFieldType); int peerArrayTypeId = denseArrayTypeId(peerListElementTypeId); - // Actual null or ref-tracked payload elements are rejected by - // readListPayloadAsPrimitiveArray. + // Actual null or ref-tracked body elements are rejected by + // readListBodyAsPrimitiveArray. if (untrackedPeerListElementTypeId != Types.UNKNOWN && localListElementTypeId != Types.UNKNOWN && peerArrayTypeId != Types.UNKNOWN @@ -155,8 +154,8 @@ static ReadAction readAction( int untrackedPeerListElementTypeId = untrackedListElementTypeId(remoteFieldType); int localListElementTypeId = listElementTypeId(localType); int peerArrayTypeId = denseArrayTypeId(peerListElementTypeId); - // Actual null or ref-tracked payload elements are rejected by - // readListPayloadAsPrimitiveArray. + // Actual null or ref-tracked body elements are rejected by + // readListBodyAsPrimitiveArray. if (untrackedPeerListElementTypeId != Types.UNKNOWN && localListElementTypeId != Types.UNKNOWN && peerArrayTypeId != Types.UNKNOWN @@ -340,21 +339,21 @@ private static Object readNotNull( int elementTypeId, Class targetType) { if (readMode == READ_LIST_TO_ARRAY) { - Object array = readListPayloadAsPrimitiveArray(readContext, arrayTypeId, elementTypeId); + Object array = readListBodyAsPrimitiveArray(readContext, arrayTypeId, elementTypeId); if (array == null) { return null; } return materializeTarget(array, arrayTypeId, targetType); } if (readMode == READ_LIST_TO_LIST) { - return readListPayloadAsListTarget(readContext, arrayTypeId, elementTypeId, targetType); + return readListBodyAsListTarget(readContext, arrayTypeId, elementTypeId, targetType); } if (readMode == READ_ARRAY_TO_LIST) { - Object array = readDenseArrayPayload(readContext, arrayTypeId); + Object array = readDenseArrayBody(readContext, arrayTypeId); return materializeTarget(array, arrayTypeId, targetType); } if (readMode == READ_ARRAY_TO_ARRAY) { - Object array = readDenseArrayPayload(readContext, arrayTypeId); + Object array = readDenseArrayBody(readContext, arrayTypeId); return materializeTarget(array, arrayTypeId, targetType); } throw new IllegalStateException("Unexpected compatible read mode " + readMode); @@ -373,7 +372,7 @@ private static int listElementTypeId(FieldTypes.FieldType fieldType, boolean req ((FieldTypes.CollectionFieldType) fieldType).getElementType(); if (elementType instanceof FieldTypes.RegisteredFieldType) { // Nullable element schema is allowed for list -> array compatibility; - // actual null payload elements are rejected by the dense-array reader. + // actual null body elements are rejected by the dense-array reader. if (requireUntracked && elementType.trackingRef()) { return Types.UNKNOWN; } @@ -398,13 +397,13 @@ private static int listElementTypeId(Descriptor descriptor, boolean requireUntra int typeId = extMeta.typeId(); if (Types.isPrimitiveArray(typeId)) { // A compatible descriptor can keep the local primitive-list carrier while the remote - // TypeDef says the peer wrote a dense array payload. Treat the TypeExtMeta as the remote + // TypeDef says the peer wrote a dense array body. Treat the TypeExtMeta as the remote // wire shape here; otherwise array->list reads are misclassified as list->list reads. return Types.UNKNOWN; } if (Types.isPrimitiveType(typeId) && (!requireUntracked || !extMeta.trackingRef())) { // Nullable element metadata is not a schema-pair rejection. The - // dense-array read path fails only when the payload contains nulls. + // dense-array read path fails only when the body contains nulls. return typeId; } } @@ -443,13 +442,13 @@ private static int listElementTypeId(TypeRef typeRef, boolean requireUntracke int typeId = extMeta.typeId(); if (Types.isPrimitiveArray(typeId)) { // A compatible descriptor can keep the local primitive-list raw carrier while the remote - // TypeDef says the peer wrote a dense array payload. Treat the TypeExtMeta as the remote + // TypeDef says the peer wrote a dense array body. Treat the TypeExtMeta as the remote // wire shape here; otherwise array->list reads are misclassified as list->list reads. return Types.UNKNOWN; } if (Types.isPrimitiveType(typeId) && (!requireUntracked || !extMeta.trackingRef())) { // Nullable element metadata is not a schema-pair rejection. The - // dense-array read path fails only when the payload contains nulls. + // dense-array read path fails only when the body contains nulls. return typeId; } } @@ -477,7 +476,7 @@ private static int untrackedListElementTypeId(TypeRef typeRef) { } private static boolean isPrimitiveElement(TypeExtMeta elementExtMeta, boolean requireUntracked) { - // Nullable element metadata is allowed; actual null payload elements fail while reading. + // Nullable element metadata is allowed; actual null body elements fail while reading. return elementExtMeta != null && Types.isPrimitiveType(elementExtMeta.typeId()) && (!requireUntracked || !elementExtMeta.trackingRef()); @@ -580,12 +579,11 @@ private static int denseArrayTypeId(int elementTypeId) { } } - private static Object readListPayloadAsPrimitiveArray( + private static Object readListBodyAsPrimitiveArray( ReadContext readContext, int arrayTypeId, int elementTypeId) { MemoryBuffer buffer = readContext.getBuffer(); int numElements = buffer.readVarUInt32Small7(); - validateElementCount(readContext.getConfig(), numElements); - validateElementStorageSize(readContext.getConfig(), numElements, elementSize(arrayTypeId)); + validateElementCount(numElements); if (numElements > 0) { int flags = buffer.readByte(); boolean hasNull = (flags & CollectionFlags.HAS_NULL) == CollectionFlags.HAS_NULL; @@ -595,18 +593,18 @@ private static Object readListPayloadAsPrimitiveArray( (flags & CollectionFlags.IS_DECL_ELEMENT_TYPE) == CollectionFlags.IS_DECL_ELEMENT_TYPE; if (trackingRef) { throw new DeserializationException( - "Cannot read ref-tracked peer list payload into local array field"); + "Cannot read ref-tracked peer list body into local array field"); } if (!sameType) { throw new DeserializationException( - "Cannot read peer list payload into local array field"); + "Cannot read peer list body into local array field"); } if (!declared) { - TypeInfo payloadElementTypeInfo = readContext.getTypeResolver().readTypeInfo(readContext); - if (payloadElementTypeInfo.getTypeId() != elementTypeId) { + TypeInfo bodyElementTypeInfo = readContext.getTypeResolver().readTypeInfo(readContext); + if (bodyElementTypeInfo.getTypeId() != elementTypeId) { throw new DeserializationException( "Cannot read peer list element type id " - + payloadElementTypeInfo.getTypeId() + + bodyElementTypeInfo.getTypeId() + " as local element type id " + elementTypeId); } @@ -616,12 +614,11 @@ private static Object readListPayloadAsPrimitiveArray( return readListPrimitiveElements(buffer, numElements, arrayTypeId, elementTypeId, false); } - private static Object readListPayloadAsListTarget( + private static Object readListBodyAsListTarget( ReadContext readContext, int arrayTypeId, int elementTypeId, Class targetType) { MemoryBuffer buffer = readContext.getBuffer(); int numElements = buffer.readVarUInt32Small7(); - validateElementCount(readContext.getConfig(), numElements); - validateElementStorageSize(readContext.getConfig(), numElements, elementSize(arrayTypeId)); + validateElementCount(numElements); if (numElements == 0) { Object array = readListPrimitiveElements(buffer, 0, arrayTypeId, elementTypeId, false); return materializeTarget(array, arrayTypeId, targetType); @@ -634,18 +631,17 @@ private static Object readListPayloadAsListTarget( (flags & CollectionFlags.IS_DECL_ELEMENT_TYPE) == CollectionFlags.IS_DECL_ELEMENT_TYPE; if (trackingRef) { throw new DeserializationException( - "Cannot read ref-tracked peer list payload into local list field"); + "Cannot read ref-tracked peer list body into local list field"); } if (!sameType) { - throw new DeserializationException( - "Cannot read peer list payload into local list field"); + throw new DeserializationException("Cannot read peer list body into local list field"); } if (!declared) { - TypeInfo payloadElementTypeInfo = readContext.getTypeResolver().readTypeInfo(readContext); - if (payloadElementTypeInfo.getTypeId() != elementTypeId) { + TypeInfo bodyElementTypeInfo = readContext.getTypeResolver().readTypeInfo(readContext); + if (bodyElementTypeInfo.getTypeId() != elementTypeId) { throw new DeserializationException( "Cannot read peer list element type id " - + payloadElementTypeInfo.getTypeId() + + bodyElementTypeInfo.getTypeId() + " as local element type id " + elementTypeId); } @@ -653,7 +649,7 @@ private static Object readListPayloadAsListTarget( if (hasNull) { // Nullable LIST element metadata is not a schema-pair rejection. Only boxed list targets can // preserve actual null elements; dense primitive array/list targets fail while reading the - // nullable payload because they cannot represent null elements. + // nullable body because they cannot represent null elements. if (!targetType.isAssignableFrom(ArrayList.class)) { throw new DeserializationException( "Cannot read null peer list element into local list field"); @@ -665,11 +661,12 @@ private static Object readListPayloadAsListTarget( return materializeTarget(array, arrayTypeId, targetType); } - private static Object readDenseArrayPayload(ReadContext readContext, int arrayTypeId) { + private static Object readDenseArrayBody(ReadContext readContext, int arrayTypeId) { MemoryBuffer buffer = readContext.getBuffer(); int byteSize = buffer.readVarUInt32Small7(); int elemSize = elementSize(arrayTypeId); - validateBinarySize(readContext.getConfig(), buffer, byteSize, elemSize); + validateBinarySize(byteSize, elemSize); + buffer.checkReadableBytes(byteSize); return readPrimitiveElements(buffer, byteSize, byteSize / elemSize, arrayTypeId); } @@ -679,14 +676,14 @@ private static Object readPrimitiveElements( case Types.BOOL_ARRAY: { boolean[] values = new boolean[numElements]; - buffer.readBooleanArrayPayload(values, byteSize); + buffer.readBooleanArrayBytes(values, byteSize); return values; } case Types.INT8_ARRAY: case Types.UINT8_ARRAY: { byte[] values = new byte[numElements]; - buffer.readByteArrayPayload(values, byteSize); + buffer.readByteArrayBytes(values, byteSize); return values; } case Types.INT16_ARRAY: @@ -696,7 +693,7 @@ private static Object readPrimitiveElements( { short[] values = new short[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt16ArrayPayload(values, byteSize); + buffer.readInt16ArrayBytes(values, byteSize); } else { for (int i = 0; i < numElements; i++) { values[i] = buffer.readInt16(); @@ -709,7 +706,7 @@ private static Object readPrimitiveElements( { int[] values = new int[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt32ArrayPayload(values, byteSize); + buffer.readInt32ArrayBytes(values, byteSize); } else { for (int i = 0; i < numElements; i++) { values[i] = buffer.readInt32(); @@ -722,7 +719,7 @@ private static Object readPrimitiveElements( { long[] values = new long[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt64ArrayPayload(values, byteSize); + buffer.readInt64ArrayBytes(values, byteSize); } else { for (int i = 0; i < numElements; i++) { values[i] = buffer.readInt64(); @@ -734,7 +731,7 @@ private static Object readPrimitiveElements( { float[] values = new float[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readFloat32ArrayPayload(values, byteSize); + buffer.readFloat32ArrayBytes(values, byteSize); } else { for (int i = 0; i < numElements; i++) { values[i] = buffer.readFloat32(); @@ -746,7 +743,7 @@ private static Object readPrimitiveElements( { double[] values = new double[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readFloat64ArrayPayload(values, byteSize); + buffer.readFloat64ArrayBytes(values, byteSize); } else { for (int i = 0; i < numElements; i++) { values[i] = buffer.readFloat64(); @@ -761,6 +758,7 @@ private static Object readPrimitiveElements( private static Object readListPrimitiveElements( MemoryBuffer buffer, int numElements, int arrayTypeId, int elementTypeId, boolean hasNull) { + buffer.checkReadableBytes(minReadablePrimitiveListBytes(numElements, elementTypeId, hasNull)); switch (elementTypeId) { case Types.BOOL: { @@ -922,6 +920,49 @@ private static Object readListPrimitiveElements( } } + private static int minReadablePrimitiveListBytes( + int numElements, int elementTypeId, boolean hasNull) { + int valueBytes; + switch (elementTypeId) { + case Types.BOOL: + case Types.INT8: + case Types.UINT8: + case Types.VARINT32: + case Types.VAR_UINT32: + case Types.VARINT64: + case Types.TAGGED_INT64: + case Types.VAR_UINT64: + case Types.TAGGED_UINT64: + valueBytes = 1; + break; + case Types.INT16: + case Types.UINT16: + case Types.FLOAT16: + case Types.BFLOAT16: + valueBytes = 2; + break; + case Types.INT32: + case Types.UINT32: + case Types.FLOAT32: + valueBytes = 4; + break; + case Types.INT64: + case Types.UINT64: + case Types.FLOAT64: + valueBytes = 8; + break; + default: + throw new IllegalArgumentException( + "Unsupported primitive element type id " + elementTypeId); + } + int bytesPerElement = hasNull ? valueBytes + 1 : valueBytes; + long byteSize = (long) numElements * bytesPerElement; + if (byteSize > Integer.MAX_VALUE) { + throw new DeserializationException("Primitive list body size exceeds int range"); + } + return (int) byteSize; + } + private static void readNonNullListElement(MemoryBuffer buffer) { byte headFlag = buffer.readByte(); if (headFlag == Fory.NULL_FLAG) { @@ -936,6 +977,7 @@ private static void readNonNullListElement(MemoryBuffer buffer) { private static List readNullableListBoxedElements( MemoryBuffer buffer, int numElements, int arrayTypeId, int elementTypeId) { + buffer.checkReadableBytes(minReadablePrimitiveListBytes(numElements, elementTypeId, true)); ArrayList values = new ArrayList<>(numElements); for (int i = 0; i < numElements; i++) { byte headFlag = buffer.readByte(); @@ -1225,41 +1267,19 @@ private static int elementSize(int arrayTypeId) { } } - private static void validateElementCount(Config config, int numElements) { + private static void validateElementCount(int numElements) { if (numElements < 0) { throw new DeserializationException("Collection size must be non-negative: " + numElements); } - if (numElements > config.maxCollectionSize()) { - throw new DeserializationException( - "Collection size " - + numElements - + " exceeds max collection size " - + config.maxCollectionSize()); - } } - private static void validateElementStorageSize(Config config, int numElements, int elemSize) { - if (numElements > config.maxBinarySize() / elemSize) { - throw new DeserializationException( - "Binary payload size " - + ((long) numElements * elemSize) - + " exceeds max binary size " - + config.maxBinarySize()); - } - } - - private static void validateBinarySize( - Config config, MemoryBuffer buffer, int byteSize, int elemSize) { + private static void validateBinarySize(int byteSize, int elemSize) { if (byteSize < 0) { - throw new DeserializationException("Binary payload size must be non-negative: " + byteSize); - } - if (byteSize > config.maxBinarySize()) { - throw new DeserializationException( - "Binary payload size " + byteSize + " exceeds max binary size " + config.maxBinarySize()); + throw new DeserializationException("Binary body size must be non-negative: " + byteSize); } if (byteSize % elemSize != 0) { throw new DeserializationException( - "Binary payload size " + byteSize + " is not aligned to element size " + elemSize); + "Binary body size " + byteSize + " is not aligned to element size " + elemSize); } } } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/DecimalSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/DecimalSerializer.java index 05043113f4..b6ebadbbf5 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/DecimalSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/DecimalSerializer.java @@ -33,12 +33,10 @@ public final class DecimalSerializer extends ImmutableSerializer imp private static final BigInteger LONG_MIN = BigInteger.valueOf(Long.MIN_VALUE); private static final BigInteger LONG_MAX = BigInteger.valueOf(Long.MAX_VALUE); private final boolean xlang; - private final int maxBinarySize; public DecimalSerializer(Config config) { super(config, BigDecimal.class); xlang = config.isXlang(); - maxBinarySize = config.maxBinarySize(); } @Override @@ -72,7 +70,7 @@ private BigDecimal readNative(ReadContext readContext) { int scale = buffer.readVarUInt32Small7(); int precision = buffer.readVarUInt32Small7(); int len = buffer.readVarUInt32Small7(); - checkBinaryPayloadLength(len, maxBinarySize); + checkBinaryBodyLength(len); buffer.checkReadableBytes(len); byte[] bytes = buffer.readBytes(len); BigInteger bigInteger = new BigInteger(bytes); @@ -84,7 +82,7 @@ private void writeXlang(WriteContext writeContext, BigDecimal value) { } private BigDecimal readXlang(ReadContext readContext) { - return readXlangDecimal(readContext.getBuffer(), maxBinarySize); + return readXlangDecimal(readContext.getBuffer()); } static void writeXlangDecimal(MemoryBuffer buffer, int scale, BigInteger unscaled) { @@ -97,29 +95,21 @@ static void writeXlangDecimal(MemoryBuffer buffer, int scale, BigInteger unscale } int sign = unscaled.signum() < 0 ? 1 : 0; - byte[] payload = toCanonicalLittleEndianMagnitude(unscaled.abs()); - long meta = (((long) payload.length) << 1) | sign; + byte[] magnitudeBytes = toCanonicalLittleEndianMagnitude(unscaled.abs()); + long meta = (((long) magnitudeBytes.length) << 1) | sign; long header = (meta << 1) | 1L; buffer.writeVarUInt64(header); - buffer.writeBytes(payload); + buffer.writeBytes(magnitudeBytes); } static BigDecimal readXlangDecimal(MemoryBuffer buffer) { - return readXlangDecimal(buffer, Integer.MAX_VALUE); - } - - static BigDecimal readXlangDecimal(MemoryBuffer buffer, int maxBinarySize) { int scale = buffer.readVarInt32(); - return new BigDecimal(readXlangUnscaled(buffer, maxBinarySize), scale); + return new BigDecimal(readXlangUnscaled(buffer), scale); } static BigInteger readXlangBigInteger(MemoryBuffer buffer) { - return readXlangBigInteger(buffer, Integer.MAX_VALUE); - } - - static BigInteger readXlangBigInteger(MemoryBuffer buffer, int maxBinarySize) { int scale = buffer.readVarInt32(); - BigInteger unscaled = readXlangUnscaled(buffer, maxBinarySize); + BigInteger unscaled = readXlangUnscaled(buffer); if (scale != 0) { throw new IllegalArgumentException( "Cannot deserialize xlang decimal with scale " + scale + " into BigInteger"); @@ -127,7 +117,7 @@ static BigInteger readXlangBigInteger(MemoryBuffer buffer, int maxBinarySize) { return unscaled; } - private static BigInteger readXlangUnscaled(MemoryBuffer buffer, int maxBinarySize) { + private static BigInteger readXlangUnscaled(MemoryBuffer buffer) { long header = buffer.readVarUInt64(); if ((header & 1L) == 0L) { return BigInteger.valueOf(decodeZigZag64(header >>> 1)); @@ -137,19 +127,15 @@ private static BigInteger readXlangUnscaled(MemoryBuffer buffer, int maxBinarySi long lenLong = meta >>> 1; if (lenLong <= 0 || lenLong > Integer.MAX_VALUE) { throw new IllegalArgumentException( - "Invalid decimal magnitude length " + lenLong + " in xlang payload"); - } - if (lenLong > maxBinarySize) { - throw new DeserializationException( - "Decimal magnitude length " + lenLong + " exceeds max binary size " + maxBinarySize); + "Invalid decimal magnitude length " + lenLong + " in xlang body"); } int len = (int) lenLong; buffer.checkReadableBytes(len); - byte[] payload = buffer.readBytes(len); - if (payload[len - 1] == 0) { - throw new IllegalArgumentException("Non-canonical decimal payload: trailing zero byte"); + byte[] magnitudeBytes = buffer.readBytes(len); + if (magnitudeBytes[len - 1] == 0) { + throw new IllegalArgumentException("Non-canonical decimal body: trailing zero byte"); } - byte[] magnitude = toBigEndian(payload); + byte[] magnitude = toBigEndian(magnitudeBytes); BigInteger abs = new BigInteger(1, magnitude); if (abs.signum() == 0) { throw new IllegalArgumentException("Big decimal encoding must not represent zero"); @@ -157,13 +143,9 @@ private static BigInteger readXlangUnscaled(MemoryBuffer buffer, int maxBinarySi return sign == 0 ? abs : abs.negate(); } - private static void checkBinaryPayloadLength(int len, int maxBinarySize) { + private static void checkBinaryBodyLength(int len) { if (len <= 0) { - throw new DeserializationException("Decimal payload length must be positive: " + len); - } - if (len > maxBinarySize) { - throw new DeserializationException( - "Decimal payload length " + len + " exceeds max binary size " + maxBinarySize); + throw new DeserializationException("Decimal body length must be positive: " + len); } } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java index e5689087ff..b3b54e18be 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/ExceptionSerializers.java @@ -500,14 +500,13 @@ private static void readAndSkipLayerClassMeta(ReadContext readContext) { private static List readSuppressedExceptions(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int numSuppressedExceptions = buffer.readVarUInt32(); - int maxCollectionSize = readContext.getConfig().maxCollectionSize(); - if (numSuppressedExceptions < 0 || numSuppressedExceptions > maxCollectionSize) { + if (numSuppressedExceptions < 0) { throw new ForyException( "Throwable suppressed exception count " + numSuppressedExceptions - + " exceeds max collection size " - + maxCollectionSize); + + " must be non-negative"); } + buffer.checkReadableBytes(numSuppressedExceptions); List suppressedExceptions = new ArrayList<>(numSuppressedExceptions); for (int i = 0; i < numSuppressedExceptions; i++) { suppressedExceptions.add((Throwable) readContext.readRef()); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveArraySerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveArraySerializers.java index d9d725f96a..552a6f0355 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveArraySerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/PrimitiveArraySerializers.java @@ -105,47 +105,24 @@ public MemoryBuffer toBuffer() { public abstract static class PrimitiveArraySerializer extends Serializer implements Shareable { protected final Config config; - protected final int maxBinarySize; public PrimitiveArraySerializer(TypeResolver typeResolver, Class cls) { super(typeResolver.getConfig(), cls); this.config = typeResolver.getConfig(); - maxBinarySize = config.maxBinarySize(); } } - private static void throwBinarySizeLimitExceeded(long size, int maxBinarySize) { - throw new DeserializationException( - "Binary payload size " + size + " exceeds max binary size " + maxBinarySize); - } - private static void throwNegativeBinarySize(int size) { - throw new DeserializationException("Binary payload size must be non-negative: " + size); + throw new DeserializationException("Binary body size must be non-negative: " + size); } private static void throwNegativeElementCount(int numElements) { throw new DeserializationException("Element count must be non-negative: " + numElements); } - private static void throwInvalidBinarySize(int size, int maxBinarySize) { - if (size < 0) { - throwNegativeBinarySize(size); - } else { - throwBinarySizeLimitExceeded(size, maxBinarySize); - } - } - - private static void throwInvalidElementCount(int numElements, int maxBinarySize, int elemSize) { - if (numElements < 0) { - throwNegativeElementCount(numElements); - } else { - throwBinarySizeLimitExceeded((long) numElements * elemSize, maxBinarySize); - } - } - private static void throwUnalignedBinarySize(int size, int elemSize) { throw new DeserializationException( - "Binary payload size " + size + " is not aligned to element size " + elemSize); + "Binary body size " + size + " is not aligned to element size " + elemSize); } public static final class BooleanArraySerializer extends PrimitiveArraySerializer { @@ -175,19 +152,18 @@ public boolean[] read(ReadContext readContext) { if (readContext.isPeerOutOfBandEnabled()) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); - if (size > maxBinarySize) { - throwBinarySizeLimitExceeded(size, maxBinarySize); - } + buf.checkReadableBytes(size); boolean[] values = new boolean[size]; - buf.readBooleanArrayPayload(values, size); + buf.readBooleanArrayBytes(values, size); return values; } int size = buffer.readVarUInt32Small7(); - if (size < 0 || size > maxBinarySize) { - throwInvalidBinarySize(size, maxBinarySize); + if (size < 0) { + throwNegativeBinarySize(size); } + buffer.checkReadableBytes(size); boolean[] values = new boolean[size]; - buffer.readBooleanArrayPayload(values, size); + buffer.readBooleanArrayBytes(values, size); return values; } } @@ -219,19 +195,18 @@ public byte[] read(ReadContext readContext) { if (readContext.isPeerOutOfBandEnabled()) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); - if (size > maxBinarySize) { - throwBinarySizeLimitExceeded(size, maxBinarySize); - } + buf.checkReadableBytes(size); byte[] values = new byte[size]; - buf.readByteArrayPayload(values, size); + buf.readByteArrayBytes(values, size); return values; } int size = buffer.readVarUInt32Small7(); - if (size < 0 || size > maxBinarySize) { - throwInvalidBinarySize(size, maxBinarySize); + if (size < 0) { + throwNegativeBinarySize(size); } + buffer.checkReadableBytes(size); byte[] values = new byte[size]; - buffer.readByteArrayPayload(values, size); + buffer.readByteArrayBytes(values, size); return values; } } @@ -287,13 +262,11 @@ public char[] read(ReadContext readContext) { if ((size & 1) != 0) { throwUnalignedBinarySize(size, 2); } - if (size > maxBinarySize) { - throwBinarySizeLimitExceeded(size, maxBinarySize); - } int numElements = size >>> 1; + buf.checkReadableBytes(size); char[] values = new char[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buf.readCharArrayPayload(values, size); + buf.readCharArrayBytes(values, size); } else { readCharBySwapEndian(buf, values, numElements); } @@ -303,13 +276,14 @@ public char[] read(ReadContext readContext) { if ((size & 1) != 0) { throwUnalignedBinarySize(size, 2); } - if (size < 0 || size > maxBinarySize) { - throwInvalidBinarySize(size, maxBinarySize); + if (size < 0) { + throwNegativeBinarySize(size); } int numElements = size >>> 1; + buffer.checkReadableBytes(size); char[] values = new char[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readCharArrayPayload(values, size); + buffer.readCharArrayBytes(values, size); } else { readCharBySwapEndian(buffer, values, numElements); } @@ -319,8 +293,8 @@ public char[] read(ReadContext readContext) { private void readCharBySwapEndian(MemoryBuffer buffer, char[] values, int numElements) { int size = numElements << 1; // Do not loop through MemoryBuffer._unsafeGet* here; those helpers carry Android dispatch. - // Copy the payload once, then byte-swap the destination values locally. - buffer.readCharArrayPayload(values, size); + // Copy the body bytes once, then byte-swap the destination values locally. + buffer.readCharArrayBytes(values, size); for (int i = 0; i < numElements; i++) { values[i] = Character.reverseBytes(values[i]); } @@ -344,7 +318,7 @@ public short[] copy(CopyContext copyContext, short[] originArray) { @Override public short[] read(ReadContext readContext) { - return readShortBits(readContext, maxBinarySize); + return readShortBits(readContext); } } @@ -397,14 +371,12 @@ public int[] read(ReadContext readContext) { if ((size & 3) != 0) { throwUnalignedBinarySize(size, 4); } - if (size > maxBinarySize) { - throwBinarySizeLimitExceeded(size, maxBinarySize); - } int numElements = size >>> 2; + buf.checkReadableBytes(size); int[] values = new int[numElements]; if (size > 0) { if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buf.readInt32ArrayPayload(values, size); + buf.readInt32ArrayBytes(values, size); } else { readInt32BySwapEndian(buf, values, numElements); } @@ -418,14 +390,15 @@ public int[] read(ReadContext readContext) { if ((size & 3) != 0) { throwUnalignedBinarySize(size, 4); } - if (size < 0 || size > maxBinarySize) { - throwInvalidBinarySize(size, maxBinarySize); + if (size < 0) { + throwNegativeBinarySize(size); } int numElements = size >>> 2; + buffer.checkReadableBytes(size); int[] values = new int[numElements]; if (size > 0) { if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt32ArrayPayload(values, size); + buffer.readInt32ArrayBytes(values, size); } else { readInt32BySwapEndian(buffer, values, numElements); } @@ -436,8 +409,8 @@ public int[] read(ReadContext readContext) { private void readInt32BySwapEndian(MemoryBuffer buffer, int[] values, int numElements) { int size = numElements << 2; // Do not loop through MemoryBuffer._unsafeGet* here; those helpers carry Android dispatch. - // Copy the payload once, then byte-swap the destination values locally. - buffer.readInt32ArrayPayload(values, size); + // Copy the body bytes once, then byte-swap the destination values locally. + buffer.readInt32ArrayBytes(values, size); for (int i = 0; i < numElements; i++) { values[i] = Integer.reverseBytes(values[i]); } @@ -452,9 +425,10 @@ private void writeInt32Compressed(MemoryBuffer buffer, int[] value) { private int[] readInt32Compressed(MemoryBuffer buffer) { int numElements = buffer.readVarUInt32Small7(); - if (numElements < 0 || numElements > maxBinarySize / 4) { - throwInvalidElementCount(numElements, maxBinarySize, 4); + if (numElements < 0) { + throwNegativeElementCount(numElements); } + buffer.checkReadableBytes(numElements); int[] values = new int[numElements]; for (int i = 0; i < numElements; i++) { values[i] = buffer.readVarInt32(); @@ -518,14 +492,12 @@ public long[] read(ReadContext readContext) { if ((size & 7) != 0) { throwUnalignedBinarySize(size, 8); } - if (size > maxBinarySize) { - throwBinarySizeLimitExceeded(size, maxBinarySize); - } int numElements = size >>> 3; + buf.checkReadableBytes(size); long[] values = new long[numElements]; if (size > 0) { if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buf.readInt64ArrayPayload(values, size); + buf.readInt64ArrayBytes(values, size); } else { readInt64BySwapEndian(buf, values, numElements); } @@ -539,14 +511,15 @@ public long[] read(ReadContext readContext) { if ((size & 7) != 0) { throwUnalignedBinarySize(size, 8); } - if (size < 0 || size > maxBinarySize) { - throwInvalidBinarySize(size, maxBinarySize); + if (size < 0) { + throwNegativeBinarySize(size); } int numElements = size >>> 3; + buffer.checkReadableBytes(size); long[] values = new long[numElements]; if (size > 0) { if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt64ArrayPayload(values, size); + buffer.readInt64ArrayBytes(values, size); } else { readInt64BySwapEndian(buffer, values, numElements); } @@ -557,8 +530,8 @@ public long[] read(ReadContext readContext) { private void readInt64BySwapEndian(MemoryBuffer buffer, long[] values, int numElements) { int size = numElements << 3; // Do not loop through MemoryBuffer._unsafeGet* here; those helpers carry Android dispatch. - // Copy the payload once, then byte-swap the destination values locally. - buffer.readInt64ArrayPayload(values, size); + // Copy the body bytes once, then byte-swap the destination values locally. + buffer.readInt64ArrayBytes(values, size); for (int i = 0; i < numElements; i++) { values[i] = Long.reverseBytes(values[i]); } @@ -581,9 +554,10 @@ private void writeInt64Compressed( private long[] readInt64Compressed(MemoryBuffer buffer, Int64Encoding longEncoding) { int numElements = buffer.readVarUInt32Small7(); - if (numElements < 0 || numElements > maxBinarySize / 8) { - throwInvalidElementCount(numElements, maxBinarySize, 8); + if (numElements < 0) { + throwNegativeElementCount(numElements); } + buffer.checkReadableBytes(numElements); long[] values = new long[numElements]; if (longEncoding == Int64Encoding.TAGGED) { for (int i = 0; i < numElements; i++) { @@ -643,13 +617,11 @@ public float[] read(ReadContext readContext) { if ((size & 3) != 0) { throwUnalignedBinarySize(size, 4); } - if (size > maxBinarySize) { - throwBinarySizeLimitExceeded(size, maxBinarySize); - } int numElements = size >>> 2; + buf.checkReadableBytes(size); float[] values = new float[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buf.readFloat32ArrayPayload(values, size); + buf.readFloat32ArrayBytes(values, size); } else { readFloat32BySwapEndian(buf, values, numElements); } @@ -659,13 +631,14 @@ public float[] read(ReadContext readContext) { if ((size & 3) != 0) { throwUnalignedBinarySize(size, 4); } - if (size < 0 || size > maxBinarySize) { - throwInvalidBinarySize(size, maxBinarySize); + if (size < 0) { + throwNegativeBinarySize(size); } int numElements = size >>> 2; + buffer.checkReadableBytes(size); float[] values = new float[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readFloat32ArrayPayload(values, size); + buffer.readFloat32ArrayBytes(values, size); } else { readFloat32BySwapEndian(buffer, values, numElements); } @@ -675,8 +648,8 @@ public float[] read(ReadContext readContext) { private void readFloat32BySwapEndian(MemoryBuffer buffer, float[] values, int numElements) { int size = numElements << 2; // Do not loop through MemoryBuffer._unsafeGet* here; those helpers carry Android dispatch. - // Copy the payload once, then byte-swap the destination values locally. - buffer.readFloat32ArrayPayload(values, size); + // Copy the body bytes once, then byte-swap the destination values locally. + buffer.readFloat32ArrayBytes(values, size); for (int i = 0; i < numElements; i++) { values[i] = Float.intBitsToFloat(Integer.reverseBytes(Float.floatToRawIntBits(values[i]))); } @@ -728,13 +701,11 @@ public double[] read(ReadContext readContext) { if ((size & 7) != 0) { throwUnalignedBinarySize(size, 8); } - if (size > maxBinarySize) { - throwBinarySizeLimitExceeded(size, maxBinarySize); - } int numElements = size >>> 3; + buf.checkReadableBytes(size); double[] values = new double[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buf.readFloat64ArrayPayload(values, size); + buf.readFloat64ArrayBytes(values, size); } else { readFloat64BySwapEndian(buf, values, numElements); } @@ -744,13 +715,14 @@ public double[] read(ReadContext readContext) { if ((size & 7) != 0) { throwUnalignedBinarySize(size, 8); } - if (size < 0 || size > maxBinarySize) { - throwInvalidBinarySize(size, maxBinarySize); + if (size < 0) { + throwNegativeBinarySize(size); } int numElements = size >>> 3; + buffer.checkReadableBytes(size); double[] values = new double[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readFloat64ArrayPayload(values, size); + buffer.readFloat64ArrayBytes(values, size); } else { readFloat64BySwapEndian(buffer, values, numElements); } @@ -760,8 +732,8 @@ public double[] read(ReadContext readContext) { private void readFloat64BySwapEndian(MemoryBuffer buffer, double[] values, int numElements) { int size = numElements << 3; // Do not loop through MemoryBuffer._unsafeGet* here; those helpers carry Android dispatch. - // Copy the payload once, then byte-swap the destination values locally. - buffer.readFloat64ArrayPayload(values, size); + // Copy the body bytes once, then byte-swap the destination values locally. + buffer.readFloat64ArrayBytes(values, size); for (int i = 0; i < numElements; i++) { values[i] = Double.longBitsToDouble(Long.reverseBytes(Double.doubleToRawLongBits(values[i]))); @@ -786,7 +758,7 @@ public Float16Array copy(CopyContext copyContext, Float16Array originArray) { @Override public Float16Array read(ReadContext readContext) { - return Float16Array.wrapBits(readShortBits(readContext, maxBinarySize)); + return Float16Array.wrapBits(readShortBits(readContext)); } } @@ -808,7 +780,7 @@ public BFloat16Array copy(CopyContext copyContext, BFloat16Array originArray) { @Override public BFloat16Array read(ReadContext readContext) { - return BFloat16Array.wrapBits(readShortBits(readContext, maxBinarySize)); + return BFloat16Array.wrapBits(readShortBits(readContext)); } } @@ -837,7 +809,7 @@ private static void writeInt16BySwapEndian(MemoryBuffer buffer, short[] value) { buffer._unsafeWriterIndex(idx + length * 2); } - private static short[] readShortBits(ReadContext readContext, int maxBinarySize) { + private static short[] readShortBits(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); if (readContext.isPeerOutOfBandEnabled()) { MemoryBuffer buf = readContext.readBufferObject(); @@ -845,13 +817,11 @@ private static short[] readShortBits(ReadContext readContext, int maxBinarySize) if ((size & 1) != 0) { throwUnalignedBinarySize(size, 2); } - if (size > maxBinarySize) { - throwBinarySizeLimitExceeded(size, maxBinarySize); - } int numElements = size >>> 1; + buf.checkReadableBytes(size); short[] values = new short[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buf.readInt16ArrayPayload(values, size); + buf.readInt16ArrayBytes(values, size); } else { readInt16BySwapEndian(buf, values, numElements); } @@ -861,13 +831,14 @@ private static short[] readShortBits(ReadContext readContext, int maxBinarySize) if ((size & 1) != 0) { throwUnalignedBinarySize(size, 2); } - if (size < 0 || size > maxBinarySize) { - throwInvalidBinarySize(size, maxBinarySize); + if (size < 0) { + throwNegativeBinarySize(size); } int numElements = size >>> 1; + buffer.checkReadableBytes(size); short[] values = new short[numElements]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt16ArrayPayload(values, size); + buffer.readInt16ArrayBytes(values, size); } else { readInt16BySwapEndian(buffer, values, numElements); } @@ -877,8 +848,8 @@ private static short[] readShortBits(ReadContext readContext, int maxBinarySize) private static void readInt16BySwapEndian(MemoryBuffer buffer, short[] values, int numElements) { int size = numElements << 1; // Do not loop through MemoryBuffer._unsafeGet* here; those helpers carry Android dispatch. - // Copy the payload once, then byte-swap the destination values locally. - buffer.readInt16ArrayPayload(values, size); + // Copy the body bytes once, then byte-swap the destination values locally. + buffer.readInt16ArrayBytes(values, size); for (int i = 0; i < numElements; i++) { values[i] = Short.reverseBytes(values[i]); } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/SerializedLambdaSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/SerializedLambdaSerializer.java index 0c225bddaf..6e8bcebc35 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/SerializedLambdaSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/SerializedLambdaSerializer.java @@ -42,7 +42,6 @@ public class SerializedLambdaSerializer extends Serializer { static final Class SERIALIZED_LAMBDA = SerializedLambda.class; private static final MethodHandle READ_RESOLVE_HANDLE; private final TypeResolver typeResolver; - private final int maxCollectionSize; static { if (AndroidSupport.IS_ANDROID) { @@ -63,7 +62,6 @@ public class SerializedLambdaSerializer extends Serializer { public SerializedLambdaSerializer(TypeResolver typeResolver, Class cls) { super(typeResolver.getConfig(), cls); this.typeResolver = typeResolver; - maxCollectionSize = typeResolver.getConfig().maxCollectionSize(); Preconditions.checkArgument(cls == SERIALIZED_LAMBDA); } @@ -130,9 +128,10 @@ Object readUnresolved(ReadContext readContext) { int implMethodKind = buffer.readVarInt32(); String instantiatedMethodType = readContext.readStringRef(); int capturedArgCount = buffer.readVarUInt32Small7(); - if (capturedArgCount < 0 || capturedArgCount > maxCollectionSize) { + if (capturedArgCount < 0) { throwInvalidCapturedArgCount(capturedArgCount); } + buffer.checkReadableBytes(capturedArgCount); Object[] capturedArgs = new Object[capturedArgCount]; for (int i = 0; i < capturedArgCount; i++) { capturedArgs[i] = readContext.readRef(); @@ -152,15 +151,8 @@ Object readUnresolved(ReadContext readContext) { } private void throwInvalidCapturedArgCount(int capturedArgCount) { - if (capturedArgCount < 0) { - throw new DeserializationException( - "SerializedLambda captured arg count must be non-negative: " + capturedArgCount); - } throw new DeserializationException( - "SerializedLambda captured arg count " - + capturedArgCount - + " exceeds max collection size " - + maxCollectionSize); + "SerializedLambda captured arg count must be non-negative: " + capturedArgCount); } static Object readResolve(Object replacement) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/StringSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/StringSerializer.java index f0f956d84a..cd4d8a75db 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/StringSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/StringSerializer.java @@ -85,7 +85,6 @@ public final class StringSerializer extends ImmutableSerializer { private final boolean compressString; private final boolean writeNumUtf16BytesForUtf8Encoding; private final boolean xlang; - private final long maxBinarySize; // set default length to 0, since char array and bytes array won't be used at the same time. private static final byte[] EMPTY_BYTES_STUB = new byte[0]; @@ -104,7 +103,6 @@ public StringSerializer(Config config) { Preconditions.checkArgument(compressString, "compress string muse be enabled for xlang mode"); } writeNumUtf16BytesForUtf8Encoding = config.writeNumUtf16BytesForUtf8Encoding(); - maxBinarySize = config.maxBinarySize(); } @Override @@ -177,8 +175,13 @@ public String readBytesString(MemoryBuffer buffer) { byte coder = (byte) (header & 0b11); int numBytes = readStringSize(header); byte[] bytes; - if (!NativeByteOrder.IS_LITTLE_ENDIAN && coder == UTF16) { - bytes = readBytesUTF16BE(buffer, numBytes); + if (coder == UTF16) { + checkUtf16Bytes(numBytes); + if (NativeByteOrder.IS_LITTLE_ENDIAN) { + bytes = readBytesUnCompressedUTF16(buffer, numBytes); + } else { + bytes = readBytesUTF16BE(buffer, numBytes); + } } else { bytes = readBytesUnCompressedUTF16(buffer, numBytes); } @@ -224,6 +227,7 @@ public String readCompressedBytesString(MemoryBuffer buffer) { } else if (coder == LATIN1) { return newBytesStringZeroCopy(coder, readBytesUnCompressedUTF16(buffer, numBytes)); } else if (coder == UTF16) { + checkUtf16Bytes(numBytes); byte[] bytes; if (NativeByteOrder.IS_LITTLE_ENDIAN) { bytes = readBytesUnCompressedUTF16(buffer, numBytes); @@ -566,8 +570,8 @@ public char[] readCharsLatin1(MemoryBuffer buffer, int numBytes) { } public byte[] readBytesUTF8(MemoryBuffer buffer, int numBytes) { - byte[] tmpArray = getByteArray(numBytes << 1); buffer.checkReadableBytes(numBytes); + byte[] tmpArray = getByteArray(numBytes << 1); int utf16NumBytes; byte[] srcArray = buffer.getHeapMemory(); if (srcArray != null) { @@ -586,9 +590,9 @@ public byte[] readBytesUTF8(MemoryBuffer buffer, int numBytes) { private byte[] readBytesUTF8PerfOptimized(MemoryBuffer buffer, int numBytes) { int udf8Bytes = buffer.readInt32(); checkStringSize(udf8Bytes); - byte[] bytes = new byte[numBytes]; // noinspection Duplicates buffer.checkReadableBytes(udf8Bytes); + byte[] bytes = new byte[numBytes]; byte[] srcArray = buffer.getHeapMemory(); if (srcArray != null) { int srcIndex = buffer._unsafeHeapReaderIndex(); @@ -620,7 +624,9 @@ public byte[] readBytesUnCompressedUTF16(MemoryBuffer buffer, int numBytes) { } public char[] readCharsUTF16(MemoryBuffer buffer, int numBytes) { + checkUtf16Bytes(numBytes); if (NativeByteOrder.IS_LITTLE_ENDIAN) { + buffer.checkReadableBytes(numBytes); char[] chars = new char[numBytes >> 1]; // FIXME JDK11 utf16 string uses little-endian order. buffer.readChars(chars, numBytes >> 1); @@ -631,9 +637,9 @@ public char[] readCharsUTF16(MemoryBuffer buffer, int numBytes) { } public String readCharsUTF8(MemoryBuffer buffer, int numBytes) { + buffer.checkReadableBytes(numBytes); char[] chars = getCharArray(numBytes); int charsLen; - buffer.checkReadableBytes(numBytes); byte[] srcArray = buffer.getHeapMemory(); if (srcArray != null) { int srcIndex = buffer._unsafeHeapReaderIndex(); @@ -652,9 +658,9 @@ public String readCharsUTF8PerfOptimized(MemoryBuffer buffer, int numBytes) { int udf16Chars = numBytes >> 1; int udf8Bytes = buffer.readInt32(); checkStringSize(udf8Bytes); - char[] chars = new char[udf16Chars]; // noinspection Duplicates buffer.checkReadableBytes(udf8Bytes); + char[] chars = new char[udf16Chars]; byte[] srcArray = buffer.getHeapMemory(); if (srcArray != null) { int srcIndex = buffer._unsafeHeapReaderIndex(); @@ -672,21 +678,27 @@ public String readCharsUTF8PerfOptimized(MemoryBuffer buffer, int numBytes) { private int readStringSize(long header) { long size = header >>> 2; - if (size > maxBinarySize) { + if (size > Integer.MAX_VALUE) { throwStringSizeOutOfBounds(size); } return (int) size; } + private static void checkUtf16Bytes(int numBytes) { + if ((numBytes & 1) != 0) { + throw new IllegalArgumentException( + "UTF-16 byte size " + numBytes + " is not aligned to element size 2"); + } + } + private void checkStringSize(int size) { - if (size < 0 || size > maxBinarySize) { + if (size < 0) { throwStringSizeOutOfBounds(size); } } private void throwStringSizeOutOfBounds(long size) { - throw new DeserializationException( - "String payload size " + size + " is outside allowed range [0, " + maxBinarySize + "]"); + throw new DeserializationException("Invalid string byte size " + size); } public void writeCharsLatin1(MemoryBuffer buffer, char[] chars, int numBytes) { diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java index 90df165485..3915b5d888 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionLikeSerializer.java @@ -49,7 +49,6 @@ public abstract class CollectionLikeSerializer extends Serializer { private MethodHandle constructor; private int numElements; protected final Config config; - protected final int maxCollectionSize; protected final boolean supportCodegenHook; protected final TypeInfoHolder elementTypeInfoHolder; protected final TypeResolver typeResolver; @@ -71,7 +70,6 @@ public CollectionLikeSerializer( TypeResolver typeResolver, Class cls, boolean supportCodegenHook) { super(typeResolver.getConfig(), cls); this.config = typeResolver.getConfig(); - maxCollectionSize = config.maxCollectionSize(); this.supportCodegenHook = supportCodegenHook; elementTypeInfoHolder = typeResolver.nilTypeInfoHolder(); this.typeResolver = typeResolver; @@ -81,7 +79,6 @@ public CollectionLikeSerializer( TypeResolver typeResolver, Class cls, boolean supportCodegenHook, boolean immutable) { super(typeResolver.getConfig(), cls, immutable); this.config = typeResolver.getConfig(); - maxCollectionSize = config.maxCollectionSize(); this.supportCodegenHook = supportCodegenHook; elementTypeInfoHolder = typeResolver.nilTypeInfoHolder(); this.typeResolver = typeResolver; @@ -566,24 +563,20 @@ protected void setNumElements(int numElements) { protected final int readCollectionSize(MemoryBuffer buffer) { int numElements = buffer.readVarUInt32Small7(); checkCollectionSize(numElements); + buffer.checkReadableBytes(numElements); return numElements; } protected final void checkCollectionSize(int numElements) { // Keep this as direct primitive branches. Collection reads are hot enough that // Preconditions.checkArgument would add helper/varargs overhead on the valid path. - if (numElements < 0 || numElements > maxCollectionSize) { + if (numElements < 0) { throwInvalidCollectionSize(numElements); } } private void throwInvalidCollectionSize(int numElements) { - if (numElements < 0) { - throw new DeserializationException("Collection size must be non-negative: " + numElements); - } else { - throw new DeserializationException( - "Collection size " + numElements + " exceeds max collection size " + maxCollectionSize); - } + throw new DeserializationException("Collection size must be non-negative: " + numElements); } public abstract T onCollectionRead(Collection collection); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java index d820331c6f..7c67d1607e 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/CollectionSerializers.java @@ -49,7 +49,6 @@ import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.LinkedBlockingQueue; import org.apache.fory.collection.CollectionSnapshot; -import org.apache.fory.config.Config; import org.apache.fory.context.CopyContext; import org.apache.fory.context.ReadContext; import org.apache.fory.context.WriteContext; @@ -87,21 +86,16 @@ private static void requireXlangNaturalOrdering(Class type, Comparator com } } - private static void throwBinarySizeLimitExceeded(long size, int maxBinarySize) { - throw new DeserializationException( - "Binary payload size " + size + " exceeds max binary size " + maxBinarySize); - } - private static void throwNegativeBinarySize(int size) { - throw new DeserializationException("Binary payload size must be non-negative: " + size); + throw new DeserializationException("Binary body size must be non-negative: " + size); } private static void throwUnalignedBinarySize(int size, int elemSize) { throw new DeserializationException( - "Binary payload size " + size + " is not aligned to element size " + elemSize); + "Binary body size " + size + " is not aligned to element size " + elemSize); } - private static void checkBoundedQueueCapacity(Config config, int numElements, int capacity) { + private static void checkBoundedQueueCapacity(int numElements, int capacity) { // Keep these as direct primitive branches. This collection read path is JIT-sensitive; using // Preconditions.checkArgument here adds helper/varargs overhead and hurts inlining. if (numElements < 0) { @@ -114,11 +108,6 @@ private static void checkBoundedQueueCapacity(Config config, int numElements, in throw new DeserializationException( "Queue capacity " + capacity + " is smaller than serialized size " + numElements); } - int maxCollectionSize = config.maxCollectionSize(); - if (capacity > maxCollectionSize) { - throw new DeserializationException( - "Queue capacity " + capacity + " exceeds max collection size " + maxCollectionSize); - } } private static UnsupportedOperationException unsupportedBoundedQueueWrite(Class type) { @@ -126,7 +115,7 @@ private static UnsupportedOperationException unsupportedBoundedQueueWrite(Class< "Serializing or copying " + type.getName() + " requires access to its exact capacity field. This runtime can deserialize existing " - + "payloads for this type, but cannot serialize or copy it without JDK concurrent " + + "wire bodies for this type, but cannot serialize or copy it without JDK concurrent " + "field access."); } @@ -349,7 +338,7 @@ public List read(ReadContext readContext) { int numElements = readCollectionSize(readContext.getBuffer()); if (numElements != 0) { throw new DeserializationException( - "Empty list payload must have zero elements but got " + numElements); + "Empty list body must have zero elements but got " + numElements); } } return Collections.EMPTY_LIST; @@ -622,7 +611,7 @@ public Collection newCollection(ReadContext readContext) { } else { if (!MemoryUtils.JDK_COLLECTION_FIELD_ACCESS) { throw new UnsupportedOperationException( - "This runtime cannot read SetFromMap backing-map payloads that require hidden JDK field " + "This runtime cannot read SetFromMap backing-map bodies that require hidden JDK field " + "restoration"); } Map map = (Map) mapSerializer.read(readContext); @@ -631,7 +620,7 @@ public Collection newCollection(ReadContext readContext) { SetFromMapAccess.restore(set, map); } catch (Throwable e) { throw new UnsupportedOperationException( - "This runtime cannot restore SetFromMap backing-map payloads through final JDK fields", + "This runtime cannot restore SetFromMap backing-map bodies through final JDK fields", e); } setNumElements(0); @@ -813,11 +802,9 @@ public EnumSet copy(CopyContext copyContext, EnumSet originCollection) { } public static class BitSetSerializer extends Serializer { - private final int maxBinarySize; public BitSetSerializer(TypeResolver typeResolver, Class type) { super(typeResolver.getConfig(), type); - maxBinarySize = typeResolver.getConfig().maxBinarySize(); } @Override @@ -842,11 +829,9 @@ public BitSet read(ReadContext readContext) { if ((size & 7) != 0) { throwUnalignedBinarySize(size, 8); } - if (size > maxBinarySize) { - throwBinarySizeLimitExceeded(size, maxBinarySize); - } + buffer.checkReadableBytes(size); long[] values = new long[size >>> 3]; - buffer.readInt64ArrayPayload(values, size); + buffer.readInt64ArrayBytes(values, size); return BitSet.valueOf(values); } } @@ -941,7 +926,8 @@ public ArrayBlockingQueue newCollection(ReadContext readContext) { int numElements = readCollectionSize(buffer); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); - checkBoundedQueueCapacity(config, numElements, capacity); + checkBoundedQueueCapacity(numElements, capacity); + buffer.checkReadableBytes(capacity); ArrayBlockingQueue queue = new ArrayBlockingQueue<>(capacity); readContext.reference(queue); return queue; @@ -1007,7 +993,8 @@ public LinkedBlockingQueue newCollection(ReadContext readContext) { int numElements = readCollectionSize(buffer); setNumElements(numElements); int capacity = buffer.readVarUInt32Small7(); - checkBoundedQueueCapacity(config, numElements, capacity); + checkBoundedQueueCapacity(numElements, capacity); + buffer.checkReadableBytes(capacity); LinkedBlockingQueue queue = new LinkedBlockingQueue<>(capacity); readContext.reference(queue); return queue; diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/Container.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/Container.java index ae0b28e0be..6a198d3a6d 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/Container.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/Container.java @@ -24,6 +24,7 @@ import java.util.Comparator; import java.util.Iterator; import java.util.Set; +import org.apache.fory.exception.DeserializationException; class Container {} @@ -104,6 +105,10 @@ class JDKImmutableMapContainer extends AbstractMap { private int offset; JDKImmutableMapContainer(int mapCapacity) { + if (mapCapacity > (Integer.MAX_VALUE >>> 1)) { + throw new DeserializationException( + "Immutable map size exceeds internal array capacity: " + mapCapacity); + } array = new Object[mapCapacity << 1]; } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java index ac4e801104..c28aa04561 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/GuavaCollectionSerializers.java @@ -403,6 +403,9 @@ public void write(WriteContext writeContext, Object value) { public Object read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int size = buffer.readVarUInt32Small7(); + if (size != 0) { + buffer.checkReadableBytes(size); + } ImmutableMap.Builder builder = biMap ? newImmutableBiMapBuilder(size) : newImmutableMapBuilder(size); for (int i = 0; i < size; i++) { @@ -488,6 +491,9 @@ public void write(WriteContext writeContext, ImmutableIntArray value) { public ImmutableIntArray read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int length = buffer.readVarUInt32Small7(); + if (length != 0) { + buffer.checkReadableBytes(length); + } int[] values = new int[length]; for (int i = 0; i < length; i++) { values[i] = buffer.readVarInt32(); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java index 3eef7973c2..cd69f2b6cf 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/ImmutableCollectionSerializers.java @@ -161,8 +161,9 @@ public Collection onCollectionRead(Collection collection) { if (JdkVersion.MAJOR_VERSION > 8) { CollectionContainer container = (CollectionContainer) collection; if (!MemoryUtils.JDK_COLLECTION_FIELD_ACCESS) { - ArrayList list = new ArrayList(container.elements.length); - Collections.addAll(list, container.elements); + Object[] elements = container.elements; + ArrayList list = new ArrayList(elements.length); + Collections.addAll(list, elements); return Collections.unmodifiableList(list); } try { @@ -221,8 +222,9 @@ public Collection onCollectionRead(Collection collection) { if (JdkVersion.MAJOR_VERSION > 8) { CollectionContainer container = (CollectionContainer) collection; if (!MemoryUtils.JDK_COLLECTION_FIELD_ACCESS) { - HashSet set = new HashSet(container.elements.length); - Collections.addAll(set, container.elements); + Object[] elements = container.elements; + HashSet set = new HashSet(elements.length); + Collections.addAll(set, elements); return Collections.unmodifiableSet(set); } try { @@ -288,7 +290,8 @@ public Map onMapRead(Map map) { } try { if (container.size() == 1) { - map = (Map) map1Factory.invoke(container.array[0], container.array[1]); + Object[] elements = container.array; + map = (Map) map1Factory.invoke(elements[0], elements[1]); } else { map = (Map) mapNFactory.invoke(container.array); } diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java index 912725d58a..334bd8a35c 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/MapLikeSerializer.java @@ -83,7 +83,6 @@ private MapTypeCache(TypeResolver typeResolver) { protected MethodHandle constructor; protected final Config config; - protected final int maxCollectionSize; protected final boolean supportCodegenHook; private final GenericType objType; // For subclass whose kv type are instantiated already, such as @@ -111,7 +110,6 @@ public MapLikeSerializer( TypeResolver typeResolver, Class cls, boolean supportCodegenHook, boolean immutable) { super(typeResolver.getConfig(), cls, immutable); this.config = typeResolver.getConfig(); - maxCollectionSize = config.maxCollectionSize(); this.typeResolver = typeResolver; trackRef = typeResolver.getConfig().trackingRef(); this.supportCodegenHook = supportCodegenHook; @@ -1031,24 +1029,27 @@ public void setNumElements(int numElements) { protected final int readMapSize(MemoryBuffer buffer) { int numElements = buffer.readVarUInt32Small7(); checkMapSize(numElements); + if (numElements > Integer.MAX_VALUE / 2) { + throwInvalidMapBodySize(numElements); + } + buffer.checkReadableBytes(numElements << 1); return numElements; } protected final void checkMapSize(int numElements) { // Keep this as direct primitive branches. Map reads are hot enough that // Preconditions.checkArgument would add helper/varargs overhead on the valid path. - if (numElements < 0 || numElements > maxCollectionSize) { + if (numElements < 0) { throwInvalidMapSize(numElements); } } private void throwInvalidMapSize(int numElements) { - if (numElements < 0) { - throw new DeserializationException("Map size must be non-negative: " + numElements); - } else { - throw new DeserializationException( - "Map size " + numElements + " exceeds max collection size " + maxCollectionSize); - } + throw new DeserializationException("Map size must be non-negative: " + numElements); + } + + private void throwInvalidMapBodySize(int numElements) { + throw new DeserializationException("Map size is too large to read: " + numElements); } public abstract T onMapCopy(Map map); diff --git a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java index 8d44967d16..27b100e2c4 100644 --- a/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java +++ b/java/fory-core/src/main/java/org/apache/fory/serializer/collection/PrimitiveListSerializers.java @@ -53,13 +53,9 @@ /** Serializers for primitive list types. */ @SuppressWarnings({"rawtypes", "unchecked"}) public class PrimitiveListSerializers { - private static void throwBinarySizeLimitExceeded(long size, int maxBinarySize) { - throw new DeserializationException( - "Binary payload size " + size + " exceeds max binary size " + maxBinarySize); - } private static void throwNegativeBinarySize(int size) { - throw new DeserializationException("Binary payload size must be non-negative: " + size); + throw new DeserializationException("Binary body size must be non-negative: " + size); } private static void throwNegativeElementCount(int size) { @@ -68,13 +64,12 @@ private static void throwNegativeElementCount(int size) { private static void throwUnalignedBinarySize(int size, int elemSize) { throw new DeserializationException( - "Binary payload size " + size + " is not aligned to element size " + elemSize); + "Binary body size " + size + " is not aligned to element size " + elemSize); } private abstract static class PrimitiveListSerializer extends CollectionLikeSerializer implements Shareable { private final boolean denseArrayPayload; - protected final int maxBinarySize; private PrimitiveListSerializer(TypeResolver typeResolver, Class cls) { this(typeResolver, cls, false); @@ -84,7 +79,6 @@ private PrimitiveListSerializer( TypeResolver typeResolver, Class cls, boolean denseArrayPayload) { super(typeResolver, cls, false, false); this.denseArrayPayload = denseArrayPayload; - maxBinarySize = config.maxBinarySize(); } @Override @@ -146,9 +140,7 @@ protected final int readOneByteHeader(MemoryBuffer buffer) { if (size < 0) { throwNegativeBinarySize(size); } - if (size > maxBinarySize) { - throwBinarySizeLimitExceeded(size, maxBinarySize); - } + buffer.checkReadableBytes(size); return size; } @@ -158,10 +150,8 @@ protected final int readFixedWidthHeader(MemoryBuffer buffer, int elemSize) { byteSize = buffer.readVarUInt32Small7(); } else if (config.isXlang()) { int size = readXlangListHeader(buffer); - if (size > maxBinarySize / elemSize) { - throwBinarySizeLimitExceeded((long) size * elemSize, maxBinarySize); - } - byteSize = size * elemSize; + byteSize = Math.multiplyExact(size, elemSize); + buffer.checkReadableBytes(byteSize); return size; } else { byteSize = buffer.readVarUInt32Small7(); @@ -172,9 +162,7 @@ protected final int readFixedWidthHeader(MemoryBuffer buffer, int elemSize) { if (byteSize % elemSize != 0) { throwUnalignedBinarySize(byteSize, elemSize); } - if (byteSize > maxBinarySize) { - throwBinarySizeLimitExceeded(byteSize, maxBinarySize); - } + buffer.checkReadableBytes(byteSize); return byteSize / elemSize; } } @@ -203,7 +191,7 @@ public BoolList read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int size = readOneByteHeader(buffer); boolean[] array = new boolean[size]; - buffer.readBooleanArrayPayload(array, size); + buffer.readBooleanArrayBytes(array, size); return new BoolList(array); } @@ -235,7 +223,7 @@ public Int8List read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int size = readOneByteHeader(buffer); byte[] array = new byte[size]; - buffer.readByteArrayPayload(array, size); + buffer.readByteArrayBytes(array, size); return new Int8List(array); } @@ -276,7 +264,7 @@ public Int16List read(ReadContext readContext) { int byteSize = size << 1; short[] array = new short[size]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt16ArrayPayload(array, byteSize); + buffer.readInt16ArrayBytes(array, byteSize); } else { for (int i = 0; i < size; i++) { array[i] = buffer.readInt16(); @@ -329,7 +317,7 @@ public Int32List read(ReadContext readContext) { int byteSize = size << 2; int[] array = new int[size]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt32ArrayPayload(array, byteSize); + buffer.readInt32ArrayBytes(array, byteSize); } else { for (int i = 0; i < size; i++) { array[i] = buffer.readInt32(); @@ -351,9 +339,8 @@ private Int32List readInt32Compressed(MemoryBuffer buffer) { if (size < 0) { throwNegativeElementCount(size); } - if (size > maxBinarySize / 4) { - throwBinarySizeLimitExceeded((long) size * 4, maxBinarySize); - } + + buffer.checkReadableBytes(size); Int32List list = new Int32List(size); for (int i = 0; i < size; i++) { list.add(buffer.readVarInt32()); @@ -415,7 +402,7 @@ public Int64List read(ReadContext readContext) { int byteSize = size << 3; long[] array = new long[size]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt64ArrayPayload(array, byteSize); + buffer.readInt64ArrayBytes(array, byteSize); } else { for (int i = 0; i < size; i++) { array[i] = buffer.readInt64(); @@ -445,9 +432,8 @@ private Int64List readInt64Compressed(MemoryBuffer buffer, Int64Encoding longEnc if (size < 0) { throwNegativeElementCount(size); } - if (size > maxBinarySize / 8) { - throwBinarySizeLimitExceeded((long) size * 8, maxBinarySize); - } + + buffer.checkReadableBytes(size); Int64List list = new Int64List(size); if (longEncoding == Int64Encoding.TAGGED) { for (int i = 0; i < size; i++) { @@ -489,7 +475,7 @@ public UInt8List read(ReadContext readContext) { MemoryBuffer buffer = readContext.getBuffer(); int size = readOneByteHeader(buffer); byte[] array = new byte[size]; - buffer.readByteArrayPayload(array, size); + buffer.readByteArrayBytes(array, size); return new UInt8List(array); } @@ -530,7 +516,7 @@ public UInt16List read(ReadContext readContext) { int byteSize = size << 1; short[] array = new short[size]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt16ArrayPayload(array, byteSize); + buffer.readInt16ArrayBytes(array, byteSize); } else { for (int i = 0; i < size; i++) { array[i] = buffer.readInt16(); @@ -583,7 +569,7 @@ public UInt32List read(ReadContext readContext) { int byteSize = size << 2; int[] array = new int[size]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt32ArrayPayload(array, byteSize); + buffer.readInt32ArrayBytes(array, byteSize); } else { for (int i = 0; i < size; i++) { array[i] = buffer.readInt32(); @@ -605,9 +591,8 @@ private UInt32List readUInt32Compressed(MemoryBuffer buffer) { if (size < 0) { throwNegativeElementCount(size); } - if (size > maxBinarySize / 4) { - throwBinarySizeLimitExceeded((long) size * 4, maxBinarySize); - } + + buffer.checkReadableBytes(size); UInt32List list = new UInt32List(size); for (int i = 0; i < size; i++) { list.add(buffer.readVarInt32()); @@ -669,7 +654,7 @@ public UInt64List read(ReadContext readContext) { int byteSize = size << 3; long[] array = new long[size]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt64ArrayPayload(array, byteSize); + buffer.readInt64ArrayBytes(array, byteSize); } else { for (int i = 0; i < size; i++) { array[i] = buffer.readInt64(); @@ -699,9 +684,8 @@ private UInt64List readUInt64Compressed(MemoryBuffer buffer, Int64Encoding longE if (size < 0) { throwNegativeElementCount(size); } - if (size > maxBinarySize / 8) { - throwBinarySizeLimitExceeded((long) size * 8, maxBinarySize); - } + + buffer.checkReadableBytes(size); UInt64List list = new UInt64List(size); if (longEncoding == Int64Encoding.TAGGED) { for (int i = 0; i < size; i++) { @@ -752,7 +736,7 @@ public Float32List read(ReadContext readContext) { int byteSize = size << 2; float[] array = new float[size]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readFloat32ArrayPayload(array, byteSize); + buffer.readFloat32ArrayBytes(array, byteSize); } else { for (int i = 0; i < size; i++) { array[i] = buffer.readFloat32(); @@ -798,7 +782,7 @@ public Float64List read(ReadContext readContext) { int byteSize = size << 3; double[] array = new double[size]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readFloat64ArrayPayload(array, byteSize); + buffer.readFloat64ArrayBytes(array, byteSize); } else { for (int i = 0; i < size; i++) { array[i] = buffer.readFloat64(); @@ -844,7 +828,7 @@ public Float16List read(ReadContext readContext) { int byteSize = size << 1; short[] array = new short[size]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt16ArrayPayload(array, byteSize); + buffer.readInt16ArrayBytes(array, byteSize); } else { for (int i = 0; i < size; i++) { array[i] = buffer.readInt16(); @@ -890,7 +874,7 @@ public BFloat16List read(ReadContext readContext) { int byteSize = size << 1; short[] array = new short[size]; if (NativeByteOrder.IS_LITTLE_ENDIAN) { - buffer.readInt16ArrayPayload(array, byteSize); + buffer.readInt16ArrayBytes(array, byteSize); } else { for (int i = 0; i < size; i++) { array[i] = buffer.readInt16(); diff --git a/java/fory-core/src/main/java16/org/apache/fory/serializer/CompressedArraySerializers.java b/java/fory-core/src/main/java16/org/apache/fory/serializer/CompressedArraySerializers.java index 9f4eee99cd..04d4173ad6 100644 --- a/java/fory-core/src/main/java16/org/apache/fory/serializer/CompressedArraySerializers.java +++ b/java/fory-core/src/main/java16/org/apache/fory/serializer/CompressedArraySerializers.java @@ -49,17 +49,13 @@ private CompressedArraySerializers() { // Utility class } - private static void validateBinarySize(int size, int maxBinarySize, int elemSize) { + private static void validateBinarySize(int size, int elemSize) { if (size < 0) { - throw new DeserializationException("Binary payload size must be non-negative: " + size); - } - if (size > maxBinarySize) { - throw new DeserializationException( - "Binary payload size " + size + " exceeds max binary size " + maxBinarySize); + throw new DeserializationException("Binary body size must be non-negative: " + size); } if ((size & (elemSize - 1)) != 0) { throw new DeserializationException( - "Binary payload size " + size + " is not aligned to element size " + elemSize); + "Binary body size " + size + " is not aligned to element size " + elemSize); } } @@ -222,33 +218,37 @@ public int[] read(ReadContext readContext) { private int[] readFromBufferObject(ReadContext readContext) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); - validateBinarySize(size, maxBinarySize, 4); + validateBinarySize(size, 4); + buf.checkReadableBytes(size); int[] values = new int[size >>> 2]; - buf.readInt32ArrayPayload(values, size); + buf.readInt32ArrayBytes(values, size); return values; } private int[] readCompressedFromBytes(MemoryBuffer buffer) { int size = buffer.readVarUInt32Small7(); - validateBinarySize(size, maxBinarySize, 1); + validateBinarySize(size, 1); + buffer.checkReadableBytes(size); byte[] values = new byte[size]; - buffer.readByteArrayPayload(values, size); + buffer.readByteArrayBytes(values, size); return ArrayCompressionUtils.decompressFromBytes(values); } private int[] readCompressedFromShorts(MemoryBuffer buffer) { int size = buffer.readVarUInt32Small7(); - validateBinarySize(size, maxBinarySize, 2); + validateBinarySize(size, 2); + buffer.checkReadableBytes(size); short[] values = new short[size >>> 1]; - buffer.readInt16ArrayPayload(values, size); + buffer.readInt16ArrayBytes(values, size); return ArrayCompressionUtils.decompressFromShorts(values); } private int[] readUncompressed(MemoryBuffer buffer) { int size = buffer.readVarUInt32Small7(); - validateBinarySize(size, maxBinarySize, 4); + validateBinarySize(size, 4); + buffer.checkReadableBytes(size); int[] values = new int[size >>> 2]; - buffer.readInt32ArrayPayload(values, size); + buffer.readInt32ArrayBytes(values, size); return values; } } @@ -326,25 +326,28 @@ public long[] read(ReadContext readContext) { private long[] readFromBufferObject(ReadContext readContext) { MemoryBuffer buf = readContext.readBufferObject(); int size = buf.remaining(); - validateBinarySize(size, maxBinarySize, 8); + validateBinarySize(size, 8); + buf.checkReadableBytes(size); long[] values = new long[size >>> 3]; - buf.readInt64ArrayPayload(values, size); + buf.readInt64ArrayBytes(values, size); return values; } private long[] readCompressedFromInts(MemoryBuffer buffer) { int size = buffer.readVarUInt32Small7(); - validateBinarySize(size, maxBinarySize, 4); + validateBinarySize(size, 4); + buffer.checkReadableBytes(size); int[] values = new int[size >>> 2]; - buffer.readInt32ArrayPayload(values, size); + buffer.readInt32ArrayBytes(values, size); return ArrayCompressionUtils.decompressFromInts(values); } private long[] readUncompressed(MemoryBuffer buffer) { int size = buffer.readVarUInt32Small7(); - validateBinarySize(size, maxBinarySize, 8); + validateBinarySize(size, 8); + buffer.checkReadableBytes(size); long[] values = new long[size >>> 3]; - buffer.readInt64ArrayPayload(values, size); + buffer.readInt64ArrayBytes(values, size); return values; } } diff --git a/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java b/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java index 4e9e410e7e..309d9e6e43 100644 --- a/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java +++ b/java/fory-core/src/main/java25/org/apache/fory/memory/MemoryBuffer.java @@ -3058,6 +3058,7 @@ private long skipPadding(long pos, int b) { } public byte[] readBytes(int length) { + checkReadableBytes(length); int readerIdx = readerIndex; byte[] bytes = new byte[length]; // use subtract to avoid overflow @@ -3187,7 +3188,7 @@ public int readBinarySize() { } int diff = size - readIdx; if (diff < binarySize) { - streamReader.fillBuffer(diff); + streamReader.fillBuffer(binarySize - diff); } return binarySize; } @@ -3210,7 +3211,7 @@ private int continueReadBinarySize(int readIdx, int bulkRead, int binarySize) { } int diff = size - readIdx; if (diff < binarySize) { - streamReader.fillBuffer(diff); + streamReader.fillBuffer(binarySize - diff); } return binarySize; } @@ -3230,11 +3231,11 @@ public byte[] readBytesAndSize() { } /** - * Reads a size-validated primitive byte-array payload into {@code values}. The caller owns size - * validation and destination allocation; this method reads payload bytes only, not the size + * Reads a size-validated primitive byte-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size * prefix. */ - public void readByteArrayPayload(byte[] values, int numBytes) { + public void readByteArrayBytes(byte[] values, int numBytes) { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { streamReader.readTo(values, 0, numBytes); @@ -3245,11 +3246,11 @@ public void readByteArrayPayload(byte[] values, int numBytes) { } /** - * Reads a size-validated primitive boolean-array payload into {@code values}. The caller owns - * size validation and destination allocation; this method reads payload bytes only, not the size + * Reads a size-validated primitive boolean-array body into {@code values}. The caller owns + * size validation and destination allocation; this method reads body bytes only, not the size * prefix. */ - public void readBooleanArrayPayload(boolean[] values, int numBytes) { + public void readBooleanArrayBytes(boolean[] values, int numBytes) { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { streamReader.readBooleans(values, 0, numBytes); @@ -3260,11 +3261,11 @@ public void readBooleanArrayPayload(boolean[] values, int numBytes) { } /** - * Reads a size-validated primitive char-array payload into {@code values}. The caller owns size - * validation and destination allocation; this method reads payload bytes only, not the size + * Reads a size-validated primitive char-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size * prefix. */ - public void readCharArrayPayload(char[] values, int numBytes) { + public void readCharArrayBytes(char[] values, int numBytes) { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { streamReader.readChars(values, 0, numBytes >>> 1); @@ -3275,11 +3276,11 @@ public void readCharArrayPayload(char[] values, int numBytes) { } /** - * Reads a size-validated primitive int16-array payload into {@code values}. The caller owns size - * validation and destination allocation; this method reads payload bytes only, not the size + * Reads a size-validated primitive int16-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size * prefix. */ - public void readInt16ArrayPayload(short[] values, int numBytes) { + public void readInt16ArrayBytes(short[] values, int numBytes) { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { streamReader.readShorts(values, 0, numBytes >>> 1); @@ -3290,11 +3291,11 @@ public void readInt16ArrayPayload(short[] values, int numBytes) { } /** - * Reads a size-validated primitive int32-array payload into {@code values}. The caller owns size - * validation and destination allocation; this method reads payload bytes only, not the size + * Reads a size-validated primitive int32-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size * prefix. */ - public void readInt32ArrayPayload(int[] values, int numBytes) { + public void readInt32ArrayBytes(int[] values, int numBytes) { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { streamReader.readInts(values, 0, numBytes >>> 2); @@ -3305,11 +3306,11 @@ public void readInt32ArrayPayload(int[] values, int numBytes) { } /** - * Reads a size-validated primitive int64-array payload into {@code values}. The caller owns size - * validation and destination allocation; this method reads payload bytes only, not the size + * Reads a size-validated primitive int64-array body into {@code values}. The caller owns size + * validation and destination allocation; this method reads body bytes only, not the size * prefix. */ - public void readInt64ArrayPayload(long[] values, int numBytes) { + public void readInt64ArrayBytes(long[] values, int numBytes) { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { streamReader.readLongs(values, 0, numBytes >>> 3); @@ -3320,11 +3321,11 @@ public void readInt64ArrayPayload(long[] values, int numBytes) { } /** - * Reads a size-validated primitive float32-array payload into {@code values}. The caller owns - * size validation and destination allocation; this method reads payload bytes only, not the size + * Reads a size-validated primitive float32-array body into {@code values}. The caller owns + * size validation and destination allocation; this method reads body bytes only, not the size * prefix. */ - public void readFloat32ArrayPayload(float[] values, int numBytes) { + public void readFloat32ArrayBytes(float[] values, int numBytes) { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { streamReader.readFloats(values, 0, numBytes >>> 2); @@ -3335,11 +3336,11 @@ public void readFloat32ArrayPayload(float[] values, int numBytes) { } /** - * Reads a size-validated primitive float64-array payload into {@code values}. The caller owns - * size validation and destination allocation; this method reads payload bytes only, not the size + * Reads a size-validated primitive float64-array body into {@code values}. The caller owns + * size validation and destination allocation; this method reads body bytes only, not the size * prefix. */ - public void readFloat64ArrayPayload(double[] values, int numBytes) { + public void readFloat64ArrayBytes(double[] values, int numBytes) { int readerIdx = readerIndex; if (readerIdx > size - numBytes) { streamReader.readDoubles(values, 0, numBytes >>> 3); @@ -3385,6 +3386,11 @@ public void readChars(char[] chars, int offset, int numElements) { @CodegenInvoke public char[] readCharsAndSize() { final int numBytes = readBinarySize(); + if ((numBytes & 1) != 0) { + throw new IllegalArgumentException( + "Char array byte size " + numBytes + " is not aligned to element size 2"); + } + checkReadableBytes(numBytes); int numElements = numBytes >> 1; char[] values = new char[numElements]; readChars(values, 0, numElements); diff --git a/java/fory-core/src/test/java/org/apache/fory/StreamTest.java b/java/fory-core/src/test/java/org/apache/fory/StreamTest.java index 7f715335d9..0cbd2efdff 100644 --- a/java/fory-core/src/test/java/org/apache/fory/StreamTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/StreamTest.java @@ -21,6 +21,7 @@ import static org.apache.fory.io.ForyStreamReader.of; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertTrue; import com.google.common.collect.Lists; @@ -336,6 +337,40 @@ public void testReadableChannelRequiresExactReads() throws IOException { } } + @Test + public void testStreamFillGrowsFromBufferedBytes() throws IOException { + byte[] complete = new byte[100]; + ForyInputStream inputWithAvailable = new ForyInputStream(new ByteArrayInputStream(complete), 4); + assertEquals(inputWithAvailable.fillBuffer(100), 100); + assertEquals(inputWithAvailable.getBuffer().getHeapMemory().length, 100); + + byte[] truncated = new byte[17]; + ForyInputStream input = new ForyInputStream(new ByteArrayInputStream(truncated), 4); + Assert.assertThrows(IndexOutOfBoundsException.class, () -> input.fillBuffer(100)); + int inputCapacity = input.getBuffer().getHeapMemory().length; + assertTrue(inputCapacity < 100); + assertTrue(inputCapacity <= 32); + + try (ForyReadableChannel channel = + new ForyReadableChannel( + new ChunkedReadableByteChannel(truncated, truncated.length), ByteBuffer.allocate(4))) { + Assert.assertThrows(DeserializationException.class, () -> channel.fillBuffer(100)); + int channelCapacity = channel.getBuffer().getHeapMemory().length; + assertTrue(channelCapacity < 100); + assertTrue(channelCapacity <= 32); + } + + Path tempFile = Files.createTempFile("readable_channel_available", "data"); + Files.write(tempFile, complete); + try (ForyReadableChannel channel = + new ForyReadableChannel(Files.newByteChannel(tempFile), ByteBuffer.allocate(4))) { + assertEquals(channel.fillBuffer(100), 100); + assertEquals(channel.getBuffer().getHeapMemory().length, 100); + } finally { + Files.delete(tempFile); + } + } + @Test public void testScopedMetaShare() throws IOException { Fory fory = @@ -470,7 +505,7 @@ public void testBigBufferStreamingMetaShare() throws IOException { } @Test - public void testPrimitiveArrayStreamReaderUsesTypedReads() throws IOException { + public void testStreamPrimitiveArrayBody() throws IOException { Fory fory = builder().requireClassRegistration(false).build(); int[] ints = new int[257]; @@ -480,7 +515,7 @@ public void testPrimitiveArrayStreamReaderUsesTypedReads() throws IOException { TrackingForyInputStream input = new TrackingForyInputStream(new ChunkedInputStream(fory.serialize(ints), 1), 3); Assert.assertEquals((int[]) fory.deserialize(input), ints); - assertTrue(input.readIntsCalled); + assertFalse(input.readIntsCalled); long[] longs = new long[257]; for (int i = 0; i < longs.length; i++) { @@ -491,7 +526,7 @@ public void testPrimitiveArrayStreamReaderUsesTypedReads() throws IOException { new TrackingForyReadableChannel( new ChunkedReadableByteChannel(serialized, 1), ByteBuffer.allocateDirect(5))) { Assert.assertEquals((long[]) fory.deserialize(channel), longs); - assertTrue(channel.readLongsCalled); + assertFalse(channel.readLongsCalled); } ByteBuffer limitedDirectBuffer = ByteBuffer.allocateDirect(serialized.length + 8); @@ -500,7 +535,7 @@ public void testPrimitiveArrayStreamReaderUsesTypedReads() throws IOException { new TrackingForyReadableChannel( new ChunkedReadableByteChannel(serialized, 1), limitedDirectBuffer)) { Assert.assertEquals((long[]) fory.deserialize(channel), longs); - assertTrue(channel.readLongsCalled); + assertFalse(channel.readLongsCalled); } } diff --git a/java/fory-core/src/test/java/org/apache/fory/memory/MemoryBufferTest.java b/java/fory-core/src/test/java/org/apache/fory/memory/MemoryBufferTest.java index a7f2f062f0..1054c2a018 100644 --- a/java/fory-core/src/test/java/org/apache/fory/memory/MemoryBufferTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/memory/MemoryBufferTest.java @@ -84,6 +84,13 @@ public void testBufferWrite() { assertEquals(buffer.readerIndex(), buffer.writerIndex()); } + @Test + public void testReadCharsAndSizeRequiresBodyBeforeAlloc() { + MemoryBuffer buffer = MemoryUtils.buffer(8); + buffer.writeVarUInt32(1 << 20); + assertThrows(IndexOutOfBoundsException.class, buffer::readCharsAndSize); + } + @Test public void testDirectBufferRejectsHeap() { assertThrows( @@ -577,6 +584,14 @@ public void testWritePrimitiveArrayWithSizeEmbedded() { assertEquals(chars, readChars); } + @Test + public void testReadCharsAndSizeAlignment() { + MemoryBuffer buf = MemoryUtils.buffer(8); + buf.writeVarUInt32(3); + buf.writeBytes(new byte[] {1, 2, 3}); + assertThrows(IllegalArgumentException.class, buf::readCharsAndSize); + } + @Test public void testWriteVarUInt32() { for (int i = 0; i < 32; i++) { diff --git a/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java b/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java index 8c11ce4edc..2d6bbce13e 100644 --- a/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/meta/NativeTypeDefEncoderTest.java @@ -72,25 +72,6 @@ public void testBigMetaEncoding() { } } - @Test - public void testTypeDefCountIgnoresLimit() { - Fory writer = Fory.builder().withXlang(false).withMetaShare(true).withCompatible(false).build(); - Fory reader = - Fory.builder() - .withXlang(false) - .withMetaShare(true) - .withMaxCollectionSize(1) - .withCompatible(false) - .build(); - TypeDef typeDef = - TypeDef.buildTypeDef(writer.getTypeResolver(), TypeDefTest.TestFieldsOrderClass1.class); - - TypeDef decoded = - TypeDef.readTypeDef( - reader.getTypeResolver(), MemoryBuffer.fromByteArray(typeDef.getEncoded())); - Assert.assertEquals(decoded, typeDef); - } - @Test public void testTypeDefArrayDimensionLimit() { Fory fory = Fory.builder().withXlang(false).withCompatible(false).build(); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java index 5e9b3225f1..b0fb46f0f4 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ArraySerializersTest.java @@ -209,19 +209,18 @@ public void testObjectArrayCopy(Fory fory) { } @Test - public void testObjectArrayReadRejectsOversizedElementCount() { + public void testObjectArrayReadRequiresBodyByte() { Fory fory = Fory.builder() .withXlang(false) .withRefTracking(true) .requireClassRegistration(false) - .withMaxCollectionSize(1) .withCompatible(false) .build(); assertThrows( - DeserializationException.class, () -> readObjectArrayPayload(fory, Object[].class, 2)); + IndexOutOfBoundsException.class, () -> readObjectArrayBody(fory, Object[].class, 2)); assertThrows( - DeserializationException.class, () -> readObjectArrayPayload(fory, String[].class, 2)); + IndexOutOfBoundsException.class, () -> readObjectArrayBody(fory, String[].class, 2)); } @Test(dataProvider = "crossLanguageReferenceTrackingConfig") @@ -310,51 +309,20 @@ public static void testPrimitiveArray(Fory fory1, Fory fory2) { } @Test - public void testPrimitiveArrayReadRejectsOversizedBinaryPayload() { - Fory fory = - Fory.builder() - .withXlang(false) - .withMaxBinarySize(4) - .withIntArrayCompressed(true) - .withLongArrayCompressed(true) - .withCompatible(false) - .build(); - for (Class arrayType : - new Class[] { - boolean[].class, - byte[].class, - char[].class, - short[].class, - int[].class, - long[].class, - float[].class, - double[].class - }) { - assertThrows( - DeserializationException.class, - () -> readPrimitiveArrayPayload(fory, arrayType, 8, false)); - } - assertThrows( - DeserializationException.class, - () -> readPrimitiveArrayPayload(fory, byte[].class, 5, true)); - } - - @Test - public void testPrimitiveArrayReadRejectsUnalignedBinaryPayload() { - Fory fory = Fory.builder().withXlang(false).withMaxBinarySize(64).withCompatible(false).build(); + public void testPrimitiveArrayReadRejectsUnalignedBinaryBody() { + Fory fory = Fory.builder().withXlang(false).withCompatible(false).build(); for (Class arrayType : new Class[] { char[].class, short[].class, int[].class, long[].class, float[].class, double[].class }) { assertThrows( - DeserializationException.class, - () -> readPrimitiveArrayPayload(fory, arrayType, 3, false)); + DeserializationException.class, () -> readPrimitiveArrayBody(fory, arrayType, 3, false)); } } @Test - public void testPrimitiveArrayReadRejectsTruncatedPayload() { - Fory fory = Fory.builder().withXlang(false).withMaxBinarySize(64).withCompatible(false).build(); + public void testPrimitiveArrayReadRejectsTruncatedBody() { + Fory fory = Fory.builder().withXlang(false).withCompatible(false).build(); Class[] arrayTypes = new Class[] { boolean[].class, @@ -372,16 +340,16 @@ public void testPrimitiveArrayReadRejectsTruncatedPayload() { int byteSize = byteSizes[i]; assertThrows( IndexOutOfBoundsException.class, - () -> readTruncatedPrimitiveArrayPayload(fory, arrayType, byteSize)); + () -> readTruncatedPrimitiveArrayBody(fory, arrayType, byteSize)); } } @Test - public void testPrimitiveArrayReadRejectsNegativeDecodedBinaryPayload() { + public void testPrimitiveArrayReadRejectsNegativeDecodedBinaryBody() { Fory fixedWidthFory = Fory.builder().withXlang(false).withCompatible(false).build(); assertThrows( DeserializationException.class, - () -> readPrimitiveArrayRawPayload(fixedWidthFory, char[].class)); + () -> readPrimitiveArrayRawBody(fixedWidthFory, char[].class)); Fory compressedFory = Fory.builder() @@ -392,13 +360,13 @@ public void testPrimitiveArrayReadRejectsNegativeDecodedBinaryPayload() { .build(); assertThrows( DeserializationException.class, - () -> readPrimitiveArrayRawPayload(compressedFory, int[].class)); + () -> readPrimitiveArrayRawBody(compressedFory, int[].class)); assertThrows( DeserializationException.class, - () -> readPrimitiveArrayRawPayload(compressedFory, long[].class)); + () -> readPrimitiveArrayRawBody(compressedFory, long[].class)); } - private static Object readPrimitiveArrayPayload( + private static Object readPrimitiveArrayBody( Fory fory, Class arrayType, int byteSize, boolean outOfBand) { ReadContext readContext = fory.getReadContext(); if (outOfBand) { @@ -414,7 +382,7 @@ private static Object readPrimitiveArrayPayload( return fory.getSerializer(arrayType).read(readContext); } - private static Object readTruncatedPrimitiveArrayPayload( + private static Object readTruncatedPrimitiveArrayBody( Fory fory, Class arrayType, int byteSize) { ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); @@ -424,7 +392,7 @@ private static Object readTruncatedPrimitiveArrayPayload( return fory.getSerializer(arrayType).read(readContext); } - private static Object readPrimitiveArrayRawPayload(Fory fory, Class arrayType) { + private static Object readPrimitiveArrayRawBody(Fory fory, Class arrayType) { ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); writeNegativeDecodedVarUInt32(buffer); @@ -432,7 +400,7 @@ private static Object readPrimitiveArrayRawPayload(Fory fory, Class arrayType return fory.getSerializer(arrayType).read(readContext); } - private static Object readObjectArrayPayload(Fory fory, Class arrayType, int numElements) { + private static Object readObjectArrayBody(Fory fory, Class arrayType, int numElements) { ReadContext readContext = fory.getReadContext(); MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(numElements); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java index f9604bb1d9..cd3d31dcac 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/CompatibleSerializerTest.java @@ -19,6 +19,8 @@ package org.apache.fory.serializer; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -31,6 +33,8 @@ import org.apache.fory.ForyTestBase; import org.apache.fory.TestUtils; import org.apache.fory.config.Language; +import org.apache.fory.memory.MemoryBuffer; +import org.apache.fory.memory.MemoryUtils; import org.apache.fory.serializer.collection.UnmodifiableSerializersTest; import org.apache.fory.test.bean.BeanA; import org.apache.fory.test.bean.BeanB; @@ -38,6 +42,7 @@ import org.apache.fory.test.bean.Foo; import org.apache.fory.test.bean.MapFields; import org.apache.fory.test.bean.Struct; +import org.apache.fory.type.Types; import org.testng.Assert; import org.testng.annotations.Test; @@ -129,6 +134,20 @@ public void testWriteCompatibleBasic() throws Exception { } } + @Test + public void testNullableListBodyBounds() throws Exception { + Method method = + CompatibleCollectionArrayReader.class.getDeclaredMethod( + "readNullableListBoxedElements", MemoryBuffer.class, int.class, int.class, int.class); + method.setAccessible(true); + MemoryBuffer buffer = MemoryUtils.buffer(0); + InvocationTargetException exception = + Assert.expectThrows( + InvocationTargetException.class, + () -> method.invoke(null, buffer, 1024, Types.INT32_ARRAY, Types.INT32)); + Assert.assertTrue(exception.getCause() instanceof IndexOutOfBoundsException); + } + @Test public void testWriteNestedCollection() throws Exception { Fory fory = diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java index 9d9b914cb5..2b0505e077 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/ExceptionSerializersTest.java @@ -60,24 +60,6 @@ public void testBuiltInThrowableRoundTrip(Fory fory) { Assert.assertEquals(copy.getSuppressed()[1].getMessage(), "suppressed-2"); } - @Test - public void testSuppressedCountLimit() { - Fory writer = Fory.builder().withXlang(false).withCompatible(false).build(); - Fory reader = - Fory.builder().withXlang(false).withMaxCollectionSize(1).withCompatible(false).build(); - RuntimeException value = new RuntimeException("outer"); - RuntimeException suppressed1 = new RuntimeException("suppressed-1"); - RuntimeException suppressed2 = new RuntimeException("suppressed-2"); - value.setStackTrace(new StackTraceElement[0]); - suppressed1.setStackTrace(new StackTraceElement[0]); - suppressed2.setStackTrace(new StackTraceElement[0]); - value.addSuppressed(suppressed1); - value.addSuppressed(suppressed2); - byte[] bytes = writer.serialize(value); - - Assert.assertThrows(ForyException.class, () -> reader.deserialize(bytes)); - } - @Test(dataProvider = "javaFory") public void testStackTraceElementRoundTrip(Fory fory) { StackTraceElement value = new Exception().getStackTrace()[0]; diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/LambdaSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/LambdaSerializerTest.java index 8327115d9d..ae40bdb812 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/LambdaSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/LambdaSerializerTest.java @@ -33,7 +33,6 @@ import java.util.function.Function; import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; -import org.apache.fory.exception.DeserializationException; import org.apache.fory.exception.InsecureException; import org.testng.Assert; import org.testng.annotations.Test; @@ -129,28 +128,6 @@ public void testSerializedLambdaAdmission() throws Exception { Assert.assertThrows(InsecureException.class, () -> reader.deserialize(bytes)); } - @Test - public void testSerializedLambdaArgLimit() throws Exception { - int delta = 7; - Function function = - (Serializable & Function) (x) -> x + delta; - Fory writer = - Fory.builder() - .withXlang(false) - .requireClassRegistration(false) - .withCompatible(false) - .build(); - Fory reader = - Fory.builder() - .withXlang(false) - .requireClassRegistration(false) - .withMaxCollectionSize(0) - .withCompatible(false) - .build(); - byte[] bytes = writer.serialize(extractSerializedLambda(function)); - Assert.assertThrows(DeserializationException.class, () -> reader.deserialize(bytes)); - } - @Test(dataProvider = "foryCopyConfig") public void testSerializedLambdaCopy(Fory fory) throws Exception { int delta = 7; diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java index 4c18e28a84..bc945d60ff 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/PrimitiveSerializersTest.java @@ -269,31 +269,24 @@ public void testPrimitiveListAsCollectionFieldWithCodegen() { } @Test - public void testPrimitiveListReadRejectsMalformedBinaryPayloadSize() { + public void testPrimitiveListReadRejectsMalformedBinaryBodySize() { Fory fory = Fory.builder() .withXlang(false) - .withMaxBinarySize(4) .withIntArrayCompressed(true) .withLongArrayCompressed(true) .withCompatible(false) .build(); assertThrows( - DeserializationException.class, () -> readPrimitiveListPayload(fory, Int8List.class, 5)); - assertThrows( - DeserializationException.class, () -> readPrimitiveListPayload(fory, Int16List.class, 3)); - assertThrows( - DeserializationException.class, () -> readPrimitiveListPayload(fory, Int32List.class, 2)); - assertThrows( - DeserializationException.class, () -> readPrimitiveListPayload(fory, Int64List.class, 1)); + DeserializationException.class, () -> readPrimitiveListBody(fory, Int16List.class, 3)); } @Test - public void testPrimitiveListReadRejectsNegativeDecodedBinaryPayload() { + public void testPrimitiveListReadRejectsNegativeDecodedBinaryBody() { Fory fixedWidthFory = Fory.builder().withXlang(false).withCompatible(false).build(); assertThrows( DeserializationException.class, - () -> readPrimitiveListRawPayload(fixedWidthFory, Int16List.class)); + () -> readPrimitiveListRawBody(fixedWidthFory, Int16List.class)); Fory compressedFory = Fory.builder() @@ -304,13 +297,13 @@ public void testPrimitiveListReadRejectsNegativeDecodedBinaryPayload() { .build(); assertThrows( DeserializationException.class, - () -> readPrimitiveListRawPayload(compressedFory, Int32List.class)); + () -> readPrimitiveListRawBody(compressedFory, Int32List.class)); assertThrows( DeserializationException.class, - () -> readPrimitiveListRawPayload(compressedFory, Int64List.class)); + () -> readPrimitiveListRawBody(compressedFory, Int64List.class)); } - private static Object readPrimitiveListPayload(Fory fory, Class listType, int headerSize) { + private static Object readPrimitiveListBody(Fory fory, Class listType, int headerSize) { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); buffer.writeVarUInt32Small7(headerSize); ReadContext readContext = fory.getReadContext(); @@ -318,7 +311,7 @@ private static Object readPrimitiveListPayload(Fory fory, Class listType, int return fory.getSerializer(listType).read(readContext); } - private static Object readPrimitiveListRawPayload(Fory fory, Class listType) { + private static Object readPrimitiveListRawBody(Fory fory, Class listType) { MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(5); writeNegativeDecodedVarUInt32(buffer); ReadContext readContext = fory.getReadContext(); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java index 81e29cdee3..dec19e459c 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/SerializersTest.java @@ -50,7 +50,6 @@ import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; import org.apache.fory.config.ForyBuilder; -import org.apache.fory.exception.DeserializationException; import org.apache.fory.exception.InsecureException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; @@ -93,42 +92,6 @@ public void testBigInt(boolean referenceTracking) { fory1, new BigInteger("11111111110101010000283895380202208220050200000000111111111")); } - @Test - public void testBigNumberReadsRejectOversizedBinaryPayload() { - Fory fory = - Fory.builder() - .withXlang(false) - .withMaxBinarySize(1) - .requireClassRegistration(false) - .withCompatible(false) - .build(); - - assertThrows( - DeserializationException.class, - () -> readSerializer(fory, fory.getSerializer(BigInteger.class), bigIntegerPayload(2))); - assertThrows( - DeserializationException.class, - () -> readSerializer(fory, fory.getSerializer(BigDecimal.class), bigDecimalPayload(2))); - - Fory xlangFory = - Fory.builder() - .withXlang(true) - .withCompatible(false) - .withMaxBinarySize(1) - .requireClassRegistration(false) - .build(); - assertThrows( - DeserializationException.class, - () -> - readSerializer( - xlangFory, xlangFory.getSerializer(BigInteger.class), xlangDecimalPayload(2))); - assertThrows( - DeserializationException.class, - () -> - readSerializer( - xlangFory, xlangFory.getSerializer(BigDecimal.class), xlangDecimalPayload(2))); - } - private static MemoryBuffer bigIntegerPayload(int len) { MemoryBuffer buffer = MemoryUtils.buffer(16); buffer.writeVarUInt32Small7(len); diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/StringSerializerTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/StringSerializerTest.java index c484e4149b..756c00e554 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/StringSerializerTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/StringSerializerTest.java @@ -33,7 +33,6 @@ import org.apache.fory.Fory; import org.apache.fory.ForyTestBase; import org.apache.fory.collection.Tuple2; -import org.apache.fory.exception.DeserializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; import org.apache.fory.platform.JdkVersion; @@ -50,6 +49,17 @@ public static Object[][] stringCompress() { return new Object[][] {{false}, {true}}; } + @Test + public void testRejectOddUtf16ByteSize() { + Fory fory = Fory.builder().build(); + StringSerializer serializer = new StringSerializer(fory.getConfig()); + MemoryBuffer buffer = MemoryBuffer.newHeapBuffer(8); + int writerIndex = buffer._unsafePutVarUint36Small(0, (3L << 2) | 1); + buffer._unsafeWriterIndex(writerIndex); + buffer.writeBytes(new byte[] {1, 2, 3}); + Assert.assertThrows(IllegalArgumentException.class, () -> serializer.readString(buffer)); + } + @Test public void testJavaStringZeroCopy() { if (JdkVersion.MAJOR_VERSION >= 17) { @@ -167,19 +177,6 @@ public void testJavaStringSimple() { } } - @Test - public void testStringSizeLimit() { - Fory writer = Fory.builder().withXlang(false).withCompatible(false).build(); - Fory reader = - Fory.builder().withXlang(false).withMaxBinarySize(2).withCompatible(false).build(); - MemoryBuffer buffer = MemoryUtils.buffer(32); - new StringSerializer(writer.getConfig()).writeString(buffer, "abcd"); - - Assert.assertThrows( - DeserializationException.class, - () -> new StringSerializer(reader.getConfig()).readString(buffer)); - } - @Data public static class Simple { private String str; diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java index c5ab5088a6..c4d5a35730 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/CollectionSerializersTest.java @@ -1136,24 +1136,22 @@ public void testSerializeJavaBlockingQueue() { .withCompatible(false) .build(); { - ArrayBlockingQueue queue = new ArrayBlockingQueue<>(10); + ArrayBlockingQueue queue = new ArrayBlockingQueue<>(3); queue.add(1); queue.add(2); queue.add(3); ArrayBlockingQueue deserialized = serDe(fory, queue); assertEquals(new ArrayList<>(deserialized), new ArrayList<>(queue)); - // Verify capacity is preserved - assertEquals(deserialized.remainingCapacity() + deserialized.size(), 10); + assertEquals(deserialized.remainingCapacity() + deserialized.size(), 3); } { - LinkedBlockingQueue queue = new LinkedBlockingQueue<>(10); + LinkedBlockingQueue queue = new LinkedBlockingQueue<>(3); queue.add(1); queue.add(2); queue.add(3); LinkedBlockingQueue deserialized = serDe(fory, queue); assertEquals(new ArrayList<>(deserialized), new ArrayList<>(queue)); - // Verify capacity is preserved - assertEquals(deserialized.remainingCapacity() + deserialized.size(), 10); + assertEquals(deserialized.remainingCapacity() + deserialized.size(), 3); } } @@ -1164,39 +1162,37 @@ public void testBlockingQueueBadCapacity() { .withXlang(false) .withRefTracking(true) .requireClassRegistration(false) - .withMaxCollectionSize(4) .withCompatible(false) .build(); - CollectionSerializers.ArrayBlockingQueueSerializer arraySerializer = - new CollectionSerializers.ArrayBlockingQueueSerializer( - fory.getTypeResolver(), ArrayBlockingQueue.class); CollectionSerializers.LinkedBlockingQueueSerializer linkedSerializer = new CollectionSerializers.LinkedBlockingQueueSerializer( fory.getTypeResolver(), LinkedBlockingQueue.class); - MemoryBuffer oversizedCapacity = MemoryUtils.buffer(8); - oversizedCapacity.writeVarUInt32Small7(2); - oversizedCapacity.writeVarUInt32Small7(5); - Assert.expectThrows( - DeserializationException.class, - () -> withReadContext(fory, oversizedCapacity, arraySerializer::newCollection)); - MemoryBuffer undersizedCapacity = MemoryUtils.buffer(8); undersizedCapacity.writeVarUInt32Small7(2); undersizedCapacity.writeVarUInt32Small7(1); Assert.expectThrows( DeserializationException.class, () -> withReadContext(fory, undersizedCapacity, linkedSerializer::newCollection)); + + CollectionSerializers.ArrayBlockingQueueSerializer arraySerializer = + new CollectionSerializers.ArrayBlockingQueueSerializer( + fory.getTypeResolver(), ArrayBlockingQueue.class); + MemoryBuffer sparseCapacity = MemoryUtils.buffer(8); + sparseCapacity.writeVarUInt32Small7(2); + sparseCapacity.writeVarUInt32Small7(10); + Assert.expectThrows( + IndexOutOfBoundsException.class, + () -> withReadContext(fory, sparseCapacity, arraySerializer::newCollection)); } @Test - public void testCollectionRejectsTooManyElements() { + public void testCollectionReadRequiresBodyByte() { Fory fory = Fory.builder() .withXlang(false) .withRefTracking(true) .requireClassRegistration(false) - .withMaxCollectionSize(1) .withCompatible(false) .build(); CollectionSerializers.ArrayListSerializer serializer = @@ -1204,7 +1200,7 @@ public void testCollectionRejectsTooManyElements() { MemoryBuffer buffer = MemoryUtils.buffer(8); buffer.writeVarUInt32Small7(2); Assert.expectThrows( - DeserializationException.class, + IndexOutOfBoundsException.class, () -> withReadContext(fory, buffer, serializer::newCollection)); } diff --git a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java index 2e7061f0f3..bb5f721c59 100644 --- a/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/serializer/collection/MapSerializersTest.java @@ -58,7 +58,6 @@ import org.apache.fory.annotation.Ref; import org.apache.fory.collection.LazyMap; import org.apache.fory.collection.MapEntry; -import org.apache.fory.exception.DeserializationException; import org.apache.fory.exception.SerializationException; import org.apache.fory.memory.MemoryBuffer; import org.apache.fory.memory.MemoryUtils; @@ -675,13 +674,12 @@ public void testEmptyMap() { } @Test - public void testMapReadRejectsOversizedElementCount() { + public void testMapReadRequiresBodyByte() { Fory fory = Fory.builder() .withXlang(false) .withRefTracking(true) .requireClassRegistration(false) - .withMaxCollectionSize(1) .withCompatible(false) .build(); MapSerializers.HashMapSerializer serializer = @@ -689,7 +687,7 @@ public void testMapReadRejectsOversizedElementCount() { MemoryBuffer buffer = MemoryUtils.buffer(8); buffer.writeVarUInt32Small7(2); Assert.expectThrows( - DeserializationException.class, () -> withReadContext(fory, buffer, serializer::newMap)); + IndexOutOfBoundsException.class, () -> withReadContext(fory, buffer, serializer::newMap)); } @Test(dataProvider = "foryCopyConfig") diff --git a/javascript/packages/core/lib/compatible/scalar.ts b/javascript/packages/core/lib/compatible/scalar.ts index 56759ec06d..5d78cbccc3 100644 --- a/javascript/packages/core/lib/compatible/scalar.ts +++ b/javascript/packages/core/lib/compatible/scalar.ts @@ -156,11 +156,14 @@ function readDecimal(reader: BinaryReader): Decimal { if (length <= 0 || length > 0x7fffffff) { throw new Error(`Invalid decimal magnitude length ${length}.`); } - const payload = reader.buffer(length); - if (payload[length - 1] === 0) { - throw new Error("Non-canonical decimal payload: trailing zero byte."); + const magnitudeBytes = reader.buffer(length); + if (magnitudeBytes[length - 1] === 0) { + throw new Error( + "Non-canonical decimal magnitude bytes: trailing zero byte.", + ); } - const magnitude = DecimalCodec.fromCanonicalLittleEndianMagnitude(payload); + const magnitude + = DecimalCodec.fromCanonicalLittleEndianMagnitude(magnitudeBytes); if (magnitude === 0n) { throw new Error("Big decimal encoding must not represent zero."); } diff --git a/javascript/packages/core/lib/context.ts b/javascript/packages/core/lib/context.ts index d13a08138c..466d50ece1 100644 --- a/javascript/packages/core/lib/context.ts +++ b/javascript/packages/core/lib/context.ts @@ -346,8 +346,6 @@ export class WriteContext { private disposeTypeInfo: TypeInfo[] = []; private dynamicTypeId = 0; - private _maxBinarySize: number; - private _maxCollectionSize: number; constructor( readonly typeResolver: TypeResolverLike, @@ -356,8 +354,6 @@ export class WriteContext { this.writer = new BinaryWriter(config); this.refWriter = new RefWriter(); this.metaStringWriter = new MetaStringWriter(); - this._maxBinarySize = config.maxBinarySize ?? 64 * 1024 * 1024; - this._maxCollectionSize = config.maxCollectionSize ?? 1_000_000; } reset() { @@ -371,24 +367,6 @@ export class WriteContext { this.dynamicTypeId = 0; } - checkCollectionSize(size: number) { - if (size > this._maxCollectionSize) { - throw new Error( - `Collection size ${size} exceeds maxCollectionSize ${this._maxCollectionSize}. ` - + "The data may be malicious, or increase maxCollectionSize if needed.", - ); - } - } - - checkBinarySize(size: number) { - if (size > this._maxBinarySize) { - throw new Error( - `Binary size ${size} exceeds maxBinarySize ${this._maxBinarySize}. ` - + "The data may be malicious, or increase maxBinarySize if needed.", - ); - } - } - isCompatible() { return this.typeResolver.isCompatible(); } @@ -549,14 +527,6 @@ export class WriteContext { setUint32Position(offset: number, value: number) { this.writer.setUint32Position(offset, value); } - - get maxBinarySize() { - return this._maxBinarySize; - } - - get maxCollectionSize() { - return this._maxCollectionSize; - } } export class ReadContext { @@ -584,8 +554,6 @@ export class ReadContext { private _depth = 0; private _maxDepth: number; - private _maxBinarySize: number; - private _maxCollectionSize: number; private static typeMetaHeaderHash(headerLow: number, headerHigh: number) { return headerHigh * 0x100000 + (headerLow >>> 12); @@ -641,8 +609,6 @@ export class ReadContext { this.refReader = new RefReader(this.reader); this.metaStringReader = new MetaStringReader(); this._maxDepth = config.maxDepth ?? 50; - this._maxBinarySize = config.maxBinarySize ?? 64 * 1024 * 1024; - this._maxCollectionSize = config.maxCollectionSize ?? 1_000_000; } reset(bytes: Uint8Array) { @@ -671,24 +637,6 @@ export class ReadContext { this._depth--; } - checkCollectionSize(size: number) { - if (size > this._maxCollectionSize) { - throw new Error( - `Collection size ${size} exceeds maxCollectionSize ${this._maxCollectionSize}. ` - + "The data may be malicious, or increase maxCollectionSize if needed.", - ); - } - } - - checkBinarySize(size: number) { - if (size > this._maxBinarySize) { - throw new Error( - `Binary size ${size} exceeds maxBinarySize ${this._maxBinarySize}. ` - + "The data may be malicious, or increase maxBinarySize if needed.", - ); - } - } - readRefFlag() { return this.refReader.readRefFlag(); } @@ -1388,12 +1336,4 @@ export class ReadContext { get maxDepth() { return this._maxDepth; } - - get maxBinarySize() { - return this._maxBinarySize; - } - - get maxCollectionSize() { - return this._maxCollectionSize; - } } diff --git a/javascript/packages/core/lib/fory.ts b/javascript/packages/core/lib/fory.ts index 50a367a14c..3c122d790b 100644 --- a/javascript/packages/core/lib/fory.ts +++ b/javascript/packages/core/lib/fory.ts @@ -34,9 +34,6 @@ import { ReadContext, WriteContext } from "./context"; const DEFAULT_DEPTH_LIMIT = 50 as const; const MIN_DEPTH_LIMIT = 2 as const; -const DEFAULT_MAX_COLLECTION_SIZE = 1_000_000 as const; -const DEFAULT_MAX_BINARY_SIZE = 64 * 1024 * 1024; // 64 MiB - export default class Fory { readonly typeResolver: TypeResolver; readonly anySerializer: Serializer; @@ -61,20 +58,6 @@ export default class Fory { `maxDepth must be an integer >= ${MIN_DEPTH_LIMIT} but got ${maxDepth}`, ); } - const maxBinarySize = this.config.maxBinarySize ?? DEFAULT_MAX_BINARY_SIZE; - if (!Number.isInteger(maxBinarySize) || maxBinarySize < 0) { - throw new Error( - `maxBinarySize must be a non-negative integer but got ${maxBinarySize}`, - ); - } - const maxCollectionSize - = this.config.maxCollectionSize ?? DEFAULT_MAX_COLLECTION_SIZE; - if (!Number.isInteger(maxCollectionSize) || maxCollectionSize < 0) { - throw new Error( - `maxCollectionSize must be a non-negative integer but got ${maxCollectionSize}`, - ); - } - this.typeResolver = new TypeResolver(this.config); this.writeContext = new WriteContext(this.typeResolver, this.config); this.readContext = new ReadContext(this.typeResolver, this.config); @@ -88,8 +71,6 @@ export default class Fory { ref: Boolean(config?.ref), useSliceString: Boolean(config?.useSliceString), maxDepth: config?.maxDepth, - maxBinarySize: config?.maxBinarySize, - maxCollectionSize: config?.maxCollectionSize, hooks: config?.hooks || {}, compatible: config?.compatible ?? true, hps: config?.hps, diff --git a/javascript/packages/core/lib/gen/array.ts b/javascript/packages/core/lib/gen/array.ts index 7027353f18..111b66a27e 100644 --- a/javascript/packages/core/lib/gen/array.ts +++ b/javascript/packages/core/lib/gen/array.ts @@ -41,7 +41,8 @@ class ArraySerializerGenerator extends CollectionSerializerGenerator { } newCollection(lenAccessor: string): string { - return `new Array(${lenAccessor})`; + void lenAccessor; + return `[]`; } putAccessor(result: string, item: string, index: string): string { diff --git a/javascript/packages/core/lib/gen/builder.ts b/javascript/packages/core/lib/gen/builder.ts index 7f82c7c101..fe60dc1286 100644 --- a/javascript/packages/core/lib/gen/builder.ts +++ b/javascript/packages/core/lib/gen/builder.ts @@ -36,6 +36,10 @@ export class BinaryReaderBuilder { return `${this.holder}.readSetCursor(${v})`; } + checkReadableBytes(len: string | number) { + return `${this.holder}.checkReadableBytes(${len})`; + } + getDataView() { return `${this.holder}.getDataView()`; } diff --git a/javascript/packages/core/lib/gen/collection.ts b/javascript/packages/core/lib/gen/collection.ts index e3607bff62..74344fc00f 100644 --- a/javascript/packages/core/lib/gen/collection.ts +++ b/javascript/packages/core/lib/gen/collection.ts @@ -64,7 +64,10 @@ export const CollectionFlags = { SAME_TYPE: 0b1000, }; -function compatibleArrayCollectionExpr(elementTypeId: number, len: string): string { +function compatibleArrayCollectionExpr( + elementTypeId: number, + len: string, +): string { switch (elementTypeId) { case TypeId.BOOL: return `new external.BoolArray(${len})`; @@ -242,12 +245,11 @@ class CollectionAnySerializer { ): any { void fromRef; const len = this.readContext.reader.readVarUint32Small7(); - const result = createCollection(len); if (len === 0) { - return result; + return createCollection(len); } - this.readContext.checkCollectionSize(len); const flags = this.readContext.reader.readUint8(); + const result = createCollection(len); // IMPORTANT: collection readers must obey the ref/null bits written on the // wire, not local TypeScript metadata that may imply a different ref // policy. Shared xlang tests intentionally deserialize one ref policy and @@ -447,14 +449,22 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera const useDeclaredStructElementReader = TypeId.structType( this.innerGenerator.getTypeId()!, ); - const compatibleReadAction = getCompatibleCollectionArrayReadAction(this.typeInfo); + const compatibleReadAction = getCompatibleCollectionArrayReadAction( + this.typeInfo, + ); const compatibleListToArray = compatibleReadAction?.target === "array"; const newCollection = compatibleListToArray ? compatibleArrayCollectionExpr(compatibleReadAction!.elementTypeId, len) : this.newCollection(len); - const putAccessor = (item: string, index: string) => compatibleListToArray - ? compatibleArrayPutAccessor(compatibleReadAction!.elementTypeId, result, item, index) - : this.putAccessor(result, item, index); + const putAccessor = (item: string, index: string) => + compatibleListToArray + ? compatibleArrayPutAccessor( + compatibleReadAction!.elementTypeId, + result, + item, + index, + ) + : this.putAccessor(result, item, index); const rejectCompatiblePayload = compatibleListToArray ? ` if (${flags} & (${CollectionFlags.HAS_NULL} | ${CollectionFlags.TRACKING_REF})) { @@ -479,18 +489,21 @@ export abstract class CollectionSerializerGenerator extends BaseSerializerGenera : innerReader.readWithDepth(assignStmt, refState); }; const readElementTypeInfo = useDeclaredStructElementReader - ? this.innerGenerator.readEmbed().readTypeInfo( - (expr: string) => `${elemSerializer} = ${expr};`, - ) + ? this.innerGenerator + .readEmbed() + .readTypeInfo((expr: string) => `${elemSerializer} = ${expr};`) : `${elemSerializer} = ${anyHelper}.detectSerializer(${readContextName});`; return ` const ${len} = ${this.builder.reader.readVarUint32Small7()}; - ${this.builder.getReadContextName()}.checkCollectionSize(${len}); + let ${flags} = 0; + if (${len} > 0) { + ${flags} = ${this.builder.reader.readUint8()}; + ${rejectCompatiblePayload} + ${this.builder.reader.checkReadableBytes(len)} + } const ${result} = ${newCollection}; ${this.maybeReference(result, refState)} if (${len} > 0) { - const ${flags} = ${this.builder.reader.readUint8()}; - ${rejectCompatiblePayload} let ${elemSerializer} = null; if (!(${flags} & ${CollectionFlags.DECL_ELEMENT_TYPE})) { ${readElementTypeInfo} diff --git a/javascript/packages/core/lib/gen/decimal.ts b/javascript/packages/core/lib/gen/decimal.ts index e0f81115d0..14adc5f4b1 100644 --- a/javascript/packages/core/lib/gen/decimal.ts +++ b/javascript/packages/core/lib/gen/decimal.ts @@ -37,7 +37,7 @@ class DecimalSerializerGenerator extends BaseSerializerGenerator { const codec = this.builder.getExternal(DecimalCodec.name); const scale = this.scope.uniqueName("decimal_scale"); const unscaled = this.scope.uniqueName("decimal_unscaled"); - const payload = this.scope.uniqueName("decimal_payload"); + const magnitudeBytes = this.scope.uniqueName("decimal_magnitude_bytes"); const meta = this.scope.uniqueName("decimal_meta"); return ` const ${scale} = ${accessor}.scale; @@ -46,10 +46,10 @@ class DecimalSerializerGenerator extends BaseSerializerGenerator { if (${codec}.canUseSmallEncoding(${unscaled})) { ${this.builder.writer.writeVarUInt64(`(${codec}.encodeZigZag64(${unscaled}) << 1n)`)} } else { - const ${payload} = ${codec}.toCanonicalLittleEndianMagnitude(${unscaled}); - const ${meta} = (BigInt(${payload}.length) << 1n) | (${unscaled} < 0n ? 1n : 0n); + const ${magnitudeBytes} = ${codec}.toCanonicalLittleEndianMagnitude(${unscaled}); + const ${meta} = (BigInt(${magnitudeBytes}.length) << 1n) | (${unscaled} < 0n ? 1n : 0n); ${this.builder.writer.writeVarUInt64(`((${meta} << 1n) | 1n)`)} - ${this.builder.writer.buffer(payload)} + ${this.builder.writer.buffer(magnitudeBytes)} } `; } @@ -61,7 +61,7 @@ class DecimalSerializerGenerator extends BaseSerializerGenerator { const header = this.scope.uniqueName("decimal_header"); const meta = this.scope.uniqueName("decimal_meta"); const length = this.scope.uniqueName("decimal_length"); - const payload = this.scope.uniqueName("decimal_payload"); + const magnitudeBytes = this.scope.uniqueName("decimal_magnitude_bytes"); const magnitude = this.scope.uniqueName("decimal_magnitude"); const unscaled = this.scope.uniqueName("decimal_unscaled"); return ` @@ -75,11 +75,11 @@ class DecimalSerializerGenerator extends BaseSerializerGenerator { if (${length} <= 0 || ${length} > 0x7fffffff) { throw new Error(\`Invalid decimal magnitude length \${${length}}.\`); } - const ${payload} = ${this.builder.reader.buffer(length)}; - if (${payload}[${length} - 1] === 0) { - throw new Error("Non-canonical decimal payload: trailing zero byte."); + const ${magnitudeBytes} = ${this.builder.reader.buffer(length)}; + if (${magnitudeBytes}[${length} - 1] === 0) { + throw new Error("Non-canonical decimal magnitude bytes: trailing zero byte."); } - const ${magnitude} = ${codec}.fromCanonicalLittleEndianMagnitude(${payload}); + const ${magnitude} = ${codec}.fromCanonicalLittleEndianMagnitude(${magnitudeBytes}); if (${magnitude} === 0n) { throw new Error("Big decimal encoding must not represent zero."); } diff --git a/javascript/packages/core/lib/gen/map.ts b/javascript/packages/core/lib/gen/map.ts index 4c3070f4de..c447691d40 100644 --- a/javascript/packages/core/lib/gen/map.ts +++ b/javascript/packages/core/lib/gen/map.ts @@ -298,7 +298,6 @@ class MapAnySerializer { read(fromRef: boolean): any { let count = this.readContext.reader.readVarUint32Small7(); - this.readContext.checkCollectionSize(count); const result = new Map(); if (fromRef) { this.readContext.reference(result); @@ -528,7 +527,6 @@ export class MapSerializerGenerator extends BaseSerializerGenerator { return ` let ${count} = ${this.builder.reader.readVarUint32Small7()}; - ${this.builder.getReadContextName()}.checkCollectionSize(${count}); const ${result} = new Map(); if (${refState}) { ${this.builder.referenceResolver.reference(result)} diff --git a/javascript/packages/core/lib/gen/typedArray.ts b/javascript/packages/core/lib/gen/typedArray.ts index 3f0e3cb343..562f31c63d 100644 --- a/javascript/packages/core/lib/gen/typedArray.ts +++ b/javascript/packages/core/lib/gen/typedArray.ts @@ -134,10 +134,10 @@ function build( const idx = this.scope.uniqueName("idx"); return ` const ${rawLen} = ${this.builder.reader.readVarUInt32()}; - ${this.builder.getReadContextName()}.checkBinarySize(${rawLen}); if ((${rawLen} % ${size}) !== 0) { throw new Error("dense array byte length is not divisible by element size"); } + ${this.builder.reader.checkReadableBytes(rawLen)}; const ${len} = ${rawLen} / ${size}; const ${result} = new ${creator}(${len}); ${this.maybeReference(result, refState)} @@ -149,7 +149,6 @@ function build( } return ` const ${len} = ${this.builder.reader.readVarUInt32()}; - ${this.builder.getReadContextName()}.checkBinarySize(${len}); const ${copied} = ${this.builder.reader.buffer(len)} const ${result} = new ${creator}(${copied}.buffer, ${copied}.byteOffset, ${copied}.byteLength / ${size}); ${this.maybeReference(result, refState)} @@ -199,7 +198,7 @@ class BoolArraySerializerGenerator extends BaseSerializerGenerator { const readByte = this.builder.reader.readUint8(); return ` const ${len} = ${this.builder.reader.readVarUInt32()}; - ${this.builder.getReadContextName()}.checkCollectionSize(${len}); + ${this.builder.reader.checkReadableBytes(len)}; let ${result}; if (${len} <= 4) { let ${bits}; @@ -279,10 +278,10 @@ class Float16ArraySerializerGenerator extends BaseSerializerGenerator { const raw = this.scope.uniqueName("raw"); return ` const ${rawLen} = ${this.builder.reader.readVarUInt32()}; - ${this.builder.getReadContextName()}.checkBinarySize(${rawLen}); if ((${rawLen} % 2) !== 0) { throw new Error("float16 array byte length is not divisible by element size"); } + ${this.builder.reader.checkReadableBytes(rawLen)}; const ${len} = ${rawLen} / 2; const ${raw} = new Uint16Array(${len}); for (let ${idx} = 0; ${idx} < ${len}; ${idx}++) { @@ -337,10 +336,10 @@ class BFloat16ArraySerializerGenerator extends BaseSerializerGenerator { const raw = this.scope.uniqueName("raw"); return ` const ${rawLen} = ${this.builder.reader.readVarUInt32()}; - ${this.builder.getReadContextName()}.checkBinarySize(${rawLen}); if ((${rawLen} % 2) !== 0) { throw new Error("bfloat16 array byte length is not divisible by element size"); } + ${this.builder.reader.checkReadableBytes(rawLen)}; const ${len} = ${rawLen} / 2; const ${raw} = new Uint16Array(${len}); for (let ${idx} = 0; ${idx} < ${len}; ${idx}++) { @@ -358,19 +357,55 @@ class BFloat16ArraySerializerGenerator extends BaseSerializerGenerator { } CodegenRegistry.register(TypeId.BOOL_ARRAY, BoolArraySerializerGenerator); -CodegenRegistry.register(TypeId.BINARY, build(Type.uint8(), `Uint8Array`, 1, "readUint8", "writeUint8")); -CodegenRegistry.register(TypeId.INT8_ARRAY, build(Type.int8(), `Int8Array`, 1, "readInt8", "writeInt8")); -CodegenRegistry.register(TypeId.INT16_ARRAY, build(Type.int16(), `Int16Array`, 2, "readInt16", "writeInt16")); -CodegenRegistry.register(TypeId.INT32_ARRAY, build(Type.int32(), `Int32Array`, 4, "readInt32", "writeInt32")); -CodegenRegistry.register(TypeId.INT64_ARRAY, build(Type.int64(), `BigInt64Array`, 8, "readInt64", "writeInt64")); -CodegenRegistry.register(TypeId.UINT8_ARRAY, build(Type.uint8(), `Uint8Array`, 1, "readUint8", "writeUint8")); -CodegenRegistry.register(TypeId.UINT16_ARRAY, build(Type.uint16(), `Uint16Array`, 2, "readUint16", "writeUint16")); -CodegenRegistry.register(TypeId.UINT32_ARRAY, build(Type.uint32(), `Uint32Array`, 4, "readUint32", "writeUint32")); -CodegenRegistry.register(TypeId.UINT64_ARRAY, build(Type.uint64(), `BigUint64Array`, 8, "readUint64", "writeUint64")); +CodegenRegistry.register( + TypeId.BINARY, + build(Type.uint8(), `Uint8Array`, 1, "readUint8", "writeUint8"), +); +CodegenRegistry.register( + TypeId.INT8_ARRAY, + build(Type.int8(), `Int8Array`, 1, "readInt8", "writeInt8"), +); +CodegenRegistry.register( + TypeId.INT16_ARRAY, + build(Type.int16(), `Int16Array`, 2, "readInt16", "writeInt16"), +); +CodegenRegistry.register( + TypeId.INT32_ARRAY, + build(Type.int32(), `Int32Array`, 4, "readInt32", "writeInt32"), +); +CodegenRegistry.register( + TypeId.INT64_ARRAY, + build(Type.int64(), `BigInt64Array`, 8, "readInt64", "writeInt64"), +); +CodegenRegistry.register( + TypeId.UINT8_ARRAY, + build(Type.uint8(), `Uint8Array`, 1, "readUint8", "writeUint8"), +); +CodegenRegistry.register( + TypeId.UINT16_ARRAY, + build(Type.uint16(), `Uint16Array`, 2, "readUint16", "writeUint16"), +); +CodegenRegistry.register( + TypeId.UINT32_ARRAY, + build(Type.uint32(), `Uint32Array`, 4, "readUint32", "writeUint32"), +); +CodegenRegistry.register( + TypeId.UINT64_ARRAY, + build(Type.uint64(), `BigUint64Array`, 8, "readUint64", "writeUint64"), +); CodegenRegistry.register(TypeId.FLOAT16_ARRAY, Float16ArraySerializerGenerator); -CodegenRegistry.register(TypeId.BFLOAT16_ARRAY, BFloat16ArraySerializerGenerator); -CodegenRegistry.register(TypeId.FLOAT32_ARRAY, build(Type.float32(), `Float32Array`, 4, "readFloat32", "writeFloat32")); -CodegenRegistry.register(TypeId.FLOAT64_ARRAY, build(Type.float64(), `Float64Array`, 8, "readFloat64", "writeFloat64")); +CodegenRegistry.register( + TypeId.BFLOAT16_ARRAY, + BFloat16ArraySerializerGenerator, +); +CodegenRegistry.register( + TypeId.FLOAT32_ARRAY, + build(Type.float32(), `Float32Array`, 4, "readFloat32", "writeFloat32"), +); +CodegenRegistry.register( + TypeId.FLOAT64_ARRAY, + build(Type.float64(), `Float64Array`, 8, "readFloat64", "writeFloat64"), +); CodegenRegistry.registerExternal(BFloat16Array); CodegenRegistry.registerExternal(BoolArray); CodegenRegistry.registerExternal(createFloat16Array); diff --git a/javascript/packages/core/lib/meta/TypeMeta.ts b/javascript/packages/core/lib/meta/TypeMeta.ts index c4cfed4af7..90c4696c45 100644 --- a/javascript/packages/core/lib/meta/TypeMeta.ts +++ b/javascript/packages/core/lib/meta/TypeMeta.ts @@ -44,6 +44,7 @@ const HASH_SHIFT_BITS = 64n - BigInt(NUM_HASH_BITS); const UINT64_MASK = 0xffffffffffffffffn; const HEADER_HASH_MASK = UINT64_MASK ^ ((1n << HASH_SHIFT_BITS) - 1n); const BIG_NAME_THRESHOLD = 0b111111; +const MAX_TYPE_META_NESTING = 128; const PRIMITIVE_TYPE_IDS = [ TypeId.BOOL, @@ -468,6 +469,8 @@ export class TypeMeta { const headerHash = Number(header >> HASH_SHIFT_BITS); const bodyStart = reader.readGetCursor(); + reader.checkReadableBytes(metaSize); + const bodyEnd = bodyStart + metaSize; const classHeader = reader.readUint8(); const isStruct = (classHeader & STRUCT_TYPEDEF_FLAG) !== 0; let numFields = 0; @@ -508,6 +511,9 @@ export class TypeMeta { } // Read fields + if (numFields > bodyEnd - reader.readGetCursor()) { + throw new Error("TypeMeta field count exceeds metadata body size"); + } const fields: FieldInfo[] = []; for (let i = 0; i < numFields; i++) { const fieldInfo = this.readFieldInfo(reader); @@ -622,7 +628,13 @@ export class TypeMeta { private static readTypeId( reader: BinaryReader, readFlag = false, + depth = 0, ): InnerFieldInfo { + if (depth > MAX_TYPE_META_NESTING) { + throw new Error( + `TypeMeta nesting depth limit exceeded: ${depth} > ${MAX_TYPE_META_NESTING}`, + ); + } const options: InnerFieldInfoOptions = {}; let nullable = false; let trackingRef = false; @@ -639,7 +651,7 @@ export class TypeMeta { ) { typeId = TypeId.UNION; } - this.readNestedTypeInfo(reader, typeId, options); + this.readNestedTypeInfo(reader, typeId, options, depth); return { typeId, userTypeId: -1, nullable, trackingRef, options }; } let typeId = reader.readUint8(); @@ -648,7 +660,7 @@ export class TypeMeta { } else if (typeId === TypeId.NAMED_UNION || typeId === TypeId.TYPED_UNION) { typeId = TypeId.UNION; } - this.readNestedTypeInfo(reader, typeId, options); + this.readNestedTypeInfo(reader, typeId, options, depth); return { typeId, userTypeId: -1, nullable, trackingRef, options }; } @@ -656,17 +668,18 @@ export class TypeMeta { reader: BinaryReader, typeId: number, options: InnerFieldInfoOptions, + depth: number, ) { switch (typeId) { case TypeId.LIST: - options.inner = this.readTypeId(reader, true); + options.inner = this.readTypeId(reader, true, depth + 1); break; case TypeId.SET: - options.key = this.readTypeId(reader, true); + options.key = this.readTypeId(reader, true, depth + 1); break; case TypeId.MAP: - options.key = this.readTypeId(reader, true); - options.value = this.readTypeId(reader, true); + options.key = this.readTypeId(reader, true, depth + 1); + options.value = this.readTypeId(reader, true, depth + 1); break; default: break; diff --git a/javascript/packages/core/lib/reader/index.ts b/javascript/packages/core/lib/reader/index.ts index 38562948ac..d39acd693a 100644 --- a/javascript/packages/core/lib/reader/index.ts +++ b/javascript/packages/core/lib/reader/index.ts @@ -93,6 +93,12 @@ export class BinaryReader { this.cursor += len; } + checkReadableBytes(len: number) { + if (len < 0 || len > this.byteLength - this.cursor) { + throw new Error("Insufficient bytes to read"); + } + } + readInt32() { const result = this.dataView.getInt32(this.cursor, true); this.cursor += 4; diff --git a/javascript/packages/core/lib/type.ts b/javascript/packages/core/lib/type.ts index d88d666e0e..5999c99b70 100644 --- a/javascript/packages/core/lib/type.ts +++ b/javascript/packages/core/lib/type.ts @@ -149,10 +149,22 @@ export const TypeId = { ].includes(id as any); }, polymorphicType(id: number) { - return [TypeId.STRUCT, TypeId.NAMED_STRUCT, TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT, TypeId.EXT, TypeId.NAMED_EXT].includes(id as any); + return [ + TypeId.STRUCT, + TypeId.NAMED_STRUCT, + TypeId.COMPATIBLE_STRUCT, + TypeId.NAMED_COMPATIBLE_STRUCT, + TypeId.EXT, + TypeId.NAMED_EXT, + ].includes(id as any); }, structType(id: number) { - return [TypeId.STRUCT, TypeId.NAMED_STRUCT, TypeId.COMPATIBLE_STRUCT, TypeId.NAMED_COMPATIBLE_STRUCT].includes(id as any); + return [ + TypeId.STRUCT, + TypeId.NAMED_STRUCT, + TypeId.COMPATIBLE_STRUCT, + TypeId.NAMED_COMPATIBLE_STRUCT, + ].includes(id as any); }, extType(id: number) { return [TypeId.EXT, TypeId.NAMED_EXT].includes(id as any); @@ -161,12 +173,14 @@ export const TypeId = { return [TypeId.ENUM, TypeId.NAMED_ENUM].includes(id as any); }, userDefinedType(id: number) { - return this.structType(id) + return ( + this.structType(id) || this.extType(id) || this.enumType(id) || id == TypeId.UNION || id == TypeId.TYPED_UNION - || id == TypeId.NAMED_UNION; + || id == TypeId.NAMED_UNION + ); }, isBuiltin(id: number) { return !this.userDefinedType(id) && id !== TypeId.UNKNOWN; @@ -202,7 +216,8 @@ export const TypeId = { // NONE(36), DURATION(37), TIMESTAMP(38), DATE(39), DECIMAL(40), BINARY(41) if (typeId >= TypeId.NONE && typeId <= TypeId.BINARY) return true; // Typed arrays BOOL_ARRAY(43)..FLOAT64_ARRAY(56) - if (typeId >= TypeId.BOOL_ARRAY && typeId <= TypeId.FLOAT64_ARRAY) return true; + if (typeId >= TypeId.BOOL_ARRAY && typeId <= TypeId.FLOAT64_ARRAY) + return true; return false; }, } as const; @@ -254,7 +269,7 @@ export enum RefFlags { export const MaxInt32 = 2147483647; export const MinInt32 = -2147483648; -export const MaxUInt32 = 0xFFFFFFFF; +export const MaxUInt32 = 0xffffffff; export const MinUInt32 = 0; export const HalfMaxInt32 = MaxInt32 / 2; export const HalfMinInt32 = MinInt32 / 2; @@ -276,8 +291,6 @@ export interface Config { ref: boolean; useSliceString: boolean; maxDepth?: number; - maxBinarySize?: number; - maxCollectionSize?: number; hooks: { afterCodeGenerated?: (code: string) => string; }; diff --git a/javascript/packages/core/lib/types/decimal.ts b/javascript/packages/core/lib/types/decimal.ts index 4baf3eec81..fa967500c7 100644 --- a/javascript/packages/core/lib/types/decimal.ts +++ b/javascript/packages/core/lib/types/decimal.ts @@ -37,9 +37,11 @@ export class Decimal { } equals(other: unknown): boolean { - return other instanceof Decimal + return ( + other instanceof Decimal && other.scale === this.scale - && other.unscaledValue === this.unscaledValue; + && other.unscaledValue === this.unscaledValue + ); } toString(): string { @@ -73,10 +75,10 @@ export class DecimalCodec { return Uint8Array.from(bytes); } - static fromCanonicalLittleEndianMagnitude(payload: Uint8Array): bigint { + static fromCanonicalLittleEndianMagnitude(bytes: Uint8Array): bigint { let magnitude = 0n; - for (let i = payload.length - 1; i >= 0; i--) { - magnitude = (magnitude << 8n) | BigInt(payload[i]); + for (let i = bytes.length - 1; i >= 0; i--) { + magnitude = (magnitude << 8n) | BigInt(bytes[i]); } return magnitude; } diff --git a/javascript/test/sizeLimit.test.ts b/javascript/test/sizeLimit.test.ts deleted file mode 100644 index 73a61caaf0..0000000000 --- a/javascript/test/sizeLimit.test.ts +++ /dev/null @@ -1,351 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -import Fory, { Type, BoolArray } from '../packages/core/index'; -import { describe, expect, test } from '@jest/globals'; - -describe('size-limit guardrails', () => { - describe('configuration', () => { - test('should have default limits matching Go', () => { - const fory = new Fory({ compatible: false }); - expect(fory.readContext.maxBinarySize).toBe(64 * 1024 * 1024); - expect(fory.readContext.maxCollectionSize).toBe(1_000_000); - }); - - test('should accept custom maxBinarySize', () => { - const fory = new Fory({ compatible: false, maxBinarySize: 1024 }); - expect(fory.readContext.maxBinarySize).toBe(1024); - }); - - test('should accept custom maxCollectionSize', () => { - const fory = new Fory({ compatible: false, maxCollectionSize: 500 }); - expect(fory.readContext.maxCollectionSize).toBe(500); - }); - - test('should accept zero as a valid limit', () => { - const fory = new Fory({ compatible: false, maxBinarySize: 0, maxCollectionSize: 0 }); - expect(fory.readContext.maxBinarySize).toBe(0); - expect(fory.readContext.maxCollectionSize).toBe(0); - }); - - test('should reject negative maxBinarySize', () => { - expect(() => new Fory({ compatible: false, maxBinarySize: -1 })).toThrow( - 'maxBinarySize must be a non-negative integer' - ); - }); - - test('should reject non-integer maxBinarySize', () => { - expect(() => new Fory({ compatible: false, maxBinarySize: 1.5 })).toThrow( - 'maxBinarySize must be a non-negative integer' - ); - }); - - test('should reject NaN maxBinarySize', () => { - expect(() => new Fory({ compatible: false, maxBinarySize: NaN })).toThrow( - 'maxBinarySize must be a non-negative integer' - ); - }); - - test('should reject negative maxCollectionSize', () => { - expect(() => new Fory({ compatible: false, maxCollectionSize: -10 })).toThrow( - 'maxCollectionSize must be a non-negative integer' - ); - }); - - test('should reject non-integer maxCollectionSize', () => { - expect(() => new Fory({ compatible: false, maxCollectionSize: 2.7 })).toThrow( - 'maxCollectionSize must be a non-negative integer' - ); - }); - - test('should work with other options combined', () => { - const fory = new Fory({ - maxDepth: 100, - maxBinarySize: 1024, - maxCollectionSize: 500, - ref: true, - compatible: true, - }); - expect(fory.readContext.maxDepth).toBe(100); - expect(fory.readContext.maxBinarySize).toBe(1024); - expect(fory.readContext.maxCollectionSize).toBe(500); - }); - }); - - describe('checkCollectionSize', () => { - test('should not throw when size is within default limit', () => { - const fory = new Fory({ compatible: false }); - expect(() => fory.readContext.checkCollectionSize(999999)).not.toThrow(); - }); - - test('should not throw when size is within limit', () => { - const fory = new Fory({ compatible: false, maxCollectionSize: 100 }); - expect(() => fory.readContext.checkCollectionSize(100)).not.toThrow(); - expect(() => fory.readContext.checkCollectionSize(0)).not.toThrow(); - }); - - test('should throw when size exceeds limit', () => { - const fory = new Fory({ compatible: false, maxCollectionSize: 100 }); - expect(() => fory.readContext.checkCollectionSize(101)).toThrow( - 'Collection size 101 exceeds maxCollectionSize 100' - ); - }); - - test('error message should include helpful suggestion', () => { - const fory = new Fory({ compatible: false, maxCollectionSize: 10 }); - expect(() => fory.readContext.checkCollectionSize(20)).toThrow( - 'increase maxCollectionSize if needed' - ); - }); - }); - - describe('checkBinarySize', () => { - test('should not throw when size is within default limit', () => { - const fory = new Fory({ compatible: false }); - expect(() => fory.readContext.checkBinarySize(999999)).not.toThrow(); - }); - - test('should not throw when size is within limit', () => { - const fory = new Fory({ compatible: false, maxBinarySize: 1024 }); - expect(() => fory.readContext.checkBinarySize(1024)).not.toThrow(); - expect(() => fory.readContext.checkBinarySize(0)).not.toThrow(); - }); - - test('should throw when size exceeds limit', () => { - const fory = new Fory({ compatible: false, maxBinarySize: 1024 }); - expect(() => fory.readContext.checkBinarySize(1025)).toThrow( - 'Binary size 1025 exceeds maxBinarySize 1024' - ); - }); - }); - - describe('list deserialization with maxCollectionSize', () => { - test('should deserialize list within limit', () => { - const fory = new Fory({ compatible: false, maxCollectionSize: 10 }); - const { serialize, deserialize } = fory.register(Type.list(Type.int32())); - const data = [1, 2, 3]; - const result = deserialize(serialize(data)); - expect(result).toEqual(data); - }); - - test('should throw when list exceeds maxCollectionSize', () => { - const serializeFory = new Fory({ compatible: false }); - const { serialize } = serializeFory.register(Type.list(Type.int32())); - const bytes = serialize([1, 2, 3, 4, 5]); - - const deserializeFory = new Fory({ compatible: false, maxCollectionSize: 3 }); - const { deserialize } = deserializeFory.register(Type.list(Type.int32())); - expect(() => deserialize(bytes)).toThrow('exceeds maxCollectionSize'); - }); - - test('should deserialize list at exact limit', () => { - const fory = new Fory({ compatible: false, maxCollectionSize: 3 }); - const { serialize, deserialize } = fory.register(Type.list(Type.int32())); - const data = [1, 2, 3]; - const result = deserialize(serialize(data)); - expect(result).toEqual(data); - }); - }); - - describe('set deserialization with maxCollectionSize', () => { - test('should deserialize set within limit', () => { - const fory = new Fory({ compatible: false, maxCollectionSize: 10, ref: true }); - const { serialize, deserialize } = fory.register(Type.set(Type.int32())); - const data = new Set([1, 2, 3]); - const result = deserialize(serialize(data)); - expect(result).toEqual(data); - }); - - test('should throw when set exceeds maxCollectionSize', () => { - const serializeFory = new Fory({ compatible: false, ref: true }); - const { serialize } = serializeFory.register(Type.set(Type.int32())); - const bytes = serialize(new Set([1, 2, 3, 4, 5])); - - const deserializeFory = new Fory({ compatible: false, maxCollectionSize: 3, ref: true }); - const { deserialize } = deserializeFory.register(Type.set(Type.int32())); - expect(() => deserialize(bytes)).toThrow('exceeds maxCollectionSize'); - }); - }); - - describe('map deserialization with maxCollectionSize', () => { - test('should deserialize map within limit', () => { - const fory = new Fory({ compatible: false, maxCollectionSize: 10, ref: true }); - const { serialize, deserialize } = fory.register( - Type.map(Type.string(), Type.int32()) - ); - const data = new Map([['a', 1], ['b', 2]]); - const result = deserialize(serialize(data)); - expect(result).toEqual(data); - }); - - test('should throw when map exceeds maxCollectionSize', () => { - const serializeFory = new Fory({ compatible: false, ref: true }); - const { serialize } = serializeFory.register( - Type.map(Type.string(), Type.int32()) - ); - const bytes = serialize(new Map([['a', 1], ['b', 2], ['c', 3], ['d', 4]])); - - const deserializeFory = new Fory({ compatible: false, maxCollectionSize: 2, ref: true }); - const { deserialize } = deserializeFory.register( - Type.map(Type.string(), Type.int32()) - ); - expect(() => deserialize(bytes)).toThrow('exceeds maxCollectionSize'); - }); - }); - - describe('binary deserialization with maxBinarySize', () => { - test('should deserialize binary within limit', () => { - const fory = new Fory({ compatible: false, maxBinarySize: 1024, ref: true }); - const { serialize, deserialize } = fory.register(Type.struct("test.binary", { - data: Type.binary(), - })); - const data = { data: new Uint8Array([1, 2, 3]) }; - const result = deserialize(serialize(data)); - expect(result!.data![0]).toBe(1); - expect(result!.data![1]).toBe(2); - expect(result!.data![2]).toBe(3); - }); - - test('should throw when binary exceeds maxBinarySize', () => { - const serializeFory = new Fory({ compatible: false, ref: true }); - const { serialize } = serializeFory.register(Type.struct("test.binary2", { - data: Type.binary(), - })); - const bytes = serialize({ data: new Uint8Array(100) }); - - const deserializeFory = new Fory({ compatible: false, maxBinarySize: 50, ref: true }); - const { deserialize } = deserializeFory.register(Type.struct("test.binary2", { - data: Type.binary(), - })); - expect(() => deserialize(bytes)).toThrow('exceeds maxBinarySize'); - }); - }); - - describe('default limits allow normal payloads', () => { - test('should allow large collections within default limit', () => { - const fory = new Fory({ compatible: false, ref: true }); - const { serialize, deserialize } = fory.register(Type.list(Type.int32())); - const bigArray = Array.from({ length: 1000 }, (_, i) => i); - const result = deserialize(serialize(bigArray)); - expect(result).toEqual(bigArray); - }); - }); - - describe('polymorphic (any-typed) collection paths', () => { - test('should enforce maxCollectionSize on untyped list', () => { - const serializeFory = new Fory({ compatible: false, ref: true }); - const bytes = serializeFory.serialize([1, "two", 3.0]); - - const deserializeFory = new Fory({ compatible: false, maxCollectionSize: 2, ref: true }); - expect(() => deserializeFory.deserialize(bytes)).toThrow('exceeds maxCollectionSize'); - }); - - test('should enforce maxCollectionSize on untyped map', () => { - const serializeFory = new Fory({ compatible: false, ref: true }); - const bytes = serializeFory.serialize(new Map([["a", 1], ["b", 2], ["c", 3]])); - - const deserializeFory = new Fory({ compatible: false, maxCollectionSize: 2, ref: true }); - expect(() => deserializeFory.deserialize(bytes)).toThrow('exceeds maxCollectionSize'); - }); - }); - - describe('bool array deserialization with maxCollectionSize', () => { - // BoolArraySerializerGenerator reads an element count — guarded by checkCollectionSize - test('should deserialize bool array within limit', () => { - const fory = new Fory({ compatible: false, maxCollectionSize: 10 }); - const { serialize, deserialize } = fory.register(Type.struct("test.boolArr", { - flags: Type.boolArray(), - })); - const data = { flags: [true, false, true] }; - const result = deserialize(serialize(data)); - expect(result!.flags).toBeInstanceOf(BoolArray); - expect(Array.from(result!.flags)).toEqual([true, false, true]); - }); - - test('should throw when bool array exceeds maxCollectionSize', () => { - const serializeFory = new Fory({ compatible: false }); - const { serialize } = serializeFory.register(Type.struct("test.boolArr2", { - flags: Type.boolArray(), - })); - const bytes = serialize({ flags: [true, false, true, true, false] }); - - const deserializeFory = new Fory({ compatible: false, maxCollectionSize: 3 }); - const { deserialize } = deserializeFory.register(Type.struct("test.boolArr2", { - flags: Type.boolArray(), - })); - expect(() => deserialize(bytes)).toThrow('exceeds maxCollectionSize'); - }); - }); - - describe('float16 array deserialization with maxBinarySize', () => { - // Float16ArraySerializerGenerator writes byte count (elements * 2) — guarded by checkBinarySize - test('should deserialize float16 array within limit', () => { - const fory = new Fory({ compatible: false, maxBinarySize: 1024 }); - const { serialize, deserialize } = fory.register(Type.struct("test.f16Arr", { - vals: Type.float16Array(), - })); - const data = { vals: [1.0, 2.0, 3.0] }; - const result = deserialize(serialize(data)); - expect(result!.vals!.length).toBe(3); - }); - - test('should throw when float16 array byte length exceeds maxBinarySize', () => { - const serializeFory = new Fory({ compatible: false }); - const { serialize } = serializeFory.register(Type.struct("test.f16Arr2", { - vals: Type.float16Array(), - })); - // 10 elements × 2 bytes each = 20 raw bytes on the wire - const bytes = serialize({ vals: Array.from({ length: 10 }, (_, i) => i * 0.5) }); - - const deserializeFory = new Fory({ compatible: false, maxBinarySize: 10 }); // 10 < 20 - const { deserialize } = deserializeFory.register(Type.struct("test.f16Arr2", { - vals: Type.float16Array(), - })); - expect(() => deserialize(bytes)).toThrow('exceeds maxBinarySize'); - }); - }); - - describe('bfloat16 array deserialization with maxBinarySize', () => { - // BFloat16ArraySerializerGenerator writes byte count (elements * 2) — same pattern as float16 - test('should deserialize bfloat16 array within limit', () => { - const fory = new Fory({ compatible: false, maxBinarySize: 1024 }); - const { serialize, deserialize } = fory.register(Type.struct("test.bf16Arr", { - vals: Type.bfloat16Array(), - })); - const data = { vals: [1.0, 2.0, 3.0] }; - const result = deserialize(serialize(data)); - expect(result!.vals!.length).toBe(3); - }); - - test('should throw when bfloat16 array byte length exceeds maxBinarySize', () => { - const serializeFory = new Fory({ compatible: false }); - const { serialize } = serializeFory.register(Type.struct("test.bf16Arr2", { - vals: Type.bfloat16Array(), - })); - // 10 elements × 2 bytes each = 20 raw bytes on the wire - const bytes = serialize({ vals: Array.from({ length: 10 }, (_, i) => i * 0.5) }); - - const deserializeFory = new Fory({ compatible: false, maxBinarySize: 10 }); // 10 < 20 - const { deserialize } = deserializeFory.register(Type.struct("test.bf16Arr2", { - vals: Type.bfloat16Array(), - })); - expect(() => deserialize(bytes)).toThrow('exceeds maxBinarySize'); - }); - }); -}); diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt index 0c288ef604..7517ca54ac 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/KotlinSerializerSourceWriter.kt @@ -1732,10 +1732,10 @@ internal class KotlinSerializerSourceWriter(private val struct: KotlinSourceStru private fun directReadExpression(field: KotlinSourceField): String? { val denseRead = denseUnsignedArrayRead(field) if (denseRead != null && !field.nullable) { - return "KotlinXlangArrayEncoding.$denseRead(readContext, typeResolver.config.maxBinarySize())" + return "KotlinXlangArrayEncoding.$denseRead(readContext)" } if (denseRead != null && field.nullable && !field.trackingRef) { - return "if (buffer.readByte() == Fory.NULL_FLAG) null else KotlinXlangArrayEncoding.$denseRead(readContext, typeResolver.config.maxBinarySize())" + return "if (buffer.readByte() == Fory.NULL_FLAG) null else KotlinXlangArrayEncoding.$denseRead(readContext)" } if (isScalarUnsigned(field) && field.nullable && !field.trackingRef) { val readValue = diff --git a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/UnionSerializerSourceWriter.kt b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/UnionSerializerSourceWriter.kt index 6b7591787b..ca609755ed 100644 --- a/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/UnionSerializerSourceWriter.kt +++ b/kotlin/fory-kotlin-ksp/src/main/kotlin/org/apache/fory/kotlin/ksp/UnionSerializerSourceWriter.kt @@ -268,7 +268,7 @@ internal class UnionSerializerSourceWriter(private val union: KotlinSourceUnion) } private fun directPayloadWrite(type: KotlinSourceTypeNode, value: String): String? { - val listWrite = directListPayloadWrite(type, value) + val listWrite = directListBodyWrite(type, value) if (listWrite != null) { return listWrite } @@ -304,13 +304,13 @@ internal class UnionSerializerSourceWriter(private val union: KotlinSourceUnion) } private fun directPayloadRead(type: KotlinSourceTypeNode): String? { - val listRead = directListPayloadRead(type) + val listRead = directListBodyRead(type) if (listRead != null) { return listRead } val denseUnsignedArrayRead = denseUnsignedArrayRead(type) if (denseUnsignedArrayRead != null && !type.trackingRef) { - return "KotlinXlangArrayEncoding.$denseUnsignedArrayRead(readContext, typeResolver.config.maxBinarySize())" + return "KotlinXlangArrayEncoding.$denseUnsignedArrayRead(readContext)" } if (!canDirect(type)) { return null @@ -342,7 +342,7 @@ internal class UnionSerializerSourceWriter(private val union: KotlinSourceUnion) private fun canDirect(type: KotlinSourceTypeNode): Boolean = !type.trackingRef && type.typeArguments.isEmpty() && type.componentType == null - private fun directListPayloadWrite(type: KotlinSourceTypeNode, value: String): String? { + private fun directListBodyWrite(type: KotlinSourceTypeNode, value: String): String? { if (type.typeId != "Types.LIST" || type.typeArguments.size != 1 || type.nullable) { return null } @@ -354,7 +354,7 @@ internal class UnionSerializerSourceWriter(private val union: KotlinSourceUnion) return "$value.let { listValue -> buffer.writeVarUInt32Small7(listValue.size); if (listValue.isNotEmpty()) { buffer.writeByte(CollectionFlags.DECL_SAME_TYPE_NOT_HAS_NULL); for (element in listValue) { $writeElement } } }" } - private fun directListPayloadRead(type: KotlinSourceTypeNode): String? { + private fun directListBodyRead(type: KotlinSourceTypeNode): String? { if (type.typeId != "Types.LIST" || type.typeArguments.size != 1 || type.nullable) { return null } @@ -364,7 +364,7 @@ internal class UnionSerializerSourceWriter(private val union: KotlinSourceUnion) } val readElement = directPayloadRead(elementType) ?: return null val valueType = type.valueTypeName.removeSuffix("?") - return "run { val size = buffer.readVarUInt32Small7(); val result = java.util.ArrayList(size); if (size > 0) { check(buffer.readByte().toInt() == CollectionFlags.DECL_SAME_TYPE_NOT_HAS_NULL); for (i in 0 until size) { result.add($readElement) } }; result as $valueType }" + return "run { val size = buffer.readVarUInt32Small7(); val result = if (size == 0) java.util.ArrayList(0) else { check(buffer.readByte().toInt() == CollectionFlags.DECL_SAME_TYPE_NOT_HAS_NULL); buffer.checkReadableBytes(size); val values = java.util.ArrayList(size); for (i in 0 until size) { values.add($readElement) }; values }; result as $valueType }" } private fun denseUnsignedArrayWrite(type: KotlinSourceTypeNode): String? = diff --git a/kotlin/fory-kotlin-ksp/src/test/kotlin/org/apache/fory/kotlin/ksp/ProcessorValidationTest.kt b/kotlin/fory-kotlin-ksp/src/test/kotlin/org/apache/fory/kotlin/ksp/ProcessorValidationTest.kt index 2ce36fa3dc..2644c12c17 100644 --- a/kotlin/fory-kotlin-ksp/src/test/kotlin/org/apache/fory/kotlin/ksp/ProcessorValidationTest.kt +++ b/kotlin/fory-kotlin-ksp/src/test/kotlin/org/apache/fory/kotlin/ksp/ProcessorValidationTest.kt @@ -1625,7 +1625,8 @@ class ProcessorValidationTest { assertTrue(source.contains("DurationSerializers.serializer(typeResolver.config")) assertTrue(source.contains("listValue.isNotEmpty()")) assertTrue(source.contains("buffer.writeByte(CollectionFlags.DECL_SAME_TYPE_NOT_HAS_NULL)")) - assertTrue(source.contains("if (size > 0)")) + assertTrue(source.contains("if (size == 0) java.util.ArrayList(0)")) + assertTrue(source.contains("buffer.checkReadableBytes(size)")) assertTrue(source.contains("java.util.ArrayList(size)")) assertTrue( source.contains("KotlinXlangArrayEncoding.writeUIntArray(writeContext, value.value)") diff --git a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/kotlin/KotlinXlangArrayEncoding.kt b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/kotlin/KotlinXlangArrayEncoding.kt index 95ef3ed11f..99502ed750 100644 --- a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/kotlin/KotlinXlangArrayEncoding.kt +++ b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/kotlin/KotlinXlangArrayEncoding.kt @@ -27,7 +27,7 @@ import org.apache.fory.exception.DeserializationException import org.apache.fory.memory.MemoryBuffer import org.apache.fory.type.Types -/** Kotlin/JVM carrier helpers for Fory xlang dense unsigned array payloads. */ +/** Kotlin/JVM carrier helpers for Fory xlang dense unsigned array bodies. */ public object KotlinXlangArrayEncoding { @JvmStatic public fun writeUByteArray(writeContext: WriteContext, value: UByteArray) { @@ -43,9 +43,9 @@ public object KotlinXlangArrayEncoding { } @JvmStatic - public fun readUByteArray(readContext: ReadContext, maxBinarySize: Int): UByteArray { - val buffer = payloadBuffer(readContext) - val size = payloadSize(readContext, buffer, maxBinarySize, 1) + public fun readUByteArray(readContext: ReadContext): UByteArray { + val buffer = arrayBuffer(readContext) + val size = arrayByteSize(readContext, buffer, 1) return UByteArray(size) { buffer.readByte().toUByte() } } @@ -63,9 +63,9 @@ public object KotlinXlangArrayEncoding { } @JvmStatic - public fun readUShortArray(readContext: ReadContext, maxBinarySize: Int): UShortArray { - val buffer = payloadBuffer(readContext) - val size = payloadSize(readContext, buffer, maxBinarySize, Short.SIZE_BYTES) + public fun readUShortArray(readContext: ReadContext): UShortArray { + val buffer = arrayBuffer(readContext) + val size = arrayByteSize(readContext, buffer, Short.SIZE_BYTES) return UShortArray(size / Short.SIZE_BYTES) { buffer.readInt16().toUShort() } } @@ -83,9 +83,9 @@ public object KotlinXlangArrayEncoding { } @JvmStatic - public fun readUIntArray(readContext: ReadContext, maxBinarySize: Int): UIntArray { - val buffer = payloadBuffer(readContext) - val size = payloadSize(readContext, buffer, maxBinarySize, Int.SIZE_BYTES) + public fun readUIntArray(readContext: ReadContext): UIntArray { + val buffer = arrayBuffer(readContext) + val size = arrayByteSize(readContext, buffer, Int.SIZE_BYTES) return UIntArray(size / Int.SIZE_BYTES) { buffer.readInt32().toUInt() } } @@ -103,9 +103,9 @@ public object KotlinXlangArrayEncoding { } @JvmStatic - public fun readULongArray(readContext: ReadContext, maxBinarySize: Int): ULongArray { - val buffer = payloadBuffer(readContext) - val size = payloadSize(readContext, buffer, maxBinarySize, Long.SIZE_BYTES) + public fun readULongArray(readContext: ReadContext): ULongArray { + val buffer = arrayBuffer(readContext) + val size = arrayByteSize(readContext, buffer, Long.SIZE_BYTES) return ULongArray(size / Long.SIZE_BYTES) { buffer.readInt64().toULong() } } @@ -153,27 +153,21 @@ public object KotlinXlangArrayEncoding { return result } - private fun payloadBuffer(readContext: ReadContext): MemoryBuffer = + private fun arrayBuffer(readContext: ReadContext): MemoryBuffer = if (readContext.isPeerOutOfBandEnabled) readContext.readBufferObject() else readContext.buffer - private fun payloadSize( - readContext: ReadContext, - buffer: MemoryBuffer, - maxBinarySize: Int, - elementSize: Int - ): Int { + private fun arrayByteSize(readContext: ReadContext, buffer: MemoryBuffer, elementSize: Int): Int { val size = if (readContext.isPeerOutOfBandEnabled) buffer.remaining() else buffer.readVarUInt32Small7() - if (size < 0 || size > maxBinarySize) { - throw DeserializationException( - "Binary payload size $size exceeds max binary size $maxBinarySize" - ) + if (size < 0) { + throw DeserializationException("Array byte size must be non-negative: $size") } if (size % elementSize != 0) { throw DeserializationException( - "Binary payload size $size is not aligned to element size $elementSize" + "Array byte size $size is not aligned to element size $elementSize" ) } + buffer.checkReadableBytes(size) return size } diff --git a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt index f420e3b7bc..e3b36d4785 100644 --- a/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt +++ b/kotlin/fory-kotlin/src/main/kotlin/org/apache/fory/serializer/kotlin/CollectionSerializer.kt @@ -21,6 +21,7 @@ package org.apache.fory.serializer.kotlin import org.apache.fory.context.ReadContext import org.apache.fory.context.WriteContext +import org.apache.fory.exception.DeserializationException import org.apache.fory.resolver.TypeResolver import org.apache.fory.serializer.collection.CollectionLikeSerializer @@ -58,7 +59,13 @@ public class KotlinArrayDequeSerializer( override fun newCollection(readContext: ReadContext): Collection { val buffer = readContext.buffer val numElements = buffer.readVarUInt32Small7() + if (numElements < 0) { + throw DeserializationException("Collection size must be non-negative: $numElements") + } setNumElements(numElements) + if (numElements != 0) { + buffer.checkReadableBytes(numElements) + } return ArrayDequeBuilder(ArrayDeque(numElements)) } } diff --git a/python/pyfory/_fory.py b/python/pyfory/_fory.py index b15dcca966..477b3bfb73 100644 --- a/python/pyfory/_fory.py +++ b/python/pyfory/_fory.py @@ -126,8 +126,6 @@ class Fory: "max_depth", "field_nullable", "policy", - "max_collection_size", - "max_binary_size", ) def __init__( @@ -140,8 +138,6 @@ def __init__( policy: DeserializationPolicy = None, field_nullable: bool = False, meta_compressor=None, - max_collection_size: int = 1_000_000, - max_binary_size: int = 64 * 1024 * 1024, ): """ Initialize a Fory serialization instance. @@ -180,17 +176,6 @@ def __init__( field_nullable: Treat all dataclass fields as nullable regardless of Optional annotation. - max_collection_size: Maximum allowed size for collections (lists, sets, tuples) - and maps (dicts) during deserialization. This limit is used to prevent - out-of-memory attacks from malicious payloads that claim extremely large - collection sizes, as collections preallocate memory based on the declared - size. Raises an exception if exceeded. Default is 1,000,000. - - max_binary_size: Maximum allowed size in bytes for binary data reads during - deserialization (default: 64 MB). Raises an exception if a single binary - read exceeds this limit, preventing out-of-memory attacks from malicious - payloads that claim extremely large binary sizes. - Example: >>> # Python native mode with reference tracking >>> fory = Fory(xlang=False, ref=True) @@ -206,8 +191,6 @@ def __init__( self.compatible = compatible self.field_nullable = field_nullable self.max_depth = max_depth - self.max_collection_size = max_collection_size - self.max_binary_size = max_binary_size self.config = Config( xlang=xlang, track_ref=ref, @@ -219,8 +202,6 @@ def __init__( field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, - max_collection_size=max_collection_size, - max_binary_size=max_binary_size, ) from pyfory.registry import SharedRegistry, TypeResolver @@ -232,7 +213,7 @@ def __init__( self.type_resolver.initialize() self.write_context = WriteContext(self.config, self.type_resolver) self.read_context = ReadContext(self.config, self.type_resolver) - self.buffer = Buffer.allocate(32, max_binary_size=max_binary_size) + self.buffer = Buffer.allocate(32) def register( self, @@ -534,7 +515,7 @@ def _deserialize( unsupported_objects: Iterable = None, ): if isinstance(buffer, bytes): - buffer = Buffer(buffer, max_binary_size=self.max_binary_size) + buffer = Buffer(buffer) read_context = self.read_context reader_index = buffer.get_reader_index() buffer.set_reader_index(reader_index + 1) @@ -606,11 +587,6 @@ class ThreadSafeFory: in both xlang and Python native mode. Set False only when every reader and writer always uses the same Python class schema and smaller payloads matter. max_depth (int): Maximum depth for deserialization. Defaults to 50. - max_collection_size (int): Maximum allowed size for collections and maps during - deserialization. Defaults to 1,000,000. - max_binary_size (int): Maximum allowed size in bytes for binary data reads during - deserialization. Defaults to 64 MB. - Example: >>> import pyfory >>> import threading diff --git a/python/pyfory/_serializer.py b/python/pyfory/_serializer.py index 758fa05e5e..ec268190a4 100644 --- a/python/pyfory/_serializer.py +++ b/python/pyfory/_serializer.py @@ -300,13 +300,13 @@ def write(self, write_context, value): write_context.write_buffer(swapped) def read(self, read_context): - payload_size = read_context.read_var_uint32() - data = read_context.read_bytes(payload_size) + byte_size = read_context.read_var_uint32() + data = read_context.read_bytes(byte_size) if self.wrapper_type is BoolArray: return BoolArray(bool(value) for value in data) if self.reduced_precision: - if payload_size & 1: - raise ValueError(f"{self.wrapper_type.__name__} payload size mismatch") + if byte_size & 1: + raise ValueError(f"{self.wrapper_type.__name__} byte size mismatch") raw = array.array("H") raw.frombytes(data) if not is_little_endian: diff --git a/python/pyfory/buffer.pxi b/python/pyfory/buffer.pxi index 6f1e8fd46b..e037e34770 100644 --- a/python/pyfory/buffer.pxi +++ b/python/pyfory/buffer.pxi @@ -18,6 +18,7 @@ from libcpp.memory cimport shared_ptr, unique_ptr from libcpp.utility cimport move from cython.operator cimport dereference as deref from libcpp.string cimport string as c_string +from libc.string cimport memcpy from libc.stdint cimport * from libcpp cimport bool as c_bool from pyfory.includes.libutil cimport( @@ -125,11 +126,8 @@ cdef class Buffer: object output_stream Py_ssize_t shape[1] Py_ssize_t stride[1] - int32_t max_binary_size - - def __init__(self, data not None, int32_t offset=0, length=None, int32_t max_binary_size= 64 * 1024 * 1024): + def __init__(self, data not None, int32_t offset=0, length=None): self.data = data - self.max_binary_size = max_binary_size cdef int32_t buffer_len = len(data) cdef int length_ if length is None: @@ -150,7 +148,7 @@ cdef class Buffer: self.output_stream = None @classmethod - def from_stream(cls, stream not None, uint32_t buffer_size=4096, int32_t max_binary_size=64 * 1024 * 1024): + def from_stream(cls, stream not None, uint32_t buffer_size=4096): cdef CBuffer* stream_buffer cdef c_string stream_error if Fory_PyCreateBufferFromStream( @@ -160,7 +158,6 @@ cdef class Buffer: if stream_buffer == NULL: raise ValueError("failed to create stream buffer") cdef Buffer buffer = Buffer.__new__(Buffer) - buffer.max_binary_size = max_binary_size buffer.c_buffer_owner.reset(stream_buffer) buffer.c_buffer = buffer.c_buffer_owner.get() buffer.data = stream @@ -172,7 +169,6 @@ cdef class Buffer: @staticmethod cdef Buffer wrap(shared_ptr[CBuffer] c_buffer): cdef Buffer buffer = Buffer.__new__(Buffer) - buffer.max_binary_size = 64 * 1024 * 1024 cdef CBuffer* ptr = c_buffer.get() buffer.c_buffer = ptr cdef _SharedBufferOwner owner = _SharedBufferOwner.__new__(_SharedBufferOwner) @@ -184,12 +180,11 @@ cdef class Buffer: return buffer @classmethod - def allocate(cls, int32_t size, int32_t max_binary_size=64 * 1024 * 1024): + def allocate(cls, int32_t size): cdef CBuffer* buf = allocate_buffer(size) if buf == NULL: raise MemoryError("out of memory") cdef Buffer buffer = Buffer.__new__(Buffer) - buffer.max_binary_size = max_binary_size buffer.c_buffer_owner.reset(buf) buffer.c_buffer = buffer.c_buffer_owner.get() buffer.data = None @@ -334,7 +329,7 @@ cdef class Buffer: cpdef inline check_bound(self, int32_t offset, int32_t length): cdef int32_t size_ = self.c_buffer.size() - if offset | length | (offset + length) | (size_- (offset + length)) < 0: + if offset < 0 or length < 0 or offset > size_ or length > size_ - offset: raise_fory_error( CErrorCode.BufferOutOfBound, f"Address range {offset, offset + length} out of bound {0, size_}", @@ -419,17 +414,19 @@ cdef class Buffer: cpdef inline bytes read_bytes(self, int32_t length): if length == 0: return b"" - - if length > self.max_binary_size: - raise ValueError(f"Binary size {length} exceeds the configured limit of {self.max_binary_size}") + if length < 0: + raise_fory_error(CErrorCode.InvalidData, f"Binary size {length} is negative") + if not self.c_buffer.ensure_readable(length, self._error): + if not self._error.ok(): + self._raise_if_error() cdef bytes py_bytes = PyBytes_FromStringAndSize(NULL, length) if py_bytes is None: raise MemoryError("out of memory") cdef char* buf = PyBytes_AS_STRING(py_bytes) - self.c_buffer.read_bytes(buf, length, self._error) - if not self._error.ok(): - self._raise_if_error() + cdef uint32_t offset = self.c_buffer.reader_index() + memcpy(buf, self.c_buffer.data() + offset, length) + self.c_buffer.reader_index(offset + length) return py_bytes cpdef inline int64_t read_bytes_as_int64(self, int32_t length): @@ -713,8 +710,6 @@ cdef class Buffer: cpdef inline str read_string(self): cdef uint64_t header = self.read_var_uint64() cdef uint64_t size64 = header >> 2 - if size64 > self.max_binary_size: - raise ValueError(f"String size {size64} exceeds the configured limit of {self.max_binary_size}") if size64 > 2147483647: raise ValueError(f"String size {size64} exceeds the maximum supported size") cdef uint32_t size = size64 diff --git a/python/pyfory/collection.pxi b/python/pyfory/collection.pxi index 94c9d3a50c..0183b26231 100644 --- a/python/pyfory/collection.pxi +++ b/python/pyfory/collection.pxi @@ -467,15 +467,13 @@ cdef class ListSerializer(CollectionSerializer): cdef int32_t ref_id cdef int64_t i - if len_ > read_context.max_collection_size: - raise ValueError( - f"List size {len_} exceeds the configured limit of {read_context.max_collection_size}" - ) - list_ = PyList_New(len_) if len_ == 0: + list_ = PyList_New(0) return list_ + read_context.check_readable_bytes(len_) collect_flag = buffer.read_int8() + list_ = PyList_New(len_) # IMPORTANT: collection readers must obey the ref/null bits written on # the wire, not local Python/Cython element metadata that may imply a # different ref policy. Shared xlang tests intentionally deserialize @@ -586,15 +584,13 @@ cdef class TupleSerializer(CollectionSerializer): cdef int8_t head_flag cdef int64_t i - if len_ > read_context.max_collection_size: - raise ValueError( - f"Tuple size {len_} exceeds the configured limit of {read_context.max_collection_size}" - ) - tuple_ = PyTuple_New(len_) if len_ == 0: + tuple_ = PyTuple_New(0) return tuple_ + read_context.check_readable_bytes(len_) collect_flag = buffer.read_int8() + tuple_ = PyTuple_New(len_) if (collect_flag & COLL_IS_SAME_TYPE) != 0: if (collect_flag & COLL_IS_DECL_ELEMENT_TYPE) == 0: typeinfo = type_resolver.read_type_info(read_context) @@ -708,10 +704,6 @@ cdef class SetSerializer(CollectionSerializer): read_context.reference(instance) len_ = buffer.read_var_uint32() - if len_ > read_context.max_collection_size: - raise ValueError( - f"Set size {len_} exceeds the configured limit of {read_context.max_collection_size}" - ) if len_ == 0: return instance @@ -1054,12 +1046,14 @@ cdef class MapSerializer(Serializer): cpdef inline read(self, ReadContext read_context): cdef int32_t size = read_context.read_var_uint32() cdef int32_t ref_id - if size > read_context.max_collection_size: - raise ValueError(f"Map size {size} exceeds the configured limit of {read_context.max_collection_size}") - cdef dict map_ = _PyDict_NewPresized(size) + cdef dict map_ cdef int8_t chunk_header = 0 - if size != 0: + if size == 0: + map_ = {} + else: + read_context.check_readable_bytes(size) chunk_header = read_context.read_uint8() + map_ = _PyDict_NewPresized(size) cdef RefReader ref_reader = read_context.ref_reader cdef Serializer key_serializer = self.key_serializer cdef Serializer value_serializer = self.value_serializer diff --git a/python/pyfory/collection.py b/python/pyfory/collection.py index f9f0610c0d..d78673a6dc 100644 --- a/python/pyfory/collection.py +++ b/python/pyfory/collection.py @@ -176,8 +176,6 @@ def _write_different_types(self, write_context, value, collect_flag=0): def read(self, read_context): length = read_context.read_var_uint32() - if length > read_context.max_collection_size: - raise ValueError(f"Collection size {length} exceeds the configured limit of {read_context.max_collection_size}") collection_ = self.new_instance(read_context, self.type_) if length == 0: return collection_ @@ -457,8 +455,6 @@ def write(self, write_context, obj): def read(self, read_context): size = read_context.read_var_uint32() - if size > read_context.max_collection_size: - raise ValueError(f"Map size {size} exceeds the configured limit of {read_context.max_collection_size}") map_ = {} ref_reader = read_context.ref_reader read_context.reference(map_) diff --git a/python/pyfory/context.pxi b/python/pyfory/context.pxi index 1a29fe8de3..702f09769c 100644 --- a/python/pyfory/context.pxi +++ b/python/pyfory/context.pxi @@ -483,8 +483,6 @@ cdef class WriteContext: cdef readonly bint compatible cdef readonly bint field_nullable cdef readonly object policy - cdef readonly int32_t max_collection_size - cdef readonly int32_t max_binary_size cdef readonly RefWriter ref_writer cdef readonly MetaStringWriter meta_string_writer cdef readonly MetaShareWriteContext meta_share_context @@ -503,8 +501,6 @@ cdef class WriteContext: self.compatible = config.compatible self.field_nullable = config.field_nullable self.policy = config.policy - self.max_collection_size = config.max_collection_size - self.max_binary_size = config.max_binary_size self.ref_writer = RefWriter(self.track_ref) self.meta_string_writer = MetaStringWriter() self.meta_share_context = MetaShareWriteContext() if config.scoped_meta_share_enabled else None @@ -750,8 +746,6 @@ cdef class ReadContext: cdef readonly bint field_nullable cdef readonly object policy cdef readonly int32_t max_depth - cdef readonly int32_t max_collection_size - cdef readonly int32_t max_binary_size cdef readonly RefReader ref_reader cdef readonly MetaStringReader meta_string_reader cdef readonly MetaShareReadContext meta_share_context @@ -772,8 +766,6 @@ cdef class ReadContext: self.field_nullable = config.field_nullable self.policy = config.policy self.max_depth = config.max_depth - self.max_collection_size = config.max_collection_size - self.max_binary_size = config.max_binary_size self.ref_reader = RefReader(self.track_ref) self.meta_string_reader = MetaStringReader(self.type_resolver.shared_registry) self.meta_share_context = MetaShareReadContext() if config.scoped_meta_share_enabled else None @@ -951,11 +943,10 @@ cdef class ReadContext: cdef Buffer buf if not self.peer_out_of_band_enabled: size = self.read_var_uint32() - if size > self.max_binary_size: - raise ValueError(f"Binary size {size} exceeds the configured limit of {self.max_binary_size}") if self.buffer.has_input_stream(): return self.buffer.read_bytes(size) reader_index = self.buffer.get_reader_index() + self.buffer.check_bound(reader_index, size) buf = self.buffer.slice(reader_index, size) self.buffer.set_reader_index(reader_index + size) return buf @@ -963,11 +954,10 @@ cdef class ReadContext: assert self.buffers is not None return next(self.buffers) size = self.read_var_uint32() - if size > self.max_binary_size: - raise ValueError(f"Binary size {size} exceeds the configured limit of {self.max_binary_size}") if self.buffer.has_input_stream(): return self.buffer.read_bytes(size) reader_index = self.buffer.get_reader_index() + self.buffer.check_bound(reader_index, size) buf = self.buffer.slice(reader_index, size) self.buffer.set_reader_index(reader_index + size) return buf @@ -1110,6 +1100,17 @@ cdef class ReadContext: cpdef read_bytes_and_size(self): return self.buffer.read_bytes_and_size() + cpdef check_readable_bytes(self, int32_t length): + cdef Buffer buffer + if length < 0: + raise_fory_error(CErrorCode.InvalidData, f"Readable byte count {length} is negative") + if length == 0: + return + buffer = self.buffer + if not self.c_buffer.ensure_readable(length, buffer._error): + if not buffer._error.ok(): + buffer._raise_if_error() + cpdef inline int32_t get_reader_index(self): return self.buffer.get_reader_index() diff --git a/python/pyfory/context.py b/python/pyfory/context.py index 2439b9046d..3abfb46e3d 100644 --- a/python/pyfory/context.py +++ b/python/pyfory/context.py @@ -230,8 +230,6 @@ class WriteContext: "compatible", "field_nullable", "policy", - "max_collection_size", - "max_binary_size", "ref_writer", "meta_string_writer", "meta_share_context", @@ -250,8 +248,6 @@ def __init__(self, config: Config, type_resolver): self.compatible = config.compatible self.field_nullable = config.field_nullable self.policy = config.policy - self.max_collection_size = config.max_collection_size - self.max_binary_size = config.max_binary_size self.ref_writer = MapRefWriter() if self.track_ref else NoRefWriter() self.meta_string_writer = MetaStringWriter() self.meta_share_context = MetaShareWriteContext() if config.scoped_meta_share_enabled else None @@ -473,8 +469,6 @@ class ReadContext: "compatible", "field_nullable", "policy", - "max_collection_size", - "max_binary_size", "max_depth", "ref_reader", "meta_string_reader", @@ -495,8 +489,6 @@ def __init__(self, config: Config, type_resolver): self.compatible = config.compatible self.field_nullable = config.field_nullable self.policy = config.policy - self.max_collection_size = config.max_collection_size - self.max_binary_size = config.max_binary_size self.max_depth = config.max_depth self.ref_reader = MapRefReader() if self.track_ref else NoRefReader() self.meta_string_reader = MetaStringReader(type_resolver.shared_registry) @@ -514,6 +506,14 @@ def __getattr__(self, name): raise AttributeError(name) return getattr(buffer, name) + def check_readable_bytes(self, length): + if length < 0: + raise ValueError(f"Readable byte count {length} is negative") + if length == 0: + return + reader_index = self.buffer.get_reader_index() + self.buffer.check_bound(reader_index, length) + def prepare( self, buffer, @@ -636,11 +636,10 @@ def _read_non_ref_internal(self, serializer=None): def read_buffer_object(self): if not self.peer_out_of_band_enabled: size = self.buffer.read_var_uint32() - if size > self.max_binary_size: - raise ValueError(f"Binary size {size} exceeds the configured limit of {self.max_binary_size}") if self.buffer.has_input_stream(): return self.buffer.read_bytes(size) reader_index = self.buffer.get_reader_index() + self.buffer.check_bound(reader_index, size) buf = self.buffer.slice(reader_index, size) self.buffer.set_reader_index(reader_index + size) return buf @@ -649,11 +648,10 @@ def read_buffer_object(self): assert self.buffers is not None return next(self.buffers) size = self.buffer.read_var_uint32() - if size > self.max_binary_size: - raise ValueError(f"Binary size {size} exceeds the configured limit of {self.max_binary_size}") if self.buffer.has_input_stream(): return self.buffer.read_bytes(size) reader_index = self.buffer.get_reader_index() + self.buffer.check_bound(reader_index, size) buf = self.buffer.slice(reader_index, size) self.buffer.set_reader_index(reader_index + size) return buf diff --git a/python/pyfory/converter.py b/python/pyfory/converter.py index 6538a489a9..81a7fa6030 100644 --- a/python/pyfory/converter.py +++ b/python/pyfory/converter.py @@ -362,14 +362,14 @@ def compatible_scalar_convert(value, remote_type_id: int, local_type_id: int): raise ValueError(f"type id {local_type_id} is not a compatible scalar target") -def _read_compatible_scalar_payload(read_context, remote_serializer, remote_type_id: int): +def _read_compatible_scalar_value(read_context, remote_serializer, remote_type_id: int): if remote_type_id == TypeId.BOOL: raw = read_context.read_uint8() if raw == 0: return False if raw == 1: return True - raise ValueError("bool payload must be encoded as 0 or 1") + raise ValueError("bool byte must be encoded as 0 or 1") return remote_serializer.read(read_context) @@ -396,7 +396,7 @@ def write(self, write_context, value): def read(self, read_context): value = None try: - value = _read_compatible_scalar_payload(read_context, self.remote_serializer, self.remote_type_id) + value = _read_compatible_scalar_value(read_context, self.remote_serializer, self.remote_type_id) return compatible_scalar_convert(value, self.remote_type_id, self.local_type_id) except (ValueError, OverflowError, decimal.InvalidOperation) as exc: _scalar_conversion_error(self.field_name, self.remote_type_id, self.local_type_id, value, exc) @@ -455,8 +455,6 @@ def read(self, read_context): from pyfory.error import TypeNotCompatibleError length = read_context.read_var_uint32() - if length > read_context.max_collection_size: - raise ValueError(f"Collection size {length} exceeds the configured limit of {read_context.max_collection_size}") if length == 0: return self._empty_target() collect_flag = read_context.read_int8() @@ -469,6 +467,7 @@ def read(self, read_context): f"Field {self.field_name!r} requires declared same-type list elements for array compatible read", ) + read_context.check_readable_bytes(length) target = self._new_target(length) append = None if np is not None and _is_numpy_1d_array_serializer(self.target_serializer) else target.append for index in range(length): diff --git a/python/pyfory/cpp/pyfory.cc b/python/pyfory/cpp/pyfory.cc index 06b587db85..864c34b912 100644 --- a/python/pyfory/cpp/pyfory.cc +++ b/python/pyfory/cpp/pyfory.cc @@ -300,24 +300,30 @@ class PyInputStream final : public InputStream { } const uint32_t read_pos = buffer_->reader_index_; - const uint32_t deficit = min_fill_size - remaining_size(); constexpr uint64_t k_max_u32 = std::numeric_limits::max(); - const uint64_t required = static_cast(buffer_->size_) + deficit; - if (required > k_max_u32) { + const uint64_t target = static_cast(read_pos) + min_fill_size; + if (target > k_max_u32) { return Unexpected( Error::out_of_bound("stream buffer size exceeds uint32 range")); } - if (required > data_.size()) { - uint64_t new_size = - std::max(required, static_cast(data_.size()) * 2); - if (new_size > k_max_u32) { - new_size = k_max_u32; - } - reserve(static_cast(new_size)); - } uint32_t write_pos = buffer_->size_; while (remaining_size() < min_fill_size) { + if (write_pos == data_.size()) { + // min_fill_size can come from attacker-controlled wire lengths. Grow + // only from bytes already buffered so truncated streams fail before + // reserving the declared body size. + uint64_t new_size = + std::max(static_cast(data_.size()) * 2, + static_cast(initial_buffer_size_)); + if (new_size <= data_.size()) { + new_size = static_cast(data_.size()) + 1; + } + if (new_size > target) { + new_size = target; + } + reserve(static_cast(new_size)); + } uint32_t writable = static_cast(data_.size()) - write_pos; auto read_result = recv_into(data_.data() + write_pos, writable); if (FORY_PREDICT_FALSE(!read_result.ok())) { diff --git a/python/pyfory/number.pxi b/python/pyfory/number.pxi index b11cdf2316..cc5be5c204 100644 --- a/python/pyfory/number.pxi +++ b/python/pyfory/number.pxi @@ -852,7 +852,7 @@ cpdef object _float16_array_from_buffer(object buffer): cdef object bits cdef Float16Array values = Float16Array() if len(raw_bytes) & 1: - raise ValueError("float16 bits payload size mismatch") + raise ValueError("float16 bits byte size mismatch") raw.frombytes(raw_bytes) for bits in raw: values._values.push_back(bits) @@ -880,7 +880,7 @@ cpdef object _bfloat16_array_from_buffer(object buffer): cdef object bits cdef BFloat16Array values = BFloat16Array() if len(raw_bytes) & 1: - raise ValueError("bfloat16 bits payload size mismatch") + raise ValueError("bfloat16 bits byte size mismatch") raw.frombytes(raw_bytes) for bits in raw: values._values.push_back(bits) @@ -896,10 +896,16 @@ cdef class Float16Serializer(Serializer): return _float16_bits_to_float(read_context.read_uint16()) -cdef inline uint32_t _array_payload_count(uint32_t payload_size, uint32_t item_size, str name) except *: - if payload_size % item_size != 0: - raise ValueError(f"{name} payload size mismatch") - return payload_size // item_size +cdef inline uint32_t _array_element_count(uint32_t byte_size, uint32_t item_size, str name) except *: + if byte_size % item_size != 0: + raise ValueError(f"{name} byte size mismatch") + return byte_size // item_size + + +cdef inline void _ensure_array_bytes_readable(ReadContext read_context, uint32_t byte_size) except *: + if byte_size > 0 and not read_context.c_buffer.ensure_readable(byte_size, read_context.buffer._error): + if not read_context.buffer._error.ok(): + read_context.buffer._raise_if_error() cdef inline void _write_uint8_vector(WriteContext write_context, vector[uint8_t]& values) except *: @@ -910,10 +916,11 @@ cdef inline void _write_uint8_vector(WriteContext write_context, vector[uint8_t] cdef inline void _read_uint8_vector(ReadContext read_context, vector[uint8_t]& values, str name) except *: - cdef uint32_t payload_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) - values.resize(payload_size) - if payload_size > 0: - read_context.c_buffer.read_bytes(&values[0], payload_size, read_context.buffer._error) + cdef uint32_t byte_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) + _ensure_array_bytes_readable(read_context, byte_size) + values.resize(byte_size) + if byte_size > 0: + read_context.c_buffer.read_bytes(&values[0], byte_size, read_context.buffer._error) cdef inline void _write_int8_vector(WriteContext write_context, vector[int8_t]& values) except *: @@ -924,10 +931,11 @@ cdef inline void _write_int8_vector(WriteContext write_context, vector[int8_t]& cdef inline void _read_int8_vector(ReadContext read_context, vector[int8_t]& values, str name) except *: - cdef uint32_t payload_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) - values.resize(payload_size) - if payload_size > 0: - read_context.c_buffer.read_bytes(&values[0], payload_size, read_context.buffer._error) + cdef uint32_t byte_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) + _ensure_array_bytes_readable(read_context, byte_size) + values.resize(byte_size) + if byte_size > 0: + read_context.c_buffer.read_bytes(&values[0], byte_size, read_context.buffer._error) cdef inline void _write_int16_vector(WriteContext write_context, vector[int16_t]& values) except *: @@ -946,13 +954,14 @@ cdef inline void _write_int16_vector(WriteContext write_context, vector[int16_t] cdef inline void _read_int16_vector(ReadContext read_context, vector[int16_t]& values, str name) except *: cdef uint32_t i - cdef uint32_t payload_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) - cdef uint32_t count = _array_payload_count(payload_size, sizeof(int16_t), name) + cdef uint32_t byte_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) + cdef uint32_t count = _array_element_count(byte_size, sizeof(int16_t), name) + _ensure_array_bytes_readable(read_context, byte_size) values.resize(count) - if payload_size == 0: + if byte_size == 0: return if is_little_endian: - read_context.c_buffer.read_bytes(&values[0], payload_size, read_context.buffer._error) + read_context.c_buffer.read_bytes(&values[0], byte_size, read_context.buffer._error) else: for i in range(count): values[i] = read_context.c_buffer.read_int16(read_context.buffer._error) @@ -974,13 +983,14 @@ cdef inline void _write_uint16_vector(WriteContext write_context, vector[uint16_ cdef inline void _read_uint16_vector(ReadContext read_context, vector[uint16_t]& values, str name) except *: cdef uint32_t i - cdef uint32_t payload_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) - cdef uint32_t count = _array_payload_count(payload_size, sizeof(uint16_t), name) + cdef uint32_t byte_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) + cdef uint32_t count = _array_element_count(byte_size, sizeof(uint16_t), name) + _ensure_array_bytes_readable(read_context, byte_size) values.resize(count) - if payload_size == 0: + if byte_size == 0: return if is_little_endian: - read_context.c_buffer.read_bytes(&values[0], payload_size, read_context.buffer._error) + read_context.c_buffer.read_bytes(&values[0], byte_size, read_context.buffer._error) else: for i in range(count): values[i] = read_context.c_buffer.read_uint16(read_context.buffer._error) @@ -1002,13 +1012,14 @@ cdef inline void _write_int32_vector(WriteContext write_context, vector[int32_t] cdef inline void _read_int32_vector(ReadContext read_context, vector[int32_t]& values, str name) except *: cdef uint32_t i - cdef uint32_t payload_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) - cdef uint32_t count = _array_payload_count(payload_size, sizeof(int32_t), name) + cdef uint32_t byte_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) + cdef uint32_t count = _array_element_count(byte_size, sizeof(int32_t), name) + _ensure_array_bytes_readable(read_context, byte_size) values.resize(count) - if payload_size == 0: + if byte_size == 0: return if is_little_endian: - read_context.c_buffer.read_bytes(&values[0], payload_size, read_context.buffer._error) + read_context.c_buffer.read_bytes(&values[0], byte_size, read_context.buffer._error) else: for i in range(count): values[i] = read_context.c_buffer.read_int32(read_context.buffer._error) @@ -1030,13 +1041,14 @@ cdef inline void _write_uint32_vector(WriteContext write_context, vector[uint32_ cdef inline void _read_uint32_vector(ReadContext read_context, vector[uint32_t]& values, str name) except *: cdef uint32_t i - cdef uint32_t payload_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) - cdef uint32_t count = _array_payload_count(payload_size, sizeof(uint32_t), name) + cdef uint32_t byte_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) + cdef uint32_t count = _array_element_count(byte_size, sizeof(uint32_t), name) + _ensure_array_bytes_readable(read_context, byte_size) values.resize(count) - if payload_size == 0: + if byte_size == 0: return if is_little_endian: - read_context.c_buffer.read_bytes(&values[0], payload_size, read_context.buffer._error) + read_context.c_buffer.read_bytes(&values[0], byte_size, read_context.buffer._error) else: for i in range(count): values[i] = read_context.c_buffer.read_uint32(read_context.buffer._error) @@ -1058,13 +1070,14 @@ cdef inline void _write_int64_vector(WriteContext write_context, vector[int64_t] cdef inline void _read_int64_vector(ReadContext read_context, vector[int64_t]& values, str name) except *: cdef uint32_t i - cdef uint32_t payload_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) - cdef uint32_t count = _array_payload_count(payload_size, sizeof(int64_t), name) + cdef uint32_t byte_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) + cdef uint32_t count = _array_element_count(byte_size, sizeof(int64_t), name) + _ensure_array_bytes_readable(read_context, byte_size) values.resize(count) - if payload_size == 0: + if byte_size == 0: return if is_little_endian: - read_context.c_buffer.read_bytes(&values[0], payload_size, read_context.buffer._error) + read_context.c_buffer.read_bytes(&values[0], byte_size, read_context.buffer._error) else: for i in range(count): values[i] = read_context.c_buffer.read_int64(read_context.buffer._error) @@ -1086,13 +1099,14 @@ cdef inline void _write_uint64_vector(WriteContext write_context, vector[uint64_ cdef inline void _read_uint64_vector(ReadContext read_context, vector[uint64_t]& values, str name) except *: cdef uint32_t i - cdef uint32_t payload_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) - cdef uint32_t count = _array_payload_count(payload_size, sizeof(uint64_t), name) + cdef uint32_t byte_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) + cdef uint32_t count = _array_element_count(byte_size, sizeof(uint64_t), name) + _ensure_array_bytes_readable(read_context, byte_size) values.resize(count) - if payload_size == 0: + if byte_size == 0: return if is_little_endian: - read_context.c_buffer.read_bytes(&values[0], payload_size, read_context.buffer._error) + read_context.c_buffer.read_bytes(&values[0], byte_size, read_context.buffer._error) else: for i in range(count): values[i] = read_context.c_buffer.read_uint64(read_context.buffer._error) @@ -1114,13 +1128,14 @@ cdef inline void _write_float_vector(WriteContext write_context, vector[float]& cdef inline void _read_float_vector(ReadContext read_context, vector[float]& values, str name) except *: cdef uint32_t i - cdef uint32_t payload_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) - cdef uint32_t count = _array_payload_count(payload_size, sizeof(float), name) + cdef uint32_t byte_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) + cdef uint32_t count = _array_element_count(byte_size, sizeof(float), name) + _ensure_array_bytes_readable(read_context, byte_size) values.resize(count) - if payload_size == 0: + if byte_size == 0: return if is_little_endian: - read_context.c_buffer.read_bytes(&values[0], payload_size, read_context.buffer._error) + read_context.c_buffer.read_bytes(&values[0], byte_size, read_context.buffer._error) else: for i in range(count): values[i] = read_context.c_buffer.read_float(read_context.buffer._error) @@ -1142,13 +1157,14 @@ cdef inline void _write_double_vector(WriteContext write_context, vector[double] cdef inline void _read_double_vector(ReadContext read_context, vector[double]& values, str name) except *: cdef uint32_t i - cdef uint32_t payload_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) - cdef uint32_t count = _array_payload_count(payload_size, sizeof(double), name) + cdef uint32_t byte_size = read_context.c_buffer.read_var_uint32(read_context.buffer._error) + cdef uint32_t count = _array_element_count(byte_size, sizeof(double), name) + _ensure_array_bytes_readable(read_context, byte_size) values.resize(count) - if payload_size == 0: + if byte_size == 0: return if is_little_endian: - read_context.c_buffer.read_bytes(&values[0], payload_size, read_context.buffer._error) + read_context.c_buffer.read_bytes(&values[0], byte_size, read_context.buffer._error) else: for i in range(count): values[i] = read_context.c_buffer.read_double(read_context.buffer._error) diff --git a/python/pyfory/registry.py b/python/pyfory/registry.py index ae73d4b090..15231b9d9f 100644 --- a/python/pyfory/registry.py +++ b/python/pyfory/registry.py @@ -342,8 +342,6 @@ class TypeResolver: "compatible", "field_nullable", "policy", - "max_collection_size", - "max_binary_size", "shared_registry", "_type_id_counter", "_types_info", @@ -374,8 +372,6 @@ def __init__(self, config, *, shared_registry): self.compatible = config.compatible self.field_nullable = config.field_nullable self.policy = config.policy - self.max_collection_size = config.max_collection_size - self.max_binary_size = config.max_binary_size self.shared_registry = shared_registry self.require_registration = self.strict self._metastr_to_type = dict() diff --git a/python/pyfory/serialization.pyx b/python/pyfory/serialization.pyx index 584e76c8db..0270ef829b 100644 --- a/python/pyfory/serialization.pyx +++ b/python/pyfory/serialization.pyx @@ -121,7 +121,7 @@ cdef class Config: directional read/write contexts. The Cython runtime treats this object as the single source of truth for - execution-mode flags and guardrail limits. Higher-level facades may expose + execution-mode and maximum-depth flags. Higher-level facades may expose convenience accessors, but runtime code should read these values from the config instance instead of mirroring them onto other owners. @@ -136,8 +136,6 @@ cdef class Config: field_nullable: Treats struct/dataclass fields as nullable by default. policy: Deserialization policy used for security-sensitive checks. meta_compressor: Optional typedef/meta compressor implementation. - max_collection_size: Upper bound for declared collection/map sizes. - max_binary_size: Upper bound for a single binary payload read. """ cdef public bint xlang @@ -150,8 +148,6 @@ cdef class Config: cdef public bint field_nullable cdef public object policy cdef public object meta_compressor - cdef public int32_t max_collection_size - cdef public int32_t max_binary_size def __init__( self, @@ -166,8 +162,6 @@ cdef class Config: field_nullable, policy, meta_compressor, - max_collection_size, - max_binary_size, ): """ Build a runtime config object for one Python or Cython Fory instance. @@ -183,8 +177,6 @@ cdef class Config: field_nullable: Treat all struct fields as nullable by default. policy: Deserialization policy implementation. meta_compressor: Optional typedef/meta compressor. - max_collection_size: Maximum declared collection/map size. - max_binary_size: Maximum binary payload size for one read. """ self.xlang = xlang self.track_ref = track_ref @@ -196,8 +188,6 @@ cdef class Config: self.field_nullable = field_nullable self.policy = policy self.meta_compressor = meta_compressor - self.max_collection_size = max_collection_size - self.max_binary_size = max_binary_size cdef inline bint _is_struct_type_id(uint8_t type_id): @@ -232,8 +222,6 @@ cdef class TypeResolver: cdef readonly bint compatible cdef readonly bint field_nullable cdef readonly object policy - cdef readonly int32_t max_collection_size - cdef readonly int32_t max_binary_size cdef readonly bint meta_share cdef readonly dict _types_info cdef readonly dict _type_id_to_type_info @@ -267,8 +255,6 @@ cdef class TypeResolver: self.compatible = resolver.compatible self.field_nullable = resolver.field_nullable self.policy = resolver.policy - self.max_collection_size = resolver.max_collection_size - self.max_binary_size = resolver.max_binary_size self.meta_share = resolver.meta_share self._types_info = resolver._types_info self._type_id_to_type_info = resolver._type_id_to_type_info @@ -820,8 +806,6 @@ cdef class Fory: cdef public bint field_nullable cdef public int32_t max_depth cdef public object policy - cdef public int32_t max_collection_size - cdef public int32_t max_binary_size cdef public Config config cdef public TypeResolver type_resolver cdef public WriteContext write_context @@ -838,8 +822,6 @@ cdef class Fory: policy=None, field_nullable=False, meta_compressor=None, - max_collection_size=1_000_000, - max_binary_size=64 * 1024 * 1024, ): """ Initialize a Cython-backed Fory runtime instance. @@ -854,8 +836,6 @@ cdef class Fory: policy: Optional deserialization policy implementation. field_nullable: Treat struct fields as nullable by default. meta_compressor: Optional typedef/meta compressor implementation. - max_collection_size: Maximum allowed declared collection/map size. - max_binary_size: Maximum allowed binary payload size for one read. """ compatible = True if compatible is None else compatible self.xlang = xlang @@ -870,8 +850,6 @@ cdef class Fory: self.compatible = compatible self.field_nullable = field_nullable self.max_depth = max_depth - self.max_collection_size = max_collection_size - self.max_binary_size = max_binary_size self.config = Config( xlang=xlang, track_ref=ref, @@ -883,8 +861,6 @@ cdef class Fory: field_nullable=field_nullable, policy=self.policy, meta_compressor=meta_compressor, - max_collection_size=max_collection_size, - max_binary_size=max_binary_size, ) from pyfory.registry import SharedRegistry @@ -896,7 +872,7 @@ cdef class Fory: self.type_resolver.initialize() self.write_context = WriteContext(self.config, self.type_resolver) self.read_context = ReadContext(self.config, self.type_resolver) - self.buffer = Buffer.allocate(32, max_binary_size=max_binary_size) + self.buffer = Buffer.allocate(32) def register( self, @@ -1040,7 +1016,7 @@ cdef class Fory: cdef uint8_t bitmap cdef bint peer_out_of_band_enabled if isinstance(buffer, bytes): - buffer = Buffer(buffer, max_binary_size=self.max_binary_size) + buffer = Buffer(buffer) read_buffer = buffer reader_index = read_buffer.get_reader_index() read_buffer.set_reader_index(reader_index + 1) diff --git a/python/pyfory/serializer.py b/python/pyfory/serializer.py index 888f208cfd..bc07ed2933 100644 --- a/python/pyfory/serializer.py +++ b/python/pyfory/serializer.py @@ -64,11 +64,9 @@ def _resolve_validated_module_qualname(policy, module_name, qualname): return obj -def _check_collection_size(read_context, size, kind): +def _check_non_negative_size(size, kind): if size < 0: raise ValueError(f"{kind} size {size} must be non-negative") - if size > read_context.max_collection_size: - raise ValueError(f"{kind} size {size} exceeds the configured limit of {read_context.max_collection_size}") def _is_local_qualname(module_name, qualname): @@ -119,6 +117,11 @@ def _validate_function_value(policy, func, is_local): return func +def _authorize_callable_materialization(policy, callable_type, **kwargs): + if policy is not DEFAULT_POLICY: + policy.authorize_instantiation(callable_type, **kwargs) + + def _bind_static_method(obj, method_name): cls = obj if isinstance(obj, type) else obj.__class__ try: @@ -365,10 +368,10 @@ def _write_decimal_parts(write_context, scale: int, unscaled: int): magnitude = abs(unscaled) if magnitude == 0: raise ValueError("Zero must use the small decimal encoding") - payload = magnitude.to_bytes((magnitude.bit_length() + 7) // 8, "little", signed=False) - meta = (len(payload) << 1) | (1 if unscaled < 0 else 0) + magnitude_bytes = magnitude.to_bytes((magnitude.bit_length() + 7) // 8, "little", signed=False) + meta = (len(magnitude_bytes) << 1) | (1 if unscaled < 0 else 0) _write_var_uint64(write_context, (meta << 1) | 1) - write_context.write_bytes(payload) + write_context.write_bytes(magnitude_bytes) def _write_var_uint64(write_context, value: int): @@ -390,10 +393,10 @@ def _read_decimal_parts(read_context) -> Tuple[int, int]: length = meta >> 1 if length <= 0: raise ValueError(f"Invalid decimal magnitude length {length}") - payload = read_context.read_bytes(length) - if payload[-1] == 0: - raise ValueError("Non-canonical decimal payload: trailing zero byte") - magnitude = int.from_bytes(payload, "little", signed=False) + magnitude_bytes = read_context.read_bytes(length) + if magnitude_bytes[-1] == 0: + raise ValueError("Non-canonical decimal magnitude bytes: trailing zero byte") + magnitude = int.from_bytes(magnitude_bytes, "little", signed=False) if magnitude == 0: raise ValueError("Big decimal encoding must not represent zero") return scale, -magnitude if sign else magnitude @@ -925,15 +928,16 @@ def read(self, read_context): read_context.set_reader_index(reader_index) dtype = np.dtype(read_context.read_string()) ndim = read_context.read_var_uint32() - _check_collection_size(read_context, ndim, "ndarray dimension") + _check_non_negative_size(ndim, "ndarray dimension") shape = tuple(read_context.read_var_uint32() for _ in range(ndim)) if dtype.kind == "O": length = read_context.read_varint32() - _check_collection_size(read_context, length, "ndarray object") + _check_non_negative_size(length, "ndarray object") + read_context.check_readable_bytes(length) items = [read_context.read_ref() for _ in range(length)] return np.array(items, dtype=object) for dim in shape: - _check_collection_size(read_context, dim, "ndarray dimension") + _check_non_negative_size(dim, "ndarray dimension") fory_buf = read_context.read_buffer_object() if isinstance(fory_buf, memoryview): return np.frombuffer(fory_buf, dtype=dtype).reshape(shape) @@ -1323,7 +1327,7 @@ def _deserialize_local_class(self, read_context): ref_id = read_context.last_preserved_ref_id() num_bases = read_context.read_var_uint32() - _check_collection_size(read_context, num_bases, "local class base") + _check_non_negative_size(num_bases, "local class base") bases = tuple(read_context.read_ref() for _ in range(num_bases)) read_context.policy.authorize_instantiation(type, module=module, qualname=qualname, bases=bases) cls = type(name, bases, {}) @@ -1331,7 +1335,7 @@ def _deserialize_local_class(self, read_context): read_context.policy.validate_class(cls, is_local=True) num_class_methods = read_context.read_var_uint32() - _check_collection_size(read_context, num_class_methods, "local class method") + _check_non_negative_size(num_class_methods, "local class method") for _ in range(num_class_methods): attr_name = read_context.read_string() func = read_context.read_ref() @@ -1504,9 +1508,10 @@ def _deserialize_function(self, read_context): func_type_id = read_context.read_int8() if func_type_id == 0: + policy = read_context.policy + _authorize_callable_materialization(policy, types.MethodType) self_obj = read_context.read_ref() method_name = read_context.read_string() - policy = read_context.policy if policy is DEFAULT_POLICY: return getattr(self_obj, method_name) return _resolve_validated_bound_method(policy, self_obj, method_name, is_local=_is_local_receiver(self_obj)) @@ -1519,7 +1524,15 @@ def _deserialize_function(self, read_context): module = read_context.read_string() qualname = read_context.read_string() - mod = _import_validated_module(read_context.policy, module, is_local=_is_local_qualname(module, qualname)) + policy = read_context.policy + mod = _import_validated_module(policy, module, is_local=_is_local_qualname(module, qualname)) + _authorize_callable_materialization( + policy, + types.FunctionType, + module=module, + qualname=qualname, + is_local=True, + ) name = qualname.rsplit(".")[-1] marshalled_code = read_context.read_bytes_and_size() @@ -1529,7 +1542,7 @@ def _deserialize_function(self, read_context): defaults = None if has_defaults: num_defaults = read_context.read_var_uint32() - _check_collection_size(read_context, num_defaults, "function default") + _check_non_negative_size(num_defaults, "function default") default_values = [] for _ in range(num_defaults): default_values.append(read_context.read_ref()) @@ -1537,7 +1550,7 @@ def _deserialize_function(self, read_context): has_closure = read_context.read_bool() num_freevars = read_context.read_var_uint32() - _check_collection_size(read_context, num_freevars, "function closure") + _check_non_negative_size(num_freevars, "function closure") closure = None closure_values = [] @@ -1548,7 +1561,7 @@ def _deserialize_function(self, read_context): closure = tuple(types.CellType(value) for value in closure_values) num_freevars = read_context.read_var_uint32() - _check_collection_size(read_context, num_freevars, "function free variable") + _check_non_negative_size(num_freevars, "function free variable") freevars = [] for _ in range(num_freevars): freevars.append(read_context.read_string()) @@ -1609,8 +1622,9 @@ def read(self, read_context): ) func = _validate_function_value(read_context.policy, func, is_local=_is_local_callable(func)) else: - obj = read_context.read_ref() policy = read_context.policy + _authorize_callable_materialization(policy, types.MethodType, method_name=name) + obj = read_context.read_ref() if policy is DEFAULT_POLICY: func = getattr(obj, name) else: @@ -1634,6 +1648,7 @@ def write(self, write_context, value): write_context.write_string(method_name) def read(self, read_context): + _authorize_callable_materialization(read_context.policy, self.cls) instance = read_context.read_ref() method_name = read_context.read_string() @@ -1682,8 +1697,7 @@ def read(self, read_context): obj = self.type_.__new__(self.type_) read_context.reference(obj) num_fields = read_context.read_var_uint32() - if num_fields > read_context.max_collection_size: - raise ValueError(f"object field size {num_fields} exceeds the configured limit of {read_context.max_collection_size}") + _check_non_negative_size(num_fields, "object field") state = {} for _ in range(num_fields): field_name = read_context.read_string() @@ -1700,8 +1714,7 @@ def read(self, read_context): obj = self.type_.__new__(self.type_) read_context.reference(obj) num_fields = read_context.read_var_uint32() - if num_fields > read_context.max_collection_size: - raise ValueError(f"object field size {num_fields} exceeds the configured limit of {read_context.max_collection_size}") + _check_non_negative_size(num_fields, "object field") for _ in range(num_fields): field_name = read_context.read_string() field_value = read_context.read_ref() diff --git a/python/pyfory/tests/test_metastring_resolver.py b/python/pyfory/tests/test_metastring_resolver.py index 256f2c371f..615dff53ba 100644 --- a/python/pyfory/tests/test_metastring_resolver.py +++ b/python/pyfory/tests/test_metastring_resolver.py @@ -17,10 +17,11 @@ import pytest -from pyfory import Buffer +from pyfory import Buffer, Fory from pyfory.context import EncodedMetaString, MetaStringReader, MetaStringWriter from pyfory.meta.metastring import MetaStringEncoder from pyfory.registry import MAX_CACHED_ENCODED_META_STRINGS, SharedRegistry +from pyfory.types import TypeId try: from pyfory.serialization import MetaStringReader as CythonMetaStringReader @@ -123,6 +124,12 @@ def test_cython_cached_big_metastring_validates_bytes_before_reuse(): reader.read_encoded_meta_string(buffer) +def test_malformed_metastring_ref_raises_value_error(): + data = bytes([1, 255, TypeId.NAMED_STRUCT, 3]) + with pytest.raises(ValueError, match="Invalid dynamic metastring id"): + Fory(xlang=True, compatible=False, strict=False).deserialize(data) + + def test_read_metastring_reset_clears_dynamic_ids_only(): shared_registry = SharedRegistry() encoded_meta_string = shared_registry.get_encoded_meta_string(MetaStringEncoder("$", "_").encode("hello")) diff --git a/python/pyfory/tests/test_policy.py b/python/pyfory/tests/test_policy.py index 4ae7a074a6..ad3453dc15 100644 --- a/python/pyfory/tests/test_policy.py +++ b/python/pyfory/tests/test_policy.py @@ -21,6 +21,7 @@ from pyfory import Fory, DeserializationPolicy from pyfory.serializer import ( FunctionSerializer, + MethodSerializer, NativeFuncMethodSerializer, TypeSerializer, ) @@ -258,7 +259,13 @@ def intercept_setstate(self, obj, state, **kwargs): raise ValueError("state blocked") FalseyState.bool_called = False - fory = Fory(xlang=False, ref=True, strict=False, policy=BlockSetStatePolicy(), compatible=False) + fory = Fory( + xlang=False, + ref=True, + strict=False, + policy=BlockSetStatePolicy(), + compatible=False, + ) data = fory.serialize(FalseyStatePayload()) with pytest.raises(ValueError, match="state blocked"): @@ -278,7 +285,13 @@ def intercept_setstate(self, obj, state, **kwargs): ObjectSetAttrPayload.setattr_called = False writer = Fory(xlang=False, ref=True, strict=False, compatible=False) - reader = Fory(xlang=False, ref=True, strict=False, policy=BlockSetStatePolicy(), compatible=False) + reader = Fory( + xlang=False, + ref=True, + strict=False, + policy=BlockSetStatePolicy(), + compatible=False, + ) writer.register(ObjectSetAttrPayload) reader.register(ObjectSetAttrPayload) @@ -490,6 +503,95 @@ def authorize_instantiation(self, cls, **kwargs): assert policy.authorize_instantiation_calls == 1 +def test_function_bound_method_authorizes_before_receiver_read(): + class BlockMethodMaterializationPolicy(DeserializationPolicy): + def __init__(self): + self.calls = [] + + def authorize_instantiation(self, cls, **kwargs): + self.calls.append((cls, kwargs)) + if cls is types.MethodType: + raise ValueError("method materialization blocked") + + policy = BlockMethodMaterializationPolicy() + fory = Fory(xlang=False, ref=True, strict=False, policy=policy, compatible=False) + serializer = FunctionSerializer(fory.type_resolver, type(policy_global_function)) + read_context = FakeReadContext(policy, [0]) + + with pytest.raises(ValueError, match="method materialization blocked"): + serializer._deserialize_function(read_context) + assert policy.calls == [(types.MethodType, {})] + + +def test_local_function_authorizes_before_body_read(): + class BlockFunctionMaterializationPolicy(DeserializationPolicy): + def __init__(self): + self.calls = [] + + def authorize_instantiation(self, cls, **kwargs): + self.calls.append((cls, kwargs)) + if cls is types.FunctionType: + raise ValueError("function materialization blocked") + + policy = BlockFunctionMaterializationPolicy() + fory = Fory(xlang=False, ref=True, strict=False, policy=policy, compatible=False) + serializer = FunctionSerializer(fory.type_resolver, type(policy_global_function)) + read_context = FakeReadContext(policy, [2, __name__, "policy_global_function"]) + + with pytest.raises(ValueError, match="function materialization blocked"): + serializer._deserialize_function(read_context) + assert policy.calls == [ + ( + types.FunctionType, + { + "module": __name__, + "qualname": "policy_global_function", + "is_local": True, + }, + ) + ] + + +def test_native_bound_method_authorizes_before_receiver_read(): + class BlockMethodMaterializationPolicy(DeserializationPolicy): + def __init__(self): + self.calls = [] + + def authorize_instantiation(self, cls, **kwargs): + self.calls.append((cls, kwargs)) + if cls is types.MethodType: + raise ValueError("method materialization blocked") + + policy = BlockMethodMaterializationPolicy() + fory = Fory(xlang=False, ref=True, strict=False, policy=policy, compatible=False) + serializer = NativeFuncMethodSerializer(fory.type_resolver, type(policy_global_function)) + read_context = FakeReadContext(policy, ["run", False]) + + with pytest.raises(ValueError, match="method materialization blocked"): + serializer.read(read_context) + assert policy.calls == [(types.MethodType, {"method_name": "run"})] + + +def test_method_serializer_authorizes_before_instance_read(): + class BlockMethodMaterializationPolicy(DeserializationPolicy): + def __init__(self): + self.calls = [] + + def authorize_instantiation(self, cls, **kwargs): + self.calls.append((cls, kwargs)) + if cls is types.MethodType: + raise ValueError("method materialization blocked") + + policy = BlockMethodMaterializationPolicy() + fory = Fory(xlang=False, ref=True, strict=False, policy=policy, compatible=False) + serializer = MethodSerializer(fory.type_resolver, types.MethodType) + read_context = FakeReadContext(policy, []) + + with pytest.raises(ValueError, match="method materialization blocked"): + serializer.read(read_context) + assert policy.calls == [(types.MethodType, {})] + + def test_validate_module(): """Test validate_module policy hook for module deserialization.""" import json @@ -500,7 +602,13 @@ def validate_module(self, module_name, is_local, **kwargs): assert not is_local return collections - fory1 = Fory(xlang=False, ref=True, strict=False, policy=ReturnModulePolicy(), compatible=False) + fory1 = Fory( + xlang=False, + ref=True, + strict=False, + policy=ReturnModulePolicy(), + compatible=False, + ) data = fory1.serialize(json) assert fory1.deserialize(data) is json @@ -572,7 +680,13 @@ class ReturnClassPolicy(DeserializationPolicy): def validate_class(self, cls, is_local, **kwargs): return SafeClass if is_local else None - fory = Fory(xlang=False, ref=True, strict=False, policy=ReturnClassPolicy(), compatible=False) + fory = Fory( + xlang=False, + ref=True, + strict=False, + policy=ReturnClassPolicy(), + compatible=False, + ) decoded = fory.deserialize(fory.serialize(make_payload_class())) assert decoded is not SafeClass assert decoded.run() == "payload" @@ -648,7 +762,13 @@ def validate_method(self, method, is_local, **kwargs): obj = GuardedMethod() method = types.MethodType(GuardedMethod.run, obj) - fory = Fory(xlang=False, ref=True, strict=False, policy=BlockMethodPolicy(), compatible=False) + fory = Fory( + xlang=False, + ref=True, + strict=False, + policy=BlockMethodPolicy(), + compatible=False, + ) data = fory.serialize(method) GuardedMethod.getattribute_called = False diff --git a/python/pyfory/tests/test_size_guardrails.py b/python/pyfory/tests/test_size_guardrails.py deleted file mode 100644 index 55399fa268..0000000000 --- a/python/pyfory/tests/test_size_guardrails.py +++ /dev/null @@ -1,218 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Test max_collection_size and max_binary_size guardrails to prevent OOM attacks -from malicious payloads. - -Collections preallocate memory based on declared size, so they need guardrails. -Binary reads are guarded by max_binary_size on the Buffer. -""" - -from dataclasses import dataclass -from typing import List - -import pytest - -import pyfory -from pyfory import Fory -from pyfory.serialization import Buffer -from pyfory.types import TypeId - - -class ObjectPayload: - pass - - -def roundtrip(data, limit, xlang=False, ref=False): - """Serialize and deserialize with given collection size limit.""" - writer = Fory(xlang=xlang, ref=ref, compatible=xlang) - reader = Fory(xlang=xlang, ref=ref, max_collection_size=limit, compatible=xlang) - return reader.deserialize(writer.serialize(data)) - - -def roundtrip_binary(data, max_binary_size, xlang=False, ref=False): - """Serialize and deserialize with given binary size limit.""" - writer = Fory(xlang=xlang, ref=ref, compatible=xlang) - reader = Fory(xlang=xlang, ref=ref, max_binary_size=max_binary_size, compatible=xlang) - return reader.deserialize(writer.serialize(data)) - - -class TestCollectionSizeLimit: - """Collections (list/set/dict) preallocate memory, so need size limits.""" - - @pytest.mark.parametrize("xlang", [False, True]) - @pytest.mark.parametrize( - "data,limit", - [ - ([1, 2, 3], 10), # list within limit - ({1, 2, 3}, 10), # set within limit - ({"a": 1}, 10), # dict within limit - ([], 0), # empty list ok - (set(), 0), # empty set ok - ({}, 0), # empty dict ok - ], - ) - def test_within_limit_succeeds(self, xlang, data, limit): - assert roundtrip(data, limit, xlang=xlang) == data - - @pytest.mark.parametrize("xlang", [False, True]) - @pytest.mark.parametrize( - "data,limit", - [ - (list(range(10)), 5), # list exceeds - (set(range(10)), 5), # set exceeds - ({str(i): i for i in range(10)}, 5), # dict exceeds - ([[1], list(range(10))], 5), # nested inner exceeds - ], - ) - def test_exceeds_limit_fails(self, xlang, data, limit): - with pytest.raises(ValueError, match="exceeds the configured limit"): - roundtrip(data, limit, xlang=xlang) - - @pytest.mark.parametrize("ref", [False, True]) - @pytest.mark.parametrize( - "data,limit,should_fail", - [ - ((1, 2, 3), 10, False), - (tuple(range(10)), 5, True), - ], - ) - def test_tuple_limit(self, ref, data, limit, should_fail): - """Tuple only works in xlang=False mode.""" - if should_fail: - with pytest.raises(ValueError, match="exceeds the configured limit"): - roundtrip(data, limit, xlang=False, ref=ref) - else: - assert roundtrip(data, limit, xlang=False, ref=ref) == data - - def test_default_limit_is_one_million(self): - assert ( - Fory( - xlang=False, - compatible=False, - ).max_collection_size - == 1_000_000 - ) - - def test_dataclass_list_field_exceeds_limit(self): - @dataclass - class Container: - items: List[pyfory.Int32] - - writer = Fory(xlang=True, compatible=False) - writer.register(Container) - reader = Fory(xlang=True, compatible=False, max_collection_size=5) - reader.register(Container) - - with pytest.raises(ValueError, match="exceeds the configured limit"): - reader.deserialize(writer.serialize(Container(items=list(range(10))))) - - def test_object_field_count_exceeds_limit(self): - obj = ObjectPayload() - obj.value = 1 - writer = Fory(xlang=False, ref=True, strict=False, compatible=False) - reader = Fory(xlang=False, ref=True, strict=False, max_collection_size=0, compatible=False) - writer.register(ObjectPayload) - reader.register(ObjectPayload) - - with pytest.raises(ValueError, match="object field size 1 exceeds"): - reader.deserialize(writer.serialize(obj)) - - def test_local_class_base_count_exceeds_limit(self): - def make_local_class(): - class LocalPayload: - pass - - return LocalPayload - - writer = Fory(xlang=False, ref=True, strict=False, compatible=False) - reader = Fory(xlang=False, ref=True, strict=False, max_collection_size=0, compatible=False) - - with pytest.raises(ValueError, match="local class base size 1 exceeds"): - reader.deserialize(writer.serialize(make_local_class())) - - def test_local_function_defaults_exceed_limit(self): - def local_function(value=1): - return value - - writer = Fory(xlang=False, ref=True, strict=False, compatible=False) - reader = Fory(xlang=False, ref=True, strict=False, max_collection_size=0, compatible=False) - - with pytest.raises(ValueError, match="function default size 1 exceeds"): - reader.deserialize(writer.serialize(local_function)) - - def test_object_ndarray_length_exceeds_limit(self): - np = pytest.importorskip("numpy") - arr = np.array([object(), object()], dtype=object) - writer = Fory(xlang=False, ref=True, strict=False, compatible=False) - reader = Fory(xlang=False, ref=True, strict=False, max_collection_size=1, compatible=False) - - with pytest.raises(ValueError, match="ndarray object size 2 exceeds"): - reader.deserialize(writer.serialize(arr)) - - -class TestBinarySizeLimit: - """Binary reads are guarded by max_binary_size on the Buffer.""" - - def test_default_limit_is_64mib(self): - assert ( - Fory( - xlang=False, - compatible=False, - ).max_binary_size - == 64 * 1024 * 1024 - ) - - @pytest.mark.parametrize("xlang", [False, True]) - def test_within_limit_succeeds(self, xlang): - assert roundtrip_binary(b"x" * 100, max_binary_size=1024, xlang=xlang) == b"x" * 100 - - @pytest.mark.parametrize("xlang", [False, True]) - def test_exceeds_limit_fails(self, xlang): - with pytest.raises(ValueError, match="exceeds the configured limit"): - roundtrip_binary(b"x" * 200, max_binary_size=100, xlang=xlang) - - @pytest.mark.parametrize("xlang", [False, True]) - def test_string_exceeds_limit_fails(self, xlang): - writer = Fory(xlang=xlang, compatible=xlang) - reader = Fory(xlang=xlang, max_binary_size=1, compatible=xlang) - with pytest.raises(ValueError, match="String size 2 exceeds"): - reader.deserialize(writer.serialize("xx")) - - def test_from_stream_respects_limit(self): - import io - - payload = Fory( - xlang=False, - compatible=False, - ).serialize(b"x" * 200) - buf = Buffer.from_stream(io.BytesIO(payload), max_binary_size=100) - with pytest.raises(ValueError, match="exceeds the configured limit"): - Fory(xlang=False, max_binary_size=100, compatible=False).deserialize(buf) - - def test_in_band_buffer_object_respects_limit(self): - payload = b"x" * 200 - data = Fory(xlang=False, ref=True, compatible=False).serialize(payload, buffer_callback=lambda _buffer: True) - - with pytest.raises(ValueError, match="exceeds the configured limit"): - Fory(xlang=False, ref=True, max_binary_size=100, compatible=False).deserialize(data, buffers=[]) - - def test_malformed_metastring_ref_raises_value_error(self): - payload = bytes([1, 255, TypeId.NAMED_STRUCT, 3]) - with pytest.raises(ValueError, match="Invalid dynamic metastring id"): - Fory(xlang=True, compatible=False, strict=False).deserialize(payload) diff --git a/python/pyfory/union.py b/python/pyfory/union.py index 82c1629b10..4b37093c3f 100644 --- a/python/pyfory/union.py +++ b/python/pyfory/union.py @@ -47,7 +47,7 @@ class UnionSerializer(Serializer): """ Serializer for generated union classes and typing.Union. - For generated unions, the payload is: + For generated unions, the case body is: | case_id (varuint32) | case_value (Any-style value) | """ @@ -110,17 +110,17 @@ def read(self, read_context): value = read_context.get_read_ref() return self._build_union(case_id, value) self.type_resolver.read_type_info(read_context) - value = self._read_case_payload(read_context, serializer) + value = self._read_case_value(read_context, serializer) read_context.set_read_ref(ref_id, value) else: if read_context.read_int8() == NULL_FLAG: value = None else: self.type_resolver.read_type_info(read_context) - value = self._read_case_payload(read_context, serializer) + value = self._read_case_value(read_context, serializer) return self._build_union(case_id, value) - def _read_case_payload(self, read_context, serializer): + def _read_case_value(self, read_context, serializer): read_context.increase_depth() try: return serializer.read(read_context) diff --git a/rust/fory-core/src/buffer.rs b/rust/fory-core/src/buffer.rs index 94e6f97d57..25a9a391b7 100644 --- a/rust/fory-core/src/buffer.rs +++ b/rust/fory-core/src/buffer.rs @@ -574,7 +574,7 @@ impl<'a> Reader<'a> { } #[inline(always)] - fn check_bound(&self, n: usize) -> Result<(), Error> { + pub(crate) fn check_bound(&self, n: usize) -> Result<(), Error> { let end = self .cursor .checked_add(n) diff --git a/rust/fory-core/src/config.rs b/rust/fory-core/src/config.rs index b46588576d..f7152d7292 100644 --- a/rust/fory-core/src/config.rs +++ b/rust/fory-core/src/config.rs @@ -40,12 +40,6 @@ pub struct Config { /// When enabled, shared references and circular references are tracked /// and preserved during serialization/deserialization. pub track_ref: bool, - /// Maximum allowed size for binary data in bytes. - /// Prevents excessive memory allocation from untrusted payloads. - pub max_binary_size: u32, - /// Maximum allowed number of elements in a collection or entries in a map. - /// Prevents excessive memory allocation from untrusted payloads. - pub max_collection_size: u32, } impl Default for Config { @@ -59,8 +53,6 @@ impl Default for Config { max_dyn_depth: 5, check_struct_version: false, track_ref: false, - max_binary_size: 64 * 1024 * 1024, // 64MB default - max_collection_size: 1024 * 1024, // 1M elements default } } } @@ -118,16 +110,4 @@ impl Config { pub fn is_track_ref(&self) -> bool { self.track_ref } - - /// Get maximum allowed binary data size in bytes. - #[inline(always)] - pub fn max_binary_size(&self) -> u32 { - self.max_binary_size - } - - /// Get maximum allowed collection/map element count. - #[inline(always)] - pub fn max_collection_size(&self) -> u32 { - self.max_collection_size - } } diff --git a/rust/fory-core/src/context.rs b/rust/fory-core/src/context.rs index ce2db2c135..933c80ac38 100644 --- a/rust/fory-core/src/context.rs +++ b/rust/fory-core/src/context.rs @@ -358,8 +358,6 @@ pub struct ReadContext<'a> { max_dyn_depth: u32, check_struct_version: bool, check_string_read: bool, - max_binary_size: u32, - max_collection_size: u32, // Context-specific fields pub reader: Reader<'a>, @@ -388,8 +386,6 @@ impl<'a> ReadContext<'a> { max_dyn_depth: config.max_dyn_depth, check_struct_version: config.check_struct_version, check_string_read: config.check_string_read, - max_binary_size: config.max_binary_size, - max_collection_size: config.max_collection_size, reader: Reader::default(), meta_resolver: MetaReaderResolver::default(), meta_string_resolver: MetaStringReaderResolver::default(), @@ -440,18 +436,6 @@ impl<'a> ReadContext<'a> { self.max_dyn_depth } - /// Get maximum allowed binary data size in bytes. - #[inline(always)] - pub fn max_binary_size(&self) -> u32 { - self.max_binary_size - } - - /// Get maximum allowed collection/map element count. - #[inline(always)] - pub fn max_collection_size(&self) -> u32 { - self.max_collection_size - } - #[inline(always)] pub fn attach_reader(&mut self, reader: Reader<'a>) { self.reader = reader; diff --git a/rust/fory-core/src/error.rs b/rust/fory-core/src/error.rs index 9311a8a4a2..6d3d61990d 100644 --- a/rust/fory-core/src/error.rs +++ b/rust/fory-core/src/error.rs @@ -196,15 +196,6 @@ pub enum Error { /// Do not construct this variant directly; use [`Error::struct_version_mismatch`] instead. #[error("{0}")] StructVersionMismatch(Cow<'static, str>), - - /// Deserialization size limit exceeded. - /// - /// Returned when a payload-driven length exceeds a configured guardrail - /// (e.g. `max_binary_size` or `max_collection_size`). - /// - /// Do not construct this variant directly; use [`Error::size_limit_exceeded`] instead. - #[error("{0}")] - SizeLimitExceeded(Cow<'static, str>), } impl Error { @@ -504,27 +495,6 @@ impl Error { err } - /// Creates a new [`Error::SizeLimitExceeded`] from a string or static message. - /// - /// If `FORY_PANIC_ON_ERROR` environment variable is set, this will panic with the error message. - /// - /// # Example - /// ``` - /// use fory_core::error::Error; - /// - /// let err = Error::size_limit_exceeded("Collection size 2000000 exceeds limit 1048576"); - /// ``` - #[inline(always)] - #[cold] - #[track_caller] - pub fn size_limit_exceeded>>(s: S) -> Self { - let err = Error::SizeLimitExceeded(s.into()); - if PANIC_ON_ERROR { - panic!("FORY_PANIC_ON_ERROR: {}", err); - } - err - } - /// Enhances a [`Error::TypeError`] with additional type name information. /// /// If the error is a `TypeError`, appends the type name to the message. diff --git a/rust/fory-core/src/fory.rs b/rust/fory-core/src/fory.rs index 90ebc00fd9..f9885fd3d6 100644 --- a/rust/fory-core/src/fory.rs +++ b/rust/fory-core/src/fory.rs @@ -298,65 +298,6 @@ impl ForyBuilder { self } - /// Sets the maximum allowed size for binary data during deserialization. - /// - /// # Arguments - /// - /// * `max_binary_size` - The maximum number of bytes allowed for a single binary/primitive-array - /// payload during deserialization. Payloads exceeding this limit will cause a - /// `SizeLimitExceeded` error. - /// - /// # Returns - /// - /// Returns `self` for method chaining. - /// - /// # Default - /// - /// The default value is `64 * 1024 * 1024` (64 MB). - /// - /// # Examples - /// - /// ```rust - /// use fory_core::Fory; - /// - /// // Limit binary payloads to 1 MB - /// let fory = Fory::builder().max_binary_size(1024 * 1024).build(); - /// ``` - pub fn max_binary_size(mut self, max_binary_size: u32) -> Self { - self.config.max_binary_size = max_binary_size; - self - } - - /// Sets the maximum allowed number of elements in a collection or entries in a map - /// during deserialization. - /// - /// # Arguments - /// - /// * `max_collection_size` - The maximum number of elements/entries allowed for a single - /// collection or map during deserialization. Payloads exceeding this limit will cause a - /// `SizeLimitExceeded` error. - /// - /// # Returns - /// - /// Returns `self` for method chaining. - /// - /// # Default - /// - /// The default value is `1024 * 1024` (1 million elements). - /// - /// # Examples - /// - /// ```rust - /// use fory_core::Fory; - /// - /// // Limit collections to 10000 elements - /// let fory = Fory::builder().max_collection_size(10000).build(); - /// ``` - pub fn max_collection_size(mut self, max_collection_size: u32) -> Self { - self.config.max_collection_size = max_collection_size; - self - } - fn finish_config(self) -> Config { let mut config = self.config; if !self.compatible_set { @@ -495,16 +436,6 @@ impl Fory { self.config.max_dyn_depth } - /// Returns the maximum allowed binary data size in bytes. - pub fn get_max_binary_size(&self) -> u32 { - self.config.max_binary_size - } - - /// Returns the maximum allowed collection/map element count. - pub fn get_max_collection_size(&self) -> u32 { - self.config.max_collection_size - } - /// Returns whether class version checking is enabled. /// /// # Returns diff --git a/rust/fory-core/src/meta/type_meta.rs b/rust/fory-core/src/meta/type_meta.rs index a6c1013e63..5913bb2796 100644 --- a/rust/fory-core/src/meta/type_meta.rs +++ b/rust/fory-core/src/meta/type_meta.rs @@ -1180,6 +1180,7 @@ impl TypeMeta { type_name = empty_name; } + reader.check_bound(num_fields)?; let mut field_infos = Vec::with_capacity(num_fields); for _ in 0..num_fields { field_infos.push(FieldInfo::from_bytes(reader)?); diff --git a/rust/fory-core/src/serializer/codec.rs b/rust/fory-core/src/serializer/codec.rs index c0198aa66f..34059103f5 100644 --- a/rust/fory-core/src/serializer/codec.rs +++ b/rust/fory-core/src/serializer/codec.rs @@ -1518,6 +1518,13 @@ signed_int_codec!( pub struct VecCodec(PhantomData<(T, C)>); +#[inline(always)] +fn check_sequence_len(context: &ReadContext, len: u32) -> Result { + let len = len as usize; + context.reader.check_bound(len)?; + Ok(len) +} + #[inline(always)] fn read_vec_items( context: &mut ReadContext, @@ -1529,7 +1536,7 @@ where T: 'static, C: Codec, { - let mut vec = Vec::with_capacity(len as usize); + let mut vec = Vec::with_capacity(check_sequence_len(context, len)?); match read_type { None | Some(ElementReadType::Direct) => { if has_null { @@ -1696,13 +1703,6 @@ where if len == 0 { return Ok(Vec::new()); } - let max = context.max_collection_size(); - if len > max { - return Err(Error::size_limit_exceeded(format!( - "Collection size {} exceeds limit {}", - len, max - ))); - } let header = context.reader.read_u8()?; if C::is_polymorphic() || C::is_shared_ref() { let field_type = Self::field_type(context.get_type_resolver())?; @@ -1731,13 +1731,6 @@ where if len == 0 { return Ok(Vec::new()); } - let max = context.max_collection_size(); - if len > max { - return Err(Error::size_limit_exceeded(format!( - "Collection size {} exceeds limit {}", - len, max - ))); - } let header = context.reader.read_u8()?; let has_null = (header & HAS_NULL) != 0; let is_same_type = (header & IS_SAME_TYPE) != 0; @@ -2088,7 +2081,7 @@ where } else { RefMode::None }; - let mut vec = Vec::with_capacity(len as usize); + let mut vec = Vec::with_capacity(check_sequence_len(context, len)?); if is_same_type { if C::is_polymorphic() { if is_declared { @@ -2280,13 +2273,6 @@ where if len == 0 { return Ok(HashMap::new()); } - let max = context.max_collection_size(); - if len > max { - return Err(Error::size_limit_exceeded(format!( - "Map size {} exceeds limit {}", - len, max - ))); - } if KC::is_polymorphic() || KC::is_shared_ref() || VC::is_polymorphic() @@ -2306,13 +2292,6 @@ where if len == 0 { return Ok(HashMap::new()); } - let max = context.max_collection_size(); - if len > max { - return Err(Error::size_limit_exceeded(format!( - "Map size {} exceeds limit {}", - len, max - ))); - } if KC::is_polymorphic() || KC::is_shared_ref() || VC::is_polymorphic() @@ -2320,7 +2299,7 @@ where { return read_map_dynamic::(context, len, remote_field_type); } - let mut map = HashMap::with_capacity(len as usize); + let mut map = HashMap::with_capacity(check_map_len(context, len)?); let mut len_counter = 0; while len_counter < len { let header = context.reader.read_u8()?; @@ -2459,6 +2438,13 @@ struct MapEntryReadType { type_info: Option>, } +#[inline(always)] +fn check_map_len(context: &ReadContext, len: u32) -> Result { + let len = len as usize; + context.reader.check_bound(len)?; + Ok(len) +} + fn read_map_static( context: &mut ReadContext, len: u32, @@ -2469,7 +2455,7 @@ where KC: Codec, VC: Codec, { - let mut map = HashMap::with_capacity(len as usize); + let mut map = HashMap::with_capacity(check_map_len(context, len)?); let mut len_counter = 0u32; while len_counter < len { let header = context.reader.read_u8()?; @@ -2575,7 +2561,7 @@ where KC: Codec, VC: Codec, { - let mut map = HashMap::with_capacity(len as usize); + let mut map = HashMap::with_capacity(check_map_len(context, len)?); let mut len_counter = 0u32; while len_counter < len { let header = context.reader.read_u8()?; diff --git a/rust/fory-core/src/serializer/collection.rs b/rust/fory-core/src/serializer/collection.rs index 0e84016089..ee16166bb4 100644 --- a/rust/fory-core/src/serializer/collection.rs +++ b/rust/fory-core/src/serializer/collection.rs @@ -36,22 +36,10 @@ pub const DECL_ELEMENT_TYPE: u8 = 0b100; // Whether collection elements type same. pub const IS_SAME_TYPE: u8 = 0b1000; -#[cold] -fn collection_size_limit_exceeded(len: u32, max: u32) -> Error { - Error::size_limit_exceeded(format!("Collection size {} exceeds limit {}", len, max)) -} - -fn check_collection_len(context: &ReadContext, len: u32) -> Result<(), Error> { - if std::mem::size_of::() == 0 { - return Ok(()); - } +fn check_collection_len(context: &ReadContext, len: u32) -> Result { let len = len as usize; - let remaining = context.reader.slice_after_cursor().len(); - if len > remaining { - let cursor = context.reader.get_cursor(); - return Err(Error::buffer_out_of_bound(cursor, len, cursor + remaining)); - } - Ok(()) + context.reader.check_bound(len)?; + Ok(len) } pub fn write_collection_type_info( @@ -254,10 +242,6 @@ where if len == 0 { return Ok(C::from_iter(std::iter::empty())); } - let max = context.max_collection_size(); - if len > max { - return Err(collection_size_limit_exceeded(len, max)); - } if T::fory_is_polymorphic() || T::fory_is_shared_ref() { return read_collection_data_dyn_ref(context, len); } @@ -273,7 +257,7 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); - check_collection_len::(context, len)?; + let _ = check_collection_len(context, len)?; if !has_null { (0..len) .map(|_| T::fory_read_data(context)) @@ -300,10 +284,6 @@ where if len == 0 { return Ok(Vec::new()); } - let max = context.max_collection_size(); - if len > max { - return Err(collection_size_limit_exceeded(len, max)); - } if T::fory_is_polymorphic() || T::fory_is_shared_ref() { return read_vec_data_dyn_ref(context, len); } @@ -317,8 +297,7 @@ where (header & IS_SAME_TYPE) != 0, Error::type_error("Type inconsistent, target type is not polymorphic") ); - check_collection_len::(context, len)?; - let mut vec = Vec::with_capacity(len as usize); + let mut vec = Vec::with_capacity(check_collection_len(context, len)?); if !has_null { for _ in 0..len { vec.push(T::fory_read_data(context)?); @@ -364,8 +343,7 @@ where } else { T::fory_get_type_info(context.get_type_resolver())? }; - check_collection_len::(context, len)?; - let mut vec = Vec::with_capacity(len as usize); + let mut vec = Vec::with_capacity(check_collection_len(context, len)?); if elem_ref_mode == RefMode::None { for _ in 0..len { vec.push(T::fory_read_with_type_info( @@ -385,8 +363,7 @@ where } Ok(vec) } else { - check_collection_len::(context, len)?; - let mut vec = Vec::with_capacity(len as usize); + let mut vec = Vec::with_capacity(check_collection_len(context, len)?); for _ in 0..len { vec.push(T::fory_read(context, elem_ref_mode, true)?); } @@ -426,8 +403,8 @@ where } else { T::fory_get_type_info(context.get_type_resolver())? }; - check_collection_len::(context, len)?; // All elements are same type + let _ = check_collection_len(context, len)?; if elem_ref_mode == RefMode::None { // No null elements, no tracking (0..len) @@ -440,7 +417,7 @@ where .collect::>() } } else { - check_collection_len::(context, len)?; + let _ = check_collection_len(context, len)?; (0..len) .map(|_| T::fory_read(context, elem_ref_mode, true)) .collect::>() @@ -510,11 +487,13 @@ fn read_primitive_array_data_bulk( } #[cfg(target_endian = "little")] { - let mut vec: Vec = Vec::with_capacity(len); + // Prove the encoded primitive-array body exists before allocating from + // its declared byte length. let src = match context.reader.read_bytes(size_bytes) { Ok(src) => src, Err(error) => return Some(Err(error)), }; + let mut vec: Vec = Vec::with_capacity(len); unsafe { std::ptr::copy_nonoverlapping(src.as_ptr(), vec.as_mut_ptr() as *mut u8, size_bytes); vec.set_len(len); @@ -581,22 +560,7 @@ where if size_bytes % elem_size != 0 { return Err(Error::invalid_data("Invalid data length")); } - let max = context.max_binary_size() as usize; - if size_bytes > max { - return Err(Error::size_limit_exceeded(format!( - "Binary size {} exceeds limit {}", - size_bytes, max - ))); - } - let remaining = context.reader.slice_after_cursor().len(); - if size_bytes > remaining { - let cursor = context.reader.get_cursor(); - return Err(Error::buffer_out_of_bound( - cursor, - size_bytes, - cursor + remaining, - )); - } + context.reader.check_bound(size_bytes)?; let len = size_bytes / elem_size; let element_type_id = primitive_array_element_type_id(remote_field_type.type_id) .ok_or_else(|| Error::type_error("array-compatible field is not a primitive array"))?; @@ -763,13 +727,6 @@ where if len == 0 { return Ok(Vec::new()); } - let max = context.max_collection_size(); - if len > max { - return Err(Error::size_limit_exceeded(format!( - "Collection size {} exceeds limit {}", - len, max - ))); - } let header = context.reader.read_u8()?; if (header & HAS_NULL) != 0 { return Err(Error::type_error( @@ -791,6 +748,7 @@ where "array-compatible list must declare element type", )); } + context.reader.check_bound(len as usize)?; let mut vec = Vec::with_capacity(len as usize); for _ in 0..len { vec.push(T::read_list_array_element(context, element_type.type_id)?); diff --git a/rust/fory-core/src/serializer/map.rs b/rust/fory-core/src/serializer/map.rs index b467f37c4a..3d0dc094e7 100644 --- a/rust/fory-core/src/serializer/map.rs +++ b/rust/fory-core/src/serializer/map.rs @@ -35,19 +35,10 @@ const TRACKING_VALUE_REF: u8 = 0b1000; pub const VALUE_NULL: u8 = 0b10000; pub const DECL_VALUE_TYPE: u8 = 0b100000; -#[cold] -fn map_size_limit_exceeded(len: u32, max: u32) -> Error { - Error::size_limit_exceeded(format!("Map size {} exceeds limit {}", len, max)) -} - -fn check_map_len(context: &ReadContext, len: u32) -> Result<(), Error> { +fn check_map_len(context: &ReadContext, len: u32) -> Result { let len = len as usize; - let remaining = context.reader.slice_after_cursor().len(); - if len > remaining { - let cursor = context.reader.get_cursor(); - return Err(Error::buffer_out_of_bound(cursor, len, cursor + remaining)); - } - Ok(()) + context.reader.check_bound(len)?; + Ok(len) } fn write_chunk_size(context: &mut WriteContext, header_offset: usize, size: u8) { @@ -571,20 +562,16 @@ impl max { - return Err(map_size_limit_exceeded(len, max)); - } - check_map_len(context, len)?; + let capacity = check_map_len(context, len)?; if K::fory_is_polymorphic() || K::fory_is_shared_ref() || V::fory_is_polymorphic() || V::fory_is_shared_ref() { - let map: HashMap = HashMap::with_capacity(len as usize); + let map: HashMap = HashMap::with_capacity(capacity); return read_hashmap_data_dyn_ref(context, map, len); } - let mut map = HashMap::::with_capacity(len as usize); + let mut map = HashMap::::with_capacity(capacity); let mut len_counter = 0; loop { if len_counter == len { @@ -727,11 +714,7 @@ impl max { - return Err(map_size_limit_exceeded(len, max)); - } - check_map_len(context, len)?; + let _ = check_map_len(context, len)?; let mut map = BTreeMap::::new(); if K::fory_is_polymorphic() || K::fory_is_shared_ref() diff --git a/rust/fory-core/src/serializer/primitive_list.rs b/rust/fory-core/src/serializer/primitive_list.rs index c1b927bd3a..f7bd6d39f3 100644 --- a/rust/fory-core/src/serializer/primitive_list.rs +++ b/rust/fory-core/src/serializer/primitive_list.rs @@ -22,11 +22,6 @@ use crate::error::Error; use crate::serializer::Serializer; use crate::type_id::TypeId; -#[cold] -fn binary_size_limit_exceeded(size_bytes: usize, max: usize) -> Error { - Error::size_limit_exceeded(format!("Binary size {} exceeds limit {}", size_bytes, max)) -} - pub fn fory_write_data(this: &[T], context: &mut WriteContext) -> Result<(), Error> { // U128, USIZE, ISIZE, INT128 are Rust-specific and not supported in xlang mode if context.is_xlang() { @@ -88,19 +83,7 @@ pub fn fory_read_data(context: &mut ReadContext) -> Result if size_bytes % std::mem::size_of::() != 0 { return Err(Error::invalid_data("Invalid data length")); } - let max = context.max_binary_size() as usize; - if size_bytes > max { - return Err(binary_size_limit_exceeded(size_bytes, max)); - } - let remaining = context.reader.slice_after_cursor().len(); - if size_bytes > remaining { - let cursor = context.reader.get_cursor(); - return Err(Error::buffer_out_of_bound( - cursor, - size_bytes, - cursor + remaining, - )); - } + context.reader.check_bound(size_bytes)?; let len = size_bytes / std::mem::size_of::(); let mut vec: Vec = Vec::with_capacity(len); diff --git a/rust/tests/tests/test_collection.rs b/rust/tests/tests/test_collection.rs index 1e2524841a..6b9adeaa95 100644 --- a/rust/tests/tests/test_collection.rs +++ b/rust/tests/tests/test_collection.rs @@ -117,33 +117,3 @@ fn test_heap_container() { assert_eq!(deserialized.binary_heap.len(), 3); assert_eq!(deserialized.binary_heap.peek(), Some(&3)); } - -#[test] -fn test_hashset_max_collection_size_guardrail() { - let fory = Fory::builder().xlang(false).compatible(false).build(); - let original = HashSet::from([ - "apple".to_string(), - "banana".to_string(), - "cherry".to_string(), - ]); - let serialized = fory.serialize(&original).unwrap(); - - let limited_fory = Fory::builder() - .xlang(false) - .max_collection_size(2) - .compatible(false) - .build(); - let err = limited_fory - .deserialize::>(&serialized) - .expect_err("expected collection size guardrail to reject the payload"); - - assert!( - matches!(err, fory_core::Error::SizeLimitExceeded(_)), - "expected SizeLimitExceeded, got: {err}" - ); - assert!( - err.to_string() - .contains("Collection size 3 exceeds limit 2"), - "unexpected error message: {err}" - ); -} diff --git a/rust/tests/tests/test_fory.rs b/rust/tests/tests/test_fory.rs index b9f30ee3e4..8880b17d67 100644 --- a/rust/tests/tests/test_fory.rs +++ b/rust/tests/tests/test_fory.rs @@ -328,42 +328,3 @@ fn test_type_mismatch_error_shows_type_name() { err_str ); } - -#[test] -fn test_size_guardrail_configuration_accessors() { - let default_fory = Fory::builder().xlang(false).compatible(false).build(); - assert_eq!(default_fory.get_max_binary_size(), 64 * 1024 * 1024); - assert_eq!(default_fory.get_max_collection_size(), 1024 * 1024); - - let configured_fory = Fory::builder() - .xlang(false) - .max_binary_size(4096) - .max_collection_size(128) - .compatible(false) - .build(); - assert_eq!(configured_fory.get_max_binary_size(), 4096); - assert_eq!(configured_fory.get_max_collection_size(), 128); -} - -#[test] -fn test_max_binary_size_does_not_limit_string_reads() { - let fory = Fory::builder().xlang(false).compatible(false).build(); - let original = "this string should not be treated as binary".repeat(4); - let serialized = fory.serialize(&original).unwrap(); - - let limited_fory = Fory::builder() - .xlang(false) - .max_binary_size(4) - .compatible(false) - .build(); - let deserialized: String = limited_fory.deserialize(&serialized).unwrap(); - - assert_eq!(deserialized, original); -} - -#[test] -fn test_size_limit_exceeded_error_display() { - let err = Error::size_limit_exceeded("Collection size 3 exceeds limit 2"); - assert!(matches!(err, Error::SizeLimitExceeded(_))); - assert_eq!(err.to_string(), "Collection size 3 exceeds limit 2"); -} diff --git a/rust/tests/tests/test_list.rs b/rust/tests/tests/test_list.rs index 7cf7ca1a31..8ade50987d 100644 --- a/rust/tests/tests/test_list.rs +++ b/rust/tests/tests/test_list.rs @@ -247,29 +247,3 @@ fn test_vec_bfloat16_special_values() { assert_eq!(obj[3].to_bits(), bfloat16::MAX.to_bits()); assert!(obj[5].is_subnormal()); } - -#[test] -fn test_vec_max_collection_size_guardrail() { - let fory = Fory::builder().xlang(false).compatible(false).build(); - let original = vec!["alpha".to_string(), "beta".to_string(), "gamma".to_string()]; - let serialized = fory.serialize(&original).unwrap(); - - let limited_fory = Fory::builder() - .xlang(false) - .max_collection_size(2) - .compatible(false) - .build(); - let err = limited_fory - .deserialize::>(&serialized) - .expect_err("expected vec deserialization to fail on max_collection_size"); - - assert!( - matches!(err, fory_core::Error::SizeLimitExceeded(_)), - "expected SizeLimitExceeded, got: {err}" - ); - assert!( - err.to_string() - .contains("Collection size 3 exceeds limit 2"), - "unexpected error message: {err}" - ); -} diff --git a/rust/tests/tests/test_map.rs b/rust/tests/tests/test_map.rs index bd8344a3ea..ee057d1645 100644 --- a/rust/tests/tests/test_map.rs +++ b/rust/tests/tests/test_map.rs @@ -67,61 +67,3 @@ fn test_struct_with_maps() { let obj: MapContainer = fory.deserialize(&bin).expect("deserialize"); assert_eq!(container, obj); } - -#[test] -fn test_hashmap_max_collection_size_guardrail() { - let fory = Fory::builder().xlang(false).compatible(false).build(); - let map = HashMap::from([ - ("key1".to_string(), 1_i32), - ("key2".to_string(), 2_i32), - ("key3".to_string(), 3_i32), - ]); - let serialized = fory.serialize(&map).unwrap(); - - let limited_fory = Fory::builder() - .xlang(false) - .max_collection_size(2) - .compatible(false) - .build(); - let err = limited_fory - .deserialize::>(&serialized) - .expect_err("expected hashmap deserialization to fail on max_collection_size"); - - assert!( - matches!(err, fory_core::Error::SizeLimitExceeded(_)), - "expected SizeLimitExceeded, got: {err}" - ); - assert!( - err.to_string().contains("Map size 3 exceeds limit 2"), - "unexpected error message: {err}" - ); -} - -#[test] -fn test_btreemap_max_collection_size_guardrail() { - let fory = Fory::builder().xlang(false).compatible(false).build(); - let map = BTreeMap::from([ - ("key1".to_string(), 1_i32), - ("key2".to_string(), 2_i32), - ("key3".to_string(), 3_i32), - ]); - let serialized = fory.serialize(&map).unwrap(); - - let limited_fory = Fory::builder() - .xlang(false) - .max_collection_size(2) - .compatible(false) - .build(); - let err = limited_fory - .deserialize::>(&serialized) - .expect_err("expected btreemap deserialization to fail on max_collection_size"); - - assert!( - matches!(err, fory_core::Error::SizeLimitExceeded(_)), - "expected SizeLimitExceeded, got: {err}" - ); - assert!( - err.to_string().contains("Map size 3 exceeds limit 2"), - "unexpected error message: {err}" - ); -} diff --git a/rust/tests/tests/test_unsigned.rs b/rust/tests/tests/test_unsigned.rs index dbf9cbbbc5..249ed868c9 100644 --- a/rust/tests/tests/test_unsigned.rs +++ b/rust/tests/tests/test_unsigned.rs @@ -129,56 +129,6 @@ fn test_binary_when_xlang() { assert_eq!(data, result); } -#[test] -fn test_binary_max_size_guardrail_for_vec_u8() { - let fory = Fory::builder().xlang(false).compatible(false).build(); - let original = vec![1_u8, 2, 3, 4, 5]; - let serialized = fory.serialize(&original).unwrap(); - - let limited_fory = Fory::builder() - .xlang(false) - .max_binary_size(4) - .compatible(false) - .build(); - let err = limited_fory - .deserialize::>(&serialized) - .expect_err("expected binary size guardrail to reject the payload"); - - assert!( - matches!(err, fory_core::Error::SizeLimitExceeded(_)), - "expected SizeLimitExceeded, got: {err}" - ); - assert!( - err.to_string().contains("Binary size 5 exceeds limit 4"), - "unexpected error message: {err}" - ); -} - -#[test] -fn test_binary_max_size_guardrail_for_vec_u32() { - let fory = Fory::builder().xlang(false).compatible(false).build(); - let original = vec![10_u32, 20, 30]; - let serialized = fory.serialize(&original).unwrap(); - - let limited_fory = Fory::builder() - .xlang(false) - .max_binary_size(8) - .compatible(false) - .build(); - let err = limited_fory - .deserialize::>(&serialized) - .expect_err("expected primitive array size guardrail to reject the payload"); - - assert!( - matches!(err, fory_core::Error::SizeLimitExceeded(_)), - "expected SizeLimitExceeded, got: {err}" - ); - assert!( - err.to_string().contains("Binary size 12 exceeds limit 8"), - "unexpected error message: {err}" - ); -} - #[test] fn test_unsigned_struct_non_compatible() { #[derive(ForyStruct, Debug, PartialEq)] diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala index caba2faf02..066e24c629 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/CollectionSerializer.scala @@ -55,9 +55,13 @@ abstract class AbstractScalaCollectionSerializer[A, T <: Iterable[A]]( override def newCollection(readContext: ReadContext): util.Collection[_] = { val buffer = readContext.getBuffer val numElements = buffer.readVarUInt32() + checkCollectionSize(numElements) setNumElements(numElements) val factory = readContext.readRef().asInstanceOf[Factory[A, T]] val builder = factory.newBuilder + if (numElements != 0) { + buffer.checkReadableBytes(numElements) + } builder.sizeHint(numElements) new JavaCollectionBuilder[A, T](builder) } diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala index 8ba41f535b..9c21954b7d 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/MapSerializer.scala @@ -52,9 +52,13 @@ abstract class AbstractScalaMapSerializer[K, V, T](typeResolver: TypeResolver, c override def newMap(readContext: ReadContext): util.Map[_, _] = { val buffer = readContext.getBuffer val numElements = buffer.readVarUInt32() + checkMapSize(numElements) setNumElements(numElements) val factory = readContext.readRef().asInstanceOf[Factory[(K, V), T]] val builder = factory.newBuilder + if (numElements != 0) { + buffer.checkReadableBytes(numElements) + } builder.sizeHint(numElements) new MapBuilder[K, V, T](builder) } diff --git a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala index 49edf89d52..9eeab286d2 100644 --- a/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala +++ b/scala/src/main/scala/org/apache/fory/serializer/scala/XlangCollectionSerializer.scala @@ -43,8 +43,12 @@ abstract class AbstractScalaXlangCollectionSerializer[A, T <: scala.collection.I } override def newCollection(readContext: ReadContext): util.Collection[_] = { - val numElements = readCollectionSize(readContext.getBuffer) + val buffer = readContext.getBuffer + val numElements = readCollectionSize(buffer) setNumElements(numElements) + if (numElements != 0) { + buffer.checkReadableBytes(numElements) + } val builder = newBuilder(numElements) if (ScalaXlangCollectionShape.hasOptionElement(readContext)) { new XlangOptionCollectionBuilder[A, T](builder) @@ -364,9 +368,14 @@ abstract class AbstractScalaXlangMapSerializer[K, V, T <: scala.collection.Map[K } override def newMap(readContext: ReadContext): util.Map[_, _] = { - val numElements = readMapSize(readContext.getBuffer) + val buffer = readContext.getBuffer + val numElements = readMapSize(buffer) setNumElements(numElements) - val builder = ScalaXlangCollectionShape.mapBuilder[K, V, T](cls, numElements) + if (numElements != 0) { + buffer.checkReadableBytes(numElements) + } + val builder = + ScalaXlangCollectionShape.mapBuilder[K, V, T](cls, numElements) val optionKey = ScalaXlangCollectionShape.hasOptionKey(readContext) val optionValue = ScalaXlangCollectionShape.hasOptionValue(readContext) if (optionKey || optionValue) { diff --git a/swift/Sources/Fory/ByteBuffer.swift b/swift/Sources/Fory/ByteBuffer.swift index 0ddc18028c..fac208ab80 100644 --- a/swift/Sources/Fory/ByteBuffer.swift +++ b/swift/Sources/Fory/ByteBuffer.swift @@ -461,7 +461,7 @@ public final class ByteBuffer { @inline(__always) public func checkBound(_ need: Int) throws { let length = readableCount - if cursor + need > length { + if need < 0 || cursor > length || need > length - cursor { throw ForyError.outOfBounds(cursor: cursor, need: need, length: length) } } @@ -885,11 +885,20 @@ public final class ByteBuffer { @inlinable public func readBytes(count: Int) throws -> [UInt8] { + try checkBound(count) if count == 0 { return [] } - return try [UInt8](unsafeUninitializedCapacity: count) { destination, initializedCount in - try readBytes(into: UnsafeMutableRawBufferPointer(destination)) + return [UInt8](unsafeUninitializedCapacity: count) { destination, initializedCount in + withUnsafeReadableBytes { rawBytes in + let sourceBase = rawBytes.baseAddress! + let destinationBase = destination.baseAddress! + UnsafeMutableRawPointer(destinationBase).copyMemory( + from: sourceBase.advanced(by: cursor), + byteCount: count + ) + } + cursor += count initializedCount = count } } diff --git a/swift/Sources/Fory/CollectionSerializers.swift b/swift/Sources/Fory/CollectionSerializers.swift index df7e33889a..1be59fb6b4 100644 --- a/swift/Sources/Fory/CollectionSerializers.swift +++ b/swift/Sources/Fory/CollectionSerializers.swift @@ -235,19 +235,19 @@ func writePrimitiveArray(_ value: [Element], context: Write } func readPrimitiveArray(_ context: ReadContext) throws -> [Element] { - let payloadSize = Int(try context.buffer.readVarUInt32()) - try context.ensureRemainingBytes(payloadSize, label: "primitive_array_payload") + let byteSize = Int(try context.buffer.readVarUInt32()) + try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") if Element.self == UInt8.self { - try context.ensureCollectionLength(payloadSize, label: "uint8_array") - let bytes = try context.buffer.readBytes(count: payloadSize) + try context.ensureCollectionLength(byteSize, label: "uint8_array") + let bytes = try context.buffer.readBytes(count: byteSize) return uncheckedArrayCast(bytes, to: Element.self) } if Element.self == Bool.self { - try context.ensureCollectionLength(payloadSize, label: "bool_array") - let out = try readArrayUninitialized(count: payloadSize) { destination in - for index in 0..(_ context: ReadContext) throws -> [ } if Element.self == Int8.self { - try context.ensureCollectionLength(payloadSize, label: "int8_array") - var out = Array(repeating: Int8(0), count: payloadSize) + try context.ensureCollectionLength(byteSize, label: "int8_array") + var out = Array(repeating: Int8(0), count: byteSize) try out.withUnsafeMutableBytes { rawBytes in try context.buffer.readBytes(into: rawBytes) } @@ -264,8 +264,8 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ } if Element.self == Int16.self { - if payloadSize % 2 != 0 { throw ForyError.invalidData("int16 array payload size mismatch") } - let count = payloadSize / 2 + if byteSize % 2 != 0 { throw ForyError.invalidData("int16 array byte size mismatch") } + let count = byteSize / 2 try context.ensureCollectionLength(count, label: "int16_array") if hostIsLittleEndian { var out = Array(repeating: Int16(0), count: count) @@ -283,8 +283,8 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ } if Element.self == Int32.self { - if payloadSize % 4 != 0 { throw ForyError.invalidData("int32 array payload size mismatch") } - let count = payloadSize / 4 + if byteSize % 4 != 0 { throw ForyError.invalidData("int32 array byte size mismatch") } + let count = byteSize / 4 try context.ensureCollectionLength(count, label: "int32_array") if hostIsLittleEndian { var out = Array(repeating: Int32(0), count: count) @@ -302,8 +302,8 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ } if Element.self == UInt32.self { - if payloadSize % 4 != 0 { throw ForyError.invalidData("uint32 array payload size mismatch") } - let count = payloadSize / 4 + if byteSize % 4 != 0 { throw ForyError.invalidData("uint32 array byte size mismatch") } + let count = byteSize / 4 try context.ensureCollectionLength(count, label: "uint32_array") if hostIsLittleEndian { var out = Array(repeating: UInt32(0), count: count) @@ -321,8 +321,8 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ } if Element.self == Int64.self { - if payloadSize % 8 != 0 { throw ForyError.invalidData("int64 array payload size mismatch") } - let count = payloadSize / 8 + if byteSize % 8 != 0 { throw ForyError.invalidData("int64 array byte size mismatch") } + let count = byteSize / 8 try context.ensureCollectionLength(count, label: "int64_array") if hostIsLittleEndian { var out = Array(repeating: Int64(0), count: count) @@ -340,8 +340,8 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ } if Element.self == UInt64.self { - if payloadSize % 8 != 0 { throw ForyError.invalidData("uint64 array payload size mismatch") } - let count = payloadSize / 8 + if byteSize % 8 != 0 { throw ForyError.invalidData("uint64 array byte size mismatch") } + let count = byteSize / 8 try context.ensureCollectionLength(count, label: "uint64_array") if hostIsLittleEndian { var out = Array(repeating: UInt64(0), count: count) @@ -359,8 +359,8 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ } if Element.self == UInt16.self { - if payloadSize % 2 != 0 { throw ForyError.invalidData("uint16 array payload size mismatch") } - let count = payloadSize / 2 + if byteSize % 2 != 0 { throw ForyError.invalidData("uint16 array byte size mismatch") } + let count = byteSize / 2 try context.ensureCollectionLength(count, label: "uint16_array") if hostIsLittleEndian { var out = Array(repeating: UInt16(0), count: count) @@ -378,8 +378,8 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ } if Element.self == Float16.self { - if payloadSize % 2 != 0 { throw ForyError.invalidData("float16 array payload size mismatch") } - let count = payloadSize / 2 + if byteSize % 2 != 0 { throw ForyError.invalidData("float16 array byte size mismatch") } + let count = byteSize / 2 try context.ensureCollectionLength(count, label: "float16_array") let values = try readArrayUninitialized(count: count) { destination in for index in 0..(_ context: ReadContext) throws -> [ } if Element.self == BFloat16.self { - if payloadSize % 2 != 0 { throw ForyError.invalidData("bfloat16 array payload size mismatch") } - let count = payloadSize / 2 + if byteSize % 2 != 0 { throw ForyError.invalidData("bfloat16 array byte size mismatch") } + let count = byteSize / 2 try context.ensureCollectionLength(count, label: "bfloat16_array") let values = try readArrayUninitialized(count: count) { destination in for index in 0..(_ context: ReadContext) throws -> [ } if Element.self == Float.self { - if payloadSize % 4 != 0 { throw ForyError.invalidData("float32 array payload size mismatch") } - let count = payloadSize / 4 + if byteSize % 4 != 0 { throw ForyError.invalidData("float32 array byte size mismatch") } + let count = byteSize / 4 try context.ensureCollectionLength(count, label: "float32_array") if hostIsLittleEndian { var out = Array(repeating: Float(0), count: count) @@ -420,8 +420,8 @@ func readPrimitiveArray(_ context: ReadContext) throws -> [ return uncheckedArrayCast(out, to: Element.self) } - if payloadSize % 8 != 0 { throw ForyError.invalidData("float64 array payload size mismatch") } - let count = payloadSize / 8 + if byteSize % 8 != 0 { throw ForyError.invalidData("float64 array byte size mismatch") } + let count = byteSize / 8 try context.ensureCollectionLength(count, label: "float64_array") if hostIsLittleEndian { var out = Array(repeating: Double(0), count: count) @@ -541,6 +541,7 @@ extension Array: Serializer where Element: Serializer { let declared = (header & CollectionHeader.declaredElementType) != 0 let sameType = (header & CollectionHeader.sameType) != 0 if !sameType { + try context.ensureRemainingBytes(length, label: "array") if trackRef { return try readArrayUninitialized(count: length) { destination in for index in 0.. 0 else { throw ForyError.invalidData("invalid decimal magnitude length \(length)") } - let payload = try context.buffer.readBytes(count: length) - guard payload[length - 1] != 0 else { - throw ForyError.invalidData("non-canonical decimal payload: trailing zero byte") + let magnitudeBytes = try context.buffer.readBytes(count: length) + guard magnitudeBytes[length - 1] != 0 else { + throw ForyError.invalidData("non-canonical decimal magnitude bytes: trailing zero byte") } - let normalized = normalizeDecimalMagnitude(payload) + let normalized = normalizeDecimalMagnitude(magnitudeBytes) guard !normalized.isEmpty else { throw ForyError.invalidData("big decimal encoding must not represent zero") } diff --git a/swift/Sources/Fory/FieldCodecs.swift b/swift/Sources/Fory/FieldCodecs.swift index 96e0babd1b..b9e8825967 100644 --- a/swift/Sources/Fory/FieldCodecs.swift +++ b/swift/Sources/Fory/FieldCodecs.swift @@ -963,7 +963,8 @@ where KeyCodec.Value: Hashable { } var map: Value = [:] - map.reserveCapacity(Swift.min(totalLength, context.buffer.remaining)) + try context.ensureRemainingBytes(totalLength, label: "map") + map.reserveCapacity(totalLength) var readCount = 0 while readCount < totalLength { let header = try context.buffer.readUInt8() @@ -1511,12 +1512,12 @@ private func readPackedArrayElementCount( width: Int, label: String ) throws -> Int { - let payloadSize = Int(try context.buffer.readVarUInt32()) - try context.ensureRemainingBytes(payloadSize, label: "primitive_array_payload") - if payloadSize % width != 0 { - throw ForyError.invalidData("\(label) payload size mismatch") + let byteSize = Int(try context.buffer.readVarUInt32()) + try context.ensureRemainingBytes(byteSize, label: "primitive_array_bytes") + if byteSize % width != 0 { + throw ForyError.invalidData("\(label) byte size mismatch") } - let count = payloadSize / width + let count = byteSize / width try context.ensureCollectionLength(count, label: label) return count } @@ -1603,6 +1604,7 @@ private func readCollectionPayload( let sameType = (header & CollectionHeader.sameType) != 0 var result: [ElementCodec.Value] = [] + try context.ensureRemainingBytes(length, label: "array") result.reserveCapacity(length) if !sameType { @@ -1698,9 +1700,6 @@ private func readListPayloadAsArrayPayload( let declared = (header & CollectionHeader.declaredElementType) != 0 let sameType = (header & CollectionHeader.sameType) != 0 - var result: [ElementCodec.Value] = [] - result.reserveCapacity(length) - if !sameType { throw ForyError.invalidData("compatible list-to-array field requires same-type elements") } @@ -1714,6 +1713,9 @@ private func readListPayloadAsArrayPayload( } else { throw ForyError.invalidData("compatible list-to-array field requires declared elements") } + try context.ensureRemainingBytes(length, label: "array") + var result: [ElementCodec.Value] = [] + result.reserveCapacity(length) return try ElementCodec.withTypeInfo(elementTypeInfo, context) { for _ in 0.. Data { let length = try context.buffer.readVarUInt32() let byteLength = Int(length) - try context.ensureBinaryLength(byteLength, label: "binary") try context.ensureRemainingBytes(byteLength, label: "binary") let bytes = try context.buffer.readBytes(count: byteLength) return Data(bytes) diff --git a/swift/Sources/Fory/ReadContext.swift b/swift/Sources/Fory/ReadContext.swift index 664916a0fc..99c089be61 100644 --- a/swift/Sources/Fory/ReadContext.swift +++ b/swift/Sources/Fory/ReadContext.swift @@ -25,8 +25,6 @@ public final class ReadContext { public let trackRef: Bool public let compatible: Bool public let checkClassVersion: Bool - public let maxCollectionSize: Int - public let maxBinarySize: Int public let maxDepth: Int public let refReader: RefReader private let compatibleTypeDefTypeInfos = ReusableArray(defaultValue: nil, reserve: 2) @@ -43,8 +41,6 @@ public final class ReadContext { trackRef: Bool, compatible: Bool = false, checkClassVersion: Bool = true, - maxCollectionSize: Int = 1_000_000, - maxBinarySize: Int = 64 * 1024 * 1024, maxDepth: Int = 5 ) { self.buffer = buffer @@ -52,8 +48,6 @@ public final class ReadContext { self.trackRef = trackRef self.compatible = compatible self.checkClassVersion = checkClassVersion - self.maxCollectionSize = maxCollectionSize - self.maxBinarySize = maxBinarySize self.maxDepth = maxDepth self.refReader = RefReader() } @@ -84,23 +78,6 @@ public final class ReadContext { if length < 0 { throw ForyError.invalidData("\(label) length is negative") } - if length > maxCollectionSize { - throw ForyError.invalidData( - "\(label) length \(length) exceeds configured maxCollectionSize \(maxCollectionSize)" - ) - } - } - - @inline(__always) - func ensureBinaryLength(_ length: Int, label: String) throws { - if length < 0 { - throw ForyError.invalidData("\(label) size is negative") - } - if length > maxBinarySize { - throw ForyError.invalidData( - "\(label) size \(length) exceeds configured maxBinarySize \(maxBinarySize)" - ) - } } @inline(__always) @@ -582,9 +559,7 @@ public final class ReadContext { if dynamicAnyDepth != 0 { dynamicAnyDepth = 0 } - if trackRef { - refReader.reset() - } + refReader.reset() if !typeInfoStack.isEmpty { typeInfoStack.clear() } diff --git a/swift/Tests/ForyTests/CollectionSerializerTests.swift b/swift/Tests/ForyTests/CollectionSerializerTests.swift index 7f8e4b8f7b..0945e82939 100644 --- a/swift/Tests/ForyTests/CollectionSerializerTests.swift +++ b/swift/Tests/ForyTests/CollectionSerializerTests.swift @@ -461,7 +461,7 @@ func collectionSerializersRejectMalformedPrimitivePayloads() throws { let _: [Int16] = try ArrayFieldCodec.readPayload(int16Context) #expect(Bool(false)) } catch { - #expect("\(error)".contains("payload size mismatch")) + #expect("\(error)".contains("byte size mismatch")) } let float64Buffer = ByteBuffer() @@ -476,6 +476,6 @@ func collectionSerializersRejectMalformedPrimitivePayloads() throws { let _: [Double] = try ArrayFieldCodec.readPayload(float64Context) #expect(Bool(false)) } catch { - #expect("\(error)".contains("payload size mismatch")) + #expect("\(error)".contains("byte size mismatch")) } } diff --git a/swift/Tests/ForyTests/DateTimeTests.swift b/swift/Tests/ForyTests/DateTimeTests.swift index 7cc493fde1..0f8bba02a5 100644 --- a/swift/Tests/ForyTests/DateTimeTests.swift +++ b/swift/Tests/ForyTests/DateTimeTests.swift @@ -112,8 +112,6 @@ func dateAndTimestampContextHelpersUseExpectedWireProtocols() throws { trackRef: false, compatible: true, checkClassVersion: true, - maxCollectionSize: 1_000_000, - maxBinarySize: 64 * 1024 * 1024, maxDepth: 5 ) let xlangLocalDateDecoded = try xlangReadContext.readLocalDate(refMode: RefMode.nullOnly, readTypeInfo: true) @@ -137,8 +135,6 @@ func dateAndTimestampContextHelpersUseExpectedWireProtocols() throws { trackRef: false, compatible: true, checkClassVersion: true, - maxCollectionSize: 1_000_000, - maxBinarySize: 64 * 1024 * 1024, maxDepth: 5 ) let timestampDecoded = try timestampReadContext.readTimestamp(refMode: RefMode.nullOnly, readTypeInfo: true) diff --git a/swift/Tests/ForyTests/ForySwiftTests.swift b/swift/Tests/ForyTests/ForySwiftTests.swift index 9d01eaecb1..0aac15035f 100644 --- a/swift/Tests/ForyTests/ForySwiftTests.swift +++ b/swift/Tests/ForyTests/ForySwiftTests.swift @@ -395,38 +395,6 @@ func structEvolvingOverrideUsesSmallerCompatiblePayload() throws { #expect(decodedFixed == fixed) } -@Test -func decodeLimitsRejectOversizedPayloads() throws { - let writer = Fory() - - let oversizedCollection = try writer.serialize(["a", "b", "c"]) - let collectionLimited = Fory(config: .init(maxCollectionSize: 2)) - do { - let _: [String] = try collectionLimited.deserialize(oversizedCollection) - #expect(Bool(false)) - } catch {} - - let oversizedMap = try writer.serialize([Int32(1): Int32(1), 2: 2, 3: 3]) - do { - let _: [Int32: Int32] = try collectionLimited.deserialize(oversizedMap) - #expect(Bool(false)) - } catch {} - - let oversizedBinary = try writer.serialize(Data([0x01, 0x02, 0x03, 0x04])) - let binaryLimited = Fory(config: .init(maxBinarySize: 3)) - do { - let _: Data = try binaryLimited.deserialize(oversizedBinary) - #expect(Bool(false)) - } catch {} - - let oversizedArrayPayload = try writer.serialize([UInt16(1), 2]) - let payloadLimited = Fory(config: .init(maxCollectionSize: 1)) - do { - let _: [UInt16] = try payloadLimited.deserialize(oversizedArrayPayload) - #expect(Bool(false)) - } catch {} -} - @Test func deserializeRejectsTrailingBytes() throws { let fory = Fory()