diff --git a/src/post_processing/utils/core_utils.py b/src/post_processing/utils/core_utils.py index 11449e9..25ea173 100644 --- a/src/post_processing/utils/core_utils.py +++ b/src/post_processing/utils/core_utils.py @@ -379,9 +379,6 @@ def get_count( def get_labels_and_annotators(df: DataFrame) -> tuple[list, list]: """Extract and align annotation labels and annotators from an APLOSE DataFrame. - If only one label is present, it is duplicated to match the number of annotators. - Similarly, if one annotator is present, it is duplicated to match the labels. - Parameters ---------- df : DataFrame @@ -397,18 +394,8 @@ def get_labels_and_annotators(df: DataFrame) -> tuple[list, list]: msg = "`df` contains no data" raise ValueError(msg) - annotators = df["annotator"].unique().tolist() - labels = df["annotation"].unique().tolist() - if len(labels) == 1: - labels = [labels[0]] * len(annotators) - if len(annotators) == 1: - annotators = [annotators[0]] * len(labels) - - if len(annotators) != len(labels): - msg = f"{len(annotators)} annotators and {len(labels)} labels must match." - raise ValueError(msg) - - return labels, annotators + unique_pairs = df[["annotator", "annotation"]].drop_duplicates() + return unique_pairs["annotation"].to_list(), unique_pairs["annotator"].to_list() def localize_timestamps(timestamps: list[Timestamp], tz: tzinfo) -> list[Timestamp]: diff --git a/tests/test_core_utils.py b/tests/test_core_utils.py index 6cfa494..dd34fc1 100644 --- a/tests/test_core_utils.py +++ b/tests/test_core_utils.py @@ -252,15 +252,6 @@ def test_single_annotator_multiple_labels(sample_df: DataFrame) -> None: assert sorted(annotators) == ["ann1"] * 2 -def test_mismatched_labels_and_annotators() -> None: - df = DataFrame({ - "annotation": ["lbl1", "lbl2", "lbl3"], - "annotator": ["ann1", "ann2", "ann2"], - }) - with pytest.raises(ValueError, match="annotators and .* labels must match"): - get_labels_and_annotators(df) - - def test_get_labels_and_annotators_empty_dataframe() -> None: with pytest.raises(ValueError, match="`df` contains no data"): get_labels_and_annotators(DataFrame())