aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect')
-rw-r--r--mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp15
-rw-r--r--mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp8
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp2
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp15
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp32
-rw-r--r--mlir/lib/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Padding.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp54
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp5
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp133
-rw-r--r--mlir/lib/Dialect/Polynomial/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt19
-rw-r--r--mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp96
-rw-r--r--mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp213
-rw-r--r--mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp41
-rw-r--r--mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp15
-rw-r--r--mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp17
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp214
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp11
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp20
-rw-r--r--mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp110
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp1
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp3
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp3
-rw-r--r--mlir/lib/Dialect/Tensor/Utils/Utils.cpp4
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp21
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp254
31 files changed, 1171 insertions, 163 deletions
diff --git a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
index e0c3abe..82a9fb0 100644
--- a/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -120,9 +120,7 @@ mlir::affine::fullyComposeAndComputeConstantDelta(Value value1, Value value2) {
mapOperands.push_back(value1);
mapOperands.push_back(value2);
affine::fullyComposeAffineMapAndOperands(&map, &mapOperands);
- ValueDimList valueDims;
- for (Value v : mapOperands)
- valueDims.push_back({v, std::nullopt});
return ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::EQ, map, valueDims);
+ presburger::BoundType::EQ,
+ ValueBoundsConstraintSet::Variable(map, mapOperands));
}
diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
index 117ee8e..1a266b7 100644
--- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp
@@ -16,16 +16,15 @@
using namespace mlir;
using namespace mlir::affine;
-static FailureOr<OpFoldResult>
-reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
- Value value, std::optional<int64_t> dim,
- ValueBoundsConstraintSet::StopConditionFn stopCondition,
- bool closedUB) {
+FailureOr<OpFoldResult> mlir::affine::reifyValueBound(
+ OpBuilder &b, Location loc, presburger::BoundType type,
+ const ValueBoundsConstraintSet::Variable &var,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
- boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+ boundMap, mapOperands, type, var, stopCondition, closedUB)))
return failure();
// Reify bound.
@@ -93,7 +92,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound(
// the owner of `value`.
return v != value;
};
- return reifyValueBound(b, loc, type, value, dim,
+ return reifyValueBound(b, loc, type, {value, dim},
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
@@ -105,7 +104,7 @@ FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound(
ValueBoundsConstraintSet &cstr) {
return v != value;
};
- return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
+ return reifyValueBound(b, loc, type, value,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
diff --git a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
index f0d4380..7cfcc41 100644
--- a/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -107,9 +107,9 @@ struct SelectOpInterface
// If trueValue <= falseValue:
// * result <= falseValue
// * result >= trueValue
- if (cstr.compare(trueValue, dim,
+ if (cstr.compare(/*lhs=*/{trueValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
- falseValue, dim)) {
+ /*rhs=*/{falseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
@@ -121,9 +121,9 @@ struct SelectOpInterface
// If falseValue <= trueValue:
// * result <= trueValue
// * result >= falseValue
- if (cstr.compare(falseValue, dim,
+ if (cstr.compare(/*lhs=*/{falseValue, dim},
ValueBoundsConstraintSet::ComparisonOperator::LE,
- trueValue, dim)) {
+ /*rhs=*/{trueValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
index 79fabd6..f87f3d6 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp
@@ -449,7 +449,7 @@ struct IndexCastPattern final : NarrowingPattern<CastOp> {
return failure();
FailureOr<int64_t> ub = ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, in, /*dim=*/std::nullopt,
+ presburger::BoundType::UB, in,
/*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(ub))
return failure();
diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
index fad2212..5fb7953 100644
--- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp
@@ -61,16 +61,15 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map,
return buildExpr(map.getResult(0));
}
-static FailureOr<OpFoldResult>
-reifyValueBound(OpBuilder &b, Location loc, presburger::BoundType type,
- Value value, std::optional<int64_t> dim,
- ValueBoundsConstraintSet::StopConditionFn stopCondition,
- bool closedUB) {
+FailureOr<OpFoldResult> mlir::arith::reifyValueBound(
+ OpBuilder &b, Location loc, presburger::BoundType type,
+ const ValueBoundsConstraintSet::Variable &var,
+ ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) {
// Compute bound.
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeBound(
- boundMap, mapOperands, type, value, dim, stopCondition, closedUB)))
+ boundMap, mapOperands, type, var, stopCondition, closedUB)))
return failure();
// Materialize tensor.dim/memref.dim ops.
@@ -128,7 +127,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound(
// the owner of `value`.
return v != value;
};
- return reifyValueBound(b, loc, type, value, dim,
+ return reifyValueBound(b, loc, type, {value, dim},
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
@@ -140,7 +139,7 @@ FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound(
ValueBoundsConstraintSet &cstr) {
return v != value;
};
- return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt,
+ return reifyValueBound(b, loc, type, value,
stopCondition ? stopCondition : reifyToOperands,
closedUB);
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
index 31500c6..b595c6d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp
@@ -165,6 +165,35 @@ int getNumberOfSMETilesForVectorType(VectorType type) {
return (vectorRows * vectorCols) / (minNumElts * minNumElts);
}
+/// Legalize `arith.constant dense<value>` splat operations to fit within SME
+/// tiles by decomposing them into tile-sized operations.
+struct LegalizeArithConstantOpsByDecomposition
+ : public OneToNOpConversionPattern<arith::ConstantOp> {
+ using OneToNOpConversionPattern::OneToNOpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor,
+ OneToNPatternRewriter &rewriter) const override {
+ auto vectorType = dyn_cast<VectorType>(constantOp.getType());
+ auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr());
+ if (!vectorType || !denseAttr || !denseAttr.isSplat())
+ return failure();
+
+ if (!isMultipleOfSMETileVectorType(vectorType))
+ return rewriter.notifyMatchFailure(constantOp,
+ kMatchFailureNotSMETileTypeMultiple);
+
+ auto smeTileType = getSMETileTypeForElement(vectorType.getElementType());
+ auto tileCount = getNumberOfSMETilesForVectorType(vectorType);
+ auto tileSplat = rewriter.create<arith::ConstantOp>(
+ constantOp.getLoc(), denseAttr.resizeSplat(smeTileType));
+ rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat),
+ adaptor.getResultMapping());
+
+ return success();
+ }
+};
+
/// Legalize `vector.outerproduct` operations to fit within SME tiles by
/// decomposing them into tile-sized operations.
struct LegalizeVectorOuterProductOpsByDecomposition
@@ -637,7 +666,8 @@ struct VectorLegalizationPass
// Note: High benefit to ensure masked outer products are lowered first.
patterns.add<LegalizeMaskedVectorOuterProductOpsByDecomposition>(
converter, context, 1024);
- patterns.add<LegalizeVectorOuterProductOpsByDecomposition,
+ patterns.add<LegalizeArithConstantOpsByDecomposition,
+ LegalizeVectorOuterProductOpsByDecomposition,
LegalizeTransferReadOpsByDecomposition,
LegalizeTransferWriteOpsByDecomposition>(converter, context);
populateFuncTypeConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt
index b1ba5a3..a324ce7 100644
--- a/mlir/lib/Dialect/CMakeLists.txt
+++ b/mlir/lib/Dialect/CMakeLists.txt
@@ -28,6 +28,7 @@ add_subdirectory(OpenACCMPCommon)
add_subdirectory(OpenMP)
add_subdirectory(PDL)
add_subdirectory(PDLInterp)
+add_subdirectory(Polynomial)
add_subdirectory(Quant)
add_subdirectory(SCF)
add_subdirectory(Shape)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
index 8c4b70d..518d2e1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp
@@ -72,8 +72,10 @@ static LogicalResult computePaddedShape(linalg::LinalgOp opToPad,
// Otherwise, try to compute a constant upper bound for the size value.
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, opOperand->get(),
- /*dim=*/i, /*stopCondition=*/nullptr, /*closedUB=*/true);
+ presburger::BoundType::UB,
+ {opOperand->get(),
+ /*dim=*/i},
+ /*stopCondition=*/nullptr, /*closedUB=*/true);
if (failed(upperBound)) {
LLVM_DEBUG(DBGS() << "----could not compute a bounding box for padding");
return failure();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
index ac896d6..71eb59d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp
@@ -257,14 +257,12 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer(
if (auto attr = llvm::dyn_cast_if_present<Attribute>(rangeValue.size)) {
size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
} else {
- Value materializedSize =
- getValueOrCreateConstantIndexOp(b, loc, rangeValue.size);
FailureOr<int64_t> upperBound =
ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB, materializedSize, /*dim=*/std::nullopt,
+ presburger::BoundType::UB, rangeValue.size,
/*stopCondition=*/nullptr, /*closedUB=*/true);
size = failed(upperBound)
- ? materializedSize
+ ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size)
: b.create<arith::ConstantIndexOp>(loc, *upperBound);
}
LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 2578565..df61381 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1412,10 +1412,11 @@ static SmallVector<int64_t> getTiledPackShape(tensor::PackOp packOp,
/// Create a TransferReadOp from `source` with static shape `readShape`. If the
/// vector type for the read is not the same as the type of `source`, then a
-/// mask is created on the read.
+/// mask is created on the read. If `doMasking` parameter is set to false we
+/// update the `inBounds` attribute instead of masking.
static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
Value source, ArrayRef<int64_t> readShape,
- Value padValue) {
+ Value padValue, bool doMasking = true) {
assert(llvm::none_of(readShape,
[](int64_t s) { return s == ShapedType::kDynamic; }));
auto sourceShape = dyn_cast<ShapedType>(source.getType()).getShape();
@@ -1424,14 +1425,21 @@ static Value createReadOrMaskedRead(OpBuilder &builder, Location loc,
auto vectorType = VectorType::get(readShape, padValue.getType());
int64_t readRank = readShape.size();
auto zero = builder.create<arith::ConstantIndexOp>(loc, 0);
+ SmallVector<bool> inBoundsVal(readRank, true);
+ if (!doMasking) {
+ // Update the inBounds attribute.
+ for (unsigned i = 0; i < readRank; i++)
+ inBoundsVal[i] = sourceShape[i] == readShape[i];
+ }
auto transferReadOp = builder.create<vector::TransferReadOp>(
loc,
/*vectorType=*/vectorType,
/*source=*/source,
/*indices=*/SmallVector<Value>(readRank, zero),
/*padding=*/padValue,
- /*inBounds=*/SmallVector<bool>(readRank, true));
- if (llvm::equal(readShape, sourceShape)) {
+ /*inBounds=*/inBoundsVal);
+
+ if (llvm::equal(readShape, sourceShape) || !doMasking) {
return transferReadOp;
}
SmallVector<OpFoldResult> mixedSourceDims =
@@ -1482,11 +1490,10 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
return write;
}
-/// Vectorize tensor::PackOp with (1) static innerTiles and (2) constant
-/// padding value into:
+/// Vectorize tensor::PackOp with (1) static innerTiles (2) constant
+/// padding value and (3) input vector sizes into:
/// masked_transfer_read->shape_cast->transpose->transfer_write_in_bounds
/// As in the following example:
-///
/// %pack = tensor.pack %src inner_dims_pos = [2, 1] inner_tiles = [16, 2]
/// into %dst : tensor<32x8x16xf32> -> tensor<32x4x1x16x2xf32>
///
@@ -1505,6 +1512,10 @@ static Operation *createWriteOrMaskedWrite(OpBuilder &builder, Location loc,
/// %empty[%c0_0, %c0_0, %c0_0, %c0_0, %c0_0]
/// {in_bounds = [true, true, true, true, true]}
/// : vector<32x4x1x16x2xf32>, tensor<32x4x1x16x2xf32>
+///
+/// If the (3) input vector sizes are not provided, the vector sizes are
+/// determined by the result tensor shape. Also, we update the inBounds
+/// attribute instead of masking.
static LogicalResult
vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes,
@@ -1525,6 +1536,16 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
(void)status; // prevent unused variable warning on non-assert builds.
assert(succeeded(status) && "failed to reify result shapes");
+ // If the input vector sizes are not provided, then the vector sizes are
+ // determined by the result tensor shape. In case the vector sizes aren't
+ // provided, we update the inBounds attribute instead of masking.
+ bool doMasking = true;
+ if (inputVectorSizes.empty()) {
+ ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
+ inputVectorSizes = resultTensorShape.take_front(packOp.getSourceRank());
+ doMasking = false;
+ }
+
// Create masked TransferReadOp.
SmallVector<int64_t> inputShape(inputVectorSizes);
auto innerTiles = packOp.getStaticInnerTiles();
@@ -1536,7 +1557,7 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, tensor::PackOp packOp,
for (auto [idx, size] : enumerate(innerTiles))
inputShape[innerDimsPos[idx]] *= size;
auto maskedRead = createReadOrMaskedRead(rewriter, loc, packOp.getSource(),
- inputShape, padValue);
+ inputShape, padValue, doMasking);
// Create ShapeCastOp.
SmallVector<int64_t> destShape(inputVectorSizes);
@@ -1763,7 +1784,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op,
/// Returns success if `inputVectorSizes` is a valid masking configuraion for
/// given `shape`, i.e., it meets:
/// 1. The numbers of elements in both array are equal.
-/// 2. `inputVectorSizes` does nos have dynamic dimensions.
+/// 2. `inputVectorSizes` does not have dynamic dimensions.
/// 3. All the values in `inputVectorSizes` are greater than or equal to
/// static sizes in `shape`.
static LogicalResult
@@ -1881,18 +1902,25 @@ static LogicalResult vectorizeLinalgOpPrecondition(
return success();
}
-/// TODO: Use a matcher to check for a constant padding value.
static LogicalResult
vectorizePackOpPrecondition(tensor::PackOp packOp,
ArrayRef<int64_t> inputVectorSizes) {
auto padValue = packOp.getPaddingValue();
- if (padValue && !padValue.getDefiningOp<arith::ConstantOp>()) {
+ Attribute cstAttr;
+ if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) {
LDBG("pad value is not constant: " << packOp << "\n");
return failure();
}
-
ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape();
- if (failed(isValidMaskedInputVector(
+ bool satisfyEmptyCond = true;
+ if (inputVectorSizes.empty()) {
+ if (!packOp.getDestType().hasStaticShape() ||
+ !packOp.getSourceType().hasStaticShape())
+ satisfyEmptyCond = false;
+ }
+
+ if (!satisfyEmptyCond &&
+ failed(isValidMaskedInputVector(
resultTensorShape.take_front(packOp.getSourceRank()),
inputVectorSizes)))
return failure();
diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
index 10ba508..1f06318 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp
@@ -23,12 +23,11 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
ValueRange independencies) {
if (ofr.is<Attribute>())
return ofr;
- Value value = ofr.get<Value>();
AffineMap boundMap;
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
- boundMap, mapOperands, presburger::BoundType::UB, value,
- /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+ boundMap, mapOperands, presburger::BoundType::UB, ofr, independencies,
+ /*closedUB=*/true)))
return failure();
return affine::materializeComputedBound(b, loc, boundMap, mapOperands);
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 35fb174..5d2281c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1548,90 +1548,41 @@ void printWsloop(OpAsmPrinter &p, Operation *op, Region &region,
p.printRegion(region, /*printEntryBlockArgs=*/false);
}
-/// loop-control ::= `(` ssa-id-list `)` `:` type `=` loop-bounds
-/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` inclusive? steps
-/// steps := `step` `(`ssa-id-list`)`
-ParseResult
-parseLoopControl(OpAsmParser &parser, Region &region,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &lowerBound,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &upperBound,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &steps,
- SmallVectorImpl<Type> &loopVarTypes, UnitAttr &inclusive) {
- // Parse an opening `(` followed by induction variables followed by `)`
- SmallVector<OpAsmParser::Argument> ivs;
- Type loopVarType;
- if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
- parser.parseColonType(loopVarType) ||
- // Parse loop bounds.
- parser.parseEqual() ||
- parser.parseOperandList(lowerBound, ivs.size(),
- OpAsmParser::Delimiter::Paren) ||
- parser.parseKeyword("to") ||
- parser.parseOperandList(upperBound, ivs.size(),
- OpAsmParser::Delimiter::Paren))
- return failure();
-
- if (succeeded(parser.parseOptionalKeyword("inclusive")))
- inclusive = UnitAttr::get(parser.getBuilder().getContext());
-
- // Parse step values.
- if (parser.parseKeyword("step") ||
- parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
- return failure();
-
- // Now parse the body.
- loopVarTypes = SmallVector<Type>(ivs.size(), loopVarType);
- for (auto &iv : ivs)
- iv.type = loopVarType;
-
- return parser.parseRegion(region, ivs);
-}
-
-void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
- ValueRange lowerBound, ValueRange upperBound,
- ValueRange steps, TypeRange loopVarTypes,
- UnitAttr inclusive) {
- auto args = region.front().getArguments();
- p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
- << ") to (" << upperBound << ") ";
- if (inclusive)
- p << "inclusive ";
- p << "step (" << steps << ") ";
- p.printRegion(region, /*printEntryBlockArgs=*/false);
-}
-
//===----------------------------------------------------------------------===//
// Simd construct [2.9.3.1]
//===----------------------------------------------------------------------===//
-void SimdLoopOp::build(OpBuilder &builder, OperationState &state,
- const SimdLoopClauseOps &clauses) {
+void SimdOp::build(OpBuilder &builder, OperationState &state,
+ const SimdClauseOps &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, reductionByRefAttr, reductionVars,
// privatizers, reductionDeclSymbols.
- SimdLoopOp::build(
- builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
- clauses.alignedVars, makeArrayAttr(ctx, clauses.alignmentAttrs),
- clauses.ifVar, clauses.nontemporalVars, clauses.orderAttr,
- clauses.simdlenAttr, clauses.safelenAttr, clauses.loopInclusiveAttr);
+ SimdOp::build(builder, state, clauses.alignedVars,
+ makeArrayAttr(ctx, clauses.alignmentAttrs), clauses.ifVar,
+ clauses.nontemporalVars, clauses.orderAttr, clauses.simdlenAttr,
+ clauses.safelenAttr);
}
-LogicalResult SimdLoopOp::verify() {
- if (this->getLowerBound().empty()) {
- return emitOpError() << "empty lowerbound for simd loop operation";
- }
- if (this->getSimdlen().has_value() && this->getSafelen().has_value() &&
- this->getSimdlen().value() > this->getSafelen().value()) {
+LogicalResult SimdOp::verify() {
+ if (getSimdlen().has_value() && getSafelen().has_value() &&
+ getSimdlen().value() > getSafelen().value())
return emitOpError()
<< "simdlen clause and safelen clause are both present, but the "
"simdlen value is not less than or equal to safelen value";
- }
- if (verifyAlignedClause(*this, this->getAlignmentValues(),
- this->getAlignedVars())
+
+ if (verifyAlignedClause(*this, getAlignmentValues(), getAlignedVars())
.failed())
return failure();
- if (verifyNontemporalClause(*this, this->getNontemporalVars()).failed())
+
+ if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
return failure();
+
+ if (!isWrapper())
+ return emitOpError() << "must be a loop wrapper";
+
+ if (getNestedWrapper())
+ return emitOpError() << "must wrap an 'omp.loop_nest' directly";
+
return success();
}
@@ -1656,6 +1607,17 @@ LogicalResult DistributeOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
+ if (!isWrapper())
+ return emitOpError() << "must be a loop wrapper";
+
+ if (LoopWrapperInterface nested = getNestedWrapper()) {
+ // Check for the allowed leaf constructs that may appear in a composite
+ // construct directly after DISTRIBUTE.
+ if (!isa<ParallelOp, SimdOp>(nested))
+ return emitError() << "only supported nested wrappers are 'omp.parallel' "
+ "and 'omp.simd'";
+ }
+
return success();
}
@@ -1818,9 +1780,8 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
TaskloopOp::build(
- builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
- clauses.loopInclusiveAttr, clauses.ifVar, clauses.finalVar,
- clauses.untiedAttr, clauses.mergeableAttr, clauses.inReductionVars,
+ builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
+ clauses.mergeableAttr, clauses.inReductionVars,
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
@@ -1859,6 +1820,16 @@ LogicalResult TaskloopOp::verify() {
"the grainsize clause and num_tasks clause are mutually exclusive and "
"may not appear on the same taskloop directive");
}
+
+ if (!isWrapper())
+ return emitOpError() << "must be a loop wrapper";
+
+ if (LoopWrapperInterface nested = getNestedWrapper()) {
+ // Check for the allowed leaf constructs that may appear in a composite
+ // construct directly after TASKLOOP.
+ if (!isa<SimdOp>(nested))
+ return emitError() << "only supported nested wrapper is 'omp.simd'";
+ }
return success();
}
@@ -1936,9 +1907,27 @@ LogicalResult LoopNestOp::verify() {
<< "range argument type does not match corresponding IV type";
}
+ auto wrapper =
+ llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
+
+ if (!wrapper || !wrapper.isWrapper())
+ return emitOpError() << "expects parent op to be a valid loop wrapper";
+
return success();
}
+void LoopNestOp::gatherWrappers(
+ SmallVectorImpl<LoopWrapperInterface> &wrappers) {
+ Operation *parent = (*this)->getParentOp();
+ while (auto wrapper =
+ llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
+ if (!wrapper.isWrapper())
+ break;
+ wrappers.push_back(wrapper);
+ parent = parent->getParentOp();
+ }
+}
+
//===----------------------------------------------------------------------===//
// WsloopOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Polynomial/CMakeLists.txt b/mlir/lib/Dialect/Polynomial/CMakeLists.txt
new file mode 100644
index 0000000..f33061b2
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/CMakeLists.txt
@@ -0,0 +1 @@
+add_subdirectory(IR)
diff --git a/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
new file mode 100644
index 0000000..7f5b325
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_dialect_library(MLIRPolynomialDialect
+ Polynomial.cpp
+ PolynomialAttributes.cpp
+ PolynomialDialect.cpp
+ PolynomialOps.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Polynomial
+
+ DEPENDS
+ MLIRPolynomialIncGen
+ MLIRPolynomialAttributesIncGen
+ MLIRBuiltinAttributesIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRSupport
+ MLIRDialect
+ MLIRIR
+ )
diff --git a/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
new file mode 100644
index 0000000..5916ffb
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/Polynomial.cpp
@@ -0,0 +1,96 @@
+//===- Polynomial.cpp - MLIR storage type for static Polynomial -*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/Twine.h"
+#include "llvm/Support/raw_ostream.h"
+
+namespace mlir {
+namespace polynomial {
+
+FailureOr<Polynomial> Polynomial::fromMonomials(ArrayRef<Monomial> monomials) {
+ // A polynomial's terms are canonically stored in order of increasing degree.
+ auto monomialsCopy = llvm::SmallVector<Monomial>(monomials);
+ std::sort(monomialsCopy.begin(), monomialsCopy.end());
+
+ // Ensure non-unique exponents are not present. Since we sorted the list by
+ // exponent, a linear scan of adjancent monomials suffices.
+ if (std::adjacent_find(monomialsCopy.begin(), monomialsCopy.end(),
+ [](const Monomial &lhs, const Monomial &rhs) {
+ return lhs.exponent == rhs.exponent;
+ }) != monomialsCopy.end()) {
+ return failure();
+ }
+
+ return Polynomial(monomialsCopy);
+}
+
+Polynomial Polynomial::fromCoefficients(ArrayRef<int64_t> coeffs) {
+ llvm::SmallVector<Monomial> monomials;
+ auto size = coeffs.size();
+ monomials.reserve(size);
+ for (size_t i = 0; i < size; i++) {
+ monomials.emplace_back(coeffs[i], i);
+ }
+ auto result = Polynomial::fromMonomials(monomials);
+ // Construction guarantees unique exponents, so the failure mode of
+ // fromMonomials can be bypassed.
+ assert(succeeded(result));
+ return result.value();
+}
+
+void Polynomial::print(raw_ostream &os, ::llvm::StringRef separator,
+ ::llvm::StringRef exponentiation) const {
+ bool first = true;
+ for (const Monomial &term : terms) {
+ if (first) {
+ first = false;
+ } else {
+ os << separator;
+ }
+ std::string coeffToPrint;
+ if (term.coefficient == 1 && term.exponent.uge(1)) {
+ coeffToPrint = "";
+ } else {
+ llvm::SmallString<16> coeffString;
+ term.coefficient.toStringSigned(coeffString);
+ coeffToPrint = coeffString.str();
+ }
+
+ if (term.exponent == 0) {
+ os << coeffToPrint;
+ } else if (term.exponent == 1) {
+ os << coeffToPrint << "x";
+ } else {
+ llvm::SmallString<16> expString;
+ term.exponent.toStringSigned(expString);
+ os << coeffToPrint << "x" << exponentiation << expString;
+ }
+ }
+}
+
+void Polynomial::print(raw_ostream &os) const { print(os, " + ", "**"); }
+
+std::string Polynomial::toIdentifier() const {
+ std::string result;
+ llvm::raw_string_ostream os(result);
+ print(os, "_", "");
+ return os.str();
+}
+
+unsigned Polynomial::getDegree() const {
+ return terms.back().exponent.getZExtValue();
+}
+
+} // namespace polynomial
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
new file mode 100644
index 0000000..ee09c73
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -0,0 +1,213 @@
+//===- PolynomialAttributes.cpp - Polynomial dialect attrs ------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "llvm/ADT/StringExtras.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+
+namespace mlir {
+namespace polynomial {
+
+void PolynomialAttr::print(AsmPrinter &p) const {
+ p << '<';
+ p << getPolynomial();
+ p << '>';
+}
+
+/// Try to parse a monomial. If successful, populate the fields of the outparam
+/// `monomial` with the results, and the `variable` outparam with the parsed
+/// variable name. Sets shouldParseMore to true if the monomial is followed by
+/// a '+'.
+ParseResult parseMonomial(AsmParser &parser, Monomial &monomial,
+ llvm::StringRef &variable, bool &isConstantTerm,
+ bool &shouldParseMore) {
+ APInt parsedCoeff(apintBitWidth, 1);
+ auto parsedCoeffResult = parser.parseOptionalInteger(parsedCoeff);
+ monomial.coefficient = parsedCoeff;
+
+ isConstantTerm = false;
+ shouldParseMore = false;
+
+ // A + indicates it's a constant term with more to go, as in `1 + x`.
+ if (succeeded(parser.parseOptionalPlus())) {
+ // If no coefficient was parsed, and there's a +, then it's effectively
+ // parsing an empty string.
+ if (!parsedCoeffResult.has_value()) {
+ return failure();
+ }
+ monomial.exponent = APInt(apintBitWidth, 0);
+ isConstantTerm = true;
+ shouldParseMore = true;
+ return success();
+ }
+
+ // A monomial can be a trailing constant term, as in `x + 1`.
+ if (failed(parser.parseOptionalKeyword(&variable))) {
+ // If neither a coefficient nor a variable was found, then it's effectively
+ // parsing an empty string.
+ if (!parsedCoeffResult.has_value()) {
+ return failure();
+ }
+
+ monomial.exponent = APInt(apintBitWidth, 0);
+ isConstantTerm = true;
+ return success();
+ }
+
+ // Parse exponentiation symbol as `**`. We can't use caret because it's
+ // reserved for basic block identifiers If no star is present, it's treated
+ // as a polynomial with exponent 1.
+ if (succeeded(parser.parseOptionalStar())) {
+ // If there's one * there must be two.
+ if (failed(parser.parseStar())) {
+ return failure();
+ }
+
+ // If there's a **, then the integer exponent is required.
+ APInt parsedExponent(apintBitWidth, 0);
+ if (failed(parser.parseInteger(parsedExponent))) {
+ parser.emitError(parser.getCurrentLocation(),
+ "found invalid integer exponent");
+ return failure();
+ }
+
+ monomial.exponent = parsedExponent;
+ } else {
+ monomial.exponent = APInt(apintBitWidth, 1);
+ }
+
+ if (succeeded(parser.parseOptionalPlus())) {
+ shouldParseMore = true;
+ }
+ return success();
+}
+
+Attribute PolynomialAttr::parse(AsmParser &parser, Type type) {
+ if (failed(parser.parseLess()))
+ return {};
+
+ llvm::SmallVector<Monomial> monomials;
+ llvm::StringSet<> variables;
+
+ while (true) {
+ Monomial parsedMonomial;
+ llvm::StringRef parsedVariableRef;
+ bool isConstantTerm;
+ bool shouldParseMore;
+ if (failed(parseMonomial(parser, parsedMonomial, parsedVariableRef,
+ isConstantTerm, shouldParseMore))) {
+ parser.emitError(parser.getCurrentLocation(), "expected a monomial");
+ return {};
+ }
+
+ if (!isConstantTerm) {
+ std::string parsedVariable = parsedVariableRef.str();
+ variables.insert(parsedVariable);
+ }
+ monomials.push_back(parsedMonomial);
+
+ if (shouldParseMore)
+ continue;
+
+ if (succeeded(parser.parseOptionalGreater())) {
+ break;
+ }
+ parser.emitError(
+ parser.getCurrentLocation(),
+ "expected + and more monomials, or > to end polynomial attribute");
+ return {};
+ }
+
+ if (variables.size() > 1) {
+ std::string vars = llvm::join(variables.keys(), ", ");
+ parser.emitError(
+ parser.getCurrentLocation(),
+ "polynomials must have one indeterminate, but there were multiple: " +
+ vars);
+ }
+
+ auto result = Polynomial::fromMonomials(monomials);
+ if (failed(result)) {
+ parser.emitError(parser.getCurrentLocation())
+ << "parsed polynomial must have unique exponents among monomials";
+ return {};
+ }
+ return PolynomialAttr::get(parser.getContext(), result.value());
+}
+
+void RingAttr::print(AsmPrinter &p) const {
+ p << "#polynomial.ring<coefficientType=" << getCoefficientType()
+ << ", coefficientModulus=" << getCoefficientModulus()
+ << ", polynomialModulus=" << getPolynomialModulus() << '>';
+}
+
+Attribute RingAttr::parse(AsmParser &parser, Type type) {
+ if (failed(parser.parseLess()))
+ return {};
+
+ if (failed(parser.parseKeyword("coefficientType")))
+ return {};
+
+ if (failed(parser.parseEqual()))
+ return {};
+
+ Type ty;
+ if (failed(parser.parseType(ty)))
+ return {};
+
+ if (failed(parser.parseComma()))
+ return {};
+
+ IntegerAttr coefficientModulusAttr = nullptr;
+ if (succeeded(parser.parseKeyword("coefficientModulus"))) {
+ if (failed(parser.parseEqual()))
+ return {};
+
+ IntegerType iType = ty.dyn_cast<IntegerType>();
+ if (!iType) {
+ parser.emitError(parser.getCurrentLocation(),
+ "coefficientType must specify an integer type");
+ return {};
+ }
+ APInt coefficientModulus(iType.getWidth(), 0);
+ auto result = parser.parseInteger(coefficientModulus);
+ if (failed(result)) {
+ parser.emitError(parser.getCurrentLocation(),
+ "invalid coefficient modulus");
+ return {};
+ }
+ coefficientModulusAttr = IntegerAttr::get(iType, coefficientModulus);
+
+ if (failed(parser.parseComma()))
+ return {};
+ }
+
+ PolynomialAttr polyAttr = nullptr;
+ if (succeeded(parser.parseKeyword("polynomialModulus"))) {
+ if (failed(parser.parseEqual()))
+ return {};
+
+ PolynomialAttr attr;
+ if (failed(parser.parseAttribute<PolynomialAttr>(attr)))
+ return {};
+ polyAttr = attr;
+ }
+
+ if (failed(parser.parseGreater()))
+ return {};
+
+ return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
+ polyAttr);
+}
+
+} // namespace polynomial
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
new file mode 100644
index 0000000..a672a59
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialDialect.cpp
@@ -0,0 +1,41 @@
+//===- PolynomialDialect.cpp - Polynomial dialect ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialOps.h"
+#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.h"
+#include "llvm/ADT/TypeSwitch.h"
+
+using namespace mlir;
+using namespace mlir::polynomial;
+
+#include "mlir/Dialect/Polynomial/IR/PolynomialDialect.cpp.inc"
+
+#define GET_ATTRDEF_CLASSES
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.cpp.inc"
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.cpp.inc"
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc"
+
+void PolynomialDialect::initialize() {
+ addAttributes<
+#define GET_ATTRDEF_LIST
+#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.cpp.inc"
+ >();
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/Polynomial/IR/PolynomialTypes.cpp.inc"
+ >();
+ addOperations<
+#define GET_OP_LIST
+#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc"
+ >();
+}
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
new file mode 100644
index 0000000..96c59a2
--- /dev/null
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -0,0 +1,15 @@
+//===- PolynomialOps.cpp - Polynomial dialect ops ---------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Polynomial/IR/Polynomial.h"
+
+using namespace mlir;
+using namespace mlir::polynomial;
+
+#define GET_OP_CLASSES
+#include "mlir/Dialect/Polynomial/IR/Polynomial.cpp.inc"
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 087ffc4..17a1c01 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -61,12 +61,13 @@ struct ForOpInterface
// An EQ constraint can be added if the yielded value (dimension size)
// equals the corresponding block argument (dimension size).
if (cstr.populateAndCompare(
- yieldedValue, dim, ValueBoundsConstraintSet::ComparisonOperator::EQ,
- iterArg, dim)) {
+ /*lhs=*/{yieldedValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::EQ,
+ /*rhs=*/{iterArg, dim})) {
if (dim.has_value()) {
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
} else {
- cstr.bound(value) == initArg;
+ cstr.bound(value) == cstr.getExpr(initArg);
}
}
}
@@ -113,8 +114,9 @@ struct IfOpInterface
// * result <= elseValue
// * result >= thenValue
if (cstr.populateAndCompare(
- thenValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
- elseValue, dim)) {
+ /*lhs=*/{thenValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ /*rhs=*/{elseValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
@@ -127,8 +129,9 @@ struct IfOpInterface
// * result <= thenValue
// * result >= elseValue
if (cstr.populateAndCompare(
- elseValue, dim, ValueBoundsConstraintSet::ComparisonOperator::LE,
- thenValue, dim)) {
+ /*lhs=*/{elseValue, dim},
+ ValueBoundsConstraintSet::ComparisonOperator::LE,
+ /*rhs=*/{thenValue, dim})) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
diff --git a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
index e549420..a2925ae 100644
--- a/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/Transforms/CMakeLists.txt
@@ -14,6 +14,7 @@ add_mlir_dialect_library(MLIRSCFTransforms
StructuralTypeConversions.cpp
TileUsingInterface.cpp
WrapInZeroTripCheck.cpp
+ UpliftWhileToFor.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SCF
diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
new file mode 100644
index 0000000..7b4024b
--- /dev/null
+++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp
@@ -0,0 +1,214 @@
+//===- UpliftWhileToFor.cpp - scf.while to scf.for loop uplifting ---------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Transforms SCF.WhileOp's into SCF.ForOp's.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+
+namespace {
+struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::WhileOp loop,
+ PatternRewriter &rewriter) const override {
+ return upliftWhileToForLoop(rewriter, loop);
+ }
+};
+} // namespace
+
+FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter,
+ scf::WhileOp loop) {
+ Block *beforeBody = loop.getBeforeBody();
+ if (!llvm::hasSingleElement(beforeBody->without_terminator()))
+ return rewriter.notifyMatchFailure(loop, "Loop body must have single op");
+
+ auto cmp = dyn_cast<arith::CmpIOp>(beforeBody->front());
+ if (!cmp)
+ return rewriter.notifyMatchFailure(loop,
+ "Loop body must have single cmp op");
+
+ scf::ConditionOp beforeTerm = loop.getConditionOp();
+ if (!cmp->hasOneUse() || beforeTerm.getCondition() != cmp.getResult())
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Expected single condition use: " << *cmp;
+ });
+
+ // All `before` block args must be directly forwarded to ConditionOp.
+ // They will be converted to `scf.for` `iter_vars` except induction var.
+ if (ValueRange(beforeBody->getArguments()) != beforeTerm.getArgs())
+ return rewriter.notifyMatchFailure(loop, "Invalid args order");
+
+ using Pred = arith::CmpIPredicate;
+ Pred predicate = cmp.getPredicate();
+ if (predicate != Pred::slt && predicate != Pred::sgt)
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Expected 'slt' or 'sgt' predicate: " << *cmp;
+ });
+
+ BlockArgument inductionVar;
+ Value ub;
+ DominanceInfo dom;
+
+ // Check if cmp has a suitable form. One of the arguments must be a `before`
+ // block arg, other must be defined outside `scf.while` and will be treated
+ // as upper bound.
+ for (bool reverse : {false, true}) {
+ auto expectedPred = reverse ? Pred::sgt : Pred::slt;
+ if (cmp.getPredicate() != expectedPred)
+ continue;
+
+ auto arg1 = reverse ? cmp.getRhs() : cmp.getLhs();
+ auto arg2 = reverse ? cmp.getLhs() : cmp.getRhs();
+
+ auto blockArg = dyn_cast<BlockArgument>(arg1);
+ if (!blockArg || blockArg.getOwner() != beforeBody)
+ continue;
+
+ if (!dom.properlyDominates(arg2, loop))
+ continue;
+
+ inductionVar = blockArg;
+ ub = arg2;
+ break;
+ }
+
+ if (!inductionVar)
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Unrecognized cmp form: " << *cmp;
+ });
+
+ // inductionVar must have 2 uses: one is in `cmp` and other is `condition`
+ // arg.
+ if (!llvm::hasNItems(inductionVar.getUses(), 2))
+ return rewriter.notifyMatchFailure(loop, [&](Diagnostic &diag) {
+ diag << "Unrecognized induction var: " << inductionVar;
+ });
+
+ Block *afterBody = loop.getAfterBody();
+ scf::YieldOp afterTerm = loop.getYieldOp();
+ unsigned argNumber = inductionVar.getArgNumber();
+ Value afterTermIndArg = afterTerm.getResults()[argNumber];
+
+ Value inductionVarAfter = afterBody->getArgument(argNumber);
+
+ // Find suitable `addi` op inside `after` block, one of the args must be an
+ // Induction var passed from `before` block and second arg must be defined
+ // outside of the loop and will be considered step value.
+ // TODO: Add `subi` support?
+ auto addOp = afterTermIndArg.getDefiningOp<arith::AddIOp>();
+ if (!addOp)
+ return rewriter.notifyMatchFailure(loop, "Didn't found suitable 'addi' op");
+
+ Value step;
+ if (addOp.getLhs() == inductionVarAfter) {
+ step = addOp.getRhs();
+ } else if (addOp.getRhs() == inductionVarAfter) {
+ step = addOp.getLhs();
+ }
+
+ if (!step || !dom.properlyDominates(step, loop))
+ return rewriter.notifyMatchFailure(loop, "Invalid 'addi' form");
+
+ Value lb = loop.getInits()[argNumber];
+
+ assert(lb.getType().isIntOrIndex());
+ assert(lb.getType() == ub.getType());
+ assert(lb.getType() == step.getType());
+
+ llvm::SmallVector<Value> newArgs;
+
+ // Populate inits for new `scf.for`, skip induction var.
+ newArgs.reserve(loop.getInits().size());
+ for (auto &&[i, init] : llvm::enumerate(loop.getInits())) {
+ if (i == argNumber)
+ continue;
+
+ newArgs.emplace_back(init);
+ }
+
+ Location loc = loop.getLoc();
+
+ // With `builder == nullptr`, ForOp::build will try to insert terminator at
+ // the end of newly created block and we don't want it. Provide empty
+ // dummy builder instead.
+ auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {};
+ auto newLoop =
+ rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder);
+
+ Block *newBody = newLoop.getBody();
+
+ // Populate block args for `scf.for` body, move induction var to the front.
+ newArgs.clear();
+ ValueRange newBodyArgs = newBody->getArguments();
+ for (auto i : llvm::seq<size_t>(0, newBodyArgs.size())) {
+ if (i < argNumber) {
+ newArgs.emplace_back(newBodyArgs[i + 1]);
+ } else if (i == argNumber) {
+ newArgs.emplace_back(newBodyArgs.front());
+ } else {
+ newArgs.emplace_back(newBodyArgs[i]);
+ }
+ }
+
+ rewriter.inlineBlockBefore(loop.getAfterBody(), newBody, newBody->end(),
+ newArgs);
+
+ auto term = cast<scf::YieldOp>(newBody->getTerminator());
+
+ // Populate new yield args, skipping the induction var.
+ newArgs.clear();
+ for (auto &&[i, arg] : llvm::enumerate(term.getResults())) {
+ if (i == argNumber)
+ continue;
+
+ newArgs.emplace_back(arg);
+ }
+
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(term);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(term, newArgs);
+
+ // Compute induction var value after loop execution.
+ rewriter.setInsertionPointAfter(newLoop);
+ Value one;
+ if (isa<IndexType>(step.getType())) {
+ one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ } else {
+ one = rewriter.create<arith::ConstantIntOp>(loc, 1, step.getType());
+ }
+
+ Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one);
+ Value len = rewriter.create<arith::SubIOp>(loc, ub, lb);
+ len = rewriter.create<arith::AddIOp>(loc, len, stepDec);
+ len = rewriter.create<arith::DivSIOp>(loc, len, step);
+ len = rewriter.create<arith::SubIOp>(loc, len, one);
+ Value res = rewriter.create<arith::MulIOp>(loc, len, step);
+ res = rewriter.create<arith::AddIOp>(loc, lb, res);
+
+ // Reconstruct `scf.while` results, inserting final induction var value
+ // into proper place.
+ newArgs.clear();
+ llvm::append_range(newArgs, newLoop.getResults());
+ newArgs.insert(newArgs.begin() + argNumber, res);
+ rewriter.replaceOp(loop, newArgs);
+ return newLoop;
+}
+
+void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) {
+ patterns.add<UpliftWhileOp>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
index 5b7c0a5..bbc318e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/TargetAndABI.cpp
@@ -120,17 +120,16 @@ bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) {
StringRef spirv::getEntryPointABIAttrName() { return "spirv.entry_point_abi"; }
-spirv::EntryPointABIAttr
-spirv::getEntryPointABIAttr(MLIRContext *context,
- ArrayRef<int32_t> workgroupSize,
- std::optional<int> subgroupSize) {
+spirv::EntryPointABIAttr spirv::getEntryPointABIAttr(
+ MLIRContext *context, ArrayRef<int32_t> workgroupSize,
+ std::optional<int> subgroupSize, std::optional<int> targetWidth) {
DenseI32ArrayAttr workgroupSizeAttr;
if (!workgroupSize.empty()) {
assert(workgroupSize.size() == 3);
workgroupSizeAttr = DenseI32ArrayAttr::get(context, workgroupSize);
}
- return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr,
- subgroupSize);
+ return spirv::EntryPointABIAttr::get(context, workgroupSizeAttr, subgroupSize,
+ targetWidth);
}
spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) {
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 6150b5e..2024a2e 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -157,7 +157,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
// Erase workgroup size.
entryPointAttr = spirv::EntryPointABIAttr::get(
entryPointAttr.getContext(), DenseI32ArrayAttr(),
- entryPointAttr.getSubgroupSize());
+ entryPointAttr.getSubgroupSize(), entryPointAttr.getTargetWidth());
}
}
if (std::optional<int> subgroupSize = entryPointAttr.getSubgroupSize()) {
@@ -170,10 +170,24 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
// Erase subgroup size.
entryPointAttr = spirv::EntryPointABIAttr::get(
entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
- std::nullopt);
+ std::nullopt, entryPointAttr.getTargetWidth());
}
}
- if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize())
+ if (std::optional<int> targetWidth = entryPointAttr.getTargetWidth()) {
+ std::optional<ArrayRef<spirv::Capability>> caps =
+ spirv::getCapabilities(spirv::ExecutionMode::SignedZeroInfNanPreserve);
+ if (!caps || targetEnv.allows(*caps)) {
+ builder.create<spirv::ExecutionModeOp>(
+ funcOp.getLoc(), funcOp,
+ spirv::ExecutionMode::SignedZeroInfNanPreserve, *targetWidth);
+ // Erase target width.
+ entryPointAttr = spirv::EntryPointABIAttr::get(
+ entryPointAttr.getContext(), entryPointAttr.getWorkgroupSize(),
+ entryPointAttr.getSubgroupSize(), std::nullopt);
+ }
+ }
+ if (entryPointAttr.getWorkgroupSize() || entryPointAttr.getSubgroupSize() ||
+ entryPointAttr.getTargetWidth())
funcOp->setAttr(entryPointAttrName, entryPointAttr);
else
funcOp->removeAttr(entryPointAttrName);
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index e905839..516b094 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -30,6 +30,14 @@
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.cpp.inc"
#include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrEnums.cpp.inc"
+// Forward declarations, following custom print/parsing methods are referenced
+// by the generated code for SparseTensorTypes.td.
+static mlir::ParseResult parseLevelRange(mlir::AsmParser &,
+ mlir::sparse_tensor::Level &,
+ mlir::sparse_tensor::Level &);
+static void printLevelRange(mlir::AsmPrinter &, mlir::sparse_tensor::Level,
+ mlir::sparse_tensor::Level);
+
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/SparseTensor/IR/SparseTensorTypes.cpp.inc"
@@ -1953,6 +1961,108 @@ LogicalResult SortOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Sparse Tensor Iteration Operations.
+//===----------------------------------------------------------------------===//
+
+IterSpaceType IteratorType::getIterSpaceType() const {
+ return IterSpaceType::get(getContext(), getEncoding(), getLoLvl(),
+ getHiLvl());
+}
+
+IteratorType IterSpaceType::getIteratorType() const {
+ return IteratorType::get(getContext(), getEncoding(), getLoLvl(), getHiLvl());
+}
+
+/// Parses a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
+static ParseResult parseLevelRange(AsmParser &parser, Level &lvlLo,
+ Level &lvlHi) {
+ if (parser.parseInteger(lvlLo))
+ return failure();
+
+ if (succeeded(parser.parseOptionalKeyword("to"))) {
+ if (parser.parseInteger(lvlHi))
+ return failure();
+ } else {
+ lvlHi = lvlLo + 1;
+ }
+
+ if (lvlHi <= lvlLo)
+ parser.emitError(parser.getNameLoc(),
+ "expect larger level upper bound than lower bound");
+
+ return success();
+}
+
+/// Parses a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
+static ParseResult parseLevelRange(OpAsmParser &parser, IntegerAttr &lvlLoAttr,
+ IntegerAttr &lvlHiAttr) {
+ Level lvlLo, lvlHi;
+ if (parseLevelRange(parser, lvlLo, lvlHi))
+ return failure();
+
+ lvlLoAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlLo);
+ lvlHiAttr = IntegerAttr::get(parser.getBuilder().getIndexType(), lvlHi);
+ return success();
+}
+
+/// Prints a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
+static void printLevelRange(AsmPrinter &p, Level lo, Level hi) {
+
+ if (lo + 1 == hi)
+ p << lo;
+ else
+ p << lo << " to " << hi;
+}
+
+/// Prints a level range in the form "$lo `to` $hi"
+/// or simply "$lo" if $hi - $lo = 1
+static void printLevelRange(OpAsmPrinter &p, Operation *, IntegerAttr lvlLo,
+ IntegerAttr lvlHi) {
+ unsigned lo = lvlLo.getValue().getZExtValue();
+ unsigned hi = lvlHi.getValue().getZExtValue();
+ printLevelRange(p, lo, hi);
+}
+
+LogicalResult ExtractIterSpaceOp::inferReturnTypes(
+ MLIRContext *ctx, std::optional<Location> loc, ValueRange ops,
+ DictionaryAttr attr, OpaqueProperties prop, RegionRange region,
+ SmallVectorImpl<mlir::Type> &ret) {
+
+ ExtractIterSpaceOp::Adaptor adaptor(ops, attr, prop, region);
+ SparseTensorType stt = getSparseTensorType(adaptor.getTensor());
+ ret.push_back(IterSpaceType::get(ctx, stt.getEncoding(), adaptor.getLoLvl(),
+ adaptor.getHiLvl()));
+ return success();
+}
+
+LogicalResult ExtractIterSpaceOp::verify() {
+ if (getLoLvl() >= getHiLvl())
+ return emitOpError("expected smaller level low than level high");
+
+ TypedValue<IteratorType> pIter = getParentIter();
+ if ((pIter && getLoLvl() == 0) || (!pIter && getLoLvl() != 0)) {
+ return emitOpError(
+ "parent iterator should be specified iff level lower bound equals 0");
+ }
+
+ if (pIter) {
+ IterSpaceType spaceTp = getResultSpace().getType();
+ if (pIter.getType().getEncoding() != spaceTp.getEncoding())
+ return emitOpError(
+ "mismatch in parent iterator encoding and iteration space encoding.");
+
+ if (spaceTp.getLoLvl() != pIter.getType().getHiLvl())
+ return emitOpError("parent iterator should be used to extract an "
+ "iteration space from a consecutive level.");
+ }
+
+ return success();
+}
+
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *SparseTensorDialect::materializeConstant(OpBuilder &builder,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index f497be6..3a89720 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -24,6 +24,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/Passes.h"
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index 67080d8..d25efcf 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -289,8 +289,7 @@ static UnpackTileDimInfo getUnpackTileDimInfo(OpBuilder &b, UnPackOp unpackOp,
info.isAlignedToInnerTileSize = false;
FailureOr<int64_t> cstSize = ValueBoundsConstraintSet::computeConstantBound(
- presburger::BoundType::UB,
- getValueOrCreateConstantIndexOp(b, loc, tileSize), /*dim=*/std::nullopt,
+ presburger::BoundType::UB, tileSize,
/*stopCondition=*/nullptr, /*closedUB=*/true);
std::optional<int64_t> cstInnerSize = getConstantIntValue(innerTileSize);
if (!failed(cstSize) && cstInnerSize) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
index 72173086..a89ce20 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp
@@ -28,7 +28,8 @@ static FailureOr<OpFoldResult> makeIndependent(OpBuilder &b, Location loc,
ValueDimList mapOperands;
if (failed(ValueBoundsConstraintSet::computeIndependentBound(
boundMap, mapOperands, presburger::BoundType::UB, value,
- /*dim=*/std::nullopt, independencies, /*closedUB=*/true)))
+ independencies,
+ /*closedUB=*/true)))
return failure();
return mlir::affine::materializeComputedBound(b, loc, boundMap, mapOperands);
}
diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
index 2dd91e2..15381ec 100644
--- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
@@ -154,7 +154,7 @@ bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
- op.getSource(), op.getResult(), srcDim, resultDim);
+ {op.getSource(), srcDim}, {op.getResult(), resultDim});
if (failed(equalDimSize) || !*equalDimSize)
return false;
++srcDim;
@@ -178,7 +178,7 @@ bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
- op.getSource(), op.getResult(), dim, resultDim);
+ {op.getSource(), dim}, {op.getResult(), resultDim});
if (failed(equalDimSize) || !*equalDimSize)
return false;
++resultDim;
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index dc19022..53f958c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -396,6 +396,13 @@ DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne(
static_cast<RewriterBase::Listener *>(rewriter.getListener());
FrozenRewritePatternSet frozenPatterns(std::move(patterns));
+ config.maxIterations = getMaxIterations() == static_cast<uint64_t>(-1)
+ ? GreedyRewriteConfig::kNoLimit
+ : getMaxIterations();
+ config.maxNumRewrites = getMaxNumRewrites() == static_cast<uint64_t>(-1)
+ ? GreedyRewriteConfig::kNoLimit
+ : getMaxNumRewrites();
+
// Apply patterns and CSE repetitively until a fixpoint is reached. If no CSE
// was requested, apply the greedy pattern rewrite only once. (The greedy
// pattern rewrite driver already iterates to a fixpoint internally.)
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 0b3f4b9..24719fe 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -32,6 +32,17 @@ void XeGPUDialect::initialize() {
//===----------------------------------------------------------------------===//
// XeGPU_TensorDescAttr
//===----------------------------------------------------------------------===//
+TensorDescAttr TensorDescAttr::get(mlir::MLIRContext *context,
+ xegpu::MemoryScope memory_scope,
+ int array_length, bool boundary_check,
+ bool scattered) {
+ auto scopeAttr = MemoryScopeAttr::get(context, memory_scope);
+ auto lengthAttr =
+ IntegerAttr::get(IntegerType::get(context, 64), array_length);
+ auto boundaryAttr = BoolAttr::get(context, boundary_check);
+ auto scatteredAttr = BoolAttr::get(context, scattered);
+ return Base::get(context, scopeAttr, lengthAttr, boundaryAttr, scatteredAttr);
+}
//===----------------------------------------------------------------------===//
// XeGPU_TensorDescType
@@ -96,6 +107,16 @@ void TensorDescType::print(::mlir::AsmPrinter &printer) const {
printer << ">";
}
+TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
+ mlir::Type elementType, bool scattered,
+ int array_length, MemoryScope memory_scope,
+ bool boundary_check) {
+ auto context = elementType.getContext();
+ auto attr = TensorDescAttr::get(context, memory_scope, array_length,
+ boundary_check, scattered);
+ return Base::get(context, shape, elementType, attr);
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 02106f2..530c50e 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -9,6 +9,9 @@
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/TypeUtilities.h"
+
+#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "xegpu"
@@ -16,8 +19,8 @@ namespace mlir {
namespace xegpu {
static void transpose(llvm::ArrayRef<int64_t> trans,
- std::vector<int64_t> &shape) {
- std::vector<int64_t> old = shape;
+ SmallVector<int64_t> &shape) {
+ SmallVector<int64_t> old = shape;
for (size_t i = 0; i < trans.size(); i++)
shape[i] = old[trans[i]];
}
@@ -38,6 +41,38 @@ static std::string makeString(T array, bool breakline = false) {
return buf;
}
+static SmallVector<int64_t> getShapeOf(Type type) {
+ SmallVector<int64_t> shape;
+ if (auto ty = llvm::dyn_cast<ShapedType>(type))
+ shape = SmallVector<int64_t>(ty.getShape());
+ else
+ shape.push_back(1);
+ return shape;
+}
+
+static int64_t getRankOf(Value val) {
+ auto type = val.getType();
+ if (auto ty = llvm::dyn_cast<ShapedType>(type))
+ return ty.getRank();
+ return 0;
+}
+
+static bool isReadHintOrNone(const CachePolicyAttr &attr) {
+ if (!attr)
+ return true;
+ auto kind = attr.getValue();
+ return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
+ kind == CachePolicy::STREAMING || kind == CachePolicy::READ_INVALIDATE;
+}
+
+static bool isWriteHintOrNone(const CachePolicyAttr &attr) {
+ if (!attr)
+ return true;
+ auto kind = attr.getValue();
+ return kind == CachePolicy::CACHED || kind == CachePolicy::UNCACHED ||
+ kind == CachePolicy::WRITE_BACK || kind == CachePolicy::WRITE_THROUGH;
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -114,6 +149,29 @@ LogicalResult CreateNdDescOp::verify() {
return emitOpError("TensorDesc should have the same element "
"type with the source if it is a memref.\n");
+ if (getType().getScattered())
+ return emitOpError("Expects a non-scattered TensorDesc.\n");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_PrefetchNdOp
+//===----------------------------------------------------------------------===//
+LogicalResult PrefetchNdOp::verify() {
+ auto tdescTy = getTensorDescType();
+ if (tdescTy.getScattered())
+ return emitOpError("Expects a non-scattered TensorDesc.\n");
+
+ if (!isReadHintOrNone(getL1HintAttr()))
+ return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+
+ if (!isReadHintOrNone(getL2HintAttr()))
+ return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+ if (!isReadHintOrNone(getL3HintAttr()))
+ return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+
return success();
}
@@ -125,22 +183,26 @@ LogicalResult LoadNdOp::verify() {
auto valueTy = getType();
if (tdescTy.getRank() != 2)
- return emitOpError(
- "The TensorDesc for LoadNdOp should be a 2D TensorDesc.");
+ return emitOpError("Expecting a 2D TensorDesc.\n");
+
+ if (tdescTy.getScattered())
+ return emitOpError("Expects a non-scattered TensorDesc.\n");
if (!valueTy)
return emitOpError("Invalid result, it should be a VectorType.\n");
- auto tdescElemTy = tdescTy.getElementType();
- auto valueElemTy = valueTy.getElementType();
+ if (!isReadHintOrNone(getL1HintAttr()))
+ return emitOpError("invlid l1_hint: ") << getL1HintAttr();
- if (tdescElemTy != valueElemTy)
- return emitOpError(
- "Value should have the same element type as TensorDesc.");
+ if (!isReadHintOrNone(getL2HintAttr()))
+ return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+ if (!isReadHintOrNone(getL3HintAttr()))
+ return emitOpError("invlid l3_hint: ") << getL3HintAttr();
auto array_len = tdescTy.getArrayLength();
- auto tdescShape = tdescTy.getShape().vec();
- auto valueShape = valueTy.getShape().vec();
+ auto tdescShape = getShapeOf(tdescTy);
+ auto valueShape = getShapeOf(valueTy);
if (getTranspose()) {
auto trans = getTranspose().value();
@@ -174,26 +236,174 @@ LogicalResult LoadNdOp::verify() {
// XeGPU_StoreNdOp
//===----------------------------------------------------------------------===//
LogicalResult StoreNdOp::verify() {
- auto dstTy = getTensorDesc().getType(); // Tile
- auto valTy = getValue().getType().cast<VectorType>(); // Vector
+ auto dstTy = getTensorDescType(); // Tile
+ auto valTy = getValueType(); // Vector
if (dstTy.getRank() != 2)
- return emitOpError("Expecting a 2D TensorDesc shape.\n");
+ return emitOpError("Expecting a 2D TensorDesc.\n");
+
+ if (dstTy.getScattered())
+ return emitOpError("Expects a non-scattered TensorDesc.\n");
if (!valTy)
return emitOpError("Exepcting a VectorType result.\n");
- auto dstElemTy = dstTy.getElementType();
- auto valElemTy = valTy.getElementType();
+ if (!isWriteHintOrNone(getL1HintAttr()))
+ return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+
+ if (!isWriteHintOrNone(getL2HintAttr()))
+ return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+ if (!isWriteHintOrNone(getL3HintAttr()))
+ return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+
+ return success();
+}
- if (dstElemTy != valElemTy) {
- return emitOpError() << "The element type of the value should "
- "match the elementtype of the TensorDesc.\n";
+//===----------------------------------------------------------------------===//
+// XeGPU_UpdateNDOffsetOp
+//===----------------------------------------------------------------------===//
+LogicalResult UpdateNdOffsetOp::verify() {
+ auto ty = getTensorDescType();
+ if (ty.getScattered())
+ return emitOpError("Expects a non-scattered TensorDesc.\n");
+
+ // number of offsets specified must match the rank of the tensor descriptor
+ if (ty.getRank() != (int64_t)getNumOffsets()) {
+ return emitOpError("Invalid number of offsets.");
}
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_CreateDescOp
+//===----------------------------------------------------------------------===//
+void CreateDescOp::build(OpBuilder &builder, OperationState &state,
+ TensorDescType TensorDesc, Value source,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ uint32_t chunk_size) {
+ llvm::SmallVector<int64_t> staticOffsets;
+ llvm::SmallVector<Value> dynamicOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ build(builder, state, TensorDesc, source, dynamicOffsets, staticOffsets,
+ chunk_size);
+}
+
+LogicalResult CreateDescOp::verify() {
+ auto tdescTy = getTensorDescType();
+ auto chunkSize = getChunkSize();
+
+ if (getRankOf(getSource()) > 1)
+ return emitOpError(
+ "Expecting the source is a 1D memref or pointer (uint64_t).");
+
+ if (!tdescTy.getScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+
+ SmallVector<int64_t> shape({(int64_t)getNumOffsets()});
+ if (chunkSize != 1)
+ shape.push_back(chunkSize);
+
+ auto tdescShape = getShapeOf(tdescTy);
+ if (shape != tdescShape)
+ return emitOpError("Incorrect TensorDesc shape. ")
+ << "Expected is " << makeString(shape) << "\n";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_PrefetchOp
+//===----------------------------------------------------------------------===//
+LogicalResult PrefetchOp::verify() {
+ auto tdescTy = getTensorDescType();
+ if (!tdescTy.getScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+
+ if (!isReadHintOrNone(getL1HintAttr()))
+ return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+
+ if (!isReadHintOrNone(getL2HintAttr()))
+ return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+ if (!isReadHintOrNone(getL3HintAttr()))
+ return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_LoadGatherOp
+//===----------------------------------------------------------------------===//
+LogicalResult LoadGatherOp::verify() {
+ auto tdescTy = getTensorDescType();
+ auto maskTy = getMaskType();
+ auto valueTy = getValueType();
+
+ if (!tdescTy.getScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+
+ if (!isReadHintOrNone(getL1HintAttr()))
+ return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+
+ if (!isReadHintOrNone(getL2HintAttr()))
+ return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+ if (!isReadHintOrNone(getL3HintAttr()))
+ return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+
+ auto tdescElemTy = tdescTy.getElementType();
+ auto valueElemTy = getElementType();
+ if (tdescElemTy != valueElemTy)
+ return emitOpError(
+ "Value should have the same element type as TensorDesc.");
+
+ auto maskShape = getShapeOf(maskTy);
+ auto valueShape = getShapeOf(valueTy);
+ auto tdescShape = getShapeOf(tdescTy);
+
+ if (tdescShape[0] != maskShape[0])
+ return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
+
+ if (getTransposeAttr()) {
+ auto trans = getTranspose().value();
+ if (tdescShape.size() < trans.size())
+ emitWarning("Invalid transpose attr. It is ignored.");
+ else
+ transpose(trans, tdescShape);
+ }
+
+ if (valueShape != tdescShape)
+ return emitOpError("Unexpected result shape")
+ << "(Expected shape: " << makeString(tdescShape)
+ << ", Given shape: " << makeString(valueShape) << ").\n";
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_StoreScatterOp
+//===----------------------------------------------------------------------===//
+LogicalResult StoreScatterOp::verify() {
+ auto tdescTy = getTensorDescType();
+ if (!tdescTy.getScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+
+ if (!isWriteHintOrNone(getL1HintAttr()))
+ return emitOpError("invlid l1_hint: ") << getL1HintAttr();
+
+ if (!isWriteHintOrNone(getL2HintAttr()))
+ return emitOpError("invlid l2_hint: ") << getL2HintAttr();
+
+ if (!isWriteHintOrNone(getL3HintAttr()))
+ return emitOpError("invlid l3_hint: ") << getL3HintAttr();
+
+ auto maskTy = getMaskType();
+ auto maskShape = getShapeOf(maskTy);
+ auto tdescShape = getShapeOf(tdescTy);
+ if (tdescShape[0] != maskShape[0])
+ return emitOpError("dim-0 of the Mask and TensorDesc should be the same.");
- if (dstTy.getShape() != valTy.getShape())
- return emitOpError()
- << "The result shape should match the TensorDesc shape.\n";
return success();
}