In [13]:
# Import library
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.preprocessing import LabelEncoder, StandardScaler
from scipy.sparse import hstack
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.model_selection import GridSearchCV

# Load data
ml_data = pd.read_csv('ml_data.csv', index_col=0)

Will be using multi-modal classification. Input will be the AA sequence along with the numerical features and the target would be the protein classification. The first step will be to convert the AA sequence (unstructured string data) into n-gram of 4. n-gram helps capture local patterns of the protein sequence and converts them in to numeric vectors. Scikit-learn's CountVectorizer that extracts the n-grams from the strings.

In [14]:
# Overview of data
print(ml_data)
       classification structureId  \
0      OXIDOREDUCTASE        1A72   
1       VIRAL PROTEIN        1A8O   
2      OXIDOREDUCTASE        1AR4   
3         TRANSFERASE        1AUE   
4           HYDROLASE        1AUK   
...               ...         ...   
33766   VIRAL PROTEIN        6F5U   
33767       HYDROLASE        6F6P   
33768       HYDROLASE        6F6P   
33769   VIRAL PROTEIN        6F6S   
33770   VIRAL PROTEIN        6F8P   

                                                sequence  resolution  \
0      STAGKVIKCKAAVLWEEKKPFSIEEVEVAPPKAHEVRIKMVATGIC...        2.60   
1      MDIRQGPKEPFRDYVDRFYKTLRAEQASQEVKNWMTETLLVQNANP...        1.70   
2      AVYTLPELPYDYSALEPYISGEIMELHHDKHHKAYVDGANTALDKL...        1.90   
3      ELIRVAILWHEMWHEGLEEASRLYFGERNVKGMFEVLEPLHAMMER...        2.33   
4      RPPNIVLIFADDLGYGDLGCYGHPSSTTPNLDQLAAGGLRFTDFYV...        2.10   
...                                                  ...         ...   
33766  EAIVNAQPKCNPNLHYWTTQDEGAAIGLAWIPYFGPAAEGIYIEGL...        2.07   
33767  GAASRLRSPSVLEVREKGYERLKEELAKAQRELKLKDEECERLSKV...        2.45   
33768  GASSRLRSPSVLEVREKGYERLKEELAKAQRELKLKDEECERLSKV...        2.45   
33769  ETGRSIPLGVIHNSALQVSDVDKLVCRDKLSSTNQLRSVGLNLEGN...        2.29   
33770  EDPHLRNRPGKGHNYIDGMTQEDATCKPVTYAGACSSFDVLLEKGK...        1.60   

       structureMolecularWeight  crystallizationTempK  densityMatthews  \
0                      40658.50                 277.0             2.30   
1                       8175.72                 277.0             2.21   
2                      45428.53                 277.0             2.05   
3                      24203.73                 277.0             2.25   
4                      52423.45                 291.0             3.30   
...                         ...                   ...              ...   
33766                  57299.72                 293.0             3.48   
33767                  47994.95                 291.0             2.61   
33768                  47994.95                 291.0             2.61   
33769                  58337.03                 293.0             3.83   
33770                  34958.86                 298.0             2.47   

       densityPercentSol  phValue  residueCount  
0                  46.82      8.4           374  
1                  43.80      8.0            70  
2                  32.00      6.1           402  
3                  45.00      8.0           200  
4                  63.00      5.4           489  
...                  ...      ...           ...  
33766              64.67      5.2           498  
33767              56.00      7.3           424  
33768              56.00      7.3           424  
33769              67.89      5.2           497  
33770              50.15      7.0           316  

[33763 rows x 10 columns]
In [ ]:
# Prepare target variable: Classification
le = LabelEncoder()
y = le.fit_transform(ml_data['classification']) # Strings to numeric integers

# prepare 4-mer sequence feature
vectorizer = CountVectorizer(analyzer='char', ngram_range=(4, 4))
X_seq = vectorizer.fit_transform(ml_data['sequence']) 

# prepare numeric features
numeric_col = ['resolution', 'crystallizationTempK', 'densityMatthews', 'densityPercentSol', 'phValue', 'residueCount']
X_num = ml_data[numeric_col]

# Standardize feature 
scaler = StandardScaler()
X_num_scaled = scaler.fit_transform(X_num)

# Combine features
X_all = hstack([X_seq, X_num_scaled]) # Horizontally stacks array side by side 


# Train test split
X_train, X_test, y_train, y_test = train_test_split(X_all, y, test_size=0.2, random_state=42, stratify=y)
# stratify= ensures that the proportion preserved in both the training and testing sets. Needed for balanced splits of data
In [17]:
# Random Forest

model = RandomForestClassifier(n_estimators =100, class_weight='balanced', random_state=42)
# Since there were an imbalance of classification counts
# Will add class_weight argument for the model to focus more on rare classes

model.fit(X_train, y_train)

y_pred = model.predict(X_test)
print(classification_report(y_test, y_pred, target_names=le.classes_))
                                   precision    recall  f1-score   support

                    CELL ADHESION       0.79      0.37      0.50       119
                       CELL CYCLE       0.67      0.35      0.46        63
                        CHAPERONE       0.94      0.42      0.59        80
              DNA BINDING PROTEIN       0.71      0.24      0.36        99
               ELECTRON TRANSPORT       0.66      0.56      0.61        87
                  GENE REGULATION       0.33      0.23      0.27        30
                        HYDROLASE       0.38      0.87      0.53      1238
    HYDROLASE/HYDROLASE INHIBITOR       0.47      0.47      0.47       183
                    IMMUNE SYSTEM       0.84      0.83      0.84       694
                        ISOMERASE       0.96      0.55      0.70       149
                           LIGASE       0.74      0.42      0.54       160
                            LYASE       0.96      0.65      0.77       273
                 MEMBRANE PROTEIN       0.75      0.45      0.56       127
            METAL BINDING PROTEIN       0.64      0.28      0.39        76
                   OXIDOREDUCTASE       0.88      0.61      0.72       735
                   PHOTOSYNTHESIS       0.64      0.83      0.72        46
                  PROTEIN BINDING       0.60      0.19      0.29       190
                PROTEIN TRANSPORT       0.87      0.46      0.61       114
              RNA BINDING PROTEIN       0.68      0.22      0.33        59
                SIGNALING PROTEIN       0.78      0.34      0.47       238
               STRUCTURAL PROTEIN       0.72      0.36      0.48       100
            SUGAR BINDING PROTEIN       0.83      0.61      0.70        82
                            TOXIN       0.92      0.48      0.63        69
                    TRANSCRIPTION       0.67      0.32      0.43       244
                      TRANSFERASE       0.64      0.65      0.65       924
TRANSFERASE/TRANSFERASE INHIBITOR       0.40      0.32      0.36        84
                TRANSPORT PROTEIN       0.82      0.40      0.54       236
                 UNKNOWN FUNCTION       0.65      0.19      0.29        80
                    VIRAL PROTEIN       0.88      0.48      0.62       174

                         accuracy                           0.59      6753
                        macro avg       0.72      0.45      0.53      6753
                     weighted avg       0.69      0.59      0.59      6753