Here we will be using a small subset of the BeatAML data to predict the AUC of drug response values.
Our predictive variables will be transcriptomics, proteomics, and drug structure.
This small subset does not have high predictive power and as such, this is just meant as a guide for creating your own models.
import pandas as pd
import numpy as np
import coderdata as cd
We will be using the keras functional API. This allows us to use multiple inputs in our model.
You may need to install a couple of these dependancies to procceed.
import keras
from numpy import loadtxt
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from rdkit import Chem
from rdkit.Chem import AllChem
from keras.models import Model
from keras.layers import Input, Dense, concatenate
from keras.optimizers import Adam
from keras.losses import MeanSquaredError
from keras.metrics import MeanAbsoluteError
from keras import layers
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
The first step is to download the data. We will then import the data into the DatasetLoader object.
If there are issues with the download process, please ensure that your connection is not blocked by a VPN.
cd.download_data_by_prefix('beataml')
beataml = cd.DatasetLoader("beataml")
Downloaded beataml_drugs.tsv.gz to local repository. Downloaded beataml_proteomics.csv.gz to local repository. Downloaded beataml_mutations.csv.gz to local repository. Downloaded beataml_experiments.csv.gz to local repository. Downloaded beataml_samples.csv to local repository. Downloaded beataml_transcriptomics.csv.gz to local repository. Processing Data...
This info command allows us to see which data types are present in the beataml DatasetLoader object.
We will be using the Transcriptomics, Proteomics, Drugs, and Experiments Data in our example.
In other examples you may want to include the samples data to associate cancer and model type with each sample. However, in the BeatAML data, all data is from Acute Myeloid Leukemia patients.
beataml.info()
Dataset Type: beataml Beat acute myeloid leukemia (BeatAML) data was collected though GitHub and Synapse. Available Datatypes and Their Formats: - drugs: Format not specified - experiments: Format not specified - mutations: long format - proteomics: long format - samples: long format - transcriptomics: long format
beataml.drugs
improve_drug_id | chem_name | formula | weight | inCHIKey | canSMILES | isoSMILES | |
---|---|---|---|---|---|---|---|
0 | SMI_43 | 17-AAG | C31H43N3O8 | 585.70 | AYUNIORJHRXIBJ-TXHRRWQRSA-N | CC1CC(C(C(C=C(C(C(C=CC=C(C(=O)NC2=CC(=O)C(=C(C... | C[C@H]1C[C@@H]([C@@H]([C@H](/C=C(/[C@@H]([C@H]... |
1 | SMI_3222 | A-674563 | C22H22N4O | 358.40 | BPNUQXPIQBZCMR-IBGZPJMESA-N | CC1=C2C=C(C=CC2=NN1)C3=CC(=CN=C3)OCC(CC4=CC=CC... | CC1=C2C=C(C=CC2=NN1)C3=CC(=CN=C3)OC[C@H](CC4=C... |
2 | SMI_146 | ABT-737 | C42H45ClN6O5S2 | 813.40 | HPLNQCPCUACXLM-PGUFJCEWSA-N | CN(C)CCC(CSC1=CC=CC=C1)NC2=C(C=C(C=C2)S(=O)(=O... | CN(C)CC[C@H](CSC1=CC=CC=C1)NC2=C(C=C(C=C2)S(=O... |
3 | SMI_57721 | AKT Inhibitor IV | C31H27IN4S | 614.50 | NAYRELMNTQSBIN-UHFFFAOYSA-M | CC[N+]1=C(N(C2=C1C=C(C=C2)C3=NC4=CC=CC=C4S3)C5... | CC[N+]1=C(N(C2=C1C=C(C=C2)C3=NC4=CC=CC=C4S3)C5... |
4 | SMI_57735 | AMPK Inhibitor | C24H25N5O | 399.50 | XHBVYDAKJHETMP-UHFFFAOYSA-N | C1CCN(CC1)CCOC2=CC=C(C=C2)C3=CN4C(=C(C=N4)C5=C... | C1CCN(CC1)CCOC2=CC=C(C=C2)C3=CN4C(=C(C=N4)C5=C... |
... | ... | ... | ... | ... | ... | ... | ... |
35420 | SMI_414 | Venetoclax | C45H50ClN7O7S | 868.40 | LQBVNQSMGBZMKD-UHFFFAOYSA-N | CC1(CCC(=C(C1)C2=CC=C(C=C2)Cl)CN3CCN(CC3)C4=CC... | CC1(CCC(=C(C1)C2=CC=C(C=C2)Cl)CN3CCN(CC3)C4=CC... |
35421 | SMI_1062 | Vismodegib | C19H14Cl2N2O3S | 421.30 | BPQMGSKTAYIVFO-UHFFFAOYSA-N | CS(=O)(=O)C1=CC(=C(C=C1)C(=O)NC2=CC(=C(C=C2)Cl... | CS(=O)(=O)C1=CC(=C(C=C1)C(=O)NC2=CC(=C(C=C2)Cl... |
35422 | SMI_3146 | Volasertib | C34H50N8O3 | 618.80 | SXNJFOWDRLKDSF-XKHVUIRMSA-N | CCC1C(=O)N(C2=CN=C(N=C2N1C(C)C)NC3=C(C=C(C=C3)... | CC[C@@H]1C(=O)N(C2=CN=C(N=C2N1C(C)C)NC3=C(C=C(... |
35423 | SMI_1137 | XAV-939 | C14H11F3N2OS | 312.31 | KLGQSVMIPOVQAX-UHFFFAOYSA-N | C1CSCC2=C1N=C(NC2=O)C3=CC=C(C=C3)C(F)(F)F | C1CSCC2=C1N=C(NC2=O)C3=CC=C(C=C3)C(F)(F)F |
35424 | SMI_181 | YM-155 | C20H19BrN4O3 | 443.30 | QBIYUDDJPRGKNJ-UHFFFAOYSA-M | CC1=[N+](C2=C(N1CCOC)C(=O)C3=CC=CC=C3C2=O)CC4=... | CC1=[N+](C2=C(N1CCOC)C(=O)C3=CC=CC=C3C2=O)CC4=... |
35425 rows × 7 columns
beataml.experiments
source | improve_sample_id | improve_drug_id | study | auc | ic50 | ec50 | ec50se | r2fit | einf | hs | aac1 | auc1 | dss1 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | synapse | 4634 | SMI_43 | BeatAML | 0.5716 | NaN | 11.430 | 0.000 | 0.0000 | 0.5716 | 3.525 | 0.4284 | 0.5716 | 0.3649 |
1 | synapse | 4634 | SMI_3222 | BeatAML | 0.6310 | NaN | 10.890 | 0.000 | -0.0000 | 0.6310 | 3.171 | 0.3690 | 0.6310 | 0.2989 |
2 | synapse | 4634 | SMI_146 | BeatAML | 0.7496 | NaN | 11.590 | 0.000 | 0.0000 | 0.7496 | 3.594 | 0.2504 | 0.7496 | 0.1671 |
3 | synapse | 4634 | SMI_57721 | BeatAML | 0.7264 | NaN | 7.525 | 3.249 | 0.0277 | 0.5343 | 4.000 | 0.4609 | 0.5391 | 0.4010 |
4 | synapse | 4634 | SMI_57735 | BeatAML | 0.5588 | NaN | 11.540 | 0.000 | -0.0000 | 0.5588 | 3.611 | 0.4412 | 0.5588 | 0.3791 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
38775 | synapse | 4346 | SMI_414 | BeatAML | 0.4791 | 10.94 | 11.360 | 0.000 | -0.0000 | 0.4791 | 3.294 | 0.5209 | 0.4791 | 0.4677 |
38776 | synapse | 4346 | SMI_1062 | BeatAML | 0.9319 | NaN | 9.805 | 0.000 | -0.0000 | 0.8638 | 0.000 | 0.0681 | 0.9319 | 0.0000 |
38777 | synapse | 4346 | SMI_3146 | BeatAML | 0.4195 | 10.81 | 11.060 | 0.000 | -0.0000 | 0.4195 | 3.247 | 0.5805 | 0.4195 | 0.5339 |
38778 | synapse | 4346 | SMI_1137 | BeatAML | 0.8347 | NaN | 0.036 | 0.000 | -0.0000 | 0.6693 | 0.000 | 0.1653 | 0.8347 | 0.0726 |
38779 | synapse | 4346 | SMI_181 | BeatAML | 0.7405 | NaN | 11.370 | 0.000 | -0.0000 | 0.7405 | 3.878 | 0.2595 | 0.7405 | 0.1773 |
38780 rows × 14 columns
This merge allows us to place isoSMILES, sample ID, drug ID and auc together.
merged_df = pd.merge(beataml.experiments[['improve_sample_id', 'improve_drug_id', 'auc']],
beataml.drugs[['improve_drug_id', 'isoSMILES']],
on='improve_drug_id',
how='inner').drop_duplicates()
merged_df = merged_df.dropna(subset=['isoSMILES'])
merged_df
improve_sample_id | improve_drug_id | auc | isoSMILES | |
---|---|---|---|---|
0 | 4634 | SMI_43 | 0.5716 | C[C@H]1C[C@@H]([C@@H]([C@H](/C=C(/[C@@H]([C@H]... |
284 | 4634 | SMI_3222 | 0.6310 | CC1=C2C=C(C=CC2=NN1)C3=CC(=CN=C3)OC[C@H](CC4=C... |
569 | 4634 | SMI_146 | 0.7496 | CN(C)CC[C@H](CSC1=CC=CC=C1)NC2=C(C=C(C=C2)S(=O... |
781 | 4634 | SMI_57721 | 0.7264 | CC[N+]1=C(N(C2=C1C=C(C=C2)C3=NC4=CC=CC=C4S3)C5... |
846 | 4634 | SMI_57735 | 0.5588 | C1CCN(CC1)CCOC2=CC=C(C=C2)C3=CN4C(=C(C=N4)C5=C... |
... | ... | ... | ... | ... |
10668667 | 4346 | SMI_414 | 0.4791 | CC1(CCC(=C(C1)C2=CC=C(C=C2)Cl)CN3CCN(CC3)C4=CC... |
10668944 | 4346 | SMI_1062 | 0.9319 | CS(=O)(=O)C1=CC(=C(C=C1)C(=O)NC2=CC(=C(C=C2)Cl... |
10669230 | 4346 | SMI_3146 | 0.4195 | CC[C@@H]1C(=O)N(C2=CN=C(N=C2N1C(C)C)NC3=C(C=C(... |
10669505 | 4346 | SMI_1137 | 0.8347 | C1CSCC2=C1N=C(NC2=O)C3=CC=C(C=C3)C(F)(F)F |
10669748 | 4346 | SMI_181 | 0.7405 | CC1=[N+](C2=C(N1CCOC)C(=O)C3=CC=CC=C3C2=O)CC4=... |
34850 rows × 4 columns
The chemical fingerprint is then mapped to improve_drug_id in the fingerprint_map. These values will be retrieved later. The merged_df is also subsetted to remove redundant information.
# Convert SMILES to Morgan fingerprints
def smiles_to_fingerprint(smiles):
mol = Chem.MolFromSmiles(smiles)
fingerprint = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=1024) # Adjust radius and nBits as needed
fingerprint_array = np.array(fingerprint)
return fingerprint_array
# smiles_to_fingerprint(merged_df.isoSMILES[0])
# # Apply the function to your SMILES column
merged_df['fingerprint'] = merged_df['isoSMILES'].apply(smiles_to_fingerprint)
fingerprint_map = merged_df[['fingerprint',"improve_drug_id"]]
merged_df = merged_df[["improve_sample_id","improve_drug_id","auc"]]
merged_df
improve_sample_id | improve_drug_id | auc | |
---|---|---|---|
0 | 4634 | SMI_43 | 0.5716 |
284 | 4634 | SMI_3222 | 0.6310 |
569 | 4634 | SMI_146 | 0.7496 |
781 | 4634 | SMI_57721 | 0.7264 |
846 | 4634 | SMI_57735 | 0.5588 |
... | ... | ... | ... |
10668667 | 4346 | SMI_414 | 0.4791 |
10668944 | 4346 | SMI_1062 | 0.9319 |
10669230 | 4346 | SMI_3146 | 0.4195 |
10669505 | 4346 | SMI_1137 | 0.8347 |
10669748 | 4346 | SMI_181 | 0.7405 |
34850 rows × 3 columns
fingerprint_map
fingerprint | improve_drug_id | |
---|---|---|
0 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ... | SMI_43 |
284 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | SMI_3222 |
569 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, ... | SMI_146 |
781 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | SMI_57721 |
846 | [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, ... | SMI_57735 |
... | ... | ... |
10668667 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, ... | SMI_414 |
10668944 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | SMI_1062 |
10669230 | [0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, ... | SMI_3146 |
10669505 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... | SMI_1137 |
10669748 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ... | SMI_181 |
34850 rows × 2 columns
fingerprint_dict = fingerprint_map.set_index('improve_drug_id')['fingerprint'].to_dict()
fingerprint_dict
{'SMI_43': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_3222': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_146': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_57721': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57735': array([0, 0, 1, ..., 0, 0, 0]), 'SMI_1356': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_949': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_990': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_224': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_274': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_388': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_398': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_98': array([0, 0, 0, ..., 1, 0, 0]), 'SMI_381': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_145': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_331': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_97': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1302': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_341': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_378': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_225': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_988': array([0, 0, 0, ..., 0, 0, 1]), 'SMI_57': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_337': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_4071': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1303': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_58': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1003': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_230': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1134': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1397': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1393': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_384': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57737': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_52446': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_3033': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_342': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_95': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57722': array([0, 0, 1, ..., 0, 0, 0]), 'SMI_52331': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1412': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_1842': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_37': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_57732': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_57724': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57728': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_64': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_338': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_56627': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_41054': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_376': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1060': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_221': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_755': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_125': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57730': array([0, 0, 0, ..., 1, 0, 0]), 'SMI_3276': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_227': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1057': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_3125': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_318': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_231': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_184': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57723': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1312': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57726': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_148': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_233': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57731': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_53380': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_1108': array([1, 0, 0, ..., 0, 0, 0]), 'SMI_283': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1403': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1056': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_339': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_63': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_262': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1063': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_276': array([0, 1, 1, ..., 0, 0, 0]), 'SMI_144': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_55709': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_52447': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_50': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_272': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_232': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57736': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_236': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_223': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_365': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_102': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_51': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_107': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_374': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_369': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_373': array([0, 1, 1, ..., 0, 0, 0]), 'SMI_188': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_54925': array([0, 0, 1, ..., 0, 0, 0]), 'SMI_1012': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_995': array([0, 0, 0, ..., 0, 1, 0]), 'SMI_57733': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_101': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_3148': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1062': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1137': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_181': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_57727': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1963': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_375': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_261': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_278': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1155': array([0, 1, 1, ..., 0, 0, 0]), 'SMI_1109': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_10': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_52922': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_2582': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_75': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1777': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_1107': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_55184': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_3166': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_410': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_256': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_265': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1327': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_106': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_140': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_1055': array([0, 0, 0, ..., 0, 1, 0]), 'SMI_60': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_281': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57729': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_1952': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_285': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_414': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_3146': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_1061': array([0, 0, 1, ..., 0, 0, 0]), 'SMI_1104': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57720': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_349': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_167': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_28': array([1, 1, 0, ..., 0, 0, 0]), 'SMI_165': array([1, 0, 0, ..., 0, 0, 0]), 'SMI_1859': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_54398': array([0, 0, 1, ..., 0, 0, 0]), 'SMI_1984': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_182': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57734': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_53718': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_2064': array([0, 1, 0, ..., 0, 0, 0]), 'SMI_299': array([0, 0, 0, ..., 0, 0, 0]), 'SMI_57725': array([0, 0, 0, ..., 0, 0, 0])}
This method of merging drops all samples that do not have both transcriptomics and proteomics data.
This reduces our sample size to 15 and allows the tutorial to run much faster on low powered machines.
# Merge merged_df with transcriptomics_df based on improve_sample_id
tp = pd.merge(beataml.transcriptomics,
beataml.proteomics[['improve_sample_id', 'proteomics', 'entrez_id']],
on=['improve_sample_id','entrez_id'],
how='inner').drop_duplicates()
merged_df = pd.merge(merged_df,
tp[['improve_sample_id', 'transcriptomics','proteomics', 'entrez_id']],
on='improve_sample_id',
how='inner').drop_duplicates()
merged_df
improve_sample_id | improve_drug_id | auc | transcriptomics | proteomics | entrez_id | |
---|---|---|---|---|---|---|
0 | 3986 | SMI_236 | 0.0000 | 7.068854 | -0.135 | 8813.0 |
1 | 3986 | SMI_236 | 0.0000 | 3.859919 | -0.300 | 6359.0 |
2 | 3986 | SMI_236 | 0.0000 | 3.859919 | -0.300 | 57147.0 |
3 | 3986 | SMI_236 | 0.0000 | 9.634914 | 1.140 | 2268.0 |
4 | 3986 | SMI_236 | 0.0000 | -1.401864 | -0.174 | 3075.0 |
... | ... | ... | ... | ... | ... | ... |
8148099 | 4007 | SMI_57727 | 0.9307 | 1.519607 | -0.192 | 80755.0 |
8148100 | 4007 | SMI_57727 | 0.9307 | 2.088872 | -0.260 | 55957.0 |
8148101 | 4007 | SMI_57727 | 0.9307 | 2.154935 | 0.865 | 6023.0 |
8148102 | 4007 | SMI_57727 | 0.9307 | 7.165478 | 0.472 | 11165.0 |
8148103 | 4007 | SMI_57727 | 0.9307 | -0.435228 | 0.365 | 6047.0 |
8148104 rows × 6 columns
Here we recommend to modify the imputation method. Taking the global mean may not be apppropriate for your data. This is just simple method and reminder that this may need to be done.
In this example, no transcriptomics values are actually imputed, and less than 3500 / 8148104 (.04%) of proteomics values are imputed.
#Impute missing proteomics (and transcriptomics if there are any) based on global mean.
columns_to_fill = ['transcriptomics', 'proteomics']
for i in columns_to_fill:
if i in merged_df.columns[merged_df.isnull().any(axis=0)]:
merged_df[i].fillna(merged_df[i].mean(), inplace=True)
merged_df
improve_sample_id | improve_drug_id | auc | transcriptomics | proteomics | entrez_id | |
---|---|---|---|---|---|---|
0 | 3986 | SMI_236 | 0.0000 | 7.068854 | -0.135 | 8813.0 |
1 | 3986 | SMI_236 | 0.0000 | 3.859919 | -0.300 | 6359.0 |
2 | 3986 | SMI_236 | 0.0000 | 3.859919 | -0.300 | 57147.0 |
3 | 3986 | SMI_236 | 0.0000 | 9.634914 | 1.140 | 2268.0 |
4 | 3986 | SMI_236 | 0.0000 | -1.401864 | -0.174 | 3075.0 |
... | ... | ... | ... | ... | ... | ... |
8148099 | 4007 | SMI_57727 | 0.9307 | 1.519607 | -0.192 | 80755.0 |
8148100 | 4007 | SMI_57727 | 0.9307 | 2.088872 | -0.260 | 55957.0 |
8148101 | 4007 | SMI_57727 | 0.9307 | 2.154935 | 0.865 | 6023.0 |
8148102 | 4007 | SMI_57727 | 0.9307 | 7.165478 | 0.472 | 11165.0 |
8148103 | 4007 | SMI_57727 | 0.9307 | -0.435228 | 0.365 | 6047.0 |
8148104 rows × 6 columns
merged_df['fingerprint'] = merged_df['improve_drug_id'].map(fingerprint_dict)
Now all of our data is gathered in one place. However, we will need to do some restructuring to get this into a sutiable input for our deep learning model
merged_df
improve_sample_id | improve_drug_id | auc | transcriptomics | proteomics | entrez_id | fingerprint | |
---|---|---|---|---|---|---|---|
0 | 3986 | SMI_236 | 0.0000 | 7.068854 | -0.135 | 8813.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ... |
1 | 3986 | SMI_236 | 0.0000 | 3.859919 | -0.300 | 6359.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ... |
2 | 3986 | SMI_236 | 0.0000 | 3.859919 | -0.300 | 57147.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ... |
3 | 3986 | SMI_236 | 0.0000 | 9.634914 | 1.140 | 2268.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ... |
4 | 3986 | SMI_236 | 0.0000 | -1.401864 | -0.174 | 3075.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ... |
... | ... | ... | ... | ... | ... | ... | ... |
8148099 | 4007 | SMI_57727 | 0.9307 | 1.519607 | -0.192 | 80755.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... |
8148100 | 4007 | SMI_57727 | 0.9307 | 2.088872 | -0.260 | 55957.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... |
8148101 | 4007 | SMI_57727 | 0.9307 | 2.154935 | 0.865 | 6023.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... |
8148102 | 4007 | SMI_57727 | 0.9307 | 7.165478 | 0.472 | 11165.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... |
8148103 | 4007 | SMI_57727 | 0.9307 | -0.435228 | 0.365 | 6047.0 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... |
8148104 rows × 7 columns
Scale the transcriptomics and proteomics data to the same 0-1 scale if you wish them to have equal weights in the Deep Learning Model.
scaler = MinMaxScaler()
merged_df[["transcriptomics","proteomics"]] = scaler.fit_transform(merged_df[["transcriptomics","proteomics"]])
These are restructured into lists within the Pandas Dataframe by Sample. Each array must be ordered in the same way be entrez id. In this case, we sued a Hashmap for time reduction. This can be done through however method you wish to use. Just make sure that each list of proteomics / transcriptomics values is linked to an entrez_id list.
#Create A Hashmap of {improve_sample_id: {improve_drug_id: { entrez_id: transcriptomics, proteomics, auc]}}
data_dict = {}
# Iterate over the DataFrame rows to populate the nested dictionary
for index, row in merged_df.iterrows():
improve_sample_id = row['improve_sample_id']
improve_drug_id = row['improve_drug_id']
auc = row['auc']
transcriptomics = row['transcriptomics']
proteomics = row['proteomics']
entrez_id = row['entrez_id']
# print(improve_sample_id,improve_drug_id,auc,transcriptomics,proteomics,entrez_id)
# Initialize improve_sample_id key if not present
if improve_sample_id not in data_dict:
data_dict[improve_sample_id] = {}
# print(data_dict)
# # Initialize improve_drug_id key if not present
if improve_drug_id not in data_dict[improve_sample_id]:
data_dict[improve_sample_id][improve_drug_id] = {}
# # Append data to the nested dictionary
if entrez_id not in data_dict[improve_sample_id][improve_drug_id]:
data_dict[improve_sample_id][improve_drug_id][entrez_id] = {
'transcriptomics': int,
'proteomics': int,
'auc': auc
}
data_dict[improve_sample_id][improve_drug_id][entrez_id]['transcriptomics'] = transcriptomics
data_dict[improve_sample_id][improve_drug_id][entrez_id]['proteomics'] = proteomics
The key takeaway here is that each of the input variables should be in a list format.
#Write [transcriptomics], [proteomics], and auc data to aggregated_df. These are all written in the same order as the unique_entrez_ids set.
# # Extract the list of unique entrez_ids
unique_entrez_ids = set(merged_df.entrez_id.unique())
# # Construct the final DataFrame
final_data = []
for improve_sample_id, sample_id_data in data_dict.items():
for improve_drug_id, drug_id_data in sample_id_data.items():
transcriptomics = []
proteomics = []
for entrez_id in unique_entrez_ids:
if entrez_id in drug_id_data:
transcriptomics.append(drug_id_data[entrez_id]['transcriptomics'])
proteomics.append(drug_id_data[entrez_id]['proteomics'])
auc = drug_id_data[entrez_id]['auc']
final_data.append([auc, transcriptomics, proteomics, improve_sample_id,improve_drug_id])
# # Create the DataFrame
aggregated_df = pd.DataFrame(final_data, columns=['auc', 'transcriptomics', 'proteomics', 'improve_sample_id', "improve_drug_id"])
aggregated_df
auc | transcriptomics | proteomics | improve_sample_id | improve_drug_id | |
---|---|---|---|---|---|
0 | 0.0000 | [0.23608062994353704, 0.5037978202291047, 0.48... | [0.3384677419354839, 0.37193548387096775, 0.38... | 3986 | SMI_236 |
1 | 0.7431 | [0.23608062994353704, 0.5037978202291047, 0.48... | [0.3384677419354839, 0.37193548387096775, 0.38... | 3986 | SMI_223 |
2 | 0.6135 | [0.23608062994353704, 0.5037978202291047, 0.48... | [0.3384677419354839, 0.37193548387096775, 0.38... | 3986 | SMI_365 |
3 | 0.7425 | [0.18000067030285077, 0.4920917394967491, 0.47... | [0.43870967741935485, 0.3573387096774194, 0.38... | 3987 | SMI_54925 |
4 | 0.0000 | [0.18000067030285077, 0.4920917394967491, 0.47... | [0.43870967741935485, 0.3573387096774194, 0.38... | 3987 | SMI_1012 |
... | ... | ... | ... | ... | ... |
988 | 0.8771 | [0.29431256693950825, 0.5401532345082279, 0.47... | [0.34040322580645166, 0.3787903225806452, 0.37... | 4007 | SMI_3148 |
989 | 1.0000 | [0.29431256693950825, 0.5401532345082279, 0.47... | [0.34040322580645166, 0.3787903225806452, 0.37... | 4007 | SMI_1062 |
990 | 0.5961 | [0.29431256693950825, 0.5401532345082279, 0.47... | [0.34040322580645166, 0.3787903225806452, 0.37... | 4007 | SMI_1137 |
991 | 0.9997 | [0.29431256693950825, 0.5401532345082279, 0.47... | [0.34040322580645166, 0.3787903225806452, 0.37... | 4007 | SMI_181 |
992 | 0.9307 | [0.29431256693950825, 0.5401532345082279, 0.47... | [0.34040322580645166, 0.3787903225806452, 0.37... | 4007 | SMI_57727 |
993 rows × 5 columns
# Map in the fingerprint by improve_drug_id. Remove improve_drug_id.
aggregated_df['fingerprint'] = aggregated_df['improve_drug_id'].map(fingerprint_dict)
aggregated_df = aggregated_df[["auc","transcriptomics","proteomics","fingerprint"]]
AUC will be trained, validated, evaluated by the transcriptomics, proteomics and fingerprint variables
# Training/Test/Validation Data is all under the name training_data.
training_data = aggregated_df
training_data
auc | transcriptomics | proteomics | fingerprint | |
---|---|---|---|---|
0 | 0.0000 | [0.23608062994353704, 0.5037978202291047, 0.48... | [0.3384677419354839, 0.37193548387096775, 0.38... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ... |
1 | 0.7431 | [0.23608062994353704, 0.5037978202291047, 0.48... | [0.3384677419354839, 0.37193548387096775, 0.38... | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, ... |
2 | 0.6135 | [0.23608062994353704, 0.5037978202291047, 0.48... | [0.3384677419354839, 0.37193548387096775, 0.38... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, ... |
3 | 0.7425 | [0.18000067030285077, 0.4920917394967491, 0.47... | [0.43870967741935485, 0.3573387096774194, 0.38... | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... |
4 | 0.0000 | [0.18000067030285077, 0.4920917394967491, 0.47... | [0.43870967741935485, 0.3573387096774194, 0.38... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ... |
... | ... | ... | ... | ... |
988 | 0.8771 | [0.29431256693950825, 0.5401532345082279, 0.47... | [0.34040322580645166, 0.3787903225806452, 0.37... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... |
989 | 1.0000 | [0.29431256693950825, 0.5401532345082279, 0.47... | [0.34040322580645166, 0.3787903225806452, 0.37... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... |
990 | 0.5961 | [0.29431256693950825, 0.5401532345082279, 0.47... | [0.34040322580645166, 0.3787903225806452, 0.37... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... |
991 | 0.9997 | [0.29431256693950825, 0.5401532345082279, 0.47... | [0.34040322580645166, 0.3787903225806452, 0.37... | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ... |
992 | 0.9307 | [0.29431256693950825, 0.5401532345082279, 0.47... | [0.34040322580645166, 0.3787903225806452, 0.37... | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... |
993 rows × 4 columns
The data here is split into 80% training, 10% validating, and 10% testing sets.
# Splitting the merged_data DataFrame into features (X) and target variable (y)
X = training_data[['transcriptomics', 'proteomics',"fingerprint"]] # These columns are your features
y = training_data['auc'] # Using 'auc' column as the target variable
# Splitting data into training, validation, and testing sets
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.1111, random_state=42) # 0.1111 * 90 = 10%
Here you may modify the inputs, layers, activation functions, and outputs of the model.
See if you can improve upon our basic setup.
A note - running some of these steps multiple times in the jupyter notebook may result in odd issues, please be careful with this behaivor.
#These are the array lengths of each feature set.
num_proteomics = 7284
num_transcriptomics = 7284
num_fingerprint = 1024
#Define inputs
transcriptomics_input = keras.Input(
shape=(num_transcriptomics,), name="transcriptomics"
)
proteomics_input = keras.Input(
shape=(num_proteomics,), name="proteomics"
)
fingerprint_input = keras.Input(
shape=(num_fingerprint,), name="fingerprint"
)
# You may adjust these layers and activation functions to get better (or worse) results.
# Merge all available features into a single large vector via concatenation
x = layers.concatenate([transcriptomics_input, proteomics_input, fingerprint_input])
x = layers.Dense(16, activation='relu')(x)
x = layers.Dense(64, activation="relu")(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(32, activation='relu')(x)
x = layers.Dense(12, activation='relu')(x)
# Priority prediction is here. You may add layers (and an output) after this too.
# This just allows for an early prediction in case you have a large datset and want preliminary results.
priority_pred = layers.Dense(1, name="priority",activation='relu')(x)
# Instantiate an end-to-end model predicting both priority and department
model = keras.Model(
inputs=[transcriptomics_input, proteomics_input, fingerprint_input],
outputs={"priority": priority_pred},
)
keras.utils.plot_model(model, "multi_input_model.png", show_shapes=True)
You must install pydot (`pip install pydot`) and install graphviz (see instructions at https://graphviz.gitlab.io/download/) for plot_model to work.
You may wish to test out different optimizers, learning rates, and metrics.
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=0.0001),
loss={
"priority": keras.losses.MeanSquaredError(),
},
metrics=[MeanAbsoluteError()],
loss_weights={"priority": 1},
)
The model does not allow for lists in the inputs, so these are converted to Arrays.
Also ensure that the length of each array is the same on an input variable basis.
For example, if the proteomics variable has two different length arrays, you get an uninformative error message.
Feel free to change the epochs and batch size.
history = model.fit(
[np.array(X_train['transcriptomics'].tolist()),
np.array(X_train['proteomics'].tolist()),
np.array(X_train['fingerprint'].tolist())
],
y_train,
epochs=100,
batch_size=32,
validation_data=(
[
np.array(X_val['transcriptomics'].tolist()),
np.array(X_val['proteomics'].tolist()),
np.array(X_val['fingerprint'].tolist())
],
y_val)
)
Epoch 1/100
2024-03-07 22:06:39.578569: W tensorflow/tsl/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
25/25 [==============================] - 1s 16ms/step - loss: 0.3187 - mean_absolute_error: 0.4867 - val_loss: 0.2245 - val_mean_absolute_error: 0.4414 Epoch 2/100 25/25 [==============================] - 0s 5ms/step - loss: 0.2099 - mean_absolute_error: 0.3757 - val_loss: 0.1138 - val_mean_absolute_error: 0.3021 Epoch 3/100 25/25 [==============================] - 0s 5ms/step - loss: 0.1783 - mean_absolute_error: 0.3399 - val_loss: 0.0712 - val_mean_absolute_error: 0.2290 Epoch 4/100 25/25 [==============================] - 0s 4ms/step - loss: 0.1419 - mean_absolute_error: 0.3088 - val_loss: 0.0421 - val_mean_absolute_error: 0.1584 Epoch 5/100 25/25 [==============================] - 0s 5ms/step - loss: 0.1281 - mean_absolute_error: 0.2879 - val_loss: 0.0451 - val_mean_absolute_error: 0.1690 Epoch 6/100 25/25 [==============================] - 0s 5ms/step - loss: 0.1105 - mean_absolute_error: 0.2707 - val_loss: 0.0426 - val_mean_absolute_error: 0.1616 Epoch 7/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0958 - mean_absolute_error: 0.2483 - val_loss: 0.0388 - val_mean_absolute_error: 0.1455 Epoch 8/100 25/25 [==============================] - 0s 4ms/step - loss: 0.0988 - mean_absolute_error: 0.2488 - val_loss: 0.0382 - val_mean_absolute_error: 0.1433 Epoch 9/100 25/25 [==============================] - 0s 4ms/step - loss: 0.0867 - mean_absolute_error: 0.2319 - val_loss: 0.0379 - val_mean_absolute_error: 0.1409 Epoch 10/100 25/25 [==============================] - 0s 5ms/step - loss: 0.0809 - mean_absolute_error: 0.2240 - val_loss: 0.0400 - val_mean_absolute_error: 0.1426 Epoch 11/100 25/25 [==============================] - 0s 4ms/step - loss: 0.0820 - mean_absolute_error: 0.2302 - val_loss: 0.0414 - val_mean_absolute_error: 0.1450 Epoch 12/100 25/25 [==============================] - 0s 7ms/step - loss: 0.0767 - mean_absolute_error: 0.2233 - val_loss: 0.0391 - val_mean_absolute_error: 0.1558 Epoch 13/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0804 - mean_absolute_error: 0.2246 - val_loss: 0.0366 - val_mean_absolute_error: 0.1444 Epoch 14/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0736 - mean_absolute_error: 0.2140 - val_loss: 0.0364 - val_mean_absolute_error: 0.1438 Epoch 15/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0742 - mean_absolute_error: 0.2169 - val_loss: 0.0420 - val_mean_absolute_error: 0.1664 Epoch 16/100 25/25 [==============================] - 0s 4ms/step - loss: 0.0803 - mean_absolute_error: 0.2277 - val_loss: 0.0417 - val_mean_absolute_error: 0.1655 Epoch 17/100 25/25 [==============================] - 0s 6ms/step - loss: 0.0665 - mean_absolute_error: 0.2101 - val_loss: 0.0380 - val_mean_absolute_error: 0.1364 Epoch 18/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0641 - mean_absolute_error: 0.2001 - val_loss: 0.0376 - val_mean_absolute_error: 0.1515 Epoch 19/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0650 - mean_absolute_error: 0.2043 - val_loss: 0.0381 - val_mean_absolute_error: 0.1360 Epoch 20/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0579 - mean_absolute_error: 0.1924 - val_loss: 0.0430 - val_mean_absolute_error: 0.1476 Epoch 21/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0610 - mean_absolute_error: 0.1939 - val_loss: 0.0606 - val_mean_absolute_error: 0.2099 Epoch 22/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0576 - mean_absolute_error: 0.1909 - val_loss: 0.0352 - val_mean_absolute_error: 0.1401 Epoch 23/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0554 - mean_absolute_error: 0.1864 - val_loss: 0.0348 - val_mean_absolute_error: 0.1357 Epoch 24/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0578 - mean_absolute_error: 0.1866 - val_loss: 0.0401 - val_mean_absolute_error: 0.1606 Epoch 25/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0553 - mean_absolute_error: 0.1847 - val_loss: 0.0380 - val_mean_absolute_error: 0.1538 Epoch 26/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0481 - mean_absolute_error: 0.1711 - val_loss: 0.0354 - val_mean_absolute_error: 0.1427 Epoch 27/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0528 - mean_absolute_error: 0.1798 - val_loss: 0.0383 - val_mean_absolute_error: 0.1354 Epoch 28/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0516 - mean_absolute_error: 0.1753 - val_loss: 0.0355 - val_mean_absolute_error: 0.1435 Epoch 29/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0506 - mean_absolute_error: 0.1744 - val_loss: 0.0349 - val_mean_absolute_error: 0.1393 Epoch 30/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0567 - mean_absolute_error: 0.1873 - val_loss: 0.0591 - val_mean_absolute_error: 0.2049 Epoch 31/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0544 - mean_absolute_error: 0.1824 - val_loss: 0.0346 - val_mean_absolute_error: 0.1352 Epoch 32/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0459 - mean_absolute_error: 0.1687 - val_loss: 0.0368 - val_mean_absolute_error: 0.1491 Epoch 33/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0447 - mean_absolute_error: 0.1643 - val_loss: 0.0348 - val_mean_absolute_error: 0.1342 Epoch 34/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0425 - mean_absolute_error: 0.1626 - val_loss: 0.0347 - val_mean_absolute_error: 0.1354 Epoch 35/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0422 - mean_absolute_error: 0.1596 - val_loss: 0.0381 - val_mean_absolute_error: 0.1530 Epoch 36/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0454 - mean_absolute_error: 0.1644 - val_loss: 0.0350 - val_mean_absolute_error: 0.1382 Epoch 37/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0432 - mean_absolute_error: 0.1594 - val_loss: 0.0375 - val_mean_absolute_error: 0.1358 Epoch 38/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0435 - mean_absolute_error: 0.1617 - val_loss: 0.0348 - val_mean_absolute_error: 0.1345 Epoch 39/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0390 - mean_absolute_error: 0.1518 - val_loss: 0.0349 - val_mean_absolute_error: 0.1369 Epoch 40/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0416 - mean_absolute_error: 0.1586 - val_loss: 0.0365 - val_mean_absolute_error: 0.1353 Epoch 41/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0380 - mean_absolute_error: 0.1496 - val_loss: 0.0352 - val_mean_absolute_error: 0.1354 Epoch 42/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0406 - mean_absolute_error: 0.1561 - val_loss: 0.0373 - val_mean_absolute_error: 0.1367 Epoch 43/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0365 - mean_absolute_error: 0.1461 - val_loss: 0.0351 - val_mean_absolute_error: 0.1364 Epoch 44/100 25/25 [==============================] - 0s 6ms/step - loss: 0.0397 - mean_absolute_error: 0.1514 - val_loss: 0.0364 - val_mean_absolute_error: 0.1448 Epoch 45/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0391 - mean_absolute_error: 0.1487 - val_loss: 0.0363 - val_mean_absolute_error: 0.1364 Epoch 46/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0402 - mean_absolute_error: 0.1519 - val_loss: 0.0361 - val_mean_absolute_error: 0.1366 Epoch 47/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0392 - mean_absolute_error: 0.1544 - val_loss: 0.0362 - val_mean_absolute_error: 0.1411 Epoch 48/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0412 - mean_absolute_error: 0.1554 - val_loss: 0.0357 - val_mean_absolute_error: 0.1373 Epoch 49/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0386 - mean_absolute_error: 0.1504 - val_loss: 0.0372 - val_mean_absolute_error: 0.1377 Epoch 50/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0357 - mean_absolute_error: 0.1443 - val_loss: 0.0370 - val_mean_absolute_error: 0.1373 Epoch 51/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0387 - mean_absolute_error: 0.1513 - val_loss: 0.0429 - val_mean_absolute_error: 0.1456 Epoch 52/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0415 - mean_absolute_error: 0.1567 - val_loss: 0.0508 - val_mean_absolute_error: 0.1599 Epoch 53/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0410 - mean_absolute_error: 0.1550 - val_loss: 0.0362 - val_mean_absolute_error: 0.1399 Epoch 54/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0349 - mean_absolute_error: 0.1401 - val_loss: 0.0371 - val_mean_absolute_error: 0.1441 Epoch 55/100 25/25 [==============================] - 0s 4ms/step - loss: 0.0441 - mean_absolute_error: 0.1634 - val_loss: 0.0546 - val_mean_absolute_error: 0.1954 Epoch 56/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0404 - mean_absolute_error: 0.1574 - val_loss: 0.0385 - val_mean_absolute_error: 0.1388 Epoch 57/100 25/25 [==============================] - 0s 7ms/step - loss: 0.0332 - mean_absolute_error: 0.1395 - val_loss: 0.0406 - val_mean_absolute_error: 0.1418 Epoch 58/100 25/25 [==============================] - 0s 4ms/step - loss: 0.0370 - mean_absolute_error: 0.1474 - val_loss: 0.0523 - val_mean_absolute_error: 0.1635 Epoch 59/100 25/25 [==============================] - 0s 4ms/step - loss: 0.0356 - mean_absolute_error: 0.1432 - val_loss: 0.0447 - val_mean_absolute_error: 0.1492 Epoch 60/100 25/25 [==============================] - 0s 5ms/step - loss: 0.0360 - mean_absolute_error: 0.1446 - val_loss: 0.0370 - val_mean_absolute_error: 0.1425 Epoch 61/100 25/25 [==============================] - 0s 4ms/step - loss: 0.0357 - mean_absolute_error: 0.1438 - val_loss: 0.0379 - val_mean_absolute_error: 0.1387 Epoch 62/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0361 - mean_absolute_error: 0.1469 - val_loss: 0.0368 - val_mean_absolute_error: 0.1416 Epoch 63/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0341 - mean_absolute_error: 0.1407 - val_loss: 0.0395 - val_mean_absolute_error: 0.1411 Epoch 64/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0364 - mean_absolute_error: 0.1440 - val_loss: 0.0432 - val_mean_absolute_error: 0.1468 Epoch 65/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0333 - mean_absolute_error: 0.1380 - val_loss: 0.0367 - val_mean_absolute_error: 0.1389 Epoch 66/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0346 - mean_absolute_error: 0.1425 - val_loss: 0.0389 - val_mean_absolute_error: 0.1404 Epoch 67/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0374 - mean_absolute_error: 0.1482 - val_loss: 0.0396 - val_mean_absolute_error: 0.1413 Epoch 68/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0340 - mean_absolute_error: 0.1386 - val_loss: 0.0366 - val_mean_absolute_error: 0.1403 Epoch 69/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0351 - mean_absolute_error: 0.1398 - val_loss: 0.0378 - val_mean_absolute_error: 0.1390 Epoch 70/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0330 - mean_absolute_error: 0.1348 - val_loss: 0.0376 - val_mean_absolute_error: 0.1392 Epoch 71/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0339 - mean_absolute_error: 0.1389 - val_loss: 0.0383 - val_mean_absolute_error: 0.1402 Epoch 72/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0322 - mean_absolute_error: 0.1342 - val_loss: 0.0375 - val_mean_absolute_error: 0.1396 Epoch 73/100 25/25 [==============================] - 0s 5ms/step - loss: 0.0333 - mean_absolute_error: 0.1385 - val_loss: 0.0410 - val_mean_absolute_error: 0.1439 Epoch 74/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0312 - mean_absolute_error: 0.1330 - val_loss: 0.0374 - val_mean_absolute_error: 0.1432 Epoch 75/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0328 - mean_absolute_error: 0.1367 - val_loss: 0.0448 - val_mean_absolute_error: 0.1494 Epoch 76/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0318 - mean_absolute_error: 0.1333 - val_loss: 0.0399 - val_mean_absolute_error: 0.1425 Epoch 77/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0304 - mean_absolute_error: 0.1323 - val_loss: 0.0549 - val_mean_absolute_error: 0.1687 Epoch 78/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0343 - mean_absolute_error: 0.1414 - val_loss: 0.0498 - val_mean_absolute_error: 0.1582 Epoch 79/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0340 - mean_absolute_error: 0.1409 - val_loss: 0.0525 - val_mean_absolute_error: 0.1639 Epoch 80/100 25/25 [==============================] - 0s 4ms/step - loss: 0.0377 - mean_absolute_error: 0.1504 - val_loss: 0.0481 - val_mean_absolute_error: 0.1555 Epoch 81/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0335 - mean_absolute_error: 0.1385 - val_loss: 0.0416 - val_mean_absolute_error: 0.1448 Epoch 82/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0305 - mean_absolute_error: 0.1309 - val_loss: 0.0388 - val_mean_absolute_error: 0.1411 Epoch 83/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0328 - mean_absolute_error: 0.1383 - val_loss: 0.0377 - val_mean_absolute_error: 0.1402 Epoch 84/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0319 - mean_absolute_error: 0.1346 - val_loss: 0.0377 - val_mean_absolute_error: 0.1406 Epoch 85/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0313 - mean_absolute_error: 0.1343 - val_loss: 0.0402 - val_mean_absolute_error: 0.1432 Epoch 86/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0295 - mean_absolute_error: 0.1294 - val_loss: 0.0370 - val_mean_absolute_error: 0.1422 Epoch 87/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0299 - mean_absolute_error: 0.1296 - val_loss: 0.0372 - val_mean_absolute_error: 0.1411 Epoch 88/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0315 - mean_absolute_error: 0.1334 - val_loss: 0.0368 - val_mean_absolute_error: 0.1437 Epoch 89/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0320 - mean_absolute_error: 0.1355 - val_loss: 0.0401 - val_mean_absolute_error: 0.1433 Epoch 90/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0295 - mean_absolute_error: 0.1295 - val_loss: 0.0377 - val_mean_absolute_error: 0.1409 Epoch 91/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0287 - mean_absolute_error: 0.1268 - val_loss: 0.0381 - val_mean_absolute_error: 0.1412 Epoch 92/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0321 - mean_absolute_error: 0.1355 - val_loss: 0.0456 - val_mean_absolute_error: 0.1517 Epoch 93/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0311 - mean_absolute_error: 0.1331 - val_loss: 0.0415 - val_mean_absolute_error: 0.1456 Epoch 94/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0310 - mean_absolute_error: 0.1324 - val_loss: 0.0411 - val_mean_absolute_error: 0.1450 Epoch 95/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0296 - mean_absolute_error: 0.1299 - val_loss: 0.0429 - val_mean_absolute_error: 0.1474 Epoch 96/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0299 - mean_absolute_error: 0.1307 - val_loss: 0.0374 - val_mean_absolute_error: 0.1414 Epoch 97/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0296 - mean_absolute_error: 0.1288 - val_loss: 0.0449 - val_mean_absolute_error: 0.1504 Epoch 98/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0317 - mean_absolute_error: 0.1353 - val_loss: 0.0381 - val_mean_absolute_error: 0.1505 Epoch 99/100 25/25 [==============================] - 0s 3ms/step - loss: 0.0303 - mean_absolute_error: 0.1298 - val_loss: 0.0413 - val_mean_absolute_error: 0.1455 Epoch 100/100 25/25 [==============================] - 0s 2ms/step - loss: 0.0305 - mean_absolute_error: 0.1282 - val_loss: 0.0408 - val_mean_absolute_error: 0.1447
Here we use the test data to evaluate our model
losses = model.evaluate(
[np.array(X_test['transcriptomics'].tolist()),
np.array(X_test['proteomics'].tolist()),
np.array(X_test['fingerprint'].tolist())
],
y_test,
return_dict=True)
print(losses)
4/4 [==============================] - 0s 1ms/step - loss: 0.0218 - mean_absolute_error: 0.1116 {'loss': 0.021844929084181786, 'mean_absolute_error': 0.11160688102245331}
These show how your model (hopefully) improve during training.
# summarize history for accuracy
plt.plot(history.history['val_mean_absolute_error'])
plt.plot(history.history['mean_absolute_error'])
plt.title('Mean Absolute Error of Deep Learning Model')
plt.ylabel('Mean Absolute Error')
plt.xlabel('epoch')
plt.legend(['Validation Set','Training Set',], loc='upper right')
plt.show()
# summarize history for loss
plt.plot(history.history['val_loss'])
plt.plot(history.history['loss'])
plt.title('Loss Function of Deep Learning Model')
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.legend(['Validation Set','Training Set'], loc='upper right')
plt.show()
Here we use our test data again, but you'd likely want to apply this to other data. The inputs (and thier format) must match the model that you trained above.
Below we provide some plots and summary statistics for evaluating true vs predicted values.
Try to create a model better than ours!
# make a prediction
predict = model.predict([np.array(X_test['transcriptomics'].tolist()),
np.array(X_test['proteomics'].tolist()),
np.array(X_test['fingerprint'].tolist())
])
4/4 [==============================] - 0s 1ms/step
new_df = pd.DataFrame({
'Predicted Values': predict["priority"].tolist(),
'True Values': y_test
})
new_df['Predicted Values'] = new_df['Predicted Values'].apply(lambda x: x[0] if isinstance(x, list) and len(x) == 1 else x)
sorted_df = new_df.sort_values(by='True Values', ascending=True) # Change 'ascending' to False for descending order
sorted_df
Predicted Values | True Values | |
---|---|---|
830 | 0.689598 | 0.0000 |
198 | 0.613777 | 0.3066 |
10 | 0.576904 | 0.3380 |
355 | 0.569279 | 0.3572 |
656 | 0.576972 | 0.4010 |
... | ... | ... |
485 | 0.907541 | 0.9939 |
371 | 0.961458 | 0.9997 |
323 | 0.957583 | 1.0000 |
430 | 0.911885 | 1.0000 |
362 | 0.750975 | 1.0000 |
100 rows × 2 columns
true_values = np.array(sorted_df["True Values"])
predicted_values = np.array(sorted_df["Predicted Values"])
# Plot the true values and predicted values
plt.plot(true_values)
plt.plot(predicted_values)
# Fit trendlines
true_values_trendline = LinearRegression().fit(np.arange(len(true_values)).reshape(-1, 1), true_values.reshape(-1, 1)).predict(np.arange(len(true_values)).reshape(-1, 1))
predicted_values_trendline = LinearRegression().fit(np.arange(len(predicted_values)).reshape(-1, 1), predicted_values.reshape(-1, 1)).predict(np.arange(len(predicted_values)).reshape(-1, 1))
# Plot trendlines
plt.plot(true_values_trendline, linestyle='--', color='blue', alpha=0.3)
plt.plot(predicted_values_trendline, linestyle='--', color='orange', alpha=0.3)
# Customize plot
plt.title('Predicted AUC vs True AUC')
plt.ylabel('Drug Response AUC')
plt.xlabel('Values')
plt.legend(['True Values', 'Predicted Values', 'True Values Trendline', 'Predicted Values Trendline'], loc='lower right')
plt.show()
true_values = np.array(new_df["True Values"])
predicted_values = np.array(new_df["Predicted Values"])
# Plot the true values and predicted values
plt.plot(true_values)
plt.plot(predicted_values)
# Fit trendlines
true_values_trendline = LinearRegression().fit(np.arange(len(true_values)).reshape(-1, 1), true_values.reshape(-1, 1)).predict(np.arange(len(true_values)).reshape(-1, 1))
predicted_values_trendline = LinearRegression().fit(np.arange(len(predicted_values)).reshape(-1, 1), predicted_values.reshape(-1, 1)).predict(np.arange(len(predicted_values)).reshape(-1, 1))
# Customize plot
plt.title('Predicted AUC vs True AUC')
plt.ylabel('Drug Response AUC')
plt.xlabel('Values')
plt.legend(['True Values', 'Predicted Values', 'True Values Trendline', 'Predicted Values Trendline'], loc='lower right')
plt.show()
# Calculate Summary Statistics
# Calculate Mean Absolute Error (MAE)
mae = mean_absolute_error(new_df['True Values'], new_df['Predicted Values'])
# Calculate Root Mean Squared Error (RMSE)
rmse = np.sqrt(mean_squared_error(new_df['True Values'], new_df['Predicted Values']))
# Calculate R-squared (R2) score
r2 = r2_score(new_df['True Values'], new_df['Predicted Values'])
summary_statistics = new_df.describe()
# Print the statistics
print("Mean Absolute Error (MAE):", mae)
print("Root Mean Squared Error (RMSE):", rmse)
print("R-squared (R2) Score:", r2)
print("\n")
# Print summary statistics
print(summary_statistics)
Mean Absolute Error (MAE): 0.11160688844966887 Root Mean Squared Error (RMSE): 0.14780030079805684 R-squared (R2) Score: 0.37366931237133216 Predicted Values True Values count 100.000000 100.000000 mean 0.731800 0.723970 std 0.128604 0.187696 min 0.477577 0.000000 25% 0.639343 0.616050 50% 0.733571 0.736050 75% 0.838911 0.869575 max 0.961458 1.000000
#Side by side comparison for first 50 values.
new_df[0:50]
Predicted Values | True Values | |
---|---|---|
918 | 0.777755 | 0.7367 |
525 | 0.716599 | 0.6711 |
567 | 0.684600 | 0.8426 |
656 | 0.576972 | 0.4010 |
915 | 0.833463 | 0.8296 |
429 | 0.644341 | 0.6900 |
855 | 0.477577 | 0.6126 |
711 | 0.530543 | 0.4982 |
174 | 0.477577 | 0.4917 |
604 | 0.572155 | 0.6209 |
865 | 0.881899 | 0.8678 |
449 | 0.695015 | 0.6786 |
777 | 0.683768 | 0.4146 |
580 | 0.851610 | 0.9087 |
76 | 0.796451 | 0.8499 |
371 | 0.961458 | 0.9997 |
884 | 0.799178 | 0.9071 |
136 | 0.515090 | 0.5763 |
158 | 0.729075 | 0.7231 |
290 | 0.787732 | 0.6100 |
673 | 0.748977 | 0.6931 |
321 | 0.702108 | 0.5028 |
757 | 0.899316 | 0.8691 |
70 | 0.666315 | 0.7865 |
355 | 0.569279 | 0.3572 |
359 | 0.961239 | 0.8362 |
107 | 0.886554 | 0.9438 |
265 | 0.640555 | 0.5975 |
825 | 0.701618 | 0.6212 |
139 | 0.563990 | 0.6810 |
184 | 0.626397 | 0.7691 |
708 | 0.688292 | 0.6542 |
622 | 0.725480 | 0.5147 |
941 | 0.828256 | 0.9047 |
305 | 0.676587 | 0.6286 |
809 | 0.726858 | 0.6427 |
306 | 0.818704 | 0.8063 |
767 | 0.721371 | 0.7445 |
863 | 0.902120 | 0.8096 |
323 | 0.957583 | 1.0000 |
59 | 0.477577 | 0.9167 |
298 | 0.494210 | 0.4433 |
668 | 0.631534 | 0.7201 |
617 | 0.635709 | 0.8713 |
370 | 0.789941 | 0.7637 |
23 | 0.700036 | 0.7924 |
30 | 0.774137 | 0.9460 |
816 | 0.760774 | 0.8721 |
10 | 0.576904 | 0.3380 |
514 | 0.685097 | 0.6305 |