Skip to contents

A Deeper Look at viewmastR

Before diving into how viewmastR works, let’s first set up the necessary environment and go through essential functions to streamline your training and analysis workflow.

1. Installing Rust

Before using viewmastR, you’ll need an updated installation of Rust, as it’s a core dependency. Follow the instructions on the official Rust installation page to set up Rust on your system.

2. Installing viewmastR

Once Rust is installed, you can install viewmastR directly from GitHub. Ensure you have the devtools package installed, and then use the following command:

devtools::install_github("furlan-lab/viewmastR")

3. Viewing the Training History

viewmastR tracks key data during the training process, which can be accessed by setting the return_type parameter to "list". This returns: 1. The query object with predicted cell types. 2. The training results.

Here’s how you can retrieve and visualize the training data:

suppressPackageStartupMessages({
  library(viewmastR)
  library(Seurat)
  library(ggplot2)
  library(scCustomize)
  library(plotly)
})

# Load query and reference datasets
seu <- readRDS(file.path(ROOT_DIR1, "240813_final_object.RDS"))
vg <- get_selected_genes(seu)
seur <- readRDS(file.path(ROOT_DIR2, "230329_rnaAugmented_seurat.RDS"))

# View training history
output_list <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, return_type = "list")

Visualizing Training Data

To plot training vs validation loss, you can use the following:

plot_training_data(output_list)

For rendering the plot without details:

plt <- plot_training_data(output_list)
plt

Tip: If the training loss decreases while the validation loss plateaus, it may indicate overfitting.

4. Tuning for Speed

viewmastR runs with 3 available backends (see Burn for more details). The candle backend tends to run faster on Apple M1/M2 processors. Here’s how you can compare the performance of different backends on your system:

run1 <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, max_epochs = 3, backend = "candle", return_type = "list")
run2 <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, max_epochs = 3, backend = "wgpu", return_type = "list")
run3 <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, max_epochs = 3, backend = "nd", return_type = "list")

# Compare training times
gp<-ggplot(data.frame(training_times = c(run1$training_output$duration$training_duration,
                                    run2$training_output$duration$training_duration,
                                    run3$training_output$duration$training_duration), 
                  backend = c("candle", "wgpu", "nd")),
       aes(x = backend, y = training_times, fill = backend)) + 
  geom_col() + 
  theme_bw() + 
  labs(x = "Backend", y = "Training Time (s)") + 
  NoLegend() + ggtitle(paste("Arch: ", as.character(Sys.info()["machine"])))
  
ggplotly(gp)
# To automatically set backend 
if (as.character(Sys.info()["machine"])=="x86_64"){
  backend <- "wgpu"
} else {
  backend <- "candle"
}

5. Saving Training Subsets

To inspect the training and test data used by viewmastR we provide the setup_training function if you so desire to evaluate these using other learning frameworks.

ti <- setup_training(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, return_type = "matrix", backend = backend)

# Convert labels to max class and save them
train_label <- apply(ti$Ytrain_label, 1, which.max)
test_label <- apply(ti$Ytest_label, 1, which.max)

# Save training data and labels
writeMMgz(as(ti$Xtrain_data, "dgCMatrix"), "/path/to/train.mm.gz")
writeMMgz(as(ti$Xtest_data, "dgCMatrix"), "/path/to/test.mm.gz")
writeMMgz(as(ti$query, "dgCMatrix"), "/path/to/query.mm.gz")
data.table::fwrite(data.frame(train = train_label), "/path/to/train_labels.tsv.gz", compress = "gzip")
data.table::fwrite(data.frame(test = test_label), "/path/to/test_labels.tsv.gz", compress = "gzip")

6. Analyzing Probabilities

Run inference and obtain prediction probabilities:

ti <- setup_training(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, return_type = "matrix")
seu <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", selected_genes = vg, max_epochs = 3, backend = backend)

# Obtain prediction probabilities
seu <- viewmastR_infer(seu, "/tmp/sc_local/model.mpk", vg, labels = levels(factor(seur$SFClassification)), return_probs = TRUE)
scCustomize::FeaturePlot_scCustom(seu, features = "prob_19_CD8.EM")

7. Evaluating Model Weights

To inspect model weights (note this only works for mlr - see below):

mod <- RcppMsgPack::msgpack_read("/tmp/sc_local/model.mpk", simplify = TRUE)
weights <- mod$item$linear1$weight$param$value
shape <- mod$item$linear1$weight$param$shape
wmat <- data.frame(t(matrix(weights, nrow = shape[2])))
rownames(wmat) <- ti$features
colnames(wmat) <- ti$label_text
top_NK_genes <- rownames(wmat)[sort(wmat$'21_NK', index.return=T, decreasing=T)$ix[1:20]]
seu <-AddModuleScore(seu, features = list(top_nk_genes=top_NK_genes))
FeaturePlot_scCustom(seu, features = "Cluster1")

