diff --git a/examples/cuda/sepconv_parallel.py b/examples/cuda/sepconv_parallel.py new file mode 100644 index 00000000..074200e1 --- /dev/null +++ b/examples/cuda/sepconv_parallel.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +import numpy +from kernel_tuner import tune_kernel +from collections import OrderedDict + + +def tune(): + with open("convolution.cu", "r") as f: + kernel_string = f.read() + + # setup tunable parameters + tune_params = OrderedDict() + tune_params["filter_height"] = [i for i in range(3, 19, 2)] + tune_params["filter_width"] = [i for i in range(3, 19, 2)] + tune_params["block_size_x"] = [16 * i for i in range(1, 65)] + tune_params["block_size_y"] = [2**i for i in range(6)] + tune_params["tile_size_x"] = [i for i in range(1, 11)] + tune_params["tile_size_y"] = [i for i in range(1, 11)] + + tune_params["use_padding"] = [0, 1] # toggle the insertion of padding in shared memory + tune_params["read_only"] = [0, 1] # toggle using the read-only cache + + # limit the search to only use padding when its effective, and at least 32 threads in a block + restrict = ["use_padding==0 or (block_size_x % 32 != 0)", "block_size_x*block_size_y >= 32"] + + # setup input and output dimensions + problem_size = (4096, 4096) + size = numpy.prod(problem_size) + largest_fh = max(tune_params["filter_height"]) + largest_fw = max(tune_params["filter_width"]) + input_size = (problem_size[0] + largest_fw - 1) * (problem_size[1] + largest_fh - 1) + + # create input data + output_image = numpy.zeros(size).astype(numpy.float32) + input_image = numpy.random.randn(input_size).astype(numpy.float32) + filter_weights = numpy.random.randn(largest_fh * largest_fw).astype(numpy.float32) + + # setup kernel arguments + cmem_args = {"d_filter": filter_weights} + args = [output_image, input_image, filter_weights] + + # tell the Kernel Tuner how to compute grid dimensions + grid_div_x = ["block_size_x", "tile_size_x"] + grid_div_y = ["block_size_y", "tile_size_y"] + + # start tuning separable convolution (row) + tune_params["filter_height"] = [1] + tune_params["tile_size_y"] = [1] + results_row = tune_kernel( + "convolution_kernel", + kernel_string, + problem_size, + args, + tune_params, + grid_div_y=grid_div_y, + grid_div_x=grid_div_x, + cmem_args=cmem_args, + verbose=False, + restrictions=restrict, + parallel_runner=1024, + cache="convolution_kernel_row", + ) + + # start tuning separable convolution (col) + tune_params["filter_height"] = tune_params["filter_width"][:] + tune_params["file_size_y"] = tune_params["tile_size_x"][:] + tune_params["filter_width"] = [1] + tune_params["tile_size_x"] = [1] + results_col = tune_kernel( + "convolution_kernel", + kernel_string, + problem_size, + args, + tune_params, + grid_div_y=grid_div_y, + grid_div_x=grid_div_x, + cmem_args=cmem_args, + verbose=False, + restrictions=restrict, + parallel_runner=1024, + cache="convolution_kernel_col", + ) + + return results_row, results_col + + +if __name__ == "__main__": + results_row, results_col = tune() diff --git a/examples/cuda/vector_add_parallel.py b/examples/cuda/vector_add_parallel.py new file mode 100644 index 00000000..8d35ce7c --- /dev/null +++ b/examples/cuda/vector_add_parallel.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python + +import numpy +from kernel_tuner import tune_kernel + + +def tune(): + kernel_string = """ + __global__ void vector_add(float *c, float *a, float *b, int n) { + int i = (blockIdx.x * block_size_x) + threadIdx.x; + if ( i < n ) { + c[i] = a[i] + b[i]; + } + } + """ + + size = 10000000 + + a = numpy.random.randn(size).astype(numpy.float32) + b = numpy.random.randn(size).astype(numpy.float32) + c = numpy.zeros_like(b) + n = numpy.int32(size) + + args = [c, a, b, n] + + tune_params = dict() + tune_params["block_size_x"] = [32 * i for i in range(1, 33)] + + results, env = tune_kernel("vector_add", kernel_string, size, args, tune_params, parallel_workers=True) + print(env) + return results + + +if __name__ == "__main__": + tune() diff --git a/kernel_tuner/interface.py b/kernel_tuner/interface.py index 32e91c86..053f71f2 100644 --- a/kernel_tuner/interface.py +++ b/kernel_tuner/interface.py @@ -39,8 +39,6 @@ import kernel_tuner.util as util from kernel_tuner.file_utils import get_input_file, get_t4_metadata, get_t4_results, import_class_from_file from kernel_tuner.integration import get_objective_defaults -from kernel_tuner.runners.sequential import SequentialRunner -from kernel_tuner.runners.simulation import SimulationRunner from kernel_tuner.searchspace import Searchspace try: @@ -473,6 +471,7 @@ def __deepcopy__(self, _): ), ("metrics", ("specifies user-defined metrics, please see :ref:`metrics`.", "dict")), ("simulation_mode", ("Simulate an auto-tuning search from an existing cachefile", "bool")), + ("parallel_workers", ("Set to `True` or an integer to enable parallel tuning. If set to an integer, this will be the number of parallel workers.", "int|bool")), ("observers", ("""A list of Observers to use during tuning, please see :ref:`observers`.""", "list")), ] ) @@ -584,6 +583,7 @@ def tune_kernel( cache=None, metrics=None, simulation_mode=False, + parallel_workers=None, observers=None, objective=None, objective_higher_is_better=None, @@ -651,9 +651,22 @@ def tune_kernel( strategy = brute_force # select the runner for this job based on input - selected_runner = SimulationRunner if simulation_mode else SequentialRunner + # TODO: we could use the "match case" syntax when removing support for 3.9 tuning_options.simulated_time = 0 - runner = selected_runner(kernelsource, kernel_options, device_options, iterations, observers) + + if parallel_workers and simulation_mode: + raise ValueError("Enabling `parallel_workers` and `simulation_mode` together is not supported") + elif simulation_mode: + from kernel_tuner.runners.simulation import SimulationRunner + runner = SimulationRunner(kernelsource, kernel_options, device_options, iterations, observers) + elif parallel_workers: + from kernel_tuner.runners.parallel import ParallelRunner + num_workers = None if parallel_workers is True else parallel_workers + runner = ParallelRunner(kernelsource, kernel_options, device_options, tuning_options, iterations, observers, num_workers=num_workers) + else: + from kernel_tuner.runners.sequential import SequentialRunner + runner = SequentialRunner(kernelsource, kernel_options, device_options, iterations, observers) + # the user-specified function may or may not have an optional atol argument; # we normalize it so that it always accepts atol. @@ -669,16 +682,20 @@ def preprocess_cache(filepath): # process cache if cache: cache = preprocess_cache(cache) - util.process_cache(cache, kernel_options, tuning_options, runner) + tuning_options.cachefile = cache + tuning_options.cache = util.process_cache(cache, kernel_options, tuning_options, runner) else: - tuning_options.cache = {} tuning_options.cachefile = None + tuning_options.cache = {} # create search space tuning_options.restrictions_unmodified = deepcopy(restrictions) - searchspace = Searchspace(tune_params, restrictions, runner.dev.max_threads, **searchspace_construction_options) + device_info = runner.get_device_info() + searchspace = Searchspace(tune_params, restrictions, device_info.max_threads, **searchspace_construction_options) + restrictions = searchspace._modified_restrictions tuning_options.restrictions = restrictions + if verbose: print(f"Searchspace has {searchspace.size} configurations after restrictions.") @@ -696,6 +713,9 @@ def preprocess_cache(filepath): results = strategy.tune(searchspace, runner, tuning_options) env = runner.get_environment(tuning_options) + # Shut down the runner + runner.shutdown() + # finished iterating over search space if results: # checks if results is not empty best_config = util.get_best_config(results, objective, objective_higher_is_better) diff --git a/kernel_tuner/runners/parallel.py b/kernel_tuner/runners/parallel.py new file mode 100644 index 00000000..a05fc2fa --- /dev/null +++ b/kernel_tuner/runners/parallel.py @@ -0,0 +1,260 @@ +"""A specialized runner that tunes in parallel the parameter space.""" +import logging +import socket +from time import perf_counter +from kernel_tuner.core import DeviceInterface +from kernel_tuner.interface import Options +from kernel_tuner.runners.runner import Runner +from kernel_tuner.util import ErrorConfig, print_config_output, process_metrics, store_cache +from datetime import datetime, timezone + +logger = logging.getLogger(__name__) + +try: + import ray +except ImportError as e: + raise ImportError(f"unable to initialize the parallel runner: {e}") from e + + +@ray.remote(num_gpus=1) +class DeviceActor: + def __init__(self, kernel_source, kernel_options, device_options, tuning_options, iterations, observers): + # detect language and create high-level device interface + self.dev = DeviceInterface(kernel_source, iterations=iterations, observers=observers, **device_options) + + self.units = self.dev.units + self.quiet = device_options.quiet + self.kernel_source = kernel_source + self.warmed_up = False if self.dev.requires_warmup else True + self.start_time = perf_counter() + self.last_strategy_start_time = self.start_time + self.last_strategy_time = 0 + self.kernel_options = kernel_options + self.tuning_options = tuning_options + + # move data to the GPU + self.gpu_args = self.dev.ready_argument_list(kernel_options.arguments) + + def shutdown(self): + ray.actor.exit_actor() + + def get_environment(self): + # Get the device properties + env = dict(self.dev.get_environment()) + + # Get the host name + env["host_name"] = socket.gethostname() + + # Get info about the ray instance + ctx = ray.get_runtime_context() + env["ray"] = { + "node_id": ctx.get_node_id(), + "worker_id": ctx.get_worker_id(), + "actor_id": ctx.get_actor_id(), + } + + return env + + def run(self, element): + # TODO: logging.debug("sequential runner started for " + self.kernel_options.kernel_name) + objective = self.tuning_options.objective + metrics = self.tuning_options.metrics + + params = dict(element) + result = None + warmup_time = 0 + + # attempt to warmup the GPU by running the first config in the parameter space and ignoring the result + if not self.warmed_up: + warmup_time = perf_counter() + self.dev.compile_and_benchmark( + self.kernel_source, self.gpu_args, params, self.kernel_options, self.tuning_options + ) + self.warmed_up = True + warmup_time = 1e3 * (perf_counter() - warmup_time) + + result = self.dev.compile_and_benchmark( + self.kernel_source, self.gpu_args, params, self.kernel_options, self.tuning_options + ) + + if isinstance(result.get(objective), ErrorConfig): + logging.debug("kernel configuration was skipped silently due to compile or runtime failure") + + params.update(result) + + # only compute metrics on configs that have not errored + if metrics and not isinstance(params.get(objective), ErrorConfig): + params = process_metrics(params, metrics) + + # get the framework time by estimating based on other times + total_time = 1000 * ((perf_counter() - self.start_time) - warmup_time) + params["strategy_time"] = self.last_strategy_time + params["framework_time"] = max( + total_time + - ( + params["compile_time"] + + params["verification_time"] + + params["benchmark_time"] + + params["strategy_time"] + ), + 0, + ) + + params["timestamp"] = datetime.now(timezone.utc).isoformat() + params["ray_actor_id"] = ray.get_runtime_context().get_actor_id() + params["host_name"] = socket.gethostname() + + self.start_time = perf_counter() + + # all visited configurations are added to results to provide a trace for optimization strategies + return params + + +class DeviceActorState: + def __init__(self, actor): + self.actor = actor + self.running_jobs = [] + self.maximum_running_jobs = 1 + self.is_running = True + self.env = ray.get(actor.get_environment.remote()) + + def __repr__(self): + actor_id = self.env["ray"]["actor_id"] + host_name = self.env["host_name"] + return f"{actor_id} ({host_name})" + + def shutdown(self): + if not self.is_running: + return + + self.is_running = False + + try: + self.actor.shutdown.remote() + except Exception: + logger.exception("Failed to request actor shutdown: %s", self) + + def submit(self, config): + logger.info(f"jobs submitted to worker {self}: {config}") + job = self.actor.run.remote(config) + self.running_jobs.append(job) + return job + + def is_available(self): + if not self.is_running: + return False + + # Check for ready jobs, but do not block + ready_jobs, self.running_jobs = ray.wait(self.running_jobs, timeout=0) + ray.get(ready_jobs) + + # Available if this actor can now run another job + return len(self.running_jobs) < self.maximum_running_jobs + + +class ParallelRunner(Runner): + def __init__(self, kernel_source, kernel_options, device_options, tuning_options, iterations, observers, num_workers=None): + if not ray.is_initialized(): + ray.init() + + if num_workers is None: + num_workers = int(ray.cluster_resources().get("GPU", 0)) + + if num_workers == 0: + raise RuntimeError("failed to initialize parallel runner: no GPUs found") + + if num_workers < 1: + raise RuntimeError(f"failed to initialize parallel runner: invalid number of GPUs specified: {num_workers}") + + self.workers = [] + + try: + for index in range(num_workers): + actor = DeviceActor.remote(kernel_source, kernel_options, device_options, tuning_options, iterations, observers) + worker = DeviceActorState(actor) + self.workers.append(worker) + + logger.info(f"launched worker {index}: {worker}") + except: + # If an exception occurs, shut down the worker + self.shutdown() + raise + + # Check if all workers have the same device + device_names = {w.env.get("device_name") for w in self.workers} + if len(device_names) != 1: + self.shutdown() + raise RuntimeError( + f"failed to initialize parallel runner: workers have different devices: {sorted(device_names)}" + ) + + self.device_name = device_names.pop() + + # TODO: Get this from the device + self.units = {"time": "ms"} + self.quiet = device_options.quiet + + def get_device_info(self): + return Options({"max_threads": 1024}) + + def get_environment(self, tuning_options): + return { + "device_name": self.device_name, + "workers": [w.env for w in self.workers] + } + + def shutdown(self): + for worker in self.workers: + try: + worker.shutdown() + except Exception as err: + logger.exception("error while shutting down worker {worker}") + + def submit_job(self, *args): + while True: + # Round-robin: first available worker gets the job and goes to the back of the list + for i, worker in enumerate(list(self.workers)): + if worker.is_available(): + self.workers.pop(i) + self.workers.append(worker) + return worker.submit(*args) + + # Gather all running jobs + running_jobs = [job for w in self.workers for job in w.running_jobs] + + # If there are no running jobs, then something must be wrong. + # Maybe a worker has crashed or gotten into an invalid state. + if not running_jobs: + raise RuntimeError("invalid state: no Ray workers are available to run job") + + # Wait until any running job completes + ray.wait(running_jobs, num_returns=1) + + def run(self, parameter_space, tuning_options): + running_jobs = dict() + completed_jobs = dict() + + # Submit jobs which are not in the cache + for config in parameter_space: + params = dict(zip(tuning_options.tune_params.keys(), config)) + key = ",".join([str(i) for i in config]) + + if key in tuning_options.cache: + completed_jobs[key] = tuning_options.cache[key] + else: + assert key not in running_jobs + running_jobs[key] = self.submit_job(params) + completed_jobs[key] = None + + # Wait for the running jobs to finish + for key, job in running_jobs.items(): + result = ray.get(job) + completed_jobs[key] = result + + # print configuration to the console + print_config_output(tuning_options.tune_params, result, self.quiet, tuning_options.metrics, self.units) + + # add configuration to cache + store_cache(key, result, tuning_options.cachefile, tuning_options.cache) + + return list(completed_jobs.values()) diff --git a/kernel_tuner/runners/runner.py b/kernel_tuner/runners/runner.py index 80ab3214..3a886ad1 100644 --- a/kernel_tuner/runners/runner.py +++ b/kernel_tuner/runners/runner.py @@ -13,8 +13,15 @@ def __init__( ): pass + def shutdown(self): + pass + + @abstractmethod + def get_device_info(self): + pass + @abstractmethod - def get_environment(self): + def get_environment(self, tuning_options): pass @abstractmethod diff --git a/kernel_tuner/runners/sequential.py b/kernel_tuner/runners/sequential.py index 5e53093b..2bd554bf 100644 --- a/kernel_tuner/runners/sequential.py +++ b/kernel_tuner/runners/sequential.py @@ -20,15 +20,13 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob :param kernel_options: A dictionary with all options for the kernel. :type kernel_options: kernel_tuner.interface.Options - :param device_options: A dictionary with all options for the device - on which the kernel should be tuned. + :param device_options: A dictionary with all options for the device on which the kernel should be tuned. :type device_options: kernel_tuner.interface.Options - :param iterations: The number of iterations used for benchmarking - each kernel instance. + :param iterations: The number of iterations used for benchmarking each kernel instance. :type iterations: int """ - #detect language and create high-level device interface + # detect language and create high-level device interface self.dev = DeviceInterface(kernel_source, iterations=iterations, observers=observers, **device_options) self.units = self.dev.units @@ -41,9 +39,12 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob self.last_strategy_time = 0 self.kernel_options = kernel_options - #move data to the GPU + # move data to the GPU self.gpu_args = self.dev.ready_argument_list(kernel_options.arguments) + def get_device_info(self): + return self.dev + def get_environment(self, tuning_options): return self.dev.get_environment() @@ -53,16 +54,14 @@ def run(self, parameter_space, tuning_options): :param parameter_space: The parameter space as an iterable. :type parameter_space: iterable - :param tuning_options: A dictionary with all options regarding the tuning - process. - :type tuning_options: kernel_tuner.iterface.Options + :param tuning_options: A dictionary with all options regarding the tuning process. + :type tuning_options: kernel_tuner.interface.Options - :returns: A list of dictionaries for executed kernel configurations and their - execution times. - :rtype: dict()) + :returns: A list of dictionaries for executed kernel configurations and their execution times. + :rtype: dict() """ - logging.debug('sequential runner started for ' + self.kernel_options.kernel_name) + logging.debug("sequential runner started for " + self.kernel_options.kernel_name) results = [] @@ -77,33 +76,46 @@ def run(self, parameter_space, tuning_options): x_int = ",".join([str(i) for i in element]) if tuning_options.cache and x_int in tuning_options.cache: params.update(tuning_options.cache[x_int]) - params['compile_time'] = 0 - params['verification_time'] = 0 - params['benchmark_time'] = 0 + params["compile_time"] = 0 + params["verification_time"] = 0 + params["benchmark_time"] = 0 else: # attempt to warmup the GPU by running the first config in the parameter space and ignoring the result if not self.warmed_up: warmup_time = perf_counter() - self.dev.compile_and_benchmark(self.kernel_source, self.gpu_args, params, self.kernel_options, tuning_options) + self.dev.compile_and_benchmark( + self.kernel_source, self.gpu_args, params, self.kernel_options, tuning_options + ) self.warmed_up = True warmup_time = 1e3 * (perf_counter() - warmup_time) - result = self.dev.compile_and_benchmark(self.kernel_source, self.gpu_args, params, self.kernel_options, tuning_options) + result = self.dev.compile_and_benchmark( + self.kernel_source, self.gpu_args, params, self.kernel_options, tuning_options + ) params.update(result) if tuning_options.objective in result and isinstance(result[tuning_options.objective], ErrorConfig): - logging.debug('kernel configuration was skipped silently due to compile or runtime failure') + logging.debug("kernel configuration was skipped silently due to compile or runtime failure") # only compute metrics on configs that have not errored if tuning_options.metrics and not isinstance(params.get(tuning_options.objective), ErrorConfig): params = process_metrics(params, tuning_options.metrics) # get the framework time by estimating based on other times - total_time = 1000 * ((perf_counter() - self.start_time) - warmup_time) - params['strategy_time'] = self.last_strategy_time - params['framework_time'] = max(total_time - (params['compile_time'] + params['verification_time'] + params['benchmark_time'] + params['strategy_time']), 0) - params['timestamp'] = str(datetime.now(timezone.utc)) + total_time = 1000 * ((perf_counter() - self.start_time) - warmup_time) + params["strategy_time"] = self.last_strategy_time + params["framework_time"] = max( + total_time + - ( + params["compile_time"] + + params["verification_time"] + + params["benchmark_time"] + + params["strategy_time"] + ), + 0, + ) + params["timestamp"] = str(datetime.now(timezone.utc)) self.start_time = perf_counter() if result: @@ -111,7 +123,7 @@ def run(self, parameter_space, tuning_options): print_config_output(tuning_options.tune_params, params, self.quiet, tuning_options.metrics, self.units) # add configuration to cache - store_cache(x_int, params, tuning_options) + store_cache(x_int, params, tuning_options.cachefile, tuning_options.cache) # all visited configurations are added to results to provide a trace for optimization strategies results.append(params) diff --git a/kernel_tuner/runners/simulation.py b/kernel_tuner/runners/simulation.py index 9695879d..b369b85a 100644 --- a/kernel_tuner/runners/simulation.py +++ b/kernel_tuner/runners/simulation.py @@ -16,11 +16,11 @@ class SimulationDevice(_SimulationDevice): @property def name(self): - return self.env['device_name'] + return self.env["device_name"] @name.setter def name(self, value): - self.env['device_name'] = value + self.env["device_name"] = value if not self.quiet: print("Simulating: " + value) @@ -40,12 +40,10 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob :param kernel_options: A dictionary with all options for the kernel. :type kernel_options: kernel_tuner.interface.Options - :param device_options: A dictionary with all options for the device - on which the kernel should be tuned. + :param device_options: A dictionary with all options for the device on which the kernel should be tuned. :type device_options: kernel_tuner.interface.Options - :param iterations: The number of iterations used for benchmarking - each kernel instance. + :param iterations: The number of iterations used for benchmarking each kernel instance. :type iterations: int """ self.quiet = device_options.quiet @@ -60,6 +58,9 @@ def __init__(self, kernel_source, kernel_options, device_options, iterations, ob self.last_strategy_time = 0 self.units = {} + def get_device_info(self): + return self.dev + def get_environment(self, tuning_options): env = self.dev.get_environment() env["simulation"] = True @@ -72,21 +73,18 @@ def run(self, parameter_space, tuning_options): :param parameter_space: The parameter space as an iterable. :type parameter_space: iterable - :param tuning_options: A dictionary with all options regarding the tuning - process. + :param tuning_options: A dictionary with all options regarding the tuning process. :type tuning_options: kernel_tuner.iterface.Options - :returns: A list of dictionaries for executed kernel configurations and their - execution times. + :returns: A list of dictionaries for executed kernel configurations and their execution times. :rtype: dict() """ - logging.debug('simulation runner started for ' + self.kernel_options.kernel_name) + logging.debug("simulation runner started for " + self.kernel_options.kernel_name) results = [] - # iterate over parameter space + # iterate over parameter space for element in parameter_space: - # check if element is in the cache x_int = ",".join([str(i) for i in element]) if tuning_options.cache and x_int in tuning_options.cache: @@ -105,21 +103,22 @@ def run(self, parameter_space, tuning_options): # configuration is already counted towards the unique_results. # It is the responsibility of cost_func to add configs to unique_results. if x_int in tuning_options.unique_results: - - result['compile_time'] = 0 - result['verification_time'] = 0 - result['benchmark_time'] = 0 + result["compile_time"] = 0 + result["verification_time"] = 0 + result["benchmark_time"] = 0 else: # configuration is evaluated for the first time, print to the console - util.print_config_output(tuning_options.tune_params, result, self.quiet, tuning_options.metrics, self.units) + util.print_config_output( + tuning_options.tune_params, result, self.quiet, tuning_options.metrics, self.units + ) # Everything but the strategy time and framework time are simulated, # self.last_strategy_time is set by cost_func - result['strategy_time'] = self.last_strategy_time + result["strategy_time"] = self.last_strategy_time try: - simulated_time = result['compile_time'] + result['verification_time'] + result['benchmark_time'] + simulated_time = result["compile_time"] + result["verification_time"] + result["benchmark_time"] tuning_options.simulated_time += simulated_time except KeyError: if "time_limit" in tuning_options: @@ -129,7 +128,7 @@ def run(self, parameter_space, tuning_options): total_time = 1000 * (perf_counter() - self.start_time) self.start_time = perf_counter() - result['framework_time'] = total_time - self.last_strategy_time + result["framework_time"] = total_time - self.last_strategy_time results.append(result) continue diff --git a/kernel_tuner/strategies/bayes_opt.py b/kernel_tuner/strategies/bayes_opt.py index a814e7ce..31a68ca8 100644 --- a/kernel_tuner/strategies/bayes_opt.py +++ b/kernel_tuner/strategies/bayes_opt.py @@ -455,6 +455,33 @@ def fit_observations_to_model(self): """Update the model based on the current list of observations.""" self.__model.fit(self.__valid_params, self.__valid_observations) + def evaluate_parallel_objective_function(self, param_configs: list[tuple]) -> list[float]: + """Evaluates the objective function for multiple configurations in parallel.""" + results = [] + valid_param_configs = [] + valid_indices = [] + + # Extract the valid configurations + for param_config in param_configs: + param_config = self.unprune_param_config(param_config) + denormalized_param_config = self.denormalize_param_config(param_config) + if not self.__searchspace_obj.is_param_config_valid(denormalized_param_config): + results.append(self.invalid_value) + else: + valid_indices.append(len(results)) + results.append(None) + valid_param_configs.append(param_config) + + # Run valid configurations in parallel + scores = self.cost_func.run_all(valid_param_configs) + + # Put the scores at the right location in the result + for idx, score in zip(valid_indices, scores): + results[idx] = score + + self.fevals += len(scores) + return results + def evaluate_objective_function(self, param_config: tuple) -> float: """Evaluates the objective function.""" param_config = self.unprune_param_config(param_config) diff --git a/kernel_tuner/strategies/common.py b/kernel_tuner/strategies/common.py index 9ffe999b..e567d998 100644 --- a/kernel_tuner/strategies/common.py +++ b/kernel_tuner/strategies/common.py @@ -103,6 +103,12 @@ def __init__( def __call__(self, x, check_restrictions=True): + return self.run_one(x, check_restrictions=check_restrictions) + + def run_all(self, xs, check_restrictions=True): + return [self.run_one(x, check_restrictions=check_restrictions) for x in xs] + + def run_one(self, x, check_restrictions=True): """Cost function used by almost all strategies.""" self.runner.last_strategy_time = 1000 * (perf_counter() - self.runner.last_strategy_start_time) diff --git a/kernel_tuner/strategies/diff_evo.py b/kernel_tuner/strategies/diff_evo.py index d80b6e8e..6350b7d9 100644 --- a/kernel_tuner/strategies/diff_evo.py +++ b/kernel_tuner/strategies/diff_evo.py @@ -140,7 +140,7 @@ def differential_evolution(searchspace, cost_func, bounds, popsize, maxiter, F, population[0] = cost_func.get_start_pos() # Calculate the initial cost for each individual in the population - population_cost = np.array([cost_func(ind) for ind in population]) + population_cost = np.array(cost_func.run_all(population)) # Keep track of the best solution found so far best_idx = np.argmin(population_cost) @@ -209,7 +209,7 @@ def differential_evolution(searchspace, cost_func, bounds, popsize, maxiter, F, # --- c. Selection --- # Calculate the cost of the new trial vectors - trial_population_cost = np.array([cost_func(ind) for ind in trial_population]) + trial_population_cost = np.array(cost_func.run_all(trial_population)) # Keep track of whether population changes over time no_change = True diff --git a/kernel_tuner/strategies/firefly_algorithm.py b/kernel_tuner/strategies/firefly_algorithm.py index a732d404..861c5f86 100644 --- a/kernel_tuner/strategies/firefly_algorithm.py +++ b/kernel_tuner/strategies/firefly_algorithm.py @@ -44,13 +44,14 @@ def tune(searchspace: Searchspace, runner, tuning_options): swarm[0].position = x0 # compute initial intensities - for j in range(num_particles): - try: + try: + for j in range(num_particles): swarm[j].compute_intensity(cost_func) - except StopCriterionReached as e: - if tuning_options.verbose: - print(e) - return cost_func.results + except StopCriterionReached as e: + if tuning_options.verbose: + print(e) + return cost_func.results + for j in range(num_particles): if swarm[j].score <= best_score_global: best_position_global = swarm[j].position best_score_global = swarm[j].score diff --git a/kernel_tuner/strategies/genetic_algorithm.py b/kernel_tuner/strategies/genetic_algorithm.py index 804758ee..230cfd49 100644 --- a/kernel_tuner/strategies/genetic_algorithm.py +++ b/kernel_tuner/strategies/genetic_algorithm.py @@ -43,19 +43,17 @@ def tune(searchspace: Searchspace, runner, tuning_options): # determine fitness of population members weighted_population = [] - for dna in population: - try: - # if we are not constraint-aware we should check restrictions upon evaluation - time = cost_func(dna, check_restrictions=not constraint_aware) - num_evaluated += 1 - except StopCriterionReached as e: - if tuning_options.verbose: - print(e) - return cost_func.results - - weighted_population.append((dna, time)) + try: + # if we are not constraint-aware we should check restrictions upon evaluation + times = cost_func.run_all(population, check_restrictions=not constraint_aware) + num_evaluated += len(population) + except StopCriterionReached as e: + if tuning_options.verbose: + print(e) + return cost_func.results # population is sorted such that better configs have higher chance of reproducing + weighted_population = list(zip(population, times)) weighted_population.sort(key=lambda x: x[1]) # 'best_score' is used only for printing diff --git a/kernel_tuner/strategies/hillclimbers.py b/kernel_tuner/strategies/hillclimbers.py index ccd4eebf..cc53d7db 100644 --- a/kernel_tuner/strategies/hillclimbers.py +++ b/kernel_tuner/strategies/hillclimbers.py @@ -72,33 +72,39 @@ def base_hillclimb(base_sol: tuple, neighbor_method: str, max_fevals: int, searc if randomize: random.shuffle(indices) + children = [] + # in each dimension see the possible values for index in indices: neighbors = searchspace.get_param_neighbors(tuple(child), index, neighbor_method, randomize) # for each value in this dimension for val in neighbors: - orig_val = child[index] + child = list(child) child[index] = val + children.append(child) + if restart: + for child in children: # get score for this position score = cost_func(child, check_restrictions=False) - # generalize this to other tuning objectives if score < best_score: best_score = score base_sol = child[:] found_improved = True - if restart: - break - else: - child[index] = orig_val + break + else: + # get score for all positions in parallel + scores = cost_func.run_all(children, check_restrictions=False) - fevals = len(tuning_options.unique_results) - if fevals >= max_fevals: - return base_sol + for child, score in zip(children, scores): + if score < best_score: + best_score = score + base_sol = child[:] + found_improved = True - if found_improved and restart: - break + if found_improved and restart: + break return base_sol diff --git a/kernel_tuner/strategies/pso.py b/kernel_tuner/strategies/pso.py index e8489d12..4e38aa31 100644 --- a/kernel_tuner/strategies/pso.py +++ b/kernel_tuner/strategies/pso.py @@ -51,24 +51,26 @@ def tune(searchspace: Searchspace, runner, tuning_options): if tuning_options.verbose: print("start iteration ", i, "best time global", best_score_global) + try: + scores = cost_func.run_all([p.position for p in swarm]) + except StopCriterionReached as e: + if tuning_options.verbose: + print(e) + return cost_func.results + # evaluate particle positions - for j in range(num_particles): - try: - swarm[j].evaluate(cost_func) - except StopCriterionReached as e: - if tuning_options.verbose: - print(e) - return cost_func.results + for p, score in zip(swarm, scores): + p.set_score(score) # update global best if needed - if swarm[j].score <= best_score_global: - best_position_global = swarm[j].position - best_score_global = swarm[j].score + if score <= best_score_global: + best_position_global = p.position + best_score_global = score # update particle velocities and positions - for j in range(0, num_particles): - swarm[j].update_velocity(best_position_global, w, c1, c2) - swarm[j].update_position(bounds) + for p in swarm: + p.update_velocity(best_position_global, w, c1, c2) + p.update_position(bounds) if tuning_options.verbose: print("Final result:") @@ -92,7 +94,10 @@ def __init__(self, bounds): self.score = sys.float_info.max def evaluate(self, cost_func): - self.score = cost_func(self.position) + self.set_score(cost_func(self.position)) + + def set_score(self, score): + self.score = score # update best_pos if needed if self.score < self.best_score: self.best_pos = self.position diff --git a/kernel_tuner/strategies/random_sample.py b/kernel_tuner/strategies/random_sample.py index 33b5075d..4efe8615 100644 --- a/kernel_tuner/strategies/random_sample.py +++ b/kernel_tuner/strategies/random_sample.py @@ -20,16 +20,13 @@ def tune(searchspace: Searchspace, runner, tuning_options): num_samples = min(tuning_options.max_fevals, searchspace.size) samples = searchspace.get_random_sample(num_samples) - cost_func = CostFunc(searchspace, tuning_options, runner) - for sample in samples: - try: - cost_func(sample, check_restrictions=False) - except StopCriterionReached as e: - if tuning_options.verbose: - print(e) - return cost_func.results + try: + cost_func.run_all(samples, check_restrictions=False) + except StopCriterionReached as e: + if tuning_options.verbose: + print(e) return cost_func.results diff --git a/kernel_tuner/util.py b/kernel_tuner/util.py index 2d9e3f1b..635c6de7 100644 --- a/kernel_tuner/util.py +++ b/kernel_tuner/util.py @@ -1152,7 +1152,7 @@ def check_matching_problem_size(cached_problem_size, problem_size): if cached_problem_size_arr.size != problem_size_arr.size or not (cached_problem_size_arr == problem_size_arr).all(): raise ValueError(f"Cannot load cache which contains results for different problem_size, cache: {cached_problem_size}, requested: {problem_size}") -def process_cache(cache, kernel_options, tuning_options, runner): +def process_cache(cachefile, kernel_options, tuning_options, runner): """Cache file for storing tuned configurations. the cache file is stored using JSON and uses the following format: @@ -1181,9 +1181,9 @@ def process_cache(cache, kernel_options, tuning_options, runner): raise ValueError("Caching only works correctly when tunable parameters are stored in a dictionary") # if file does not exist, create new cache - if not os.path.isfile(cache): + if not os.path.isfile(cachefile): if tuning_options.simulation_mode: - raise ValueError(f"Simulation mode requires an existing cachefile: file {cache} does not exist") + raise ValueError(f"Simulation mode requires an existing cachefile: file {cachefile} does not exist") c = dict() c["device_name"] = runner.dev.name @@ -1197,15 +1197,14 @@ def process_cache(cache, kernel_options, tuning_options, runner): contents = json.dumps(c, cls=NpEncoder, indent="")[:-3] # except the last "}\n}" # write the header to the cachefile - with open(cache, "w") as cachefile: - cachefile.write(contents) + with open(cachefile, "w") as f: + f.write(contents) - tuning_options.cachefile = cache - tuning_options.cache = {} + return {} # if file exists else: - cached_data = read_cache(cache, open_cache=not tuning_options.simulation_mode) + cached_data = read_cache(cachefile, open_cache=not tuning_options.simulation_mode) # if in simulation mode, use the device name from the cache file as the runner device name if runner.simulation_mode: @@ -1231,17 +1230,16 @@ def process_cache(cache, kernel_options, tuning_options, runner): ) raise ValueError( f"Cannot load cache which contains results obtained with different tunable parameters. \ - Cache at '{cache}' has: {cached_data['tune_params_keys']}, tuning_options has: {list(tuning_options.tune_params.keys())}" + Cache at '{cachefile}' has: {cached_data['tune_params_keys']}, tuning_options has: {list(tuning_options.tune_params.keys())}" ) - tuning_options.cachefile = cache - tuning_options.cache = cached_data["cache"] + return cached_data["cache"] -def correct_open_cache(cache, open_cache=True): +def correct_open_cache(cachefile, open_cache=True): """If cache file was not properly closed, pretend it was properly closed.""" - with open(cache, "r") as cachefile: - filestr = cachefile.read().strip() + with open(cachefile, "r") as f: + filestr = f.read().strip() # if file was not properly closed, pretend it was properly closed if len(filestr) > 0 and filestr[-3:] not in ["}\n}", "}}}"]: @@ -1253,15 +1251,15 @@ def correct_open_cache(cache, open_cache=True): else: if open_cache: # if it was properly closed, open it for appending new entries - with open(cache, "w") as cachefile: - cachefile.write(filestr[:-3] + ",") + with open(cachefile, "w") as f: + f.write(filestr[:-3] + ",") return filestr -def read_cache(cache, open_cache=True): +def read_cache(cachefile, open_cache=True): """Read the cachefile into a dictionary, if open_cache=True prepare the cachefile for appending.""" - filestr = correct_open_cache(cache, open_cache) + filestr = correct_open_cache(cachefile, open_cache) error_configs = { "InvalidConfig": InvalidConfig(), @@ -1279,25 +1277,25 @@ def read_cache(cache, open_cache=True): return cache_data -def close_cache(cache): - if not os.path.isfile(cache): +def close_cache(cachefile): + if not os.path.isfile(cachefile): raise ValueError("close_cache expects cache file to exist") - with open(cache, "r") as fh: + with open(cachefile, "r") as fh: contents = fh.read() # close to file to make sure it can be read by JSON parsers if contents[-1] == ",": - with open(cache, "w") as fh: + with open(cachefile, "w") as fh: fh.write(contents[:-1] + "}\n}") -def store_cache(key, params, tuning_options): +def store_cache(key, params, cachefile, cache): """Stores a new entry (key, params) to the cachefile.""" # logging.debug('store_cache called, cache=%s, cachefile=%s' % (tuning_options.cache, tuning_options.cachefile)) - if isinstance(tuning_options.cache, dict): - if key not in tuning_options.cache: - tuning_options.cache[key] = params + if isinstance(cache, dict): + if key not in cache: + cache[key] = params # Convert ErrorConfig objects to string, wanted to do this inside the JSONconverter but couldn't get it to work output_params = params.copy() @@ -1305,9 +1303,9 @@ def store_cache(key, params, tuning_options): if isinstance(v, ErrorConfig): output_params[k] = str(v) - if tuning_options.cachefile: - with open(tuning_options.cachefile, "a") as cachefile: - cachefile.write("\n" + json.dumps({key: output_params}, cls=NpEncoder)[1:-1] + ",") + if cachefile: + with open(cachefile, "a") as f: + f.write("\n" + json.dumps({key: output_params}, cls=NpEncoder)[1:-1] + ",") def dump_cache(obj: str, tuning_options): diff --git a/test/test_util_functions.py b/test/test_util_functions.py index 4a1858f3..56a5a761 100644 --- a/test/test_util_functions.py +++ b/test/test_util_functions.py @@ -621,25 +621,25 @@ def assert_open_cachefile_is_correctly_parsed(cache): try: # call process_cache without pre-existing cache - process_cache(cache, kernel_options, tuning_options, runner) + tuning_options.cachefile = cache + tuning_options.cache = process_cache(cache, kernel_options, tuning_options, runner) # check if file has been created assert os.path.isfile(cache) assert_open_cachefile_is_correctly_parsed(cache) - assert tuning_options.cachefile == cache assert isinstance(tuning_options.cache, dict) assert len(tuning_options.cache) == 0 # store one entry in the cache params = {"x": 4, "time": np.float32(0.1234)} - store_cache("4", params, tuning_options) + store_cache("4", params, cache, tuning_options.cache) assert len(tuning_options.cache) == 1 # close the cache close_cache(cache) # now test process cache with a pre-existing cache file - process_cache(cache, kernel_options, tuning_options, runner) + tuning_options.cache = process_cache(cache, kernel_options, tuning_options, runner) assert_open_cachefile_is_correctly_parsed(cache) assert tuning_options.cache["4"]["time"] == params["time"] @@ -648,7 +648,7 @@ def assert_open_cachefile_is_correctly_parsed(cache): # a different kernel, device, or parameter set with pytest.raises(ValueError) as excep: kernel_options.kernel_name = "wrong_kernel" - process_cache(cache, kernel_options, tuning_options, runner) + tuning_options.cache = process_cache(cache, kernel_options, tuning_options, runner) assert "kernel" in str(excep.value) # correct the kernel name from last test @@ -656,7 +656,7 @@ def assert_open_cachefile_is_correctly_parsed(cache): with pytest.raises(ValueError) as excep: runner.dev.name = "wrong_device" - process_cache(cache, kernel_options, tuning_options, runner) + tuning_options.cache = process_cache(cache, kernel_options, tuning_options, runner) assert "device" in str(excep.value) # correct the device from last test @@ -664,7 +664,7 @@ def assert_open_cachefile_is_correctly_parsed(cache): with pytest.raises(ValueError) as excep: tuning_options.tune_params["y"] = ["a", "b"] - process_cache(cache, kernel_options, tuning_options, runner) + tuning_options.cache = process_cache(cache, kernel_options, tuning_options, runner) assert "parameter" in str(excep.value) finally: