Training Models¶
This guide covers RESOLVE's training configuration and optimization strategies.
Basic Training¶
model = resolve.ResolveModel(schema=dataset.schema, targets=targets)
trainer = resolve.Trainer(model, dataset)
result = trainer.fit()
Trainer Configuration¶
Core Parameters¶
trainer = resolve.Trainer(
model=model,
dataset=dataset,
max_epochs=200, # Maximum training epochs
patience=30, # Early stopping patience
batch_size=256, # Samples per batch
lr=1e-3, # Learning rate
device="auto", # "auto", "cpu", or "cuda"
)
Phased Training¶
RESOLVE uses phased training for regression targets:
trainer = resolve.Trainer(
model=model,
dataset=dataset,
phase_boundaries=(50, 150), # Phase transitions at epochs 50 and 150
)
Phases:
- Phase 1 (epochs 0-50): MAE loss - robust initial learning
- Phase 2 (epochs 50-150): SMAPE loss - scale-invariant refinement
- Phase 3 (epochs 150+): Band accuracy - calibrated predictions
Validation Split¶
Model Architecture¶
Hidden Dimensions¶
Configure the shared encoder's hidden layers:
model = resolve.ResolveModel(
schema=dataset.schema,
targets=targets,
hidden_dims=[256, 128, 64], # 3 hidden layers
)
Hash Dimension¶
Control the species feature hashing dimension:
Species Encoding Options¶
RESOLVE supports four encoding modes via the species_encoding parameter.
Hash encoding (default)¶
Feature hashing maps the full species list into a fixed-dimension vector. Fast and works with any species pool size.
Embed encoding¶
Learned per-species embeddings. More expressive but limited to the top-k species seen during training.
trainer = resolve.Trainer(
dataset,
species_encoding="embed",
species_embed_dim=32,
top_k_species=10,
)
Rank-pool encoding¶
Variable-length species lists with weighted pooling. Avoids padding waste when species richness varies widely across plots.
Transformer encoding¶
Transformer-based encoder with self-attention over species. Captures co-occurrence patterns but requires more data and compute.
trainer = resolve.Trainer(
dataset,
species_encoding="transformer",
n_heads=4,
n_attention_layers=2,
transformer_ff_dim=256,
transformer_pooling="attention", # or "mean"
)
Abundance Normalization¶
Options:
- "raw": Use abundance values directly
- "relative_plot": Normalize to sum to 1 per plot
- "log_scaled": Apply log1p transform
Unknown Species Tracking¶
Enable detailed tracking of novel species:
Training Results¶
The fit() method returns a result object:
result = trainer.fit()
print(f"Best epoch: {result.best_epoch}")
print(f"Stopped at epoch: {result.stopped_epoch}")
for target, metrics in result.final_metrics.items():
print(f"\n{target}:")
for metric, value in metrics.items():
print(f" {metric}: {value:.4f}")
Metrics¶
Regression targets:
- mae: Mean Absolute Error
- smape: Symmetric Mean Absolute Percentage Error
- band_5pct: Fraction within 5% of true value
- band_10pct: Fraction within 10% of true value
- band_20pct: Fraction within 20% of true value
Classification targets:
- accuracy: Overall accuracy
- f1_macro: Macro-averaged F1 score
Saving and Loading¶
# Save trained model
trainer.save("model.pt")
# Load for prediction
predictor = resolve.Predictor.load("model.pt")
GPU Training¶
Or auto-detect:
Next Steps¶
- Making Predictions: Use trained models for inference
- Understanding Embeddings: Extract and interpret latent representations