Predicting Drug Solubility with Deep Learning

Sully Chen
11 min readDec 16, 2020

**Code used in this article is available here on my GitHub.

The Problem

The process of developing and releasing a new pharmaceutical is an absolutely tremendous feat. Oftentimes the first step in drug discovery is the testing of hundreds of thousands of candidate drugs in vitro through the use of high-throughput screening — automated processes that conduct simple pharmacological experiments in parallel on a drug target. As would imagine, this process is ridiculously expensive, time-consuming, and resource-intensive. Computational approaches that allow for the screening of drug candidates in silico are therefore highly sought after; this type of system could reduce the number of necessary physical experiments by orders of magnitude, drastically speeding up drug discovery and reducing costs. These types of approaches already exist, and you can see my favorite example of this concept in this paper published in Cell earlier this year.

In this article, I’m going to demonstrate how we can create a very simple Python and Torch script that takes a molecule as an input (represented by a SMILES string) and classifies it as either insoluble, slightly soluble, or soluble in water via a deep neural network trained off of the Harvard AqSolDB dataset, which contains a list of 10000 compounds and their various properties. Solubility is, as one may expect, an extremely important characteristic of a potential drug. A drug needs to be able to distribute itself to the target area(s) of the body, and it cannot do so if it remains mostly as a solid when taken orally. Furthermore, it would be absolutely disastrous if the compound crashed out of solution when injected intravenously into a patient. You can read more about the importance of drug solubility here.

How do we even approach this with deep learning?

Most neural networks typically take inputs of a fixed dimension, or at least inputs that have fixed constraints on their dimensionality. For example, a convolutional neural network trained on 256x256 pixel images can’t even process an image that is 255x257 pixels, purely from an architectural standpoint (of course, the image can be cropped, resized, or transformed in such a way that it can be used as an input). Other networks, like recurrent neural networks (RNN) can take inputs of arbitrary length, but require that every dimension besides the length is fixed. For example, an RNN trained on character-level data may accept sequences of dimension (Nx26), where N is the sequence length and 26 is a one-hot feature vector for each letter of the alphabet. Essentially, the flexibility of a neural network’s inputs varies from model to model, but there are almost always constraints on the input dimension.

Molecules, on the other hand, vary wildly in structure. Even molecules with the same number and type of atoms, such as octane (C8H20), can be arranged in 18 different ways. In other words, the composition, length, and connectivity can all vary between molecules. What neural architecture can we use to process this kind of wild data?

First off, it helps to think of molecules as undirected graphs. All molecules can be represented by a set of nodes and edges. Each node in such a graph would then have a set of features, such as the element the node represents, the stereochemical information of that atom, hybridization state, location in physical space, etc. So to start off, let’s write some Python code that will take a SMILES string and convert it into a NetworkX graph structure, using Pysmiles. This will allow us to at least handle the data in some meaningful way.

import numpy as np
from pysmiles import read_smiles
G = read_smiles("CN(C)C(=N)N=C(N)N", explicit_hydrogen=True)
print(G.nodes(data=’element’))
print(G.edges)

It’s as easy as that — we now have a NetworkX graph, G, that contains our molecule’s (metformin, in this case) structure. The output of that snippet yields:

[(0, ‘C’), (1, ‘N’), (2, ‘C’), (3, ‘C’), (4, ‘N’), (5, ‘N’), (6, ‘C’), (7, ‘N’), (8, ‘N’), (9, ‘H’), (10, ‘H’), (11, ‘H’), (12, ‘H’), (13, ‘H’), (14, ‘H’), (15, ‘H’), (16, ‘H’), (17, ‘H’), (18, ‘H’), (19, ‘H’)][(0, 1), (0, 9), (0, 10), (0, 11), (1, 2), (1, 3), (2, 12), (2, 13), (2, 14), (3, 4), (3, 5), (4, 15), (5, 6), (6, 7), (6, 8), (7, 16), (7, 17), (8, 18), (8, 19)]

Notice how each node, from 0 to 19, has a corresponding element, and how each edge is described by a (parent, child) ordered pair.

The Fun Part

