interscale.model.LocalModel#

class interscale.model.LocalModel(adata, cfg)#

Local model with only local component.

Methods table#

get_anndata_manager(adata[, required])

Retrieves the AnnDataManager for a given AnnData object.

get_model_output([adata, prefix])

Save the embeddings, predictions and attentionsin the adata object.

load(dir_path, adata, cfg[, model_name, ...])

Load a saved model.

predict_nodewise([adata, indices, batch_size])

Return cell label predictions.

save([dir_path, overwrite, postfix, save_kwargs])

Save the state of the model.

save_evaluation_results(adata, prefix, ...)

Save the evaluation results in the adata object.

train(max_epochs[, shuffle_set_split, ...])

Train the model.

Methods#

LocalModel.get_anndata_manager(adata, required=False)#

Retrieves the AnnDataManager for a given AnnData object.

Requires self.id has been set. Checks for an AnnDataManager specific to this model instance.

Parameters:
  • adata (AnnData) – AnnData object to find manager instance for.

  • required (bool (default: False)) – If True, errors on missing manager. Otherwise, returns None when manager is missing.

Return type:

AnnDataManager | None

LocalModel.get_model_output(adata=None, prefix='')#

Save the embeddings, predictions and attentionsin the adata object.

Parameters:
  • adata (AnnData | None (default: None)) – AnnData object to run the model on. If None, the model’s AnnData object is used.

  • prefix (str (default: '')) – Prefix for the output columns.

classmethod LocalModel.load(dir_path, adata, cfg, model_name=None, local_component=False, global_component=False, postfix=None, wandb_save=False, enable_remapping=True)#

Load a saved model.

Parameters:
  • dir_path (str) – Path to saved model directory.

  • adata (AnnData) – AnnData object to load the model with.

  • cfg (CfgNode) – Configuration object.

  • model_name (str | None (default: None)) – Name of the model to load. If None, the model name is inferred from the config file.

  • local_component (bool (default: False)) – Whether this is a local component model.

  • global_component (bool (default: False)) – Whether this is a global component model.

  • wandb_save (bool (default: False)) – Whether this was saved via wandb.

  • enable_remapping (bool (default: True)) – Whether to enable automatic state dict key remapping.

Returns:

model Loaded model.

LocalModel.predict_nodewise(adata=None, indices=None, batch_size=None)#

Return cell label predictions.

Parameters:
  • adata (AnnData | None (default: None)) – AnnData object that has been registered via corresponding setup method in model class.

  • indices (Sequence[int] | None (default: None)) – Indices of the data to predict. If None, all data is predicted.

LocalModel.save(dir_path=None, overwrite=False, postfix=None, save_kwargs=None)#

Save the state of the model. File is saved as <dataset_name>_<prediction_task[:4]>_<prediction_level>_<local_component_name>_<global_component_name>_<model_state_dict>.pt

Parameters:
  • dir_path (str | None (default: None)) – Path to a directory or cfg.model.save_path

  • overwrite (bool (default: False)) – Overwrite existing data or not. If False and directory already exists at dir_path, error will be raised.

  • save_kwargs (dict | None (default: None)) – Keyword arguments passed into save().

LocalModel.save_evaluation_results(adata, prefix, y_pred_local_df, y_pred_global_df, local_embeddings_df=None, global_embeddings_df=None, attention_matrix_df=None, cls_token_horizontal=None, cls_token_vertical=None)#

Save the evaluation results in the adata object.

Parameters:
Returns:

adata: AnnData AnnData object with the evaluation results saved in the obsm and layers.

LocalModel.train(max_epochs, shuffle_set_split=True, load_sparse_tensor=False, early_stopping=True, patience=5, datasplitter_kwargs=None, plan_kwargs=None, datamodule=None, wandb_use=None, **trainer_kwargs)#

Train the model.

Parameters:
  • max_epochs (int) – The maximum number of epochs to train the model. The actual number of epochs may be less if early stopping is enabled.

  • shuffle_set_split (bool (default: True)) – Whether to shuffle indices before splitting. If False, the val, train, and test set are split in the sequential order of the data.

  • load_sparse_tensor (bool (default: False)) – Whether to load data as sparse tensors.

  • early_stopping (bool (default: True)) – Perform early stopping. Additional arguments can be passed in through **trainer_kwargs.

  • patience (int (default: 5)) – Patience for early stopping: number of epochs to wait for improvement before stopping.

  • datasplitter_kwargs (dict | None (default: None)) – Additional keyword arguments passed into the data splitter. Not used if datamodule is passed in.

  • plan_kwargs (dict | None (default: None)) – Additional keyword arguments passed into the training plan.

  • datamodule (LightningDataModule | None (default: None)) – A LightningDataModule instance to use for training.

  • wandb_use (bool | None (default: None)) – Whether to log to Weights & Biases. Defaults to the project config.

  • **trainer_kwargs – Additional keyword arguments passed into the Trainer.