Skip to content
代码片段 群组 项目
提交 5aefd2fe 编辑于 作者: Phil Wang's avatar Phil Wang
浏览文件

make sure sparse attention gets relative positional encoding as well

上级 14e4f282
分支
标签 0.0.78
无相关合并请求
......@@ -262,6 +262,11 @@ class Attention(nn.Module):
self.rotary_sinu_emb = FixedPositionalEmbedding(dim_head) if rotary_rpe else None
def apply_rpe(self, q, k):
sinu_emb = self.rotary_sinu_emb(q)
q, k = apply_rotary_pos_emb(q, k, sinu_emb)
return q, k
def forward(self, x, context = None, mask = None, context_mask = None, tie_attn_dim = None, **kwargs):
device, orig_shape, h, has_context = x.device, x.shape, self.heads, exists(context)
......@@ -300,8 +305,7 @@ class Attention(nn.Module):
# rotary relative positional encoding
if exists(self.rotary_sinu_emb) and not has_context:
sinu_emb = self.rotary_sinu_emb(q)
q, k = apply_rotary_pos_emb(q, k, sinu_emb)
q, k = self.apply_rpe(q, k)
# for tying row-attention, for MSA axial self-attention
......@@ -390,6 +394,9 @@ class SparseAttention(Attention):
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
if exists(self.rotary_sinu_emb):
q, k = self.apply_rpe(q, k)
key_pad_mask = None
if exists(mask):
key_pad_mask = repeat(~mask, 'b n -> b h n', h = h)
......
......@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
name = 'alphafold2-pytorch',
packages = find_packages(),
version = '0.0.77',
version = '0.0.78',
license='MIT',
description = 'AlphaFold2 - Pytorch',
author = 'Phil Wang, Eric Alcaide',
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册