Now it’s time to actually do some deep learning. It turns out there is a whole lot of research devoted to processing graphs with deep neural networks, as this problem has wide applications. In this project, we will be using a Graph Convolutional Network (GCN). GCNs handle the arbitrary and varied structure of graphs in a very elegant way. Essentially, each node on our graph has its own 1D-feature vector of a fixed dimension. At each processing step in the neural network, the nodes “exchange” information with each other by sending a transformed version of their feature vector to their neighbor nodes. It also processes its own information by sending the transformed feature vector to itself. The transformation applied to the feature vector can be represented as a learnable weight matrix, which we can train through gradient descent. Each node takes all the transformed feature vectors from their neighbors and combines them through a dimension-preserving operation, such as an elementwise average, a max-pool, or something else. In our particular case, the dimension preserving operation is a little complicated, but very beautiful. I suggest reading this excellent piece on GCNs. Finally, a non-linearity activation function is applied, such as RELU, Tanh, Sigmoid, or something else. Finally, we can use these processed features at each node to compute something about the properties of each node, or perhaps the graph as a whole. For example, if we were trying to predict the age of a user in a social network, we could use regression on the features at each node to make a prediction. In our case, we want to do graph-classification, i.e. we want to classify the graph as a whole as soluble, insoluble, or slightly soluble. To do this, we can do one final dimension-preserving aggregation of all the nodes in the graph, and run a classification network on that final feature.

A nice picture of a graph neural network (source)

The Code

We will be using Python and PyTorch Geometric to train our model. Pytorch Geometric is built on top of PyTorch and implements many graph network layers and algorithms that will make this process a lot easier for us. First, let’s get all of our import statements out of the way so they don’t clutter the rest of our code:

import numpy as np
import random
import matplotlib.pyplot as plt
from pysmiles import read_smiles
import pandas as pd
import logging
from tqdm import tqdm
import torch
from torch.nn import Sequential as Seq, Linear, ReLU, CrossEntropyLoss
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, GCNConv
from torch_geometric.utils import remove_self_loops, add_self_loops, degree
from torch_geometric.data import Data
logging.getLogger(‘pysmiles’).setLevel(logging.CRITICAL) # Anything higher than warning

Loading the Data

First, let’s load up our dataset using Pandas, and format it into the data structure that PyTorch Geometric uses.

df = pd.read_csv(‘dataset.csv’) #read dataset 
X_smiles = list(df[‘SMILES’]) #get smiles strings
Y = np.asarray(df[‘Solubility’]) #get solubility values
#list of all elements in the dataset, which I've precomputed
elements = [‘K’, ‘Y’, ‘V’, ‘Sm’, ‘Dy’, ‘In’, ‘Lu’, ‘Hg’, ‘Co’, ‘Mg’,
‘Cu’, ‘Rh’, ‘Hf’, ‘O’, ‘As’, ‘Ge’, ‘Au’, ‘Mo’, ‘Br’, ‘Ce’,
‘Zr’, ‘Ag’, ‘Ba’, ’N’, ‘Cr’, ‘Sr’, ‘Fe’, ‘Gd’, ‘I’, ‘Al’,
‘B’, ‘Se’, ‘Pr’, ‘Te’, ‘Cd’, ‘Pd’, ‘Si’, ‘Zn’, ‘Pb’, ‘Sn’,
‘Cl’, ‘Mn’, ‘Cs’, ‘Na’, ‘S’, ‘Ti’, ‘Ni’, ‘Ru’, ‘Ca’, ‘Nd’,
‘W’, ‘H’, ‘Li’, ‘Sb’, ‘Bi’, ‘La’, ‘Pt’, ‘Nb’, ‘P’, ‘F’, ‘C’]
#convert element to a one-hot vector of dimension len(elements)
def element_to_onehot(element):
out = []
for i in range(0, len(element)):
v = np.zeros(len(elements))
v[elements.index(element[i])] = 1.0
out.append(v)
return np.asarray(out)

Above we’ve extracted the SMILES strings and solubility information from the .csv file, defined a list of all the elements in the strings, and defined a function called “element_to_onehot”. This function converts an element string (‘C’, ‘Ru’, etc.) into a 61-dimensional one-hot vector, which will be the starting features of our nodes.

def val_to_class(val):
if val < -3.65: #insoluble
return [1, 0, 0]
elif val < -1.69: #slightly soluble
return [0, 1, 0]
else: #soluble
return [0, 0, 1]

Next, we define the criteria for solubility in this function. Essentially, if a saturated solution can hold at most 10^-3.65 mol/L in solution, we consider that solute to be insoluble. If it’s not insoluble, but can hold at most 10^-1.69 mol/L in solution, we consider it slightly soluble. Otherwise, the result is very soluble. Granted, this is not the best definition of solubility, and I simply did this so the dataset would divide up into even thirds. This will do for our purposes of example though.

