geneformer.classifier

Geneformer classifier.

Input data:

Cell state classifier:
Single-cell transcriptomes as Geneformer rank value encodings with cell state labels in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
Gene classifier:
Dictionary in format {Gene_label: list(genes)} for gene labels and single-cell transcriptomes as Geneformer rank value encodings in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)

Usage:

>>> from geneformer import Classifier
>>> cc = Classifier(classifier="cell",  # example of cell state classifier
...                 cell_state_dict={"state_key": "disease", "states": "all"},
...                 filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
...                 training_args=training_args,
...                 freeze_layers = 2,
...                 num_crossval_splits = 1,
...                 forward_batch_size=200,
...                 nproc=16)
>>> cc.prepare_data(input_data_file="path/to/input_data",
...                 output_directory="path/to/output_directory",
...                 output_prefix="output_prefix")
>>> all_metrics = cc.validate(model_directory="path/to/model",
...                           prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset",
...                           id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
...                           output_directory="path/to/output_directory",
...                           output_prefix="output_prefix",
...                           predict_eval=True)
>>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
...                  output_directory="path/to/output_directory",
...                  output_prefix="output_prefix",
...                  custom_class_order=["healthy","disease1","disease2"])
>>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl",
...                     id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
...                     title="disease",
...                     output_directory="path/to/output_directory",
...                     output_prefix="output_prefix",
...                     custom_class_order=["healthy","disease1","disease2"])
class Classifier(classifier=None, quantize=False, cell_state_dict=None, gene_class_dict=None, filter_data=None, rare_threshold=0, max_ncells=None, max_ncells_per_class=None, training_args=None, ray_config=None, freeze_layers=0, num_crossval_splits=1, split_sizes={'test': 0.1, 'train': 0.8, 'valid': 0.1}, stratify_splits_col=None, no_eval=False, forward_batch_size=100, model_version='V2', token_dictionary_file=None, nproc=4, ngpu=1)[source]

Initialize Geneformer classifier.

Parameters:

classifier{“cell”, “gene”}
Whether to fine-tune a cell state or gene classifier.
quantizebool, dict
Whether to fine-tune a quantized model.
If True and no config provided, will use default.
Will use custom config if provided.
Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft).
For example: {“bnb_config”: BitsAndBytesConfig(…),
“peft_config”: LoraConfig(…)}
cell_state_dictNone, dict
Cell states to fine-tune model to distinguish.
Two-item dictionary with keys: state_key and states
state_key: key specifying name of column in .dataset that defines the states to model
states: list of values in the state_key column that specifies the states to model
Alternatively, instead of a list of states, can specify “all” to use all states in that state key from input data.
Of note, if using “all”, states will be defined after data is filtered.
Must have at least 2 states to model.
For example: {“state_key”: “disease”,
“states”: [“nf”, “hcm”, “dcm”]}
or
{“state_key”: “disease”,
“states”: “all”}
gene_class_dictNone, dict
Gene classes to fine-tune model to distinguish.
Dictionary in format: {Gene_label_A: list(geneA1, geneA2, …),
Gene_label_B: list(geneB1, geneB2, …)}
Gene values should be Ensembl IDs.
filter_dataNone, dict
Default is to fine-tune with all input data.
Otherwise, dictionary specifying .dataset column name and list of values to filter by.
rare_thresholdfloat
Threshold below which rare cell states should be removed.
For example, setting to 0.05 will remove cell states representing
< 5% of the total cells from the cell state classifier’s possible classes.
max_ncellsNone, int
Maximum number of cells to use for fine-tuning.
Default is to fine-tune with all input data.
max_ncells_per_classNone, int
Maximum number of cells per cell class to use for fine-tuning.
Of note, will be applied after max_ncells above.
(Only valid for cell classification.)
training_argsNone, dict
Training arguments for fine-tuning.
If None, defaults will be inferred for 6 layer Geneformer.
Otherwise, will use the Hugging Face defaults:
Note: Hyperparameter tuning is highly recommended, rather than using defaults.
ray_configNone, dict
Training argument ranges for tuning hyperparameters with Ray.
freeze_layersint
Number of layers to freeze from fine-tuning.
0: no layers will be frozen; 2: first two layers will be frozen; etc.
num_crossval_splits{0, 1, 5}
0: train on all data without splitting
1: split data into train and eval sets by designated split_sizes[“valid”]
5: split data into 5 folds of train and eval sets by designated split_sizes[“valid”]
split_sizesNone, dict
Dictionary of proportion of data to hold out for train, validation, and test sets
{“train”: 0.8, “valid”: 0.1, “test”: 0.1} if intending 80/10/10 train/valid/test split
stratify_splits_colNone, str
Name of column in .dataset to be used for stratified splitting.
Proportion of each class in this column will be the same in the splits as in the original dataset.
no_evalbool
If True, will skip eval step and use all data for training.
Otherwise, will perform eval during training.
forward_batch_sizeint
Batch size for forward pass (for evaluation, not training).
model_versionstr
To auto-select settings for model version other than current default.
Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
token_dictionary_fileNone, str
Default is to use token dictionary file from Geneformer
Otherwise, will load custom gene token dictionary.
nprocint
Number of CPU processes to use.
ngpuint
Number of GPUs available.
evaluate_model(model, num_classes, id_class_dict, eval_data, predict=False, output_directory=None, output_prefix=None)[source]

