Tutorial 1: Input tracks of ESMProtein¶

ESM3 is a frontier generative model for biology, able to jointly reason across three fundamental biological properties of proteins: sequence, structure, and function. These three data modalities are represented as tracks of discrete tokens at the input and output of ESM3. You can present the model with a combination of partial inputs across the tracks, and ESM3 will provide output predictions for all the tracks.

In this notebook, we will familiarize ourselves with the ESMProtein class, which holds multiple properties of a protein representing sequence, structure, and function. The ESM3 models use these properties from the input (prompts) and generate them as part of the output.

An ESMProtein has 5 attributes that represent input (promptable) tracks:

  • sequence: amino acid sequence
  • coordinates: 3D coordinates of atoms in each amino acid of the protein
  • secondary_structure: 8-class secondary structure (SS8)
  • sasa: solvent-accessible surface area (SASA)
  • function_annotations: function annotations derived from InterPro

You can prompt an ESM3 model by setting any subset of these tracks to be partially unmasked when calling the model with an ESMProtein instance.

One way to create an ESMProtein object is from a pdb id and chain id from RCSB. Below, we first create a ProteinChain with the pdb id and chain id and then create an ESMProtein from it. This will populate the sequence and coordinates properties.

Imports¶

InĀ [1]:
# Install esm and other dependencies
! pip install esm
! pip install py3Dmol
! pip install matplotlib
! pip install dna-features-viewer

Create an ESMProtein object¶

InĀ [5]:
from biotite.database import rcsb
from esm.sdk.api import ESMProtein
from esm.utils.structure.protein_chain import ProteinChain
from esm.utils.types import FunctionAnnotation

pdb_id = "1cm4"
chain_id = "A"

# Create a protein using a pdb format file from RCSB
# Note: instead of the next two lines, we could use
# protein_chain = ProteinChain.from_rcsb(pdb_id, chain_id)
# but in future implementations, this function may use the mmcif file
# which would throw off some indices later on in this notebook
str_io = rcsb.fetch(pdb_id, "pdb")
protein_chain = ProteinChain.from_pdb(str_io, chain_id=chain_id, id=pdb_id)
protein = ESMProtein.from_protein_chain(protein_chain)

## We can also load from a local pdb file by passing its path
# protein_chain = ProteinChain.from_pdb('xxxx.pdb', chain_id=chain_id, id=pdb_id)
# The chain_id and id arguments are optional and will be inferred if None

sequence¶

The sequence track contains a sequence of 1-letter representation of the amino acids in the protein:

InĀ [6]:
print(protein.sequence)
LTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMINEVDADGNGTIDFPEFLTMMARKMKDTDSEEEIREAFRVFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIREADIDGDGQVNYEEFVQMMT

coordinates¶

coordinates contains the 3D coordinates of atoms in the protein. It contains a tensor of shape (n_residues, 37, 3), where

  • n_residues is the number of amino acids in the protein.
  • 37 is the maximum possible number of atoms in an amino acid, represented in the atom37 representation. If certain atoms are not present in the structure, they will show up as nan.
  • 3 is for 3D (x,y,z) coordinates.
InĀ [7]:
print(protein.coordinates.shape)
torch.Size([143, 37, 3])
InĀ [8]:
print(protein.coordinates)
tensor([[[29.4900, 55.2300, 89.4950],
         [28.5830, 56.0660, 88.7360],
         [28.2720, 57.3360, 89.5170],
         ...,
         [    nan,     nan,     nan],
         [    nan,     nan,     nan],
         [    nan,     nan,     nan]],

        [[27.5510, 58.2640, 88.8800],
         [27.1540, 59.5320, 89.4890],
         [25.6850, 59.8650, 89.2390],
         ...,
         [    nan,     nan,     nan],
         [    nan,     nan,     nan],
         [    nan,     nan,     nan]],

        [[25.1770, 60.9740, 89.8140],
         [23.7800, 61.4010, 89.6760],
         [23.3690, 61.5970, 88.2230],
         ...,
         [    nan,     nan,     nan],
         [    nan,     nan,     nan],
         [    nan,     nan,     nan]],

        ...,

        [[30.1690, 65.5070, 73.3640],
         [29.0820, 65.4730, 74.3330],
         [28.9170, 64.0580, 74.8810],
         ...,
         [    nan,     nan,     nan],
         [    nan,     nan,     nan],
         [    nan,     nan,     nan]],

        [[29.4320, 63.0820, 74.1230],
         [29.4880, 61.7100, 74.5770],
         [30.6400, 61.5040, 75.5490],
         ...,
         [    nan,     nan,     nan],
         [    nan,     nan,     nan],
         [    nan,     nan,     nan]],

        [[31.8270, 62.0180, 75.2010],
         [33.0270, 61.8650, 76.0150],
         [33.1590, 62.9650, 77.0630],
         ...,
         [    nan,     nan,     nan],
         [    nan,     nan,     nan],
         [    nan,     nan,     nan]]])

