获取提交引用时发生错误。请稍后再试。
Update implicit broadcast output shape element type to match operand.
Some binary ops return types whose element type differs from the operand element types (e.g. Ne). While broadcasting, it should use the element type of the operand instead of the inferred type. Before: ``` HloModule UnboundedBinaryOpTest_8.15, entry_computation_layout={(f32[?,10]{1,0}, f32[1]{0})->pred[?,10]{1,0}} ENTRY %UnboundedBinaryOpTest_8.15 (lhs.1: f32[?,10], rhs.2: f32[1]) -> pred[?,10] { %constant.3 = s32[1]{0} constant({1}) %lhs.1 = f32[?,10]{1,0} parameter(0) %get-dimension-size.4 = s32[] get-dimension-size(f32[?,10]{1,0} %lhs.1), dimensions={0} %reshape.5 = s32[1]{0} reshape(s32[] %get-dimension-size.4) %constant.6 = s32[1]{0} constant({10}) %concatenate.9 = s32[2]{0} concatenate(s32[1]{0} %reshape.5, s32[1]{0} %constant.6), dimensions={0} %constant.7 = s32[1]{0} constant({1}) %constant.8 = s32[1]{0} constant({1}) %concatenate.10 = s32[2]{0} concatenate(s32[1]{0} %constant.7, s32[1]{0} %constant.8), dimensions={0} %maximum.11 = s32[2]{0} maximum(s32[2]{0} %concatenate.9, s32[2]{0} %concatenate.10) %custom-call.12 = pred[?,10]{1,0} custom-call(f32[?,10]{1,0} %lhs.1, s32[2]{0} %maximum.11), custom_call_target="mhlo.dynamic_broadcast_in_dim", backend_config={broadcast_dimensions=[0,1]} %rhs.2 = f32[1]{0} parameter(1) %custom-call.13 = pred[?,10]{1,0} custom-call(f32[1]{0} %rhs.2, s32[2]{0} %maximum.11), custom_call_target="mhlo.dynamic_broadcast_in_dim", backend_config={broadcast_dimensions=[1]} ROOT %compare.14 = pred[?,10]{1,0} compare(pred[?,10]{1,0} %custom-call.12, pred[?,10]{1,0} %custom-call.13), direction=NE } ``` After: ``` HloModule UnboundedBinaryOpTest_8.15, entry_computation_layout={(f32[?,10]{1,0}, f32[1]{0})->pred[?,10]{1,0}} ENTRY %UnboundedBinaryOpTest_8.15 (lhs.1: f32[?,10], rhs.2: f32[1]) -> pred[?,10] { %constant.3 = s32[1]{0} constant({1}) %lhs.1 = f32[?,10]{1,0} parameter(0) %get-dimension-size.4 = s32[] get-dimension-size(f32[?,10]{1,0} %lhs.1), dimensions={0} %reshape.5 = s32[1]{0} reshape(s32[] %get-dimension-size.4) %constant.6 = s32[1]{0} constant({10}) %concatenate.9 = s32[2]{0} concatenate(s32[1]{0} %reshape.5, s32[1]{0} %constant.6), dimensions={0} %constant.7 = s32[1]{0} constant({1}) %constant.8 = s32[1]{0} constant({1}) %concatenate.10 = s32[2]{0} concatenate(s32[1]{0} %constant.7, s32[1]{0} %constant.8), dimensions={0} %maximum.11 = s32[2]{0} maximum(s32[2]{0} %concatenate.9, s32[2]{0} %concatenate.10) %custom-call.12 = f32[?,10]{1,0} custom-call(f32[?,10]{1,0} %lhs.1, s32[2]{0} %maximum.11), custom_call_target="mhlo.dynamic_broadcast_in_dim", backend_config={broadcast_dimensions=[0,1]} %rhs.2 = f32[1]{0} parameter(1) %custom-call.13 = f32[?,10]{1,0} custom-call(f32[1]{0} %rhs.2, s32[2]{0} %maximum.11), custom_call_target="mhlo.dynamic_broadcast_in_dim", backend_config={broadcast_dimensions=[1]} ROOT %compare.14 = pred[?,10]{1,0} compare(f32[?,10]{1,0} %custom-call.12, f32[?,10]{1,0} %custom-call.13), direction=NE } ``` Notice the difference in `mhlo.dynamic_broadcast_in_dim` custom call return type `pred[?,10]` vs `f32[?,10]`. The HLO after the change correctly returns unchanged element type from the operand. PiperOrigin-RevId: 617345816
想要评论请 注册 或 登录