aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp15
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp48
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp6
-rw-r--r--mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp98
-rw-r--r--mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp4
-rw-r--r--mlir/lib/Dialect/DLTI/DLTI.cpp5
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp64
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp19
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp36
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp2
-rw-r--r--mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp20
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp49
-rw-r--r--mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp353
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp2
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp17
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp198
-rw-r--r--mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp21
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp21
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp3
-rw-r--r--mlir/lib/IR/PatternMatch.cpp9
-rw-r--r--mlir/lib/Interfaces/DataLayoutInterfaces.cpp25
-rw-r--r--mlir/lib/Interfaces/ValueBoundsOpInterface.cpp94
-rw-r--r--mlir/lib/Support/Timing.cpp198
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp2
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp102
-rw-r--r--mlir/lib/Transforms/Mem2Reg.cpp16
32 files changed, 1005 insertions, 447 deletions
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 0e3b646..25fa158 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -62,8 +62,14 @@ struct ConvertLoad final : public OpConversionPattern<memref::LoadOp> {
return rewriter.notifyMatchFailure(op.getLoc(), "cannot convert type");
}
+ auto arrayValue =
+ dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
+ if (!arrayValue) {
+ return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
+ }
+
auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), operands.getMemref(), operands.getIndices());
+ op.getLoc(), arrayValue, operands.getIndices());
auto noInit = emitc::OpaqueAttr::get(getContext(), "");
auto var =
@@ -81,9 +87,14 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
LogicalResult
matchAndRewrite(memref::StoreOp op, OpAdaptor operands,
ConversionPatternRewriter &rewriter) const override {
+ auto arrayValue =
+ dyn_cast<TypedValue<emitc::ArrayType>>(operands.getMemref());
+ if (!arrayValue) {
+ return rewriter.notifyMatchFailure(op.getLoc(), "expected array type");
+ }
auto subscript = rewriter.create<emitc::SubscriptOp>(
- op.getLoc(), operands.getMemref(), operands.getIndices());
+ op.getLoc(), arrayValue, operands.getIndices());
rewriter.replaceOpWithNewOp<emitc::AssignOp>(op, subscript,
operands.getValue());
return success();
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
index 1dc69ab..05c7707 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopCoalescing.cpp
@@ -39,9 +39,9 @@ struct LoopCoalescingPass
func::FuncOp func = getOperation();
func.walk<WalkOrder::PreOrder>([](Operation *op) {
if (auto scfForOp = dyn_cast<scf::ForOp>(op))
- (void)coalescePerfectlyNestedLoops(scfForOp);
+ (void)coalescePerfectlyNestedSCFForLoops(scfForOp);
else if (auto affineForOp = dyn_cast<AffineForOp>(op))
- (void)coalescePerfectlyNestedLoops(affineForOp);
+ (void)coalescePerfectlyNestedAffineLoops(affineForOp);
});
}
};
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 37b36f7..117ee8e 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -84,7 +84,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
// We are trying to reify a bound for `value` in terms of the owning op's
// operands. Construct a stop condition that evaluates to "true" for any SSA
// value except for `value`. I.e., the bound will be computed in terms of
@@ -100,7 +101,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index af59973..268050a 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -2765,3 +2765,51 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest,
return success();
}
+
+LogicalResult affine::coalescePerfectlyNestedAffineLoops(AffineForOp op) {
+ LogicalResult result(failure());
+ SmallVector<AffineForOp> loops;
+ getPerfectlyNestedLoops(loops, op);
+ if (loops.size() <= 1)
+ return success();
+
+ // Look for a band of loops that can be coalesced, i.e. perfectly nested
+ // loops with bounds defined above some loop.
+ // 1. For each loop, find above which parent loop its operands are
+ // defined.
+ SmallVector<unsigned> operandsDefinedAbove(loops.size());
+ for (unsigned i = 0, e = loops.size(); i < e; ++i) {
+ operandsDefinedAbove[i] = i;
+ for (unsigned j = 0; j < i; ++j) {
+ if (areValuesDefinedAbove(loops[i].getOperands(), loops[j].getRegion())) {
+ operandsDefinedAbove[i] = j;
+ break;
+ }
+ }
+ }
+
+ // 2. Identify bands of loops such that the operands of all of them are
+ // defined above the first loop in the band. Traverse the nest bottom-up
+ // so that modifications don't invalidate the inner loops.
+ for (unsigned end = loops.size(); end > 0; --end) {
+ unsigned start = 0;
+ for (; start < end - 1; ++start) {
+ auto maxPos =
+ *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+ std::next(operandsDefinedAbove.begin(), end));
+ if (maxPos > start)
+ continue;
+ assert(maxPos == start &&
+ "expected loop bounds to be known at the start of the band");
+ auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
+ if (succeeded(coalesceLoops(band)))
+ result = success();
+ break;
+ }
+ // If a band was found and transformed, keep looking at the loops above
+ // the outermost transformed loop.
+ if (start != end - 1)
+ end = start + 1;
+ }
+ return result;
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index 8d9fd14..fad2212 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -119,7 +119,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition,
bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
// We are trying to reify a bound for `value` in terms of the owning op's
// operands. Construct a stop condition that evaluates to "true" for any SSA
// value expect for `value`. I.e., the bound will be computed in terms of
@@ -135,7 +136,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
OpBuilder &b, Location loc, presburger::BoundType type, Value value,
ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
- auto reifyToOperands = [&](Value v, std::optional<int64_t> d) {
+ auto reifyToOperands = [&](Value v, std::optional<int64_t> d,
+ ValueBoundsConstraintSet &cstr) {
return v != value;
};
return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 1f48d27..1374022 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -40,8 +40,9 @@ static Type matchContainerType(Type element, Type container) {
/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile
/// any vector.contract into multiple smmla instructions with unrolling so long
-/// as [2,2,8] is a divisor of its shape. If no unrolling is necessary, a single
-/// smmla instruction is emitted.
+/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM
+/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is
+/// necessary, a single smmla instruction is emitted.
class LowerContractionToSMMLAPattern
: public OpRewritePattern<vector::ContractionOp> {
public:
@@ -49,32 +50,35 @@ public:
LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- // Check index maps that represent M N K in contract.
- auto indexingMaps = op.getIndexingMapsArray();
- if (llvm::any_of(indexingMaps, [](mlir::AffineMap affineMap) {
- return affineMap.isPermutation() || affineMap.getNumDims() != 3 ||
- affineMap.getNumResults() != 2;
- })) {
- return failure();
- }
- // Check iterator types for contract.
- auto iteratorTypes = op.getIteratorTypesArray();
- if (iteratorTypes.size() != 3 ||
- iteratorTypes[0] != vector::IteratorType::parallel ||
- iteratorTypes[1] != vector::IteratorType::parallel ||
- iteratorTypes[2] != vector::IteratorType::reduction) {
- return failure();
- }
- // Infer tile sizes from operands; Note: RHS is not transposed.
+ // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim.
+ // Note: RHS is not transposed.
mlir::VectorType lhsType = op.getLhsType();
mlir::VectorType rhsType = op.getRhsType();
- auto dimM = lhsType.getDimSize(0);
+ auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
- auto dimK = lhsType.getDimSize(1);
-
+ auto dimK = rhsType.getDimSize(1);
+ bool isVecmat = dimM == 1 ? true : false;
+ if (lhsType.getDimSize(lhsType.getRank() - 1) !=
+ rhsType.getDimSize(rhsType.getRank() - 1)) {
+ return failure(); // dimK mismatch
+ }
// Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
// tiling.
- if (dimM % 2 != 0 || dimN % 2 != 0 || dimK % 8 != 0) {
+ if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) {
+ return failure();
+ }
+
+ // Check iterator types for contract. All iterators except inner-most
+ // dimension must be parallel.
+ auto iteratorTypes = op.getIteratorTypesArray();
+ if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] !=
+ vector::IteratorType::reduction) {
+ return failure();
+ }
+ if (llvm::any_of(ArrayRef<vector::IteratorType>(iteratorTypes).drop_back(1),
+ [](vector::IteratorType iteratorType) {
+ return iteratorType != vector::IteratorType::parallel;
+ })) {
return failure();
}
@@ -120,11 +124,14 @@ public:
loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
SmallVector<int64_t> unrolledSize = *op.getShapeForUnroll();
- SmallVector<int64_t> smmlaShape{2, 2, 8};
- SmallVector<int64_t> loopOrder{0, 1, 2};
+ SmallVector<int64_t> smmlaShape{2, 8};
+ SmallVector<int64_t> loopOrder{0, 1};
+ if (unrolledSize.size() == 3) {
+ smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2);
+ loopOrder.push_back(2);
+ }
for (SmallVector<int64_t> offsets :
StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) {
-
// Helper to compute the new shape of each operand and extract the slice.
auto extractOperand = [&](Value operand, AffineMap permutationMap,
ArrayRef<int64_t> operandOffsets) {
@@ -150,16 +157,40 @@ public:
Value tiledAcc =
extractOperand(op.getAcc(), accPermutationMap, accOffsets);
+ auto inputElementType =
+ tiledLhs.getType().cast<ShapedType>().getElementType();
+ auto accElementType =
+ tiledAcc.getType().cast<ShapedType>().getElementType();
+ auto inputExpandedType = VectorType::get({2, 8}, inputElementType);
+ auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+ // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
+ // rows along dimM. Expand their shapes to match the smmla op.
+ if (isVecmat) {
+ auto expandForSMMLA = [&](Value tiledOperand,
+ VectorType expandedTypeType) {
+ auto emptyOperand = rewriter.create<arith::ConstantOp>(
+ loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
+ SmallVector<int64_t> offsets(
+ emptyOperand.getType().cast<ShapedType>().getRank(), 0);
+ SmallVector<int64_t> strides(
+ tiledOperand.getType().cast<ShapedType>().getRank(), 1);
+ return rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, tiledOperand, emptyOperand, offsets, strides);
+ };
+ tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType);
+ tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType);
+ }
+
// Collapse tiled operands to 1D vectors required by smmla intrinsic
- auto collapsedInputType = VectorType::get(
- tiledLhs.getType().cast<ShapedType>().getNumElements(),
- tiledLhs.getType().cast<ShapedType>().getElementType());
- auto collapsedOutputType = VectorType::get(
- {4}, tiledAcc.getType().cast<ShapedType>().getElementType());
+ auto collapsedInputType =
+ VectorType::get(inputExpandedType.getNumElements(), inputElementType);
auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledLhs.getLoc(), collapsedInputType, tiledLhs);
auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+ auto collapsedOutputType =
+ VectorType::get(outputExpandedType.getNumElements(), accElementType);
auto collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
@@ -172,6 +203,11 @@ public:
Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
smmlaOp.getLoc(), tiledAcc.getType(), smmlaOp);
+ // With vecmat, only one row of tiled ACC can be inserted inot file result
+ if (isVecmat) {
+ tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
+ }
+
// Insert the tiled result back into the non tiled result of the
// contract op.
SmallVector<int64_t> strides(
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 5d11f8f6..1320db3 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -531,8 +531,8 @@ static ParseResult parseSwitchOpCases(
failed(parser.parseSuccessor(destination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
- if (failed(parser.parseOperandList(operands, OpAsmParser::Delimiter::None,
- /*allowResultNumber=*/false)) ||
+ if (failed(parser.parseOperandList(operands,
+ OpAsmParser::Delimiter::None)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
diff --git a/mlir/lib/Dialect/DLTI/DLTI.cpp b/mlir/lib/Dialect/DLTI/DLTI.cpp
index daef234..98a8865 100644
--- a/mlir/lib/Dialect/DLTI/DLTI.cpp
+++ b/mlir/lib/Dialect/DLTI/DLTI.cpp
@@ -282,6 +282,11 @@ DataLayoutEntryListRef DataLayoutSpecAttr::getEntries() const {
}
StringAttr
+DataLayoutSpecAttr::getEndiannessIdentifier(MLIRContext *context) const {
+ return Builder(context).getStringAttr(DLTIDialect::kDataLayoutEndiannessKey);
+}
+
+StringAttr
DataLayoutSpecAttr::getAllocaMemorySpaceIdentifier(MLIRContext *context) const {
return Builder(context).getStringAttr(
DLTIDialect::kDataLayoutAllocaMemorySpaceKey);
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index f4a9dc3..7cbf28b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -70,6 +70,11 @@ bool mlir::emitc::isSupportedIntegerType(Type type) {
return false;
}
+bool mlir::emitc::isIntegerIndexOrOpaqueType(Type type) {
+ return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
+ isSupportedIntegerType(type);
+}
+
bool mlir::emitc::isSupportedFloatType(Type type) {
if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
switch (floatType.getWidth()) {
@@ -780,12 +785,61 @@ LogicalResult emitc::YieldOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult emitc::SubscriptOp::verify() {
- if (getIndices().size() != (size_t)getArray().getType().getRank()) {
- return emitOpError() << "requires number of indices ("
- << getIndices().size()
- << ") to match the rank of the array type ("
- << getArray().getType().getRank() << ")";
+ // Checks for array operand.
+ if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
+ // Check number of indices.
+ if (getIndices().size() != (size_t)arrayType.getRank()) {
+ return emitOpError() << "on array operand requires number of indices ("
+ << getIndices().size()
+ << ") to match the rank of the array type ("
+ << arrayType.getRank() << ")";
+ }
+ // Check types of index operands.
+ for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
+ Type type = getIndices()[i].getType();
+ if (!isIntegerIndexOrOpaqueType(type)) {
+ return emitOpError() << "on array operand requires index operand " << i
+ << " to be integer-like, but got " << type;
+ }
+ }
+ // Check element type.
+ Type elementType = arrayType.getElementType();
+ if (elementType != getType()) {
+ return emitOpError() << "on array operand requires element type ("
+ << elementType << ") and result type (" << getType()
+ << ") to match";
+ }
+ return success();
}
+
+ // Checks for pointer operand.
+ if (auto pointerType =
+ llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
+ // Check number of indices.
+ if (getIndices().size() != 1) {
+ return emitOpError()
+ << "on pointer operand requires one index operand, but got "
+ << getIndices().size();
+ }
+ // Check types of index operand.
+ Type type = getIndices()[0].getType();
+ if (!isIntegerIndexOrOpaqueType(type)) {
+ return emitOpError() << "on pointer operand requires index operand to be "
+ "integer-like, but got "
+ << type;
+ }
+ // Check pointee type.
+ Type pointeeType = pointerType.getPointee();
+ if (pointeeType != getType()) {
+ return emitOpError() << "on pointer operand requires pointee type ("
+ << pointeeType << ") and result type (" << getType()
+ << ") to match";
+ }
+ return success();
+ }
+
+ // The operand has opaque type, so we can't assume anything about the number
+ // or types of index operands.
return success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
index 06c1fdd..5637552 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp
@@ -168,13 +168,6 @@ bool LLVM::StoreOp::canUsesBeRemoved(
DeletionKind LLVM::StoreOp::removeBlockingUses(
const MemorySlot &slot, const SmallPtrSetImpl<OpOperand *> &blockingUses,
RewriterBase &rewriter, Value reachingDefinition) {
- // `canUsesBeRemoved` checked this blocking use must be the stored slot
- // pointer.
- for (Operation *user : slot.ptr.getUsers())
- if (auto declareOp = dyn_cast<LLVM::DbgDeclareOp>(user))
- rewriter.create<LLVM::DbgValueOp>(declareOp->getLoc(), getValue(),
- declareOp.getVarInfo(),
- declareOp.getLocationExpr());
return DeletionKind::Delete;
}
@@ -407,6 +400,18 @@ DeletionKind LLVM::DbgValueOp::removeBlockingUses(
return DeletionKind::Keep;
}
+bool LLVM::DbgDeclareOp::requiresReplacedValues() { return true; }
+
+void LLVM::DbgDeclareOp::visitReplacedValues(
+ ArrayRef<std::pair<Operation *, Value>> definitions,
+ RewriterBase &rewriter) {
+ for (auto [op, value] : definitions) {
+ rewriter.setInsertionPointAfter(op);
+ rewriter.create<LLVM::DbgValueOp>(getLoc(), value, getVarInfo(),
+ getLocationExpr());
+ }
+}
+
//===----------------------------------------------------------------------===//
// Interfaces for GEPOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 2d7219f..9c5c58f 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -373,14 +373,15 @@ namespace {
class RegionBuilderHelper {
public:
- RegionBuilderHelper(MLIRContext *context, Block &block)
- : context(context), block(block) {}
+ RegionBuilderHelper(OpBuilder &builder, Block &block)
+ : builder(builder), block(block) {}
// Build the unary functions defined by OpDSL.
Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
if (!isFloatingPoint(arg))
llvm_unreachable("unsupported non numeric type");
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
switch (unaryFn) {
case UnaryFn::exp:
return builder.create<math::ExpOp>(arg.getLoc(), arg);
@@ -407,7 +408,8 @@ public:
arg1.getType().getIntOrFloatBitWidth() == 1;
if (!allComplex && !allFloatingPoint && !allInteger)
llvm_unreachable("unsupported non numeric type");
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
switch (binaryFn) {
case BinaryFn::add:
if (allComplex)
@@ -481,29 +483,32 @@ public:
}
void yieldOutputs(ValueRange values) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
builder.create<YieldOp>(loc, values);
}
Value constant(const std::string &value) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
Location loc = builder.getUnknownLoc();
Attribute valueAttr = parseAttribute(value, builder.getContext());
return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
}
Value index(int64_t dim) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
}
Type getIntegerType(unsigned width) {
- return IntegerType::get(context, width);
+ return IntegerType::get(builder.getContext(), width);
}
- Type getFloat32Type() { return Float32Type::get(context); }
- Type getFloat64Type() { return Float64Type::get(context); }
+ Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
+ Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
private:
// Generates operations to cast the given operand to a specified type.
@@ -511,7 +516,8 @@ private:
// operand returned as-is (which will presumably yield a verification
// issue downstream).
Value cast(Type toType, Value operand, bool isUnsignedCast) {
- OpBuilder builder = getBuilder();
+ OpBuilder::InsertionGuard g(builder);
+ builder.setInsertionPointToEnd(&block);
auto loc = operand.getLoc();
return convertScalarToDtype(builder, loc, operand, toType, isUnsignedCast);
}
@@ -526,13 +532,7 @@ private:
return llvm::isa<IntegerType>(value.getType());
}
- OpBuilder getBuilder() {
- OpBuilder builder(context);
- builder.setInsertionPointToEnd(&block);
- return builder;
- }
-
- MLIRContext *context;
+ OpBuilder &builder;
Block &block;
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
index 58fb2e9..899b8c8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/IR/DstBufferizableOpInterfaceImpl.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
@@ -110,6 +111,10 @@ struct LinalgOpInterface
ArrayRef<OpOperand *> opOperands) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
+ // Accesses into sparse data structures are not necessarily elementwise.
+ if (sparse_tensor::hasAnySparseOperand(linalgOp))
+ return false;
+
// All loops must be parallel.
if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
return false;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index b32ea8e..c3a08ce 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter,
FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound(
rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(),
/*stopCondition=*/
- [&](Value v, std::optional<int64_t> d) {
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
if (v == forOp.getUpperBound())
return false;
// Compute a bound that is independent of any affine op results.
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index 0b85462..42629e1 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -216,20 +216,30 @@ static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
// Convert `math.fpowi` to a series of `arith.mulf` operations.
// If the power is negative, we divide one by the result.
// If both the base and power are zero, the result is 1.
-static LogicalResult convertFPowICstOp(math::FPowIOp op,
- PatternRewriter &rewriter) {
+// In the case of non constant power, we convert the operation to `math.powf`.
+static LogicalResult convertFPowIOp(math::FPowIOp op,
+ PatternRewriter &rewriter) {
ImplicitLocOpBuilder b(op->getLoc(), rewriter);
Value base = op.getOperand(0);
Value power = op.getOperand(1);
Type baseType = base.getType();
+ auto convertFPowItoPowf = [&]() -> LogicalResult {
+ Value castPowerToFp =
+ rewriter.create<arith::SIToFPOp>(op.getLoc(), baseType, power);
+ Value res = rewriter.create<math::PowFOp>(op.getLoc(), baseType, base,
+ castPowerToFp);
+ rewriter.replaceOp(op, res);
+ return success();
+ };
+
Attribute cstAttr;
if (!matchPattern(power, m_Constant(&cstAttr)))
- return failure();
+ return convertFPowItoPowf();
APInt value;
if (!matchPattern(cstAttr, m_ConstantInt(&value)))
- return failure();
+ return convertFPowItoPowf();
int64_t powerInt = value.getSExtValue();
bool isNegative = powerInt < 0;
@@ -591,7 +601,7 @@ void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
}
void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
- patterns.add(convertFPowICstOp);
+ patterns.add(convertFPowIOp);
}
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index bf58750..a043431 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1538,6 +1538,21 @@ static void printAtomicReductionRegion(OpAsmPrinter &printer,
printer.printRegion(region);
}
+static ParseResult parseCleanupReductionRegion(OpAsmParser &parser,
+ Region &region) {
+ if (parser.parseOptionalKeyword("cleanup"))
+ return success();
+ return parser.parseRegion(region);
+}
+
+static void printCleanupReductionRegion(OpAsmPrinter &printer,
+ DeclareReductionOp op, Region &region) {
+ if (region.empty())
+ return;
+ printer << "cleanup ";
+ printer.printRegion(region);
+}
+
LogicalResult DeclareReductionOp::verifyRegions() {
if (getInitializerRegion().empty())
return emitOpError() << "expects non-empty initializer region";
@@ -1571,21 +1586,29 @@ LogicalResult DeclareReductionOp::verifyRegions() {
"of the reduction type";
}
- if (getAtomicReductionRegion().empty())
+ if (!getAtomicReductionRegion().empty()) {
+ Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
+ if (atomicReductionEntryBlock.getNumArguments() != 2 ||
+ atomicReductionEntryBlock.getArgumentTypes()[0] !=
+ atomicReductionEntryBlock.getArgumentTypes()[1])
+ return emitOpError() << "expects atomic reduction region with two "
+ "arguments of the same type";
+ auto ptrType = llvm::dyn_cast<PointerLikeType>(
+ atomicReductionEntryBlock.getArgumentTypes()[0]);
+ if (!ptrType ||
+ (ptrType.getElementType() && ptrType.getElementType() != getType()))
+ return emitOpError() << "expects atomic reduction region arguments to "
+ "be accumulators containing the reduction type";
+ }
+
+ if (getCleanupRegion().empty())
return success();
+ Block &cleanupEntryBlock = getCleanupRegion().front();
+ if (cleanupEntryBlock.getNumArguments() != 1 ||
+ cleanupEntryBlock.getArgument(0).getType() != getType())
+ return emitOpError() << "expects cleanup region with one argument "
+ "of the reduction type";
- Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
- if (atomicReductionEntryBlock.getNumArguments() != 2 ||
- atomicReductionEntryBlock.getArgumentTypes()[0] !=
- atomicReductionEntryBlock.getArgumentTypes()[1])
- return emitOpError() << "expects atomic reduction region with two "
- "arguments of the same type";
- auto ptrType = llvm::dyn_cast<PointerLikeType>(
- atomicReductionEntryBlock.getArgumentTypes()[0]);
- if (!ptrType ||
- (ptrType.getElementType() && ptrType.getElementType() != getType()))
- return emitOpError() << "expects atomic reduction region arguments to "
- "be accumulators containing the reduction type";
return success();
}
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index cb36e0c..1e13e60 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -58,7 +58,7 @@ struct ForOpInterface
ValueDimList boundOperands;
LogicalResult status = ValueBoundsConstraintSet::computeBound(
bound, boundOperands, BoundType::EQ, yieldedValue, dim,
- [&](Value v, std::optional<int64_t> d) {
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
// Stop when reaching a block argument of the loop body.
if (auto bbArg = llvm::dyn_cast<BlockArgument>(v))
return bbArg.getOwner()->getParentOp() == forOp;
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index c091841..7e4faf8 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -332,9 +332,9 @@ transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter,
transform::TransformState &state) {
LogicalResult result(failure());
if (scf::ForOp scfForOp = dyn_cast<scf::ForOp>(op))
- result = coalescePerfectlyNestedLoops(scfForOp);
+ result = coalescePerfectlyNestedSCFForLoops(scfForOp);
else if (AffineForOp affineForOp = dyn_cast<AffineForOp>(op))
- result = coalescePerfectlyNestedLoops(affineForOp);
+ result = coalescePerfectlyNestedAffineLoops(affineForOp);
results.push_back(op);
if (failed(result)) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
index a69df02..6ba7020 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopCollapsing.cpp
@@ -28,6 +28,7 @@ namespace {
struct TestSCFParallelLoopCollapsing
: public impl::TestSCFParallelLoopCollapsingBase<
TestSCFParallelLoopCollapsing> {
+
void runOnOperation() override {
Operation *module = getOperation();
@@ -88,6 +89,7 @@ struct TestSCFParallelLoopCollapsing
// Only apply the transformation on parallel loops where the specified
// transformation is valid, but do NOT early abort in the case of invalid
// loops.
+ IRRewriter rewriter(&getContext());
module->walk([&](scf::ParallelOp op) {
if (flattenedCombinedLoops.size() != op.getNumLoops()) {
op.emitOpError("has ")
@@ -97,7 +99,7 @@ struct TestSCFParallelLoopCollapsing
<< flattenedCombinedLoops.size() << " iter args.";
return;
}
- collapseParallelLoops(op, combinedLoops);
+ collapseParallelLoops(rewriter, op, combinedLoops);
});
}
};
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 914aeb4..9279081 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -13,6 +13,7 @@
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
@@ -472,18 +473,23 @@ LogicalResult mlir::loopUnrollByFactor(
return success();
}
-/// Return the new lower bound, upper bound, and step in that order. Insert any
-/// additional bounds calculations before the given builder and any additional
-/// conversion back to the original loop induction value inside the given Block.
-static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
- OpBuilder &insideLoopBuilder, Location loc,
- Value lowerBound, Value upperBound, Value step,
- Value inductionVar) {
+/// Transform a loop with a strictly positive step
+/// for %i = %lb to %ub step %s
+/// into a 0-based loop with step 1
+/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
+/// %i = %ii * %s + %lb
+/// Insert the induction variable remapping in the body of `inner`, which is
+/// expected to be either `loop` or another loop perfectly nested under `loop`.
+/// Insert the definition of new bounds immediate before `outer`, which is
+/// expected to be either `loop` or its parent in the loop nest.
+static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+ Value lb, Value ub, Value step) {
+ // For non-index types, generate `arith` instructions
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
bool isZeroBased = false;
- if (auto ubCst = getConstantIntValue(lowerBound))
- isZeroBased = ubCst.value() == 0;
+ if (auto lbCst = getConstantIntValue(lb))
+ isZeroBased = lbCst.value() == 0;
bool isStepOne = false;
if (auto stepCst = getConstantIntValue(step))
@@ -493,62 +499,90 @@ static LoopParams normalizeLoop(OpBuilder &boundsBuilder,
// assuming the step is strictly positive. Update the bounds and the step
// of the loop to go from 0 to the number of iterations, if necessary.
if (isZeroBased && isStepOne)
- return {/*lowerBound=*/lowerBound, /*upperBound=*/upperBound,
- /*step=*/step};
+ return {lb, ub, step};
- Value diff = boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound);
+ Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
Value newUpperBound =
- boundsBuilder.create<arith::CeilDivSIOp>(loc, diff, step);
-
- Value newLowerBound =
- isZeroBased ? lowerBound
- : boundsBuilder.create<arith::ConstantOp>(
- loc, boundsBuilder.getZeroAttr(lowerBound.getType()));
- Value newStep =
- isStepOne ? step
- : boundsBuilder.create<arith::ConstantOp>(
- loc, boundsBuilder.getIntegerAttr(step.getType(), 1));
-
- // Insert code computing the value of the original loop induction variable
- // from the "normalized" one.
- Value scaled =
- isStepOne
- ? inductionVar
- : insideLoopBuilder.create<arith::MulIOp>(loc, inductionVar, step);
- Value shifted =
- isZeroBased
- ? scaled
- : insideLoopBuilder.create<arith::AddIOp>(loc, scaled, lowerBound);
-
- SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
- shifted.getDefiningOp()};
- inductionVar.replaceAllUsesExcept(shifted, preserve);
- return {/*lowerBound=*/newLowerBound, /*upperBound=*/newUpperBound,
- /*step=*/newStep};
+ isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
+
+ Value newLowerBound = isZeroBased
+ ? lb
+ : rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(lb.getType()));
+ Value newStep = isStepOne
+ ? step
+ : rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getIntegerAttr(step.getType(), 1));
+
+ return {newLowerBound, newUpperBound, newStep};
}
-/// Transform a loop with a strictly positive step
-/// for %i = %lb to %ub step %s
-/// into a 0-based loop with step 1
-/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
-/// %i = %ii * %s + %lb
-/// Insert the induction variable remapping in the body of `inner`, which is
-/// expected to be either `loop` or another loop perfectly nested under `loop`.
-/// Insert the definition of new bounds immediate before `outer`, which is
-/// expected to be either `loop` or its parent in the loop nest.
-static void normalizeLoop(scf::ForOp loop, scf::ForOp outer, scf::ForOp inner) {
- OpBuilder builder(outer);
- OpBuilder innerBuilder = OpBuilder::atBlockBegin(inner.getBody());
- auto loopPieces = normalizeLoop(builder, innerBuilder, loop.getLoc(),
- loop.getLowerBound(), loop.getUpperBound(),
- loop.getStep(), loop.getInductionVar());
-
- loop.setLowerBound(loopPieces.lowerBound);
- loop.setUpperBound(loopPieces.upperBound);
- loop.setStep(loopPieces.step);
+/// Get back the original induction variable values after loop normalization
+static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
+ Value normalizedIv, Value origLb,
+ Value origStep) {
+ Value denormalizedIv;
+ SmallPtrSet<Operation *, 2> preserve;
+ bool isStepOne = isConstantIntValue(origStep, 1);
+ bool isZeroBased = isConstantIntValue(origLb, 0);
+
+ Value scaled = normalizedIv;
+ if (!isStepOne) {
+ scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep);
+ preserve.insert(scaled.getDefiningOp());
+ }
+ denormalizedIv = scaled;
+ if (!isZeroBased) {
+ denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb);
+ preserve.insert(denormalizedIv.getDefiningOp());
+ }
+
+ rewriter.replaceAllUsesExcept(normalizedIv, denormalizedIv, preserve);
}
-LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
+/// Helper function to multiply a sequence of values.
+static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
+ ArrayRef<Value> values) {
+ assert(!values.empty() && "unexpected empty list");
+ Value productOf = values.front();
+ for (auto v : values.drop_front()) {
+ productOf = rewriter.create<arith::MulIOp>(loc, productOf, v);
+ }
+ return productOf;
+}
+
+/// For each original loop, the value of the
+/// induction variable can be obtained by dividing the induction variable of
+/// the linearized loop by the total number of iterations of the loops nested
+/// in it modulo the number of iterations in this loop (remove the values
+/// related to the outer loops):
+/// iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
+/// Compute these iteratively from the innermost loop by creating a "running
+/// quotient" of division by the range.
+static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2>>
+delinearizeInductionVariable(RewriterBase &rewriter, Location loc,
+ Value linearizedIv, ArrayRef<Value> ubs) {
+ Value previous = linearizedIv;
+ SmallVector<Value> delinearizedIvs(ubs.size());
+ SmallPtrSet<Operation *, 2> preservedUsers;
+ for (unsigned i = 0, e = ubs.size(); i < e; ++i) {
+ unsigned idx = ubs.size() - i - 1;
+ if (i != 0) {
+ previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]);
+ preservedUsers.insert(previous.getDefiningOp());
+ }
+ Value iv = previous;
+ if (i != e - 1) {
+ iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]);
+ preservedUsers.insert(iv.getDefiningOp());
+ }
+ delinearizedIvs[idx] = iv;
+ }
+ return {delinearizedIvs, preservedUsers};
+}
+
+LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
+ MutableArrayRef<scf::ForOp> loops) {
if (loops.size() < 2)
return failure();
@@ -557,57 +591,148 @@ LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
// 1. Make sure all loops iterate from 0 to upperBound with step 1. This
// allows the following code to assume upperBound is the number of iterations.
- for (auto loop : loops)
- normalizeLoop(loop, outermost, innermost);
+ for (auto loop : loops) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(outermost);
+ Value lb = loop.getLowerBound();
+ Value ub = loop.getUpperBound();
+ Value step = loop.getStep();
+ auto newLoopParams =
+ emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
+
+ rewriter.modifyOpInPlace(loop, [&]() {
+ loop.setLowerBound(newLoopParams.lowerBound);
+ loop.setUpperBound(newLoopParams.upperBound);
+ loop.setStep(newLoopParams.step);
+ });
+
+ rewriter.setInsertionPointToStart(innermost.getBody());
+ denormalizeInductionVariable(rewriter, loop.getLoc(),
+ loop.getInductionVar(), lb, step);
+ }
// 2. Emit code computing the upper bound of the coalesced loop as product
// of the number of iterations of all loops.
- OpBuilder builder(outermost);
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(outermost);
Location loc = outermost.getLoc();
- Value upperBound = outermost.getUpperBound();
- for (auto loop : loops.drop_front())
- upperBound =
- builder.create<arith::MulIOp>(loc, upperBound, loop.getUpperBound());
+ SmallVector<Value> upperBounds = llvm::map_to_vector(
+ loops, [](auto loop) { return loop.getUpperBound(); });
+ Value upperBound = getProductOfIntsOrIndexes(rewriter, loc, upperBounds);
outermost.setUpperBound(upperBound);
- builder.setInsertionPointToStart(outermost.getBody());
-
- // 3. Remap induction variables. For each original loop, the value of the
- // induction variable can be obtained by dividing the induction variable of
- // the linearized loop by the total number of iterations of the loops nested
- // in it modulo the number of iterations in this loop (remove the values
- // related to the outer loops):
- // iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
- // Compute these iteratively from the innermost loop by creating a "running
- // quotient" of division by the range.
- Value previous = outermost.getInductionVar();
+ rewriter.setInsertionPointToStart(innermost.getBody());
+ auto [delinearizeIvs, preservedUsers] = delinearizeInductionVariable(
+ rewriter, loc, outermost.getInductionVar(), upperBounds);
+ rewriter.replaceAllUsesExcept(outermost.getInductionVar(), delinearizeIvs[0],
+ preservedUsers);
+
+ for (int i = loops.size() - 1; i > 0; --i) {
+ auto outerLoop = loops[i - 1];
+ auto innerLoop = loops[i];
+
+ Operation *innerTerminator = innerLoop.getBody()->getTerminator();
+ auto yieldedVals = llvm::to_vector(innerTerminator->getOperands());
+ rewriter.eraseOp(innerTerminator);
+
+ SmallVector<Value> innerBlockArgs;
+ innerBlockArgs.push_back(delinearizeIvs[i]);
+ llvm::append_range(innerBlockArgs, outerLoop.getRegionIterArgs());
+ rewriter.inlineBlockBefore(innerLoop.getBody(), outerLoop.getBody(),
+ Block::iterator(innerLoop), innerBlockArgs);
+ rewriter.replaceOp(innerLoop, yieldedVals);
+ }
+ return success();
+}
+
+LogicalResult mlir::coalesceLoops(MutableArrayRef<scf::ForOp> loops) {
+ if (loops.empty()) {
+ return failure();
+ }
+ IRRewriter rewriter(loops.front().getContext());
+ return coalesceLoops(rewriter, loops);
+}
+
+LogicalResult mlir::coalescePerfectlyNestedSCFForLoops(scf::ForOp op) {
+ LogicalResult result(failure());
+ SmallVector<scf::ForOp> loops;
+ getPerfectlyNestedLoops(loops, op);
+
+ // Look for a band of loops that can be coalesced, i.e. perfectly nested
+ // loops with bounds defined above some loop.
+
+ // 1. For each loop, find above which parent loop its bounds operands are
+ // defined.
+ SmallVector<unsigned> operandsDefinedAbove(loops.size());
for (unsigned i = 0, e = loops.size(); i < e; ++i) {
- unsigned idx = loops.size() - i - 1;
- if (i != 0)
- previous = builder.create<arith::DivSIOp>(loc, previous,
- loops[idx + 1].getUpperBound());
-
- Value iv = (i == e - 1) ? previous
- : builder.create<arith::RemSIOp>(
- loc, previous, loops[idx].getUpperBound());
- replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
- loops.back().getRegion());
+ operandsDefinedAbove[i] = i;
+ for (unsigned j = 0; j < i; ++j) {
+ SmallVector<Value> boundsOperands = {loops[i].getLowerBound(),
+ loops[i].getUpperBound(),
+ loops[i].getStep()};
+ if (areValuesDefinedAbove(boundsOperands, loops[j].getRegion())) {
+ operandsDefinedAbove[i] = j;
+ break;
+ }
+ }
}
- // 4. Move the operations from the innermost just above the second-outermost
- // loop, delete the extra terminator and the second-outermost loop.
- scf::ForOp second = loops[1];
- innermost.getBody()->back().erase();
- outermost.getBody()->getOperations().splice(
- Block::iterator(second.getOperation()),
- innermost.getBody()->getOperations());
- second.erase();
- return success();
+ // 2. For each inner loop check that the iter_args for the immediately outer
+ // loop are the init for the immediately inner loop and that the yields of the
+ // return of the inner loop is the yield for the immediately outer loop. Keep
+ // track of where the chain starts from for each loop.
+ SmallVector<unsigned> iterArgChainStart(loops.size());
+ iterArgChainStart[0] = 0;
+ for (unsigned i = 1, e = loops.size(); i < e; ++i) {
+ // By default set the start of the chain to itself.
+ iterArgChainStart[i] = i;
+ auto outerloop = loops[i - 1];
+ auto innerLoop = loops[i];
+ if (outerloop.getNumRegionIterArgs() != innerLoop.getNumRegionIterArgs()) {
+ continue;
+ }
+ if (!llvm::equal(outerloop.getRegionIterArgs(), innerLoop.getInitArgs())) {
+ continue;
+ }
+ auto outerloopTerminator = outerloop.getBody()->getTerminator();
+ if (!llvm::equal(outerloopTerminator->getOperands(),
+ innerLoop.getResults())) {
+ continue;
+ }
+ iterArgChainStart[i] = iterArgChainStart[i - 1];
+ }
+
+ // 3. Identify bands of loops such that the operands of all of them are
+ // defined above the first loop in the band. Traverse the nest bottom-up
+ // so that modifications don't invalidate the inner loops.
+ for (unsigned end = loops.size(); end > 0; --end) {
+ unsigned start = 0;
+ for (; start < end - 1; ++start) {
+ auto maxPos =
+ *std::max_element(std::next(operandsDefinedAbove.begin(), start),
+ std::next(operandsDefinedAbove.begin(), end));
+ if (maxPos > start)
+ continue;
+ if (iterArgChainStart[end - 1] > start)
+ continue;
+ auto band = llvm::MutableArrayRef(loops.data() + start, end - start);
+ if (succeeded(coalesceLoops(band)))
+ result = success();
+ break;
+ }
+ // If a band was found and transformed, keep looking at the loops above
+ // the outermost transformed loop.
+ if (start != end - 1)
+ end = start + 1;
+ }
+ return result;
}
void mlir::collapseParallelLoops(
- scf::ParallelOp loops, ArrayRef<std::vector<unsigned>> combinedDimensions) {
- OpBuilder outsideBuilder(loops);
+ RewriterBase &rewriter, scf::ParallelOp loops,
+ ArrayRef<std::vector<unsigned>> combinedDimensions) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loops);
Location loc = loops.getLoc();
// Presort combined dimensions.
@@ -619,25 +744,29 @@ void mlir::collapseParallelLoops(
SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
normalizedUpperBounds;
for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
- OpBuilder insideLoopBuilder = OpBuilder::atBlockBegin(loops.getBody());
- auto resultBounds =
- normalizeLoop(outsideBuilder, insideLoopBuilder, loc,
- loops.getLowerBound()[i], loops.getUpperBound()[i],
- loops.getStep()[i], loops.getBody()->getArgument(i));
-
- normalizedLowerBounds.push_back(resultBounds.lowerBound);
- normalizedUpperBounds.push_back(resultBounds.upperBound);
- normalizedSteps.push_back(resultBounds.step);
+ OpBuilder::InsertionGuard g2(rewriter);
+ rewriter.setInsertionPoint(loops);
+ Value lb = loops.getLowerBound()[i];
+ Value ub = loops.getUpperBound()[i];
+ Value step = loops.getStep()[i];
+ auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
+ normalizedLowerBounds.push_back(newLoopParams.lowerBound);
+ normalizedUpperBounds.push_back(newLoopParams.upperBound);
+ normalizedSteps.push_back(newLoopParams.step);
+
+ rewriter.setInsertionPointToStart(loops.getBody());
+ denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,
+ step);
}
// Combine iteration spaces.
SmallVector<Value, 3> lowerBounds, upperBounds, steps;
- auto cst0 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 0);
- auto cst1 = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
+ auto cst0 = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ auto cst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1);
for (auto &sortedDimension : sortedDimensions) {
- Value newUpperBound = outsideBuilder.create<arith::ConstantIndexOp>(loc, 1);
+ Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, 1);
for (auto idx : sortedDimension) {
- newUpperBound = outsideBuilder.create<arith::MulIOp>(
+ newUpperBound = rewriter.create<arith::MulIOp>(
loc, newUpperBound, normalizedUpperBounds[idx]);
}
lowerBounds.push_back(cst0);
@@ -651,7 +780,7 @@ void mlir::collapseParallelLoops(
// value. The remainders then determine based on that range, which iteration
// of the original induction value this represents. This is a normalized value
// that is un-normalized already by the previous logic.
- auto newPloop = outsideBuilder.create<scf::ParallelOp>(
+ auto newPloop = rewriter.create<scf::ParallelOp>(
loc, lowerBounds, upperBounds, steps,
[&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) {
for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) {
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
index d4c1792..acea25f 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp
@@ -274,7 +274,7 @@ struct SparseTensorCodegenPass
});
// The following operations and dialects may be introduced by the
// codegen rules, and are therefore marked as legal.
- target.addLegalOp<linalg::FillOp>();
+ target.addLegalOp<linalg::FillOp, linalg::YieldOp>();
target.addLegalDialect<
arith::ArithDialect, bufferization::BufferizationDialect,
complex::ComplexDialect, memref::MemRefDialect, scf::SCFDialect>();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 6e6e843..e06ac9a 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -955,25 +955,34 @@ LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
}
mlir::LogicalResult tosa::ReshapeOp::verify() {
- ShapedType inputType = llvm::cast<ShapedType>(getInput1().getType());
- ShapedType outputType = llvm::cast<ShapedType>(getType());
+ TensorType inputType = getInput1().getType();
+ RankedTensorType outputType = getType();
if (hasZeroDimension(inputType) || hasZeroDimension(outputType))
return emitOpError() << "tensor has a dimension with size zero. Each "
"dimension of a tensor must have size >= 1";
+ if ((int64_t) getNewShape().size() != outputType.getRank())
+ return emitOpError() << "new shape does not match result rank";
+
+ for (auto [newShapeDim, outputShapeDim] :
+ zip(getNewShape(), outputType.getShape()))
+ if (newShapeDim != -1 && outputShapeDim != ShapedType::kDynamic &&
+ newShapeDim != outputShapeDim)
+ return emitOpError() << "new shape is inconsistent with result shape";
+
if (inputType.hasStaticShape() && outputType.hasStaticShape()) {
int64_t inputElementsNum = inputType.getNumElements();
int64_t outputElementsNum = outputType.getNumElements();
if (inputElementsNum != outputElementsNum) {
- return emitOpError() << "Cannot reshape " << inputElementsNum
+ return emitOpError() << "cannot reshape " << inputElementsNum
<< " elements into " << outputElementsNum;
}
}
int missingDims = llvm::count(getNewShape(), -1);
if (missingDims > 1)
- return emitOpError() << "At most one target dimension can be -1";
+ return emitOpError() << "expected at most one target dimension to be -1";
return mlir::success();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index ad28c56..8614559 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -18,14 +18,9 @@
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/BuiltinOps.h"
-#include "mlir/IR/IRMapping.h"
-#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/Support/FormatVariadic.h"
namespace mlir {
namespace tosa {
@@ -39,9 +34,87 @@ using namespace mlir::tosa;
namespace {
-void propagateShapesInRegion(Region &region);
+// Check whether this use case is replaceable. We define an op as
+// being replaceable if it is used by a TosaOp, or an op with a
+// type-inference related interface.
+// When a non-replaceable use is encountered, the value is wrapped in a
+// cast back to the original type after inference.
+bool isReplaceableUser(Operation *user) {
+ // Handle unregistered dialects.
+ if (!user->getDialect())
+ return false;
+
+ return user->getDialect()->getNamespace() ==
+ TosaDialect::getDialectNamespace() ||
+ isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
+}
+
+// During type propagation, the types of values in the operator graph are
+// updated. For the tosa.while_loop operation, types are speculatively updated
+// within the body region to determine the output type of the while_loop. This
+// process is performed until a fixed point is reached, then the types are
+// reverted.
+//
+// This class encapsulates the state information needed to perform the reversion
+// process or to commit to the final changes.
+class TypeModificationState {
+public:
+ TypeModificationState() = default;
+
+ ~TypeModificationState() {
+ // Ensure the recorded modifications are either committed or reverted.
+ assert(oldTypes.empty() && "unhandled type modifications");
+ }
+
+ // Update the state of the value and record the old type.
+ void setType(Value value, Type type) {
+ if (value.getType() != type) {
+ oldTypes.emplace_back(value, value.getType());
+ value.setType(type);
+ }
+ }
-void propagateShapesToTosaIf(Operation &op) {
+ // Revert changes made to the types in the IR by setting all the affected
+ // values to their old types.
+ void revert() {
+ // Otherwise revert the changes.
+ for (auto [value, type] : oldTypes)
+ value.setType(type);
+
+ oldTypes.clear();
+ }
+
+ // Commit the changes to the types in the IR.
+ // This requires inserting tensor.cast operations to mediate the newly
+ // inferred result types with users that do not support type inference.
+ void commit() {
+ // For each use whose type changed, cast the value with the new type back to
+ // the old type.
+ for (auto [value, oldType] : oldTypes) {
+ for (auto &use : value.getUses()) {
+ if (isReplaceableUser(use.getOwner()))
+ continue;
+
+ OpBuilder builder(value.getContext());
+ builder.setInsertionPoint(use.getOwner());
+
+ Location loc = value.getLoc();
+ use.set(builder.create<tensor::CastOp>(loc, oldType, value));
+ }
+ }
+
+ oldTypes.clear();
+ }
+
+private:
+ // A record of each value whose type was updated along with that value's
+ // previous type.
+ llvm::SmallVector<std::pair<Value, Type>> oldTypes;
+};
+
+void propagateShapesInRegion(Region &region, TypeModificationState &state);
+
+void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
IfOp ifOp = dyn_cast<IfOp>(op);
if (!ifOp)
return;
@@ -58,7 +131,7 @@ void propagateShapesToTosaIf(Operation &op) {
if (inferredTy.hasRank()) {
Type newType = oldType.clone(inferredTy.getShape());
- blockArg.setType(newType);
+ state.setType(blockArg, newType);
}
}
@@ -71,14 +144,14 @@ void propagateShapesToTosaIf(Operation &op) {
ValueKnowledge::join(operandKnowledge, blockKnowledge);
if (!joinedKnowledge)
continue;
- frontBlock.getArgument(i).setType(joinedKnowledge.getType());
+ state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
}
- propagateShapesInRegion(region);
+ propagateShapesInRegion(region, state);
}
}
-void propagateShapesToTosaWhile(Operation &op) {
+void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
WhileOp whileOp = dyn_cast<WhileOp>(op);
if (!whileOp)
return;
@@ -86,49 +159,29 @@ void propagateShapesToTosaWhile(Operation &op) {
// Determine what the expected argument types are to the cond/body blocks.
// The expected arguments should be compatible with ever iteration of the
// loop body / condition for tosa.while.
- llvm::SmallVector<Type> argTypes;
- for (auto operand : op.getOperands()) {
- auto operandTy = cast<ShapedType>(operand.getType());
- if (operandTy.hasRank()) {
- auto newTy = operandTy.clone(operandTy.getShape());
- argTypes.push_back(newTy);
- } else {
- argTypes.push_back(operand.getType());
- }
- }
-
- // Save out the type information so we can restore at the end.
- llvm::DenseMap<Value, Type> originalTypeMap;
- for (auto &block : op.getRegion(1)) {
- for (auto arg : block.getArguments())
- originalTypeMap[arg] = arg.getType();
- for (auto &op : block)
- for (auto result : op.getResults())
- originalTypeMap[result] = result.getType();
- }
+ SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
bool hasNewTypes = true;
while (hasNewTypes) {
+ TypeModificationState localState;
// Set types on the block args.
Region &bodyRegion = op.getRegion(1);
Block &block = bodyRegion.front();
for (int i = 0, s = argTypes.size(); i < s; i++) {
- block.getArgument(i).setType(argTypes[i]);
+ localState.setType(block.getArgument(i), argTypes[i]);
}
// Propagate to the end.
- propagateShapesInRegion(bodyRegion);
+ propagateShapesInRegion(bodyRegion, localState);
- // Find all the tosa yield types and verify there is atleast one.
+ // Find all the tosa yield types and verify there is a single one.
llvm::SmallVector<YieldOp> yieldOps;
for (auto &block : bodyRegion)
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
yieldOps.push_back(yieldOp);
- if (yieldOps.empty())
- return;
-
+ assert(yieldOps.size() == 1 && "missing or non-unique yield op");
// Using the new tosa.yield operand types, infer the new subtypes.
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
for (auto ty : argTypes) {
@@ -158,17 +211,8 @@ void propagateShapesToTosaWhile(Operation &op) {
argTypes[i] = newType;
}
- // The types inferred in the block assume the operand types specified for
- // this iteration. We need to restore the original types to ensure that
- // future iterations only use the already specified types, not possible
- // types from previous iterations.
- for (auto &block : bodyRegion) {
- for (auto arg : block.getArguments())
- arg.setType(originalTypeMap[arg]);
- for (auto &op : block)
- for (auto result : op.getResults())
- result.setType(originalTypeMap[result]);
- }
+ // Revert all changes made during the speculative part of the algorithm.
+ localState.revert();
}
// We now set the block arguments according to the most recent shape
@@ -176,41 +220,22 @@ void propagateShapesToTosaWhile(Operation &op) {
// iteration.
for (auto &region : op.getRegions()) {
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
- region.front().getArgument(i).setType(argTypes[i]);
+ state.setType(region.front().getArgument(i), argTypes[i]);
}
- propagateShapesInRegion(region);
+ propagateShapesInRegion(region, state);
}
}
-// Track the old type for each operand whose type was updated
-// during inference. This information is used to introduce casts
-// back to the type expected by the operand after inference.
-struct TypeRewriteInfo {
- OpOperand *operand;
- Type oldType;
-};
-
-void propagateShapesInRegion(Region &region) {
- // Check whether this use case is replaceable. We define an op as
- // being replaceable if it is used by a TosaOp, or an op with a
- // type-inference related interface.
- // When a non-replaceable use is encountered, the value is wrapped in a
- // cast back to the original type after inference.
- auto isReplaceableUser = [](Operation *user) -> bool {
- return user->getDialect()->getNamespace() ==
- TosaDialect::getDialectNamespace() ||
- isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
- };
-
- llvm::SmallVector<TypeRewriteInfo> requiresUpdate;
+void propagateShapesInRegion(Region &region, TypeModificationState &state) {
for (auto &block : region) {
for (Operation &op : block) {
- if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
+ if (!op.getDialect() ||
+ op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
continue;
- propagateShapesToTosaIf(op);
- propagateShapesToTosaWhile(op);
+ propagateShapesToTosaIf(op, state);
+ propagateShapesToTosaWhile(op, state);
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
@@ -252,30 +277,11 @@ void propagateShapesInRegion(Region &region) {
continue;
// Set new type
- result.setType(newKnowledge.getType());
-
- // Collect all uses of the operation which require update.
- for (auto &user : result.getUses()) {
- if (!isReplaceableUser(user.getOwner()))
- requiresUpdate.push_back({&user, resultTy});
- }
+ state.setType(result, newKnowledge.getType());
}
}
}
}
-
- // For each use whose type changed, cast the value with the new type back to
- // the old type.
- IRRewriter rewriter(region.getContext());
- for (auto [operand, oldType] : requiresUpdate) {
- rewriter.setInsertionPoint(operand->getOwner());
-
- auto oldValue = operand->get();
-
- auto loc = oldValue.getLoc();
- auto castOp = rewriter.create<tensor::CastOp>(loc, oldType, oldValue);
- operand->set(castOp);
- }
}
/// Pass that performs shape propagation across TOSA operations. This includes
@@ -285,7 +291,9 @@ struct TosaInferShapes
public:
void runOnOperation() override {
func::FuncOp func = getOperation();
- propagateShapesInRegion(func.getBody());
+ TypeModificationState state;
+ propagateShapesInRegion(func.getBody(), state);
+ state.commit();
}
};
} // namespace
diff --git a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
index 6d7e3bc..52359fa8 100644
--- a/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
+++ b/mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
@@ -47,17 +47,26 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
unsigned vscaleMax, presburger::BoundType boundType, bool closedUB,
StopConditionFn stopCondition) {
using namespace presburger;
-
assert(vscaleMin <= vscaleMax);
- ScalableValueBoundsConstraintSet scalableCstr(value.getContext(), vscaleMin,
- vscaleMax);
- int64_t pos = scalableCstr.populateConstraintsSet(value, dim, stopCondition);
+ // No stop condition specified: Keep adding constraints until the worklist
+ // is empty.
+ auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
+ mlir::ValueBoundsConstraintSet &cstr) {
+ return false;
+ };
+
+ ScalableValueBoundsConstraintSet scalableCstr(
+ value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
+ vscaleMin, vscaleMax);
+ int64_t pos = scalableCstr.populateConstraintsSet(value, dim);
// Project out all variables apart from vscale.
// This should result in constraints in terms of vscale only.
- scalableCstr.projectOut(
- [&](ValueDim p) { return p.first != scalableCstr.getVscaleValue(); });
+ auto projectOutFn = [&](ValueDim p) {
+ return p.first != scalableCstr.getVscaleValue();
+ };
+ scalableCstr.projectOut(projectOutFn);
assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
scalableCstr.positionToValueDim.size() &&
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index 593c1e5..8d733c5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -398,13 +398,30 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp,
transposeResults.push_back(targetExpr);
}
}
+
+ // Checks if only the outer, unit dimensions (of size 1) are permuted.
+ // Such transposes do not materially effect the underlying vector and can
+ // be omitted. EG: perm [1, 0, 2] applied to vector<1x1x8xi32>
+ bool transposeNonOuterUnitDims = false;
+ auto operandShape = operands[it.index()].getType().cast<ShapedType>();
+ for (auto [index, dim] :
+ llvm::enumerate(ArrayRef<int64_t>(perm).drop_back(1))) {
+ if (dim != static_cast<int64_t>(index) &&
+ operandShape.getDimSize(index) != 1) {
+ transposeNonOuterUnitDims = true;
+ break;
+ }
+ }
+
// Do the tranpose now if needed so that we can drop the
// correct dim using extract later.
if (tranposeNeeded) {
map = AffineMap::get(map.getNumDims(), 0, transposeResults,
contractOp.getContext());
- operands[it.index()] = rewriter.create<vector::TransposeOp>(
- loc, operands[it.index()], perm);
+ if (transposeNonOuterUnitDims) {
+ operands[it.index()] = rewriter.createOrFold<vector::TransposeOp>(
+ loc, operands[it.index()], perm);
+ }
}
}
// We have taken care to have the dim to be dropped be
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 4fa5b8a..b59e906 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -26,6 +26,9 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
// Reject index since getElementTypeBitWidth will abort for Index types.
if (!vecType || vecType.getElementType().isIndex())
return false;
+ // There are no dimension to fold if it is a 0-D vector.
+ if (vecType.getRank() == 0)
+ return false;
unsigned trailingVecDimBitWidth =
vecType.getShape().back() * vecType.getElementTypeBitWidth();
if (trailingVecDimBitWidth >= targetBitWidth)
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5944a0e..286f47c 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -11,6 +11,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"
+#include "llvm/ADT/SmallPtrSet.h"
using namespace mlir;
@@ -250,6 +251,14 @@ void RewriterBase::finalizeOpModification(Operation *op) {
rewriteListener->notifyOperationModified(op);
}
+void RewriterBase::replaceAllUsesExcept(
+ Value from, Value to, const SmallPtrSetImpl<Operation *> &preservedUsers) {
+ return replaceUsesWithIf(from, to, [&](OpOperand &use) {
+ Operation *user = use.getOwner();
+ return !preservedUsers.contains(user);
+ });
+}
+
void RewriterBase::replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced) {
diff --git a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
index b5b7d78..e93a9ef 100644
--- a/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
+++ b/mlir/lib/Interfaces/DataLayoutInterfaces.cpp
@@ -234,6 +234,15 @@ std::optional<uint64_t> mlir::detail::getDefaultIndexBitwidth(
return std::nullopt;
}
+// Returns the endianness if specified in the given entry. If the entry is empty
+// the default endianness represented by an empty attribute is returned.
+Attribute mlir::detail::getDefaultEndianness(DataLayoutEntryInterface entry) {
+ if (entry == DataLayoutEntryInterface())
+ return Attribute();
+
+ return entry.getValue();
+}
+
// Returns the memory space used for alloca operations if specified in the
// given entry. If the entry is empty the default memory space represented by
// an empty attribute is returned.
@@ -548,6 +557,22 @@ std::optional<uint64_t> mlir::DataLayout::getTypeIndexBitwidth(Type t) const {
});
}
+mlir::Attribute mlir::DataLayout::getEndianness() const {
+ checkValid();
+ if (endianness)
+ return *endianness;
+ DataLayoutEntryInterface entry;
+ if (originalLayout)
+ entry = originalLayout.getSpecForIdentifier(
+ originalLayout.getEndiannessIdentifier(originalLayout.getContext()));
+
+ if (auto iface = dyn_cast_or_null<DataLayoutOpInterface>(scope))
+ endianness = iface.getEndianness(entry);
+ else
+ endianness = detail::getDefaultEndianness(entry);
+ return *endianness;
+}
+
mlir::Attribute mlir::DataLayout::getAllocaMemorySpace() const {
checkValid();
if (allocaMemorySpace)
diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
index 99598f2..0d362c7 100644
--- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
+++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
@@ -67,8 +67,11 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) {
return std::nullopt;
}
-ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx)
- : builder(ctx) {}
+ValueBoundsConstraintSet::ValueBoundsConstraintSet(
+ MLIRContext *ctx, StopConditionFn stopCondition)
+ : builder(ctx), stopCondition(stopCondition) {
+ assert(stopCondition && "expected non-null stop condition");
+}
char ValueBoundsConstraintSet::ID = 0;
@@ -193,7 +196,8 @@ static Operation *getOwnerOfValue(Value value) {
return value.getDefiningOp();
}
-void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
+void ValueBoundsConstraintSet::processWorklist() {
+ LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n");
while (!worklist.empty()) {
int64_t pos = worklist.front();
worklist.pop();
@@ -214,13 +218,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
// Do not process any further if the stop condition is met.
auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim);
- if (stopCondition(value, maybeDim))
+ if (stopCondition(value, maybeDim, *this)) {
+ LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value
+ << " (dim: " << maybeDim << ")\n");
continue;
+ }
// Query `ValueBoundsOpInterface` for constraints. New items may be added to
// the worklist.
auto valueBoundsOp =
dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value));
+ LLVM_DEBUG(llvm::dbgs()
+ << "Query value bounds for: " << value
+ << " (owner: " << getOwnerOfValue(value)->getName() << ")\n");
if (valueBoundsOp) {
if (dim == kIndexValue) {
valueBoundsOp.populateBoundsForIndexValue(value, *this);
@@ -229,6 +239,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) {
}
continue;
}
+ LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n");
// If the op does not implement `ValueBoundsOpInterface`, check if it
// implements the `DestinationStyleOpInterface`. OpResults of such ops are
@@ -278,32 +289,20 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
bool closedUB) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
- assert(!stopCondition(value, dim) &&
- "stop condition should not be satisfied for starting point");
#endif // NDEBUG
int64_t ubAdjustment = closedUB ? 0 : 1;
Builder b(value.getContext());
mapOperands.clear();
- if (stopCondition(value, dim)) {
- // Special case: If the stop condition is satisfied for the input
- // value/dimension, directly return it.
- mapOperands.push_back(std::make_pair(value, dim));
- AffineExpr bound = b.getAffineDimExpr(0);
- if (type == BoundType::UB)
- bound = bound + ubAdjustment;
- resultMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
- b.getAffineDimExpr(0));
- return success();
- }
-
// Process the backward slice of `value` (i.e., reverse use-def chain) until
// `stopCondition` is met.
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
- ValueBoundsConstraintSet cstr(value.getContext());
+ ValueBoundsConstraintSet cstr(value.getContext(), stopCondition);
+ assert(!stopCondition(value, dim, cstr) &&
+ "stop condition should not be satisfied for starting point");
int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false);
- cstr.processWorklist(stopCondition);
+ cstr.processWorklist();
// Project out all variables (apart from `valueDim`) that do not match the
// stop condition.
@@ -313,7 +312,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound(
return false;
auto maybeDim =
p.second == kIndexValue ? std::nullopt : std::make_optional(p.second);
- return !stopCondition(p.first, maybeDim);
+ return !stopCondition(p.first, maybeDim, cstr);
});
// Compute lower and upper bounds for `valueDim`.
@@ -419,7 +418,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound(
bool closedUB) {
return computeBound(
resultMap, mapOperands, type, value, dim,
- [&](Value v, std::optional<int64_t> d) {
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
return llvm::is_contained(dependencies, std::make_pair(v, d));
},
closedUB);
@@ -455,7 +454,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound(
// Reify bounds in terms of any independent values.
return computeBound(
resultMap, mapOperands, type, value, dim,
- [&](Value v, std::optional<int64_t> d) { return isIndependent(v); },
+ [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) {
+ return isIndependent(v);
+ },
closedUB);
}
@@ -488,21 +489,19 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
presburger::BoundType type, AffineMap map, ValueDimList operands,
StopConditionFn stopCondition, bool closedUB) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
- ValueBoundsConstraintSet cstr(map.getContext());
- int64_t pos = 0;
- if (stopCondition) {
- cstr.populateConstraintsSet(map, operands, stopCondition, &pos);
- } else {
- // No stop condition specified: Keep adding constraints until a bound could
- // be computed.
- cstr.populateConstraintsSet(
- map, operands,
- [&](Value v, std::optional<int64_t> dim) {
- return cstr.cstr.getConstantBound64(type, pos).has_value();
- },
- &pos);
- }
+ // Default stop condition if none was specified: Keep adding constraints until
+ // a bound could be computed.
+ int64_t pos;
+ auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
+ ValueBoundsConstraintSet &cstr) {
+ return cstr.cstr.getConstantBound64(type, pos).has_value();
+ };
+
+ ValueBoundsConstraintSet cstr(
+ map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
+ cstr.populateConstraintsSet(map, operands, &pos);
+
// Compute constant bound for `valueDim`.
int64_t ubAdjustment = closedUB ? 0 : 1;
if (auto bound = cstr.cstr.getConstantBound64(type, pos))
@@ -510,8 +509,9 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
return failure();
}
-int64_t ValueBoundsConstraintSet::populateConstraintsSet(
- Value value, std::optional<int64_t> dim, StopConditionFn stopCondition) {
+int64_t
+ValueBoundsConstraintSet::populateConstraintsSet(Value value,
+ std::optional<int64_t> dim) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
#endif // NDEBUG
@@ -519,12 +519,12 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet(
AffineMap map =
AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
Builder(value.getContext()).getAffineDimExpr(0));
- return populateConstraintsSet(map, {{value, dim}}, stopCondition);
+ return populateConstraintsSet(map, {{value, dim}});
}
-int64_t ValueBoundsConstraintSet::populateConstraintsSet(
- AffineMap map, ValueDimList operands, StopConditionFn stopCondition,
- int64_t *posOut) {
+int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map,
+ ValueDimList operands,
+ int64_t *posOut) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
int64_t pos = insert(/*isSymbol=*/false);
if (posOut)
@@ -545,13 +545,7 @@ int64_t ValueBoundsConstraintSet::populateConstraintsSet(
// Process the backward slice of `operands` (i.e., reverse use-def chain)
// until `stopCondition` is met.
- if (stopCondition) {
- processWorklist(stopCondition);
- } else {
- // No stop condition specified: Keep adding constraints until the worklist
- // is empty.
- processWorklist([](Value v, std::optional<int64_t> dim) { return false; });
- }
+ processWorklist();
return pos;
}
diff --git a/mlir/lib/Support/Timing.cpp b/mlir/lib/Support/Timing.cpp
index 2249312..ac16eb7 100644
--- a/mlir/lib/Support/Timing.cpp
+++ b/mlir/lib/Support/Timing.cpp
@@ -32,6 +32,7 @@
using namespace mlir;
using namespace detail;
using DisplayMode = DefaultTimingManager::DisplayMode;
+using OutputFormat = DefaultTimingManager::OutputFormat;
constexpr llvm::StringLiteral kTimingDescription =
"... Execution time report ...";
@@ -109,56 +110,105 @@ TimingIdentifier TimingIdentifier::get(StringRef str, TimingManager &tm) {
namespace {
-/// Simple record class to record timing information.
-struct TimeRecord {
- TimeRecord(double wall = 0.0, double user = 0.0) : wall(wall), user(user) {}
+class OutputTextStrategy : public OutputStrategy {
+public:
+ OutputTextStrategy(raw_ostream &os) : OutputStrategy(os) {}
+
+ void printHeader(const TimeRecord &total) override {
+ // Figure out how many spaces to description name.
+ unsigned padding = (80 - kTimingDescription.size()) / 2;
+ os << "===" << std::string(73, '-') << "===\n";
+ os.indent(padding) << kTimingDescription << '\n';
+ os << "===" << std::string(73, '-') << "===\n";
- TimeRecord &operator+=(const TimeRecord &other) {
- wall += other.wall;
- user += other.user;
- return *this;
+ // Print the total time followed by the section headers.
+ os << llvm::format(" Total Execution Time: %.4f seconds\n\n", total.wall);
+ if (total.user != total.wall)
+ os << " ----User Time----";
+ os << " ----Wall Time---- ----Name----\n";
}
- TimeRecord &operator-=(const TimeRecord &other) {
- wall -= other.wall;
- user -= other.user;
- return *this;
+ void printFooter() override { os.flush(); }
+
+ void printTime(const TimeRecord &time, const TimeRecord &total) override {
+ if (total.user != total.wall) {
+ os << llvm::format(" %8.4f (%5.1f%%)", time.user,
+ 100.0 * time.user / total.user);
+ }
+ os << llvm::format(" %8.4f (%5.1f%%) ", time.wall,
+ 100.0 * time.wall / total.wall);
}
- /// Print the current time record to 'os', with a breakdown showing
- /// contributions to the give 'total' time record.
- void print(raw_ostream &os, const TimeRecord &total) {
- if (total.user != total.wall)
- os << llvm::format(" %8.4f (%5.1f%%)", user, 100.0 * user / total.user);
- os << llvm::format(" %8.4f (%5.1f%%) ", wall, 100.0 * wall / total.wall);
+ void printListEntry(StringRef name, const TimeRecord &time,
+ const TimeRecord &total, bool lastEntry) override {
+ printTime(time, total);
+ os << name << "\n";
+ }
+
+ void printTreeEntry(unsigned indent, StringRef name, const TimeRecord &time,
+ const TimeRecord &total) override {
+ printTime(time, total);
+ os.indent(indent) << name << "\n";
}
- double wall, user;
+ void printTreeEntryEnd(unsigned indent, bool lastEntry) override {}
};
-} // namespace
+class OutputJsonStrategy : public OutputStrategy {
+public:
+ OutputJsonStrategy(raw_ostream &os) : OutputStrategy(os) {}
-/// Utility to print a single line entry in the timer output.
-static void printTimeEntry(raw_ostream &os, unsigned indent, StringRef name,
- TimeRecord time, TimeRecord total) {
- time.print(os, total);
- os.indent(indent) << name << "\n";
-}
+ void printHeader(const TimeRecord &total) override { os << "[" << "\n"; }
-/// Utility to print the timer heading information.
-static void printTimeHeader(raw_ostream &os, TimeRecord total) {
- // Figure out how many spaces to description name.
- unsigned padding = (80 - kTimingDescription.size()) / 2;
- os << "===" << std::string(73, '-') << "===\n";
- os.indent(padding) << kTimingDescription << '\n';
- os << "===" << std::string(73, '-') << "===\n";
-
- // Print the total time followed by the section headers.
- os << llvm::format(" Total Execution Time: %.4f seconds\n\n", total.wall);
- if (total.user != total.wall)
- os << " ----User Time----";
- os << " ----Wall Time---- ----Name----\n";
-}
+ void printFooter() override {
+ os << "]" << "\n";
+ os.flush();
+ }
+
+ void printTime(const TimeRecord &time, const TimeRecord &total) override {
+ if (total.user != total.wall) {
+ os << "\"user\": {";
+ os << "\"duration\": " << llvm::format("%8.4f", time.user) << ", ";
+ os << "\"percentage\": "
+ << llvm::format("%5.1f", 100.0 * time.user / total.user);
+ os << "}, ";
+ }
+ os << "\"wall\": {";
+ os << "\"duration\": " << llvm::format("%8.4f", time.wall) << ", ";
+ os << "\"percentage\": "
+ << llvm::format("%5.1f", 100.0 * time.wall / total.wall);
+ os << "}";
+ }
+
+ void printListEntry(StringRef name, const TimeRecord &time,
+ const TimeRecord &total, bool lastEntry) override {
+ os << "{";
+ printTime(time, total);
+ os << ", \"name\": " << "\"" << name << "\"";
+ os << "}";
+ if (!lastEntry)
+ os << ",";
+ os << "\n";
+ }
+
+ void printTreeEntry(unsigned indent, StringRef name, const TimeRecord &time,
+ const TimeRecord &total) override {
+ os.indent(indent) << "{";
+ printTime(time, total);
+ os << ", \"name\": " << "\"" << name << "\"";
+ os << ", \"passes\": [" << "\n";
+ }
+
+ void printTreeEntryEnd(unsigned indent, bool lastEntry) override {
+ os.indent(indent) << "{}]";
+ os << "}";
+ if (!lastEntry)
+ os << ",";
+ os << "\n";
+ }
+};
+
+} // namespace
//===----------------------------------------------------------------------===//
// Timer Implementation for DefaultTimingManager
@@ -176,7 +226,8 @@ public:
using ChildrenMap = llvm::MapVector<const void *, std::unique_ptr<TimerImpl>>;
using AsyncChildrenMap = llvm::DenseMap<uint64_t, ChildrenMap>;
- TimerImpl(std::string &&name) : threadId(llvm::get_threadid()), name(name) {}
+ TimerImpl(std::string &&name, std::unique_ptr<OutputStrategy> &output)
+ : threadId(llvm::get_threadid()), name(name), output(output) {}
/// Start the timer.
void start() { startTime = std::chrono::steady_clock::now(); }
@@ -206,7 +257,7 @@ public:
TimerImpl *nestTail(std::unique_ptr<TimerImpl> &child,
function_ref<std::string()> nameBuilder) {
if (!child)
- child = std::make_unique<TimerImpl>(nameBuilder());
+ child = std::make_unique<TimerImpl>(nameBuilder(), output);
return child.get();
}
@@ -320,7 +371,7 @@ public:
}
/// Print the timing result in list mode.
- void printAsList(raw_ostream &os, TimeRecord total) {
+ void printAsList(TimeRecord total) {
// Flatten the leaf timers in the tree and merge them by name.
llvm::StringMap<TimeRecord> mergedTimers;
std::function<void(TimerImpl *)> addTimer = [&](TimerImpl *timer) {
@@ -343,34 +394,37 @@ public:
// Print the timing information sequentially.
for (auto &timeData : timerNameAndTime)
- printTimeEntry(os, 0, timeData.first, timeData.second, total);
+ output->printListEntry(timeData.first, timeData.second, total);
}
/// Print the timing result in tree mode.
- void printAsTree(raw_ostream &os, TimeRecord total, unsigned indent = 0) {
+ void printAsTree(TimeRecord total, unsigned indent = 0) {
unsigned childIndent = indent;
if (!hidden) {
- printTimeEntry(os, indent, name, getTimeRecord(), total);
+ output->printTreeEntry(indent, name, getTimeRecord(), total);
childIndent += 2;
}
for (auto &child : children) {
- child.second->printAsTree(os, total, childIndent);
+ child.second->printAsTree(total, childIndent);
+ }
+ if (!hidden) {
+ output->printTreeEntryEnd(indent);
}
}
/// Print the current timing information.
- void print(raw_ostream &os, DisplayMode displayMode) {
+ void print(DisplayMode displayMode) {
// Print the banner.
auto total = getTimeRecord();
- printTimeHeader(os, total);
+ output->printHeader(total);
// Defer to a specialized printer for each display mode.
switch (displayMode) {
case DisplayMode::List:
- printAsList(os, total);
+ printAsList(total);
break;
case DisplayMode::Tree:
- printAsTree(os, total);
+ printAsTree(total);
break;
}
@@ -379,9 +433,9 @@ public:
auto rest = total;
for (auto &child : children)
rest -= child.second->getTimeRecord();
- printTimeEntry(os, 0, "Rest", rest, total);
- printTimeEntry(os, 0, "Total", total, total);
- os.flush();
+ output->printListEntry("Rest", rest, total);
+ output->printListEntry("Total", total, total, /*lastEntry=*/true);
+ output->printFooter();
}
/// The last time instant at which the timer was started.
@@ -415,6 +469,8 @@ public:
/// Mutex for the async children.
std::mutex asyncMutex;
+
+ std::unique_ptr<OutputStrategy> &output;
};
} // namespace
@@ -435,9 +491,6 @@ public:
/// The configured display mode.
DisplayMode displayMode = DisplayMode::Tree;
- /// The stream where we should print our output. This will always be non-null.
- raw_ostream *output = &llvm::errs();
-
/// The root timer.
std::unique_ptr<TimerImpl> rootTimer;
};
@@ -446,7 +499,8 @@ public:
} // namespace mlir
DefaultTimingManager::DefaultTimingManager()
- : impl(std::make_unique<DefaultTimingManagerImpl>()) {
+ : impl(std::make_unique<DefaultTimingManagerImpl>()),
+ out(std::make_unique<OutputTextStrategy>(llvm::errs())) {
clear(); // initializes the root timer
}
@@ -469,26 +523,22 @@ DefaultTimingManager::DisplayMode DefaultTimingManager::getDisplayMode() const {
}
/// Change the stream where the output will be printed to.
-void DefaultTimingManager::setOutput(raw_ostream &os) { impl->output = &os; }
-
-/// Return the current output stream where the output will be printed to.
-raw_ostream &DefaultTimingManager::getOutput() const {
- assert(impl->output);
- return *impl->output;
+void DefaultTimingManager::setOutput(std::unique_ptr<OutputStrategy> output) {
+ out = std::move(output);
}
/// Print and clear the timing results.
void DefaultTimingManager::print() {
if (impl->enabled) {
impl->rootTimer->finalize();
- impl->rootTimer->print(*impl->output, impl->displayMode);
+ impl->rootTimer->print(impl->displayMode);
}
clear();
}
/// Clear the timing results.
void DefaultTimingManager::clear() {
- impl->rootTimer = std::make_unique<TimerImpl>("root");
+ impl->rootTimer = std::make_unique<TimerImpl>("root", out);
impl->rootTimer->hidden = true;
}
@@ -500,13 +550,13 @@ void DefaultTimingManager::dumpTimers(raw_ostream &os) {
/// Debug print the timers as a list.
void DefaultTimingManager::dumpAsList(raw_ostream &os) {
impl->rootTimer->finalize();
- impl->rootTimer->print(os, DisplayMode::List);
+ impl->rootTimer->print(DisplayMode::List);
}
/// Debug print the timers as a tree.
void DefaultTimingManager::dumpAsTree(raw_ostream &os) {
impl->rootTimer->finalize();
- impl->rootTimer->print(os, DisplayMode::Tree);
+ impl->rootTimer->print(DisplayMode::Tree);
}
std::optional<void *> DefaultTimingManager::rootTimer() {
@@ -549,6 +599,13 @@ struct DefaultTimingManagerOptions {
"display the results in a list sorted by total time"),
clEnumValN(DisplayMode::Tree, "tree",
"display the results ina with a nested tree view"))};
+ llvm::cl::opt<OutputFormat> outputFormat{
+ "mlir-output-format", llvm::cl::desc("Output format for timing data"),
+ llvm::cl::init(OutputFormat::Text),
+ llvm::cl::values(clEnumValN(OutputFormat::Text, "text",
+ "display the results in text format"),
+ clEnumValN(OutputFormat::Json, "json",
+ "display the results in JSON format"))};
};
} // namespace
@@ -564,4 +621,11 @@ void mlir::applyDefaultTimingManagerCLOptions(DefaultTimingManager &tm) {
return;
tm.setEnabled(options->timing);
tm.setDisplayMode(options->displayMode);
+
+ std::unique_ptr<OutputStrategy> printer;
+ if (options->outputFormat == OutputFormat::Text)
+ printer = std::make_unique<OutputTextStrategy>(llvm::errs());
+ else if (options->outputFormat == OutputFormat::Json)
+ printer = std::make_unique<OutputJsonStrategy>(llvm::errs());
+ tm.setOutput(std::move(printer));
}
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 0b07b4b..ee87c1d 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1104,7 +1104,7 @@ CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop)
std::string CppEmitter::getSubscriptName(emitc::SubscriptOp op) {
std::string out;
llvm::raw_string_ostream ss(out);
- ss << getOrCreateName(op.getArray());
+ ss << getOrCreateName(op.getValue());
for (auto index : op.getIndices()) {
ss << "[" << getOrCreateName(index) << "]";
}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 646d0ed..08ec578 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -804,13 +804,13 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
/// Allocate space for privatized reduction variables.
template <typename T>
-static void
-allocByValReductionVars(T loop, llvm::IRBuilderBase &builder,
- LLVM::ModuleTranslation &moduleTranslation,
- llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
- SmallVector<omp::DeclareReductionOp> &reductionDecls,
- SmallVector<llvm::Value *> &privateReductionVariables,
- DenseMap<Value, llvm::Value *> &reductionVariableMap) {
+static void allocByValReductionVars(
+ T loop, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation,
+ llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
+ SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+ SmallVectorImpl<llvm::Value *> &privateReductionVariables,
+ DenseMap<Value, llvm::Value *> &reductionVariableMap) {
llvm::IRBuilderBase::InsertPointGuard guard(builder);
builder.restoreIP(allocaIP);
auto args =
@@ -825,6 +825,25 @@ allocByValReductionVars(T loop, llvm::IRBuilderBase &builder,
}
}
+/// Map input argument to all reduction initialization regions
+template <typename T>
+static void
+mapInitializationArg(T loop, LLVM::ModuleTranslation &moduleTranslation,
+ SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+ unsigned i) {
+ // map input argument to the initialization region
+ mlir::omp::DeclareReductionOp &reduction = reductionDecls[i];
+ Region &initializerRegion = reduction.getInitializerRegion();
+ Block &entry = initializerRegion.front();
+ assert(entry.getNumArguments() == 1 &&
+ "the initialization region has one argument");
+
+ mlir::Value mlirSource = loop.getReductionVars()[i];
+ llvm::Value *llvmSource = moduleTranslation.lookupValue(mlirSource);
+ assert(llvmSource && "lookup reduction var");
+ moduleTranslation.mapValue(entry.getArgument(0), llvmSource);
+}
+
/// Collect reduction info
template <typename T>
static void collectReductionInfo(
@@ -858,6 +877,32 @@ static void collectReductionInfo(
}
}
+/// handling of DeclareReductionOp's cleanup region
+static LogicalResult inlineReductionCleanup(
+ llvm::SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
+ llvm::ArrayRef<llvm::Value *> privateReductionVariables,
+ LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder) {
+ for (auto [i, reductionDecl] : llvm::enumerate(reductionDecls)) {
+ Region &cleanupRegion = reductionDecl.getCleanupRegion();
+ if (cleanupRegion.empty())
+ continue;
+
+ // map the argument to the cleanup region
+ Block &entry = cleanupRegion.front();
+ moduleTranslation.mapValue(entry.getArgument(0),
+ privateReductionVariables[i]);
+
+ if (failed(inlineConvertOmpRegions(cleanupRegion, "omp.reduction.cleanup",
+ builder, moduleTranslation)))
+ return failure();
+
+ // clear block argument mapping in case it needs to be re-created with a
+ // different source for another use of the same reduction decl
+ moduleTranslation.forgetMapping(cleanupRegion);
+ }
+ return success();
+}
+
/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -902,6 +947,10 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
loop.getRegion().getArguments().take_back(loop.getNumReductionVars());
for (unsigned i = 0; i < loop.getNumReductionVars(); ++i) {
SmallVector<llvm::Value *> phis;
+
+ // map block argument to initializer region
+ mapInitializationArg(loop, moduleTranslation, reductionDecls, i);
+
if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
"omp.reduction.neutral", builder,
moduleTranslation, &phis)))
@@ -925,6 +974,11 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
builder.CreateStore(phis[0], privateReductionVariables[i]);
// the rest was handled in allocByValReductionVars
}
+
+ // forget the mapping for the initializer region because we might need a
+ // different mapping if this reduction declaration is re-used for a
+ // different variable
+ moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
}
// Store the mapping between reduction variables and their private copies on
@@ -1044,7 +1098,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
tempTerminator->eraseFromParent();
builder.restoreIP(nextInsertionPoint);
- return success();
+ // after the workshare loop, deallocate private reduction variables
+ return inlineReductionCleanup(reductionDecls, privateReductionVariables,
+ moduleTranslation, builder);
}
/// A RAII class that on construction replaces the region arguments of the
@@ -1097,13 +1153,13 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
LogicalResult bodyGenStatus = success();
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
- auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
- // Collect reduction declarations
- SmallVector<omp::DeclareReductionOp> reductionDecls;
- collectReductionDecls(opInst, reductionDecls);
+ // Collect reduction declarations
+ SmallVector<omp::DeclareReductionOp> reductionDecls;
+ collectReductionDecls(opInst, reductionDecls);
+ SmallVector<llvm::Value *> privateReductionVariables;
+ auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
// Allocate reduction vars
- SmallVector<llvm::Value *> privateReductionVariables;
DenseMap<Value, llvm::Value *> reductionVariableMap;
if (!isByRef) {
allocByValReductionVars(opInst, builder, moduleTranslation, allocaIP,
@@ -1118,6 +1174,9 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
opInst.getNumReductionVars());
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
SmallVector<llvm::Value *> phis;
+
+ // map the block argument
+ mapInitializationArg(opInst, moduleTranslation, reductionDecls, i);
if (failed(inlineConvertOmpRegions(
reductionDecls[i].getInitializerRegion(), "omp.reduction.neutral",
builder, moduleTranslation, &phis)))
@@ -1144,6 +1203,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
builder.CreateStore(phis[0], privateReductionVariables[i]);
// the rest is done in allocByValReductionVars
}
+
+ // clear block argument mapping in case it needs to be re-created with a
+ // different source for another use of the same reduction decl
+ moduleTranslation.forgetMapping(reductionDecls[i].getInitializerRegion());
}
// Store the mapping between reduction variables and their private copies on
@@ -1296,7 +1359,18 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
// TODO: Perform finalization actions for variables. This has to be
// called for variables which have destructors/finalizers.
- auto finiCB = [&](InsertPointTy codeGenIP) {};
+ auto finiCB = [&](InsertPointTy codeGenIP) {
+ InsertPointTy oldIP = builder.saveIP();
+ builder.restoreIP(codeGenIP);
+
+ // if the reduction has a cleanup region, inline it here to finalize the
+ // reduction variables
+ if (failed(inlineReductionCleanup(reductionDecls, privateReductionVariables,
+ moduleTranslation, builder)))
+ bodyGenStatus = failure();
+
+ builder.restoreIP(oldIP);
+ };
llvm::Value *ifCond = nullptr;
if (auto ifExprVar = opInst.getIfExprVar())
diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp
index 80e3b79..abe565e 100644
--- a/mlir/lib/Transforms/Mem2Reg.cpp
+++ b/mlir/lib/Transforms/Mem2Reg.cpp
@@ -202,6 +202,7 @@ private:
/// Contains the reaching definition at this operation. Reaching definitions
/// are only computed for promotable memory operations with blocking uses.
DenseMap<PromotableMemOpInterface, Value> reachingDefs;
+ DenseMap<PromotableMemOpInterface, Value> replacedValuesMap;
DominanceInfo &dominance;
MemorySlotPromotionInfo info;
const Mem2RegStatistics &statistics;
@@ -438,6 +439,7 @@ Value MemorySlotPromoter::computeReachingDefInBlock(Block *block,
assert(stored && "a memory operation storing to a slot must provide a "
"new definition of the slot");
reachingDef = stored;
+ replacedValuesMap[memOp] = stored;
}
}
}
@@ -552,6 +554,10 @@ void MemorySlotPromoter::removeBlockingUses() {
dominanceSort(usersToRemoveUses, *slot.ptr.getParentBlock()->getParent());
llvm::SmallVector<Operation *> toErase;
+ // List of all replaced values in the slot.
+ llvm::SmallVector<std::pair<Operation *, Value>> replacedValuesList;
+ // Ops to visit with the `visitReplacedValues` method.
+ llvm::SmallVector<PromotableOpInterface> toVisit;
for (Operation *toPromote : llvm::reverse(usersToRemoveUses)) {
if (auto toPromoteMemOp = dyn_cast<PromotableMemOpInterface>(toPromote)) {
Value reachingDef = reachingDefs.lookup(toPromoteMemOp);
@@ -565,7 +571,9 @@ void MemorySlotPromoter::removeBlockingUses() {
slot, info.userToBlockingUses[toPromote], rewriter,
reachingDef) == DeletionKind::Delete)
toErase.push_back(toPromote);
-
+ if (toPromoteMemOp.storesTo(slot))
+ if (Value replacedValue = replacedValuesMap[toPromoteMemOp])
+ replacedValuesList.push_back({toPromoteMemOp, replacedValue});
continue;
}
@@ -574,6 +582,12 @@ void MemorySlotPromoter::removeBlockingUses() {
if (toPromoteBasic.removeBlockingUses(info.userToBlockingUses[toPromote],
rewriter) == DeletionKind::Delete)
toErase.push_back(toPromote);
+ if (toPromoteBasic.requiresReplacedValues())
+ toVisit.push_back(toPromoteBasic);
+ }
+ for (PromotableOpInterface op : toVisit) {
+ rewriter.setInsertionPointAfter(op);
+ op.visitReplacedValues(replacedValuesList, rewriter);
}
for (Operation *toEraseOp : toErase)