We define two functions below that visualize the coordinates attribute: we define two functions below (as before, there is no need to go through them)

  • visualize_3D_coordinates() visualizes directly from the coordinates tensor by creating a pdb file with all alanines
  • visualize_3D_protein() visualizes from the ESMProtein instance, which has the correct amino acids
InĀ [9]:
# Functions for visualizing 3D structure

import py3Dmol


def visualize_pdb(pdb_string):
    view = py3Dmol.view(width=400, height=400)
    view.addModel(pdb_string, "pdb")
    view.setStyle({"cartoon": {"color": "spectrum"}})
    view.zoomTo()
    view.render()
    view.center()
    return view


def visualize_3D_coordinates(coordinates):
    """
    This uses all Alanines
    """
    protein_with_same_coords = ESMProtein(coordinates=coordinates)
    # pdb with all alanines
    pdb_string = protein_with_same_coords.to_pdb_string()
    return visualize_pdb(pdb_string)


def visualize_3D_protein(protein):
    pdb_string = protein.to_pdb_string()
    return visualize_pdb(pdb_string)
InĀ [10]:
# visualize from just the coordinates
visualize_3D_coordinates(protein.coordinates)

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

Out[10]:
<py3Dmol.view at 0x7efeb6dbea50>
InĀ [11]:
# visualize from sequence and coordinates
visualize_3D_protein(protein)

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

Out[11]:
<py3Dmol.view at 0x7efd6de07b10>

secondary_structure¶

The secondary_structure property contains a representation of the secondary structure. At a high level of categorization, we can classify each amino acid as belonging into three classes: alpha helices, beta sheets, and coil, which we could see in the previous 3D visualization.

ESMProtein uses a 8-class secondary structure that can be computed with dssp given 3D atom coordinates. Since installing dssp is a separate process from installing the esm package, in this notebook, we show how to compute the coarser 3-class classification using biotite's annotate_sse. We can set the secondary_structure property with this 3-class classification.

InĀ [12]:
from biotite.structure import annotate_sse


def get_approximate_ss(protein_chain: ProteinChain):
    # get biotite's ss3 representation
    ss3_arr = annotate_sse(protein_chain.atom_array)
    biotite_ss3_str = "".join(ss3_arr)

    # translate into ESM3's representation
    translation_table = str.maketrans(
        {
            "a": "H",  # alpha helix
            "b": "E",  # beta sheet
            "c": "C",  # coil
        }
    )
    esm_ss3 = biotite_ss3_str.translate(translation_table)
    return esm_ss3
InĀ [13]:
protein.secondary_structure = get_approximate_ss(protein_chain)
print(protein.secondary_structure)
CCHHHHHHHHHHHHHHCCCCCCCCCHHHHHHHHHHHCCCCCHHHHHHCCCCCCCCCCCCCCHHHHHHHHHEEEEECCCCCHHHHHHHHHHCCCCCCCCHHHHHHHHHHHCCCCCHHHHHHHHHHHCCCCCCCCCHHHHHHHHC

The next cell defines a function that visualizes the secondary structure and there is no need to read them!

InĀ [14]:
# Slightly modified version of secondary structure plotting code from
# https://www.biotite-python.org/examples/gallery/structure/transketolase_sse.html
# Code source: Patrick Kunzmann
# License: BSD 3 clause