Evaluate the fine-tuned model.

Parameters

modelnn.Module
Loaded fine-tuned model (e.g. trainer.model)
num_classesint
Number of classes for classifier
id_class_dictdict
Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data
(dictionary of format: numerical IDs: class_labels)
eval_dataDataset
Loaded evaluation .dataset input
predictbool
Whether or not to save eval predictions
output_directoryPath
Path to directory where eval data will be saved
output_prefixstr
Prefix for output files
evaluate_saved_model(model_directory, id_class_dict_file, test_data_file, output_directory, output_prefix, predict=True)[source]

Evaluate the fine-tuned model.

Parameters

model_directoryPath
Path to directory containing model
id_class_dict_filePath
Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
(dictionary of format: numerical IDs: class_labels)
test_data_filePath
Path to directory containing test .dataset
output_directoryPath
Path to directory where eval data will be saved
output_prefixstr
Prefix for output files
predictbool
Whether or not to save eval predictions
hyperopt_classifier(model_directory, num_classes, train_data, eval_data, output_directory, n_trials=100)[source]

Fine-tune model for cell state or gene classification.

Parameters

model_directoryPath
Path to directory containing model
num_classesint
Number of classes for classifier
train_dataDataset
Loaded training .dataset input
For cell classifier, labels in column “label”.
For gene classifier, labels in column “labels”.
eval_dataNone, Dataset
(Optional) Loaded evaluation .dataset input
For cell classifier, labels in column “label”.
For gene classifier, labels in column “labels”.
output_directoryPath
Path to directory where fine-tuned model will be saved
n_trialsint
Number of trials to run for hyperparameter optimization
plot_conf_mat(conf_mat_dict, output_directory, output_prefix, custom_class_order=None)[source]

Plot confusion matrix results of evaluating the fine-tuned model.

Parameters

conf_mat_dictdict
Dictionary of model_name : confusion_matrix_DataFrame
(all_metrics[“conf_matrix”] from self.validate)
output_directoryPath
Path to directory where plots will be saved
output_prefixstr
Prefix for output file
custom_class_orderNone, list
List of classes in custom order for plots.
Same order will be used for all models.
plot_predictions(predictions_file, id_class_dict_file, title, output_directory, output_prefix, custom_class_order=None, kwargs_dict=None)[source]

Plot prediction results of evaluating the fine-tuned model.

Parameters

predictions_filepath
Path of model predictions output to plot
(saved output from self.validate if predict_eval=True)
(or saved output from self.evaluate_saved_model)
id_class_dict_filePath
Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
(dictionary of format: numerical IDs: class_labels)
titlestr
Title for legend containing class labels.
output_directoryPath
Path to directory where plots will be saved
output_prefixstr
Prefix for output file
custom_class_orderNone, list
List of classes in custom order for plots.
Same order will be used for all models.
kwargs_dictNone, dict
Dictionary of kwargs to pass to plotting function.
plot_roc(roc_metric_dict, model_style_dict, title, output_directory, output_prefix)[source]

Plot ROC curve results of evaluating the fine-tuned model.

Parameters

roc_metric_dictdict
Dictionary of model_name : roc_metrics
(all_metrics[“all_roc_metrics”] from self.validate)
model_style_dictdict[dict]
Dictionary of model_name : dictionary of style_attribute : style
where style includes color and linestyle
e.g. {‘Model_A’: {‘color’: ‘black’, ‘linestyle’: ‘-‘}, ‘Model_B’: …}
titlestr
Title of plot (e.g. ‘Dosage-sensitive vs -insensitive factors’)
output_directoryPath
Path to directory where plots will be saved
output_prefixstr
Prefix for output file
prepare_data(input_data_file, output_directory, output_prefix, split_id_dict=None, test_size=None, attr_to_split=None, attr_to_balance=None, max_trials=100, pval_threshold=0.1)[source]

Prepare data for cell state or gene classification.

Parameters

