Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ jobs:
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
- name: Format with Black
uses: psf/black@stable
with:
options: ". -l 79 --check"
- name: Install ruff
run: pip install ruff>=0.9.0
- name: Format check with ruff
run: ruff format --check .
check-version:
name: Check version
runs-on: ubuntu-latest
Expand Down
10 changes: 6 additions & 4 deletions .github/workflows/push.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ jobs:
steps:
- name: Checkout repo
uses: actions/checkout@v4
- name: Check formatting
uses: "lgeiger/black-action@master"
with:
args: ". -l 79 --check"
- name: Setup Python
uses: actions/setup-python@v5
- name: Install ruff
run: pip install ruff>=0.9.0
- name: Format check with ruff
run: ruff format --check .
ensure-model-version-aligns-with-sim-api:
name: Ensure model version aligns with simulation API
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ debug-test:
MAX_HOUSEHOLDS=1000 FLASK_DEBUG=1 pytest -vv --durations=0 tests

format:
black . -l 79
ruff format .

deploy:
python gcp/export.py
Expand Down
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: minor
changes:
changed:
- Switched code formatter from Black to Ruff.
20 changes: 5 additions & 15 deletions dashboard/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@

st.subheader("Look up a policy")

policy_id = int(
st.text_input("Enter a policy ID", "1", key="policy_lookup_text")
)
country_id = st.text_input(
"Enter a country ID", "uk", key="policy_lookup_country"
)
policy_id = int(st.text_input("Enter a policy ID", "1", key="policy_lookup_text"))
country_id = st.text_input("Enter a country ID", "uk", key="policy_lookup_country")
if st.button("Look up policy", key="policy_lookup"):
try:
results = database.query(
Expand All @@ -40,9 +36,7 @@

policy_id = int(st.text_input("Enter a policy ID", "1"))
country_id = st.text_input("Enter a country ID", "uk")
new_label = st.text_input(
"Enter a new label", "New label", key="policy_label_text"
)
new_label = st.text_input("Enter a new label", "New label", key="policy_label_text")
if st.button("Set policy label", key="policy_label"):
try:
database.set_policy_label(policy_id, country_id, new_label)
Expand All @@ -54,12 +48,8 @@

st.subheader("Delete a policy")

policy_id = int(
st.text_input("Enter a policy ID", "1", key="policy_delete_text")
)
country_id = st.text_input(
"Enter a country ID", "uk", key="policy_delete_country"
)
policy_id = int(st.text_input("Enter a policy ID", "1", key="policy_delete_text"))
country_id = st.text_input("Enter a country ID", "uk", key="policy_delete_country")
if st.button("Delete policy", key="policy_delete"):
try:
database.delete_policy(policy_id, country_id)
Expand Down
14 changes: 4 additions & 10 deletions dashboard/experiments/gpt4_api_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@

@st.cache_data
def get_embeddings_df():
embeddings_df = pd.read_csv(
"parameter_embeddings.csv.gz", compression="gzip"
)
embeddings_df.parameter_embedding = (
embeddings_df.parameter_embedding.apply(lambda x: eval(x))
embeddings_df = pd.read_csv("parameter_embeddings.csv.gz", compression="gzip")
embeddings_df.parameter_embedding = embeddings_df.parameter_embedding.apply(
lambda x: eval(x)
)
return embeddings_df

Expand Down Expand Up @@ -55,11 +53,7 @@ def embed(prompt, engine="text-embedding-ada-002"):
lambda x: cosine_similarity(x, embedding)
)

top5 = (
embeddings_df.sort_values("similarities", ascending=False)
.head(5)["json"]
.values
)
top5 = embeddings_df.sort_values("similarities", ascending=False).head(5)["json"].values
# display in streamlit


Expand Down
10 changes: 4 additions & 6 deletions gcp/bump_country_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ def main():
help="Country package to bump",
choices=["policyengine-uk", "policyengine-us", "policyengine-canada"],
)
parser.add_argument(
"--version", type=str, required=True, help="Version to bump to"
)
parser.add_argument("--version", type=str, required=True, help="Version to bump to")
args = parser.parse_args()
country = args.country
version = args.version
Expand All @@ -44,9 +42,9 @@ def bump_country_package(country, version):
with open(setup_py_path, "w") as f:
f.write(setup_py)

