diff options
author | Diego Caballero <dieg0ca6aller0@gmail.com> | 2024-09-19 10:17:13 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-09-19 10:17:13 -0700 |
commit | bcd65ba6129bea92485432fdd09874bc3fc6671e (patch) | |
tree | d3442acda945a47d641dbe0418bd1bf6739bcaa9 | |
parent | 8a34f6dba14e49332ff63abfa6a8aa3ca560fc50 (diff) | |
download | llvm-bcd65ba6129bea92485432fdd09874bc3fc6671e.zip llvm-bcd65ba6129bea92485432fdd09874bc3fc6671e.tar.gz llvm-bcd65ba6129bea92485432fdd09874bc3fc6671e.tar.bz2 |
[mlir][Vector] Verify that masked ops implement MaskableOpInterface (#108123)
This PR fixes a bug in `MaskOp::verifier` that allowed `vector.mask` to
mask operations that did not implement the MaskableOpInterface.
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 11 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/canonicalize.mlir | 12 | ||||
-rw-r--r-- | mlir/test/Dialect/Vector/invalid.mlir | 8 |
3 files changed, 23 insertions, 8 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8164477..1438ddd 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6131,7 +6131,9 @@ LogicalResult MaskOp::verify() { Block &block = getMaskRegion().getBlocks().front(); if (block.getOperations().empty()) return emitOpError("expects a terminator within the mask region"); - if (block.getOperations().size() > 2) + + unsigned numMaskRegionOps = block.getOperations().size(); + if (numMaskRegionOps > 2) return emitOpError("expects only one operation to mask"); // Terminator checks. @@ -6143,11 +6145,14 @@ LogicalResult MaskOp::verify() { return emitOpError( "expects number of results to match mask region yielded values"); - auto maskableOp = dyn_cast<MaskableOpInterface>(block.front()); // Empty vector.mask. Nothing else to check. - if (!maskableOp) + if (numMaskRegionOps == 1) return success(); + auto maskableOp = dyn_cast<MaskableOpInterface>(block.front()); + if (!maskableOp) + return emitOpError("expects a MaskableOpInterface within the mask region"); + // Result checks. if (maskableOp->getNumResults() != getNumResults()) return emitOpError("expects number of results to match maskable operation " diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index e71a6eb..b7c78de 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2471,13 +2471,15 @@ func.func @empty_vector_mask_with_return(%a : vector<8xf32>, %mask : vector<8xi1 // ----- // CHECK-LABEL: func @all_true_vector_mask -// CHECK-SAME: %[[IN:.*]]: vector<3x4xf32> -func.func @all_true_vector_mask(%a : vector<3x4xf32>) -> vector<3x4xf32> { +// CHECK-SAME: %[[IN:.*]]: tensor<3x4xf32> +func.func @all_true_vector_mask(%ta : tensor<3x4xf32>) -> vector<3x4xf32> { // CHECK-NOT: vector.mask -// CHECK: %[[ADD:.*]] = arith.addf %[[IN]], %[[IN]] : vector<3x4xf32> -// CHECK: return %[[ADD]] : vector<3x4xf32> +// CHECK: %[[LD:.*]] = vector.transfer_read %[[IN]] +// CHECK: return %[[LD]] : vector<3x4xf32> + %c0 = arith.constant 0 : index + %cf0 = arith.constant 0.0 : f32 %all_true = vector.constant_mask [3, 4] : vector<3x4xi1> - %0 = vector.mask %all_true { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32> + %0 = vector.mask %all_true { vector.transfer_read %ta[%c0, %c0], %cf0 : tensor<3x4xf32>, vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32> return %0 : vector<3x4xf32> } diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir index c95b8bd5..e2bc5ef 100644 --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -1724,6 +1724,14 @@ func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor<?xf3 vector.mask %m0, %pt0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> vector<16xf32> return } +// ----- + +func.func @vector_mask_non_maskable_op(%a : vector<3x4xf32>) -> vector<3x4xf32> { + %m0 = vector.constant_mask [2, 2] : vector<3x4xi1> + // expected-error@+1 {{'vector.mask' op expects a MaskableOpInterface within the mask region}} + %0 = vector.mask %m0 { arith.addf %a, %a : vector<3x4xf32> } : vector<3x4xi1> -> vector<3x4xf32> + return %0 : vector<3x4xf32> +} // ----- |