Tutorial 5: Guided Generation with ESM3¶

Guided generation is a powerful tool that allows you to sample outputs out of ESM3 that maximize any kind of score function.

For example, you may want to

  1. Guide generations towards higher quality metrics like pTM
  2. Constrain the distribution of outputs to have certain amino acid frequencies or structural attributes
  3. Minimize a biophysical energy function
  4. Use experimental screening data to guide designs with a regression model

As long as your scoring function takes a protein as input and outputs a single score, you can use it to guide designs. To accomplish this, we use an implementation of derivative-free guidance inspired by Soft Value-Based Decoding described in Li, et al 2024.

In this notebook we will walk through a few examples to illustrate how to use guided generation.

  1. Guide towards high pTM for improved generation quality
  2. Generate a protein with no cysteine (C) residues
  3. Maximize protein globularity by minimizing the radius of gyration

Imports¶

In [ ]:
!pip install git+https://github.com/evolutionaryscale/esm.git
!pip install py3dmol
In [1]:
import biotite.structure as bs
import py3Dmol
from esm.models.esm3 import ESM3
from esm.sdk.api import ESMProtein, GenerationConfig
from esm.sdk.experimental import ESM3GuidedDecoding, GuidedDecodingScoringFunction

Creating a scoring function¶

To get started with the guided generation API the only thing you need is to create a callable class that inherits from GuidedDecodingScoringFunction. This class should receive as input an ESMProtein object and output a numerical score.

For example, one of the computational metrics we can use to measure the quality of a generated protein structure is the Predicted Template Modelling (pTM) score, so we'll use it to create a PTMScoringFunction.

Fortunately for us, every time we generate a protein using ESM3 (either locally or on Forge) we also get its pTM, so all our class needs to do when its called is to return the ptm attribute of its input.

In [2]:
# Create scoring function (e.g. PTM scoring function)
class PTMScoringFunction(GuidedDecodingScoringFunction):
    def __call__(self, protein: ESMProtein) -> float:
        # Minimal example of a scoring function that scores proteins based on their pTM score
        # Given that ESM3 already has a pTM prediction head, we can directly access the pTM score
        assert protein.ptm is not None, "Protein must have pTM scores to be scored"
        return float(protein.ptm)

Initialize your client¶

The guided generation is compatible with both local inference using the ESM3 class and remote inference with the Forge client

In [4]:
# To use the tokenizers and the open model you'll need to login into Hugging Face
# ! pip install ipywidgets
from huggingface_hub import notebook_login

notebook_login()
VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…
In [5]:
## Locally with ESM3-open
model = ESM3.from_pretrained().to("cuda")

## On Forge with larger ESM3 models
# from getpass import getpass

# from esm.sdk import client

# token = getpass("Token from Forge console: ")
# model = client(model="esm3-open", url="https://forge.evolutionaryscale.ai", token=token)
Fetching 22 files:   0%|          | 0/22 [00:00<?, ?it/s]
hyperplanes_8bit_68103.npz:   0%|          | 0.00/34.9M [00:00<?, ?B/s]
hyperplanes_8bit_58641.npz:   0%|          | 0.00/30.0M [00:00<?, ?B/s]
.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]
1utn.pdb:   0%|          | 0.00/569k [00:00<?, ?B/s]
config.json:   0%|          | 0.00/3.00 [00:00<?, ?B/s]
esm3_entry.list:   0%|          | 0.00/1.93M [00:00<?, ?B/s]
entry_list_safety_29026.list:   0%|          | 0.00/1.60M [00:00<?, ?B/s]
ParentChildTreeFile.txt:   0%|          | 0.00/595k [00:00<?, ?B/s]
interpro2keywords.csv:   0%|          | 0.00/7.32M [00:00<?, ?B/s]
interpro_29026_to_keywords_58641.csv:   0%|          | 0.00/10.1M [00:00<?, ?B/s]
keyword_idf_safety_filtered_58641.npy:   0%|          | 0.00/469k [00:00<?, ?B/s]
(…)ord_vocabulary_safety_filtered_58641.txt:   0%|          | 0.00/788k [00:00<?, ?B/s]
tag_dict_4.json:   0%|          | 0.00/691k [00:00<?, ?B/s]
tfidf_safety_filtered_58641.pkl:   0%|          | 0.00/2.02M [00:00<?, ?B/s]
tag_dict_4_safety_filtered.json:   0%|          | 0.00/569k [00:00<?, ?B/s]
keywords.txt:   0%|          | 0.00/788k [00:00<?, ?B/s]
(…)0_residue_annotations_gt_1k_proteins.csv:   0%|          | 0.00/109k [00:00<?, ?B/s]
esm3_function_decoder_v0.pth:   0%|          | 0.00/1.30G [00:00<?, ?B/s]
esm3_sm_open_v1.pth:   0%|          | 0.00/2.80G [00:00<?, ?B/s]
esm3_structure_decoder_v0.pth:   0%|          | 0.00/1.24G [00:00<?, ?B/s]
esm3_structure_encoder_v0.pth:   0%|          | 0.00/62.3M [00:00<?, ?B/s]

Guide towards high pTM for improved generation quality¶

Once your scoring function is defined and you have initialized your model you can create an ESM3GuidedDecoding instance to sample from it

In [6]:
ptm_guided_decoding = ESM3GuidedDecoding(
    client=model, scoring_function=PTMScoringFunction()
)
In [7]:
# Start from a fully masked protein
PROTEIN_LENGTH = 256
starting_protein = ESMProtein(sequence="_" * PROTEIN_LENGTH)

