diff options
author | Zhaoshi Zheng <zhaoshiz@quicinc.com> | 2024-04-25 13:54:47 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-04-25 13:54:47 -0700 |
commit | dbcc4549e6b75ff328256e3d914763c9a74b2635 (patch) | |
tree | 9e3859a83a445eb0a4f967ad6d177f698d4efd5b | |
parent | 45b59cb1d42b57a40c79a61afc8d1b8892826480 (diff) | |
download | llvm-dbcc4549e6b75ff328256e3d914763c9a74b2635.zip llvm-dbcc4549e6b75ff328256e3d914763c9a74b2635.tar.gz llvm-dbcc4549e6b75ff328256e3d914763c9a74b2635.tar.bz2 |
[MLIR][Vector] Allow Scalable Dim in OneDimMultiReductionToTwoDim (#89978)
To correctly lower multi_reduction of 1-dim scalable vector, e.g., <[4]xf32>
-rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp | 15 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir | 17 |
2 files changed, 26 insertions, 6 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 2f21c50..ac576ed0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -437,8 +437,10 @@ struct OneDimMultiReductionToTwoDim auto loc = multiReductionOp.getLoc(); auto srcVectorType = multiReductionOp.getSourceVectorType(); auto srcShape = srcVectorType.getShape(); - auto castedType = VectorType::get(ArrayRef<int64_t>{1, srcShape.back()}, - srcVectorType.getElementType()); + auto castedType = VectorType::get( + ArrayRef<int64_t>{1, srcShape.back()}, srcVectorType.getElementType(), + ArrayRef<bool>{false, srcVectorType.getScalableDims().back()}); + auto accType = VectorType::get(ArrayRef<int64_t>{1}, srcVectorType.getElementType()); assert(!llvm::isa<VectorType>(multiReductionOp.getDestType()) && @@ -455,10 +457,11 @@ struct OneDimMultiReductionToTwoDim loc, accType, multiReductionOp.getAcc()); Value castMask; if (maskableOp.isMasked()) { - auto maskType = llvm::cast<ShapedType>(mask.getType()); - auto castMaskType = - VectorType::get(ArrayRef<int64_t>{1, maskType.getShape().back()}, - maskType.getElementType()); + auto maskType = llvm::cast<VectorType>(mask.getType()); + auto castMaskType = VectorType::get( + ArrayRef<int64_t>{1, maskType.getShape().back()}, + maskType.getElementType(), + ArrayRef<bool>{false, maskType.getScalableDims().back()}); castMask = rewriter.create<vector::BroadcastOp>(loc, castMaskType, mask); } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir index 22808aa..f70d23a 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -281,6 +281,23 @@ func.func private @scalable_dims(%A : vector<8x[4]x2xf32>, %B: vector<8x[4]xf32> // CHECK: %[[VAL_163:.*]] = vector.shape_cast %[[VAL_162]] : vector<[32]xf32> to vector<8x[4]xf32> // CHECK: return %[[VAL_163]] : vector<8x[4]xf32> +// Check that OneDimMultiReductionToTwoDim handles scalable dim +func.func @scalable_dim_1d(%A: vector<[4]xf32>, %B: f32, %C: vector<[4]xi1>) -> f32 { + %0 = vector.mask %C { vector.multi_reduction <add>, %A, %B [0] : vector<[4]xf32> to f32 } : vector<[4]xi1> -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func.func @scalable_dim_1d( +// CHECK-SAME: %[[ARG_0:.*]]: vector<[4]xf32>, +// CHECK-SAME: %[[ARG_1:.*]]: f32, +// CHECK-SAME: %[[ARG_2:.*]]: vector<[4]xi1>) -> f32 { +// CHECK-DAG: %[[VAL_0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> +// CHECK: %[[VAL_2:.*]] = vector.mask %[[ARG_2]] { vector.reduction <add>, %[[ARG_0]], %[[ARG_1]] : vector<[4]xf32> into f32 } : vector<[4]xi1> -> f32 +// CHECK: %[[VAL_3:.*]] = vector.insertelement %[[VAL_2]], %[[VAL_1]][%[[VAL_0]] : index] : vector<1xf32> +// CHECK: %[[VAL_4:.*]] = vector.extract %[[VAL_3]][0] : f32 from vector<1xf32> +// CHECK: return %[[VAL_4]] : f32 + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%root : !transform.any_op {transform.readonly}) { %func_op = transform.structured.match ops{["func.func"]} in %root : (!transform.any_op) -> !transform.op<"func.func"> |