Skip to content
代码片段 群组 项目
提交 92374260 编辑于 作者: Gunhyun Park's avatar Gunhyun Park 提交者: TensorFlower Gardener
浏览文件

Update implicit broadcast output shape element type to match operand.

Some binary ops return types whose element type differs from the operand element types (e.g. Ne). While broadcasting, it should use the element type of the operand instead of the inferred type.

Before:
```
HloModule UnboundedBinaryOpTest_8.15, entry_computation_layout={(f32[?,10]{1,0}, f32[1]{0})->pred[?,10]{1,0}}

ENTRY %UnboundedBinaryOpTest_8.15 (lhs.1: f32[?,10], rhs.2: f32[1]) -> pred[?,10] {
  %constant.3 = s32[1]{0} constant({1})
  %lhs.1 = f32[?,10]{1,0} parameter(0)
  %get-dimension-size.4 = s32[] get-dimension-size(f32[?,10]{1,0} %lhs.1), dimensions={0}
  %reshape.5 = s32[1]{0} reshape(s32[] %get-dimension-size.4)
  %constant.6 = s32[1]{0} constant({10})
  %concatenate.9 = s32[2]{0} concatenate(s32[1]{0} %reshape.5, s32[1]{0} %constant.6), dimensions={0}
  %constant.7 = s32[1]{0} constant({1})
  %constant.8 = s32[1]{0} constant({1})
  %concatenate.10 = s32[2]{0} concatenate(s32[1]{0} %constant.7, s32[1]{0} %constant.8), dimensions={0}
  %maximum.11 = s32[2]{0} maximum(s32[2]{0} %concatenate.9, s32[2]{0} %concatenate.10)
  %custom-call.12 = pred[?,10]{1,0} custom-call(f32[?,10]{1,0} %lhs.1, s32[2]{0} %maximum.11), custom_call_target="mhlo.dynamic_broadcast_in_dim", backend_config={broadcast_dimensions=[0,1]}
  %rhs.2 = f32[1]{0} parameter(1)
  %custom-call.13 = pred[?,10]{1,0} custom-call(f32[1]{0} %rhs.2, s32[2]{0} %maximum.11), custom_call_target="mhlo.dynamic_broadcast_in_dim", backend_config={broadcast_dimensions=[1]}
  ROOT %compare.14 = pred[?,10]{1,0} compare(pred[?,10]{1,0} %custom-call.12, pred[?,10]{1,0} %custom-call.13), direction=NE
}
```

After:
```
HloModule UnboundedBinaryOpTest_8.15, entry_computation_layout={(f32[?,10]{1,0}, f32[1]{0})->pred[?,10]{1,0}}

ENTRY %UnboundedBinaryOpTest_8.15 (lhs.1: f32[?,10], rhs.2: f32[1]) -> pred[?,10] {
  %constant.3 = s32[1]{0} constant({1})
  %lhs.1 = f32[?,10]{1,0} parameter(0)
  %get-dimension-size.4 = s32[] get-dimension-size(f32[?,10]{1,0} %lhs.1), dimensions={0}
  %reshape.5 = s32[1]{0} reshape(s32[] %get-dimension-size.4)
  %constant.6 = s32[1]{0} constant({10})
  %concatenate.9 = s32[2]{0} concatenate(s32[1]{0} %reshape.5, s32[1]{0} %constant.6), dimensions={0}
  %constant.7 = s32[1]{0} constant({1})
  %constant.8 = s32[1]{0} constant({1})
  %concatenate.10 = s32[2]{0} concatenate(s32[1]{0} %constant.7, s32[1]{0} %constant.8), dimensions={0}
  %maximum.11 = s32[2]{0} maximum(s32[2]{0} %concatenate.9, s32[2]{0} %concatenate.10)
  %custom-call.12 = f32[?,10]{1,0} custom-call(f32[?,10]{1,0} %lhs.1, s32[2]{0} %maximum.11), custom_call_target="mhlo.dynamic_broadcast_in_dim", backend_config={broadcast_dimensions=[0,1]}
  %rhs.2 = f32[1]{0} parameter(1)
  %custom-call.13 = f32[?,10]{1,0} custom-call(f32[1]{0} %rhs.2, s32[2]{0} %maximum.11), custom_call_target="mhlo.dynamic_broadcast_in_dim", backend_config={broadcast_dimensions=[1]}
  ROOT %compare.14 = pred[?,10]{1,0} compare(f32[?,10]{1,0} %custom-call.12, f32[?,10]{1,0} %custom-call.13), direction=NE
}
```

Notice the difference in `mhlo.dynamic_broadcast_in_dim` custom call return type `pred[?,10]` vs `f32[?,10]`. The HLO after the change correctly returns unchanged element type from the operand.

PiperOrigin-RevId: 617345816
上级 76e23e8e
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册