country_package_full_name = country.replace(
"policyengine", "PolicyEngine"
).replace("-", " ")
country_package_full_name = country.replace("policyengine", "PolicyEngine").replace(
"-", " "
)
country_id = country.replace("policyengine-", "")
country_package_full_name = country_package_full_name.replace(
country_id, country_id.upper()
Expand Down
8 changes: 2 additions & 6 deletions gcp/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@
dockerfile = dockerfile.replace(
".github_microdata_token", GITHUB_MICRODATA_TOKEN
)
dockerfile = dockerfile.replace(
".anthropic_api_key", ANTHROPIC_API_KEY
)
dockerfile = dockerfile.replace(".anthropic_api_key", ANTHROPIC_API_KEY)
dockerfile = dockerfile.replace(".openai_api_key", OPENAI_API_KEY)
dockerfile = dockerfile.replace(
".hugging_face_token", HUGGING_FACE_TOKEN
)
dockerfile = dockerfile.replace(".hugging_face_token", HUGGING_FACE_TOKEN)

with open(dockerfile_location, "w") as f:
f.write(dockerfile)
12 changes: 3 additions & 9 deletions policyengine_api/ai_prompts/simulation_analysis_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,12 @@ def generate_simulation_analysis_prompt(params: InboundParameters) -> str:
)

impact_budget: str = json.dumps(parameters.impact["budget"])
impact_intra_decile: dict[str, Any] = json.dumps(
parameters.impact["intra_decile"]
)
impact_intra_decile: dict[str, Any] = json.dumps(parameters.impact["intra_decile"])
impact_decile: str = json.dumps(parameters.impact["decile"])
impact_inequality: str = json.dumps(parameters.impact["inequality"])
impact_poverty: str = json.dumps(parameters.impact["poverty"]["poverty"])
impact_deep_poverty: str = json.dumps(
parameters.impact["poverty"]["deep_poverty"]
)
impact_poverty_by_gender: str = json.dumps(
parameters.impact["poverty_by_gender"]
)
impact_deep_poverty: str = json.dumps(parameters.impact["poverty"]["deep_poverty"])
impact_poverty_by_gender: str = json.dumps(parameters.impact["poverty_by_gender"])

all_parameters: AllParameters = AllParameters.model_validate(
{
Expand Down
16 changes: 4 additions & 12 deletions policyengine_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,7 @@ def log_timing(message):

app.route("/<country_id>/calculate-full", methods=["POST"])(
cache.cached(make_cache_key=make_cache_key)(
lambda *args, **kwargs: get_calculate(
*args, **kwargs, add_missing=True
)
lambda *args, **kwargs: get_calculate(*args, **kwargs, add_missing=True)
)
)
log_timing("Calculate-full endpoint registered")
Expand All @@ -153,9 +151,7 @@ def log_timing(message):
app.route("/<country_id>/user-policy", methods=["PUT"])(update_user_policy)
log_timing("User policy update endpoint registered")

app.route("/<country_id>/user-policy/<user_id>", methods=["GET"])(
get_user_policy
)
app.route("/<country_id>/user-policy/<user_id>", methods=["GET"])(get_user_policy)
log_timing("User policy get endpoint registered")

app.register_blueprint(user_profile_bp)
Expand All @@ -177,19 +173,15 @@ def log_timing(message):

@app.route("/liveness-check", methods=["GET"])
def liveness_check():
return flask.Response(
"OK", status=200, headers={"Content-Type": "text/plain"}
)
return flask.Response("OK", status=200, headers={"Content-Type": "text/plain"})


log_timing("Liveness check endpoint registered")


@app.route("/readiness-check", methods=["GET"])
def readiness_check():
return flask.Response(
"OK", status=200, headers={"Content-Type": "text/plain"}
)
return flask.Response("OK", status=200, headers={"Content-Type": "text/plain"})


log_timing("Readiness check endpoint registered")
Expand Down
50 changes: 17 additions & 33 deletions policyengine_api/country.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ def build_metadata(self):
}[self.country_id],
basicInputs=self.tax_benefit_system.basic_inputs,
modelled_policies=self.tax_benefit_system.modelled_policies,
version=get_package_version(
self.country_package_name.replace("_", "-")
),
version=get_package_version(self.country_package_name.replace("_", "-")),
)

def build_microsimulation_options(self) -> dict:
Expand All @@ -77,13 +75,9 @@ def build_microsimulation_options(self) -> dict:
region = [
dict(name="uk", label="the UK", type="national"),
dict(name="country/england", label="England", type="country"),
dict(
name="country/scotland", label="Scotland", type="country"
),
dict(name="country/scotland", label="Scotland", type="country"),
dict(name="country/wales", label="Wales", type="country"),
dict(
name="country/ni", label="Northern Ireland", type="country"
),
dict(name="country/ni", label="Northern Ireland", type="country"),
]
for i in range(len(constituency_names)):
region.append(
Expand Down Expand Up @@ -130,9 +124,7 @@ def build_microsimulation_options(self) -> dict:
dict(name="state/co", label="Colorado", type="state"),
dict(name="state/ct", label="Connecticut", type="state"),
dict(name="state/de", label="Delaware", type="state"),
dict(
name="state/dc", label="District of Columbia", type="state"
),
dict(name="state/dc", label="District of Columbia", type="state"),
dict(name="state/fl", label="Florida", type="state"),
dict(name="state/ga", label="Georgia", type="state"),
dict(name="state/hi", label="Hawaii", type="state"),
Expand Down Expand Up @@ -276,9 +268,9 @@ def build_variables(self) -> dict:
dict(value=value.name, label=value.value)
for value in variable.possible_values
]
variable_data[variable_name][
"defaultValue"
] = variable.default_value.name
variable_data[variable_name]["defaultValue"] = (
variable.default_value.name
)
return variable_data

