Predictor¶
The Predictor class provides the inference interface for trained models.
Class Definition¶
class Predictor:
"""
Inference interface for trained RESOLVE models.
Loads saved checkpoints and produces predictions on new data.
"""
Constructor¶
Predictor(model, species_encoder, scalers, device="auto")¶
Create a predictor from model components.
Parameters:
| Parameter | Type | Description |
|---|---|---|
model |
ResolveModel |
Trained model |
species_encoder |
SpeciesEncoder |
Fitted species encoder |
scalers |
dict |
Fitted scalers |
device |
str |
Device for inference |
Class Methods¶
load(path, device="auto")¶
Load predictor from saved checkpoint.
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
path |
str \| Path |
required | Path to checkpoint file |
device |
str |
"auto" |
Device for inference |
Returns: Predictor
Example:
Methods¶
predict(dataset, return_latent=False, output_space="raw")¶
Generate predictions for a dataset.
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
dataset |
ResolveDataset |
required | Dataset to predict on |
return_latent |
bool |
False |
Include latent representations |
output_space |
str |
"raw" |
Output space for regression ("raw" or "transformed") |
Returns: ResolvePredictions
Example:
get_embeddings(dataset)¶
Get latent embeddings for all plots.
Returns: np.ndarray of shape (n_plots, latent_dim)
get_genus_embeddings()¶
Get learned genus embedding weights.
Returns: np.ndarray of shape (n_genera, genus_emb_dim)
get_family_embeddings()¶
Get learned family embedding weights.
Returns: np.ndarray of shape (n_families, family_emb_dim)
ResolvePredictions¶
Container for model predictions.
Attributes¶
| Attribute | Type | Description |
|---|---|---|
predictions |
dict[str, np.ndarray] |
Predictions per target |
plot_ids |
np.ndarray |
Plot identifiers |
latent |
np.ndarray \| None |
Latent representations (if requested) |
Methods¶
__getitem__(target)¶
Get predictions for a specific target.
to_polars()¶
Convert predictions to Polars DataFrame.
Returns: pl.DataFrame
to_csv(path)¶
Save predictions to CSV file.