Extract linear-layer weights and map them to feature / class names
get_weights.Rd
Reads a model exported by the Rust Burn pipeline together with its companion metadata file and returns a tidy weight matrix whose rows correspond to the original feature names and whose columns correspond to the class labels.
Value
A base-data.frame
with dimension
\((\#\;features) \times (\#\;classes)\), where
rownames(wmat)
are the feature names and
colnames(wmat)
are the class labels. Cell (i,j) is the weight
connecting feature i to logit j.
Details
Internally the function:
deserialises the two MessagePack files with msgpackR;
raw‐decodes the tensor bytes through
decode_param()
;reshapes the flat vector into a column-major matrix using the stored shape (
[out_dim, in_dim]
);transposes it so that rows align with features;
re-labels rows and columns from the metadata lists.
The resulting object is ready for
pheatmap()
, corrplot()
, or as.matrix()
for further analysis.
See also
msgpack_read()
from msgpackR – generic MessagePack readerdecode_param()
– helper that converts Burn tensor blobs into R vectors
Examples
if (FALSE) { # \dontrun{
w <- get_weights("artifacts/run-42")
head(w[, 1:5]) # first 5 classes
# visualise top positive / negative features for class 3
cls <- 3
w_sorted <- w[order(w[, cls]), cls]
barplot(tail(w_sorted, 10), horiz = TRUE, las = 1)
barplot(head(w_sorted, 10), horiz = TRUE, las = 1)
} # }