aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/CAPI/Transforms/Rewrite.cpp13
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp21
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp154
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp82
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp10
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp6
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp12
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp16
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp8
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp24
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp40
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp24
-rw-r--r--mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp162
-rw-r--r--mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt53
-rw-r--r--mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt14
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp42
29 files changed, 548 insertions, 172 deletions
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 8ee6308..0d56259 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -259,22 +259,23 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
/// RewritePatternSet and FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//
-inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
}
-inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
+static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
return {module};
}
-inline mlir::FrozenRewritePatternSet *
+static inline mlir::FrozenRewritePatternSet *
unwrap(MlirFrozenRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
}
-inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
+static inline MlirFrozenRewritePatternSet
+wrap(mlir::FrozenRewritePatternSet *module) {
return {module};
}
@@ -321,12 +322,12 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
//===----------------------------------------------------------------------===//
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
-inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
+static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
assert(module.ptr && "unexpected null module");
return static_cast<mlir::PDLPatternModule *>(module.ptr);
}
-inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
+static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
return {module};
}
diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
index 7626d35..c64e10f5 100644
--- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp
@@ -123,7 +123,8 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality(
vector::OuterProductOp, vector::ScanOp>(
[&](Operation *op) { return converter.isLegal(op); });
target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp,
- arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>();
+ arith::ConstantOp, arith::SelectOp, vector::SplatOp,
+ vector::BroadcastOp>();
}
void EmulateUnsupportedFloatsPass::runOnOperation() {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3bd763e..05fc7cb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1622,12 +1622,12 @@ static void generateCollapsedIndexingRegion(
}
}
-void collapseOperandsAndResults(LinalgOp op,
- const CollapsingInfo &collapsingInfo,
- RewriterBase &rewriter,
- SmallVectorImpl<Value> &inputOperands,
- SmallVectorImpl<Value> &outputOperands,
- SmallVectorImpl<Type> &resultTypes) {
+static void collapseOperandsAndResults(LinalgOp op,
+ const CollapsingInfo &collapsingInfo,
+ RewriterBase &rewriter,
+ SmallVectorImpl<Value> &inputOperands,
+ SmallVectorImpl<Value> &outputOperands,
+ SmallVectorImpl<Type> &resultTypes) {
Location loc = op->getLoc();
inputOperands =
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
@@ -1651,8 +1651,8 @@ void collapseOperandsAndResults(LinalgOp op,
/// Clone a `LinalgOp` to a collapsed version of same name
template <typename OpTy>
-OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
- const CollapsingInfo &collapsingInfo) {
+static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
+ const CollapsingInfo &collapsingInfo) {
return nullptr;
}
@@ -1699,8 +1699,9 @@ GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
return collapsedOp;
}
-LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
- RewriterBase &rewriter) {
+static LinalgOp createCollapsedOp(LinalgOp op,
+ const CollapsingInfo &collapsingInfo,
+ RewriterBase &rewriter) {
if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
} else {
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 32ebe06..5672942 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -33,6 +33,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
+#include "llvm/Support/InterleavedRange.h"
#include <cstddef>
#include <iterator>
#include <optional>
@@ -3385,6 +3386,9 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
Value result = getResult();
auto [newCli, gen, cons] = decodeCli(result);
+ // Structured binding `gen` cannot be captured in lambdas before C++20
+ OpOperand *generator = gen;
+
// Derive the CLI variable name from its generator:
// * "canonloop" for omp.canonical_loop
// * custom name for loop transformation generatees
@@ -3403,6 +3407,24 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
.Case([&](UnrollHeuristicOp op) -> std::string {
llvm_unreachable("heuristic unrolling does not generate a loop");
})
+ .Case([&](TileOp op) -> std::string {
+ auto [generateesFirst, generateesCount] =
+ op.getGenerateesODSOperandIndexAndLength();
+ unsigned firstGrid = generateesFirst;
+ unsigned firstIntratile = generateesFirst + generateesCount / 2;
+ unsigned end = generateesFirst + generateesCount;
+ unsigned opnum = generator->getOperandNumber();
+ // In the OpenMP apply and looprange clauses, indices are 1-based
+ if (firstGrid <= opnum && opnum < firstIntratile) {
+ unsigned gridnum = opnum - firstGrid + 1;
+ return ("grid" + Twine(gridnum)).str();
+ }
+ if (firstIntratile <= opnum && opnum < end) {
+ unsigned intratilenum = opnum - firstIntratile + 1;
+ return ("intratile" + Twine(intratilenum)).str();
+ }
+ llvm_unreachable("Unexpected generatee argument");
+ })
.Default([&](Operation *op) {
assert(false && "TODO: Custom name for this operation");
return "transformed";
@@ -3632,6 +3654,138 @@ UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
}
//===----------------------------------------------------------------------===//
+// TileOp
+//===----------------------------------------------------------------------===//
+
+static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
+ OperandRange generatees,
+ OperandRange applyees) {
+ if (!generatees.empty())
+ p << '(' << llvm::interleaved(generatees) << ')';
+
+ if (!applyees.empty())
+ p << " <- (" << llvm::interleaved(applyees) << ')';
+}
+
+static ParseResult parseLoopTransformClis(
+ OpAsmParser &parser,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &generateesOperands,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &applyeesOperands) {
+ if (parser.parseOptionalLess()) {
+ // Syntax 1: generatees present
+
+ if (parser.parseOperandList(generateesOperands,
+ mlir::OpAsmParser::Delimiter::Paren))
+ return failure();
+
+ if (parser.parseLess())
+ return failure();
+ } else {
+ // Syntax 2: generatees omitted
+ }
+
+ // Parse `<-` (`<` has already been parsed)
+ if (parser.parseMinus())
+ return failure();
+
+ if (parser.parseOperandList(applyeesOperands,
+ mlir::OpAsmParser::Delimiter::Paren))
+ return failure();
+
+ return success();
+}
+
+LogicalResult TileOp::verify() {
+ if (getApplyees().empty())
+ return emitOpError() << "must apply to at least one loop";
+
+ if (getSizes().size() != getApplyees().size())
+ return emitOpError() << "there must be one tile size for each applyee";
+
+ if (!getGeneratees().empty() &&
+ 2 * getSizes().size() != getGeneratees().size())
+ return emitOpError()
+ << "expecting two times the number of generatees than applyees";
+
+ DenseSet<Value> parentIVs;
+
+ Value parent = getApplyees().front();
+ for (auto &&applyee : llvm::drop_begin(getApplyees())) {
+ auto [parentCreate, parentGen, parentCons] = decodeCli(parent);
+ auto [create, gen, cons] = decodeCli(applyee);
+
+ if (!parentGen)
+ return emitOpError() << "applyee CLI has no generator";
+
+ auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner());
+ if (!parentGen)
+ return emitOpError()
+ << "currently only supports omp.canonical_loop as applyee";
+
+ parentIVs.insert(parentLoop.getInductionVar());
+
+ if (!gen)
+ return emitOpError() << "applyee CLI has no generator";
+ auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
+ if (!loop)
+ return emitOpError()
+ << "currently only supports omp.canonical_loop as applyee";
+
+ // Canonical loop must be perfectly nested, i.e. the body of the parent must
+ // only contain the omp.canonical_loop of the nested loops, and
+ // omp.terminator
+ bool isPerfectlyNested = [&]() {
+ auto &parentBody = parentLoop.getRegion();
+ if (!parentBody.hasOneBlock())
+ return false;
+ auto &parentBlock = parentBody.getBlocks().front();
+
+ auto nestedLoopIt = parentBlock.begin();
+ if (nestedLoopIt == parentBlock.end() ||
+ (&*nestedLoopIt != loop.getOperation()))
+ return false;
+
+ auto termIt = std::next(nestedLoopIt);
+ if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
+ return false;
+
+ if (std::next(termIt) != parentBlock.end())
+ return false;
+
+ return true;
+ }();
+ if (!isPerfectlyNested)
+ return emitOpError() << "tiled loop nest must be perfectly nested";
+
+ if (parentIVs.contains(loop.getTripCount()))
+ return emitOpError() << "tiled loop nest must be rectangular";
+
+ parent = applyee;
+ }
+
+ // TODO: The tile sizes must be computed before the loop, but checking this
+ // requires dominance analysis. For instance:
+ //
+ // %canonloop = omp.new_cli
+ // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
+ // // write to %x
+ // omp.terminator
+ // }
+ // %ts = llvm.load %x
+ // omp.tile <- (%canonloop) sizes(%ts : i32)
+
+ return success();
+}
+
+std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
+ return getODSOperandIndexAndLength(odsIndex_applyees);
+}
+
+std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
+ return getODSOperandIndexAndLength(odsIndex_generatees);
+}
+
+//===----------------------------------------------------------------------===//
// Critical construct (2.17.1)
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eb46869..b0132e8 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -580,7 +580,7 @@ namespace {
// ElideSingleElementReduction for ReduceOp.
struct ElideUnitDimsInMultiDimReduction
: public OpRewritePattern<MultiDimReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
PatternRewriter &rewriter) const override {
@@ -730,7 +730,7 @@ std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
namespace {
struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ReductionOp reductionOp,
PatternRewriter &rewriter) const override {
@@ -2197,7 +2197,7 @@ namespace {
// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
@@ -2220,7 +2220,7 @@ public:
// Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask.
class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractOp extractOp,
PatternRewriter &rewriter) const override {
@@ -2546,7 +2546,7 @@ rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(FromElementsOp fromElements,
PatternRewriter &rewriter) const override {
@@ -2938,7 +2938,7 @@ namespace {
// Fold broadcast1(broadcast2(x)) into broadcast1(x).
struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
PatternRewriter &rewriter) const override {
@@ -3109,7 +3109,7 @@ namespace {
// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
// to a broadcast.
struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
PatternRewriter &rewriter) const override {
@@ -3165,7 +3165,7 @@ static Value getScalarSplatSource(Value value) {
/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ShuffleOp op,
PatternRewriter &rewriter) const override {
@@ -3182,7 +3182,7 @@ public:
/// vector.interleave.
class ShuffleInterleave : public OpRewritePattern<ShuffleOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ShuffleOp op,
PatternRewriter &rewriter) const override {
@@ -3326,7 +3326,7 @@ namespace {
// broadcast.
class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertOp insertOp,
PatternRewriter &rewriter) const override {
@@ -3344,7 +3344,7 @@ public:
/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
@@ -3380,7 +3380,7 @@ public:
/// %result = vector.from_elements %c1, %c2 : vector<2xi32>
class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
@@ -3748,7 +3748,7 @@ namespace {
class FoldInsertStridedSliceSplat final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
@@ -3768,7 +3768,7 @@ public:
class FoldInsertStridedSliceOfExtract final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
@@ -3798,7 +3798,7 @@ public:
class InsertStridedSliceConstantFolder final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
// Do not create constants with more than `vectorSizeFoldThreashold` elements,
// unless the source vector constant has a single use.
@@ -4250,7 +4250,7 @@ namespace {
// %mask = vector.create_mask %new_ub : vector<8xi1>
class StridedSliceCreateMaskFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
public:
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
@@ -4310,7 +4310,7 @@ public:
class StridedSliceConstantMaskFolder final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
PatternRewriter &rewriter) const override {
@@ -4365,7 +4365,7 @@ public:
class StridedSliceBroadcast final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -4416,7 +4416,7 @@ public:
/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -4448,7 +4448,7 @@ public:
class ContiguousExtractStridedSliceToExtract final
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -5023,7 +5023,7 @@ namespace {
/// ```
struct TransferReadAfterWriteToBroadcast
: public OpRewritePattern<TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(TransferReadOp readOp,
PatternRewriter &rewriter) const override {
@@ -5458,7 +5458,7 @@ namespace {
/// any other uses.
class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
if (!llvm::isa<RankedTensorType>(writeOp.getShapedType()))
@@ -5514,7 +5514,7 @@ public:
struct SwapExtractSliceOfTransferWrite
: public OpRewritePattern<tensor::InsertSliceOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
PatternRewriter &rewriter) const override {
@@ -5737,7 +5737,7 @@ LogicalResult MaskedLoadOp::verify() {
namespace {
class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(MaskedLoadOp load,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(load.getMask())) {
@@ -5794,7 +5794,7 @@ LogicalResult MaskedStoreOp::verify() {
namespace {
class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(MaskedStoreOp store,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(store.getMask())) {
@@ -5890,7 +5890,7 @@ static LogicalResult isZeroBasedContiguousSeq(Value indexVec) {
namespace {
class GatherFolder final : public OpRewritePattern<GatherOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(gather.getMask())) {
@@ -5910,7 +5910,7 @@ public:
/// maskedload. Only 1D fixed vectors are supported for now.
class FoldContiguousGather final : public OpRewritePattern<GatherOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(GatherOp op,
PatternRewriter &rewriter) const override {
if (!isa<MemRefType>(op.getBase().getType()))
@@ -5962,7 +5962,7 @@ LogicalResult ScatterOp::verify() {
namespace {
class ScatterFolder final : public OpRewritePattern<ScatterOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ScatterOp scatter,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(scatter.getMask())) {
@@ -5982,7 +5982,7 @@ public:
/// maskedstore. Only 1D fixed vectors are supported for now.
class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ScatterOp op,
PatternRewriter &rewriter) const override {
if (failed(isZeroBasedContiguousSeq(op.getIndices())))
@@ -6030,7 +6030,7 @@ LogicalResult ExpandLoadOp::verify() {
namespace {
class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExpandLoadOp expand,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(expand.getMask())) {
@@ -6081,7 +6081,7 @@ LogicalResult CompressStoreOp::verify() {
namespace {
class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(CompressStoreOp compress,
PatternRewriter &rewriter) const override {
switch (getMaskFormat(compress.getMask())) {
@@ -6260,7 +6260,7 @@ static VectorType trimTrailingOneDims(VectorType oldType) {
class ShapeCastCreateMaskFolderTrailingOneDim final
: public OpRewritePattern<ShapeCastOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ShapeCastOp shapeOp,
PatternRewriter &rewriter) const override {
@@ -6330,7 +6330,7 @@ public:
/// If both (i) and (ii) are possible, (i) is chosen.
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
PatternRewriter &rewriter) const override {
@@ -6614,7 +6614,7 @@ namespace {
// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
@@ -6646,7 +6646,7 @@ public:
/// Replace transpose(splat-like(v)) with broadcast(v)
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
@@ -6663,7 +6663,7 @@ public:
/// Folds transpose(create_mask) into a new transposed create_mask.
class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(TransposeOp transpOp,
PatternRewriter &rewriter) const override {
@@ -6700,7 +6700,7 @@ public:
/// Folds transpose(shape_cast) into a new shape_cast.
class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
@@ -6750,7 +6750,7 @@ public:
/// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6).
class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
@@ -6971,7 +6971,7 @@ namespace {
/// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1>
class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
PatternRewriter &rewriter) const override {
@@ -7300,7 +7300,7 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor,
/// %0 = arith.select %mask, %a, %passthru : vector<8xf32>
///
class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(MaskOp maskOp,
PatternRewriter &rewriter) const override {
@@ -7410,7 +7410,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
// vector.broadcast.
class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
public:
- using OpRewritePattern<SplatOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(SplatOp splatOp,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index dedc3b3..61d9357 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -34,7 +34,7 @@ namespace {
/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::BroadcastOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
index 65702ff..efe8d14 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp
@@ -1151,7 +1151,7 @@ FailureOr<Value> ContractionOpLowering::lowerReduction(
///
class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::OuterProductOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index 1f96a3a..6bc8347 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -50,7 +50,7 @@ namespace {
///
/// Supports vector types with a fixed leading dimension.
struct UnrollGather : OpRewritePattern<vector::GatherOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
@@ -98,7 +98,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
/// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef,
/// but should be fairly straightforward to extend beyond that.
struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
@@ -164,7 +164,7 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> {
/// `tensor.extract`s. To avoid out-of-bounds memory accesses, these
/// loads/extracts are made conditional using `scf.if` ops.
struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
index 9d6a865..479fc0c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp
@@ -163,7 +163,7 @@ private:
/// : vector<7xi16>, vector<7xi16>
/// ```
struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::InterleaveOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
index 5617b06..7730c4e 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
@@ -48,7 +48,7 @@ namespace {
/// until a one-dimensional vector is reached.
class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::CreateMaskOp op,
PatternRewriter &rewriter) const override {
@@ -100,7 +100,7 @@ public:
/// will be folded at LLVM IR level.
class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ConstantMaskOp op,
PatternRewriter &rewriter) const override {
@@ -184,7 +184,7 @@ namespace {
/// and actually match the traits of its the nested `MaskableOpInterface`.
template <class SourceOp>
struct MaskOpRewritePattern : OpRewritePattern<MaskOp> {
- using OpRewritePattern<MaskOp>::OpRewritePattern;
+ using Base::Base;
private:
LogicalResult matchAndRewrite(MaskOp maskOp,
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
index 4773732d..e86e2a9 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp
@@ -39,7 +39,7 @@ namespace {
class InnerOuterDimReductionConversion
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
explicit InnerOuterDimReductionConversion(
MLIRContext *context, vector::VectorMultiReductionLowering options,
@@ -136,7 +136,7 @@ private:
class ReduceMultiDimReductionRank
: public OpRewritePattern<vector::MultiDimReductionOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
explicit ReduceMultiDimReductionRank(
MLIRContext *context, vector::VectorMultiReductionLowering options,
@@ -304,7 +304,7 @@ private:
/// and combines results
struct TwoDimMultiReductionToElementWise
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
@@ -359,7 +359,7 @@ struct TwoDimMultiReductionToElementWise
/// a sequence of vector.reduction ops.
struct TwoDimMultiReductionToReduction
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
@@ -420,7 +420,7 @@ struct TwoDimMultiReductionToReduction
/// separately.
struct OneDimMultiReductionToTwoDim
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
index af4851e..258f2cb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp
@@ -99,7 +99,7 @@ namespace {
/// return %7, %8 : vector<2x3xi32>, vector<2xi32>
/// ```
struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ScanOp scanOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
index 603ea41..c5f22b2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp
@@ -189,7 +189,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> {
}
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
@@ -356,7 +356,7 @@ public:
class ScalableShapeCastOpRewritePattern
: public OpRewritePattern<vector::ShapeCastOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
index 78102f7..8f46ad6 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp
@@ -44,7 +44,7 @@ namespace {
///
struct MixedSizeInputShuffleOpRewrite final
: OpRewritePattern<vector::ShuffleOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
index ee5568a..08e7c89 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp
@@ -24,7 +24,7 @@ using namespace mlir::vector;
namespace {
struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::StepOp stepOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
index 6407a86..7521e24 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp
@@ -667,7 +667,7 @@ getToElementsDefiningOps(FromElementsOp fromElemsOp,
struct ToFromElementsToShuffleTreeRewrite final
: OpRewritePattern<vector::FromElementsOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
index 9e7d0ce..c3f7de0 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp
@@ -300,7 +300,7 @@ namespace {
/// %x = vector.insert .., .. [.., ..]
class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering,
MLIRContext *context, PatternBenefit benefit = 1)
@@ -395,7 +395,7 @@ private:
class Transpose2DWithUnitDimToShapeCast
: public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
Transpose2DWithUnitDimToShapeCast(MLIRContext *context,
PatternBenefit benefit = 1)
@@ -433,7 +433,7 @@ public:
class TransposeOp2DToShuffleLowering
: public OpRewritePattern<vector::TransposeOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
TransposeOp2DToShuffleLowering(
vector::VectorTransposeLowering vectorTransposeLowering,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
index cab1289..963b2c8 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp
@@ -54,7 +54,7 @@ namespace {
// input by inserting vector.broadcast.
struct CastAwayExtractStridedSliceLeadingOneDim
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
PatternRewriter &rewriter) const override {
@@ -104,7 +104,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
// inputs by inserting vector.broadcast.
struct CastAwayInsertStridedSliceLeadingOneDim
: public OpRewritePattern<vector::InsertStridedSliceOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp,
PatternRewriter &rewriter) const override {
@@ -145,7 +145,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
// Casts away leading one dimensions in vector.insert's vector inputs by
// inserting vector.broadcast.
struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::InsertOp insertOp,
PatternRewriter &rewriter) const override {
@@ -221,7 +221,7 @@ static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask,
// 1 dimensions.
struct CastAwayTransferReadLeadingOneDim
: public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
@@ -275,7 +275,7 @@ struct CastAwayTransferReadLeadingOneDim
// 1 dimensions.
struct CastAwayTransferWriteLeadingOneDim
: public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
@@ -541,7 +541,7 @@ public:
// vector.broadcast back to the original shape.
struct CastAwayConstantMaskLeadingOneDim
: public OpRewritePattern<vector::ConstantMaskOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ConstantMaskOp mask,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
index bdbb792..7acc120 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp
@@ -48,7 +48,7 @@ namespace {
///
struct VectorMaskedLoadOpConverter final
: OpRewritePattern<vector::MaskedLoadOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp,
PatternRewriter &rewriter) const override {
@@ -117,7 +117,7 @@ struct VectorMaskedLoadOpConverter final
///
struct VectorMaskedStoreOpConverter final
: OpRewritePattern<vector::MaskedStoreOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 264cbc1..3a6684f 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -548,7 +548,7 @@ namespace {
// NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to
// `false` to generate non-atomic RMW sequences.
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW)
: OpConversionPattern<vector::StoreOp>(context),
@@ -827,7 +827,7 @@ private:
/// adjusted mask .
struct ConvertVectorMaskedStore final
: OpConversionPattern<vector::MaskedStoreOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LogicalResult
matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor,
@@ -950,7 +950,7 @@ struct ConvertVectorMaskedStore final
/// those cases, loads are converted to byte-aligned, byte-sized loads and the
/// target vector is extracted from the loaded vector.
struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LogicalResult
matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor,
@@ -1059,7 +1059,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
/// bitcasting, since each `i8` container element holds two `i4` values.
struct ConvertVectorMaskedLoad final
: OpConversionPattern<vector::MaskedLoadOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LogicalResult
matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor,
@@ -1257,7 +1257,7 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy,
// TODO: Document-me
struct ConvertVectorTransferRead final
: OpConversionPattern<vector::TransferReadOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LogicalResult
matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor,
@@ -1942,7 +1942,7 @@ namespace {
/// advantage of high-level information to avoid leaving LLVM to scramble with
/// peephole optimizations.
struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp,
PatternRewriter &rewriter) const override {
@@ -2147,7 +2147,7 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
/// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4>
///
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
- using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
PatternRewriter &rewriter) const override {
@@ -2200,7 +2200,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
/// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4>
///
struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> {
- using OpRewritePattern<vector::TransposeOp>::OpRewritePattern;
+ using Base::Base;
RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit)
: OpRewritePattern<vector::TransposeOp>(context, benefit) {}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index f6d6555..9e49873 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -34,7 +34,7 @@ using namespace mlir::vector;
class DecomposeDifferentRankInsertStridedSlice
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(InsertStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -84,7 +84,7 @@ public:
class ConvertSameRankInsertStridedSliceIntoShuffle
: public OpRewritePattern<InsertStridedSliceOp> {
public:
- using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern;
+ using Base::Base;
void initialize() {
// This pattern creates recursive InsertStridedSliceOp, but the recursion is
@@ -183,7 +183,7 @@ public:
class Convert1DExtractStridedSliceIntoShuffle
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
@@ -271,7 +271,7 @@ private:
class DecomposeNDExtractStridedSlice
: public OpRewritePattern<ExtractStridedSliceOp> {
public:
- using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern;
+ using Base::Base;
void initialize() {
// This pattern creates recursive ExtractStridedSliceOp, but the recursion
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 82bac8c..71fba71c 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -214,7 +214,7 @@ SmallVector<int64_t> static getStridedSliceInsertionIndices(
/// vector.extract_strided_slice operation.
struct LinearizeVectorExtractStridedSlice final
: public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter,
MLIRContext *context,
PatternBenefit benefit = 1)
@@ -285,7 +285,7 @@ struct LinearizeVectorExtractStridedSlice final
///
struct LinearizeVectorInsertStridedSlice final
: public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter,
MLIRContext *context,
PatternBenefit benefit = 1)
@@ -348,7 +348,7 @@ struct LinearizeVectorInsertStridedSlice final
/// of the original shuffle operation.
struct LinearizeVectorShuffle final
: public OpConversionPattern<vector::ShuffleOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorShuffle(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -423,7 +423,7 @@ struct LinearizeVectorShuffle final
///
struct LinearizeVectorExtract final
: public OpConversionPattern<vector::ExtractOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorExtract(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -501,7 +501,7 @@ struct LinearizeVectorExtract final
///
struct LinearizeVectorInsert final
: public OpConversionPattern<vector::InsertOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorInsert(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -575,7 +575,7 @@ struct LinearizeVectorInsert final
/// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16>
struct LinearizeVectorBitCast final
: public OpConversionPattern<vector::BitCastOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorBitCast(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -598,7 +598,7 @@ struct LinearizeVectorBitCast final
/// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32>
struct LinearizeVectorSplat final
: public OpConversionPattern<vector::SplatOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
@@ -629,7 +629,7 @@ struct LinearizeVectorSplat final
/// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1>
struct LinearizeVectorCreateMask final
: OpConversionPattern<vector::CreateMaskOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorCreateMask(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
@@ -684,7 +684,7 @@ struct LinearizeVectorCreateMask final
/// For generic cases, the vector unroll pass should be used to unroll the load
/// to vector<1x1x...xN> form and then linearized
struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -731,7 +731,7 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> {
/// to vector<1x1x...xN> form and then linearized
struct LinearizeVectorStore final
: public OpConversionPattern<vector::StoreOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context,
PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -778,7 +778,7 @@ struct LinearizeVectorStore final
///
struct LinearizeVectorFromElements final
: public OpConversionPattern<vector::FromElementsOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorFromElements(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern(typeConverter, context, benefit) {}
@@ -814,7 +814,7 @@ struct LinearizeVectorFromElements final
///
struct LinearizeVectorToElements final
: public OpConversionPattern<vector::ToElementsOp> {
- using OpConversionPattern::OpConversionPattern;
+ using Base::Base;
LinearizeVectorToElements(const TypeConverter &typeConverter,
MLIRContext *context, PatternBenefit benefit = 1)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index c364a8b..1121d95 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -1081,7 +1081,7 @@ private:
/// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>)
/// to memref.store.
class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 866f789..d6a6d7cd 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -78,7 +78,7 @@ namespace {
/// ```
struct MultiReduceToContract
: public OpRewritePattern<vector::MultiDimReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp,
PatternRewriter &rewriter) const override {
@@ -138,7 +138,7 @@ struct MultiReduceToContract
/// ```
struct CombineContractABTranspose final
: public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
@@ -202,7 +202,7 @@ struct CombineContractABTranspose final
/// ```
struct CombineContractResultTranspose final
: public OpRewritePattern<vector::TransposeOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransposeOp resTOp,
PatternRewriter &rewriter) const override {
@@ -568,7 +568,7 @@ static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) {
// %2 = vector.extract %1[1] : f16 from vector<2xf16>
struct BubbleDownVectorBitCastForExtract
: public OpRewritePattern<vector::ExtractOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
PatternRewriter &rewriter) const override {
@@ -643,7 +643,7 @@ struct BubbleDownVectorBitCastForExtract
// %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16>
struct BubbleDownBitCastForStridedSliceExtract
: public OpRewritePattern<vector::ExtractStridedSliceOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp,
PatternRewriter &rewriter) const override {
@@ -721,7 +721,7 @@ struct BubbleDownBitCastForStridedSliceExtract
// %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8>
//
struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
PatternRewriter &rewriter) const override {
@@ -794,7 +794,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> {
// offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32>
struct BubbleUpBitCastForStridedSliceInsert
: public OpRewritePattern<vector::BitCastOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp,
PatternRewriter &rewriter) const override {
@@ -892,7 +892,7 @@ struct BubbleUpBitCastForStridedSliceInsert
// %7 = vector.insert_strided_slice %6, %cst {
// offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32>
struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
public:
BreakDownVectorBitCast(MLIRContext *context,
@@ -1131,7 +1131,7 @@ struct ReorderElementwiseOpsOnBroadcast final
class ExtractOpFromElementwise final
: public OpRewritePattern<vector::ExtractOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ExtractOp op,
PatternRewriter &rewriter) const override {
@@ -1206,7 +1206,7 @@ static bool isSupportedMemSinkElementType(Type type) {
/// ```
class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ExtractOp op,
PatternRewriter &rewriter) const override {
@@ -1285,7 +1285,7 @@ public:
class StoreOpFromSplatOrBroadcast final
: public OpRewritePattern<vector::StoreOp> {
public:
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::StoreOp op,
PatternRewriter &rewriter) const override {
@@ -1476,7 +1476,7 @@ static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) {
/// InstCombine seems to handle vectors with multiple elements but not the
/// single element ones.
struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
- using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(arith::SelectOp selectOp,
PatternRewriter &rewriter) const override {
@@ -1560,7 +1560,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
/// Drop inner most contiguous unit dimensions from transfer_read operand.
class DropInnerMostUnitDimsTransferRead
: public OpRewritePattern<vector::TransferReadOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
PatternRewriter &rewriter) const override {
@@ -1651,7 +1651,7 @@ class DropInnerMostUnitDimsTransferRead
/// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`).
class DropInnerMostUnitDimsTransferWrite
: public OpRewritePattern<vector::TransferWriteOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
@@ -1728,7 +1728,7 @@ class DropInnerMostUnitDimsTransferWrite
/// with the RHS transposed) lowering.
struct CanonicalizeContractMatmulToMMT final
: OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
using FilterConstraintType =
std::function<LogicalResult(vector::ContractionOp op)>;
@@ -1845,7 +1845,7 @@ private:
template <typename ExtOp>
struct FoldArithExtIntoContractionOp
: public OpRewritePattern<vector::ContractionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
@@ -1878,7 +1878,7 @@ struct FoldArithExtIntoContractionOp
/// %b = vector.reduction <add> %a, %acc
/// ```
struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ReductionOp op,
PatternRewriter &rewriter) const override {
@@ -2033,7 +2033,7 @@ struct DropUnitDimFromElementwiseOps final
/// ```
struct DropUnitDimsFromTransposeOp final
: OpRewritePattern<vector::TransposeOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::TransposeOp op,
PatternRewriter &rewriter) const override {
@@ -2110,7 +2110,7 @@ struct DropUnitDimsFromTransposeOp final
/// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32>
/// ```
struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(scf::ForOp forOp,
PatternRewriter &rewriter) const override {
@@ -2155,7 +2155,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> {
/// %c = vector.reduction <add> %b, %acc
/// ```
struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> {
- using OpRewritePattern::OpRewritePattern;
+ using Base::Base;
LogicalResult matchAndRewrite(vector::ReductionOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 9413a92..784e5d6 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -824,7 +824,7 @@ struct WgToSgStoreScatterOpWithOffset
return failure();
xegpu::DistributeLayoutAttr layout =
- xegpu::getDistributeLayoutAttr(op.getValue());
+ xegpu::getDistributeLayoutAttr(op.getOperand(0));
if (!layout || !layout.isForWorkgroup())
return failure();
@@ -844,12 +844,19 @@ struct WgToSgStoreScatterOpWithOffset
auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize);
for (auto [val, offs, mask] : llvm::zip(
adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) {
- xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs,
- mask, chunkSizeAttr, op.getL1HintAttr(),
- op.getL2HintAttr(), op.getL3HintAttr());
+ auto store = xegpu::StoreScatterOp::create(
+ rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr,
+ op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr());
// Update the layout attribute to drop sg_layout and sg_data.
- if (auto newLayout = layout.dropSgLayoutAndData())
- op->setAttr("layout", newLayout);
+ if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
+ !layout.getEffectiveInstDataAsInt().empty()) {
+ for (OpOperand &operand : store->getOpOperands()) {
+ // Skip for operand one (memref)
+ if (operand.getOperandNumber() == 1)
+ continue;
+ xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
+ }
+ }
}
rewriter.eraseOp(op);
return success();
@@ -1247,10 +1254,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
target.addDynamicallyLegalOp<xegpu::StoreScatterOp>(
[=](xegpu::StoreScatterOp op) -> bool {
- // Check if the layout attribute is present on the result.
- auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout");
- if (!layout)
- return true;
+ auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0));
return isLegal(layout);
});
diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
index d6b8a8a..e3f075f 100644
--- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
+++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp
@@ -54,6 +54,7 @@ struct OpStrings {
std::string opCppName;
SmallVector<std::string> opResultNames;
SmallVector<std::string> opOperandNames;
+ SmallVector<std::string> opRegionNames;
};
static std::string joinNameList(llvm::ArrayRef<std::string> names) {
@@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) {
/// Generates OpStrings from an OperatioOp
static OpStrings getStrings(irdl::OperationOp op) {
auto operandOp = op.getOp<irdl::OperandsOp>();
-
auto resultOp = op.getOp<irdl::ResultsOp>();
+ auto regionsOp = op.getOp<irdl::RegionsOp>();
OpStrings strings;
strings.opName = op.getSymName();
@@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) {
}));
}
+ if (regionsOp) {
+ strings.opRegionNames = SmallVector<std::string>(
+ llvm::map_range(regionsOp->getNames(), [](Attribute attr) {
+ return llvm::formatv("{0}", cast<StringAttr>(attr));
+ }));
+ }
+
return strings;
}
@@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict,
static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
const auto operandCount = strings.opOperandNames.size();
const auto resultCount = strings.opResultNames.size();
+ const auto regionCount = strings.opRegionNames.size();
dict["OP_NAME"] = strings.opName;
dict["OP_CPP_NAME"] = strings.opCppName;
@@ -131,6 +140,7 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) {
operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}";
dict["OP_RESULT_INITIALIZER_LIST"] =
resultCount ? joinNameList(strings.opResultNames) : "{\"\"}";
+ dict["OP_REGION_COUNT"] = std::to_string(regionCount);
}
/// Fills a dictionary with values from DialectStrings
@@ -179,6 +189,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
const OpStrings &opStrings) {
auto opGetters = std::string{};
auto resGetters = std::string{};
+ auto regionGetters = std::string{};
+ auto regionAdaptorGetters = std::string{};
for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) {
const auto op =
@@ -196,8 +208,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict,
op, i);
}
+ for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) {
+ const auto op =
+ llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true);
+ regionAdaptorGetters += llvm::formatv(
+ R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; }
+ )",
+ op, i);
+ regionGetters += llvm::formatv(
+ R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); }
+ )",
+ op, i);
+ }
+
dict["OP_OPERAND_GETTER_DECLS"] = opGetters;
dict["OP_RESULT_GETTER_DECLS"] = resGetters;
+ dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters;
+ dict["OP_REGION_GETTER_DECLS"] = regionGetters;
}
static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
@@ -238,6 +265,22 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict,
dict["OP_BUILD_DECLS"] = buildDecls;
}
+// add traits to the dictionary, return true if any were added
+static SmallVector<std::string> generateTraits(irdl::OperationOp op,
+ const OpStrings &strings) {
+ SmallVector<std::string> cppTraitNames;
+ if (!strings.opRegionNames.empty()) {
+ cppTraitNames.push_back(
+ llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl",
+ strings.opRegionNames.size())
+ .str());
+
+ // Requires verifyInvariantsImpl is implemented on the op
+ cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants");
+ }
+ return cppTraitNames;
+}
+
static LogicalResult generateOperationInclude(irdl::OperationOp op,
raw_ostream &output,
irdl::detail::dictionary &dict) {
@@ -247,6 +290,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op,
const auto opStrings = getStrings(op);
fillDict(dict, opStrings);
+ SmallVector<std::string> traitNames = generateTraits(op, opStrings);
+ if (traitNames.empty())
+ dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName;
+ else
+ dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName,
+ llvm::join(traitNames, ", "));
+
generateOpGetterDeclarations(dict, opStrings);
generateOpBuilderDeclarations(dict, opStrings);
@@ -301,6 +351,110 @@ static LogicalResult generateInclude(irdl::DialectOp dialect,
return success();
}
+static void generateRegionConstraintVerifiers(
+ irdl::detail::dictionary &dict, irdl::OperationOp op,
+ const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers,
+ SmallVectorImpl<std::string> &verifierCalls) {
+ auto regionsOp = op.getOp<irdl::RegionsOp>();
+ if (strings.opRegionNames.empty() || !regionsOp)
+ return;
+
+ for (size_t i = 0; i < strings.opRegionNames.size(); ++i) {
+ std::string regionName = strings.opRegionNames[i];
+ std::string helperFnName =
+ llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}",
+ strings.opCppName, regionName)
+ .str();
+
+ // Extract the actual region constraint from the IRDL RegionOp
+ std::string condition = "true";
+ std::string textualConditionName = "any region";
+
+ if (auto regionDefOp =
+ dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) {
+ // Generate constraint condition based on RegionOp attributes
+ SmallVector<std::string> conditionParts;
+ SmallVector<std::string> descriptionParts;
+
+ // Check number of blocks constraint
+ if (auto blockCount = regionDefOp.getNumberOfBlocks()) {
+ conditionParts.push_back(
+ llvm::formatv("region.getBlocks().size() == {0}",
+ blockCount.value())
+ .str());
+ descriptionParts.push_back(
+ llvm::formatv("exactly {0} block(s)", blockCount.value()).str());
+ }
+
+ // Check entry block arguments constraint
+ if (regionDefOp.getConstrainedArguments()) {
+ size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size();
+ conditionParts.push_back(
+ llvm::formatv("region.getNumArguments() == {0}", expectedArgCount)
+ .str());
+ descriptionParts.push_back(
+ llvm::formatv("{0} entry block argument(s)", expectedArgCount)
+ .str());
+ }
+
+ // Combine conditions
+ if (!conditionParts.empty()) {
+ condition = llvm::join(conditionParts, " && ");
+ }
+
+ // Generate descriptive error message
+ if (!descriptionParts.empty()) {
+ textualConditionName =
+ llvm::formatv("region with {0}",
+ llvm::join(descriptionParts, " and "))
+ .str();
+ }
+ }
+
+ verifierHelpers.push_back(llvm::formatv(
+ R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region &region, ::llvm::StringRef regionName, unsigned regionIndex) {{
+ if (!({1})) {{
+ return op->emitOpError("region #") << regionIndex
+ << (regionName.empty() ? " " : " ('" + regionName + "') ")
+ << "failed to verify constraint: {2}";
+ }
+ return ::mlir::success();
+})",
+ helperFnName, condition, textualConditionName));
+
+ verifierCalls.push_back(llvm::formatv(R"(
+ if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1})))
+ return ::mlir::failure();)",
+ helperFnName, i, regionName)
+ .str());
+ }
+}
+
+static void generateVerifiers(irdl::detail::dictionary &dict,
+ irdl::OperationOp op, const OpStrings &strings) {
+ SmallVector<std::string> verifierHelpers;
+ SmallVector<std::string> verifierCalls;
+
+ generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers,
+ verifierCalls);
+
+ // Add an overall verifier that sequences the helper calls
+ std::string verifierDef =
+ llvm::formatv(R"(
+::llvm::LogicalResult {0}::verifyInvariantsImpl() {{
+ if(::mlir::failed(verify()))
+ return ::mlir::failure();
+
+ {1}
+
+ return ::mlir::success();
+})",
+ strings.opCppName, llvm::join(verifierCalls, "\n"));
+
+ dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n");
+ dict["OP_VERIFIER"] = verifierDef;
+}
+
static std::string generateOpDefinition(irdl::detail::dictionary &dict,
irdl::OperationOp op) {
static const auto perOpDefTemplate = mlir::irdl::detail::Template{
@@ -370,6 +524,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, {
dict["OP_BUILD_DEFS"] = buildDefinition;
+ generateVerifiers(dict, op, opStrings);
+
std::string str;
llvm::raw_string_ostream stream{str};
perOpDefTemplate.render(stream, dict);
@@ -427,7 +583,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output,
dict["TYPE_PARSER"] = llvm::formatv(
R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) {
return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
- {0}
+ {0}
.Default([&](llvm::StringRef keyword, llvm::SMLoc) {{
*mnemonic = keyword;
return std::nullopt;
@@ -520,6 +676,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) {
"IRDL C++ translation does not yet support variadic results");
}))
.Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); }))
+ .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); }))
+ .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); }))
.Default([](mlir::Operation *op) -> LogicalResult {
return op->emitError("IRDL C++ translation does not yet support "
"translation of ")
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
index e9068e9..93ce0be 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt
@@ -12,15 +12,15 @@ public:
struct Properties {
};
public:
- __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
- : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
- odsRegions(op->getRegions())
+ __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op)
+ : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()),
+ odsRegions(op->getRegions())
{}
/// Return the unstructured operand index of a structured operand along with
// the amount of unstructured operands it contains.
std::pair<unsigned, unsigned>
- getStructuredOperandIndexAndLength (unsigned index,
+ getStructuredOperandIndexAndLength (unsigned index,
unsigned odsOperandsSize) {
return {index, 1};
}
@@ -32,6 +32,12 @@ public:
::mlir::DictionaryAttr getAttributes() {
return odsAttrs;
}
+
+ __OP_REGION_ADAPTER_GETTER_DECLS__
+
+ ::mlir::RegionRange getRegions() {
+ return odsRegions;
+ }
protected:
::mlir::DictionaryAttr odsAttrs;
::std::optional<::mlir::OperationName> odsOpName;
@@ -42,28 +48,28 @@ protected:
} // namespace detail
template <typename RangeT>
-class __OP_CPP_NAME__GenericAdaptor
+class __OP_CPP_NAME__GenericAdaptor
: public detail::__OP_CPP_NAME__GenericAdaptorBase {
using ValueT = ::llvm::detail::ValueOfRange<RangeT>;
using Base = detail::__OP_CPP_NAME__GenericAdaptorBase;
public:
__OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs,
- ::mlir::OpaqueProperties properties,
- ::mlir::RegionRange regions = {})
- : __OP_CPP_NAME__GenericAdaptor(values, attrs,
- (properties ? *properties.as<::mlir::EmptyProperties *>()
+ ::mlir::OpaqueProperties properties,
+ ::mlir::RegionRange regions = {})
+ : __OP_CPP_NAME__GenericAdaptor(values, attrs,
+ (properties ? *properties.as<::mlir::EmptyProperties *>()
: ::mlir::EmptyProperties{}), regions) {}
- __OP_CPP_NAME__GenericAdaptor(RangeT values,
+ __OP_CPP_NAME__GenericAdaptor(RangeT values,
const __OP_CPP_NAME__GenericAdaptorBase &base)
: Base(base), odsOperands(values) {}
- // This template parameter allows using __OP_CPP_NAME__ which is declared
+ // This template parameter allows using __OP_CPP_NAME__ which is declared
// later.
template <typename LateInst = __OP_CPP_NAME__,
typename = std::enable_if_t<
std::is_same_v<LateInst, __OP_CPP_NAME__>>>
- __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
+ __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op)
: Base(op), odsOperands(values) {}
/// Return the unstructured operand index of a structured operand along with
@@ -77,7 +83,7 @@ public:
RangeT getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(odsOperands.begin(), valueRange.first),
- std::next(odsOperands.begin(),
+ std::next(odsOperands.begin(),
valueRange.first + valueRange.second)};
}
@@ -91,7 +97,7 @@ private:
RangeT odsOperands;
};
-class __OP_CPP_NAME__Adaptor
+class __OP_CPP_NAME__Adaptor
: public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> {
public:
using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor;
@@ -100,7 +106,7 @@ public:
::llvm::LogicalResult verify(::mlir::Location loc);
};
-class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> {
+class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> {
public:
using Op::Op;
using Op::print;
@@ -112,6 +118,8 @@ public:
return {};
}
+ ::llvm::LogicalResult verifyInvariantsImpl();
+
static constexpr ::llvm::StringLiteral getOperationName() {
return ::llvm::StringLiteral("__DIALECT_NAME__.__OP_NAME__");
}
@@ -147,7 +155,7 @@ public:
::mlir::Operation::operand_range getStructuredOperands(unsigned index) {
auto valueRange = getStructuredOperandIndexAndLength(index);
return {std::next(getOperation()->operand_begin(), valueRange.first),
- std::next(getOperation()->operand_begin(),
+ std::next(getOperation()->operand_begin(),
valueRange.first + valueRange.second)};
}
@@ -162,18 +170,19 @@ public:
::mlir::Operation::result_range getStructuredResults(unsigned index) {
auto valueRange = getStructuredResultIndexAndLength(index);
return {std::next(getOperation()->result_begin(), valueRange.first),
- std::next(getOperation()->result_begin(),
+ std::next(getOperation()->result_begin(),
valueRange.first + valueRange.second)};
}
__OP_OPERAND_GETTER_DECLS__
__OP_RESULT_GETTER_DECLS__
-
+ __OP_REGION_GETTER_DECLS__
+
__OP_BUILD_DECLS__
- static void build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+ static void build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {});
static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder,
diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
index 30ca420..f4a1b7a 100644
--- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
+++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt
@@ -6,12 +6,14 @@ R"(
__NAMESPACE_OPEN__
+__OP_VERIFIER_HELPERS__
+
__OP_BUILD_DEFS__
-void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
- ::mlir::OperationState &odsState,
- ::mlir::TypeRange resultTypes,
- ::mlir::ValueRange operands,
+void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState,
+ ::mlir::TypeRange resultTypes,
+ ::mlir::ValueRange operands,
::llvm::ArrayRef<::mlir::NamedAttribute> attributes)
{
assert(operands.size() == __OP_OPERAND_COUNT__);
@@ -19,6 +21,9 @@ void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder,
odsState.addOperands(operands);
odsState.addAttributes(attributes);
odsState.addTypes(resultTypes);
+ for (unsigned i = 0; i != __OP_REGION_COUNT__; ++i) {
+ (void)odsState.addRegion();
+ }
}
__OP_CPP_NAME__
@@ -44,6 +49,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder,
return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes);
}
+__OP_VERIFIER__
__NAMESPACE_CLOSE__
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 53209a4..9fcb02e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3175,6 +3175,45 @@ applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
return success();
}
+/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the
+/// OpenMPIRBuilder.
+static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
+ llvm::OpenMPIRBuilder::LocationDescription loc(builder);
+
+ SmallVector<llvm::CanonicalLoopInfo *> translatedLoops;
+ SmallVector<llvm::Value *> translatedSizes;
+
+ for (Value size : op.getSizes()) {
+ llvm::Value *translatedSize = moduleTranslation.lookupValue(size);
+ assert(translatedSize &&
+ "sizes clause arguments must already be translated");
+ translatedSizes.push_back(translatedSize);
+ }
+
+ for (Value applyee : op.getApplyees()) {
+ llvm::CanonicalLoopInfo *consBuilderCLI =
+ moduleTranslation.lookupOMPLoop(applyee);
+ assert(applyee && "Canonical loop must already been translated");
+ translatedLoops.push_back(consBuilderCLI);
+ }
+
+ auto generatedLoops =
+ ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes);
+ if (!op.getGeneratees().empty()) {
+ for (auto [mlirLoop, genLoop] :
+ zip_equal(op.getGeneratees(), generatedLoops))
+ moduleTranslation.mapOmpLoop(mlirLoop, genLoop);
+ }
+
+ // CLIs can only be consumed once
+ for (Value applyee : op.getApplyees())
+ moduleTranslation.invalidateOmpLoop(applyee);
+
+ return success();
+}
+
/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
static llvm::AtomicOrdering
convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
@@ -6227,6 +6266,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
// the omp.canonical_loop.
return applyUnrollHeuristic(op, builder, moduleTranslation);
})
+ .Case([&](omp::TileOp op) {
+ return applyTile(op, builder, moduleTranslation);
+ })
.Case([&](omp::TargetAllocMemOp) {
return convertTargetAllocMemOp(*op, builder, moduleTranslation);
})