Tutorial 4: Generating with ESM3¶

ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.

ESM3 is a generative masked language model. You can prompt it with partial sequence, structure, and function keywords, and iteratively sample masked positions until all positions are unmasked. This iterative sampling is what the .generate() function does.

image.png

The ESM3 architecture is highly scalable due to its transformer backbone and all-to-all reasoning over discrete token sequences. At its largest scale, ESM3 was trained with 1.07e24 FLOPs on 2.78 billion proteins and 771 billion unique tokens, and has 98 billion parameters. Here we present esm3-open-small. With 1.4B parameters it is the smallest and fastest model in the family, trained specifically to be open sourced. ESM3-open is available under a non-commercial license.

Imports¶

In [2]:
%set_env TOKENIZERS_PARALLELISM=false
# !pip install esm
import numpy as np
import torch

# !pip install py3Dmol
import py3Dmol
from esm.sdk import client
from esm.sdk.api import ESMProtein, GenerationConfig
from esm.utils.structure.protein_chain import ProteinChain
env: TOKENIZERS_PARALLELISM=false

Set up the client to Forge¶

Grab a token from the Forge console and add it below. Note that your token is like a password for your account and you should take care to protect it. For this reason it is recommended to frequently create a new token and delete old, unused ones. It is also recommended to paste the token directly into an environment variable or use a utility like getpass as shown below so tokens are not accidentally shared or checked into code repositories.

In [ ]:
# from getpass import getpass

# token = getpass("Token from Forge console: ")
# model = client(model="esm3-open", url="https://forge.evolutionaryscale.ai", token=token)
In [5]:
from esm.models.esm3 import ESM3
from huggingface_hub import notebook_login
notebook_login()
VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…
In [6]:
model = ESM3.from_pretrained().to("cuda")

Let's construct a prompt for ESM3, focusing on the task of scaffolding a motif from a natural protein¶

First, we can use the ProteinChain class from the esm sdk to grab a protein structure from the PDB. We'll work with a human renal (kidney) dipeptidase (a protein that breaks up two amino acids bound together). Renal dipeptidases are of particular interest because they metabolize certain antibiotics.

In [4]:
pdb_id = "1ITU"  # PDB ID corresponding to Renal Dipeptidase
chain_id = "A"  # Chain ID corresponding to Renal Dipeptidase in the PDB structure
renal_dipep_chain = ProteinChain.from_rcsb(pdb_id, chain_id)
# Alternatively, we could have used ProteinChain.from_pdb() to load a protein structure from a local PDB file

The ProteinChain class is a object that makes it easy to work with protein structures. It contains a sequence attribute that contains the amino acid sequence of the protein

In [5]:
print(renal_dipep_chain.sequence)
DFFRDEAERIMRDSPVIDGHNDLPWQLLDMFNNRLQDERANLTTLAGTHTNIPKLRAGFVGGQFWSVYTPCDTQNKDAVRRTLEQMDVVHRMCRMYPETFLYVTSSAGIRQAFREGKVASLIGVEGGHSIDSSLGVLRALYQLGMRYLTLTHSCNTPWADNWLVDTGDSEPQSQGLSPFGQRVVKELNRLGVLIDLAHVSVATMKATLQLSRAPVIFSHSSAYSVCASRRNVPDDVLRLVKQTDSLVMVNFYNNYISCTNKANLSQVADHLDHIKEVAGARAVGFGGDFDGVPRVPEGLEDVSKYPDLIAELLRRNWTEAEVKGALADNLLRVFEAVEQASNLTQAPEEEPIPLDQLGGSCRTHYGYSS

ProteinChain also contains an atom37_positions numpy array that contains the atomic coordinates of each of the residues in the protein.

The shape of the array is (n_residues, 37, 3) where n_residues is the number of residues in the protein and 37 is the number of possible distinct atoms that may be present across all amino acids (e.g. the first three atoms are the N, C-alpha, and C atoms corresponding to the protein backbone). The 3 corresponds to the x, y, and z coordinates of each atom. The atom37 representation of protein structure allows us to use a single format to conveniently represent all amino acids -- coordinates are only present for the atoms that are present in the amino acid and nan otherwise.

In [6]:
print("atom37_positions shape: ", renal_dipep_chain.atom37_positions.shape)
print(renal_dipep_chain.atom37_positions[:3])
atom37_positions shape:  (369, 37, 3)
[[[-40.525  -9.87   -2.643]
  [-39.79   -9.325  -3.825]
  [-38.765 -10.354  -4.294]
  [-39.096  -8.012  -3.45 ]
  [-37.878 -10.748  -3.53 ]
  [-38.41   -7.359  -4.629]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [-39.105  -7.036  -5.617]
  [-37.177  -7.161  -4.562]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]]

 [[-38.877 -10.768  -5.555]
  [-37.975 -11.767  -6.115]
  [-36.508 -11.389  -6.096]
  [-38.365 -12.141  -7.546]
  [-35.674 -12.205  -5.716]
  [-37.411 -13.109  -8.19 ]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [-36.568 -12.698  -9.215]
  [-37.342 -14.432  -7.756]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [-35.67  -13.589  -9.799]
  [-36.447 -15.332  -8.333]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [-35.612 -14.91   -9.356]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]]

 [[-36.191 -10.172  -6.525]
  [-34.798  -9.736  -6.576]
  [-34.127  -9.485  -5.225]
  [-34.629  -8.57   -7.553]
  [-32.912  -9.65   -5.097]
  [-34.691  -8.997  -9.002]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [-33.837  -9.991  -9.482]
  [-35.629  -8.45   -9.87 ]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [-33.912 -10.442 -10.806]
  [-35.714  -8.891 -11.195]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [-34.852  -9.892 -11.662]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]
  [    nan     nan     nan]]]

We can visualize the protein chain using the py3Dmol library

In [7]:
# First we can create a `py3Dmol` view object
view = py3Dmol.view(width=500, height=500)
# py3Dmol requires the atomic coordinates to be in PDB format, so we convert the `ProteinChain` object to a PDB string
pdb_str = renal_dipep_chain.to_pdb_string()
# Load the PDB string into the `py3Dmol` view object
view.addModel(pdb_str, "pdb")
# Set the style of the protein chain
view.setStyle({"cartoon": {"color": "spectrum"}})
# Zoom in on the protein chain
view.zoomTo()
# Display the protein chain
view.show()

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

Now, let's try to scaffold a motif from this protein using ESM3 -- we'll prompt the model with the sequence and structure of a helix-coil motif from renal dipeptidase and have the model generate a larger scaffold that includes the motif

In [8]:
motif_inds = np.arange(123, 146)
# `ProteinChain` objects can be indexed like numpy arrays to extract the sequence and atomic coordinates of a subset of residues
motif_sequence = renal_dipep_chain[motif_inds].sequence
motif_atom37_positions = renal_dipep_chain[motif_inds].atom37_positions
print("Motif sequence: ", motif_sequence)
print("Motif atom37_positions shape: ", motif_atom37_positions.shape)
Motif sequence:  VEGGHSIDSSLGVLRALYQLGMR
Motif atom37_positions shape:  (23, 37, 3)

We can also visualize the motif in the original chain using py3Dmol. We'll color the original chain in grey and the motif in blue

In [9]:
view = py3Dmol.view(width=500, height=500)
view.addModel(pdb_str, "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
motif_res_inds = (
    motif_inds + 1
).tolist()  # residue indices are 1-indexed in PDB files, so we add 1 to the indices
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}})
view.zoomTo()
view.show()

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

Now, we can use the ESMProtein class to construct a prompt that will instruct ESM3 to scaffold the motif

In [10]:
prompt_length = 200
# First, we can construct a sequence prompt of all masks
sequence_prompt = ["_"] * prompt_length
# Then, we can randomly insert the motif sequence into the prompt (we randomly choose 72 here)
sequence_prompt[72 : 72 + len(motif_sequence)] = list(motif_sequence)
sequence_prompt = "".join(sequence_prompt)
print("Sequence prompt: ", sequence_prompt)
print("Length of sequence prompt: ", len(sequence_prompt))