#process SMILES strings into graphs
nodes = []
edge_index = []
for smiles in tqdm(X_smiles):
try:
G = read_smiles(smiles, explicit_hydrogen=True)
feature = element_to_onehot(np.asarray(G.nodes(data=’element’))[:, 1])
edges = np.asarray(G.edges)
index = np.asarray([edges[:,0], edges[:,1]])
nodes.append(feature)
edge_index.append(index)
except:
pass

Here we process the SMILES strings into graphs like we did earlier, and stores them in lists. We have a try/except block to ignore any errors processing the SMILES string. TQDM is used as a progress bar.

#Generate Data objects
data = list()
#process graphs into torch_geometric Data objects
for i in tqdm(range(0, len(nodes))):
x = torch.tensor(nodes[i], dtype=torch.float)
edges = torch.tensor(edge_index[i], dtype=torch.long)
y = torch.tensor([val_to_class(Y[i])], dtype=torch.float)
data.append(Data(x=x,edge_index=edges, y=y))
random.shuffle(data)
train = data[:int(len(data)*0.8)] #train set
test = data[int(len(data)*0.8):] #val set
train = data

Next, we generate the PyTorch Geometric Data() objects. This simply requires a list of node features (x) and an edge index (edges). We also add our data label (y) to the object for training purposes. Finally, we partition our dataset into train and validation sets with an 80:20 split.

Defining the network

We’re finally ready to define our graph convolutional network. We set it up as we would any other PyTorch network, so I’ll paste the block of code below and we’ll go over the parts of the code that are unique and important.

class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(61, 32)
self.conv2 = GCNConv(32, 32)
self.conv3 = GCNConv(32, 32)
self.conv4 = GCNConv(32, 32)
self.lin1 = Linear(32, 16)
self.lin2 = Linear(16, 3)
def forward(self, data):
x, edge_index= data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.25, training=self.training)

x = self.conv2(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.25, training=self.training)

x = self.conv3(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.25, training=self.training)

x = self.conv4(x, edge_index)
x = F.relu(x)

x = torch.sum(x, dim=0)
x = self.lin1(x)
x = F.relu(x)

x = self.lin2(x)
return x

First off, we use GCNConv layers that take an input dimension and an output dimension. Notice how our first layer takes an input size of 61, which is our one-hot dimension and outputs a size of 32. After this layer, we will have a set of nodes that each have a feature vector of size 32. Before our first linear layer, I sum together the values of all the node features element-wise, to obtain a single 32-dimensional vector which we will do classification on via a normal fully-connected neural network. Notice how we aren’t applying SoftMax or any normalization to the final output of size 3 (representing our 3 classes) — this is because the loss function we use later on will do this automatically.

Training

It’s now time for us to start training the model. Let’s initialize our environment first:

#set up device and create model
device = torch.device(‘cuda’ if torch.cuda.is_available() else ‘cpu’) #use CUDA if available
model = Net().to(device) #create network and send to the device memory
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4) #use Adam optimizer
CSE = CrossEntropyLoss()

Here, we define the device we’re using, declare our model and transfer it to device memory, declare our optimizer (Adam), and declare our loss function (cross-entropy).

Now, we’ll define our main training loop:

#train model
model.train() #set model to training mode
for epoch in range(32): #run for epochs of training
sum_loss = 0 #used to compute average loss in an epoch
num_correct = 0
random.shuffle(train) #shuffle the training data each epoch
for d in tqdm(train): #go over each training point
data = d.to(device) #send data to device
optimizer.zero_grad() #zero gradients
out = model(data) #evaluate data point
if torch.argmax(out) == torch.argmax(data.y): #if prediction is correct, increment counter for accuracy calculation
num_correct += 1
loss = CSE(torch.reshape(out, [1, 3]), torch.reshape(torch.argmax(data.y),[1])) #compute loss
sum_loss += float(loss) #add loss value to aggregate loss
loss.backward() #compute gradients
optimizer.step() #apply optimization
print(‘Epoch: {:03d}, Average loss: {:.5f}, Accuracy: {:.5f}’.format(epoch, sum_loss/len(train), num_correct/len(train)))

This really isn’t any different than a normal PyTorch training loop, so not much needs to be said about this code snippet. After 32 epochs, we have a model that has about a 43% accuracy on our validation set. Compared to random guessing (33% accuracy), we see about a 10% increase in accuracy. Granted, this isn’t great, but solubility is an incredibly challenging problem and it’s amazing that we see any performance increase at all on such a simplistic setup.