import biotite
import biotite.sequence as seq
import biotite.sequence.graphics as graphics
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle


# Create 'FeaturePlotter' subclasses
# for drawing the secondary structure features
class HelixPlotter(graphics.FeaturePlotter):
    def __init__(self):
        pass

    # Check whether this class is applicable for drawing a feature
    def matches(self, feature):
        if feature.key == "SecStr":
            if "sec_str_type" in feature.qual:
                if feature.qual["sec_str_type"] == "helix":
                    return True
        return False

    # The drawing function itself
    def draw(self, axes, feature, bbox, loc, style_param):
        # Approx. 1 turn per 3.6 residues to resemble natural helix
        n_turns = np.ceil((loc.last - loc.first + 1) / 3.6)
        x_val = np.linspace(0, n_turns * 2 * np.pi, 100)
        # Curve ranges from 0.3 to 0.7
        y_val = (-0.4 * np.sin(x_val) + 1) / 2

        # Transform values for correct location in feature map
        x_val *= bbox.width / (n_turns * 2 * np.pi)
        x_val += bbox.x0
        y_val *= bbox.height
        y_val += bbox.y0

        # Draw white background to overlay the guiding line
        background = Rectangle(
            bbox.p0, bbox.width, bbox.height, color="white", linewidth=0
        )
        axes.add_patch(background)
        axes.plot(x_val, y_val, linewidth=2, color=biotite.colors["dimgreen"])


class SheetPlotter(graphics.FeaturePlotter):
    def __init__(self, head_width=0.8, tail_width=0.5):
        self._head_width = head_width
        self._tail_width = tail_width

    def matches(self, feature):
        if feature.key == "SecStr":
            if "sec_str_type" in feature.qual:
                if feature.qual["sec_str_type"] == "sheet":
                    return True
        return False

    def draw(self, axes, feature, bbox, loc, style_param):
        x = bbox.x0
        y = bbox.y0 + bbox.height / 2
        dx = bbox.width
        dy = 0

        if loc.defect & seq.Location.Defect.MISS_RIGHT:
            # If the feature extends into the previous or next line
            # do not draw an arrow head
            draw_head = False
        else:
            draw_head = True

        axes.add_patch(
            biotite.AdaptiveFancyArrow(
                x,
                y,
                dx,
                dy,
                self._tail_width * bbox.height,
                self._head_width * bbox.height,
                # Create head with 90 degrees tip
                # -> head width/length ratio = 1/2
                head_ratio=0.5,
                draw_head=draw_head,
                color=biotite.colors["orange"],
                linewidth=0,
            )
        )


# Converter for the DSSP secondary structure elements
# to the classical ones
dssp_to_abc = {
    "I": "c",
    "S": "c",
    "H": "a",
    "E": "b",
    "G": "c",
    "B": "b",
    "T": "c",
    "C": "c",
}


def visualize_secondary_structure(sse, first_id):
    """
    Helper function to convert secondary structure array to annotation
    and visualize it.
    """

    def _add_sec_str(annotation, first, last, str_type):
        if str_type == "a":
            str_type = "helix"
        elif str_type == "b":
            str_type = "sheet"
        else:
            # coil
            return
        feature = seq.Feature(
            "SecStr", [seq.Location(first, last)], {"sec_str_type": str_type}
        )
        annotation.add_feature(feature)

    # Find the intervals for each secondary structure element
    # and add to annotation
    annotation = seq.Annotation()
    curr_sse = None
    curr_start = None
    for i in range(len(sse)):
        if curr_start is None:
            curr_start = i
            curr_sse = sse[i]
        else:
            if sse[i] != sse[i - 1]:
                _add_sec_str(
                    annotation, curr_start + first_id, i - 1 + first_id, curr_sse
                )
                curr_start = i
                curr_sse = sse[i]
    # Add last secondary structure element to annotation
    _add_sec_str(annotation, curr_start + first_id, i + first_id, curr_sse)

    fig = plt.figure(figsize=(30.0, 3.0))
    ax = fig.add_subplot(111)
    graphics.plot_feature_map(
        ax,
        annotation,
        symbols_per_line=150,
        loc_range=(first_id, first_id + len(sse)),
        feature_plotters=[HelixPlotter(), SheetPlotter()],
    )
    fig.tight_layout()
    return fig, ax


