diff options
Diffstat (limited to 'mlir/lib/Dialect')
25 files changed, 526 insertions, 194 deletions
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 7cfd6d3..898d76c 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1282,6 +1282,13 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_OneFloat())) return getLhs(); + if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan | + arith::FastMathFlags::nsz)) { + // mulf(x, 0) -> 0 + if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat())) + return getRhs(); + } + return constFoldBinaryOp<FloatAttr>( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a * b; }); diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 7626d35..c64e10f5 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -123,7 +123,8 @@ void mlir::arith::populateEmulateUnsupportedFloatsLegality( vector::OuterProductOp, vector::ScanOp>( [&](Operation *op) { return converter.isLegal(op); }); target.addLegalOp<arith::BitcastOp, arith::ExtFOp, arith::TruncFOp, - arith::ConstantOp, vector::SplatOp, vector::BroadcastOp>(); + arith::ConstantOp, arith::SelectOp, vector::SplatOp, + vector::BroadcastOp>(); } void EmulateUnsupportedFloatsPass::runOnOperation() { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index a173cf1..5672942 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -33,6 +33,7 @@ #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" +#include "llvm/Support/InterleavedRange.h" #include <cstddef> #include <iterator> #include <optional> @@ -77,6 +78,232 @@ struct LLVMPointerPointerLikeModel }; } // namespace +/// Generate a name of a canonical loop nest of the format +/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region +/// argument index of an operation that has multiple regions, if the operation +/// has multiple regions. +/// `_s<idx>` identifies the position of an operation within a region, where +/// only operations that may potentially contain loops ("container operations" +/// i.e. have region arguments) are counted. Again, it is omitted if there is +/// only one such operation in a region. If there are canonical loops nested +/// inside each other, also may also use the format `_d<num>` where <num> is the +/// nesting depth of the loop. +/// +/// The generated name is a best-effort to make canonical loop unique within an +/// SSA namespace. This also means that regions with IsolatedFromAbove property +/// do not consider any parents or siblings. +static std::string generateLoopNestingName(StringRef prefix, + CanonicalLoopOp op) { + struct Component { + /// If true, this component describes a region operand of an operation (the + /// operand's owner) If false, this component describes an operation located + /// in a parent region + bool isRegionArgOfOp; + bool skip = false; + bool isUnique = false; + + size_t idx; + Operation *op; + Region *parentRegion; + size_t loopDepth; + + Operation *&getOwnerOp() { + assert(isRegionArgOfOp && "Must describe a region operand"); + return op; + } + size_t &getArgIdx() { + assert(isRegionArgOfOp && "Must describe a region operand"); + return idx; + } + + Operation *&getContainerOp() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return op; + } + size_t &getOpPos() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return idx; + } + bool isLoopOp() const { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return isa<CanonicalLoopOp>(op); + } + Region *&getParentRegion() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return parentRegion; + } + size_t &getLoopDepth() { + assert(!isRegionArgOfOp && "Must describe a operation of a region"); + return loopDepth; + } + + void skipIf(bool v = true) { skip = skip || v; } + }; + + // List of ancestors, from inner to outer. + // Alternates between + // * region argument of an operation + // * operation within a region + SmallVector<Component> components; + + // Gather a list of parent regions and operations, and the position within + // their parent + Operation *o = op.getOperation(); + while (o) { + // Operation within a region + Region *r = o->getParentRegion(); + if (!r) + break; + + llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front()); + size_t idx = 0; + bool found = false; + size_t sequentialIdx = -1; + bool isOnlyContainerOp = true; + for (Block *b : traversal) { + for (Operation &op : *b) { + if (&op == o && !found) { + sequentialIdx = idx; + found = true; + } + if (op.getNumRegions()) { + idx += 1; + if (idx > 1) + isOnlyContainerOp = false; + } + if (found && !isOnlyContainerOp) + break; + } + } + + Component &containerOpInRegion = components.emplace_back(); + containerOpInRegion.isRegionArgOfOp = false; + containerOpInRegion.isUnique = isOnlyContainerOp; + containerOpInRegion.getContainerOp() = o; + containerOpInRegion.getOpPos() = sequentialIdx; + containerOpInRegion.getParentRegion() = r; + + Operation *parent = r->getParentOp(); + + // Region argument of an operation + Component ®ionArgOfOperation = components.emplace_back(); + regionArgOfOperation.isRegionArgOfOp = true; + regionArgOfOperation.isUnique = true; + regionArgOfOperation.getArgIdx() = 0; + regionArgOfOperation.getOwnerOp() = parent; + + // The IsolatedFromAbove trait of the parent operation implies that each + // individual region argument has its own separate namespace, so no + // ambiguity. + if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>()) + break; + + // Component only needed if operation has multiple region operands. Region + // arguments may be optional, but we currently do not consider this. + if (parent->getRegions().size() > 1) { + auto getRegionIndex = [](Operation *o, Region *r) { + for (auto [idx, region] : llvm::enumerate(o->getRegions())) { + if (®ion == r) + return idx; + } + llvm_unreachable("Region not child of its parent operation"); + }; + regionArgOfOperation.isUnique = false; + regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r); + } + + // next parent + o = parent; + } + + // Determine whether a region-argument component is not needed + for (Component &c : components) + c.skipIf(c.isRegionArgOfOp && c.isUnique); + + // Find runs of nested loops and determine each loop's depth in the loop nest + size_t numSurroundingLoops = 0; + for (Component &c : llvm::reverse(components)) { + if (c.skip) + continue; + + // non-skipped multi-argument operands interrupt the loop nest + if (c.isRegionArgOfOp) { + numSurroundingLoops = 0; + continue; + } + + // Multiple loops in a region means each of them is the outermost loop of a + // new loop nest + if (!c.isUnique) + numSurroundingLoops = 0; + + c.getLoopDepth() = numSurroundingLoops; + + // Next loop is surrounded by one more loop + if (isa<CanonicalLoopOp>(c.getContainerOp())) + numSurroundingLoops += 1; + } + + // In loop nests, skip all but the innermost loop that contains the depth + // number + bool isLoopNest = false; + for (Component &c : components) { + if (c.skip || c.isRegionArgOfOp) + continue; + + if (!isLoopNest && c.getLoopDepth() >= 1) { + // Innermost loop of a loop nest of at least two loops + isLoopNest = true; + } else if (isLoopNest) { + // Non-innermost loop of a loop nest + c.skipIf(c.isUnique); + + // If there is no surrounding loop left, this must have been the outermost + // loop; leave loop-nest mode for the next iteration + if (c.getLoopDepth() == 0) + isLoopNest = false; + } + } + + // Skip non-loop unambiguous regions (but they should interrupt loop nests, so + // we mark them as skipped only after computing loop nests) + for (Component &c : components) + c.skipIf(!c.isRegionArgOfOp && c.isUnique && + !isa<CanonicalLoopOp>(c.getContainerOp())); + + // Components can be skipped if they are already disambiguated by their parent + // (or does not have a parent) + bool newRegion = true; + for (Component &c : llvm::reverse(components)) { + c.skipIf(newRegion && c.isUnique); + + // non-skipped components disambiguate unique children + if (!c.skip) + newRegion = true; + + // ...except canonical loops that need a suffix for each nest + if (!c.isRegionArgOfOp && c.getContainerOp()) + newRegion = false; + } + + // Compile the nesting name string + SmallString<64> Name{prefix}; + llvm::raw_svector_ostream NameOS(Name); + for (auto &c : llvm::reverse(components)) { + if (c.skip) + continue; + + if (c.isRegionArgOfOp) + NameOS << "_r" << c.getArgIdx(); + else if (c.getLoopDepth() >= 1) + NameOS << "_d" << c.getLoopDepth(); + else + NameOS << "_s" << c.getOpPos(); + } + + return NameOS.str().str(); +} + void OpenMPDialect::initialize() { addOperations< #define GET_OP_LIST @@ -3159,6 +3386,9 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { Value result = getResult(); auto [newCli, gen, cons] = decodeCli(result); + // Structured binding `gen` cannot be captured in lambdas before C++20 + OpOperand *generator = gen; + // Derive the CLI variable name from its generator: // * "canonloop" for omp.canonical_loop // * custom name for loop transformation generatees @@ -3172,71 +3402,29 @@ void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) { cliName = TypeSwitch<Operation *, std::string>(gen->getOwner()) .Case([&](CanonicalLoopOp op) { - // Find the canonical loop nesting: For each ancestor add a - // "+_r<idx>" suffix (in reverse order) - SmallVector<std::string> components; - Operation *o = op.getOperation(); - while (o) { - if (o->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>()) - break; - - Region *r = o->getParentRegion(); - if (!r) - break; - - auto getSequentialIndex = [](Region *r, Operation *o) { - llvm::ReversePostOrderTraversal<Block *> traversal( - &r->getBlocks().front()); - size_t idx = 0; - for (Block *b : traversal) { - for (Operation &op : *b) { - if (&op == o) - return idx; - // Only consider operations that are containers as - // possible children - if (!op.getRegions().empty()) - idx += 1; - } - } - llvm_unreachable("Operation not part of the region"); - }; - size_t sequentialIdx = getSequentialIndex(r, o); - components.push_back(("s" + Twine(sequentialIdx)).str()); - - Operation *parent = r->getParentOp(); - if (!parent) - break; - - // If the operation has more than one region, also count in - // which of the regions - if (parent->getRegions().size() > 1) { - auto getRegionIndex = [](Operation *o, Region *r) { - for (auto [idx, region] : - llvm::enumerate(o->getRegions())) { - if (®ion == r) - return idx; - } - llvm_unreachable("Region not child its parent operation"); - }; - size_t regionIdx = getRegionIndex(parent, r); - components.push_back(("r" + Twine(regionIdx)).str()); - } - - // next parent - o = parent; - } - - SmallString<64> Name("canonloop"); - for (const std::string &s : reverse(components)) { - Name += '_'; - Name += s; - } - - return Name; + return generateLoopNestingName("canonloop", op); }) .Case([&](UnrollHeuristicOp op) -> std::string { llvm_unreachable("heuristic unrolling does not generate a loop"); }) + .Case([&](TileOp op) -> std::string { + auto [generateesFirst, generateesCount] = + op.getGenerateesODSOperandIndexAndLength(); + unsigned firstGrid = generateesFirst; + unsigned firstIntratile = generateesFirst + generateesCount / 2; + unsigned end = generateesFirst + generateesCount; + unsigned opnum = generator->getOperandNumber(); + // In the OpenMP apply and looprange clauses, indices are 1-based + if (firstGrid <= opnum && opnum < firstIntratile) { + unsigned gridnum = opnum - firstGrid + 1; + return ("grid" + Twine(gridnum)).str(); + } + if (firstIntratile <= opnum && opnum < end) { + unsigned intratilenum = opnum - firstIntratile + 1; + return ("intratile" + Twine(intratilenum)).str(); + } + llvm_unreachable("Unexpected generatee argument"); + }) .Default([&](Operation *op) { assert(false && "TODO: Custom name for this operation"); return "transformed"; @@ -3323,7 +3511,8 @@ void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) { void CanonicalLoopOp::getAsmBlockArgumentNames(Region ®ion, OpAsmSetValueNameFn setNameFn) { - setNameFn(region.getArgument(0), "iv"); + std::string ivName = generateLoopNestingName("iv", *this); + setNameFn(region.getArgument(0), ivName); } void CanonicalLoopOp::print(OpAsmPrinter &p) { @@ -3465,6 +3654,138 @@ UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() { } //===----------------------------------------------------------------------===// +// TileOp +//===----------------------------------------------------------------------===// + +static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, + OperandRange generatees, + OperandRange applyees) { + if (!generatees.empty()) + p << '(' << llvm::interleaved(generatees) << ')'; + + if (!applyees.empty()) + p << " <- (" << llvm::interleaved(applyees) << ')'; +} + +static ParseResult parseLoopTransformClis( + OpAsmParser &parser, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &generateesOperands, + SmallVectorImpl<OpAsmParser::UnresolvedOperand> &applyeesOperands) { + if (parser.parseOptionalLess()) { + // Syntax 1: generatees present + + if (parser.parseOperandList(generateesOperands, + mlir::OpAsmParser::Delimiter::Paren)) + return failure(); + + if (parser.parseLess()) + return failure(); + } else { + // Syntax 2: generatees omitted + } + + // Parse `<-` (`<` has already been parsed) + if (parser.parseMinus()) + return failure(); + + if (parser.parseOperandList(applyeesOperands, + mlir::OpAsmParser::Delimiter::Paren)) + return failure(); + + return success(); +} + +LogicalResult TileOp::verify() { + if (getApplyees().empty()) + return emitOpError() << "must apply to at least one loop"; + + if (getSizes().size() != getApplyees().size()) + return emitOpError() << "there must be one tile size for each applyee"; + + if (!getGeneratees().empty() && + 2 * getSizes().size() != getGeneratees().size()) + return emitOpError() + << "expecting two times the number of generatees than applyees"; + + DenseSet<Value> parentIVs; + + Value parent = getApplyees().front(); + for (auto &&applyee : llvm::drop_begin(getApplyees())) { + auto [parentCreate, parentGen, parentCons] = decodeCli(parent); + auto [create, gen, cons] = decodeCli(applyee); + + if (!parentGen) + return emitOpError() << "applyee CLI has no generator"; + + auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner()); + if (!parentGen) + return emitOpError() + << "currently only supports omp.canonical_loop as applyee"; + + parentIVs.insert(parentLoop.getInductionVar()); + + if (!gen) + return emitOpError() << "applyee CLI has no generator"; + auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner()); + if (!loop) + return emitOpError() + << "currently only supports omp.canonical_loop as applyee"; + + // Canonical loop must be perfectly nested, i.e. the body of the parent must + // only contain the omp.canonical_loop of the nested loops, and + // omp.terminator + bool isPerfectlyNested = [&]() { + auto &parentBody = parentLoop.getRegion(); + if (!parentBody.hasOneBlock()) + return false; + auto &parentBlock = parentBody.getBlocks().front(); + + auto nestedLoopIt = parentBlock.begin(); + if (nestedLoopIt == parentBlock.end() || + (&*nestedLoopIt != loop.getOperation())) + return false; + + auto termIt = std::next(nestedLoopIt); + if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt)) + return false; + + if (std::next(termIt) != parentBlock.end()) + return false; + + return true; + }(); + if (!isPerfectlyNested) + return emitOpError() << "tiled loop nest must be perfectly nested"; + + if (parentIVs.contains(loop.getTripCount())) + return emitOpError() << "tiled loop nest must be rectangular"; + + parent = applyee; + } + + // TODO: The tile sizes must be computed before the loop, but checking this + // requires dominance analysis. For instance: + // + // %canonloop = omp.new_cli + // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) { + // // write to %x + // omp.terminator + // } + // %ts = llvm.load %x + // omp.tile <- (%canonloop) sizes(%ts : i32) + + return success(); +} + +std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() { + return getODSOperandIndexAndLength(odsIndex_applyees); +} + +std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() { + return getODSOperandIndexAndLength(odsIndex_generatees); +} + +//===----------------------------------------------------------------------===// // Critical construct (2.17.1) //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 132ed81..3385b2a 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -616,11 +616,10 @@ DiagnosedSilenceableFailure transform::ApplyConversionPatternsOp::apply( if (diag.succeeded()) { // Tracking failure is the only failure. return trackingFailure; - } else { - diag.attachNote() << "tracking listener also failed: " - << trackingFailure.getMessage(); - (void)trackingFailure.silence(); } + diag.attachNote() << "tracking listener also failed: " + << trackingFailure.getMessage(); + (void)trackingFailure.silence(); } if (!diag.succeeded()) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index eb46869..b0132e8 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -580,7 +580,7 @@ namespace { // ElideSingleElementReduction for ReduceOp. struct ElideUnitDimsInMultiDimReduction : public OpRewritePattern<MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, PatternRewriter &rewriter) const override { @@ -730,7 +730,7 @@ std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() { namespace { struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ReductionOp reductionOp, PatternRewriter &rewriter) const override { @@ -2197,7 +2197,7 @@ namespace { // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -2220,7 +2220,7 @@ public: // Pattern to rewrite a ExtractOp(CreateMask) -> CreateMask. class ExtractOpFromCreateMask final : public OpRewritePattern<ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -2546,7 +2546,7 @@ rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp, class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(FromElementsOp fromElements, PatternRewriter &rewriter) const override { @@ -2938,7 +2938,7 @@ namespace { // Fold broadcast1(broadcast2(x)) into broadcast1(x). struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(BroadcastOp broadcastOp, PatternRewriter &rewriter) const override { @@ -3109,7 +3109,7 @@ namespace { // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector // to a broadcast. struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { @@ -3165,7 +3165,7 @@ static Value getScalarSplatSource(Value value) { /// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v). class ShuffleSplat final : public OpRewritePattern<ShuffleOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { @@ -3182,7 +3182,7 @@ public: /// vector.interleave. class ShuffleInterleave : public OpRewritePattern<ShuffleOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShuffleOp op, PatternRewriter &rewriter) const override { @@ -3326,7 +3326,7 @@ namespace { // broadcast. class InsertToBroadcast final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp insertOp, PatternRewriter &rewriter) const override { @@ -3344,7 +3344,7 @@ public: /// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v). class InsertSplatToSplat final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { @@ -3380,7 +3380,7 @@ public: /// %result = vector.from_elements %c1, %c2 : vector<2xi32> class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertOp op, PatternRewriter &rewriter) const override { @@ -3748,7 +3748,7 @@ namespace { class FoldInsertStridedSliceSplat final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -3768,7 +3768,7 @@ public: class FoldInsertStridedSliceOfExtract final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, PatternRewriter &rewriter) const override { @@ -3798,7 +3798,7 @@ public: class InsertStridedSliceConstantFolder final : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; // Do not create constants with more than `vectorSizeFoldThreashold` elements, // unless the source vector constant has a single use. @@ -4250,7 +4250,7 @@ namespace { // %mask = vector.create_mask %new_ub : vector<8xi1> class StridedSliceCreateMaskFolder final : public OpRewritePattern<ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; public: LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, @@ -4310,7 +4310,7 @@ public: class StridedSliceConstantMaskFolder final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, PatternRewriter &rewriter) const override { @@ -4365,7 +4365,7 @@ public: class StridedSliceBroadcast final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -4416,7 +4416,7 @@ public: /// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v). class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -4448,7 +4448,7 @@ public: class ContiguousExtractStridedSliceToExtract final : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -5023,7 +5023,7 @@ namespace { /// ``` struct TransferReadAfterWriteToBroadcast : public OpRewritePattern<TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -5458,7 +5458,7 @@ namespace { /// any other uses. class FoldWaw final : public OpRewritePattern<TransferWriteOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransferWriteOp writeOp, PatternRewriter &rewriter) const override { if (!llvm::isa<RankedTensorType>(writeOp.getShapedType())) @@ -5514,7 +5514,7 @@ public: struct SwapExtractSliceOfTransferWrite : public OpRewritePattern<tensor::InsertSliceOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -5737,7 +5737,7 @@ LogicalResult MaskedLoadOp::verify() { namespace { class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskedLoadOp load, PatternRewriter &rewriter) const override { switch (getMaskFormat(load.getMask())) { @@ -5794,7 +5794,7 @@ LogicalResult MaskedStoreOp::verify() { namespace { class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskedStoreOp store, PatternRewriter &rewriter) const override { switch (getMaskFormat(store.getMask())) { @@ -5890,7 +5890,7 @@ static LogicalResult isZeroBasedContiguousSeq(Value indexVec) { namespace { class GatherFolder final : public OpRewritePattern<GatherOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(GatherOp gather, PatternRewriter &rewriter) const override { switch (getMaskFormat(gather.getMask())) { @@ -5910,7 +5910,7 @@ public: /// maskedload. Only 1D fixed vectors are supported for now. class FoldContiguousGather final : public OpRewritePattern<GatherOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(GatherOp op, PatternRewriter &rewriter) const override { if (!isa<MemRefType>(op.getBase().getType())) @@ -5962,7 +5962,7 @@ LogicalResult ScatterOp::verify() { namespace { class ScatterFolder final : public OpRewritePattern<ScatterOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ScatterOp scatter, PatternRewriter &rewriter) const override { switch (getMaskFormat(scatter.getMask())) { @@ -5982,7 +5982,7 @@ public: /// maskedstore. Only 1D fixed vectors are supported for now. class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ScatterOp op, PatternRewriter &rewriter) const override { if (failed(isZeroBasedContiguousSeq(op.getIndices()))) @@ -6030,7 +6030,7 @@ LogicalResult ExpandLoadOp::verify() { namespace { class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExpandLoadOp expand, PatternRewriter &rewriter) const override { switch (getMaskFormat(expand.getMask())) { @@ -6081,7 +6081,7 @@ LogicalResult CompressStoreOp::verify() { namespace { class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(CompressStoreOp compress, PatternRewriter &rewriter) const override { switch (getMaskFormat(compress.getMask())) { @@ -6260,7 +6260,7 @@ static VectorType trimTrailingOneDims(VectorType oldType) { class ShapeCastCreateMaskFolderTrailingOneDim final : public OpRewritePattern<ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShapeCastOp shapeOp, PatternRewriter &rewriter) const override { @@ -6330,7 +6330,7 @@ public: /// If both (i) and (ii) are possible, (i) is chosen. class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, PatternRewriter &rewriter) const override { @@ -6614,7 +6614,7 @@ namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6646,7 +6646,7 @@ public: /// Replace transpose(splat-like(v)) with broadcast(v) class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6663,7 +6663,7 @@ public: /// Folds transpose(create_mask) into a new transposed create_mask. class FoldTransposeCreateMask final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transpOp, PatternRewriter &rewriter) const override { @@ -6700,7 +6700,7 @@ public: /// Folds transpose(shape_cast) into a new shape_cast. class FoldTransposeShapeCast final : public OpRewritePattern<TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(TransposeOp transposeOp, PatternRewriter &rewriter) const override { @@ -6750,7 +6750,7 @@ public: /// within the groups [0,1] and [3,4], like (1 0 2 4 3 5 6). class FoldTransposeBroadcast : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; FoldTransposeBroadcast(MLIRContext *context, PatternBenefit benefit = 1) : OpRewritePattern<vector::TransposeOp>(context, benefit) {} @@ -6971,7 +6971,7 @@ namespace { /// %0 = vector.constant_mask [8, 16] : vector<8x[16]xi1> class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, PatternRewriter &rewriter) const override { @@ -7300,7 +7300,7 @@ LogicalResult MaskOp::fold(FoldAdaptor adaptor, /// %0 = arith.select %mask, %a, %passthru : vector<8xf32> /// class CanonializeEmptyMaskOp : public OpRewritePattern<MaskOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(MaskOp maskOp, PatternRewriter &rewriter) const override { @@ -7410,7 +7410,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { // vector.broadcast. class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> { public: - using OpRewritePattern<SplatOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(SplatOp splatOp, PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(), diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index dedc3b3..61d9357 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -34,7 +34,7 @@ namespace { /// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly. class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BroadcastOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp index 65702ff..efe8d14 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorContract.cpp @@ -1151,7 +1151,7 @@ FailureOr<Value> ContractionOpLowering::lowerReduction( /// class OuterProductOpLowering : public OpRewritePattern<vector::OuterProductOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::OuterProductOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 1f96a3a..6bc8347 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -50,7 +50,7 @@ namespace { /// /// Supports vector types with a fixed leading dimension. struct UnrollGather : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { @@ -98,7 +98,7 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> { /// ATM this is effectively limited to reading a 1D Vector from a 2D MemRef, /// but should be fairly straightforward to extend beyond that. struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { @@ -164,7 +164,7 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { /// `tensor.extract`s. To avoid out-of-bounds memory accesses, these /// loads/extracts are made conditional using `scf.if` ops. struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp index 9d6a865..479fc0c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorInterleave.cpp @@ -163,7 +163,7 @@ private: /// : vector<7xi16>, vector<7xi16> /// ``` struct InterleaveToShuffle final : OpRewritePattern<vector::InterleaveOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InterleaveOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 5617b06..7730c4e 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -48,7 +48,7 @@ namespace { /// until a one-dimensional vector is reached. class CreateMaskOpLowering : public OpRewritePattern<vector::CreateMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::CreateMaskOp op, PatternRewriter &rewriter) const override { @@ -100,7 +100,7 @@ public: /// will be folded at LLVM IR level. class ConstantMaskOpLowering : public OpRewritePattern<vector::ConstantMaskOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ConstantMaskOp op, PatternRewriter &rewriter) const override { @@ -184,7 +184,7 @@ namespace { /// and actually match the traits of its the nested `MaskableOpInterface`. template <class SourceOp> struct MaskOpRewritePattern : OpRewritePattern<MaskOp> { - using OpRewritePattern<MaskOp>::OpRewritePattern; + using Base::Base; private: LogicalResult matchAndRewrite(MaskOp maskOp, diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp index 4773732d..e86e2a9 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp @@ -39,7 +39,7 @@ namespace { class InnerOuterDimReductionConversion : public OpRewritePattern<vector::MultiDimReductionOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; explicit InnerOuterDimReductionConversion( MLIRContext *context, vector::VectorMultiReductionLowering options, @@ -136,7 +136,7 @@ private: class ReduceMultiDimReductionRank : public OpRewritePattern<vector::MultiDimReductionOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; explicit ReduceMultiDimReductionRank( MLIRContext *context, vector::VectorMultiReductionLowering options, @@ -304,7 +304,7 @@ private: /// and combines results struct TwoDimMultiReductionToElementWise : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -359,7 +359,7 @@ struct TwoDimMultiReductionToElementWise /// a sequence of vector.reduction ops. struct TwoDimMultiReductionToReduction : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { @@ -420,7 +420,7 @@ struct TwoDimMultiReductionToReduction /// separately. struct OneDimMultiReductionToTwoDim : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp index af4851e..258f2cb 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorScan.cpp @@ -99,7 +99,7 @@ namespace { /// return %7, %8 : vector<2x3xi32>, vector<2xi32> /// ``` struct ScanToArithOps : public OpRewritePattern<vector::ScanOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ScanOp scanOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp index 603ea41..c5f22b2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp @@ -189,7 +189,7 @@ class ShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { } public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { @@ -356,7 +356,7 @@ public: class ScalableShapeCastOpRewritePattern : public OpRewritePattern<vector::ShapeCastOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShapeCastOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp index 78102f7..8f46ad6 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorShuffle.cpp @@ -44,7 +44,7 @@ namespace { /// struct MixedSizeInputShuffleOpRewrite final : OpRewritePattern<vector::ShuffleOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ShuffleOp shuffleOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp index ee5568a..08e7c89 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorStep.cpp @@ -24,7 +24,7 @@ using namespace mlir::vector; namespace { struct StepToArithConstantOpRewrite final : OpRewritePattern<vector::StepOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::StepOp stepOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp index 6407a86..7521e24 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorToFromElementsToShuffleTree.cpp @@ -667,7 +667,7 @@ getToElementsDefiningOps(FromElementsOp fromElemsOp, struct ToFromElementsToShuffleTreeRewrite final : OpRewritePattern<vector::FromElementsOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::FromElementsOp fromElemsOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp index 9e7d0ce..c3f7de0 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTranspose.cpp @@ -300,7 +300,7 @@ namespace { /// %x = vector.insert .., .. [.., ..] class TransposeOpLowering : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; TransposeOpLowering(vector::VectorTransposeLowering vectorTransposeLowering, MLIRContext *context, PatternBenefit benefit = 1) @@ -395,7 +395,7 @@ private: class Transpose2DWithUnitDimToShapeCast : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; Transpose2DWithUnitDimToShapeCast(MLIRContext *context, PatternBenefit benefit = 1) @@ -433,7 +433,7 @@ public: class TransposeOp2DToShuffleLowering : public OpRewritePattern<vector::TransposeOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; TransposeOp2DToShuffleLowering( vector::VectorTransposeLowering vectorTransposeLowering, diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index cab1289..963b2c8 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -54,7 +54,7 @@ namespace { // input by inserting vector.broadcast. struct CastAwayExtractStridedSliceLeadingOneDim : public OpRewritePattern<vector::ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { @@ -104,7 +104,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim // inputs by inserting vector.broadcast. struct CastAwayInsertStridedSliceLeadingOneDim : public OpRewritePattern<vector::InsertStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InsertStridedSliceOp insertOp, PatternRewriter &rewriter) const override { @@ -145,7 +145,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim // Casts away leading one dimensions in vector.insert's vector inputs by // inserting vector.broadcast. struct CastAwayInsertLeadingOneDim : public OpRewritePattern<vector::InsertOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::InsertOp insertOp, PatternRewriter &rewriter) const override { @@ -221,7 +221,7 @@ static Value dropUnitDimsFromMask(OpBuilder &b, Location loc, Value mask, // 1 dimensions. struct CastAwayTransferReadLeadingOneDim : public OpRewritePattern<vector::TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp read, PatternRewriter &rewriter) const override { @@ -275,7 +275,7 @@ struct CastAwayTransferReadLeadingOneDim // 1 dimensions. struct CastAwayTransferWriteLeadingOneDim : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp write, PatternRewriter &rewriter) const override { @@ -541,7 +541,7 @@ public: // vector.broadcast back to the original shape. struct CastAwayConstantMaskLeadingOneDim : public OpRewritePattern<vector::ConstantMaskOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ConstantMaskOp mask, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp index bdbb792..7acc120 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateMaskedLoadStore.cpp @@ -48,7 +48,7 @@ namespace { /// struct VectorMaskedLoadOpConverter final : OpRewritePattern<vector::MaskedLoadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedLoadOp, PatternRewriter &rewriter) const override { @@ -117,7 +117,7 @@ struct VectorMaskedLoadOpConverter final /// struct VectorMaskedStoreOpConverter final : OpRewritePattern<vector::MaskedStoreOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedStoreOp maskedStoreOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index 264cbc1..3a6684f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -548,7 +548,7 @@ namespace { // NOTE: By default, all RMW sequences are atomic. Set `disableAtomicRMW` to // `false` to generate non-atomic RMW sequences. struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; ConvertVectorStore(MLIRContext *context, bool disableAtomicRMW) : OpConversionPattern<vector::StoreOp>(context), @@ -827,7 +827,7 @@ private: /// adjusted mask . struct ConvertVectorMaskedStore final : OpConversionPattern<vector::MaskedStoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedStoreOp op, OpAdaptor adaptor, @@ -950,7 +950,7 @@ struct ConvertVectorMaskedStore final /// those cases, loads are converted to byte-aligned, byte-sized loads and the /// target vector is extracted from the loaded vector. struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::LoadOp op, OpAdaptor adaptor, @@ -1059,7 +1059,7 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> { /// bitcasting, since each `i8` container element holds two `i4` values. struct ConvertVectorMaskedLoad final : OpConversionPattern<vector::MaskedLoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MaskedLoadOp op, OpAdaptor adaptor, @@ -1257,7 +1257,7 @@ static bool fitsInMultiByteContainerTy(VectorType subByteVecTy, // TODO: Document-me struct ConvertVectorTransferRead final : OpConversionPattern<vector::TransferReadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp op, OpAdaptor adaptor, @@ -1942,7 +1942,7 @@ namespace { /// advantage of high-level information to avoid leaving LLVM to scramble with /// peephole optimizations. struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitCastOp, PatternRewriter &rewriter) const override { @@ -2147,7 +2147,7 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> { /// %5 = vector.bitcast %4 : vector<4xi8> to vector<8xi4> /// struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> { - using OpRewritePattern<arith::TruncIOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(arith::TruncIOp truncOp, PatternRewriter &rewriter) const override { @@ -2200,7 +2200,7 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> { /// %2 = arith.trunci %1 : vector<16x8xi8> to vector<16x8xi4> /// struct RewriteVectorTranspose : OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern<vector::TransposeOp>::OpRewritePattern; + using Base::Base; RewriteVectorTranspose(MLIRContext *context, PatternBenefit benefit) : OpRewritePattern<vector::TransposeOp>(context, benefit) {} diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index f6d6555..9e49873 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -34,7 +34,7 @@ using namespace mlir::vector; class DecomposeDifferentRankInsertStridedSlice : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(InsertStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -84,7 +84,7 @@ public: class ConvertSameRankInsertStridedSliceIntoShuffle : public OpRewritePattern<InsertStridedSliceOp> { public: - using OpRewritePattern<InsertStridedSliceOp>::OpRewritePattern; + using Base::Base; void initialize() { // This pattern creates recursive InsertStridedSliceOp, but the recursion is @@ -183,7 +183,7 @@ public: class Convert1DExtractStridedSliceIntoShuffle : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(ExtractStridedSliceOp op, PatternRewriter &rewriter) const override { @@ -271,7 +271,7 @@ private: class DecomposeNDExtractStridedSlice : public OpRewritePattern<ExtractStridedSliceOp> { public: - using OpRewritePattern<ExtractStridedSliceOp>::OpRewritePattern; + using Base::Base; void initialize() { // This pattern creates recursive ExtractStridedSliceOp, but the recursion diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 82bac8c..71fba71c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -214,7 +214,7 @@ SmallVector<int64_t> static getStridedSliceInsertionIndices( /// vector.extract_strided_slice operation. struct LinearizeVectorExtractStridedSlice final : public mlir::OpConversionPattern<mlir::vector::ExtractStridedSliceOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorExtractStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -285,7 +285,7 @@ struct LinearizeVectorExtractStridedSlice final /// struct LinearizeVectorInsertStridedSlice final : public mlir::OpConversionPattern<mlir::vector::InsertStridedSliceOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorInsertStridedSlice(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -348,7 +348,7 @@ struct LinearizeVectorInsertStridedSlice final /// of the original shuffle operation. struct LinearizeVectorShuffle final : public OpConversionPattern<vector::ShuffleOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorShuffle(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -423,7 +423,7 @@ struct LinearizeVectorShuffle final /// struct LinearizeVectorExtract final : public OpConversionPattern<vector::ExtractOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorExtract(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -501,7 +501,7 @@ struct LinearizeVectorExtract final /// struct LinearizeVectorInsert final : public OpConversionPattern<vector::InsertOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorInsert(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -575,7 +575,7 @@ struct LinearizeVectorInsert final /// %out_nd = vector.shape_cast %out_1d: vector<16xf16> to vector<4x4xf16> struct LinearizeVectorBitCast final : public OpConversionPattern<vector::BitCastOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorBitCast(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -598,7 +598,7 @@ struct LinearizeVectorBitCast final /// %out_nd = vector.shape_cast %out_1d : vector<16xf32> to vector<4x4xf32> struct LinearizeVectorSplat final : public OpConversionPattern<vector::SplatOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorSplat(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -629,7 +629,7 @@ struct LinearizeVectorSplat final /// %shape_cast = vector.shape_cast %mask : vector<4xi1> to vector<1x4xi1> struct LinearizeVectorCreateMask final : OpConversionPattern<vector::CreateMaskOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorCreateMask(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) @@ -684,7 +684,7 @@ struct LinearizeVectorCreateMask final /// For generic cases, the vector unroll pass should be used to unroll the load /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorLoad(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -731,7 +731,7 @@ struct LinearizeVectorLoad final : public OpConversionPattern<vector::LoadOp> { /// to vector<1x1x...xN> form and then linearized struct LinearizeVectorStore final : public OpConversionPattern<vector::StoreOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorStore(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -778,7 +778,7 @@ struct LinearizeVectorStore final /// struct LinearizeVectorFromElements final : public OpConversionPattern<vector::FromElementsOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorFromElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) : OpConversionPattern(typeConverter, context, benefit) {} @@ -814,7 +814,7 @@ struct LinearizeVectorFromElements final /// struct LinearizeVectorToElements final : public OpConversionPattern<vector::ToElementsOp> { - using OpConversionPattern::OpConversionPattern; + using Base::Base; LinearizeVectorToElements(const TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index c364a8b..1121d95 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -1081,7 +1081,7 @@ private: /// Rewrite transfer_writes of vectors of size 1 (e.g., vector<1x1xf32>) /// to memref.store. class RewriteScalarWrite : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 866f789..d6a6d7cd 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -78,7 +78,7 @@ namespace { /// ``` struct MultiReduceToContract : public OpRewritePattern<vector::MultiDimReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::MultiDimReductionOp reduceOp, PatternRewriter &rewriter) const override { @@ -138,7 +138,7 @@ struct MultiReduceToContract /// ``` struct CombineContractABTranspose final : public OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -202,7 +202,7 @@ struct CombineContractABTranspose final /// ``` struct CombineContractResultTranspose final : public OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp resTOp, PatternRewriter &rewriter) const override { @@ -568,7 +568,7 @@ static SmallVector<int64_t> getIntValueVector(ArrayAttr arrayAttr) { // %2 = vector.extract %1[1] : f16 from vector<2xf16> struct BubbleDownVectorBitCastForExtract : public OpRewritePattern<vector::ExtractOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp extractOp, PatternRewriter &rewriter) const override { @@ -643,7 +643,7 @@ struct BubbleDownVectorBitCastForExtract // %1 = vector.bitcast %0 : vector<2xf32> to vector<4xf16> struct BubbleDownBitCastForStridedSliceExtract : public OpRewritePattern<vector::ExtractStridedSliceOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractStridedSliceOp extractOp, PatternRewriter &rewriter) const override { @@ -721,7 +721,7 @@ struct BubbleDownBitCastForStridedSliceExtract // %2 = vector.insert %0, %1 [4] : vector<16xi8> into vector<8x16xi8> // struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override { @@ -794,7 +794,7 @@ struct BubbleUpBitCastForInsert : public OpRewritePattern<vector::BitCastOp> { // offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> struct BubbleUpBitCastForStridedSliceInsert : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::BitCastOp bitcastOp, PatternRewriter &rewriter) const override { @@ -892,7 +892,7 @@ struct BubbleUpBitCastForStridedSliceInsert // %7 = vector.insert_strided_slice %6, %cst { // offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> struct BreakDownVectorBitCast : public OpRewritePattern<vector::BitCastOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; public: BreakDownVectorBitCast(MLIRContext *context, @@ -1131,7 +1131,7 @@ struct ReorderElementwiseOpsOnBroadcast final class ExtractOpFromElementwise final : public OpRewritePattern<vector::ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { @@ -1206,7 +1206,7 @@ static bool isSupportedMemSinkElementType(Type type) { /// ``` class ExtractOpFromLoad final : public OpRewritePattern<vector::ExtractOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ExtractOp op, PatternRewriter &rewriter) const override { @@ -1285,7 +1285,7 @@ public: class StoreOpFromSplatOrBroadcast final : public OpRewritePattern<vector::StoreOp> { public: - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::StoreOp op, PatternRewriter &rewriter) const override { @@ -1476,7 +1476,7 @@ static bool allI1ConstantValuesSetTo(arith::ConstantOp constantOp, bool value) { /// InstCombine seems to handle vectors with multiple elements but not the /// single element ones. struct FoldI1Select : public OpRewritePattern<arith::SelectOp> { - using OpRewritePattern<arith::SelectOp>::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(arith::SelectOp selectOp, PatternRewriter &rewriter) const override { @@ -1560,7 +1560,7 @@ getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) { /// Drop inner most contiguous unit dimensions from transfer_read operand. class DropInnerMostUnitDimsTransferRead : public OpRewritePattern<vector::TransferReadOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferReadOp readOp, PatternRewriter &rewriter) const override { @@ -1651,7 +1651,7 @@ class DropInnerMostUnitDimsTransferRead /// Note, this pattern will not collapse "scalable unit" dims (i.e. `[1]`). class DropInnerMostUnitDimsTransferWrite : public OpRewritePattern<vector::TransferWriteOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp, PatternRewriter &rewriter) const override { @@ -1728,7 +1728,7 @@ class DropInnerMostUnitDimsTransferWrite /// with the RHS transposed) lowering. struct CanonicalizeContractMatmulToMMT final : OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; using FilterConstraintType = std::function<LogicalResult(vector::ContractionOp op)>; @@ -1845,7 +1845,7 @@ private: template <typename ExtOp> struct FoldArithExtIntoContractionOp : public OpRewritePattern<vector::ContractionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { @@ -1878,7 +1878,7 @@ struct FoldArithExtIntoContractionOp /// %b = vector.reduction <add> %a, %acc /// ``` struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { @@ -2033,7 +2033,7 @@ struct DropUnitDimFromElementwiseOps final /// ``` struct DropUnitDimsFromTransposeOp final : OpRewritePattern<vector::TransposeOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override { @@ -2110,7 +2110,7 @@ struct DropUnitDimsFromTransposeOp final /// : vector<[4]x4xf32> to vector<[4]x1x1x4xf32> /// ``` struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(scf::ForOp forOp, PatternRewriter &rewriter) const override { @@ -2155,7 +2155,7 @@ struct DropUnitDimsFromScfForOp final : OpRewritePattern<scf::ForOp> { /// %c = vector.reduction <add> %b, %acc /// ``` struct ReduceRedundantZero final : OpRewritePattern<vector::ReductionOp> { - using OpRewritePattern::OpRewritePattern; + using Base::Base; LogicalResult matchAndRewrite(vector::ReductionOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 9413a92..784e5d6 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -824,7 +824,7 @@ struct WgToSgStoreScatterOpWithOffset return failure(); xegpu::DistributeLayoutAttr layout = - xegpu::getDistributeLayoutAttr(op.getValue()); + xegpu::getDistributeLayoutAttr(op.getOperand(0)); if (!layout || !layout.isForWorkgroup()) return failure(); @@ -844,12 +844,19 @@ struct WgToSgStoreScatterOpWithOffset auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); for (auto [val, offs, mask] : llvm::zip( adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { - xegpu::StoreScatterOp::create(rewriter, loc, val, op.getDest(), offs, - mask, chunkSizeAttr, op.getL1HintAttr(), - op.getL2HintAttr(), op.getL3HintAttr()); + auto store = xegpu::StoreScatterOp::create( + rewriter, loc, val, op.getDest(), offs, mask, chunkSizeAttr, + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); // Update the layout attribute to drop sg_layout and sg_data. - if (auto newLayout = layout.dropSgLayoutAndData()) - op->setAttr("layout", newLayout); + if (!layout.getEffectiveLaneLayoutAsInt().empty() || + !layout.getEffectiveInstDataAsInt().empty()) { + for (OpOperand &operand : store->getOpOperands()) { + // Skip for operand one (memref) + if (operand.getOperandNumber() == 1) + continue; + xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData()); + } + } } rewriter.eraseOp(op); return success(); @@ -1247,10 +1254,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() { target.addDynamicallyLegalOp<xegpu::StoreScatterOp>( [=](xegpu::StoreScatterOp op) -> bool { - // Check if the layout attribute is present on the result. - auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout"); - if (!layout) - return true; + auto layout = xegpu::getDistributeLayoutAttr(op.getOperand(0)); return isLegal(layout); }); |