input_data_filePath
Path to directory containing .dataset input
output_directoryPath
Path to directory where prepared data will be saved
output_prefixstr
Prefix for output file
split_id_dictNone, dict
Dictionary of IDs for train and test splits
Three-item dictionary with keys: attr_key, train, test
attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
train: list of IDs in the attr_key column to include in the train split
test: list of IDs in the attr_key column to include in the test split
For example: {“attr_key”: “individual”,
“train”: [“patient1”, “patient2”, “patient3”, “patient4”],
“test”: [“patient5”, “patient6”]}
test_sizeNone, float
Proportion of data to be saved separately and held out for test set
(e.g. 0.2 if intending hold out 20%)
If None, will inherit from split_sizes[“test”] from Classifier
The training set will be further split to train / validation in self.validate
Note: only available for CellClassifiers
attr_to_splitNone, str
Key for attribute on which to split data while balancing potential confounders
e.g. “patient_id” for splitting by patient while balancing other characteristics
Note: only available for CellClassifiers
attr_to_balanceNone, list
List of attribute keys on which to balance data while splitting on attr_to_split
e.g. [“age”, “sex”] for balancing these characteristics while splitting by patient
Note: only available for CellClassifiers
max_trialsNone, int
Maximum number of trials of random splitting to try to achieve balanced other attributes
If no split is found without significant (p<0.05) differences in other attributes, will select best
Note: only available for CellClassifiers
pval_thresholdNone, float
P-value threshold to use for attribute balancing across splits
E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
train_all_data(model_directory, prepared_input_data_file, id_class_dict_file, output_directory, output_prefix, save_eval_output=True, gene_balance=False)[source]

Train cell state or gene classifier using all data.

Parameters

model_directoryPath
Path to directory containing model
prepared_input_data_filePath
Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
id_class_dict_filePath
Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
(dictionary of format: numerical IDs: class_labels)
output_directoryPath
Path to directory where model and eval data will be saved
output_prefixstr
Prefix for output files
save_eval_outputbool
Whether to save cross-fold eval output
Saves as pickle file of dictionary of eval metrics
gene_balanceNone, bool
Whether to automatically balance genes in training set.
Only available for binary gene classifications.

Output

Returns trainer after fine-tuning with all data.

train_classifier(model_directory, num_classes, train_data, eval_data, output_directory, predict=False)[source]

Fine-tune model for cell state or gene classification.

Parameters

model_directoryPath
Path to directory containing model
num_classesint
Number of classes for classifier
train_dataDataset
Loaded training .dataset input
For cell classifier, labels in column “label”.
For gene classifier, labels in column “labels”.
eval_dataNone, Dataset
(Optional) Loaded evaluation .dataset input
For cell classifier, labels in column “label”.
For gene classifier, labels in column “labels”.
output_directoryPath
Path to directory where fine-tuned model will be saved
predictbool
Whether or not to save eval predictions from trainer
validate(model_directory, prepared_input_data_file, id_class_dict_file, output_directory, output_prefix, split_id_dict=None, attr_to_split=None, attr_to_balance=None, gene_balance=False, max_trials=100, pval_threshold=0.1, save_eval_output=True, predict_eval=True, predict_trainer=False, n_hyperopt_trials=0, save_gene_split_datasets=True, debug_gene_split_datasets=False)[source]

(Cross-)validate cell state or gene classifier.

Parameters

model_directoryPath
Path to directory containing model
prepared_input_data_filePath
Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
id_class_dict_filePath
Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data
(dictionary of format: numerical IDs: class_labels)
output_directoryPath
Path to directory where model and eval data will be saved
output_prefixstr
Prefix for output files
split_id_dictNone, dict
Dictionary of IDs for train and eval splits
Three-item dictionary with keys: attr_key, train, eval
attr_key: key specifying name of column in .dataset that contains the IDs for the data splits
train: list of IDs in the attr_key column to include in the train split
eval: list of IDs in the attr_key column to include in the eval split
For example: {“attr_key”: “individual”,
“train”: [“patient1”, “patient2”, “patient3”, “patient4”],
“eval”: [“patient5”, “patient6”]}
Note: only available for CellClassifiers with 1-fold split (self.classifier=”cell”; self.num_crossval_splits=1)
attr_to_splitNone, str
Key for attribute on which to split data while balancing potential confounders
e.g. “patient_id” for splitting by patient while balancing other characteristics
Note: only available for CellClassifiers with 1-fold split (self.classifier=”cell”; self.num_crossval_splits=1)
attr_to_balanceNone, list
List of attribute keys on which to balance data while splitting on attr_to_split
e.g. [“age”, “sex”] for balancing these characteristics while splitting by patient
gene_balanceNone, bool
Whether to automatically balance genes in training set.
Only available for binary gene classifications.
max_trialsNone, int
Maximum number of trials of random splitting to try to achieve balanced other attribute
If no split is found without significant (p < pval_threshold) differences in other attributes, will select best
pval_thresholdNone, float
P-value threshold to use for attribute balancing across splits
E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
save_eval_outputbool
Whether to save cross-fold eval output
Saves as pickle file of dictionary of eval metrics
predict_evalbool
Whether or not to save eval predictions
Saves as a pickle file of self.evaluate predictions
predict_trainerbool
Whether or not to save eval predictions from trainer
Saves as a pickle file of trainer predictions
n_hyperopt_trialsint
Number of trials to run for hyperparameter optimization
If 0, will not optimize hyperparameters
save_gene_split_datasetsbool
Whether or not to save train, valid, and test gene-labeled datasets