def plot_ss8(ss8_string):
    ss3 = np.array([dssp_to_abc[e] for e in ss8_string], dtype="U1")
    _, ax = visualize_secondary_structure(ss3, 1)
    ax.set_xticks([])

Using these functions, we can visualize the secondary structure that we obtained. The alpha helices are represented in green, the beta sheets are represented in orange, and coils are represented by gray lines.

Note: because the secondary structure assignment algorithm is not the same one as the one used by 3D visualization, this differs a bit from the cartoon representations in the 3D assignment.

InĀ [15]:
plot_ss8(protein.secondary_structure)
No description has been provided for this image

function_annotations¶

An ESMProtein also contains function annotations derived from InterPro. Annotations directly from InterPro contain information about the following entry types:

  • Family
  • Domain
  • Homologous superfamily
  • Repeat
  • Site (conserved site, active site, binding site, post-translational modification site)
InĀ [16]:
interpro_function_annotations = [
    FunctionAnnotation(label="IPR050145", start=1, end=142),  # 1 indexed, inclusive;
    FunctionAnnotation(label="IPR002048", start=4, end=75),
    FunctionAnnotation(label="IPR002048", start=77, end=144),
    FunctionAnnotation(label="IPR011992", start=1, end=143),
    FunctionAnnotation(label="IPR018247", start=17, end=29),
    FunctionAnnotation(label="IPR018247", start=53, end=65),
    FunctionAnnotation(label="IPR018247", start=90, end=102),
    FunctionAnnotation(label="IPR018247", start=126, end=138),
]

We can visualize these InterPro annotations with the following function:

InĀ [17]:
# Functions for visualizing InterPro function annotations

from dna_features_viewer import GraphicFeature, GraphicRecord
from esm.utils.function.interpro import InterPro, InterProEntryType
from matplotlib import colormaps


def visualize_function_annotations(
    annotations: list[FunctionAnnotation],
    sequence_length: int,
    ax: plt.Axes,
    interpro_=InterPro(),
):
    cmap = colormaps["tab10"]
    colors = [cmap(i) for i in range(len(InterProEntryType))]
    type_colors = dict(zip(InterProEntryType, colors))

    features = []
    for annotation in annotations:
        if annotation.label in interpro_.entries:
            entry = interpro_.entries[annotation.label]
            label = entry.name
            entry_type = entry.type
        else:
            label = annotation.label
            entry_type = InterProEntryType.UNKNOWN

        feature = GraphicFeature(
            start=annotation.start - 1,  # one index -> zero index
            end=annotation.end,
            label=label,
            color=type_colors[entry_type],
            strand=None,
        )
        features.append(feature)

    record = GraphicRecord(
        sequence=None, sequence_length=sequence_length, features=features
    )

    record.plot(figure_width=12, plot_sequence=False, ax=ax)

We plot the InterPro annotations below, with colors indicating the entry type of the InterPro annotation

InĀ [18]:
fig, ax = plt.subplots(figsize=(20.0, 4.0))
visualize_function_annotations(interpro_function_annotations, len(protein), ax)
No description has been provided for this image

When using our ESM3 model, we recommend you use keyword annotations, which are keywords in the description of the InterPro entry and associated Gene Ontology terms from InterPro2GO. For instance, for the InterPro entry IPR011992, the keywords are "domain pair", "hand domain", "ef hand", "pair", and "ef". For more details regarding how the keywords were computed, please refer to our preprint.

Practically, we can derive keyword annotations from the InterPro annotations with the function below. Each InterPro annotation corresponds to multiple keyword annotation covering the same range.

InĀ [19]:
from esm.tokenization import InterProQuantizedTokenizer


def get_keywords_from_interpro(
    interpro_annotations,
    interpro2keywords=InterProQuantizedTokenizer().interpro2keywords,
):
    keyword_annotations_list = []
    for interpro_annotation in interpro_annotations:
        keywords = interpro2keywords.get(interpro_annotation.label, [])
        keyword_annotations_list.extend(
            [
                FunctionAnnotation(
                    label=keyword,
                    start=interpro_annotation.start,
                    end=interpro_annotation.end,
                )
                for keyword in keywords
            ]
        )
    return keyword_annotations_list
