diff --git a/ledsa/core/image_reading.py b/ledsa/core/image_reading.py index 57b4585..f7cc8f9 100644 --- a/ledsa/core/image_reading.py +++ b/ledsa/core/image_reading.py @@ -26,6 +26,25 @@ def read_channel_data_from_img(filename: str, channel: int) -> np.ndarray: channel_array = _read_channel_data_from_raw_file(filename, channel) return channel_array +def read_img_array_from_img(filename: str, channel: int) -> np.ndarray: + """ + Returns a 2D array of the image for all color channels. + + + :param filename: The path of the image file to read. + :type filename: str + :param channel: The color channel used for calculating black level values. + :type channel: int + :return: A 2D array containing the processed image data for all color channels. + :rtype: np.ndarray + """ + extension = os.path.splitext(filename)[-1] + if extension in ['.JPG', '.JPEG', '.jpg', '.jpeg', '.PNG', '.png']: + img_array = _read_grayscale_img_array_from_img_file(filename) + elif extension in ['.CR2', '.CR3']: + img_array, _ = _read_img_array_from_raw_file(filename, channel) + return img_array + def _read_channel_data_from_img_file(filename: str, channel: int) -> np.ndarray: """ @@ -38,7 +57,7 @@ def _read_channel_data_from_img_file(filename: str, channel: int) -> np.ndarray: :return: A 2D numpy array containing the data of the specified color channel from the image. :rtype: np.ndarray """ - img_array = read_img_array_from_img_file(filename) + img_array = _read_img_array_from_img_file(filename) return img_array[:, :, channel] @@ -54,7 +73,7 @@ def _read_channel_data_from_raw_file(filename: str, channel: int) -> np.ndarray: :return: A 2D numpy array representing the extracted channel, with all other channel values masked or set to zero. :rtype: np.ndarray """ - img_array, filter_array = read_img_array_from_raw_file(filename, channel) + img_array, filter_array = _read_img_array_from_raw_file(filename, channel) if channel == 0 or channel == 2: channel_array = np.where(filter_array == channel, img_array, 0) elif channel == 1: @@ -62,7 +81,7 @@ def _read_channel_data_from_raw_file(filename: str, channel: int) -> np.ndarray: return channel_array -def read_img_array_from_raw_file(filename: str, channel: int) -> np.ndarray: +def _read_img_array_from_raw_file(filename: str, channel: int) -> np.ndarray: # TODO: channel is only relevant for black level, consider individually! with rawpy.imread(filename) as raw: data = raw.raw_image_visible.copy() @@ -74,10 +93,16 @@ def read_img_array_from_raw_file(filename: str, channel: int) -> np.ndarray: img_array = np.clip(img_array, 0, white_level) return img_array, filter_array -def read_img_array_from_img_file(filename: str) -> np.ndarray: +def _read_img_array_from_img_file(filename: str) -> np.ndarray: img_array = plt.imread(filename) return img_array +def _read_grayscale_img_array_from_img_file(filename: str) -> np.ndarray: + img_array = plt.imread(filename) + weights = np.array([0.2989, 0.5870, 0.1140]) + gray = np.dot(img_array[..., :3], weights).astype(np.uint8) + return gray + def get_exif_entry(filename: str, tag: str) -> str: """ diff --git a/ledsa/data_extraction/DataExtractor.py b/ledsa/data_extraction/DataExtractor.py index aacf9f4..47978b5 100644 --- a/ledsa/data_extraction/DataExtractor.py +++ b/ledsa/data_extraction/DataExtractor.py @@ -3,6 +3,7 @@ import os import matplotlib.pyplot as plt +from matplotlib.colors import LogNorm import numpy as np from tqdm import tqdm @@ -26,7 +27,9 @@ class DataExtractor: :vartype channels: Tuple :ivar fit_leds: Whether to fit LEDs or not. :vartype fit_leds: bool - :ivar search_areas: 2D numpy array with dimension (# of LEDs) x (LED_id, x, y). + :ivar fit_leds: Whether to fit LEDs or not. + :vartype threshold: float, optional + :ivar threshold: The threshold value used for LED detection. :vartype search_areas: numpy.ndarray, optional :ivar line_indices: 2D list with dimension (# of LED arrays) x (# of LEDs per array) or None. :vartype line_indices: list[list[int]], optional @@ -46,6 +49,7 @@ def __init__(self, channels=(0), load_config_file=True, build_experiment_infos=T self.config = ConfigData(load_config_file=load_config_file) self.channels = list(channels) self.fit_leds = fit_leds + self.threshold = None # 2D numpy array with dimension (# of LEDs) x (LED_id, x, y) self.search_areas = None @@ -95,12 +99,13 @@ def find_search_areas(self) -> None: max_num_leds = int(config['max_num_leds']) pixel_value_percentile = float(config['pixel_value_percentile']) if channel == 'all': - data, _ = ledsa.core.image_reading.read_img_array_from_raw_file(in_file_path, channel=0) # TODO: Channel to be removed here! + # TODO this currently only works for RAW files but should work for JPG files as well + data = ledsa.core.image_reading.read_img_array_from_img(in_file_path, channel=0) # TODO: Channel to be removed here! else: channel = int(channel) data = ledsa.core.image_reading.read_channel_data_from_img(in_file_path, channel=channel) - self.search_areas = ledsa.data_extraction.step_1_functions.find_search_areas(data, search_area_radius=search_area_radius, max_n_leds=max_num_leds, pixel_value_percentile=pixel_value_percentile) + self.search_areas, self.threshold = ledsa.data_extraction.step_1_functions.find_search_areas(data, search_area_radius=search_area_radius, max_n_leds=max_num_leds, pixel_value_percentile=pixel_value_percentile) self.write_search_areas() self.plot_search_areas() ledsa.core.file_handling.remove_flag('reorder_leds') @@ -122,12 +127,14 @@ def plot_search_areas(self, reorder_leds=False) -> None: self.load_search_areas() in_file_path = os.path.join(config['img_directory'], config['img_name_string'].format(int(config['ref_img_id']))) - data = ledsa.core.image_reading.read_channel_data_from_img(in_file_path, channel=0) + # TODO this currently only works for RAW files but should work for JPG files as well + data = ledsa.core.image_reading.read_img_array_from_img(in_file_path, channel=0) search_area_radius = int(config['search_area_radius']) plt.figure(dpi=1200) ax = plt.gca() ledsa.data_extraction.step_1_functions.add_search_areas_to_plot(self.search_areas, search_area_radius, ax) - plt.imshow(data, cmap='Greys') + plt.imshow(data, norm=LogNorm(vmin=self.threshold, vmax=data.max()), cmap='Grays') + plt.xlim(self.search_areas[:, 2].min() - 5 * search_area_radius, self.search_areas[:, 2].max() + 5 * search_area_radius) plt.ylim(self.search_areas[:, 1].max() + 5 * search_area_radius, self.search_areas[:, 1].min() - 5 * search_area_radius) plt.colorbar() diff --git a/ledsa/data_extraction/step_1_functions.py b/ledsa/data_extraction/step_1_functions.py index 8fb5715..7e6cb96 100644 --- a/ledsa/data_extraction/step_1_functions.py +++ b/ledsa/data_extraction/step_1_functions.py @@ -1,9 +1,13 @@ +from typing import Any + import numpy as np from matplotlib import pyplot as plt import cv2 +from numpy import ndarray, dtype, floating -def find_search_areas(image: np.ndarray, search_area_radius, pixel_value_percentile=99.875, max_n_leds=1300) -> np.ndarray: +def find_search_areas(image: np.ndarray, search_area_radius, pixel_value_percentile=99.875, max_n_leds=1300) -> tuple[ + ndarray[Any, dtype[Any]], floating[Any]]: """ Identifies and extracts locations of LEDs in an image. @@ -15,8 +19,10 @@ def find_search_areas(image: np.ndarray, search_area_radius, pixel_value_percent :type pixel_value_percentile: float :param max_n_leds: The maximum number of LED locations to identify in the image. :type max_n_leds: int - :return: A numpy array of identified LED locations, each represented as (LED ID, y-coordinate, x-coordinate). - :rtype: np.ndarray + :return: A tuple containing: + - A numpy array of identified LED locations, each represented as (LED ID, y-coordinate, x-coordinate). + - The threshold value used for LED detection. + :rtype: tuple[np.ndarray, float] """ (_, max_pixel_value, _, max_pixel_loc) = cv2.minMaxLoc(image) threshold = np.percentile(image, pixel_value_percentile) @@ -34,7 +40,7 @@ def find_search_areas(image: np.ndarray, search_area_radius, pixel_value_percent led_id += 1 print('\n') print(f"Found {led_id} LEDS") - return np.array(search_areas_list) + return np.array(search_areas_list), threshold def add_search_areas_to_plot(search_areas: np.ndarray, search_area_radius: int, ax: plt.axes) -> None: