Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 104 additions & 10 deletions processing/scripts/make_txt_embedding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,44 @@
"""
Makes the entire set of BioCLIP 2 text emebeddings for all possible names in the tree of life.
Makes the entire set of text embeddings for all possible taxonomic names in the tree of life.
Designed for the txt_emb_species.json file from TreeOfLife-200M.

Generalized for any open_clip-compatible model accessible via Hugging Face Hub. Use
--preset for the common BioCLIP variants, or pass --model / --tokenizer / --embed-dim
to point at any other model (e.g. BioCAP, future BioCLIP releases).

Note: lower --batch-size for larger models (e.g. bioclip-2.5-vith14) or smaller
GPUs to avoid CUDA OOM.

Usage:
python make_txt_embedding.py \\
--names-path NAMES.json \\
--out-path OUT.npy \\
(--preset PRESET | --model MODEL [--tokenizer TOKENIZER] --embed-dim N) \\
[--batch-size N]

Examples:
# BioCLIP (ViT-B/16, 512-dim) via preset
python make_txt_embedding.py \\
--names-path txt_emb_bioclip.json \\
--out-path txt_emb_bioclip.npy \\
--preset bioclip \\
--batch-size 16384

# BioCLIP 2.5 Huge (ViT-H/14, 1024-dim) via preset
python make_txt_embedding.py \\
--names-path txt_emb_bioclip-2.5-vith14.json \\
--out-path txt_emb_bioclip-2.5-vith14.npy \\
--preset bioclip-2.5-vith14 \\
--batch-size 16384

# Arbitrary model via explicit args (e.g. a future release)
python make_txt_embedding.py \\
--names-path txt_emb_species.json \\
--out-path txt_emb_custom.npy \\
--model hf-hub:imageomics/<model-id> \\
--tokenizer hf-hub:imageomics/<model-id> \\
--embed-dim <model-dim> \\
--batch-size 8192
"""
import argparse
import json
Expand All @@ -20,23 +58,54 @@
logging.basicConfig(level=logging.INFO, format=log_format)
logger = logging.getLogger()

