aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp16
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp50
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp9
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp6
-rw-r--r--mlir/lib/Dialect/Async/IR/Async.cpp11
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp11
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp8
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp25
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp29
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp58
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp1
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp1
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp145
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp17
-rw-r--r--mlir/lib/Dialect/Shape/IR/Shape.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp4
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp16
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp30
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp7
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp37
-rw-r--r--mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp5
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp2
26 files changed, 374 insertions, 134 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 4c4965e..df955fc 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -399,13 +399,15 @@ LogicalResult WMMAOp::verify() {
if (!sourceAElemType.isFloat(8) && sourceAElemType != sourceBElemType) {
return emitOpError(
- "source element types much match (except for fp8) but have ")
+ "source element types must match (except for fp8/bf8) but have ")
<< sourceAType << " and " << sourceBType;
}
- if (!sourceAElemType.isInteger(4) && getK() != 16) {
- return emitOpError("K dimension must be 16 for source element type ")
- << sourceAElemType;
+ if (isSrcFloat) {
+ if (getClamp())
+ return emitOpError("clamp flag is not supported for float types");
+ if (getUnsignedA() || getUnsignedB())
+ return emitOpError("unsigned flags are not supported for float types");
}
return success();
}
@@ -422,11 +424,11 @@ LogicalResult MFMAOp::verify() {
Type sourceElem = sourceType, destElem = destType;
uint32_t sourceLen = 1, destLen = 1;
- if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
+ if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
sourceLen = sourceVector.getNumElements();
sourceElem = sourceVector.getElementType();
}
- if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
+ if (auto destVector = dyn_cast<VectorType>(destType)) {
destLen = destVector.getNumElements();
destElem = destVector.getElementType();
}
@@ -451,7 +453,7 @@ LogicalResult MFMAOp::verify() {
return emitOpError("expected both non-small-float source operand types "
"to match exactly");
}
- // Normalize the wider integer types the compiler expects to i8
+ // Normalize the wider integer types the compiler expects to i8.
if (sourceElem.isInteger(32)) {
sourceLen *= 4;
sourceElem = b.getI8Type();
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index e0a53cd..0c35921 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2716,8 +2716,9 @@ LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
return success(folded);
}
-OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert((point.isParent() || point == getRegion()) && "invalid region point");
+OperandRange AffineForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert((successor.isParent() || successor.getSuccessor() == &getRegion()) &&
+ "invalid region point");
// The initial operands map to the loop arguments after the induction
// variable or are forwarded to the results when the trip count is zero.
@@ -2726,34 +2727,41 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
void AffineForOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
- assert((point.isParent() || point == getRegion()) && "expected loop region");
+ assert((point.isParent() ||
+ point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getRegion()) &&
+ "expected loop region");
// The loop may typically branch back to its body or to the parent operation.
// If the predecessor is the parent op and the trip count is known to be at
// least one, branch into the body using the iterator arguments. And in cases
// we know the trip count is zero, it can only branch back to its parent.
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
- if (point.isParent() && tripCount.has_value()) {
- if (tripCount.value() > 0) {
- regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- return;
- }
- if (tripCount.value() == 0) {
- regions.push_back(RegionSuccessor(getResults()));
- return;
+ if (tripCount.has_value()) {
+ if (!point.isParent()) {
+ // From the loop body, if the trip count is one, we can only branch back
+ // to the parent.
+ if (tripCount == 1) {
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ return;
+ }
+ if (tripCount == 0)
+ return;
+ } else {
+ if (tripCount.value() > 0) {
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
+ return;
+ }
+ if (tripCount.value() == 0) {
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
+ return;
+ }
}
}
- // From the loop body, if the trip count is one, we can only branch back to
- // the parent.
- if (!point.isParent() && tripCount == 1) {
- regions.push_back(RegionSuccessor(getResults()));
- return;
- }
-
// In all other cases, the loop may branch back to itself or the parent
// operation.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
AffineBound AffineForOp::getLowerBound() {
@@ -3142,7 +3150,7 @@ void AffineIfOp::getSuccessorRegions(
RegionSuccessor(&getThenRegion(), getThenRegion().getArguments()));
// If the "else" region is empty, branch bach into parent.
if (getElseRegion().empty()) {
- regions.push_back(getResults());
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
} else {
regions.push_back(
RegionSuccessor(&getElseRegion(), getElseRegion().getArguments()));
@@ -3152,7 +3160,7 @@ void AffineIfOp::getSuccessorRegions(
// If the predecessor is the `else`/`then` region, then branching into parent
// op is valid.
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
LogicalResult AffineIfOp::verify() {
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 50a0f3d..e08cc6f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -978,12 +978,11 @@ static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp,
LLVM_DEBUG(
dbgs() << "\n[early-vect]+++++ affine.apply on vector operand\n");
return nullptr;
- } else {
- Value updatedOperand = state.valueScalarReplacement.lookupOrNull(operand);
- if (!updatedOperand)
- updatedOperand = operand;
- updatedOperands.push_back(updatedOperand);
}
+ Value updatedOperand = state.valueScalarReplacement.lookupOrNull(operand);
+ if (!updatedOperand)
+ updatedOperand = operand;
+ updatedOperands.push_back(updatedOperand);
}
auto newApplyOp = AffineApplyOp::create(
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index d925c19..a651710 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -216,8 +216,8 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
for (auto condBranch : worklist) {
auto loc = condBranch.getLoc();
Block *block = condBranch->getBlock();
- auto newTrueBranch = rewriter.splitBlock(block, block->end());
- auto newFalseBranch = rewriter.splitBlock(block, block->end());
+ auto *newTrueBranch = rewriter.splitBlock(block, block->end());
+ auto *newFalseBranch = rewriter.splitBlock(block, block->end());
insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
condBranch.getTrueDestOperands());
insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
@@ -382,7 +382,7 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
// Find or create a live range for `value`.
auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
LiveRange &valueLiveRange = it->second;
- auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+ auto *lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
// Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
unsigned startOpIdx =
operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp
index dc7b07d..8e4a49d 100644
--- a/mlir/lib/Dialect/Async/IR/Async.cpp
+++ b/mlir/lib/Dialect/Async/IR/Async.cpp
@@ -36,8 +36,9 @@ void AsyncDialect::initialize() {
constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";
-OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBodyRegion() && "invalid region index");
+OperandRange ExecuteOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBodyRegion() &&
+ "invalid region index");
return getBodyOperands();
}
@@ -53,8 +54,10 @@ bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `body` region branch back to the parent operation.
- if (point == getBodyRegion()) {
- regions.push_back(RegionSuccessor(getBodyResults()));
+ if (!point.isParent() &&
+ point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getBodyRegion()) {
+ regions.push_back(RegionSuccessor(getOperation(), getBodyResults()));
return;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index b593cca..36a759c 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -562,8 +562,11 @@ LogicalResult
BufferDeallocation::updateFunctionSignature(FunctionOpInterface op) {
SmallVector<TypeRange> returnOperandTypes(llvm::map_range(
op.getFunctionBody().getOps<RegionBranchTerminatorOpInterface>(),
- [](RegionBranchTerminatorOpInterface op) {
- return op.getSuccessorOperands(RegionBranchPoint::parent()).getTypes();
+ [&](RegionBranchTerminatorOpInterface branchOp) {
+ return branchOp
+ .getSuccessorOperands(RegionSuccessor(
+ op.getOperation(), op.getOperation()->getResults()))
+ .getTypes();
}));
if (!llvm::all_equal(returnOperandTypes))
return op->emitError(
@@ -942,8 +945,8 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
// about, but we would need to check how many successors there are and under
// which condition they are taken, etc.
- MutableOperandRange operands =
- op.getMutableSuccessorOperands(RegionBranchPoint::parent());
+ MutableOperandRange operands = op.getMutableSuccessorOperands(
+ RegionSuccessor(op.getOperation(), op.getOperation()->getResults()));
SmallVector<Value> updatedOwnerships;
auto result = deallocation_impl::insertDeallocOpForReturnLike(
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4754f0b..0992ce14 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -845,7 +845,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
// The `then` and the `else` region branch back to the parent operation.
if (!point.isParent()) {
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
return;
}
@@ -854,7 +855,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
// Don't consider the else region if it is empty.
Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
else
regions.push_back(RegionSuccessor(elseRegion));
}
@@ -871,7 +873,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index b5f8dda..6c6d8d2 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2399,7 +2399,7 @@ ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
void WarpExecuteOnLane0Op::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index c551fba..1c21a2f 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -405,7 +405,7 @@ ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
void AllocaScopeOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index a15bf89..3aa801b 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -66,7 +66,7 @@ struct ExpandShapeOpInterface
ValueBoundsConstraintSet &cstr) const {
auto expandOp = cast<memref::ExpandShapeOp>(op);
assert(value == expandOp.getResult() && "invalid value");
- cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
+ cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
}
};
@@ -98,6 +98,27 @@ struct RankOpInterface
}
};
+struct CollapseShapeOpInterface
+ : public ValueBoundsOpInterface::ExternalModel<CollapseShapeOpInterface,
+ memref::CollapseShapeOp> {
+ void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
+ ValueBoundsConstraintSet &cstr) const {
+ auto collapseOp = cast<memref::CollapseShapeOp>(op);
+ assert(value == collapseOp.getResult() && "invalid value");
+
+ // Multiply the expressions for the dimensions in the reassociation group.
+ const ReassociationIndices &reassocIndices =
+ collapseOp.getReassociationIndices()[dim];
+ AffineExpr productExpr =
+ cstr.getExpr(collapseOp.getSrc(), reassocIndices[0]);
+ for (size_t i = 1; i < reassocIndices.size(); ++i) {
+ productExpr =
+ productExpr * cstr.getExpr(collapseOp.getSrc(), reassocIndices[i]);
+ }
+ cstr.bound(value)[dim] == productExpr;
+ }
+};
+
struct SubViewOpInterface
: public ValueBoundsOpInterface::ExternalModel<SubViewOpInterface,
SubViewOp> {
@@ -134,6 +155,8 @@ void mlir::memref::registerValueBoundsOpInterfaceExternalModels(
memref::AllocOpInterface<memref::AllocaOp>>(*ctx);
memref::CastOp::attachInterface<memref::CastOpInterface>(*ctx);
memref::DimOp::attachInterface<memref::DimOpInterface>(*ctx);
+ memref::CollapseShapeOp::attachInterface<memref::CollapseShapeOpInterface>(
+ *ctx);
memref::ExpandShapeOp::attachInterface<memref::ExpandShapeOpInterface>(
*ctx);
memref::GetGlobalOp::attachInterface<memref::GetGlobalOpInterface>(*ctx);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 291da1f..14152c5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
using namespace mlir;
@@ -273,7 +274,9 @@ struct SubViewOpInterface
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
auto metadataOp =
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
- for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+ // Reset insertion point to before the operation for each dimension
+ builder.setInsertionPoint(subView);
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(builder, loc,
@@ -290,6 +293,16 @@ struct SubViewOpInterface
std::to_string(i) +
" is out-of-bounds"));
+ // Only verify if size > 0
+ Value sizeIsNonZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+ auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
+ sizeIsNonZero, /*withElseRegion=*/true);
+
+ // Populate the "then" region (for size > 0).
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
@@ -298,8 +311,20 @@ struct SubViewOpInterface
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
+
+ scf::YieldOp::create(builder, loc, lastPosInBounds);
+
+ // Populate the "else" region (for size == 0).
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ Value trueVal =
+ arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
+ scf::YieldOp::create(builder, loc, trueVal);
+
+ builder.setInsertionPointAfter(ifOp);
+ Value finalCondition = ifOp.getResult(0);
+
cf::AssertOp::create(
- builder, loc, lastPosInBounds,
+ builder, loc, finalCondition,
generateErrorMessage(op,
"subview runs out-of-bounds along dimension " +
std::to_string(i)));
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 744a595..2946b53 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -111,10 +111,8 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
return nullptr;
}
-/// Helper function to compute the difference between two values. This is used
-/// by the loop implementations to compute the trip count.
-static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
- bool isSigned) {
+std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
+ bool isSigned) {
llvm::APSInt diff;
auto addOp = ub.getDefiningOp<arith::AddIOp>();
if (!addOp)
@@ -399,7 +397,7 @@ void ExecuteRegionOp::getSuccessorRegions(
}
// Otherwise, the region branches back to the parent operation.
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
//===----------------------------------------------------------------------===//
@@ -407,10 +405,11 @@ void ExecuteRegionOp::getSuccessorRegions(
//===----------------------------------------------------------------------===//
MutableOperandRange
-ConditionOp::getMutableSuccessorOperands(RegionBranchPoint point) {
- assert((point.isParent() || point == getParentOp().getAfter()) &&
- "condition op can only exit the loop or branch to the after"
- "region");
+ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
+ assert(
+ (point.isParent() || point.getSuccessor() == &getParentOp().getAfter()) &&
+ "condition op can only exit the loop or branch to the after"
+ "region");
// Pass all operands except the condition to the successor region.
return getArgsMutable();
}
@@ -428,7 +427,7 @@ void ConditionOp::getSuccessorRegions(
regions.emplace_back(&whileOp.getAfter(),
whileOp.getAfter().getArguments());
if (!boolAttr || !boolAttr.getValue())
- regions.emplace_back(whileOp.getResults());
+ regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
}
//===----------------------------------------------------------------------===//
@@ -751,7 +750,7 @@ ForOp mlir::scf::getForInductionVarOwner(Value val) {
return dyn_cast_or_null<ForOp>(containingOp);
}
-OperandRange ForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+OperandRange ForOp::getEntrySuccessorOperands(RegionSuccessor successor) {
return getInitArgs();
}
@@ -761,7 +760,7 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
// back into the operation itself. It is possible for loop not to enter the
// body.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
@@ -2055,9 +2054,10 @@ void ForallOp::getSuccessorRegions(RegionBranchPoint point,
// parallel by multiple threads. We should not expect to branch back into
// the forall body after the region's execution is complete.
if (point.isParent())
- regions.push_back(RegionSuccessor(&getRegion()));
+ regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
else
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
}
//===----------------------------------------------------------------------===//
@@ -2335,9 +2335,10 @@ void IfOp::print(OpAsmPrinter &p) {
void IfOp::getSuccessorRegions(RegionBranchPoint point,
SmallVectorImpl<RegionSuccessor> &regions) {
- // The `then` and the `else` region branch back to the parent operation.
+ // The `then` and the `else` region branch back to the parent operation or one
+ // of the recursive parent operations (early exit case).
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
@@ -2346,7 +2347,8 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
// Don't consider the else region if it is empty.
Region *elseRegion = &this->getElseRegion();
if (elseRegion->empty())
- regions.push_back(RegionSuccessor());
+ regions.push_back(
+ RegionSuccessor(getOperation(), getOperation()->getResults()));
else
regions.push_back(RegionSuccessor(elseRegion));
}
@@ -2363,7 +2365,7 @@ void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
if (!getElseRegion().empty())
regions.emplace_back(&getElseRegion());
else
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getResults());
}
}
@@ -3387,7 +3389,8 @@ void ParallelOp::getSuccessorRegions(
// back into the operation itself. It is possible for loop not to enter the
// body.
regions.push_back(RegionSuccessor(&getRegion()));
- regions.push_back(RegionSuccessor());
+ regions.push_back(RegionSuccessor(
+ getOperation(), ResultRange{getResults().end(), getResults().end()}));
}
//===----------------------------------------------------------------------===//
@@ -3433,7 +3436,7 @@ LogicalResult ReduceOp::verifyRegions() {
}
MutableOperandRange
-ReduceOp::getMutableSuccessorOperands(RegionBranchPoint point) {
+ReduceOp::getMutableSuccessorOperands(RegionSuccessor point) {
// No operands are forwarded to the next iteration.
return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
}
@@ -3516,8 +3519,8 @@ Block::BlockArgListType WhileOp::getRegionIterArgs() {
return getBeforeArguments();
}
-OperandRange WhileOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBefore() &&
+OperandRange WhileOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBefore() &&
"WhileOp is expected to branch only to the first region");
return getInits();
}
@@ -3530,15 +3533,18 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,
return;
}
- assert(llvm::is_contained({&getAfter(), &getBefore()}, point) &&
+ assert(llvm::is_contained(
+ {&getAfter(), &getBefore()},
+ point.getTerminatorPredecessorOrNull()->getParentRegion()) &&
"there are only two regions in a WhileOp");
// The body region always branches back to the condition region.
- if (point == getAfter()) {
+ if (point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getAfter()) {
regions.emplace_back(&getBefore(), getBefore().getArguments());
return;
}
- regions.emplace_back(getResults());
+ regions.emplace_back(getOperation(), getResults());
regions.emplace_back(&getAfter(), getAfter().getArguments());
}
@@ -4447,7 +4453,7 @@ void IndexSwitchOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &successors) {
// All regions branch back to the parent op.
if (!point.isParent()) {
- successors.emplace_back(getResults());
+ successors.emplace_back(getOperation(), getResults());
return;
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index ae52af5..ddcbda8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -23,7 +23,6 @@ namespace mlir {
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
-using namespace llvm;
using namespace mlir;
using scf::ForOp;
using scf::WhileOp;
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
index a2f03f1..00bef70 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForallToFor.cpp
@@ -21,7 +21,6 @@ namespace mlir {
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
-using namespace llvm;
using namespace mlir;
using scf::LoopNest;
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 10eae89..888dd44 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -291,47 +291,61 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
return arith::DivUIOp::create(builder, loc, sum, divisor);
}
-/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
-/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
-/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
-/// unrolled iteration using annotateFn.
-static void generateUnrolledLoop(
- Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
+void mlir::generateUnrolledLoop(
+ Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
- ValueRange iterArgs, ValueRange yieldedValues) {
+ ValueRange iterArgs, ValueRange yieldedValues,
+ IRMapping *clonedToSrcOpsMap) {
+
+ // Check if the op was cloned from another source op, and return it if found
+ // (or the same op if not found)
+ auto findOriginalSrcOp =
+ [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * {
+ Operation *srcOp = op;
+ // If the source op derives from another op: traverse the chain to find the
+ // original source op
+ while (srcOp && clonedToSrcOpsMap.contains(srcOp))
+ srcOp = clonedToSrcOpsMap.lookup(srcOp);
+ return srcOp;
+ };
+
// Builder to insert unrolled bodies just before the terminator of the body of
- // 'forOp'.
+ // the loop.
auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
- constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
+ static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
if (!annotateFn)
- annotateFn = defaultAnnotateFn;
+ annotateFn = noopAnnotateFn;
// Keep a pointer to the last non-terminator operation in the original block
// so that we know what to clone (since we are doing this in-place).
Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
- // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
+ // Unroll the contents of the loop body (append unrollFactor - 1 additional
+ // copies).
SmallVector<Value, 4> lastYielded(yieldedValues);
for (unsigned i = 1; i < unrollFactor; i++) {
- IRMapping operandMap;
-
// Prepare operand map.
+ IRMapping operandMap;
operandMap.map(iterArgs, lastYielded);
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
- if (!forOpIV.use_empty()) {
- Value ivUnroll = ivRemapFn(i, forOpIV, builder);
- operandMap.map(forOpIV, ivUnroll);
+ if (!iv.use_empty()) {
+ Value ivUnroll = ivRemapFn(i, iv, builder);
+ operandMap.map(iv, ivUnroll);
}
// Clone the original body of 'forOp'.
for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
- Operation *clonedOp = builder.clone(*it, operandMap);
+ Operation *srcOp = &(*it);
+ Operation *clonedOp = builder.clone(*srcOp, operandMap);
annotateFn(i, clonedOp, builder);
+ if (clonedToSrcOpsMap)
+ clonedToSrcOpsMap->map(clonedOp,
+ findOriginalSrcOp(srcOp, *clonedToSrcOpsMap));
}
// Update yielded values.
@@ -1544,3 +1558,100 @@ bool mlir::isPerfectlyNestedForLoops(
}
return true;
}
+
+llvm::SmallVector<int64_t>
+mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
+ std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
+ std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
+ std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
+ if (!loBnds || !upBnds || !steps)
+ return {};
+ llvm::SmallVector<int64_t> tripCounts;
+ for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
+ std::optional<llvm::APInt> numIter = constantTripCount(
+ lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
+ if (!numIter)
+ return {};
+ tripCounts.push_back(numIter->getSExtValue());
+ }
+ return tripCounts;
+}
+
+FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
+ scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors,
+ RewriterBase &rewriter,
+ function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
+ IRMapping *clonedToSrcOpsMap) {
+ const unsigned numLoops = op.getNumLoops();
+ assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) &&
+ "Expected positive unroll factors");
+ assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) &&
+ "Expected non-empty unroll factors of size <= to the number of loops");
+
+ // Bail out if no valid unroll factors were provided
+ if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; }))
+ return rewriter.notifyMatchFailure(
+ op, "Unrolling not applied if all factors are 1");
+
+ // Return if the loop body is empty.
+ if (llvm::hasSingleElement(op.getBody()->getOperations()))
+ return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body");
+
+ // If the provided unroll factors do not cover all the loop dims, they are
+ // applied to the inner loop dimensions.
+ const unsigned firstLoopDimIdx = numLoops - unrollFactors.size();
+
+ // Make sure that the unroll factors divide the iteration space evenly
+ // TODO: Support unrolling loops with dynamic iteration spaces.
+ const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op);
+ if (tripCounts.empty())
+ return rewriter.notifyMatchFailure(
+ op, "Failed to compute constant trip counts for the loop. Note that "
+ "dynamic loop sizes are not supported.");
+
+ for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
+ const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
+ if (tripCounts[dimIdx] % unrollFactor)
+ return rewriter.notifyMatchFailure(
+ op, "Unroll factors don't divide the iteration space evenly");
+ }
+
+ std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps();
+ if (!maybeFoldSteps)
+ return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps");
+ llvm::SmallVector<size_t> steps{};
+ for (auto step : *maybeFoldSteps)
+ steps.push_back(static_cast<size_t>(*getConstantIntValue(step)));
+
+ for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
+ const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
+ if (unrollFactor == 1)
+ continue;
+ const size_t origStep = steps[dimIdx];
+ const int64_t newStep = origStep * unrollFactor;
+ IRMapping clonedToSrcOpsMap;
+
+ ValueRange iterArgs = ValueRange(op.getRegionIterArgs());
+ auto yieldedValues = op.getBody()->getTerminator()->getOperands();
+
+ generateUnrolledLoop(
+ op.getBody(), op.getInductionVars()[dimIdx], unrollFactor,
+ [&](unsigned i, Value iv, OpBuilder b) {
+ // iv' = iv + step * i;
+ const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i);
+ const auto map =
+ b.getDimIdentityMap().dropResult(0).insertResult(expr, 0);
+ return affine::AffineApplyOp::create(b, iv.getLoc(), map,
+ ValueRange{iv});
+ },
+ /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap);
+
+ // Update loop step
+ auto prevInsertPoint = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ op.getStepMutable()[dimIdx].assign(
+ arith::ConstantIndexOp::create(rewriter, op.getLoc(), newStep));
+ rewriter.restoreInsertionPoint(prevInsertPoint);
+ }
+ return op;
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index fe50865..0c8114d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1276,12 +1276,19 @@ LogicalResult spirv::GlobalVariableOp::verify() {
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), init.getAttr());
// TODO: Currently only variable initialization with specialization
- // constants and other variables is supported. They could be normal
- // constants in the module scope as well.
- if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
- spirv::SpecConstantCompositeOp>(initOp)) {
+ // constants is supported. There could be normal constants in the module
+ // scope as well.
+ //
+ // In the current setup we also cannot initialize one global variable with
+ // another. The problem is that if we try to initialize pointer of type X
+ // with another pointer type, the validator fails because it expects the
+ // variable to be initialized to be type X, not pointer to X. Now
+ // `spirv.GlobalVariable` only allows pointer type, so in the current design
+ // we cannot initialize one `spirv.GlobalVariable` with another.
+ if (!initOp ||
+ !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
return emitOpError("initializer must be result of a "
- "spirv.SpecConstant or spirv.GlobalVariable or "
+ "spirv.SpecConstant or "
"spirv.SpecConstantCompositeOp op");
}
}
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 5ba8289..f0f22e5 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -346,7 +346,7 @@ void AssumingOp::getSuccessorRegions(
// parent, so return the correct RegionSuccessor purely based on the index
// being None or 0.
if (!point.isParent()) {
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
return;
}
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 1a9d9e1..3962e3e 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2597,7 +2597,7 @@ std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() {
std::optional<ResultRange> IterateOp::getLoopResults() { return getResults(); }
-OperandRange IterateOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+OperandRange IterateOp::getEntrySuccessorOperands(RegionSuccessor successor) {
return getInitArgs();
}
@@ -2607,7 +2607,7 @@ void IterateOp::getSuccessorRegions(RegionBranchPoint point,
// or back into the operation itself.
regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
// It is possible for loop not to enter the body.
- regions.push_back(RegionSuccessor(getResults()));
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
}
void CoIterateOp::build(OpBuilder &builder, OperationState &odsState,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index 73e0f3d..f53d272 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -159,14 +159,22 @@ IterationGraphSorter::IterationGraphSorter(
loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
strategy(strategy) {
// One map per tensor.
- assert(loop2InsLvl.size() == ins.size());
+ assert(this->loop2InsLvl.size() == this->ins.size());
// All the affine maps have the same number of dimensions (loops).
assert(llvm::all_equal(llvm::map_range(
- loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
+ this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
// The number of results of the map should match the rank of the tensor.
- assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
+ assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) {
auto [m, v] = mvPair;
- return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();
+
+ // For ranked types the rank must match.
+ // Simply return true for UnrankedTensorType
+ if (auto shapedType = llvm::dyn_cast<ShapedType>(v.getType())) {
+ return !shapedType.hasRank() ||
+ (m.getNumResults() == shapedType.getRank());
+ }
+ // Non-shaped (scalar) types behave like rank-0.
+ return m.getNumResults() == 0;
}));
itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
index c031118..753cb95 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -158,7 +159,11 @@ struct ExtractSliceOpInterface
// 0 <= offset + (size - 1) * stride < dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
- for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+
+ for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+ // Reset insertion point to before the operation for each dimension
+ builder.setInsertionPoint(extractSliceOp);
+
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(
@@ -176,6 +181,16 @@ struct ExtractSliceOpInterface
std::to_string(i) +
" is out-of-bounds"));
+ // Only verify if size > 0
+ Value sizeIsNonZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+ auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
+ sizeIsNonZero, /*withElseRegion=*/true);
+
+ // Populate the "then" region (for size > 0).
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
@@ -184,8 +199,19 @@ struct ExtractSliceOpInterface
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
+ scf::YieldOp::create(builder, loc, lastPosInBounds);
+
+ // Populate the "else" region (for size == 0).
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ Value trueVal =
+ arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
+ scf::YieldOp::create(builder, loc, trueVal);
+
+ builder.setInsertionPointAfter(ifOp);
+ Value finalCondition = ifOp.getResult(0);
+
cf::AssertOp::create(
- builder, loc, lastPosInBounds,
+ builder, loc, finalCondition,
generateErrorMessage(
op, "extract_slice runs out-of-bounds along dimension " +
std::to_string(i)));
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index a85ff10a..293c6af 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -38,7 +38,7 @@ using namespace mlir::tosa;
//===----------------------------------------------------------------------===//
// Check that the zero point of the tensor and padding operations are aligned.
-bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
+static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
// Check that padConst is a constant value and a scalar tensor
DenseElementsAttr padConstAttr;
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
@@ -889,8 +889,9 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
template <typename IntFolder, typename FloatFolder>
-DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
- RankedTensorType returnTy) {
+static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
+ DenseElementsAttr rhs,
+ RankedTensorType returnTy) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 365afab..062606e 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -96,9 +96,9 @@ ensurePayloadIsSeparateFromTransform(transform::TransformOpInterface transform,
// AlternativesOp
//===----------------------------------------------------------------------===//
-OperandRange
-transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- if (!point.isParent() && getOperation()->getNumOperands() == 1)
+OperandRange transform::AlternativesOp::getEntrySuccessorOperands(
+ RegionSuccessor successor) {
+ if (!successor.isParent() && getOperation()->getNumOperands() == 1)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
getOperation()->operand_end());
@@ -107,15 +107,18 @@ transform::AlternativesOp::getEntrySuccessorOperands(RegionBranchPoint point) {
void transform::AlternativesOp::getSuccessorRegions(
RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
for (Region &alternative : llvm::drop_begin(
- getAlternatives(),
- point.isParent() ? 0
- : point.getRegionOrNull()->getRegionNumber() + 1)) {
+ getAlternatives(), point.isParent()
+ ? 0
+ : point.getTerminatorPredecessorOrNull()
+ ->getParentRegion()
+ ->getRegionNumber() +
+ 1)) {
regions.emplace_back(&alternative, !getOperands().empty()
? alternative.getArguments()
: Block::BlockArgListType());
}
if (!point.isParent())
- regions.emplace_back(getOperation()->getResults());
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
void transform::AlternativesOp::getRegionInvocationBounds(
@@ -1740,16 +1743,18 @@ void transform::ForeachOp::getSuccessorRegions(
}
// Branch back to the region or the parent.
- assert(point == getBody() && "unexpected region index");
+ assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getBody() &&
+ "unexpected region index");
regions.emplace_back(bodyRegion, bodyRegion->getArguments());
- regions.emplace_back();
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
OperandRange
-transform::ForeachOp::getEntrySuccessorOperands(RegionBranchPoint point) {
+transform::ForeachOp::getEntrySuccessorOperands(RegionSuccessor successor) {
// Each block argument handle is mapped to a subset (one op to be precise)
// of the payload of the corresponding `targets` operand of ForeachOp.
- assert(point == getBody() && "unexpected region index");
+ assert(successor.getSuccessor() == &getBody() && "unexpected region index");
return getOperation()->getOperands();
}
@@ -2948,8 +2953,8 @@ void transform::SequenceOp::getEffects(
}
OperandRange
-transform::SequenceOp::getEntrySuccessorOperands(RegionBranchPoint point) {
- assert(point == getBody() && "unexpected region index");
+transform::SequenceOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ assert(successor.getSuccessor() == &getBody() && "unexpected region index");
if (getOperation()->getNumOperands() > 0)
return getOperation()->getOperands();
return OperandRange(getOperation()->operand_end(),
@@ -2966,8 +2971,10 @@ void transform::SequenceOp::getSuccessorRegions(
return;
}
- assert(point == getBody() && "unexpected region index");
- regions.emplace_back(getOperation()->getResults());
+ assert(point.getTerminatorPredecessorOrNull()->getParentRegion() ==
+ &getBody() &&
+ "unexpected region index");
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
void transform::SequenceOp::getRegionInvocationBounds(
diff --git a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
index c627158..f727118 100644
--- a/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
+++ b/mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
@@ -9,6 +9,7 @@
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
@@ -112,7 +113,7 @@ static void printAlternativesOpSelectedRegion(OpAsmPrinter &printer,
}
OperandRange transform::tune::AlternativesOp::getEntrySuccessorOperands(
- RegionBranchPoint point) {
+ RegionSuccessor successor) {
// No operands will be forwarded to the region(s).
return getOperands().slice(0, 0);
}
@@ -128,7 +129,7 @@ void transform::tune::AlternativesOp::getSuccessorRegions(
for (Region &alternative : getAlternatives())
regions.emplace_back(&alternative, Block::BlockArgListType());
else
- regions.emplace_back(getOperation()->getResults());
+ regions.emplace_back(getOperation(), getOperation()->getResults());
}
void transform::tune::AlternativesOp::getRegionInvocationBounds(
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 24e9095..f9aa28d5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -113,9 +113,12 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
if (layout.size() != shape.size())
return std::nullopt;
auto ratio = computeShapeRatio(shape, layout);
- if (!ratio.has_value())
+ if (ratio.has_value()) {
+ newShape = ratio.value();
+ } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
return std::nullopt;
- newShape = ratio.value();
+ }
+ // Round-robin case: continue with original newShape
}
if (data.size()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 2c37140..ec5feb8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -344,6 +344,13 @@ void XeGPUBlockingPass::runOnOperation() {
xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter);
+ // Remove leading unit dimensions from vector ops and then
+ // do the unrolling.
+ {
+ RewritePatternSet patterns(ctx);
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ (void)applyPatternsGreedily(op, std::move(patterns));
+ }
xegpu::UnrollOptions options;
options.setFilterConstraint(
[&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index b4605cd..a38993e 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -147,7 +147,7 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
- auto parentOp = arg.getOwner()->getParentOp();
+ auto *parentOp = arg.getOwner()->getParentOp();
if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
if (tiedInit)