spine.model.manager
Centralize all methods associated with a machine-learning model.
Classes
|
Groups all relevant functions to construct a model and its loss. |
- class spine.model.manager.ModelManager(name, modules, network_input, loss_input=None, weight_path=None, weight_list=None, train: Mapping[str, Any] | None = None, to_numpy=False, time_dependent_loss=False, dtype='float32', distributed=False, rank=None, detect_anomaly=False, find_unused_parameters=False, iter_per_epoch=None)[source]
Groups all relevant functions to construct a model and its loss.
Methods
__call__(data[, iteration, epoch])Calls the forward (and backward) function on a batch of data.
backward(loss)Run the backward step on the model.
cast_to_numpy(result)Casts the model output data products to numpy object in place.
clean_config(config)Remove model loading/freezing keys from all level of a dictionary.
forward(data[, iteration])Pass one minibatch of data through the network and the loss.
Freeze the weights of certain model components.
initialize_train(optimizer[, weight_prefix, ...])Initialize the training regimen.
load_weights(full_weight_path)Load the weights of certain model components.
prepare_data(data)Fetches the necessary data products to form the input to the forward function and the input to the loss function.
save_state(iteration, epoch)Save the model state.
- initialize_train(optimizer, weight_prefix='snapshot', restore_optimizer=False, save_step=None, save_epoch=None, lr_scheduler=None, iter_per_epoch=None)[source]
Initialize the training regimen.
- Parameters:
optimizer (dict) – Configuration of the optimizer
weight_prefix (str, default 'snapshot') – Path + name of the weight file prefix
save_step (int, optional) – Number of iterations before recording the model weights
save_epoch (float, optional) – Fraction of epoch to train on before recording the model weights
restore_optimizer (bool, default False) – Whether to load the opimizer state from the torch checkpoint
lr_scheduler (dict, optional) – Configuration of the learning rate scheduler
iter_per_epoch (int, optional) – Number of iterations per epoch (relevant for training)
- clean_config(config)[source]
Remove model loading/freezing keys from all level of a dictionary.
This is used to remove the weight loading/freezing from the input configuration before it is fed to the model/loss classes.
- Parameters:
config (dict) – Dictionary to remove the keys from
- freeze_weights()[source]
Freeze the weights of certain model components.
Breadth-first search for freeze_weights parameters in the model configuration. If freeze_weights is True under a module block, requires_grad is set to False for its parameters. The batch normalization and dropout layers are set to evaluation mode.
- load_weights(full_weight_path)[source]
Load the weights of certain model components.
Breadth-first search for weight_path parameters in the model configuration. If ‘weight_path’ is found under a module block, the weights are loaded for its parameters.
If a weight_path is not found for a given module, load the overall weights from weight_path under trainval for that module instead.
- Parameters:
full_weight_path (str) – Path to the weights for the full model
- prepare_data(data)[source]
Fetches the necessary data products to form the input to the forward function and the input to the loss function.
- Parameters:
data (dict) – Dictionary of input data product keys, each of which maps to its associated batched data product
- Returns:
input_dict (dict) – Input to the forward pass of the model
loss_dict (dict) – Labels to be used in the loss computation
- forward(data, iteration=None)[source]
Pass one minibatch of data through the network and the loss.
Load one minibatch of data. pass it through the network forward function and the loss computation. Store the output.
- Parameters:
data (dict) – Dictionary of input data product keys which each map to its associated batched data product
iteration (int, optional) – Iteration number (relevant for time-dependant losses)
- Returns:
Dictionary of model and loss outputs
- Return type:
dict
- backward(loss)[source]
Run the backward step on the model.
- Parameters:
loss (torch.tensor) – Scalar loss value to step the model weights