model_str = "hf-hub:imageomics/bioclip-2"
tokenizer_str = "ViT-L-14"
# Known model presets: (model_str, tokenizer_str, embed_dim).
# --preset is a shorthand; passing --model / --tokenizer / --embed-dim overrides.
PRESETS = {
Comment thread
egrace479 marked this conversation as resolved.
"bioclip": {
"model": "hf-hub:imageomics/bioclip",
"tokenizer": "hf-hub:imageomics/bioclip", # ViT-B/16
"embed_dim": 512,
},
"bioclip-2": {
"model": "hf-hub:imageomics/bioclip-2",
"tokenizer": "hf-hub:imageomics/bioclip-2", # ViT-L/14
"embed_dim": 768,
},
"bioclip-2.5-vith14": {
"model": "hf-hub:imageomics/bioclip-2.5-vith14",
"tokenizer": "hf-hub:imageomics/bioclip-2.5-vith14", # ViT-H/14
"embed_dim": 1024,
},
"biocap": {
"model": "hf-hub:imageomics/biocap",
"tokenizer": "hf-hub:imageomics/biocap", # ViT-B/16
"embed_dim": 512,
},
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


@torch.no_grad()
def write_txt_features(all_names):
def write_txt_features(all_names, embed_dim):
if os.path.isfile(args.out_path):
all_features = np.load(args.out_path)
if all_features.shape != (embed_dim, len(all_names)):
raise SystemExit(
f"Existing {args.out_path} has shape {all_features.shape} but expected "
f"({embed_dim}, {len(all_names)}). Move it aside or pick a fresh --out-path."
)
else:
all_features = np.zeros((768, len(all_names)), dtype=np.float32)
all_features = np.zeros((embed_dim, len(all_names)), dtype=np.float32)

batch_size = args.batch_size // len(openai_imagenet_template)
num_batches = int(len(all_names) / batch_size)
# Ceiling division so the trailing partial batch is processed.
num_batches = (len(all_names) + batch_size - 1) // batch_size
for batch_idx in tqdm(range(num_batches), desc="Extracting text features"):
start = batch_idx * batch_size
end = start + batch_size
# Clamp final batch end to len(all_names) to avoid an IndexError on
# the trailing partial batch.
end = min(start + batch_size, len(all_names))
if all_features[:, start:end].any():
logger.info(
"Skipping batch %d (%d to %d) because it already exists in the output file.",
Expand All @@ -58,7 +127,7 @@ def write_txt_features(all_names):
txts = tokenizer(txts).to(device)
txt_features = model.encode_text(txts)
txt_features = torch.reshape(
txt_features, (len(names), len(openai_imagenet_template), 768)
txt_features, (len(names), len(openai_imagenet_template), embed_dim)
)
txt_features = F.normalize(txt_features, dim=2).mean(dim=1)
txt_features /= txt_features.norm(dim=1, keepdim=True)
Expand All @@ -74,9 +143,34 @@ def write_txt_features(all_names):
parser = argparse.ArgumentParser()
parser.add_argument("--names-path", help="Path to the taxonomic names file (e.g., txt_emb_species.json).", required=True)
parser.add_argument("--out-path", help="Path to the output file.", required=True)
parser.add_argument("--batch-size", help="Batch size.", default=2**14, type=int)
parser.add_argument("--batch-size", default=2**14, type=int,
help="Outer batch size (taxa per step). Lower for larger models / smaller "
"GPUs to avoid CUDA OOM.")
parser.add_argument("--preset", choices=sorted(PRESETS.keys()),
help="Shorthand for a known model (see PRESETS). Overrides --model / "
"--tokenizer / --embed-dim when set.")
parser.add_argument("--model",
help="open_clip model identifier (e.g. 'hf-hub:imageomics/bioclip-2', "
"'hf-hub:imageomics/biocap'). Required unless --preset is given.")
parser.add_argument("--tokenizer",
help="open_clip tokenizer identifier. Defaults to --model when not set.")
parser.add_argument("--embed-dim", type=int,
help="Joint embedding dimension. Required unless --preset is given.")
args = parser.parse_args()

if args.preset:
preset = PRESETS[args.preset]
model_str = preset["model"]
tokenizer_str = preset["tokenizer"]
embed_dim = preset["embed_dim"]
else:
if not args.model or args.embed_dim is None:
parser.error("either --preset or both --model and --embed-dim are required")
model_str = args.model
tokenizer_str = args.tokenizer or args.model
embed_dim = args.embed_dim
logger.info("model=%s tokenizer=%s embed_dim=%d", model_str, tokenizer_str, embed_dim)

model = create_model(model_str, output_dict=True, require_pretrained=True)
model = model.to(device)
logger.info("Created model.")
Expand All @@ -87,4 +181,4 @@ def write_txt_features(all_names):
names = json.load(fd)

tokenizer = get_tokenizer(tokenizer_str)
write_txt_features(names)
write_txt_features(names, embed_dim)
71 changes: 58 additions & 13 deletions processing/scripts/make_txt_embedding_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

This script creates a JSON file for species embeddings by:
- Filtering catalog to only entries with non-null kingdom and non-null species
- Removing rows where any taxonomic rank looks like a CSV-parsing leak
(ISO-8601 timestamp or the literal string 'true' / 'false'); disable with
--no-corruption-filter to reproduce the pre-v2 upstream behavior
- For each remaining unique taxonomy, collecting all available common names
- Preferring English common names from GBIF VernacularNames.tsv (from GBIF Backbone Taxonomy), falling back to any language
- Sorting by taxonomy and outputting in [[taxonomy_array], common_name] format
Expand All @@ -23,6 +26,12 @@
import argparse
from pathlib import Path

# Regex that matches the two CSV-parsing leak patterns observed in TOL-200M
# catalogs (kingdom slot occasionally contains an ISO-8601 timestamp or a
# Boolean literal that bled in from an adjacent column).
CORRUPTION_PATTERN = r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}|^(?i:true|false)$"
TAXONOMIC_RANKS = ["kingdom", "phylum", "class", "order", "family", "genus", "species"]

def load_vernacular_names(vernacular_path: str) -> tuple[set, set]:
"""Load and return sets of vernacular names from GBIF's TSV file."""
print(f"Loading vernacular names from: {vernacular_path}")
Expand Down Expand Up @@ -78,21 +87,43 @@ def select_best_common_name_from_list(names_list, english_names: set, all_names:
return ""


def process_catalog_to_embeddings(catalog_path: str, english_names: set, all_names: set, output_path: str):
def drop_corrupted_rows(df: pl.DataFrame) -> tuple[pl.DataFrame, int]:
"""Drop rows where any taxonomic rank matches CORRUPTION_PATTERN.

Returns the filtered DataFrame and the number of rows dropped.
"""
n_before = len(df)
# Build an OR-mask across all 7 ranks. A rank that's null contributes False.
mask = pl.lit(False)
for rank in TAXONOMIC_RANKS:
mask = mask | (
pl.col(rank).is_not_null() & pl.col(rank).str.contains(CORRUPTION_PATTERN)
)
df = df.filter(~mask)
return df, n_before - len(df)


def process_catalog_to_embeddings(catalog_path: str, english_names: set, all_names: set, output_path: str, apply_corruption_filter: bool = True):
"""Process catalog data into embeddings JSON format."""
print(f"Loading catalog from: {catalog_path}")
# Load catalog

# Load catalog
df_catalog = pl.read_parquet(catalog_path)
print(f"\tTotal catalog entries: {len(df_catalog)}")

# Filter to only keep entries with non-null kingdom AND species
df_filtered = df_catalog.filter(
(pl.col("kingdom").is_not_null()) &
(pl.col("kingdom").is_not_null()) &
(pl.col("species").is_not_null())
)
print(f"\tAfter null kingdom/species filtering: {len(df_filtered)}")


# Drop rows whose taxonomic ranks contain CSV-parsing leaks (ISO-8601
# timestamps or boolean literals). Off via --no-corruption-filter.
if apply_corruption_filter:
df_filtered, n_dropped = drop_corrupted_rows(df_filtered)
print(f"\tCorruption filter dropped: {n_dropped} rows")

# Get all unique taxonomies with their common names from the catalog
df_grouped = (
df_filtered
Expand Down Expand Up @@ -188,25 +219,39 @@ def main():
default="txt_emb_species.json",
help="Output JSON file path"
)


parser.add_argument(
"--no-corruption-filter",
action="store_true",
help="Disable the ISO-8601 / boolean corruption filter on taxonomic "
"ranks (reproduces the BioCLIP 2 upstream behavior)."
)

args = parser.parse_args()

# Check input files exist
if not Path(args.catalog_path).exists():
raise FileNotFoundError(f"Catalog file not found: {args.catalog_path}")
if not Path(args.vernacular_path).exists():
raise FileNotFoundError(f"VernacularNames file not found: {args.vernacular_path}")

print(f"Catalog: {args.catalog_path}")
print(f"Vernacular names: {args.vernacular_path}")
print(f"Output: {args.output}")

print(f"Corruption filter: {'OFF' if args.no_corruption_filter else 'ON'}")

# Load vernacular names
english_names, all_names = load_vernacular_names(args.vernacular_path)

# Process catalog, generate JSON
entry_count = process_catalog_to_embeddings(args.catalog_path, english_names, all_names, args.output)

entry_count = process_catalog_to_embeddings(
args.catalog_path,
english_names,
all_names,
args.output,
apply_corruption_filter=not args.no_corruption_filter,
)

print(f"\nEmbeddings JSON complete.")


Expand Down