aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp2
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp2
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp9
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp6
-rw-r--r--mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp29
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp6
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp145
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp17
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp16
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp30
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp2
-rw-r--r--mlir/lib/ExecutionEngine/ExecutionEngine.cpp2
-rw-r--r--mlir/lib/IR/MLIRContext.cpp4
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp52
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp18
20 files changed, 286 insertions, 83 deletions
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 06d0256..cda4fe1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -598,7 +598,7 @@ class PyOpOperand {
public:
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
- PyOpView getOwner() {
+ nb::typed<nb::object, PyOpView> getOwner() {
MlirOperation owner = mlirOpOperandGetOwner(opOperand);
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(owner));
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index b711e33..a4c66e1 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -692,7 +692,7 @@ SymbolRefAttr PatternLowering::generateRewriter(
llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
args.append(mappedArgs.begin(), mappedArgs.end());
pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(),
- /*resultTypes=*/TypeRange(), rewriteName,
+ /*results=*/TypeRange(), rewriteName,
args);
} else {
// Otherwise this is a dag rewriter defined using PDL operations.
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 4c4965e..585b6da 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -422,11 +422,11 @@ LogicalResult MFMAOp::verify() {
Type sourceElem = sourceType, destElem = destType;
uint32_t sourceLen = 1, destLen = 1;
- if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
+ if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
sourceLen = sourceVector.getNumElements();
sourceElem = sourceVector.getElementType();
}
- if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
+ if (auto destVector = dyn_cast<VectorType>(destType)) {
destLen = destVector.getNumElements();
destElem = destVector.getElementType();
}
@@ -451,7 +451,7 @@ LogicalResult MFMAOp::verify() {
return emitOpError("expected both non-small-float source operand types "
"to match exactly");
}
- // Normalize the wider integer types the compiler expects to i8
+ // Normalize the wider integer types the compiler expects to i8.
if (sourceElem.isInteger(32)) {
sourceLen *= 4;
sourceElem = b.getI8Type();
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 50a0f3d..e08cc6f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -978,12 +978,11 @@ static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp,
LLVM_DEBUG(
dbgs() << "\n[early-vect]+++++ affine.apply on vector operand\n");
return nullptr;
- } else {
- Value updatedOperand = state.valueScalarReplacement.lookupOrNull(operand);
- if (!updatedOperand)
- updatedOperand = operand;
- updatedOperands.push_back(updatedOperand);
}
+ Value updatedOperand = state.valueScalarReplacement.lookupOrNull(operand);
+ if (!updatedOperand)
+ updatedOperand = operand;
+ updatedOperands.push_back(updatedOperand);
}
auto newApplyOp = AffineApplyOp::create(
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index d925c19..a651710 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -216,8 +216,8 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
for (auto condBranch : worklist) {
auto loc = condBranch.getLoc();
Block *block = condBranch->getBlock();
- auto newTrueBranch = rewriter.splitBlock(block, block->end());
- auto newFalseBranch = rewriter.splitBlock(block, block->end());
+ auto *newTrueBranch = rewriter.splitBlock(block, block->end());
+ auto *newFalseBranch = rewriter.splitBlock(block, block->end());
insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
condBranch.getTrueDestOperands());
insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
@@ -382,7 +382,7 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
// Find or create a live range for `value`.
auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
LiveRange &valueLiveRange = it->second;
- auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+ auto *lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
// Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
unsigned startOpIdx =
operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index a15bf89..6fa8ce4 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -66,7 +66,7 @@ struct ExpandShapeOpInterface
ValueBoundsConstraintSet &cstr) const {
auto expandOp = cast<memref::ExpandShapeOp>(op);
assert(value == expandOp.getResult() && "invalid value");
- cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
+ cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
}
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 291da1f..14152c5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
using namespace mlir;
@@ -273,7 +274,9 @@ struct SubViewOpInterface
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
auto metadataOp =
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
- for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+ // Reset insertion point to before the operation for each dimension
+ builder.setInsertionPoint(subView);
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(builder, loc,
@@ -290,6 +293,16 @@ struct SubViewOpInterface
std::to_string(i) +
" is out-of-bounds"));
+ // Only verify if size > 0
+ Value sizeIsNonZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+ auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
+ sizeIsNonZero, /*withElseRegion=*/true);
+
+ // Populate the "then" region (for size > 0).
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
@@ -298,8 +311,20 @@ struct SubViewOpInterface
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
+
+ scf::YieldOp::create(builder, loc, lastPosInBounds);
+
+ // Populate the "else" region (for size == 0).
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ Value trueVal =
+ arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
+ scf::YieldOp::create(builder, loc, trueVal);
+
+ builder.setInsertionPointAfter(ifOp);
+ Value finalCondition = ifOp.getResult(0);
+
cf::AssertOp::create(
- builder, loc, lastPosInBounds,
+ builder, loc, finalCondition,
generateErrorMessage(op,
"subview runs out-of-bounds along dimension " +
std::to_string(i)));
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 744a595..1ab01d8 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -111,10 +111,8 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
return nullptr;
}
-/// Helper function to compute the difference between two values. This is used
-/// by the loop implementations to compute the trip count.
-static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
- bool isSigned) {
+std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
+ bool isSigned) {
llvm::APSInt diff;
auto addOp = ub.getDefiningOp<arith::AddIOp>();
if (!addOp)
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 10eae89..888dd44 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -291,47 +291,61 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
return arith::DivUIOp::create(builder, loc, sum, divisor);
}
-/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
-/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
-/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
-/// unrolled iteration using annotateFn.
-static void generateUnrolledLoop(
- Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
+void mlir::generateUnrolledLoop(
+ Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
- ValueRange iterArgs, ValueRange yieldedValues) {
+ ValueRange iterArgs, ValueRange yieldedValues,
+ IRMapping *clonedToSrcOpsMap) {
+
+ // Check if the op was cloned from another source op, and return it if found
+ // (or the same op if not found)
+ auto findOriginalSrcOp =
+ [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * {
+ Operation *srcOp = op;
+ // If the source op derives from another op: traverse the chain to find the
+ // original source op
+ while (srcOp && clonedToSrcOpsMap.contains(srcOp))
+ srcOp = clonedToSrcOpsMap.lookup(srcOp);
+ return srcOp;
+ };
+
// Builder to insert unrolled bodies just before the terminator of the body of
- // 'forOp'.
+ // the loop.
auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
- constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
+ static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
if (!annotateFn)
- annotateFn = defaultAnnotateFn;
+ annotateFn = noopAnnotateFn;
// Keep a pointer to the last non-terminator operation in the original block
// so that we know what to clone (since we are doing this in-place).
Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
- // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
+ // Unroll the contents of the loop body (append unrollFactor - 1 additional
+ // copies).
SmallVector<Value, 4> lastYielded(yieldedValues);
for (unsigned i = 1; i < unrollFactor; i++) {
- IRMapping operandMap;
-
// Prepare operand map.
+ IRMapping operandMap;
operandMap.map(iterArgs, lastYielded);
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
- if (!forOpIV.use_empty()) {
- Value ivUnroll = ivRemapFn(i, forOpIV, builder);
- operandMap.map(forOpIV, ivUnroll);
+ if (!iv.use_empty()) {
+ Value ivUnroll = ivRemapFn(i, iv, builder);
+ operandMap.map(iv, ivUnroll);
}
// Clone the original body of 'forOp'.
for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
- Operation *clonedOp = builder.clone(*it, operandMap);
+ Operation *srcOp = &(*it);
+ Operation *clonedOp = builder.clone(*srcOp, operandMap);
annotateFn(i, clonedOp, builder);
+ if (clonedToSrcOpsMap)
+ clonedToSrcOpsMap->map(clonedOp,
+ findOriginalSrcOp(srcOp, *clonedToSrcOpsMap));
}
// Update yielded values.
@@ -1544,3 +1558,100 @@ bool mlir::isPerfectlyNestedForLoops(
}
return true;
}
+
+llvm::SmallVector<int64_t>
+mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
+ std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
+ std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
+ std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
+ if (!loBnds || !upBnds || !steps)
+ return {};
+ llvm::SmallVector<int64_t> tripCounts;
+ for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
+ std::optional<llvm::APInt> numIter = constantTripCount(
+ lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
+ if (!numIter)
+ return {};
+ tripCounts.push_back(numIter->getSExtValue());
+ }
+ return tripCounts;
+}
+
+FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
+ scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors,
+ RewriterBase &rewriter,
+ function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
+ IRMapping *clonedToSrcOpsMap) {
+ const unsigned numLoops = op.getNumLoops();
+ assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) &&
+ "Expected positive unroll factors");
+ assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) &&
+ "Expected non-empty unroll factors of size <= to the number of loops");
+
+ // Bail out if no valid unroll factors were provided
+ if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; }))
+ return rewriter.notifyMatchFailure(
+ op, "Unrolling not applied if all factors are 1");
+
+ // Return if the loop body is empty.
+ if (llvm::hasSingleElement(op.getBody()->getOperations()))
+ return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body");
+
+ // If the provided unroll factors do not cover all the loop dims, they are
+ // applied to the inner loop dimensions.
+ const unsigned firstLoopDimIdx = numLoops - unrollFactors.size();
+
+ // Make sure that the unroll factors divide the iteration space evenly
+ // TODO: Support unrolling loops with dynamic iteration spaces.
+ const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op);
+ if (tripCounts.empty())
+ return rewriter.notifyMatchFailure(
+ op, "Failed to compute constant trip counts for the loop. Note that "
+ "dynamic loop sizes are not supported.");
+
+ for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
+ const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
+ if (tripCounts[dimIdx] % unrollFactor)
+ return rewriter.notifyMatchFailure(
+ op, "Unroll factors don't divide the iteration space evenly");
+ }
+
+ std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps();
+ if (!maybeFoldSteps)
+ return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps");
+ llvm::SmallVector<size_t> steps{};
+ for (auto step : *maybeFoldSteps)
+ steps.push_back(static_cast<size_t>(*getConstantIntValue(step)));
+
+ for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
+ const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
+ if (unrollFactor == 1)
+ continue;
+ const size_t origStep = steps[dimIdx];
+ const int64_t newStep = origStep * unrollFactor;
+ IRMapping clonedToSrcOpsMap;
+
+ ValueRange iterArgs = ValueRange(op.getRegionIterArgs());
+ auto yieldedValues = op.getBody()->getTerminator()->getOperands();
+
+ generateUnrolledLoop(
+ op.getBody(), op.getInductionVars()[dimIdx], unrollFactor,
+ [&](unsigned i, Value iv, OpBuilder b) {
+ // iv' = iv + step * i;
+ const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i);
+ const auto map =
+ b.getDimIdentityMap().dropResult(0).insertResult(expr, 0);
+ return affine::AffineApplyOp::create(b, iv.getLoc(), map,
+ ValueRange{iv});
+ },
+ /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap);
+
+ // Update loop step
+ auto prevInsertPoint = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ op.getStepMutable()[dimIdx].assign(
+ arith::ConstantIndexOp::create(rewriter, op.getLoc(), newStep));
+ rewriter.restoreInsertionPoint(prevInsertPoint);
+ }
+ return op;
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index fe50865..0c8114d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1276,12 +1276,19 @@ LogicalResult spirv::GlobalVariableOp::verify() {
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), init.getAttr());
// TODO: Currently only variable initialization with specialization
- // constants and other variables is supported. They could be normal
- // constants in the module scope as well.
- if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
- spirv::SpecConstantCompositeOp>(initOp)) {
+ // constants is supported. There could be normal constants in the module
+ // scope as well.
+ //
+ // In the current setup we also cannot initialize one global variable with
+ // another. The problem is that if we try to initialize pointer of type X
+ // with another pointer type, the validator fails because it expects the
+ // variable to be initialized to be type X, not pointer to X. Now
+ // `spirv.GlobalVariable` only allows pointer type, so in the current design
+ // we cannot initialize one `spirv.GlobalVariable` with another.
+ if (!initOp ||
+ !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
return emitOpError("initializer must be result of a "
- "spirv.SpecConstant or spirv.GlobalVariable or "
+ "spirv.SpecConstant or "
"spirv.SpecConstantCompositeOp op");
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index 73e0f3d..f53d272 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -159,14 +159,22 @@ IterationGraphSorter::IterationGraphSorter(
loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
strategy(strategy) {
// One map per tensor.
- assert(loop2InsLvl.size() == ins.size());
+ assert(this->loop2InsLvl.size() == this->ins.size());
// All the affine maps have the same number of dimensions (loops).
assert(llvm::all_equal(llvm::map_range(
- loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
+ this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
// The number of results of the map should match the rank of the tensor.
- assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
+ assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) {
auto [m, v] = mvPair;
- return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();
+
+ // For ranked types the rank must match.
+ // Simply return true for UnrankedTensorType
+ if (auto shapedType = llvm::dyn_cast<ShapedType>(v.getType())) {
+ return !shapedType.hasRank() ||
+ (m.getNumResults() == shapedType.getRank());
+ }
+ // Non-shaped (scalar) types behave like rank-0.
+ return m.getNumResults() == 0;
}));
itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
index c031118..753cb95 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -158,7 +159,11 @@ struct ExtractSliceOpInterface
// 0 <= offset + (size - 1) * stride < dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
- for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+
+ for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+ // Reset insertion point to before the operation for each dimension
+ builder.setInsertionPoint(extractSliceOp);
+
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(
@@ -176,6 +181,16 @@ struct ExtractSliceOpInterface
std::to_string(i) +
" is out-of-bounds"));
+ // Only verify if size > 0
+ Value sizeIsNonZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+ auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
+ sizeIsNonZero, /*withElseRegion=*/true);
+
+ // Populate the "then" region (for size > 0).
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
@@ -184,8 +199,19 @@ struct ExtractSliceOpInterface
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
+ scf::YieldOp::create(builder, loc, lastPosInBounds);
+
+ // Populate the "else" region (for size == 0).
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ Value trueVal =
+ arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
+ scf::YieldOp::create(builder, loc, trueVal);
+
+ builder.setInsertionPointAfter(ifOp);
+ Value finalCondition = ifOp.getResult(0);
+
cf::AssertOp::create(
- builder, loc, lastPosInBounds,
+ builder, loc, finalCondition,
generateErrorMessage(
op, "extract_slice runs out-of-bounds along dimension " +
std::to_string(i)));
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index a85ff10a..293c6af 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -38,7 +38,7 @@ using namespace mlir::tosa;
//===----------------------------------------------------------------------===//
// Check that the zero point of the tensor and padding operations are aligned.
-bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
+static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
// Check that padConst is a constant value and a scalar tensor
DenseElementsAttr padConstAttr;
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
@@ -889,8 +889,9 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
template <typename IntFolder, typename FloatFolder>
-DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
- RankedTensorType returnTy) {
+static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
+ DenseElementsAttr rhs,
+ RankedTensorType returnTy) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 24e9095..f9aa28d5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -113,9 +113,12 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
if (layout.size() != shape.size())
return std::nullopt;
auto ratio = computeShapeRatio(shape, layout);
- if (!ratio.has_value())
+ if (ratio.has_value()) {
+ newShape = ratio.value();
+ } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
return std::nullopt;
- newShape = ratio.value();
+ }
+ // Round-robin case: continue with original newShape
}
if (data.size()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 2c37140..ec5feb8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -344,6 +344,13 @@ void XeGPUBlockingPass::runOnOperation() {
xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter);
+ // Remove leading unit dimensions from vector ops and then
+ // do the unrolling.
+ {
+ RewritePatternSet patterns(ctx);
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ (void)applyPatternsGreedily(op, std::move(patterns));
+ }
xegpu::UnrollOptions options;
options.setFilterConstraint(
[&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index b4605cd..a38993e 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -147,7 +147,7 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
- auto parentOp = arg.getOwner()->getParentOp();
+ auto *parentOp = arg.getOwner()->getParentOp();
if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
if (tiedInit)
diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index 52162a4..2255633 100644
--- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -239,6 +239,8 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
// Remember all entry-points if object dumping is enabled.
if (options.enableObjectDump) {
for (auto funcOp : m->getRegion(0).getOps<LLVM::LLVMFuncOp>()) {
+ if (funcOp.getBlocks().empty())
+ continue;
StringRef funcName = funcOp.getSymName();
engine->functionNames.push_back(funcName.str());
}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 5f63fe6..73219c6 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -709,7 +709,7 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
/// Return information for registered operations by dialect.
ArrayRef<RegisteredOperationName>
MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
- auto lowerBound = llvm::lower_bound(
+ auto *lowerBound = llvm::lower_bound(
impl->sortedRegisteredOperations, dialectName, [](auto &lhs, auto &rhs) {
return lhs.getDialect().getNamespace().compare(rhs);
});
@@ -718,7 +718,7 @@ MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
lowerBound->getDialect().getNamespace() != dialectName)
return ArrayRef<RegisteredOperationName>();
- auto upperBound =
+ auto *upperBound =
std::upper_bound(lowerBound, impl->sortedRegisteredOperations.end(),
dialectName, [](auto &lhs, auto &rhs) {
return lhs.compare(rhs.getDialect().getNamespace());
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 2acbd03..64e3c5f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -649,40 +649,38 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
auto *arrayType = llvm::ArrayType::get(elementType, numElements);
if (child->isZeroValue() && !elementType->isFPOrFPVectorTy()) {
return llvm::ConstantAggregateZero::get(arrayType);
- } else {
- if (llvm::ConstantDataSequential::isElementTypeCompatible(
- elementType)) {
- // TODO: Handle all compatible types. This code only handles integer.
- if (isa<llvm::IntegerType>(elementType)) {
- if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) {
- if (ci->getBitWidth() == 8) {
- SmallVector<int8_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
- if (ci->getBitWidth() == 16) {
- SmallVector<int16_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
- if (ci->getBitWidth() == 32) {
- SmallVector<int32_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
- if (ci->getBitWidth() == 64) {
- SmallVector<int64_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
+ }
+ if (llvm::ConstantDataSequential::isElementTypeCompatible(elementType)) {
+ // TODO: Handle all compatible types. This code only handles integer.
+ if (isa<llvm::IntegerType>(elementType)) {
+ if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) {
+ if (ci->getBitWidth() == 8) {
+ SmallVector<int8_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
+ }
+ if (ci->getBitWidth() == 16) {
+ SmallVector<int16_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
+ }
+ if (ci->getBitWidth() == 32) {
+ SmallVector<int32_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
+ }
+ if (ci->getBitWidth() == 64) {
+ SmallVector<int64_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
}
}
}
+ }
// std::vector is used here to accomodate large number of elements that
// exceed SmallVector capacity.
std::vector<llvm::Constant *> constants(numElements, child);
return llvm::ConstantArray::get(arrayType, constants);
- }
}
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index b88fbaa..29ed5a4 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -89,6 +89,22 @@ static bool isZeroValue(Attribute attr) {
return false;
}
+/// Move all functions declaration before functions definitions. In SPIR-V
+/// "declarations" are functions without a body and "definitions" functions
+/// with a body. This is stronger than necessary. It should be sufficient to
+/// ensure any declarations precede their uses and not all definitions, however
+/// this allows to avoid analysing every function in the module this way.
+static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp) {
+ Block::OpListType &ops = moduleOp.getBody()->getOperations();
+ if (ops.empty())
+ return;
+ Operation &firstOp = ops.front();
+ for (Operation &op : llvm::drop_begin(ops))
+ if (auto funcOp = dyn_cast<spirv::FuncOp>(op))
+ if (funcOp.getBody().empty())
+ funcOp->moveBefore(&firstOp);
+}
+
namespace mlir {
namespace spirv {
@@ -119,6 +135,8 @@ LogicalResult Serializer::serialize() {
processMemoryModel();
processDebugInfo();
+ moveFuncDeclarationsToTop(module);
+
// Iterate over the module body to serialize it. Assumptions are that there is
// only one basic block in the moduleOp
for (auto &op : *module.getBody()) {