diff --git a/include/libserial/serial.hpp b/include/libserial/serial.hpp index 5483122..6b615c4 100644 --- a/include/libserial/serial.hpp +++ b/include/libserial/serial.hpp @@ -302,6 +302,26 @@ void setMinNumberCharRead(uint16_t); */ void setBaudRate(BaudRate baud_rate); +/** + * @brief Sets the maximum safe read size + * + * Configures the maximum number of bytes that can be read + * in a single read operation to prevent excessive memory usage. + * + * @param size The desired maximum safe read size in bytes + */ +void setMaxSafeReadSize(size_t size); + +/** + * @brief Gets the maximum safe read size + * + * Retrieves the maximum number of bytes that can be read + * in a single read operation to prevent excessive memory usage. + * + * @return The maximum safe read size in bytes + */ +size_t getMaxSafeReadSize() const; + /** * @brief Gets the current baud rate * @@ -318,8 +338,9 @@ int getBaudRate() const; void setFdForTest(int fd) { fd_serial_port_ = fd; } - -// For testing - allow injection of mock functions +// WARNING: Test helper only! This function allows injection of custom +// system call functions for testing error handling. It should NEVER be +// used in production code. void setSystemCallFunctions( std::function poll_func, std::function read_func) { @@ -410,8 +431,9 @@ std::chrono::milliseconds write_timeout_ms_{1000}; ///< Write timeout in mill * * Defines the maximum number of bytes that can be read * in a single read operation to prevent excessive memory usage. + * Default is 2048 bytes (2KB). */ -static constexpr size_t kMaxSafeReadSize = 2048; // 2KB limit +size_t max_safe_read_size_{2048}; // 2KB limit /** * @brief Timeout value in milliseconds diff --git a/src/serial.cpp b/src/serial.cpp index ff70861..7fc2600 100644 --- a/src/serial.cpp +++ b/src/serial.cpp @@ -65,7 +65,7 @@ size_t Serial::read(std::shared_ptr buffer) { } buffer->clear(); - buffer->resize(kMaxSafeReadSize); + buffer->resize(max_safe_read_size_); struct pollfd fd_poll; fd_poll.fd = fd_serial_port_; @@ -83,7 +83,8 @@ size_t Serial::read(std::shared_ptr buffer) { } // Data available: do the read - ssize_t bytes_read = read_(fd_serial_port_, const_cast(buffer->data()), kMaxSafeReadSize); + ssize_t bytes_read = read_(fd_serial_port_, const_cast(buffer->data()), + max_safe_read_size_); if (bytes_read < 0) { throw IOException(std::string("Error reading from serial port: ") + strerror(errno)); } @@ -108,7 +109,7 @@ size_t Serial::readBytes(std::shared_ptr buffer, size_t num_bytes) buffer->clear(); buffer->resize(num_bytes); - ssize_t bytes_read = ::read(fd_serial_port_, buffer->data(), num_bytes); // codacy-ignore[buffer-boundary] + ssize_t bytes_read = read_(fd_serial_port_, buffer->data(), num_bytes); // codacy-ignore[buffer-boundary] if (bytes_read < 0) { throw IOException("Error reading from serial port: " + std::string(strerror(errno))); @@ -130,9 +131,9 @@ size_t Serial::readUntil(std::shared_ptr buffer, char terminator) { while (temp_char != terminator) { // Check buffer size limit to prevent excessive memory usage - if (buffer->size() >= kMaxSafeReadSize) { + if (buffer->size() >= max_safe_read_size_) { throw IOException("Read buffer exceeded maximum size limit of " + - std::to_string(kMaxSafeReadSize) + + std::to_string(max_safe_read_size_) + " bytes without finding terminator"); } // Check timeout if enabled (0 means no timeout) @@ -344,6 +345,14 @@ void Serial::setMinNumberCharRead(uint16_t num) { this->setTermios2(); } +void Serial::setMaxSafeReadSize(size_t size) { + max_safe_read_size_ = size; +} + +size_t Serial::getMaxSafeReadSize() const { + return max_safe_read_size_; +} + int Serial::getAvailableData() const { int bytes_available; if (ioctl(fd_serial_port_, FIONREAD, &bytes_available) < 0) { diff --git a/test/test_serial_pty.cpp b/test/test_serial_pty.cpp index 89047db..3cb2f73 100644 --- a/test/test_serial_pty.cpp +++ b/test/test_serial_pty.cpp @@ -426,6 +426,39 @@ TEST_F(PseudoTerminalTest, ReadBytesWithInvalidNumBytes) { }, libserial::IOException); } +TEST_F(PseudoTerminalTest, ReadBytesWithReadFail) { + libserial::Serial serial_port; + + serial_port.open(slave_port_); + serial_port.setBaudRate(9600); + serial_port.setCanonicalMode(libserial::CanonicalMode::DISABLE); + + auto read_buffer = std::make_shared(); + + for (const auto& [error_num, error_msg] : errors_read_) { + serial_port.setSystemCallFunctions( + [](struct pollfd*, nfds_t, int) -> int { + return 1; + }, + [error_num](int, void*, size_t) -> ssize_t { + errno = error_num; + return -1; + }); + + auto expected_what = "Error reading from serial port: " + error_msg; + + EXPECT_THROW({ + try { + serial_port.readBytes(read_buffer, 10); + } + catch (const libserial::IOException& e) { + EXPECT_STREQ(expected_what.c_str(), e.what()); + throw; + } + }, libserial::IOException); + } +} + TEST_F(PseudoTerminalTest, ReadBytesCanonicalMode) { libserial::Serial serial_port; @@ -472,6 +505,25 @@ TEST_F(PseudoTerminalTest, ReadUntil) { EXPECT_EQ(*read_buffer, "Read Until!"); } +TEST_F(PseudoTerminalTest, ReadUntilWithNullBuffer) { + libserial::Serial serial_port; + + serial_port.open(slave_port_); + serial_port.setBaudRate(9600); + + std::shared_ptr null_buffer; + + EXPECT_THROW({ + try { + serial_port.readUntil(null_buffer, '!'); + } + catch (const libserial::IOException& e) { + EXPECT_STREQ("Null pointer passed to readUntil function", e.what()); + throw; + } + }, libserial::IOException); +} + TEST_F(PseudoTerminalTest, ReadUntilTimeout) { libserial::Serial serial_port; @@ -552,3 +604,38 @@ TEST_F(PseudoTerminalTest, ReadUntilWithPollFail) { }, libserial::IOException); } } + +TEST_F(PseudoTerminalTest, ReadUntilWithOverflowBuffer) { + libserial::Serial serial_port; + + serial_port.open(slave_port_); + serial_port.setBaudRate(9600); + EXPECT_NO_THROW(serial_port.setMaxSafeReadSize(10)); // Set max safe read size to 10 bytes + + std::string test_message(15, 'a'); + test_message.push_back('\n'); + + ssize_t bytes_written = write(master_fd_, test_message.c_str(), test_message.length()); + ASSERT_GT(bytes_written, 0) << "Failed to write to master end"; + + // Give time for data to propagate + fsync(master_fd_); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Test reading with shared pointer - only read what's available + auto read_buffer = std::make_shared(); + + auto expected_what = "Read buffer exceeded maximum size limit of " + + std::to_string(serial_port.getMaxSafeReadSize()) + + " bytes without finding terminator"; + + EXPECT_THROW({ + try { + serial_port.readUntil(read_buffer, '!'); + } + catch (const libserial::IOException& e) { + EXPECT_STREQ(expected_what.c_str(), e.what()); + throw; + } + }, libserial::IOException); +}