geneformer.mtl_classifier

Geneformer multi-task cell classifier.

Input data:

Single-cell transcriptomes as Geneformer rank value encodings with cell state labels for each task in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py). Must contain “unique_cell_id” column for logging.

Usage:

>>> from geneformer import MTLClassifier
>>> mc = MTLClassifier(task_columns = ["task1", "task2"],
...                 study_name = "mtl",
...                 pretrained_path = "/path/pretrained/model",
...                 train_path = "/path/train/set",
...                 val_path = "/path/eval/set",
...                 test_path = "/path/test/set",
...                 model_save_path = "/results/directory/save_path",
...                 trials_result_path = "/results/directory/results.txt",
...                 results_dir = "/results/directory",
...                 tensorboard_log_dir = "/results/tblogdir",
...                 hyperparameters = hyperparameters)
>>> mc.run_optuna_study()
>>> mc.load_and_evaluate_test_model()
>>> mc.save_model_without_heads()
class MTLClassifier(task_columns=None, train_path=None, val_path=None, test_path=None, pretrained_path=None, model_save_path=None, results_dir=None, trials_result_path=None, batch_size=4, n_trials=15, study_name='mtl', max_layers_to_freeze=None, epochs=1, tensorboard_log_dir='/results/tblogdir', distributed_training=False, master_addr='localhost', master_port='12355', use_attention_pooling=True, use_task_weights=True, hyperparameters=None, manual_hyperparameters=None, use_manual_hyperparameters=False, use_wandb=False, wandb_project=None, gradient_clipping=False, max_grad_norm=None, gradient_accumulation_steps=1, seed=42)[source]

Initialize Geneformer multi-task classifier.

Parameters:

task_columnslist
List of tasks for cell state classification
Input data columns are labeled with corresponding task names
study_nameNone, str
Study name for labeling output files
pretrained_pathNone, str
Path to pretrained model
train_pathNone, str
Path to training dataset with task columns and “unique_cell_id” column
val_pathNone, str
Path to validation dataset with task columns and “unique_cell_id” column
test_pathNone, str
Path to test dataset with task columns and “unique_cell_id” column
model_save_pathNone, str
Path to directory to save output model (either full model or model without heads)
trials_result_pathNone, str
Path to directory to save hyperparameter tuning trial results
results_dirNone, str
Path to directory to save results
tensorboard_log_dirNone, str
Path to directory for Tensorboard logging results
distributed_trainingNone, bool
Whether to use distributed data parallel training across multiple GPUs
master_addrNone, str
Master address for distributed training (default: localhost)
master_portNone, str
Master port for distributed training (default: 12355)
use_attention_poolingNone, bool
Whether to use attention pooling
use_task_weightsNone, bool
Whether to use task weights
batch_sizeNone, int
Batch size to use
n_trialsNone, int
Number of trials for hyperparameter tuning
epochsNone, int
Number of epochs for training
max_layers_to_freezeNone, dict
Dictionary with keys “min” and “max” indicating the min and max layers to freeze from fine-tuning (int)
0: no layers will be frozen; 2: first two layers will be frozen; etc.
hyperparametersNone, dict
Dictionary of categorical max and min for each hyperparameter for tuning
For example:
{“learning_rate”: {“type”:”float”, “low”:”1e-5”, “high”:”1e-3”, “log”:True}, “task_weights”: {…}, …}
manual_hyperparametersNone, dict
Dictionary of manually set value for each hyperparameter
For example:
{“learning_rate”: 0.001, “task_weights”: [1, 1], …}
use_manual_hyperparametersNone, bool
Whether to use manually set hyperparameters
use_wandbNone, bool
Whether to use Weights & Biases for logging
wandb_projectNone, str
Weights & Biases project name
gradient_clippingNone, bool
Whether to use gradient clipping
max_grad_normNone, int, float
Maximum norm for gradient clipping
gradient_accumulation_stepsNone, int
Number of steps to accumulate gradients before performing a backward/update pass
seedNone, int
Random seed
load_and_evaluate_test_model()[source]

Loads previously fine-tuned multi-task model and evaluates on test data.

run_manual_tuning()[source]

Manual hyperparameter tuning and multi-task fine-tuning of pretrained model.

run_optuna_study()[source]

Hyperparameter optimization and/or multi-task fine-tuning of pretrained model.