diff --git a/endaq/calc/utils.py b/endaq/calc/utils.py index 6bde969..2474938 100644 --- a/endaq/calc/utils.py +++ b/endaq/calc/utils.py @@ -2,8 +2,9 @@ from __future__ import annotations +import bisect import typing -from typing import Optional, Union, Literal +from typing import Optional, Union, Literal, List import warnings import numpy as np @@ -128,32 +129,45 @@ def to_dB( "audio_intensity": 1e-12, # W/m² } - -def resample(df: pd.DataFrame, sample_rate: Optional[float] = None) -> pd.DataFrame: +def resample( + df: pd.DataFrame, + sample_rate: Optional[float] = None, + num_samples: Optional[int] = None + ) -> pd.DataFrame: """ - Resample a dataframe to a desired sample rate (in Hz) - + Resample a dataframe to a desired sample rate (in Hz) or a desired number of points. + Note that ``sample_rate`` and ``num_samples`` are mutually exclusive. If + neither of sample_rate or num_samples is suplied, it will use the same sample_rate + as it currently does, but makes the time stamps uniformly spaced. + :param df: The DataFrame to resample, indexed by time :param sample_rate: The desired sample rate to resample the given data to. - If one is not supplied, then it will use the same as it currently does, but - make the time stamps uniformly spaced + :param num_samples: The desired number of samples to resample the given data to. + :return: The resampled data in a DataFrame """ - if sample_rate is None: - num_samples_after_resampling = len(df) - else: + if sample_rate is not None and num_samples is not None: + raise ValueError("Only one of `sample_rate` and `num_samples` can be set.") + + if sample_rate is not None: dt = sample_spacing(df) num_samples_after_resampling = int(dt * len(df) * sample_rate) + elif num_samples is not None: + num_samples_after_resampling = num_samples + else: + num_samples_after_resampling = len(df) resampled_data, resampled_time = scipy.signal.resample( df, num_samples_after_resampling, t=df.index.values.astype(np.float64), ) - resampled_time = pd.date_range( - df.iloc[0].name, df.iloc[-1].name, - periods=num_samples_after_resampling, - ) + + if resampled_time[0] != df.index[0] or resampled_time[-1] != df.index[-1]: + resampled_time = pd.date_range( + df.index[0], df.index[-1], + periods=num_samples_after_resampling, + ) # Check for datetimes, if so localize if 'datetime' in str(df.index.dtype): @@ -169,7 +183,6 @@ def resample(df: pd.DataFrame, sample_rate: Optional[float] = None) -> pd.DataFr return resampled_df - def _rolling_slice_definitions( df: pd.DataFrame, num_slices: int = 100, @@ -462,3 +475,74 @@ def to_altitude(df: pd.DataFrame, # Return DataFrame with New Altitude Column return alt_df + +def align_dataframes(dfs: List[pd.DataFrame]) -> List[pd.DataFrame]: + """ + Resamples the given dataframes to all be equal-sized with resampled uniform timestamps. + Any timestamps outside of the shared range will be dropped. + + :param dfs: a List of dataframes with DateTimeIndex to align. + + :return: a list of dataframes in the same order that they were inputted in. + """ + aligned_start = max([df.index[0] for df in dfs]) + aligned_end = min([df.index[-1] for df in dfs]) + + if aligned_start >= aligned_end: + raise ValueError("No range of time shared between dataframes") + left_idx = [bisect.bisect_right(df.index, aligned_start) - 1 for df in dfs] #the most left point in bound + right_idx = [bisect.bisect_left(df.index, aligned_end) for df in dfs] #the first right point out of bounds + + #removes the start / end points + trimmed_dfs = [dfs[i][left_idx[i] + 1: right_idx[i] - 1] for i in range(len(dfs))] + + for i, (df, l_idx) in enumerate(zip(dfs, left_idx)): + #if the original timestamp is too early + if df.index[l_idx] != aligned_start: + #change in time and acceleration + dt = (df.index[l_idx] - df.index[l_idx - 1]).total_seconds() + da = (df.iloc[l_idx] - df.iloc[l_idx - 1]) / dt + new_dt = (aligned_start - df.index[l_idx]).total_seconds() + #compute the new point + new_point = df.iloc[l_idx - 1] + new_dt * da + #and add it back to the dataframe + trimmed_dfs[i] = pd.concat([ + pd.DataFrame([new_point], index= [aligned_start]), + trimmed_dfs[i] + ]) + #in the case that the data is already in the correct point, add it back in + else: + trimmed_dfs[i] = pd.concat([df.loc[[aligned_start]], trimmed_dfs[i]]) + + #repeating the steps above, with slight indexing differences to accomodate the right index + for i, (df, r_idx) in enumerate(zip(dfs, right_idx)): + if df.index[r_idx] != aligned_end: + dt = (df.index[r_idx] - df.index[r_idx - 1]).total_seconds() + da = (df.iloc[r_idx] - df.iloc[r_idx - 1]) / dt + new_dt = (aligned_end - (df.index[r_idx - 1])).total_seconds() + new_point = df.iloc[r_idx - 1] + new_dt * da + trimmed_dfs[i] = pd.concat([ + trimmed_dfs[i], + pd.DataFrame([new_point], index = [aligned_end]) + ]) + else: + trimmed_dfs[i] = pd.concat([trimmed_dfs[i], df.loc[[aligned_end]]]) + + #resamples the data to the dataframe with the most points available + total_samples = max(tdf.shape[0] for tdf in trimmed_dfs) + resampled_dfs = [resample(df, num_samples=total_samples) for df in trimmed_dfs] + + """ + In the current implementation of scipy's resample, there can be some inconsistent rounding point + when creating the datetimes. For this reason, we will find one that meets spec (correct start + and end points) and use that for all. + """ + datepoints = None + for df in resampled_dfs: + if df.index[0] == aligned_start and df.index[-1] == aligned_end: + datepoints = df.index + break + if datepoints is None: + raise Exception("resampling error, timestamps incosistent with inputs") + for df in resampled_dfs: df.index = datepoints + return resampled_dfs \ No newline at end of file diff --git a/endaq/ide/info.py b/endaq/ide/info.py index 94380d8..b222a6b 100644 --- a/endaq/ide/info.py +++ b/endaq/ide/info.py @@ -14,15 +14,18 @@ import pandas.io.formats.style import idelib.dataset -from .measurement import MeasurementType, ANY, get_channels +from .measurement import MeasurementType, ANY, ROTATION, ACCELERATION, MAGNETIC, get_channels from .files import get_doc from .util import parse_time +from ..calc.utils import align_dataframes +import ahrs __all__ = [ "get_channel_table", "to_pandas", "get_primary_sensor_data", + "compute_orientation" ] @@ -433,3 +436,55 @@ def get_primary_sensor_data( #Return only the subchannels with right units return data[channels.name] + +# ============================================================================ +# +# ============================================================================ + +def compute_orientation(doc): + """ + Uses a Madgwick filter to compute the absolute orientation of the enDAQ device, + which uses the primary accelerometer, gyroscope, and optionally, a + magnetometer. + + Note that this method relies on the use of :py:func:`~endaq.calc.resample()`, + which can create artifacts at the starting / ending points of the dataset. + If important data is contained within those segments, it is not recommended + to use this method. + + :param doc: An open `Dataset` object, see :py:func:`~endaq.ide.get_doc()` + for more. + + :return: A pandas DataFrame containing orientation data with + column names ['X', 'Y', 'Z', 'W'] (scalar-last order quaternion). + + :raises: + ValueError: if an acceleration or rotation channel is not present + """ + acc = get_primary_sensor_data(doc=doc, measurement_type=ACCELERATION) + rot = get_primary_sensor_data(doc=doc, measurement_type=ROTATION) + rot = np.deg2rad(rot) + try: + mag = get_primary_sensor_data(doc=doc, measurement_type=MAGNETIC) + al_acc, al_rot, al_mag = align_dataframes([acc, rot, mag]) + except: + al_acc, al_rot = align_dataframes([acc, rot]) + al_mag = None + + timestamps = al_acc.index + dt = (timestamps[1] - timestamps[0]).total_seconds() + data = ahrs.filters.Madgwick( + gyr=al_rot.to_numpy(), + acc=al_acc.to_numpy(), + mag=al_mag.to_numpy() if al_mag is not None else None, + Dt=dt + ).Q + + return pd.DataFrame( + data, + index=timestamps, + columns=['W', 'X', 'Y', 'Z'] + #changes the columns from scalar first to scalar last + ).iloc[:, [1,2,3,0]] + + diff --git a/requirements.txt b/requirements.txt index 31cb1eb..c0be50a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ python-dotenv>=0.18.0 requests>=2.25.1 scipy>=1.7.1 pint>=0.18 - +ahrs # Testing hypothesis==6.41.0 pytest diff --git a/setup.py b/setup.py index 848f668..7657ed0 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,8 @@ def get_version(rel_path): "python-dotenv>=0.18.0", "requests>=2.25.1", "scipy>=1.7.1", - "pint>=0.18" + "pint>=0.18", + "ahrs" ] TEST_REQUIRES = [ diff --git a/tests/ide/tco_no_acc.IDE b/tests/ide/tco_no_acc.IDE new file mode 100644 index 0000000..1aace8e Binary files /dev/null and b/tests/ide/tco_no_acc.IDE differ diff --git a/tests/ide/tco_no_mag.IDE b/tests/ide/tco_no_mag.IDE new file mode 100644 index 0000000..1dd4bac Binary files /dev/null and b/tests/ide/tco_no_mag.IDE differ diff --git a/tests/ide/tco_no_rot.IDE b/tests/ide/tco_no_rot.IDE new file mode 100644 index 0000000..860064c Binary files /dev/null and b/tests/ide/tco_no_rot.IDE differ diff --git a/tests/ide/tco_normal.IDE b/tests/ide/tco_normal.IDE new file mode 100644 index 0000000..f8aa5f1 Binary files /dev/null and b/tests/ide/tco_normal.IDE differ diff --git a/tests/ide/test_info.py b/tests/ide/test_info.py index f68758e..a90cef2 100644 --- a/tests/ide/test_info.py +++ b/tests/ide/test_info.py @@ -7,10 +7,15 @@ from idelib.importer import importFile import pandas as pd -from endaq.ide import files, info +from endaq.ide import measurement +from endaq.ide import files, info, get_channels, compute_orientation IDE_FILENAME = os.path.join(os.path.dirname(__file__), "test.ide") +TCO1_FILENAME = os.path.join(os.path.dirname(__file__), "tco_normal.IDE") +TCO2_FILENAME = os.path.join(os.path.dirname(__file__), "tco_no_acc.IDE") +TCO3_FILENAME = os.path.join(os.path.dirname(__file__), "tco_no_rot.IDE") +TCO4_FILENAME = os.path.join(os.path.dirname(__file__), "tco_no_mag.IDE") @pytest.fixture @@ -18,6 +23,15 @@ def test_IDE(): with importFile(IDE_FILENAME) as ds: yield ds +@pytest.fixture +def tco_pass_IDE(): + #orientation computation files that **will** pass + yield [importFile(TCO1_FILENAME), importFile(TCO4_FILENAME)] + +@pytest.fixture +def tco_fail_IDE(): + #orientation computation files that **will not** pass + yield [importFile(TCO2_FILENAME), importFile(TCO3_FILENAME)] class TimeParseTest(unittest.TestCase): """ Test the time parsing function. @@ -201,6 +215,20 @@ def test_to_pandas_tz(test_IDE): assert "UTC" in str(result_utc.index.dtype) assert "device" in str(result_device.index.dtype) +def test_compute_orientation(tco_pass_IDE, tco_fail_IDE): + """ + Tests that channels should fail if the needed information is not present, and if it + passes, follows the scalar-last convention + """ + for ide in tco_pass_IDE: + assert np.array_equal(compute_orientation(ide).columns, ['X', 'Y', 'Z', 'W']) + + for ide in tco_fail_IDE: + with pytest.raises(Exception): + compute_orientation(ide) + assert False + + if __name__ == '__main__': unittest.main()