Skip to content
GitLab
探索
登录
主导航
搜索或转到…
项目
A
alphafold2
管理
动态
成员
标记
计划
议题
0
议题看板
里程碑
迭代
Wiki
代码
合并请求
0
仓库
分支
提交
标签
仓库图
比较修订版本
代码片段
锁定的文件
构建
流水线
作业
流水线计划
产物
部署
发布
软件包库
运维
环境
Terraform 模块
监控
事件
服务台
分析
价值流分析
Contributor analytics
CI/CD 分析
仓库分析
代码评审分析
议题分析
模型实验
帮助
帮助
支持
GitLab 文档
比较 GitLab 各版本
社区论坛
为极狐GitLab 提交贡献
提交反馈
快捷键
?
支持
扫描加入微信群:
1. 获取企业级DevOps解决方案支持
2. 免费或折扣极狐GitLab 官方培训认证
代码片段
群组
项目
HPCSource
alphafold2
提交
92afe510
提交
92afe510
编辑于
3年前
作者:
Phil Wang
浏览文件
操作
下载
补丁
差异文件
protect against edge case where MSA row is made of padding
上级
2c9fed09
分支
分支 包含提交
标签
0.0.78
标签 包含提交
无相关合并请求
变更
3
隐藏空白变更内容
行内
左右并排
显示
3 个更改的文件
README.md
+1
-1
1 个添加, 1 个删除
README.md
alphafold2_pytorch/embeds.py
+26
-5
26 个添加, 5 个删除
alphafold2_pytorch/embeds.py
setup.py
+1
-1
1 个添加, 1 个删除
setup.py
有
28 个添加
和
7 个删除
README.md
+
1
−
1
浏览文件 @
92afe510
...
...
@@ -180,7 +180,7 @@ alphafold2 = Alphafold2(
Todo:
-
[x] make sure MSA Transformer embeddings work
-
[
] process MSA embeddings one by one if any rows are pure padding
-
[
x
] process MSA embeddings one by one if any rows are pure padding
-
[ ] make sure ESM embedding wrapper works
## Real-Value Distance Prediction
...
...
This diff is collapsed.
Click to expand it.
alphafold2_pytorch/embeds.py
+
26
−
5
浏览文件 @
92afe510
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
alphafold2_pytorch.utils
import
get_msa_embedd
,
get_esm_embedd
from
alphafold2_pytorch.utils
import
get_msa_embedd
,
get_esm_embedd
,
exists
from
alphafold2_pytorch.constants
import
MSA_MODEL_PATH
,
MSA_EMBED_DIM
from
einops
import
rearrange
...
...
@@ -18,17 +19,37 @@ class MSAEmbedWrapper(nn.Module):
self
.
batch_converter
=
batch_converter
self
.
project_embed
=
nn
.
Linear
(
MSA_EMBED_DIM
,
alphafold2
.
dim
)
if
MSA_EMBED_DIM
!=
alphafold2
.
dim
else
nn
.
Identity
()
def
forward
(
self
,
seq
,
msa
,
**
kwargs
):
def
forward
(
self
,
seq
,
msa
,
msa_mask
=
None
,
**
kwargs
):
assert
seq
.
shape
[
-
1
]
==
msa
.
shape
[
-
1
],
'
sequence and msa must have the same length if you wish to use MSA transformer embeddings
'
model
,
batch_converter
,
device
=
self
.
model
,
self
.
batch_converter
,
seq
.
device
seq_and_msa
=
torch
.
cat
((
seq
.
unsqueeze
(
1
),
msa
),
dim
=
1
)
embeds
=
get_msa_embedd
(
seq_and_msa
,
model
,
batch_converter
,
device
=
device
)
embeds
=
self
.
project_embed
(
embeds
)
if
exists
(
msa_mask
):
# in the event that there are rows in the MSA that are completely padding
# process each batch element individually, so that padding isn't processed
# with row-tied attention
num_msa
=
msa_mask
.
any
(
dim
=
-
1
).
sum
(
dim
=
-
1
).
tolist
()
seq_and_msa_list
=
seq_and_msa
.
unbind
(
dim
=
0
)
num_rows
=
seq_and_msa
.
shape
[
1
]
embeds
=
[]
for
num
,
batch_el
in
zip
(
num_msa
,
seq_and_msa_list
):
batch_el
=
rearrange
(
batch_el
,
'
... -> () ...
'
)
batch_el
=
batch_el
[:,
:
num
]
embed
=
get_msa_embedd
(
batch_el
,
model
,
batch_converter
,
device
=
device
)
embed
=
F
.
pad
(
embed
,
(
0
,
0
,
0
,
0
,
0
,
num_rows
-
num
),
value
=
0.
)
embeds
.
append
(
embed
)
embeds
=
torch
.
cat
(
embeds
,
dim
=
0
)
else
:
embeds
=
get_msa_embedd
(
seq_and_msa
,
model
,
batch_converter
,
device
=
device
)
embeds
=
self
.
project_embed
(
embeds
)
seq_embed
,
msa_embed
=
embeds
[:,
0
],
embeds
[:,
1
:]
return
self
.
alphafold2
(
seq
,
msa
,
seq_embed
=
seq_embed
,
msa_embed
=
msa_embed
,
**
kwargs
)
return
self
.
alphafold2
(
seq
,
msa
,
seq_embed
=
seq_embed
,
msa_embed
=
msa_embed
,
msa_mask
=
msa_mask
,
**
kwargs
)
class
ESMEmbedWrapper
(
nn
.
Module
):
def
__init__
(
self
,
*
,
alphafold2
):
...
...
This diff is collapsed.
Click to expand it.
setup.py
+
1
−
1
浏览文件 @
92afe510
...
...
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup
(
name
=
'
alphafold2-pytorch
'
,
packages
=
find_packages
(),
version
=
'
0.0.8
5
'
,
version
=
'
0.0.8
6
'
,
license
=
'
MIT
'
,
description
=
'
AlphaFold2 - Pytorch
'
,
author
=
'
Phil Wang, Eric Alcaide
'
,
...
...
This diff is collapsed.
Click to expand it.
预览
0%
请重试
或
添加新附件
.
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
保存评论
取消
想要评论请
注册
或
登录