InĀ [20]:
protein.function_annotations = get_keywords_from_interpro(interpro_function_annotations)
protein.function_annotations
Out[20]:
[FunctionAnnotation(label='ef', start=4, end=75),
 FunctionAnnotation(label='hand', start=4, end=75),
 FunctionAnnotation(label='ef hand', start=4, end=75),
 FunctionAnnotation(label='hand domain', start=4, end=75),
 FunctionAnnotation(label='calcium', start=4, end=75),
 FunctionAnnotation(label='calcium ion', start=4, end=75),
 FunctionAnnotation(label='calcium', start=4, end=75),
 FunctionAnnotation(label='calcium ion', start=4, end=75),
 FunctionAnnotation(label='cation binding', start=4, end=75),
 FunctionAnnotation(label='ef', start=77, end=144),
 FunctionAnnotation(label='hand', start=77, end=144),
 FunctionAnnotation(label='ef hand', start=77, end=144),
 FunctionAnnotation(label='hand domain', start=77, end=144),
 FunctionAnnotation(label='calcium', start=77, end=144),
 FunctionAnnotation(label='calcium ion', start=77, end=144),
 FunctionAnnotation(label='calcium', start=77, end=144),
 FunctionAnnotation(label='calcium ion', start=77, end=144),
 FunctionAnnotation(label='cation binding', start=77, end=144),
 FunctionAnnotation(label='ef', start=1, end=143),
 FunctionAnnotation(label='hand', start=1, end=143),
 FunctionAnnotation(label='pair', start=1, end=143),
 FunctionAnnotation(label='ef hand', start=1, end=143),
 FunctionAnnotation(label='hand domain', start=1, end=143),
 FunctionAnnotation(label='domain pair', start=1, end=143),
 FunctionAnnotation(label='ef', start=17, end=29),
 FunctionAnnotation(label='hand', start=17, end=29),
 FunctionAnnotation(label='ef hand', start=17, end=29),
 FunctionAnnotation(label='hand 1', start=17, end=29),
 FunctionAnnotation(label='calcium', start=17, end=29),
 FunctionAnnotation(label='site', start=17, end=29),
 FunctionAnnotation(label='calcium binding', start=17, end=29),
 FunctionAnnotation(label='binding site', start=17, end=29),
 FunctionAnnotation(label='ef', start=53, end=65),
 FunctionAnnotation(label='hand', start=53, end=65),
 FunctionAnnotation(label='ef hand', start=53, end=65),
 FunctionAnnotation(label='hand 1', start=53, end=65),
 FunctionAnnotation(label='calcium', start=53, end=65),
 FunctionAnnotation(label='site', start=53, end=65),
 FunctionAnnotation(label='calcium binding', start=53, end=65),
 FunctionAnnotation(label='binding site', start=53, end=65),
 FunctionAnnotation(label='ef', start=90, end=102),
 FunctionAnnotation(label='hand', start=90, end=102),
 FunctionAnnotation(label='ef hand', start=90, end=102),
 FunctionAnnotation(label='hand 1', start=90, end=102),
 FunctionAnnotation(label='calcium', start=90, end=102),
 FunctionAnnotation(label='site', start=90, end=102),
 FunctionAnnotation(label='calcium binding', start=90, end=102),
 FunctionAnnotation(label='binding site', start=90, end=102),
 FunctionAnnotation(label='ef', start=126, end=138),
 FunctionAnnotation(label='hand', start=126, end=138),
 FunctionAnnotation(label='ef hand', start=126, end=138),
 FunctionAnnotation(label='hand 1', start=126, end=138),
 FunctionAnnotation(label='calcium', start=126, end=138),
 FunctionAnnotation(label='site', start=126, end=138),
 FunctionAnnotation(label='calcium binding', start=126, end=138),
 FunctionAnnotation(label='binding site', start=126, end=138)]

We can also visualize the keyword annotations, which all have the same color, indicating it is not a known InterPro entry type.