Testing it out

Let’s play around with the model a bit to see what it’s learned.

First, let’s feed it the smiles string for decane(C10H22), a hydrocarbon that is basically insoluble in water.

Decane
>>>evaluate_smiles(‘CCCCCCCCCC’)insoluble

Great! Now let’s add carboxylic acid functional groups to the ends of our molecule and see how that affects solubility.

Sebacic acid
>>>evaluate_smiles(‘C(=O)(O)CCCCCCCCC(=O)(O)’)slightly soluble

Indeed, the molecule is slightly soluble in water. Awesome, it looks like the model works for this case!

Alright, let’s try something harder. Let’s give it the SMILES structure of Aspirin:

Aspirin
>>>evaluate_smiles(‘O=C(C1=C(OC(C)=O)C=CC=C1)O’)soluble

I would hope so! Aspirin better be soluble, otherwise I’ve been taking a placebo for my headaches the past few years…

The model isn’t perfect though… it reports benzene as “slightly soluble” in water, which of course is false:

Benzene
>>>evaluate_smiles(‘c1ccccc1’)slightly soluble

When we add a couple of tert-butyl groups to our benzene, the model reports it as totally insoluble which is great!

1,4-di-tert-butylbenzene
>>>evaluate_smiles('CC(C)(C)c1ccc(C(C)(C)C)cc1')insoluble

Okay, for the final test, let’s try the protein insulin and see how it handles it…

>>>evaluate_smiles(‘CCC(C)C1C(=O)NC2CSSCC(C(=O)NC(CSSCC(C(=O)NCC(=O)NC(C(=O)NC(C(=O)NC(C(=O)NC(C(=O)NC(C(=O)NC(C(=O)NC(C(=O)NC(C(=O)NC(C(=O)NC(C(=O)NC(CSSCC(NC(=O)C(NC(=O)C(NC(=O)C(NC(=O)C(NC(=O)C(NC(=O)C(NC(=O)C(NC(=O)C(NC2=O)CO)CC(C)C)CC3=CC=C(C=C3)O)CCC(=O)N)CC(C)C)CCC(=O)O)CC(=O)N)CC4=CC=C(C=C4)O)C(=O)NC(CC(=O)N)C(=O)O)C(=O)NCC(=O)NC(CCC(=O)O)C(=O)NC(CCCNC(=N)N)C(=O)NCC(=O)NC(CC5=CC=CC=C5)C(=O)NC(CC6=CC=CC=C6)C(=O)NC(CC7=CC=C(C=C7)O)C(=O)NC(C(C)O)C(=O)N8CCCC8C(=O)NC(CCCCN)C(=O)NC(C)C(=O)O)C(C)C)CC(C)C)CC9=CC=C(C=C9)O)CC(C)C)C)CCC(=O)O)C(C)C)CC(C)C)CC2=CN=CN2)CO)NC(=O)C(CC(C)C)NC(=O)C(CC2=CN=CN2)NC(=O)C(CCC(=O)N)NC(=O)C(CC(=O)N)NC(=O)C(C(C)C)NC(=O)C(CC2=CC=CC=C2)N)C(=O)NC(C(=O)NC(C(=O)N1)CO)C(C)O)NC(=O)C(CCC(=O)N)NC(=O)C(CCC(=O)O)NC(=O)C(C(C)C)NC(=O)C(C(C)CC)NC(=O)CN’)insoluble

Hmm, that’s unexpected isn’t it? We can inject insulin directly into our veins. Actually, insulin is insoluble in pure water — but dissolves easily in our blood plasma. This is a great example of the limitations of such a model. Compounds measured to be soluble/insoluble in water may not behave the same way in a physiological environment! Additionally, proteins like insulin fold in a particular 3D structure, which this model does not explicitly take into account. Thus for any large molecules, this model would likely be highly inaccurate.

Finally, let’s try a super soluble organic compound, vitamin C:

Vitamin C
>>>evaluate_smiles(‘OC([C@H](O1)[C@H](CO)O)=C(O)C1=O’)soluble

Phew! That would be embarrassing if it got that one wrong.

Wrap-Up

Using a simple Graph Convolutional Network in PyTorch Geometric, we were able to create a classification system that could take an arbitrary graph representation of a molecule and predict its solubility with accuracy 10% better than random guessing. Not great, but amazing given the complexity of the problem! We also had some fun playing around with different molecules in our software.

--

--

Sully Chen

Machine learning, mathematics, medicine. I do research in biotech.