# Next, we can construct a structure prompt of all nan coordinates
structure_prompt = torch.full((prompt_length, 37, 3), np.nan)
# Then, we can insert the motif atomic coordinates into the prompt, starting at index 72
structure_prompt[72 : 72 + len(motif_atom37_positions)] = torch.tensor(
    motif_atom37_positions
)
print("Structure prompt shape: ", structure_prompt.shape)
print(
    "Indices with structure conditioning: ",
    torch.where(~torch.isnan(structure_prompt).any(dim=-1).all(dim=-1))[0].tolist(),
)

# Finally, we can use the ESMProtein class to compose the sequence and structure prompts into a single prompt that can be passed to ESM3
protein_prompt = ESMProtein(sequence=sequence_prompt, coordinates=structure_prompt)
Sequence prompt:  ________________________________________________________________________VEGGHSIDSSLGVLRALYQLGMR_________________________________________________________________________________________________________
Length of sequence prompt:  200
Structure prompt shape:  torch.Size([200, 37, 3])
Indices with structure conditioning:  [72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94]

Now, we can use the generate method of the model to iteratively sample a protein sequence based on the prompt. Under the hood, the model performs num_steps forward passes and samples a set of tokens at each step until the chosen track being generated is fully unmasked.

In [11]:
# We'll have to first construct a `GenerationConfig` object that specifies the decoding parameters that we want to use
sequence_generation_config = GenerationConfig(
    track="sequence",  # We want ESM3 to generate tokens for the sequence track
    num_steps=sequence_prompt.count("_")
    // 2,  # We'll use num(mask tokens) // 2 steps to decode the sequence
    temperature=0.5,  # We'll use a temperature of 0.5 to control the randomness of the decoding process
)

# Now, we can use the `generate` method of the model to decode the sequence
sequence_generation = model.generate(protein_prompt, sequence_generation_config)
print("Sequence Prompt:\n\t", protein_prompt.sequence)
print("Generated sequence:\n\t", sequence_generation.sequence)
Sequence Prompt:
	 ________________________________________________________________________VEGGHSIDSSLGVLRALYQLGMR_________________________________________________________________________________________________________
Generated sequence:
	 DLAALRAGGVDAQFFAVYVPPEYAGRAVEATLEQIAAVHRLVARHPDRLALARTAADVRAARAAGRIAALIGVEGGHSIDSSLGVLRALYQLGMRYMTLTWNDANDWADGVTEPRGGGLSAFGREVVAEMNRLGMLVDLSHISERTFWDVLALSRAPAIASHSNARALCDHPRNLTDAQLRALAASGGVVMVNFYSAFVA

We can also use the generate method to predict the structure of the generated sequence by iteratively sampling structure tokens.

In [12]:
structure_prediction_config = GenerationConfig(
    track="structure",  # We want ESM3 to generate tokens for the structure track
    num_steps=len(sequence_generation) // 8,
    temperature=0.7,
)
structure_prediction_prompt = ESMProtein(sequence=sequence_generation.sequence)
structure_prediction = model.generate(
    structure_prediction_prompt, structure_prediction_config
)

Now, we can visualize the generated structure using py3Dmol. We'll visualize the generated structure (right, green) alongside the original structure (left, grey) from which the motif was drawn. The motif residues are colored in cyan.

In [13]:
# Convert the generated structure to a back into a ProteinChain object
structure_prediction_chain = structure_prediction.to_protein_chain()
# Align the generated structure to the original structure using the motif residues
motif_inds_in_generation = np.arange(72, 72 + len(motif_sequence))
structure_prediction_chain.align(
    renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds
)
crmsd = structure_prediction_chain.rmsd(
    renal_dipep_chain, mobile_inds=motif_inds_in_generation, target_inds=motif_inds
)
print(
    "cRMSD of the motif in the generated structure vs the original structure: ", crmsd
)