# Call guided_generate
generated_protein = ptm_guided_decoding.guided_generate(
    protein=starting_protein,
    num_decoding_steps=len(starting_protein) // 8,
    num_samples_per_step=10,
)
Current score: 0.95: 100%|██████████| 32/32 [00:27<00:00,  1.15it/s]

Compare against baseline with no guidance¶

First we are going to sample a protein generated without any guidance. This means that, when not providing pTM guidance, we could be sampling proteins that have no clear structure.

In [10]:
# Generate a protein WITHOUT guidance
generated_protein_no_guided: ESMProtein = model.generate(
    input=starting_protein,
    config=GenerationConfig(track="sequence", num_steps=len(starting_protein) // 8),
)  # type: ignore

# Fold
generated_protein_no_guided: ESMProtein = model.generate(
    input=generated_protein_no_guided,
    config=GenerationConfig(track="structure", num_steps=1),
)  # type: ignore
100%|██████████| 32/32 [00:00<00:00, 41.03it/s]
100%|██████████| 1/1 [00:00<00:00, 41.38it/s]
In [11]:
# Create a 1x2 grid of viewers (1 row, 2 columns)
view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))

# Convert ESMProtein objects to ProteinChain objects
protein_chain1 = generated_protein_no_guided.to_protein_chain()
protein_chain2 = generated_protein.to_protein_chain()

# Add models to respective panels
view.addModel(protein_chain1.to_pdb_string(), "pdb", viewer=(0, 0))
view.addModel(protein_chain2.to_pdb_string(), "pdb", viewer=(0, 1))

# Set styles for each protein
view.setStyle({}, {"cartoon": {"color": "spectrum"}}, viewer=(0, 0))
view.setStyle({}, {"cartoon": {"color": "spectrum"}}, viewer=(0, 1))

# Zoom and center the view
view.zoomTo()
view.show()

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

Generate a Protein with No Cysteines¶

Guided generation is not constrained to structural metrics, you can also use it to guide the sequence generation.

For example, we can create a NoCysteineScoringFunction that penalizes the protein if it contains Cysteine residues

In [12]:
class NoCysteineScoringFunction(GuidedDecodingScoringFunction):
    def __call__(self, protein: ESMProtein) -> float:
        # Penalize proteins that contain cysteine
        assert protein.sequence is not None, "Protein must have a sequence to be scored"
        # Note that we use a negative score here, to discourage the presence of cysteine
        return -protein.sequence.count("C")
In [13]:
no_cysteine_guided_decoding = ESM3GuidedDecoding(
    client=model, scoring_function=NoCysteineScoringFunction()
)
In [14]:
no_cysteine_protein = no_cysteine_guided_decoding.guided_generate(
    protein=starting_protein,
    num_decoding_steps=len(starting_protein) // 8,
    num_samples_per_step=10,
)
Current score: 0.00: 100%|██████████| 32/32 [00:25<00:00,  1.23it/s]

Let's check our sequence!

If guided generation converged to score == 0.00, the resulting protein should contain no Cysteine residues

In [15]:
assert no_cysteine_protein.sequence is not None, "Protein must have a sequence"
print(no_cysteine_protein.sequence)
print(f"Number of cysteine residues: {no_cysteine_protein.sequence.count('C')}")
MANKILKNLRTTSKYISSRTTSRLTAYLIGFAEPRGLELLPITPAGRNPNDLLKLLSERIGWVSKRFSIKNVTVGSLVPINTSAVNVYRRTLSKTKTSLQSEVSTRQGTYTIPVNSFAIIEYTNLRKFIEELAGVKVRKVEFLLNEESLIIKIIPYISKDVQELRQLKVDIPKEIIEQFFGKSSIDKISKKFNKNNRIVEEKRKDYSREYYDIRTFPVENNEFKGSAEILSTHPVYVFETKNHQVESGVFLPLEIF
Number of cysteine residues: 0

Maximize Globularity¶

We use the radius of gyration as a proxy to maximize globularity, we also encourage generations to have high pTM

In [16]:
class RadiousOfGyrationScoringFunction(GuidedDecodingScoringFunction):
    def __call__(self, protein: ESMProtein) -> float:
        score = -1 * self.radius_of_gyration(protein)

        assert protein.ptm is not None, "Protein must have pTM scores to be scored"
        if protein.ptm < 0.5:
            # Penalize proteins with low pTM scores
            score = score * 2

        return score

    @staticmethod
    def radius_of_gyration(protein: ESMProtein) -> float:
        protein_chain = protein.to_protein_chain()
        arr = protein_chain.atom_array_no_insertions
        return bs.gyration_radius(arr)
In [17]:
radius_guided_decoding = ESM3GuidedDecoding(
    client=model, scoring_function=RadiousOfGyrationScoringFunction()
)
In [18]:
radius_guided_protein = radius_guided_decoding.guided_generate(
    protein=starting_protein,
    num_decoding_steps=len(starting_protein) // 8,
    num_samples_per_step=10,
)
Current score: -16.94: 100%|██████████| 32/32 [00:34<00:00,  1.08s/it]
In [19]:
view = py3Dmol.view(width=800, height=400)
view.addModel(radius_guided_protein.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "spectrum"}})
view.zoomTo()

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

Out[19]:
<py3Dmol.view at 0x7f5886122490>