geneformer.emb_extractor

Geneformer embedding extractor.

Description:

Extracts gene or cell embeddings.
Plots cell embeddings as heatmaps or UMAPs.
Generates cell state embedding dictionary for use with InSilicoPerturber.
class EmbExtractor(model_type='Pretrained', num_classes=0, emb_mode='cls', cell_emb_style='mean_pool', gene_emb_style='mean_pool', filter_data=None, max_ncells=1000, emb_layer=-1, emb_label=None, labels_to_plot=None, forward_batch_size=100, nproc=4, summary_stat=None, model_version='V2', token_dictionary_file=None)[source]

Initialize embedding extractor.

Parameters:

model_type{“Pretrained”, “GeneClassifier”, “CellClassifier”, “Pretrained-Quantized”}
Whether model is the pretrained Geneformer (full or quantized) or a fine-tuned gene or cell classifier.
num_classesint
If model is a gene or cell classifier, specify number of classes it was trained to classify.
For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
emb_mode{“cls”, “cell”, “gene”}
Whether to output CLS, cell, or gene embeddings.
CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
cell_emb_style{“mean_pool”}
Method for summarizing cell embeddings if not using CLS token.
Currently only option is mean pooling of gene embeddings for given cell.
gene_emb_style“mean_pool”
Method for summarizing gene embeddings.
Currently only option is mean pooling of contextual gene embeddings for given gene.
filter_dataNone, dict
Default is to extract embeddings from all input data.
Otherwise, dictionary specifying .dataset column name and list of values to filter by.
max_ncellsNone, int
Maximum number of cells to extract embeddings from.
Default is 1000 cells randomly sampled from input data.
If None, will extract embeddings from all cells.
emb_layer{-1, 0}
Embedding layer to extract.
The last layer is most specifically weighted to optimize the given learning objective.
Generally, it is best to extract the 2nd to last layer to get a more general representation.
-1: 2nd to last layer
0: last layer
emb_labelNone, list
List of column name(s) in .dataset to add as labels to embedding output.
labels_to_plotNone, list
Cell labels to plot.
Shown as color bar in heatmap.
Shown as cell color in umap.
Plotting umap requires labels to plot.
forward_batch_sizeint
Batch size for forward pass.
nprocint
Number of CPU processes to use.
summary_stat{None, “mean”, “median”, “exact_mean”, “exact_median”}
If exact_mean or exact_median, outputs only exact mean or median embedding of input data.
If mean or median, outputs only approximated mean or median embedding of input data.
Non-exact recommended if encountering memory constraints while generating goal embedding positions.
Non-exact is slower but more memory-efficient.
model_versionstr
To auto-select settings for model version other than current default.
Current options: V1: models pretrained on ~30M cells, V2: models pretrained on ~104M cells
token_dictionary_filePath
Default is the Geneformer token dictionary
Path to pickle file containing token dictionary (Ensembl ID:token).

Examples:

>>> from geneformer import EmbExtractor
>>> embex = EmbExtractor(model_type="CellClassifier",
...         num_classes=3,
...         emb_mode="cell",
...         filter_data={"cell_type":["cardiomyocyte"]},
...         max_ncells=1000,
...         emb_layer=-1,
...         emb_label=["disease", "cell_type"],
...         labels_to_plot=["disease", "cell_type"])
extract_embs(model_directory, input_data_file, output_directory, output_prefix, output_torch_embs=False, cell_state=None)[source]

Extract embeddings from input data and save as results in output_directory.

Parameters:

model_directoryPath
Path to directory containing model
input_data_filePath
Path to directory containing .dataset inputs
output_directoryPath
Path to directory where embedding data will be saved as csv
output_prefixstr
Prefix for output file
output_torch_embsbool
Whether or not to also output the embeddings as a tensor.
Note, if true, will output embeddings as both dataframe and tensor.
cell_statedict
Cell state key and value for state embedding extraction.

Examples:

>>> embs = embex.extract_embs("path/to/model",
...                           "path/to/input_data",
...                           "path/to/output_directory",
...                           "output_prefix")
get_state_embs(cell_states_to_model, model_directory, input_data_file, output_directory, output_prefix, output_torch_embs=True)[source]

Extract exact mean or exact median cell state embedding positions from input data and save as results in output_directory.

Parameters:

cell_states_to_modelNone, dict
Cell states to model if testing perturbations that achieve goal state change.
Four-item dictionary with keys: state_key, start_state, goal_state, and alt_states
state_key: key specifying name of column in .dataset that defines the start/goal states
start_state: value in the state_key column that specifies the start state
goal_state: value in the state_key column taht specifies the goal end state
alt_states: list of values in the state_key column that specify the alternate end states
For example:
{“state_key”: “disease”,
“start_state”: “dcm”,
“goal_state”: “nf”,
“alt_states”: [“hcm”, “other1”, “other2”]}
model_directoryPath
Path to directory containing model
input_data_filePath
Path to directory containing .dataset inputs
output_directoryPath
Path to directory where embedding data will be saved as csv
output_prefixstr
Prefix for output file
output_torch_embsbool
Whether or not to also output the embeddings as a tensor.
Note, if true, will output embeddings as both dataframe and tensor.

Outputs

Outputs state_embs_dict for use with in silico perturber.
Format is dictionary of embedding positions of each cell state to model shifts from/towards.
Keys specify each possible cell state to model.
Values are target embedding positions as torch.tensor.
For example:
{“nf”: emb_nf,
“hcm”: emb_hcm,
“dcm”: emb_dcm,
“other1”: emb_other1,
“other2”: emb_other2}
plot_embs(embs, plot_style, output_directory, output_prefix, max_ncells_to_plot=1000, kwargs_dict=None)[source]

Plot embeddings, coloring by provided labels.

Parameters:

embspandas.core.frame.DataFrame
Pandas dataframe containing embeddings output from extract_embs
plot_stylestr
Style of plot: “heatmap” or “umap”
output_directoryPath
Path to directory where plots will be saved as pdf
output_prefixstr
Prefix for output file
max_ncells_to_plotNone, int
Maximum number of cells to plot.
Default is 1000 cells randomly sampled from embeddings.
If None, will plot embeddings from all cells.
kwargs_dictdict
Dictionary of kwargs to pass to plotting function.

Examples:

>>> embex.plot_embs(embs=embs,
...                 plot_style="heatmap",
...                 output_directory="path/to/output_directory",
...                 output_prefix="output_prefix")