geneformer.classifier
Geneformer classifier.
Input data:
Cell state classifier:
Single-cell transcriptomes as Geneformer rank value encodings with cell state labels in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
Gene classifier:
Dictionary in format {Gene_label: list(genes)} for gene labels and single-cell transcriptomes as Geneformer rank value encodings in Geneformer .dataset format (generated from single-cell RNAseq data by tokenizer.py)
Usage:
>>> from geneformer import Classifier
>>> cc = Classifier(classifier="cell", # example of cell state classifier
... cell_state_dict={"state_key": "disease", "states": "all"},
... filter_data={"cell_type":["Cardiomyocyte1","Cardiomyocyte2","Cardiomyocyte3"]},
... training_args=training_args,
... freeze_layers = 2,
... num_crossval_splits = 1,
... forward_batch_size=200,
... nproc=16)
>>> cc.prepare_data(input_data_file="path/to/input_data",
... output_directory="path/to/output_directory",
... output_prefix="output_prefix")
>>> all_metrics = cc.validate(model_directory="path/to/model",
... prepared_input_data_file=f"path/to/output_directory/{output_prefix}_labeled.dataset",
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
... output_directory="path/to/output_directory",
... output_prefix="output_prefix",
... predict_eval=True)
>>> cc.plot_conf_mat(conf_mat_dict={"Geneformer": all_metrics["conf_matrix"]},
... output_directory="path/to/output_directory",
... output_prefix="output_prefix",
... custom_class_order=["healthy","disease1","disease2"])
>>> cc.plot_predictions(predictions_file=f"path/to/output_directory/datestamp_geneformer_cellClassifier_{output_prefix}/ksplit1/predictions.pkl",
... id_class_dict_file=f"path/to/output_directory/{output_prefix}_id_class_dict.pkl",
... title="disease",
... output_directory="path/to/output_directory",
... output_prefix="output_prefix",
... custom_class_order=["healthy","disease1","disease2"])
- class Classifier(classifier=None, quantize=False, cell_state_dict=None, gene_class_dict=None, filter_data=None, rare_threshold=0, max_ncells=None, max_ncells_per_class=None, training_args=None, ray_config=None, freeze_layers=0, num_crossval_splits=1, split_sizes={'test': 0.1, 'train': 0.8, 'valid': 0.1}, stratify_splits_col=None, no_eval=False, forward_batch_size=100, model_version='V2', token_dictionary_file=None, nproc=4, ngpu=1)[source]
Initialize Geneformer classifier.
Parameters:
- classifier{“cell”, “gene”}
- Whether to fine-tune a cell state or gene classifier.
- quantizebool, dict
- Whether to fine-tune a quantized model.If True and no config provided, will use default.Will use custom config if provided.Configs should be provided as dictionary of BitsAndBytesConfig (transformers) and LoraConfig (peft).For example: {“bnb_config”: BitsAndBytesConfig(…),“peft_config”: LoraConfig(…)}
- cell_state_dictNone, dict
- Cell states to fine-tune model to distinguish.Two-item dictionary with keys: state_key and statesstate_key: key specifying name of column in .dataset that defines the states to modelstates: list of values in the state_key column that specifies the states to modelAlternatively, instead of a list of states, can specify “all” to use all states in that state key from input data.Of note, if using “all”, states will be defined after data is filtered.Must have at least 2 states to model.For example: {“state_key”: “disease”,“states”: [“nf”, “hcm”, “dcm”]}or{“state_key”: “disease”,“states”: “all”}
- gene_class_dictNone, dict
- Gene classes to fine-tune model to distinguish.Dictionary in format: {Gene_label_A: list(geneA1, geneA2, …),Gene_label_B: list(geneB1, geneB2, …)}Gene values should be Ensembl IDs.
- filter_dataNone, dict
- Default is to fine-tune with all input data.Otherwise, dictionary specifying .dataset column name and list of values to filter by.
- rare_thresholdfloat
- Threshold below which rare cell states should be removed.For example, setting to 0.05 will remove cell states representing< 5% of the total cells from the cell state classifier’s possible classes.
- max_ncellsNone, int
- Maximum number of cells to use for fine-tuning.Default is to fine-tune with all input data.
- max_ncells_per_classNone, int
- Maximum number of cells per cell class to use for fine-tuning.Of note, will be applied after max_ncells above.(Only valid for cell classification.)
- training_argsNone, dict
- Training arguments for fine-tuning.If None, defaults will be inferred for 6 layer Geneformer.Otherwise, will use the Hugging Face defaults:Note: Hyperparameter tuning is highly recommended, rather than using defaults.
- ray_configNone, dict
- Training argument ranges for tuning hyperparameters with Ray.
- freeze_layersint
- Number of layers to freeze from fine-tuning.0: no layers will be frozen; 2: first two layers will be frozen; etc.
- num_crossval_splits{0, 1, 5}
- 0: train on all data without splitting1: split data into train and eval sets by designated split_sizes[“valid”]5: split data into 5 folds of train and eval sets by designated split_sizes[“valid”]
- split_sizesNone, dict
- Dictionary of proportion of data to hold out for train, validation, and test sets{“train”: 0.8, “valid”: 0.1, “test”: 0.1} if intending 80/10/10 train/valid/test split
- stratify_splits_colNone, str
- Name of column in .dataset to be used for stratified splitting.Proportion of each class in this column will be the same in the splits as in the original dataset.
- no_evalbool
- If True, will skip eval step and use all data for training.Otherwise, will perform eval during training.
- forward_batch_sizeint
- Batch size for forward pass (for evaluation, not training).
- model_versionstr
- To auto-select settings for model version other than current default.Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
- token_dictionary_fileNone, str
- Default is to use token dictionary file from GeneformerOtherwise, will load custom gene token dictionary.
- nprocint
- Number of CPU processes to use.
- ngpuint
- Number of GPUs available.
- evaluate_model(model, num_classes, id_class_dict, eval_data, predict=False, output_directory=None, output_prefix=None)[source]
Evaluate the fine-tuned model.
Parameters
- modelnn.Module
- Loaded fine-tuned model (e.g. trainer.model)
- num_classesint
- Number of classes for classifier
- id_class_dictdict
- Loaded _id_class_dict.pkl previously prepared by Classifier.prepare_data(dictionary of format: numerical IDs: class_labels)
- eval_dataDataset
- Loaded evaluation .dataset input
- predictbool
- Whether or not to save eval predictions
- output_directoryPath
- Path to directory where eval data will be saved
- output_prefixstr
- Prefix for output files
- evaluate_saved_model(model_directory, id_class_dict_file, test_data_file, output_directory, output_prefix, predict=True)[source]
Evaluate the fine-tuned model.
Parameters
- model_directoryPath
- Path to directory containing model
- id_class_dict_filePath
- Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data(dictionary of format: numerical IDs: class_labels)
- test_data_filePath
- Path to directory containing test .dataset
- output_directoryPath
- Path to directory where eval data will be saved
- output_prefixstr
- Prefix for output files
- predictbool
- Whether or not to save eval predictions
- hyperopt_classifier(model_directory, num_classes, train_data, eval_data, output_directory, n_trials=100)[source]
Fine-tune model for cell state or gene classification.
Parameters
- model_directoryPath
- Path to directory containing model
- num_classesint
- Number of classes for classifier
- train_dataDataset
- Loaded training .dataset inputFor cell classifier, labels in column “label”.For gene classifier, labels in column “labels”.
- eval_dataNone, Dataset
- (Optional) Loaded evaluation .dataset inputFor cell classifier, labels in column “label”.For gene classifier, labels in column “labels”.
- output_directoryPath
- Path to directory where fine-tuned model will be saved
- n_trialsint
- Number of trials to run for hyperparameter optimization
- plot_conf_mat(conf_mat_dict, output_directory, output_prefix, custom_class_order=None)[source]
Plot confusion matrix results of evaluating the fine-tuned model.
Parameters
- conf_mat_dictdict
- Dictionary of model_name : confusion_matrix_DataFrame(all_metrics[“conf_matrix”] from self.validate)
- output_directoryPath
- Path to directory where plots will be saved
- output_prefixstr
- Prefix for output file
- custom_class_orderNone, list
- List of classes in custom order for plots.Same order will be used for all models.
- plot_predictions(predictions_file, id_class_dict_file, title, output_directory, output_prefix, custom_class_order=None, kwargs_dict=None)[source]
Plot prediction results of evaluating the fine-tuned model.
Parameters
- predictions_filepath
- Path of model predictions output to plot(saved output from self.validate if predict_eval=True)(or saved output from self.evaluate_saved_model)
- id_class_dict_filePath
- Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data(dictionary of format: numerical IDs: class_labels)
- titlestr
- Title for legend containing class labels.
- output_directoryPath
- Path to directory where plots will be saved
- output_prefixstr
- Prefix for output file
- custom_class_orderNone, list
- List of classes in custom order for plots.Same order will be used for all models.
- kwargs_dictNone, dict
- Dictionary of kwargs to pass to plotting function.
- plot_roc(roc_metric_dict, model_style_dict, title, output_directory, output_prefix)[source]
Plot ROC curve results of evaluating the fine-tuned model.
Parameters
- roc_metric_dictdict
- Dictionary of model_name : roc_metrics(all_metrics[“all_roc_metrics”] from self.validate)
- model_style_dictdict[dict]
- Dictionary of model_name : dictionary of style_attribute : stylewhere style includes color and linestylee.g. {‘Model_A’: {‘color’: ‘black’, ‘linestyle’: ‘-‘}, ‘Model_B’: …}
- titlestr
- Title of plot (e.g. ‘Dosage-sensitive vs -insensitive factors’)
- output_directoryPath
- Path to directory where plots will be saved
- output_prefixstr
- Prefix for output file
- prepare_data(input_data_file, output_directory, output_prefix, split_id_dict=None, test_size=None, attr_to_split=None, attr_to_balance=None, max_trials=100, pval_threshold=0.1)[source]
Prepare data for cell state or gene classification.
Parameters
- input_data_filePath
- Path to directory containing .dataset input
- output_directoryPath
- Path to directory where prepared data will be saved
- output_prefixstr
- Prefix for output file
- split_id_dictNone, dict
- Dictionary of IDs for train and test splitsThree-item dictionary with keys: attr_key, train, testattr_key: key specifying name of column in .dataset that contains the IDs for the data splitstrain: list of IDs in the attr_key column to include in the train splittest: list of IDs in the attr_key column to include in the test splitFor example: {“attr_key”: “individual”,“train”: [“patient1”, “patient2”, “patient3”, “patient4”],“test”: [“patient5”, “patient6”]}
- test_sizeNone, float
- Proportion of data to be saved separately and held out for test set(e.g. 0.2 if intending hold out 20%)If None, will inherit from split_sizes[“test”] from ClassifierThe training set will be further split to train / validation in self.validateNote: only available for CellClassifiers
- attr_to_splitNone, str
- Key for attribute on which to split data while balancing potential confounderse.g. “patient_id” for splitting by patient while balancing other characteristicsNote: only available for CellClassifiers
- attr_to_balanceNone, list
- List of attribute keys on which to balance data while splitting on attr_to_splite.g. [“age”, “sex”] for balancing these characteristics while splitting by patientNote: only available for CellClassifiers
- max_trialsNone, int
- Maximum number of trials of random splitting to try to achieve balanced other attributesIf no split is found without significant (p<0.05) differences in other attributes, will select bestNote: only available for CellClassifiers
- pval_thresholdNone, float
- P-value threshold to use for attribute balancing across splitsE.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
- train_all_data(model_directory, prepared_input_data_file, id_class_dict_file, output_directory, output_prefix, save_eval_output=True, gene_balance=False)[source]
Train cell state or gene classifier using all data.
Parameters
- model_directoryPath
- Path to directory containing model
- prepared_input_data_filePath
- Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
- id_class_dict_filePath
- Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data(dictionary of format: numerical IDs: class_labels)
- output_directoryPath
- Path to directory where model and eval data will be saved
- output_prefixstr
- Prefix for output files
- save_eval_outputbool
- Whether to save cross-fold eval outputSaves as pickle file of dictionary of eval metrics
- gene_balanceNone, bool
- Whether to automatically balance genes in training set.Only available for binary gene classifications.
Output
Returns trainer after fine-tuning with all data.
- train_classifier(model_directory, num_classes, train_data, eval_data, output_directory, predict=False)[source]
Fine-tune model for cell state or gene classification.
Parameters
- model_directoryPath
- Path to directory containing model
- num_classesint
- Number of classes for classifier
- train_dataDataset
- Loaded training .dataset inputFor cell classifier, labels in column “label”.For gene classifier, labels in column “labels”.
- eval_dataNone, Dataset
- (Optional) Loaded evaluation .dataset inputFor cell classifier, labels in column “label”.For gene classifier, labels in column “labels”.
- output_directoryPath
- Path to directory where fine-tuned model will be saved
- predictbool
- Whether or not to save eval predictions from trainer
- validate(model_directory, prepared_input_data_file, id_class_dict_file, output_directory, output_prefix, split_id_dict=None, attr_to_split=None, attr_to_balance=None, gene_balance=False, max_trials=100, pval_threshold=0.1, save_eval_output=True, predict_eval=True, predict_trainer=False, n_hyperopt_trials=0, save_gene_split_datasets=True, debug_gene_split_datasets=False)[source]
(Cross-)validate cell state or gene classifier.
Parameters
- model_directoryPath
- Path to directory containing model
- prepared_input_data_filePath
- Path to directory containing _labeled.dataset previously prepared by Classifier.prepare_data
- id_class_dict_filePath
- Path to _id_class_dict.pkl previously prepared by Classifier.prepare_data(dictionary of format: numerical IDs: class_labels)
- output_directoryPath
- Path to directory where model and eval data will be saved
- output_prefixstr
- Prefix for output files
- split_id_dictNone, dict
- Dictionary of IDs for train and eval splitsThree-item dictionary with keys: attr_key, train, evalattr_key: key specifying name of column in .dataset that contains the IDs for the data splitstrain: list of IDs in the attr_key column to include in the train spliteval: list of IDs in the attr_key column to include in the eval splitFor example: {“attr_key”: “individual”,“train”: [“patient1”, “patient2”, “patient3”, “patient4”],“eval”: [“patient5”, “patient6”]}Note: only available for CellClassifiers with 1-fold split (self.classifier=”cell”; self.num_crossval_splits=1)
- attr_to_splitNone, str
- Key for attribute on which to split data while balancing potential confounderse.g. “patient_id” for splitting by patient while balancing other characteristicsNote: only available for CellClassifiers with 1-fold split (self.classifier=”cell”; self.num_crossval_splits=1)
- attr_to_balanceNone, list
- List of attribute keys on which to balance data while splitting on attr_to_splite.g. [“age”, “sex”] for balancing these characteristics while splitting by patient
- gene_balanceNone, bool
- Whether to automatically balance genes in training set.Only available for binary gene classifications.
- max_trialsNone, int
- Maximum number of trials of random splitting to try to achieve balanced other attributeIf no split is found without significant (p < pval_threshold) differences in other attributes, will select best
- pval_thresholdNone, float
- P-value threshold to use for attribute balancing across splitsE.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
- save_eval_outputbool
- Whether to save cross-fold eval outputSaves as pickle file of dictionary of eval metrics
- predict_evalbool
- Whether or not to save eval predictionsSaves as a pickle file of self.evaluate predictions
- predict_trainerbool
- Whether or not to save eval predictions from trainerSaves as a pickle file of trainer predictions
- n_hyperopt_trialsint
- Number of trials to run for hyperparameter optimizationIf 0, will not optimize hyperparameters
- save_gene_split_datasetsbool
- Whether or not to save train, valid, and test gene-labeled datasets