Skip to content
代码片段 群组 项目
提交 92afe510 编辑于 作者: Phil Wang's avatar Phil Wang
浏览文件

protect against edge case where MSA row is made of padding

上级 2c9fed09
分支
标签 0.0.78
无相关合并请求
......@@ -180,7 +180,7 @@ alphafold2 = Alphafold2(
Todo:
- [x] make sure MSA Transformer embeddings work
- [ ] process MSA embeddings one by one if any rows are pure padding
- [x] process MSA embeddings one by one if any rows are pure padding
- [ ] make sure ESM embedding wrapper works
## Real-Value Distance Prediction
......
import torch
import torch.nn.functional as F
from torch import nn
from alphafold2_pytorch.utils import get_msa_embedd, get_esm_embedd
from alphafold2_pytorch.utils import get_msa_embedd, get_esm_embedd, exists
from alphafold2_pytorch.constants import MSA_MODEL_PATH, MSA_EMBED_DIM
from einops import rearrange
......@@ -18,17 +19,37 @@ class MSAEmbedWrapper(nn.Module):
self.batch_converter = batch_converter
self.project_embed = nn.Linear(MSA_EMBED_DIM, alphafold2.dim) if MSA_EMBED_DIM != alphafold2.dim else nn.Identity()
def forward(self, seq, msa, **kwargs):
def forward(self, seq, msa, msa_mask = None, **kwargs):
assert seq.shape[-1] == msa.shape[-1], 'sequence and msa must have the same length if you wish to use MSA transformer embeddings'
model, batch_converter, device = self.model, self.batch_converter, seq.device
seq_and_msa = torch.cat((seq.unsqueeze(1), msa), dim = 1)
embeds = get_msa_embedd(seq_and_msa, model, batch_converter, device = device)
embeds = self.project_embed(embeds)
if exists(msa_mask):
# in the event that there are rows in the MSA that are completely padding
# process each batch element individually, so that padding isn't processed
# with row-tied attention
num_msa = msa_mask.any(dim = -1).sum(dim = -1).tolist()
seq_and_msa_list = seq_and_msa.unbind(dim = 0)
num_rows = seq_and_msa.shape[1]
embeds = []
for num, batch_el in zip(num_msa, seq_and_msa_list):
batch_el = rearrange(batch_el, '... -> () ...')
batch_el = batch_el[:, :num]
embed = get_msa_embedd(batch_el, model, batch_converter, device = device)
embed = F.pad(embed, (0, 0, 0, 0, 0, num_rows - num), value = 0.)
embeds.append(embed)
embeds = torch.cat(embeds, dim = 0)
else:
embeds = get_msa_embedd(seq_and_msa, model, batch_converter, device = device)
embeds = self.project_embed(embeds)
seq_embed, msa_embed = embeds[:, 0], embeds[:, 1:]
return self.alphafold2(seq, msa, seq_embed = seq_embed, msa_embed = msa_embed, **kwargs)
return self.alphafold2(seq, msa, seq_embed = seq_embed, msa_embed = msa_embed, msa_mask = msa_mask, **kwargs)
class ESMEmbedWrapper(nn.Module):
def __init__(self, *, alphafold2):
......
......@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'alphafold2-pytorch',
packages = find_packages(),
version = '0.0.85',
version = '0.0.86',
license='MIT',
description = 'AlphaFold2 - Pytorch',
author = 'Phil Wang, Eric Alcaide',
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册