A deeper look at viewmastR
2024-08-08
InDepth.Rmd
A Deeper Look at viewmastR
Before diving into how viewmastR
works, let’s first set
up the necessary environment and go through essential functions to
streamline your training and analysis workflow.
1. Installing Rust
Before using viewmastR
, you’ll need an updated
installation of Rust, as it’s a core dependency. Follow
the instructions on the official Rust installation
page to set up Rust on your system.
2. Installing viewmastR
Once Rust is installed, you can install viewmastR
directly from GitHub. Ensure you have the devtools
package
installed, and then use the following command:
devtools::install_github("furlan-lab/viewmastR")
3. Viewing the Training History
viewmastR
tracks key data during the training process,
which can be accessed by setting the return_type
parameter
to "list"
. This returns: 1. The query object with predicted
cell types. 2. The training results.
Here’s how you can retrieve and visualize the training data:
suppressPackageStartupMessages({
library(viewmastR)
library(Seurat)
library(ggplot2)
library(scCustomize)
library(plotly)
})
# Load query and reference datasets
seu <- readRDS(file.path(ROOT_DIR1, "240813_final_object.RDS"))
vg <- get_selected_genes(seu)
seur <- readRDS(file.path(ROOT_DIR2, "230329_rnaAugmented_seurat.RDS"))
# View training history
output_list <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, return_type = "list")
Visualizing Training Data
To plot training vs validation loss, you can use the following:
plot_training_data(output_list)
For rendering the plot without details:
plt <- plot_training_data(output_list)
plt
Tip: If the training loss decreases while the validation loss plateaus, it may indicate overfitting.
4. Tuning for Speed
viewmastR
runs with 3 available backends (see Burn
for more details). The candle
backend tends to run faster
on Apple M1/M2 processors. Here’s how you can compare the performance of
different backends on your system:
run1 <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, max_epochs = 3, backend = "candle", return_type = "list")
run2 <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, max_epochs = 3, backend = "wgpu", return_type = "list")
run3 <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, max_epochs = 3, backend = "nd", return_type = "list")
# Compare training times
gp<-ggplot(data.frame(training_times = c(run1$training_output$duration$training_duration,
run2$training_output$duration$training_duration,
run3$training_output$duration$training_duration),
backend = c("candle", "wgpu", "nd")),
aes(x = backend, y = training_times, fill = backend)) +
geom_col() +
theme_bw() +
labs(x = "Backend", y = "Training Time (s)") +
NoLegend() + ggtitle(paste("Arch: ", as.character(Sys.info()["machine"])))
ggplotly(gp)
# To automatically set backend
if (as.character(Sys.info()["machine"])=="x86_64"){
backend <- "wgpu"
} else {
backend <- "candle"
}
5. Saving Training Subsets
To inspect the training and test data used by viewmastR
we provide the setup_training function if you so desire to evaluate
these using other learning frameworks.
ti <- setup_training(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, return_type = "matrix", backend = backend)
# Convert labels to max class and save them
train_label <- apply(ti$Ytrain_label, 1, which.max)
test_label <- apply(ti$Ytest_label, 1, which.max)
# Save training data and labels
writeMMgz(as(ti$Xtrain_data, "dgCMatrix"), "/path/to/train.mm.gz")
writeMMgz(as(ti$Xtest_data, "dgCMatrix"), "/path/to/test.mm.gz")
writeMMgz(as(ti$query, "dgCMatrix"), "/path/to/query.mm.gz")
data.table::fwrite(data.frame(train = train_label), "/path/to/train_labels.tsv.gz", compress = "gzip")
data.table::fwrite(data.frame(test = test_label), "/path/to/test_labels.tsv.gz", compress = "gzip")
6. Analyzing Probabilities
Run inference and obtain prediction probabilities:
ti <- setup_training(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, return_type = "matrix")
seu <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, max_epochs = 3, backend = backend)
# Obtain prediction probabilities
seu <- viewmastR_infer(seu, "/tmp/sc_local/model.mpk", vg, labels = levels(factor(seur$SFClassification)), return_probs = TRUE)
scCustomize::FeaturePlot_scCustom(seu, features = "prob_19_CD8.EM")
7. Evaluating Model Weights
To inspect model weights (note this only works for mlr - see below):
mod <- RcppMsgPack::msgpack_read("/tmp/sc_local/model.mpk", simplify = TRUE)
weights <- mod$item$linear1$weight$param$value
shape <- mod$item$linear1$weight$param$shape
wmat <- data.frame(t(matrix(weights, nrow = shape[2])))
rownames(wmat) <- ti$features
colnames(wmat) <- ti$label_text
top_NK_genes <- rownames(wmat)[sort(wmat$'21_NK', index.return=T, decreasing=T)$ix[1:20]]
seu <-AddModuleScore(seu, features = list(top_nk_genes=top_NK_genes))
FeaturePlot_scCustom(seu, features = "Cluster1")
8. Comparing Different Algorithms
viewmastR
supports various algorithms, such as a
pseudo multinomial logistic regression (mlr),
multinomial naive bayes (nb), and a multi-layer
perceptron (nn). Note that our mlr function is a neural network
with no hidden layers using ReLu activation. You can read more about the
similarity between this simple neural network and logistic regression
here (https://medium.com/@axegggl/neural-networks-decoded-how-logistic-regression-is-the-hidden-first-step-495f4a0b5fd#:~:text=When%20you%20think%20about%20a,like%20a%20logistic%20regression%20model.)
The code below shows how you can run and compare these methods:
seu <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", FUNC = "mlr", query_celldata_col = "mlr_pred", selected_genes = vg, backend = backend)
seu <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", FUNC = "nb", query_celldata_col = "nb_pred", selected_genes = vg, backend = backend)
seu <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", FUNC = "nn", query_celldata_col = "nn_pred", selected_genes = vg, hidden_layers = c(200), backend = backend)
# Visualize predictions
DimPlot(seu, group.by = "mlr_pred")
DimPlot(seu, group.by = "nb_pred")
DimPlot(seu, group.by = "nn_pred")
# Evaluate accuracy
accuracy_mlr <- length(which(seu$mlr_pred == seu$ground_truth)) / dim(seu)[2]
accuracy_nb <- length(which(seu$nb_pred == seu$ground_truth)) / dim(seu)[2]
accuracy_nn <- length(which(seu$nn_pred == seu$ground_truth)) / dim(seu)[2]
# Compare accuracies
gp<-ggplot(data.frame(accuracy = c(accuracy_mlr, accuracy_nb, accuracy_nn)*100,
algorithm = c("mlr", "nb", "nn")),
aes(x = algorithm, y = accuracy, fill = algorithm)) +
geom_col() +
theme_bw() +
labs(x = "Algorithm", y = "Accuracy (%)") +
NoLegend()
ggplotly(gp)
Appendix: Session Information
To ensure reproducibility, here’s how you can capture session information:
## R version 4.4.0 (2024-04-24)
## Platform: x86_64-apple-darwin20
## Running under: macOS Ventura 13.6.7
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRlapack.dylib; LAPACK version 3.12.0
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## time zone: America/Los_Angeles
## tzcode source: internal
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] plotly_4.10.4 scCustomize_2.1.2 ggplot2_3.5.1 Seurat_5.1.0
## [5] SeuratObject_5.0.2 sp_2.1-4 viewmastR_0.2.3
##
## loaded via a namespace (and not attached):
## [1] fs_1.6.4 matrixStats_1.3.0
## [3] spatstat.sparse_3.0-3 RcppMsgPack_0.2.3
## [5] lubridate_1.9.3 httr_1.4.7
## [7] RColorBrewer_1.1-3 doParallel_1.0.17
## [9] tools_4.4.0 sctransform_0.4.1
## [11] backports_1.5.0 utf8_1.2.4
## [13] R6_2.5.1 lazyeval_0.2.2
## [15] uwot_0.2.2 GetoptLong_1.0.5
## [17] withr_3.0.0 gridExtra_2.3
## [19] progressr_0.14.0 cli_3.6.2
## [21] Biobase_2.64.0 textshaping_0.4.0
## [23] spatstat.explore_3.2-7 fastDummies_1.7.3
## [25] labeling_0.4.3 sass_0.4.9
## [27] spatstat.data_3.0-4 ggridges_0.5.6
## [29] pbapply_1.7-2 pkgdown_2.0.9
## [31] systemfonts_1.1.0 foreign_0.8-86
## [33] R.utils_2.12.3 parallelly_1.37.1
## [35] rstudioapi_0.16.0 generics_0.1.3
## [37] shape_1.4.6.1 crosstalk_1.2.1
## [39] ica_1.0-3 spatstat.random_3.2-3
## [41] dplyr_1.1.4 Matrix_1.7-0
## [43] ggbeeswarm_0.7.2 fansi_1.0.6
## [45] S4Vectors_0.42.0 abind_1.4-5
## [47] R.methodsS3_1.8.2 lifecycle_1.0.4
## [49] yaml_2.3.8 snakecase_0.11.1
## [51] SummarizedExperiment_1.34.0 recipes_1.1.0
## [53] SparseArray_1.4.8 Rtsne_0.17
## [55] paletteer_1.6.0 grid_4.4.0
## [57] promises_1.3.0 crayon_1.5.2
## [59] miniUI_0.1.1.1 lattice_0.22-6
## [61] cowplot_1.1.3 pillar_1.9.0
## [63] knitr_1.46 ComplexHeatmap_2.20.0
## [65] GenomicRanges_1.56.0 rjson_0.2.21
## [67] boot_1.3-30 future.apply_1.11.2
## [69] codetools_0.2-20 leiden_0.4.3.1
## [71] glue_1.7.0 data.table_1.15.4
## [73] vctrs_0.6.5 png_0.1-8
## [75] spam_2.10-0 gtable_0.3.5
## [77] rematch2_2.1.2 assertthat_0.2.1
## [79] cachem_1.1.0 gower_1.0.1
## [81] xfun_0.44 S4Arrays_1.4.1
## [83] mime_0.12 prodlim_2024.06.25
## [85] survival_3.6-4 timeDate_4041.110
## [87] SingleCellExperiment_1.26.0 iterators_1.0.14
## [89] pbmcapply_1.5.1 hardhat_1.4.0
## [91] lava_1.8.0 fitdistrplus_1.1-11
## [93] ROCR_1.0-11 ipred_0.9-15
## [95] nlme_3.1-164 RcppAnnoy_0.0.22
## [97] GenomeInfoDb_1.40.1 bslib_0.7.0
## [99] irlba_2.3.5.1 vipor_0.4.7
## [101] KernSmooth_2.23-24 rpart_4.1.23
## [103] colorspace_2.1-0 BiocGenerics_0.50.0
## [105] Hmisc_5.1-2 nnet_7.3-19
## [107] ggrastr_1.0.2 tidyselect_1.2.1
## [109] compiler_4.4.0 htmlTable_2.4.2
## [111] desc_1.4.3 DelayedArray_0.30.1
## [113] checkmate_2.3.1 scales_1.3.0
## [115] lmtest_0.9-40 stringr_1.5.1
## [117] digest_0.6.35 goftest_1.2-3
## [119] spatstat.utils_3.1-0 minqa_1.2.7
## [121] rmarkdown_2.27 XVector_0.44.0
## [123] htmltools_0.5.8.1 pkgconfig_2.0.3
## [125] base64enc_0.1-3 lme4_1.1-35.3
## [127] sparseMatrixStats_1.16.0 MatrixGenerics_1.16.0
## [129] highr_0.10 fastmap_1.2.0
## [131] rlang_1.1.4 GlobalOptions_0.1.2
## [133] htmlwidgets_1.6.4 UCSC.utils_1.0.0
## [135] shiny_1.8.1.1 DelayedMatrixStats_1.26.0
## [137] farver_2.1.2 jquerylib_0.1.4
## [139] zoo_1.8-12 jsonlite_1.8.8
## [141] ModelMetrics_1.2.2.2 R.oo_1.26.0
## [143] magrittr_2.0.3 Formula_1.2-5
## [145] GenomeInfoDbData_1.2.12 dotCall64_1.1-1
## [147] patchwork_1.2.0 munsell_0.5.1
## [149] Rcpp_1.0.12 reticulate_1.37.0
## [151] stringi_1.8.4 pROC_1.18.5
## [153] zlibbioc_1.50.0 MASS_7.3-60.2
## [155] plyr_1.8.9 parallel_4.4.0
## [157] listenv_0.9.1 ggrepel_0.9.5
## [159] forcats_1.0.0 deldir_2.0-4
## [161] splines_4.4.0 tensor_1.5
## [163] circlize_0.4.16 igraph_2.0.3
## [165] spatstat.geom_3.2-9 RcppHNSW_0.6.0
## [167] reshape2_1.4.4 stats4_4.4.0
## [169] evaluate_0.23 ggprism_1.0.5
## [171] nloptr_2.0.3 foreach_1.5.2
## [173] httpuv_1.6.15 RANN_2.6.1
## [175] tidyr_1.3.1 purrr_1.0.2
## [177] polyclip_1.10-6 future_1.33.2
## [179] clue_0.3-65 scattermore_1.2
## [181] janitor_2.2.0 xtable_1.8-4
## [183] monocle3_1.3.7 RSpectra_0.16-1
## [185] later_1.3.2 viridisLite_0.4.2
## [187] class_7.3-22 ragg_1.3.2
## [189] tibble_3.2.1 memoise_2.0.1
## [191] beeswarm_0.4.0 IRanges_2.38.0
## [193] cluster_2.1.6 timechange_0.3.0
## [195] globals_0.16.3 caret_6.0-94
getwd()
## [1] "/Users/sfurla/develop/viewmastR/vignettes"