8. Comparing Different Algorithms

viewmastR supports various algorithms, such as a pseudo multinomial logistic regression (mlr), multinomial naive bayes (nb), and a multi-layer perceptron (nn). Note that our mlr function is a neural network with no hidden layers using ReLu activation. You can read more about the similarity between this simple neural network and logistic regression here (https://medium.com/@axegggl/neural-networks-decoded-how-logistic-regression-is-the-hidden-first-step-495f4a0b5fd#:~:text=When%20you%20think%20about%20a,like%20a%20logistic%20regression%20model.)

The code below shows how you can run and compare these methods:

seu <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", FUNC = "mlr", query_celldata_col = "mlr_pred", selected_genes = vg, backend = backend)
seu <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", FUNC = "nb", query_celldata_col = "nb_pred", selected_genes = vg, backend = backend)
seu <- viewmastR(seu, seur, ref_celldata_col = "SFClassification", FUNC = "nn", query_celldata_col = "nn_pred", selected_genes = vg, hidden_layers = c(200), backend = backend)

# Visualize predictions
DimPlot(seu, group.by = "mlr_pred")

DimPlot(seu, group.by = "nb_pred")

DimPlot(seu, group.by = "nn_pred")

# Evaluate accuracy
accuracy_mlr <- length(which(seu$mlr_pred == seu$ground_truth)) / dim(seu)[2]
accuracy_nb <- length(which(seu$nb_pred == seu$ground_truth)) / dim(seu)[2]
accuracy_nn <- length(which(seu$nn_pred == seu$ground_truth)) / dim(seu)[2]

# Compare accuracies
gp<-ggplot(data.frame(accuracy = c(accuracy_mlr, accuracy_nb, accuracy_nn)*100, 
                  algorithm = c("mlr", "nb", "nn")),
       aes(x = algorithm, y = accuracy, fill = algorithm)) + 
  geom_col() + 
  theme_bw() + 
  labs(x = "Algorithm", y = "Accuracy (%)") + 
  NoLegend()
ggplotly(gp)

Appendix: Session Information

To ensure reproducibility, here’s how you can capture session information:

## R version 4.4.0 (2024-04-24)
## Platform: x86_64-apple-darwin20
## Running under: macOS Ventura 13.6.7
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRblas.0.dylib 
## LAPACK: /Library/Frameworks/R.framework/Versions/4.4-x86_64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.0
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## time zone: America/Los_Angeles
## tzcode source: internal
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
## [1] plotly_4.10.4      scCustomize_2.1.2  ggplot2_3.5.1      Seurat_5.1.0      
## [5] SeuratObject_5.0.2 sp_2.1-4           viewmastR_0.2.3   
## 
## loaded via a namespace (and not attached):
##   [1] fs_1.6.4                    matrixStats_1.3.0          
##   [3] spatstat.sparse_3.0-3       RcppMsgPack_0.2.3          
##   [5] lubridate_1.9.3             httr_1.4.7                 
##   [7] RColorBrewer_1.1-3          doParallel_1.0.17          
##   [9] tools_4.4.0                 sctransform_0.4.1          
##  [11] backports_1.5.0             utf8_1.2.4                 
##  [13] R6_2.5.1                    lazyeval_0.2.2             
##  [15] uwot_0.2.2                  GetoptLong_1.0.5           
##  [17] withr_3.0.0                 gridExtra_2.3              
##  [19] progressr_0.14.0            cli_3.6.2                  
##  [21] Biobase_2.64.0              textshaping_0.4.0          
##  [23] spatstat.explore_3.2-7      fastDummies_1.7.3          
##  [25] labeling_0.4.3              sass_0.4.9                 
##  [27] spatstat.data_3.0-4         ggridges_0.5.6             
##  [29] pbapply_1.7-2               pkgdown_2.0.9              
##  [31] systemfonts_1.1.0           foreign_0.8-86             
##  [33] R.utils_2.12.3              parallelly_1.37.1          
##  [35] rstudioapi_0.16.0           generics_0.1.3             
##  [37] shape_1.4.6.1               crosstalk_1.2.1            
##  [39] ica_1.0-3                   spatstat.random_3.2-3      
##  [41] dplyr_1.1.4                 Matrix_1.7-0               
##  [43] ggbeeswarm_0.7.2            fansi_1.0.6                
##  [45] S4Vectors_0.42.0            abind_1.4-5                
##  [47] R.methodsS3_1.8.2           lifecycle_1.0.4            
##  [49] yaml_2.3.8                  snakecase_0.11.1           
##  [51] SummarizedExperiment_1.34.0 recipes_1.1.0              
##  [53] SparseArray_1.4.8           Rtsne_0.17                 
##  [55] paletteer_1.6.0             grid_4.4.0                 
##  [57] promises_1.3.0              crayon_1.5.2               
##  [59] miniUI_0.1.1.1              lattice_0.22-6             
##  [61] cowplot_1.1.3               pillar_1.9.0               
##  [63] knitr_1.46                  ComplexHeatmap_2.20.0      
##  [65] GenomicRanges_1.56.0        rjson_0.2.21               
##  [67] boot_1.3-30                 future.apply_1.11.2        
##  [69] codetools_0.2-20            leiden_0.4.3.1             
##  [71] glue_1.7.0                  data.table_1.15.4          
##  [73] vctrs_0.6.5                 png_0.1-8                  
##  [75] spam_2.10-0                 gtable_0.3.5               
##  [77] rematch2_2.1.2              assertthat_0.2.1           
##  [79] cachem_1.1.0                gower_1.0.1                
##  [81] xfun_0.44                   S4Arrays_1.4.1             
##  [83] mime_0.12                   prodlim_2024.06.25         
##  [85] survival_3.6-4              timeDate_4041.110          
##  [87] SingleCellExperiment_1.26.0 iterators_1.0.14           
##  [89] pbmcapply_1.5.1             hardhat_1.4.0              
##  [91] lava_1.8.0                  fitdistrplus_1.1-11        
##  [93] ROCR_1.0-11                 ipred_0.9-15               
##  [95] nlme_3.1-164                RcppAnnoy_0.0.22           
##  [97] GenomeInfoDb_1.40.1         bslib_0.7.0                
##  [99] irlba_2.3.5.1               vipor_0.4.7                
## [101] KernSmooth_2.23-24          rpart_4.1.23               
## [103] colorspace_2.1-0            BiocGenerics_0.50.0        
## [105] Hmisc_5.1-2                 nnet_7.3-19                
## [107] ggrastr_1.0.2               tidyselect_1.2.1           
## [109] compiler_4.4.0              htmlTable_2.4.2            
## [111] desc_1.4.3                  DelayedArray_0.30.1        
## [113] checkmate_2.3.1             scales_1.3.0               
## [115] lmtest_0.9-40               stringr_1.5.1              
## [117] digest_0.6.35               goftest_1.2-3              
## [119] spatstat.utils_3.1-0        minqa_1.2.7                
## [121] rmarkdown_2.27              XVector_0.44.0             
## [123] htmltools_0.5.8.1           pkgconfig_2.0.3            
## [125] base64enc_0.1-3             lme4_1.1-35.3              
## [127] sparseMatrixStats_1.16.0    MatrixGenerics_1.16.0      
## [129] highr_0.10                  fastmap_1.2.0              
## [131] rlang_1.1.4                 GlobalOptions_0.1.2        
## [133] htmlwidgets_1.6.4           UCSC.utils_1.0.0           
## [135] shiny_1.8.1.1               DelayedMatrixStats_1.26.0  
## [137] farver_2.1.2                jquerylib_0.1.4            
## [139] zoo_1.8-12                  jsonlite_1.8.8             
## [141] ModelMetrics_1.2.2.2        R.oo_1.26.0                
## [143] magrittr_2.0.3              Formula_1.2-5              
## [145] GenomeInfoDbData_1.2.12     dotCall64_1.1-1            
## [147] patchwork_1.2.0             munsell_0.5.1              
## [149] Rcpp_1.0.12                 reticulate_1.37.0          
## [151] stringi_1.8.4               pROC_1.18.5                
## [153] zlibbioc_1.50.0             MASS_7.3-60.2              
## [155] plyr_1.8.9                  parallel_4.4.0             
## [157] listenv_0.9.1               ggrepel_0.9.5              
## [159] forcats_1.0.0               deldir_2.0-4               
## [161] splines_4.4.0               tensor_1.5                 
## [163] circlize_0.4.16             igraph_2.0.3               
## [165] spatstat.geom_3.2-9         RcppHNSW_0.6.0             
## [167] reshape2_1.4.4              stats4_4.4.0               
## [169] evaluate_0.23               ggprism_1.0.5              
## [171] nloptr_2.0.3                foreach_1.5.2              
## [173] httpuv_1.6.15               RANN_2.6.1                 
## [175] tidyr_1.3.1                 purrr_1.0.2                
## [177] polyclip_1.10-6             future_1.33.2              
## [179] clue_0.3-65                 scattermore_1.2            
## [181] janitor_2.2.0               xtable_1.8-4               
## [183] monocle3_1.3.7              RSpectra_0.16-1            
## [185] later_1.3.2                 viridisLite_0.4.2          
## [187] class_7.3-22                ragg_1.3.2                 
## [189] tibble_3.2.1                memoise_2.0.1              
## [191] beeswarm_0.4.0              IRanges_2.38.0             
## [193] cluster_2.1.6               timechange_0.3.0           
## [195] globals_0.16.3              caret_6.0-94
## [1] "/Users/sfurla/develop/viewmastR/vignettes"