diff options
Diffstat (limited to 'mlir/lib')
29 files changed, 548 insertions, 172 deletions
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 8ee6308..0d56259 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -259,22 +259,23 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { +static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return *(static_cast<mlir::RewritePatternSet *>(module.ptr)); } -inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { +static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { return {module}; } -inline mlir::FrozenRewritePatternSet * +static inline mlir::FrozenRewritePatternSet * unwrap(MlirFrozenRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr); } -inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) { +static inline MlirFrozenRewritePatternSet +wrap(mlir::FrozenRewritePatternSet *module) { return {module}; } @@ -321,12 +322,12 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { //===----------------------------------------------------------------------===// #if MLIR_ENABLE_PDL_IN_PATTERNMATCH -inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { +static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { assert(module.ptr && "unexpected null module"); return static_cast<mlir::PDLPatternModule *>(module.ptr); } -inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { +static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { return {module}; } diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 7626d35..c64e10f5 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -123,7 +123,8 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality( vector::OuterProductOp, vector::ScanOp>( [&](Operation *op) { return converter.isLegal(op); }); target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp, - arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>(); + arith::ConstantOp, arith::SelectOp, vector::SplatOp, + vector::BroadcastOp>(); } void EmulateUnsupportedFloatsPass::runOnOperation() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 3bd763e..05fc7cb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1622,12 +1622,12 @@ static void generateCollapsedIndexingRegion( } } -void collapseOperandsAndResults(LinalgOp op, - const CollapsingInfo &collapsingInfo, - RewriterBase &rewriter, - SmallVectorImpl<Value> &inputOperands, - SmallVectorImpl<Value> &outputOperands, - SmallVectorImpl<Type> &resultTypes) { +static void collapseOperandsAndResults(LinalgOp op, + const CollapsingInfo &collapsingInfo, + RewriterBase &rewriter, + SmallVectorImpl<Value> &inputOperands, + SmallVectorImpl<Value> &outputOperands, + SmallVectorImpl<Type> &resultTypes) { Location loc = op->getLoc(); inputOperands = llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) { @@ -1651,8 +1651,8 @@ void collapseOperandsAndResults(LinalgOp op, /// Clone a `LinalgOp` to a collapsed version of same name template <typename OpTy> -OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, - const CollapsingInfo &collapsingInfo) { +static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, + const CollapsingInfo &collapsingInfo) { return nullptr; } @@ -1699,8 +1699,9 @@ GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter, return collapsedOp; } -LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, - RewriterBase &rewriter) { +static LinalgOp createCollapsedOp(LinalgOp op, + const CollapsingInfo &collapsingInfo, + RewriterBase &rewriter) { if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) { return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo); } else { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 32ebe06..5672942 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Support/InterleavedRange.h" #include <cstddef> #include <iterator> #include <optional> @@ -3385,6 +3386,9 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { Value result = getResult(); auto [newCli, gen, cons] = decodeCli(result); + // Structured binding `gen` cannot be captured in lambdas before C++20 + OpOperand *generator = gen; + // Derive the CLI variable name from its generator: // * "canonloop" for omp.canonical_loop // * custom name for loop transformation generatees @@ -3403,6 +3407,24 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { .Case([&](UnrollHeuristicOp op) -> std::string { llvm_unreachable("heuristic unrolling does not generate a loop"); }) + .Case([&](TileOp op) -> std::string { + auto [generateesFirst, generateesCount] = + op.getGenerateesODSOperandIndexAndLength(); + unsigned firstGrid = generateesFirst; + unsigned firstIntratile = generateesFirst + generateesCount / 2; + unsigned end = generateesFirst + generateesCount; + unsigned opnum = generator->getOperandNumber(); + // In the OpenMP apply and looprange clauses, indices are 1-based + if (firstGrid <= opnum && opnum < firstIntratile) { + unsigned gridnum = opnum - firstGrid + 1; + return ("grid" + Twine(gridnum)).str(); + } + if (firstIntratile <= opnum && opnum < end) { + unsigned intratilenum = opnum - firstIntratile + 1; + return ("intratile" + Twine(intratilenum)).str(); + } + llvm_unreachable("Unexpected generatee argument"); + }) .Default([&](Operation *op) { assert(false && "TODO: Custom name for this operation"); return "transformed"; @@ -3632,6 +3654,138 @@ UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() { } //===----------------------------------------------------------------------===// +// TileOp +//===----------------------------------------------------------------------===// + +static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, + OperandRange generatees, + OperandRange applyees) { + if (!generatees.empty()) + p << '(' << llvm::interleaved(generatees) << ')'; + + if (!applyees.empty()) + p << " <- (" << llvm::interleaved(applyees) << ')'; +} + +static ParseResult parseLoopTransformClis( + OpAsmParser &parser, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &generateesOperands, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &applyeesOperands) { + if (parser.parseOptionalLess()) { + // Syntax 1: generatees present + + if (parser.parseOperandList(generateesOperands, + mlir::OpAsmParser::Delimiter::Paren)) + return failure(); + + if (parser.parseLess()) + return failure(); + } else { + // Syntax 2: generatees omitted + } + + // Parse `<-` (`<` has already been parsed) + if (parser.parseMinus()) + return failure(); + + if (parser.parseOperandList(applyeesOperands, + mlir::OpAsmParser::Delimiter::Paren)) + return failure(); + + return success(); +} + +LogicalResult TileOp::verify() { + if (getApplyees().empty()) + return emitOpError() << "must apply to at least one loop"; + + if (getSizes().size() != getApplyees().size()) + return emitOpError() << "there must be one tile size for each applyee"; + + if (!getGeneratees().empty() && + 2 * getSizes().size() != getGeneratees().size()) + return emitOpError() + << "expecting two times the number of generatees than applyees"; + + DenseSet<Value> parentIVs; + + Value parent = getApplyees().front(); + for (auto &&applyee : llvm::drop_begin(getApplyees())) { + auto [parentCreate, parentGen, parentCons] = decodeCli(parent); + auto [create, gen, cons] = decodeCli(applyee); + + if (!parentGen) + return emitOpError() << "applyee CLI has no generator"; + + auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner()); + if (!parentGen) + return emitOpError() + << "currently only supports omp.canonical_loop as applyee"; + + parentIVs.insert(parentLoop.getInductionVar()); + + if (!gen) + return emitOpError() << "applyee CLI has no generator"; + auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner()); + if (!loop) + return emitOpError() + << "currently only supports omp.canonical_loop as applyee"; + + // Canonical loop must be perfectly nested, i.e. the body of the parent must + // only contain the omp.canonical_loop of the nested loops, and + // omp.terminator + bool isPerfectlyNested = [&]() { + auto &parentBody = parentLoop.getRegion(); + if (!parentBody.hasOneBlock()) + return false; + auto &parentBlock = parentBody.getBlocks().front(); + + auto nestedLoopIt = parentBlock.begin(); + if (nestedLoopIt == parentBlock.end() || + (&*nestedLoopIt != loop.getOperation())) + return false; + + auto termIt = std::next(nestedLoopIt); + if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt)) + return false; + + if (std::next(termIt) != parentBlock.end()) + return false; + + return true; + }(); + if (!isPerfectlyNested) + return emitOpError() << "tiled loop nest must be perfectly nested"; + + if (parentIVs.contains(loop.getTripCount())) + return emitOpError() << "tiled loop nest must be rectangular"; + + parent = applyee; + } + + // TODO: The tile sizes must be computed before the loop, but checking this + // requires dominance analysis. For instance: + // + // %canonloop = omp.new_cli + // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { + // // write to %x + // omp.terminator + // } + // %ts = llvm.load %x + // omp.tile <- (%canonloop) sizes(%ts : i32) + + return success(); +} + +std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() { + return getODSOperandIndexAndLength(odsIndex_applyees); +} + +std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() { + return getODSOperandIndexAndLength(odsIndex_generatees); +} + +//===----------------------------------------------------------------------===// // Critical construct (2.17.1) //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eb46869..b0132e8 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -580,7 +580,7 @@ namespace { // ElideSingleElementReduction for ReduceOp. struct ElideUnitDimsInMultiDimReduction : public OpRewritePattern<MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { @@ -730,7 +730,7 @@ std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() { namespace { struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ReductionOp reductionOp, PatternRewriter &rewriter) const override { @@ -2197,7 +2197,7 @@ namespace { // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -2220,7 +2220,7 @@ public: // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask. class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -2546,7 +2546,7 @@ rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(FromElementsOp fromElements, PatternRewriter &rewriter) const override { @@ -2938,7 +2938,7 @@ namespace { // Fold broadcast1(broadcast2(x)) into broadcast1(x). struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(BroadcastOp broadcastOp, PatternRewriter &rewriter) const override { @@ -3109,7 +3109,7 @@ namespace { // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector // to a broadcast. struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { @@ -3165,7 +3165,7 @@ static Value getScalarSplatSource(Value value) { /// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v). class ShuffleSplat final : public OpRewritePattern<ShuffleOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { @@ -3182,7 +3182,7 @@ public: /// vector.interleave. class ShuffleInterleave : public OpRewritePattern<ShuffleOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { @@ -3326,7 +3326,7 @@ namespace { // broadcast. class InsertToBroadcast final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp insertOp, PatternRewriter &rewriter) const override { @@ -3344,7 +3344,7 @@ public: /// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v). class InsertSplatToSplat final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { @@ -3380,7 +3380,7 @@ public: /// %result = vector.from_elements %c1, %c2 : vector<2xi32> class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { @@ -3748,7 +3748,7 @@ namespace { class FoldInsertStridedSliceSplat final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -3768,7 +3768,7 @@ public: class FoldInsertStridedSliceOfExtract final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -3798,7 +3798,7 @@ public: class InsertStridedSliceConstantFolder final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; // Do not create constants with more than `vectorSizeFoldThreashold` elements, // unless the source vector constant has a single use. @@ -4250,7 +4250,7 @@ namespace { // %mask = vector.create_mask %new_ub : vector<8xi1> class StridedSliceCreateMaskFolder final : public OpRewritePattern<ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; public: LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, @@ -4310,7 +4310,7 @@ public: class StridedSliceConstantMaskFolder final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, PatternRewriter &rewriter) const override { @@ -4365,7 +4365,7 @@ public: class StridedSliceBroadcast final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -4416,7 +4416,7 @@ public: /// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v). class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -4448,7 +4448,7 @@ public: class ContiguousExtractStridedSliceToExtract final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -5023,7 +5023,7 @@ namespace { /// ``` struct TransferReadAfterWriteToBroadcast : public OpRewritePattern<TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -5458,7 +5458,7 @@ namespace { /// any other uses. class FoldWaw final : public OpRewritePattern<TransferWriteOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransferWriteOp writeOp, PatternRewriter &rewriter) const override { if (!llvm::isa<RankedTensorType>(writeOp.getShapedType())) @@ -5514,7 +5514,7 @@ public: struct SwapExtractSliceOfTransferWrite : public OpRewritePattern<tensor::InsertSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -5737,7 +5737,7 @@ LogicalResult MaskedLoadOp::verify() { namespace { class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskedLoadOp load, PatternRewriter &rewriter) const override { switch (getMaskFormat(load.getMask())) { @@ -5794,7 +5794,7 @@ LogicalResult MaskedStoreOp::verify() { namespace { class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskedStoreOp store, PatternRewriter &rewriter) const override { switch (getMaskFormat(store.getMask())) { @@ -5890,7 +5890,7 @@ static LogicalResult isZeroBasedContiguousSeq(Value indexVec) { namespace { class GatherFolder final : public OpRewritePattern<GatherOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(GatherOp gather, PatternRewriter &rewriter) const override { switch (getMaskFormat(gather.getMask())) { @@ -5910,7 +5910,7 @@ public: /// maskedload. Only 1D fixed vectors are supported for now. class FoldContiguousGather final : public OpRewritePattern<GatherOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(GatherOp op, PatternRewriter &rewriter) const override { if (!isa<MemRefType>(op.getBase().getType())) @@ -5962,7 +5962,7 @@ LogicalResult ScatterOp::verify() { namespace { class ScatterFolder final : public OpRewritePattern<ScatterOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ScatterOp scatter, PatternRewriter &rewriter) const override { switch (getMaskFormat(scatter.getMask())) { @@ -5982,7 +5982,7 @@ public: /// maskedstore. Only 1D fixed vectors are supported for now. class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ScatterOp op, PatternRewriter &rewriter) const override { if (failed(isZeroBasedContiguousSeq(op.getIndices()))) @@ -6030,7 +6030,7 @@ LogicalResult ExpandLoadOp::verify() { namespace { class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExpandLoadOp expand, PatternRewriter &rewriter) const override { switch (getMaskFormat(expand.getMask())) { @@ -6081,7 +6081,7 @@ LogicalResult CompressStoreOp::verify() { namespace { class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(CompressStoreOp compress, PatternRewriter &rewriter) const override { switch (getMaskFormat(compress.getMask())) { @@ -6260,7 +6260,7 @@ static VectorType trimTrailingOneDims(VectorType oldType) { class ShapeCastCreateMaskFolderTrailingOneDim final : public OpRewritePattern<ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShapeCastOp shapeOp, PatternRewriter &rewriter) const override { @@ -6330,7 +6330,7 @@ public: /// If both (i) and (ii) are possible, (i) is chosen. class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { @@ -6614,7 +6614,7 @@ namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6646,7 +6646,7 @@ public: /// Replace transpose(splat-like(v)) with broadcast(v) class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6663,7 +6663,7 @@ public: /// Folds transpose(create_mask) into a new transposed create_mask. class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transpOp, PatternRewriter &rewriter) const override { @@ -6700,7 +6700,7 @@ public: /// Folds transpose(shape_cast) into a new shape_cast. class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6750,7 +6750,7 @@ public: /// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6). class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern<vector::TransposeOp>(context, benefit) {} @@ -6971,7 +6971,7 @@ namespace { /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { @@ -7300,7 +7300,7 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor, /// %0 = arith.select %mask, %a, %passthru : vector<8xf32> /// class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskOp maskOp, PatternRewriter &rewriter) const override { @@ -7410,7 +7410,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { // vector.broadcast. class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> { public: - using OpRewritePattern<SplatOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(SplatOp splatOp, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index dedc3b3..61d9357 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -34,7 +34,7 @@ namespace { /// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly. class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BroadcastOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 65702ff..efe8d14 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -1151,7 +1151,7 @@ FailureOr<Value> ContractionOpLowering::lowerReduction( /// class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 1f96a3a..6bc8347 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -50,7 +50,7 @@ namespace { /// /// Supports vector types with a fixed leading dimension. struct UnrollGather : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { @@ -98,7 +98,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> { /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef, /// but should be fairly straightforward to extend beyond that. struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { @@ -164,7 +164,7 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these /// loads/extracts are made conditional using `scf.if` ops. struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp index 9d6a865..479fc0c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp @@ -163,7 +163,7 @@ private: /// : vector<7xi16>, vector<7xi16> /// ``` struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InterleaveOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 5617b06..7730c4e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -48,7 +48,7 @@ namespace { /// until a one-dimensional vector is reached. class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { @@ -100,7 +100,7 @@ public: /// will be folded at LLVM IR level. class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ConstantMaskOp op, PatternRewriter &rewriter) const override { @@ -184,7 +184,7 @@ namespace { /// and actually match the traits of its the nested `MaskableOpInterface`. template <class SourceOp> struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { - using OpRewritePattern<MaskOp>::OpRewritePattern; + using Base::Base; private: LogicalResult matchAndRewrite(MaskOp maskOp, diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 4773732d..e86e2a9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -39,7 +39,7 @@ namespace { class InnerOuterDimReductionConversion : public OpRewritePattern<vector::MultiDimReductionOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; explicit InnerOuterDimReductionConversion( MLIRContext *context, vector::VectorMultiReductionLowering options, @@ -136,7 +136,7 @@ private: class ReduceMultiDimReductionRank : public OpRewritePattern<vector::MultiDimReductionOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; explicit ReduceMultiDimReductionRank( MLIRContext *context, vector::VectorMultiReductionLowering options, @@ -304,7 +304,7 @@ private: /// and combines results struct TwoDimMultiReductionToElementWise : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -359,7 +359,7 @@ struct TwoDimMultiReductionToElementWise /// a sequence of vector.reduction ops. struct TwoDimMultiReductionToReduction : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -420,7 +420,7 @@ struct TwoDimMultiReductionToReduction /// separately. struct OneDimMultiReductionToTwoDim : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index af4851e..258f2cb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -99,7 +99,7 @@ namespace { /// return %7, %8 : vector<2x3xi32>, vector<2xi32> /// ``` struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ScanOp scanOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index 603ea41..c5f22b2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -189,7 +189,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { } public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { @@ -356,7 +356,7 @@ public: class ScalableShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp index 78102f7..8f46ad6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp @@ -44,7 +44,7 @@ namespace { /// struct MixedSizeInputShuffleOpRewrite final : OpRewritePattern<vector::ShuffleOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp index ee5568a..08e7c89 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp @@ -24,7 +24,7 @@ using namespace mlir::vector; namespace { struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::StepOp stepOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index 6407a86..7521e24 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -667,7 +667,7 @@ getToElementsDefiningOps(FromElementsOp fromElemsOp, struct ToFromElementsToShuffleTreeRewrite final : OpRewritePattern<vector::FromElementsOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 9e7d0ce..c3f7de0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -300,7 +300,7 @@ namespace { /// %x = vector.insert .., .. [.., ..] class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering, MLIRContext *context, PatternBenefit benefit = 1) @@ -395,7 +395,7 @@ private: class Transpose2DWithUnitDimToShapeCast : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; Transpose2DWithUnitDimToShapeCast(MLIRContext *context, PatternBenefit benefit = 1) @@ -433,7 +433,7 @@ public: class TransposeOp2DToShuffleLowering : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; TransposeOp2DToShuffleLowering( vector::VectorTransposeLowering vectorTransposeLowering, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index cab1289..963b2c8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -54,7 +54,7 @@ namespace { // input by inserting vector.broadcast. struct CastAwayExtractStridedSliceLeadingOneDim : public OpRewritePattern<vector::ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { @@ -104,7 +104,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim // inputs by inserting vector.broadcast. struct CastAwayInsertStridedSliceLeadingOneDim : public OpRewritePattern<vector::InsertStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -145,7 +145,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim // Casts away leading one dimensions in vector.insert's vector inputs by // inserting vector.broadcast. struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InsertOp insertOp, PatternRewriter &rewriter) const override { @@ -221,7 +221,7 @@ static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, // 1 dimensions. struct CastAwayTransferReadLeadingOneDim : public OpRewritePattern<vector::TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { @@ -275,7 +275,7 @@ struct CastAwayTransferReadLeadingOneDim // 1 dimensions. struct CastAwayTransferWriteLeadingOneDim : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { @@ -541,7 +541,7 @@ public: // vector.broadcast back to the original shape. struct CastAwayConstantMaskLeadingOneDim : public OpRewritePattern<vector::ConstantMaskOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp index bdbb792..7acc120 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp @@ -48,7 +48,7 @@ namespace { /// struct VectorMaskedLoadOpConverter final : OpRewritePattern<vector::MaskedLoadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp, PatternRewriter &rewriter) const override { @@ -117,7 +117,7 @@ struct VectorMaskedLoadOpConverter final /// struct VectorMaskedStoreOpConverter final : OpRewritePattern<vector::MaskedStoreOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 264cbc1..3a6684f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -548,7 +548,7 @@ namespace { // NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to // `false` to generate non-atomic RMW sequences. struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW) : OpConversionPattern<vector::StoreOp>(context), @@ -827,7 +827,7 @@ private: /// adjusted mask . struct ConvertVectorMaskedStore final : OpConversionPattern<vector::MaskedStoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor, @@ -950,7 +950,7 @@ struct ConvertVectorMaskedStore final /// those cases, loads are converted to byte-aligned, byte-sized loads and the /// target vector is extracted from the loaded vector. struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor, @@ -1059,7 +1059,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> { /// bitcasting, since each `i8` container element holds two `i4` values. struct ConvertVectorMaskedLoad final : OpConversionPattern<vector::MaskedLoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor, @@ -1257,7 +1257,7 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy, // TODO: Document-me struct ConvertVectorTransferRead final : OpConversionPattern<vector::TransferReadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor, @@ -1942,7 +1942,7 @@ namespace { /// advantage of high-level information to avoid leaving LLVM to scramble with /// peephole optimizations. struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp, PatternRewriter &rewriter) const override { @@ -2147,7 +2147,7 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> { /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4> /// struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> { - using OpRewritePattern<arith::TruncIOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(arith::TruncIOp truncOp, PatternRewriter &rewriter) const override { @@ -2200,7 +2200,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> { /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4> /// struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; + using Base::Base; RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit) : OpRewritePattern<vector::TransposeOp>(context, benefit) {} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index f6d6555..9e49873 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -34,7 +34,7 @@ using namespace mlir::vector; class DecomposeDifferentRankInsertStridedSlice : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -84,7 +84,7 @@ public: class ConvertSameRankInsertStridedSliceIntoShuffle : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; + using Base::Base; void initialize() { // This pattern creates recursive InsertStridedSliceOp, but the recursion is @@ -183,7 +183,7 @@ public: class Convert1DExtractStridedSliceIntoShuffle : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -271,7 +271,7 @@ private: class DecomposeNDExtractStridedSlice : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; + using Base::Base; void initialize() { // This pattern creates recursive ExtractStridedSliceOp, but the recursion diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 82bac8c..71fba71c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -214,7 +214,7 @@ SmallVector<int64_t> static getStridedSliceInsertionIndices( /// vector.extract_strided_slice operation. struct LinearizeVectorExtractStridedSlice final : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -285,7 +285,7 @@ struct LinearizeVectorExtractStridedSlice final /// struct LinearizeVectorInsertStridedSlice final : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -348,7 +348,7 @@ struct LinearizeVectorInsertStridedSlice final /// of the original shuffle operation. struct LinearizeVectorShuffle final : public OpConversionPattern<vector::ShuffleOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorShuffle(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -423,7 +423,7 @@ struct LinearizeVectorShuffle final /// struct LinearizeVectorExtract final : public OpConversionPattern<vector::ExtractOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorExtract(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -501,7 +501,7 @@ struct LinearizeVectorExtract final /// struct LinearizeVectorInsert final : public OpConversionPattern<vector::InsertOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorInsert(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -575,7 +575,7 @@ struct LinearizeVectorInsert final /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> struct LinearizeVectorBitCast final : public OpConversionPattern<vector::BitCastOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorBitCast(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -598,7 +598,7 @@ struct LinearizeVectorBitCast final /// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> struct LinearizeVectorSplat final : public OpConversionPattern<vector::SplatOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -629,7 +629,7 @@ struct LinearizeVectorSplat final /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> struct LinearizeVectorCreateMask final : OpConversionPattern<vector::CreateMaskOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorCreateMask(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -684,7 +684,7 @@ struct LinearizeVectorCreateMask final /// For generic cases, the vector unroll pass should be used to unroll the load /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -731,7 +731,7 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> { /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorStore final : public OpConversionPattern<vector::StoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -778,7 +778,7 @@ struct LinearizeVectorStore final /// struct LinearizeVectorFromElements final : public OpConversionPattern<vector::FromElementsOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorFromElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -814,7 +814,7 @@ struct LinearizeVectorFromElements final /// struct LinearizeVectorToElements final : public OpConversionPattern<vector::ToElementsOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorToElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index c364a8b..1121d95 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -1081,7 +1081,7 @@ private: /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) /// to memref.store. class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 866f789..d6a6d7cd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -78,7 +78,7 @@ namespace { /// ``` struct MultiReduceToContract : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, PatternRewriter &rewriter) const override { @@ -138,7 +138,7 @@ struct MultiReduceToContract /// ``` struct CombineContractABTranspose final : public OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -202,7 +202,7 @@ struct CombineContractABTranspose final /// ``` struct CombineContractResultTranspose final : public OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp resTOp, PatternRewriter &rewriter) const override { @@ -568,7 +568,7 @@ static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) { // %2 = vector.extract %1[1] : f16 from vector<2xf16> struct BubbleDownVectorBitCastForExtract : public OpRewritePattern<vector::ExtractOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -643,7 +643,7 @@ struct BubbleDownVectorBitCastForExtract // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> struct BubbleDownBitCastForStridedSliceExtract : public OpRewritePattern<vector::ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { @@ -721,7 +721,7 @@ struct BubbleDownBitCastForStridedSliceExtract // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8> // struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override { @@ -794,7 +794,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> struct BubbleUpBitCastForStridedSliceInsert : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override { @@ -892,7 +892,7 @@ struct BubbleUpBitCastForStridedSliceInsert // %7 = vector.insert_strided_slice %6, %cst { // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; public: BreakDownVectorBitCast(MLIRContext *context, @@ -1131,7 +1131,7 @@ struct ReorderElementwiseOpsOnBroadcast final class ExtractOpFromElementwise final : public OpRewritePattern<vector::ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { @@ -1206,7 +1206,7 @@ static bool isSupportedMemSinkElementType(Type type) { /// ``` class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { @@ -1285,7 +1285,7 @@ public: class StoreOpFromSplatOrBroadcast final : public OpRewritePattern<vector::StoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::StoreOp op, PatternRewriter &rewriter) const override { @@ -1476,7 +1476,7 @@ static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) { /// InstCombine seems to handle vectors with multiple elements but not the /// single element ones. struct FoldI1Select : public OpRewritePattern<arith::SelectOp> { - using OpRewritePattern<arith::SelectOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(arith::SelectOp selectOp, PatternRewriter &rewriter) const override { @@ -1560,7 +1560,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { /// Drop inner most contiguous unit dimensions from transfer_read operand. class DropInnerMostUnitDimsTransferRead : public OpRewritePattern<vector::TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -1651,7 +1651,7 @@ class DropInnerMostUnitDimsTransferRead /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`). class DropInnerMostUnitDimsTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { @@ -1728,7 +1728,7 @@ class DropInnerMostUnitDimsTransferWrite /// with the RHS transposed) lowering. struct CanonicalizeContractMatmulToMMT final : OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; using FilterConstraintType = std::function<LogicalResult(vector::ContractionOp op)>; @@ -1845,7 +1845,7 @@ private: template <typename ExtOp> struct FoldArithExtIntoContractionOp : public OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -1878,7 +1878,7 @@ struct FoldArithExtIntoContractionOp /// %b = vector.reduction <add> %a, %acc /// ``` struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { @@ -2033,7 +2033,7 @@ struct DropUnitDimFromElementwiseOps final /// ``` struct DropUnitDimsFromTransposeOp final : OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { @@ -2110,7 +2110,7 @@ struct DropUnitDimsFromTransposeOp final /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> /// ``` struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override { @@ -2155,7 +2155,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> { /// %c = vector.reduction <add> %b, %acc /// ``` struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9413a92..784e5d6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -824,7 +824,7 @@ struct WgToSgStoreScatterOpWithOffset return failure(); xegpu::DistributeLayoutAttr layout = - xegpu::getDistributeLayoutAttr(op.getValue()); + xegpu::getDistributeLayoutAttr(op.getOperand(0)); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -844,12 +844,19 @@ struct WgToSgStoreScatterOpWithOffset auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); for (auto [val, offs, mask] : llvm::zip( adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { - xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs, - mask, chunkSizeAttr, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + auto store = xegpu::StoreScatterOp::create( + rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr, + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); // Update the layout attribute to drop sg_layout and sg_data. - if (auto newLayout = layout.dropSgLayoutAndData()) - op->setAttr("layout", newLayout); + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) { + for (OpOperand &operand : store->getOpOperands()) { + // Skip for operand one (memref) + if (operand.getOperandNumber() == 1) + continue; + xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); + } + } } rewriter.eraseOp(op); return success(); @@ -1247,10 +1254,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp<xegpu::StoreScatterOp>( [=](xegpu::StoreScatterOp op) -> bool { - // Check if the layout attribute is present on the result. - auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout"); - if (!layout) - return true; + auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0)); return isLegal(layout); }); diff --git a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp index d6b8a8a..e3f075f 100644 --- a/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp +++ b/mlir/lib/Target/IRDLToCpp/IRDLToCpp.cpp @@ -54,6 +54,7 @@ struct OpStrings { std::string opCppName; SmallVector<std::string> opResultNames; SmallVector<std::string> opOperandNames; + SmallVector<std::string> opRegionNames; }; static std::string joinNameList(llvm::ArrayRef<std::string> names) { @@ -87,8 +88,8 @@ static TypeStrings getStrings(irdl::TypeOp type) { /// Generates OpStrings from an OperatioOp static OpStrings getStrings(irdl::OperationOp op) { auto operandOp = op.getOp<irdl::OperandsOp>(); - auto resultOp = op.getOp<irdl::ResultsOp>(); + auto regionsOp = op.getOp<irdl::RegionsOp>(); OpStrings strings; strings.opName = op.getSymName(); @@ -108,6 +109,13 @@ static OpStrings getStrings(irdl::OperationOp op) { })); } + if (regionsOp) { + strings.opRegionNames = SmallVector<std::string>( + llvm::map_range(regionsOp->getNames(), [](Attribute attr) { + return llvm::formatv("{0}", cast<StringAttr>(attr)); + })); + } + return strings; } @@ -122,6 +130,7 @@ static void fillDict(irdl::detail::dictionary &dict, static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) { const auto operandCount = strings.opOperandNames.size(); const auto resultCount = strings.opResultNames.size(); + const auto regionCount = strings.opRegionNames.size(); dict["OP_NAME"] = strings.opName; dict["OP_CPP_NAME"] = strings.opCppName; @@ -131,6 +140,7 @@ static void fillDict(irdl::detail::dictionary &dict, const OpStrings &strings) { operandCount ? joinNameList(strings.opOperandNames) : "{\"\"}"; dict["OP_RESULT_INITIALIZER_LIST"] = resultCount ? joinNameList(strings.opResultNames) : "{\"\"}"; + dict["OP_REGION_COUNT"] = std::to_string(regionCount); } /// Fills a dictionary with values from DialectStrings @@ -179,6 +189,8 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, const OpStrings &opStrings) { auto opGetters = std::string{}; auto resGetters = std::string{}; + auto regionGetters = std::string{}; + auto regionAdaptorGetters = std::string{}; for (size_t i = 0, end = opStrings.opOperandNames.size(); i < end; ++i) { const auto op = @@ -196,8 +208,23 @@ static void generateOpGetterDeclarations(irdl::detail::dictionary &dict, op, i); } + for (size_t i = 0, end = opStrings.opRegionNames.size(); i < end; ++i) { + const auto op = + llvm::convertToCamelFromSnakeCase(opStrings.opRegionNames[i], true); + regionAdaptorGetters += llvm::formatv( + R"(::mlir::Region &get{0}() { return *getRegions()[{1}]; } + )", + op, i); + regionGetters += llvm::formatv( + R"(::mlir::Region &get{0}() { return (*this)->getRegion({1}); } + )", + op, i); + } + dict["OP_OPERAND_GETTER_DECLS"] = opGetters; dict["OP_RESULT_GETTER_DECLS"] = resGetters; + dict["OP_REGION_ADAPTER_GETTER_DECLS"] = regionAdaptorGetters; + dict["OP_REGION_GETTER_DECLS"] = regionGetters; } static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, @@ -238,6 +265,22 @@ static void generateOpBuilderDeclarations(irdl::detail::dictionary &dict, dict["OP_BUILD_DECLS"] = buildDecls; } +// add traits to the dictionary, return true if any were added +static SmallVector<std::string> generateTraits(irdl::OperationOp op, + const OpStrings &strings) { + SmallVector<std::string> cppTraitNames; + if (!strings.opRegionNames.empty()) { + cppTraitNames.push_back( + llvm::formatv("::mlir::OpTrait::NRegions<{0}>::Impl", + strings.opRegionNames.size()) + .str()); + + // Requires verifyInvariantsImpl is implemented on the op + cppTraitNames.emplace_back("::mlir::OpTrait::OpInvariants"); + } + return cppTraitNames; +} + static LogicalResult generateOperationInclude(irdl::OperationOp op, raw_ostream &output, irdl::detail::dictionary &dict) { @@ -247,6 +290,13 @@ static LogicalResult generateOperationInclude(irdl::OperationOp op, const auto opStrings = getStrings(op); fillDict(dict, opStrings); + SmallVector<std::string> traitNames = generateTraits(op, opStrings); + if (traitNames.empty()) + dict["OP_TEMPLATE_ARGS"] = opStrings.opCppName; + else + dict["OP_TEMPLATE_ARGS"] = llvm::formatv("{0}, {1}", opStrings.opCppName, + llvm::join(traitNames, ", ")); + generateOpGetterDeclarations(dict, opStrings); generateOpBuilderDeclarations(dict, opStrings); @@ -301,6 +351,110 @@ static LogicalResult generateInclude(irdl::DialectOp dialect, return success(); } +static void generateRegionConstraintVerifiers( + irdl::detail::dictionary &dict, irdl::OperationOp op, + const OpStrings &strings, SmallVectorImpl<std::string> &verifierHelpers, + SmallVectorImpl<std::string> &verifierCalls) { + auto regionsOp = op.getOp<irdl::RegionsOp>(); + if (strings.opRegionNames.empty() || !regionsOp) + return; + + for (size_t i = 0; i < strings.opRegionNames.size(); ++i) { + std::string regionName = strings.opRegionNames[i]; + std::string helperFnName = + llvm::formatv("__mlir_irdl_local_region_constraint_{0}_{1}", + strings.opCppName, regionName) + .str(); + + // Extract the actual region constraint from the IRDL RegionOp + std::string condition = "true"; + std::string textualConditionName = "any region"; + + if (auto regionDefOp = + dyn_cast<irdl::RegionOp>(regionsOp->getArgs()[i].getDefiningOp())) { + // Generate constraint condition based on RegionOp attributes + SmallVector<std::string> conditionParts; + SmallVector<std::string> descriptionParts; + + // Check number of blocks constraint + if (auto blockCount = regionDefOp.getNumberOfBlocks()) { + conditionParts.push_back( + llvm::formatv("region.getBlocks().size() == {0}", + blockCount.value()) + .str()); + descriptionParts.push_back( + llvm::formatv("exactly {0} block(s)", blockCount.value()).str()); + } + + // Check entry block arguments constraint + if (regionDefOp.getConstrainedArguments()) { + size_t expectedArgCount = regionDefOp.getEntryBlockArgs().size(); + conditionParts.push_back( + llvm::formatv("region.getNumArguments() == {0}", expectedArgCount) + .str()); + descriptionParts.push_back( + llvm::formatv("{0} entry block argument(s)", expectedArgCount) + .str()); + } + + // Combine conditions + if (!conditionParts.empty()) { + condition = llvm::join(conditionParts, " && "); + } + + // Generate descriptive error message + if (!descriptionParts.empty()) { + textualConditionName = + llvm::formatv("region with {0}", + llvm::join(descriptionParts, " and ")) + .str(); + } + } + + verifierHelpers.push_back(llvm::formatv( + R"(static ::llvm::LogicalResult {0}(::mlir::Operation *op, ::mlir::Region ®ion, ::llvm::StringRef regionName, unsigned regionIndex) {{ + if (!({1})) {{ + return op->emitOpError("region #") << regionIndex + << (regionName.empty() ? " " : " ('" + regionName + "') ") + << "failed to verify constraint: {2}"; + } + return ::mlir::success(); +})", + helperFnName, condition, textualConditionName)); + + verifierCalls.push_back(llvm::formatv(R"( + if (::mlir::failed({0}(*this, (*this)->getRegion({1}), "{2}", {1}))) + return ::mlir::failure();)", + helperFnName, i, regionName) + .str()); + } +} + +static void generateVerifiers(irdl::detail::dictionary &dict, + irdl::OperationOp op, const OpStrings &strings) { + SmallVector<std::string> verifierHelpers; + SmallVector<std::string> verifierCalls; + + generateRegionConstraintVerifiers(dict, op, strings, verifierHelpers, + verifierCalls); + + // Add an overall verifier that sequences the helper calls + std::string verifierDef = + llvm::formatv(R"( +::llvm::LogicalResult {0}::verifyInvariantsImpl() {{ + if(::mlir::failed(verify())) + return ::mlir::failure(); + + {1} + + return ::mlir::success(); +})", + strings.opCppName, llvm::join(verifierCalls, "\n")); + + dict["OP_VERIFIER_HELPERS"] = llvm::join(verifierHelpers, "\n"); + dict["OP_VERIFIER"] = verifierDef; +} + static std::string generateOpDefinition(irdl::detail::dictionary &dict, irdl::OperationOp op) { static const auto perOpDefTemplate = mlir::irdl::detail::Template{ @@ -370,6 +524,8 @@ void {0}::build(::mlir::OpBuilder &opBuilder, ::mlir::OperationState &opState, { dict["OP_BUILD_DEFS"] = buildDefinition; + generateVerifiers(dict, op, opStrings); + std::string str; llvm::raw_string_ostream stream{str}; perOpDefTemplate.render(stream, dict); @@ -427,7 +583,7 @@ static LogicalResult generateLib(irdl::DialectOp dialect, raw_ostream &output, dict["TYPE_PARSER"] = llvm::formatv( R"(static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef *mnemonic, ::mlir::Type &value) { return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser) - {0} + {0} .Default([&](llvm::StringRef keyword, llvm::SMLoc) {{ *mnemonic = keyword; return std::nullopt; @@ -520,6 +676,8 @@ static LogicalResult verifySupported(irdl::DialectOp dialect) { "IRDL C++ translation does not yet support variadic results"); })) .Case<irdl::AnyOp>(([](irdl::AnyOp) { return success(); })) + .Case<irdl::RegionOp>(([](irdl::RegionOp) { return success(); })) + .Case<irdl::RegionsOp>(([](irdl::RegionsOp) { return success(); })) .Default([](mlir::Operation *op) -> LogicalResult { return op->emitError("IRDL C++ translation does not yet support " "translation of ") diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt index e9068e9..93ce0be 100644 --- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt +++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDecl.txt @@ -12,15 +12,15 @@ public: struct Properties { }; public: - __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op) - : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()), - odsRegions(op->getRegions()) + __OP_CPP_NAME__GenericAdaptorBase(::mlir::Operation *op) + : odsAttrs(op->getRawDictionaryAttrs()), odsOpName(op->getName()), + odsRegions(op->getRegions()) {} /// Return the unstructured operand index of a structured operand along with // the amount of unstructured operands it contains. std::pair<unsigned, unsigned> - getStructuredOperandIndexAndLength (unsigned index, + getStructuredOperandIndexAndLength (unsigned index, unsigned odsOperandsSize) { return {index, 1}; } @@ -32,6 +32,12 @@ public: ::mlir::DictionaryAttr getAttributes() { return odsAttrs; } + + __OP_REGION_ADAPTER_GETTER_DECLS__ + + ::mlir::RegionRange getRegions() { + return odsRegions; + } protected: ::mlir::DictionaryAttr odsAttrs; ::std::optional<::mlir::OperationName> odsOpName; @@ -42,28 +48,28 @@ protected: } // namespace detail template <typename RangeT> -class __OP_CPP_NAME__GenericAdaptor +class __OP_CPP_NAME__GenericAdaptor : public detail::__OP_CPP_NAME__GenericAdaptorBase { using ValueT = ::llvm::detail::ValueOfRange<RangeT>; using Base = detail::__OP_CPP_NAME__GenericAdaptorBase; public: __OP_CPP_NAME__GenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, - ::mlir::OpaqueProperties properties, - ::mlir::RegionRange regions = {}) - : __OP_CPP_NAME__GenericAdaptor(values, attrs, - (properties ? *properties.as<::mlir::EmptyProperties *>() + ::mlir::OpaqueProperties properties, + ::mlir::RegionRange regions = {}) + : __OP_CPP_NAME__GenericAdaptor(values, attrs, + (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} - __OP_CPP_NAME__GenericAdaptor(RangeT values, + __OP_CPP_NAME__GenericAdaptor(RangeT values, const __OP_CPP_NAME__GenericAdaptorBase &base) : Base(base), odsOperands(values) {} - // This template parameter allows using __OP_CPP_NAME__ which is declared + // This template parameter allows using __OP_CPP_NAME__ which is declared // later. template <typename LateInst = __OP_CPP_NAME__, typename = std::enable_if_t< std::is_same_v<LateInst, __OP_CPP_NAME__>>> - __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op) + __OP_CPP_NAME__GenericAdaptor(RangeT values, LateInst op) : Base(op), odsOperands(values) {} /// Return the unstructured operand index of a structured operand along with @@ -77,7 +83,7 @@ public: RangeT getStructuredOperands(unsigned index) { auto valueRange = getStructuredOperandIndexAndLength(index); return {std::next(odsOperands.begin(), valueRange.first), - std::next(odsOperands.begin(), + std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; } @@ -91,7 +97,7 @@ private: RangeT odsOperands; }; -class __OP_CPP_NAME__Adaptor +class __OP_CPP_NAME__Adaptor : public __OP_CPP_NAME__GenericAdaptor<::mlir::ValueRange> { public: using __OP_CPP_NAME__GenericAdaptor::__OP_CPP_NAME__GenericAdaptor; @@ -100,7 +106,7 @@ public: ::llvm::LogicalResult verify(::mlir::Location loc); }; -class __OP_CPP_NAME__ : public ::mlir::Op<__OP_CPP_NAME__> { +class __OP_CPP_NAME__ : public ::mlir::Op<__OP_TEMPLATE_ARGS__> { public: using Op::Op; using Op::print; @@ -112,6 +118,8 @@ public: return {}; } + ::llvm::LogicalResult verifyInvariantsImpl(); + static constexpr ::llvm::StringLiteral getOperationName() { return ::llvm::StringLiteral("__DIALECT_NAME__.__OP_NAME__"); } @@ -147,7 +155,7 @@ public: ::mlir::Operation::operand_range getStructuredOperands(unsigned index) { auto valueRange = getStructuredOperandIndexAndLength(index); return {std::next(getOperation()->operand_begin(), valueRange.first), - std::next(getOperation()->operand_begin(), + std::next(getOperation()->operand_begin(), valueRange.first + valueRange.second)}; } @@ -162,18 +170,19 @@ public: ::mlir::Operation::result_range getStructuredResults(unsigned index) { auto valueRange = getStructuredResultIndexAndLength(index); return {std::next(getOperation()->result_begin(), valueRange.first), - std::next(getOperation()->result_begin(), + std::next(getOperation()->result_begin(), valueRange.first + valueRange.second)}; } __OP_OPERAND_GETTER_DECLS__ __OP_RESULT_GETTER_DECLS__ - + __OP_REGION_GETTER_DECLS__ + __OP_BUILD_DECLS__ - static void build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, - ::mlir::TypeRange resultTypes, - ::mlir::ValueRange operands, + static void build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, + ::mlir::TypeRange resultTypes, + ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); static __OP_CPP_NAME__ create(::mlir::OpBuilder &odsBuilder, diff --git a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt index 30ca420..f4a1b7a 100644 --- a/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt +++ b/mlir/lib/Target/IRDLToCpp/Templates/PerOperationDef.txt @@ -6,12 +6,14 @@ R"( __NAMESPACE_OPEN__ +__OP_VERIFIER_HELPERS__ + __OP_BUILD_DEFS__ -void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, - ::mlir::OperationState &odsState, - ::mlir::TypeRange resultTypes, - ::mlir::ValueRange operands, +void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, + ::mlir::OperationState &odsState, + ::mlir::TypeRange resultTypes, + ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) { assert(operands.size() == __OP_OPERAND_COUNT__); @@ -19,6 +21,9 @@ void __OP_CPP_NAME__::build(::mlir::OpBuilder &odsBuilder, odsState.addOperands(operands); odsState.addAttributes(attributes); odsState.addTypes(resultTypes); + for (unsigned i = 0; i != __OP_REGION_COUNT__; ++i) { + (void)odsState.addRegion(); + } } __OP_CPP_NAME__ @@ -44,6 +49,7 @@ __OP_CPP_NAME__::create(::mlir::ImplicitLocOpBuilder &odsBuilder, return create(odsBuilder, odsBuilder.getLoc(), resultTypes, operands, attributes); } +__OP_VERIFIER__ __NAMESPACE_CLOSE__ diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 53209a4..9fcb02e 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -3175,6 +3175,45 @@ applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder, return success(); } +/// Apply a `#pragma omp tile` / `!$omp tile` transformation using the +/// OpenMPIRBuilder. +static LogicalResult applyTile(omp::TileOp op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); + llvm::OpenMPIRBuilder::LocationDescription loc(builder); + + SmallVector<llvm::CanonicalLoopInfo *> translatedLoops; + SmallVector<llvm::Value *> translatedSizes; + + for (Value size : op.getSizes()) { + llvm::Value *translatedSize = moduleTranslation.lookupValue(size); + assert(translatedSize && + "sizes clause arguments must already be translated"); + translatedSizes.push_back(translatedSize); + } + + for (Value applyee : op.getApplyees()) { + llvm::CanonicalLoopInfo *consBuilderCLI = + moduleTranslation.lookupOMPLoop(applyee); + assert(applyee && "Canonical loop must already been translated"); + translatedLoops.push_back(consBuilderCLI); + } + + auto generatedLoops = + ompBuilder->tileLoops(loc.DL, translatedLoops, translatedSizes); + if (!op.getGeneratees().empty()) { + for (auto [mlirLoop, genLoop] : + zip_equal(op.getGeneratees(), generatedLoops)) + moduleTranslation.mapOmpLoop(mlirLoop, genLoop); + } + + // CLIs can only be consumed once + for (Value applyee : op.getApplyees()) + moduleTranslation.invalidateOmpLoop(applyee); + + return success(); +} + /// Convert an Atomic Ordering attribute to llvm::AtomicOrdering. static llvm::AtomicOrdering convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) { @@ -6227,6 +6266,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, // the omp.canonical_loop. return applyUnrollHeuristic(op, builder, moduleTranslation); }) + .Case([&](omp::TileOp op) { + return applyTile(op, builder, moduleTranslation); + }) .Case([&](omp::TargetAllocMemOp) { return convertTargetAllocMemOp(*op, builder, moduleTranslation); }) |