view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(pdb_str, "pdb", viewer=(0, 0))
view.addModel(structure_prediction_chain.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "lightgrey"}}, viewer=(0, 0))
view.setStyle({"cartoon": {"color": "lightgreen"}}, viewer=(0, 1))
view.addStyle({"resi": motif_res_inds}, {"cartoon": {"color": "cyan"}}, viewer=(0, 0))
view.addStyle(
    {"resi": (motif_inds_in_generation + 1).tolist()},
    {"cartoon": {"color": "cyan"}},
    viewer=(0, 1),
)
view.zoomTo()
view.show()
cRMSD of the motif in the generated structure vs the original structure:  0.4935894152515273

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

Secondary Structure Editing Example: Helix Shortening¶

Now, we can try another generation task with ESM3. We'll use the secondary structure track, along with the sequence track, to shorten a helix-coil-helix region (residues 39-111) in a protein structure (colored in blue below)

In [14]:
helix_shortening_chain = ProteinChain.from_rcsb("7XBQ", "A")
view = py3Dmol.view(width=500, height=500)
view.addModel(helix_shortening_chain.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
helix_region = np.arange(38, 111)  # zero-indexed
view.addStyle(
    {"resi": (helix_region + 1).tolist()}, {"cartoon": {"color": "lightblue"}}
)
view.zoomTo()
view.show()
helix_shortening_ss8 = "CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC"
print(
    "Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) \n\t",
    helix_shortening_ss8,
)

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

Secondary structure of protein: (H: Alpha Helix, E: Beta Strand, C: Coil) 
	 CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC

The helix-coil-helix region in the original protein is 73 residues long. We will try to shorten it to 45 residues by prompting the model with partial sequence and secondary structure

In [15]:
shortened_region_length = 45

# We'll construct a sequence prompt that masks the (shortened) helix-coil-helix region, but leaves the flanking regions unmasked
sequence_prompt = (
    helix_shortening_chain.sequence[: helix_region[0]]
    + "_" * shortened_region_length
    + helix_shortening_chain.sequence[helix_region[-1] + 1 :]
)
print("Sequence prompt:\n\t", sequence_prompt)

# We'll construct a secondary structure prompt that retains the secondary structure of the flanking regions, and shortens the lengths of helices in the helix-coil-helix region
ss8_prompt = (
    helix_shortening_ss8[: helix_region[0]]
    + (
        ((shortened_region_length - 3) // 2) * "H"
        + "C" * 3
        + ((shortened_region_length - 3) // 2) * "H"
    )
    + helix_shortening_ss8[helix_region[-1] + 1 :]
)
print("SS8 prompt:\n\t", ss8_prompt)
print(
    "Proposed SS8 for shortened helix-coil-helix region:\n\t",
    " " * helix_region[0] + ss8_prompt[helix_region[0] : helix_region[0] + 45],
)

print("")
print("Original sequence:\n\t", helix_shortening_chain.sequence)
print("Original SS8:\n\t", helix_shortening_ss8)
print(
    "Original SS8 for helix-coil-helix region:\n\t",
    " " * helix_region[0]
    + helix_shortening_ss8[helix_region[0] : helix_region[-1] + 1],
)


# We can again use the ESMProtein class to compose the sequence and secondary structure prompts into a single prompt that can be passed to ESM3
protein_prompt = ESMProtein(sequence=sequence_prompt, secondary_structure=ss8_prompt)
Sequence prompt:
	 MAREENVYMAKLAEQAERYEEMVQFMEKVSTSLGSEEL_____________________________________________SASNGDSKVFYLKMKGDYHRYLAEFKTGAERKEAAESTLSAYKAAQDIANTELAPTHPIRLGLALNFSVFYYEILNSPDRACNLAKQAFDEAIAELDTLGEESYKDSTLIMQLLRDNLTLWT
SS8 prompt:
	 CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCHHHHHHHHHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHHHHHHHHGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC
Proposed SS8 for shortened helix-coil-helix region:
	                                       HHHHHHHHHHHHHHHHHHHHHCCCHHHHHHHHHHHHHHHHHHHHH

Original sequence:
	 MAREENVYMAKLAEQAERYEEMVQFMEKVSTSLGSEELTVEERNLLSVAYKNVIGARRASWRIISSIEQKEESRGNEEHVKCIKEYRSKIESELSNICDGILKLLDSNLIPSASNGDSKVFYLKMKGDYHRYLAEFKTGAERKEAAESTLSAYKAAQDIANTELAPTHPIRLGLALNFSVFYYEILNSPDRACNLAKQAFDEAIAELDTLGEESYKDSTLIMQLLRDNLTLWT
Original SS8:
	 CCCSHHHHHHHHHHHTTCHHHHHHHHHHHHHTCSSCCCCHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGGGCCSHHHHHHHHHHHHHHHHHHHHHCCHHHHHHHHHHHHHHHHHHHHHHHHHSCTTCHHHHHHHHHHHHHIIIIICCHHHHHHHHHHHHHHHHTTCTTCCSSHHHHHHHHHHHHHHHHHHHC
Original SS8 for helix-coil-helix region:
	                                       CHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHHTTCHHHHHHHHHHHHHHHHHHHHHHHHHHHHIIIIIGG

We can again use the generate method of the model to iteratively decode a protein sequence based on the prompt

In [16]:
print("Generating protein sequence...")
sequence_generation = model.generate(
    protein_prompt,
    GenerationConfig(
        track="sequence",
        num_steps=protein_prompt.sequence.count("_") // 2,
        temperature=0.5,
    ),
)
print("Folding protein...")
structure_prediction = model.generate(
    ESMProtein(sequence=sequence_generation.sequence),
    GenerationConfig(
        track="structure", num_steps=len(protein_prompt) // 4, temperature=0
    ),
)
Generating protein sequence...
Folding protein...

Now, we can visualize the generated structure using py3Dmol. We'll visualize the generated structure (right) alongside the original structure (left) from which the motif was drawn. The helix-coil-helix region in the original structure is colored in blue and the shortened region in the generated structure is colored in red.

In [17]:
predicted_chain = structure_prediction.to_protein_chain()
predicted_chain = predicted_chain.align(
    helix_shortening_chain,
    mobile_inds=np.arange(len(predicted_chain) - 120, len(predicted_chain)),
    target_inds=np.arange(
        len(helix_shortening_chain) - 120, len(helix_shortening_chain)
    ),
)
view = py3Dmol.view(width=1000, height=500, viewergrid=(1, 2))
view.addModel(helix_shortening_chain.to_pdb_string(), "pdb", viewer=(0, 0))
view.addModel(predicted_chain.to_pdb_string(), "pdb", viewer=(0, 1))
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addStyle(
    {"resi": (helix_region + 1).tolist()},
    {"cartoon": {"color": "lightblue"}},
    viewer=(0, 0),
)
view.addStyle(
    {"resi": (np.arange(helix_region[0], helix_region[0] + 45) + 1).tolist()},
    {"cartoon": {"color": "red"}},
    viewer=(0, 1),
)
view.zoomTo()
view.show()

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

SASA Editing Example: Exposing a buried helix¶

Let's grab 1LBS from the PDB and visualize it using py3Dmol. 1LBS has an alternating alpha-beta sandwich fold, with a buried helix in the center, highlighted in red

In [7]:
lipase_chain = ProteinChain.from_rcsb("1LBS", "A")
span_start = 105
span_end = 116
view = py3Dmol.view(width=500, height=500)
view.addModel(lipase_chain.to_pdb_string(), "pdb")
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addStyle(
    {"resi": (np.arange(span_start, span_end) + 1).tolist()},
    {"cartoon": {"color": "red"}},
)
view.zoomTo()
view.show()
lipase_ss8 = "CCSSCCCCSSCHHHHHHTEEETTBBTTBCSSEEEEECCTTCCHHHHHTTTHHHHHHHTTCEEEEECCTTTTCSCHHHHHHHHHHHHHHHHHHTTSCCEEEEEETHHHHHHHHHHHHCGGGGGTEEEEEEESCCTTCBGGGHHHHHTTCBCHHHHHTBTTCHHHHHHHHTTTTBCSSCEEEEECTTCSSSCCCCSSSTTSTTCCBTSEEEEHHHHHCTTCCCCSHHHHHBHHHHHHHHHHHHCTTSSCCGGGCCSTTCCCSBCTTSCHHHHHHHHSTHHHHHHHHHHSCCBSSCCCCCGGGGGGSTTCEETTEECCC"

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

We can construct a multimodal prompt for ESM3 to instruct it to expose the buried helix as follows:

  1. Prompt with the structure of the buried helix highlighted in red -- this will prompt ESM3 to generate a protein that contains that same helix
  2. Prompt with high SASA values for the residues in the buried helix -- this will prompt ESM3 to expose the helix to the surface of the protein
In [8]:
structure_prompt = torch.full((len(lipase_chain), 37, 3), torch.nan)
structure_prompt[span_start:span_end] = torch.tensor(
    lipase_chain[span_start:span_end].atom37_positions, dtype=torch.float32
)

sasa_prompt = [None] * len(lipase_chain)
sasa_prompt[span_start:span_end] = [40.0] * (span_end - span_start)

print("SASA prompt (just for buried region): ", sasa_prompt[span_start:span_end])

protein_prompt = ESMProtein(
    sequence="_" * len(lipase_chain), coordinates=structure_prompt, sasa=sasa_prompt
)
SASA prompt (just for buried region):  [40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0, 40.0]

This is a more difficult task, so you may need to sample more generations from ESM before you find a solution. We'll sample 16 here and sort by the generations with the highest predicted TM-score (pTM) by ESM3.

In [ ]:
import concurrent.futures


def generate_protein_sequence_and_structure(protein_prompt, model):
    sequence_generation = model.generate(
        protein_prompt,
        GenerationConfig(
            track="sequence",
            num_steps=protein_prompt.sequence.count("_") // 2,
            temperature=0.5,
        ),
    )
    structure_prediction = model.generate(
        ESMProtein(sequence=sequence_generation.sequence),
        GenerationConfig(
            track="structure", num_steps=len(protein_prompt) // 4, temperature=0.7
        ),
    )
    return structure_prediction


N_SAMPLES = 16
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
    futures = [
        executor.submit(generate_protein_sequence_and_structure, protein_prompt, model)
        for _ in range(N_SAMPLES)
    ]

    generated_proteins = [future.result() for future in futures]


# Sort generations by ptm
generated_proteins = sorted(
    generated_proteins, key=lambda x: x.ptm.item(), reverse=True
)

Let's visualize the top 4 generations by pTM, alongside with the original protein (on the left)

In [12]:
N_SAMPLES_TO_SHOW = 4
view = py3Dmol.view(width=2000, height=500, viewergrid=(1, N_SAMPLES_TO_SHOW + 1))
view.addModel(lipase_chain.to_pdb_string(), "pdb", viewer=(0, 0))
for i in range(N_SAMPLES_TO_SHOW):
    print(
        "PTM of generated protein {}: {:.2f}".format(
            i + 1, generated_proteins[i].ptm.item()
        )
    )
    view.addModel(
        generated_proteins[i].to_protein_chain().to_pdb_string(),
        "pdb",
        viewer=(0, i + 1),
    )
view.setStyle({"cartoon": {"color": "lightgrey"}})
view.addStyle(
    {"resi": (np.arange(span_start, span_end) + 1).tolist()},
    {"cartoon": {"color": "red"}},
)
view.zoomTo()
view.show()
PTM of generated protein 1: 0.93
PTM of generated protein 2: 0.92
PTM of generated protein 3: 0.92
PTM of generated protein 4: 0.67

3Dmol.js failed to load for some reason. Please check your browser console for error messages.