From 3f09d05db241fa26ae88b3ca9bead2442b6e9805 Mon Sep 17 00:00:00 2001 From: d-w-moore Date: Sun, 18 May 2025 22:55:25 -0400 Subject: [PATCH] [_722] fix segfault and hung threads when signals abort a parallel transfer. We now provide the abort_parallel_transfers() function, which can be used to shut down parallel PUTs and GETs in an orderly way. For its usage, see the README.md. The segmentation faults were not successfully duplicated. I believe they are likely to have happened when quitting a main process that spawned daemon threads. Note that in 3.9 (the minimum Python interpreter we now support), concurrent.futures no longer uses daemon threads, and this futures mechanism is what underlies parallel transfers in the Python client. Note that the use of os._exit, heretofore the simple application's best resort for exiting while aborting all unfinished threads (daemonic or not), is now best supplanted by the builtin exit function or sys.exit, preceded of course by the appropriate call(s) to abort_parallel_transfers. --- README.md | 99 +++++++ irods/manager/data_object_manager.py | 61 ++-- irods/parallel.py | 262 ++++++++++++++---- irods/session.py | 17 +- irods/test/data_obj_test.py | 15 + ...test_signal_handling_in_multithread_get.py | 127 +++++++++ ...test_signal_handling_in_multithread_put.py | 179 ++++++++++++ irods/test/modules/tools.py | 31 +++ 8 files changed, 707 insertions(+), 84 deletions(-) create mode 100644 irods/test/modules/test_signal_handling_in_multithread_get.py create mode 100644 irods/test/modules/test_signal_handling_in_multithread_put.py create mode 100644 irods/test/modules/tools.py diff --git a/README.md b/README.md index d8ee206de..391cc943b 100644 --- a/README.md +++ b/README.md @@ -312,6 +312,105 @@ will spawn a number of threads in order to optimize performance for iRODS server versions 4.2.9+ and file sizes larger than a default threshold value of 32 Megabytes. +Because multithreaded processes under Unix-type operating systems sometimes +need special handling, it is recommended that any put or get of a large file +be appropriately handled in the case that a terminating signal aborts the +transfer: + +```python +from irods.parallel import abort_parallel_transfers + +def handler(signal, _): + abort_parallel_transfers() + +signal(SIGTERM, handler) + +try: + # A multi-1247 put or get can leave non-daemon threads running if not treated with care. + session.data_objects.put(...) +except KeyboardInterrupt: + # Internally, the library has likely already started a shutdown for the + # present put operation, but we can non-destructively issue the following + # call to ensure any other ongoing transfers are aborted, prior to re-raising: + abort_parallel_transfers() + + printf('Due to a SIGINT or Control-C, the put failed.') + # Raise again, as is customary when catching a directly BaseException-derived object. + raise +except RuntimeError: + printf('The put failed.') +# ... +``` + +In general it is better (for applications wanting to gracefully handle +interrupted lengthy data transfers to/from iRODS data objects) to anticipate +control-C by handling both `KeyboardInterrupt` and `RuntimeError`, as shown +above. + +Of course, had we intercepted SIGINT (meaning we assigned it a custom +handler), we could avoid the need to handle the `KeyboardInterrupt` in an +exception clause. + +Internally, of course, the PRC must handle all eventualities, including +`KeyboardInterrupt`, by closing down the current transfer being setup or +waited on, in consideration of very simple applications which might do no +signal or exception handling of their own. Otherwise we would risk some +non-daemon threads not finishing, which could risk preventing a prompt and +orderly exit from the main program. + +When a signal or exception handler calls `abort_parallel_transfers()`, all +parallel transfers are aborted immediately. Upon return to the normal flow +of the main program, the affected transfers will raise `RuntimeError` to +indicate the PUT or GET operation has failed. + +Note that `abort_parallel_transfers()` is designed to be safe for inclusion +in signal handlers (e.g. it may be called several times without detrimental +effects); and that it returns promptly after having initiated the process of +shutting down the data transfer threads, rather than waiting for them to +terminate first. This owes to the best practice of minimizing time spent in +signal handlers. However, if desired, `abort_parallel_transfers()` may be +iterated subsequently with `(dry_run=True, ...)` to track the progress of the +shutdown. The default object returned (a dictionary whose keys are weak +references to the thread managers) will have a boolean value of `False` once +all transfer threads have exited. + +The following example shows how to abort all synchronous ("foreground") puts +while leaving background transfers alone: + +```python +import irods.helpers, threading +from irods.parallel import abort_parallel_transfers, FILTER_FUNCTIONS, io_main, Oper + +session = irods.helpers.make_session() +hc = irods.helpers.home_collection(session) + +sessions_for_threads = [session.clone() for _ in range(3)] + +# Launch an asynchronous (i.e. "background") transfer +io_main(session, '{hc}/target_0', Oper.PUT|Oper.NONBLOCKING, 'my_large_file') + +Threads = [ threading.Thread(target=lambda _sess, put_args: _sess.data_objects.put(*put_args, num_threads=2), + args=(sess,["my_large_file", f"{hc}/target_"+str(i+1)])) + for i, sess in enumerate(sessions_for_threads) ] + +try: + # Launch transfer threads + for t in Threads: + t.start() + # Wait on transfer threads + for t in Threads: + t.join() +except KeyboardInterrupt: + # Trapping control-C, stop transfers launched synchronously, i.e. in the foreground of each transfer thread + # explicitly launched above. + abort_parallel_transfers(filter_function = FILTER_FUNCTIONS.foreground) + +# As the main application continues (or exits), the asynchronous transfer will be the only one to finish +# successfully after this call to `abort_parallel_transfers`. Note: Although care is taken not to +# leave replicas in a locked state, these relics of unfinished transfers can still remain in the +# object catalog. +``` + Progress bars ------------- diff --git a/irods/manager/data_object_manager.py b/irods/manager/data_object_manager.py index f2c5ed31b..96a1980dd 100644 --- a/irods/manager/data_object_manager.py +++ b/irods/manager/data_object_manager.py @@ -130,13 +130,16 @@ def __init__(self, *a, **kwd): # if provided via keyword '_session'. self._iRODS_session = kwd.pop("_session", None) super(ManagedBufferedRandom, self).__init__(*a, **kwd) - import irods.session + self.do_close = True + + import irods.session with irods.session._fds_lock: - irods.session._fds[self] = None + if irods.session._fds is not None: + irods.session._fds[self] = None def __del__(self): - if not self.closed: + if self.do_close and not self.closed: self.close() call___del__if_exists(super(ManagedBufferedRandom, self)) @@ -245,15 +248,21 @@ def _download(self, obj, local_path, num_threads, updatables=(), **options): if self.should_parallelize_transfer( num_threads, o, open_options=options.items() ): - if not self.parallel_get( - (obj, o), - local_file, - num_threads=num_threads, - target_resource_name=options.get(kw.RESC_NAME_KW, ""), - data_open_returned_values=data_open_returned_values_, - updatables=updatables, - ): - raise RuntimeError("parallel get failed") + error = RuntimeError("parallel get failed") + try: + if not self.parallel_get( + (obj, o), + local_file, + num_threads=num_threads, + target_resource_name=options.get(kw.RESC_NAME_KW, ""), + data_open_returned_values=data_open_returned_values_, + updatables=updatables, + ): + raise error + except ex.iRODSException as e: + raise e + except BaseException as e: + raise error from e else: with open(local_file, "wb") as f: for chunk in chunks(o, self.READ_BUFFER_SIZE): @@ -353,17 +362,23 @@ def put( ): o = deferred_call(self.open, (obj, "w"), options) f.close() - if not self.parallel_put( - local_path, - (obj, o), - total_bytes=sizelist[0], - num_threads=num_threads, - target_resource_name=options.get(kw.RESC_NAME_KW, "") - or options.get(kw.DEST_RESC_NAME_KW, ""), - open_options=options, - updatables=updatables, - ): - raise RuntimeError("parallel put failed") + error = RuntimeError("parallel put failed") + try: + if not self.parallel_put( + local_path, + (obj, o), + total_bytes=sizelist[0], + num_threads=num_threads, + target_resource_name=options.get(kw.RESC_NAME_KW, "") + or options.get(kw.DEST_RESC_NAME_KW, ""), + open_options=options, + updatables=updatables, + ): + raise error + except ex.iRODSException as e: + raise e + except BaseException as e: + raise error from e else: with self.open(obj, "w", **options) as o: # Set operation type to trigger acPostProcForPut diff --git a/irods/parallel.py b/irods/parallel.py index 2ad03492d..4d949cf44 100644 --- a/irods/parallel.py +++ b/irods/parallel.py @@ -9,13 +9,73 @@ import concurrent.futures import threading import multiprocessing -from typing import List, Union +from typing import List, Union, Any +import weakref from irods.data_object import iRODSDataObject from irods.exception import DataObjectDoesNotExist import irods.keywords as kw from queue import Queue, Full, Empty +paths_active: weakref.WeakValueDictionary[str, "AsyncNotify"] = ( + weakref.WeakValueDictionary() +) +transfer_managers: weakref.WeakKeyDictionary["_Multipart_close_manager", Any] = ( + weakref.WeakKeyDictionary() +) + + +class FILTER_FUNCTIONS: + """The members of this class are free functions designed to be passed to + the "filter_function" parameter of the abort_parallel_transfers function. + """ + foreground = staticmethod(lambda item: isinstance(item[1], tuple)) + background = staticmethod(lambda item: not isinstance(item[1], tuple)) + + +def abort_parallel_transfers( + dry_run=False, filter_function=None, transform=weakref.WeakKeyDictionary +): + """ + If no explicit arguments are given, all ongoing parallel puts and gets are + cancelled as soon as possible. The corresponding threads are signalled to + exit by calling the quit() method on their corresponding transfer-manager + objects. + + Setting dry_run=True results in no such cancellation being performed, + although a dict object will be computed for the return value containing, + as its keys, the transfer-manager objects that would have been so affected. + + filter_function is usually left to its default value of None. By applying + a member of the FILTER_FUNCTIONS class, the caller may specify which + transfer-managers are affected and/or reflected in the return value: + + FILTER_FUNCTIONS.foreground for example limits its scope to instances + of session.data_object.put() and session.data_object.get() that are + running synchronously. These calls block within the thread(s) that + are calling them. + + FILTER_FUNCTIONS.background limits its scope to transfers started by + calls to io_main() which used Oper.NONBLOCKING to spawn the put or + get operation in the background. + + transform defaults to a dictionary type with weak keys, since + allowing strong references to transfer-manager objects may artificially + increase lifetimes of threads and other objects unnecessarily and + complicate troubleshooting by altering library behavior. Consider using + transform=len if all that is desired is to check how many parallel + transfers exist in total, at the time. + """ + mgrs = dict(filter(filter_function, transfer_managers.items())) + if not dry_run: + for mgr, item in mgrs.items(): + if isinstance(item, tuple): + quit_func, args = item[:2] + quit_func(*args) + else: + mgr.quit() + return transform(mgrs) + logger = logging.getLogger(__name__) _nullh = logging.NullHandler() @@ -91,9 +151,11 @@ def __init__( for future in self._futures: future.add_done_callback(self) else: - self.__invoke_done_callback() + self.__invoke_futures_done_logic() + return self.progress = [0, 0] + if (progress_Queue) and (total is not None): self.progress[1] = total @@ -112,7 +174,7 @@ def _progress(Q, this): # - thread to update progress indicator self._progress_fn = _progress self._progress_thread = threading.Thread( - target=self._progress_fn, args=(progress_Queue, self) + target=self._progress_fn, args=(progress_Queue, self), daemon=True ) self._progress_thread.start() @@ -153,11 +215,14 @@ def __call__( with self._lock: self._futures_done[future] = future.result() if len(self._futures) == len(self._futures_done): - self.__invoke_done_callback() + # If a future returns None rather than an integer byte count, it has aborted the transfer. + self.__invoke_futures_done_logic( + skip_user_callback=(None in self._futures_done.values()) + ) - def __invoke_done_callback(self): + def __invoke_futures_done_logic(self, skip_user_callback=False): try: - if callable(self.done_callback): + if not skip_user_callback and callable(self.done_callback): self.done_callback(self) finally: self.keep.pop("mgr", None) @@ -240,6 +305,12 @@ def _copy_part(src, dst, length, queueObject, debug_info, mgr, updatables=()): bytecount = 0 accum = 0 while True and bytecount < length: + if mgr._quit: + # Indicate by the return value that we are aborting (this part of) the data transfer. + # In the great majority of cases, this should be seen by the application as an overall + # abort of the PUT or GET of the requested object. + bytecount = None + break buf = src.read(min(COPY_BUF_SIZE, length - bytecount)) buf_len = len(buf) if 0 == buf_len: @@ -274,11 +345,39 @@ class _Multipart_close_manager: """ - def __init__(self, initial_io_, exit_barrier_): + def __init__(self, initial_io_, exit_barrier_, executor=None): + self._quit = False self.exit_barrier = exit_barrier_ self.initial_io = initial_io_ self.__lock = threading.Lock() self.aux = [] + self.futures = set() + self.executor = executor + + def add_future(self, future): + self.futures.add(future) + + @property + def active_futures(self): + return tuple(_ for _ in self.futures if not _.done()) + + def shutdown(self): + if self.executor: + self.executor.shutdown(cancel_futures=True) + + def quit(self): + from irods.session import _exclude_fds_from_auto_close + + _exclude_fds_from_auto_close(self.aux + [self.initial_io]) + + if not self._quit: + self._quit = True + + # Disable barrier and abort threads. + self.exit_barrier.abort() + self.shutdown() + + return self.active_futures def __contains__(self, Io): with self.__lock: @@ -297,6 +396,7 @@ def add_io(self, Io): # synchronizes all of the parallel threads just before exit, so that we know # exactly when to perform a finalizing close on the data object + def remove_io(self, Io): is_initial = True with self.__lock: @@ -304,8 +404,12 @@ def remove_io(self, Io): Io.close() self.aux.remove(Io) is_initial = False - self.exit_barrier.wait() - if is_initial: + broken = False + try: + self.exit_barrier.wait() + except threading.BrokenBarrierError: + broken = True + if is_initial and not (broken or self._quit): self.finalize() def finalize(self): @@ -393,7 +497,7 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk): futures = [] executor = concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) num_threads = min(num_threads, len(ranges)) - mgr = _Multipart_close_manager(Io, Barrier(num_threads)) + mgr = _Multipart_close_manager(Io, Barrier(num_threads), executor) counter = 1 gen_file_handle = lambda: open( fname, Operation.disk_file_mode(initial_open=(counter == 1)) @@ -405,48 +509,84 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk): "queueObject": queueObject, } - for byte_range in ranges: - if Io is None: - Io = session.data_objects.open( - Data_object.path, - Operation.data_object_mode(initial_open=False), - create=False, - finalize_on_close=False, - allow_redirect=False, - **{ - kw.NUM_THREADS_KW: str(num_threads), - kw.DATA_SIZE_KW: str(total_size), - kw.RESC_HIER_STR_KW: hier_str, - kw.REPLICA_TOKEN_KW: replica_token, - } - ) - mgr.add_io(Io) - logger.debug("target_host = %s", Io.raw.session.pool.account.host) - if File is None: - File = gen_file_handle() - futures.append( - executor.submit( - _io_part, - Io, - byte_range, - File, - Operation, - mgr, - thread_debug_id=str(counter), - **thread_opts - ) - ) - counter += 1 - Io = File = None + transfer_managers[mgr] = (_quit_current_transfer, [id(mgr)]) - if Operation.isNonBlocking(): - if queueLength: - return futures, queueObject, mgr + try: + thread_setup_error = None + + for byte_range in ranges: + if Io is None: + Io = session.data_objects.open( + Data_object.path, + Operation.data_object_mode(initial_open=False), + create=False, + finalize_on_close=False, + allow_redirect=False, + **{ + kw.NUM_THREADS_KW: str(num_threads), + kw.DATA_SIZE_KW: str(total_size), + kw.RESC_HIER_STR_KW: hier_str, + kw.REPLICA_TOKEN_KW: replica_token, + } + ) + mgr.add_io(Io) + logger.debug("target_host = %s", Io.raw.session.pool.account.host) + if File is None: + File = gen_file_handle() + try: + f = None + futures.append( + f := executor.submit( + _io_part, + Io, + byte_range, + File, + Operation, + mgr, + thread_debug_id=str(counter), + **thread_opts + ) + ) + except RuntimeError as error: + # Executor was probably shut down before parallel transfer could be initiated. + thread_setup_error = error + break + else: + mgr.add_future(f) + + counter += 1 + Io = File = None + + if thread_setup_error: + raise thread_setup_error + + bytes_transferred = 0 + + if Operation.isNonBlocking(): + transfer_managers[mgr] = None + return (futures, mgr, queueObject) else: - return futures - else: - bytecounts = [f.result() for f in futures] - return sum(bytecounts), total_size + # Enable user attempts to cancel the current synchronous transfer. + # At any given time, only one transfer manager key should map to a tuple object T. + # You should be able to quit all threads of the current transfer by calling T[0](*T[1]). + bytecounts = [future.result() for future in futures] + # If, rather than an integer byte-count, the "None" object was included as one of futures' return values, this + # is an indication that the PUT or GET operation should be marked as aborted, i.e. no bytes transferred. + if None not in bytecounts: + bytes_transferred = sum(bytecounts) + + return (bytes_transferred, total_size) + + except BaseException as e: + if isinstance(e, (SystemExit, KeyboardInterrupt, RuntimeError)): + mgr.quit() + raise + + +def _quit_current_transfer(obj_id): + l = [_ for _ in transfer_managers if id(_) == obj_id] + if l: + l[0].quit() def io_main(session, Data, opr_, fname, R="", **kwopt): @@ -559,18 +699,20 @@ def io_main(session, Data, opr_, fname, R="", **kwopt): if Operation.isNonBlocking(): - if queueLength > 0: - (futures, chunk_notify_queue, mgr) = retval - else: - futures = retval - chunk_notify_queue = total_bytes = None + (futures, mgr, chunk_notify_queue) = retval - return AsyncNotify( + # For convenience, this information can help determine which data object mgr is tracking. + transfer_managers[mgr] = Data.path + + paths_active[Data.path] = async_notify = AsyncNotify( futures, # individual futures, one per transfer thread progress_Queue=chunk_notify_queue, # for notifying the progress indicator thread total=total_bytes, # total number of bytes for parallel transfer - keep_={"mgr": mgr}, - ) # an open raw i/o object needing to be persisted, if any + keep_={ + "mgr": mgr + }, # objects needing to be persisted while futures are pending + ) + return async_notify else: (_bytes_transferred, _bytes_total) = retval return _bytes_transferred == _bytes_total @@ -645,10 +787,10 @@ def setupLoggingWithDateTimeHeader(name, level=logging.DEBUG): timeout=10.0 ) # - or do other useful work here if done: - bytes_transferred = sum(ret.futures_done.values()) + bytes_transferred_total = sum(ret.futures_done.values()) print( "Asynch transfer complete. Total bytes transferred:", - bytes_transferred, + bytes_transferred_total, file=sys.stderr, ) else: diff --git a/irods/session.py b/irods/session.py index eabb437e8..8eb51694a 100644 --- a/irods/session.py +++ b/irods/session.py @@ -2,12 +2,15 @@ import atexit import copy import errno +from io import BufferedRandom import json import logging from numbers import Number import os import threading +from typing import Iterable, Any, Optional import weakref + import irods.auth from irods.query import Query from irods.genquery2 import GenQuery2 @@ -29,12 +32,24 @@ from . import at_client_exit from . import DEFAULT_CONNECTION_TIMEOUT, MAXIMUM_CONNECTION_TIMEOUT -_fds = None +_fds : Optional[dict[BufferedRandom, Any]] = None _fds_lock = threading.Lock() _sessions = None _sessions_lock = threading.Lock() +def _exclude_fds_from_auto_close(descriptors: Iterable): + """Remove all descriptors from consideration for auto_close.""" + from irods.manager.data_object_manager import ManagedBufferedRandom + + with _fds_lock: + fds : dict[BufferedRandom, Any] = _fds or {} + for fd in descriptors: + fds.pop(fd, None) + if isinstance(fd, ManagedBufferedRandom): + fd.do_close = False + + def _cleanup_remaining_sessions(): for fd in list((_fds or {}).keys()): if not fd.closed: diff --git a/irods/test/data_obj_test.py b/irods/test/data_obj_test.py index 071771717..ad0f12334 100644 --- a/irods/test/data_obj_test.py +++ b/irods/test/data_obj_test.py @@ -3320,6 +3320,21 @@ def test_access_time__issue_700(self): # Test that access_time is there, and of the right type. self.assertIs(type(data.access_time), datetime) + def test_handling_of_termination_signals_during_multithread_get__issue_722(self): + from irods.test.modules.test_signal_handling_in_multithread_get import ( + test as test_get__issue_722, + ) + + test_get__issue_722(self) + + def test_handling_of_termination_signals_during_multithread_put__issue_722(self): + from irods.test.modules.test_signal_handling_in_multithread_put import ( + test as test_put__issue_722, + ) + + test_put__issue_722(self) + + if __name__ == "__main__": # let the tests find the parent irods lib sys.path.insert(0, os.path.abspath("../..")) diff --git a/irods/test/modules/test_signal_handling_in_multithread_get.py b/irods/test/modules/test_signal_handling_in_multithread_get.py new file mode 100644 index 000000000..cd42b29d9 --- /dev/null +++ b/irods/test/modules/test_signal_handling_in_multithread_get.py @@ -0,0 +1,127 @@ +import os +import re +import signal +import subprocess +import sys +import tempfile + +import irods.helpers +from irods.test import modules as test_modules +from irods.parallel import abort_parallel_transfers + +OBJECT_SIZE = 4 * 1024**3 +OBJECT_NAME = "data_get_issue__722" +LOCAL_TEMPFILE_NAME = "data_object_for_issue_722.dat" + + +def test(test_case, signal_names=("SIGTERM", "SIGINT")): + """Creates a child process executing a long get() and ensures the process can be + terminated using SIGINT or SIGTERM. + """ + from .tools import wait_till_true + + program = os.path.join(test_modules.__path__[0], os.path.basename(__file__)) + + for signal_name in signal_names: + + with test_case.subTest(f"Testing with signal {signal_name}"): + + # Call into this same module as a command. This will initiate another Python process that + # performs a lengthy data object "get" operation (see the main body of the script, below.) + process = subprocess.Popen( + [sys.executable, program], + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + ) + + # Wait for download process to reach the point of spawning data transfer threads. In Python 3.9+ versions + # of the concurrent.futures module, these are nondaemon threads and will block the exit of the main thread + # unless measures are taken. + localfile = process.stdout.readline().strip() + # Use timeout of 10 minutes for test transfer, which should be more than enough. + test_case.assertTrue( + wait_till_true( + lambda: os.path.exists(localfile) + and os.stat(localfile).st_size > OBJECT_SIZE // 2, + ), + "Parallel download from data_objects.get() probably experienced a fatal error before spawning auxiliary data transfer threads.", + ) + + sig = getattr(signal, signal_name) + + signal_offset_return_code = lambda s: 128 - s if s < 0 else s + signal_plus_128 = lambda sig: 128 + sig + + # Interrupt the subprocess with the given signal. + process.send_signal(sig) + + # Assert that this signal is what killed the subprocess, rather than a timed out process "wait" or a natural exit + # due to misproper or incomplete handling of the signal. + try: + translated_return_code = signal_offset_return_code( + process.wait(timeout=15) + ) + test_case.assertIn( + translated_return_code, + [1, signal_plus_128(sig)], + f"Expected subprocess return code of {signal_plus_128(sig) = }; got {translated_return_code = }", + ) + except subprocess.TimeoutExpired: + test_case.fail( + "Subprocess timed out before terminating. " + "Non-daemon thread(s) probably prevented subprocess's main thread from exiting." + ) + # Assert that in the case of SIGINT, the process registered a KeyboardInterrupt. + if sig == signal.SIGINT: + test_case.assertTrue( + re.search("KeyboardInterrupt", process.stderr.read()), + "Did not find expected string 'KeyboardInterrupt' in log output.", + ) + + +if __name__ == "__main__": + # These lines are run only if the module is launched as a process. + session = irods.helpers.make_session() + hc = irods.helpers.home_collection(session) + TESTFILE_FILL = b"_" * (1024 * 1024) + object_path = f"{hc}/{OBJECT_NAME}" + + # Create the object to be downloaded. + with session.data_objects.open(object_path, "w") as f: + for y in range(OBJECT_SIZE // len(TESTFILE_FILL)): + f.write(TESTFILE_FILL) + local_path = None + # Establish where (ie absolute path) to place the downloaded file, i.e. the get() target. + try: + with tempfile.NamedTemporaryFile( + prefix="local_file_issue_722.dat", delete=True + ) as t: + local_path = t.name + + # Tell the parent process the name of the local file, ie the result of the "get" from iRODS. + # That parent process is the unittest, which will use the filename to verify the threads are started + # and we're somewhere mid-transfer. + print(local_path) + sys.stdout.flush() + + def handler(sig, *_): + abort_parallel_transfers() + if sig == signal.SIGTERM: + os._exit(128 + sig) + + signal.signal(signal.SIGTERM, handler) + + try: + # download the object + session.data_objects.get(object_path, local_path) + except KeyboardInterrupt: + abort_parallel_transfers() + raise + + finally: + # Clean up, whether or not the download succeeded. + if local_path is not None and os.path.exists(local_path): + os.unlink(local_path) + if session.data_objects.exists(object_path): + session.data_objects.unlink(object_path, force=True) diff --git a/irods/test/modules/test_signal_handling_in_multithread_put.py b/irods/test/modules/test_signal_handling_in_multithread_put.py new file mode 100644 index 000000000..0fd4cf571 --- /dev/null +++ b/irods/test/modules/test_signal_handling_in_multithread_put.py @@ -0,0 +1,179 @@ +import datetime +import os +import re +import signal +import subprocess +import sys +import irods.helpers +from irods.session import iRODSSession +from irods.test.helpers import unique_name +from irods.test import modules as test_modules +from irods.parallel import abort_parallel_transfers + +OBJECT_SIZE = 4 * 1024**3 +LOCAL_TEMPFILE_NAME = "data_object_for_issue_722.dat" + + +def test(test_case, signal_names=("SIGTERM", "SIGINT")): + """Creates a child process executing a long put() and ensures the process can be terminated using SIGINT or SIGTERM.""" + from .tools import wait_till_true + + program = os.path.join(test_modules.__path__[0], os.path.basename(__file__)) + session = getattr(test_case, "sess", None) or irods.helpers.make_session() + + for signal_name in signal_names: + + with test_case.subTest(f"Testing with signal {signal_name}"): + + try: + # Call into this same module as a command. This will initiate another Python process that + # performs a lengthy data object "get" operation (see the main body of the script, below.) + process = subprocess.Popen( + # -k: Keep object around for replica status testing. + [sys.executable, program, "-k"], + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + ) + + # Wait for download process to reach the point of spawning data transfer threads. In Python 3.9+ versions + # of the concurrent.futures module, these are non-daemon threads and will block the exit of the main thread + # unless measures are taken. + logical_path = process.stdout.readline().strip() + + # Use timeout of 10 minutes for test transfer, which should be more than enough. + test_case.assertTrue( + wait_till_true( + lambda: session.data_objects.exists(logical_path) + and named_irods_data_object( + session, logical_path, delete=False + ).data.size + > OBJECT_SIZE // 2, + ), + "Parallel download from data_objects.put() probably experienced a fatal error before spawning auxiliary data transfer threads.", + ) + sig = getattr(signal, signal_name) + + signal_offset_return_code = lambda s: 128 - s if s < 0 else s + signal_plus_128 = lambda sig: 128 + sig + + # Interrupt the subprocess with the given signal. + process.send_signal(sig) + + # Assert that this signal is what killed the subprocess, rather than a timed out process "wait" or a natural exit + # due to misproper or incomplete handling of the signal. + try: + translated_return_code = signal_offset_return_code( + process.wait(timeout=15) + ) + test_case.assertIn( + translated_return_code, + [1, signal_plus_128(sig)], + f"Expected subprocess return code of {signal_plus_128(sig) = }; got {translated_return_code = }", + ) + except subprocess.TimeoutExpired: + test_case.fail( + "Subprocess timed out before terminating. " + "Non-daemon thread(s) probably prevented subprocess's main thread from exiting." + ) + + # Assert that in the case of SIGINT, the process registered a KeyboardInterrupt. + if sig == signal.SIGINT: + test_case.assertTrue( + re.search("KeyboardInterrupt", process.stderr.read()), + "Did not find expected string 'KeyboardInterrupt' in log output.", + ) + + # Assert that the status is left as not LOCKED. + test_case.assertTrue( + wait_till_true( + lambda: int( + session.data_objects.get(logical_path).replica_status + ) + < 2 + ) + ) + + finally: + if logical_path and ( + d := irods.helpers.get_data_object(session, logical_path) + ): + d.unlink(force=True) + + +class named_irods_data_object: + + def __init__(self, /, session: iRODSSession, path: str = "", delete: bool = True): + self.sess = session + self.delete = delete + if not path: + path = ( + irods.helpers.home_collection(session) + + "/" + + unique_name(datetime.datetime.now()) + ) + self.path = path + + @property + def data(self): + return irods.helpers.get_data_object(self.sess, self.path) + + def __del__(self): + if self.delete: + self.remove() + + def remove(self): + if d := self.data: + d.unlink(force=True) + + def create(self): + self.sess.data_objects.create(self.path) + return self + + +if __name__ == "__main__": + import getopt + + opts, _ = getopt.getopt(sys.argv[1:], "k") + keep_data_object = "-k" in (_[0] for _ in opts) + + # These lines are run only if the module is launched as a process. + test_session = irods.helpers.make_session() + hc = irods.helpers.home_collection(test_session) + TESTFILE_FILL = b"_" * (1024 * 1024) + + object_path = named_irods_data_object(test_session, delete=True).create().path + local_path = object_path.split("/")[-1] + + # Create the object to uploaded. + with open(local_path, "wb") as f: + for y in range(OBJECT_SIZE // len(TESTFILE_FILL)): + f.write(TESTFILE_FILL) + + try: + # Tell the parent process the name of the data object logical path, the target of the "put" to iRODS. + # That parent process is the unittest, which will use the logical path to verify the threads are started + # and we're somewhere mid-transfer. + print(object_path) + sys.stdout.flush() + + def handler(sig, *_): + abort_parallel_transfers() + if sig == signal.SIGTERM: + os._exit(128 + sig) + + signal.signal(signal.SIGTERM, handler) + + try: + # Upload the object + test_session.data_objects.put(local_path, object_path) + except KeyboardInterrupt: + abort_parallel_transfers() + raise + + finally: + # Clean up, whether or not the upload succeeded. + if local_path is not None and os.path.exists(local_path): + os.unlink(local_path) + if not keep_data_object: + named_irods_data_object(test_session, path=object_path, delete=True) diff --git a/irods/test/modules/tools.py b/irods/test/modules/tools.py new file mode 100644 index 000000000..65272b55e --- /dev/null +++ b/irods/test/modules/tools.py @@ -0,0 +1,31 @@ +import time + +_clock_polling_interval = max(0.01, time.clock_getres(time.CLOCK_BOOTTIME)) + +LARGE_TEST_TIMEOUT = 10 * 60.0 # ten minutes. + + +def wait_till_true(callback, timeout=LARGE_TEST_TIMEOUT, msg=""): + """Wait for test purposes until a condition becomes true , as determined by the + return value of the provided test function. + + By default, we wait at most LARGE_TEST_TIMEOUT seconds for the callback function to return true, + and then quit or time out. Alternatively, a timeout of None translates as a request not to time out. + + If the msg value passed in is a nonzero-length string, it can be used to raise a timeout exception; + otherwise timing out causes a normal exit, relaying as the return value the last value returned + from the test callback function. + """ + start_time = time.clock_gettime_ns(time.CLOCK_BOOTTIME) + while not (truth_value := callback()): + if ( + timeout is not None + and (time.clock_gettime_ns(time.CLOCK_BOOTTIME) - start_time) * 1e-9 + > timeout + ): + if msg: + raise TimeoutError(msg) + else: + break + time.sleep(_clock_polling_interval) + return truth_value