diff options
Diffstat (limited to 'mlir/lib/Dialect/MemRef')
| -rw-r--r-- | mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp | 29 |
2 files changed, 28 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp index a15bf89..6fa8ce4 100644 --- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp @@ -66,7 +66,7 @@ struct ExpandShapeOpInterface ValueBoundsConstraintSet &cstr) const { auto expandOp = cast<memref::ExpandShapeOp>(op); assert(value == expandOp.getResult() && "invalid value"); - cstr.bound(value)[dim] == expandOp.getOutputShape()[dim]; + cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim]; } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index 291da1f..14152c5 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Interfaces/RuntimeVerifiableOpInterface.h" using namespace mlir; @@ -273,7 +274,9 @@ struct SubViewOpInterface Value one = arith::ConstantIndexOp::create(builder, loc, 1); auto metadataOp = ExtractStridedMetadataOp::create(builder, loc, subView.getSource()); - for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { + for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) { + // Reset insertion point to before the operation for each dimension + builder.setInsertionPoint(subView); Value offset = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedOffsets()[i]); Value size = getValueOrCreateConstantIndexOp(builder, loc, @@ -290,6 +293,16 @@ struct SubViewOpInterface std::to_string(i) + " is out-of-bounds")); + // Only verify if size > 0 + Value sizeIsNonZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::sgt, size, zero); + + auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(), + sizeIsNonZero, /*withElseRegion=*/true); + + // Populate the "then" region (for size > 0). + builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + // Verify that slice does not run out-of-bounds. Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); Value sizeMinusOneTimesStride = @@ -298,8 +311,20 @@ struct SubViewOpInterface arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); + + scf::YieldOp::create(builder, loc, lastPosInBounds); + + // Populate the "else" region (for size == 0). + builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); + Value trueVal = + arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); + scf::YieldOp::create(builder, loc, trueVal); + + builder.setInsertionPointAfter(ifOp); + Value finalCondition = ifOp.getResult(0); + cf::AssertOp::create( - builder, loc, lastPosInBounds, + builder, loc, finalCondition, generateErrorMessage(op, "subview runs out-of-bounds along dimension " + std::to_string(i))); |
