Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 3 additions & 14 deletions api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
amp_editor_dependency,
amp_viewer_dependency,
)
from db import ThingContactAssociation, Thing, Contact, Email, Phone, Address
from db import Contact, Email, Phone, Address
from schemas.contact import (
CreateContact,
CreateAddress,
Expand All @@ -43,14 +43,11 @@
UpdateAddress,
)
from services.crud_helper import model_patcher, model_deleter, model_adder
from services.contact_helper import (
add_contact,
)
from services.contact_helper import add_contact, get_db_contacts
from services.lexicon_helper import get_terms_by_category
from services.query_helper import (
simple_get_by_id,
paginated_all_getter,
order_sort_filter,
)
from services.exceptions_helper import PydanticStyleException

Expand Down Expand Up @@ -484,15 +481,7 @@ async def get_contacts(
:param session:
:return:
"""
if thing_id:
sql = select(Contact)
sql = sql.join(ThingContactAssociation).join(Thing)
sql = sql.where(Thing.id == thing_id)

sql = order_sort_filter(sql, Contact, sort=sort, order=order, filter_=filter_)
return paginate(query=sql, conn=session)
else:
return paginated_all_getter(session, Contact, sort, order, filter_)
return get_db_contacts(session, thing_id, sort, order, filter_)


@router.get("/{contact_id}", summary="Get contact by ID")
Expand Down
3 changes: 2 additions & 1 deletion api/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,15 @@ async def update_sample(
async def get_samples(
session: session_dependency,
user: viewer_dependency,
thing_id: int | None = None,
sort: str = None,
order: str = None,
filter_: str = Query(alias="filter", default=None),
) -> CustomPage[SampleResponse]:
"""
Endpoint to retrieve samples.
"""
return get_db_samples(session, sort=sort, order=order, filter_=filter_)
return get_db_samples(session, thing_id, sort=sort, order=order, filter_=filter_)


@router.get("/{sample_id}", summary="Get Sample by ID")
Expand Down
4 changes: 2 additions & 2 deletions schemas/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ class SampleResponse(BaseResponseModel):
field_event: FieldEventResponse
field_activity: FieldActivityResponse
contact: ContactResponse
field_activity_id: int
field_event_contact_id: int
# field_activity_id: int
# field_event_contact_id: int
sample_date: Annotated[AwareDatetime, PastDatetime()]
sample_name: str
sample_matrix: str
Expand Down
30 changes: 29 additions & 1 deletion services/contact_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,36 @@
from schemas.contact import (
CreateContact,
)
from services.query_helper import order_sort_filter
from services.audit_helper import audit_add
from sqlalchemy.orm import Session

from fastapi_pagination.ext.sqlalchemy import paginate
from sqlalchemy.orm import Session, joinedload


def get_db_contacts(
session: Session,
thing_id: int | None = None,
sort: str | None = None,
order: str | None = None,
filter_: str | None = None,
):
sql = session.query(Contact).options(
# eagerly load related tables to avoid N+1 problems
joinedload(Contact.emails),
joinedload(Contact.phones),
joinedload(Contact.addresses),
joinedload(Contact.thing_associations).joinedload(
ThingContactAssociation.thing
),
)

if thing_id:
sql = sql.join(ThingContactAssociation)
sql = sql.where(ThingContactAssociation.thing_id == thing_id)

sql = order_sort_filter(sql, Contact, sort, order, filter_)
return paginate(sql)


def add_contact(session: Session, data: CreateContact | dict, user: dict) -> Contact:
Expand Down
6 changes: 6 additions & 0 deletions services/sample_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

def get_db_samples(
session: Session,
thing_id: int | None = None,
order: str | None = None,
sort: str | None = None,
filter_: str | None = None,
Expand All @@ -21,6 +22,11 @@ def get_db_samples(
), # Eagerly load related Contact
)

if thing_id:
query = query.join(FieldActivity)
query = query.join(FieldEvent)
query = query.where(FieldEvent.thing_id == thing_id)

query = order_sort_filter(query, Sample, sort, order, filter_)

return paginate(query)
9 changes: 9 additions & 0 deletions tests/test_contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,15 @@ def test_get_contacts(contact, email, address, phone):
assert data["items"][0]["addresses"][0]["release_status"] == address.release_status


def test_get_contacts_by_thing_id(contact, second_contact, water_well_thing):
response = client.get(f"/contact?thing_id={water_well_thing.id}")
data = response.json()

assert response.status_code == 200
assert data["total"] == 1
assert data["items"][0]["id"] == contact.id


def test_get_contact_by_id(contact, email, address, phone):
response = client.get(f"/contact/{contact.id}")
assert response.status_code == 200
Expand Down
32 changes: 22 additions & 10 deletions tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_validate_sample_top_and_bottom():

# ============= Post tests for samples =============================================
def test_add_sample(
groundwater_level_field_activity, water_well_thing, field_event_contact
groundwater_level_field_activity, water_well_thing, contact, field_event_contact
):
"""
Test adding a sample.
Expand Down Expand Up @@ -84,9 +84,7 @@ def test_add_sample(
assert data["thing"]["id"] == water_well_thing.id
assert data["field_event"]["id"] == groundwater_level_field_activity.field_event_id
assert data["field_activity"]["id"] == groundwater_level_field_activity.id
assert data["field_activity_id"] == payload["field_activity_id"]
assert data["contact"]["id"] == field_event_contact.contact_id
assert data["field_event_contact_id"] == payload["field_event_contact_id"]
assert data["contact"]["id"] == contact.id
assert data["sample_date"] == payload["sample_date"]
assert data["sample_name"] == payload["sample_name"]
assert data["sample_matrix"] == payload["sample_matrix"]
Expand Down Expand Up @@ -183,7 +181,6 @@ def test_patch_sample(water_chemistry_sample, groundwater_level_field_activity):
"""
payload = {
"field_activity_id": groundwater_level_field_activity.id,
# "field_event_contact_id": third_contact.id,
"sample_date": "2025-01-02T00:00:00Z",
"sample_name": "patched sample name",
"sample_matrix": "soil",
Expand All @@ -199,6 +196,8 @@ def test_patch_sample(water_chemistry_sample, groundwater_level_field_activity):
data = response.json()

for key, value in payload.items():
if key in ["field_event_contact_id", "field_activity_id"]:
continue
assert data[key] == value

# rollback after updating the sample
Expand Down Expand Up @@ -288,9 +287,7 @@ def test_get_samples(water_chemistry_sample, groundwater_level_sample):
assert "thing" in item
assert "field_event" in item
assert "field_activity" in item
assert "field_activity_id" in item
assert "contact" in item
assert "field_event_contact_id" in item
assert "sample_date" in item
assert "sample_name" in item
assert "sample_matrix" in item
Expand All @@ -302,12 +299,28 @@ def test_get_samples(water_chemistry_sample, groundwater_level_sample):
assert "release_status" in item


def test_get_samples_by_thing_id(
water_chemistry_sample, groundwater_level_sample, water_well_thing
):
response = client.get(f"/sample?thing_id={water_well_thing.id}")
assert response.status_code == 200
data = response.json()
assert data["total"] == 2

data_ids = [d["id"] for d in data["items"]]
sorted_data_ids = sorted(data_ids)

assert sorted_data_ids == sorted(
[water_chemistry_sample.id, groundwater_level_sample.id]
)


def test_get_sample_by_id(
water_chemistry_sample,
water_chemistry_field_activity,
field_event,
water_well_thing,
field_event_contact,
contact,
):
"""
Test retrieving a sample by its ID.
Expand All @@ -322,8 +335,7 @@ def test_get_sample_by_id(
assert data["thing"]["id"] == water_well_thing.id
assert data["field_event"]["id"] == field_event.id
assert data["field_activity"]["id"] == water_chemistry_field_activity.id
assert data["field_activity_id"] == water_chemistry_field_activity.id
assert data["field_event_contact_id"] == field_event_contact.id
assert data["contact"]["id"] == contact.id
assert data["sample_date"] == water_chemistry_sample.sample_date
assert data["sample_name"] == water_chemistry_sample.sample_name
assert data["sample_matrix"] == water_chemistry_sample.sample_matrix
Expand Down
Loading