InĀ [21]:
fig, ax = plt.subplots(figsize=(20.0, 8.0))
visualize_function_annotations(protein.function_annotations, len(protein), ax)
No description has been provided for this image

sasa¶

The final input track of ESMProtein is the solvent-accessible surface area, or SASA. For each amino acid, this track indicates how much of it is accessible to the solvent. We can compute this by ProteinChain's sasa function, which uses biotite's sasa function under the hood.

InĀ [22]:
protein.sasa = protein_chain.sasa()

One way to visualize this track is to represent its values as it varies along the amino acid sequence.

InĀ [23]:
plt.plot(protein.sasa)
Out[23]:
[<matplotlib.lines.Line2D at 0x7efd5ce93890>]
No description has been provided for this image

We can also map these SASA values onto the 3D visualization of the structure, leveraging the fact that we have this protein's 3D coordinates.

First we define which colors map to which values:

InĀ [24]:
cmap = colormaps["cividis"]
clip_sasa_lower = 10
clip_sasa_upper = 90


def plot_heatmap_legend(cmap, clip_sasa_lower, clip_sasa_upper):
    gradient = np.linspace(0, 1, 256)
    gradient = np.vstack((gradient, gradient))
    _, ax = plt.subplots(figsize=(5, 0.3), dpi=350)
    ax.imshow(gradient, aspect="auto", cmap=cmap)
    ax.text(
        0.1,
        -0.3,
        f"{clip_sasa_lower} or lower",
        va="center",
        ha="right",
        fontsize=7,
        transform=ax.transAxes,
    )
    ax.text(
        0.5,
        -0.3,
        f"{(clip_sasa_lower + clip_sasa_upper) // 2}",
        va="center",
        ha="right",
        fontsize=7,
        transform=ax.transAxes,
    )
    ax.text(
        0.9,
        -0.3,
        f"{clip_sasa_upper} or higher",
        va="center",
        ha="left",
        fontsize=7,
        transform=ax.transAxes,
    )
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_xticks([])
    ax.set_yticks([])
    plt.show()


plot_heatmap_legend(cmap, clip_sasa_lower, clip_sasa_upper)
No description has been provided for this image
InĀ [25]:
# Functions for visualizing SASA as colors on the 3D structure


def get_color_strings(sasa, clip_sasa_lower, clip_sasa_upper, cmap):
    transformed_sasa = np.clip(sasa, clip_sasa_lower, clip_sasa_upper)
    transformed_sasa = (transformed_sasa - clip_sasa_lower) / (
        clip_sasa_upper - clip_sasa_lower
    )
    rgbas = (cmap(transformed_sasa) * 255).astype(int)

    return [f"rgb({rgba[0]},{rgba[1]},{rgba[2]})" for rgba in rgbas]


def visualize_sasa_3D_protein(
    protein, clip_sasa_lower=clip_sasa_lower, clip_sasa_upper=clip_sasa_upper, cmap=cmap
):
    pdb_string = protein.to_pdb_string()
    plot_heatmap_legend(cmap, clip_sasa_lower, clip_sasa_upper)
    view = py3Dmol.view(width=400, height=400)
    view.addModel(pdb_string, "pdb")

    for res_pos, res_color in enumerate(
        get_color_strings(protein.sasa, clip_sasa_lower, clip_sasa_upper, cmap)
    ):
        view.setStyle(
            {"chain": "A", "resi": res_pos + 1}, {"cartoon": {"color": res_color}}
        )
    view.zoomTo()
    view.render()
    view.center()

    return view

We visualize SASA on the 3D structure below. Note that the amino acids that are on the inside have lower SASA values, and the amino acids at the surface have higher SASA values.

InĀ [26]:
visualize_sasa_3D_protein(protein)
No description has been provided for this image

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

Out[26]:
<py3Dmol.view at 0x7efd671d4f50>

We have now covered all the tracks of ESMProtein.

We can initialize an ESMProtein by providing any of these tracks. For instance, to initialize a protein with the same coordinates as our protein, we would do:

InĀ [27]:
same_structure_protein = ESMProtein(coordinates=protein.coordinates)

and similarly for any other track.

We hope this helps you get started with our ESM3 models!