Graph Neural Network with Multi Modal Encoder
- class GNN_MME.Encoder(input_dim, latent_dim, output_dim)[source]
A simple feed-forward neural network used as an encoder with two linear layers separated by dropout and batch normalization.
- encoder
Contains the linear layers of the encoder.
- Type:
nn.ModuleList
- norm
Contains the batch normalization layers corresponding to each encoder layer.
- Type:
nn.ModuleList
- decoder
A sequential module containing the decoder part of the model.
- Type:
torch.nn.Sequential
- drop
Dropout layer to prevent overfitting.
- Type:
nn.Dropout
- Parameters:
input_dim (int) – Dimensionality of the input features.
latent_dim (int) – Dimensionality of the latent space (middle layer’s output).
output_dim (int) – Dimensionality of the output features after decoding.
- forward(x)[source]
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class GNN_MME.GAT_MME(input_dims, latent_dims, decoder_dim, hidden_feats, num_classes, PNet=None)[source]
A multi-modal GAT (Graph Attention Network) model utilizing encoder modules for initial feature transformation, applying GATConv convolution over the graph structure.
This model combines several data modalities, each processed by separate encoders, integrates the encoded features, and performs graph-based learning to produce node embeddings or class scores.
- encoder_dims
List of encoder modules for each modality of data.
- Type:
nn.ModuleList
- gnnlayers
List of GATConv convolution layers for propagating and transforming node features across the graph.
- Type:
nn.ModuleList
- batch_norms
Batch normalization applied to the outputs of GNN layers except the last layer.
- Type:
nn.ModuleList
- num_layers
Total number of GNN layers.
- Type:
int
- input_dims
List of input dimensions, one for each data modality.
- Type:
list
List of the feature dimensions for each hidden layer in the GNN.
- Type:
list
- num_classes
Number of output classes or the dimension of the output features.
- Type:
int
- drop
Dropout layer applied after each GNN layer for regularization.
- Type:
nn.Dropout
- Parameters:
input_dims (list) – Input dimensions for each modality of input data.
latent_dims (list) – Latent dimensions for corresponding encoders processing each modality of input data.
decoder_dim (int) – Unified dimension to which all modalities are decoded.
hidden_feats (list) – Dimensions for hidden layers of the GNN.
num_classes (int) – Number of classes for classification tasks.
PNet (optional) – A PNet model for embedding pathway networks, used as an optional modality-specific encoder.
- embedding_extraction(g, h, device, batch_size)[source]
Extract embeddings for the nodes in the graph. This method is typically used to retrieve node embeddings that can then be used for visualization, clustering, or as input for downstream tasks.
- Parameters:
g (dgl.DGLGraph) – The graph for which embeddings are to be retrieved.
h (torch.Tensor) – Node features tensor.
device (torch.device) – The device to perform computations on.
batch_size (int) – Size of the batches to use during the computation.
- Returns:
Node embeddings extracted by the model.
- Return type:
torch.Tensor
- forward(g, h)[source]
Forward pass for GSage_MME embedding computation.
- Parameters:
g (dgl.DGLGraph) – Input graph.
h (torch.Tensor) – Feature matrix.
- Returns:
Output after passing through the GNN layers.
- Return type:
torch.Tensor
- inference(g, h, device, batch_size)[source]
Perform a forward pass using the model for inference, without computing gradients. Usually used after the model has been trained.
- Parameters:
g (dgl.DGLGraph) – The DGL graph on which inference is performed.
h (torch.Tensor) – Node features for all nodes in the graph.
device (torch.device) – The device tensors will be sent to.
batch_size (int) – The size of batches to use during inference.
- Returns:
The outputs of the inference.
- Return type:
torch.Tensor
- class GNN_MME.GCN_MME(input_dims, latent_dims, decoder_dim, hidden_feats, num_classes, PNet=None)[source]
A multi-modal GraphSAGE model utilizing encoder modules for initial feature transformation, applying GraphConv convolution over the graph structure.
This model combines several data modalities, each processed by separate encoders, integrates the encoded features, and performs graph-based learning to produce node embeddings or class scores.
- encoder_dims
List of encoder modules for each modality of data.
- Type:
nn.ModuleList
- gnnlayers
List of GraphConv convolution layers for propagating and transforming node features across the graph.
- Type:
nn.ModuleList
- batch_norms
Batch normalization applied to the outputs of GNN layers except the last layer.
- Type:
nn.ModuleList
- num_layers
Total number of GNN layers.
- Type:
int
- input_dims
List of input dimensions, one for each data modality.
- Type:
list
List of the feature dimensions for each hidden layer in the GNN.
- Type:
list
- num_classes
Number of output classes or the dimension of the output features.
- Type:
int
- drop
Dropout layer applied after each GNN layer for regularization.
- Type:
nn.Dropout
- Parameters:
input_dims (list) – Input dimensions for each modality of input data.
latent_dims (list) – Latent dimensions for corresponding encoders processing each modality of input data.
decoder_dim (int) – Unified dimension to which all modalities are decoded.
hidden_feats (list) – Dimensions for hidden layers of the GNN.
num_classes (int) – Number of classes for classification tasks.
PNet (optional) – A PNet model for embedding pathway networks, used as an optional modality-specific encoder.
- embedding_extraction(g, h, device, batch_size)[source]
Extract embeddings for the nodes in the graph. This method is typically used to retrieve node embeddings that can then be used for visualization, clustering, or as input for downstream tasks.
- Parameters:
g (dgl.DGLGraph) – The graph for which embeddings are to be retrieved.
h (torch.Tensor) – Node features tensor.
device (torch.device) – The device to perform computations on.
batch_size (int) – Size of the batches to use during the computation.
- Returns:
Node embeddings extracted by the model.
- Return type:
torch.Tensor
- feature_importance(test_dataloader, device)[source]
Calculate feature importances using the Conductance algorithm through Captum.
- Parameters:
test_dataset (torch.Tensor) – The dataset for which to calculate importances.
target_class (int) – The target class index for which to calculate importances.
- Returns:
A dataframe containing the feature importances.
- Return type:
pd.DataFrame
- forward(h, g)[source]
Forward pass for GSage_MME embedding computation.
- Parameters:
g (dgl.DGLGraph) – Input graph.
h (torch.Tensor) – Feature matrix.
- Returns:
Output after passing through the GNN layers.
- Return type:
torch.Tensor
- inference(g, h, device, batch_size)[source]
Perform a forward pass using the model for inference, without computing gradients. Usually used after the model has been trained.
- Parameters:
g (dgl.DGLGraph) – The DGL graph on which inference is performed.
h (torch.Tensor) – Node features for all nodes in the graph.
device (torch.device) – The device tensors will be sent to.
batch_size (int) – The size of batches to use during inference.
- Returns:
The outputs of the inference.
- Return type:
torch.Tensor
- layerwise_importance(test_dataloader, device)[source]
Compute layer-wise importance scores across all layers for given targets.
- Parameters:
test_dataset (torch.Tensor) – The dataset for which to calculate importances.
target_class (int) – The target class index for importance calculation.
- Returns:
A list containing the importance scores for each layer.
- Return type:
List[pd.DataFrame]
- class GNN_MME.GSage_MME(input_dims, latent_dims, decoder_dim, hidden_feats, num_classes, PNet=None)[source]
A multi-modal GraphSAGE model utilizing encoder modules for initial feature transformation, applying GraphSAGE convolution over the graph structure.
This model combines several data modalities, each processed by separate encoders, integrates the encoded features, and performs graph-based learning to produce node embeddings or class scores.
- encoder_dims
List of encoder modules for each modality of data.
- Type:
nn.ModuleList
- gnnlayers
List of GraphSAGE convolution layers for propagating and transforming node features across the graph.
- Type:
nn.ModuleList
- batch_norms
Batch normalization applied to the outputs of GNN layers except the last layer.
- Type:
nn.ModuleList
- num_layers
Total number of GNN layers.
- Type:
int
- input_dims
List of input dimensions, one for each data modality.
- Type:
list
List of the feature dimensions for each hidden layer in the GNN.
- Type:
list
- num_classes
Number of output classes or the dimension of the output features.
- Type:
int
- drop
Dropout layer applied after each GNN layer for regularization.
- Type:
nn.Dropout
- Parameters:
input_dims (list) – Input dimensions for each modality of input data.
latent_dims (list) – Latent dimensions for corresponding encoders processing each modality of input data.
decoder_dim (int) – Unified dimension to which all modalities are decoded.
hidden_feats (list) – Dimensions for hidden layers of the GNN.
num_classes (int) – Number of classes for classification tasks.
PNet (optional) – A PNet model for embedding pathway networks, used as an optional modality-specific encoder.
- embedding_extraction(g, h, device, batch_size)[source]
Extract embeddings for the nodes in the graph. This method is typically used to retrieve node embeddings that can then be used for visualization, clustering, or as input for downstream tasks.
- Parameters:
g (dgl.DGLGraph) – The graph for which embeddings are to be retrieved.
h (torch.Tensor) – Node features tensor.
device (torch.device) – The device to perform computations on.
batch_size (int) – Size of the batches to use during the computation.
- Returns:
Node embeddings extracted by the model.
- Return type:
torch.Tensor
- forward(g, h)[source]
Forward pass for GSage_MME embedding computation.
- Parameters:
g (dgl.DGLGraph) – Input graph.
h (torch.Tensor) – Feature matrix.
- Returns:
Output after passing through the GNN layers.
- Return type:
torch.Tensor
- inference(g, h, device, batch_size)[source]
Perform a forward pass using the model for inference, without computing gradients. Usually used after the model has been trained.
- Parameters:
g (dgl.DGLGraph) – The DGL graph on which inference is performed.
h (torch.Tensor) – Node features for all nodes in the graph.
device (torch.device) – The device tensors will be sent to.
batch_size (int) – The size of batches to use during inference.
- Returns:
The outputs of the inference.
- Return type:
torch.Tensor