def build_parameters(self) -> dict:
Expand All @@ -299,9 +291,7 @@ def build_parameters(self) -> dict:
),
}
elif isinstance(parameter, ParameterScaleBracket):
bracket_index = int(
parameter.name[parameter.name.index("[") + 1 : -1]
)
bracket_index = int(parameter.name[parameter.name.index("[") + 1 : -1])
# Set the label to 'first bracket' for the first bracket, 'second bracket' for the second, etc.
bracket_label = f"bracket {bracket_index + 1}"
parameter_data[parameter.name] = {
Expand Down Expand Up @@ -378,9 +368,7 @@ def calculate(
for parameter_name in reform:
for time_period, value in reform[parameter_name].items():
start_instant, end_instant = time_period.split(".")
parameter = get_parameter(
system.parameters, parameter_name
)
parameter = get_parameter(system.parameters, parameter_name)
node_type = type(parameter.values_list[-1].value)
if node_type == int:
node_type = float
Expand Down Expand Up @@ -434,9 +422,9 @@ def calculate(
if any([math.isinf(value) for value in result]):
raise ValueError("Infinite value")
else:
household[entity_plural][entity_id][variable_name][
period
] = result
household[entity_plural][entity_id][variable_name][period] = (
result
)
else:
entity_index = population.get_index(entity_id)
if variable.value_type == Enum:
Expand All @@ -453,19 +441,15 @@ def calculate(
else:
entity_result = result.tolist()[entity_index]

household[entity_plural][entity_id][variable_name][
period
] = entity_result
household[entity_plural][entity_id][variable_name][period] = (
entity_result
)
except Exception as e:
if "axes" in household:
pass
else:
household[entity_plural][entity_id][variable_name][
period
] = None
print(
f"Error computing {variable_name} for {entity_id}: {e}"
)
household[entity_plural][entity_id][variable_name][period] = None
print(f"Error computing {variable_name} for {entity_id}: {e}")

tracer_output = simulation.tracer.computation_log
log_lines = tracer_output.lines(aggregate=False, max_depth=10)
Expand Down
4 changes: 1 addition & 3 deletions policyengine_api/data/congressional_districts.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,9 +683,7 @@ def build_congressional_district_metadata() -> list[dict]:
return [
{
"name": _build_district_name(district.state_code, district.number),
"label": _build_district_label(
district.state_code, district.number
),
"label": _build_district_label(district.state_code, district.number),
"type": "congressional_district",
"state_abbreviation": district.state_code,
"state_name": STATE_CODE_TO_NAME[district.state_code],
Expand Down
8 changes: 2 additions & 6 deletions policyengine_api/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def __init__(
self.local = local
if local:
# Local development uses a sqlite database.
self.db_url = (
REPO / "policyengine_api" / "data" / "policyengine.db"
)
self.db_url = REPO / "policyengine_api" / "data" / "policyengine.db"
if initialize or not Path(self.db_url).exists():
self.initialize()
else:
Expand All @@ -69,9 +67,7 @@ def __init__(
self.initialize()

def _create_pool(self):
instance_connection_name = (
"policyengine-api:us-central1:policyengine-api-data"
)
instance_connection_name = "policyengine-api:us-central1:policyengine-api-data"
self.connector = Connector()
db_user = "policyengine"
db_pass = os.environ["POLICYENGINE_DB_PASSWORD"]
Expand Down
8 changes: 2 additions & 6 deletions policyengine_api/data/model_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ def get_dataset_version(country_id: str) -> str | None:


for dataset in datasets["uk"]:
datasets["uk"][
dataset
] = f"{datasets['uk'][dataset]}@{get_dataset_version('uk')}"
datasets["uk"][dataset] = f"{datasets['uk'][dataset]}@{get_dataset_version('uk')}"

for dataset in datasets["us"]:
datasets["us"][
dataset
] = f"{datasets['us'][dataset]}@{get_dataset_version('us')}"
datasets["us"][dataset] = f"{datasets['us'][dataset]}@{get_dataset_version('us')}"
3 changes: 1 addition & 2 deletions policyengine_api/data/places.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,5 @@ def validate_place_code(place_code: str) -> None:

if not place_fips.isdigit() or len(place_fips) != 5:
raise ValueError(
f"Invalid FIPS code in place: '{place_fips}'. "
"Expected 5-digit FIPS code"
f"Invalid FIPS code in place: '{place_fips}'. Expected 5-digit FIPS code"
)
Loading
Loading