aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp48
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp6
-rw-r--r--mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp98
-rw-r--r--mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp4
-rw-r--r--mlir/lib/Dialect/DLTI/DLTI.cpp5
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp64
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp19
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp36
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp2
-rw-r--r--mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp20
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp49
-rw-r--r--mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp353
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp2
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp17
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp198
-rw-r--r--mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp21
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp21
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp3
24 files changed, 679 insertions, 312 deletions
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
index 1dc69ab..05c7707 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
@@ -39,9 +39,9 @@ struct LoopCoalescingPass
func::FuncOp func = getOperation();
func.walk<WalkOrder::PreOrder>([](Operation *op) {
if (auto scfForOp = dyn_cast<scf::ForOp>(op))
- (void)coalescePerfectlyNestedLoops(scfForOp);
+ (void)coalescePerfectlyNestedSCFForLoops(scfForOp);
else if (auto affineForOp = dyn_cast<AffineForOp>(op))
- (void)coalescePerfectlyNestedLoops(affineForOp);
+ (void)coalescePerfectlyNestedAffineLoops(affineForOp);
});
}
};
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 37b36f7..117ee8e 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -84,7 +84,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
// We are trying to reify a bound for `value` in terms of the owning op's
// operands. Construct a stop condition that evaluates to "true" for any SSA
// value except for `value`. I.e., the bound will be computed in terms of
@@ -100,7 +101,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index af59973..268050a 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -2765,3 +2765,51 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
return success();
}
+
+LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
+ LogicalResult result(failure());
+ SmallVector<AffineForOp> loops;
+ getPerfectlyNestedLoops(loops, op);
+ if (loops.size() <= 1)
+ return success();
+
+ // Look for a band of loops that can be coalesced, i.e. perfectly nested
+ // loops with bounds defined above some loop.
+ // 1. For each loop, find above which parent loop its operands are
+ // defined.
+ SmallVector<unsigned> operandsDefinedAbove(loops.size());
+ for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+ operandsDefinedAbove[i] = i;
+ for (unsigned j = 0; j < i; ++j) {
+ if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
+ operandsDefinedAbove[i] = j;
+ break;
+ }
+ }
+ }
+
+ // 2. Identify bands of loops such that the operands of all of them are
+ // defined above the first loop in the band. Traverse the nest bottom-up
+ // so that modifications don't invalidate the inner loops.
+ for (unsigned end = loops.size(); end > 0; --end) {
+ unsigned start = 0;
+ for (; start < end - 1; ++start) {
+ auto maxPos =
+ *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+ std::next(operandsDefinedAbove.begin(), end));
+ if (maxPos > start)
+ continue;
+ assert(maxPos == start &&
+ "expected loop bounds to be known at the start of the band");
+ auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
+ if (succeeded(coalesceLoops(band)))
+ result = success();
+ break;
+ }
+ // If a band was found and transformed, keep looking at the loops above
+ // the outermost transformed loop.
+ if (start != end - 1)
+ end = start + 1;
+ }
+ return result;
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 8d9fd14..fad2212 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -119,7 +119,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
// We are trying to reify a bound for `value` in terms of the owning op's
// operands. Construct a stop condition that evaluates to "true" for any SSA
// value expect for `value`. I.e., the bound will be computed in terms of
@@ -135,7 +136,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 1f48d27..1374022 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -40,8 +40,9 @@ static Type matchContainerType(Type element, Type container) {
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
-/// smmla instruction is emitted.
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
class LowerContractionToSMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
@@ -49,32 +50,35 @@ public:
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- // Check index maps that represent M N K in contract.
- auto indexingMaps = op.getIndexingMapsArray();
- if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
- return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
- affineMap.getNumResults() != 2;
- })) {
- return failure();
- }
- // Check iterator types for contract.
- auto iteratorTypes = op.getIteratorTypesArray();
- if (iteratorTypes.size() != 3 ||
- iteratorTypes[0] != vector::IteratorType::parallel ||
- iteratorTypes[1] != vector::IteratorType::parallel ||
- iteratorTypes[2] != vector::IteratorType::reduction) {
- return failure();
- }
- // Infer tile sizes from operands; Note: RHS is not transposed.
+ // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
+ // Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
- auto dimM = lhsType.getDimSize(0);
+ auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
- auto dimK = lhsType.getDimSize(1);
-
+ auto dimK = rhsType.getDimSize(1);
+ bool isVecmat = dimM == 1 ? true : false;
+ if (lhsType.getDimSize(lhsType.getRank() - 1) !=
+ rhsType.getDimSize(rhsType.getRank() - 1)) {
+ return failure(); // dimK mismatch
+ }
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
// tiling.
- if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
+ if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
+ return failure();
+ }
+
+ // Check iterator types for contract. All iterators except inner-most
+ // dimension must be parallel.
+ auto iteratorTypes = op.getIteratorTypesArray();
+ if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
+ vector::IteratorType::reduction) {
+ return failure();
+ }
+ if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
+ [](vector::IteratorType iteratorType) {
+ return iteratorType != vector::IteratorType::parallel;
+ })) {
return failure();
}
@@ -120,11 +124,14 @@ public:
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
- SmallVector<int64_t> smmlaShape{2, 2, 8};
- SmallVector<int64_t> loopOrder{0, 1, 2};
+ SmallVector<int64_t> smmlaShape{2, 8};
+ SmallVector<int64_t> loopOrder{0, 1};
+ if (unrolledSize.size() == 3) {
+ smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
+ loopOrder.push_back(2);
+ }
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
-
// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
@@ -150,16 +157,40 @@ public:
Value tiledAcc =
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
+ auto inputElementType =
+ tiledLhs.getType().cast<ShapedType>().getElementType();
+ auto accElementType =
+ tiledAcc.getType().cast<ShapedType>().getElementType();
+ auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
+ auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+ // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
+ // rows along dimM. Expand their shapes to match the smmla op.
+ if (isVecmat) {
+ auto expandForSMMLA = [&](Value tiledOperand,
+ VectorType expandedTypeType) {
+ auto emptyOperand = rewriter.create<arith::ConstantOp>(
+ loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
+ SmallVector<int64_t> offsets(
+ emptyOperand.getType().cast<ShapedType>().getRank(), 0);
+ SmallVector<int64_t> strides(
+ tiledOperand.getType().cast<ShapedType>().getRank(), 1);
+ return rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, tiledOperand, emptyOperand, offsets, strides);
+ };
+ tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
+ tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
+ }
+
// Collapse tiled operands to 1D vectors required by smmla intrinsic
- auto collapsedInputType = VectorType::get(
- tiledLhs.getType().cast<ShapedType>().getNumElements(),
- tiledLhs.getType().cast<ShapedType>().getElementType());
- auto collapsedOutputType = VectorType::get(
- {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
+ auto collapsedInputType =
+ VectorType::get(inputExpandedType.getNumElements(), inputElementType);
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+ auto collapsedOutputType =
+ VectorType::get(outputExpandedType.getNumElements(), accElementType);
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
@@ -172,6 +203,11 @@ public:
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+ // With vecmat, only one row of tiled ACC can be inserted inot file result
+ if (isVecmat) {
+ tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
+ }
+
// Insert the tiled result back into the non tiled result of the
// contract op.
SmallVector<int64_t> strides(
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 5d11f8f6..1320db3 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -531,8 +531,8 @@ static ParseResult parseSwitchOpCases(
failed(parser.parseSuccessor(destination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
- if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
- /*allowResultNumber=*/false)) ||
+ if (failed(parser.parseOperandList(operands,
+ OpAsmParser::Delimiter::None)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index daef234..98a8865 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -282,6 +282,11 @@ DataLayoutEntryListRef DataLayoutSpecAttr::getEntries() const {
}
StringAttr
+DataLayoutSpecAttr::getEndiannessIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(DLTIDialect::kDataLayoutEndiannessKey);
+}
+
+StringAttr
DataLayoutSpecAttr::getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
return Builder(context).getStringAttr(
DLTIDialect::kDataLayoutAllocaMemorySpaceKey);
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index f4a9dc3..7cbf28b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -70,6 +70,11 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
return false;
}
+bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
+ return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
+ isSupportedIntegerType(type);
+}
+
bool mlir::emitc::isSupportedFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
@@ -780,12 +785,61 @@ LogicalResult emitc::YieldOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult emitc::SubscriptOp::verify() {
- if (getIndices().size() != (size_t)getArray().getType().getRank()) {
- return emitOpError() << "requires number of indices ("
- << getIndices().size()
- << ") to match the rank of the array type ("
- << getArray().getType().getRank() << ")";
+ // Checks for array operand.
+ if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
+ // Check number of indices.
+ if (getIndices().size() != (size_t)arrayType.getRank()) {
+ return emitOpError() << "on array operand requires number of indices ("
+ << getIndices().size()
+ << ") to match the rank of the array type ("
+ << arrayType.getRank() << ")";
+ }
+ // Check types of index operands.
+ for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
+ Type type = getIndices()[i].getType();
+ if (!isIntegerIndexOrOpaqueType(type)) {
+ return emitOpError() << "on array operand requires index operand " << i
+ << " to be integer-like, but got " << type;
+ }
+ }
+ // Check element type.
+ Type elementType = arrayType.getElementType();
+ if (elementType != getType()) {
+ return emitOpError() << "on array operand requires element type ("
+ << elementType << ") and result type (" << getType()
+ << ") to match";
+ }
+ return success();
}
+
+ // Checks for pointer operand.
+ if (auto pointerType =
+ llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
+ // Check number of indices.
+ if (getIndices().size() != 1) {
+ return emitOpError()
+ << "on pointer operand requires one index operand, but got "
+ << getIndices().size();
+ }
+ // Check types of index operand.
+ Type type = getIndices()[0].getType();
+ if (!isIntegerIndexOrOpaqueType(type)) {
+ return emitOpError() << "on pointer operand requires index operand to be "
+ "integer-like, but got "
+ << type;
+ }
+ // Check pointee type.
+ Type pointeeType = pointerType.getPointee();
+ if (pointeeType != getType()) {
+ return emitOpError() << "on pointer operand requires pointee type ("
+ << pointeeType << ") and result type (" << getType()
+ << ") to match";
+ }
+ return success();
+ }
+
+ // The operand has opaque type, so we can't assume anything about the number
+ // or types of index operands.
return success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 06c1fdd..5637552 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -168,13 +168,6 @@ bool LLVM::StoreOp::canUsesBeRemoved(
DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
- // `canUsesBeRemoved` checked this blocking use must be the stored slot
- // pointer.
- for (Operation *user : slot.ptr.getUsers())
- if (auto declareOp = dyn_cast<LLVM::DbgDeclareOp>(user))
- rewriter.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
- declareOp.getVarInfo(),
- declareOp.getLocationExpr());
return DeletionKind::Delete;
}
@@ -407,6 +400,18 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses(
return DeletionKind::Keep;
}
+bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }
+
+void LLVM::DbgDeclareOp::visitReplacedValues(
+ ArrayRef<std::pair<Operation *, Value>> definitions,
+ RewriterBase &rewriter) {
+ for (auto [op, value] : definitions) {
+ rewriter.setInsertionPointAfter(op);
+ rewriter.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
+ getLocationExpr());
+ }
+}
+
//===----------------------------------------------------------------------===//
// Interfaces for GEPOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2d7219f..9c5c58f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -373,14 +373,15 @@ namespace {
class RegionBuilderHelper {
public:
- RegionBuilderHelper(MLIRContext *context, Block &block)
- : context(context), block(block) {}
+ RegionBuilderHelper(OpBuilder &builder, Block &block)
+ : builder(builder), block(block) {}
// Build the unary functions defined by OpDSL.
Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
if (!isFloatingPoint(arg))
llvm_unreachable("unsupported non numeric type");
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
switch (unaryFn) {
case UnaryFn::exp:
return builder.create<math::ExpOp>(arg.getLoc(), arg);
@@ -407,7 +408,8 @@ public:
arg1.getType().getIntOrFloatBitWidth() == 1;
if (!allComplex && !allFloatingPoint && !allInteger)
llvm_unreachable("unsupported non numeric type");
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
switch (binaryFn) {
case BinaryFn::add:
if (allComplex)
@@ -481,29 +483,32 @@ public:
}
void yieldOutputs(ValueRange values) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
builder.create<YieldOp>(loc, values);
}
Value constant(const std::string &value) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
}
Value index(int64_t dim) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
}
Type getIntegerType(unsigned width) {
- return IntegerType::get(context, width);
+ return IntegerType::get(builder.getContext(), width);
}
- Type getFloat32Type() { return Float32Type::get(context); }
- Type getFloat64Type() { return Float64Type::get(context); }
+ Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
+ Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
private:
// Generates operations to cast the given operand to a specified type.
@@ -511,7 +516,8 @@ private:
// operand returned as-is (which will presumably yield a verification
// issue downstream).
Value cast(Type toType, Value operand, bool isUnsignedCast) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
auto loc = operand.getLoc();
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
}
@@ -526,13 +532,7 @@ private:
return llvm::isa<IntegerType>(value.getType());
}
- OpBuilder getBuilder() {
- OpBuilder builder(context);
- builder.setInsertionPointToEnd(&block);
- return builder;
- }
-
- MLIRContext *context;
+ OpBuilder &builder;
Block &block;
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 58fb2e9..899b8c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
@@ -110,6 +111,10 @@ struct LinalgOpInterface
ArrayRef<OpOperand *> opOperands) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
+ // Accesses into sparse data structures are not necessarily elementwise.
+ if (sparse_tensor::hasAnySparseOperand(linalgOp))
+ return false;
+
// All loops must be parallel.
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
return false;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index b32ea8e..c3a08ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
/*stopCondition=*/
- [&](Value v, std::optional<int64_t> d) {
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
if (v == forOp.getUpperBound())
return false;
// Compute a bound that is independent of any affine op results.
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 0b85462..42629e1 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -216,20 +216,30 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
// Convert `math.fpowi` to a series of `arith.mulf` operations.
// If the power is negative, we divide one by the result.
// If both the base and power are zero, the result is 1.
-static LogicalResult convertFPowICstOp(math::FPowIOp op,
- PatternRewriter &rewriter) {
+// In the case of non constant power, we convert the operation to `math.powf`.
+static LogicalResult convertFPowIOp(math::FPowIOp op,
+ PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value base = op.getOperand(0);
Value power = op.getOperand(1);
Type baseType = base.getType();
+ auto convertFPowItoPowf = [&]() -> LogicalResult {
+ Value castPowerToFp =
+ rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
+ Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
+ castPowerToFp);
+ rewriter.replaceOp(op, res);
+ return success();
+ };
+
Attribute cstAttr;
if (!matchPattern(power, m_Constant(&cstAttr)))
- return failure();
+ return convertFPowItoPowf();
APInt value;
if (!matchPattern(cstAttr, m_ConstantInt(&value)))
- return failure();
+ return convertFPowItoPowf();
int64_t powerInt = value.getSExtValue();
bool isNegative = powerInt < 0;
@@ -591,7 +601,7 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
}
void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
- patterns.add(convertFPowICstOp);
+ patterns.add(convertFPowIOp);
}
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index bf58750..a043431 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1538,6 +1538,21 @@ static void printAtomicReductionRegion(OpAsmPrinter &printer,
printer.printRegion(region);
}
+static ParseResult parseCleanupReductionRegion(OpAsmParser &parser,
+ Region &region) {
+ if (parser.parseOptionalKeyword("cleanup"))
+ return success();
+ return parser.parseRegion(region);
+}
+
+static void printCleanupReductionRegion(OpAsmPrinter &printer,
+ DeclareReductionOp op, Region &region) {
+ if (region.empty())
+ return;
+ printer << "cleanup ";
+ printer.printRegion(region);
+}
+
LogicalResult DeclareReductionOp::verifyRegions() {
if (getInitializerRegion().empty())
return emitOpError() << "expects non-empty initializer region";
@@ -1571,21 +1586,29 @@ LogicalResult DeclareReductionOp::verifyRegions() {
"of the reduction type";
}
- if (getAtomicReductionRegion().empty())
+ if (!getAtomicReductionRegion().empty()) {
+ Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
+ if (atomicReductionEntryBlock.getNumArguments() != 2 ||
+ atomicReductionEntryBlock.getArgumentTypes()[0] !=
+ atomicReductionEntryBlock.getArgumentTypes()[1])
+ return emitOpError() << "expects atomic reduction region with two "
+ "arguments of the same type";
+ auto ptrType = llvm::dyn_cast<PointerLikeType>(
+ atomicReductionEntryBlock.getArgumentTypes()[0]);
+ if (!ptrType ||
+ (ptrType.getElementType() && ptrType.getElementType() != getType()))
+ return emitOpError() << "expects atomic reduction region arguments to "
+ "be accumulators containing the reduction type";
+ }
+
+ if (getCleanupRegion().empty())
return success();
+ Block &cleanupEntryBlock = getCleanupRegion().front();
+ if (cleanupEntryBlock.getNumArguments() != 1 ||
+ cleanupEntryBlock.getArgument(0).getType() != getType())
+ return emitOpError() << "expects cleanup region with one argument "
+ "of the reduction type";
- Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
- if (atomicReductionEntryBlock.getNumArguments() != 2 ||
- atomicReductionEntryBlock.getArgumentTypes()[0] !=
- atomicReductionEntryBlock.getArgumentTypes()[1])
- return emitOpError() << "expects atomic reduction region with two "
- "arguments of the same type";
- auto ptrType = llvm::dyn_cast<PointerLikeType>(
- atomicReductionEntryBlock.getArgumentTypes()[0]);
- if (!ptrType ||
- (ptrType.getElementType() && ptrType.getElementType() != getType()))
- return emitOpError() << "expects atomic reduction region arguments to "
- "be accumulators containing the reduction type";
return success();
}
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index cb36e0c..1e13e60 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -58,7 +58,7 @@ struct ForOpInterface
ValueDimList boundOperands;
LogicalResult status = ValueBoundsConstraintSet::computeBound(
bound, boundOperands, BoundType::EQ, yieldedValue, dim,
- [&](Value v, std::optional<int64_t> d) {
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
// Stop when reaching a block argument of the loop body.
if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
return bbArg.getOwner()->getParentOp() == forOp;
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index c091841..7e4faf8 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -332,9 +332,9 @@ transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
transform::TransformState &state) {
LogicalResult result(failure());
if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
- result = coalescePerfectlyNestedLoops(scfForOp);
+ result = coalescePerfectlyNestedSCFForLoops(scfForOp);
else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
- result = coalescePerfectlyNestedLoops(affineForOp);
+ result = coalescePerfectlyNestedAffineLoops(affineForOp);
results.push_back(op);
if (failed(result)) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
index a69df02..6ba7020 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
@@ -28,6 +28,7 @@ namespace {
struct TestSCFParallelLoopCollapsing
: public impl::TestSCFParallelLoopCollapsingBase<
TestSCFParallelLoopCollapsing> {
+
void runOnOperation() override {
Operation *module = getOperation();
@@ -88,6 +89,7 @@ struct TestSCFParallelLoopCollapsing
// Only apply the transformation on parallel loops where the specified
// transformation is valid, but do NOT early abort in the case of invalid
// loops.
+ IRRewriter rewriter(&getContext());
module->walk([&](scf::ParallelOp op) {
if (flattenedCombinedLoops.size() != op.getNumLoops()) {
op.emitOpError("has ")
@@ -97,7 +99,7 @@ struct TestSCFParallelLoopCollapsing
<< flattenedCombinedLoops.size() << " iter args.";
return;
}
- collapseParallelLoops(op, combinedLoops);
+ collapseParallelLoops(rewriter, op, combinedLoops);
});
}
};
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 914aeb4..9279081 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
@@ -472,18 +473,23 @@ LogicalResult mlir::loopUnrollByFactor(
return success();
}
-/// Return the new lower bound, upper bound, and step in that order. Insert any
-/// additional bounds calculations before the given builder and any additional
-/// conversion back to the original loop induction value inside the given Block.
-static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
- OpBuilder &insideLoopBuilder, Location loc,
- Value lowerBound, Value upperBound, Value step,
- Value inductionVar) {
+/// Transform a loop with a strictly positive step
+/// for %i = %lb to %ub step %s
+/// into a 0-based loop with step 1
+/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
+/// %i = %ii * %s + %lb
+/// Insert the induction variable remapping in the body of `inner`, which is
+/// expected to be either `loop` or another loop perfectly nested under `loop`.
+/// Insert the definition of new bounds immediate before `outer`, which is
+/// expected to be either `loop` or its parent in the loop nest.
+static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+ Value lb, Value ub, Value step) {
+ // For non-index types, generate `arith` instructions
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
bool isZeroBased = false;
- if (auto ubCst = getConstantIntValue(lowerBound))
- isZeroBased = ubCst.value() == 0;
+ if (auto lbCst = getConstantIntValue(lb))
+ isZeroBased = lbCst.value() == 0;
bool isStepOne = false;
if (auto stepCst = getConstantIntValue(step))
@@ -493,62 +499,90 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
// assuming the step is strictly positive. Update the bounds and the step
// of the loop to go from 0 to the number of iterations, if necessary.
if (isZeroBased && isStepOne)
- return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound,
- /*step=*/step};
+ return {lb, ub, step};
- Value diff = boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
+ Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
Value newUpperBound =
- boundsBuilder.create<arith::CeilDivSIOp>(loc, diff, step);
-
- Value newLowerBound =
- isZeroBased ? lowerBound
- : boundsBuilder.create<arith::ConstantOp>(
- loc, boundsBuilder.getZeroAttr(lowerBound.getType()));
- Value newStep =
- isStepOne ? step
- : boundsBuilder.create<arith::ConstantOp>(
- loc, boundsBuilder.getIntegerAttr(step.getType(), 1));
-
- // Insert code computing the value of the original loop induction variable
- // from the "normalized" one.
- Value scaled =
- isStepOne
- ? inductionVar
- : insideLoopBuilder.create<arith::MulIOp>(loc, inductionVar, step);
- Value shifted =
- isZeroBased
- ? scaled
- : insideLoopBuilder.create<arith::AddIOp>(loc, scaled, lowerBound);
-
- SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
- shifted.getDefiningOp()};
- inductionVar.replaceAllUsesExcept(shifted, preserve);
- return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound,
- /*step=*/newStep};
+ isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
+
+ Value newLowerBound = isZeroBased
+ ? lb
+ : rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(lb.getType()));
+ Value newStep = isStepOne
+ ? step
+ : rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(step.getType(), 1));
+
+ return {newLowerBound, newUpperBound, newStep};
}
-/// Transform a loop with a strictly positive step
-/// for %i = %lb to %ub step %s
-/// into a 0-based loop with step 1
-/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
-/// %i = %ii * %s + %lb
-/// Insert the induction variable remapping in the body of `inner`, which is
-/// expected to be either `loop` or another loop perfectly nested under `loop`.
-/// Insert the definition of new bounds immediate before `outer`, which is
-/// expected to be either `loop` or its parent in the loop nest.
-static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
- OpBuilder builder(outer);
- OpBuilder innerBuilder = OpBuilder::atBlockBegin(inner.getBody());
- auto loopPieces = normalizeLoop(builder, innerBuilder, loop.getLoc(),
- loop.getLowerBound(), loop.getUpperBound(),
- loop.getStep(), loop.getInductionVar());
-
- loop.setLowerBound(loopPieces.lowerBound);
- loop.setUpperBound(loopPieces.upperBound);
- loop.setStep(loopPieces.step);
+/// Get back the original induction variable values after loop normalization
+static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
+ Value normalizedIv, Value origLb,
+ Value origStep) {
+ Value denormalizedIv;
+ SmallPtrSet<Operation *, 2> preserve;
+ bool isStepOne = isConstantIntValue(origStep, 1);
+ bool isZeroBased = isConstantIntValue(origLb, 0);
+
+ Value scaled = normalizedIv;
+ if (!isStepOne) {
+ scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep);
+ preserve.insert(scaled.getDefiningOp());
+ }
+ denormalizedIv = scaled;
+ if (!isZeroBased) {
+ denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb);
+ preserve.insert(denormalizedIv.getDefiningOp());
+ }
+
+ rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
}
-LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
+/// Helper function to multiply a sequence of values.
+static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
+ ArrayRef<Value> values) {
+ assert(!values.empty() && "unexpected empty list");
+ Value productOf = values.front();
+ for (auto v : values.drop_front()) {
+ productOf = rewriter.create<arith::MulIOp>(loc, productOf, v);
+ }
+ return productOf;
+}
+
+/// For each original loop, the value of the
+/// induction variable can be obtained by dividing the induction variable of
+/// the linearized loop by the total number of iterations of the loops nested
+/// in it modulo the number of iterations in this loop (remove the values
+/// related to the outer loops):
+/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
+/// Compute these iteratively from the innermost loop by creating a "running
+/// quotient" of division by the range.
+static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
+delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
+ Value linearizedIv, ArrayRef<Value> ubs) {
+ Value previous = linearizedIv;
+ SmallVector<Value> delinearizedIvs(ubs.size());
+ SmallPtrSet<Operation *, 2> preservedUsers;
+ for (unsigned i = 0, e = ubs.size(); i < e; ++i) {
+ unsigned idx = ubs.size() - i - 1;
+ if (i != 0) {
+ previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
+ preservedUsers.insert(previous.getDefiningOp());
+ }
+ Value iv = previous;
+ if (i != e - 1) {
+ iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
+ preservedUsers.insert(iv.getDefiningOp());
+ }
+ delinearizedIvs[idx] = iv;
+ }
+ return {delinearizedIvs, preservedUsers};
+}
+
+LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
+ MutableArrayRef<scf::ForOp> loops) {
if (loops.size() < 2)
return failure();
@@ -557,57 +591,148 @@ LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
// 1. Make sure all loops iterate from 0 to upperBound with step 1. This
// allows the following code to assume upperBound is the number of iterations.
- for (auto loop : loops)
- normalizeLoop(loop, outermost, innermost);
+ for (auto loop : loops) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(outermost);
+ Value lb = loop.getLowerBound();
+ Value ub = loop.getUpperBound();
+ Value step = loop.getStep();
+ auto newLoopParams =
+ emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
+
+ rewriter.modifyOpInPlace(loop, [&]() {
+ loop.setLowerBound(newLoopParams.lowerBound);
+ loop.setUpperBound(newLoopParams.upperBound);
+ loop.setStep(newLoopParams.step);
+ });
+
+ rewriter.setInsertionPointToStart(innermost.getBody());
+ denormalizeInductionVariable(rewriter, loop.getLoc(),
+ loop.getInductionVar(), lb, step);
+ }
// 2. Emit code computing the upper bound of the coalesced loop as product
// of the number of iterations of all loops.
- OpBuilder builder(outermost);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(outermost);
Location loc = outermost.getLoc();
- Value upperBound = outermost.getUpperBound();
- for (auto loop : loops.drop_front())
- upperBound =
- builder.create<arith::MulIOp>(loc, upperBound, loop.getUpperBound());
+ SmallVector<Value> upperBounds = llvm::map_to_vector(
+ loops, [](auto loop) { return loop.getUpperBound(); });
+ Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
outermost.setUpperBound(upperBound);
- builder.setInsertionPointToStart(outermost.getBody());
-
- // 3. Remap induction variables. For each original loop, the value of the
- // induction variable can be obtained by dividing the induction variable of
- // the linearized loop by the total number of iterations of the loops nested
- // in it modulo the number of iterations in this loop (remove the values
- // related to the outer loops):
- // iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
- // Compute these iteratively from the innermost loop by creating a "running
- // quotient" of division by the range.
- Value previous = outermost.getInductionVar();
+ rewriter.setInsertionPointToStart(innermost.getBody());
+ auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
+ rewriter, loc, outermost.getInductionVar(), upperBounds);
+ rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
+ preservedUsers);
+
+ for (int i = loops.size() - 1; i > 0; --i) {
+ auto outerLoop = loops[i - 1];
+ auto innerLoop = loops[i];
+
+ Operation *innerTerminator = innerLoop.getBody()->getTerminator();
+ auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
+ rewriter.eraseOp(innerTerminator);
+
+ SmallVector<Value> innerBlockArgs;
+ innerBlockArgs.push_back(delinearizeIvs[i]);
+ llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
+ rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
+ Block::iterator(innerLoop), innerBlockArgs);
+ rewriter.replaceOp(innerLoop, yieldedVals);
+ }
+ return success();
+}
+
+LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
+ if (loops.empty()) {
+ return failure();
+ }
+ IRRewriter rewriter(loops.front().getContext());
+ return coalesceLoops(rewriter, loops);
+}
+
+LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
+ LogicalResult result(failure());
+ SmallVector<scf::ForOp> loops;
+ getPerfectlyNestedLoops(loops, op);
+
+ // Look for a band of loops that can be coalesced, i.e. perfectly nested
+ // loops with bounds defined above some loop.
+
+ // 1. For each loop, find above which parent loop its bounds operands are
+ // defined.
+ SmallVector<unsigned> operandsDefinedAbove(loops.size());
for (unsigned i = 0, e = loops.size(); i < e; ++i) {
- unsigned idx = loops.size() - i - 1;
- if (i != 0)
- previous = builder.create<arith::DivSIOp>(loc, previous,
- loops[idx + 1].getUpperBound());
-
- Value iv = (i == e - 1) ? previous
- : builder.create<arith::RemSIOp>(
- loc, previous, loops[idx].getUpperBound());
- replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
- loops.back().getRegion());
+ operandsDefinedAbove[i] = i;
+ for (unsigned j = 0; j < i; ++j) {
+ SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
+ loops[i].getUpperBound(),
+ loops[i].getStep()};
+ if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
+ operandsDefinedAbove[i] = j;
+ break;
+ }
+ }
}
- // 4. Move the operations from the innermost just above the second-outermost
- // loop, delete the extra terminator and the second-outermost loop.
- scf::ForOp second = loops[1];
- innermost.getBody()->back().erase();
- outermost.getBody()->getOperations().splice(
- Block::iterator(second.getOperation()),
- innermost.getBody()->getOperations());
- second.erase();
- return success();
+ // 2. For each inner loop check that the iter_args for the immediately outer
+ // loop are the init for the immediately inner loop and that the yields of the
+ // return of the inner loop is the yield for the immediately outer loop. Keep
+ // track of where the chain starts from for each loop.
+ SmallVector<unsigned> iterArgChainStart(loops.size());
+ iterArgChainStart[0] = 0;
+ for (unsigned i = 1, e = loops.size(); i < e; ++i) {
+ // By default set the start of the chain to itself.
+ iterArgChainStart[i] = i;
+ auto outerloop = loops[i - 1];
+ auto innerLoop = loops[i];
+ if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
+ continue;
+ }
+ if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
+ continue;
+ }
+ auto outerloopTerminator = outerloop.getBody()->getTerminator();
+ if (!llvm::equal(outerloopTerminator->getOperands(),
+ innerLoop.getResults())) {
+ continue;
+ }
+ iterArgChainStart[i] = iterArgChainStart[i - 1];
+ }
+
+ // 3. Identify bands of loops such that the operands of all of them are
+ // defined above the first loop in the band. Traverse the nest bottom-up
+ // so that modifications don't invalidate the inner loops.
+ for (unsigned end = loops.size(); end > 0; --end) {
+ unsigned start = 0;
+ for (; start < end - 1; ++start) {
+ auto maxPos =
+ *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+ std::next(operandsDefinedAbove.begin(), end));
+ if (maxPos > start)
+ continue;
+ if (iterArgChainStart[end - 1] > start)
+ continue;
+ auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
+ if (succeeded(coalesceLoops(band)))
+ result = success();
+ break;
+ }
+ // If a band was found and transformed, keep looking at the loops above
+ // the outermost transformed loop.
+ if (start != end - 1)
+ end = start + 1;
+ }
+ return result;
}
void mlir::collapseParallelLoops(
- scf::ParallelOp loops, ArrayRef<std::vector<unsigned>> combinedDimensions) {
- OpBuilder outsideBuilder(loops);
+ RewriterBase &rewriter, scf::ParallelOp loops,
+ ArrayRef<std::vector<unsigned>> combinedDimensions) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loops);
Location loc = loops.getLoc();
// Presort combined dimensions.
@@ -619,25 +744,29 @@ void mlir::collapseParallelLoops(
SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
normalizedUpperBounds;
for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
- OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(loops.getBody());
- auto resultBounds =
- normalizeLoop(outsideBuilder, insideLoopBuilder, loc,
- loops.getLowerBound()[i], loops.getUpperBound()[i],
- loops.getStep()[i], loops.getBody()->getArgument(i));
-
- normalizedLowerBounds.push_back(resultBounds.lowerBound);
- normalizedUpperBounds.push_back(resultBounds.upperBound);
- normalizedSteps.push_back(resultBounds.step);
+ OpBuilder::InsertionGuard g2(rewriter);
+ rewriter.setInsertionPoint(loops);
+ Value lb = loops.getLowerBound()[i];
+ Value ub = loops.getUpperBound()[i];
+ Value step = loops.getStep()[i];
+ auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
+ normalizedLowerBounds.push_back(newLoopParams.lowerBound);
+ normalizedUpperBounds.push_back(newLoopParams.upperBound);
+ normalizedSteps.push_back(newLoopParams.step);
+
+ rewriter.setInsertionPointToStart(loops.getBody());
+ denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
+ step);
}
// Combine iteration spaces.
SmallVector<Value, 3> lowerBounds, upperBounds, steps;
- auto cst0 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 0);
- auto cst1 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
+ auto cst0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto cst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
for (auto &sortedDimension : sortedDimensions) {
- Value newUpperBound = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
+ Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, 1);
for (auto idx : sortedDimension) {
- newUpperBound = outsideBuilder.create<arith::MulIOp>(
+ newUpperBound = rewriter.create<arith::MulIOp>(
loc, newUpperBound, normalizedUpperBounds[idx]);
}
lowerBounds.push_back(cst0);
@@ -651,7 +780,7 @@ void mlir::collapseParallelLoops(
// value. The remainders then determine based on that range, which iteration
// of the original induction value this represents. This is a normalized value
// that is un-normalized already by the previous logic.
- auto newPloop = outsideBuilder.create<scf::ParallelOp>(
+ auto newPloop = rewriter.create<scf::ParallelOp>(
loc, lowerBounds, upperBounds, steps,
[&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index d4c1792..acea25f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -274,7 +274,7 @@ struct SparseTensorCodegenPass
});
// The following operations and dialects may be introduced by the
// codegen rules, and are therefore marked as legal.
- target.addLegalOp<linalg::FillOp>();
+ target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
target.addLegalDialect<
arith::ArithDialect, bufferization::BufferizationDialect,
complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6e6e843..e06ac9a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -955,25 +955,34 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
}
mlir::LogicalResult tosa::ReshapeOp::verify() {
- ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
- ShapedType outputType = llvm::cast<ShapedType>(getType());
+ TensorType inputType = getInput1().getType();
+ RankedTensorType outputType = getType();
if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
return emitOpError() << "tensor has a dimension with size zero. Each "
"dimension of a tensor must have size >= 1";
+ if ((int64_t) getNewShape().size() != outputType.getRank())
+ return emitOpError() << "new shape does not match result rank";
+
+ for (auto [newShapeDim, outputShapeDim] :
+ zip(getNewShape(), outputType.getShape()))
+ if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
+ newShapeDim != outputShapeDim)
+ return emitOpError() << "new shape is inconsistent with result shape";
+
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
int64_t inputElementsNum = inputType.getNumElements();
int64_t outputElementsNum = outputType.getNumElements();
if (inputElementsNum != outputElementsNum) {
- return emitOpError() << "Cannot reshape " << inputElementsNum
+ return emitOpError() << "cannot reshape " << inputElementsNum
<< " elements into " << outputElementsNum;
}
}
int missingDims = llvm::count(getNewShape(), -1);
if (missingDims > 1)
- return emitOpError() << "At most one target dimension can be -1";
+ return emitOpError() << "expected at most one target dimension to be -1";
return mlir::success();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index ad28c56..8614559 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -18,14 +18,9 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {
@@ -39,9 +34,87 @@ using namespace mlir::tosa;
namespace {
-void propagateShapesInRegion(Region &region);
+// Check whether this use case is replaceable. We define an op as
+// being replaceable if it is used by a TosaOp, or an op with a
+// type-inference related interface.
+// When a non-replaceable use is encountered, the value is wrapped in a
+// cast back to the original type after inference.
+bool isReplaceableUser(Operation *user) {
+ // Handle unregistered dialects.
+ if (!user->getDialect())
+ return false;
+
+ return user->getDialect()->getNamespace() ==
+ TosaDialect::getDialectNamespace() ||
+ isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
+}
+
+// During type propagation, the types of values in the operator graph are
+// updated. For the tosa.while_loop operation, types are speculatively updated
+// within the body region to determine the output type of the while_loop. This
+// process is performed until a fixed point is reached, then the types are
+// reverted.
+//
+// This class encapsulates the state information needed to perform the reversion
+// process or to commit to the final changes.
+class TypeModificationState {
+public:
+ TypeModificationState() = default;
+
+ ~TypeModificationState() {
+ // Ensure the recorded modifications are either committed or reverted.
+ assert(oldTypes.empty() && "unhandled type modifications");
+ }
+
+ // Update the state of the value and record the old type.
+ void setType(Value value, Type type) {
+ if (value.getType() != type) {
+ oldTypes.emplace_back(value, value.getType());
+ value.setType(type);
+ }
+ }
-void propagateShapesToTosaIf(Operation &op) {
+ // Revert changes made to the types in the IR by setting all the affected
+ // values to their old types.
+ void revert() {
+ // Otherwise revert the changes.
+ for (auto [value, type] : oldTypes)
+ value.setType(type);
+
+ oldTypes.clear();
+ }
+
+ // Commit the changes to the types in the IR.
+ // This requires inserting tensor.cast operations to mediate the newly
+ // inferred result types with users that do not support type inference.
+ void commit() {
+ // For each use whose type changed, cast the value with the new type back to
+ // the old type.
+ for (auto [value, oldType] : oldTypes) {
+ for (auto &use : value.getUses()) {
+ if (isReplaceableUser(use.getOwner()))
+ continue;
+
+ OpBuilder builder(value.getContext());
+ builder.setInsertionPoint(use.getOwner());
+
+ Location loc = value.getLoc();
+ use.set(builder.create<tensor::CastOp>(loc, oldType, value));
+ }
+ }
+
+ oldTypes.clear();
+ }
+
+private:
+ // A record of each value whose type was updated along with that value's
+ // previous type.
+ llvm::SmallVector<std::pair<Value, Type>> oldTypes;
+};
+
+void propagateShapesInRegion(Region &region, TypeModificationState &state);
+
+void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
IfOp ifOp = dyn_cast<IfOp>(op);
if (!ifOp)
return;
@@ -58,7 +131,7 @@ void propagateShapesToTosaIf(Operation &op) {
if (inferredTy.hasRank()) {
Type newType = oldType.clone(inferredTy.getShape());
- blockArg.setType(newType);
+ state.setType(blockArg, newType);
}
}
@@ -71,14 +144,14 @@ void propagateShapesToTosaIf(Operation &op) {
ValueKnowledge::join(operandKnowledge, blockKnowledge);
if (!joinedKnowledge)
continue;
- frontBlock.getArgument(i).setType(joinedKnowledge.getType());
+ state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
}
- propagateShapesInRegion(region);
+ propagateShapesInRegion(region, state);
}
}
-void propagateShapesToTosaWhile(Operation &op) {
+void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
WhileOp whileOp = dyn_cast<WhileOp>(op);
if (!whileOp)
return;
@@ -86,49 +159,29 @@ void propagateShapesToTosaWhile(Operation &op) {
// Determine what the expected argument types are to the cond/body blocks.
// The expected arguments should be compatible with ever iteration of the
// loop body / condition for tosa.while.
- llvm::SmallVector<Type> argTypes;
- for (auto operand : op.getOperands()) {
- auto operandTy = cast<ShapedType>(operand.getType());
- if (operandTy.hasRank()) {
- auto newTy = operandTy.clone(operandTy.getShape());
- argTypes.push_back(newTy);
- } else {
- argTypes.push_back(operand.getType());
- }
- }
-
- // Save out the type information so we can restore at the end.
- llvm::DenseMap<Value, Type> originalTypeMap;
- for (auto &block : op.getRegion(1)) {
- for (auto arg : block.getArguments())
- originalTypeMap[arg] = arg.getType();
- for (auto &op : block)
- for (auto result : op.getResults())
- originalTypeMap[result] = result.getType();
- }
+ SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
bool hasNewTypes = true;
while (hasNewTypes) {
+ TypeModificationState localState;
// Set types on the block args.
Region &bodyRegion = op.getRegion(1);
Block &block = bodyRegion.front();
for (int i = 0, s = argTypes.size(); i < s; i++) {
- block.getArgument(i).setType(argTypes[i]);
+ localState.setType(block.getArgument(i), argTypes[i]);
}
// Propagate to the end.
- propagateShapesInRegion(bodyRegion);
+ propagateShapesInRegion(bodyRegion, localState);
- // Find all the tosa yield types and verify there is atleast one.
+ // Find all the tosa yield types and verify there is a single one.
llvm::SmallVector<YieldOp> yieldOps;
for (auto &block : bodyRegion)
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
yieldOps.push_back(yieldOp);
- if (yieldOps.empty())
- return;
-
+ assert(yieldOps.size() == 1 && "missing or non-unique yield op");
// Using the new tosa.yield operand types, infer the new subtypes.
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
for (auto ty : argTypes) {
@@ -158,17 +211,8 @@ void propagateShapesToTosaWhile(Operation &op) {
argTypes[i] = newType;
}
- // The types inferred in the block assume the operand types specified for
- // this iteration. We need to restore the original types to ensure that
- // future iterations only use the already specified types, not possible
- // types from previous iterations.
- for (auto &block : bodyRegion) {
- for (auto arg : block.getArguments())
- arg.setType(originalTypeMap[arg]);
- for (auto &op : block)
- for (auto result : op.getResults())
- result.setType(originalTypeMap[result]);
- }
+ // Revert all changes made during the speculative part of the algorithm.
+ localState.revert();
}
// We now set the block arguments according to the most recent shape
@@ -176,41 +220,22 @@ void propagateShapesToTosaWhile(Operation &op) {
// iteration.
for (auto &region : op.getRegions()) {
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
- region.front().getArgument(i).setType(argTypes[i]);
+ state.setType(region.front().getArgument(i), argTypes[i]);
}
- propagateShapesInRegion(region);
+ propagateShapesInRegion(region, state);
}
}
-// Track the old type for each operand whose type was updated
-// during inference. This information is used to introduce casts
-// back to the type expected by the operand after inference.
-struct TypeRewriteInfo {
- OpOperand *operand;
- Type oldType;
-};
-
-void propagateShapesInRegion(Region &region) {
- // Check whether this use case is replaceable. We define an op as
- // being replaceable if it is used by a TosaOp, or an op with a
- // type-inference related interface.
- // When a non-replaceable use is encountered, the value is wrapped in a
- // cast back to the original type after inference.
- auto isReplaceableUser = [](Operation *user) -> bool {
- return user->getDialect()->getNamespace() ==
- TosaDialect::getDialectNamespace() ||
- isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
- };
-
- llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
+void propagateShapesInRegion(Region &region, TypeModificationState &state) {
for (auto &block : region) {
for (Operation &op : block) {
- if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+ if (!op.getDialect() ||
+ op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
continue;
- propagateShapesToTosaIf(op);
- propagateShapesToTosaWhile(op);
+ propagateShapesToTosaIf(op, state);
+ propagateShapesToTosaWhile(op, state);
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
@@ -252,30 +277,11 @@ void propagateShapesInRegion(Region &region) {
continue;
// Set new type
- result.setType(newKnowledge.getType());
-
- // Collect all uses of the operation which require update.
- for (auto &user : result.getUses()) {
- if (!isReplaceableUser(user.getOwner()))
- requiresUpdate.push_back({&user, resultTy});
- }
+ state.setType(result, newKnowledge.getType());
}
}
}
}
-
- // For each use whose type changed, cast the value with the new type back to
- // the old type.
- IRRewriter rewriter(region.getContext());
- for (auto [operand, oldType] : requiresUpdate) {
- rewriter.setInsertionPoint(operand->getOwner());
-
- auto oldValue = operand->get();
-
- auto loc = oldValue.getLoc();
- auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue);
- operand->set(castOp);
- }
}
/// Pass that performs shape propagation across TOSA operations. This includes
@@ -285,7 +291,9 @@ struct TosaInferShapes
public:
void runOnOperation() override {
func::FuncOp func = getOperation();
- propagateShapesInRegion(func.getBody());
+ TypeModificationState state;
+ propagateShapesInRegion(func.getBody(), state);
+ state.commit();
}
};
} // namespace
diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
index 6d7e3bc..52359fa8 100644
--- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
+++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
@@ -47,17 +47,26 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
StopConditionFn stopCondition) {
using namespace presburger;
-
assert(vscaleMin <= vscaleMax);
- ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin,
- vscaleMax);
- int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition);
+ // No stop condition specified: Keep adding constraints until the worklist
+ // is empty.
+ auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
+ mlir::ValueBoundsConstraintSet &cstr) {
+ return false;
+ };
+
+ ScalableValueBoundsConstraintSet scalableCstr(
+ value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
+ vscaleMin, vscaleMax);
+ int64_t pos = scalableCstr.populateConstraintsSet(value, dim);
// Project out all variables apart from vscale.
// This should result in constraints in terms of vscale only.
- scalableCstr.projectOut(
- [&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); });
+ auto projectOutFn = [&](ValueDim p) {
+ return p.first != scalableCstr.getVscaleValue();
+ };
+ scalableCstr.projectOut(projectOutFn);
assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
scalableCstr.positionToValueDim.size() &&
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 593c1e5..8d733c5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -398,13 +398,30 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
transposeResults.push_back(targetExpr);
}
}
+
+ // Checks if only the outer, unit dimensions (of size 1) are permuted.
+ // Such transposes do not materially effect the underlying vector and can
+ // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
+ bool transposeNonOuterUnitDims = false;
+ auto operandShape = operands[it.index()].getType().cast<ShapedType>();
+ for (auto [index, dim] :
+ llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
+ if (dim != static_cast<int64_t>(index) &&
+ operandShape.getDimSize(index) != 1) {
+ transposeNonOuterUnitDims = true;
+ break;
+ }
+ }
+
// Do the tranpose now if needed so that we can drop the
// correct dim using extract later.
if (tranposeNeeded) {
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
contractOp.getContext());
- operands[it.index()] = rewriter.create<vector::TransposeOp>(
- loc, operands[it.index()], perm);
+ if (transposeNonOuterUnitDims) {
+ operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+ loc, operands[it.index()], perm);
+ }
}
}
// We have taken care to have the dim to be dropped be
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 4fa5b8a..b59e906 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -26,6 +26,9 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
// Reject index since getElementTypeBitWidth will abort for Index types.
if (!vecType || vecType.getElementType().isIndex())
return false;
+ // There are no dimension to fold if it is a 0-D vector.
+ if (vecType.getRank() == 0)
+ return false;
unsigned trailingVecDimBitWidth =
vecType.getShape().back() * vecType.getElementTypeBitWidth();
if (trailingVecDimBitWidth >= targetBitWidth)