aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/MemRef
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/MemRef')
-rw-r--r--mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp29
2 files changed, 28 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index a15bf89..6fa8ce4 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -66,7 +66,7 @@ struct ExpandShapeOpInterface
ValueBoundsConstraintSet &cstr) const {
auto expandOp = cast<memref::ExpandShapeOp>(op);
assert(value == expandOp.getResult() && "invalid value");
- cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
+ cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
}
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 291da1f..14152c5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
using namespace mlir;
@@ -273,7 +274,9 @@ struct SubViewOpInterface
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
auto metadataOp =
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
- for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+ // Reset insertion point to before the operation for each dimension
+ builder.setInsertionPoint(subView);
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(builder, loc,
@@ -290,6 +293,16 @@ struct SubViewOpInterface
std::to_string(i) +
" is out-of-bounds"));
+ // Only verify if size > 0
+ Value sizeIsNonZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+ auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
+ sizeIsNonZero, /*withElseRegion=*/true);
+
+ // Populate the "then" region (for size > 0).
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
@@ -298,8 +311,20 @@ struct SubViewOpInterface
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
+
+ scf::YieldOp::create(builder, loc, lastPosInBounds);
+
+ // Populate the "else" region (for size == 0).
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ Value trueVal =
+ arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
+ scf::YieldOp::create(builder, loc, trueVal);
+
+ builder.setInsertionPointAfter(ifOp);
+ Value finalCondition = ifOp.getResult(0);
+
cf::AssertOp::create(
- builder, loc, lastPosInBounds,
+ builder, loc, finalCondition,
generateErrorMessage(op,
"subview runs out-of-bounds along dimension " +
std::to_string(i)));