interscale.evaluation.gene_loadings#
- interscale.evaluation.gene_loadings(adata, model, layer_key, local_latent_key='_local_emb', global_latent_key='_global_emb', local_varm_key='_local_std_gene_loadings', global_varm_key='_global_std_gene_loadings', eps=1e-08)#
Compute standardized local + global gene loadings from two linear decoders.
The model is assumed to decode log-normalized expression from two transformer outputs (local + global embedding):
x_hat_local = W_local z_local + b_local x_hat_global = W_global z_global + b_global
- Standardized loading (computed separately for local and global):
S_gk = W_gk * std(z[k]) / std(x_g)
- Parameters:
adata (
AnnData) – AnnData objectmodel –
Trained model (possibly DDP-wrapped) Decoder weights accessed via:
model.module.state_dict()[‘local_module.decoder.decoder.weight’]
model.module.state_dict()[‘global_module.decoder.decoder.weight’]
layer_key (
str) – adata.layers[layer_key] must contain log-normalized expressionlocal_latent_key (
str(default:'_local_emb')) – adata.obsm key containing local transformer output embeddingsglobal_latent_key (
str(default:'_global_emb')) – adata.obsm key containing global transformer output embeddingslocal_varm_key (
str(default:'_local_std_gene_loadings')) – Key to store standardized LOCAL gene loadings in adata.varmglobal_varm_key (
str(default:'_global_std_gene_loadings')) – Key to store standardized GLOBAL gene loadings in adata.varmeps (
float(default:1e-08)) – Small constant for numerical stability