This notebook covers some basic supervised learning techniques, which can be used in microbiome science.
The following methods can be used to find features (for example ASVs) that can predict some outcome of interest, for example whether a sample comes from a control or treatment group, or some other metadata associated with the samples. In some sense, the methods have the same outcome as in Differential abundance analysis, but with supervised machine learning the purpose is not of inference (based on p-values), but it is prediction.
Getting in-depth: If you want to learn more about machine learning I can highly recommend The Elements of Statistical Learning by Hastie, Tibshirani, and Friedman.
Short primers:
Let's load our example dataset
library(phyloseq)
load("../data/physeq.RData")
To ensure our model is not overfitting (having a good fit in our dataset, but cannot be generalized to other similar datasets), we need to split our dataset into 3 parts to ensure proper fitting. The train set of the dataset is used to train a specific model, the validation set is used compare models to choose hyperparameters of the model, and the test set is used only to check how good our final model works.
A widely used method for choosing hyperparameters is to use cross-validation. With cross-validation the train and validation datasets are combined and are split into k parts. Then the model is fit k times using a different part of the dataset each time as validation set, and the remainder as train set.
If you want to know how good your model is, you should use a test set which has not been used for training at all. It's actually uncommon to see a test set in microbiome science, but because it is best practice we will use it in this notebook
# Split in test and train/validate (30 random samples are used a test set)
set.seed(42)
test_set <- sample(sample_names(phy), 30)
phy_train <- subset_samples(phy, !sample_names(phy) %in% test_set)
phy_test <- subset_samples(phy, sample_names(phy) %in% test_set)
phy_train
phy_test
A simple supervised learning method would be to use multiple linear regression, and simply add all features as independent variables. However, the problem with microbiome datasets is that we usually have many more features than samples (p > n problem), which means we cannot fit these models. A way to fix this problem is to use sparse regularization; the idea is that we penalize the model when it adds features, meaning that we try to force the model to only use features that are important enough for the prediction.
In-depth paper on regularization
L1 penalty or LASSO, is a penalty which sets the estimates of "non-important" features to zero, that is, it selectes which features are most important for predicting the outcome. If there are highly correlating features, it will choose randomly among these.
L2 penalty or Ridge, is a penalty which reduces the estimates of all features as more features are added. It will therefore not select features (most estimates will be non-zero), but it will regularize the model. It is better fitted than LASSO if highly correlated features are included in the model
Elastic net is a generalized penalty which introduces an alpha parameter. When alpha=1 it is a LASSO penalty, when alpha=0 it is Ridge, and with alpha between 0 and 1 it is a mix of the two.
All these sparse regularization methods needs a lamda hyperparameter, which controls how strong the penalty is. Elastic net additionally needs the alpha hyperparameter.
Let's fit a LASSO model. We use logistic regression as we want to predict whether our sample comes from a child which has been born by C-section (1) or vaginal birth (0).
library(glmnet)
# Extract outcome and make it binary
y <- ifelse(unlist(sample_data(phy_train)[,"Delivery"]) == "Sectio", 1, 0)
# Extract features and normalize and transform them
X <- otu_table(phy_train)
X <- apply(X, 2, function(x) x+1/sum(x+1))
X <- t(log10(X))
Note on transformation: As the model assume linearity we log-transform the relative abundances to make them more normal. Alternatively, one could do a CLR transformation of the abundances.
Fit the model (with 5 cross-validation folds, 5-10 are usually recommended)
cvfit <- cv.glmnet(X, y, family = "binomial", alpha = 1, nfolds = 5)
We can plot the lambda parameter against the deviance. Low deviance means it's a good fit.
plot(cvfit)
If we start reading the plot from the left, we have many features in the model (84), as lambda increases (moving to the right), we get fewer features and the deviance is getting smaller. At some point the deviance starts rising again as we get even fewer features. The first vertical line denotes the best model (lowest deviance), the second vertical line denotes the simplest model of which the deviance is within 1 standard error of the mean of the best model.
With few features in the model (high lambda, right in the plot) we simply don't have enough information to make good predictions. With many features (low lambda, left in the plot) we start overfitting; features are added that only contributes variance to the model, so they might correlate with the outcome in the training set, but not in the validation set, and are therefore probably noise. So the sweet spot is somewhere inbetween - the lowest point in the U. If the curve is not U-shaped you might have specified the model incorretly or there is simply no signal in the data.
Let's see the coefficients of the simplest (but still good) model:
all_1se <- as.matrix(coef(cvfit, s = "lambda.1se"))
chosen_1se <- all_1se[all_1se > 0, ]
chosen_1se
Get taxonomy of the chosen ones (the -1 removes the intercept):
tax_table(phy)[names(chosen_1se)[-1]]
Above we have our chosen features and associated estimates for the model. As glmnet by default is standardizing the features, the estimates can be compared directly, and the highest estimate (in absolute terms) can be said to be most important for the prediction. Positive estimates would mean that higher abundance results in increase odds of being in the 1 group (C-section) compared to the 0 group (vaginal birth), and vice versa for negative estimates. This a strength of linear models compared to for example decision trees (e.g. random forest) where the associations can be non-linear and therefore not necesarrily easily interpretable.
Let's check how good the model is on the test set
# Extract outcome and make it binary
y_test <- ifelse(unlist(sample_data(phy_test)[,"Delivery"]) == "Sectio", 1, 0)
# Extract features and normalize and transform them
X_test <- otu_table(phy_test)
X_test <- apply(X_test, 2, function(x) x+1/sum(x+1))
X_test <- t(log10(X_test))
table(y_test, predict(object = cvfit, s = "lambda.1se", newx = X_test, type = "class"))
The rows are the truths (test set) and the columns are the predicted. So of the 14 samples that were 0 (vaginal birth), 11 were correctly predicted as such and 3 were false predicted as 1 (C-section). Of the 16 samples that were 1 (C-section), 12 were correctly predicted as such and 4 were false predicted as 0 (vaginal birth). The accuracy is 77% ((11+12)/30).
Let's test our model on our train/validation set:
table(y, predict(object = cvfit, s = "lambda.1se", newx = X, type = "class"))
Now the accuracy is 94%. So we can see that the accuracy becomes falsely inflated if we were to test it on the same dataset as we used to training.
For more details on sparse regularized linear models, see here
Random forests is a model based on an ensemble of decision trees based on bagging (bootstrap aggregating) and random subsets of the features.
A decision tree is a type of model that splits the samples into groups (leaves on the tree) based on the features. For example, a simple tree could contain a single split such that samples in which an ASV is less abundant than 1% are in branch A and samples in which this ASV is more abundant than 1% are in branch B. The decision trees can have multiple splits, such that they are based on multiple features. In-depth paper on decision trees
With bootstrapping one does a random sampling with replacement of the samples. That is, instead of using the raw data as input to the model, the model randomly chooses the same number of samples, but some samples can be included more than one time (and some might be excluded). Bootstrapping is crucial, as without bootstrapping the trees would be highly similar (correlated), which would bias the final model. In-depth paper in the bootstrap
To create more variation among the decision trees, only a random subset of features are used at each split.
Random forest is fitting multiple decision trees, each tree is trained on bootstrapped samples, and each split in the trees are using random subsets of the features. Each tree is actually a bad predictor, but aggregating a lot of trees will produce good predictions. For classification problems they are aggregated by majority vote, and for regression problems the aggregation is simply the mean predictions across all trees. In-depth paper on bagging and random forest
There are 2 main hyperparameters for random forests, the number of trees, and the number of randomly chosen features at each split (AKA mtry). The last one has some widely used recommendations; sqrt(n_features) for classification problems and n_features/3 for regression problems. As for the number of trees, a few hundred usually works well, but this number could also be tuned through cross-validation.
Below we simply use the default parameters for random forest, and we therefore only have a training and test set.
library(randomForest)
fit <- randomForest(y = factor(y), x = X,
ytest = factor(y_test), xtest = X_test,
importance = TRUE)
fit
So the model has a test accuracy of 73%, a little worse than the LASSO linear model used above. The OOB estimate of error rate means Out-Of-Bag, and is using the samples not included by the bootstrap procedure as a validation set.
We can see which features are important for the model. MeanDecreaseAccuracy is the decrease in accuracy if this features was removed from the model, thus a high value means the feature is important:
fit$importance[rev(order(fit$importance[, "MeanDecreaseAccuracy"])), ]
We can merge this with the taxonomy:
tax <- data.frame(tax_table(phy))
imp <- fit$importance
imp_tax <- merge(imp, tax, by = "row.names")
imp_tax[rev(order(imp_tax$MeanDecreaseAccuracy)), ]