geneformer.mtl_classifier
Geneformer multi-task cell classifier.
Input data:
Single-cell transcriptomes as Geneformer rank value encodings with cell state labels for each task in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py). Must contain “unique_cell_id” column for logging.
Usage:
>>> from geneformer import MTLClassifier
>>> mc = MTLClassifier(task_columns = ["task1", "task2"],
... study_name = "mtl",
... pretrained_path = "/path/pretrained/model",
... train_path = "/path/train/set",
... val_path = "/path/eval/set",
... test_path = "/path/test/set",
... model_save_path = "/results/directory/save_path",
... trials_result_path = "/results/directory/results.txt",
... results_dir = "/results/directory",
... tensorboard_log_dir = "/results/tblogdir",
... hyperparameters = hyperparameters)
>>> mc.run_optuna_study()
>>> mc.load_and_evaluate_test_model()
>>> mc.save_model_without_heads()
- class MTLClassifier(task_columns=None, train_path=None, val_path=None, test_path=None, pretrained_path=None, model_save_path=None, results_dir=None, trials_result_path=None, batch_size=4, n_trials=15, study_name='mtl', max_layers_to_freeze=None, epochs=1, tensorboard_log_dir='/results/tblogdir', distributed_training=False, master_addr='localhost', master_port='12355', use_attention_pooling=True, use_task_weights=True, hyperparameters=None, manual_hyperparameters=None, use_manual_hyperparameters=False, use_wandb=False, wandb_project=None, gradient_clipping=False, max_grad_norm=None, gradient_accumulation_steps=1, seed=42)[source]
Initialize Geneformer multi-task classifier.
Parameters:
- task_columnslist
- List of tasks for cell state classificationInput data columns are labeled with corresponding task names
- study_nameNone, str
- Study name for labeling output files
- pretrained_pathNone, str
- Path to pretrained model
- train_pathNone, str
- Path to training dataset with task columns and “unique_cell_id” column
- val_pathNone, str
- Path to validation dataset with task columns and “unique_cell_id” column
- test_pathNone, str
- Path to test dataset with task columns and “unique_cell_id” column
- model_save_pathNone, str
- Path to directory to save output model (either full model or model without heads)
- trials_result_pathNone, str
- Path to directory to save hyperparameter tuning trial results
- results_dirNone, str
- Path to directory to save results
- tensorboard_log_dirNone, str
- Path to directory for Tensorboard logging results
- distributed_trainingNone, bool
- Whether to use distributed data parallel training across multiple GPUs
- master_addrNone, str
- Master address for distributed training (default: localhost)
- master_portNone, str
- Master port for distributed training (default: 12355)
- use_attention_poolingNone, bool
- Whether to use attention pooling
- use_task_weightsNone, bool
- Whether to use task weights
- batch_sizeNone, int
- Batch size to use
- n_trialsNone, int
- Number of trials for hyperparameter tuning
- epochsNone, int
- Number of epochs for training
- max_layers_to_freezeNone, dict
- Dictionary with keys “min” and “max” indicating the min and max layers to freeze from fine-tuning (int)0: no layers will be frozen; 2: first two layers will be frozen; etc.
- hyperparametersNone, dict
- Dictionary of categorical max and min for each hyperparameter for tuningFor example:{“learning_rate”: {“type”:”float”, “low”:”1e-5”, “high”:”1e-3”, “log”:True}, “task_weights”: {…}, …}
- manual_hyperparametersNone, dict
- Dictionary of manually set value for each hyperparameterFor example:{“learning_rate”: 0.001, “task_weights”: [1, 1], …}
- use_manual_hyperparametersNone, bool
- Whether to use manually set hyperparameters
- use_wandbNone, bool
- Whether to use Weights & Biases for logging
- wandb_projectNone, str
- Weights & Biases project name
- gradient_clippingNone, bool
- Whether to use gradient clipping
- max_grad_normNone, int, float
- Maximum norm for gradient clipping
- gradient_accumulation_stepsNone, int
- Number of steps to accumulate gradients before performing a backward/update pass
- seedNone, int
- Random seed
- load_and_evaluate_test_model()[source]
Loads previously fine-tuned multi-task model and evaluates on test data.