Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 13 additions & 53 deletions codecarbon/core/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# from httpx import AsyncClient
import dataclasses
import json
from datetime import timedelta, tzinfo

import arrow
Expand Down Expand Up @@ -86,9 +85,7 @@ def check_auth(self):
url = self.url + "/auth/check"
headers = self._get_headers()
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def get_list_organizations(self):
Expand All @@ -98,18 +95,14 @@ def get_list_organizations(self):
url = self.url + "/organizations"
headers = self._get_headers()
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def check_organization_exists(self, organization_name: str):
"""
Check if an organization exists
"""
organizations = self.get_list_organizations()
if organizations is None:
return False
for organization in organizations:
if organization["name"] == organization_name:
return organization
Expand All @@ -129,9 +122,7 @@ def create_organization(self, organization: OrganizationCreate):
else:
headers = self._get_headers()
r = requests.post(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 201:
self._log_error(url, payload, r)
return None
r.raise_for_status()
return r.json()

def get_organization(self, organization_id):
Expand All @@ -141,9 +132,7 @@ def get_organization(self, organization_id):
headers = self._get_headers()
url = self.url + "/organizations/" + organization_id
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def update_organization(self, organization: OrganizationCreate):
Expand All @@ -154,9 +143,7 @@ def update_organization(self, organization: OrganizationCreate):
headers = self._get_headers()
url = self.url + "/organizations/" + organization.id
r = requests.patch(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, payload, r)
return None
r.raise_for_status()
return r.json()

def list_projects_from_organization(self, organization_id):
Expand All @@ -166,9 +153,7 @@ def list_projects_from_organization(self, organization_id):
url = self.url + "/organizations/" + organization_id + "/projects"
headers = self._get_headers()
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def create_project(self, project: ProjectCreate):
Expand All @@ -179,9 +164,7 @@ def create_project(self, project: ProjectCreate):
url = self.url + "/projects"
headers = self._get_headers()
r = requests.post(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 201:
self._log_error(url, payload, r)
return None
r.raise_for_status()
return r.json()

def get_project(self, project_id):
Expand All @@ -191,9 +174,7 @@ def get_project(self, project_id):
url = self.url + "/projects/" + project_id
headers = self._get_headers()
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def add_emission(self, carbon_emission: dict):
Expand Down Expand Up @@ -237,9 +218,7 @@ def add_emission(self, carbon_emission: dict):
url = self.url + "/emissions"
headers = self._get_headers()
r = requests.post(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 201:
self._log_error(url, payload, r)
return False
r.raise_for_status()
logger.debug(f"ApiClient - Successful upload emission {payload} to {url}")
except Exception as e:
logger.error(e, exc_info=True)
Expand Down Expand Up @@ -279,9 +258,7 @@ def _create_run(self, experiment_id: str):
url = self.url + "/runs"
headers = self._get_headers()
r = requests.post(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 201:
self._log_error(url, payload, r)
return None
r.raise_for_status()
self.run_id = r.json()["id"]
logger.info(
"ApiClient Successfully registered your run on the API.\n\n"
Expand All @@ -304,9 +281,7 @@ def list_experiments_from_project(self, project_id: str):
url = self.url + "/projects/" + project_id + "/experiments"
headers = self._get_headers()
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return []
r.raise_for_status()
return r.json()

def set_experiment(self, experiment_id: str):
Expand All @@ -324,9 +299,7 @@ def add_experiment(self, experiment: ExperimentCreate):
url = self.url + "/experiments"
headers = self._get_headers()
r = requests.post(url=url, json=payload, timeout=2, headers=headers)
if r.status_code != 201:
self._log_error(url, payload, r)
return None
r.raise_for_status()
return r.json()

def get_experiment(self, experiment_id):
Expand All @@ -336,22 +309,9 @@ def get_experiment(self, experiment_id):
url = self.url + "/experiments/" + experiment_id
headers = self._get_headers()
r = requests.get(url=url, timeout=2, headers=headers)
if r.status_code != 200:
self._log_error(url, {}, r)
return None
r.raise_for_status()
return r.json()

def _log_error(self, url, payload, response):
if len(payload) > 0:
logger.error(
f"ApiClient Error when calling the API on {url} with : {json.dumps(payload)}"
)
else:
logger.error(f"ApiClient Error when calling the API on {url}")
logger.error(
f"ApiClient API return http code {response.status_code} and answer : {response.text}"
)

def close_experiment(self):
"""
Tell the API that the experiment has ended.
Expand Down
26 changes: 16 additions & 10 deletions tests/test_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
from uuid import uuid4

import requests
import requests_mock

from codecarbon.core.api_client import ApiClient
Expand Down Expand Up @@ -137,7 +138,7 @@ def test_call_api(self):
assert payload["ram_utilization_percent"] == 56.5
assert payload["wue"] == 0.8

def test_check_auth_returns_none_on_error(self):
def test_check_auth_raises_on_error(self):
with requests_mock.Mocker() as m:
m.get("http://test.com/auth/check", text="bad", status_code=401)
api = ApiClient(
Expand All @@ -146,17 +147,19 @@ def test_check_auth_returns_none_on_error(self):
create_run_automatically=False,
)

self.assertIsNone(api.check_auth())
with self.assertRaises(requests.HTTPError):
api.check_auth()

def test_check_organization_exists_returns_false_when_list_fails(self):
def test_check_organization_exists_raises_when_list_fails(self):
with requests_mock.Mocker() as m:
m.get("http://test.com/organizations", text="bad", status_code=500)
api = ApiClient(
endpoint_url="http://test.com",
create_run_automatically=False,
)

self.assertFalse(api.check_organization_exists("missing"))
with self.assertRaises(requests.HTTPError):
api.check_organization_exists("missing")

def test_create_organization_skips_when_name_exists(self):
organization = OrganizationCreate(name="existing", description="desc")
Expand Down Expand Up @@ -267,7 +270,7 @@ def test_create_run_returns_none_on_unsuccessful_status(self):
self.assertIsNone(api._create_run("experiment_id"))
self.assertIsNone(api.run_id)

def test_list_experiments_from_project_returns_empty_list_on_error(self):
def test_list_experiments_from_project_raises_on_error(self):
with requests_mock.Mocker() as m:
m.get(
"http://test.com/projects/proj-1/experiments",
Expand All @@ -279,7 +282,8 @@ def test_list_experiments_from_project_returns_empty_list_on_error(self):
create_run_automatically=False,
)

self.assertEqual(api.list_experiments_from_project("proj-1"), [])
with self.assertRaises(requests.HTTPError):
api.list_experiments_from_project("proj-1")

def test_set_experiment_updates_value(self):
api = ApiClient(endpoint_url="http://test.com", create_run_automatically=False)
Expand All @@ -288,7 +292,7 @@ def test_set_experiment_updates_value(self):

self.assertEqual(api.experiment_id, "exp-2")

def test_add_experiment_returns_none_on_error(self):
def test_add_experiment_raises_on_error(self):
experiment = ExperimentCreate(
timestamp="2024-01-01T00:00:00+00:00",
name="exp",
Expand All @@ -303,14 +307,16 @@ def test_add_experiment_returns_none_on_error(self):
create_run_automatically=False,
)

self.assertIsNone(api.add_experiment(experiment))
with self.assertRaises(requests.HTTPError):
api.add_experiment(experiment)

def test_get_experiment_returns_none_on_error(self):
def test_get_experiment_raises_on_error(self):
with requests_mock.Mocker() as m:
m.get("http://test.com/experiments/exp-1", text="bad", status_code=404)
api = ApiClient(
endpoint_url="http://test.com",
create_run_automatically=False,
)

self.assertIsNone(api.get_experiment("exp-1"))
with self.assertRaises(requests.HTTPError):
api.get_experiment("exp-1")