interscale.model.LocalModel#
- class interscale.model.LocalModel(adata, cfg)#
Local model with only local component.
Methods table#
|
Retrieves the |
|
Save the embeddings, predictions and attentionsin the adata object. |
|
Load a saved model. |
|
Return cell label predictions. |
|
Save the state of the model. |
|
Save the evaluation results in the adata object. |
|
Train the model. |
Methods#
- LocalModel.get_anndata_manager(adata, required=False)#
Retrieves the
AnnDataManagerfor a given AnnData object.Requires
self.idhas been set. Checks for anAnnDataManagerspecific to this model instance.- Parameters:
- Return type:
- LocalModel.get_model_output(adata=None, prefix='')#
Save the embeddings, predictions and attentionsin the adata object.
- classmethod LocalModel.load(dir_path, adata, cfg, model_name=None, local_component=False, global_component=False, postfix=None, wandb_save=False, enable_remapping=True)#
Load a saved model.
- Parameters:
dir_path (
str) – Path to saved model directory.adata (
AnnData) – AnnData object to load the model with.cfg (
CfgNode) – Configuration object.model_name (
str|None(default:None)) – Name of the model to load. If None, the model name is inferred from the config file.local_component (
bool(default:False)) – Whether this is a local component model.global_component (
bool(default:False)) – Whether this is a global component model.wandb_save (
bool(default:False)) – Whether this was saved via wandb.enable_remapping (
bool(default:True)) – Whether to enable automatic state dict key remapping.
- Returns:
model Loaded model.
- LocalModel.predict_nodewise(adata=None, indices=None, batch_size=None)#
Return cell label predictions.
- LocalModel.save(dir_path=None, overwrite=False, postfix=None, save_kwargs=None)#
Save the state of the model. File is saved as <dataset_name>_<prediction_task[:4]>_<prediction_level>_<local_component_name>_<global_component_name>_<model_state_dict>.pt
- Parameters:
dir_path (
str|None(default:None)) – Path to a directory or cfg.model.save_pathoverwrite (
bool(default:False)) – Overwrite existing data or not. IfFalseand directory already exists atdir_path, error will be raised.save_kwargs (
dict|None(default:None)) – Keyword arguments passed intosave().
- LocalModel.save_evaluation_results(adata, prefix, y_pred_local_df, y_pred_global_df, local_embeddings_df=None, global_embeddings_df=None, attention_matrix_df=None, cls_token_horizontal=None, cls_token_vertical=None)#
Save the evaluation results in the adata object.
- LocalModel.train(max_epochs, shuffle_set_split=True, load_sparse_tensor=False, early_stopping=True, patience=5, datasplitter_kwargs=None, plan_kwargs=None, datamodule=None, wandb_use=None, **trainer_kwargs)#
Train the model.
- Parameters:
max_epochs (
int) – The maximum number of epochs to train the model. The actual number of epochs may be less if early stopping is enabled.shuffle_set_split (
bool(default:True)) – Whether to shuffle indices before splitting. IfFalse, the val, train, and test set are split in the sequential order of the data.load_sparse_tensor (
bool(default:False)) – Whether to load data as sparse tensors.early_stopping (
bool(default:True)) – Perform early stopping. Additional arguments can be passed in through**trainer_kwargs.patience (
int(default:5)) – Patience for early stopping: number of epochs to wait for improvement before stopping.datasplitter_kwargs (
dict|None(default:None)) – Additional keyword arguments passed into the data splitter. Not used ifdatamoduleis passed in.plan_kwargs (
dict|None(default:None)) – Additional keyword arguments passed into the training plan.datamodule (
LightningDataModule|None(default:None)) – ALightningDataModuleinstance to use for training.wandb_use (
bool|None(default:None)) – Whether to log to Weights & Biases. Defaults to the project config.**trainer_kwargs – Additional keyword arguments passed into the
Trainer.