Skip to content
代码片段 群组 项目
未验证 提交 946dbd62 编辑于 作者: YuliangLiu0306's avatar YuliangLiu0306 提交者: GitHub
浏览文件

[hotfix]fix bugs caused by refactored pipeline (#1133)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c.

* [hotfix]fix bugs caused by refactored pipeline
上级 789cad30
无相关合并请求
......@@ -67,8 +67,8 @@ class NonPipelineSchedule(BaseSchedule):
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data = self.load_batch(data_iter)
if self.batch_data_process_func:
data, label = self.batch_data_process_func(batch_data)
if self.data_process_func:
data, label = self.data_process_func(batch_data)
else:
# if not batch data process func is given,
# then we regard the batch data as a simple tuple of (data, label)
......
......@@ -141,6 +141,8 @@ class PipelineSchedule(BaseSchedule):
for element in data:
if isinstance(element, dict):
data_dict.update({k: v[offset:offset + self.microbatch_size] for k, v in element.items()})
elif data_dict:
data_dict['label'] = element[offset:offset + self.microbatch_size]
if data_dict:
return data_dict
return [val[offset:offset + self.microbatch_size] for val in data]
......@@ -175,7 +177,10 @@ class PipelineSchedule(BaseSchedule):
elif isinstance(data, (list, tuple)):
return model(*data)
elif isinstance(data, dict):
return model(**data)
stage_output = None
if 'stage_output' in data:
stage_output = data.pop('stage_output')
return model(stage_output, **data)
else:
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
......@@ -204,41 +209,14 @@ class PipelineSchedule(BaseSchedule):
data = stage_output
_, label = micro_batch_data
elif isinstance(micro_batch_data, dict):
args = []
data = {}
label = {}
# we feed the stage output to args first
# then map each arg in args to its param name
if stage_output is not None:
if isinstance(stage_output, torch.Tensor):
args.append(stage_output)
elif isinstance(stage_output, (list, tuple)):
args.extend(stage_output)
else:
raise TypeError(
f"Expected the values passed from previous pipeline stage to be torch.Tensor, list or tuple, but got {type(input_obj)}"
)
# get all parameter names for the forward function of the model
fwd_sig = self._get_actual_forward_func(model)
fwd_sig_param_name = [p.name for p in fwd_sig.parameters.values()]
# build the kwargs for the forward function
for idx, param_name in enumerate(fwd_sig_param_name):
if idx < len(args):
data[param_name] = args[idx]
else:
if param_name in micro_batch_data:
data[param_name] = micro_batch_data[param_name]
# get the tensors for loss
loss_sig = inspect.signature(criterion)
loss_sig_param_name = [p.name for p in loss_sig.parameters.values()]
for param_name in loss_sig_param_name:
if param_name in micro_batch_data:
label[param_name] = micro_batch_data[param_name]
data['stage_output'] = stage_output
if 'label' in micro_batch_data:
label = micro_batch_data.pop('label')
else:
label = None
load_data = micro_batch_data
data.update(load_data)
return data, label
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
......
......@@ -66,8 +66,11 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
modified_args = []
for arg in args:
if isinstance(arg, torch.nn.Module):
# (lyl)TODO: if nn.Module is an argument of the root module, then we should just record the module instance itself.
arg = self._layer_spec_dict[id(arg)]
# if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build.
# if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context.
if id(arg) in self._layer_spec_dict:
arg = self._layer_spec_dict[id(arg)]
modified_args.append(arg)
# to the same for the keyword arguments
......
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册