spine.model.grappa
GrapPA: Graph Neural Network for Particle Aggregation.
This module implements the GrapPA (Graph Particle Aggregation) architecture, a graph neural network designed for clustering and grouping particle instances.
GrapPA learns to aggregate fragment-level features into particle-level clusters through message passing and edge classification, enabling particle instance segmentation and identification.
Classes
|
Graph Particle Aggregator (GrapPA) model. |
|
Takes the output of the GrapPA and computes the total loss. |
- class spine.model.grappa.GrapPA(*args: Any, **kwargs: Any)[source]
Graph Particle Aggregator (GrapPA) model.
This class mostly acts as a wrapper that will hand the graph data to the underlying graph neural network (GNN).
When trained standalone, this model must be provided with a cluster label tensor, allowing it to build a set of intput clusters based on the label boundaries of the clusters and their semantic types.
Typical configuration can look like this:
model: name: grappa modules: grappa: nodes: <dictionary of arguments to specify the input type> graph: name: <name of the input graph type> <dictionary of arguments to specify the graph> node_encoder: name: <name of the type of node encoder> <dictionary of arguments to specify the node encoder> edge_encoder: name: <name of the type of edge encoder> <dictionary of arguments to specify the edge encoder> global_encoder: name: <name of the type of global encoder> <dictionary of arguments to specify the global encoder> gnn_model: name: <name of the type of backbone GNN feature extractor> <dictionary of arguments to specify the GNN>
See configuration files prefixed with grappa_ under the config directory for detailed examples of working configurations.
See also
Methods
__call__(*args, **kwargs)Call self as a function.
forward(data[, coord_label, clusts, ...])Prepares particle clusters and feed them to the GNN model.
process_dbscan_config([shapes, min_size])Process the DBSCAN fragmenter configuration.
process_final_config(final, prefix)Process a final layer configuration.
process_gnn_config([node_pred, edge_pred, ...])Process the GNN backbone structure and the output layers.
process_model_config(gnn_model[, nodes, ...])Process the top-level configuration block.
process_node_config([source, shapes, ...])Process the node parameters of the model.
- MODULES = [('grappa', ['base', 'dbscan', 'node_encoder', 'edge_encoder', 'gnn_model']), 'grappa_loss']
- process_model_config(gnn_model, nodes=None, graph=None, node_encoder=None, edge_encoder=None, global_encoder=None, dbscan=None, return_features=False)[source]
Process the top-level configuration block.
This dispatches each block to its own configuration processor.
- Parameters:
gnn_model (dict) – Underlying graph neural network configuration
nodes (dict, optional) – Input node configuration
graph (dict, optional) – Input graph configuration
node_encoder (dict, optional) – Node encoder configuration
edge_encoder (dict, optional) – Edge encoder configuration
global_encoder (dict, optional) – Global encoder configuration
dbscan (dict, optional) – DBSCAN fragmentation configuration
return_features (bool, default False) – If True, the model will return the node/edge/global features
- process_node_config(source='cluster', shapes=None, min_size=-1, make_groups=False, grouping_method='score', grouping_through_track=False)[source]
Process the node parameters of the model.
- Parameters:
source (str, default 'cluster') – Column name in the label tensor which contains the input cluster IDs
shapes (int, optional) – Type of nodes to include in the input. If not specified, include all types
min_size (int, default -1) – Minimum number of voxels in a cluster to be included in the input
make_groups (bool, default False) – Use edge predictions to build node groups
grouping_method (str, default 'score') – Algorithm used to build a node partition
grouping_through_track (bool, default False) – If True, shower objects can only be connected to one track object
- process_gnn_config(node_pred=None, edge_pred=None, global_pred=None, **gnn_model)[source]
Process the GNN backbone structure and the output layers.
- Parameters:
node_pred (Union[int, dict], optional) – Number of node predictions. If there are multiple node predictions, provide a (key, value) pair for each type of prediction
edge_pred (Union[int, dict], optional) – Number of edge predictions. If there are multiple edge predictions, provide a (key, value) pair for each type of prediction
global_pred (Union[int, dict], optional) – Number of global predictions. If there are multiple global predictions, provide a (key, value) pair for each type of prediction
**gnn_model – Paramters to initialize the GNN backbone
dict – Paramters to initialize the GNN backbone
- process_final_config(final, prefix)[source]
Process a final layer configuration.
- Parameters:
final (Union[int, dict]) – Final layer configuration
prefix (str) – Name of the final layer
- process_dbscan_config(shapes=None, min_size=None, **kwargs)[source]
Process the DBSCAN fragmenter configuration.
- Parameters:
shapes (Union[int, list], optional) – This should not be specified (fetched from the node configuration)
min_size (Union[int, list], optional) – This should not be specified (fetched from the node configuration)
**kwargs (dict, optional) – Rest of the DBSCAN configuration
- forward(data, coord_label=None, clusts=None, edge_index=None, node_features=None, edge_features=None, global_features=None, shapes=None, groups=None, points=None, extra=None)[source]
Prepares particle clusters and feed them to the GNN model.
- Parameters:
data (TensorBatch) – Tensor of voxel/value pairs with shape (N, 1 + D + N_f), where N is the total number of voxels, the leading column stores the batch ID, D is the image dimensionality and N_f is the number of features. When clusts is not provided, the features must also contain the labels needed to build clusters on the fly.
coord_label (TensorBatch, optional) – (P, 1 + 2*D + 2) Tensor of label points (start/end/time/shape)
clusts (IndexBatch, optional) –
List of indexes corresponding to each cluster
edge_index (EdgeIndexBatch, optional) – (E, 2) Incidence matrix. If not provided, it will be built based on the cluster indexes and the graph configuration
node_features (TensorBatch, optional) – (C, N_c,f) Node features. If not provided, they will be built based on
edge_features (TensorBatch, optional) – (C, N_e,f) Edge features. If not provided, they will be built based on the cluster indexes and the edge encoder configuration
global_features (TensorBatch, optional) – (C, N_g,f) Global features. If not provided, they will be built based on the cluster indexes and the global encoder configuration
shapes (TensorBatch, optional) –
List of cluster semantic class used to define the max length
groups (TensorBatch, optional) – (C) List of node groups, one per cluster. If specified, removes connections between nodes that belong to different groups.
points (TensorBatch, optional) – (C, 3/6) Tensor of start (and end) points
extra (TensorBatch, optional) – (C, N_f) Batch of features to append to the existing node features
- Returns:
clusts (IndexBatch) – (C, N_c, N_{c,i}) Cluster indexes
edge_index (EdgeIndexBatch) – (E, 2) Incidence matrix
node_features (TensorBatch) – (C, N_c,f) Node features
edge_features (TensorBatch) – (C, N_e,f) Node features
global_features (TensorBatch) – (C, N_g,f) Global features
node_pred (TensorBatch) – (C, N_n) Node predictions (logits)
edge_pred (TensorBatch) – (C, N_e) Edge predictions (logits)
global_pred (TensorBatch) – (C, N_e) Global predictions (logits)
- class spine.model.grappa.GrapPALoss(*args: Any, **kwargs: Any)[source]
Takes the output of the GrapPA and computes the total loss.
For use in config:
model: name: grappa modules: grappa_loss: node_loss: name: <name of the node loss> <dictionary of arguments to pass to the loss> edge_loss: name: <name of the edge loss> <dictionary of arguments to pass to the loss> global_loss: name: <name of the global loss> <dictionary of arguments to pass to the loss>
Each of the specific loss blocks can also contain multiple losses by providing a name key in a loss block nested below it. Each loss name of a specific type should be provided with a corresponding output from GRaPA.
See configuration files prefixed with grappa_ under the config directory for detailed examples of working configurations.
Methods
__call__(*args, **kwargs)Call self as a function.
forward(clust_label[, coord_label, ...])Apply the node/edge/global losses to the logits from GrapPA.
process_loss_config([node_loss, edge_loss, ...])Process the loss configuration.
process_single_loss_config(prefix, loss, ...)Process a loss configuration.
- process_loss_config(node_loss=None, edge_loss=None, global_loss=None)[source]
Process the loss configuration.
- Parameters:
node_loss (Union[dict, Dict[dict]], optional) – Node loss configuration
edge_loss (Union[dict, Dict[dict]], optional) – Edge loss configuration
global_loss (Union[dict, Dict[dict]], optional) – Global loss configuration
- process_single_loss_config(prefix, loss, constructor)[source]
Process a loss configuration.
- Parameters:
prefix (dict) – Name of the output type to apply the loss to
loss (Union[int, dict]) – Loss configuration
constructor (object) – Loss constructor function
- forward(clust_label, coord_label=None, graph_label=None, iteration=None, **output)[source]
Apply the node/edge/global losses to the logits from GrapPA.
- Parameters:
clust_label (TensorBatch) – (N, 1 + D + N_f) Tensor of voxel/value pairs - N is the the total number of voxels in the image - 1 is the batch ID - D is the number of dimensions in the input image - N_f is is the number of cluster labels
coord_label (TensorBatch, optional) – (P, 1 + D + 8) Tensor of start/end point labels for each true particle in the image
graph_label (EdgeIndexTensor, optional) – (2, E) Tensor of edges that correspond to physical connections between true particle in the image
iteration (int, optional) – Iteration index
**output (dict) – Output of the GrapPA model