Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ logs/
# data -- include the README but ignore everything else
data/*
!data/README.md

evaluation_checkpoint/*
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ To generate new data, a minimal example is
python main.py --mode generate --num_samples [num_samples] --load [path/to/parameters.pt]
```

### Evaluate
To evaluate the generated samples, a minimal example is (here assuming WBM is the reference set)
```
python main.py --mode evaluate --train_datafile data/wbm/raw/wbm_train.csv --generated_datafile [path/to/generated_samples.pt]
```
This will default to using 10k samples (from both the generated and reference/train set) when computing all metrics. If you prefer another number, this can be set with `--num_samples_in_evaluation`.

**Note regarding pre-trained weights:** The pre-trained weights of the Wrenformer used for computing FWD will be automatically downloaded. If you prefer to download these weights yourself, create a new directory called `evaluation_checkpoint` and place the weights called ``checkpoint.pth`` there. The code will run a checksum to verify that they are the same weights.

**Note regarding compatibility with other models:** The evaluation code also supports reading data generated by other models that we compared to in the paper (CDVAE, DiffCSP++, SymmCD). The data will be converted to protostructures.

### Parse generated data
To convert generated data to protostructures and prototypes, run
```
Expand Down
22 changes: 22 additions & 0 deletions wyckoff_generation/common/args_and_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
"mlp_activation": "SiLU",
# Generation
"num_samples": 10000,
# Evaluation
"num_samples_in_evaluation": 10000,
}


Expand Down Expand Up @@ -242,6 +244,26 @@ def get_parser():
help="Number of samples to generate",
)

# Evaluation
parser.add_argument(
"--generated_datafile",
type=str,
help="Path to generated data",
)

parser.add_argument(
"--train_datafile",
type=str,
help="Path to csv file with training data (for evaluation)",
)

parser.add_argument(
"--num_samples_in_evaluation",
type=int,
default=default_args_dict["num_samples_in_evaluation"],
help="Number of samples to compute evaluation statistics on",
)

# Post processing of generated samples
parser.add_argument(
"--post_process_all",
Expand Down
9 changes: 9 additions & 0 deletions wyckoff_generation/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

"""

import hashlib
import importlib
import os
import re
Expand All @@ -34,6 +35,14 @@ def get_pretrained_checkpoint(load_path, best=True):
return checkpoint


def compare_hash(data_path, correct_hash):
sha256_hash = hashlib.sha256()
with open(data_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest() == correct_hash


def increment_filename(file_path):
# Split the file path into directory, filename, and extension
directory, filename = os.path.split(file_path)
Expand Down
80 changes: 80 additions & 0 deletions wyckoff_generation/datasets/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import aviary.wren.data as aviary_wren_data
import aviary.wren.utils as aviary_wren_utils
import pandas as pd
import torch
from aviary.wren.utils import (
canonicalize_element_wyckoffs,
Expand Down Expand Up @@ -304,3 +305,82 @@ def enrich_dataset(
return dataset, all_aflow_labels, all_prototype_labels

return dataset


def compare_generated_with_training_dataset_fast_label_list(
generated_dataset: list[list[int, str, str]],
training_dataset: list[str],
set_name: str,
return_duplicate_names: bool = False,
) -> pd.DataFrame:

"""

generated_dataset: list of lists, where second level contains [original_index, protostructure_label, prototype_label]
training_dataset: list of protostructure labels from reference set (e.g., train/val/test)
set_name: name of dataset, include split for explicitness
return_duplicate_names: return the duplicate names if needed for other function

returns: pd.DataFrame containing protostructures, prototypes, novel, novel_prototype, duplicates_{set_name}_set", duplicates_{set_name}_set_prototype
"""

print(
f"Identifying novelty of generated data compared with {set_name} dataset, saving to attribute 'novel' and 'duplicates_{set_name}_set'...",
file=sys.stdout,
)

# Create a dataframe from the aflow label lists and canonical prototypes in the generated dataset
gen_df = pd.DataFrame(
generated_dataset, columns=["original_index", "protostructures", "prototypes"]
)

# --- Protostructures and prototypes matching

train_df = pd.DataFrame(training_dataset, columns=["protostructures"])
train_df["prototypes"] = train_df.protostructures.apply(
aviary_wren_utils.get_prototype_from_protostructure
)

# Check if novel (note that novel is opposite to is-in)
gen_df["novel"] = ~gen_df.protostructures.isin(train_df.protostructures)
gen_df["novel_prototype"] = ~gen_df.prototypes.isin(train_df.prototypes)

# Collect duplicates
# Protostructures
duplicate_attribute_name = f"duplicates_{set_name}_set"
train_grouped_protostructures = (
train_df.groupby("protostructures")["protostructures"]
.apply(list)
.reset_index(name=duplicate_attribute_name)
)
gen_df = gen_df.merge(
train_grouped_protostructures,
how="left",
left_on="protostructures",
right_on="protostructures",
)
gen_df[duplicate_attribute_name] = gen_df[duplicate_attribute_name].apply(
lambda x: x if isinstance(x, list) else []
)

# Prototypes
duplicate_prototype_attribute_name = f"duplicates_{set_name}_set_prototype"
train_grouped_prototypes = (
train_df.groupby("prototypes")["prototypes"]
.apply(list)
.reset_index(name=duplicate_prototype_attribute_name)
)
gen_df = gen_df.merge(
train_grouped_prototypes,
how="left",
left_on="prototypes",
right_on="prototypes",
)
gen_df[duplicate_prototype_attribute_name] = gen_df[
duplicate_prototype_attribute_name
].apply(lambda x: x if isinstance(x, list) else [])

if return_duplicate_names:
return gen_df, duplicate_attribute_name, duplicate_prototype_attribute_name

return gen_df
9 changes: 1 addition & 8 deletions wyckoff_generation/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch_geometric.loader import DataLoader

from wyckoff_generation.common.registry import registry
from wyckoff_generation.common.utils import compare_hash
from wyckoff_generation.datasets.lookup_tables import (
element_number,
spg_wyckoff,
Expand Down Expand Up @@ -199,14 +200,6 @@ def preprocess(raw_file_path) -> pd.DataFrame:
return parsed_aflow_labels


def compare_hash(data_path, correct_hash):
sha256_hash = hashlib.sha256()
with open(data_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest() == correct_hash


def decompress_bz2_file(
compressed_file_path, remove_original=False, decompressed_file_path=None
):
Expand Down
146 changes: 146 additions & 0 deletions wyckoff_generation/evaluation/compute_fwd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import torch
from aviary.wrenformer.data import df_to_in_mem_dataloader
from aviary.wrenformer.model import Wrenformer
from tqdm import tqdm

from wyckoff_generation.evaluation import read_file_utils
from wyckoff_generation.evaluation.frechet_distance import (
frechet_distance_from_embeddings,
)
from wyckoff_generation.evaluation.novelty_helper import (
get_enriched_df,
get_statistics_from_df,
)


def get_embeddings(model, dataset):
store = []

def hook_fn(module, input, output):
store.append(output)
return output

target_layer = list(model.children())[-2]
hook_handle = target_layer.register_forward_hook(hook_fn)

# ids_list = []
with torch.no_grad():
for d in dataset:
(padded_features, mask, equivalence_counts), targets, ids = d
# ids_list.extend(ids.tolist())
output = model(padded_features, mask, equivalence_counts)
hook_handle.remove()

return torch.cat(store, dim=0)


def main(args):
train_data_df_full = read_file_utils.get_dataset_df(
args["train_datafile"],
)
print("Parsing generated materials")
gen_data_df_full = read_file_utils.get_dataset_df(args["generated_datafile"])

assert len(train_data_df_full.index) >= args["num_samples_in_evaluation"], len(
train_data_df_full.index
)
assert len(gen_data_df_full.index) >= args["num_samples_in_evaluation"], len(
gen_data_df_full.index
)

gen_data_df_enriched = get_enriched_df(gen_data_df_full, train_data_df_full)
gen_data_df_subsampled = gen_data_df_enriched.sample(
n=args["num_samples_in_evaluation"],
replace=False,
ignore_index=True,
random_state=42,
)
train_data_df_subsampled = train_data_df_full.sample(
n=args["num_samples_in_evaluation"],
replace=False,
ignore_index=True,
random_state=42,
)

gen_data_fwd = df_to_in_mem_dataloader(
gen_data_df_subsampled,
input_col="protostructures",
batch_size=args["batch_size"],
shuffle=False,
)
train_data_fwd = df_to_in_mem_dataloader(
train_data_df_subsampled,
input_col="wyckoff",
batch_size=args["batch_size"],
shuffle=False,
)

state_dict = torch.load("evaluation_checkpoint/checkpoint.pth", map_location="cpu")
model = Wrenformer(**state_dict["model_params"]).to(args["device"])
model.load_state_dict(state_dict["model_state"])
model.train(False)
assert not model.training
print("Computing training embeddings")

# to improve stability, use double precision
print("Computing Wrenformer embeddings of generated and training materials")
train_embeddings = get_embeddings(model, train_data_fwd).double()
gen_embeddings = get_embeddings(model, gen_data_fwd).double()
fwd = float(frechet_distance_from_embeddings(train_embeddings, gen_embeddings))

stats_subsampled = get_statistics_from_df(gen_data_df_subsampled)
stats_subsampled["fwd"] = fwd
print("\n\n----Stats for generated materials----")
for key, value in stats_subsampled.items():
if isinstance(value, float):
print(f"{key}: {value}")

gen_data_novel_only = gen_data_df_enriched.loc[gen_data_df_enriched["novel"]]
assert len(gen_data_novel_only.index) >= args["num_samples_in_evaluation"]
gen_data_novel_subsampled = gen_data_novel_only.sample(
n=args["num_samples_in_evaluation"],
replace=False,
ignore_index=True,
random_state=42,
)
gen_data_novel_fwd = df_to_in_mem_dataloader(
gen_data_novel_subsampled,
input_col="protostructures",
batch_size=args["batch_size"],
shuffle=False,
)

gen_novel_embeddings = get_embeddings(model, gen_data_novel_fwd).double()
fwd_novel = float(
frechet_distance_from_embeddings(train_embeddings, gen_novel_embeddings)
)
stats_novel = get_statistics_from_df(gen_data_novel_subsampled)
stats_novel["fwd"] = fwd_novel
print("\n\n----Stats for generated novel materials----")
for key, value in stats_novel.items():
if isinstance(value, float):
print(f"{key}: {value}")

result_string = " & ".join(
[
f"{stats_subsampled['fwd']:.2f}",
f"{stats_subsampled['novelty']*100:.1f}",
f"{stats_subsampled['uniqueness']*100:.1f}",
f"{stats_novel['fwd']:.2f}",
f"{stats_novel['uniqueness']*100:.1f}",
]
)
print("\n\n----Results string for LaTeX table----\n", result_string)

full_results_dict = {
"stats_subsampled": {
key: value
for key, value in stats_subsampled.items()
if isinstance(value, float)
},
"stats_novel": {
key: value for key, value in stats_novel.items() if isinstance(value, float)
},
"result_string": result_string,
}
return full_results_dict, gen_data_df_subsampled, gen_data_novel_subsampled
19 changes: 19 additions & 0 deletions wyckoff_generation/evaluation/compute_prototype_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
import sys

import pandas as pd


def main(folder, num_samples):
file = os.path.join(
folder, f"gen_data_novel_subsampled_num_samples={num_samples}.csv"
)
df = pd.read_csv(file)
assert len(df.index) == 10000
assert df["novel"].all()

novel_prototype_df = df[df["novel_prototype"]]
# print(novel_prototype_df.head())

unique_prototypes = novel_prototype_df["prototypes"].unique().tolist()
print(len(unique_prototypes))
Loading