MOGDx & PNet
This notebook shows the operability of MOGDx with a biologically interpretable encoder architecture modelled from the paper PNet
import pandas as pd
import numpy as np
import os
import sys
sys.path.insert(0 , './../')
from MAIN.utils import *
from MAIN.train import *
import MAIN.preprocess_functions
from MAIN.GNN_MME import GCN_MME , GSage_MME , GAT_MME
from Modules.PNetTorch.MAIN.reactome import ReactomeNetwork
from Modules.PNetTorch.MAIN.Pnet import MaskedLinear , PNET
from Modules.PNetTorch.MAIN.utils import numpy_array_to_one_hot, get_gpu_memory
from Modules.PNetTorch.MAIN.interpret import interpret , evaluate_interpret_save , visualize_importances
import torch
import torch.nn.functional as F
import dgl
from dgl.dataloading import MultiLayerFullNeighborSampler
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold
import networkx as nx
from datetime import datetime
import joblib
import warnings
import gc
import copy
print("Finished Library Import \n")
data_input = './../../data/TCGA/BRCA/raw/'
snf_net = 'RPPA_mRNA_graph.graphml'
index_col = 'index'
target = 'paper_BRCA_Subtype_PAM50'
interpret_feat = True
pnet = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Get GPU device name, else use CPU
print("Using %s device" % device)
datModalities , meta = data_parsing(data_input , ['mRNA' , 'RPPA' ] , target , index_col)
if interpret_feat :
features = {}
for i , mod in enumerate(datModalities) :
features[i] = list(datModalities[mod].columns)
model_scores = {}
layer_importance_scores = {}
if pnet :
# List of cancer genes was taken from the PNet paper dataset
genes = pd.read_csv('./../../data/genelists/BRCA_genelist.txt', header=0 , delimiter='\t')
# Build network to obtain gene and pathway relationships
net = ReactomeNetwork(genes_of_interest=np.unique(list(genes['genes'].values)) , n_levels=5)
graph_file = data_input + '../Networks/' + snf_net
g = nx.read_graphml(graph_file)
meta = meta.loc[list(g.nodes())]
meta = meta.loc[sorted(meta.index)]
label = F.one_hot(torch.Tensor(list(meta.astype('category')
skf = StratifiedKFold(n_splits=5 , shuffle=True)
MME_input_shapes = [ datModalities[mod].shape[1] for mod in datModalities]
h = reduce(merge_dfs , list(datModalities.values()))
h = h.loc[meta.index]
h = h.loc[sorted(h.index)]
g = dgl.from_networkx(g , node_attrs=['idx' , 'label'])
g.ndata['feat'] = torch.Tensor(h.to_numpy())
g.ndata['label'] = label
#g = dgl.add_self_loop(g)
del datModalities
output_metrics = []
test_logits = []
test_labels = []
for i, (train_index, test_index) in enumerate(skf.split(meta.index, meta)) :
model = GCN_MME(MME_input_shapes , [16 , 32] , 32 , [16] , len(meta.unique()), PNet=net).to(device)
g =
loss_plot = train(g, train_index, device , model , label , 2000 , 1e-3 , 100 , batch_size=1024)
plt.title(f'Loss for split {i}')
sampler = MultiLayerFullNeighborSampler(
len(model.gnnlayers), # fanout for each layer
test_dataloader = DataLoader(
test_output_metrics = evaluate(model , g, test_dataloader)
"Fold : {:01d} | Test Accuracy = {:.4f} | F1 = {:.4f} ".format(
i+1 , test_output_metrics[1] , test_output_metrics[2] )
if i == 0 :
best_model = copy.deepcopy(model).to('cpu')
best_idx = i
elif output_metrics[best_idx][1] < test_output_metrics[1] :
best_model = copy.deepcopy(model).to('cpu')
best_idx = i
if interpret_feat :
model.features = [element for sublist in features.values() for element in sublist]
if i ==0 :
model_scores['Input Features'] = {}
model_scores['Input Features']['mad'] = pd.DataFrame(model.feature_importance(test_dataloader , device).abs().mean(axis=0)).T
else :
model_scores['Input Features']['mad'].loc[i] = model.feature_importance(test_dataloader , device).abs().mean(axis=0)
layer_importance_scores[i] = model.layerwise_importance(test_dataloader , device)
# Get the number of layers of the model
n_layers = len(next(iter(layer_importance_scores[i].values())))
# Sum corresponding modalities importances
mean_absolute_distance = [sum([layer_importance_scores[i][k][ii].abs().mean() for k in layer_importance_scores[i].keys()]) for ii in range(n_layers)]
summed_variation_attr = [sum([layer_importance_scores[i][k][ii].std()/max(layer_importance_scores[i][k][ii].std()) for k in layer_importance_scores[i].keys()]) for ii in range(n_layers)]
for ii , (mad , sva) in enumerate(zip(mean_absolute_distance , summed_variation_attr)) :
layer_title = f"Pathway Level {ii} Importance" if ii > 0 else "Gene Importance"
if i == 0 :
model_scores[layer_title] = {}
model_scores[layer_title]['mad'] = pd.DataFrame(mad).T
model_scores[layer_title]['sva'] = pd.DataFrame(sva).T
else :
model_scores[layer_title]['mad'].loc[i] = mad
model_scores[layer_title]['sva'].loc[i] = sva
del model , test_dataloader
print('Clearing gpu memory')
test_logits = torch.stack(test_logits)
test_labels = torch.stack(test_labels)
accuracy = []
F1 = []
i = 0
for metric in output_metrics :
print("%i Fold Cross Validation Accuracy = %2.2f \u00B1 %2.2f" %(5 , np.mean(accuracy)*100 , np.std(accuracy)*100))
print("%i Fold Cross Validation F1 = %2.2f \u00B1 %2.2f" %(5 , np.mean(F1)*100 , np.std(F1)*100))
confusion_matrix(test_logits , test_labels , meta.astype('category').cat.categories)
plt.title('Test Accuracy = %2.1f %%' % (np.mean(accuracy)*100))
precision_recall_plot , all_predictions_conf = AUROC(test_logits, test_labels , meta)
node_predictions = []
node_true = []
display_label = meta.astype('category').cat.categories
for pred , true in zip(all_predictions_conf.argmax(1) , list(test_labels.detach().cpu().argmax(1).numpy())) :
preds = pd.DataFrame({'Actual' : node_true , 'Predicted' : node_predictions})
Graph(num_nodes=1076, num_edges=18300,
ndata_schemes={'idx': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(5,), dtype=torch.int64), 'feat': Scheme(shape=(30459,), dtype=torch.float32)}
Fold : 1 | Test Accuracy = 0.8380 | F1 = 0.8185
Graph(num_nodes=1076, num_edges=18300,
ndata_schemes={'idx': Scheme(shape=(), dtype=torch.int64), 'label': Scheme(shape=(5,), dtype=torch.int64), 'feat': Scheme(shape=(30459,), dtype=torch.float32)}
Fold : 2 | Test Accuracy = 0.8512 | F1 = 0.8233
Fold : 3 | Test Accuracy = 0.8558 | F1 = 0.8163
Fold : 4 | Test Accuracy = 0.8791 | F1 = 0.8675
Fold : 5 | Test Accuracy = 0.8698 | F1 = 0.8441
5 Fold Cross Validation Accuracy = 85.88 ± 1.44
5 Fold Cross Validation F1 = 83.39 ± 1.94


from Modules.PNetTorch.MAIN.interpret import pnet_significance_testing, pnet_model_significance_testing , calculate_es
sig_feats = pnet_significance_testing(model_scores , key = 'mad')
16960 Feautures have p-value < 0.01 in Layer Input Features:
212 Feautures have p-value < 0.01 in Layer Gene Importance:
48 Feautures have p-value < 0.01 in Layer Pathway Level 1 Importance:
19 Feautures have p-value < 0.01 in Layer Pathway Level 2 Importance:
6 Feautures have p-value < 0.01 in Layer Pathway Level 3 Importance:
4 Feautures have p-value < 0.01 in Layer Pathway Level 4 Importance:
2 Feautures have p-value < 0.01 in Layer Pathway Level 5 Importance:
avg_gene_ranking = pd.merge(model_scores['Gene Importance']['sva'].mean(axis=0).reset_index() , genes , left_on='index' , right_on='genes')
avg_gene_ranking.sort_values(0, ascending=False).iloc[:20]
index | 0 | genes | group | |
1645 | RUNX1 | 0.593016 | RUNX1 | pos |
1414 | PIK3R1 | 0.582591 | PIK3R1 | pos |
1371 | PB1 | 0.571091 | PB1 | neg |
1068 | MAPK1 | 0.565907 | MAPK1 | pos |
1503 | PSMA7 | 0.565385 | PSMA7 | neg |
1471 | PPP2R5C | 0.545946 | PPP2R5C | neg |
1766 | SMAD3 | 0.542921 | SMAD3 | pos |
1519 | PTPN11 | 0.522787 | PTPN11 | pos |
1469 | PPP2R1A | 0.518152 | PPP2R1A | pos |
1810 | SRC | 0.516633 | SRC | pos |
811 | HSP90AA1 | 0.510108 | HSP90AA1 | pos |
1590 | RHOA | 0.462977 | RHOA | pos |
898 | JAK1 | 0.461686 | JAK1 | pos |
412 | CREB1 | 0.444376 | CREB1 | pos |
1833 | STAT5B | 0.431349 | STAT5B | pos |
446 | CYP2C8 | 0.427268 | CYP2C8 | pos |
316 | CDK1 | 0.425561 | CDK1 | neg |
1547 | RAC1 | 0.419924 | RAC1 | pos |
426 | CSPG4 | 0.414304 | CSPG4 | neg |
1506 | PSMC5 | 0.411303 | PSMC5 | neg |
pnet_model_significance_testing(model_scores['Gene Importance']['sva'],genes[genes['group'] == 'pos']['genes'])
Enrichment Score for Pos: 0.342657803328948
The Observed Effect Size (ES) of genes related to outcome is 0.342657803328948 with significance p-value 0.005994005994005994

Enrichment Score for Pos: 0.3150978000639709
The Observed Effect Size (ES) of genes related to outcome is 0.3150978000639709 with significance p-value 0.001998001998001998

Enrichment Score for Pos: 0.4337781609699964
The Observed Effect Size (ES) of genes related to outcome is 0.4337781609699964 with significance p-value 0.0

Enrichment Score for Pos: 0.41806920601503694
The Observed Effect Size (ES) of genes related to outcome is 0.41806920601503694 with significance p-value 0.0

Enrichment Score for Pos: 0.34847527669874134
The Observed Effect Size (ES) of genes related to outcome is 0.34847527669874134 with significance p-value 0.0

S = np.array(genes[genes['group'] == 'pos']['genes'])
r = model_scores['Gene Importance']['sva'].mean(axis=0).sort_values(ascending=False)
r_index = np.array(r.index)
r = r.to_numpy()
hits = np.isin(r_index , S)
N_R = np.sum(np.abs(r[hits])**1)
real_es_pos = calculate_es(S , r, hits , N_R)
print(f"Enrichment Score for Pos: {real_es_pos}")
perm_es_scores = []
for _ in range(1000):
hits = np.isin(np.random.permutation(r_index) , S)
perm_es = calculate_es(S, r, hits , N_R)
p_value = np.sum(perm_es_scores > real_es_pos) / (1000 + 1)
print(f'The Observed Effect Size (ES) of genes related to outcome is {real_es_pos} with significance p-value {p_value}')
# Plotting the permutation ES scores
plt.hist(perm_es_scores, bins=30, alpha=0.75, label='Permutation ES')
plt.axvline(x=real_es_pos, color='red', label='Observed ES')
plt.title('Permutation Test for Gene Set Enrichment')
plt.xlabel('Enrichment Score (ES)')
Enrichment Score for Pos: 0.35872396224896697
The Observed Effect Size (ES) of genes related to outcome is 0.35872396224896697 with significance p-value 0.0

model_layers_importance = {}
model_layers_importance_fig= {}
for i, layer in enumerate(model_scores):
if i == 0 :
fig = plt.figure(figsize=(12,6))
plt.xticks(rotation=45, ha='right', rotation_mode='anchor')
plt.title('Input Feature Importance')
model_layers_importance_fig['Feature Importance'] = fig
else :
layer_title = f"Pathway Level {i} Importance" if i > 1 else "Gene Importance"
model_layers_importance[layer_title] = layer
model_layers_importance_fig[layer_title] = visualize_importances(
model_scores[layer]['sva'].mean(axis=0), title=f"Average {layer_title}")







#del model , train_loader , test_loader
# test the model
acc = layerwise_infer(
device, g, np.arange(len(g.nodes())),, batch_size=4096
print("Test Accuracy {:.4f}".format(acc.item()))
Test Accuracy 0.9015
with torch.no_grad():
emb = best_model.embedding_extraction(
g, g.ndata['feat'] ,device, 4096
) # pred in buffer_device
tsne_embedding_plot(emb.detach().cpu().numpy() , meta)
