geneformer.emb_extractor
Geneformer embedding extractor.
Description:
Extracts gene or cell embeddings.
Plots cell embeddings as heatmaps or UMAPs.
Generates cell state embedding dictionary for use with InSilicoPerturber.
- class EmbExtractor(model_type='Pretrained', num_classes=0, emb_mode='cls', cell_emb_style='mean_pool', gene_emb_style='mean_pool', filter_data=None, max_ncells=1000, emb_layer=-1, emb_label=None, labels_to_plot=None, forward_batch_size=100, nproc=4, summary_stat=None, model_version='V2', token_dictionary_file=None)[source]
Initialize embedding extractor.
Parameters:
- model_type{“Pretrained”, “GeneClassifier”, “CellClassifier”, “Pretrained-Quantized”}
- Whether model is the pretrained Geneformer (full or quantized) or a fine-tuned gene or cell classifier.
- num_classesint
- If model is a gene or cell classifier, specify number of classes it was trained to classify.For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
- emb_mode{“cls”, “cell”, “gene”}
- Whether to output CLS, cell, or gene embeddings.CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
- cell_emb_style{“mean_pool”}
- Method for summarizing cell embeddings if not using CLS token.Currently only option is mean pooling of gene embeddings for given cell.
- gene_emb_style“mean_pool”
- Method for summarizing gene embeddings.Currently only option is mean pooling of contextual gene embeddings for given gene.
- filter_dataNone, dict
- Default is to extract embeddings from all input data.Otherwise, dictionary specifying .dataset column name and list of values to filter by.
- max_ncellsNone, int
- Maximum number of cells to extract embeddings from.Default is 1000 cells randomly sampled from input data.If None, will extract embeddings from all cells.
- emb_layer{-1, 0}
- Embedding layer to extract.The last layer is most specifically weighted to optimize the given learning objective.Generally, it is best to extract the 2nd to last layer to get a more general representation.-1: 2nd to last layer0: last layer
- emb_labelNone, list
- List of column name(s) in .dataset to add as labels to embedding output.
- labels_to_plotNone, list
- Cell labels to plot.Shown as color bar in heatmap.Shown as cell color in umap.Plotting umap requires labels to plot.
- forward_batch_sizeint
- Batch size for forward pass.
- nprocint
- Number of CPU processes to use.
- summary_stat{None, “mean”, “median”, “exact_mean”, “exact_median”}
- If exact_mean or exact_median, outputs only exact mean or median embedding of input data.If mean or median, outputs only approximated mean or median embedding of input data.Non-exact recommended if encountering memory constraints while generating goal embedding positions.Non-exact is slower but more memory-efficient.
- model_versionstr
- To auto-select settings for model version other than current default.Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
- token_dictionary_filePath
- Default is the Geneformer token dictionaryPath to pickle file containing token dictionary (Ensembl ID:token).
Examples:
>>> from geneformer import EmbExtractor >>> embex = EmbExtractor(model_type="CellClassifier", ... num_classes=3, ... emb_mode="cell", ... filter_data={"cell_type":["cardiomyocyte"]}, ... max_ncells=1000, ... emb_layer=-1, ... emb_label=["disease", "cell_type"], ... labels_to_plot=["disease", "cell_type"])
- extract_embs(model_directory, input_data_file, output_directory, output_prefix, output_torch_embs=False, cell_state=None)[source]
Extract embeddings from input data and save as results in output_directory.
Parameters:
- model_directoryPath
- Path to directory containing model
- input_data_filePath
- Path to directory containing .dataset inputs
- output_directoryPath
- Path to directory where embedding data will be saved as csv
- output_prefixstr
- Prefix for output file
- output_torch_embsbool
- Whether or not to also output the embeddings as a tensor.Note, if true, will output embeddings as both dataframe and tensor.
- cell_statedict
- Cell state key and value for state embedding extraction.
Examples:
>>> embs = embex.extract_embs("path/to/model", ... "path/to/input_data", ... "path/to/output_directory", ... "output_prefix")
- get_state_embs(cell_states_to_model, model_directory, input_data_file, output_directory, output_prefix, output_torch_embs=True)[source]
Extract exact mean or exact median cell state embedding positions from input data and save as results in output_directory.
Parameters:
- cell_states_to_modelNone, dict
- Cell states to model if testing perturbations that achieve goal state change.Four-item dictionary with keys: state_key, start_state, goal_state, and alt_statesstate_key: key specifying name of column in .dataset that defines the start/goal statesstart_state: value in the state_key column that specifies the start stategoal_state: value in the state_key column taht specifies the goal end statealt_states: list of values in the state_key column that specify the alternate end statesFor example:{“state_key”: “disease”,“start_state”: “dcm”,“goal_state”: “nf”,“alt_states”: [“hcm”, “other1”, “other2”]}
- model_directoryPath
- Path to directory containing model
- input_data_filePath
- Path to directory containing .dataset inputs
- output_directoryPath
- Path to directory where embedding data will be saved as csv
- output_prefixstr
- Prefix for output file
- output_torch_embsbool
- Whether or not to also output the embeddings as a tensor.Note, if true, will output embeddings as both dataframe and tensor.
Outputs
Outputs state_embs_dict for use with in silico perturber.Format is dictionary of embedding positions of each cell state to model shifts from/towards.Keys specify each possible cell state to model.Values are target embedding positions as torch.tensor.For example:{“nf”: emb_nf,“hcm”: emb_hcm,“dcm”: emb_dcm,“other1”: emb_other1,“other2”: emb_other2}
- plot_embs(embs, plot_style, output_directory, output_prefix, max_ncells_to_plot=1000, kwargs_dict=None)[source]
Plot embeddings, coloring by provided labels.
Parameters:
- embspandas.core.frame.DataFrame
- Pandas dataframe containing embeddings output from extract_embs
- plot_stylestr
- Style of plot: “heatmap” or “umap”
- output_directoryPath
- Path to directory where plots will be saved as pdf
- output_prefixstr
- Prefix for output file
- max_ncells_to_plotNone, int
- Maximum number of cells to plot.Default is 1000 cells randomly sampled from embeddings.If None, will plot embeddings from all cells.
- kwargs_dictdict
- Dictionary of kwargs to pass to plotting function.
Examples:
>>> embex.plot_embs(embs=embs, ... plot_style="heatmap", ... output_directory="path/to/output_directory", ... output_prefix="output_prefix")