interscale.evaluation.gene_loadings

Contents

interscale.evaluation.gene_loadings#

interscale.evaluation.gene_loadings(adata, model, layer_key, local_latent_key='_local_emb', global_latent_key='_global_emb', local_varm_key='_local_std_gene_loadings', global_varm_key='_global_std_gene_loadings', eps=1e-08)#

Compute standardized local + global gene loadings from two linear decoders.

The model is assumed to decode log-normalized expression from two transformer outputs (local + global embedding):

x_hat_local = W_local z_local + b_local x_hat_global = W_global z_global + b_global

Standardized loading (computed separately for local and global):

S_gk = W_gk * std(z[k]) / std(x_g)

Parameters:
  • adata (AnnData) – AnnData object

  • model

    Trained model (possibly DDP-wrapped) Decoder weights accessed via:

    • model.module.state_dict()[‘local_module.decoder.decoder.weight’]

    • model.module.state_dict()[‘global_module.decoder.decoder.weight’]

  • layer_key (str) – adata.layers[layer_key] must contain log-normalized expression

  • local_latent_key (str (default: '_local_emb')) – adata.obsm key containing local transformer output embeddings

  • global_latent_key (str (default: '_global_emb')) – adata.obsm key containing global transformer output embeddings

  • local_varm_key (str (default: '_local_std_gene_loadings')) – Key to store standardized LOCAL gene loadings in adata.varm

  • global_varm_key (str (default: '_global_std_gene_loadings')) – Key to store standardized GLOBAL gene loadings in adata.varm

  • eps (float (default: 1e-08)) – Small constant for numerical stability