diff options
Diffstat (limited to 'mlir/lib')
329 files changed, 7393 insertions, 6080 deletions
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp index 51fa773..fb5649e 100644 --- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #define DEBUG_TYPE "constant-propagation" @@ -46,7 +47,7 @@ void ConstantValue::print(raw_ostream &os) const { LogicalResult SparseConstantPropagation::visitOperation( Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands, ArrayRef<Lattice<ConstantValue> *> results) { - LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n"); + LDBG() << "SCP: Visiting operation: " << *op; // Don't try to simulate the results of a region operation as we can't // guarantee that folding will be out-of-place. We don't allow in-place @@ -98,12 +99,11 @@ LogicalResult SparseConstantPropagation::visitOperation( // Merge in the result of the fold, either a constant or a value. OpFoldResult foldResult = std::get<1>(it); if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) { - LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n"); + LDBG() << "Folded to constant: " << attr; propagateIfChanged(lattice, lattice->join(ConstantValue(attr, op->getDialect()))); } else { - LLVM_DEBUG(llvm::dbgs() - << "Folded to value: " << cast<Value>(foldResult) << "\n"); + LDBG() << "Folded to value: " << cast<Value>(foldResult); AbstractSparseForwardDataFlowAnalysis::join( lattice, *getLatticeElement(cast<Value>(foldResult))); } diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index 1abdfcb..10874fd 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -23,12 +23,11 @@ #include "mlir/Support/LLVM.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #include <optional> #define DEBUG_TYPE "dead-code-analysis" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; using namespace mlir::dataflow; @@ -127,7 +126,8 @@ DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver) } LogicalResult DeadCodeAnalysis::initialize(Operation *top) { - LDBG("Initializing DeadCodeAnalysis for top-level op: " << top->getName()); + LDBG() << "Initializing DeadCodeAnalysis for top-level op: " + << top->getName(); // Mark the top-level blocks as executable. for (Region ®ion : top->getRegions()) { if (region.empty()) @@ -135,7 +135,7 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) { auto *state = getOrCreate<Executable>(getProgramPointBefore(®ion.front())); propagateIfChanged(state, state->setToLive()); - LDBG("Marked entry block live for region in op: " << top->getName()); + LDBG() << "Marked entry block live for region in op: " << top->getName(); } // Mark as overdefined the predecessors of symbol callables with potentially @@ -146,18 +146,18 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) { } void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { - LDBG("[init] Entering initializeSymbolCallables for top-level op: " - << top->getName()); + LDBG() << "[init] Entering initializeSymbolCallables for top-level op: " + << top->getName(); analysisScope = top; auto walkFn = [&](Operation *symTable, bool allUsesVisible) { - LDBG("[init] Processing symbol table op: " << symTable->getName()); + LDBG() << "[init] Processing symbol table op: " << symTable->getName(); Region &symbolTableRegion = symTable->getRegion(0); Block *symbolTableBlock = &symbolTableRegion.front(); bool foundSymbolCallable = false; for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) { - LDBG("[init] Found CallableOpInterface: " - << callable.getOperation()->getName()); + LDBG() << "[init] Found CallableOpInterface: " + << callable.getOperation()->getName(); Region *callableRegion = callable.getCallableRegion(); if (!callableRegion) continue; @@ -171,8 +171,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(callable)); propagateIfChanged(state, state->setHasUnknownPredecessors()); - LDBG("[init] Marked callable as having unknown predecessors: " - << callable.getOperation()->getName()); + LDBG() << "[init] Marked callable as having unknown predecessors: " + << callable.getOperation()->getName(); } foundSymbolCallable = true; } @@ -187,15 +187,15 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { if (!uses) { // If we couldn't gather the symbol uses, conservatively assume that // we can't track information for any nested symbols. - LDBG("[init] Could not gather symbol uses, conservatively marking " - "all nested callables as having unknown predecessors"); + LDBG() << "[init] Could not gather symbol uses, conservatively marking " + "all nested callables as having unknown predecessors"; return top->walk([&](CallableOpInterface callable) { auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(callable)); propagateIfChanged(state, state->setHasUnknownPredecessors()); - LDBG("[init] Marked nested callable as " - "having unknown predecessors: " - << callable.getOperation()->getName()); + LDBG() << "[init] Marked nested callable as " + "having unknown predecessors: " + << callable.getOperation()->getName(); }); } @@ -209,15 +209,15 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { continue; auto *state = getOrCreate<PredecessorState>(getProgramPointAfter(symbol)); propagateIfChanged(state, state->setHasUnknownPredecessors()); - LDBG("[init] Found non-call use for symbol, " - "marked as having unknown predecessors: " - << symbol->getName()); + LDBG() << "[init] Found non-call use for symbol, " + "marked as having unknown predecessors: " + << symbol->getName(); } }; SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(), walkFn); - LDBG("[init] Finished initializeSymbolCallables for top-level op: " - << top->getName()); + LDBG() << "[init] Finished initializeSymbolCallables for top-level op: " + << top->getName(); } /// Returns true if the operation is a returning terminator in region @@ -229,14 +229,14 @@ static bool isRegionOrCallableReturn(Operation *op) { } LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { - LDBG("[init] Entering initializeRecursively for op: " << op->getName() - << " at " << op); + LDBG() << "[init] Entering initializeRecursively for op: " << op->getName() + << " at " << op; // Initialize the analysis by visiting every op with control-flow semantics. if (op->getNumRegions() || op->getNumSuccessors() || isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) { - LDBG("[init] Visiting op with control-flow semantics: " << *op); - // When the liveness of the parent block changes, make sure to re-invoke the - // analysis on the op. + LDBG() << "[init] Visiting op with control-flow semantics: " << *op; + // When the liveness of the parent block changes, make sure to + // re-invoke the analysis on the op. if (op->getBlock()) getOrCreate<Executable>(getProgramPointBefore(op->getBlock())) ->blockContentSubscribe(this); @@ -246,21 +246,21 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { } // Recurse on nested operations. for (Region ®ion : op->getRegions()) { - LDBG("[init] Recursing into region of op: " << op->getName()); + LDBG() << "[init] Recursing into region of op: " << op->getName(); for (Operation &nestedOp : region.getOps()) { - LDBG("[init] Recursing into nested op: " << nestedOp.getName() << " at " - << &nestedOp); + LDBG() << "[init] Recursing into nested op: " << nestedOp.getName() + << " at " << &nestedOp; if (failed(initializeRecursively(&nestedOp))) return failure(); } } - LDBG("[init] Finished initializeRecursively for op: " << op->getName() - << " at " << op); + LDBG() << "[init] Finished initializeRecursively for op: " << op->getName() + << " at " << op; return success(); } void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { - LDBG("Marking edge live from block " << from << " to block " << to); + LDBG() << "Marking edge live from block " << from << " to block " << to; auto *state = getOrCreate<Executable>(getProgramPointBefore(to)); propagateIfChanged(state, state->setToLive()); auto *edgeState = @@ -269,35 +269,35 @@ void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { } void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) { - LDBG("Marking entry blocks live for op: " << op->getName()); + LDBG() << "Marking entry blocks live for op: " << op->getName(); for (Region ®ion : op->getRegions()) { if (region.empty()) continue; auto *state = getOrCreate<Executable>(getProgramPointBefore(®ion.front())); propagateIfChanged(state, state->setToLive()); - LDBG("Marked entry block live for region in op: " << op->getName()); + LDBG() << "Marked entry block live for region in op: " << op->getName(); } } LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { - LDBG("Visiting program point: " << point << " " << *point); + LDBG() << "Visiting program point: " << point << " " << *point; if (point->isBlockStart()) return success(); Operation *op = point->getPrevOp(); - LDBG("Visiting operation: " << *op); + LDBG() << "Visiting operation: " << *op; // If the parent block is not executable, there is nothing to do. if (op->getBlock() != nullptr && !getOrCreate<Executable>(getProgramPointBefore(op->getBlock())) ->isLive()) { - LDBG("Parent block not live, skipping op: " << *op); + LDBG() << "Parent block not live, skipping op: " << *op; return success(); } // We have a live call op. Add this as a live predecessor of the callee. if (auto call = dyn_cast<CallOpInterface>(op)) { - LDBG("Visiting call operation: " << *op); + LDBG() << "Visiting call operation: " << *op; visitCallOperation(call); } @@ -305,12 +305,12 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { if (op->getNumRegions()) { // Check if we can reason about the region control-flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { - LDBG("Visiting region branch operation: " << *op); + LDBG() << "Visiting region branch operation: " << *op; visitRegionBranchOperation(branch); // Check if this is a callable operation. } else if (auto callable = dyn_cast<CallableOpInterface>(op)) { - LDBG("Visiting callable operation: " << *op); + LDBG() << "Visiting callable operation: " << *op; const auto *callsites = getOrCreateFor<PredecessorState>( getProgramPointAfter(op), getProgramPointAfter(callable)); @@ -322,19 +322,19 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { // Otherwise, conservatively mark all entry blocks as executable. } else { - LDBG("Marking all entry blocks live for op: " << *op); + LDBG() << "Marking all entry blocks live for op: " << *op; markEntryBlocksLive(op); } } if (isRegionOrCallableReturn(op)) { if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) { - LDBG("Visiting region terminator: " << *op); + LDBG() << "Visiting region terminator: " << *op; // Visit the exiting terminator of a region. visitRegionTerminator(op, branch); } else if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) { - LDBG("Visiting callable terminator: " << *op); + LDBG() << "Visiting callable terminator: " << *op; // Visit the exiting terminator of a callable. visitCallableTerminator(op, callable); } @@ -343,12 +343,12 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { if (op->getNumSuccessors()) { // Check if we can reason about the control-flow. if (auto branch = dyn_cast<BranchOpInterface>(op)) { - LDBG("Visiting branch operation: " << *op); + LDBG() << "Visiting branch operation: " << *op; visitBranchOperation(branch); // Otherwise, conservatively mark all successors as exectuable. } else { - LDBG("Marking all successors live for op: " << *op); + LDBG() << "Marking all successors live for op: " << *op; for (Block *successor : op->getSuccessors()) markEdgeLive(op->getBlock(), successor); } @@ -358,7 +358,7 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { } void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { - LDBG("visitCallOperation: " << call.getOperation()->getName()); + LDBG() << "visitCallOperation: " << call.getOperation()->getName(); Operation *callableOp = call.resolveCallableInTable(&symbolTable); // A call to a externally-defined callable has unknown predecessors. @@ -381,15 +381,15 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { auto *callsites = getOrCreate<PredecessorState>(getProgramPointAfter(callableOp)); propagateIfChanged(callsites, callsites->join(call)); - LDBG("Added callsite as predecessor for callable: " - << callableOp->getName()); + LDBG() << "Added callsite as predecessor for callable: " + << callableOp->getName(); } else { // Mark this call op's predecessors as overdefined. auto *predecessors = getOrCreate<PredecessorState>(getProgramPointAfter(call)); propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors()); - LDBG("Marked call op's predecessors as unknown for: " - << call.getOperation()->getName()); + LDBG() << "Marked call op's predecessors as unknown for: " + << call.getOperation()->getName(); } } @@ -421,7 +421,7 @@ DeadCodeAnalysis::getOperandValues(Operation *op) { } void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { - LDBG("visitBranchOperation: " << branch.getOperation()->getName()); + LDBG() << "visitBranchOperation: " << branch.getOperation()->getName(); // Try to deduce a single successor for the branch. std::optional<SmallVector<Attribute>> operands = getOperandValues(branch); if (!operands) @@ -429,18 +429,18 @@ void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { if (Block *successor = branch.getSuccessorForOperands(*operands)) { markEdgeLive(branch->getBlock(), successor); - LDBG("Branch has single successor: " << successor); + LDBG() << "Branch has single successor: " << successor; } else { // Otherwise, mark all successors as executable and outgoing edges. for (Block *successor : branch->getSuccessors()) markEdgeLive(branch->getBlock(), successor); - LDBG("Branch has multiple/all successors live"); + LDBG() << "Branch has multiple/all successors live"; } } void DeadCodeAnalysis::visitRegionBranchOperation( RegionBranchOpInterface branch) { - LDBG("visitRegionBranchOperation: " << branch.getOperation()->getName()); + LDBG() << "visitRegionBranchOperation: " << branch.getOperation()->getName(); // Try to deduce which regions are executable. std::optional<SmallVector<Attribute>> operands = getOperandValues(branch); if (!operands) @@ -457,19 +457,19 @@ void DeadCodeAnalysis::visitRegionBranchOperation( // Mark the entry block as executable. auto *state = getOrCreate<Executable>(point); propagateIfChanged(state, state->setToLive()); - LDBG("Marked region successor live: " << point); + LDBG() << "Marked region successor live: " << point; // Add the parent op as a predecessor. auto *predecessors = getOrCreate<PredecessorState>(point); propagateIfChanged( predecessors, predecessors->join(branch, successor.getSuccessorInputs())); - LDBG("Added region branch as predecessor for successor: " << point); + LDBG() << "Added region branch as predecessor for successor: " << point; } } void DeadCodeAnalysis::visitRegionTerminator(Operation *op, RegionBranchOpInterface branch) { - LDBG("visitRegionTerminator: " << *op); + LDBG() << "visitRegionTerminator: " << *op; std::optional<SmallVector<Attribute>> operands = getOperandValues(op); if (!operands) return; @@ -488,7 +488,7 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, auto *state = getOrCreate<Executable>(getProgramPointBefore(®ion->front())); propagateIfChanged(state, state->setToLive()); - LDBG("Marked region entry block live for region: " << region); + LDBG() << "Marked region entry block live for region: " << region; predecessors = getOrCreate<PredecessorState>( getProgramPointBefore(®ion->front())); } else { @@ -498,14 +498,14 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op, } propagateIfChanged(predecessors, predecessors->join(op, successor.getSuccessorInputs())); - LDBG("Added region terminator as predecessor for successor: " - << (successor.getSuccessor() ? "region entry" : "parent op")); + LDBG() << "Added region terminator as predecessor for successor: " + << (successor.getSuccessor() ? "region entry" : "parent op"); } } void DeadCodeAnalysis::visitCallableTerminator(Operation *op, CallableOpInterface callable) { - LDBG("visitCallableTerminator: " << *op); + LDBG() << "visitCallableTerminator: " << *op; // Add as predecessors to all callsites this return op. auto *callsites = getOrCreateFor<PredecessorState>( getProgramPointAfter(op), getProgramPointAfter(callable)); @@ -516,15 +516,15 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op, getOrCreate<PredecessorState>(getProgramPointAfter(predecessor)); if (canResolve) { propagateIfChanged(predecessors, predecessors->join(op)); - LDBG("Added callable terminator as predecessor for callsite: " - << predecessor->getName()); + LDBG() << "Added callable terminator as predecessor for callsite: " + << predecessor->getName(); } else { // If the terminator is not a return-like, then conservatively assume we // can't resolve the predecessor. propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors()); - LDBG("Could not resolve callable terminator for callsite: " - << predecessor->getName()); + LDBG() << "Could not resolve callable terminator for callsite: " + << predecessor->getName(); } } } diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp index 6a12fe3..509f520 100644 --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -10,7 +10,7 @@ #include <cassert> #include <mlir/Analysis/DataFlow/LivenessAnalysis.h> -#include <llvm/Support/Debug.h> +#include <llvm/Support/DebugLog.h> #include <mlir/Analysis/DataFlow/SparseAnalysis.h> #include <mlir/Analysis/DataFlow/Utils.h> #include <mlir/Analysis/DataFlowFramework.h> @@ -21,8 +21,6 @@ #include <mlir/Support/LLVM.h> #define DEBUG_TYPE "liveness-analysis" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; using namespace mlir::dataflow; @@ -81,16 +79,15 @@ ChangeResult Liveness::meet(const AbstractSparseLattice &other) { LogicalResult LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands, ArrayRef<const Liveness *> results) { - LLVM_DEBUG(DBGS() << "[visitOperation] Enter: "; - op->print(llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n"); + LDBG() << "[visitOperation] Enter: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // This marks values of type (1.a) and (4) liveness as "live". if (!isMemoryEffectFree(op) || op->hasTrait<OpTrait::ReturnLike>()) { - LDBG("[visitOperation] Operation has memory effects or is " - "return-like, marking operands live"); + LDBG() << "[visitOperation] Operation has memory effects or is " + "return-like, marking operands live"; for (auto *operand : operands) { - LDBG(" [visitOperation] Marking operand live: " - << operand << " (" << operand->isLive << ")"); + LDBG() << " [visitOperation] Marking operand live: " << operand << " (" + << operand->isLive << ")"; propagateIfChanged(operand, operand->markLive()); } } @@ -99,28 +96,28 @@ LivenessAnalysis::visitOperation(Operation *op, ArrayRef<Liveness *> operands, bool foundLiveResult = false; for (const Liveness *r : results) { if (r->isLive && !foundLiveResult) { - LDBG("[visitOperation] Found live result, " - "meeting all operands with result: " - << r); + LDBG() << "[visitOperation] Found live result, " + "meeting all operands with result: " + << r; // It is assumed that each operand is used to compute each result of an // op. Thus, if at least one result is live, each operand is live. for (Liveness *operand : operands) { - LDBG(" [visitOperation] Meeting operand: " << operand - << " with result: " << r); + LDBG() << " [visitOperation] Meeting operand: " << operand + << " with result: " << r; meet(operand, *r); } foundLiveResult = true; } - LDBG("[visitOperation] Adding dependency for result: " << r << " after op: " - << *op); + LDBG() << "[visitOperation] Adding dependency for result: " << r + << " after op: " << *op; addDependency(const_cast<Liveness *>(r), getProgramPointAfter(op)); } return success(); } void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { - LDBG("Visiting branch operand: " << operand.get() - << " in op: " << *operand.getOwner()); + LDBG() << "Visiting branch operand: " << operand.get() + << " in op: " << *operand.getOwner(); // We know (at the moment) and assume (for the future) that `operand` is a // non-forwarded branch operand of a `RegionBranchOpInterface`, // `BranchOpInterface`, `RegionBranchTerminatorOpInterface` or return-like op. @@ -152,9 +149,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { for (Value result : op->getResults()) { if (getLatticeElement(result)->isLive) { mayLive = true; - LDBG("[visitBranchOperand] Non-forwarded branch " - "operand may be live due to live result: " - << result); + LDBG() << "[visitBranchOperand] Non-forwarded branch " + "operand may be live due to live result: " + << result; break; } } @@ -174,8 +171,8 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { // Therefore, we conservatively consider the non-forwarded operand of the // branch operation may live. mayLive = true; - LDBG("[visitBranchOperand] Non-forwarded branch operand may " - "be live due to branch op interface"); + LDBG() << "[visitBranchOperand] Non-forwarded branch operand may " + "be live due to branch op interface"; } else { Operation *parentOp = op->getParentOp(); assert(isa<RegionBranchOpInterface>(parentOp) && @@ -191,9 +188,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { for (Value result : parentOp->getResults()) { if (getLatticeElement(result)->isLive) { mayLive = true; - LDBG("[visitBranchOperand] Non-forwarded branch " - "operand may be live due to parent live result: " - << result); + LDBG() << "[visitBranchOperand] Non-forwarded branch " + "operand may be live due to parent live result: " + << result; break; } } @@ -214,9 +211,9 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { for (Operation &nestedOp : *block) { if (!isMemoryEffectFree(&nestedOp)) { mayLive = true; - LDBG("Non-forwarded branch operand may be " - "live due to memory effect in block: " - << block); + LDBG() << "Non-forwarded branch operand may be " + "live due to memory effect in block: " + << block; break; } } @@ -224,7 +221,7 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { if (mayLive) { Liveness *operandLiveness = getLatticeElement(operand.get()); - LDBG("Marking branch operand live: " << operand.get()); + LDBG() << "Marking branch operand live: " << operand.get(); propagateIfChanged(operandLiveness, operandLiveness->markLive()); } @@ -236,7 +233,7 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { SmallVector<const Liveness *, 4> resultsLiveness; for (const Value result : op->getResults()) resultsLiveness.push_back(getLatticeElement(result)); - LDBG("Visiting operation for non-forwarded branch operand: " << *op); + LDBG() << "Visiting operation for non-forwarded branch operand: " << *op; (void)visitOperation(op, operandLiveness, resultsLiveness); // We also visit the parent op with the parent's results and this operand if @@ -249,14 +246,14 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { SmallVector<const Liveness *, 4> parentResultsLiveness; for (const Value parentResult : parentOp->getResults()) parentResultsLiveness.push_back(getLatticeElement(parentResult)); - LDBG("Visiting parent operation for non-forwarded branch operand: " - << *parentOp); + LDBG() << "Visiting parent operation for non-forwarded branch operand: " + << *parentOp; (void)visitOperation(parentOp, operandLiveness, parentResultsLiveness); } void LivenessAnalysis::visitCallOperand(OpOperand &operand) { - LDBG("Visiting call operand: " << operand.get() - << " in op: " << *operand.getOwner()); + LDBG() << "Visiting call operand: " << operand.get() + << " in op: " << *operand.getOwner(); // We know (at the moment) and assume (for the future) that `operand` is a // non-forwarded call operand of an op implementing `CallOpInterface`. assert(isa<CallOpInterface>(operand.getOwner()) && @@ -269,18 +266,18 @@ void LivenessAnalysis::visitCallOperand(OpOperand &operand) { // This marks values of type (1.c) liveness as "live". A non-forwarded // call operand is live. Liveness *operandLiveness = getLatticeElement(operand.get()); - LDBG("Marking call operand live: " << operand.get()); + LDBG() << "Marking call operand live: " << operand.get(); propagateIfChanged(operandLiveness, operandLiveness->markLive()); } void LivenessAnalysis::setToExitState(Liveness *lattice) { - LDBG("setToExitState for lattice: " << lattice); + LDBG() << "setToExitState for lattice: " << lattice; if (lattice->isLive) { - LDBG("Lattice already live, nothing to do"); + LDBG() << "Lattice already live, nothing to do"; return; } // This marks values of type (2) liveness as "live". - LDBG("Marking lattice live due to exit state"); + LDBG() << "Marking lattice live due to exit state"; (void)lattice->markLive(); propagateIfChanged(lattice, ChangeResult::Change); } @@ -290,14 +287,14 @@ void LivenessAnalysis::setToExitState(Liveness *lattice) { //===----------------------------------------------------------------------===// RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) { - LDBG("Constructing RunLivenessAnalysis for op: " << op->getName()); + LDBG() << "Constructing RunLivenessAnalysis for op: " << op->getName(); SymbolTableCollection symbolTable; loadBaselineAnalyses(solver); solver.load<LivenessAnalysis>(symbolTable); - LDBG("Initializing and running solver"); + LDBG() << "Initializing and running solver"; (void)solver.initializeAndRun(op); - LDBG("Dumping liveness state for op"); + LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName(); } const Liveness *RunLivenessAnalysis::getLiveness(Value val) { diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp index 176d53e..16f7033 100644 --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -14,7 +14,7 @@ #include "llvm/ADT/iterator.h" #include "llvm/Config/abi-breaking.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "dataflow" @@ -44,9 +44,8 @@ void AnalysisState::addDependency(ProgramPoint *dependent, (void)inserted; DATAFLOW_DEBUG({ if (inserted) { - llvm::dbgs() << "Creating dependency between " << debugName << " of " - << anchor << "\nand " << debugName << " on " << dependent - << "\n"; + LDBG() << "Creating dependency between " << debugName << " of " << anchor + << "\nand " << debugName << " on " << dependent; } }); } @@ -116,8 +115,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { // Initialize the analyses. for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) { - DATAFLOW_DEBUG(llvm::dbgs() - << "Priming analysis: " << analysis.debugName << "\n"); + DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName); if (failed(analysis.initialize(top))) return failure(); } @@ -129,8 +127,8 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { auto [point, analysis] = worklist.front(); worklist.pop(); - DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName - << "' on: " << point << "\n"); + DATAFLOW_DEBUG(LDBG() << "Invoking '" << analysis->debugName + << "' on: " << point); if (failed(analysis->visit(point))) return failure(); } @@ -143,9 +141,9 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state, assert(isRunning && "DataFlowSolver is not running, should not use propagateIfChanged"); if (changed == ChangeResult::Change) { - DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName - << " of " << state->anchor << "\n" - << "Value: " << *state << "\n"); + DATAFLOW_DEBUG(LDBG() << "Propagating update to " << state->debugName + << " of " << state->anchor << "\n" + << "Value: " << *state); state->onUpdate(this); } } diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp index 239ffe6..ea7dfdc 100644 --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -15,7 +15,6 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #include <cassert> #include <functional> diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp index 9f4a87a..8b14e71 100644 --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -89,6 +89,7 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, nestedPunctuation.pop_back(); return success(); }; + const char *curBufferEnd = state.lex.getBufferEnd(); do { // Handle code completions, which may appear in the middle of the symbol // body. @@ -98,6 +99,12 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, break; } + if (curBufferEnd == curPtr) { + if (!nestedPunctuation.empty()) + return emitPunctError(); + return emitError("unexpected nul or EOF in pretty dialect name"); + } + char c = *curPtr++; switch (c) { case '\0': diff --git a/mlir/lib/AsmParser/Lexer.cpp b/mlir/lib/AsmParser/Lexer.cpp index 751bd63..8f53529 100644 --- a/mlir/lib/AsmParser/Lexer.cpp +++ b/mlir/lib/AsmParser/Lexer.cpp @@ -37,6 +37,18 @@ Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context, AsmParserCodeCompleteContext *codeCompleteContext) : sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) { auto bufferID = sourceMgr.getMainFileID(); + + // Check to see if the main buffer contains the last buffer, and if so the + // last buffer should be used as main file for parsing. + if (sourceMgr.getNumBuffers() > 1) { + unsigned lastFileID = sourceMgr.getNumBuffers(); + const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID); + const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID); + if (main->getBufferStart() <= last->getBufferStart() && + main->getBufferEnd() >= last->getBufferEnd()) { + bufferID = lastFileID; + } + } curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer(); curPtr = curBuffer.begin(); @@ -71,6 +83,7 @@ Token Lexer::emitError(const char *loc, const Twine &message) { } Token Lexer::lexToken() { + const char *curBufferEnd = curBuffer.end(); while (true) { const char *tokStart = curPtr; @@ -78,6 +91,9 @@ Token Lexer::lexToken() { if (tokStart == codeCompleteLoc) return formToken(Token::code_complete, tokStart); + if (tokStart == curBufferEnd) + return formToken(Token::eof, tokStart); + // Lex the next token. switch (*curPtr++) { default: @@ -102,7 +118,7 @@ Token Lexer::lexToken() { case 0: // This may either be a nul character in the source file or may be the EOF // marker that llvm::MemoryBuffer guarantees will be there. - if (curPtr - 1 == curBuffer.end()) + if (curPtr - 1 == curBufferEnd) return formToken(Token::eof, tokStart); continue; @@ -259,7 +275,11 @@ void Lexer::skipComment() { assert(*curPtr == '/'); ++curPtr; + const char *curBufferEnd = curBuffer.end(); while (true) { + if (curPtr == curBufferEnd) + return; + switch (*curPtr++) { case '\n': case '\r': @@ -267,7 +287,7 @@ void Lexer::skipComment() { return; case 0: // If this is the end of the buffer, end the comment. - if (curPtr - 1 == curBuffer.end()) { + if (curPtr - 1 == curBufferEnd) { --curPtr; return; } @@ -405,6 +425,7 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) { Token Lexer::lexString(const char *tokStart) { assert(curPtr[-1] == '"'); + const char *curBufferEnd = curBuffer.end(); while (true) { // Check to see if there is a code completion location within the string. In // these cases we generate a completion location and place the currently @@ -419,7 +440,7 @@ Token Lexer::lexString(const char *tokStart) { case 0: // If this is a random nul character in the middle of a string, just // include it. If it is the end of file, then it is an error. - if (curPtr - 1 != curBuffer.end()) + if (curPtr - 1 != curBufferEnd) continue; [[fallthrough]]; case '\n': diff --git a/mlir/lib/AsmParser/Lexer.h b/mlir/lib/AsmParser/Lexer.h index 4085a9b..670444e 100644 --- a/mlir/lib/AsmParser/Lexer.h +++ b/mlir/lib/AsmParser/Lexer.h @@ -40,6 +40,9 @@ public: /// Returns the start of the buffer. const char *getBufferBegin() { return curBuffer.data(); } + /// Returns the end of the buffer. + const char *getBufferEnd() { return curBuffer.end(); } + /// Return the code completion location of the lexer, or nullptr if there is /// none. const char *getCodeCompleteLoc() const { return codeCompleteLoc; } diff --git a/mlir/lib/AsmParser/TypeParser.cpp b/mlir/lib/AsmParser/TypeParser.cpp index 21bb0ec..a461ebe 100644 --- a/mlir/lib/AsmParser/TypeParser.cpp +++ b/mlir/lib/AsmParser/TypeParser.cpp @@ -19,7 +19,6 @@ #include "mlir/IR/TensorEncoding.h" #include "mlir/IR/Types.h" #include "mlir/Support/LLVM.h" -#include "llvm/ADT/STLExtras.h" #include <cassert> #include <cstdint> #include <limits> diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 8f79caf..db84ee1 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -16,8 +16,8 @@ #include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/raw_ostream.h" @@ -1428,6 +1428,12 @@ public: } }; +// Check if the python version is less than 3.13. Py_IsFinalizing is a part +// of stable ABI since 3.13 and before it was available as _Py_IsFinalizing. +#if PY_VERSION_HEX < 0x030d0000 +#define Py_IsFinalizing _Py_IsFinalizing +#endif + class PyDenseResourceElementsAttribute : public PyConcreteAttribute<PyDenseResourceElementsAttribute> { public: @@ -1474,8 +1480,9 @@ public: // The userData is a Py_buffer* that the deleter owns. auto deleter = [](void *userData, const void *data, size_t size, size_t align) { - if (!Py_IsInitialized()) - Py_Initialize(); + if (Py_IsFinalizing()) + return; + assert(Py_IsInitialized() && "expected interpreter to be initialized"); Py_buffer *ownedView = static_cast<Py_buffer *>(userData); nb::gil_scoped_acquire gil; PyBuffer_Release(ownedView); diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt index 8b9a395..ccda668 100644 --- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt +++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt @@ -1,19 +1,16 @@ # Dialect registration. -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything RegisterEverything.cpp LINK_LIBS PUBLIC - ${dialect_libs} ${translation_libs} - ${conversion_libs} - ${extension_libs} MLIRBuiltinToLLVMIRTranslation MLIRCAPIIR - MLIRLLVMToLLVMIRTranslation MLIRCAPITransforms + MLIRLLVMToLLVMIRTranslation + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses ) diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index d25c84a..191b5ab6 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -20,3 +20,37 @@ add_subdirectory(Target) add_subdirectory(Tools) add_subdirectory(Transforms) add_subdirectory(ExecutionEngine) + +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) + +add_mlir_library(MLIRRegisterAllDialects + RegisterAllDialects.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + ${dialect_libs} + ) + +add_mlir_library(MLIRRegisterAllPasses + RegisterAllPasses.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + ${dialect_libs} # Some passes are part of the dialect libs + ${conversion_libs} + ) + +add_mlir_library(MLIRRegisterAllExtensions + RegisterAllExtensions.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + ${dialect_libs} + ${conversion_libs} + ${extension_libs} + ) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index bc0d9bf..64720bf 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -232,8 +232,8 @@ struct FatRawBufferCastLowering Value result = MemRefDescriptor::poison( rewriter, loc, getTypeConverter()->convertType(op.getResult().getType())); - result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, - kAllocatedPtrPosInMemRefDescriptor); + SmallVector<int64_t> pos{kAllocatedPtrPosInMemRefDescriptor}; + result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, pos); result = LLVM::InsertValueOp::create(rewriter, loc, result, fatPtr, kAlignedPtrPosInMemRefDescriptor); result = LLVM::InsertValueOp::create(rewriter, loc, result, offset, @@ -481,16 +481,16 @@ struct MemoryCounterWaitOpLowering if (chipset.majorVersion >= 12) { Location loc = op.getLoc(); if (std::optional<int> ds = adaptor.getDs()) - rewriter.create<ROCDL::WaitDscntOp>(loc, *ds); + ROCDL::WaitDscntOp::create(rewriter, loc, *ds); if (std::optional<int> load = adaptor.getLoad()) - rewriter.create<ROCDL::WaitLoadcntOp>(loc, *load); + ROCDL::WaitLoadcntOp::create(rewriter, loc, *load); if (std::optional<int> store = adaptor.getStore()) - rewriter.create<ROCDL::WaitStorecntOp>(loc, *store); + ROCDL::WaitStorecntOp::create(rewriter, loc, *store); if (std::optional<int> exp = adaptor.getExp()) - rewriter.create<ROCDL::WaitExpcntOp>(loc, *exp); + ROCDL::WaitExpcntOp::create(rewriter, loc, *exp); rewriter.eraseOp(op); return success(); diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp index 8c68b57..8230591 100644 --- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp +++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp @@ -449,7 +449,7 @@ LogicalResult ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opOutWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -460,6 +460,8 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Type scaleType = getElementTypeOrSelf(scale); Type outType = getElementTypeOrSelf(out); + int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth(); + VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); @@ -473,7 +475,7 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, else if (scaleType.getIntOrFloatBitWidth() > 32) scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale); - VectorType extScaleResultType = VectorType::get(opWidth, outType); + VectorType extScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { Value inCast = vector::BroadcastOp::create(rewriter, loc, @@ -487,10 +489,11 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) + if (origScaleVecType) llvm::append_range(originalScaleShape, origScaleVecType.getShape()); originalScaleShape.insert(originalScaleShape.end(), @@ -524,19 +527,26 @@ ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op, Value blockResult = rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); + for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i); i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = vector::ExtractStridedSliceOp::create( - rewriter, loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 - Value scaleExt = amdgpu::ScaledExtPackedOp::create( - rewriter, loc, extScaleResultType, slice, uniformScale, 0); - if (sliceWidth != opWidth) - scaleExt = vector::ExtractStridedSliceOp::create( - rewriter, loc, scaleExt, 0, sliceWidth, 1); - blockResult = vector::InsertStridedSliceOp::create( - rewriter, loc, scaleExt, blockResult, i, 1); + i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) { + Value inSlice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, inSliceWidth, 1); + for (int64_t j = 0, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j); + j < inSliceWidth; j += outSliceWidth, + outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) { + // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1 + Value scaleExt = amdgpu::ScaledExtPackedOp::create( + rewriter, loc, extScaleResultType, inSlice, uniformScale, + j / opOutWidth); + if (outSliceWidth < opOutWidth) { + scaleExt = vector::ExtractStridedSliceOp::create( + rewriter, loc, scaleExt, 0, outSliceWidth, 1); + } + blockResult = vector::InsertStridedSliceOp::create( + rewriter, loc, scaleExt, blockResult, i + j, 1); + } } VectorType resultType = VectorType::get(ratio, outType); @@ -555,7 +565,7 @@ LogicalResult ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, PatternRewriter &rewriter) const { Location loc = op.getLoc(); - constexpr int64_t opWidth = 2; + constexpr int64_t opInWidth = 2; Value in = op.getIn(); Value scale = op.getScale(); @@ -568,7 +578,6 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType outVecType = dyn_cast<VectorType>(out.getType()); VectorType scaleVecType = dyn_cast<VectorType>(scale.getType()); - if (outVecType && outVecType.isScalable()) return failure(); @@ -581,8 +590,8 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, Value zero = arith::ConstantOp::create(rewriter, loc, outType, rewriter.getFloatAttr(outType, 0.0)); - unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth(); - VectorType truncScaleResultType = VectorType::get(numPackedElem, outType); + int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth(); + VectorType truncScaleResultType = VectorType::get(opOutWidth, outType); if (!outVecType) { Type inVecType = VectorType::get(1, inType); @@ -598,16 +607,16 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, VectorType inVecType = cast<VectorType>(in.getType()); Value origScale = getOriginalVectorValue(op.getScale()); + VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType()); ArrayRef<int64_t> inShape = inVecType.getShape(); - SmallVector<int64_t> originalScaleShape; - if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType())) - llvm::append_range(originalScaleShape, origScaleVecType.getShape()); + SmallVector<int64_t> scaleShape; + if (origScaleVecType) + llvm::append_range(scaleShape, origScaleVecType.getShape()); - originalScaleShape.insert(originalScaleShape.end(), - inShape.size() - originalScaleShape.size(), 1); + scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1); - auto maybeRatio = computeShapeRatio(inShape, originalScaleShape); + auto maybeRatio = computeShapeRatio(inShape, scaleShape); assert(maybeRatio && "failed to derive block size from broadcast or splat operation"); @@ -633,20 +642,36 @@ ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op, Value blockResult = rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero); - for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i); - i < blockSize; - i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) { - Value slice = vector::ExtractStridedSliceOp::create( - rewriter, loc, block1D, i, sliceWidth, 1); - // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 - Value scaleTrunc = amdgpu::PackedScaledTruncOp::create( - rewriter, loc, truncScaleResultType, slice, uniformScale, 0, - /*existing=*/nullptr); - int64_t packedWidth = - cast<VectorType>(scaleTrunc.getType()).getNumElements(); - if (packedWidth != opWidth) + for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i); + i < blockSize; i += outSliceWidth, + outSliceWidth = std::min(opOutWidth, blockSize - i)) { + Value scaleTrunc; + // Case where <= 2 elements are being truncated. + if (outSliceWidth <= opInWidth) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i, outSliceWidth, 1); + // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1 + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, 0, + /*existing=*/nullptr); + } else { + scaleTrunc = vector::BroadcastOp::create(rewriter, loc, + truncScaleResultType, zero); + for (int64_t j = 0, + inSliceWidth = std::min(opInWidth, outSliceWidth - j); + j < outSliceWidth; j += opInWidth, + inSliceWidth = std::min(opInWidth, outSliceWidth - j)) { + Value slice = vector::ExtractStridedSliceOp::create( + rewriter, loc, block1D, i + j, inSliceWidth, 1); + scaleTrunc = amdgpu::PackedScaledTruncOp::create( + rewriter, loc, truncScaleResultType, slice, uniformScale, + j / opInWidth, scaleTrunc); + } + } + if (outSliceWidth != opOutWidth) { scaleTrunc = vector::ExtractStridedSliceOp::create( - rewriter, loc, scaleTrunc, 0, sliceWidth, 1); + rewriter, loc, scaleTrunc, 0, outSliceWidth, 1); + } blockResult = vector::InsertStridedSliceOp::create( rewriter, loc, scaleTrunc, blockResult, i, 1); } diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 59b3fe2..515fe5c 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -402,8 +402,8 @@ public: Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType); // Actual cast (may change bitwidth) - auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(), - castDestType, actualOp); + auto cast = + emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp); // Cast to the expected output type auto result = adaptValueType(cast, rewriter, opReturnType); @@ -507,8 +507,8 @@ public: Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); - Value arithmeticResult = rewriter.template create<EmitCOp>( - op.getLoc(), arithmeticType, lhs, rhs); + Value arithmeticResult = + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); Value result = adaptValueType(arithmeticResult, rewriter, type); @@ -547,8 +547,8 @@ public: Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType); Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType); - Value arithmeticResult = rewriter.template create<EmitCOp>( - op.getLoc(), arithmeticType, lhs, rhs); + Value arithmeticResult = + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); Value result = adaptValueType(arithmeticResult, rewriter, type); @@ -748,8 +748,8 @@ public: } Value fpCastOperand = adaptor.getIn(); if (actualOperandType != operandType) { - fpCastOperand = rewriter.template create<emitc::CastOp>( - castOp.getLoc(), actualOperandType, fpCastOperand); + fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(), + actualOperandType, fpCastOperand); } rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand); diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp index d43e681..265293b 100644 --- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp +++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp @@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType, return builder.getF32FloatAttr(dstVal.convertToFloat()); } +// Get in IntegerAttr from FloatAttr while preserving the bits. +// Useful for converting float constants to integer constants while preserving +// the bits. +static IntegerAttr +getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType, + ConversionPatternRewriter &rewriter) { + APFloat floatVal = floatAttr.getValue(); + APInt intVal = floatVal.bitcastToAPInt(); + return rewriter.getIntegerAttr(dstType, intVal); +} + /// Returns true if the given `type` is a boolean scalar or vector type. static bool isBoolScalarOrVector(Type type) { assert(type && "Not a valid type"); @@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final SmallVector<Attribute, 8> elements; if (isa<FloatType>(srcElemType)) { for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) { - FloatAttr dstAttr = - convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter); + Attribute dstAttr = nullptr; + // Handle 8-bit float conversion to 8-bit integer. + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcElemType.getIntOrFloatBitWidth() == 8 && + isa<IntegerType>(dstElemType)) { + dstAttr = + getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter); + } else { + dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), + rewriter); + } if (!dstAttr) return failure(); elements.push_back(dstAttr); @@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final // Floating-point types. if (isa<FloatType>(srcType)) { auto srcAttr = cast<FloatAttr>(cstAttr); - auto dstAttr = srcAttr; + Attribute dstAttr = srcAttr; // Floating-point types not supported in the target environment are all // converted to float type. - if (srcType != dstType) { + auto *typeConverter = getTypeConverter<SPIRVTypeConverter>(); + if (typeConverter->getOptions().emulateUnsupportedFloatTypes && + srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) && + dstType.getIntOrFloatBitWidth() == 8) { + // If the source is an 8-bit float, convert it to a 8-bit integer. + dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter); + if (!dstAttr) + return failure(); + } else if (srcType != dstType) { dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter); if (!dstAttr) return failure(); @@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // Use UnrealizedConversionCast as the bridge so that we don't need to pull diff --git a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp index 1510b0b..e34b368 100644 --- a/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp +++ b/mlir/lib/Conversion/ArmNeon2dToIntr/ArmNeon2dToIntr.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassRegistry.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" namespace mlir { diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp index 79e1683..29e6552 100644 --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -18,7 +18,6 @@ #include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp index 30a7170..3edcbb8 100644 --- a/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp +++ b/mlir/lib/Conversion/BufferizationToMemRef/BufferizationToMemRef.cpp @@ -68,9 +68,8 @@ struct CloneOpConversion : public OpConversionPattern<bufferization::CloneOp> { scf::YieldOp::create(rewriter, loc, acc); }; - auto size = rewriter - .create<scf::ForOp>(loc, zero, rank, one, ValueRange(one), - loopBody) + auto size = scf::ForOp::create(rewriter, loc, zero, rank, one, + ValueRange(one), loopBody) .getResult(0); MemRefType memrefType = MemRefType::get({ShapedType::kDynamic}, diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index f84375b..785cb82 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -43,7 +43,7 @@ add_subdirectory(MathToSPIRV) add_subdirectory(MemRefToEmitC) add_subdirectory(MemRefToLLVM) add_subdirectory(MemRefToSPIRV) -add_subdirectory(MeshToMPI) +add_subdirectory(ShardToMPI) add_subdirectory(MPIToLLVM) add_subdirectory(NVGPUToNVVM) add_subdirectory(NVVMToLLVM) diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 6f0fc29..35ad99c 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_cabs_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>( patterns.getContext(), "__ocml_cabs_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>( + patterns.getContext(), "__ocml_carg_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>( + patterns.getContext(), "__ocml_carg_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>( + patterns.getContext(), "__ocml_conj_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>( + patterns.getContext(), "__ocml_conj_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>( + patterns.getContext(), "__ocml_ccos_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>( + patterns.getContext(), "__ocml_ccos_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>( patterns.getContext(), "__ocml_cexp_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>( patterns.getContext(), "__ocml_cexp_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>( + patterns.getContext(), "__ocml_clog_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>( + patterns.getContext(), "__ocml_clog_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>( + patterns.getContext(), "__ocml_cpow_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>( + patterns.getContext(), "__ocml_cpow_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>( + patterns.getContext(), "__ocml_csin_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>( + patterns.getContext(), "__ocml_csin_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>( + patterns.getContext(), "__ocml_csqrt_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>( + patterns.getContext(), "__ocml_csqrt_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>( + patterns.getContext(), "__ocml_ctan_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>( + patterns.getContext(), "__ocml_ctan_f64"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>( + patterns.getContext(), "__ocml_ctanh_f32"); + patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>( + patterns.getContext(), "__ocml_ctanh_f64"); } namespace { @@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<func::FuncDialect>(); - target.addIllegalOp<complex::AbsOp, complex::ExpOp>(); + target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp, + complex::CosOp, complex::ExpOp, complex::LogOp, + complex::PowOp, complex::SinOp, complex::SqrtOp, + complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index eeff8a9..5ad514d 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include <type_traits> diff --git a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp index c8311eb..5ac838c 100644 --- a/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp +++ b/mlir/lib/Conversion/ControlFlowToSCF/ControlFlowToSCF.cpp @@ -144,12 +144,11 @@ ControlFlowToSCFTransformation::createUnreachableTerminator(Location loc, return emitError(loc, "Cannot create unreachable terminator for '") << parentOp->getName() << "'"; - return builder - .create<func::ReturnOp>( - loc, llvm::map_to_vector(funcOp.getResultTypes(), - [&](Type type) { - return getUndefValue(loc, builder, type); - })) + return func::ReturnOp::create( + builder, loc, + llvm::map_to_vector( + funcOp.getResultTypes(), + [&](Type type) { return getUndefValue(loc, builder, type); })) .getOperation(); } diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp index 03f4bf4..56b6181 100644 --- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp +++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp @@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); // TODO: We should also take care of block argument type conversion. diff --git a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp index c9b1dc1..ee6d7d5 100644 --- a/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp +++ b/mlir/lib/Conversion/ConvertToEmitC/ConvertToEmitCPass.cpp @@ -9,8 +9,6 @@ #include "mlir/Conversion/ConvertToEmitC/ConvertToEmitCPass.h" #include "mlir/Conversion/ConvertToEmitC/ToEmitCInterface.h" -#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" -#include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" diff --git a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp index 252245d..c70b5f0 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ToLLVMInterface.cpp @@ -9,7 +9,6 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" -#include "llvm/ADT/DenseSet.h" using namespace mlir; diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp index 8ed9f65..c0439a4 100644 --- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp +++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp @@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() { SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 63eb6c58..3cfbd89 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -579,8 +579,8 @@ LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, auto function = [&] { if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) return function; - return OpBuilder::atBlockEnd(module.getBody()) - .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); + auto builder = OpBuilder::atBlockEnd(module.getBody()); + return LLVM::LLVMFuncOp::create(builder, loc, functionName, functionType); }(); return LLVM::CallOp::create(builder, loc, function, arguments); } diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp index a19194e..1817861 100644 --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -507,25 +507,27 @@ LogicalResult GPURotateConversion::matchAndRewrite( getTypeConverter<SPIRVTypeConverter>()->getTargetEnv(); unsigned subgroupSize = targetEnv.getAttr().getResourceLimits().getSubgroupSize(); - IntegerAttr widthAttr; - if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) || - widthAttr.getValue().getZExtValue() > subgroupSize) + unsigned width = rotateOp.getWidth(); + if (width > subgroupSize) return rewriter.notifyMatchFailure( - rotateOp, - "rotate width is not a constant or larger than target subgroup size"); + rotateOp, "rotate width is larger than target subgroup size"); Location loc = rotateOp.getLoc(); auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup); + Value offsetVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr()); + Value widthVal = + arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr()); Value rotateResult = spirv::GroupNonUniformRotateKHROp::create( - rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(), - adaptor.getWidth()); + rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal); Value validVal; - if (widthAttr.getValue().getZExtValue() == subgroupSize) { + if (width == subgroupSize) { validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter); } else { + IntegerAttr widthAttr = adaptor.getWidthAttr(); Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr); validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, - laneId, adaptor.getWidth()); + laneId, widthVal); } rewriter.replaceOp(rotateOp, {rotateResult, validVal}); @@ -559,8 +561,8 @@ static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc, builder, loc, builder.getI32Type(), builder.getIntegerAttr(builder.getI32Type(), *clusterSize)); - return builder - .create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue) + return NonUniformOp::create(builder, loc, type, scope, groupOp, arg, + clusterSizeValue) .getResult(); } diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index ecd5b63..2568044 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -272,14 +272,13 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( // Allocate memory, copy, and free the source if necessary. Value memory = - toDynamic - ? builder - .create<LLVM::CallOp>(loc, mallocFunc.value(), allocationSize) - .getResult() - : LLVM::AllocaOp::create(builder, loc, getPtrType(), - IntegerType::get(getContext(), 8), - allocationSize, - /*alignment=*/0); + toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(), + allocationSize) + .getResult() + : LLVM::AllocaOp::create(builder, loc, getPtrType(), + IntegerType::get(getContext(), 8), + allocationSize, + /*alignment=*/0); Value source = desc.memRefDescPtr(builder, loc); LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); if (!toDynamic) diff --git a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp index 5b68eb8..e5496e5 100644 --- a/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp +++ b/mlir/lib/Conversion/MPIToLLVM/MPIToLLVM.cpp @@ -35,7 +35,7 @@ static Op getOrDefineGlobal(ModuleOp &moduleOp, const Location loc, if (!(ret = moduleOp.lookupSymbol<Op>(name))) { ConversionPatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); - ret = rewriter.template create<Op>(loc, std::forward<Args>(args)...); + ret = Op::create(rewriter, loc, std::forward<Args>(args)...); } return ret; } diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp index 08a4566..cde2340 100644 --- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp +++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp @@ -17,13 +17,12 @@ #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOFUNCS @@ -33,7 +32,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-funcs" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") namespace { // Pattern to convert vector operations to scalar operations. @@ -654,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op, /// } static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { if (!isa<IntegerType>(elementType)) { - LLVM_DEBUG({ - DBGS() << "non-integer element type for CtlzFunc; type was: "; - elementType.print(llvm::dbgs()); - }); + LDBG() << "non-integer element type for CtlzFunc; type was: " + << elementType; llvm_unreachable("non-integer element type"); } int64_t bitWidth = elementType.getIntOrFloatBitWidth(); @@ -699,7 +695,8 @@ static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) { scf::IfOp ifOp = scf::IfOp::create(builder, elementType, inputEqZero, /*addThenBlock=*/true, /*addElseBlock=*/true); - ifOp.getThenBodyBuilder().create<scf::YieldOp>(loc, bitWidthValue); + auto thenBuilder = ifOp.getThenBodyBuilder(); + scf::YieldOp::create(thenBuilder, loc, bitWidthValue); auto elseBuilder = ImplicitLocOpBuilder::atBlockEnd(loc, &ifOp.getElseRegion().front()); diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp index 93d8b49..df219f3 100644 --- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp +++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/MathToROCDL/MathToROCDL.h" +#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" #include "mlir/Conversion/LLVMCommon/LoweringOptions.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -21,7 +22,6 @@ #include "../GPUCommon/GPUOpsLowering.h" #include "../GPUCommon/OpToFuncCallLowering.h" -#include "mlir/Conversion/GPUCommon/GPUCommonPass.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMATHTOROCDL @@ -31,7 +31,6 @@ namespace mlir { using namespace mlir; #define DEBUG_TYPE "math-to-rocdl" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") template <typename OpTy> static void populateOpPatterns(const LLVMTypeConverter &converter, diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index e882845..6bd0e2d 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -19,10 +19,18 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include <cstdint> using namespace mlir; +static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) { + return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() && + memRefType.getRank() != 0 && + !llvm::is_contained(memRefType.getShape(), 0); +} + namespace { /// Implement the interface to convert MemRef to EmitC. struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { @@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = allocOp.getLoc(); + MemRefType memrefType = allocOp.getType(); + if (!isMemRefTypeLegalForEmitC(memrefType)) { + return rewriter.notifyMatchFailure( + loc, "incompatible memref type for EmitC conversion"); + } + + Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); + Type elementType = memrefType.getElementType(); + IndexType indexType = rewriter.getIndexType(); + emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>( + loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); + + int64_t numElements = 1; + for (int64_t dimSize : memrefType.getShape()) { + numElements *= dimSize; + } + Value numElementsValue = rewriter.create<emitc::ConstantOp>( + loc, indexType, rewriter.getIndexAttr(numElements)); + + Value totalSizeBytes = rewriter.create<emitc::MulOp>( + loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); + + emitc::CallOpaqueOp allocCall; + StringAttr allocFunctionName; + Value alignmentValue; + SmallVector<Value, 2> argsVec; + if (allocOp.getAlignment()) { + allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); + alignmentValue = rewriter.create<emitc::ConstantOp>( + loc, sizeTType, + rewriter.getIntegerAttr(indexType, + allocOp.getAlignment().value_or(0))); + argsVec.push_back(alignmentValue); + } else { + allocFunctionName = rewriter.getStringAttr(mallocFunctionName); + } + + argsVec.push_back(totalSizeBytes); + ValueRange args(argsVec); + + allocCall = rewriter.create<emitc::CallOpaqueOp>( + loc, + emitc::PointerType::get( + emitc::OpaqueType::get(rewriter.getContext(), "void")), + allocFunctionName, args); + + emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); + emitc::CastOp castOp = rewriter.create<emitc::CastOp>( + loc, targetPointerType, allocCall.getResult(0)); + + rewriter.replaceOp(allocOp, castOp); + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; @@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { typeConverter.addConversion( [&](MemRefType memRefType) -> std::optional<Type> { - if (!memRefType.hasStaticShape() || - !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || - llvm::is_contained(memRefType.getShape(), 0)) { + if (!isMemRefTypeLegalForEmitC(memRefType)) { return {}; } Type convertedElementType = @@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, - ConvertStore>(converter, patterns.getContext()); + patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, + ConvertLoad, ConvertStore>(converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index cf25c09..e78dd76 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -28,9 +29,11 @@ using namespace mlir; namespace { struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { + using Base::Base; void runOnOperation() override { TypeConverter converter; - + ConvertMemRefToEmitCOptions options; + options.lowerToCpp = this->lowerToCpp; // Fallback for other types. converter.addConversion([](Type type) -> std::optional<Type> { if (!emitc::isSupportedEmitCType(type)) @@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); + + mlir::ModuleOp module = getOperation(); + module.walk([&](mlir::emitc::CallOpaqueOp callOp) { + if (callOp.getCallee() != alignedAllocFunctionName && + callOp.getCallee() != mallocFunctionName) { + return mlir::WalkResult::advance(); + } + + for (auto &op : *module.getBody()) { + emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op); + if (!includeOp) { + continue; + } + if (includeOp.getIsStandardInclude() && + ((options.lowerToCpp && + includeOp.getInclude() == cppStandardLibraryHeader) || + (!options.lowerToCpp && + includeOp.getInclude() == cStandardLibraryHeader))) { + return mlir::WalkResult::interrupt(); + } + } + + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + StringAttr includeAttr = + builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader); + builder.create<mlir::emitc::IncludeOp>( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); + return mlir::WalkResult::interrupt(); + }); } }; } // namespace diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index 53a1912..dc2035b 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -24,11 +24,12 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IRMapping.h" #include "mlir/Pass/Pass.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/MathExtras.h" + #include <optional> #define DEBUG_TYPE "memref-to-llvm" -#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] " namespace mlir { #define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS @@ -575,8 +576,8 @@ private: Value sizePtr = LLVM::GEPOp::create(rewriter, loc, indexPtrTy, getTypeConverter()->getIndexType(), offsetPtr, idxPlusOne); - return rewriter - .create<LLVM::LoadOp>(loc, getTypeConverter()->getIndexType(), sizePtr) + return LLVM::LoadOp::create(rewriter, loc, + getTypeConverter()->getIndexType(), sizePtr) .getResult(); } @@ -1848,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::xchg; case arith::AtomicRMWKind::maximumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed " - "from fmax to fmaximum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw maximumf changed " + "from fmax to fmaximum, expect more NaNs"; return LLVM::AtomicBinOp::fmaximum; case arith::AtomicRMWKind::maxnumf: return LLVM::AtomicBinOp::fmax; @@ -1859,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::umax; case arith::AtomicRMWKind::minimumf: // TODO: remove this by end of 2025. - LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed " - "from fmin to fminimum, expect more NaNs"); + LDBG() << "the lowering of memref.atomicrmw minimum changed " + "from fmin to fminimum, expect more NaNs"; return LLVM::AtomicBinOp::fminimum; case arith::AtomicRMWKind::minnumf: return LLVM::AtomicBinOp::fmin; diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 905287e1..2549a9c 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -21,19 +21,17 @@ #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include <optional> #define DEBUG_TYPE "nvgpu-to-nvvm" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define DBGSE() (llvm::dbgs()) namespace mlir { #define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS @@ -1106,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: " - << "leading_off:" << leadDimVal << "\t" - << "stride_off :" << strideDimVal << "\t" - << "base_offset:" << offsetVal << "\t" - << "layout_type:" << swizzle << " (" - << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) - << ")\n start_addr : " << baseAddr << "\n"); + LDBG() << "Generating warpgroup.descriptor: " + << "leading_off:" << leadDimVal << "\t" + << "stride_off :" << strideDimVal << "\t" + << "base_offset:" << offsetVal << "\t" + << "layout_type:" << swizzle << " (" + << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) + << ")\n start_addr : " << baseAddr; rewriter.replaceOp(op, dsc); return success(); @@ -1282,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering } else { llvm_unreachable("msg: not supported K shape"); } - LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM - << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n"); + LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM + << ", n = " << wgmmaN << ", k = " << wgmmaK << "]"; } /// Generates WGMMATypesAttr from MLIR Type @@ -1367,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering int tileShapeA = matrixTypeA.getDimSize(1); int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k - << "] [wgmma descriptors] Descriptor A + " - << incrementVal << " | \t "); + LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k + << "] [wgmma descriptors] Descriptor A + " << incrementVal + << " | \t "; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1392,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering int byte = elemB.getIntOrFloatBitWidth() / 8; int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte; incrementVal = incrementVal >> exclude4LSB; - LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n"); + LDBG() << "Descriptor B + " << incrementVal; if (!incrementVal) return desc; return makeAdd(desc, makeI64Const(b, incrementVal)); @@ -1401,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix /// descriptors and arranges them based on induction variables: i, j, and k. Value generateWgmma(int i, int j, int k, Value matrixC) { - LLVM_DEBUG(DBGS() << "\t wgmma." - << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK - << "(A[" << (iterationM * wgmmaM) << ":" - << (iterationM * wgmmaM) + wgmmaM << "][" - << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "] * " - << " B[" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" - << wgmmaN << "])\n"); + LDBG() << "\t wgmma." + << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A[" + << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM + << "][" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "] * " + << " B[" << (iterationK * wgmmaK) << ":" + << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN + << "])"; Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); @@ -1468,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering totalM = op.getDescriptorA().getType().getTensor().getDimSize(0); totalN = op.getDescriptorB().getType().getTensor().getDimSize(1); totalK = op.getDescriptorA().getType().getTensor().getDimSize(1); - LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN - << "] += A[" << totalM << "][" << totalK << "] * B[" - << totalK << "][" << totalN << "] ---===\n"); + LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A[" + << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN + << "] ---==="; // Find the shape for one wgmma instruction findWgmmaShape( diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp index 662ee9e..91788f9 100644 --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -25,11 +25,10 @@ #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "nvvm-to-llvm" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS @@ -52,17 +51,17 @@ struct PtxLowering LogicalResult matchAndRewrite(BasicPtxBuilderInterface op, PatternRewriter &rewriter) const override { if (op.hasIntrinsic()) { - LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n"); + LDBG() << "Ptx Builder does not lower \n\t" << op; return failure(); } SmallVector<std::pair<Value, PTXRegisterMod>> asmValues; - LLVM_DEBUG(DBGS() << op.getPtx() << "\n"); + LDBG() << op.getPtx(); PtxBuilder generator(op, rewriter); op.getAsmValues(rewriter, asmValues); for (auto &[asmValue, modifier] : asmValues) { - LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier); + LDBG() << asmValue << "\t Modifier : " << &modifier; generator.insertValue(asmValue, modifier); } diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 3e434ea..5bd1d49 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -49,7 +49,7 @@ static void getTreePredicates(std::vector<PositionalPredicate> &predList, assert(isa<pdl::AttributeType>(val.getType()) && "expected attribute type"); predList.emplace_back(pos, builder.getIsNotNull()); - if (auto attr = dyn_cast<pdl::AttributeOp>(val.getDefiningOp())) { + if (auto attr = val.getDefiningOp<pdl::AttributeOp>()) { // If the attribute has a type or value, add a constraint. if (Value type = attr.getValueType()) getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); diff --git a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp index e1a9fa59..2d9c661f 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/RootOrdering.cpp @@ -14,9 +14,7 @@ #include "RootOrdering.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SmallVector.h" -#include <queue> #include <utility> using namespace mlir; diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 240491a..807be7e 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -582,6 +582,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // block. This should be reconsidered if we allow break/continue in SCF. rewriter.setInsertionPointToEnd(before); auto condOp = cast<ConditionOp>(before->getTerminator()); + SmallVector<Value> args = llvm::to_vector(condOp.getArgs()); rewriter.replaceOpWithNewOp<cf::CondBranchOp>(condOp, condOp.getCondition(), after, condOp.getArgs(), continuation, ValueRange()); @@ -593,7 +594,7 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. - rewriter.replaceOp(whileOp, condOp.getArgs()); + rewriter.replaceOp(whileOp, args); return success(); } diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index f191f35..badd2f6 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -25,9 +25,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/IRMapping.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/Support/Debug.h" #include <optional> diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index aae3271..9b61540 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -1493,11 +1493,11 @@ public: Value extended; if (op2TypeWidth < dstTypeWidth) { if (isUnsignedIntegerOrVector(op2Type)) { - extended = rewriter.template create<LLVM::ZExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::ZExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } else { - extended = rewriter.template create<LLVM::SExtOp>( - loc, dstType, adaptor.getOperand2()); + extended = + LLVM::SExtOp::create(rewriter, loc, dstType, adaptor.getOperand2()); } } else if (op2TypeWidth == dstTypeWidth) { extended = adaptor.getOperand2(); @@ -1505,8 +1505,8 @@ public: return failure(); } - Value result = rewriter.template create<LLVMOp>( - loc, dstType, adaptor.getOperand1(), extended); + Value result = + LLVMOp::create(rewriter, loc, dstType, adaptor.getOperand1(), extended); rewriter.replaceOp(op, result); return success(); } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp index 7025c5a..0ff9fb3 100644 --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/STLExtras.h" diff --git a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt index 15560aa..564f36f 100644 --- a/mlir/lib/Conversion/MeshToMPI/CMakeLists.txt +++ b/mlir/lib/Conversion/ShardToMPI/CMakeLists.txt @@ -1,8 +1,8 @@ -add_mlir_conversion_library(MLIRMeshToMPI - MeshToMPI.cpp +add_mlir_conversion_library(MLIRShardToMPI + ShardToMPI.cpp ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MeshToMPI + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ShardToMPI DEPENDS MLIRConversionPassIncGen @@ -17,7 +17,7 @@ add_mlir_conversion_library(MLIRMeshToMPI MLIRLinalgTransforms MLIRMemRefDialect MLIRPass - MLIRMeshDialect + MLIRShardDialect MLIRMPIDialect MLIRTransforms ) diff --git a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index 63b1fda..fa9e544 100644 --- a/mlir/lib/Conversion/MeshToMPI/MeshToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -1,4 +1,4 @@ -//===- MeshToMPI.cpp - Mesh to MPI dialect conversion -----------------===// +//===- ShardToMPI.cpp - Shard to MPI dialect conversion -----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// // -// This file implements a translation of Mesh communication ops tp MPI ops. +// This file implements a translation of Shard communication ops to MPI ops. // //===----------------------------------------------------------------------===// -#include "mlir/Conversion/MeshToMPI/MeshToMPI.h" +#include "mlir/Conversion/ShardToMPI/ShardToMPI.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" @@ -20,11 +20,11 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Transforms/Simplifications.h" -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Transforms/Simplifications.h" +#include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Builders.h" @@ -35,16 +35,15 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#define DEBUG_TYPE "mesh-to-mpi" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define DEBUG_TYPE "shard-to-mpi" namespace mlir { -#define GEN_PASS_DEF_CONVERTMESHTOMPIPASS +#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS #include "mlir/Conversion/Passes.h.inc" } // namespace mlir using namespace mlir; -using namespace mesh; +using namespace shard; namespace { /// Converts a vector of OpFoldResults (ints) into vector of Values of the @@ -177,9 +176,8 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { auto type = RankedTensorType::get({nSplits, 2}, i64); Value resHaloSizes = haloSizes.empty() - ? rewriter - .create<tensor::EmptyOp>(loc, std::array<int64_t, 2>{0, 0}, - i64) + ? tensor::EmptyOp::create(rewriter, loc, + std::array<int64_t, 2>{0, 0}, i64) .getResult() : tensor::FromElementsOp::create(rewriter, loc, type, haloSizes) .getResult(); @@ -188,18 +186,18 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { // maxSplitSize+1}. Store the offsets in the tensor but set trailing // elements for smaller split-groups to -1. Computing the max size of the // split groups needs using collectiveProcessGroupSize (which needs the - // MeshOp) + // GridOp) Value resOffsets; if (adaptor.getStaticShardedDimsOffsets().empty()) { resOffsets = tensor::EmptyOp::create(rewriter, loc, std::array<int64_t, 2>{0, 0}, i64); } else { SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); + auto gridOp = getGrid(op, symbolTableCollection); int64_t maxSplitSize = 0; for (auto axes : splitAxes) { int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); assert(splitSize != ShapedType::kDynamic); maxSplitSize = std::max<int64_t>(maxSplitSize, splitSize); } @@ -218,7 +216,7 @@ struct ConvertShardingOp : public OpConversionPattern<ShardingOp> { int64_t curr = 0; for (auto [i, axes] : llvm::enumerate(splitAxes)) { int64_t splitSize = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); assert(splitSize != ShapedType::kDynamic && splitSize < maxSplitSize); ++splitSize; // add one for the total size ArrayRef<Value> values(&offsets[curr], splitSize); @@ -264,20 +262,20 @@ struct ConvertProcessMultiIndexOp SymbolTableCollection symbolTableCollection; Location loc = op.getLoc(); - auto meshOp = getMesh(op, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto gridOp = getGrid(op, symbolTableCollection); + // For now we only support static grid shapes + if (ShapedType::isDynamicShape(gridOp.getShape())) return failure(); SmallVector<Value> dims; llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); - Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), meshOp); + Value rank = ProcessLinearIndexOp::create(rewriter, op.getLoc(), gridOp); auto mIdx = linearToMultiIndex(loc, rewriter, rank, dims); - // optionally extract subset of mesh axes + // optionally extract subset of grid axes auto axes = adaptor.getAxes(); if (!axes.empty()) { SmallVector<Value> subIndex; @@ -306,13 +304,11 @@ public: auto ctx = op.getContext(); Value commWorld = mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx)); - auto rank = - rewriter - .create<mpi::CommRankOp>( - loc, - TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, - commWorld) - .getRank(); + auto rank = mpi::CommRankOp::create( + rewriter, loc, + TypeRange{mpi::RetvalType::get(ctx), rewriter.getI32Type()}, + commWorld) + .getRank(); rewriter.replaceOpWithNewOp<arith::IndexCastOp>(op, rewriter.getIndexType(), rank); return success(); @@ -338,12 +334,12 @@ struct ConvertNeighborsLinearIndicesOp Location loc = op.getLoc(); SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(op, symbolTableCollection); + auto gridOp = getGrid(op, symbolTableCollection); auto mIdx = adaptor.getDevice(); auto orgIdx = mIdx[axes[0]]; SmallVector<Value> dims; llvm::transform( - meshOp.getShape(), std::back_inserter(dims), [&](int64_t i) { + gridOp.getShape(), std::back_inserter(dims), [&](int64_t i) { return arith::ConstantIndexOp::create(rewriter, loc, i).getResult(); }); Value dimSz = dims[axes[0]]; @@ -394,14 +390,14 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { auto sharding = op.getSharding().getDefiningOp<ShardingOp>(); if (!sharding) { return op->emitError() - << "Expected SharingOp as defining op for sharding" + << "Expected ShardingOp as defining op for sharding" << " but found " << adaptor.getSharding()[0].getDefiningOp(); } // Compute the sharded shape by applying the sharding to the input shape. // If shardedDimsOffsets is not defined in the sharding, the shard shape is // computed by dividing the dimension size by the number of shards in that - // dimension (which is given by the size of the mesh axes provided in + // dimension (which is given by the size of the grid axes provided in // split-axes). Odd elements get distributed to trailing shards. If a // shardedDimsOffsets is provided, the shard shape is computed by // subtracting the offset of the current shard from the offset of the next @@ -431,11 +427,11 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { SmallVector<Value> multiIdx = getMixedAsValues(rewriter, loc, adaptor.getDevice(), dynDevice, index); - // Get the MeshOp, the mesh shape is needed to compute the sharded shape. + // Get the GridOp, the grid shape is needed to compute the sharded shape. SymbolTableCollection symbolTableCollection; - auto meshOp = getMesh(sharding, symbolTableCollection); - // For now we only support static mesh shapes - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto gridOp = getGrid(sharding, symbolTableCollection); + // For now we only support static grid shapes + if (ShapedType::isDynamicShape(gridOp.getShape())) return failure(); auto splitAxes = sharding.getSplitAxes().getAxes(); @@ -455,7 +451,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { tmp); } - // With static mesh shape the sizes of the split axes are known. + // With static grid shape the sizes of the split axes are known. // Hence the start/pos for each split axes in shardDimsOffsets can be // computed statically. int64_t pos = 0; @@ -475,10 +471,10 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { // Create a value from the static position in shardDimsOffsets. Value posVal = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(pos)); - // Get the index of the local shard in the mesh axis. + // Get the index of the local shard in the grid axis. Value idx = multiIdx[axes[0]]; auto numShards = - collectiveProcessGroupSize(axes.asArrayRef(), meshOp.getShape()); + collectiveProcessGroupSize(axes.asArrayRef(), gridOp.getShape()); if (shardedDimsOffs) { // If sharded dims offsets are provided, use them to compute the // sharded shape. @@ -556,13 +552,13 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { matchAndRewrite(AllReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SymbolTableCollection symbolTableCollection; - auto mesh = adaptor.getMesh(); - mlir::mesh::MeshOp meshOp = getMesh(op, symbolTableCollection); - if (!meshOp) - return op->emitError() << "No mesh found for AllReduceOp"; - if (ShapedType::isDynamicShape(meshOp.getShape())) + auto grid = adaptor.getGrid(); + mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection); + if (!gridOp) + return op->emitError() << "No grid found for AllReduceOp"; + if (ShapedType::isDynamicShape(gridOp.getShape())) return op->emitError() - << "Dynamic mesh shape not supported in AllReduceOp"; + << "Dynamic grid shape not supported in AllReduceOp"; ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); Value input = adaptor.getInput(); @@ -592,27 +588,27 @@ struct ConvertAllReduceOp : public OpConversionPattern<AllReduceOp> { linalg::CopyOp::create(iBuilder, input, buffer); // Get an MPI_Comm_split for the AllReduce operation. - // The color is the linear index of the process in the mesh along the - // non-reduced axes. The key is the linear index of the process in the mesh + // The color is the linear index of the process in the grid along the + // non-reduced axes. The key is the linear index of the process in the grid // along the reduced axes. - SmallVector<Type> indexResultTypes(meshOp.getShape().size(), + SmallVector<Type> indexResultTypes(gridOp.getShape().size(), iBuilder.getIndexType()); SmallVector<Value> myMultiIndex = - ProcessMultiIndexOp::create(iBuilder, indexResultTypes, mesh) + ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid) .getResult(); Value zero = arith::ConstantIndexOp::create(iBuilder, 0); SmallVector<Value> multiKey(myMultiIndex.size(), zero); - auto redAxes = adaptor.getMeshAxes(); + auto redAxes = adaptor.getGridAxes(); for (auto axis : redAxes) { multiKey[axis] = myMultiIndex[axis]; myMultiIndex[axis] = zero; } Value color = - createProcessLinearIndex(mesh, myMultiIndex, redAxes, iBuilder); + createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder); color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color); - Value key = createProcessLinearIndex(mesh, multiKey, redAxes, iBuilder); + Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder); key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key); // Finally split the communicator @@ -698,15 +694,14 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { } auto rank = cast<ShapedType>(array.getType()).getRank(); auto opSplitAxes = adaptor.getSplitAxes().getAxes(); - auto mesh = adaptor.getMesh(); - auto meshOp = getMesh(op, symbolTableCollection); + auto grid = adaptor.getGrid(); + auto gridOp = getGrid(op, symbolTableCollection); // subviews need Index values for (auto &sz : haloSizes) { if (auto value = dyn_cast<Value>(sz)) - sz = - rewriter - .create<arith::IndexCastOp>(loc, rewriter.getIndexType(), value) - .getResult(); + sz = arith::IndexCastOp::create(rewriter, loc, rewriter.getIndexType(), + value) + .getResult(); } // most of the offset/size/stride data is the same for all dims @@ -745,10 +740,10 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { auto zeroAttr = rewriter.getI32IntegerAttr(0); // for detecting v<0 auto zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); - SmallVector<Type> indexResultTypes(meshOp.getShape().size(), + SmallVector<Type> indexResultTypes(gridOp.getShape().size(), rewriter.getIndexType()); auto myMultiIndex = - ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, mesh) + ProcessMultiIndexOp::create(rewriter, loc, indexResultTypes, grid) .getResult(); // traverse all split axes from high to low dim for (ssize_t dim = opSplitAxes.size() - 1; dim >= 0; --dim) { @@ -758,9 +753,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { assert(currHaloDim >= 0 && (size_t)currHaloDim < haloSizes.size() / 2); // Get the linearized ids of the neighbors (down and up) for the // given split - auto tmp = rewriter - .create<NeighborsLinearIndicesOp>(loc, mesh, myMultiIndex, - splitAxes) + auto tmp = NeighborsLinearIndicesOp::create(rewriter, loc, grid, + myMultiIndex, splitAxes) .getResults(); // MPI operates on i32... Value neighbourIDs[2] = { @@ -791,7 +785,7 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { dimSizes[dim] = upperHalo ? haloSizes[currHaloDim * 2 + 1] : haloSizes[currHaloDim * 2]; // Check if we need to send and/or receive - // Processes on the mesh borders have only one neighbor + // Processes on the grid borders have only one neighbor auto to = upperHalo ? neighbourIDs[0] : neighbourIDs[1]; auto from = upperHalo ? neighbourIDs[1] : neighbourIDs[0]; auto hasFrom = arith::CmpIOp::create( @@ -869,8 +863,8 @@ struct ConvertUpdateHaloOp : public OpConversionPattern<UpdateHaloOp> { } }; -struct ConvertMeshToMPIPass - : public impl::ConvertMeshToMPIPassBase<ConvertMeshToMPIPass> { +struct ConvertShardToMPIPass + : public impl::ConvertShardToMPIPassBase<ConvertShardToMPIPass> { using Base::Base; /// Run the dialect converter on the module. @@ -879,12 +873,12 @@ struct ConvertMeshToMPIPass RewritePatternSet patterns(ctxt); ConversionTarget target(getContext()); - // Define a type converter to convert mesh::ShardingType, + // Define a type converter to convert shard::ShardingType, // mostly for use in return operations. TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - // convert mesh::ShardingType to a tuple of RankedTensorTypes + // convert shard::ShardingType to a tuple of RankedTensorTypes typeConverter.addConversion( [](ShardingType type, SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> { @@ -920,10 +914,10 @@ struct ConvertMeshToMPIPass return results; }); - // No mesh dialect should left after conversion... - target.addIllegalDialect<mesh::MeshDialect>(); - // ...except the global MeshOp. MeshShapeOp which will get folded later. - target.addLegalOp<mesh::MeshOp, mesh::MeshShapeOp>(); + // No shard dialect should left after conversion... + target.addIllegalDialect<shard::ShardDialect>(); + // ...except the global GridOp. GridShapeOp which will get folded later. + target.addLegalOp<shard::GridOp, shard::GridShapeOp>(); // Allow all the stuff that our patterns will convert to target.addLegalDialect< BuiltinDialect, mpi::MPIDialect, scf::SCFDialect, arith::ArithDialect, @@ -951,7 +945,7 @@ struct ConvertMeshToMPIPass // Folding patterns cannot be mixed with conversion patterns -> extra pass. patterns.clear(); SymbolTableCollection symbolTableCollection; - mlir::mesh::populateFoldingPatterns(patterns, symbolTableCollection); + mlir::shard::populateFoldingPatterns(patterns, symbolTableCollection); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } }; diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp index f07386e..8cd650e 100644 --- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp @@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass SPIRVConversionOptions options; options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes; + options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes; SPIRVTypeConverter typeConverter(targetAttr, options); RewritePatternSet patterns(context); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index ec55091..0e3de06 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -22,7 +22,6 @@ #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" @@ -570,10 +569,9 @@ static Value createLinalgBodyCalculationForElementwiseOp( // to UIToFP. if (srcTy.isUnsignedInteger() && isa<FloatType>(dstTy)) { auto unrealizedCast = - rewriter - .create<UnrealizedConversionCastOp>( - loc, rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), - args[0]) + UnrealizedConversionCastOp::create( + rewriter, loc, + rewriter.getIntegerType(srcTy.getIntOrFloatBitWidth()), args[0]) .getResult(0); return arith::UIToFPOp::create(rewriter, loc, resultTypes[0], unrealizedCast); @@ -869,14 +867,13 @@ static Value broadcastDynamicDimension(PatternRewriter &rewriter, Location loc, // Emit 'linalg.generic' op auto resultTensor = - opBuilder - .create<linalg::GenericOp>( - loc, outputTensor.getType(), operand, outputTensor, affineMaps, - getNParallelLoopsAttrs(rank), - [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { - // Emit 'linalg.yield' op - linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); - }) + linalg::GenericOp::create( + opBuilder, loc, outputTensor.getType(), operand, outputTensor, + affineMaps, getNParallelLoopsAttrs(rank), + [&](OpBuilder &opBuilder, Location loc, ValueRange blockArgs) { + // Emit 'linalg.yield' op + linalg::YieldOp::create(opBuilder, loc, blockArgs.front()); + }) .getResult(0); // Cast to original operand type if necessary @@ -1156,11 +1153,9 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, inputs.push_back(input); // First fill the output buffer with the init value. - auto emptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, resultTy.getElementType(), - dynDims) - .getResult(); + auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) + .getResult(); auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); if (!fillValueAttr) @@ -1168,10 +1163,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, op, "No initial value found for reduction operation"); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); - auto filledTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValue}, - ValueRange{emptyTensor}) - .result(); + auto filledTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{emptyTensor}) + .result(); outputs.push_back(filledTensor); bool isNanIgnoreMode = false; @@ -1187,14 +1182,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, auto trueAttr = rewriter.getBoolAttr(true); auto trueValue = arith::ConstantOp::create(rewriter, loc, trueAttr); auto emptyBoolTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, trueValue.getType(), - dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + trueValue.getType(), dynDims) .getResult(); auto allResultsNaNTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{trueValue}, - ValueRange{emptyBoolTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{trueValue}, + ValueRange{emptyBoolTensor}) .result(); // Note that because the linalg::ReduceOp has two variadic arguments // (inputs and outputs) and it has the SameVariadicOperandSize trait we @@ -1262,22 +1255,19 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false)); auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr); auto emptyNanTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) .getResult(); auto nanFilledTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{nanValue}, - ValueRange{emptyNanTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{nanValue}, + ValueRange{emptyNanTensor}) .result(); // Create an empty tensor, non need to fill this since it will be // overwritten by the select. auto finalEmptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, + resultTy.getElementType(), dynDims) .getResult(); // Do a selection between the tensors akin to: @@ -1504,12 +1494,11 @@ public: Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg]; if (valueTy.isUnsignedInteger()) { - value = nestedBuilder - .create<UnrealizedConversionCastOp>( - nestedLoc, - nestedBuilder.getIntegerType( - valueTy.getIntOrFloatBitWidth()), - value) + value = UnrealizedConversionCastOp::create( + nestedBuilder, nestedLoc, + nestedBuilder.getIntegerType( + valueTy.getIntOrFloatBitWidth()), + value) .getResult(0); } if (valueTy.getIntOrFloatBitWidth() < 32) { @@ -1558,9 +1547,8 @@ public: } if (outIntType.isUnsignedInteger()) { - value = nestedBuilder - .create<UnrealizedConversionCastOp>(nestedLoc, - outIntType, value) + value = UnrealizedConversionCastOp::create(nestedBuilder, nestedLoc, + outIntType, value) .getResult(0); } linalg::YieldOp::create(nestedBuilder, loc, value); @@ -2096,10 +2084,9 @@ public: Value axisDimSize = tensor::DimOp::create(rewriter, loc, input, axis); // First fill the output buffer with the init value. - auto emptyTensor = rewriter - .create<tensor::EmptyOp>(loc, inputTy.getShape(), - inputTy.getElementType(), - ArrayRef<Value>({dynDims})) + auto emptyTensor = tensor::EmptyOp::create( + rewriter, loc, inputTy.getShape(), + inputTy.getElementType(), ArrayRef<Value>({dynDims})) .getResult(); SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; @@ -2242,23 +2229,22 @@ public: } // First fill the output buffer for the index. - auto emptyTensorIdx = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - outElementTy, dynDims) - .getResult(); + auto emptyTensorIdx = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + outElementTy, dynDims) + .getResult(); auto fillValueIdx = arith::ConstantOp::create( rewriter, loc, rewriter.getIntegerAttr(outElementTy, 0)); auto filledTensorIdx = - rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValueIdx}, - ValueRange{emptyTensorIdx}) + linalg::FillOp::create(rewriter, loc, ValueRange{fillValueIdx}, + ValueRange{emptyTensorIdx}) .result(); // Second fill the output buffer for the running max. - auto emptyTensorMax = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - inElementTy, dynDims) - .getResult(); + auto emptyTensorMax = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), inElementTy, + dynDims) + .getResult(); auto fillValueMaxAttr = createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); @@ -2269,9 +2255,8 @@ public: auto fillValueMax = arith::ConstantOp::create(rewriter, loc, fillValueMaxAttr); auto filledTensorMax = - rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValueMax}, - ValueRange{emptyTensorMax}) + linalg::FillOp::create(rewriter, loc, ValueRange{fillValueMax}, + ValueRange{emptyTensorMax}) .result(); // We need to reduce along the arg-max axis, with parallel operations along @@ -2372,9 +2357,8 @@ public: auto loc = op.getLoc(); auto emptyTensor = - rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), resultElementTy, - dynamicDims) + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultElementTy, dynamicDims) .getResult(); SmallVector<AffineMap, 2> affineMaps = { @@ -2449,10 +2433,10 @@ public: } } - auto emptyTensor = rewriter - .create<tensor::EmptyOp>(loc, resultTy.getShape(), - resultElementTy, dynDims) - .getResult(); + auto emptyTensor = + tensor::EmptyOp::create(rewriter, loc, resultTy.getShape(), + resultElementTy, dynDims) + .getResult(); SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank()), @@ -2586,10 +2570,10 @@ struct RFFT2dConverter final : public OpRewritePattern<RFFT2dOp> { tensor::EmptyOp::create(rewriter, loc, type, dynamicSizes); auto fillValueAttr = rewriter.getZeroAttr(type.getElementType()); auto fillValue = arith::ConstantOp::create(rewriter, loc, fillValueAttr); - auto filledTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{fillValue}, - ValueRange{emptyTensor}) - .result(); + auto filledTensor = + linalg::FillOp::create(rewriter, loc, ValueRange{fillValue}, + ValueRange{emptyTensor}) + .result(); return filledTensor; } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 3a20524..da1fb20 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -64,19 +64,20 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias, Value conv, Value result, ArrayRef<AffineMap> indexingMaps) { ShapedType resultTy = cast<ShapedType>(conv.getType()); - return rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({bias, conv}), result, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [](OpBuilder &builder, Location loc, ValueRange args) { - Value biasVal = args[0]; - Type resType = args[1].getType(); - if (resType != biasVal.getType()) { - biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal); - } - Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]); - linalg::YieldOp::create(builder, loc, added); - }) + return linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({bias, conv}), result, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = + arith::ExtSIOp::create(builder, loc, resType, biasVal); + } + Value added = + arith::AddIOp::create(builder, loc, biasVal, args[1]); + linalg::YieldOp::create(builder, loc, added); + }) .getResult(0); } @@ -124,23 +125,23 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter, indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); // Build the broadcast-like operation as a linalg.generic. - return rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({source}), result, indexingMaps, - getNParallelLoopsAttrs(resultTy.getRank()), - [&resultTy](OpBuilder &builder, Location loc, ValueRange args) { - Value biasVal = args[0]; - Type resType = args[1].getType(); - if (resType != biasVal.getType()) { - biasVal = - resultTy.getElementType().isFloat() - ? arith::ExtFOp::create(builder, loc, resType, biasVal) - .getResult() - : arith::ExtSIOp::create(builder, loc, resType, biasVal) - .getResult(); - } - linalg::YieldOp::create(builder, loc, biasVal); - }) + return linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({source}), result, + indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()), + [&resultTy](OpBuilder &builder, Location loc, ValueRange args) { + Value biasVal = args[0]; + Type resType = args[1].getType(); + if (resType != biasVal.getType()) { + biasVal = + resultTy.getElementType().isFloat() + ? arith::ExtFOp::create(builder, loc, resType, biasVal) + .getResult() + : arith::ExtSIOp::create(builder, loc, resType, + biasVal) + .getResult(); + } + linalg::YieldOp::create(builder, loc, biasVal); + }) .getResult(0); } @@ -397,21 +398,19 @@ public: auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp); - Value conv = - rewriter - .create<LinalgConvQOp>( - loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{broadcastBias}, strideAttr, dilationAttr) - ->getResult(0); + Value conv = LinalgConvQOp::create( + rewriter, loc, resultTy, + ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{broadcastBias}, strideAttr, dilationAttr) + ->getResult(0); rewriter.replaceOp(op, conv); return success(); } - Value conv = rewriter - .create<LinalgConvOp>( - loc, accTy, ValueRange{input, weight}, - ValueRange{broadcastBias}, strideAttr, dilationAttr) + Value conv = LinalgConvOp::create( + rewriter, loc, accTy, ValueRange{input, weight}, + ValueRange{broadcastBias}, strideAttr, dilationAttr) ->getResult(0); // We may need to truncate back to the result type if the accumulator was @@ -529,9 +528,8 @@ public: Value emptyTensor = tensor::EmptyOp::create( rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims); Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr); - Value zeroTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{zero}, - ValueRange{emptyTensor}) + Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{emptyTensor}) .result(); Value biasEmptyTensor = tensor::EmptyOp::create( @@ -544,10 +542,9 @@ public: indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank)); if (hasNullZps) { - Value conv = rewriter - .create<linalg::DepthwiseConv2DNhwcHwcmOp>( - loc, linalgConvTy, ValueRange{input, weight}, - ValueRange{zeroTensor}, strideAttr, dilationAttr) + Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create( + rewriter, loc, linalgConvTy, ValueRange{input, weight}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) .getResult(0); // We may need to truncate back to the result type if the accumulator was @@ -565,22 +562,20 @@ public: rewriter, loc, resultTy, conv, reassociationMap); Value result = - rewriter - .create<linalg::GenericOp>( - loc, resultTy, ValueRange({bias, convReshape}), - biasEmptyTensor, indexingMaps, - getNParallelLoopsAttrs(resultRank), - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - Value added; - if (llvm::isa<FloatType>(inputETy)) - added = arith::AddFOp::create(nestedBuilder, loc, args[0], - args[1]); - else - added = arith::AddIOp::create(nestedBuilder, loc, args[0], - args[1]); - linalg::YieldOp::create(nestedBuilder, nestedLoc, added); - }) + linalg::GenericOp::create( + rewriter, loc, resultTy, ValueRange({bias, convReshape}), + biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + Value added; + if (llvm::isa<FloatType>(inputETy)) + added = arith::AddFOp::create(nestedBuilder, loc, args[0], + args[1]); + else + added = arith::AddIOp::create(nestedBuilder, loc, args[0], + args[1]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, added); + }) .getResult(0); rewriter.replaceOp(op, result); } else { @@ -588,12 +583,11 @@ public: IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal); auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp); auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp); - Value conv = - rewriter - .create<linalg::DepthwiseConv2DNhwcHwcmQOp>( - loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal}, - ValueRange{zeroTensor}, strideAttr, dilationAttr) - .getResult(0); + Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create( + rewriter, loc, linalgConvTy, + ValueRange{input, weight, iZpVal, kZpVal}, + ValueRange{zeroTensor}, strideAttr, dilationAttr) + .getResult(0); SmallVector<ReassociationExprs, 4> reassociationMap; createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter); Value convReshape = tensor::CollapseShapeOp::create( @@ -639,9 +633,8 @@ public: auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(), outputTy.getElementType(), filteredDims); - Value zeroTensor = rewriter - .create<linalg::FillOp>(loc, ValueRange{zero}, - ValueRange{emptyTensor}) + Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero}, + ValueRange{emptyTensor}) .result(); FailureOr<int64_t> maybeAZp = op.getAZeroPoint(); @@ -910,20 +903,18 @@ public: rewriter, loc, accTy.getShape(), accETy, dynamicDims); Value filledEmptyTensor = - rewriter - .create<linalg::FillOp>(loc, ValueRange{initialValue}, - ValueRange{poolEmptyTensor}) + linalg::FillOp::create(rewriter, loc, ValueRange{initialValue}, + ValueRange{poolEmptyTensor}) .result(); Value fakeWindowDims = tensor::EmptyOp::create(rewriter, loc, kernel, accETy); // Sum across the pooled region. - Value poolingOp = rewriter - .create<linalg::PoolingNhwcSumOp>( - loc, ArrayRef<Type>{accTy}, - ValueRange{paddedInput, fakeWindowDims}, - filledEmptyTensor, strideAttr, dilationAttr) + Value poolingOp = linalg::PoolingNhwcSumOp::create( + rewriter, loc, ArrayRef<Type>{accTy}, + ValueRange{paddedInput, fakeWindowDims}, + filledEmptyTensor, strideAttr, dilationAttr) .getResult(0); // Normalize the summed value by the number of elements grouped in each @@ -1050,10 +1041,9 @@ public: Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8); auto scaled = - rewriter - .create<tosa::ApplyScaleOp>( - loc, rewriter.getI32Type(), poolVal, multiplier, shift, - rewriter.getStringAttr("SINGLE_ROUND")) + tosa::ApplyScaleOp::create( + rewriter, loc, rewriter.getI32Type(), poolVal, multiplier, + shift, rewriter.getStringAttr("SINGLE_ROUND")) .getResult(); // If we have quantization information we need to apply output diff --git a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp index b83f5ec9..f8efb34 100644 --- a/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp +++ b/mlir/lib/Conversion/TosaToMLProgram/TosaToMLProgram.cpp @@ -13,7 +13,6 @@ #include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h" #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/PatternMatch.h" using namespace mlir; diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp index 77aab85..1d1904f 100644 --- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp +++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp @@ -31,10 +31,9 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "vector-to-gpu" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") namespace mlir { #define GEN_PASS_DEF_CONVERTVECTORTOGPU @@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op, // by all operations. if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) { if (!supportsMMaMatrixType(op, useNvGpu)) { - LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n"); + LDBG() << "cannot convert op: " << *op; return true; } return false; @@ -482,14 +481,12 @@ struct CombineTransferReadOpTranspose final permutationMap.compose(transferReadOp.getPermutationMap()); auto loc = op.getLoc(); - Value result = - rewriter - .create<vector::TransferReadOp>( - loc, resultType, transferReadOp.getBase(), - transferReadOp.getIndices(), AffineMapAttr::get(newMap), - transferReadOp.getPadding(), transferReadOp.getMask(), - transferReadOp.getInBoundsAttr()) - .getResult(); + Value result = vector::TransferReadOp::create( + rewriter, loc, resultType, transferReadOp.getBase(), + transferReadOp.getIndices(), AffineMapAttr::get(newMap), + transferReadOp.getPadding(), transferReadOp.getMask(), + transferReadOp.getInBoundsAttr()) + .getResult(); // Fuse through the integer extend op. if (extOp) { @@ -550,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } @@ -585,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, isTranspose ? rewriter.getUnitAttr() : UnitAttr()); valueMapping[mappingResult] = load; - LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n"); + LDBG() << "transfer read to: " << load; return success(); } @@ -599,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, std::optional<int64_t> stride = getStaticallyKnownRowStride(op.getShapedType()); if (!stride.has_value()) { - LLVM_DEBUG(DBGS() << "no stride\n"); + LDBG() << "no stride"; return rewriter.notifyMatchFailure(op, "no stride"); } auto it = valueMapping.find(op.getVector()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no mapping\n"); + LDBG() << "no mapping"; return rewriter.notifyMatchFailure(op, "no mapping"); } @@ -615,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr()); (void)store; - LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n"); + LDBG() << "transfer write to: " << store; - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -643,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); auto dense = dyn_cast<SplatElementsAttr>(op.getValue()); if (!dense) { - LLVM_DEBUG(DBGS() << "not a splat\n"); + LDBG() << "not a splat"; return rewriter.notifyMatchFailure(op, "not a splat"); } @@ -679,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { mlir::AffineMap map = op.getPermutationMap(); if (map.getNumResults() != 2) { - LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` " - "is not a 2d operand\n"); + LDBG() << "Failed because the result of `vector.transfer_read` " + "is not a 2d operand"; return failure(); } @@ -693,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) { auto exprN = dyn_cast<AffineDimExpr>(dN); if (!exprM || !exprN) { - LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim " - "expressions, then transpose cannot be determined.\n"); + LDBG() << "Failed because expressions are not affine dim " + "expressions, then transpose cannot be determined."; return failure(); } @@ -711,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo = nvgpu::getWarpMatrixInfo(op); if (failed(warpMatrixInfo)) { - LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n"); + LDBG() << "no warpMatrixInfo"; return rewriter.notifyMatchFailure(op, "no warpMatrixInfo"); } FailureOr<nvgpu::FragmentElementInfo> regInfo = nvgpu::getMmaSyncRegisterType(*warpMatrixInfo); if (failed(regInfo)) { - LLVM_DEBUG(DBGS() << "not mma sync reg info\n"); + LDBG() << "not mma sync reg info"; return rewriter.notifyMatchFailure(op, "not mma sync reg info"); } FailureOr<bool> transpose = isTransposed(op); if (failed(transpose)) { - LLVM_DEBUG(DBGS() << "failed to determine the transpose\n"); + LDBG() << "failed to determine the transpose"; return rewriter.notifyMatchFailure( op, "Op should likely not be converted to a nvgpu.ldmatrix call."); } @@ -733,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose); if (failed(params)) { - LLVM_DEBUG( - DBGS() - << "failed to convert vector.transfer_read to ldmatrix. " - << "Op should likely not be converted to a nvgpu.ldmatrix call.\n"); + LDBG() << "failed to convert vector.transfer_read to ldmatrix. " + << "Op should likely not be converted to a nvgpu.ldmatrix call."; return rewriter.notifyMatchFailure( op, "failed to convert vector.transfer_read to ldmatrix; this op " "likely should not be converted to a nvgpu.ldmatrix call."); @@ -747,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, FailureOr<AffineMap> offsets = nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params); if (failed(offsets)) { - LLVM_DEBUG(DBGS() << "no offsets\n"); + LDBG() << "no offsets"; return rewriter.notifyMatchFailure(op, "no offsets"); } @@ -936,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices); } - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1134,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, loop.getNumResults()))) rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); - LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n"); - LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n"); - LLVM_DEBUG(DBGS() << "erase: " << loop); + LDBG() << "newLoop now: " << newLoop; + LDBG() << "stripped scf.for: " << loop; + LDBG() << "erase: " << loop; rewriter.eraseOp(loop); return newLoop; @@ -1152,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, for (const auto &operand : llvm::enumerate(op.getInitArgs())) { auto it = valueMapping.find(operand.value()); if (it == valueMapping.end()) { - LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n"); + LDBG() << "no value mapping for: " << operand.value(); continue; } argMapping.push_back(std::make_pair( @@ -1170,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, loopBody.getArgument(mapping.second + newForOp.getNumInductionVars()); } - LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n"); + LDBG() << "scf.for to: " << newForOp; return success(); } @@ -1193,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, } scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands); - LLVM_DEBUG(DBGS() << "erase: " << op << "\n"); + LDBG() << "erase: " << op; rewriter.eraseOp(op); return success(); } @@ -1246,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter, auto globalRes = LogicalResult::success(); for (Operation *op : ops) { - LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n"); + LDBG() << "Process op: " << *op; // Apparently callers do not want to early exit on failure here. auto res = LogicalResult::success(); if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) { diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9cd491c..17a79e3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -29,7 +29,9 @@ #include "mlir/Target/LLVMIR/TypeToLLVM.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/APFloat.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/Support/Casting.h" + #include <optional> using namespace mlir; @@ -1068,39 +1070,6 @@ public: } }; -class VectorExtractElementOpConversion - : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { -public: - using ConvertOpToLLVMPattern< - vector::ExtractElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = extractEltOp.getSourceVectorType(); - auto llvmType = typeConverter->convertType(vectorType.getElementType()); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = extractEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - class VectorExtractOpConversion : public ConvertOpToLLVMPattern<vector::ExtractOp> { public: @@ -1204,39 +1173,6 @@ public: } }; -class VectorInsertElementOpConversion - : public ConvertOpToLLVMPattern<vector::InsertElementOp> { -public: - using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = typeConverter->convertType(vectorType); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = insertEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - class VectorInsertOpConversion : public ConvertOpToLLVMPattern<vector::InsertOp> { public: @@ -2242,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorGatherOpConversion, VectorScatterOpConversion>( converter, useVectorAlignment); patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, - VectorExtractElementOpConversion, VectorExtractOpConversion, - VectorFMAOp1DConversion, VectorInsertElementOpConversion, + VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index 4c1047a..508f4e2 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -24,7 +24,6 @@ #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -691,7 +690,7 @@ struct PrepareTransferWriteConversion /// %lastIndex = arith.subi %length, %c1 : index /// vector.print punctuation <open> /// scf.for %i = %c0 to %length step %c1 { -/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> +/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32> /// vector.print %el : i32 punctuation <no_punctuation> /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index /// scf.if %notLastIndex { @@ -1644,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> { /// Is rewritten to approximately the following pseudo-IR: /// ``` /// for i = 0 to 9 { -/// %t = vector.extractelement %vec[i] : vector<9xf32> +/// %t = vector.extract %vec[i] : f32 from vector<9xf32> /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> /// } /// ``` diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 986eae3..a4be7d4 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -335,63 +335,6 @@ struct VectorInsertOpConvert final } }; -struct VectorExtractElementOpConvert final - : public OpConversionPattern<vector::ExtractElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultType = getTypeConverter()->convertType(extractOp.getType()); - if (!resultType) - return failure(); - - if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { - rewriter.replaceOp(extractOp, adaptor.getVector()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( - extractOp, resultType, adaptor.getVector(), - rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())})); - else - rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( - extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - -struct VectorInsertElementOpConvert final - : public OpConversionPattern<vector::InsertElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type vectorType = getTypeConverter()->convertType(insertOp.getType()); - if (!vectorType) - return failure(); - - if (isa<spirv::ScalarType>(vectorType)) { - rewriter.replaceOp(insertOp, adaptor.getSource()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( - insertOp, adaptor.getSource(), adaptor.getDest(), - cstPos.getSExtValue()); - else - rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( - insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern<vector::InsertStridedSliceOp> { using OpConversionPattern::OpConversionPattern; @@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final void mlir::populateVectorToSPIRVPatterns( const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< - VectorBitcastConvert, VectorBroadcastConvert, - VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>, VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, - VectorToElementOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>, + VectorToElementOpConvert, VectorInsertOpConvert, + VectorReductionPattern<GL_INT_MAX_MIN_OPS>, VectorReductionPattern<CL_INT_MAX_MIN_OPS>, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 2411af0..4dfcb2b 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -10,7 +10,6 @@ #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/XeVMDialect.h" diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 18e8270..9a0a230 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -510,6 +510,10 @@ LogicalResult DPPOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// GatherToLDSOp +//===----------------------------------------------------------------------===// + LogicalResult GatherToLDSOp::verify() { MemRefType srcType = cast<MemRefType>(getSrc().getType()); MemRefType dstType = cast<MemRefType>(getDst().getType()); @@ -546,6 +550,42 @@ LogicalResult GatherToLDSOp::verify() { return success(); } +namespace { +/// If the source/target of a GatherToLDSOp is a CastOp that only removes static +/// information or changes layout, the cast can be skipped. +struct FoldGatherToLDSOfCast final : OpRewritePattern<GatherToLDSOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherToLDSOp gatherOp, + PatternRewriter &rewriter) const override { + bool modified = false; + auto foldCast = [&](OpOperand &operand) { + if (auto castOp = operand.get().getDefiningOp<memref::CastOp>()) { + if (memref::CastOp::canFoldIntoConsumerOp(castOp)) { + rewriter.modifyOpInPlace(gatherOp, + [&] { operand.assign(castOp.getSource()); }); + modified = true; + } + } + }; + + foldCast(gatherOp.getSrcMutable()); + foldCast(gatherOp.getDstMutable()); + + return success(modified); + } +}; +} // namespace + +void GatherToLDSOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<FoldGatherToLDSOfCast>(context); +} + +//===----------------------------------------------------------------------===// +// TransposeLoadOp +//===----------------------------------------------------------------------===// + LogicalResult TransposeLoadOp::verify() { MemRefType srcType = cast<MemRefType>(getSrc().getType()); diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt index 17bbe54..729e3da 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt @@ -1,7 +1,8 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms EmulateAtomics.cpp - ResolveStridedMetadata.cpp + FoldMemRefsOps.cpp MaskedloadToLoad.cpp + ResolveStridedMetadata.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms @@ -12,6 +13,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms LINK_LIBS PUBLIC MLIRAMDGPUDialect MLIRAMDGPUUtils + MLIRAffineUtils MLIRArithDialect MLIRMemRefDialect MLIRSCFDialect diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp index 37e0d2a..6d1f64e 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp @@ -99,8 +99,8 @@ static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); Type allBitsType = rewriter.getIntegerType(bitwidth); auto allBitsVecType = VectorType::get({1}, allBitsType); - Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val); - Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0); + Value bitcast = vector::BitCastOp::create(rewriter, loc, allBitsVecType, val); + Value scalar = vector::ExtractOp::create(rewriter, loc, bitcast, 0); return scalar; } @@ -118,27 +118,27 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite( SmallVector<NamedAttribute> loadAttrs; patchOperandSegmentSizes(origAttrs, loadAttrs, DataArgAction::Drop); - Value initialLoad = - rewriter.create<RawBufferLoadOp>(loc, dataType, invariantArgs, loadAttrs); + Value initialLoad = RawBufferLoadOp::create(rewriter, loc, dataType, + invariantArgs, loadAttrs); Block *currentBlock = rewriter.getInsertionBlock(); Block *afterAtomic = rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint()); Block *loopBlock = rewriter.createBlock(afterAtomic, {dataType}, {loc}); rewriter.setInsertionPointToEnd(currentBlock); - rewriter.create<cf::BranchOp>(loc, loopBlock, initialLoad); + cf::BranchOp::create(rewriter, loc, loopBlock, initialLoad); rewriter.setInsertionPointToEnd(loopBlock); Value prevLoad = loopBlock->getArgument(0); - Value operated = rewriter.create<ArithOp>(loc, data, prevLoad); + Value operated = ArithOp::create(rewriter, loc, data, prevLoad); dataType = operated.getType(); SmallVector<NamedAttribute> cmpswapAttrs; patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate); SmallVector<Value> cmpswapArgs = {operated, prevLoad}; cmpswapArgs.append(invariantArgs.begin(), invariantArgs.end()); - Value atomicRes = rewriter.create<RawBufferAtomicCmpswapOp>( - loc, dataType, cmpswapArgs, cmpswapAttrs); + Value atomicRes = RawBufferAtomicCmpswapOp::create(rewriter, loc, dataType, + cmpswapArgs, cmpswapAttrs); // We care about exact bitwise equality here, so do some bitcasts. // These will fold away during lowering to the ROCDL dialect, where @@ -150,14 +150,15 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite( if (auto floatDataTy = dyn_cast<FloatType>(dataType)) { Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth()); prevLoadForCompare = - rewriter.create<arith::BitcastOp>(loc, equivInt, prevLoad); + arith::BitcastOp::create(rewriter, loc, equivInt, prevLoad); atomicResForCompare = - rewriter.create<arith::BitcastOp>(loc, equivInt, atomicRes); + arith::BitcastOp::create(rewriter, loc, equivInt, atomicRes); } - Value canLeave = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, atomicResForCompare, prevLoadForCompare); - rewriter.create<cf::CondBranchOp>(loc, canLeave, afterAtomic, ValueRange{}, - loopBlock, atomicRes); + Value canLeave = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, + atomicResForCompare, prevLoadForCompare); + cf::CondBranchOp::create(rewriter, loc, canLeave, afterAtomic, ValueRange{}, + loopBlock, atomicRes); rewriter.eraseOp(atomicOp); return success(); } diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp new file mode 100644 index 0000000..a3fdc7e --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp @@ -0,0 +1,97 @@ +//===- FoldSubviewOps.cpp - AMDGPU fold subview ops -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" + +namespace mlir::amdgpu { +#define GEN_PASS_DEF_AMDGPUFOLDMEMREFOPSPASS +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" + +struct AmdgpuFoldMemRefOpsPass final + : amdgpu::impl::AmdgpuFoldMemRefOpsPassBase<AmdgpuFoldMemRefOpsPass> { + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateAmdgpuFoldMemRefOpsPatterns(patterns); + walkAndApplyPatterns(getOperation(), std::move(patterns)); + } +}; + +struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GatherToLDSOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + + Value memrefSource; + SmallVector<Value> sourceIndices; + auto foldResult = + llvm::TypeSwitch<Operation *, LogicalResult>( + op.getSrc().getDefiningOp()) + .Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) { + // If the source is a SubViewOp, we can directly rewrite the + // GatherToLDSOp. + mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides( + rewriter, loc, subviewOp.getMixedOffsets(), + subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), + op.getSrcIndices(), sourceIndices); + memrefSource = subviewOp.getSource(); + return success(); + }) + .Case<memref::ExpandShapeOp>( + [&](memref::ExpandShapeOp expandShapeOp) { + if (failed(mlir::memref::resolveSourceIndicesExpandShape( + loc, rewriter, expandShapeOp, op.getSrcIndices(), + sourceIndices, false))) { + return failure(); + } + memrefSource = expandShapeOp.getViewSource(); + return success(); + }) + .Case<memref::CollapseShapeOp>( + [&](memref::CollapseShapeOp collapseShapeOp) { + if (failed(mlir::memref::resolveSourceIndicesCollapseShape( + loc, rewriter, collapseShapeOp, op.getSrcIndices(), + sourceIndices))) { + return failure(); + } + memrefSource = collapseShapeOp.getViewSource(); + return success(); + }) + .Default([&](Operation *op) { + // If the source is not a SubViewOp, ExpandShapeOp, or + // CollapseShapeOp, we cannot fold the GatherToLDSOp. + return rewriter.notifyMatchFailure( + op, + "source producer is not one of SubViewOp, ExpandShapeOp, or " + "CollapseShapeOp"); + }); + + if (failed(foldResult)) { + return failure(); + } + + rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices, + op.getDst(), op.getDstIndices(), + op.getTransferType()); + + return success(); + } +}; + +void populateAmdgpuFoldMemRefOpsPatterns(RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add<FoldMemRefOpsIntoGatherToLDSOp>(patterns.getContext(), benefit); +} +} // namespace mlir::amdgpu diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index af8634c..f15c63c 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -54,11 +54,11 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, vector::MaskedLoadOp maskedOp, bool passthru) { VectorType vectorType = maskedOp.getVectorType(); - Value load = builder.create<vector::LoadOp>( - loc, vectorType, maskedOp.getBase(), maskedOp.getIndices()); + Value load = vector::LoadOp::create( + builder, loc, vectorType, maskedOp.getBase(), maskedOp.getIndices()); if (passthru) - load = builder.create<arith::SelectOp>(loc, vectorType, maskedOp.getMask(), - load, maskedOp.getPassThru()); + load = arith::SelectOp::create(builder, loc, vectorType, maskedOp.getMask(), + load, maskedOp.getPassThru()); return load; } @@ -108,7 +108,7 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> { SmallVector<OpFoldResult> indices = maskedOp.getIndices(); auto stridedMetadata = - rewriter.create<memref::ExtractStridedMetadataOp>(loc, src); + memref::ExtractStridedMetadataOp::create(rewriter, loc, src); SmallVector<OpFoldResult> strides = stridedMetadata.getConstifiedMixedStrides(); SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes(); @@ -122,47 +122,47 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> { // delta = bufferSize - linearizedOffset Value vectorSizeOffset = - rewriter.create<arith::ConstantIndexOp>(loc, vectorSize); + arith::ConstantIndexOp::create(rewriter, loc, vectorSize); Value linearIndex = getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices); Value totalSize = getValueOrCreateConstantIndexOp( rewriter, loc, linearizedInfo.linearizedSize); - Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex); + Value delta = arith::SubIOp::create(rewriter, loc, totalSize, linearIndex); // 1) check if delta < vectorSize - Value isOutofBounds = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset); + Value isOutofBounds = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ult, delta, vectorSizeOffset); // 2) check if (detla % elements_per_word != 0) - Value elementsPerWord = rewriter.create<arith::ConstantIndexOp>( - loc, llvm::divideCeil(32, elementBitWidth)); - Value isNotWordAligned = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ne, - rewriter.create<arith::RemUIOp>(loc, delta, elementsPerWord), - rewriter.create<arith::ConstantIndexOp>(loc, 0)); + Value elementsPerWord = arith::ConstantIndexOp::create( + rewriter, loc, llvm::divideCeil(32, elementBitWidth)); + Value isNotWordAligned = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ne, + arith::RemUIOp::create(rewriter, loc, delta, elementsPerWord), + arith::ConstantIndexOp::create(rewriter, loc, 0)); // We take the fallback of maskedload default lowering only it is both // out-of-bounds and not word aligned. The fallback ensures correct results // when loading at the boundary of the buffer since buffer load returns // inconsistent zeros for the whole word when boundary is crossed. Value ifCondition = - rewriter.create<arith::AndIOp>(loc, isOutofBounds, isNotWordAligned); + arith::AndIOp::create(rewriter, loc, isOutofBounds, isNotWordAligned); auto thenBuilder = [&](OpBuilder &builder, Location loc) { Operation *read = builder.clone(*maskedOp.getOperation()); read->setAttr(kMaskedloadNeedsMask, builder.getUnitAttr()); Value readResult = read->getResult(0); - builder.create<scf::YieldOp>(loc, readResult); + scf::YieldOp::create(builder, loc, readResult); }; auto elseBuilder = [&](OpBuilder &builder, Location loc) { Value res = createVectorLoadForMaskedLoad(builder, loc, maskedOp, /*passthru=*/true); - rewriter.create<scf::YieldOp>(loc, res); + scf::YieldOp::create(rewriter, loc, res); }; auto ifOp = - rewriter.create<scf::IfOp>(loc, ifCondition, thenBuilder, elseBuilder); + scf::IfOp::create(rewriter, loc, ifCondition, thenBuilder, elseBuilder); rewriter.replaceOp(maskedOp, ifOp); @@ -185,13 +185,13 @@ struct FullMaskedLoadToConditionalLoad auto trueBuilder = [&](OpBuilder &builder, Location loc) { Value res = createVectorLoadForMaskedLoad(builder, loc, loadOp, /*passthru=*/false); - rewriter.create<scf::YieldOp>(loc, res); + scf::YieldOp::create(rewriter, loc, res); }; auto falseBuilder = [&](OpBuilder &builder, Location loc) { - rewriter.create<scf::YieldOp>(loc, loadOp.getPassThru()); + scf::YieldOp::create(rewriter, loc, loadOp.getPassThru()); }; - auto ifOp = rewriter.create<scf::IfOp>(loadOp.getLoc(), cond, trueBuilder, - falseBuilder); + auto ifOp = scf::IfOp::create(rewriter, loadOp.getLoc(), cond, trueBuilder, + falseBuilder); rewriter.replaceOp(loadOp, ifOp); return success(); } @@ -210,11 +210,12 @@ struct FullMaskedStoreToConditionalStore Value cond = maybeCond.value(); auto trueBuilder = [&](OpBuilder &builder, Location loc) { - rewriter.create<vector::StoreOp>(loc, storeOp.getValueToStore(), - storeOp.getBase(), storeOp.getIndices()); - rewriter.create<scf::YieldOp>(loc); + vector::StoreOp::create(rewriter, loc, storeOp.getValueToStore(), + storeOp.getBase(), storeOp.getIndices()); + scf::YieldOp::create(rewriter, loc); }; - auto ifOp = rewriter.create<scf::IfOp>(storeOp.getLoc(), cond, trueBuilder); + auto ifOp = + scf::IfOp::create(rewriter, storeOp.getLoc(), cond, trueBuilder); rewriter.replaceOp(storeOp, ifOp); return success(); } diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp index 195f59d..f8bab82 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/ResolveStridedMetadata.cpp @@ -37,8 +37,8 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final return rewriter.notifyMatchFailure(metadataOp, "not a fat raw buffer cast"); Location loc = castOp.getLoc(); - auto sourceMetadata = rewriter.create<memref::ExtractStridedMetadataOp>( - loc, castOp.getSource()); + auto sourceMetadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, castOp.getSource()); SmallVector<Value> results; if (metadataOp.getBaseBuffer().use_empty()) { results.push_back(nullptr); @@ -48,13 +48,13 @@ struct ExtractStridedMetadataOnFatRawBufferCastFolder final if (baseBufferType == castOp.getResult().getType()) { results.push_back(castOp.getResult()); } else { - results.push_back(rewriter.create<memref::ReinterpretCastOp>( - loc, baseBufferType, castOp.getResult(), /*offset=*/0, + results.push_back(memref::ReinterpretCastOp::create( + rewriter, loc, baseBufferType, castOp.getResult(), /*offset=*/0, /*sizes=*/ArrayRef<int64_t>{}, /*strides=*/ArrayRef<int64_t>{})); } } if (castOp.getResetOffset()) - results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, 0)); else results.push_back(sourceMetadata.getOffset()); llvm::append_range(results, sourceMetadata.getSizes()); diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 12b375b..6f3110c 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -76,8 +76,8 @@ static SmallVector<Value> getTileSizes(Location loc, amx::TileType tType, auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0)); auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes); return SmallVector<Value>{ - rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr), - rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr)}; + LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, mattr), + LLVM::ConstantOp::create(rewriter, loc, llvmInt16Type, nattr)}; } /// Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer @@ -95,15 +95,14 @@ static Value getStride(Location loc, MemRefType mType, Value base, // Dynamic stride needs code to compute the stride at runtime. MemRefDescriptor memrefDescriptor(base); auto attr = rewriter.getI64IntegerAttr(bytes); - Value scale = rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr); - return rewriter - .create<LLVM::MulOp>(loc, llvmInt64Type, scale, - memrefDescriptor.stride(rewriter, loc, preLast)) + Value scale = LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr); + return LLVM::MulOp::create(rewriter, loc, llvmInt64Type, scale, + memrefDescriptor.stride(rewriter, loc, preLast)) .getResult(); } // Use direct constant for static stride. auto attr = rewriter.getI64IntegerAttr(strides[preLast] * bytes); - return rewriter.create<LLVM::ConstantOp>(loc, llvmInt64Type, attr) + return LLVM::ConstantOp::create(rewriter, loc, llvmInt64Type, attr) .getResult(); } diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 8d7053c..22608a1 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -26,7 +26,7 @@ #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" #include <numeric> @@ -40,7 +40,6 @@ using llvm::divideFloorSigned; using llvm::mod; #define DEBUG_TYPE "affine-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc" @@ -1062,12 +1061,9 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp, AffineMap *map, ValueRange dims, ValueRange syms) { + LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`"; AffineMap affineMinMap = minOp.getAffineMap(); - LLVM_DEBUG({ - DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n"; - }); - // Check the value is positive. for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) { // Compare each expression in the minimum against 0. diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp index f18cec5..df39544 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp @@ -202,7 +202,7 @@ void AffineDataCopyGeneration::runOnBlock(Block *block, void AffineDataCopyGeneration::runOnOperation() { func::FuncOp f = getOperation(); OpBuilder topBuilder(f.getBody()); - zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0); + zeroIndex = arith::ConstantIndexOp::create(topBuilder, f.getLoc(), 0); // Nests that are copy-in's or copy-out's; the root AffineForOps of those // nests are stored herein. diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp index 5430bdc..c0d174a 100644 --- a/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/AffineExpandIndexOps.cpp @@ -58,8 +58,9 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter, // Note: basis elements and their products are, definitionally, // non-negative, so `nuw` is justified. if (dynamicPart) - dynamicPart = rewriter.create<arith::MulIOp>( - loc, dynamicPart, dynamicBasis[dynamicIndex - 1], ovflags); + dynamicPart = + arith::MulIOp::create(rewriter, loc, dynamicPart, + dynamicBasis[dynamicIndex - 1], ovflags); else dynamicPart = dynamicBasis[dynamicIndex - 1]; --dynamicIndex; @@ -74,7 +75,7 @@ static SmallVector<Value> computeStrides(Location loc, RewriterBase &rewriter, rewriter.createOrFold<arith::ConstantIndexOp>(loc, staticPart); if (dynamicPart) stride = - rewriter.create<arith::MulIOp>(loc, dynamicPart, stride, ovflags); + arith::MulIOp::create(rewriter, loc, dynamicPart, stride, ovflags); result.push_back(stride); } } @@ -106,20 +107,20 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, Value zero = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); Value initialPart = - rewriter.create<arith::FloorDivSIOp>(loc, linearIdx, strides.front()); + arith::FloorDivSIOp::create(rewriter, loc, linearIdx, strides.front()); results.push_back(initialPart); auto emitModTerm = [&](Value stride) -> Value { - Value remainder = rewriter.create<arith::RemSIOp>(loc, linearIdx, stride); - Value remainderNegative = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::slt, remainder, zero); + Value remainder = arith::RemSIOp::create(rewriter, loc, linearIdx, stride); + Value remainderNegative = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, remainder, zero); // If the correction is relevant, this term is <= stride, which is known // to be positive in `index`. Otherwise, while 2 * stride might overflow, // this branch won't be taken, so the risk of `poison` is fine. - Value corrected = rewriter.create<arith::AddIOp>( - loc, remainder, stride, arith::IntegerOverflowFlags::nsw); - Value mod = rewriter.create<arith::SelectOp>(loc, remainderNegative, - corrected, remainder); + Value corrected = arith::AddIOp::create(rewriter, loc, remainder, stride, + arith::IntegerOverflowFlags::nsw); + Value mod = arith::SelectOp::create(rewriter, loc, remainderNegative, + corrected, remainder); return mod; }; @@ -131,7 +132,7 @@ affine::lowerAffineDelinearizeIndexOp(RewriterBase &rewriter, // We know both inputs are positive, so floorDiv == div. // This could potentially be a divui, but it's not clear if that would // cause issues. - Value divided = rewriter.create<arith::DivSIOp>(loc, modulus, nextStride); + Value divided = arith::DivSIOp::create(rewriter, loc, modulus, nextStride); results.push_back(divided); } @@ -167,8 +168,8 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter, // our hands on an `OpOperand&` for the loop invariant counting function. for (auto [stride, idxOp] : llvm::zip_equal(strides, llvm::drop_end(op.getMultiIndexMutable()))) { - Value scaledIdx = rewriter.create<arith::MulIOp>( - loc, idxOp.get(), stride, arith::IntegerOverflowFlags::nsw); + Value scaledIdx = arith::MulIOp::create(rewriter, loc, idxOp.get(), stride, + arith::IntegerOverflowFlags::nsw); int64_t numHoistableLoops = numEnclosingInvariantLoops(idxOp); scaledValues.emplace_back(scaledIdx, numHoistableLoops); } @@ -184,8 +185,8 @@ LogicalResult affine::lowerAffineLinearizeIndexOp(RewriterBase &rewriter, Value result = scaledValues.front().first; for (auto [scaledValue, numHoistableLoops] : llvm::drop_begin(scaledValues)) { std::ignore = numHoistableLoops; - result = rewriter.create<arith::AddIOp>(loc, result, scaledValue, - arith::IntegerOverflowFlags::nsw); + result = arith::AddIOp::create(rewriter, loc, result, scaledValue, + arith::IntegerOverflowFlags::nsw); } rewriter.replaceOp(op, result); return success(); diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp index 4fd0cf9..6265f46 100644 --- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp @@ -15,13 +15,13 @@ #include "mlir/Dialect/Affine/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/InterleavedRange.h" using namespace mlir; using namespace mlir::affine; #define DEBUG_TYPE "decompose-affine-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") /// Count the number of loops surrounding `operand` such that operand could be /// hoisted above. @@ -88,8 +88,8 @@ static AffineApplyOp createSubApply(RewriterBase &rewriter, auto rhsMap = AffineMap::get(m.getNumDims(), m.getNumSymbols(), expr, ctx); SmallVector<Value> rhsOperands = originalOp->getOperands(); canonicalizeMapAndOperands(&rhsMap, &rhsOperands); - return rewriter.create<AffineApplyOp>(originalOp.getLoc(), rhsMap, - rhsOperands); + return AffineApplyOp::create(rewriter, originalOp.getLoc(), rhsMap, + rhsOperands); } FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, @@ -115,7 +115,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, return rewriter.notifyMatchFailure( op, "only add or mul binary expr can be reassociated"); - LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n"); + LDBG() << "Start decomposeIntoFinerGrainedOps: " << op; // 2. Iteratively extract the RHS subexpressions while the top-level binary // expr kind remains the same. @@ -125,11 +125,11 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp); if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) { subExpressions.push_back(remainingExp); - LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n"); + LDBG() << "--terminal: " << subExpressions.back(); break; } subExpressions.push_back(currentBinExpr.getRHS()); - LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n"); + LDBG() << "--subExpr: " << subExpressions.back(); remainingExp = currentBinExpr.getLHS(); } @@ -146,9 +146,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) { return getMaxSymbol(e1) < getMaxSymbol(e2); }); - LLVM_DEBUG( - llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: "); - llvm::dbgs() << "\n"); + LDBG() << "--sorted subexprs: " << llvm::interleaved(subExpressions); // 4. Merge sorted subExpressions iteratively, thus achieving reassociation. auto s0 = getAffineSymbolExpr(0, ctx); @@ -160,9 +158,9 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter, auto current = createSubApply(rewriter, op, subExpressions[0]); for (int64_t i = 1, e = subExpressions.size(); i < e; ++i) { Value tmp = createSubApply(rewriter, op, subExpressions[i]); - current = rewriter.create<AffineApplyOp>(op.getLoc(), binMap, - ValueRange{current, tmp}); - LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n"); + current = AffineApplyOp::create(rewriter, op.getLoc(), binMap, + ValueRange{current, tmp}); + LDBG() << "--reassociate into: " << current; } // 5. Replace original op. diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 1d5a665..6c9adff 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -424,7 +424,7 @@ static Value createPrivateMemRef(AffineForOp forOp, // consumer loop nests to reduce their live range. Currently they are added // at the beginning of the block, because loop nests can be reordered // during the fusion pass. - Value newMemRef = top.create<memref::AllocOp>(forOp.getLoc(), newMemRefType); + Value newMemRef = memref::AllocOp::create(top, forOp.getLoc(), newMemRefType); // Build an AffineMap to remap access functions based on lower bound offsets. SmallVector<AffineExpr, 4> remapExprs; diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp index 05a352f..c942c02 100644 --- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp @@ -100,16 +100,16 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { } // Create and place the alloc right before the 'affine.for' operation. - Value newMemRef = bOuter.create<memref::AllocOp>( - forOp.getLoc(), newMemRefType, allocOperands); + Value newMemRef = memref::AllocOp::create(bOuter, forOp.getLoc(), + newMemRefType, allocOperands); // Create 'iv mod 2' value to index the leading dimension. auto d0 = bInner.getAffineDimExpr(0); int64_t step = forOp.getStepAsInt(); auto modTwoMap = AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0, d0.floorDiv(step) % 2); - auto ivModTwoOp = bInner.create<AffineApplyOp>(forOp.getLoc(), modTwoMap, - forOp.getInductionVar()); + auto ivModTwoOp = AffineApplyOp::create(bInner, forOp.getLoc(), modTwoMap, + forOp.getInductionVar()); // replaceAllMemRefUsesWith will succeed unless the forOp body has // non-dereferencing uses of the memref (dealloc's are fine though). @@ -130,7 +130,7 @@ static bool doubleBuffer(Value oldMemRef, AffineForOp forOp) { } // Insert the dealloc op right after the for loop. bOuter.setInsertionPointAfter(forOp); - bOuter.create<memref::DeallocOp>(forOp.getLoc(), newMemRef); + memref::DeallocOp::create(bOuter, forOp.getLoc(), newMemRef); return true; } diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp index 1a266b7..9537d3e 100644 --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -51,10 +51,10 @@ OpFoldResult affine::materializeComputedBound( "expected dynamic dim"); if (isa<RankedTensorType>(value.getType())) { // A tensor dimension is used: generate a tensor.dim. - operands.push_back(b.create<tensor::DimOp>(loc, value, *dim)); + operands.push_back(tensor::DimOp::create(b, loc, value, *dim)); } else if (isa<MemRefType>(value.getType())) { // A memref dimension is used: generate a memref.dim. - operands.push_back(b.create<memref::DimOp>(loc, value, *dim)); + operands.push_back(memref::DimOp::create(b, loc, value, *dim)); } else { llvm_unreachable("cannot generate DimOp for unsupported shaped type"); } @@ -76,7 +76,7 @@ OpFoldResult affine::materializeComputedBound( operands[expr.getPosition() + boundMap.getNumDims()]); // General case: build affine.apply op. return static_cast<OpFoldResult>( - b.create<affine::AffineApplyOp>(loc, boundMap, operands).getResult()); + affine::AffineApplyOp::create(b, loc, boundMap, operands).getResult()); } FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound( diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp index 8493b60..2521512 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp @@ -19,11 +19,10 @@ #include "mlir/Interfaces/ValueBoundsOpInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/IntEqClasses.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #define DEBUG_TYPE "affine-min-max" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; using namespace mlir::affine; @@ -39,7 +38,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { ValueRange operands = affineOp.getOperands(); static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>; - LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; }); + LDBG() << "analyzing value: `" << affineOp; // Create a `Variable` list with values corresponding to each of the results // in the affine affineMap. @@ -48,12 +47,9 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { [&](unsigned i) { return Variable(affineMap.getSliceMap(i, 1), operands); }); - LLVM_DEBUG({ - DBGS() << "- constructed variables are: " - << llvm::interleaved_array(llvm::map_range( - variables, [](const Variable &v) { return v.getMap(); })) - << "`\n"; - }); + LDBG() << "- constructed variables are: " + << llvm::interleaved_array(llvm::map_range( + variables, [](const Variable &v) { return v.getMap(); })); // Get the comparison operation. ComparisonOperator cmpOp = @@ -72,10 +68,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Initialize the bound. Variable *bound = &v; - LLVM_DEBUG({ - DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap() - << "`\n"; - }); + LDBG() << "- inspecting variable: #" << i << ", with map: `" << v.getMap() + << "`\n"; // Check against the other variables. for (size_t j = i + 1; j < variables.size(); ++j) { @@ -87,10 +81,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Get the bound of the equivalence class or itself. Variable *nv = bounds.lookup_or(jEqClass, &variables[j]); - LLVM_DEBUG({ - DBGS() << "- comparing with variable: #" << jEqClass - << ", with map: " << nv->getMap() << "\n"; - }); + LDBG() << "- comparing with variable: #" << jEqClass + << ", with map: " << nv->getMap(); // Compare the variables. FailureOr<bool> cmpResult = @@ -98,18 +90,14 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // The variables cannot be compared. if (failed(cmpResult)) { - LLVM_DEBUG({ - DBGS() << "-- classes: #" << i << ", #" << jEqClass - << " cannot be merged\n"; - }); + LDBG() << "-- classes: #" << i << ", #" << jEqClass + << " cannot be merged"; continue; } // Join the equivalent classes and update the bound if necessary. - LLVM_DEBUG({ - DBGS() << "-- merging classes: #" << i << ", #" << jEqClass - << ", is cmp(lhs, rhs): " << *cmpResult << "`\n"; - }); + LDBG() << "-- merging classes: #" << i << ", #" << jEqClass + << ", is cmp(lhs, rhs): " << *cmpResult << "`"; if (*cmpResult) { boundedClasses.join(eqClass, jEqClass); } else { @@ -124,8 +112,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Return if there's no simplification. if (bounds.size() >= affineMap.getNumResults()) { - LLVM_DEBUG( - { DBGS() << "- the affine operation couldn't get simplified\n"; }); + LDBG() << "- the affine operation couldn't get simplified"; return false; } @@ -135,13 +122,11 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { for (auto [k, bound] : bounds) results.push_back(bound->getMap().getResult(0)); - LLVM_DEBUG({ - DBGS() << "- starting from map: " << affineMap << "\n"; - DBGS() << "- creating new map with: \n"; - DBGS() << "--- dims: " << affineMap.getNumDims() << "\n"; - DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n"; - DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n"; - }); + LDBG() << "- starting from map: " << affineMap; + LDBG() << "- creating new map with:"; + LDBG() << "--- dims: " << affineMap.getNumDims(); + LDBG() << "--- syms: " << affineMap.getNumSymbols(); + LDBG() << "--- res: " << llvm::interleaved_array(results); affineMap = AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(), @@ -149,7 +134,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) { // Update the affine op. rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); }); - LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; }); + LDBG() << "- simplified affine op: `" << affineOp << "`"; return true; } diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp index 7fae260..50a0f3d 100644 --- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp @@ -905,8 +905,8 @@ static void computeMemoryOpIndices(Operation *op, AffineMap map, for (auto resultExpr : map.getResults()) { auto singleResMap = AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr); - auto afOp = state.builder.create<AffineApplyOp>(op->getLoc(), singleResMap, - mapOperands); + auto afOp = AffineApplyOp::create(state.builder, op->getLoc(), singleResMap, + mapOperands); results.push_back(afOp); } } @@ -961,7 +961,7 @@ static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp, auto vecForOp = cast<AffineForOp>(parentOp); state.builder.setInsertionPointToStart(vecForOp.getBody()); auto newConstOp = - state.builder.create<arith::ConstantOp>(constOp.getLoc(), vecAttr); + arith::ConstantOp::create(state.builder, constOp.getLoc(), vecAttr); // Register vector replacement for future uses in the scope. state.registerOpVectorReplacement(constOp, newConstOp); @@ -986,8 +986,8 @@ static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp, } } - auto newApplyOp = state.builder.create<AffineApplyOp>( - applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands); + auto newApplyOp = AffineApplyOp::create( + state.builder, applyOp.getLoc(), applyOp.getAffineMap(), updatedOperands); // Register the new affine.apply result. state.registerValueScalarReplacement(applyOp.getResult(), @@ -1010,7 +1010,7 @@ static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind, auto vecTy = getVectorType(scalarTy, state.strategy); auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr); auto newConstOp = - state.builder.create<arith::ConstantOp>(oldOperand.getLoc(), vecAttr); + arith::ConstantOp::create(state.builder, oldOperand.getLoc(), vecAttr); return newConstOp; } @@ -1062,11 +1062,11 @@ static Value createMask(AffineForOp vecForOp, VectorizationState &state) { AffineMap ubMap = vecForOp.getUpperBoundMap(); Value ub; if (ubMap.getNumResults() == 1) - ub = state.builder.create<AffineApplyOp>(loc, vecForOp.getUpperBoundMap(), - vecForOp.getUpperBoundOperands()); + ub = AffineApplyOp::create(state.builder, loc, vecForOp.getUpperBoundMap(), + vecForOp.getUpperBoundOperands()); else - ub = state.builder.create<AffineMinOp>(loc, vecForOp.getUpperBoundMap(), - vecForOp.getUpperBoundOperands()); + ub = AffineMinOp::create(state.builder, loc, vecForOp.getUpperBoundMap(), + vecForOp.getUpperBoundOperands()); // Then we compute the number of (original) iterations left in the loop. AffineExpr subExpr = state.builder.getAffineDimExpr(0) - state.builder.getAffineDimExpr(1); @@ -1080,7 +1080,7 @@ static Value createMask(AffineForOp vecForOp, VectorizationState &state) { Type maskTy = VectorType::get(state.strategy->vectorSizes, state.builder.getIntegerType(1)); Value mask = - state.builder.create<vector::CreateMaskOp>(loc, maskTy, itersLeft); + vector::CreateMaskOp::create(state.builder, loc, maskTy, itersLeft); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a mask:\n" << itersLeft << "\n" @@ -1123,8 +1123,8 @@ static Operation *vectorizeUniform(Value uniformVal, state.builder.setInsertionPointAfterValue(uniformScalarRepl); auto vectorTy = getVectorType(uniformVal.getType(), state.strategy); - auto bcastOp = state.builder.create<BroadcastOp>(uniformVal.getLoc(), - vectorTy, uniformScalarRepl); + auto bcastOp = BroadcastOp::create(state.builder, uniformVal.getLoc(), + vectorTy, uniformScalarRepl); state.registerValueVectorReplacement(uniformVal, bcastOp); return bcastOp; } @@ -1256,8 +1256,8 @@ static Operation *vectorizeAffineLoad(AffineLoadOp loadOp, LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); - auto transfer = state.builder.create<vector::TransferReadOp>( - loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, + auto transfer = vector::TransferReadOp::create( + state.builder, loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, /*padding=*/std::nullopt, permutationMap); // Register replacement for future uses in the scope. @@ -1303,9 +1303,9 @@ static Operation *vectorizeAffineStore(AffineStoreOp storeOp, LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: "); LLVM_DEBUG(permutationMap.print(dbgs())); - auto transfer = state.builder.create<vector::TransferWriteOp>( - storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices, - permutationMap); + auto transfer = vector::TransferWriteOp::create( + state.builder, storeOp.getLoc(), vectorValue, storeOp.getMemRef(), + indices, permutationMap); LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer); // Register replacement for future uses in the scope. @@ -1322,7 +1322,7 @@ static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind, return false; Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy, state.builder, value.getLoc()); - if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(value.getDefiningOp())) + if (auto constOp = value.getDefiningOp<arith::ConstantOp>()) return constOp.getValue() == valueAttr; return false; } @@ -1387,10 +1387,10 @@ static Operation *vectorizeAffineForOp(AffineForOp forOp, } } - auto vecForOp = state.builder.create<AffineForOp>( - forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), - forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep, - vecIterOperands, + auto vecForOp = AffineForOp::create( + state.builder, forOp.getLoc(), forOp.getLowerBoundOperands(), + forOp.getLowerBoundMap(), forOp.getUpperBoundOperands(), + forOp.getUpperBoundMap(), newStep, vecIterOperands, /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) { // Make sure we don't create a default terminator in the loop body as // the proper terminator will be added during vectorization. @@ -1512,8 +1512,8 @@ static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp, // IterOperands are neutral element vectors. Value neutralVal = cast<AffineForOp>(newParentOp).getInits()[i]; state.builder.setInsertionPoint(combinerOps.back()); - Value maskedReducedVal = state.builder.create<arith::SelectOp>( - reducedVal.getLoc(), mask, reducedVal, neutralVal); + Value maskedReducedVal = arith::SelectOp::create( + state.builder, reducedVal.getLoc(), mask, reducedVal, neutralVal); LLVM_DEBUG( dbgs() << "\n[early-vect]+++++ masking an input to a binary op that" "produces value for a yield Op: " @@ -1865,7 +1865,6 @@ verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) { return success(); } - /// External utility to vectorize affine loops in 'loops' using the n-D /// vectorization factors in 'vectorSizes'. By default, each vectorization /// factor is applied inner-to-outer to the loops of each loop nest. @@ -1927,4 +1926,4 @@ LogicalResult mlir::affine::vectorizeAffineLoopNest( if (failed(verifyLoopNesting(loops))) return failure(); return vectorizeLoopNest(loops, strategy); -} +}
\ No newline at end of file diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 21f69ad..2de057d 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -54,8 +54,8 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, OpBuilder b(forOp); auto lbMap = forOp.getLowerBoundMap(); - auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap, - forOp.getLowerBoundOperands()); + auto lb = AffineApplyOp::create(b, forOp.getLoc(), lbMap, + forOp.getLowerBoundOperands()); // For each upper bound expr, get the range. // Eg: affine.for %i = lb to min (ub1, ub2), @@ -71,7 +71,7 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor, auto bumpMap = AffineMap::get(tripCountMap.getNumDims(), tripCountMap.getNumSymbols(), bumpExprs[i]); bumpValues[i] = - b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands); + AffineApplyOp::create(b, forOp.getLoc(), bumpMap, tripCountOperands); } SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults()); @@ -134,8 +134,8 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) { builder.setInsertionPointToStart(&func.getFunctionBody().front()); else builder.setInsertionPoint(forOp); - auto constOp = builder.create<arith::ConstantIndexOp>( - forOp.getLoc(), forOp.getConstantLowerBound()); + auto constOp = arith::ConstantIndexOp::create( + builder, forOp.getLoc(), forOp.getConstantLowerBound()); iv.replaceAllUsesWith(constOp); } else { auto lbOperands = forOp.getLowerBoundOperands(); @@ -146,7 +146,7 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) { iv.replaceAllUsesWith(lbOperands[0]); } else { auto affineApplyOp = - builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands); + AffineApplyOp::create(builder, forOp.getLoc(), lbMap, lbOperands); iv.replaceAllUsesWith(affineApplyOp); } } @@ -181,8 +181,8 @@ static AffineForOp generateShiftedLoop( assert(ubMap.getNumInputs() == ubOperands.size()); auto loopChunk = - b.create<AffineForOp>(srcForOp.getLoc(), lbOperands, lbMap, ubOperands, - ubMap, srcForOp.getStepAsInt()); + AffineForOp::create(b, srcForOp.getLoc(), lbOperands, lbMap, ubOperands, + ubMap, srcForOp.getStepAsInt()); auto loopChunkIV = loopChunk.getInductionVar(); auto srcIV = srcForOp.getInductionVar(); @@ -197,8 +197,8 @@ static AffineForOp generateShiftedLoop( // Generate the remapping if the shift is not zero: remappedIV = newIV - // shift. if (!srcIV.use_empty() && shift != 0) { - auto ivRemap = bodyBuilder.create<AffineApplyOp>( - srcForOp.getLoc(), + auto ivRemap = AffineApplyOp::create( + bodyBuilder, srcForOp.getLoc(), bodyBuilder.getSingleDimShiftAffineMap( -static_cast<int64_t>(srcForOp.getStepAsInt() * shift)), loopChunkIV); @@ -433,7 +433,7 @@ static void constructTiledLoopNest(MutableArrayRef<AffineForOp> origLoops, for (unsigned i = 0; i < width; i++) { OpBuilder b(topLoop); // Loop bounds will be set later. - AffineForOp pointLoop = b.create<AffineForOp>(loc, 0, 0); + AffineForOp pointLoop = AffineForOp::create(b, loc, 0, 0); pointLoop.getBody()->getOperations().splice( pointLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), topLoop); @@ -447,7 +447,7 @@ static void constructTiledLoopNest(MutableArrayRef<AffineForOp> origLoops, for (unsigned i = width; i < 2 * width; i++) { OpBuilder b(topLoop); // Loop bounds will be set later. - AffineForOp tileSpaceLoop = b.create<AffineForOp>(loc, 0, 0); + AffineForOp tileSpaceLoop = AffineForOp::create(b, loc, 0, 0); tileSpaceLoop.getBody()->getOperations().splice( tileSpaceLoop.getBody()->begin(), topLoop->getBlock()->getOperations(), topLoop); @@ -1048,7 +1048,7 @@ LogicalResult mlir::affine::loopUnrollByFactor( // iv' = iv + i * step auto d0 = b.getAffineDimExpr(0); auto bumpMap = AffineMap::get(1, 0, d0 + i * step); - return b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, iv); + return AffineApplyOp::create(b, forOp.getLoc(), bumpMap, iv); }, /*annotateFn=*/annotateFn, /*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues); @@ -1212,7 +1212,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp, auto d0 = builder.getAffineDimExpr(0); auto bumpMap = AffineMap::get(1, 0, d0 + i * step); auto ivUnroll = - builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV); + AffineApplyOp::create(builder, forOp.getLoc(), bumpMap, forOpIV); operandMaps[i - 1].map(forOpIV, ivUnroll); } // Clone the sub-block being unroll-jammed. @@ -1541,8 +1541,8 @@ stripmineSink(AffineForOp forOp, uint64_t factor, for (auto t : targets) { // Insert newForOp before the terminator of `t`. auto b = OpBuilder::atBlockTerminator(t.getBody()); - auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap, - ubOperands, ubMap, originalStep); + auto newForOp = AffineForOp::create(b, t.getLoc(), lbOperands, lbMap, + ubOperands, ubMap, originalStep); auto begin = t.getBody()->begin(); // Skip terminator and `newForOp` which is just before the terminator. auto nOps = t.getBody()->getOperations().size() - 2; @@ -1616,9 +1616,9 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) { // 1. Store the upper bound of the outermost loop in a variable. Value prev; if (!llvm::hasSingleElement(origUbMap.getResults())) - prev = builder.create<AffineMinOp>(loc, origUbMap, ubOperands); + prev = AffineMinOp::create(builder, loc, origUbMap, ubOperands); else - prev = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands); + prev = AffineApplyOp::create(builder, loc, origUbMap, ubOperands); upperBoundSymbols.push_back(prev); // 2. Emit code computing the upper bound of the coalesced loop as product of @@ -1630,16 +1630,16 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) { Value upperBound; // If upper bound map has more than one result, take their minimum. if (!llvm::hasSingleElement(origUbMap.getResults())) - upperBound = builder.create<AffineMinOp>(loc, origUbMap, ubOperands); + upperBound = AffineMinOp::create(builder, loc, origUbMap, ubOperands); else - upperBound = builder.create<AffineApplyOp>(loc, origUbMap, ubOperands); + upperBound = AffineApplyOp::create(builder, loc, origUbMap, ubOperands); upperBoundSymbols.push_back(upperBound); SmallVector<Value, 4> operands; operands.push_back(prev); operands.push_back(upperBound); // Maintain running product of loop upper bounds. - prev = builder.create<AffineApplyOp>( - loc, + prev = AffineApplyOp::create( + builder, loc, AffineMap::get(/*dimCount=*/1, /*symbolCount=*/1, builder.getAffineDimExpr(0) * @@ -1668,13 +1668,12 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) { SmallVector<Value, 4> operands; operands.push_back(previous); operands.push_back(upperBoundSymbols[idx]); - previous = builder.create<AffineApplyOp>( - loc, - AffineMap::get( - /*dimCount=*/1, /*symbolCount=*/1, - builder.getAffineDimExpr(0).floorDiv( - builder.getAffineSymbolExpr(0))), - operands); + previous = AffineApplyOp::create(builder, loc, + AffineMap::get( + /*dimCount=*/1, /*symbolCount=*/1, + builder.getAffineDimExpr(0).floorDiv( + builder.getAffineSymbolExpr(0))), + operands); } // Modified value of the induction variables of the nested loops after // coalescing. @@ -1685,8 +1684,8 @@ LogicalResult mlir::affine::coalesceLoops(MutableArrayRef<AffineForOp> loops) { SmallVector<Value, 4> applyOperands; applyOperands.push_back(previous); applyOperands.push_back(upperBoundSymbols[idx - 1]); - inductionVariable = builder.create<AffineApplyOp>( - loc, + inductionVariable = AffineApplyOp::create( + builder, loc, AffineMap::get( /*dimCount=*/1, /*symbolCount=*/1, builder.getAffineDimExpr(0) % builder.getAffineSymbolExpr(0)), @@ -1723,21 +1722,21 @@ void mlir::affine::mapLoopToProcessorIds(scf::ForOp forOp, Value linearIndex = processorId.front(); for (unsigned i = 1, e = processorId.size(); i < e; ++i) { - auto mulApplyOp = b.create<AffineApplyOp>( - loc, mulMap, ValueRange{linearIndex, numProcessors[i]}); - linearIndex = b.create<AffineApplyOp>( - loc, addMap, ValueRange{mulApplyOp, processorId[i]}); + auto mulApplyOp = AffineApplyOp::create( + b, loc, mulMap, ValueRange{linearIndex, numProcessors[i]}); + linearIndex = AffineApplyOp::create(b, loc, addMap, + ValueRange{mulApplyOp, processorId[i]}); } - auto mulApplyOp = b.create<AffineApplyOp>( - loc, mulMap, ValueRange{linearIndex, forOp.getStep()}); - Value lb = b.create<AffineApplyOp>( - loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()}); + auto mulApplyOp = AffineApplyOp::create( + b, loc, mulMap, ValueRange{linearIndex, forOp.getStep()}); + Value lb = AffineApplyOp::create( + b, loc, addMap, ValueRange{mulApplyOp, forOp.getLowerBound()}); forOp.setLowerBound(lb); Value step = forOp.getStep(); for (auto numProcs : numProcessors) - step = b.create<AffineApplyOp>(loc, mulMap, ValueRange{numProcs, step}); + step = AffineApplyOp::create(b, loc, mulMap, ValueRange{numProcs, step}); forOp.setStep(step); } @@ -1874,7 +1873,7 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef, auto fastBufOffsetMap = AffineMap::get(lbOperands.size(), 0, fastBufOffsets[d]); - auto offset = b.create<AffineApplyOp>(loc, fastBufOffsetMap, lbOperands); + auto offset = AffineApplyOp::create(b, loc, fastBufOffsetMap, lbOperands); // Construct the subscript for the fast memref being copied into/from: // x - offset_x. @@ -1901,16 +1900,16 @@ generatePointWiseCopy(Location loc, Value memref, Value fastMemRef, if (!isCopyOut) { // Copy in. - auto load = b.create<AffineLoadOp>(loc, memref, memIndices); - b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufMap, - fastBufMapOperands); + auto load = AffineLoadOp::create(b, loc, memref, memIndices); + AffineStoreOp::create(b, loc, load, fastMemRef, fastBufMap, + fastBufMapOperands); return copyNestRoot; } // Copy out. auto load = - b.create<AffineLoadOp>(loc, fastMemRef, fastBufMap, fastBufMapOperands); - b.create<AffineStoreOp>(loc, load, memref, memIndices); + AffineLoadOp::create(b, loc, fastMemRef, fastBufMap, fastBufMapOperands); + AffineStoreOp::create(b, loc, load, memref, memIndices); return copyNestRoot; } @@ -1945,7 +1944,7 @@ static LogicalResult generateCopy( auto f = begin->getParentOfType<FunctionOpInterface>(); OpBuilder topBuilder(f.getFunctionBody()); - Value zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0); + Value zeroIndex = arith::ConstantIndexOp::create(topBuilder, f.getLoc(), 0); *sizeInBytes = 0; @@ -2056,7 +2055,7 @@ static LogicalResult generateCopy( memIndices.push_back(zeroIndex); } else { memIndices.push_back( - top.create<arith::ConstantIndexOp>(loc, indexVal).getResult()); + arith::ConstantIndexOp::create(top, loc, indexVal).getResult()); } } else { // The coordinate for the start location is just the lower bound along the @@ -2070,7 +2069,8 @@ static LogicalResult generateCopy( lbs[d] = lbs[d].replaceDimsAndSymbols( /*dimReplacements=*/{}, symReplacements, lbs[d].getNumSymbols(), /*numResultSyms=*/0); - memIndices.push_back(b.create<AffineApplyOp>(loc, lbs[d], regionSymbols)); + memIndices.push_back( + AffineApplyOp::create(b, loc, lbs[d], regionSymbols)); } // The fast buffer is copied into at location zero; addressing is relative. bufIndices.push_back(zeroIndex); @@ -2094,7 +2094,7 @@ static LogicalResult generateCopy( // Create the fast memory space buffer just before the 'affine.for' // operation. fastMemRef = - prologue.create<memref::AllocOp>(loc, fastMemRefType).getResult(); + memref::AllocOp::create(prologue, loc, fastMemRefType).getResult(); // Record it. fastBufferMap[memref] = fastMemRef; // fastMemRefType is a constant shaped memref. @@ -2111,7 +2111,7 @@ static LogicalResult generateCopy( fastMemRef = fastBufferMap[memref]; } - auto numElementsSSA = top.create<arith::ConstantIndexOp>(loc, *numElements); + auto numElementsSSA = arith::ConstantIndexOp::create(top, loc, *numElements); Value dmaStride; Value numEltPerDmaStride; @@ -2128,9 +2128,9 @@ static LogicalResult generateCopy( if (!dmaStrideInfos.empty()) { dmaStride = - top.create<arith::ConstantIndexOp>(loc, dmaStrideInfos[0].stride); - numEltPerDmaStride = top.create<arith::ConstantIndexOp>( - loc, dmaStrideInfos[0].numEltPerStride); + arith::ConstantIndexOp::create(top, loc, dmaStrideInfos[0].stride); + numEltPerDmaStride = arith::ConstantIndexOp::create( + top, loc, dmaStrideInfos[0].numEltPerStride); } } @@ -2160,21 +2160,21 @@ static LogicalResult generateCopy( // Create a tag (single element 1-d memref) for the DMA. auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {}, copyOptions.tagMemorySpace); - auto tagMemRef = prologue.create<memref::AllocOp>(loc, tagMemRefType); + auto tagMemRef = memref::AllocOp::create(prologue, loc, tagMemRefType); SmallVector<Value, 4> tagIndices({zeroIndex}); auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size()); fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices); if (!region.isWrite()) { // DMA non-blocking read from original buffer to fast buffer. - b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices, - fastMemRef, bufAffineMap, bufIndices, - tagMemRef, tagAffineMap, tagIndices, - numElementsSSA, dmaStride, numEltPerDmaStride); + AffineDmaStartOp::create(b, loc, memref, memAffineMap, memIndices, + fastMemRef, bufAffineMap, bufIndices, tagMemRef, + tagAffineMap, tagIndices, numElementsSSA, + dmaStride, numEltPerDmaStride); } else { // DMA non-blocking write from fast buffer to the original memref. - auto op = b.create<AffineDmaStartOp>( - loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap, + auto op = AffineDmaStartOp::create( + b, loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap, memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA, dmaStride, numEltPerDmaStride); // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the @@ -2184,11 +2184,11 @@ static LogicalResult generateCopy( } // Matching DMA wait to block on completion; tag always has a 0 index. - b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex, - numElementsSSA); + AffineDmaWaitOp::create(b, loc, tagMemRef, tagAffineMap, zeroIndex, + numElementsSSA); // Generate dealloc for the tag. - auto tagDeallocOp = epilogue.create<memref::DeallocOp>(loc, tagMemRef); + auto tagDeallocOp = memref::DeallocOp::create(epilogue, loc, tagMemRef); if (*nEnd == end && isCopyOutAtEndOfBlock) // Since new ops are being appended (for outgoing DMAs), adjust the end to // mark end of range of the original. @@ -2197,7 +2197,7 @@ static LogicalResult generateCopy( // Generate dealloc for the buffer. if (!existingBuf) { - auto bufDeallocOp = epilogue.create<memref::DeallocOp>(loc, fastMemRef); + auto bufDeallocOp = memref::DeallocOp::create(epilogue, loc, fastMemRef); // When generating pointwise copies, `nEnd' has to be set to deallocOp on // the fast buffer (since it marks the new end insertion point). if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock) @@ -2567,8 +2567,8 @@ AffineForOp mlir::affine::createCanonicalizedAffineForOp( canonicalizeMapAndOperands(&ubMap, &upperOperands); ubMap = removeDuplicateExprs(ubMap); - return b.create<AffineForOp>(loc, lowerOperands, lbMap, upperOperands, ubMap, - step); + return AffineForOp::create(b, loc, lowerOperands, lbMap, upperOperands, ubMap, + step); } /// Creates an AffineIfOp that encodes the conditional to choose between @@ -2651,8 +2651,8 @@ static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops, SmallVector<Value, 4> setOperands; cst.getValues(0, cst.getNumDimAndSymbolVars(), &setOperands); canonicalizeSetAndOperands(&ifCondSet, &setOperands); - return b.create<AffineIfOp>(loops[0].getLoc(), ifCondSet, setOperands, - /*withElseRegion=*/true); + return AffineIfOp::create(b, loops[0].getLoc(), ifCondSet, setOperands, + /*withElseRegion=*/true); } /// Create the full tile loop nest (along with its body). diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 7bb158e..845be20 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -56,7 +56,7 @@ public: auto rhs = visit(expr.getRHS()); if (!lhs || !rhs) return nullptr; - auto op = builder.create<OpTy>(loc, lhs, rhs, overflowFlags); + auto op = OpTy::create(builder, loc, lhs, rhs, overflowFlags); return op.getResult(); } @@ -90,14 +90,14 @@ public: auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value remainder = builder.create<arith::RemSIOp>(loc, lhs, rhs); - Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); - Value isRemainderNegative = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::slt, remainder, zeroCst); + Value remainder = arith::RemSIOp::create(builder, loc, lhs, rhs); + Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0); + Value isRemainderNegative = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::slt, remainder, zeroCst); Value correctedRemainder = - builder.create<arith::AddIOp>(loc, remainder, rhs); - Value result = builder.create<arith::SelectOp>( - loc, isRemainderNegative, correctedRemainder, remainder); + arith::AddIOp::create(builder, loc, remainder, rhs); + Value result = arith::SelectOp::create(builder, loc, isRemainderNegative, + correctedRemainder, remainder); return result; } @@ -129,18 +129,19 @@ public: auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); - Value noneCst = builder.create<arith::ConstantIndexOp>(loc, -1); - Value negative = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::slt, lhs, zeroCst); - Value negatedDecremented = builder.create<arith::SubIOp>(loc, noneCst, lhs); - Value dividend = - builder.create<arith::SelectOp>(loc, negative, negatedDecremented, lhs); - Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs); + Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0); + Value noneCst = arith::ConstantIndexOp::create(builder, loc, -1); + Value negative = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::slt, lhs, zeroCst); + Value negatedDecremented = + arith::SubIOp::create(builder, loc, noneCst, lhs); + Value dividend = arith::SelectOp::create(builder, loc, negative, + negatedDecremented, lhs); + Value quotient = arith::DivSIOp::create(builder, loc, dividend, rhs); Value correctedQuotient = - builder.create<arith::SubIOp>(loc, noneCst, quotient); - Value result = builder.create<arith::SelectOp>(loc, negative, - correctedQuotient, quotient); + arith::SubIOp::create(builder, loc, noneCst, quotient); + Value result = arith::SelectOp::create(builder, loc, negative, + correctedQuotient, quotient); return result; } @@ -168,26 +169,26 @@ public: auto rhs = visit(expr.getRHS()); assert(lhs && rhs && "unexpected affine expr lowering failure"); - Value zeroCst = builder.create<arith::ConstantIndexOp>(loc, 0); - Value oneCst = builder.create<arith::ConstantIndexOp>(loc, 1); - Value nonPositive = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sle, lhs, zeroCst); - Value negated = builder.create<arith::SubIOp>(loc, zeroCst, lhs); - Value decremented = builder.create<arith::SubIOp>(loc, lhs, oneCst); - Value dividend = - builder.create<arith::SelectOp>(loc, nonPositive, negated, decremented); - Value quotient = builder.create<arith::DivSIOp>(loc, dividend, rhs); + Value zeroCst = arith::ConstantIndexOp::create(builder, loc, 0); + Value oneCst = arith::ConstantIndexOp::create(builder, loc, 1); + Value nonPositive = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::sle, lhs, zeroCst); + Value negated = arith::SubIOp::create(builder, loc, zeroCst, lhs); + Value decremented = arith::SubIOp::create(builder, loc, lhs, oneCst); + Value dividend = arith::SelectOp::create(builder, loc, nonPositive, negated, + decremented); + Value quotient = arith::DivSIOp::create(builder, loc, dividend, rhs); Value negatedQuotient = - builder.create<arith::SubIOp>(loc, zeroCst, quotient); + arith::SubIOp::create(builder, loc, zeroCst, quotient); Value incrementedQuotient = - builder.create<arith::AddIOp>(loc, quotient, oneCst); - Value result = builder.create<arith::SelectOp>( - loc, nonPositive, negatedQuotient, incrementedQuotient); + arith::AddIOp::create(builder, loc, quotient, oneCst); + Value result = arith::SelectOp::create( + builder, loc, nonPositive, negatedQuotient, incrementedQuotient); return result; } Value visitConstantExpr(AffineConstantExpr expr) { - auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue()); + auto op = arith::ConstantIndexOp::create(builder, loc, expr.getValue()); return op.getResult(); } @@ -297,9 +298,9 @@ static AffineIfOp hoistAffineIfOp(AffineIfOp ifOp, Operation *hoistOverOp) { // block. IRMapping operandMap; OpBuilder b(hoistOverOp); - auto hoistedIfOp = b.create<AffineIfOp>(ifOp.getLoc(), ifOp.getIntegerSet(), - ifOp.getOperands(), - /*elseBlock=*/true); + auto hoistedIfOp = AffineIfOp::create(b, ifOp.getLoc(), ifOp.getIntegerSet(), + ifOp.getOperands(), + /*elseBlock=*/true); // Create a clone of hoistOverOp to use for the else branch of the hoisted // conditional. The else block may get optimized away if empty. @@ -368,8 +369,8 @@ mlir::affine::affineParallelize(AffineForOp forOp, parallelReductions, [](const LoopReduction &red) { return red.value; })); auto reductionKinds = llvm::to_vector<4>(llvm::map_range( parallelReductions, [](const LoopReduction &red) { return red.kind; })); - AffineParallelOp newPloop = outsideBuilder.create<AffineParallelOp>( - loc, ValueRange(reducedValues).getTypes(), reductionKinds, + AffineParallelOp newPloop = AffineParallelOp::create( + outsideBuilder, loc, ValueRange(reducedValues).getTypes(), reductionKinds, llvm::ArrayRef(lowerBoundMap), lowerBoundOperands, llvm::ArrayRef(upperBoundMap), upperBoundOperands, llvm::ArrayRef(forOp.getStepAsInt())); @@ -540,7 +541,8 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) { SmallVector<Value, 8> applyOperands{dimOperands}; applyOperands.push_back(iv); applyOperands.append(symbolOperands.begin(), symbolOperands.end()); - auto apply = builder.create<AffineApplyOp>(op.getLoc(), map, applyOperands); + auto apply = + AffineApplyOp::create(builder, op.getLoc(), map, applyOperands); iv.replaceAllUsesExcept(apply, apply); } @@ -621,8 +623,9 @@ LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op, AffineValueMap newIvToOldIvMap; AffineValueMap::difference(lbMap, scaleIvValueMap, &newIvToOldIvMap); (void)newIvToOldIvMap.canonicalize(); - auto newIV = opBuilder.create<AffineApplyOp>( - loc, newIvToOldIvMap.getAffineMap(), newIvToOldIvMap.getOperands()); + auto newIV = + AffineApplyOp::create(opBuilder, loc, newIvToOldIvMap.getAffineMap(), + newIvToOldIvMap.getOperands()); op.getInductionVar().replaceAllUsesExcept(newIV->getResult(0), newIV); return success(); } @@ -1186,8 +1189,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( for (auto resultExpr : oldMap.getResults()) { auto singleResMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr); - auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, - oldMapOperands); + auto afOp = AffineApplyOp::create(builder, op->getLoc(), singleResMap, + oldMapOperands); oldMemRefOperands.push_back(afOp); affineApplyOps.push_back(afOp); } @@ -1213,8 +1216,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( for (auto resultExpr : indexRemap.getResults()) { auto singleResMap = AffineMap::get( indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); - auto afOp = builder.create<AffineApplyOp>(op->getLoc(), singleResMap, - remapOperands); + auto afOp = AffineApplyOp::create(builder, op->getLoc(), singleResMap, + remapOperands); remapOutputs.push_back(afOp); affineApplyOps.push_back(afOp); } @@ -1263,8 +1266,8 @@ LogicalResult mlir::affine::replaceAllMemRefUsesWith( // AffineMapAccessInterface, we need to apply the values of `newMapOperands` // to the `newMap` to get the correct indices. for (unsigned i = 0; i < newMemRefRank; i++) { - state.operands.push_back(builder.create<AffineApplyOp>( - op->getLoc(), + state.operands.push_back(AffineApplyOp::create( + builder, op->getLoc(), AffineMap::get(newMap.getNumDims(), newMap.getNumSymbols(), newMap.getResult(i)), newMapOperands)); @@ -1449,8 +1452,8 @@ void mlir::affine::createAffineComputationSlice( for (auto resultExpr : composedMap.getResults()) { auto singleResMap = AffineMap::get(composedMap.getNumDims(), composedMap.getNumSymbols(), resultExpr); - sliceOps->push_back(builder.create<AffineApplyOp>( - opInst->getLoc(), singleResMap, composedOpOperands)); + sliceOps->push_back(AffineApplyOp::create( + builder, opInst->getLoc(), singleResMap, composedOpOperands)); } // Construct the new operands that include the results from the composed @@ -1680,7 +1683,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType, // Create ConstantOp for static dimension. auto constantAttr = b.getIntegerAttr(b.getIndexType(), oldMemRefShape[d]); inAffineApply.emplace_back( - b.create<arith::ConstantOp>(allocOp.getLoc(), constantAttr)); + arith::ConstantOp::create(b, allocOp.getLoc(), constantAttr)); } } @@ -1704,7 +1707,7 @@ static void createNewDynamicSizes(MemRefType oldMemRefType, AffineMap newMap = AffineMap::get(map.getNumInputs(), map.getNumSymbols(), newMapOutput); Value affineApp = - b.create<AffineApplyOp>(allocOp.getLoc(), newMap, inAffineApply); + AffineApplyOp::create(b, allocOp.getLoc(), newMap, inAffineApply); newDynamicSizes.emplace_back(affineApp); } newDimIdx++; @@ -1739,12 +1742,11 @@ LogicalResult mlir::affine::normalizeMemRef(AllocLikeOp allocOp) { createNewDynamicSizes(oldMemRefType, newMemRefType, layoutMap, allocOp, b, newDynamicSizes); // Add the new dynamic sizes in new AllocOp. - newAlloc = - b.create<AllocLikeOp>(allocOp.getLoc(), newMemRefType, newDynamicSizes, - allocOp.getAlignmentAttr()); + newAlloc = AllocLikeOp::create(b, allocOp.getLoc(), newMemRefType, + newDynamicSizes, allocOp.getAlignmentAttr()); } else { - newAlloc = b.create<AllocLikeOp>(allocOp.getLoc(), newMemRefType, - allocOp.getAlignmentAttr()); + newAlloc = AllocLikeOp::create(b, allocOp.getLoc(), newMemRefType, + allocOp.getAlignmentAttr()); } // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, @@ -1802,10 +1804,10 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) { for (unsigned i = 0, e = memrefType.getRank(); i < e; i++) { if (memrefType.isDynamicDim(i)) mapOperands[i] = - b.create<arith::SubIOp>(loc, oldSizes[0].getType(), oldSizes[idx++], - b.create<arith::ConstantIndexOp>(loc, 1)); + arith::SubIOp::create(b, loc, oldSizes[0].getType(), oldSizes[idx++], + arith::ConstantIndexOp::create(b, loc, 1)); else - mapOperands[i] = b.create<arith::ConstantIndexOp>(loc, oldShape[i] - 1); + mapOperands[i] = arith::ConstantIndexOp::create(b, loc, oldShape[i] - 1); } for (unsigned i = 0, e = oldStrides.size(); i < e; i++) mapOperands[memrefType.getRank() + i] = oldStrides[i]; @@ -1815,20 +1817,20 @@ mlir::affine::normalizeMemRef(memref::ReinterpretCastOp reinterpretCastOp) { for (unsigned i = 0; i < newRank; i++) { if (!newMemRefType.isDynamicDim(i)) continue; - newSizes.push_back(b.create<AffineApplyOp>( - loc, + newSizes.push_back(AffineApplyOp::create( + b, loc, AffineMap::get(oldLayoutMap.getNumDims(), oldLayoutMap.getNumSymbols(), oldLayoutMap.getResult(i)), mapOperands)); } for (unsigned i = 0, e = newSizes.size(); i < e; i++) { newSizes[i] = - b.create<arith::AddIOp>(loc, newSizes[i].getType(), newSizes[i], - b.create<arith::ConstantIndexOp>(loc, 1)); + arith::AddIOp::create(b, loc, newSizes[i].getType(), newSizes[i], + arith::ConstantIndexOp::create(b, loc, 1)); } // Create the new reinterpret_cast op. - auto newReinterpretCast = b.create<memref::ReinterpretCastOp>( - loc, newMemRefType, reinterpretCastOp.getSource(), + auto newReinterpretCast = memref::ReinterpretCastOp::create( + b, loc, newMemRefType, reinterpretCastOp.getSource(), /*offsets=*/ValueRange(), newSizes, /*strides=*/ValueRange(), /*static_offsets=*/newStaticOffsets, diff --git a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp index ebcb951..e7cbee6 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithDialect.cpp @@ -64,7 +64,7 @@ Operation *arith::ArithDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto poison = dyn_cast<ub::PoisonAttr>(value)) - return builder.create<ub::PoisonOp>(loc, type, poison); + return ub::PoisonOp::create(builder, loc, type, poison); return ConstantOp::materialize(builder, value, type, loc); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 910334b..488c3c3 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2498,7 +2498,7 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { matchPattern(adaptor.getFalseValue(), m_Zero())) return condition; - if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) { + if (auto cmp = condition.getDefiningOp<arith::CmpIOp>()) { auto pred = cmp.getPredicate(); if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) { auto cmpLhs = cmp.getLhs(); diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp index f2e7732..9199dcc 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.cpp @@ -67,8 +67,8 @@ struct SelectOpInterface return state.getMemrefWithUniqueOwnership(builder, value, value.getParentBlock()); - Value ownership = builder.create<arith::SelectOp>( - op->getLoc(), selectOp.getCondition(), + Value ownership = arith::SelectOp::create( + builder, op->getLoc(), selectOp.getCondition(), state.getOwnership(selectOp.getTrueValue(), block).getIndicator(), state.getOwnership(selectOp.getFalseValue(), block).getIndicator()); return {selectOp.getResult(), ownership}; diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index afee162..b073a31 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -170,10 +170,10 @@ struct SelectOpInterface return failure(); if (trueBuffer.getType() != *targetType) trueBuffer = - rewriter.create<memref::CastOp>(loc, *targetType, trueBuffer); + memref::CastOp::create(rewriter, loc, *targetType, trueBuffer); if (falseBuffer.getType() != *targetType) falseBuffer = - rewriter.create<memref::CastOp>(loc, *targetType, falseBuffer); + memref::CastOp::create(rewriter, loc, *targetType, falseBuffer); } replaceOpWithNewBufferizedOp<arith::SelectOp>( diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index f96bda6..93682a9 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -27,7 +27,7 @@ add_mlir_dialect_library(MLIRArithTransforms MLIRInferIntRangeInterface MLIRIR MLIRMemRefDialect - MLIRMeshDialect + MLIRShardDialect MLIRPass MLIRShardingInterface MLIRTensorDialect diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp index 55b757c..7626d35 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateUnsupportedFloats.cpp @@ -75,7 +75,7 @@ LogicalResult EmulateFloatPattern::matchAndRewrite( for (auto [res, oldType, newType] : llvm::zip_equal( MutableArrayRef{newResults}, op->getResultTypes(), resultTypes)) { if (oldType != newType) { - auto truncFOp = rewriter.create<arith::TruncFOp>(loc, oldType, res); + auto truncFOp = arith::TruncFOp::create(rewriter, loc, oldType, res); truncFOp.setFastmath(arith::FastMathFlags::contract); res = truncFOp.getResult(); } @@ -98,7 +98,7 @@ void mlir::arith::populateEmulateUnsupportedFloatsConversions( }); converter.addTargetMaterialization( [](OpBuilder &b, Type target, ValueRange input, Location loc) { - auto extFOp = b.create<arith::ExtFOp>(loc, target, input); + auto extFOp = arith::ExtFOp::create(b, loc, target, input); extFOp.setFastmath(arith::FastMathFlags::contract); return extFOp; }); diff --git a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp index d5d1559..efe6ad2 100644 --- a/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp @@ -72,7 +72,7 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, // Scalarize the result in case of 1D vectors. if (shape.size() == 1) - return rewriter.create<vector::ExtractOp>(loc, input, lastOffset); + return vector::ExtractOp::create(rewriter, loc, input, lastOffset); SmallVector<int64_t> offsets(shape.size(), 0); offsets.back() = lastOffset; @@ -80,8 +80,8 @@ static Value extractLastDimSlice(ConversionPatternRewriter &rewriter, sizes.back() = 1; SmallVector<int64_t> strides(shape.size(), 1); - return rewriter.create<vector::ExtractStridedSliceOp>(loc, input, offsets, - sizes, strides); + return vector::ExtractStridedSliceOp::create(rewriter, loc, input, offsets, + sizes, strides); } /// Extracts two vector slices from the `input` whose type is `vector<...x2T>`, @@ -107,7 +107,7 @@ static Value dropTrailingX1Dim(ConversionPatternRewriter &rewriter, assert(shape.back() == 1 && "Expected the last vector dim to be x1"); auto newVecTy = VectorType::get(shape.drop_back(), vecTy.getElementType()); - return rewriter.create<vector::ShapeCastOp>(loc, newVecTy, input); + return vector::ShapeCastOp::create(rewriter, loc, newVecTy, input); } /// Performs a vector shape cast to append an x1 dimension. If the @@ -122,7 +122,7 @@ static Value appendX1Dim(ConversionPatternRewriter &rewriter, Location loc, auto newShape = llvm::to_vector(vecTy.getShape()); newShape.push_back(1); auto newTy = VectorType::get(newShape, vecTy.getElementType()); - return rewriter.create<vector::ShapeCastOp>(loc, newTy, input); + return vector::ShapeCastOp::create(rewriter, loc, newTy, input); } /// Inserts the `source` vector slice into the `dest` vector at offset @@ -136,13 +136,13 @@ static Value insertLastDimSlice(ConversionPatternRewriter &rewriter, // Handle scalar source. if (isa<IntegerType>(source.getType())) - return rewriter.create<vector::InsertOp>(loc, source, dest, lastOffset); + return vector::InsertOp::create(rewriter, loc, source, dest, lastOffset); SmallVector<int64_t> offsets(shape.size(), 0); offsets.back() = lastOffset; SmallVector<int64_t> strides(shape.size(), 1); - return rewriter.create<vector::InsertStridedSliceOp>(loc, source, dest, - offsets, strides); + return vector::InsertStridedSliceOp::create(rewriter, loc, source, dest, + offsets, strides); } /// Constructs a new vector of type `resultType` by creating a series of @@ -254,12 +254,12 @@ struct ConvertAddI final : OpConversionPattern<arith::AddIOp> { extractLastDimHalves(rewriter, loc, adaptor.getRhs()); auto lowSum = - rewriter.create<arith::AddUIExtendedOp>(loc, lhsElem0, rhsElem0); + arith::AddUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0); Value overflowVal = - rewriter.create<arith::ExtUIOp>(loc, newElemTy, lowSum.getOverflow()); + arith::ExtUIOp::create(rewriter, loc, newElemTy, lowSum.getOverflow()); - Value high0 = rewriter.create<arith::AddIOp>(loc, overflowVal, lhsElem1); - Value high = rewriter.create<arith::AddIOp>(loc, high0, rhsElem1); + Value high0 = arith::AddIOp::create(rewriter, loc, overflowVal, lhsElem1); + Value high = arith::AddIOp::create(rewriter, loc, high0, rhsElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {lowSum.getSum(), high}); @@ -293,8 +293,8 @@ struct ConvertBitwiseBinary final : OpConversionPattern<BinaryOp> { auto [rhsElem0, rhsElem1] = extractLastDimHalves(rewriter, loc, adaptor.getRhs()); - Value resElem0 = rewriter.create<BinaryOp>(loc, lhsElem0, rhsElem0); - Value resElem1 = rewriter.create<BinaryOp>(loc, lhsElem1, rhsElem1); + Value resElem0 = BinaryOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value resElem1 = BinaryOp::create(rewriter, loc, lhsElem1, rhsElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); rewriter.replaceOp(op, resultVec); @@ -346,26 +346,26 @@ struct ConvertCmpI final : OpConversionPattern<arith::CmpIOp> { extractLastDimHalves(rewriter, loc, adaptor.getRhs()); Value lowCmp = - rewriter.create<arith::CmpIOp>(loc, lowPred, lhsElem0, rhsElem0); + arith::CmpIOp::create(rewriter, loc, lowPred, lhsElem0, rhsElem0); Value highCmp = - rewriter.create<arith::CmpIOp>(loc, highPred, lhsElem1, rhsElem1); + arith::CmpIOp::create(rewriter, loc, highPred, lhsElem1, rhsElem1); Value cmpResult{}; switch (highPred) { case arith::CmpIPredicate::eq: { - cmpResult = rewriter.create<arith::AndIOp>(loc, lowCmp, highCmp); + cmpResult = arith::AndIOp::create(rewriter, loc, lowCmp, highCmp); break; } case arith::CmpIPredicate::ne: { - cmpResult = rewriter.create<arith::OrIOp>(loc, lowCmp, highCmp); + cmpResult = arith::OrIOp::create(rewriter, loc, lowCmp, highCmp); break; } default: { // Handle inequality checks. - Value highEq = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1); + Value highEq = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, lhsElem1, rhsElem1); cmpResult = - rewriter.create<arith::SelectOp>(loc, highEq, lowCmp, highCmp); + arith::SelectOp::create(rewriter, loc, highEq, lowCmp, highCmp); break; } } @@ -401,14 +401,14 @@ struct ConvertMulI final : OpConversionPattern<arith::MulIOp> { // Multiplying two i2N integers produces (at most) an i4N result, but // because the calculation of top i2N is not necessary, we omit it. auto mulLowLow = - rewriter.create<arith::MulUIExtendedOp>(loc, lhsElem0, rhsElem0); - Value mulLowHi = rewriter.create<arith::MulIOp>(loc, lhsElem0, rhsElem1); - Value mulHiLow = rewriter.create<arith::MulIOp>(loc, lhsElem1, rhsElem0); + arith::MulUIExtendedOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value mulLowHi = arith::MulIOp::create(rewriter, loc, lhsElem0, rhsElem1); + Value mulHiLow = arith::MulIOp::create(rewriter, loc, lhsElem1, rhsElem0); Value resLow = mulLowLow.getLow(); Value resHi = - rewriter.create<arith::AddIOp>(loc, mulLowLow.getHigh(), mulLowHi); - resHi = rewriter.create<arith::AddIOp>(loc, resHi, mulHiLow); + arith::AddIOp::create(rewriter, loc, mulLowLow.getHigh(), mulLowHi); + resHi = arith::AddIOp::create(rewriter, loc, resHi, mulHiLow); Value resultVec = constructResultVector(rewriter, loc, newTy, {resLow, resHi}); @@ -443,10 +443,10 @@ struct ConvertExtSI final : OpConversionPattern<arith::ExtSIOp> { loc, newResultComponentTy, newOperand); Value operandZeroCst = createScalarOrSplatConstant(rewriter, loc, newResultComponentTy, 0); - Value signBit = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::slt, extended, operandZeroCst); + Value signBit = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, extended, operandZeroCst); Value signValue = - rewriter.create<arith::ExtSIOp>(loc, newResultComponentTy, signBit); + arith::ExtSIOp::create(rewriter, loc, newResultComponentTy, signBit); Value resultVec = constructResultVector(rewriter, loc, newTy, {extended, signValue}); @@ -508,7 +508,7 @@ struct ConvertMaxMin final : OpConversionPattern<SourceOp> { // Rewrite Max*I/Min*I as compare and select over original operands. Let // the CmpI and Select emulation patterns handle the final legalization. Value cmp = - rewriter.create<arith::CmpIOp>(loc, CmpPred, op.getLhs(), op.getRhs()); + arith::CmpIOp::create(rewriter, loc, CmpPred, op.getLhs(), op.getRhs()); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, op.getLhs(), op.getRhs()); return success(); @@ -587,7 +587,7 @@ struct ConvertIndexCastIndexToInt final : OpConversionPattern<CastOp> { // Sign or zero-extend the result. Let the matching conversion pattern // legalize the extension op. Value underlyingVal = - rewriter.create<CastOp>(loc, narrowTy, adaptor.getIn()); + CastOp::create(rewriter, loc, narrowTy, adaptor.getIn()); rewriter.replaceOpWithNewOp<ExtensionOp>(op, resultType, underlyingVal); return success(); } @@ -616,9 +616,9 @@ struct ConvertSelect final : OpConversionPattern<arith::SelectOp> { Value cond = appendX1Dim(rewriter, loc, adaptor.getCondition()); Value resElem0 = - rewriter.create<arith::SelectOp>(loc, cond, trueElem0, falseElem0); + arith::SelectOp::create(rewriter, loc, cond, trueElem0, falseElem0); Value resElem1 = - rewriter.create<arith::SelectOp>(loc, cond, trueElem1, falseElem1); + arith::SelectOp::create(rewriter, loc, cond, trueElem1, falseElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); rewriter.replaceOp(op, resultVec); @@ -680,33 +680,33 @@ struct ConvertShLI final : OpConversionPattern<arith::ShLIOp> { Value elemBitWidth = createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth); - Value illegalElemShift = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); + Value illegalElemShift = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); Value shiftedElem0 = - rewriter.create<arith::ShLIOp>(loc, lhsElem0, rhsElem0); - Value resElem0 = rewriter.create<arith::SelectOp>(loc, illegalElemShift, - zeroCst, shiftedElem0); + arith::ShLIOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value resElem0 = arith::SelectOp::create(rewriter, loc, illegalElemShift, + zeroCst, shiftedElem0); - Value cappedShiftAmount = rewriter.create<arith::SelectOp>( - loc, illegalElemShift, elemBitWidth, rhsElem0); + Value cappedShiftAmount = arith::SelectOp::create( + rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0); Value rightShiftAmount = - rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount); + arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount); Value shiftedRight = - rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rightShiftAmount); + arith::ShRUIOp::create(rewriter, loc, lhsElem0, rightShiftAmount); Value overshotShiftAmount = - rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth); + arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth); Value shiftedLeft = - rewriter.create<arith::ShLIOp>(loc, lhsElem0, overshotShiftAmount); + arith::ShLIOp::create(rewriter, loc, lhsElem0, overshotShiftAmount); Value shiftedElem1 = - rewriter.create<arith::ShLIOp>(loc, lhsElem1, rhsElem0); - Value resElem1High = rewriter.create<arith::SelectOp>( - loc, illegalElemShift, zeroCst, shiftedElem1); - Value resElem1Low = rewriter.create<arith::SelectOp>( - loc, illegalElemShift, shiftedLeft, shiftedRight); + arith::ShLIOp::create(rewriter, loc, lhsElem1, rhsElem0); + Value resElem1High = arith::SelectOp::create( + rewriter, loc, illegalElemShift, zeroCst, shiftedElem1); + Value resElem1Low = arith::SelectOp::create(rewriter, loc, illegalElemShift, + shiftedLeft, shiftedRight); Value resElem1 = - rewriter.create<arith::OrIOp>(loc, resElem1Low, resElem1High); + arith::OrIOp::create(rewriter, loc, resElem1Low, resElem1High); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); @@ -769,33 +769,33 @@ struct ConvertShRUI final : OpConversionPattern<arith::ShRUIOp> { Value elemBitWidth = createScalarOrSplatConstant(rewriter, loc, newOperandTy, newBitWidth); - Value illegalElemShift = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); + Value illegalElemShift = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::uge, rhsElem0, elemBitWidth); Value shiftedElem0 = - rewriter.create<arith::ShRUIOp>(loc, lhsElem0, rhsElem0); - Value resElem0Low = rewriter.create<arith::SelectOp>(loc, illegalElemShift, - zeroCst, shiftedElem0); + arith::ShRUIOp::create(rewriter, loc, lhsElem0, rhsElem0); + Value resElem0Low = arith::SelectOp::create(rewriter, loc, illegalElemShift, + zeroCst, shiftedElem0); Value shiftedElem1 = - rewriter.create<arith::ShRUIOp>(loc, lhsElem1, rhsElem0); - Value resElem1 = rewriter.create<arith::SelectOp>(loc, illegalElemShift, - zeroCst, shiftedElem1); + arith::ShRUIOp::create(rewriter, loc, lhsElem1, rhsElem0); + Value resElem1 = arith::SelectOp::create(rewriter, loc, illegalElemShift, + zeroCst, shiftedElem1); - Value cappedShiftAmount = rewriter.create<arith::SelectOp>( - loc, illegalElemShift, elemBitWidth, rhsElem0); + Value cappedShiftAmount = arith::SelectOp::create( + rewriter, loc, illegalElemShift, elemBitWidth, rhsElem0); Value leftShiftAmount = - rewriter.create<arith::SubIOp>(loc, elemBitWidth, cappedShiftAmount); + arith::SubIOp::create(rewriter, loc, elemBitWidth, cappedShiftAmount); Value shiftedLeft = - rewriter.create<arith::ShLIOp>(loc, lhsElem1, leftShiftAmount); + arith::ShLIOp::create(rewriter, loc, lhsElem1, leftShiftAmount); Value overshotShiftAmount = - rewriter.create<arith::SubIOp>(loc, rhsElem0, elemBitWidth); + arith::SubIOp::create(rewriter, loc, rhsElem0, elemBitWidth); Value shiftedRight = - rewriter.create<arith::ShRUIOp>(loc, lhsElem1, overshotShiftAmount); + arith::ShRUIOp::create(rewriter, loc, lhsElem1, overshotShiftAmount); - Value resElem0High = rewriter.create<arith::SelectOp>( - loc, illegalElemShift, shiftedRight, shiftedLeft); + Value resElem0High = arith::SelectOp::create( + rewriter, loc, illegalElemShift, shiftedRight, shiftedLeft); Value resElem0 = - rewriter.create<arith::OrIOp>(loc, resElem0Low, resElem0High); + arith::OrIOp::create(rewriter, loc, resElem0Low, resElem0High); Value resultVec = constructResultVector(rewriter, loc, newTy, {resElem0, resElem1}); @@ -832,33 +832,33 @@ struct ConvertShRSI final : OpConversionPattern<arith::ShRSIOp> { // Perform as many ops over the narrow integer type as possible and let the // other emulation patterns convert the rest. Value elemZero = createScalarOrSplatConstant(rewriter, loc, narrowTy, 0); - Value signBit = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::slt, lhsElem1, elemZero); + Value signBit = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, lhsElem1, elemZero); signBit = dropTrailingX1Dim(rewriter, loc, signBit); // Create a bit pattern of either all ones or all zeros. Then shift it left // to calculate the sign extension bits created by shifting the original // sign bit right. - Value allSign = rewriter.create<arith::ExtSIOp>(loc, oldTy, signBit); + Value allSign = arith::ExtSIOp::create(rewriter, loc, oldTy, signBit); Value maxShift = createScalarOrSplatConstant(rewriter, loc, narrowTy, origBitwidth); Value numNonSignExtBits = - rewriter.create<arith::SubIOp>(loc, maxShift, rhsElem0); + arith::SubIOp::create(rewriter, loc, maxShift, rhsElem0); numNonSignExtBits = dropTrailingX1Dim(rewriter, loc, numNonSignExtBits); numNonSignExtBits = - rewriter.create<arith::ExtUIOp>(loc, oldTy, numNonSignExtBits); + arith::ExtUIOp::create(rewriter, loc, oldTy, numNonSignExtBits); Value signBits = - rewriter.create<arith::ShLIOp>(loc, allSign, numNonSignExtBits); + arith::ShLIOp::create(rewriter, loc, allSign, numNonSignExtBits); // Use original arguments to create the right shift. Value shrui = - rewriter.create<arith::ShRUIOp>(loc, op.getLhs(), op.getRhs()); - Value shrsi = rewriter.create<arith::OrIOp>(loc, shrui, signBits); + arith::ShRUIOp::create(rewriter, loc, op.getLhs(), op.getRhs()); + Value shrsi = arith::OrIOp::create(rewriter, loc, shrui, signBits); // Handle shifting by zero. This is necessary when the `signBits` shift is // invalid. - Value isNoop = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, - rhsElem0, elemZero); + Value isNoop = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, rhsElem0, elemZero); isNoop = dropTrailingX1Dim(rewriter, loc, isNoop); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNoop, op.getLhs(), shrsi); @@ -892,14 +892,14 @@ struct ConvertSubI final : OpConversionPattern<arith::SubIOp> { // Emulates LHS - RHS by [LHS0 - RHS0, LHS1 - RHS1 - CARRY] where // CARRY is 1 or 0. - Value low = rewriter.create<arith::SubIOp>(loc, lhsElem0, rhsElem0); + Value low = arith::SubIOp::create(rewriter, loc, lhsElem0, rhsElem0); // We have a carry if lhsElem0 < rhsElem0. - Value carry0 = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0); - Value carryVal = rewriter.create<arith::ExtUIOp>(loc, newElemTy, carry0); + Value carry0 = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ult, lhsElem0, rhsElem0); + Value carryVal = arith::ExtUIOp::create(rewriter, loc, newElemTy, carry0); - Value high0 = rewriter.create<arith::SubIOp>(loc, lhsElem1, carryVal); - Value high = rewriter.create<arith::SubIOp>(loc, high0, rhsElem1); + Value high0 = arith::SubIOp::create(rewriter, loc, lhsElem1, carryVal); + Value high = arith::SubIOp::create(rewriter, loc, high0, rhsElem1); Value resultVec = constructResultVector(rewriter, loc, newTy, {low, high}); rewriter.replaceOp(op, resultVec); @@ -933,13 +933,13 @@ struct ConvertSIToFP final : OpConversionPattern<arith::SIToFPOp> { // result or not based on that sign bit. We implement negation by // subtracting from zero. Note that this relies on the the other conversion // patterns to legalize created ops and narrow the bit widths. - Value isNeg = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, - in, zeroCst); - Value neg = rewriter.create<arith::SubIOp>(loc, zeroCst, in); - Value abs = rewriter.create<arith::SelectOp>(loc, isNeg, neg, in); + Value isNeg = arith::CmpIOp::create(rewriter, loc, + arith::CmpIPredicate::slt, in, zeroCst); + Value neg = arith::SubIOp::create(rewriter, loc, zeroCst, in); + Value abs = arith::SelectOp::create(rewriter, loc, isNeg, neg, in); - Value absResult = rewriter.create<arith::UIToFPOp>(loc, op.getType(), abs); - Value negResult = rewriter.create<arith::NegFOp>(loc, absResult); + Value absResult = arith::UIToFPOp::create(rewriter, loc, op.getType(), abs); + Value negResult = arith::NegFOp::create(rewriter, loc, absResult); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, negResult, absResult); return success(); @@ -985,13 +985,13 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> { // // Note 2: We do not strictly need the `hi == 0`, case, but it makes // constant folding easier. - Value hiEqZero = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, hiInt, zeroCst); + Value hiEqZero = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, hiInt, zeroCst); Type resultTy = op.getType(); Type resultElemTy = getElementTypeOrSelf(resultTy); - Value lowFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, lowInt); - Value hiFp = rewriter.create<arith::UIToFPOp>(loc, resultTy, hiInt); + Value lowFp = arith::UIToFPOp::create(rewriter, loc, resultTy, lowInt); + Value hiFp = arith::UIToFPOp::create(rewriter, loc, resultTy, hiInt); int64_t pow2Int = int64_t(1) << newBitWidth; TypedAttr pow2Attr = @@ -999,10 +999,11 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> { if (auto vecTy = dyn_cast<VectorType>(resultTy)) pow2Attr = SplatElementsAttr::get(vecTy, pow2Attr); - Value pow2Val = rewriter.create<arith::ConstantOp>(loc, resultTy, pow2Attr); + Value pow2Val = + arith::ConstantOp::create(rewriter, loc, resultTy, pow2Attr); - Value hiVal = rewriter.create<arith::MulFOp>(loc, hiFp, pow2Val); - Value result = rewriter.create<arith::AddFOp>(loc, lowFp, hiVal); + Value hiVal = arith::MulFOp::create(rewriter, loc, hiFp, pow2Val); + Value result = arith::AddFOp::create(rewriter, loc, lowFp, hiVal); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, hiEqZero, lowFp, result); return success(); @@ -1037,22 +1038,22 @@ struct ConvertFPToSI final : OpConversionPattern<arith::FPToSIOp> { // result is UB. TypedAttr zeroAttr = rewriter.getZeroAttr(fpTy); - Value zeroCst = rewriter.create<arith::ConstantOp>(loc, zeroAttr); + Value zeroCst = arith::ConstantOp::create(rewriter, loc, zeroAttr); Value zeroCstInt = createScalarOrSplatConstant(rewriter, loc, intTy, 0); // Get the absolute value. One could have used math.absf here, but that // introduces an extra dependency. - Value isNeg = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, - inFp, zeroCst); - Value negInFp = rewriter.create<arith::NegFOp>(loc, inFp); + Value isNeg = arith::CmpFOp::create( + rewriter, loc, arith::CmpFPredicate::OLT, inFp, zeroCst); + Value negInFp = arith::NegFOp::create(rewriter, loc, inFp); - Value absVal = rewriter.create<arith::SelectOp>(loc, isNeg, negInFp, inFp); + Value absVal = arith::SelectOp::create(rewriter, loc, isNeg, negInFp, inFp); // Defer the absolute value to fptoui. - Value res = rewriter.create<arith::FPToUIOp>(loc, intTy, absVal); + Value res = arith::FPToUIOp::create(rewriter, loc, intTy, absVal); // Negate the value if < 0 . - Value neg = rewriter.create<arith::SubIOp>(loc, zeroCstInt, res); + Value neg = arith::SubIOp::create(rewriter, loc, zeroCstInt, res); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, neg, res); return success(); @@ -1109,17 +1110,17 @@ struct ConvertFPToUI final : OpConversionPattern<arith::FPToUIOp> { if (auto vecType = dyn_cast<VectorType>(fpTy)) powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr); Value powBitwidthFloatCst = - rewriter.create<arith::ConstantOp>(loc, powBitwidthAttr); + arith::ConstantOp::create(rewriter, loc, powBitwidthAttr); Value fpDivPowBitwidth = - rewriter.create<arith::DivFOp>(loc, inFp, powBitwidthFloatCst); + arith::DivFOp::create(rewriter, loc, inFp, powBitwidthFloatCst); Value resHigh = - rewriter.create<arith::FPToUIOp>(loc, newHalfType, fpDivPowBitwidth); + arith::FPToUIOp::create(rewriter, loc, newHalfType, fpDivPowBitwidth); // Calculate fp - resHigh * 2^N by getting the remainder of the division Value remainder = - rewriter.create<arith::RemFOp>(loc, inFp, powBitwidthFloatCst); + arith::RemFOp::create(rewriter, loc, inFp, powBitwidthFloatCst); Value resLow = - rewriter.create<arith::FPToUIOp>(loc, newHalfType, remainder); + arith::FPToUIOp::create(rewriter, loc, newHalfType, remainder); Value high = appendX1Dim(rewriter, loc, resHigh); Value low = appendX1Dim(rewriter, loc, resLow); diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index e842f44..f8fa35c 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -28,10 +28,10 @@ static Value createConst(Location loc, Type type, int value, PatternRewriter &rewriter) { auto attr = rewriter.getIntegerAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast<ShapedType>(type)) { - return rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(shapedTy, attr)); + return arith::ConstantOp::create(rewriter, loc, + DenseElementsAttr::get(shapedTy, attr)); } - return rewriter.create<arith::ConstantOp>(loc, attr); + return arith::ConstantOp::create(rewriter, loc, attr); } /// Create a float constant. @@ -39,11 +39,11 @@ static Value createFloatConst(Location loc, Type type, APFloat value, PatternRewriter &rewriter) { auto attr = rewriter.getFloatAttr(getElementTypeOrSelf(type), value); if (auto shapedTy = dyn_cast<ShapedType>(type)) { - return rewriter.create<arith::ConstantOp>( - loc, DenseElementsAttr::get(shapedTy, attr)); + return arith::ConstantOp::create(rewriter, loc, + DenseElementsAttr::get(shapedTy, attr)); } - return rewriter.create<arith::ConstantOp>(loc, attr); + return arith::ConstantOp::create(rewriter, loc, attr); } /// Creates shapedType using shape from cloneFrom and base type from cloneTo @@ -67,11 +67,11 @@ struct CeilDivUIOpConverter : public OpRewritePattern<arith::CeilDivUIOp> { Value b = op.getRhs(); Value zero = createConst(loc, a.getType(), 0, rewriter); Value compare = - rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, a, zero); + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, a, zero); Value one = createConst(loc, a.getType(), 1, rewriter); - Value minusOne = rewriter.create<arith::SubIOp>(loc, a, one); - Value quotient = rewriter.create<arith::DivUIOp>(loc, minusOne, b); - Value plusOne = rewriter.create<arith::AddIOp>(loc, quotient, one); + Value minusOne = arith::SubIOp::create(rewriter, loc, a, one); + Value quotient = arith::DivUIOp::create(rewriter, loc, minusOne, b); + Value plusOne = arith::AddIOp::create(rewriter, loc, quotient, one); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, compare, zero, plusOne); return success(); } @@ -96,22 +96,22 @@ struct CeilDivSIOpConverter : public OpRewritePattern<arith::CeilDivSIOp> { Value zero = createConst(loc, type, 0, rewriter); Value one = createConst(loc, type, 1, rewriter); - Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b); - Value product = rewriter.create<arith::MulIOp>(loc, quotient, b); - Value notEqualDivisor = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ne, a, product); + Value quotient = arith::DivSIOp::create(rewriter, loc, a, b); + Value product = arith::MulIOp::create(rewriter, loc, quotient, b); + Value notEqualDivisor = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ne, a, product); - Value aNeg = - rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); - Value bNeg = - rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); + Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + a, zero); + Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + b, zero); - Value signEqual = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, aNeg, bNeg); + Value signEqual = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::eq, aNeg, bNeg); Value cond = - rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signEqual); + arith::AndIOp::create(rewriter, loc, notEqualDivisor, signEqual); - Value quotientPlusOne = rewriter.create<arith::AddIOp>(loc, quotient, one); + Value quotientPlusOne = arith::AddIOp::create(rewriter, loc, quotient, one); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientPlusOne, quotient); @@ -135,25 +135,25 @@ struct FloorDivSIOpConverter : public OpRewritePattern<arith::FloorDivSIOp> { Value a = op.getLhs(); Value b = op.getRhs(); - Value quotient = rewriter.create<arith::DivSIOp>(loc, a, b); - Value product = rewriter.create<arith::MulIOp>(loc, quotient, b); - Value notEqualDivisor = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ne, a, product); + Value quotient = arith::DivSIOp::create(rewriter, loc, a, b); + Value product = arith::MulIOp::create(rewriter, loc, quotient, b); + Value notEqualDivisor = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ne, a, product); Value zero = createConst(loc, type, 0, rewriter); - Value aNeg = - rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, a, zero); - Value bNeg = - rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, b, zero); + Value aNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + a, zero); + Value bNeg = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, + b, zero); - Value signOpposite = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ne, aNeg, bNeg); + Value signOpposite = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::ne, aNeg, bNeg); Value cond = - rewriter.create<arith::AndIOp>(loc, notEqualDivisor, signOpposite); + arith::AndIOp::create(rewriter, loc, notEqualDivisor, signOpposite); Value minusOne = createConst(loc, type, -1, rewriter); Value quotientMinusOne = - rewriter.create<arith::AddIOp>(loc, quotient, minusOne); + arith::AddIOp::create(rewriter, loc, quotient, minusOne); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cond, quotientMinusOne, quotient); @@ -171,7 +171,7 @@ public: Value lhs = op.getLhs(); Value rhs = op.getRhs(); - Value cmp = rewriter.create<arith::CmpIOp>(op.getLoc(), pred, lhs, rhs); + Value cmp = arith::CmpIOp::create(rewriter, op.getLoc(), pred, lhs, rhs); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmp, lhs, rhs); return success(); } @@ -192,12 +192,12 @@ public: static_assert(pred == arith::CmpFPredicate::UGT || pred == arith::CmpFPredicate::ULT, "pred must be either UGT or ULT"); - Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs); - Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs); + Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs); + Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs); // Handle the case where rhs is NaN: 'isNaN(rhs) ? rhs : select'. - Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, - rhs, rhs); + Value isNaN = arith::CmpFOp::create(rewriter, loc, + arith::CmpFPredicate::UNO, rhs, rhs); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select); return success(); } @@ -218,12 +218,12 @@ public: static_assert(pred == arith::CmpFPredicate::UGT || pred == arith::CmpFPredicate::ULT, "pred must be either UGT or ULT"); - Value cmp = rewriter.create<arith::CmpFOp>(loc, pred, lhs, rhs); - Value select = rewriter.create<arith::SelectOp>(loc, cmp, lhs, rhs); + Value cmp = arith::CmpFOp::create(rewriter, loc, pred, lhs, rhs); + Value select = arith::SelectOp::create(rewriter, loc, cmp, lhs, rhs); // Handle the case where lhs is NaN: 'isNaN(lhs) ? rhs : select'. - Value isNaN = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UNO, - lhs, lhs); + Value isNaN = arith::CmpFOp::create(rewriter, loc, + arith::CmpFPredicate::UNO, lhs, lhs); rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNaN, rhs, select); return success(); } @@ -247,12 +247,12 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> { Type i16Ty = cloneToShapedType(operandTy, b.getI16Type()); Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); - Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand); - Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast); + Value bitcast = arith::BitcastOp::create(b, i16Ty, operand); + Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast); Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); - Value shl = b.create<arith::ShLIOp>(exti, c16); - Value result = b.create<arith::BitcastOp>(resultTy, shl); + Value shl = arith::ShLIOp::create(b, exti, c16); + Value result = arith::BitcastOp::create(b, resultTy, shl); rewriter.replaceOp(op, result); return success(); @@ -296,7 +296,7 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { // exponent bits, that simple truncation is the desired outcome for // infinities. Value isNan = - b.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, operand, operand); + arith::CmpFOp::create(b, arith::CmpFPredicate::UNE, operand, operand); // Constant used to make the rounding bias. Value c7FFF = createConst(op.getLoc(), i32Ty, 0x7fff, rewriter); // Constant used to generate a quiet NaN. @@ -305,30 +305,30 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { Value c16 = createConst(op.getLoc(), i32Ty, 16, rewriter); Value c1 = createConst(op.getLoc(), i32Ty, 1, rewriter); // Reinterpret the input f32 value as bits. - Value bitcast = b.create<arith::BitcastOp>(i32Ty, operand); + Value bitcast = arith::BitcastOp::create(b, i32Ty, operand); // Read bit 16 as a value in {0,1}. Value bit16 = - b.create<arith::AndIOp>(b.create<arith::ShRUIOp>(bitcast, c16), c1); + arith::AndIOp::create(b, arith::ShRUIOp::create(b, bitcast, c16), c1); // Determine the rounding bias to add as either 0x7fff or 0x8000 depending // on bit 16, implementing the tie-breaking "to nearest even". - Value roundingBias = b.create<arith::AddIOp>(bit16, c7FFF); + Value roundingBias = arith::AddIOp::create(b, bit16, c7FFF); // Add the rounding bias. Generally we want this to be added to the // mantissa, but nothing prevents this to from carrying into the exponent // bits, which would feel like a bug, but this is the magic trick here: // when that happens, the mantissa gets reset to zero and the exponent // gets incremented by the carry... which is actually exactly what we // want. - Value biased = b.create<arith::AddIOp>(bitcast, roundingBias); + Value biased = arith::AddIOp::create(b, bitcast, roundingBias); // Now that the rounding-bias has been added, truncating the low bits // yields the correctly rounded result. - Value biasedAndShifted = b.create<arith::ShRUIOp>(biased, c16); + Value biasedAndShifted = arith::ShRUIOp::create(b, biased, c16); Value normalCaseResultI16 = - b.create<arith::TruncIOp>(i16Ty, biasedAndShifted); + arith::TruncIOp::create(b, i16Ty, biasedAndShifted); // Select either the above-computed result, or a quiet NaN constant // if the input was NaN. Value select = - b.create<arith::SelectOp>(isNan, c7FC0I16, normalCaseResultI16); - Value result = b.create<arith::BitcastOp>(resultTy, select); + arith::SelectOp::create(b, isNan, c7FC0I16, normalCaseResultI16); + Value result = arith::BitcastOp::create(b, resultTy, select); rewriter.replaceOp(op, result); return success(); } @@ -381,7 +381,7 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> { Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); Type i4Ty = cloneToShapedType(operandTy, b.getI4Type()); Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); - Value i4Bits = b.create<arith::BitcastOp>(i4Ty, operand); + Value i4Bits = arith::BitcastOp::create(b, i4Ty, operand); Value c0x0 = createConst(loc, i4Ty, 0x0, rewriter); Value c0x1 = createConst(loc, i4Ty, 0x1, rewriter); @@ -390,38 +390,39 @@ struct F4E2M1ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> { // Set last Exponent bit and Mantissa. Value c0x00000014 = createConst(loc, i32Ty, 0x14, rewriter); - Value bits1To24 = b.create<arith::ShLIOp>(i4Bits, c0x2); + Value bits1To24 = arith::ShLIOp::create(b, i4Bits, c0x2); Value isHalf = - b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x1); - bits1To24 = b.create<arith::SelectOp>(isHalf, c0x0, bits1To24); - bits1To24 = b.create<arith::ExtUIOp>(i32Ty, bits1To24); - bits1To24 = b.create<arith::ShLIOp>(bits1To24, c0x00000014); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x1); + bits1To24 = arith::SelectOp::create(b, isHalf, c0x0, bits1To24); + bits1To24 = arith::ExtUIOp::create(b, i32Ty, bits1To24); + bits1To24 = arith::ShLIOp::create(b, bits1To24, c0x00000014); // Set first 7 bits of Exponent. Value zeroExpBits = createConst(loc, i32Ty, 0x00000000, rewriter); Value highExpBits = createConst(loc, i32Ty, 0x40000000, rewriter); Value lowExpBits = createConst(loc, i32Ty, 0x3f000000, rewriter); Value useLargerExp = - b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x4); + arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x4); Value bits25To31 = - b.create<arith::SelectOp>(useLargerExp, highExpBits, lowExpBits); + arith::SelectOp::create(b, useLargerExp, highExpBits, lowExpBits); Value zeroExp = - b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, i4Bits, c0x0); - bits25To31 = b.create<arith::SelectOp>(zeroExp, zeroExpBits, bits25To31); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, i4Bits, c0x0); + bits25To31 = arith::SelectOp::create(b, zeroExp, zeroExpBits, bits25To31); // Set sign. Value c0x80000000 = createConst(loc, i32Ty, 0x80000000, rewriter); Value c0x8 = createConst(loc, i4Ty, 0x8, rewriter); Value negative = - b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, i4Bits, c0x8); - Value bit32 = b.create<arith::SelectOp>(negative, c0x80000000, zeroExpBits); + arith::CmpIOp::create(b, arith::CmpIPredicate::uge, i4Bits, c0x8); + Value bit32 = + arith::SelectOp::create(b, negative, c0x80000000, zeroExpBits); // Add segments together. - Value bits1To31 = b.create<arith::AddIOp>(bits1To24, bits25To31); - Value bits1To32 = b.create<arith::AddIOp>(bits1To31, bit32); - Value result = b.create<arith::BitcastOp>(f32Ty, bits1To32); + Value bits1To31 = arith::AddIOp::create(b, bits1To24, bits25To31); + Value bits1To32 = arith::AddIOp::create(b, bits1To31, bit32); + Value result = arith::BitcastOp::create(b, f32Ty, bits1To32); if (!isa<Float32Type>(resultETy)) - result = b.create<arith::TruncFOp>(resultTy, result); + result = arith::TruncFOp::create(b, resultTy, result); rewriter.replaceOp(op, result); return success(); @@ -447,25 +448,25 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> { Type i32Ty = cloneToShapedType(operandTy, b.getI32Type()); Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); - Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand); + Value bitcast = arith::BitcastOp::create(b, i8Ty, operand); // create constants for NaNs Value cF8NaN = createConst(op.getLoc(), i8Ty, 0xff, rewriter); Value cF32NaN = createConst(op.getLoc(), i32Ty, 0xffffffff, rewriter); Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); - Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast); - Value f32Bits = b.create<arith::ShLIOp>(exti, cF32MantissaWidth); + Value exti = arith::ExtUIOp::create(b, i32Ty, bitcast); + Value f32Bits = arith::ShLIOp::create(b, exti, cF32MantissaWidth); Value isNan = - b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, bitcast, cF8NaN); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, bitcast, cF8NaN); // select for NaNs - f32Bits = b.create<arith::SelectOp>(isNan, cF32NaN, f32Bits); - Value result = b.create<arith::BitcastOp>(f32Ty, f32Bits); + f32Bits = arith::SelectOp::create(b, isNan, cF32NaN, f32Bits); + Value result = arith::BitcastOp::create(b, f32Ty, f32Bits); if (resultETy.getIntOrFloatBitWidth() < 32) { - result = b.create<arith::TruncFOp>(resultTy, result, nullptr, - op.getFastmathAttr()); + result = arith::TruncFOp::create(b, resultTy, result, nullptr, + op.getFastmathAttr()); } else if (resultETy.getIntOrFloatBitWidth() > 32) { - result = b.create<arith::ExtFOp>(resultTy, result, op.getFastmathAttr()); + result = arith::ExtFOp::create(b, resultTy, result, op.getFastmathAttr()); } rewriter.replaceOp(op, result); return success(); @@ -520,7 +521,7 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { if (!isa<Float4E2M1FNType>(resultETy)) return rewriter.notifyMatchFailure(op, "not a trunc of F4E2M1FN"); if (!isa<Float32Type>(operandETy)) - operand = b.create<arith::ExtFOp>(f32Ty, operand); + operand = arith::ExtFOp::create(b, f32Ty, operand); Value c0x1 = createConst(loc, i4Ty, 1, rewriter); Value c0x3 = createConst(loc, i4Ty, 3, rewriter); @@ -532,65 +533,65 @@ struct F4E2M1TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { // Step 0: Clamp to bounds. Value cHigherBound = createFloatConst(loc, f32Ty, APFloat(6.0f), rewriter); Value cLowerBound = createFloatConst(loc, f32Ty, APFloat(-6.0f), rewriter); - Value operandClamped = b.create<arith::MinNumFOp>(cHigherBound, operand); - operandClamped = b.create<arith::MaxNumFOp>(cLowerBound, operandClamped); - Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operandClamped); + Value operandClamped = arith::MinNumFOp::create(b, cHigherBound, operand); + operandClamped = arith::MaxNumFOp::create(b, cLowerBound, operandClamped); + Value f32Bits = arith::BitcastOp::create(b, i32Ty, operandClamped); // Step 1: Set sign bit. Value cF32ExpManWidth = createConst(loc, i32Ty, 31, rewriter); // 23 - Value f32Sign = b.create<arith::ShRUIOp>(f32Bits, cF32ExpManWidth); - Value f4Sign = b.create<arith::TruncIOp>(i4Ty, f32Sign); - Value f4Bits = b.create<arith::ShLIOp>(f4Sign, c0x3); + Value f32Sign = arith::ShRUIOp::create(b, f32Bits, cF32ExpManWidth); + Value f4Sign = arith::TruncIOp::create(b, i4Ty, f32Sign); + Value f4Bits = arith::ShLIOp::create(b, f4Sign, c0x3); // Step 2: Convert exponent by adjusting bias. Value biasAdjustment = createConst(loc, i32Ty, 0x7e, rewriter); Value cF4MantissaWidth = c0x1; // 1 Value cF32MantissaWidth = createConst(loc, i32Ty, 23, rewriter); // 23 - Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth); + Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth); Value biasAdjustedSignExp = - b.create<arith::SubIOp>(f32SignExp, biasAdjustment); - Value f4Exp = b.create<arith::TruncIOp>(i4Ty, biasAdjustedSignExp); - f4Exp = b.create<arith::ShLIOp>(f4Exp, cF4MantissaWidth); - f4Bits = b.create<arith::AddIOp>(f4Bits, f4Exp); + arith::SubIOp::create(b, f32SignExp, biasAdjustment); + Value f4Exp = arith::TruncIOp::create(b, i4Ty, biasAdjustedSignExp); + f4Exp = arith::ShLIOp::create(b, f4Exp, cF4MantissaWidth); + f4Bits = arith::AddIOp::create(b, f4Bits, f4Exp); // Step 3: Set mantissa to first bit. Value cF32FirstBitMask = createConst(loc, i32Ty, 0x400000, rewriter); - Value man1Bit = b.create<arith::AndIOp>(f32Bits, cF32FirstBitMask); - man1Bit = b.create<arith::ShRUIOp>(man1Bit, c0x00000016); - Value f4Man = b.create<arith::TruncIOp>(i4Ty, man1Bit); - f4Bits = b.create<arith::AddIOp>(f4Bits, f4Man); + Value man1Bit = arith::AndIOp::create(b, f32Bits, cF32FirstBitMask); + man1Bit = arith::ShRUIOp::create(b, man1Bit, c0x00000016); + Value f4Man = arith::TruncIOp::create(b, i4Ty, man1Bit); + f4Bits = arith::AddIOp::create(b, f4Bits, f4Man); // Step 4: Special consideration for conversion to 0.5. Value cF32MantissaMask = createConst(loc, i32Ty, 0x7fffff, rewriter); - Value f8Exp = b.create<arith::TruncIOp>(i8Ty, biasAdjustedSignExp); + Value f8Exp = arith::TruncIOp::create(b, i8Ty, biasAdjustedSignExp); Value isSubnormal = - b.create<arith::CmpIOp>(arith::CmpIPredicate::sle, f8Exp, c0x00); + arith::CmpIOp::create(b, arith::CmpIPredicate::sle, f8Exp, c0x00); Value isNegOneExp = - b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0xff); - Value man23Bits = b.create<arith::AndIOp>(f32Bits, cF32MantissaMask); - Value isNonZeroMan = b.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, - man23Bits, zeroExpBits); - Value roundToHalf = b.create<arith::AndIOp>(isNegOneExp, isNonZeroMan); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0xff); + Value man23Bits = arith::AndIOp::create(b, f32Bits, cF32MantissaMask); + Value isNonZeroMan = arith::CmpIOp::create(b, arith::CmpIPredicate::ugt, + man23Bits, zeroExpBits); + Value roundToHalf = arith::AndIOp::create(b, isNegOneExp, isNonZeroMan); Value isZeroExp = - b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, f8Exp, c0x00); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, f8Exp, c0x00); Value subnormalF4Bits = createConst(loc, i4Ty, 0xf, rewriter); Value halfF4Bits = createConst(loc, i4Ty, 0x0, rewriter); Value subResult = - b.create<arith::SelectOp>(isSubnormal, subnormalF4Bits, f4Bits); - subResult = b.create<arith::SelectOp>(roundToHalf, halfF4Bits, subResult); - f4Bits = b.create<arith::SelectOp>(isZeroExp, f4Bits, subResult); + arith::SelectOp::create(b, isSubnormal, subnormalF4Bits, f4Bits); + subResult = arith::SelectOp::create(b, roundToHalf, halfF4Bits, subResult); + f4Bits = arith::SelectOp::create(b, isZeroExp, f4Bits, subResult); // Step 5: Round up if necessary. Value cF32Last22BitMask = createConst(loc, i32Ty, 0x3fffff, rewriter); Value cRound = createConst(loc, i32Ty, 0x200000, rewriter); // 010 0000... - Value man22Bits = b.create<arith::AndIOp>(f32Bits, cF32Last22BitMask); + Value man22Bits = arith::AndIOp::create(b, f32Bits, cF32Last22BitMask); Value shouldRound = - b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, man22Bits, cRound); - shouldRound = b.create<arith::OrIOp>(shouldRound, isSubnormal); - Value roundedF4Bits = b.create<arith::AddIOp>(f4Bits, c0x1); - f4Bits = b.create<arith::SelectOp>(shouldRound, roundedF4Bits, f4Bits); + arith::CmpIOp::create(b, arith::CmpIPredicate::uge, man22Bits, cRound); + shouldRound = arith::OrIOp::create(b, shouldRound, isSubnormal); + Value roundedF4Bits = arith::AddIOp::create(b, f4Bits, c0x1); + f4Bits = arith::SelectOp::create(b, shouldRound, roundedF4Bits, f4Bits); - Value result = b.create<arith::BitcastOp>(resultTy, f4Bits); + Value result = arith::BitcastOp::create(b, resultTy, f4Bits); rewriter.replaceOp(op, result); return success(); } @@ -625,16 +626,16 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> { Type f32Ty = cloneToShapedType(operandTy, b.getF32Type()); if (operandETy.getIntOrFloatBitWidth() < 32) { - operand = b.create<arith::ExtFOp>(f32Ty, operand, op.getFastmathAttr()); + operand = arith::ExtFOp::create(b, f32Ty, operand, op.getFastmathAttr()); } else if (operandETy.getIntOrFloatBitWidth() > 32) { - operand = b.create<arith::TruncFOp>( - f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr()); + operand = arith::TruncFOp::create( + b, f32Ty, operand, op.getRoundingmodeAttr(), op.getFastmathAttr()); } - Value f32Bits = b.create<arith::BitcastOp>(i32Ty, operand); + Value f32Bits = arith::BitcastOp::create(b, i32Ty, operand); Value cF32MantissaWidth = createConst(op->getLoc(), i32Ty, 23, rewriter); - Value f32SignExp = b.create<arith::ShRUIOp>(f32Bits, cF32MantissaWidth); - Value exp8Bits = b.create<arith::TruncIOp>(i8Ty, f32SignExp); - Value result = b.create<arith::BitcastOp>(resultTy, exp8Bits); + Value f32SignExp = arith::ShRUIOp::create(b, f32Bits, cF32MantissaWidth); + Value exp8Bits = arith::TruncIOp::create(b, i8Ty, f32SignExp); + Value result = arith::BitcastOp::create(b, resultTy, exp8Bits); rewriter.replaceOp(op, result); return success(); } @@ -653,8 +654,8 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> { if (scaleETy.getIntOrFloatBitWidth() >= 16) { scaleETy = b.getF8E8M0Type(); scaleTy = cloneToShapedType(scaleTy, scaleETy); - scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr, - op.getFastmathAttr()); + scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr, + op.getFastmathAttr()); } // Catch scale types like f8E5M2. if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) { @@ -666,11 +667,11 @@ struct ScalingExtFOpConverter : public OpRewritePattern<arith::ScalingExtFOp> { // extf on scale will essentially create floating point number // of type resulTy that is 2^scale and will also propagate NaNs Value scaleExt = - b.create<arith::ExtFOp>(resultTy, scaleOperand, op.getFastmathAttr()); + arith::ExtFOp::create(b, resultTy, scaleOperand, op.getFastmathAttr()); Value inputExt = - b.create<arith::ExtFOp>(resultTy, inputOperand, op.getFastmathAttr()); + arith::ExtFOp::create(b, resultTy, inputOperand, op.getFastmathAttr()); Value result = - b.create<arith::MulFOp>(inputExt, scaleExt, op.getFastmathAttr()); + arith::MulFOp::create(b, inputExt, scaleExt, op.getFastmathAttr()); rewriter.replaceOp(op, result); return success(); } @@ -695,8 +696,8 @@ struct ScalingTruncFOpConverter if (scaleETy.getIntOrFloatBitWidth() >= 16) { scaleETy = b.getF8E8M0Type(); scaleTy = cloneToShapedType(scaleTy, scaleETy); - scaleOperand = b.create<arith::TruncFOp>(scaleTy, scaleOperand, nullptr, - op.getFastmathAttr()); + scaleOperand = arith::TruncFOp::create(b, scaleTy, scaleOperand, nullptr, + op.getFastmathAttr()); } if (!llvm::isa<Float8E8M0FNUType>(scaleETy)) { return rewriter.notifyMatchFailure( @@ -708,11 +709,11 @@ struct ScalingTruncFOpConverter // this will create a floating point number of type // inputTy that is 2^scale and will also propagate NaNs scaleOperand = - b.create<arith::ExtFOp>(inputTy, scaleOperand, op.getFastmathAttr()); - Value result = b.create<arith::DivFOp>(inputOperand, scaleOperand, - op.getFastmathAttr()); - Value resultCast = b.create<arith::TruncFOp>( - resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); + arith::ExtFOp::create(b, inputTy, scaleOperand, op.getFastmathAttr()); + Value result = arith::DivFOp::create(b, inputOperand, scaleOperand, + op.getFastmathAttr()); + Value resultCast = arith::TruncFOp::create( + b, resultTy, result, op.getRoundingmodeAttr(), op.getFastmathAttr()); rewriter.replaceOp(op, resultCast); return success(); } diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index f2f9388..777ff0e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -305,18 +305,18 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType, if (isa<IndexType>(srcElemType) || isa<IndexType>(dstElemType)) { if (castKind == CastKind::Signed) - return builder.create<arith::IndexCastOp>(loc, dstType, src); - return builder.create<arith::IndexCastUIOp>(loc, dstType, src); + return arith::IndexCastOp::create(builder, loc, dstType, src); + return arith::IndexCastUIOp::create(builder, loc, dstType, src); } auto srcInt = cast<IntegerType>(srcElemType); auto dstInt = cast<IntegerType>(dstElemType); if (dstInt.getWidth() < srcInt.getWidth()) - return builder.create<arith::TruncIOp>(loc, dstType, src); + return arith::TruncIOp::create(builder, loc, dstType, src); if (castKind == CastKind::Signed) - return builder.create<arith::ExtSIOp>(loc, dstType, src); - return builder.create<arith::ExtUIOp>(loc, dstType, src); + return arith::ExtSIOp::create(builder, loc, dstType, src); + return arith::ExtUIOp::create(builder, loc, dstType, src); } struct NarrowElementwise final : OpTraitRewritePattern<OpTrait::Elementwise> { diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp index 5fb7953..4bdd1e6 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -23,8 +23,8 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, std::function<Value(AffineExpr)> buildExpr = [&](AffineExpr e) -> Value { switch (e.getKind()) { case AffineExprKind::Constant: - return b.create<ConstantIndexOp>(loc, - cast<AffineConstantExpr>(e).getValue()); + return ConstantIndexOp::create(b, loc, + cast<AffineConstantExpr>(e).getValue()); case AffineExprKind::DimId: return operands[cast<AffineDimExpr>(e).getPosition()]; case AffineExprKind::SymbolId: @@ -32,28 +32,28 @@ static Value buildArithValue(OpBuilder &b, Location loc, AffineMap map, map.getNumDims()]; case AffineExprKind::Add: { auto binaryExpr = cast<AffineBinaryOpExpr>(e); - return b.create<AddIOp>(loc, buildExpr(binaryExpr.getLHS()), - buildExpr(binaryExpr.getRHS())); + return AddIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); } case AffineExprKind::Mul: { auto binaryExpr = cast<AffineBinaryOpExpr>(e); - return b.create<MulIOp>(loc, buildExpr(binaryExpr.getLHS()), - buildExpr(binaryExpr.getRHS())); + return MulIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); } case AffineExprKind::FloorDiv: { auto binaryExpr = cast<AffineBinaryOpExpr>(e); - return b.create<DivSIOp>(loc, buildExpr(binaryExpr.getLHS()), - buildExpr(binaryExpr.getRHS())); + return DivSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); } case AffineExprKind::CeilDiv: { auto binaryExpr = cast<AffineBinaryOpExpr>(e); - return b.create<CeilDivSIOp>(loc, buildExpr(binaryExpr.getLHS()), - buildExpr(binaryExpr.getRHS())); + return CeilDivSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); } case AffineExprKind::Mod: { auto binaryExpr = cast<AffineBinaryOpExpr>(e); - return b.create<RemSIOp>(loc, buildExpr(binaryExpr.getLHS()), - buildExpr(binaryExpr.getRHS())); + return RemSIOp::create(b, loc, buildExpr(binaryExpr.getLHS()), + buildExpr(binaryExpr.getRHS())); } } llvm_unreachable("unsupported AffineExpr kind"); @@ -89,10 +89,10 @@ FailureOr<OpFoldResult> mlir::arith::reifyValueBound( "expected dynamic dim"); if (isa<RankedTensorType>(value.getType())) { // A tensor dimension is used: generate a tensor.dim. - operands.push_back(b.create<tensor::DimOp>(loc, value, *dim)); + operands.push_back(tensor::DimOp::create(b, loc, value, *dim)); } else if (isa<MemRefType>(value.getType())) { // A memref dimension is used: generate a memref.dim. - operands.push_back(b.create<memref::DimOp>(loc, value, *dim)); + operands.push_back(memref::DimOp::create(b, loc, value, *dim)); } else { llvm_unreachable("cannot generate DimOp for unsupported shaped type"); } diff --git a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp index 3478adc..3e34246 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ShardingInterfaceImpl.cpp @@ -6,22 +6,22 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/IR/DialectRegistry.h" using namespace mlir; using namespace mlir::arith; -using namespace mlir::mesh; +using namespace mlir::shard; namespace { // Sharding of arith.constant // RankedTensor constants can be sharded like any other tensor. // %cst = arith.constant dense<0.000000e+00> : tensor<1024x1024xf32> -// %sharding = mesh.sharding @mesh4x4 split_axes = [[0]] : !mesh.sharding +// %sharding = shard.sharding @grid4x4 split_axes = [[0]] : !shard.sharding // Scalar constants are always replicated and need no sharding annotation. struct ConstantShardingInterface @@ -48,8 +48,8 @@ struct ConstantShardingInterface // Otherwise mirror result sharding if it is a tensor constant. // Otherwise return replication option. FailureOr<ShardingOption> - getShardingOption(Operation *op, ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings) const { + getShardingOption(Operation *op, ArrayRef<Sharding> operandShardings, + ArrayRef<Sharding> resultShardings) const { assert(resultShardings.size() == 1 && "Expecting exactly one result sharding for arith.constant"); auto resultSharding = resultShardings[0]; @@ -61,17 +61,17 @@ struct ConstantShardingInterface for (auto [i, axes] : llvm::enumerate(resultSharding.getSplitAxes())) { axesArray[i].append(axes.asArrayRef().begin(), axes.asArrayRef().end()); } - return ShardingOption(axesArray, resultSharding.getMeshAttr()); + return ShardingOption(axesArray, resultSharding.getGridAttr()); } - return ShardingOption({}, resultSharding.getMeshAttr()); + return ShardingOption({}, resultSharding.getGridAttr()); } - LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, - ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder) const { + LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands, + ArrayRef<Sharding> operandShardings, + ArrayRef<Sharding> resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { auto cOp = cast<ConstantOp>(op); if (auto value = dyn_cast<DenseIntOrFPElementsAttr>(cOp.getValue())) { if (!value.isSplat() || !resultShardings[0]) { @@ -80,15 +80,15 @@ struct ConstantShardingInterface } auto sharding = resultShardings[0]; auto newType = cast<RankedTensorType>(shardType( - cOp.getType(), getMesh(op, sharding.getMeshAttr(), symbolTable), + cOp.getType(), getGrid(op, sharding.getGridAttr(), symbolTable), sharding)); auto newValue = value.resizeSplat(newType); - auto newOp = builder.create<ConstantOp>(op->getLoc(), newType, newValue); - spmdizationMap.map(op->getResult(0), newOp.getResult()); - spmdizationMap.map(op, newOp.getOperation()); + auto newOp = ConstantOp::create(builder, op->getLoc(), newType, newValue); + partitionMap.map(op->getResult(0), newOp.getResult()); + partitionMap.map(op, newOp.getOperation()); } else { // `clone` will populate the mapping of old to new results. - (void)builder.clone(*op, spmdizationMap); + (void)builder.clone(*op, partitionMap); } return success(); } diff --git a/mlir/lib/Dialect/Arith/Utils/Utils.cpp b/mlir/lib/Dialect/Arith/Utils/Utils.cpp index bdeeccb..b1fc9aa 100644 --- a/mlir/lib/Dialect/Arith/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Arith/Utils/Utils.cpp @@ -67,7 +67,7 @@ mlir::inferExpandShapeOutputShape(OpBuilder &b, Location loc, // dynamism. Value indexGroupSize = cast<Value>(inputShape[inputIndex]); Value indexGroupStaticSizesProduct = - b.create<arith::ConstantIndexOp>(loc, indexGroupStaticSizesProductInt); + arith::ConstantIndexOp::create(b, loc, indexGroupStaticSizesProductInt); Value dynamicDimSize = b.createOrFold<arith::DivSIOp>( loc, indexGroupSize, indexGroupStaticSizesProduct); outputShapeValues.push_back(dynamicDimSize); @@ -104,8 +104,8 @@ Value mlir::getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, if (auto value = dyn_cast_if_present<Value>(ofr)) return value; auto attr = cast<IntegerAttr>(cast<Attribute>(ofr)); - return b.create<arith::ConstantOp>( - loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue())); + return arith::ConstantOp::create( + b, loc, b.getIntegerAttr(attr.getType(), attr.getValue().getSExtValue())); } Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, @@ -113,7 +113,7 @@ Value mlir::getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, if (auto value = dyn_cast_if_present<Value>(ofr)) return value; auto attr = cast<IntegerAttr>(cast<Attribute>(ofr)); - return b.create<arith::ConstantIndexOp>(loc, attr.getValue().getSExtValue()); + return arith::ConstantIndexOp::create(b, loc, attr.getValue().getSExtValue()); } Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, @@ -124,7 +124,7 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, bool targetIsIndex = targetType.isIndex(); bool valueIsIndex = value.getType().isIndex(); if (targetIsIndex ^ valueIsIndex) - return b.create<arith::IndexCastOp>(loc, targetType, value); + return arith::IndexCastOp::create(b, loc, targetType, value); auto targetIntegerType = dyn_cast<IntegerType>(targetType); auto valueIntegerType = dyn_cast<IntegerType>(value.getType()); @@ -133,8 +133,8 @@ Value mlir::getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, assert(targetIntegerType.getSignedness() == valueIntegerType.getSignedness()); if (targetIntegerType.getWidth() > valueIntegerType.getWidth()) - return b.create<arith::ExtSIOp>(loc, targetIntegerType, value); - return b.create<arith::TruncIOp>(loc, targetIntegerType, value); + return arith::ExtSIOp::create(b, loc, targetIntegerType, value); + return arith::TruncIOp::create(b, loc, targetIntegerType, value); } static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, @@ -142,21 +142,21 @@ static Value convertScalarToIntDtype(ImplicitLocOpBuilder &b, Value operand, // If operand is floating point, cast directly to the int type. if (isa<FloatType>(operand.getType())) { if (isUnsigned) - return b.create<arith::FPToUIOp>(toType, operand); - return b.create<arith::FPToSIOp>(toType, operand); + return arith::FPToUIOp::create(b, toType, operand); + return arith::FPToSIOp::create(b, toType, operand); } // Cast index operands directly to the int type. if (operand.getType().isIndex()) - return b.create<arith::IndexCastOp>(toType, operand); + return arith::IndexCastOp::create(b, toType, operand); if (auto fromIntType = dyn_cast<IntegerType>(operand.getType())) { // Either extend or truncate. if (toType.getWidth() > fromIntType.getWidth()) { if (isUnsigned) - return b.create<arith::ExtUIOp>(toType, operand); - return b.create<arith::ExtSIOp>(toType, operand); + return arith::ExtUIOp::create(b, toType, operand); + return arith::ExtSIOp::create(b, toType, operand); } if (toType.getWidth() < fromIntType.getWidth()) - return b.create<arith::TruncIOp>(toType, operand); + return arith::TruncIOp::create(b, toType, operand); return operand; } @@ -169,14 +169,14 @@ static Value convertScalarToFpDtype(ImplicitLocOpBuilder &b, Value operand, // Note that it is unclear how to cast from BF16<->FP16. if (isa<IntegerType>(operand.getType())) { if (isUnsigned) - return b.create<arith::UIToFPOp>(toType, operand); - return b.create<arith::SIToFPOp>(toType, operand); + return arith::UIToFPOp::create(b, toType, operand); + return arith::SIToFPOp::create(b, toType, operand); } if (auto fromFpTy = dyn_cast<FloatType>(operand.getType())) { if (toType.getWidth() > fromFpTy.getWidth()) - return b.create<arith::ExtFOp>(toType, operand); + return arith::ExtFOp::create(b, toType, operand); if (toType.getWidth() < fromFpTy.getWidth()) - return b.create<arith::TruncFOp>(toType, operand); + return arith::TruncFOp::create(b, toType, operand); return operand; } @@ -189,18 +189,18 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, if (auto fromComplexType = dyn_cast<ComplexType>(operand.getType())) { if (isa<FloatType>(targetType.getElementType()) && isa<FloatType>(fromComplexType.getElementType())) { - Value real = b.create<complex::ReOp>(operand); - Value imag = b.create<complex::ImOp>(operand); + Value real = complex::ReOp::create(b, operand); + Value imag = complex::ImOp::create(b, operand); Type targetETy = targetType.getElementType(); if (targetType.getElementType().getIntOrFloatBitWidth() < fromComplexType.getElementType().getIntOrFloatBitWidth()) { - real = b.create<arith::TruncFOp>(targetETy, real); - imag = b.create<arith::TruncFOp>(targetETy, imag); + real = arith::TruncFOp::create(b, targetETy, real); + imag = arith::TruncFOp::create(b, targetETy, imag); } else { - real = b.create<arith::ExtFOp>(targetETy, real); - imag = b.create<arith::ExtFOp>(targetETy, imag); + real = arith::ExtFOp::create(b, targetETy, real); + imag = arith::ExtFOp::create(b, targetETy, imag); } - return b.create<complex::CreateOp>(targetType, real, imag); + return complex::CreateOp::create(b, targetType, real, imag); } } @@ -209,27 +209,27 @@ static Value convertScalarToComplexDtype(ImplicitLocOpBuilder &b, Value operand, auto toBitwidth = toFpTy.getIntOrFloatBitWidth(); Value from = operand; if (from.getType().getIntOrFloatBitWidth() < toBitwidth) { - from = b.create<arith::ExtFOp>(toFpTy, from); + from = arith::ExtFOp::create(b, toFpTy, from); } if (from.getType().getIntOrFloatBitWidth() > toBitwidth) { - from = b.create<arith::TruncFOp>(toFpTy, from); + from = arith::TruncFOp::create(b, toFpTy, from); } - Value zero = b.create<mlir::arith::ConstantFloatOp>( - toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0)); - return b.create<complex::CreateOp>(targetType, from, zero); + Value zero = mlir::arith::ConstantFloatOp::create( + b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0)); + return complex::CreateOp::create(b, targetType, from, zero); } if (isa<IntegerType>(operand.getType())) { FloatType toFpTy = cast<FloatType>(targetType.getElementType()); Value from = operand; if (isUnsigned) { - from = b.create<arith::UIToFPOp>(toFpTy, from); + from = arith::UIToFPOp::create(b, toFpTy, from); } else { - from = b.create<arith::SIToFPOp>(toFpTy, from); + from = arith::SIToFPOp::create(b, toFpTy, from); } - Value zero = b.create<mlir::arith::ConstantFloatOp>( - toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0)); - return b.create<complex::CreateOp>(targetType, from, zero); + Value zero = mlir::arith::ConstantFloatOp::create( + b, toFpTy, mlir::APFloat(toFpTy.getFloatSemantics(), 0)); + return complex::CreateOp::create(b, targetType, from, zero); } return {}; @@ -277,7 +277,7 @@ Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, attr = SplatElementsAttr::get(vecTy, value); } - return builder.create<arith::ConstantOp>(loc, attr); + return arith::ConstantOp::create(builder, loc, attr); } Value mlir::createScalarOrSplatConstant(OpBuilder &builder, Location loc, @@ -309,35 +309,35 @@ Type mlir::getType(OpFoldResult ofr) { } Value ArithBuilder::_and(Value lhs, Value rhs) { - return b.create<arith::AndIOp>(loc, lhs, rhs); + return arith::AndIOp::create(b, loc, lhs, rhs); } Value ArithBuilder::add(Value lhs, Value rhs) { if (isa<FloatType>(lhs.getType())) - return b.create<arith::AddFOp>(loc, lhs, rhs); - return b.create<arith::AddIOp>(loc, lhs, rhs, ovf); + return arith::AddFOp::create(b, loc, lhs, rhs); + return arith::AddIOp::create(b, loc, lhs, rhs, ovf); } Value ArithBuilder::sub(Value lhs, Value rhs) { if (isa<FloatType>(lhs.getType())) - return b.create<arith::SubFOp>(loc, lhs, rhs); - return b.create<arith::SubIOp>(loc, lhs, rhs, ovf); + return arith::SubFOp::create(b, loc, lhs, rhs); + return arith::SubIOp::create(b, loc, lhs, rhs, ovf); } Value ArithBuilder::mul(Value lhs, Value rhs) { if (isa<FloatType>(lhs.getType())) - return b.create<arith::MulFOp>(loc, lhs, rhs); - return b.create<arith::MulIOp>(loc, lhs, rhs, ovf); + return arith::MulFOp::create(b, loc, lhs, rhs); + return arith::MulIOp::create(b, loc, lhs, rhs, ovf); } Value ArithBuilder::sgt(Value lhs, Value rhs) { if (isa<FloatType>(lhs.getType())) - return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGT, lhs, rhs); - return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, lhs, rhs); + return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OGT, lhs, rhs); + return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::sgt, lhs, rhs); } Value ArithBuilder::slt(Value lhs, Value rhs) { if (isa<FloatType>(lhs.getType())) - return b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, lhs, rhs); - return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt, lhs, rhs); + return arith::CmpFOp::create(b, loc, arith::CmpFPredicate::OLT, lhs, rhs); + return arith::CmpIOp::create(b, loc, arith::CmpIPredicate::slt, lhs, rhs); } Value ArithBuilder::select(Value cmp, Value lhs, Value rhs) { - return b.create<arith::SelectOp>(loc, cmp, lhs, rhs); + return arith::SelectOp::create(b, loc, cmp, lhs, rhs); } namespace mlir::arith { @@ -348,8 +348,8 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values) { Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values, Type resultType) { - Value one = builder.create<ConstantOp>(loc, resultType, - builder.getOneAttr(resultType)); + Value one = ConstantOp::create(builder, loc, resultType, + builder.getOneAttr(resultType)); ArithBuilder arithBuilder(builder, loc); return std::accumulate( values.begin(), values.end(), one, diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp index 5aadaec..1aa8064 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp @@ -49,7 +49,7 @@ std::optional<Value> getExtOperand(Value v) { // If the operand is not defined by an explicit extend operation of the // accepted operation type allow for an implicit sign-extension. - auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp()); + auto extOp = v.getDefiningOp<Op>(); if (!extOp) { if constexpr (std::is_same<Op, arith::ExtSIOp>::value) { auto eltTy = cast<VectorType>(v.getType()).getElementType(); @@ -145,8 +145,8 @@ protected: return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc, lhs, rhs); case MMLA::Bfloat: - return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs, - rhs); + return arm_neon::BfmmlaOp::create(rewriter, loc, acc.getType(), acc, lhs, + rhs); case MMLA::Nop: llvm_unreachable("Uninitialized operation type"); } @@ -226,8 +226,9 @@ public: // Initial accumulator for the final result. This is the un-tiled result if // tiling is done. - Value result = rewriter.create<arith::ConstantOp>( - loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); + Value result = + arith::ConstantOp::create(rewriter, loc, op.getResultType(), + rewriter.getZeroAttr(op.getResultType())); SmallVector<int64_t, 3> loopOrder = {0, 1}; if (iterationBounds.size() == 3) @@ -263,8 +264,9 @@ public: if (dimM == 1) { auto expandRowVector = [&](Value tiledOperand, VectorType expandedTypeType) { - auto emptyOperand = rewriter.create<arith::ConstantOp>( - loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); + auto emptyOperand = + arith::ConstantOp::create(rewriter, loc, expandedTypeType, + rewriter.getZeroAttr(expandedTypeType)); SmallVector<int64_t> offsets( cast<ShapedType>(emptyOperand.getType()).getRank(), 0); SmallVector<int64_t> strides( @@ -280,8 +282,8 @@ public: // using the instruction for unsigned by signed multiplication with // reversed operands. if (swapOperands) - tiledAcc = rewriter.create<vector::TransposeOp>( - loc, tiledAcc, ArrayRef<int64_t>({1, 0})); + tiledAcc = vector::TransposeOp::create(rewriter, loc, tiledAcc, + ArrayRef<int64_t>({1, 0})); // Collapse tiled operands to 1D vectors required by the ArmNeon ops auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>( @@ -309,8 +311,8 @@ public: // Because of the reversed operands the result is obtained transposed. // Transpose it back, if (swapOperands) - tiledRes = rewriter.create<vector::TransposeOp>( - loc, tiledRes, ArrayRef<int64_t>({1, 0})); + tiledRes = vector::TransposeOp::create(rewriter, loc, tiledRes, + ArrayRef<int64_t>({1, 0})); // With vecmat, only one row of tiled ACC can be inserted into the final // result diff --git a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp index 5f00cef..e5e1312 100644 --- a/mlir/lib/Dialect/ArmSME/IR/Utils.cpp +++ b/mlir/lib/Dialect/ArmSME/IR/Utils.cpp @@ -75,21 +75,21 @@ scf::ForOp createLoopOverTileSlices( PatternRewriter &rewriter, Location loc, Value initTile, std::function<Value(OpBuilder &, Location, Value, Value)> makeLoopBody) { OpBuilder::InsertionGuard g(rewriter); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - auto minTileSlices = rewriter.create<arith::ConstantIndexOp>( - loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0)); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + auto minTileSlices = arith::ConstantIndexOp::create( + rewriter, loc, llvm::cast<VectorType>(initTile.getType()).getDimSize(0)); auto vscale = - rewriter.create<vector::VectorScaleOp>(loc, rewriter.getIndexType()); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); + vector::VectorScaleOp::create(rewriter, loc, rewriter.getIndexType()); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto numTileSlices = - rewriter.create<arith::MulIOp>(loc, minTileSlices, vscale); - auto forOp = rewriter.create<scf::ForOp>(loc, lowerBound, numTileSlices, step, - ValueRange{initTile}); + arith::MulIOp::create(rewriter, loc, minTileSlices, vscale); + auto forOp = scf::ForOp::create(rewriter, loc, lowerBound, numTileSlices, + step, ValueRange{initTile}); rewriter.setInsertionPointToStart(forOp.getBody()); Value nextTile = makeLoopBody(rewriter, loc, /*tileSliceIndex=*/forOp.getInductionVar(), /*currentTile=*/forOp.getRegionIterArg(0)); - rewriter.create<scf::YieldOp>(loc, nextTile); + scf::YieldOp::create(rewriter, loc, nextTile); return forOp; } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp index 23f2c2b..9bf0265 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp @@ -136,7 +136,7 @@ public: auto loc = op.getLoc(); auto packInputs = [&](Value lhs, Value rhs) { - return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs); + return vector::InterleaveOp::create(rewriter, loc, lhs, rhs); }; auto lhs = packInputs(op1.getLhs().getDefiningOp()->getOperand(0), @@ -284,7 +284,7 @@ public: auto loc = op.getLoc(); auto packInputs = [&](Value lhs, Value rhs) { - return rewriter.create<vector::InterleaveOp>(loc, lhs, rhs); + return vector::InterleaveOp::create(rewriter, loc, lhs, rhs); }; auto lhs0 = packInputs(op1.getLhs().getDefiningOp()->getOperand(0), @@ -456,8 +456,8 @@ struct SwapVectorExtractOfArithExtend Value extendSource = extendOp->getOperand(0); // Create new extract from source of extend. - Value newExtract = rewriter.create<vector::ExtractOp>( - loc, extendSource, extractOp.getMixedPosition()); + Value newExtract = vector::ExtractOp::create(rewriter, loc, extendSource, + extractOp.getMixedPosition()); // Extend new extract to original result type. Operation *newExtend = @@ -503,8 +503,9 @@ struct SwapVectorScalableExtractOfArithExtend // Create new extract from source of extend. VectorType extractResultVectorType = resultType.clone(extendSourceVectorType.getElementType()); - Value newExtract = rewriter.create<vector::ScalableExtractOp>( - loc, extractResultVectorType, extendSource, extractOp.getPos()); + Value newExtract = vector::ScalableExtractOp::create( + rewriter, loc, extractResultVectorType, extendSource, + extractOp.getPos()); // Extend new extract to original result type. Operation *newExtend = diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp index b3c988d..d925c19 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp @@ -210,7 +210,7 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) { auto insertJump = [&](Location loc, Block *source, Block *dest, auto args) { rewriter.setInsertionPointToEnd(source); - rewriter.create<cf::BranchOp>(loc, dest, args); + cf::BranchOp::create(rewriter, loc, dest, args); }; for (auto condBranch : worklist) { @@ -253,7 +253,7 @@ void insertCopiesAtBranches(IRRewriter &rewriter, for (OpOperand &operand : terminator->getOpOperands()) { if (isValidSMETileVectorType(operand.get().getType())) { auto copy = - rewriter.create<CopyTileOp>(terminator->getLoc(), operand.get()); + CopyTileOp::create(rewriter, terminator->getLoc(), operand.get()); rewriter.modifyOpInPlace(terminator, [&] { operand.assign(copy); }); } } diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 1e8e126..1c0eced 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -82,13 +82,14 @@ SmallVector<Value, 2> addConstantScalableOffset(OpBuilder &builder, Location loc, ValueRange indices, ArrayRef<int> scalableOffsets) { - auto vscale = builder.create<vector::VectorScaleOp>(loc); + auto vscale = vector::VectorScaleOp::create(builder, loc); return llvm::map_to_vector( llvm::zip_equal(indices, scalableOffsets), [&](auto pair) -> Value { auto [index, base] = pair; - auto offset = builder.create<arith::MulIOp>( - loc, builder.create<arith::ConstantIndexOp>(loc, base), vscale); - return builder.create<arith::AddIOp>(loc, index, offset); + auto offset = arith::MulIOp::create( + builder, loc, arith::ConstantIndexOp::create(builder, loc, base), + vscale); + return arith::AddIOp::create(builder, loc, index, offset); }); } @@ -132,8 +133,8 @@ Value extractSMEMask(OpBuilder &builder, Location loc, Value mask, // from the mask operands to get the parameters for this sub-tile. auto smeTileMaskDims = addConstantScalableOffset( builder, loc, createMask.getOperands(), {-smeTile.row, -smeTile.col}); - auto smeTileCreateMask = builder.create<vector::CreateMaskOp>( - loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims); + auto smeTileCreateMask = vector::CreateMaskOp::create( + builder, loc, smeTile.type.clone(builder.getI1Type()), smeTileMaskDims); return smeTileCreateMask.getResult(); } @@ -190,8 +191,8 @@ struct LegalizeArithConstantOpsByDecomposition auto smeTileType = getSMETileTypeForElement(vectorType.getElementType()); auto tileCount = getNumberOfSMETilesForVectorType(vectorType); - auto tileSplat = rewriter.create<arith::ConstantOp>( - constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); + auto tileSplat = arith::ConstantOp::create( + rewriter, constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); SmallVector<Value> repl(tileCount, tileSplat); rewriter.replaceOpWithMultiple(constantOp, {repl}); @@ -237,12 +238,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition decomposeToSMETiles(rewriter, vectorType, smeTileType))) { auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); - auto lhs = rewriter.create<vector::ScalableExtractOp>( - loc, sliceType, outerProductOp.getLhs(), smeTile.row); - auto rhs = rewriter.create<vector::ScalableExtractOp>( - loc, sliceType, outerProductOp.getRhs(), smeTile.col); - auto smeOuterProduct = rewriter.create<vector::OuterProductOp>( - loc, smeTileType, lhs, rhs, + auto lhs = vector::ScalableExtractOp::create( + rewriter, loc, sliceType, outerProductOp.getLhs(), smeTile.row); + auto rhs = vector::ScalableExtractOp::create( + rewriter, loc, sliceType, outerProductOp.getRhs(), smeTile.col); + auto smeOuterProduct = vector::OuterProductOp::create( + rewriter, loc, smeTileType, lhs, rhs, !accSMETiles.empty() ? accSMETiles[index] : Value{}, outerProductOp.getKind()); @@ -314,8 +315,8 @@ struct LegalizeTransferReadOpsByDecomposition for (SMESubTile smeTile : decomposeToSMETiles(rewriter, vectorType, smeTileType, transposed)) { auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); - auto smeRead = rewriter.create<vector::TransferReadOp>( - loc, smeTileType, readOp.getBase(), + auto smeRead = vector::TransferReadOp::create( + rewriter, loc, smeTileType, readOp.getBase(), getSMESubTileIndices(rewriter, loc, readOp.getIndices(), smeTile), readOp.getPermutationMapAttr(), readOp.getPadding(), smeMask, readOp.getInBoundsAttr()); @@ -363,8 +364,8 @@ struct LegalizeTransferWriteOpsByDecomposition for (auto [index, smeTile] : llvm::enumerate(decomposeToSMETiles( rewriter, vectorType, smeTileType, transposed))) { auto smeMask = extractSMEMask(rewriter, loc, mask, smeTile); - auto smeWrite = rewriter.create<vector::TransferWriteOp>( - loc, inputSMETiles[index], destTensorOrMemref, + auto smeWrite = vector::TransferWriteOp::create( + rewriter, loc, inputSMETiles[index], destTensorOrMemref, getSMESubTileIndices(rewriter, loc, writeOp.getIndices(), smeTile), writeOp.getPermutationMapAttr(), smeMask, writeOp.getInBoundsAttr()); if (writeOp.hasPureTensorSemantics()) @@ -456,11 +457,11 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop VectorType::get(minTileSlices, rewriter.getI1Type(), true); // Create loop over all tile slices. - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto upperBound = createVscaleMultiple(minTileSlices); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); auto storeLoop = - rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step); + scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step); rewriter.setInsertionPointToStart(storeLoop.getBody()); // For each sub-tile of the multi-tile `vectorType`. @@ -474,30 +475,31 @@ struct LegalizeMultiTileTransferWriteAsStoreLoop // The current slice of `vectorType` we are processing. auto sliceIndex = - rewriter.create<arith::AddIOp>(loc, tileRow, tileSliceIndex); + arith::AddIOp::create(rewriter, loc, tileRow, tileSliceIndex); // Where in the destination memref the current slice will be stored. - auto storeRow = rewriter.create<arith::AddIOp>(loc, sliceIndex, - writeOp.getIndices()[0]); - auto storeCol = - rewriter.create<arith::AddIOp>(loc, tileCol, writeOp.getIndices()[1]); + auto storeRow = arith::AddIOp::create(rewriter, loc, sliceIndex, + writeOp.getIndices()[0]); + auto storeCol = arith::AddIOp::create(rewriter, loc, tileCol, + writeOp.getIndices()[1]); // Extract the mask for the current slice. Value sliceMask = nullptr; if (mask) { - sliceMask = rewriter.create<vector::ExtractOp>( - loc, mask, OpFoldResult(sliceIndex)); + sliceMask = vector::ExtractOp::create(rewriter, loc, mask, + OpFoldResult(sliceIndex)); if (sliceMaskType != sliceMask.getType()) - sliceMask = rewriter.create<vector::ScalableExtractOp>( - loc, sliceMaskType, sliceMask, smeTile.col); + sliceMask = vector::ScalableExtractOp::create( + rewriter, loc, sliceMaskType, sliceMask, smeTile.col); } // Extract and store the current slice. Value tile = inputSMETiles[index]; auto slice = - rewriter.create<vector::ExtractOp>(loc, tile, tileSliceIndex); - rewriter.create<vector::TransferWriteOp>( - loc, slice, writeOp.getBase(), ValueRange{storeRow, storeCol}, + vector::ExtractOp::create(rewriter, loc, tile, tileSliceIndex); + vector::TransferWriteOp::create( + rewriter, loc, slice, writeOp.getBase(), + ValueRange{storeRow, storeCol}, AffineMapAttr::get(writeOp.getPermutationMap().dropResult(0)), sliceMask, rewriter.getBoolArrayAttr( @@ -567,14 +569,15 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks extractOp, "constant vector.create_masks dims should be folded elsewhere"); - auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); auto extractionIndex = getValueOrCreateConstantIndexOp( rewriter, loc, extractOp.getMixedPosition()[0]); - auto extractionInTrueRegion = rewriter.create<arith::CmpIOp>( - loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, extractionIndex, - frontMaskDim); - auto newMaskFrontDim = rewriter.create<arith::SelectOp>( - loc, extractionInTrueRegion, createMaskOp.getOperand(1), zero); + auto extractionInTrueRegion = arith::CmpIOp::create( + rewriter, loc, rewriter.getI1Type(), arith::CmpIPredicate::slt, + extractionIndex, frontMaskDim); + auto newMaskFrontDim = + arith::SelectOp::create(rewriter, loc, extractionInTrueRegion, + createMaskOp.getOperand(1), zero); rewriter.replaceOpWithNewOp<vector::CreateMaskOp>( extractOp, extractedMaskType, @@ -660,8 +663,8 @@ struct LiftIllegalVectorTransposeToMemory illegalRead, "expected read to have identity permutation map"); auto loc = transposeOp.getLoc(); - auto zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto one = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto one = arith::ConstantIndexOp::create(rewriter, loc, 1); // Create a subview that matches the size of the illegal read vector type. auto readType = illegalRead.getVectorType(); @@ -669,16 +672,16 @@ struct LiftIllegalVectorTransposeToMemory llvm::zip_equal(readType.getShape(), readType.getScalableDims()), [&](auto dim) -> Value { auto [size, isScalable] = dim; - auto dimSize = rewriter.create<arith::ConstantIndexOp>(loc, size); + auto dimSize = arith::ConstantIndexOp::create(rewriter, loc, size); if (!isScalable) return dimSize; - auto vscale = rewriter.create<vector::VectorScaleOp>(loc); - return rewriter.create<arith::MulIOp>(loc, vscale, dimSize); + auto vscale = vector::VectorScaleOp::create(rewriter, loc); + return arith::MulIOp::create(rewriter, loc, vscale, dimSize); }); SmallVector<Value> strides(readType.getRank(), Value(one)); - auto readSubview = rewriter.create<memref::SubViewOp>( - loc, illegalRead.getBase(), illegalRead.getIndices(), readSizes, - strides); + auto readSubview = + memref::SubViewOp::create(rewriter, loc, illegalRead.getBase(), + illegalRead.getIndices(), readSizes, strides); // Apply the transpose to all values/attributes of the transfer_read: // - The mask @@ -686,14 +689,14 @@ struct LiftIllegalVectorTransposeToMemory if (mask) { // Note: The transpose for the mask should fold into the // vector.create_mask/constant_mask op, which will then become legal. - mask = rewriter.create<vector::TransposeOp>(loc, mask, - transposeOp.getPermutation()); + mask = vector::TransposeOp::create(rewriter, loc, mask, + transposeOp.getPermutation()); } // - The source memref mlir::AffineMap transposeMap = AffineMap::getPermutationMap( transposeOp.getPermutation(), getContext()); - auto transposedSubview = rewriter.create<memref::TransposeOp>( - loc, readSubview, AffineMapAttr::get(transposeMap)); + auto transposedSubview = memref::TransposeOp::create( + rewriter, loc, readSubview, AffineMapAttr::get(transposeMap)); ArrayAttr inBoundsAttr = illegalRead.getInBoundsAttr(); // - The `in_bounds` attribute if (inBoundsAttr) { @@ -706,8 +709,8 @@ struct LiftIllegalVectorTransposeToMemory VectorType legalReadType = resultType.clone(readType.getElementType()); // Note: The indices are all zero as the subview is already offset. SmallVector<Value> readIndices(illegalRead.getIndices().size(), zero); - auto legalRead = rewriter.create<vector::TransferReadOp>( - loc, legalReadType, transposedSubview, readIndices, + auto legalRead = vector::TransferReadOp::create( + rewriter, loc, legalReadType, transposedSubview, readIndices, illegalRead.getPermutationMapAttr(), illegalRead.getPadding(), mask, inBoundsAttr); @@ -797,12 +800,12 @@ struct LowerIllegalTransposeStoreViaZA AffineMap::getPermutationMap(ArrayRef<int64_t>{1, 0}, getContext())); // Note: We need to use `get_tile` as there's no vector-level `undef`. - Value undefTile = rewriter.create<arm_sme::GetTileOp>(loc, smeTileType); + Value undefTile = arm_sme::GetTileOp::create(rewriter, loc, smeTileType); Value destTensorOrMemref = writeOp.getBase(); auto numSlicesPerTile = std::min(sourceType.getDimSize(0), smeTileType.getDimSize(0)); auto numSlices = - rewriter.create<arith::ConstantIndexOp>(loc, numSlicesPerTile); + arith::ConstantIndexOp::create(rewriter, loc, numSlicesPerTile); for (auto [index, smeTile] : llvm::enumerate( decomposeToSMETiles(rewriter, sourceType, smeTileType))) { // 1. _Deliberately_ drop a scalable dimension and insert a fixed number @@ -811,47 +814,47 @@ struct LowerIllegalTransposeStoreViaZA // rows of the tile after 1*vscale rows. Value tile = undefTile; for (int d = 0; d < numSlicesPerTile; ++d) { - Value vector = rewriter.create<vector::ExtractOp>( - loc, transposeOp.getVector(), - rewriter.getIndexAttr(d + smeTile.row)); + Value vector = + vector::ExtractOp::create(rewriter, loc, transposeOp.getVector(), + rewriter.getIndexAttr(d + smeTile.row)); if (vector.getType() != smeSliceType) { - vector = rewriter.create<vector::ScalableExtractOp>( - loc, smeSliceType, vector, smeTile.col); + vector = vector::ScalableExtractOp::create( + rewriter, loc, smeSliceType, vector, smeTile.col); } - tile = rewriter.create<vector::InsertOp>(loc, vector, tile, d); + tile = vector::InsertOp::create(rewriter, loc, vector, tile, d); } // 2. Transpose the tile position. auto transposedRow = createVscaleMultiple(smeTile.col); auto transposedCol = - rewriter.create<arith::ConstantIndexOp>(loc, smeTile.row); + arith::ConstantIndexOp::create(rewriter, loc, smeTile.row); // 3. Compute mask for tile store. Value maskRows; Value maskCols; if (auto mask = writeOp.getMask()) { auto createMask = mask.getDefiningOp<vector::CreateMaskOp>(); - maskRows = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(0), - transposedRow); - maskCols = rewriter.create<arith::SubIOp>(loc, createMask.getOperand(1), - transposedCol); - maskCols = rewriter.create<index::MinSOp>(loc, maskCols, numSlices); + maskRows = arith::SubIOp::create( + rewriter, loc, createMask.getOperand(0), transposedRow); + maskCols = arith::SubIOp::create( + rewriter, loc, createMask.getOperand(1), transposedCol); + maskCols = index::MinSOp::create(rewriter, loc, maskCols, numSlices); } else { maskRows = createVscaleMultiple(smeTileType.getDimSize(0)); maskCols = numSlices; } - auto subMask = rewriter.create<vector::CreateMaskOp>( - loc, smeTileType.clone(rewriter.getI1Type()), + auto subMask = vector::CreateMaskOp::create( + rewriter, loc, smeTileType.clone(rewriter.getI1Type()), ValueRange{maskRows, maskCols}); // 4. Emit a transposed tile write. auto writeIndices = writeOp.getIndices(); Value destRow = - rewriter.create<arith::AddIOp>(loc, transposedRow, writeIndices[0]); + arith::AddIOp::create(rewriter, loc, transposedRow, writeIndices[0]); Value destCol = - rewriter.create<arith::AddIOp>(loc, transposedCol, writeIndices[1]); - auto smeWrite = rewriter.create<vector::TransferWriteOp>( - loc, tile, destTensorOrMemref, ValueRange{destRow, destCol}, + arith::AddIOp::create(rewriter, loc, transposedCol, writeIndices[1]); + auto smeWrite = vector::TransferWriteOp::create( + rewriter, loc, tile, destTensorOrMemref, ValueRange{destRow, destCol}, transposeMap, subMask, writeOp.getInBounds()); if (writeOp.hasPureTensorSemantics()) @@ -934,42 +937,42 @@ struct LowerColumnTransferReadToLoops // Create a loop over all rows and load one element at a time. auto loc = readOp.getLoc(); - auto lowerBound = rewriter.create<arith::ConstantIndexOp>(loc, 0); + auto lowerBound = arith::ConstantIndexOp::create(rewriter, loc, 0); auto createVscaleMultiple = vector::makeVscaleConstantBuilder(rewriter, loc); auto upperBound = createVscaleMultiple(numRows); - auto step = rewriter.create<arith::ConstantIndexOp>(loc, 1); - Value init = rewriter.create<arith::ConstantOp>( - loc, newResType, DenseElementsAttr::get(newResType, 0.0f)); + auto step = arith::ConstantIndexOp::create(rewriter, loc, 1); + Value init = arith::ConstantOp::create( + rewriter, loc, newResType, DenseElementsAttr::get(newResType, 0.0f)); scf::ForOp loadLoop; { OpBuilder::InsertionGuard g(rewriter); - loadLoop = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound, step, - ValueRange{init}); + loadLoop = scf::ForOp::create(rewriter, loc, lowerBound, upperBound, step, + ValueRange{init}); rewriter.setInsertionPointToStart(loadLoop.getBody()); auto tileSliceIndex = loadLoop.getInductionVar(); - auto idx0 = rewriter.create<arith::AddIOp>(loc, tileSliceIndex, - readOp.getIndices()[0]); + auto idx0 = arith::AddIOp::create(rewriter, loc, tileSliceIndex, + readOp.getIndices()[0]); auto idx1 = readOp.getIndices()[1]; - Value scalar = rewriter.create<memref::LoadOp>( - loc, readOp.getBase(), SmallVector<Value>({idx0, idx1})); + Value scalar = memref::LoadOp::create(rewriter, loc, readOp.getBase(), + SmallVector<Value>({idx0, idx1})); - Operation *updateInit = rewriter.create<vector::InsertOp>( - loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex); + Operation *updateInit = vector::InsertOp::create( + rewriter, loc, scalar, loadLoop.getRegionIterArg(0), tileSliceIndex); - rewriter.create<scf::YieldOp>(loc, updateInit->getResult(0)); + scf::YieldOp::create(rewriter, loc, updateInit->getResult(0)); } // The read operation has been "legalized", but since the original result // type was a 2D vector, we need to cast before returning the result. This // ShapeCast should cancel-out with some other ShapeCast (i.e. it's a // no-op). - auto sc = rewriter.create<vector::ShapeCastOp>( - loc, readOp.getResult().getType(), loadLoop.getResult(0)); + auto sc = vector::ShapeCastOp::create( + rewriter, loc, readOp.getResult().getType(), loadLoop.getResult(0)); rewriter.replaceOp(readOp, sc); diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp index 7b64e57..a7c6981 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -87,8 +87,8 @@ struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> { VectorType sourceType = source.getType(); VectorType resultType = convertOp.getResult().getType(); - Value result = rewriter.create<arith::ConstantOp>( - loc, resultType, rewriter.getZeroAttr(resultType)); + Value result = arith::ConstantOp::create(rewriter, loc, resultType, + rewriter.getZeroAttr(resultType)); // We want to iterate over the input vector in steps of the trailing // dimension. So this creates tile shape where all leading dimensions are 1, @@ -100,15 +100,15 @@ struct SvboolConversionOpLowering : public ConvertOpToLLVMPattern<Op> { for (SmallVector<int64_t> index : StaticTileOffsetRange(sourceType.getShape(), tileShape)) { auto extractOrInsertPosition = ArrayRef(index).drop_back(); - auto sourceVector = rewriter.create<vector::ExtractOp>( - loc, source, extractOrInsertPosition); + auto sourceVector = vector::ExtractOp::create(rewriter, loc, source, + extractOrInsertPosition); VectorType convertedType = VectorType::Builder(llvm::cast<VectorType>(sourceVector.getType())) .setDim(0, resultType.getShape().back()); auto convertedVector = - rewriter.create<IntrOp>(loc, TypeRange{convertedType}, sourceVector); - result = rewriter.create<vector::InsertOp>(loc, convertedVector, result, - extractOrInsertPosition); + IntrOp::create(rewriter, loc, TypeRange{convertedType}, sourceVector); + result = vector::InsertOp::create(rewriter, loc, convertedVector, result, + extractOrInsertPosition); } rewriter.replaceOp(convertOp, result); @@ -135,12 +135,12 @@ struct PselOpLowering : public ConvertOpToLLVMPattern<PselOp> { ConversionPatternRewriter &rewriter) const override { auto svboolType = VectorType::get(16, rewriter.getI1Type(), true); auto loc = pselOp.getLoc(); - auto svboolP1 = rewriter.create<ConvertToSvboolIntrOp>(loc, svboolType, - adaptor.getP1()); - auto indexI32 = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getI32Type(), pselOp.getIndex()); - auto pselIntr = rewriter.create<PselIntrOp>(loc, svboolType, svboolP1, - pselOp.getP2(), indexI32); + auto svboolP1 = ConvertToSvboolIntrOp::create(rewriter, loc, svboolType, + adaptor.getP1()); + auto indexI32 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI32Type(), pselOp.getIndex()); + auto pselIntr = PselIntrOp::create(rewriter, loc, svboolType, svboolP1, + pselOp.getP2(), indexI32); rewriter.replaceOpWithNewOp<ConvertFromSvboolIntrOp>( pselOp, adaptor.getP1().getType(), pselIntr); return success(); @@ -174,7 +174,7 @@ struct CreateMaskOpLowering "not SVE predicate-sized"); auto loc = createMaskOp.getLoc(); - auto zero = rewriter.create<LLVM::ZeroOp>(loc, rewriter.getI64Type()); + auto zero = LLVM::ZeroOp::create(rewriter, loc, rewriter.getI64Type()); rewriter.replaceOpWithNewOp<WhileLTIntrOp>(createMaskOp, maskType, zero, adaptor.getOperands()[0]); return success(); diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index 3dbb93b..3a409ad 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -71,8 +71,8 @@ void replaceOpWithUnrealizedConversion(PatternRewriter &rewriter, TOp op, TLegalizerCallback callback) { replaceOpWithLegalizedOp(rewriter, op, [&](TOp newOp) { // Mark our `unrealized_conversion_casts` with a pass label. - return rewriter.create<UnrealizedConversionCastOp>( - op.getLoc(), TypeRange{op.getResult().getType()}, + return UnrealizedConversionCastOp::create( + rewriter, op.getLoc(), TypeRange{op.getResult().getType()}, ValueRange{callback(newOp)}, NamedAttribute(rewriter.getStringAttr(kSVELegalizerTag), rewriter.getUnitAttr())); @@ -239,8 +239,8 @@ struct LegalizeSVEMaskStoreConversion auto legalMaskType = widenScalableMaskTypeToSvbool( llvm::cast<VectorType>(valueToStore.getType())); - auto convertToSvbool = rewriter.create<arm_sve::ConvertToSvboolOp>( - loc, legalMaskType, valueToStore); + auto convertToSvbool = arm_sve::ConvertToSvboolOp::create( + rewriter, loc, legalMaskType, valueToStore); // Replace this store with a conversion to a storable svbool mask [1], // followed by a wider store. replaceOpWithLegalizedOp(rewriter, storeOp, @@ -290,8 +290,8 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern<memref::LoadOp> { replaceOpWithLegalizedOp(rewriter, loadOp, [&](memref::LoadOp newLoadOp) { newLoadOp.setMemRef(*legalMemref); newLoadOp.getResult().setType(legalMaskType); - return rewriter.create<arm_sve::ConvertFromSvboolOp>( - loc, loadedMask.getType(), newLoadOp); + return arm_sve::ConvertFromSvboolOp::create( + rewriter, loc, loadedMask.getType(), newLoadOp); }); return success(); @@ -408,8 +408,8 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> { reassoc.back().push_back(i); if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc)) return failure(); - Value collapsedMem = rewriter.create<memref::CollapseShapeOp>( - readOp.getLoc(), readOp.getBase(), reassoc); + Value collapsedMem = memref::CollapseShapeOp::create( + rewriter, readOp.getLoc(), readOp.getBase(), reassoc); // Get a vector type with collapsed trailing dimensions. SmallVector<int64_t> shape(origVT.getShape()); @@ -424,14 +424,14 @@ struct LegalizeTransferRead : public OpRewritePattern<vector::TransferReadOp> { auto indices = readOp.getIndices().drop_back(numCollapseDims - 1); // Create the new `transfer_read`. - auto newReadOp = rewriter.create<vector::TransferReadOp>( - readOp.getLoc(), collapsedVT, collapsedMem, indices, + auto newReadOp = vector::TransferReadOp::create( + rewriter, readOp.getLoc(), collapsedVT, collapsedMem, indices, readOp.getPadding(), ArrayRef<bool>(origInBounds).drop_back(numCollapseDims - 1)); // Cast back to the original vector type. - auto toOrigShape = rewriter.create<vector::ShapeCastOp>(readOp.getLoc(), - origVT, newReadOp); + auto toOrigShape = vector::ShapeCastOp::create(rewriter, readOp.getLoc(), + origVT, newReadOp); rewriter.replaceOp(readOp, toOrigShape); return success(); diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp index ac1df38..35b0bd1 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp @@ -50,7 +50,7 @@ std::optional<Value> getExtOperand(Value v) { // If the operand is not defined by an explicit extend operation of the // accepted operation type allow for an implicit sign-extension. - auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp()); + auto extOp = v.getDefiningOp<Op>(); if (!extOp) { if constexpr (std::is_same<Op, arith::ExtSIOp>::value) { auto vTy = cast<VectorType>(v.getType()); @@ -214,13 +214,13 @@ Value VectorContractRewriter::createMMLA(PatternRewriter &rewriter, switch (mmlaOp) { case MMLA::SignedInt: - return rewriter.create<arm_sve::SmmlaOp>(loc, resTy, acc, lhs, rhs); + return arm_sve::SmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs); case MMLA::UnsignedInt: - return rewriter.create<arm_sve::UmmlaOp>(loc, resTy, acc, lhs, rhs); + return arm_sve::UmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs); case MMLA::MixedInt: - return rewriter.create<arm_sve::UsmmlaOp>(loc, resTy, acc, lhs, rhs); + return arm_sve::UsmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs); case MMLA::Bfloat: - return rewriter.create<arm_sve::BfmmlaOp>(loc, resTy, acc, lhs, rhs); + return arm_sve::BfmmlaOp::create(rewriter, loc, resTy, acc, lhs, rhs); default: llvm_unreachable("Uninitialized operation kind"); } @@ -316,62 +316,63 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, for (int64_t i = 0; i < M; i += 2) { // Extract two consecutive rows of the LHS tile. auto r0 = - rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i}); + vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i}); auto r1 = - rewriter.create<vector::ExtractOp>(loc, lhs, ArrayRef<int64_t>{i + 1}); + vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i + 1}); // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile. SmallVector<int64_t> shuffleIdx(2 * K); std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0); - auto t = rewriter.create<vector::ShuffleOp>(loc, r0, r1, shuffleIdx); + auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx); // Turn it into a scalable vector. - auto s = rewriter.create<vector::ScalableInsertOp>( - loc, t, rewriter.create<ub::PoisonOp>(loc, flatLhsType), 0); + auto s = vector::ScalableInsertOp::create( + rewriter, loc, t, ub::PoisonOp::create(rewriter, loc, flatLhsType), 0); // Replicate the sub-tile VSCALE times to fill the entire vector. - auto r = rewriter.create<arm_sve::DupQLaneOp>(loc, s, 0); + auto r = arm_sve::DupQLaneOp::create(rewriter, loc, s, 0); lhsTile.push_back(r); } // "Flatten" the RHS tile from <[N]xK> to <[N*K]>. - auto rhs = rewriter.create<vector::ShapeCastOp>(this->rhs.getLoc(), - flatRhsTileType, this->rhs); + auto rhs = vector::ShapeCastOp::create(rewriter, this->rhs.getLoc(), + flatRhsTileType, this->rhs); // Extract the RHS sub-tiles with logical shape <Kx[2]>. SmallVector<Value> rhsTile; for (int64_t j = 0; j < N; j += 2) - rhsTile.push_back(rewriter.create<vector::ScalableExtractOp>( - loc, flatRhsType, rhs, j * K)); + rhsTile.push_back(vector::ScalableExtractOp::create( + rewriter, loc, flatRhsType, rhs, j * K)); // Extract and pack the ACC sub-tiles. SmallVector<Value> accTile; for (int64_t i = 0; i < M; i += 2) { // Extract two consecutive rows of the accumulator tile. - auto r0 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(), - ArrayRef<int64_t>{i}); - auto r1 = rewriter.create<vector::ExtractOp>(loc, op.getAcc(), - ArrayRef<int64_t>{i + 1}); + auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(), + ArrayRef<int64_t>{i}); + auto r1 = vector::ExtractOp::create(rewriter, loc, op.getAcc(), + ArrayRef<int64_t>{i + 1}); Value accTileVec; if (swapOperands) { // We are performing the operation with swapped LHS and RHS we need to // transpose each individual 2x2 tile of the accumulator and (later) the // final result. - accTileVec = rewriter.create<vector::InterleaveOp>(loc, r0, r1); + accTileVec = vector::InterleaveOp::create(rewriter, loc, r0, r1); } else { // Bitcast accumulator rows to double-width integer elements, so // subsequent interleave/deinterleave work on pairs of elements. - auto r0I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r0); - auto r1I64 = rewriter.create<vector::BitCastOp>(loc, accRow64Ty, r1); + auto r0I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r0); + auto r1I64 = vector::BitCastOp::create(rewriter, loc, accRow64Ty, r1); // Interleave the rows, effectively flattening each 2x2 tile into 4 // consecutive elements. - auto intrI64 = rewriter.create<vector::InterleaveOp>(loc, r0I64, r1I64); + auto intrI64 = vector::InterleaveOp::create(rewriter, loc, r0I64, r1I64); // Bitcast back to original element type. - accTileVec = rewriter.create<vector::BitCastOp>(loc, accRowX2Ty, intrI64); + accTileVec = + vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64); } // Extract ACC sub-tiles. for (int64_t j = 0; j < N; j += 2) - accTile.push_back(rewriter.create<vector::ScalableExtractOp>( - loc, flatAccType, accTileVec, j * 2)); + accTile.push_back(vector::ScalableExtractOp::create( + rewriter, loc, flatAccType, accTileVec, j * 2)); } // Emit sub-tile matrix multiplications. @@ -384,13 +385,13 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, } // Unpack the OUT sub-tiles and insert into the result. - Value result = rewriter.create<ub::PoisonOp>(loc, op.getResultType()); + Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType()); for (int64_t i = 0; i < M / 2; ++i) { // Collect a number of sub-tiles in a row. - Value row = rewriter.create<ub::PoisonOp>(loc, accRowX2Ty); + Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty); for (int64_t j = 0; j < N / 2; ++j) - row = rewriter.create<vector::ScalableInsertOp>( - loc, outTile[i * N / 2 + j], row, j * 4); + row = vector::ScalableInsertOp::create( + rewriter, loc, outTile[i * N / 2 + j], row, j * 4); // Unpack the row to obtain two rows of the output. If we have the out // sub-tiles transposed we obtain two consecutive output rows by @@ -398,22 +399,22 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, // Otherwise, the interleave is by pairs. Value out0, out1; if (swapOperands) { - auto tmp = rewriter.create<vector::DeinterleaveOp>(loc, row); + auto tmp = vector::DeinterleaveOp::create(rewriter, loc, row); out0 = tmp.getRes1(); out1 = tmp.getRes2(); } else { // Deinterleave by pairs. - auto row64 = rewriter.create<vector::BitCastOp>(loc, accRowX264Ty, row); - auto deintr64 = rewriter.create<vector::DeinterleaveOp>(loc, row64); + auto row64 = vector::BitCastOp::create(rewriter, loc, accRowX264Ty, row); + auto deintr64 = vector::DeinterleaveOp::create(rewriter, loc, row64); // Bitcast back into original element type and insert into the result. - out0 = - rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes1()); - out1 = - rewriter.create<vector::BitCastOp>(loc, accRowTy, deintr64.getRes2()); + out0 = vector::BitCastOp::create(rewriter, loc, accRowTy, + deintr64.getRes1()); + out1 = vector::BitCastOp::create(rewriter, loc, accRowTy, + deintr64.getRes2()); } - result = rewriter.create<vector::InsertOp>(loc, out0, result, i * 2); - result = rewriter.create<vector::InsertOp>(loc, out1, result, i * 2 + 1); + result = vector::InsertOp::create(rewriter, loc, out0, result, i * 2); + result = vector::InsertOp::create(rewriter, loc, out1, result, i * 2 + 1); } return result; diff --git a/mlir/lib/Dialect/Async/IR/Async.cpp b/mlir/lib/Dialect/Async/IR/Async.cpp index 08a57db..dc7b07d 100644 --- a/mlir/lib/Dialect/Async/IR/Async.cpp +++ b/mlir/lib/Dialect/Async/IR/Async.cpp @@ -97,7 +97,7 @@ void ExecuteOp::build(OpBuilder &builder, OperationState &result, // expected result is empty. Otherwise, leave this to the caller // because we don't know which values to return from the execute op. if (resultTypes.empty() && !bodyBuilder) { - builder.create<async::YieldOp>(result.location, ValueRange()); + async::YieldOp::create(builder, result.location, ValueRange()); } else if (bodyBuilder) { bodyBuilder(builder, result.location, bodyBlock->getArguments()); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp index bf6bfe2a..96283cd 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp @@ -190,8 +190,8 @@ static SmallVector<Value> delinearize(ImplicitLocOpBuilder &b, Value index, assert(!tripCounts.empty() && "tripCounts must be not empty"); for (ssize_t i = tripCounts.size() - 1; i >= 0; --i) { - coords[i] = b.create<arith::RemSIOp>(index, tripCounts[i]); - index = b.create<arith::DivSIOp>(index, tripCounts[i]); + coords[i] = arith::RemSIOp::create(b, index, tripCounts[i]); + index = arith::DivSIOp::create(b, index, tripCounts[i]); } return coords; @@ -275,15 +275,15 @@ static ParallelComputeFunction createParallelComputeFunction( BlockArgument blockSize = args.blockSize(); // Constants used below. - Value c0 = b.create<arith::ConstantIndexOp>(0); - Value c1 = b.create<arith::ConstantIndexOp>(1); + Value c0 = arith::ConstantIndexOp::create(b, 0); + Value c1 = arith::ConstantIndexOp::create(b, 1); // Materialize known constants as constant operation in the function body. auto values = [&](ArrayRef<BlockArgument> args, ArrayRef<IntegerAttr> attrs) { return llvm::to_vector( llvm::map_range(llvm::zip(args, attrs), [&](auto tuple) -> Value { if (IntegerAttr attr = std::get<1>(tuple)) - return b.create<arith::ConstantOp>(attr); + return arith::ConstantOp::create(b, attr); return std::get<0>(tuple); })); }; @@ -302,17 +302,17 @@ static ParallelComputeFunction createParallelComputeFunction( // one-dimensional iteration space. Value tripCount = tripCounts[0]; for (unsigned i = 1; i < tripCounts.size(); ++i) - tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]); + tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]); // Find one-dimensional iteration bounds: [blockFirstIndex, blockLastIndex]: // blockFirstIndex = blockIndex * blockSize - Value blockFirstIndex = b.create<arith::MulIOp>(blockIndex, blockSize); + Value blockFirstIndex = arith::MulIOp::create(b, blockIndex, blockSize); // The last one-dimensional index in the block defined by the `blockIndex`: // blockLastIndex = min(blockFirstIndex + blockSize, tripCount) - 1 - Value blockEnd0 = b.create<arith::AddIOp>(blockFirstIndex, blockSize); - Value blockEnd1 = b.create<arith::MinSIOp>(blockEnd0, tripCount); - Value blockLastIndex = b.create<arith::SubIOp>(blockEnd1, c1); + Value blockEnd0 = arith::AddIOp::create(b, blockFirstIndex, blockSize); + Value blockEnd1 = arith::MinSIOp::create(b, blockEnd0, tripCount); + Value blockLastIndex = arith::SubIOp::create(b, blockEnd1, c1); // Convert one-dimensional indices to multi-dimensional coordinates. auto blockFirstCoord = delinearize(b, blockFirstIndex, tripCounts); @@ -325,7 +325,7 @@ static ParallelComputeFunction createParallelComputeFunction( // dimension when inner compute dimension contains multiple blocks. SmallVector<Value> blockEndCoord(op.getNumLoops()); for (size_t i = 0; i < blockLastCoord.size(); ++i) - blockEndCoord[i] = b.create<arith::AddIOp>(blockLastCoord[i], c1); + blockEndCoord[i] = arith::AddIOp::create(b, blockLastCoord[i], c1); // Construct a loop nest out of scf.for operations that will iterate over // all coordinates in [blockFirstCoord, blockLastCoord] range. @@ -368,21 +368,22 @@ static ParallelComputeFunction createParallelComputeFunction( ImplicitLocOpBuilder b(loc, nestedBuilder); // Compute induction variable for `loopIdx`. - computeBlockInductionVars[loopIdx] = b.create<arith::AddIOp>( - lowerBounds[loopIdx], b.create<arith::MulIOp>(iv, steps[loopIdx])); + computeBlockInductionVars[loopIdx] = + arith::AddIOp::create(b, lowerBounds[loopIdx], + arith::MulIOp::create(b, iv, steps[loopIdx])); // Check if we are inside first or last iteration of the loop. - isBlockFirstCoord[loopIdx] = b.create<arith::CmpIOp>( - arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]); - isBlockLastCoord[loopIdx] = b.create<arith::CmpIOp>( - arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]); + isBlockFirstCoord[loopIdx] = arith::CmpIOp::create( + b, arith::CmpIPredicate::eq, iv, blockFirstCoord[loopIdx]); + isBlockLastCoord[loopIdx] = arith::CmpIOp::create( + b, arith::CmpIPredicate::eq, iv, blockLastCoord[loopIdx]); // Check if the previous loop is in its first or last iteration. if (loopIdx > 0) { - isBlockFirstCoord[loopIdx] = b.create<arith::AndIOp>( - isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]); - isBlockLastCoord[loopIdx] = b.create<arith::AndIOp>( - isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]); + isBlockFirstCoord[loopIdx] = arith::AndIOp::create( + b, isBlockFirstCoord[loopIdx], isBlockFirstCoord[loopIdx - 1]); + isBlockLastCoord[loopIdx] = arith::AndIOp::create( + b, isBlockLastCoord[loopIdx], isBlockLastCoord[loopIdx - 1]); } // Keep building loop nest. @@ -390,24 +391,24 @@ static ParallelComputeFunction createParallelComputeFunction( if (loopIdx + 1 >= op.getNumLoops() - numBlockAlignedInnerLoops) { // For block aligned loops we always iterate starting from 0 up to // the loop trip counts. - b.create<scf::ForOp>(c0, tripCounts[loopIdx + 1], c1, ValueRange(), - workLoopBuilder(loopIdx + 1)); + scf::ForOp::create(b, c0, tripCounts[loopIdx + 1], c1, ValueRange(), + workLoopBuilder(loopIdx + 1)); } else { // Select nested loop lower/upper bounds depending on our position in // the multi-dimensional iteration space. - auto lb = b.create<arith::SelectOp>(isBlockFirstCoord[loopIdx], - blockFirstCoord[loopIdx + 1], c0); + auto lb = arith::SelectOp::create(b, isBlockFirstCoord[loopIdx], + blockFirstCoord[loopIdx + 1], c0); - auto ub = b.create<arith::SelectOp>(isBlockLastCoord[loopIdx], - blockEndCoord[loopIdx + 1], - tripCounts[loopIdx + 1]); + auto ub = arith::SelectOp::create(b, isBlockLastCoord[loopIdx], + blockEndCoord[loopIdx + 1], + tripCounts[loopIdx + 1]); - b.create<scf::ForOp>(lb, ub, c1, ValueRange(), - workLoopBuilder(loopIdx + 1)); + scf::ForOp::create(b, lb, ub, c1, ValueRange(), + workLoopBuilder(loopIdx + 1)); } - b.create<scf::YieldOp>(loc); + scf::YieldOp::create(b, loc); return; } @@ -418,13 +419,13 @@ static ParallelComputeFunction createParallelComputeFunction( for (auto &bodyOp : op.getRegion().front().without_terminator()) b.clone(bodyOp, mapping); - b.create<scf::YieldOp>(loc); + scf::YieldOp::create(b, loc); }; }; - b.create<scf::ForOp>(blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(), - workLoopBuilder(0)); - b.create<func::ReturnOp>(ValueRange()); + scf::ForOp::create(b, blockFirstCoord[0], blockEndCoord[0], c1, ValueRange(), + workLoopBuilder(0)); + func::ReturnOp::create(b, ValueRange()); return {op.getNumLoops(), func, std::move(computeFuncType.captures)}; } @@ -484,8 +485,8 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, b.setInsertionPointToEnd(block); Type indexTy = b.getIndexType(); - Value c1 = b.create<arith::ConstantIndexOp>(1); - Value c2 = b.create<arith::ConstantIndexOp>(2); + Value c1 = arith::ConstantIndexOp::create(b, 1); + Value c2 = arith::ConstantIndexOp::create(b, 2); // Get the async group that will track async dispatch completion. Value group = block->getArgument(0); @@ -500,7 +501,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, SmallVector<Location> locations = {loc, loc}; // Create a recursive dispatch loop. - scf::WhileOp whileOp = b.create<scf::WhileOp>(types, operands); + scf::WhileOp whileOp = scf::WhileOp::create(b, types, operands); Block *before = b.createBlock(&whileOp.getBefore(), {}, types, locations); Block *after = b.createBlock(&whileOp.getAfter(), {}, types, locations); @@ -510,10 +511,10 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, b.setInsertionPointToEnd(before); Value start = before->getArgument(0); Value end = before->getArgument(1); - Value distance = b.create<arith::SubIOp>(end, start); + Value distance = arith::SubIOp::create(b, end, start); Value dispatch = - b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, distance, c1); - b.create<scf::ConditionOp>(dispatch, before->getArguments()); + arith::CmpIOp::create(b, arith::CmpIPredicate::sgt, distance, c1); + scf::ConditionOp::create(b, dispatch, before->getArguments()); } // Setup the async dispatch loop body: recursively call dispatch function @@ -522,9 +523,9 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, b.setInsertionPointToEnd(after); Value start = after->getArgument(0); Value end = after->getArgument(1); - Value distance = b.create<arith::SubIOp>(end, start); - Value halfDistance = b.create<arith::DivSIOp>(distance, c2); - Value midIndex = b.create<arith::AddIOp>(start, halfDistance); + Value distance = arith::SubIOp::create(b, end, start); + Value halfDistance = arith::DivSIOp::create(b, distance, c2); + Value midIndex = arith::AddIOp::create(b, start, halfDistance); // Call parallel compute function inside the async.execute region. auto executeBodyBuilder = [&](OpBuilder &executeBuilder, @@ -535,16 +536,16 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, operands[1] = midIndex; operands[2] = end; - executeBuilder.create<func::CallOp>(executeLoc, func.getSymName(), - func.getResultTypes(), operands); - executeBuilder.create<async::YieldOp>(executeLoc, ValueRange()); + func::CallOp::create(executeBuilder, executeLoc, func.getSymName(), + func.getResultTypes(), operands); + async::YieldOp::create(executeBuilder, executeLoc, ValueRange()); }; // Create async.execute operation to dispatch half of the block range. - auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(), - executeBodyBuilder); - b.create<AddToGroupOp>(indexTy, execute.getToken(), group); - b.create<scf::YieldOp>(ValueRange({start, midIndex})); + auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(), + executeBodyBuilder); + AddToGroupOp::create(b, indexTy, execute.getToken(), group); + scf::YieldOp::create(b, ValueRange({start, midIndex})); } // After dispatching async operations to process the tail of the block range @@ -556,10 +557,9 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc, SmallVector<Value> computeFuncOperands = {blockStart}; computeFuncOperands.append(forwardedInputs.begin(), forwardedInputs.end()); - b.create<func::CallOp>(computeFunc.func.getSymName(), - computeFunc.func.getResultTypes(), - computeFuncOperands); - b.create<func::ReturnOp>(ValueRange()); + func::CallOp::create(b, computeFunc.func.getSymName(), + computeFunc.func.getResultTypes(), computeFuncOperands); + func::ReturnOp::create(b, ValueRange()); return func; } @@ -577,8 +577,8 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, func::FuncOp asyncDispatchFunction = createAsyncDispatchFunction(parallelComputeFunction, rewriter); - Value c0 = b.create<arith::ConstantIndexOp>(0); - Value c1 = b.create<arith::ConstantIndexOp>(1); + Value c0 = arith::ConstantIndexOp::create(b, 0); + Value c1 = arith::ConstantIndexOp::create(b, 1); // Appends operands shared by async dispatch and parallel compute functions to // the given operands vector. @@ -594,7 +594,7 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, // completely. If this will be known statically, then canonicalization will // erase async group operations. Value isSingleBlock = - b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, blockCount, c1); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, blockCount, c1); auto syncDispatch = [&](OpBuilder &nestedBuilder, Location loc) { ImplicitLocOpBuilder b(loc, nestedBuilder); @@ -603,10 +603,10 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, SmallVector<Value> operands = {c0, blockSize}; appendBlockComputeOperands(operands); - b.create<func::CallOp>(parallelComputeFunction.func.getSymName(), - parallelComputeFunction.func.getResultTypes(), - operands); - b.create<scf::YieldOp>(); + func::CallOp::create(b, parallelComputeFunction.func.getSymName(), + parallelComputeFunction.func.getResultTypes(), + operands); + scf::YieldOp::create(b); }; auto asyncDispatch = [&](OpBuilder &nestedBuilder, Location loc) { @@ -615,24 +615,24 @@ static void doAsyncDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, // Create an async.group to wait on all async tokens from the concurrent // execution of multiple parallel compute function. First block will be // executed synchronously in the caller thread. - Value groupSize = b.create<arith::SubIOp>(blockCount, c1); - Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize); + Value groupSize = arith::SubIOp::create(b, blockCount, c1); + Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize); // Launch async dispatch function for [0, blockCount) range. SmallVector<Value> operands = {group, c0, blockCount, blockSize}; appendBlockComputeOperands(operands); - b.create<func::CallOp>(asyncDispatchFunction.getSymName(), - asyncDispatchFunction.getResultTypes(), operands); + func::CallOp::create(b, asyncDispatchFunction.getSymName(), + asyncDispatchFunction.getResultTypes(), operands); // Wait for the completion of all parallel compute operations. - b.create<AwaitAllOp>(group); + AwaitAllOp::create(b, group); - b.create<scf::YieldOp>(); + scf::YieldOp::create(b); }; // Dispatch either single block compute function, or launch async dispatch. - b.create<scf::IfOp>(isSingleBlock, syncDispatch, asyncDispatch); + scf::IfOp::create(b, isSingleBlock, syncDispatch, asyncDispatch); } // Dispatch parallel compute functions by submitting all async compute tasks @@ -646,14 +646,14 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, func::FuncOp compute = parallelComputeFunction.func; - Value c0 = b.create<arith::ConstantIndexOp>(0); - Value c1 = b.create<arith::ConstantIndexOp>(1); + Value c0 = arith::ConstantIndexOp::create(b, 0); + Value c1 = arith::ConstantIndexOp::create(b, 1); // Create an async.group to wait on all async tokens from the concurrent // execution of multiple parallel compute function. First block will be // executed synchronously in the caller thread. - Value groupSize = b.create<arith::SubIOp>(blockCount, c1); - Value group = b.create<CreateGroupOp>(GroupType::get(ctx), groupSize); + Value groupSize = arith::SubIOp::create(b, blockCount, c1); + Value group = CreateGroupOp::create(b, GroupType::get(ctx), groupSize); // Call parallel compute function for all blocks. using LoopBodyBuilder = @@ -680,28 +680,27 @@ doSequentialDispatch(ImplicitLocOpBuilder &b, PatternRewriter &rewriter, // Call parallel compute function inside the async.execute region. auto executeBodyBuilder = [&](OpBuilder &executeBuilder, Location executeLoc, ValueRange executeArgs) { - executeBuilder.create<func::CallOp>(executeLoc, compute.getSymName(), - compute.getResultTypes(), - computeFuncOperands(iv)); - executeBuilder.create<async::YieldOp>(executeLoc, ValueRange()); + func::CallOp::create(executeBuilder, executeLoc, compute.getSymName(), + compute.getResultTypes(), computeFuncOperands(iv)); + async::YieldOp::create(executeBuilder, executeLoc, ValueRange()); }; // Create async.execute operation to launch parallel computate function. - auto execute = b.create<ExecuteOp>(TypeRange(), ValueRange(), ValueRange(), - executeBodyBuilder); - b.create<AddToGroupOp>(rewriter.getIndexType(), execute.getToken(), group); - b.create<scf::YieldOp>(); + auto execute = ExecuteOp::create(b, TypeRange(), ValueRange(), ValueRange(), + executeBodyBuilder); + AddToGroupOp::create(b, rewriter.getIndexType(), execute.getToken(), group); + scf::YieldOp::create(b); }; // Iterate over all compute blocks and launch parallel compute operations. - b.create<scf::ForOp>(c1, blockCount, c1, ValueRange(), loopBuilder); + scf::ForOp::create(b, c1, blockCount, c1, ValueRange(), loopBuilder); // Call parallel compute function for the first block in the caller thread. - b.create<func::CallOp>(compute.getSymName(), compute.getResultTypes(), - computeFuncOperands(c0)); + func::CallOp::create(b, compute.getSymName(), compute.getResultTypes(), + computeFuncOperands(c0)); // Wait for the completion of all async compute operations. - b.create<AwaitAllOp>(group); + AwaitAllOp::create(b, group); } LogicalResult @@ -737,17 +736,17 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, // for the scf.parallel operation. Value tripCount = tripCounts[0]; for (size_t i = 1; i < tripCounts.size(); ++i) - tripCount = b.create<arith::MulIOp>(tripCount, tripCounts[i]); + tripCount = arith::MulIOp::create(b, tripCount, tripCounts[i]); // Short circuit no-op parallel loops (zero iterations) that can arise from // the memrefs with dynamic dimension(s) equal to zero. - Value c0 = b.create<arith::ConstantIndexOp>(0); + Value c0 = arith::ConstantIndexOp::create(b, 0); Value isZeroIterations = - b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, tripCount, c0); + arith::CmpIOp::create(b, arith::CmpIPredicate::eq, tripCount, c0); // Do absolutely nothing if the trip count is zero. auto noOp = [&](OpBuilder &nestedBuilder, Location loc) { - nestedBuilder.create<scf::YieldOp>(loc); + scf::YieldOp::create(nestedBuilder, loc); }; // Compute the parallel block size and dispatch concurrent tasks computing @@ -797,9 +796,9 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, Value numWorkerThreadsVal; if (numWorkerThreads >= 0) - numWorkerThreadsVal = b.create<arith::ConstantIndexOp>(numWorkerThreads); + numWorkerThreadsVal = arith::ConstantIndexOp::create(b, numWorkerThreads); else - numWorkerThreadsVal = b.create<async::RuntimeNumWorkerThreadsOp>(); + numWorkerThreadsVal = async::RuntimeNumWorkerThreadsOp::create(b); // With large number of threads the value of creating many compute blocks // is reduced because the problem typically becomes memory bound. For this @@ -818,38 +817,38 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, {4, 4.0f}, {8, 2.0f}, {16, 1.0f}, {32, 0.8f}, {64, 0.6f}}; const float initialOvershardingFactor = 8.0f; - Value scalingFactor = b.create<arith::ConstantFloatOp>( - b.getF32Type(), llvm::APFloat(initialOvershardingFactor)); + Value scalingFactor = arith::ConstantFloatOp::create( + b, b.getF32Type(), llvm::APFloat(initialOvershardingFactor)); for (const std::pair<int, float> &p : overshardingBrackets) { - Value bracketBegin = b.create<arith::ConstantIndexOp>(p.first); - Value inBracket = b.create<arith::CmpIOp>( - arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin); - Value bracketScalingFactor = b.create<arith::ConstantFloatOp>( - b.getF32Type(), llvm::APFloat(p.second)); - scalingFactor = b.create<arith::SelectOp>(inBracket, bracketScalingFactor, - scalingFactor); + Value bracketBegin = arith::ConstantIndexOp::create(b, p.first); + Value inBracket = arith::CmpIOp::create( + b, arith::CmpIPredicate::sgt, numWorkerThreadsVal, bracketBegin); + Value bracketScalingFactor = arith::ConstantFloatOp::create( + b, b.getF32Type(), llvm::APFloat(p.second)); + scalingFactor = arith::SelectOp::create( + b, inBracket, bracketScalingFactor, scalingFactor); } Value numWorkersIndex = - b.create<arith::IndexCastOp>(b.getI32Type(), numWorkerThreadsVal); + arith::IndexCastOp::create(b, b.getI32Type(), numWorkerThreadsVal); Value numWorkersFloat = - b.create<arith::SIToFPOp>(b.getF32Type(), numWorkersIndex); + arith::SIToFPOp::create(b, b.getF32Type(), numWorkersIndex); Value scaledNumWorkers = - b.create<arith::MulFOp>(scalingFactor, numWorkersFloat); + arith::MulFOp::create(b, scalingFactor, numWorkersFloat); Value scaledNumInt = - b.create<arith::FPToSIOp>(b.getI32Type(), scaledNumWorkers); + arith::FPToSIOp::create(b, b.getI32Type(), scaledNumWorkers); Value scaledWorkers = - b.create<arith::IndexCastOp>(b.getIndexType(), scaledNumInt); + arith::IndexCastOp::create(b, b.getIndexType(), scaledNumInt); - Value maxComputeBlocks = b.create<arith::MaxSIOp>( - b.create<arith::ConstantIndexOp>(1), scaledWorkers); + Value maxComputeBlocks = arith::MaxSIOp::create( + b, arith::ConstantIndexOp::create(b, 1), scaledWorkers); // Compute parallel block size from the parallel problem size: // blockSize = min(tripCount, // max(ceil_div(tripCount, maxComputeBlocks), // minTaskSize)) - Value bs0 = b.create<arith::CeilDivSIOp>(tripCount, maxComputeBlocks); - Value bs1 = b.create<arith::MaxSIOp>(bs0, minTaskSize); - Value blockSize = b.create<arith::MinSIOp>(tripCount, bs1); + Value bs0 = arith::CeilDivSIOp::create(b, tripCount, maxComputeBlocks); + Value bs1 = arith::MaxSIOp::create(b, bs0, minTaskSize); + Value blockSize = arith::MinSIOp::create(b, tripCount, bs1); // Dispatch parallel compute function using async recursive work splitting, // or by submitting compute task sequentially from a caller thread. @@ -859,7 +858,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, // the parallel operation body for a subset of iteration space. // Compute the number of parallel compute blocks. - Value blockCount = b.create<arith::CeilDivSIOp>(tripCount, blockSize); + Value blockCount = arith::CeilDivSIOp::create(b, tripCount, blockSize); // Dispatch parallel compute function without hints to unroll inner loops. auto dispatchDefault = [&](OpBuilder &nestedBuilder, Location loc) { @@ -868,7 +867,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, ImplicitLocOpBuilder b(loc, nestedBuilder); doDispatch(b, rewriter, compute, op, blockSize, blockCount, tripCounts); - b.create<scf::YieldOp>(); + scf::YieldOp::create(b); }; // Dispatch parallel compute function with hints for unrolling inner loops. @@ -879,34 +878,34 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op, ImplicitLocOpBuilder b(loc, nestedBuilder); // Align the block size to be a multiple of the statically known // number of iterations in the inner loops. - Value numIters = b.create<arith::ConstantIndexOp>( - numIterations[op.getNumLoops() - numUnrollableLoops]); - Value alignedBlockSize = b.create<arith::MulIOp>( - b.create<arith::CeilDivSIOp>(blockSize, numIters), numIters); + Value numIters = arith::ConstantIndexOp::create( + b, numIterations[op.getNumLoops() - numUnrollableLoops]); + Value alignedBlockSize = arith::MulIOp::create( + b, arith::CeilDivSIOp::create(b, blockSize, numIters), numIters); doDispatch(b, rewriter, compute, op, alignedBlockSize, blockCount, tripCounts); - b.create<scf::YieldOp>(); + scf::YieldOp::create(b); }; // Dispatch to block aligned compute function only if the computed block // size is larger than the number of iterations in the unrollable inner // loops, because otherwise it can reduce the available parallelism. if (numUnrollableLoops > 0) { - Value numIters = b.create<arith::ConstantIndexOp>( - numIterations[op.getNumLoops() - numUnrollableLoops]); - Value useBlockAlignedComputeFn = b.create<arith::CmpIOp>( - arith::CmpIPredicate::sge, blockSize, numIters); - - b.create<scf::IfOp>(useBlockAlignedComputeFn, dispatchBlockAligned, - dispatchDefault); - b.create<scf::YieldOp>(); + Value numIters = arith::ConstantIndexOp::create( + b, numIterations[op.getNumLoops() - numUnrollableLoops]); + Value useBlockAlignedComputeFn = arith::CmpIOp::create( + b, arith::CmpIPredicate::sge, blockSize, numIters); + + scf::IfOp::create(b, useBlockAlignedComputeFn, dispatchBlockAligned, + dispatchDefault); + scf::YieldOp::create(b); } else { dispatchDefault(b, loc); } }; // Replace the `scf.parallel` operation with the parallel compute function. - b.create<scf::IfOp>(isZeroIterations, noOp, dispatch); + scf::IfOp::create(b, isZeroIterations, noOp, dispatch); // Parallel operation was replaced with a block iteration loop. rewriter.eraseOp(op); @@ -921,7 +920,7 @@ void AsyncParallelForPass::runOnOperation() { populateAsyncParallelForPatterns( patterns, asyncDispatch, numWorkerThreads, [&](ImplicitLocOpBuilder builder, scf::ParallelOp op) { - return builder.create<arith::ConstantIndexOp>(minTaskSize); + return arith::ConstantIndexOp::create(builder, minTaskSize); }); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp index 0da9b3a..ddc64ea 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -48,7 +48,7 @@ static LogicalResult dropRefIfNoUses(Value value, unsigned count = 1) { else b.setInsertionPointToStart(value.getParentBlock()); - b.create<RuntimeDropRefOp>(value.getLoc(), value, b.getI64IntegerAttr(1)); + RuntimeDropRefOp::create(b, value.getLoc(), value, b.getI64IntegerAttr(1)); return success(); } @@ -309,7 +309,7 @@ LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { // Add a drop_ref immediately after the last user. builder.setInsertionPointAfter(lastUser); - builder.create<RuntimeDropRefOp>(loc, value, builder.getI64IntegerAttr(1)); + RuntimeDropRefOp::create(builder, loc, value, builder.getI64IntegerAttr(1)); } return success(); @@ -327,7 +327,7 @@ AsyncRuntimeRefCountingPass::addAddRefBeforeFunctionCall(Value value) { // Add a reference before the function call to pass the value at `+1` // reference to the function entry block. builder.setInsertionPoint(user); - builder.create<RuntimeAddRefOp>(loc, value, builder.getI64IntegerAttr(1)); + RuntimeAddRefOp::create(builder, loc, value, builder.getI64IntegerAttr(1)); } return success(); @@ -411,12 +411,12 @@ AsyncRuntimeRefCountingPass::addDropRefInDivergentLivenessSuccessor( refCountingBlock = &successor->getParent()->emplaceBlock(); refCountingBlock->moveBefore(successor); OpBuilder builder = OpBuilder::atBlockEnd(refCountingBlock); - builder.create<cf::BranchOp>(value.getLoc(), successor); + cf::BranchOp::create(builder, value.getLoc(), successor); } OpBuilder builder = OpBuilder::atBlockBegin(refCountingBlock); - builder.create<RuntimeDropRefOp>(value.getLoc(), value, - builder.getI64IntegerAttr(1)); + RuntimeDropRefOp::create(builder, value.getLoc(), value, + builder.getI64IntegerAttr(1)); // No need to update the terminator operation. if (successor == refCountingBlock) @@ -507,13 +507,13 @@ AsyncRuntimePolicyBasedRefCountingPass::addRefCounting(Value value) { // Create `add_ref` operation before the operand owner. if (cnt > 0) { b.setInsertionPoint(operand.getOwner()); - b.create<RuntimeAddRefOp>(loc, value, b.getI64IntegerAttr(cnt)); + RuntimeAddRefOp::create(b, loc, value, b.getI64IntegerAttr(cnt)); } // Create `drop_ref` operation after the operand owner. if (cnt < 0) { b.setInsertionPointAfter(operand.getOwner()); - b.create<RuntimeDropRefOp>(loc, value, b.getI64IntegerAttr(-cnt)); + RuntimeDropRefOp::create(b, loc, value, b.getI64IntegerAttr(-cnt)); } } } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp index 44a3837..112d69c 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -186,22 +186,22 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) { std::optional<Value> retToken; if (isStateful) - retToken.emplace(builder.create<RuntimeCreateOp>(TokenType::get(ctx))); + retToken.emplace(RuntimeCreateOp::create(builder, TokenType::get(ctx))); llvm::SmallVector<Value, 4> retValues; ArrayRef<Type> resValueTypes = isStateful ? func.getResultTypes().drop_front() : func.getResultTypes(); for (auto resType : resValueTypes) retValues.emplace_back( - builder.create<RuntimeCreateOp>(resType).getResult()); + RuntimeCreateOp::create(builder, resType).getResult()); // ------------------------------------------------------------------------ // // Initialize coroutine: get coroutine id and coroutine handle. // ------------------------------------------------------------------------ // - auto coroIdOp = builder.create<CoroIdOp>(CoroIdType::get(ctx)); + auto coroIdOp = CoroIdOp::create(builder, CoroIdType::get(ctx)); auto coroHdlOp = - builder.create<CoroBeginOp>(CoroHandleType::get(ctx), coroIdOp.getId()); - builder.create<cf::BranchOp>(originalEntryBlock); + CoroBeginOp::create(builder, CoroHandleType::get(ctx), coroIdOp.getId()); + cf::BranchOp::create(builder, originalEntryBlock); Block *cleanupBlock = func.addBlock(); Block *cleanupBlockForDestroy = func.addBlock(); @@ -212,10 +212,10 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) { // ------------------------------------------------------------------------ // auto buildCleanupBlock = [&](Block *cb) { builder.setInsertionPointToStart(cb); - builder.create<CoroFreeOp>(coroIdOp.getId(), coroHdlOp.getHandle()); + CoroFreeOp::create(builder, coroIdOp.getId(), coroHdlOp.getHandle()); // Branch into the suspend block. - builder.create<cf::BranchOp>(suspendBlock); + cf::BranchOp::create(builder, suspendBlock); }; buildCleanupBlock(cleanupBlock); buildCleanupBlock(cleanupBlockForDestroy); @@ -227,7 +227,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) { builder.setInsertionPointToStart(suspendBlock); // Mark the end of a coroutine: async.coro.end - builder.create<CoroEndOp>(coroHdlOp.getHandle()); + CoroEndOp::create(builder, coroHdlOp.getHandle()); // Return created optional `async.token` and `async.values` from the suspend // block. This will be the return value of a coroutine ramp function. @@ -235,7 +235,7 @@ static CoroMachinery setupCoroMachinery(func::FuncOp func) { if (retToken) ret.push_back(*retToken); llvm::append_range(ret, retValues); - builder.create<func::ReturnOp>(ret); + func::ReturnOp::create(builder, ret); // `async.await` op lowering will create resume blocks for async // continuations, and will conditionally branch to cleanup or suspend blocks. @@ -272,13 +272,13 @@ static Block *setupSetErrorBlock(CoroMachinery &coro) { // Coroutine set_error block: set error on token and all returned values. if (coro.asyncToken) - builder.create<RuntimeSetErrorOp>(*coro.asyncToken); + RuntimeSetErrorOp::create(builder, *coro.asyncToken); for (Value retValue : coro.returnValues) - builder.create<RuntimeSetErrorOp>(retValue); + RuntimeSetErrorOp::create(builder, retValue); // Branch into the cleanup block. - builder.create<cf::BranchOp>(coro.cleanup); + cf::BranchOp::create(builder, coro.cleanup); return *coro.setError; } @@ -333,13 +333,13 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { // Await on all dependencies before starting to execute the body region. for (size_t i = 0; i < numDependencies; ++i) - builder.create<AwaitOp>(func.getArgument(i)); + AwaitOp::create(builder, func.getArgument(i)); // Await on all async value operands and unwrap the payload. SmallVector<Value, 4> unwrappedOperands(numOperands); for (size_t i = 0; i < numOperands; ++i) { Value operand = func.getArgument(numDependencies + i); - unwrappedOperands[i] = builder.create<AwaitOp>(loc, operand).getResult(); + unwrappedOperands[i] = AwaitOp::create(builder, loc, operand).getResult(); } // Map from function inputs defined above the execute op to the function @@ -366,15 +366,15 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { // Save the coroutine state: async.coro.save auto coroSaveOp = - builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); + CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle); // Pass coroutine to the runtime to be resumed on a runtime managed // thread. - builder.create<RuntimeResumeOp>(coro.coroHandle); + RuntimeResumeOp::create(builder, coro.coroHandle); // Add async.coro.suspend as a suspended block terminator. - builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, - branch.getDest(), coro.cleanupForDestroy); + CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend, + branch.getDest(), coro.cleanupForDestroy); branch.erase(); } @@ -382,8 +382,9 @@ outlineExecuteOp(SymbolTable &symbolTable, ExecuteOp execute) { // Replace the original `async.execute` with a call to outlined function. { ImplicitLocOpBuilder callBuilder(loc, execute); - auto callOutlinedFunc = callBuilder.create<func::CallOp>( - func.getName(), execute.getResultTypes(), functionInputs.getArrayRef()); + auto callOutlinedFunc = func::CallOp::create(callBuilder, func.getName(), + execute.getResultTypes(), + functionInputs.getArrayRef()); execute.replaceAllUsesWith(callOutlinedFunc.getResults()); execute.erase(); } @@ -451,7 +452,7 @@ public: Location loc = op->getLoc(); auto newFuncOp = - rewriter.create<func::FuncOp>(loc, op.getName(), op.getFunctionType()); + func::FuncOp::create(rewriter, loc, op.getName(), op.getFunctionType()); SymbolTable::setSymbolVisibility(newFuncOp, SymbolTable::getSymbolVisibility(op)); @@ -521,16 +522,16 @@ public: for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { Value returnValue = std::get<0>(tuple); Value asyncValue = std::get<1>(tuple); - rewriter.create<RuntimeStoreOp>(loc, returnValue, asyncValue); - rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); + RuntimeStoreOp::create(rewriter, loc, returnValue, asyncValue); + RuntimeSetAvailableOp::create(rewriter, loc, asyncValue); } if (coro.asyncToken) // Switch the coroutine completion token to available state. - rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken); + RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken); rewriter.eraseOp(op); - rewriter.create<cf::BranchOp>(loc, coro.cleanup); + cf::BranchOp::create(rewriter, loc, coro.cleanup); return success(); } @@ -581,16 +582,17 @@ public: // the async object (token, value or group) to become available. if (!isInCoroutine) { ImplicitLocOpBuilder builder(loc, rewriter); - builder.create<RuntimeAwaitOp>(loc, operand); + RuntimeAwaitOp::create(builder, loc, operand); // Assert that the awaited operands is not in the error state. - Value isError = builder.create<RuntimeIsErrorOp>(i1, operand); - Value notError = builder.create<arith::XOrIOp>( - isError, builder.create<arith::ConstantOp>( - loc, i1, builder.getIntegerAttr(i1, 1))); - - builder.create<cf::AssertOp>(notError, - "Awaited async operand is in error state"); + Value isError = RuntimeIsErrorOp::create(builder, i1, operand); + Value notError = arith::XOrIOp::create( + builder, isError, + arith::ConstantOp::create(builder, loc, i1, + builder.getIntegerAttr(i1, 1))); + + cf::AssertOp::create(builder, notError, + "Awaited async operand is in error state"); } // Inside the coroutine we convert await operation into coroutine suspension @@ -605,28 +607,28 @@ public: // Save the coroutine state and resume on a runtime managed thread when // the operand becomes available. auto coroSaveOp = - builder.create<CoroSaveOp>(CoroStateType::get(ctx), coro.coroHandle); - builder.create<RuntimeAwaitAndResumeOp>(operand, coro.coroHandle); + CoroSaveOp::create(builder, CoroStateType::get(ctx), coro.coroHandle); + RuntimeAwaitAndResumeOp::create(builder, operand, coro.coroHandle); // Split the entry block before the await operation. Block *resume = rewriter.splitBlock(suspended, Block::iterator(op)); // Add async.coro.suspend as a suspended block terminator. builder.setInsertionPointToEnd(suspended); - builder.create<CoroSuspendOp>(coroSaveOp.getState(), coro.suspend, resume, - coro.cleanupForDestroy); + CoroSuspendOp::create(builder, coroSaveOp.getState(), coro.suspend, + resume, coro.cleanupForDestroy); // Split the resume block into error checking and continuation. Block *continuation = rewriter.splitBlock(resume, Block::iterator(op)); // Check if the awaited value is in the error state. builder.setInsertionPointToStart(resume); - auto isError = builder.create<RuntimeIsErrorOp>(loc, i1, operand); - builder.create<cf::CondBranchOp>(isError, - /*trueDest=*/setupSetErrorBlock(coro), - /*trueArgs=*/ArrayRef<Value>(), - /*falseDest=*/continuation, - /*falseArgs=*/ArrayRef<Value>()); + auto isError = RuntimeIsErrorOp::create(builder, loc, i1, operand); + cf::CondBranchOp::create(builder, isError, + /*trueDest=*/setupSetErrorBlock(coro), + /*trueArgs=*/ArrayRef<Value>(), + /*falseDest=*/continuation, + /*falseArgs=*/ArrayRef<Value>()); // Make sure that replacement value will be constructed in the // continuation block. @@ -672,7 +674,7 @@ public: ConversionPatternRewriter &rewriter) const override { // Load from the async value storage. auto valueType = cast<ValueType>(operand.getType()).getValueType(); - return rewriter.create<RuntimeLoadOp>(op->getLoc(), valueType, operand); + return RuntimeLoadOp::create(rewriter, op->getLoc(), valueType, operand); } }; @@ -713,15 +715,15 @@ public: for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { Value yieldValue = std::get<0>(tuple); Value asyncValue = std::get<1>(tuple); - rewriter.create<RuntimeStoreOp>(loc, yieldValue, asyncValue); - rewriter.create<RuntimeSetAvailableOp>(loc, asyncValue); + RuntimeStoreOp::create(rewriter, loc, yieldValue, asyncValue); + RuntimeSetAvailableOp::create(rewriter, loc, asyncValue); } if (coro.asyncToken) // Switch the coroutine completion token to available state. - rewriter.create<RuntimeSetAvailableOp>(loc, *coro.asyncToken); + RuntimeSetAvailableOp::create(rewriter, loc, *coro.asyncToken); - rewriter.create<cf::BranchOp>(loc, coro.cleanup); + cf::BranchOp::create(rewriter, loc, coro.cleanup); rewriter.eraseOp(op); return success(); @@ -755,11 +757,11 @@ public: Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); rewriter.setInsertionPointToEnd(cont->getPrevNode()); - rewriter.create<cf::CondBranchOp>(loc, adaptor.getArg(), - /*trueDest=*/cont, - /*trueArgs=*/ArrayRef<Value>(), - /*falseDest=*/setupSetErrorBlock(coro), - /*falseArgs=*/ArrayRef<Value>()); + cf::CondBranchOp::create(rewriter, loc, adaptor.getArg(), + /*trueDest=*/cont, + /*trueArgs=*/ArrayRef<Value>(), + /*falseDest=*/setupSetErrorBlock(coro), + /*falseArgs=*/ArrayRef<Value>()); rewriter.eraseOp(op); return success(); diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp index 2bf326a..4dfba74 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferDeallocationOpInterface.cpp @@ -35,7 +35,7 @@ using namespace bufferization; //===----------------------------------------------------------------------===// static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { - return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value)); + return arith::ConstantOp::create(builder, loc, builder.getBoolAttr(value)); } static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); } @@ -150,7 +150,7 @@ DeallocationState::getMemrefWithUniqueOwnership(OpBuilder &builder, // ownerships more intelligently to not end up with an 'Unknown' ownership in // the first place. auto cloneOp = - builder.create<bufferization::CloneOp>(memref.getLoc(), memref); + bufferization::CloneOp::create(builder, memref.getLoc(), memref); Value condition = buildBoolValue(builder, memref.getLoc(), true); Value newMemref = cloneOp.getResult(); updateOwnership(newMemref, condition); @@ -196,8 +196,8 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate( // Simply cast unranked MemRefs to ranked memrefs with 0 dimensions such // that we can call extract_strided_metadata on it. if (auto unrankedMemRefTy = dyn_cast<UnrankedMemRefType>(memref.getType())) - memref = builder.create<memref::ReinterpretCastOp>( - loc, memref, + memref = memref::ReinterpretCastOp::create( + builder, loc, memref, /*offset=*/builder.getIndexAttr(0), /*sizes=*/ArrayRef<OpFoldResult>{}, /*strides=*/ArrayRef<OpFoldResult>{}); @@ -207,7 +207,7 @@ LogicalResult DeallocationState::getMemrefsAndConditionsToDeallocate( // alloc operation has to be passed to the dealloc operation. Passing // subviews, etc. to a dealloc operation is not allowed. memrefs.push_back( - builder.create<memref::ExtractStridedMetadataOp>(loc, memref) + memref::ExtractStridedMetadataOp::create(builder, loc, memref) .getResult(0)); conditions.push_back(ownership.getIndicator()); } @@ -296,8 +296,8 @@ FailureOr<Operation *> deallocation_impl::insertDeallocOpForReturnLike( if (memrefs.empty() && toRetain.empty()) return op; - auto deallocOp = builder.create<bufferization::DeallocOp>( - op->getLoc(), memrefs, conditions, toRetain); + auto deallocOp = bufferization::DeallocOp::create( + builder, op->getLoc(), memrefs, conditions, toRetain); // We want to replace the current ownership of the retained values with the // result values of the dealloc operation as they are always unique. diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 8f17a82f..f7b0b87 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -18,7 +18,6 @@ #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/Debug.h" //===----------------------------------------------------------------------===// // BufferizableOpInterface @@ -35,8 +34,6 @@ namespace bufferization { MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::AnalysisState) #define DEBUG_TYPE "bufferizable-op-interface" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << (X)) using namespace mlir; using namespace bufferization; @@ -170,8 +167,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue( if (llvm::isa<RankedTensorType>(shapedValue.getType())) { tensor = shapedValue; } else if (llvm::isa<MemRefType>(shapedValue.getType())) { - tensor = b.create<ToTensorOp>( - loc, memref::getTensorTypeFromMemRefType(shapedValue.getType()), + tensor = ToTensorOp::create( + b, loc, memref::getTensorTypeFromMemRefType(shapedValue.getType()), shapedValue); } else if (llvm::isa<UnrankedTensorType>(shapedValue.getType()) || llvm::isa<UnrankedMemRefType>(shapedValue.getType())) { @@ -209,8 +206,8 @@ FailureOr<Value> bufferization::allocateTensorForShapedValue( } // Create AllocTensorOp. - auto allocTensorOp = b.create<AllocTensorOp>(loc, tensorType, dynamicSizes, - copy ? tensor : Value()); + auto allocTensorOp = AllocTensorOp::create(b, loc, tensorType, dynamicSizes, + copy ? tensor : Value()); // Add 'memory_space' attribute. Not needed if 'copy' operand is specified. if (copy) @@ -691,8 +688,8 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value, if (failed(bufferType)) return failure(); ensureToBufferOpIsValid(value, *bufferType); - return rewriter - .create<bufferization::ToBufferOp>(value.getLoc(), *bufferType, value) + return bufferization::ToBufferOp::create(rewriter, value.getLoc(), + *bufferType, value) .getResult(); } @@ -753,8 +750,8 @@ void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, // ToTensorOp. Throughout bufferization, this ToTensorOp will gradually // loose all of its users and eventually DCE away. rewriter.setInsertionPointAfter(op); - replacement = rewriter.create<bufferization::ToTensorOp>( - replacement.getLoc(), opResult.getType(), replacement); + replacement = bufferization::ToTensorOp::create( + rewriter, replacement.getLoc(), opResult.getType(), replacement); } replacements.push_back(replacement); } @@ -775,11 +772,10 @@ FailureOr<Value> BufferizationOptions::createAlloc(OpBuilder &b, Location loc, // Default bufferallocation via AllocOp. if (bufferAlignment != 0) - return b - .create<memref::AllocOp>(loc, type, dynShape, - b.getI64IntegerAttr(bufferAlignment)) + return memref::AllocOp::create(b, loc, type, dynShape, + b.getI64IntegerAttr(bufferAlignment)) .getResult(); - return b.create<memref::AllocOp>(loc, type, dynShape).getResult(); + return memref::AllocOp::create(b, loc, type, dynShape).getResult(); } /// Create a memory copy between two memref buffers. @@ -788,7 +784,7 @@ LogicalResult BufferizationOptions::createMemCpy(OpBuilder &b, Location loc, if (memCpyFn) return (*memCpyFn)(b, loc, from, to); - b.create<memref::CopyOp>(loc, from, to); + memref::CopyOp::create(b, loc, from, to); return success(); } diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 875a065..7eb729f 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -58,7 +58,7 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue( // a fix extra conditions in `isGuaranteedCastCompatible`. if (memref::CastOp::areCastCompatible(srcType, destType) && isGuaranteedCastCompatible(srcType, destType)) { - Value casted = b.create<memref::CastOp>(value.getLoc(), destType, value); + Value casted = memref::CastOp::create(b, value.getLoc(), destType, value); return casted; } @@ -67,7 +67,7 @@ FailureOr<Value> mlir::bufferization::castOrReallocMemRefValue( for (int i = 0; i < destType.getRank(); ++i) { if (destType.getShape()[i] != ShapedType::kDynamic) continue; - Value size = b.create<memref::DimOp>(loc, value, i); + Value size = memref::DimOp::create(b, loc, value, i); dynamicOperands.push_back(size); } @@ -134,10 +134,10 @@ void mlir::bufferization::populateDynamicDimSizes( for (int64_t i = 0; i < shapedType.getRank(); ++i) { if (shapedType.isDynamicDim(i)) { if (llvm::isa<MemRefType>(shapedType)) { - dynamicDims.push_back(b.create<memref::DimOp>(loc, shapedValue, i)); + dynamicDims.push_back(memref::DimOp::create(b, loc, shapedValue, i)); } else { assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor"); - dynamicDims.push_back(b.create<tensor::DimOp>(loc, shapedValue, i)); + dynamicDims.push_back(tensor::DimOp::create(b, loc, shapedValue, i)); } } } @@ -321,8 +321,8 @@ struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> { newShape, op.getType().getElementType(), op.getType().getEncoding()); if (newType == op.getType()) return failure(); - auto newOp = rewriter.create<AllocTensorOp>( - op.getLoc(), newType, newDynamicSizes, /*copy=*/Value()); + auto newOp = AllocTensorOp::create(rewriter, op.getLoc(), newType, + newDynamicSizes, /*copy=*/Value()); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); return success(); } @@ -427,7 +427,7 @@ void AllocTensorOp::print(OpAsmPrinter &p) { Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) { assert(isDynamicDim(idx) && "expected dynamic dim"); if (getCopy()) - return b.create<tensor::DimOp>(getLoc(), getCopy(), idx); + return tensor::DimOp::create(b, getLoc(), getCopy(), idx); return getOperand(getIndexOfDynamicSize(idx)); } @@ -513,8 +513,8 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> { } if (source.getType() != cloneOp.getType()) - source = rewriter.create<memref::CastOp>(cloneOp.getLoc(), - cloneOp.getType(), source); + source = memref::CastOp::create(rewriter, cloneOp.getLoc(), + cloneOp.getType(), source); rewriter.replaceOp(cloneOp, source); rewriter.eraseOp(redundantDealloc); return success(); @@ -538,7 +538,7 @@ LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options, state); if (failed(buffer)) return failure(); - rewriter.create<memref::DeallocOp>(getLoc(), *buffer); + memref::DeallocOp::create(rewriter, getLoc(), *buffer); rewriter.eraseOp(getOperation()); return success(); } @@ -643,8 +643,9 @@ Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder, assert(getRestrict() && "expected that ops with memrefs dest have 'restrict'"); setRestrict(false); - return builder.create<ToTensorOp>( - loc, memref::getTensorTypeFromMemRefType(getDest().getType()), getDest(), + return ToTensorOp::create( + builder, loc, memref::getTensorTypeFromMemRefType(getDest().getType()), + getDest(), /*restrict=*/true, getWritable()); } @@ -804,10 +805,18 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> { tensorCastOperand.getOperand().getType()); if (!srcTensorType) return failure(); + auto currentOutputMemRefType = + dyn_cast<MemRefType>(toBuffer.getResult().getType()); + if (!currentOutputMemRefType) + return failure(); + auto memrefType = MemRefType::get(srcTensorType.getShape(), - srcTensorType.getElementType()); - Value memref = rewriter.create<ToBufferOp>(toBuffer.getLoc(), memrefType, - tensorCastOperand.getOperand()); + srcTensorType.getElementType(), + currentOutputMemRefType.getLayout(), + currentOutputMemRefType.getMemorySpace()); + Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType, + tensorCastOperand.getOperand(), + toBuffer.getReadOnly()); rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(), memref); return success(); @@ -880,12 +889,12 @@ LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter, std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) { - return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) + return memref::DeallocOp::create(builder, alloc.getLoc(), alloc) .getOperation(); } std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) { - return builder.create<CloneOp>(alloc.getLoc(), alloc).getResult(); + return CloneOp::create(builder, alloc.getLoc(), alloc).getResult(); } //===----------------------------------------------------------------------===// @@ -959,7 +968,7 @@ struct DeallocRemoveDuplicateDeallocMemrefs Value &newCond = newConditions[memrefToCondition[memref]]; if (newCond != cond) newCond = - rewriter.create<arith::OrIOp>(deallocOp.getLoc(), newCond, cond); + arith::OrIOp::create(rewriter, deallocOp.getLoc(), newCond, cond); } else { memrefToCondition.insert({memref, newConditions.size()}); newMemrefs.push_back(memref); @@ -1014,8 +1023,8 @@ struct DeallocRemoveDuplicateRetainedMemrefs // We need to create a new op because the number of results is always the // same as the number of condition operands. auto newDeallocOp = - rewriter.create<DeallocOp>(deallocOp.getLoc(), deallocOp.getMemrefs(), - deallocOp.getConditions(), newRetained); + DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(), + deallocOp.getConditions(), newRetained); SmallVector<Value> replacements( llvm::map_range(resultReplacementIdx, [&](unsigned idx) { return newDeallocOp.getUpdatedConditions()[idx]; @@ -1036,8 +1045,8 @@ struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> { LogicalResult matchAndRewrite(DeallocOp deallocOp, PatternRewriter &rewriter) const override { if (deallocOp.getMemrefs().empty()) { - Value constFalse = rewriter.create<arith::ConstantOp>( - deallocOp.getLoc(), rewriter.getBoolAttr(false)); + Value constFalse = arith::ConstantOp::create(rewriter, deallocOp.getLoc(), + rewriter.getBoolAttr(false)); rewriter.replaceOp( deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(), constFalse)); diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index db1eb20..7f495b0 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -70,12 +70,12 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, *getFunctionBoundaryTypeConversion()); if (getMemcpyOp() == "memref.copy") { options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { - b.create<memref::CopyOp>(loc, from, to); + memref::CopyOp::create(b, loc, from, to); return success(); }; } else if (getMemcpyOp() == "linalg.copy") { options.memCpyFn = [](OpBuilder &b, Location loc, Value from, Value to) { - b.create<linalg::CopyOp>(loc, from, to); + linalg::CopyOp::create(b, loc, from, to); return success(); }; } else { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp index c5fab80..8916526 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -167,8 +167,8 @@ struct RemoveDeallocMemrefsContainedInRetained std::optional<bool> analysisResult = analysis.isSameAllocation(retained, memref); if (analysisResult == true) { - auto disjunction = rewriter.create<arith::OrIOp>( - deallocOp.getLoc(), updatedCondition, cond); + auto disjunction = arith::OrIOp::create(rewriter, deallocOp.getLoc(), + updatedCondition, cond); rewriter.replaceAllUsesExcept(updatedCondition, disjunction.getResult(), disjunction); } @@ -247,16 +247,16 @@ struct RemoveRetainedMemrefsGuaranteedToNotAlias continue; } - replacements.push_back(rewriter.create<arith::ConstantOp>( - deallocOp.getLoc(), rewriter.getBoolAttr(false))); + replacements.push_back(arith::ConstantOp::create( + rewriter, deallocOp.getLoc(), rewriter.getBoolAttr(false))); } if (newRetainedMemrefs.size() == deallocOp.getRetained().size()) return failure(); - auto newDeallocOp = rewriter.create<DeallocOp>( - deallocOp.getLoc(), deallocOp.getMemrefs(), deallocOp.getConditions(), - newRetainedMemrefs); + auto newDeallocOp = + DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(), + deallocOp.getConditions(), newRetainedMemrefs); int i = 0; for (auto &repl : replacements) { if (!repl) @@ -326,8 +326,8 @@ struct SplitDeallocWhenNotAliasingAnyOther } // Create new bufferization.dealloc op for `memref`. - auto newDeallocOp = rewriter.create<DeallocOp>(loc, memref, cond, - deallocOp.getRetained()); + auto newDeallocOp = DeallocOp::create(rewriter, loc, memref, cond, + deallocOp.getRetained()); updatedConditions.push_back( llvm::to_vector(ValueRange(newDeallocOp.getUpdatedConditions()))); } @@ -337,8 +337,9 @@ struct SplitDeallocWhenNotAliasingAnyOther return failure(); // Create bufferization.dealloc op for all remaining memrefs. - auto newDeallocOp = rewriter.create<DeallocOp>( - loc, remainingMemrefs, remainingConditions, deallocOp.getRetained()); + auto newDeallocOp = + DeallocOp::create(rewriter, loc, remainingMemrefs, remainingConditions, + deallocOp.getRetained()); // Bit-or all conditions. SmallVector<Value> replacements = @@ -347,8 +348,8 @@ struct SplitDeallocWhenNotAliasingAnyOther assert(replacements.size() == additionalConditions.size() && "expected same number of updated conditions"); for (int64_t i = 0, e = replacements.size(); i < e; ++i) { - replacements[i] = rewriter.create<arith::OrIOp>( - loc, replacements[i], additionalConditions[i]); + replacements[i] = arith::OrIOp::create(rewriter, loc, replacements[i], + additionalConditions[i]); } } rewriter.replaceOp(deallocOp, replacements); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp index 6924e88..e30e094 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferResultsToOutParams.cpp @@ -132,7 +132,7 @@ updateReturnOps(func::FuncOp func, ArrayRef<BlockArgument> appendedEntryArgs, return WalkResult::interrupt(); } } - builder.create<func::ReturnOp>(op.getLoc(), keepAsReturnOperands); + func::ReturnOp::create(builder, op.getLoc(), keepAsReturnOperands); op.erase(); return WalkResult::advance(); }); @@ -190,7 +190,7 @@ updateCalls(ModuleOp module, assert(hasFullyDynamicLayoutMap(memrefType) && "layout map not supported"); outParam = - builder.create<memref::CastOp>(op.getLoc(), memrefType, outParam); + memref::CastOp::create(builder, op.getLoc(), memrefType, outParam); } memref.replaceAllUsesWith(outParam); outParams.push_back(outParam); @@ -200,8 +200,8 @@ updateCalls(ModuleOp module, newOperands.append(outParams.begin(), outParams.end()); auto newResultTypes = llvm::to_vector<6>(llvm::map_range( replaceWithNewCallResults, [](Value v) { return v.getType(); })); - auto newCall = builder.create<func::CallOp>(op.getLoc(), op.getCalleeAttr(), - newResultTypes, newOperands); + auto newCall = func::CallOp::create( + builder, op.getLoc(), op.getCalleeAttr(), newResultTypes, newOperands); for (auto t : llvm::zip(replaceWithNewCallResults, newCall.getResults())) std::get<0>(t).replaceAllUsesWith(std::get<1>(t)); op.erase(); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp index a66be7d..c0e0809 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -141,8 +141,9 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, cast<MemRefType>(getMemRefTypeWithStaticIdentityLayout(type)); if (memorySpace) memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); - auto global = globalBuilder.create<memref::GlobalOp>( - constantOp.getLoc(), (Twine("__constant_") + os.str()).str(), + auto global = memref::GlobalOp::create( + globalBuilder, constantOp.getLoc(), + (Twine("__constant_") + os.str()).str(), /*sym_visibility=*/globalBuilder.getStringAttr("private"), /*type=*/memrefType, /*initial_value=*/cast<ElementsAttr>(constantOp.getValue()), diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 246555d..91f6f25 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -434,8 +434,8 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, // Replace all uses of the original tensor bbArg. rewriter.setInsertionPointToStart(block); if (!bbArgUses.empty()) { - Value toTensorOp = rewriter.create<bufferization::ToTensorOp>( - bbArg.getLoc(), tensorType, bbArg); + Value toTensorOp = bufferization::ToTensorOp::create( + rewriter, bbArg.getLoc(), tensorType, bbArg); for (OpOperand *use : bbArgUses) use->set(toTensorOp); } @@ -466,13 +466,13 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, if (failed(operandBufferType)) return failure(); rewriter.setInsertionPointAfterValue(operand); - Value bufferizedOperand = rewriter.create<bufferization::ToBufferOp>( - operand.getLoc(), *operandBufferType, operand); + Value bufferizedOperand = bufferization::ToBufferOp::create( + rewriter, operand.getLoc(), *operandBufferType, operand); // A cast is needed if the operand and the block argument have different // bufferized types. if (type != *operandBufferType) - bufferizedOperand = rewriter.create<memref::CastOp>( - operand.getLoc(), type, bufferizedOperand); + bufferizedOperand = memref::CastOp::create(rewriter, operand.getLoc(), + type, bufferizedOperand); newOperands.push_back(bufferizedOperand); } operands.getMutableForwardedOperands().assign(newOperands); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp index c10d290..a50ddbe 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/DropEquivalentBufferResults.cpp @@ -118,8 +118,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { // Update function calls. for (func::CallOp callOp : callerMap[funcOp]) { rewriter.setInsertionPoint(callOp); - auto newCallOp = rewriter.create<func::CallOp>(callOp.getLoc(), funcOp, - callOp.getOperands()); + auto newCallOp = func::CallOp::create(rewriter, callOp.getLoc(), funcOp, + callOp.getOperands()); SmallVector<Value> newResults; int64_t nextResult = 0; for (int64_t i = 0; i < callOp.getNumResults(); ++i) { @@ -134,8 +134,8 @@ mlir::bufferization::dropEquivalentBufferResults(ModuleOp module) { Type expectedType = callOp.getResult(i).getType(); if (replacement.getType() != expectedType) { // A cast must be inserted at the call site. - replacement = rewriter.create<memref::CastOp>( - callOp.getLoc(), expectedType, replacement); + replacement = memref::CastOp::create(rewriter, callOp.getLoc(), + expectedType, replacement); } newResults.push_back(replacement); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index b7db2e8..1784964 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -168,8 +168,8 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( cast<ShapedType>(v.getType()).getElementType()) continue; rewriter.setInsertionPointAfterValue(replacement); - replacement = rewriter.create<tensor::CastOp>(v.getLoc(), v.getType(), - replacement); + replacement = tensor::CastOp::create(rewriter, v.getLoc(), v.getType(), + replacement); } // Replace the specific use of the tensor::EmptyOp. rewriter.modifyOpInPlace(user, [&]() { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 2a98203..f69efd1 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -319,8 +319,9 @@ struct CallOpInterface } // 3. Create the new CallOp. - Operation *newCallOp = rewriter.create<func::CallOp>( - callOp.getLoc(), funcOp.getSymName(), resultTypes, newOperands); + Operation *newCallOp = + func::CallOp::create(rewriter, callOp.getLoc(), funcOp.getSymName(), + resultTypes, newOperands); newCallOp->setAttrs(callOp->getAttrs()); // 4. Replace the old op with the new op. @@ -483,8 +484,8 @@ struct FuncOpInterface // Note: If `inferFunctionResultLayout = true`, casts are later folded // away. - Value toBufferOp = rewriter.create<bufferization::ToBufferOp>( - returnOp.getLoc(), bufferizedType, returnVal); + Value toBufferOp = bufferization::ToBufferOp::create( + rewriter, returnOp.getLoc(), bufferizedType, returnVal); returnValues.push_back(toBufferOp); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp index a611126..e9ad13f 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/LowerDeallocations.cpp @@ -64,8 +64,8 @@ class DeallocOpConversion rewriter.replaceOpWithNewOp<scf::IfOp>( op, adaptor.getConditions()[0], [&](OpBuilder &builder, Location loc) { - builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]); - builder.create<scf::YieldOp>(loc); + memref::DeallocOp::create(builder, loc, adaptor.getMemrefs()[0]); + scf::YieldOp::create(builder, loc); }); return success(); } @@ -108,45 +108,46 @@ class DeallocOpConversion // Compute the base pointer indices, compare all retained indices to the // memref index to check if they alias. SmallVector<Value> doesNotAliasList; - Value memrefAsIdx = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>( - op->getLoc(), adaptor.getMemrefs()[0]); + Value memrefAsIdx = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, op->getLoc(), adaptor.getMemrefs()[0]); for (Value retained : adaptor.getRetained()) { - Value retainedAsIdx = - rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op->getLoc(), - retained); - Value doesNotAlias = rewriter.create<arith::CmpIOp>( - op->getLoc(), arith::CmpIPredicate::ne, memrefAsIdx, retainedAsIdx); + Value retainedAsIdx = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, op->getLoc(), retained); + Value doesNotAlias = arith::CmpIOp::create(rewriter, op->getLoc(), + arith::CmpIPredicate::ne, + memrefAsIdx, retainedAsIdx); doesNotAliasList.push_back(doesNotAlias); } // AND-reduce the list of booleans from above. Value prev = doesNotAliasList.front(); for (Value doesNotAlias : ArrayRef(doesNotAliasList).drop_front()) - prev = rewriter.create<arith::AndIOp>(op->getLoc(), prev, doesNotAlias); + prev = arith::AndIOp::create(rewriter, op->getLoc(), prev, doesNotAlias); // Also consider the condition given by the dealloc operation and perform a // conditional deallocation guarded by that value. - Value shouldDealloc = rewriter.create<arith::AndIOp>( - op->getLoc(), prev, adaptor.getConditions()[0]); + Value shouldDealloc = arith::AndIOp::create(rewriter, op->getLoc(), prev, + adaptor.getConditions()[0]); - rewriter.create<scf::IfOp>( - op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { - builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[0]); - builder.create<scf::YieldOp>(loc); - }); + scf::IfOp::create(rewriter, op.getLoc(), shouldDealloc, + [&](OpBuilder &builder, Location loc) { + memref::DeallocOp::create(builder, loc, + adaptor.getMemrefs()[0]); + scf::YieldOp::create(builder, loc); + }); // Compute the replacement values for the dealloc operation results. This // inserts an already canonicalized form of // `select(does_alias_with_memref(r), memref_cond, false)` for each retained // value r. SmallVector<Value> replacements; - Value trueVal = rewriter.create<arith::ConstantOp>( - op->getLoc(), rewriter.getBoolAttr(true)); + Value trueVal = arith::ConstantOp::create(rewriter, op->getLoc(), + rewriter.getBoolAttr(true)); for (Value doesNotAlias : doesNotAliasList) { Value aliases = - rewriter.create<arith::XOrIOp>(op->getLoc(), doesNotAlias, trueVal); - Value result = rewriter.create<arith::AndIOp>(op->getLoc(), aliases, - adaptor.getConditions()[0]); + arith::XOrIOp::create(rewriter, op->getLoc(), doesNotAlias, trueVal); + Value result = arith::AndIOp::create(rewriter, op->getLoc(), aliases, + adaptor.getConditions()[0]); replacements.push_back(result); } @@ -230,108 +231,112 @@ class DeallocOpConversion // Without storing them to memrefs, we could not use for-loops but only a // completely unrolled version of it, potentially leading to code-size // blow-up. - Value toDeallocMemref = rewriter.create<memref::AllocOp>( - op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, - rewriter.getIndexType())); - Value conditionMemref = rewriter.create<memref::AllocOp>( - op.getLoc(), MemRefType::get({(int64_t)adaptor.getConditions().size()}, - rewriter.getI1Type())); - Value toRetainMemref = rewriter.create<memref::AllocOp>( - op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, - rewriter.getIndexType())); + Value toDeallocMemref = memref::AllocOp::create( + rewriter, op.getLoc(), + MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, + rewriter.getIndexType())); + Value conditionMemref = memref::AllocOp::create( + rewriter, op.getLoc(), + MemRefType::get({(int64_t)adaptor.getConditions().size()}, + rewriter.getI1Type())); + Value toRetainMemref = memref::AllocOp::create( + rewriter, op.getLoc(), + MemRefType::get({(int64_t)adaptor.getRetained().size()}, + rewriter.getIndexType())); auto getConstValue = [&](uint64_t value) -> Value { - return rewriter.create<arith::ConstantOp>(op.getLoc(), - rewriter.getIndexAttr(value)); + return arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getIndexAttr(value)); }; // Extract the base pointers of the memrefs as indices to check for aliasing // at runtime. for (auto [i, toDealloc] : llvm::enumerate(adaptor.getMemrefs())) { - Value memrefAsIdx = - rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(), - toDealloc); - rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, - toDeallocMemref, getConstValue(i)); + Value memrefAsIdx = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, op.getLoc(), toDealloc); + memref::StoreOp::create(rewriter, op.getLoc(), memrefAsIdx, + toDeallocMemref, getConstValue(i)); } for (auto [i, cond] : llvm::enumerate(adaptor.getConditions())) - rewriter.create<memref::StoreOp>(op.getLoc(), cond, conditionMemref, - getConstValue(i)); + memref::StoreOp::create(rewriter, op.getLoc(), cond, conditionMemref, + getConstValue(i)); for (auto [i, toRetain] : llvm::enumerate(adaptor.getRetained())) { - Value memrefAsIdx = - rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(op.getLoc(), - toRetain); - rewriter.create<memref::StoreOp>(op.getLoc(), memrefAsIdx, toRetainMemref, - getConstValue(i)); + Value memrefAsIdx = memref::ExtractAlignedPointerAsIndexOp::create( + rewriter, op.getLoc(), toRetain); + memref::StoreOp::create(rewriter, op.getLoc(), memrefAsIdx, + toRetainMemref, getConstValue(i)); } // Cast the allocated memrefs to dynamic shape because we want only one // helper function no matter how many operands the bufferization.dealloc // has. - Value castedDeallocMemref = rewriter.create<memref::CastOp>( - op->getLoc(), + Value castedDeallocMemref = memref::CastOp::create( + rewriter, op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), toDeallocMemref); - Value castedCondsMemref = rewriter.create<memref::CastOp>( - op->getLoc(), + Value castedCondsMemref = memref::CastOp::create( + rewriter, op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), conditionMemref); - Value castedRetainMemref = rewriter.create<memref::CastOp>( - op->getLoc(), + Value castedRetainMemref = memref::CastOp::create( + rewriter, op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getIndexType()), toRetainMemref); - Value deallocCondsMemref = rewriter.create<memref::AllocOp>( - op.getLoc(), MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, - rewriter.getI1Type())); - Value retainCondsMemref = rewriter.create<memref::AllocOp>( - op.getLoc(), MemRefType::get({(int64_t)adaptor.getRetained().size()}, - rewriter.getI1Type())); - - Value castedDeallocCondsMemref = rewriter.create<memref::CastOp>( - op->getLoc(), + Value deallocCondsMemref = memref::AllocOp::create( + rewriter, op.getLoc(), + MemRefType::get({(int64_t)adaptor.getMemrefs().size()}, + rewriter.getI1Type())); + Value retainCondsMemref = memref::AllocOp::create( + rewriter, op.getLoc(), + MemRefType::get({(int64_t)adaptor.getRetained().size()}, + rewriter.getI1Type())); + + Value castedDeallocCondsMemref = memref::CastOp::create( + rewriter, op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), deallocCondsMemref); - Value castedRetainCondsMemref = rewriter.create<memref::CastOp>( - op->getLoc(), + Value castedRetainCondsMemref = memref::CastOp::create( + rewriter, op->getLoc(), MemRefType::get({ShapedType::kDynamic}, rewriter.getI1Type()), retainCondsMemref); Operation *symtableOp = op->getParentWithTrait<OpTrait::SymbolTable>(); - rewriter.create<func::CallOp>( - op.getLoc(), deallocHelperFuncMap.lookup(symtableOp), + func::CallOp::create( + rewriter, op.getLoc(), deallocHelperFuncMap.lookup(symtableOp), SmallVector<Value>{castedDeallocMemref, castedRetainMemref, castedCondsMemref, castedDeallocCondsMemref, castedRetainCondsMemref}); for (unsigned i = 0, e = adaptor.getMemrefs().size(); i < e; ++i) { Value idxValue = getConstValue(i); - Value shouldDealloc = rewriter.create<memref::LoadOp>( - op.getLoc(), deallocCondsMemref, idxValue); - rewriter.create<scf::IfOp>( - op.getLoc(), shouldDealloc, [&](OpBuilder &builder, Location loc) { - builder.create<memref::DeallocOp>(loc, adaptor.getMemrefs()[i]); - builder.create<scf::YieldOp>(loc); - }); + Value shouldDealloc = memref::LoadOp::create( + rewriter, op.getLoc(), deallocCondsMemref, idxValue); + scf::IfOp::create(rewriter, op.getLoc(), shouldDealloc, + [&](OpBuilder &builder, Location loc) { + memref::DeallocOp::create(builder, loc, + adaptor.getMemrefs()[i]); + scf::YieldOp::create(builder, loc); + }); } SmallVector<Value> replacements; for (unsigned i = 0, e = adaptor.getRetained().size(); i < e; ++i) { Value idxValue = getConstValue(i); - Value ownership = rewriter.create<memref::LoadOp>( - op.getLoc(), retainCondsMemref, idxValue); + Value ownership = memref::LoadOp::create(rewriter, op.getLoc(), + retainCondsMemref, idxValue); replacements.push_back(ownership); } // Deallocate above allocated memrefs again to avoid memory leaks. // Deallocation will not be run on code after this stage. - rewriter.create<memref::DeallocOp>(op.getLoc(), toDeallocMemref); - rewriter.create<memref::DeallocOp>(op.getLoc(), toRetainMemref); - rewriter.create<memref::DeallocOp>(op.getLoc(), conditionMemref); - rewriter.create<memref::DeallocOp>(op.getLoc(), deallocCondsMemref); - rewriter.create<memref::DeallocOp>(op.getLoc(), retainCondsMemref); + memref::DeallocOp::create(rewriter, op.getLoc(), toDeallocMemref); + memref::DeallocOp::create(rewriter, op.getLoc(), toRetainMemref); + memref::DeallocOp::create(rewriter, op.getLoc(), conditionMemref); + memref::DeallocOp::create(rewriter, op.getLoc(), deallocCondsMemref); + memref::DeallocOp::create(rewriter, op.getLoc(), retainCondsMemref); rewriter.replaceOp(op, replacements); return success(); @@ -349,8 +354,8 @@ public: ConversionPatternRewriter &rewriter) const override { // Lower the trivial case. if (adaptor.getMemrefs().empty()) { - Value falseVal = rewriter.create<arith::ConstantOp>( - op.getLoc(), rewriter.getBoolAttr(false)); + Value falseVal = arith::ConstantOp::create(rewriter, op.getLoc(), + rewriter.getBoolAttr(false)); rewriter.replaceOp( op, SmallVector<Value>(adaptor.getRetained().size(), falseVal)); return success(); @@ -449,93 +454,92 @@ func::FuncOp mlir::bufferization::buildDeallocationLibraryFunction( Value retainCondsMemref = helperFuncOp.getArguments()[4]; // Insert some prerequisites. - Value c0 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(0)); - Value c1 = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1)); + Value c0 = arith::ConstantOp::create(builder, loc, builder.getIndexAttr(0)); + Value c1 = arith::ConstantOp::create(builder, loc, builder.getIndexAttr(1)); Value trueValue = - builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(true)); + arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true)); Value falseValue = - builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(false)); - Value toDeallocSize = builder.create<memref::DimOp>(loc, toDeallocMemref, c0); - Value toRetainSize = builder.create<memref::DimOp>(loc, toRetainMemref, c0); + arith::ConstantOp::create(builder, loc, builder.getBoolAttr(false)); + Value toDeallocSize = + memref::DimOp::create(builder, loc, toDeallocMemref, c0); + Value toRetainSize = memref::DimOp::create(builder, loc, toRetainMemref, c0); - builder.create<scf::ForOp>( - loc, c0, toRetainSize, c1, ValueRange(), + scf::ForOp::create( + builder, loc, c0, toRetainSize, c1, ValueRange(), [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - builder.create<memref::StoreOp>(loc, falseValue, retainCondsMemref, i); - builder.create<scf::YieldOp>(loc); + memref::StoreOp::create(builder, loc, falseValue, retainCondsMemref, i); + scf::YieldOp::create(builder, loc); }); - builder.create<scf::ForOp>( - loc, c0, toDeallocSize, c1, ValueRange(), + scf::ForOp::create( + builder, loc, c0, toDeallocSize, c1, ValueRange(), [&](OpBuilder &builder, Location loc, Value outerIter, ValueRange iterArgs) { Value toDealloc = - builder.create<memref::LoadOp>(loc, toDeallocMemref, outerIter); + memref::LoadOp::create(builder, loc, toDeallocMemref, outerIter); Value cond = - builder.create<memref::LoadOp>(loc, conditionMemref, outerIter); + memref::LoadOp::create(builder, loc, conditionMemref, outerIter); // Build the first for loop that computes aliasing with retained // memrefs. - Value noRetainAlias = - builder - .create<scf::ForOp>( - loc, c0, toRetainSize, c1, trueValue, + Value + noRetainAlias = + scf::ForOp::create( + builder, loc, c0, toRetainSize, c1, trueValue, [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - Value retainValue = builder.create<memref::LoadOp>( - loc, toRetainMemref, i); - Value doesAlias = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, retainValue, + Value retainValue = memref::LoadOp::create( + builder, loc, toRetainMemref, i); + Value doesAlias = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, retainValue, toDealloc); - builder.create<scf::IfOp>( - loc, doesAlias, + scf::IfOp::create( + builder, loc, doesAlias, [&](OpBuilder &builder, Location loc) { - Value retainCondValue = - builder.create<memref::LoadOp>( - loc, retainCondsMemref, i); - Value aggregatedRetainCond = - builder.create<arith::OrIOp>( - loc, retainCondValue, cond); - builder.create<memref::StoreOp>( - loc, aggregatedRetainCond, retainCondsMemref, - i); - builder.create<scf::YieldOp>(loc); + Value retainCondValue = memref::LoadOp::create( + builder, loc, retainCondsMemref, i); + Value aggregatedRetainCond = arith::OrIOp::create( + builder, loc, retainCondValue, cond); + memref::StoreOp::create(builder, loc, + aggregatedRetainCond, + retainCondsMemref, i); + scf::YieldOp::create(builder, loc); }); - Value doesntAlias = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ne, retainValue, + Value doesntAlias = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::ne, retainValue, toDealloc); - Value yieldValue = builder.create<arith::AndIOp>( - loc, iterArgs[0], doesntAlias); - builder.create<scf::YieldOp>(loc, yieldValue); + Value yieldValue = arith::AndIOp::create( + builder, loc, iterArgs[0], doesntAlias); + scf::YieldOp::create(builder, loc, yieldValue); }) - .getResult(0); + .getResult(0); // Build the second for loop that adds aliasing with previously // deallocated memrefs. - Value noAlias = - builder - .create<scf::ForOp>( - loc, c0, outerIter, c1, noRetainAlias, + Value + noAlias = + scf::ForOp::create( + builder, loc, c0, outerIter, c1, noRetainAlias, [&](OpBuilder &builder, Location loc, Value i, ValueRange iterArgs) { - Value prevDeallocValue = builder.create<memref::LoadOp>( - loc, toDeallocMemref, i); - Value doesntAlias = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::ne, prevDeallocValue, - toDealloc); - Value yieldValue = builder.create<arith::AndIOp>( - loc, iterArgs[0], doesntAlias); - builder.create<scf::YieldOp>(loc, yieldValue); + Value prevDeallocValue = memref::LoadOp::create( + builder, loc, toDeallocMemref, i); + Value doesntAlias = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::ne, + prevDeallocValue, toDealloc); + Value yieldValue = arith::AndIOp::create( + builder, loc, iterArgs[0], doesntAlias); + scf::YieldOp::create(builder, loc, yieldValue); }) - .getResult(0); + .getResult(0); - Value shouldDealoc = builder.create<arith::AndIOp>(loc, noAlias, cond); - builder.create<memref::StoreOp>(loc, shouldDealoc, deallocCondsMemref, - outerIter); - builder.create<scf::YieldOp>(loc); + Value shouldDealoc = arith::AndIOp::create(builder, loc, noAlias, cond); + memref::StoreOp::create(builder, loc, shouldDealoc, deallocCondsMemref, + outerIter); + scf::YieldOp::create(builder, loc); }); - builder.create<func::ReturnOp>(loc); + func::ReturnOp::create(builder, loc); return helperFuncOp; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index d1d1062..aa53f94 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -1,4 +1,5 @@ -//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// +//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries +//----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,12 +9,13 @@ // // Module Bufferization is an extension of One-Shot Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` -// implementations for FuncOp, CallOp and ReturnOp. +// implementations for FuncOp, CallOp and ReturnOp. Although it is named +// Module Bufferization, it may operate on any SymbolTable. // -// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. -// This function analyzes the given module and determines the order of analysis -// and bufferization: Functions that are called are processed before their -// respective callers. +// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp, +// ...)`. This function analyzes the given op and determines the order of +// analysis and bufferization: Functions that are called are processed before +// their respective callers. // // After analyzing a FuncOp, additional information about its bbArgs is // gathered and stored in `FuncAnalysisState`. @@ -309,7 +311,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) { /// Return `failure()` if we are unable to retrieve the called FuncOp from /// any func::CallOp. static LogicalResult getFuncOpsOrderedByCalls( - ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, + Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables) { // For each FuncOp, the set of functions called by it (i.e. the union of @@ -317,26 +319,29 @@ static LogicalResult getFuncOpsOrderedByCalls( DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; // For each FuncOp, the number of func::CallOp it contains. DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; - - for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) { - // Collect function calls and populate the caller map. - numberCallOpsContainedInFuncOp[funcOp] = 0; - WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { - func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); - assert(calledFunction && "could not retrieved called func::FuncOp"); - // If the called function does not have any tensors in its signature, then - // it is not necessary to bufferize the callee before the caller. - if (!hasTensorSignature(calledFunction)) - return WalkResult::skip(); - - callerMap[calledFunction].insert(callOp); - if (calledBy[calledFunction].insert(funcOp).second) { - numberCallOpsContainedInFuncOp[funcOp]++; + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) { + // Collect function calls and populate the caller map. + numberCallOpsContainedInFuncOp[funcOp] = 0; + WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { + func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); + assert(calledFunction && "could not retrieved called func::FuncOp"); + // If the called function does not have any tensors in its signature, + // then it is not necessary to bufferize the callee before the caller. + if (!hasTensorSignature(calledFunction)) + return WalkResult::skip(); + + callerMap[calledFunction].insert(callOp); + if (calledBy[calledFunction].insert(funcOp).second) { + numberCallOpsContainedInFuncOp[funcOp]++; + } + return WalkResult::advance(); + }); + if (res.wasInterrupted()) + return failure(); } - return WalkResult::advance(); - }); - if (res.wasInterrupted()) - return failure(); + } } // Iteratively remove function operations that do not call any of the @@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) { } LogicalResult -mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, +mlir::bufferization::analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics) { assert(state.getOptions().bufferizeFunctionBoundaries && @@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, } void mlir::bufferization::removeBufferizationAttributesInModule( - ModuleOp moduleOp) { - for (auto op : moduleOp.getOps<func::FuncOp>()) { - for (BlockArgument bbArg : op.getArguments()) - removeBufferizationAttributes(bbArg); + Operation *moduleOp) { + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) { + for (BlockArgument bbArg : funcOp.getArguments()) + removeBufferizationAttributes(bbArg); + } + } } } LogicalResult mlir::bufferization::bufferizeModuleOp( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, + Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); - IRRewriter rewriter(moduleOp.getContext()); + IRRewriter rewriter(moduleOp->getContext()); // A list of non-circular functions in the order in which they are analyzed // and bufferized. @@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } // Bufferize all other ops. - for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { - // Functions were already bufferized. - if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) - continue; - if (failed(bufferizeOp(&op, options, state, statistics))) - return failure(); + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (mlir::Operation &op : + llvm::make_early_inc_range(block.getOperations())) { + // Functions were already bufferized. + if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) + continue; + if (failed(bufferizeOp(&op, options, state, statistics))) + return failure(); + } + } } // Post-pass cleanup of function argument attributes. @@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } LogicalResult mlir::bufferization::runOneShotModuleBufferize( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, + Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp index 605a487..b8ddee6 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OptimizeAllocationLiveness.cpp @@ -18,11 +18,9 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "optimize-allocation-liveness" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir { namespace bufferization { @@ -65,8 +63,8 @@ Operation *findUserWithFreeSideEffect(Value value) { for (const auto &effect : effects) { if (isa<MemoryEffects::Free>(effect.getEffect())) { if (freeOpUser) { - LDBG("Multiple users with free effect found: " << *freeOpUser - << " and " << *user); + LDBG() << "Multiple users with free effect found: " << *freeOpUser + << " and " << *user; return nullptr; } freeOpUser = user; @@ -121,7 +119,7 @@ public: return WalkResult::advance(); auto allocOp = memEffectOp; - LDBG("Checking alloc op: " << allocOp); + LDBG() << "Checking alloc op: " << allocOp; SmallVector<OpResult> allocationResults = collectAllocations(allocOp); // Multiple allocations from a single op are not considered here yet. @@ -129,7 +127,7 @@ public: return WalkResult::advance(); OpResult allocResult = allocationResults[0]; - LDBG("On allocation result: " << allocResult); + LDBG() << "On allocation result: " << allocResult; auto *deallocOp = findUserWithFreeSideEffect(allocResult); if (!deallocOp || (deallocOp->getBlock() != allocOp->getBlock())) { @@ -159,12 +157,12 @@ public: if (lastUser == nullptr) { return WalkResult::advance(); } - LDBG("Last user found: " << *lastUser); + LDBG() << "Last user found: " << *lastUser; assert(lastUser->getBlock() == allocOp->getBlock()); assert(lastUser->getBlock() == deallocOp->getBlock()); // Move the dealloc op after the last user. deallocOp->moveAfter(lastUser); - LDBG("Moved dealloc op after: " << *lastUser); + LDBG() << "Moved dealloc op after: " << *lastUser; return WalkResult::advance(); }); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index 1eeafc4..725fa24 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -43,7 +43,7 @@ using namespace mlir::bufferization; //===----------------------------------------------------------------------===// static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) { - return builder.create<arith::ConstantOp>(loc, builder.getBoolAttr(value)); + return arith::ConstantOp::create(builder, loc, builder.getBoolAttr(value)); } static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); } @@ -750,19 +750,17 @@ Value BufferDeallocation::materializeMemrefWithGuaranteedOwnership( // Insert a runtime check and only clone if we still don't have ownership at // runtime. - Value maybeClone = - builder - .create<scf::IfOp>( - memref.getLoc(), condition, - [&](OpBuilder &builder, Location loc) { - builder.create<scf::YieldOp>(loc, newMemref); - }, - [&](OpBuilder &builder, Location loc) { - Value clone = - builder.create<bufferization::CloneOp>(loc, newMemref); - builder.create<scf::YieldOp>(loc, clone); - }) - .getResult(0); + Value maybeClone = scf::IfOp::create( + builder, memref.getLoc(), condition, + [&](OpBuilder &builder, Location loc) { + scf::YieldOp::create(builder, loc, newMemref); + }, + [&](OpBuilder &builder, Location loc) { + Value clone = bufferization::CloneOp::create( + builder, loc, newMemref); + scf::YieldOp::create(builder, loc, clone); + }) + .getResult(0); Value trueVal = buildBoolValue(builder, memref.getLoc(), true); state.updateOwnership(maybeClone, trueVal); state.addMemrefToDeallocate(maybeClone, maybeClone.getParentBlock()); @@ -797,8 +795,8 @@ BufferDeallocation::handleInterface(BranchOpInterface op) { state.getMemrefsToRetain(block, op->getSuccessor(0), forwardedOperands, toRetain); - auto deallocOp = builder.create<bufferization::DeallocOp>( - op.getLoc(), memrefs, conditions, toRetain); + auto deallocOp = bufferization::DeallocOp::create( + builder, op.getLoc(), memrefs, conditions, toRetain); // We want to replace the current ownership of the retained values with the // result values of the dealloc operation as they are always unique. @@ -885,12 +883,11 @@ BufferDeallocation::handleInterface(MemoryEffectOpInterface op) { builder.setInsertionPoint(op); Ownership ownership = state.getOwnership(operand, block); if (ownership.isUnique()) { - Value ownershipInverted = builder.create<arith::XOrIOp>( - op.getLoc(), ownership.getIndicator(), + Value ownershipInverted = arith::XOrIOp::create( + builder, op.getLoc(), ownership.getIndicator(), buildBoolValue(builder, op.getLoc(), true)); - builder.create<cf::AssertOp>( - op.getLoc(), ownershipInverted, - "expected that the block does not have ownership"); + cf::AssertOp::create(builder, op.getLoc(), ownershipInverted, + "expected that the block does not have ownership"); } } } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index f999c93..a6159ee 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -33,7 +33,7 @@ LogicalResult mlir::bufferization::insertTensorCopies( // analysis depending on whether function boundary bufferization is enabled or // not. if (options.bufferizeFunctionBoundaries) { - if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics))) + if (failed(analyzeModuleOp(op, analysisState, statistics))) return failure(); } else { if (failed(analyzeOp(op, analysisState, statistics))) diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt index 3cc52eb..053ee95 100644 --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -19,7 +19,7 @@ add_subdirectory(Linalg) add_subdirectory(LLVMIR) add_subdirectory(Math) add_subdirectory(MemRef) -add_subdirectory(Mesh) +add_subdirectory(Shard) add_subdirectory(MLProgram) add_subdirectory(MPI) add_subdirectory(NVGPU) diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 568da89..4c09022 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -171,10 +171,9 @@ static LogicalResult verifyInitializationAttribute(Operation *op, /// In the format string, all `{}` are replaced by Placeholders, except if the /// `{` is escaped by `{{` - then it doesn't start a placeholder. template <class ArgType> -FailureOr<SmallVector<ReplacementItem>> -parseFormatString(StringRef toParse, ArgType fmtArgs, - std::optional<llvm::function_ref<mlir::InFlightDiagnostic()>> - emitError = {}) { +FailureOr<SmallVector<ReplacementItem>> parseFormatString( + StringRef toParse, ArgType fmtArgs, + llvm::function_ref<mlir::InFlightDiagnostic()> emitError = {}) { SmallVector<ReplacementItem> items; // If there are not operands, the format string is not interpreted. @@ -197,8 +196,7 @@ parseFormatString(StringRef toParse, ArgType fmtArgs, continue; } if (toParse.size() < 2) { - return (*emitError)() - << "expected '}' after unescaped '{' at end of string"; + return emitError() << "expected '}' after unescaped '{' at end of string"; } // toParse contains at least two characters and starts with `{`. char nextChar = toParse[1]; @@ -214,8 +212,8 @@ parseFormatString(StringRef toParse, ArgType fmtArgs, continue; } - if (emitError.has_value()) { - return (*emitError)() << "expected '}' after unescaped '{'"; + if (emitError) { + return emitError() << "expected '}' after unescaped '{'"; } return failure(); } diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp index d5fe3b4..3f0690c 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp @@ -62,9 +62,7 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> { continue; for (Value operand : op.getOperands()) { - auto usedExpression = - dyn_cast_if_present<ExpressionOp>(operand.getDefiningOp()); - + auto usedExpression = operand.getDefiningOp<ExpressionOp>(); if (!usedExpression) continue; diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp index 612e809..fa05ad8 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp @@ -31,7 +31,7 @@ struct WrapFuncInClassPass Operation *rootOp = getOperation(); RewritePatternSet patterns(&getContext()); - populateFuncPatterns(patterns, namedAttribute); + populateFuncPatterns(patterns); walkAndApplyPatterns(rootOp, std::move(patterns)); } @@ -43,8 +43,8 @@ struct WrapFuncInClassPass class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> { public: - WrapFuncInClass(MLIRContext *context, StringRef attrName) - : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {} + WrapFuncInClass(MLIRContext *context) + : OpRewritePattern<emitc::FuncOp>(context) {} LogicalResult matchAndRewrite(emitc::FuncOp funcOp, PatternRewriter &rewriter) const override { @@ -101,12 +101,8 @@ public: rewriter.replaceOp(funcOp, newClassOp); return success(); } - -private: - StringRef attributeName; }; -void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns, - StringRef namedAttribute) { - patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute); +void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) { + patterns.add<WrapFuncInClass>(patterns.getContext()); } diff --git a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp index eb6b59b..1b18ef2 100644 --- a/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp +++ b/mlir/lib/Dialect/Func/Extensions/AllExtensions.cpp @@ -8,7 +8,7 @@ #include "mlir/Dialect/Func/Extensions/AllExtensions.h" #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" -#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h" +#include "mlir/Dialect/Func/Extensions/ShardingExtensions.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt index 47363f4..87ef51e 100644 --- a/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt +++ b/mlir/lib/Dialect/Func/Extensions/CMakeLists.txt @@ -1,7 +1,7 @@ set(LLVM_OPTIONAL_SOURCES AllExtensions.cpp InlinerExtension.cpp - MeshShardingExtensions.cpp + ShardingExtensions.cpp ) add_mlir_extension_library(MLIRFuncInlinerExtension @@ -17,8 +17,8 @@ add_mlir_extension_library(MLIRFuncInlinerExtension MLIRFuncDialect ) -add_mlir_extension_library(MLIRFuncMeshShardingExtensions - MeshShardingExtensions.cpp +add_mlir_extension_library(MLIRFuncShardingExtensions + ShardingExtensions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Func/Extensions @@ -38,5 +38,5 @@ add_mlir_extension_library(MLIRFuncAllExtensions LINK_LIBS PUBLIC MLIRFuncInlinerExtension - MLIRFuncMeshShardingExtensions + MLIRFuncShardingExtensions ) diff --git a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp index da508cc..dfd1348 100644 --- a/mlir/lib/Dialect/Func/Extensions/MeshShardingExtensions.cpp +++ b/mlir/lib/Dialect/Func/Extensions/ShardingExtensions.cpp @@ -1,4 +1,4 @@ -//===- MeshShardingExtensions.cpp - ---------------------------------------===// +//===- ShardingExtensions.cpp - ---------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,9 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Func/Extensions/MeshShardingExtensions.h" +#include "mlir/Dialect/Func/Extensions/ShardingExtensions.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" #include "mlir/IR/MLIRContext.h" namespace mlir::func { @@ -16,7 +16,7 @@ namespace mlir::func { void registerShardingInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, FuncDialect *dialect) { ReturnOp::attachInterface< - mesh::IndependentParallelIteratorDomainShardingInterface<ReturnOp>>( + shard::IndependentParallelIteratorDomainShardingInterface<ReturnOp>>( *ctx); }); } diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index d186a48..5a72ef1 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -1395,40 +1395,12 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value, // RotateOp //===----------------------------------------------------------------------===// -void RotateOp::build(OpBuilder &builder, OperationState &result, Value value, - int32_t offset, int32_t width) { - build(builder, result, value, - arith::ConstantOp::create(builder, result.location, - builder.getI32IntegerAttr(offset)), - arith::ConstantOp::create(builder, result.location, - builder.getI32IntegerAttr(width))); -} - LogicalResult RotateOp::verify() { - auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>(); - if (!offsetConstOp) - return emitOpError() << "offset is not a constant value"; - - auto offsetIntAttr = - llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue()); - - auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>(); - if (!widthConstOp) - return emitOpError() << "width is not a constant value"; - - auto widthIntAttr = - llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue()); - - llvm::APInt offsetValue = offsetIntAttr.getValue(); - llvm::APInt widthValue = widthIntAttr.getValue(); - - if (!widthValue.isPowerOf2()) - return emitOpError() << "width must be a power of two"; + uint32_t offset = getOffset(); + uint32_t width = getWidth(); - if (offsetValue.sge(widthValue) || offsetValue.slt(0)) { - int64_t widthValueInt = widthValue.getSExtValue(); - return emitOpError() << "offset must be in the range [0, " << widthValueInt - << ")"; + if (offset >= width) { + return emitOpError() << "offset must be in the range [0, " << width << ")"; } return success(); diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 1d8279c..21cb2f6 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -39,7 +39,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/LogicalResult.h" @@ -51,11 +51,6 @@ using namespace mlir::transform; using namespace mlir::transform::gpu; #define DEBUG_TYPE "gpu-transforms" -#define DEBUG_TYPE_ALIAS "gpu-transforms-alias" - -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") -#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") //===----------------------------------------------------------------------===// // Apply...ConversionPatternsOp @@ -471,7 +466,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp, ArrayRef<int64_t> availableMappingSizes, ForallRewriteResult &result, const GpuIdBuilder &gpuIdBuilder) { - LDBG("--start rewriteOneForallCommonImpl"); + LDBG() << "--start rewriteOneForallCommonImpl"; // Step 1. Complete the mapping to a full mapping (with 1s) if necessary. auto numParallelIterations = @@ -506,14 +501,14 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( // Otherwise, we have a new insertion without a size -> use size 1. tmpMappingSizes.push_back(1); } - LDBG("----tmpMappingSizes extracted from scf.forall op: " - << llvm::interleaved(tmpMappingSizes)); + LDBG() << "----tmpMappingSizes extracted from scf.forall op: " + << llvm::interleaved(tmpMappingSizes); // Step 2. sort the values by the corresponding DeviceMappingAttrInterface. SmallVector<int64_t> forallMappingSizes = getValuesSortedByKey( forallMappingAttrs.getArrayRef(), tmpMappingSizes, comparator); - LDBG("----forallMappingSizes: " << llvm::interleaved(forallMappingSizes)); - LDBG("----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs)); + LDBG() << "----forallMappingSizes: " << llvm::interleaved(forallMappingSizes); + LDBG() << "----forallMappingAttrs: " << llvm::interleaved(forallMappingAttrs); // Step 3. Generate the mappingIdOps using the provided generator. Location loc = forallOp.getLoc(); @@ -522,24 +517,24 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( SmallVector<int64_t> originalBasis(availableMappingSizes); bool originalBasisWasProvided = !originalBasis.empty(); if (!originalBasisWasProvided) { - LDBG("----originalBasis was not provided, deriving it and there will be no " - "predication"); + LDBG() << "----originalBasis was not provided, deriving it and there will " + "be no " + "predication"; originalBasis = forallMappingSizes; while (originalBasis.size() < 3) originalBasis.push_back(1); } else { - LDBG("----originalBasis was provided, using it, there will be predication"); + LDBG() << "----originalBasis was provided, using it, there will be " + "predication"; } - LLVM_DEBUG( - llvm::interleaveComma(originalBasis, DBGS() << "------originalBasis: "); - llvm::dbgs() << "\n"); + LDBG() << "------originalBasis: " << llvm::interleaved(originalBasis); IdBuilderResult builderResult = gpuIdBuilder.idBuilder(rewriter, loc, forallMappingSizes, originalBasis); if (!builderResult.errorMsg.empty()) return definiteFailureHelper(transformOp, forallOp, builderResult.errorMsg); - LLVM_DEBUG(DBGS() << builderResult); + LDBG() << builderResult; // Step 4. Map the induction variables to the mappingIdOps, this may involve // a permutation. @@ -550,7 +545,7 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( forallMappingAttrs.getArrayRef().take_front(forallOp.getRank()))) { auto mappingAttr = cast<DeviceMappingAttrInterface>(dim); Value peIdOp = mappingIdOps[mappingAttr.getRelativeIndex()]; - LDBG("----map: " << iv << " to " << peIdOp); + LDBG() << "----map: " << iv << " to " << peIdOp; bvm.map(iv, peIdOp); } @@ -596,9 +591,9 @@ static DiagnosedSilenceableFailure rewriteOneForallCommonImpl( // Step 8. Erase old op. rewriter.eraseOp(forallOp); - LDBG("----result forallMappingSizes: " - << llvm::interleaved(forallMappingSizes)); - LDBG("----result mappingIdOps: " << llvm::interleaved(mappingIdOps)); + LDBG() << "----result forallMappingSizes: " + << llvm::interleaved(forallMappingSizes); + LDBG() << "----result mappingIdOps: " << llvm::interleaved(mappingIdOps); result = ForallRewriteResult{forallMappingSizes, mappingIdOps}; return DiagnosedSilenceableFailure::success(); @@ -612,7 +607,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapForallToBlocksImpl( RewriterBase &rewriter, TransformOpInterface transformOp, scf::ForallOp forallOp, SmallVectorImpl<int64_t> &gridDims, const GpuIdBuilder &gpuIdBuilder) { - LDBG("Start mapForallToBlocksImpl"); + LDBG() << "Start mapForallToBlocksImpl"; { // GPU-specific verifications. There is no better place to anchor @@ -893,7 +888,7 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( RewriterBase &rewriter, std::optional<TransformOpInterface> transformOp, Operation *target, ArrayRef<int64_t> blockDims, int64_t warpSize, bool syncAfterDistribute) { - LDBG("Start mapNestedForallToThreadsImpl"); + LDBG() << "Start mapNestedForallToThreadsImpl"; if (blockDims.size() != 3) { return definiteFailureHelper(transformOp, target, "requires size-3 thread mapping"); diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp index 2fba09b..05bd917 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp @@ -27,7 +27,8 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/InterleavedRange.h" using namespace mlir; using namespace mlir::gpu; @@ -36,10 +37,6 @@ using namespace mlir::transform::gpu; #define DEBUG_TYPE "gpu-transforms" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") -#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") - /// Build predicates to filter execution by only the activeIds. Along each /// dimension, 3 cases appear: /// 1. activeMappingSize > availableMappingSize: this is an unsupported case @@ -54,15 +51,9 @@ buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds, ArrayRef<int64_t> activeMappingSizes, ArrayRef<int64_t> availableMappingSizes, std::string &errorMsg) { - // clang-format off - LLVM_DEBUG( - llvm::interleaveComma( - activeMappingSizes, DBGS() << "----activeMappingSizes: "); - DBGS() << "\n"; - llvm::interleaveComma( - availableMappingSizes, DBGS() << "----availableMappingSizes: "); - DBGS() << "\n";); - // clang-format on + LDBG() << "----activeMappingSizes: " << llvm::interleaved(activeMappingSizes); + LDBG() << "----availableMappingSizes: " + << llvm::interleaved(availableMappingSizes); SmallVector<Value> predicateOps; for (auto [activeId, activeMappingSize, availableMappingSize] : @@ -88,10 +79,8 @@ buildPredicates(RewriterBase &rewriter, Location loc, ArrayRef<Value> activeIds, template <typename ThreadOrBlockIdOp> static Value buildLinearId(RewriterBase &rewriter, Location loc, ArrayRef<OpFoldResult> originalBasisOfr) { - LLVM_DEBUG(llvm::interleaveComma( - originalBasisOfr, - DBGS() << "----buildLinearId with originalBasisOfr: "); - llvm::dbgs() << "\n"); + LDBG() << "----buildLinearId with originalBasisOfr: " + << llvm::interleaved(originalBasisOfr); assert(originalBasisOfr.size() == 3 && "expected 3 sizes"); IndexType indexType = rewriter.getIndexType(); AffineExpr tx, ty, tz, bdx, bdy; @@ -157,7 +146,7 @@ commonLinearIdBuilderFn(int64_t multiplicity = 1, mask.createLogicalLinearMappingId(rewriter, scaledLinearIdI64); scaledLinearId = arith::IndexCastUIOp::create( rewriter, loc, rewriter.getIndexType(), logicalLinearIdI64); - LDBG("------adjusting linearId with mask: " << scaledLinearId); + LDBG() << "------adjusting linearId with mask: " << scaledLinearId; } // 3. Compute remapped indices. @@ -179,7 +168,7 @@ commonLinearIdBuilderFn(int64_t multiplicity = 1, if (mask) { Value isActiveIdPredicate = mask.createIsActiveIdPredicate(rewriter, scaledLinearIdI64); - LDBG("------adjusting predicate with mask: " << isActiveIdPredicate); + LDBG() << "------adjusting predicate with mask: " << isActiveIdPredicate; predicateOps.push_back(isActiveIdPredicate); } else { // 4.b. Otherwise, handle predicates using physicalLinearId. diff --git a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp index d88f4d5..8e05436 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ShuffleRewriter.cpp @@ -60,14 +60,12 @@ struct GpuShuffleRewriter : public OpRewritePattern<gpu::ShuffleOp> { // Shuffle the values. ValueRange loRes = - rewriter - .create<gpu::ShuffleOp>(op.getLoc(), lo, op.getOffset(), - op.getWidth(), op.getMode()) + gpu::ShuffleOp::create(rewriter, op.getLoc(), lo, op.getOffset(), + op.getWidth(), op.getMode()) .getResults(); ValueRange hiRes = - rewriter - .create<gpu::ShuffleOp>(op.getLoc(), hi, op.getOffset(), - op.getWidth(), op.getMode()) + gpu::ShuffleOp::create(rewriter, op.getLoc(), hi, op.getOffset(), + op.getWidth(), op.getMode()) .getResults(); // Convert lo back to i64. diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp index b9e2dd5..b45fdf3 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp @@ -197,10 +197,9 @@ Value createSubgroupShuffleReduction(OpBuilder &builder, Location loc, // Parallel reduction using butterfly shuffles. for (unsigned i = ci.clusterStride; i < ci.clusterStride * ci.clusterSize; i <<= 1) { - Value shuffled = builder - .create<gpu::ShuffleOp>(loc, packFn(laneVal), i, - /*width=*/ci.subgroupSize, - /*mode=*/gpu::ShuffleMode::XOR) + Value shuffled = gpu::ShuffleOp::create(builder, loc, packFn(laneVal), i, + /*width=*/ci.subgroupSize, + /*mode=*/gpu::ShuffleMode::XOR) .getShuffleResult(); laneVal = vector::makeArithReduction(builder, loc, gpu::convertReductionKind(mode), diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index d987b72..ff55f17 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -21,10 +21,7 @@ add_mlir_dialect_library(MLIRLLVMDialect intrinsics_gen LINK_COMPONENTS - AsmParser BinaryFormat - BitReader - BitWriter Core LINK_LIBS PUBLIC diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 5b01596..422039f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -26,8 +26,7 @@ #include "llvm/ADT/APFloat.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Type.h" +#include "llvm/IR/DataLayout.h" #include "llvm/Support/Error.h" #include <numeric> @@ -2707,7 +2706,7 @@ LogicalResult IFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) { while (alias) { Block &initBlock = alias.getInitializerBlock(); auto returnOp = cast<ReturnOp>(initBlock.getTerminator()); - auto addrOp = dyn_cast<AddressOfOp>(returnOp.getArg().getDefiningOp()); + auto addrOp = returnOp.getArg().getDefiningOp<AddressOfOp>(); // FIXME: This is a best effort solution. The AliasOp body might be more // complex and in that case we bail out with success. To completely match // the LLVM IR logic it would be necessary to implement proper alias and @@ -4064,28 +4063,9 @@ void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, } void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, - Value cond, - ArrayRef<llvm::OperandBundleDefT<Value>> opBundles) { - SmallVector<ValueRange> opBundleOperands; - SmallVector<Attribute> opBundleTags; - opBundleOperands.reserve(opBundles.size()); - opBundleTags.reserve(opBundles.size()); - - for (const llvm::OperandBundleDefT<Value> &bundle : opBundles) { - opBundleOperands.emplace_back(bundle.inputs()); - opBundleTags.push_back( - StringAttr::get(builder.getContext(), bundle.getTag())); - } - - auto opBundleTagsAttr = ArrayAttr::get(builder.getContext(), opBundleTags); - return build(builder, state, cond, opBundleOperands, opBundleTagsAttr); -} - -void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, Value cond, llvm::StringRef tag, ValueRange args) { - llvm::OperandBundleDefT<Value> opBundle( - tag.str(), SmallVector<Value>(args.begin(), args.end())); - return build(builder, state, cond, opBundle); + return build(builder, state, cond, ArrayRef<ValueRange>(args), + builder.getStrArrayAttr(tag)); } void LLVM::AssumeOp::build(OpBuilder &builder, OperationState &state, diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 6e29b12..52cd0ce 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -30,15 +30,9 @@ #include "mlir/IR/Types.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/AsmParser/Parser.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicsNVPTX.h" -#include "llvm/IR/Type.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" #include <cassert> #include <optional> diff --git a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp index 1a9ccf5..17371ec 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/ROCDLDialect.cpp @@ -24,7 +24,6 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/IR/Type.h" using namespace mlir; using namespace ROCDL; diff --git a/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp index bd9d3528..1d4a0af 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/VCIXDialect.cpp @@ -20,11 +20,6 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/AsmParser/Parser.h" -#include "llvm/IR/Attributes.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/Type.h" -#include "llvm/Support/SourceMgr.h" using namespace mlir; using namespace vcix; diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp index 935aa3c..b951df8 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp @@ -22,6 +22,8 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" + #define DEBUG_TYPE "llvm-inliner" using namespace mlir; @@ -670,44 +672,42 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { bool wouldBeCloned) const final { auto callOp = dyn_cast<LLVM::CallOp>(call); if (!callOp) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '" - << LLVM::CallOp::getOperationName() << "' op\n"); + LDBG() << "Cannot inline: call is not an '" + << LLVM::CallOp::getOperationName() << "' op"; return false; } if (callOp.getNoInline()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is marked no_inline\n"); + LDBG() << "Cannot inline: call is marked no_inline"; return false; } auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable); if (!funcOp) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline: callable is not an '" - << LLVM::LLVMFuncOp::getOperationName() << "' op\n"); + LDBG() << "Cannot inline: callable is not an '" + << LLVM::LLVMFuncOp::getOperationName() << "' op"; return false; } if (funcOp.isNoInline()) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline: function is marked no_inline\n"); + LDBG() << "Cannot inline: function is marked no_inline"; return false; } if (funcOp.isVarArg()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n"); + LDBG() << "Cannot inline: callable is variadic"; return false; } // TODO: Generate aliasing metadata from noalias result attributes. if (auto attrs = funcOp.getArgAttrs()) { for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) { if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() - << ": inalloca arguments not supported\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": inalloca arguments not supported"; return false; } } } // TODO: Handle exceptions. if (funcOp.getPersonality()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName() - << ": unhandled function personality\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": unhandled function personality"; return false; } if (funcOp.getPassthrough()) { @@ -717,10 +717,8 @@ struct LLVMInlinerInterface : public DialectInlinerInterface { if (!stringAttr) return false; if (disallowedFunctionAttrs.contains(stringAttr)) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot inline " << funcOp.getSymName() - << ": found disallowed function attribute " - << stringAttr << "\n"); + LDBG() << "Cannot inline " << funcOp.getSymName() + << ": found disallowed function attribute " << stringAttr; return true; } return false; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp index b6e168e..7f6ecab 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp @@ -15,7 +15,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Interfaces/SubsetOpInterface.h" @@ -119,8 +119,8 @@ void mlir::linalg::LinalgDialect::initialize() { addInterfaces<LinalgInlinerInterface>(); - declarePromisedInterface<mesh::ShardingInterface, GenericOp>(); - declarePromisedInterfaces<mesh::ShardingInterface, + declarePromisedInterface<shard::ShardingInterface, GenericOp>(); + declarePromisedInterfaces<shard::ShardingInterface, #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" >(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index f49d9a1..73ae029 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -476,10 +476,10 @@ inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps, SmallVector<unsigned, 2>(ac.begin(), ac.end()), SmallVector<unsigned, 2>(bc.begin(), bc.end()), SmallVector<unsigned, 2>(ra.begin(), ra.end())}; - llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); - llvm::sort(dimensions.m.begin(), dimensions.m.end()); - llvm::sort(dimensions.n.begin(), dimensions.n.end()); - llvm::sort(dimensions.k.begin(), dimensions.k.end()); + llvm::sort(dimensions.batch); + llvm::sort(dimensions.m); + llvm::sort(dimensions.n); + llvm::sort(dimensions.k); return dimensions; } @@ -797,12 +797,12 @@ inferConvolutionDimsImpl(LinalgOp linalgOp, SmallVector<unsigned, 2>(depth.begin(), depth.end()), /*strides=*/SmallVector<int64_t, 2>{}, /*dilations=*/SmallVector<int64_t, 2>{}}; - llvm::sort(dimensions.batch.begin(), dimensions.batch.end()); - llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end()); - llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end()); - llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end()); - llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end()); - llvm::sort(dimensions.depth.begin(), dimensions.depth.end()); + llvm::sort(dimensions.batch); + llvm::sort(dimensions.outputImage); + llvm::sort(dimensions.outputChannel); + llvm::sort(dimensions.filterLoop); + llvm::sort(dimensions.inputChannel); + llvm::sort(dimensions.depth); // Use the op carried strides/dilations attribute if present. auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides"); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 3aa6ac3..34c63d3 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -32,6 +32,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -62,10 +63,10 @@ static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v, return getAsOpFoldResult( TypeSwitch<Type, Value>(v.getType()) .Case<RankedTensorType>([&](RankedTensorType t) -> Value { - return builder.create<tensor::DimOp>(loc, v, dim); + return tensor::DimOp::create(builder, loc, v, dim); }) .Case<MemRefType>([&](MemRefType t) -> Value { - return builder.create<memref::DimOp>(loc, v, dim); + return memref::DimOp::create(builder, loc, v, dim); })); } @@ -77,12 +78,12 @@ static Operation *getSlice(OpBuilder &b, Location loc, Value source, ArrayRef<OpFoldResult> strides) { return TypeSwitch<Type, Operation *>(source.getType()) .Case<RankedTensorType>([&](RankedTensorType t) -> Operation * { - return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes, - strides); + return tensor::ExtractSliceOp::create(b, loc, source, offsets, sizes, + strides); }) .Case<MemRefType>([&](MemRefType type) -> Operation * { - return b.create<memref::SubViewOp>(loc, source, offsets, sizes, - strides); + return memref::SubViewOp::create(b, loc, source, offsets, sizes, + strides); }) .Default([&](Type t) -> Operation * { return nullptr; }); } @@ -453,35 +454,35 @@ public: builder.setInsertionPointToEnd(&block); switch (unaryFn) { case UnaryFn::exp: - return builder.create<math::ExpOp>(arg.getLoc(), arg); + return math::ExpOp::create(builder, arg.getLoc(), arg); case UnaryFn::log: - return builder.create<math::LogOp>(arg.getLoc(), arg); + return math::LogOp::create(builder, arg.getLoc(), arg); case UnaryFn::abs: - return builder.create<math::AbsFOp>(arg.getLoc(), arg); + return math::AbsFOp::create(builder, arg.getLoc(), arg); case UnaryFn::ceil: - return builder.create<math::CeilOp>(arg.getLoc(), arg); + return math::CeilOp::create(builder, arg.getLoc(), arg); case UnaryFn::floor: - return builder.create<math::FloorOp>(arg.getLoc(), arg); + return math::FloorOp::create(builder, arg.getLoc(), arg); case UnaryFn::negf: - return builder.create<arith::NegFOp>(arg.getLoc(), arg); + return arith::NegFOp::create(builder, arg.getLoc(), arg); case UnaryFn::reciprocal: { Attribute oneAttr = builder.getOneAttr(arg.getType()); - auto one = builder.create<arith::ConstantOp>(arg.getLoc(), - ::cast<TypedAttr>(oneAttr)); - return builder.create<arith::DivFOp>(arg.getLoc(), one, arg); + auto one = arith::ConstantOp::create(builder, arg.getLoc(), + ::cast<TypedAttr>(oneAttr)); + return arith::DivFOp::create(builder, arg.getLoc(), one, arg); } case UnaryFn::round: - return builder.create<math::RoundOp>(arg.getLoc(), arg); + return math::RoundOp::create(builder, arg.getLoc(), arg); case UnaryFn::sqrt: - return builder.create<math::SqrtOp>(arg.getLoc(), arg); + return math::SqrtOp::create(builder, arg.getLoc(), arg); case UnaryFn::rsqrt: - return builder.create<math::RsqrtOp>(arg.getLoc(), arg); + return math::RsqrtOp::create(builder, arg.getLoc(), arg); case UnaryFn::square: - return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg); + return arith::MulFOp::create(builder, arg.getLoc(), arg, arg); case UnaryFn::tanh: - return builder.create<math::TanhOp>(arg.getLoc(), arg); + return math::TanhOp::create(builder, arg.getLoc(), arg); case UnaryFn::erf: - return builder.create<math::ErfOp>(arg.getLoc(), arg); + return math::ErfOp::create(builder, arg.getLoc(), arg); } if (emitError) { emitError() << "unsupported unary function"; @@ -516,17 +517,17 @@ public: switch (binaryFn) { case BinaryFn::add: if (allComplex) - return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1); + return complex::AddOp::create(builder, arg0.getLoc(), arg0, arg1); if (allFloatingPoint) - return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1); + return arith::AddFOp::create(builder, arg0.getLoc(), arg0, arg1); if (allBool) - return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1); - return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1); + return arith::OrIOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::AddIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::sub: if (allComplex) - return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1); + return complex::SubOp::create(builder, arg0.getLoc(), arg0, arg1); if (allFloatingPoint) - return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1); + return arith::SubFOp::create(builder, arg0.getLoc(), arg0, arg1); if (allBool) { if (emitError) { emitError() << "unsupported operation: sub with bools"; @@ -534,20 +535,20 @@ public: } llvm_unreachable("unsupported operation: sub with bools"); } - return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1); + return arith::SubIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::mul: if (allComplex) - return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1); + return complex::MulOp::create(builder, arg0.getLoc(), arg0, arg1); if (allFloatingPoint) - return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1); + return arith::MulFOp::create(builder, arg0.getLoc(), arg0, arg1); if (allBool) - return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1); - return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1); + return arith::AndIOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MulIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::div: if (allComplex) - return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1); + return complex::DivOp::create(builder, arg0.getLoc(), arg0, arg1); if (allFloatingPoint) - return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1); + return arith::DivFOp::create(builder, arg0.getLoc(), arg0, arg1); if (allBool) { if (emitError) { emitError() << "unsupported operation: div with bools"; @@ -555,7 +556,7 @@ public: } llvm_unreachable("unsupported operation: div with bools"); } - return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1); + return arith::DivSIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::div_unsigned: if (!allInteger || allBool) { if (emitError) { @@ -564,30 +565,30 @@ public: } llvm_unreachable("unsupported operation: unsigned div not on uint"); } - return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1); + return arith::DivUIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::max_signed: assert(!allComplex); if (allFloatingPoint) - return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1); - return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1); + return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MaxSIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::min_signed: assert(!allComplex); if (allFloatingPoint) - return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1); - return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1); + return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MinSIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::max_unsigned: assert(!allComplex); if (allFloatingPoint) - return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1); - return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1); + return arith::MaximumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MaxUIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::min_unsigned: assert(!allComplex); if (allFloatingPoint) - return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1); - return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1); + return arith::MinimumFOp::create(builder, arg0.getLoc(), arg0, arg1); + return arith::MinUIOp::create(builder, arg0.getLoc(), arg0, arg1); case BinaryFn::powf: assert(allFloatingPoint); - return builder.create<math::PowFOp>(arg0.getLoc(), arg0, arg1); + return math::PowFOp::create(builder, arg0.getLoc(), arg0, arg1); } if (emitError) { emitError() << "unsupported binary function"; @@ -610,7 +611,7 @@ public: case TernaryFn::select: if (!headBool && !(tailFloatingPoint || tailInteger)) llvm_unreachable("unsupported non numeric type"); - return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2); + return arith::SelectOp::create(builder, arg0.getLoc(), arg0, arg1, arg2); } if (emitError) { emitError() << "unsupported ternary function"; @@ -639,7 +640,7 @@ public: OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); Location loc = builder.getUnknownLoc(); - builder.create<YieldOp>(loc, values); + YieldOp::create(builder, loc, values); } Value constant(const std::string &value) { @@ -647,13 +648,14 @@ public: builder.setInsertionPointToEnd(&block); Location loc = builder.getUnknownLoc(); Attribute valueAttr = parseAttribute(value, builder.getContext()); - return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr)); + return arith::ConstantOp::create(builder, loc, + ::cast<TypedAttr>(valueAttr)); } Value index(int64_t dim) { OpBuilder::InsertionGuard g(builder); builder.setInsertionPointToEnd(&block); - return builder.create<IndexOp>(builder.getUnknownLoc(), dim); + return IndexOp::create(builder, builder.getUnknownLoc(), dim); } Type getIntegerType(unsigned width) { @@ -749,14 +751,14 @@ struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> { TensorReshapeOp newInit; if constexpr (std::is_same<TensorReshapeOp, tensor::ExpandShapeOp>::value) { - newInit = rewriter.create<TensorReshapeOp>( - loc, reshapeOp.getResultType(), oldFill.output(), + newInit = TensorReshapeOp::create( + rewriter, loc, reshapeOp.getResultType(), oldFill.output(), reshapeOp.getReassociation(), reshapeOp.getOutputShape(), reshapeOp.getStaticOutputShape()); } else { - newInit = rewriter.create<TensorReshapeOp>(loc, reshapeOp.getResultType(), - oldFill.output(), - reshapeOp.getReassociation()); + newInit = TensorReshapeOp::create( + rewriter, loc, reshapeOp.getResultType(), oldFill.output(), + reshapeOp.getReassociation()); } rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()}, ValueRange{newInit}); @@ -786,17 +788,16 @@ struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> { return rewriter.notifyMatchFailure( padOp, "failed to reify tensor.pad op result shape"); - auto emptyTensor = rewriter.create<tensor::EmptyOp>( - padOp.getLoc(), reifiedShape.front(), - padOp.getResultType().getElementType()); + auto emptyTensor = + tensor::EmptyOp::create(rewriter, padOp.getLoc(), reifiedShape.front(), + padOp.getResultType().getElementType()); Value replacement = - rewriter - .create<FillOp>(fillOp.getLoc(), ValueRange{padValue}, - ValueRange{emptyTensor}) + FillOp::create(rewriter, fillOp.getLoc(), ValueRange{padValue}, + ValueRange{emptyTensor}) .getResult(0); if (replacement.getType() != padOp.getResultType()) { - replacement = rewriter.create<tensor::CastOp>( - fillOp.getLoc(), padOp.getResultType(), replacement); + replacement = tensor::CastOp::create(rewriter, fillOp.getLoc(), + padOp.getResultType(), replacement); } rewriter.replaceOp(padOp, replacement); return success(); @@ -889,7 +890,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> { for (int i = 0, e = srcPadType.getRank(); i < e; ++i) { if (srcPadType.isDynamicDim(i)) { newSizes.push_back( - rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i) + tensor::DimOp::create(rewriter, loc, srcPadOp.getSource(), i) .getResult()); } else { newSizes.push_back(rewriter.getIndexAttr(srcPadType.getDimSize(i))); @@ -942,8 +943,8 @@ static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter, if (!packOpDest.hasOneUse()) return failure(); - return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(), - packOp.getDest()); + return linalg::FillOp::create(rewriter, packOp.getLoc(), fillOp.getInputs(), + packOp.getDest()); } /// Wrapper pattern that applies foldFillPackIntoFillOp method. @@ -1042,8 +1043,8 @@ struct FoldConcatsOfFill : public OpRewritePattern<tensor::ConcatOp> { concatOp, "not all operands are defined by a compatible fill op"); } - Value outsConcat = rewriter.create<tensor::ConcatOp>( - concatOp.getLoc(), concatOp.getDim(), allOuts); + Value outsConcat = tensor::ConcatOp::create(rewriter, concatOp.getLoc(), + concatOp.getDim(), allOuts); rewriter.replaceOpWithNewOp<linalg::FillOp>( concatOp, firstFillOp.getDpsInputOperand(0)->get(), outsConcat); return success(); @@ -1407,14 +1408,14 @@ struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> { // TODO: unify the two ops? if (sparse_tensor::getSparseTensorEncoding(returnType) || sparse_tensor::getSparseTensorEncoding(resultType)) - returnedArg = rewriter.create<sparse_tensor::ConvertOp>( - linalgOp.getLoc(), resultType, returnedArg); + returnedArg = sparse_tensor::ConvertOp::create( + rewriter, linalgOp.getLoc(), resultType, returnedArg); else { if (!tensor::CastOp::areCastCompatible(returnedArg.getType(), resultType)) return failure(); - returnedArg = rewriter.create<tensor::CastOp>( - linalgOp.getLoc(), resultType, returnedArg); + returnedArg = tensor::CastOp::create(rewriter, linalgOp.getLoc(), + resultType, returnedArg); } } returnedArgs.push_back(returnedArg); @@ -1528,7 +1529,7 @@ static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result, TypeRange{llvm::cast<ShapedType>(result.operands.back().getType()) .getElementType()}, payloadOpAttrs); - b.create<YieldOp>(result.location, payloadOp->getResults()); + YieldOp::create(b, result.location, payloadOp->getResults()); } ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { @@ -1945,7 +1946,7 @@ static void buildIdentityRegion(OpBuilder &builder, Location loc, buildGenericRegion(builder, loc, region, inputs, outputs, [](OpBuilder &b, Location loc, ValueRange args) { if (!args.empty()) - b.create<linalg::YieldOp>(loc, args[0]); + linalg::YieldOp::create(b, loc, args[0]); }); } @@ -2138,7 +2139,7 @@ struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> { unsigned inputRank = broadcastInputTy.getRank(); for (unsigned i = 0; i < inputRank; ++i) { if (broadcastInputTy.isDynamicDim(i)) { - dims.push_back(rewriter.create<tensor::DimOp>(loc, broadcastInput, i) + dims.push_back(tensor::DimOp::create(rewriter, loc, broadcastInput, i) ->getResult(0)); } else { dims.push_back(IntegerAttr::get(IndexType::get(ctx), @@ -2147,15 +2148,14 @@ struct SwapTransposeWithBroadcast : OpRewritePattern<linalg::TransposeOp> { } SmallVector<OpFoldResult> transposeResultShapes = applyPermutation(dims, resultPerms); - Value transposeInit = rewriter.create<tensor::EmptyOp>( - transposeOp.getLoc(), transposeResultShapes, + Value transposeInit = tensor::EmptyOp::create( + rewriter, transposeOp.getLoc(), transposeResultShapes, broadcastInputTy.getElementType()); // Create broadcast(transpose(input)). Value transposeResult = - rewriter - .create<TransposeOp>(loc, broadcastOp.getInput(), transposeInit, - resultPerms) + TransposeOp::create(rewriter, loc, broadcastOp.getInput(), + transposeInit, resultPerms) ->getResult(0); rewriter.replaceOpWithNewOp<BroadcastOp>( transposeOp, transposeResult, transposeOp.getInit(), resultDimensions); @@ -2293,9 +2293,39 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); } +/// Fold back-to-back broadcasts together. +struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> { + using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp, + PatternRewriter &rewriter) const override { + auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>(); + if (!defBroadcastOp) + return failure(); + ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions(); + ArrayRef<int64_t> dimensions = broadcastOp.getDimensions(); + SmallVector<int64_t> foldedDims(dimensions); + Value init = broadcastOp.getInit(); + int64_t initRank = cast<ShapedType>(init.getType()).getRank(); + // Mapping from input dims to init dims. + SmallVector<int64_t> dimMap; + for (auto dim : llvm::seq<int64_t>(0, initRank)) { + if (!llvm::is_contained(dimensions, dim)) + dimMap.push_back(dim); + } + for (auto dim : defDimensions) + foldedDims.push_back(dimMap[dim]); + + llvm::sort(foldedDims); + rewriter.replaceOpWithNewOp<BroadcastOp>( + broadcastOp, defBroadcastOp.getInput(), init, foldedDims); + return success(); + } +}; + void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<EraseIdentityLinalgOp<BroadcastOp>>(context); + results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context); } //===----------------------------------------------------------------------===// @@ -2547,7 +2577,7 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> { // continue to propagate as far up the stack as it can go. OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber); Value newOperand = - rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get()); + tensor::CastOp::create(rewriter, loc, resultType, outOperand->get()); SmallVector<Value> newOperands = linalgOp.getDpsInputs(); SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(), linalgOp.getDpsInits().end()); @@ -2560,8 +2590,8 @@ struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> { Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands); // Create a tensor.cast operation back to the original type. - Value castBack = rewriter.create<tensor::CastOp>( - loc, resultValue.getType(), newOp->getResult(resultNumber)); + Value castBack = tensor::CastOp::create( + rewriter, loc, resultValue.getType(), newOp->getResult(resultNumber)); SmallVector<Value> results(newOp->result_begin(), newOp->result_end()); results[resultNumber] = castBack; @@ -2653,7 +2683,7 @@ static void createNewOperandWithStaticSizes( changeNeeded = true; // Get the new operand value given its size and element type by // casting it. - Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src); + Value newOperand = tensor::CastOp::create(rewriter, loc, resultType, src); unsigned index = opOperand->getOperandNumber(); newOperands[index] = newOperand; } @@ -2718,7 +2748,7 @@ struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> { Type oldType = oldResult.getType(); replacements.push_back( (newType != oldType) - ? rewriter.create<tensor::CastOp>(loc, oldType, newResult) + ? tensor::CastOp::create(rewriter, loc, oldType, newResult) : newResult); } rewriter.replaceOp(linalgOp, replacements); @@ -2756,8 +2786,8 @@ SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) { int64_t operandRank = getInputOperandRank(); SmallVector<Range> loopBounds(operandRank); Location loc = getLoc(); - Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); - Value one = builder.create<arith::ConstantIndexOp>(loc, 1); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + Value one = arith::ConstantIndexOp::create(builder, loc, 1); Value source = getInput(); for (auto dim : llvm::seq<int64_t>(0, operandRank)) { loopBounds[dim].offset = zero; @@ -2924,11 +2954,11 @@ static Value reduce(OpBuilder &builder, Location loc, Value input, Value output, "We should have two maps: 1 for the input, 1 for the output"); assert(indexingMaps[0].isIdentity() && "input map should be identity"); - auto genericOp = builder.create<linalg::GenericOp>( - loc, output.getType(), input, output, indexingMaps, iteratorTypes, - [&](OpBuilder &b, Location loc, ValueRange args) { - Value result = b.create<T>(loc, args[0], args[1]); - b.create<linalg::YieldOp>(loc, result); + auto genericOp = linalg::GenericOp::create( + builder, loc, output.getType(), input, output, indexingMaps, + iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { + Value result = T::create(b, loc, args[0], args[1]); + linalg::YieldOp::create(b, loc, result); }); return genericOp.getResult(0); } @@ -2947,12 +2977,13 @@ static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input, assert(indexingMaps[0].isIdentity() && "input map should be identity"); // Add the affine map for the output argument. indexingMaps.push_back(indexingMaps[0]); - auto genericOp = builder.create<linalg::GenericOp>( - loc, input.getType(), ValueRange{input, max}, output, indexingMaps, - iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]); - Value result = b.create<math::ExpOp>(loc, diff); - b.create<linalg::YieldOp>(loc, result); + auto genericOp = linalg::GenericOp::create( + builder, loc, input.getType(), ValueRange{input, max}, output, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value diff = arith::SubFOp::create(b, loc, args[0], args[1]); + Value result = math::ExpOp::create(b, loc, diff); + linalg::YieldOp::create(b, loc, result); }); return genericOp.getResult(0); } @@ -2974,12 +3005,12 @@ static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator, assert(indexingMaps[0].isIdentity() && "Numerator map should be identity"); // Add the affine map for the output tensor. indexingMaps.push_back(indexingMaps[0]); - auto genericOp = builder.create<linalg::GenericOp>( - loc, numerator.getType(), ValueRange{numerator, denominator}, output, - indexingMaps, iteratorTypes, + auto genericOp = linalg::GenericOp::create( + builder, loc, numerator.getType(), ValueRange{numerator, denominator}, + output, indexingMaps, iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { - Value result = b.create<arith::DivFOp>(loc, args[0], args[1]); - b.create<linalg::YieldOp>(loc, result); + Value result = arith::DivFOp::create(b, loc, args[0], args[1]); + linalg::YieldOp::create(b, loc, result); }); return genericOp.getResult(0); } @@ -3015,12 +3046,12 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) { Value output = getOutput(); dims.erase(dims.begin() + reductionDim); // Step 1: Compute max along dim. - Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType); + Value outputReduce = tensor::EmptyOp::create(b, loc, dims, elementType); Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maxnumf, elementType, b, loc, /*useOnlyFiniteValue=*/true); Value neutralForMaxFInit = - b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce) + linalg::FillOp::create(b, loc, Value{neutralForMaxF}, outputReduce) .result(); Value max = reduce<arith::MaxNumFOp>(b, loc, input, neutralForMaxFInit, reductionDim); @@ -3032,7 +3063,7 @@ FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) { Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType, b, loc, /*useOnlyFiniteValue=*/true); Value zeroInit = - b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result(); + linalg::FillOp::create(b, loc, Value{zero}, outputReduce).result(); Value denominator = reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim); @@ -3153,8 +3184,8 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation( int64_t filterRank = getFilterOperandRank(); SmallVector<OpFoldResult> filterStrides(filterRank, oneAttr); Location loc = getLoc(); - auto filterSlice = builder.create<tensor::ExtractSliceOp>( - loc, getFilter(), sliceOffsets, sliceSizes, filterStrides); + auto filterSlice = tensor::ExtractSliceOp::create( + builder, loc, getFilter(), sliceOffsets, sliceSizes, filterStrides); tiledOperands.emplace_back(filterSlice); SmallVector<OpFoldResult> resultOffsets, resultSizes; @@ -3164,8 +3195,8 @@ FailureOr<TilingResult> WinogradFilterTransformOp::getTiledImplementation( int64_t outputRank = getOutputOperandRank(); SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr); - auto outputSlice = builder.create<tensor::ExtractSliceOp>( - loc, getOutput(), resultOffsets, resultSizes, outputStrides); + auto outputSlice = tensor::ExtractSliceOp::create( + builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides); tiledOperands.emplace_back(outputSlice); SmallVector<Type> resultTypes; @@ -3333,8 +3364,8 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, {sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]}); int64_t inputRank = getInputOperandRank(); SmallVector<OpFoldResult> inputStrides(inputRank, oneAttr); - auto inputSlice = builder.create<tensor::ExtractSliceOp>( - loc, getInput(), sliceOffsets, sliceSizes, inputStrides); + auto inputSlice = tensor::ExtractSliceOp::create( + builder, loc, getInput(), sliceOffsets, sliceSizes, inputStrides); tiledOperands.emplace_back(inputSlice); SmallVector<OpFoldResult> resultOffsets, resultSizes; @@ -3344,8 +3375,8 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder, int64_t outputRank = getOutputOperandRank(); SmallVector<OpFoldResult> outputStrides(outputRank, oneAttr); - auto outputSlice = builder.create<tensor::ExtractSliceOp>( - loc, getOutput(), resultOffsets, resultSizes, outputStrides); + auto outputSlice = tensor::ExtractSliceOp::create( + builder, loc, getOutput(), resultOffsets, resultSizes, outputStrides); tiledOperands.emplace_back(outputSlice); SmallVector<Type> resultTypes; @@ -3504,8 +3535,8 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation( sizes[getValueFDim()]}); int64_t valueRank = getValueOperandRank(); SmallVector<OpFoldResult> sliceStrides(valueRank, oneAttr); - auto valueSlice = builder.create<tensor::ExtractSliceOp>( - loc, getValue(), sliceOffsets, sliceSizes, sliceStrides); + auto valueSlice = tensor::ExtractSliceOp::create( + builder, loc, getValue(), sliceOffsets, sliceSizes, sliceStrides); tiledOperands.emplace_back(valueSlice); SmallVector<OpFoldResult> resultOffsets, resultSizes; @@ -3515,8 +3546,8 @@ FailureOr<TilingResult> WinogradOutputTransformOp::getTiledImplementation( int64_t outputRank = getOutputOperandRank(); SmallVector<OpFoldResult> strides(outputRank, oneAttr); - auto outputSlice = builder.create<tensor::ExtractSliceOp>( - loc, getOutput(), resultOffsets, resultSizes, strides); + auto outputSlice = tensor::ExtractSliceOp::create( + builder, loc, getOutput(), resultOffsets, resultSizes, strides); tiledOperands.emplace_back(outputSlice); SmallVector<Type> resultTypes; @@ -4490,6 +4521,29 @@ Speculation::Speculatability ElementwiseOp::getSpeculatability() { //===----------------------------------------------------------------------===// // PackOp/UnPackOp Common //===----------------------------------------------------------------------===// + +template <typename OpTy, typename> +SmallVector<int64_t> +getPackedOuterShapeWithoutTransposition(OpTy packOrUnPack) { + RankedTensorType packedType = (std::is_same<OpTy, PackOp>::value) + ? packOrUnPack.getDestType() + : packOrUnPack.getSourceType(); + RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value) + ? packOrUnPack.getSourceType() + : packOrUnPack.getDestType(); + SmallVector<int64_t> result( + packedType.getShape().take_front(unpackedType.getRank())); + if (!packOrUnPack.getOuterDimsPerm().empty()) { + applyPermutationToVector( + result, invertPermutationVector(packOrUnPack.getOuterDimsPerm())); + } + return result; +} +template SmallVector<int64_t> + getPackedOuterShapeWithoutTransposition<PackOp>(PackOp); +template SmallVector<int64_t> + getPackedOuterShapeWithoutTransposition<UnPackOp>(UnPackOp); + // Given the (potentially) updated packed type, `newPackedTy`, generates an // updated mixed-tile-sizes attribute. A tile size is updated only // when: @@ -4599,22 +4653,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos, }); } -/// Returns true if the dimension of `sourceShape` is smaller than the dimension -/// of the `limitShape`. -static bool areAllInBound(ArrayRef<int64_t> sourceShape, - ArrayRef<int64_t> limitShape) { - assert( - sourceShape.size() == limitShape.size() && - "expected source shape rank, and limit of the shape to have same rank"); - return llvm::all_of( - llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) { - int64_t sourceExtent = std::get<0>(it); - int64_t limit = std::get<1>(it); - return ShapedType::isDynamic(sourceExtent) || - ShapedType::isDynamic(limit) || sourceExtent <= limit; - }); -} - template <typename OpTy> static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, @@ -4673,11 +4711,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // represents full tiles. RankedTensorType expectedPackedType = PackOp::inferPackedType( unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); - if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) { - return op->emitError("the shape of output is not large enough to hold the " - "packed data. Expected at least ") - << expectedPackedType << ", got " << packedType; - } if (!llvm::all_of( llvm::zip(packedType.getShape().take_back(mixedTiles.size()), mixedTiles), @@ -4694,6 +4727,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { return op->emitError("mismatch in inner tile sizes specified and shaped of " "tiled dimension in the packed type"); } + if (failed(verifyCompatibleShape(expectedPackedType.getShape(), + packedType.getShape()))) { + return op->emitError("expected ") + << expectedPackedType << " for the packed domain value, got " + << packedType; + } return success(); } @@ -4971,7 +5010,7 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source, llvm::cast<RankedTensorType>(source.getType()).getShape())) { if (ShapedType::isDynamic(value)) mixedSizes.push_back( - b.create<tensor::DimOp>(loc, source, index).getResult()); + tensor::DimOp::create(b, loc, source, index).getResult()); else mixedSizes.push_back(b.getIndexAttr(value)); } @@ -4985,7 +5024,7 @@ Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source, mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end()); auto elemType = llvm::cast<ShapedType>(source.getType()).getElementType(); - return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType); + return tensor::EmptyOp::create(b, loc, mixedSizes, elemType); } PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, @@ -4996,9 +5035,9 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, Value transposedDest = createDestinationTensor(b, loc, getSource(), metadata.innerTiles, metadata.innerDimsPos, metadata.outerDimsPerm); - return b.create<PackOp>(loc, getSource(), transposedDest, - metadata.innerDimsPos, metadata.innerTiles, - getPaddingValue(), metadata.outerDimsPerm); + return PackOp::create(b, loc, getSource(), transposedDest, + metadata.innerDimsPos, metadata.innerTiles, + getPaddingValue(), metadata.outerDimsPerm); } /// Returns true if the tiles and the tiled dims are constant. @@ -5138,7 +5177,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { if (srcShape != packOp.getSourceType().getShape()) { auto newSrcType = packOp.getSourceType().clone(srcShape); source = - rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource()); + tensor::CastOp::create(rewriter, loc, newSrcType, packOp.getSource()); } Value dest = packOp.getDest(); RankedTensorType originalResultType = packOp.getDestType(); @@ -5146,7 +5185,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { if (needUpdateDestType) { auto newDestType = packOp.getDestType().clone(destShape); dest = - rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest()); + tensor::CastOp::create(rewriter, loc, newDestType, packOp.getDest()); } rewriter.modifyOpInPlace(packOp, [&] { packOp.getSourceMutable().assign(source); @@ -5157,7 +5196,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { if (needUpdateDestType) { rewriter.setInsertionPointAfter(packOp); auto castOp = - rewriter.create<tensor::CastOp>(loc, originalResultType, packOp); + tensor::CastOp::create(rewriter, loc, originalResultType, packOp); rewriter.replaceAllUsesExcept(packOp, castOp, castOp); } return success(); @@ -5250,18 +5289,20 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> { // TODO: Strictly speaking, discardable attributes should be _discarded_ at // this point. However, in practice, we use them for things that we'd like // to preserve. Implement a better abstraction. - PackOp newOp = rewriter.create<PackOp>( - op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), - newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); + PackOp newOp = + PackOp::create(rewriter, op.getLoc(), newOperands[0], newOperands[1], + op.getInnerDimsPos(), newMixedTileSizes, + op.getPaddingValue(), op.getOuterDimsPerm()); newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); // Replace op. Value oldResult = op.getResult(); Value newResult = newOp.getResult(); - Value replacement = (newResult.getType() != oldResult.getType()) - ? rewriter.create<tensor::CastOp>( - op->getLoc(), oldResult.getType(), newResult) - : newResult; + Value replacement = + (newResult.getType() != oldResult.getType()) + ? tensor::CastOp::create(rewriter, op->getLoc(), + oldResult.getType(), newResult) + : newResult; rewriter.replaceOp(op, {replacement}); @@ -5358,7 +5399,8 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc, for (auto i : llvm::seq<unsigned>(0, srcType.getRank() - innerTileSizes.size())) { if (srcType.isDynamicDim(i)) - mixedSizes.push_back(b.create<tensor::DimOp>(loc, source, i).getResult()); + mixedSizes.push_back( + tensor::DimOp::create(b, loc, source, i).getResult()); else mixedSizes.push_back(b.getIndexAttr(srcType.getDimSize(i))); } @@ -5371,7 +5413,7 @@ Value UnPackOp::createDestinationTensor(OpBuilder &b, Location loc, mixedSizes[dimPos] = dimMul(mixedSizes[dimPos], tileSize); auto elemType = srcType.getElementType(); - return b.create<tensor::EmptyOp>(loc, mixedSizes, elemType); + return tensor::EmptyOp::create(b, loc, mixedSizes, elemType); } UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc, @@ -5380,9 +5422,9 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc, ArrayRef<int64_t> outerPermutation) { PackOrUnPackTransposeResult metadata = commonPermutationOfPackAndUnPackOp( *this, innerPermutation, outerPermutation); - return b.create<UnPackOp>(loc, transposedSource, getDest(), - metadata.innerDimsPos, metadata.innerTiles, - metadata.outerDimsPerm); + return UnPackOp::create(b, loc, transposedSource, getDest(), + metadata.innerDimsPos, metadata.innerTiles, + metadata.outerDimsPerm); } /// Returns true if the `srcShape` or `destShape` is different from the one in @@ -5447,15 +5489,11 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, if (unPackOp->hasOneUse()) { auto extractSliceUser = dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin()); - if (extractSliceUser && - areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) && - areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) && - extractSliceUser.getSourceType().getRank() == - extractSliceUser.getResultType().getRank()) { + if (extractSliceUser && unPackOp.canFoldSliceOp(extractSliceUser)) { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unPackOp); - auto newDest = rewriter.create<tensor::ExtractSliceOp>( - unPackOp->getLoc(), unPackOp.getDest(), + auto newDest = tensor::ExtractSliceOp::create( + rewriter, unPackOp->getLoc(), unPackOp.getDest(), extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(), extractSliceUser.getMixedStrides()); rewriter.modifyOpInPlace(unPackOp, [&]() { @@ -5474,18 +5512,18 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, Value source = unPackOp.getSource(); if (srcShape != unPackOp.getSourceType().getShape()) { auto newSrcType = unPackOp.getSourceType().clone(srcShape); - source = rewriter.create<tensor::CastOp>(loc, newSrcType, - unPackOp.getSource()); + source = tensor::CastOp::create(rewriter, loc, newSrcType, + unPackOp.getSource()); } Value dest = unPackOp.getDest(); if (destShape != unPackOp.getDestType().getShape()) { auto newDestType = unPackOp.getDestType().clone(destShape); - dest = - rewriter.create<tensor::CastOp>(loc, newDestType, unPackOp.getDest()); + dest = tensor::CastOp::create(rewriter, loc, newDestType, + unPackOp.getDest()); } - Value newOp = rewriter.create<UnPackOp>( - loc, source, dest, unPackOp.getInnerDimsPos(), unPackOp.getMixedTiles(), - unPackOp.getOuterDimsPerm()); + Value newOp = UnPackOp::create( + rewriter, loc, source, dest, unPackOp.getInnerDimsPos(), + unPackOp.getMixedTiles(), unPackOp.getOuterDimsPerm()); rewriter.replaceOpWithNewOp<tensor::CastOp>( unPackOp, unPackOp.getResult().getType(), newOp); return success(); @@ -5494,6 +5532,32 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, return failure(); } +bool UnPackOp::canFoldSliceOp(tensor::ExtractSliceOp sliceOp) { + // Rank-reduced folding is not supported. + if (sliceOp.getResultType().getRank() != this->getDestType().getRank()) + return false; + if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || + !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) + return false; + RankedTensorType unpackedTypeAfterFold = sliceOp.getResultType(); + SmallVector<int64_t> outerShapeWithoutTranspose = + getPackedOuterShapeWithoutTransposition(*this); + for (auto [pos, tileSize] : + llvm::zip_equal(this->getInnerDimsPos(), this->getStaticInnerTiles())) { + if (unpackedTypeAfterFold.isDynamicDim(pos)) + return false; + if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos])) + return false; + if (ShapedType::isDynamic(tileSize)) + return false; + int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize - + unpackedTypeAfterFold.getDimSize(pos); + if (paddingSize >= tileSize) + return false; + } + return true; +} + bool UnPackOp::isLikeUnPad() { RankedTensorType packedTensorType = getSourceType(); return isLikePadUnPad(*this, packedTensorType); @@ -5542,18 +5606,19 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> { // TODO: Strictly speaking, discardable attributes should be _discarded_ at // this point. However, in practice, we use them for things that we'd like // to preserve. Implement a better abstraction. - UnPackOp newOp = rewriter.create<UnPackOp>( - op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(), - newMixedTileSizes, op.getOuterDimsPerm()); + UnPackOp newOp = UnPackOp::create(rewriter, op.getLoc(), sourceTensor, + newOperands[1], op.getInnerDimsPos(), + newMixedTileSizes, op.getOuterDimsPerm()); newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); // Replace op. Value oldResult = op.getResult(); Value newResult = newOp.getResult(); - Value replacement = (newResult.getType() != oldResult.getType()) - ? rewriter.create<tensor::CastOp>( - op->getLoc(), oldResult.getType(), newResult) - : newResult; + Value replacement = + (newResult.getType() != oldResult.getType()) + ? tensor::CastOp::create(rewriter, op->getLoc(), + oldResult.getType(), newResult) + : newResult; rewriter.replaceOp(op, {replacement}); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp index ce1b1b9..5c8c2de 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp @@ -12,6 +12,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" @@ -21,8 +22,6 @@ using namespace mlir; #define DEBUG_TYPE "linalg-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") static Attribute linearId0(MLIRContext *ctx) { return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim0); @@ -43,9 +42,8 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx, assert(!copySizes.empty() && copySizes.size() <= 3 && "only 1,2,3-D copies are supported for now"); - LDBG("START CopyMappingInfo, favorPredication: " << favorPredication); - LLVM_DEBUG(DBGS() << "--copy shape: " << llvm::interleaved(copySizes) - << "\n"); + LDBG() << "START CopyMappingInfo, favorPredication: " << favorPredication; + LDBG() << "--copy shape: " << llvm::interleaved(copySizes); // Greedily find the largest vector size that can be used to copy the most // minor dimension: we are in the business of filling kMaxVectorLoadBitWidth @@ -53,20 +51,19 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx, int64_t desiredVectorSize = CopyMappingInfo::maxContiguousElementsToTransfer( desiredBitAlignment, copySizes.back(), elementalBitwidth); - LDBG("--greedily determined vectorSize: " - << desiredVectorSize << " elements of " << elementalBitwidth - << "b each -> " << (desiredVectorSize * elementalBitwidth) - << "b total out of a max of " << kMaxVectorLoadBitWidth << "b"); + LDBG() << "--greedily determined vectorSize: " << desiredVectorSize + << " elements of " << elementalBitwidth << "b each -> " + << (desiredVectorSize * elementalBitwidth) + << "b total out of a max of " << kMaxVectorLoadBitWidth << "b"; status = inferNumThreads(totalNumThreads, copySizes, desiredVectorSize, favorPredication); if (status == Status::Invalid) return; - LLVM_DEBUG(DBGS() << "--copy: " << llvm::interleaved(copySizes) << "\n" - << "--numThreads: " << llvm::interleaved(this->numThreads) - << "\n" - << "--vectorSize: " << this->vectorSize << "\n"); + LDBG() << "--copy: " << llvm::interleaved(copySizes) << "\n" + << "--numThreads: " << llvm::interleaved(this->numThreads) << "\n" + << "--vectorSize: " << this->vectorSize; assert(this->numThreads.size() == copySizes.size() && "compute copy mapping expected same number of threads and copy sizes"); @@ -84,7 +81,7 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx, this->threadMapping = llvm::to_vector(ArrayRef(allThreadMappings) .take_back(this->smallestBoundingTileSizes.size())); - LLVM_DEBUG(this->print(DBGS()); llvm::dbgs() << "\n"); + LDBG() << *this; } int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer( @@ -140,7 +137,7 @@ static SmallVector<int64_t> maximizeNumThreads(ArrayRef<int64_t> sizes, "currentIndex out of bounds"); std::string indent(2 * currentIndex, '-'); if (static_cast<size_t>(currentIndex) == sizes.size() - 1) { - LDBG(indent << "mandated globalBest: " << sizes[currentIndex]); + LDBG() << indent << "mandated globalBest: " << sizes[currentIndex]; return SmallVector<int64_t>{sizes[currentIndex]}; } @@ -149,16 +146,16 @@ static SmallVector<int64_t> maximizeNumThreads(ArrayRef<int64_t> sizes, SmallVector<int64_t> factors = getFactors(s); SmallVector<int64_t> localThreadsPerDim; localThreadsPerDim.reserve(sizes.size()); - LDBG(indent << "maximizeNumThreads in " << s - << " with limit: " << maxNumThreads); + LDBG() << indent << "maximizeNumThreads in " << s + << " with limit: " << maxNumThreads; for (auto factor : factors) { auto nestedThreadsPerDim = maximizeNumThreads(sizes, currentIndex + 1, maxNumThreads / factor); int64_t localBest = factor * product(nestedThreadsPerDim); if (localBest > best && localBest <= maxNumThreads) { - LDBG(indent << "new localBest: " << localBest); - LDBG(indent << "nestedThreadsPerDim: " - << llvm::interleaved(nestedThreadsPerDim)); + LDBG() << indent << "new localBest: " << localBest; + LDBG() << indent << "nestedThreadsPerDim: " + << llvm::interleaved(nestedThreadsPerDim); localThreadsPerDim.clear(); localThreadsPerDim.push_back(factor); llvm::append_range(localThreadsPerDim, nestedThreadsPerDim); @@ -166,8 +163,8 @@ static SmallVector<int64_t> maximizeNumThreads(ArrayRef<int64_t> sizes, } } - LDBG(indent << "found globalBest: " << best); - LDBG(indent << "numThreads: " << llvm::interleaved(localThreadsPerDim)); + LDBG() << indent << "found globalBest: " << best; + LDBG() << indent << "numThreads: " << llvm::interleaved(localThreadsPerDim); return localThreadsPerDim; } @@ -192,8 +189,8 @@ transform::gpu::CopyMappingInfo::inferNumThreads(int64_t totalNumThreads, if (status == Status::Success || status == Status::Invalid) return status; - LDBG("requires predication, try reducing vector size to " - << (localVectorSize / 2)); + LDBG() << "requires predication, try reducing vector size to " + << (localVectorSize / 2); } } @@ -210,8 +207,8 @@ transform::gpu::CopyMappingInfo::inferNumThreadsImpl( assert(sizes.back() % desiredVectorSize == 0 && "most-minor size not divisible by actualVectorSize"); - LDBG("inferNumThreadsImpl with totalNumThreads: " - << totalNumThreads << " and vectorSize: " << desiredVectorSize); + LDBG() << "inferNumThreadsImpl with totalNumThreads: " << totalNumThreads + << " and vectorSize: " << desiredVectorSize; // Scale the most minor size to account for the chosen vector size and // maximize the number of threads without exceeding the total number of @@ -219,22 +216,22 @@ transform::gpu::CopyMappingInfo::inferNumThreadsImpl( SmallVector<int64_t> scaledSizes(sizes); scaledSizes.back() /= desiredVectorSize; if (scaledSizes.back() > totalNumThreads) { - LDBG("--Too few threads given the required vector size -> FAIL"); + LDBG() << "--Too few threads given the required vector size -> FAIL"; return Status::Invalid; } SmallVector<int64_t> inferredNumThreads = maximizeNumThreads(scaledSizes, 0, totalNumThreads); - LDBG("inferred numThreads: " << llvm::interleaved(inferredNumThreads)); - LDBG("computed actualVectorSize: " << desiredVectorSize); + LDBG() << "inferred numThreads: " << llvm::interleaved(inferredNumThreads); + LDBG() << "computed actualVectorSize: " << desiredVectorSize; // Corner case: we cannot use more threads than available. If the dimension of // the copy is so bad it is because higher-level tiling did not do its job, we // do not try to recover from it here. int64_t totalNumThreadsUsed = product(inferredNumThreads); - LDBG("--totalNumThreadsUsed: " << totalNumThreadsUsed); + LDBG() << "--totalNumThreadsUsed: " << totalNumThreadsUsed; if (totalNumThreadsUsed == 0 || totalNumThreadsUsed > totalNumThreads) { - LDBG("--Too few threads given the required vector size -> FAIL"); + LDBG() << "--Too few threads given the required vector size -> FAIL"; return Status::Invalid; } diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp index 2fe72a3..d4a3e5f 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp @@ -15,14 +15,13 @@ #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InterleavedRange.h" using namespace mlir; #define DEBUG_TYPE "linalg-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") //===----------------------------------------------------------------------===// // StructuredMatchOp @@ -39,7 +38,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( return emitSilenceableError() << "expected a Linalg op"; } // If errors are suppressed, succeed and set all results to empty lists. - LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op"); + LDBG() << "optional nested matcher expected a Linalg op"; results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation())); return DiagnosedSilenceableFailure::success(); } @@ -75,8 +74,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation( // When they are defined in this block, we additionally check if we have // already applied the operation that defines them. If not, the // corresponding results will be set to empty lists. - LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage() - << "\n"); + LDBG() << "optional nested matcher failed: " << diag.getMessage(); (void)diag.silence(); SmallVector<OpOperand *> undefinedOperands; for (OpOperand &terminatorOperand : diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 109e5b7..bdfc8d0 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -40,7 +40,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/LogicalResult.h" #include <type_traits> @@ -49,9 +49,6 @@ using namespace mlir::linalg; using namespace mlir::transform; #define DEBUG_TYPE "linalg-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") -#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") /// Attempts to apply the pattern specified as template argument to the given /// operation. The pattern is expected to have a `returningMatchAndRewrite` @@ -672,9 +669,10 @@ static Operation *replaceForAllWithNewSignature( newOuts.push_back(outputs[resultNumber]); // Create new scf.forall op - auto newforallOp = rewriter.create<scf::ForallOp>( - loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), - forallOp.getMixedStep(), newOuts, forallOp.getMapping()); + auto newforallOp = scf::ForallOp::create( + rewriter, loc, forallOp.getMixedLowerBound(), + forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts, + forallOp.getMapping()); rewriter.eraseBlock(newforallOp.getBody()); newforallOp.getRegion().takeBody(forallOp.getRegion()); @@ -699,8 +697,8 @@ static Operation *replaceForAllWithNewSignature( Value src = tileAndFuseResult.tiledValues[0]; Value dst = newforallOp.getRegionIterArgs().back(); SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1)); - rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src, - dst, offsets, sizes, strides); + tensor::ParallelInsertSliceOp::create(rewriter, firstYieldOp->getLoc(), src, + dst, offsets, sizes, strides); for (auto result : llvm::enumerate(forallOp.getResults())) { rewriter.replaceAllUsesWith(result.value(), @@ -772,7 +770,7 @@ static bool sameOrEquivalentIterArg(Value src, Value dst) { static std::tuple<SmallVector<Operation *>, Operation *> tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { - LLVM_DEBUG(DBGS() << "Try to fuse a direct extract use\n"); + LDBG() << "Try to fuse a direct extract use"; auto tileableProducer = dyn_cast<TilingInterface>(producerOp); if (!tileableProducer) { diag.attachNote(producerOp->getLoc()) @@ -837,7 +835,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, // Tile the producer. int64_t resultNumber = cast<OpResult>(sliceOpToTile.getSource()).getResultNumber(); - LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); + LDBG() << "resultNumber: " << resultNumber; SmallVector<OpFoldResult> offsets = sliceOpToTile.getMixedOffsets(); SmallVector<OpFoldResult> sizes = sliceOpToTile.getMixedSizes(); @@ -854,7 +852,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, #ifndef NDEBUG for (auto *tiledOp : tileAndFuseResult->tiledOps) { - LLVM_DEBUG(DBGS() << "tiledProducer: " << *tiledOp << "\n"); + LDBG() << "tiledProducer: " << *tiledOp; } #endif @@ -893,7 +891,7 @@ static SmallVector<Operation *> tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { - LLVM_DEBUG(DBGS() << "Try to fuse an extract use through block argument\n"); + LDBG() << "Try to fuse an extract use through block argument"; auto tileableProducer = dyn_cast<TilingInterface>(producerOp); if (!tileableProducer) { @@ -946,7 +944,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( // Replace the use in the tileableProducer before tiling: clone, replace and // then tile. int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber(); - LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); + LDBG() << "resultNumber: " << resultNumber; // Gather destination tensors. SmallVector<Value> destinationTensors; @@ -995,7 +993,7 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, Operation *containingOp) { - LLVM_DEBUG(DBGS() << "Try to fuse an use by cloning\n"); + LDBG() << "Try to fuse an use by cloning"; // Gather all uses inside the containing op. SmallVector<OpOperand *> uses; @@ -1029,7 +1027,7 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) && "Parallel insert slice is not a valid clone destination"); unsigned resultNumber = cast<OpResult>(use->get()).getResultNumber(); - LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n"); + LDBG() << "resultNumber: " << resultNumber; OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(use->getOwner()); @@ -1112,7 +1110,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, auto [tiledOps, newContainingOp] = tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); if (!tiledOps.empty()) { - LLVM_DEBUG(DBGS() << "\nFused a direct extract use\n" << *containingOp); + LDBG() << "\nFused a direct extract use\n" << *containingOp; fusedOps.append(tiledOps); if (newContainingOp) { // Update handles associated with the containing op so we don't need to @@ -1138,8 +1136,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( rewriter, diag, producerOp, containingOp); if (!tiledContainingOpOperand.empty()) { - LLVM_DEBUG(DBGS() << "\nFused an extract use through block argument\n" - << *containingOp); + LDBG() << "\nFused an extract use through block argument\n" + << *containingOp; fusedOps.append(tiledContainingOpOperand); continue; } @@ -1147,7 +1145,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, Operation *cloned = cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp); if (cloned) { - LLVM_DEBUG(DBGS() << "\nFused an use by cloning\n" << *containingOp); + LDBG() << "\nFused an use by cloning\n" << *containingOp; fusedOps.push_back(cloned); continue; } @@ -1851,7 +1849,7 @@ transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter, assert(!packOp && "packOp must be null on entry when unPackOp is not null"); OpOperand *packUse = linalgOp.getDpsInitOperand( cast<OpResult>(unPackOp.getSource()).getResultNumber()); - packOp = dyn_cast_or_null<linalg::PackOp>(packUse->get().getDefiningOp()); + packOp = packUse->get().getDefiningOp<linalg::PackOp>(); if (!packOp || !packOp.getResult().hasOneUse()) return emitSilenceableError() << "could not find matching pack op"; } @@ -3410,12 +3408,12 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter, for (auto [ofrIdx, ofr] : llvm::enumerate(getMixedSizes())) { if (auto attr = llvm::dyn_cast_if_present<Attribute>(ofr)) { if (scalableSizes[ofrIdx]) { - auto val = b.create<arith::ConstantIndexOp>( - getLoc(), cast<IntegerAttr>(attr).getInt()); + auto val = arith::ConstantIndexOp::create( + b, getLoc(), cast<IntegerAttr>(attr).getInt()); Value vscale = - b.create<vector::VectorScaleOp>(getLoc(), b.getIndexType()); + vector::VectorScaleOp::create(b, getLoc(), b.getIndexType()); sizes.push_back( - b.create<arith::MulIOp>(getLoc(), val, vscale).getResult()); + arith::MulIOp::create(b, getLoc(), val, vscale).getResult()); } else { sizes.push_back(attr); } @@ -3626,9 +3624,10 @@ static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter, SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(), rewriter.getIndexAttr(1)); - auto normalizedForallOp = rewriter.create<scf::ForallOp>( - loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(), - loop.getMapping(), [](OpBuilder &, Location, ValueRange) {}); + auto normalizedForallOp = scf::ForallOp::create( + rewriter, loc, normalizedLbs, normalizedUbs, normalizedSteps, + loop.getOutputs(), loop.getMapping(), + [](OpBuilder &, Location, ValueRange) {}); auto normalizedLoopIvs = normalizedForallOp.getInductionVars(); OpBuilder::InsertionGuard g(rewriter); @@ -4131,12 +4130,11 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, target->template getParentOfType<scf::InParallelOp>()); } - Value extracted = rewriter.create<tensor::ExtractSliceOp>( - target.getLoc(), target.getDest(), target.getMixedOffsets(), + Value extracted = tensor::ExtractSliceOp::create( + rewriter, target.getLoc(), target.getDest(), target.getMixedOffsets(), target.getMixedSizes(), target.getMixedStrides()); - Value copied = rewriter - .create<linalg::CopyOp>(target.getLoc(), - target.getSource(), extracted) + Value copied = linalg::CopyOp::create(rewriter, target.getLoc(), + target.getSource(), extracted) .getResult(0); // Reset the insertion point. rewriter.setInsertionPoint(target); diff --git a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp index 281d9f2..ba94ad7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/AllInterfaces.cpp @@ -10,14 +10,14 @@ #include "mlir/Dialect/Linalg/IR/ValueBoundsOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/SubsetInsertionOpInterfaceImpl.h" #include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" void mlir::linalg::registerAllDialectInterfaceImplementations( DialectRegistry ®istry) { registerBufferizableOpInterfaceExternalModels(registry); - registerMeshShardingInterfaceExternalModels(registry); + registerShardingInterfaceExternalModels(registry); registerSubsetOpInterfaceExternalModels(registry); registerTilingInterfaceExternalModels(registry); registerValueBoundsOpInterfaceExternalModels(registry); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index 1f6d96c..3512ecd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -184,9 +184,9 @@ struct SoftmaxOpInterface getBuffer(rewriter, softmaxOp.getOutput(), options, state); if (failed(outputBuffer)) return failure(); - rewriter.create<linalg::SoftmaxOp>(softmaxOp.getLoc(), - /*result=*/TypeRange(), *inputBuffer, - *outputBuffer, softmaxOp.getDimension()); + linalg::SoftmaxOp::create(rewriter, softmaxOp.getLoc(), + /*result=*/TypeRange(), *inputBuffer, + *outputBuffer, softmaxOp.getDimension()); replaceOpWithBufferizedValues(rewriter, op, *outputBuffer); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 69e6fda..70f846e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -24,7 +24,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms Interchange.cpp Loops.cpp TransposeMatmul.cpp - MeshShardingInterfaceImpl.cpp + ShardingInterfaceImpl.cpp NamedOpConversions.cpp BlockPackMatmul.cpp PackAndUnpackPatterns.cpp @@ -68,7 +68,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms MLIRIR MLIRMemRefDialect MLIRMemRefTransforms - MLIRMeshTransforms + MLIRShardTransforms MLIRLinalgDialect MLIRLinalgUtils MLIRSCFDialect diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp index a7732b9..d1eb270 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -30,10 +30,10 @@ static bool hasAllOneValues(DenseIntElementsAttr attr) { static Value createAdd(Location loc, Value x, Value y, OpBuilder &builder) { if (isa<IntegerType>(x.getType())) - return builder.create<arith::AddIOp>(loc, x, y); + return arith::AddIOp::create(builder, loc, x, y); if (isa<ComplexType>(x.getType())) - return builder.create<complex::AddOp>(loc, x, y); - return builder.create<arith::AddFOp>(loc, x, y); + return complex::AddOp::create(builder, loc, x, y); + return arith::AddFOp::create(builder, loc, x, y); } static Value createMul(Location loc, Value x, Value y, Type accType, @@ -44,10 +44,10 @@ static Value createMul(Location loc, Value x, Value y, Type accType, Value yConvert = convertScalarToDtype(builder, loc, y, accType, /*isUnsignedCast=*/false); if (isa<ComplexType>(accType)) - return builder.create<complex::MulOp>(loc, xConvert, yConvert); + return complex::MulOp::create(builder, loc, xConvert, yConvert); if (isa<IntegerType>(accType)) - return builder.create<arith::MulIOp>(loc, xConvert, yConvert); - return builder.create<arith::MulFOp>(loc, xConvert, yConvert); + return arith::MulIOp::create(builder, loc, xConvert, yConvert); + return arith::MulFOp::create(builder, loc, xConvert, yConvert); } // Delinearizes the given composite `index` by the basis specified in `factors`. @@ -56,7 +56,7 @@ static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index, assert(!factors.empty() && "empty factor list"); SmallVector<Value> basis; for (int64_t f : factors) - basis.push_back(b.create<arith::ConstantOp>(loc, b.getIndexAttr(f))); + basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f))); FailureOr<SmallVector<Value>> multiIndex = affine::delinearizeIndex(b, loc, index, basis); assert(!failed(multiIndex) && "Failed to linearize img2col index"); @@ -115,18 +115,18 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { SmallVector<ReassociationIndices> filterReassocIndices = {{0, 1, 2}, {3}}; auto reshapedFilterType = RankedTensorType::get({fh * fw * ic, oc}, filterType.getElementType()); - Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>( - loc, reshapedFilterType, filter, filterReassocIndices); + Value reshapedFilter = tensor::CollapseShapeOp::create( + rewriter, loc, reshapedFilterType, filter, filterReassocIndices); SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}}; RankedTensorType reshapedOutputType = RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); - Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>( - loc, reshapedOutputType, output, outputReassocIndices); + Value reshapedOutput = tensor::CollapseShapeOp::create( + rewriter, loc, reshapedOutputType, output, outputReassocIndices); SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic}; - Value colTensor = rewriter.create<tensor::EmptyOp>( - loc, colTensorShape, inputType.getElementType()); + Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape, + inputType.getElementType()); // Convert the input to a (BMK) column tensor. auto nloops = colTensorShape.size(); @@ -138,15 +138,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { SmallVector<AffineMap> img2colIndexingMaps = { AffineMap::getMultiDimIdentityMap(nloops, context)}; - auto img2ColTensor = rewriter.create<linalg::GenericOp>( - loc, colTensor.getType(), + auto img2ColTensor = linalg::GenericOp::create( + rewriter, loc, colTensor.getType(), /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0); - Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1); - Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2); + Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); + Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); + Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); // Recover the original iteration indices from the problem/input sizes. SmallVector<Value> mIndices = unrollIndex( @@ -170,9 +170,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex}; - Value inputVal = nestedBuilder.create<tensor::ExtractOp>( - loc, input, extractionIndices); - nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal); + Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, + extractionIndices); + linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); }); // Because the filter does not share the same batch dimension, @@ -187,8 +187,8 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { SmallVector<utils::IteratorType> genericIterators = {parallel, parallel, parallel, reduction}; - auto genericOp = rewriter.create<linalg::GenericOp>( - loc, reshapedOutputType, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, reshapedOutputType, /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter}, /*outputs=*/ValueRange{reshapedOutput}, ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators, @@ -196,12 +196,12 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { Value mul = createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); Value add = createAdd(loc, mul, args[2], nestedBuilder); - nestedBuilder.create<linalg::YieldOp>(nestedLoc, add); + linalg::YieldOp::create(nestedBuilder, nestedLoc, add); }); Value result = genericOp.getResults().front(); - auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>( - loc, outputType, result, outputReassocIndices); + auto reshapedResult = tensor::ExpandShapeOp::create( + rewriter, loc, outputType, result, outputReassocIndices); rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult}); @@ -244,8 +244,8 @@ rewriteInIm2Col(RewriterBase &rewriter, SmallVector<int64_t> targetShape = llvm::to_vector<4>(llvm::map_range( indices, [&](int64_t index) -> int64_t { return inputShape[index]; })); - Value outputTensor = rewriter.create<tensor::EmptyOp>( - loc, targetShape, operandTensorType.getElementType()); + Value outputTensor = tensor::EmptyOp::create( + rewriter, loc, targetShape, operandTensorType.getElementType()); SmallVector<utils::IteratorType> loopAttributeTypes( nloops, utils::IteratorType::parallel); @@ -255,12 +255,12 @@ rewriteInIm2Col(RewriterBase &rewriter, AffineMap::get(nloops, 0, exprs, rewriter.getContext())), AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; - auto transposedOp = rewriter.create<linalg::GenericOp>( - loc, outputTensor.getType(), + auto transposedOp = linalg::GenericOp::create( + rewriter, loc, outputTensor.getType(), /*inputs=*/operand, /*outputs=*/outputTensor, indexingMaps, loopAttributeTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); return transposedOp.getResult(0); @@ -307,15 +307,15 @@ rewriteInIm2Col(RewriterBase &rewriter, AffineMap::get(nloops, 0, inputExprs, rewriter.getContext()), AffineMap::getMultiDimIdentityMap(nloops, rewriter.getContext())}; - Value colTensor = rewriter.create<tensor::EmptyOp>( - loc, colTensorShape, inputType.getElementType()); + Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape, + inputType.getElementType()); - auto img2ColTensor = rewriter.create<linalg::GenericOp>( - loc, colTensor.getType(), + auto img2ColTensor = linalg::GenericOp::create( + rewriter, loc, colTensor.getType(), /*inputs=*/inputT, /*outputs=*/colTensor, indexingMaps, loopAttributeTypes, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - nestedBuilder.create<linalg::YieldOp>(nestedLoc, args[0]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); SmallVector<ReassociationIndices> img2ColTensorReassocIndices = { @@ -331,26 +331,27 @@ rewriteInIm2Col(RewriterBase &rewriter, auto reshapedOutputTensorType = RankedTensorType::get({n * c, oh * ow}, outputType.getElementType()); - Value reshapedImg2ColTensor = rewriter.create<tensor::CollapseShapeOp>( - loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), + Value reshapedImg2ColTensor = tensor::CollapseShapeOp::create( + rewriter, loc, reshapedImg2ColTensorType, img2ColTensor.getResult(0), img2ColTensorReassocIndices); - Value reshapedFilterTensor = rewriter.create<tensor::CollapseShapeOp>( - loc, reshapedFilterTensorType, filterT, filterReassociationIndice); - Value reshapedoutputTensor = rewriter.create<tensor::CollapseShapeOp>( - loc, reshapedOutputTensorType, transposedOutputTensor, + Value reshapedFilterTensor = + tensor::CollapseShapeOp::create(rewriter, loc, reshapedFilterTensorType, + filterT, filterReassociationIndice); + Value reshapedoutputTensor = tensor::CollapseShapeOp::create( + rewriter, loc, reshapedOutputTensorType, transposedOutputTensor, outputReassociationIndice); - auto batchMatVecResult = rewriter.create<linalg::BatchMatvecOp>( - loc, TypeRange{reshapedoutputTensor.getType()}, + auto batchMatVecResult = linalg::BatchMatvecOp::create( + rewriter, loc, TypeRange{reshapedoutputTensor.getType()}, ValueRange{reshapedImg2ColTensor, reshapedFilterTensor}, ValueRange{reshapedoutputTensor}); SmallVector<ReassociationIndices> batchMatVecReassociationIndice = {{0, 1}, {2, 3}}; - auto batchMatVecResultReshaped = rewriter.create<tensor::ExpandShapeOp>( - loc, transposedOutputTensor.getType(), batchMatVecResult.getResult(0), - batchMatVecReassociationIndice); + auto batchMatVecResultReshaped = tensor::ExpandShapeOp::create( + rewriter, loc, transposedOutputTensor.getType(), + batchMatVecResult.getResult(0), batchMatVecReassociationIndice); Value transposedResult = transposeOperand(batchMatVecResultReshaped, {0, 2, 3, 1}); @@ -400,19 +401,19 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}}; auto reshapedFilterType = RankedTensorType::get({oc, ic * fh * fw}, inputType.getElementType()); - Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>( - loc, reshapedFilterType, filter, filterReassocIndices); + Value reshapedFilter = tensor::CollapseShapeOp::create( + rewriter, loc, reshapedFilterType, filter, filterReassocIndices); SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1}, {2, 3}}; auto reshapedOutputType = RankedTensorType::get({n, oc, oh * ow}, outputType.getElementType()); - Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>( - loc, reshapedOutputType, output, outputReassocIndices); + Value reshapedOutput = tensor::CollapseShapeOp::create( + rewriter, loc, reshapedOutputType, output, outputReassocIndices); // Convert the input to a (BKN) tensor. SmallVector<int64_t, 4> colTensorShape = {n, ic * fh * fw, oh * ow}; - Value colTensor = rewriter.create<tensor::EmptyOp>( - loc, colTensorShape, inputType.getElementType()); + Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape, + inputType.getElementType()); auto nloops = colTensorShape.size(); @@ -423,15 +424,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { SmallVector<AffineMap, 4> img2colIndexingMaps = { AffineMap::getMultiDimIdentityMap(nloops, context)}; - auto img2ColTensor = rewriter.create<linalg::GenericOp>( - loc, colTensor.getType(), + auto img2ColTensor = linalg::GenericOp::create( + rewriter, loc, colTensor.getType(), /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0); - Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1); - Value nIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2); + Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); + Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); + Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); // Recover the original iteration indices from the problem/input sizes. SmallVector<Value> kIndices = unrollIndex( @@ -455,9 +456,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw] SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex}; - Value inputVal = nestedBuilder.create<tensor::ExtractOp>( - loc, input, extractionIndices); - nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal); + Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, + extractionIndices); + linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); }); // Because the filter does not share the same batch dimension, @@ -471,8 +472,8 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { auto resultMap = AffineMap::get(4, 0, {bDim, mDim, nDim}, context); SmallVector<utils::IteratorType> genericIterators = {parallel, parallel, parallel, reduction}; - auto genericOp = rewriter.create<linalg::GenericOp>( - loc, reshapedOutputType, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, reshapedOutputType, /*inputs=*/ValueRange{reshapedFilter, img2ColTensor.getResult(0)}, /*outputs=*/ValueRange{reshapedOutput}, ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators, @@ -480,12 +481,12 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { Value mul = createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); Value add = createAdd(loc, mul, args[2], nestedBuilder); - nestedBuilder.create<linalg::YieldOp>(nestedLoc, add); + linalg::YieldOp::create(nestedBuilder, nestedLoc, add); }); Value result = genericOp.getResults().front(); - auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>( - loc, outputType, result, outputReassocIndices); + auto reshapedResult = tensor::ExpandShapeOp::create( + rewriter, loc, outputType, result, outputReassocIndices); rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult}); @@ -535,18 +536,18 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { SmallVector<ReassociationIndices> filterReassocIndices = {{0}, {1, 2, 3}}; auto reshapedFilterType = RankedTensorType::get({oc, fh * fw * ic}, filterType.getElementType()); - Value reshapedFilter = rewriter.create<tensor::CollapseShapeOp>( - loc, reshapedFilterType, filter, filterReassocIndices); + Value reshapedFilter = tensor::CollapseShapeOp::create( + rewriter, loc, reshapedFilterType, filter, filterReassocIndices); SmallVector<ReassociationIndices> outputReassocIndices = {{0}, {1, 2}, {3}}; RankedTensorType reshapedOutputType = RankedTensorType::get({n, oh * ow, oc}, outputType.getElementType()); - Value reshapedOutput = rewriter.create<tensor::CollapseShapeOp>( - loc, reshapedOutputType, output, outputReassocIndices); + Value reshapedOutput = tensor::CollapseShapeOp::create( + rewriter, loc, reshapedOutputType, output, outputReassocIndices); SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic}; - Value colTensor = rewriter.create<tensor::EmptyOp>( - loc, colTensorShape, inputType.getElementType()); + Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape, + inputType.getElementType()); // Convert the input to a (BMK) column tensor. auto nloops = colTensorShape.size(); @@ -558,15 +559,15 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { SmallVector<AffineMap> img2colIndexingMaps = { AffineMap::getMultiDimIdentityMap(nloops, context)}; - auto img2ColTensor = rewriter.create<linalg::GenericOp>( - loc, colTensor.getType(), + auto img2ColTensor = linalg::GenericOp::create( + rewriter, loc, colTensor.getType(), /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = nestedBuilder.create<linalg::IndexOp>(loc, 0); - Value mIndex = nestedBuilder.create<linalg::IndexOp>(loc, 1); - Value kIndex = nestedBuilder.create<linalg::IndexOp>(loc, 2); + Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); + Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); + Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); // Recover the original iteration indices from the problem/input sizes. SmallVector<Value> mIndices = unrollIndex( @@ -590,9 +591,9 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex}; - Value inputVal = nestedBuilder.create<tensor::ExtractOp>( - loc, input, extractionIndices); - nestedBuilder.create<linalg::YieldOp>(nestedLoc, inputVal); + Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, + extractionIndices); + linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); }); // Because we didn't transpose the filters we don't actually have a batched @@ -606,8 +607,8 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { SmallVector<utils::IteratorType> genericIterators = {parallel, parallel, parallel, reduction}; - auto genericOp = rewriter.create<linalg::GenericOp>( - loc, reshapedOutputType, + auto genericOp = linalg::GenericOp::create( + rewriter, loc, reshapedOutputType, /*inputs=*/ValueRange{img2ColTensor.getResult(0), reshapedFilter}, /*outputs=*/ValueRange{reshapedOutput}, ArrayRef<AffineMap>{lhsMap, rhsMap, resultMap}, genericIterators, @@ -615,12 +616,12 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { Value mul = createMul(loc, args[0], args[1], args[2].getType(), nestedBuilder); Value add = createAdd(loc, mul, args[2], nestedBuilder); - nestedBuilder.create<linalg::YieldOp>(nestedLoc, add); + linalg::YieldOp::create(nestedBuilder, nestedLoc, add); }); Value result = genericOp.getResults().front(); - auto reshapedResult = rewriter.create<tensor::ExpandShapeOp>( - loc, outputType, result, outputReassocIndices); + auto reshapedResult = tensor::ExpandShapeOp::create( + rewriter, loc, outputType, result, outputReassocIndices); rewriter.replaceOp(convOp, ArrayRef<Value>{reshapedResult}); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 39e2aac..76ddee4 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -37,8 +37,8 @@ static Value createInserts(RewriterBase &rewriter, Location loc, int dim, if (dim == static_cast<int>(shape.size()) - 1) { for (int i = 0; i < shape.back(); ++i) { indices.back() = constants[i]; - destination = rewriter.create<tensor::InsertOp>(loc, *elementIt, - destination, indices); + destination = tensor::InsertOp::create(rewriter, loc, *elementIt, + destination, indices); ++elementIt; } return destination; @@ -65,27 +65,27 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, MaterializeInDestination: { // Note: This is the preferred way of memcpy'ing because no layout map // and/or memory space must be specified for the source. - auto materializeOp = b.create<bufferization::MaterializeInDestinationOp>( - loc, tensorSource, memrefDest); + auto materializeOp = bufferization::MaterializeInDestinationOp::create( + b, loc, tensorSource, memrefDest); materializeOp.setWritable(true); } break; case linalg::BufferizeToAllocationOptions::MemcpyOp::MemrefCopy: { // TODO: Support custom memory space on source. // We do not know the layout map of the source yet, so use a fully dynamic // layout for best compatibility. - Value toBuffer = b.create<bufferization::ToBufferOp>( - loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), + Value toBuffer = bufferization::ToBufferOp::create( + b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), tensorSource, /*readOnly=*/true); - b.create<memref::CopyOp>(loc, toBuffer, memrefDest); + memref::CopyOp::create(b, loc, toBuffer, memrefDest); } break; case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: { // TODO: Support custom memory space on source. // We do not know the layout map of the source yet, so use a fully dynamic // layout for best compatibility. - Value toBuffer = b.create<bufferization::ToBufferOp>( - loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), + Value toBuffer = bufferization::ToBufferOp::create( + b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), tensorSource, /*readOnly=*/true); - b.create<linalg::CopyOp>(loc, toBuffer, memrefDest); + linalg::CopyOp::create(b, loc, toBuffer, memrefDest); } break; }; } @@ -120,15 +120,15 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, ->materializeConstant(rewriter, constYieldedValue, yieldedValue.getType(), yieldedValue.getLoc()) ->getResult(0); - auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(fillValue), - ValueRange(dest)); + auto fillOp = linalg::FillOp::create(rewriter, loc, ValueRange(fillValue), + ValueRange(dest)); return fillOp; } if (invariantYieldedValue) { // Padding with an invariant value. - auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange(yieldedValue), - ValueRange(dest)); + auto fillOp = linalg::FillOp::create( + rewriter, loc, ValueRange(yieldedValue), ValueRange(dest)); return fillOp; } @@ -137,8 +137,8 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, utils::IteratorType::parallel); SmallVector<AffineMap> indexingMaps( 1, rewriter.getMultiDimIdentityMap(resultType.getRank())); - auto genericOp = rewriter.create<linalg::GenericOp>( - loc, resultType, /*inputs=*/ValueRange(), + auto genericOp = linalg::GenericOp::create( + rewriter, loc, resultType, /*inputs=*/ValueRange(), /*outputs=*/ValueRange{dest}, /*indexingMaps=*/ indexingMaps, iteratorTypes); Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, @@ -146,7 +146,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, rewriter.setInsertionPointToStart(body); SmallVector<Value> bbArgReplacements; for (int64_t i = 0; i < resultType.getRank(); ++i) - bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i)); + bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i)); rewriter.mergeBlocks(padOp.getBody(), body, bbArgReplacements); // Update terminator. @@ -179,8 +179,8 @@ static SmallVector<Value> reifyOrComputeDynamicSizes(OpBuilder &b, for (int64_t i = 0; i < tensorType.getRank(); ++i) { if (tensorType.isDynamicDim(i)) dynSizes.push_back( - b.create<DimOp>(value.getLoc(), value, - b.create<arith::ConstantIndexOp>(value.getLoc(), i))); + DimOp::create(b, value.getLoc(), value, + arith::ConstantIndexOp::create(b, value.getLoc(), i))); } return dynSizes; } @@ -201,15 +201,15 @@ createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, Value alloc; if (options.allocOp == linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloc) { - alloc = rewriter.create<memref::AllocOp>(loc, memrefType, dynamicSizes); + alloc = memref::AllocOp::create(rewriter, loc, memrefType, dynamicSizes); if (options.emitDealloc) { // Place deallocation at the end of the block. rewriter.setInsertionPoint(rewriter.getInsertionBlock()->getTerminator()); - rewriter.create<memref::DeallocOp>(loc, alloc); + memref::DeallocOp::create(rewriter, loc, alloc); } } else if (options.allocOp == linalg::BufferizeToAllocationOptions::AllocOp::MemrefAlloca) { - alloc = rewriter.create<memref::AllocaOp>(loc, memrefType, dynamicSizes); + alloc = memref::AllocaOp::create(rewriter, loc, memrefType, dynamicSizes); // No dealloc is needed. } @@ -243,14 +243,14 @@ Value linalg::bufferizeToAllocation( getMixedSizes(rewriter, loc, padOp.getSource()); SmallVector<OpFoldResult> strides(padOp.getResultType().getRank(), rewriter.getIndexAttr(1)); - Value subview = rewriter.create<memref::SubViewOp>( - loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides); + Value subview = memref::SubViewOp::create( + rewriter, loc, alloc, /*offsets=*/padOp.getMixedLowPad(), sizes, strides); createMemcpy(rewriter, loc, padOp.getSource(), subview, options); // Create bufferization.to_tensor with "restrict" and "writable". The returned // tensor is a new buffer allocation, so it does not alias with any buffer. - Value toTensorOp = rewriter.create<bufferization::ToTensorOp>( - loc, padOp.getResult().getType(), alloc, /*restrict=*/true, + Value toTensorOp = bufferization::ToTensorOp::create( + rewriter, loc, padOp.getResult().getType(), alloc, /*restrict=*/true, /*writable=*/true); rewriter.replaceOp(padOp, toTensorOp); return alloc; @@ -338,8 +338,9 @@ Value linalg::bufferizeToAllocation( // Create bufferization.to_tensor with "restrict" and "writable". The returned // tensor is a new buffer allocation, so it does not alias with any buffer. - Value toTensorOp = rewriter.create<bufferization::ToTensorOp>( - loc, allocTensorOp.getResult().getType(), alloc, /*restrict=*/true, + Value toTensorOp = bufferization::ToTensorOp::create( + rewriter, loc, allocTensorOp.getResult().getType(), alloc, + /*restrict=*/true, /*writable=*/true); rewriter.replaceOp(allocTensorOp, toTensorOp); return alloc; @@ -354,7 +355,7 @@ FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle( auto shape = tensorType.getShape(); // Create tensor.empty. - auto emptyOp = rewriter.create<EmptyOp>(loc, tensorType, ValueRange()); + auto emptyOp = EmptyOp::create(rewriter, loc, tensorType, ValueRange()); // Case: tensor<elem_type>. if (shape.empty()) { @@ -369,7 +370,7 @@ FailureOr<Operation *> mlir::linalg::rewriteInDestinationPassingStyle( SmallVector<Value, 2> constants; constants.reserve(maxDim); for (int i = 0; i < maxDim; ++i) - constants.push_back(rewriter.create<arith::ConstantIndexOp>(loc, i)); + constants.push_back(arith::ConstantIndexOp::create(rewriter, loc, i)); // Traverse all elements and create tensor.insert ops. auto elementIt = fromElementsOp.getElements().begin(); @@ -394,16 +395,16 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType()); // Create tensor.empty. - auto emptyOp = - rewriter.create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents()); + auto emptyOp = EmptyOp::create(rewriter, loc, tensorType, + generateOp.getDynamicExtents()); // Create linalg.generic. SmallVector<utils::IteratorType> iteratorTypes(tensorType.getRank(), utils::IteratorType::parallel); SmallVector<AffineMap> indexingMaps( 1, rewriter.getMultiDimIdentityMap(tensorType.getRank())); - auto genericOp = rewriter.create<linalg::GenericOp>( - loc, tensorType, /*inputs=*/ValueRange(), + auto genericOp = linalg::GenericOp::create( + rewriter, loc, tensorType, /*inputs=*/ValueRange(), /*outputs=*/ValueRange{emptyOp.getResult()}, /*indexingMaps=*/ indexingMaps, iteratorTypes); Block *body = rewriter.createBlock(&genericOp->getRegion(0), {}, @@ -411,7 +412,7 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, rewriter.setInsertionPointToStart(body); SmallVector<Value> bbArgReplacements; for (int64_t i = 0; i < tensorType.getRank(); ++i) - bbArgReplacements.push_back(rewriter.create<linalg::IndexOp>(loc, i)); + bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i)); rewriter.mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements); // Update terminator. @@ -450,13 +451,13 @@ mlir::linalg::rewriteInDestinationPassingStyle(RewriterBase &rewriter, llvm::all_of(padOp.getMixedHighPad(), isZeroInteger)) { using bufferization::AllocTensorOp; Value allocated = - rewriter.create<AllocTensorOp>(loc, resultType, dynamicSizes); + AllocTensorOp::create(rewriter, loc, resultType, dynamicSizes); auto copyOp = rewriter.replaceOpWithNewOp<linalg::CopyOp>( padOp, padOp.getSource(), allocated); return copyOp.getOperation(); } - Value empty = rewriter.create<EmptyOp>(loc, resultType, dynamicSizes); + Value empty = EmptyOp::create(rewriter, loc, resultType, dynamicSizes); // Create linalg.fill or linalg.generic. Operation *fillOp = movePaddingToFillOrGenericOp(rewriter, loc, padOp, empty); rewriter.setInsertionPointAfter(fillOp); @@ -567,8 +568,8 @@ Value linalg::bufferizeToAllocation( createMemcpy(rewriter, op->getLoc(), operand->get(), alloc, options); } rewriter.modifyOpInPlace(op, [&]() { - auto toTensorOp = rewriter.create<ToTensorOp>( - op->getLoc(), operand->get().getType(), alloc); + auto toTensorOp = ToTensorOp::create(rewriter, op->getLoc(), + operand->get().getType(), alloc); operand->set(toTensorOp); if (options.bufferizeDestinationOnly) { rewriter.modifyOpInPlace(toTensorOp, [&]() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 7057490..0a9c176 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -287,8 +287,8 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, auto empty = linalg::PackOp::createDestinationTensor( b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); - auto packedOperand = b.create<linalg::PackOp>( - loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, + auto packedOperand = linalg::PackOp::create( + b, loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, /*padding=*/std::nullopt, outerDimsPerm); return std::make_tuple(packedOperand, indexingMap); } @@ -345,8 +345,9 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, indexingMaps.push_back(packedOutIndexingMap); - auto newGenericOp = rewriter.create<linalg::GenericOp>( - loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes, + auto newGenericOp = linalg::GenericOp::create( + rewriter, loc, dest.getType(), inputOperands, dest, indexingMaps, + iterTypes, /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), newGenericOp.getRegion().begin()); @@ -457,9 +458,9 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, if (!packOpDest.hasOneUse()) return failure(); if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) { - packOpDest = rewriter.create<tensor::EmptyOp>( - genericOp->getLoc(), emptyOp.getMixedSizes(), - emptyOp.getType().getElementType()); + packOpDest = tensor::EmptyOp::create(rewriter, genericOp->getLoc(), + emptyOp.getMixedSizes(), + emptyOp.getType().getElementType()); } else { DominanceInfo dom(genericOp); if (!dom.properlyDominates(packOpDest, genericOp)) @@ -562,8 +563,8 @@ public: auto empty = linalg::PackOp::createDestinationTensor( rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos, outerDimsPerm); - auto sourcePack = rewriter.create<linalg::PackOp>( - loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, + auto sourcePack = linalg::PackOp::create( + rewriter, loc, padOp.getSource(), empty, innerDimsPos, mixedTiles, /*padding=*/std::nullopt, outerDimsPerm); // If we have `outer_dims_perms` we need to adjust the padded dimensions. @@ -579,17 +580,18 @@ public: lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); - auto newPadOp = rewriter.create<tensor::PadOp>( - loc, /*result=*/Type(), sourcePack, lowPad, highPad, paddingVal, - padOp.getNofold()); + auto newPadOp = + tensor::PadOp::create(rewriter, loc, /*result=*/Type(), sourcePack, + lowPad, highPad, paddingVal, padOp.getNofold()); // If the pad has more than one user, create an unpack on the new pad to // replace the other uses. if (!padOp->hasOneUse()) { auto unpackEmpty = linalg::UnPackOp::createDestinationTensor( rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm); - Value unpackedPad = rewriter.create<linalg::UnPackOp>( - loc, newPadOp, unpackEmpty, innerDimsPos, mixedTiles, outerDimsPerm); + Value unpackedPad = + linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty, + innerDimsPos, mixedTiles, outerDimsPerm); rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack); } @@ -719,9 +721,10 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, auto emptyOp = linalg::PackOp::createDestinationTensor( rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(), projectedInnerDimsPos, newOuterDimsPerm); - auto newPackOp = rewriter.create<linalg::PackOp>( - packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos, - packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm); + auto newPackOp = linalg::PackOp::create( + rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp, + projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(), + newOuterDimsPerm); SmallVector<ReassociationIndices> newReassocIndices = reassocIndices; // First apply the permutation on the reassociations of the outer dims. @@ -735,8 +738,9 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, nextPos += 1; } - auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>( - collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices); + auto newCollapseOp = tensor::CollapseShapeOp::create( + rewriter, collapseOp.getLoc(), packOp.getType(), newPackOp, + newReassocIndices); rewriter.replaceOp(packOp, newCollapseOp); return success(); @@ -853,13 +857,14 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, Value destTensor = linalg::PackOp::createDestinationTensor( rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(), projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{}); - Value packedVal = rewriter.create<linalg::PackOp>( - packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos, - packOp.getMixedTiles(), packOp.getPaddingValue(), + Value packedVal = linalg::PackOp::create( + rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor, + projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(), /*outerDimsPerm=*/SmallVector<int64_t>{}); - Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>( - packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand); + Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(), + packOp.getDestType(), + packedVal, *reassocExpand); rewriter.replaceOp(packOp, newExpandOp); return success(); @@ -972,15 +977,15 @@ static LogicalResult pushDownUnPackOpThroughExpandShape( RankedTensorType newExpandType = linalg::PackOp::inferPackedType( expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); - auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>( - expandOp.getLoc(), newExpandType, unPackOp.getSource(), - newReassocIndices); + auto newExpandOp = + tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType, + unPackOp.getSource(), newReassocIndices); auto emptyOp = linalg::UnPackOp::createDestinationTensor( rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(), projectedInnerDimsPos, newOuterDimsPerm); - auto newUnPackOp = rewriter.create<linalg::UnPackOp>( - unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, + auto newUnPackOp = linalg::UnPackOp::create( + rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp, projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm); rewriter.replaceOp(expandOp, newUnPackOp); @@ -1138,10 +1143,9 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, // Insert an unPackOp right after the packed generic. Value unPackOpRes = - rewriter - .create<linalg::UnPackOp>(genericOp.getLoc(), newResult, - destPack.getSource(), innerDimsPos, - mixedTiles, outerDimsPerm) + linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult, + destPack.getSource(), innerDimsPos, mixedTiles, + outerDimsPerm) .getResult(); return std::make_tuple(newGenericOp, unPackOpRes); @@ -1212,17 +1216,17 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> { lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); highPad.append(pointLoopsSize, rewriter.getIndexAttr(0)); - auto newPadOp = rewriter.create<tensor::PadOp>( - loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad, - paddingVal, padOp.getNofold()); + auto newPadOp = tensor::PadOp::create(rewriter, loc, /*result=*/Type(), + unpackOp.getSource(), lowPad, highPad, + paddingVal, padOp.getNofold()); // Inject the linalg.unpack right after the packed padOp. - Value outputUnPack = rewriter.create<tensor::EmptyOp>( - loc, padOp.getResultType().getShape(), - padOp.getResultType().getElementType()); + Value outputUnPack = + tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(), + padOp.getResultType().getElementType()); - Value replacement = rewriter.create<linalg::UnPackOp>( - loc, newPadOp.getResult(), outputUnPack, innerDimsPos, + Value replacement = linalg::UnPackOp::create( + rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos, unpackOp.getMixedTiles(), outerDimsPerm); rewriter.replaceOp(padOp, replacement); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp index 692bf595..b7da20c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeGenericByUnfoldingPermutation.cpp @@ -198,10 +198,10 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite( transposedShape[i] = inputRTType.getShape()[permutation[i]]; Value emptyTensor = - rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType); + tensor::EmptyOp::create(rewriter, loc, transposedShape, elType); - auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i], - emptyTensor, permutation); + auto transposeOp = TransposeOp::create(rewriter, loc, newInitValues[i], + emptyTensor, permutation); newInitValues[i] = transposeOp->getResult(0); isChanged = true; } @@ -209,11 +209,11 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite( // Does it require broadcast? if (!broadcastedDims.empty()) { assert(broadcastedDims.size() && "should have non size broadcast"); - Value emptyTensor = rewriter.create<tensor::EmptyOp>( - loc, outputShape, inputRTType.getElementType()); + Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, outputShape, + inputRTType.getElementType()); - auto broadcastOp = rewriter.create<linalg::BroadcastOp>( - loc, newInitValues[i], emptyTensor, broadcastedDims); + auto broadcastOp = linalg::BroadcastOp::create( + rewriter, loc, newInitValues[i], emptyTensor, broadcastedDims); newInitValues[i] = broadcastOp->getResult(0); isChanged = true; @@ -227,7 +227,8 @@ LogicalResult DecomposeProjectedPermutation::matchAndRewrite( SmallVector<Value> operands = op->getOperands(); ValueRange operandsRef(operands); - auto newOp = rewriter.create<linalg::GenericOp>( + auto newOp = linalg::GenericOp::create( + rewriter, /*location=*/op.getLoc(), /*resultTensorTypes=*/op->getResultTypes(), /*inputs=*/newInitValues, diff --git a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp index 1419175..c92a27f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp @@ -133,13 +133,13 @@ static Value getZero(OpBuilder &b, Location loc, Type elementType) { assert(elementType.isIntOrIndexOrFloat() && "expected scalar type while computing zero value"); if (isa<IntegerType>(elementType)) - return b.create<arith::ConstantIntOp>(loc, elementType, 0); + return arith::ConstantIntOp::create(b, loc, elementType, 0); if (elementType.isIndex()) - return b.create<arith::ConstantIndexOp>(loc, 0); + return arith::ConstantIndexOp::create(b, loc, 0); // Assume float. auto floatType = cast<FloatType>(elementType); - return b.create<arith::ConstantFloatOp>( - loc, floatType, APFloat::getZero(floatType.getFloatSemantics())); + return arith::ConstantFloatOp::create( + b, loc, floatType, APFloat::getZero(floatType.getFloatSemantics())); } GenericOp @@ -188,8 +188,8 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, // Fall back path, use an `init_tensor` and identity indexing map. AffineMap indexingMap = rewriter.getMultiDimIdentityMap(domain.size()); - Value emptyTensor = - rewriter.create<tensor::EmptyOp>(loc, domain, scalarOpResult.getType()); + Value emptyTensor = tensor::EmptyOp::create(rewriter, loc, domain, + scalarOpResult.getType()); newInitValues.push_back(emptyTensor); newResultTypes.push_back(emptyTensor.getType()); peeledGenericOpIndexingMaps.push_back(indexingMap); @@ -202,10 +202,10 @@ DecomposeLinalgOp::createPeeledGenericOp(GenericOp genericOp, resultTypes.append(newResultTypes.begin(), newResultTypes.end()); auto indexingMapAttr = rewriter.getAffineMapArrayAttr(peeledGenericOpIndexingMaps); - return rewriter.create<GenericOp>( - loc, resultTypes, genericOp.getInputs(), outsOperands, indexingMapAttr, - genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, - [](OpBuilder, Location, ValueRange) {}); + return GenericOp::create( + rewriter, loc, resultTypes, genericOp.getInputs(), outsOperands, + indexingMapAttr, genericOp.getIteratorTypes(), /*doc=*/nullptr, + /*libraryCall=*/nullptr, [](OpBuilder, Location, ValueRange) {}); } GenericOp @@ -239,8 +239,8 @@ DecomposeLinalgOp::createResidualGenericOp(GenericOp genericOp, indexingMaps.push_back(genericOp.getMatchingIndexingMap(&outOperand)); auto indexingMapAttr = rewriter.getAffineMapArrayAttr(indexingMaps); - return rewriter.create<GenericOp>( - genericOp->getLoc(), genericOp->getResultTypes(), + return GenericOp::create( + rewriter, genericOp->getLoc(), genericOp->getResultTypes(), residualGenericOpOperands, genericOp.getOutputs(), indexingMapAttr, genericOp.getIteratorTypes(), /*doc=*/nullptr, /*libraryCall=*/nullptr, [](OpBuilder, Location, ValueRange) {}); @@ -324,7 +324,7 @@ DecomposeLinalgOp::matchAndRewrite(GenericOp genericOp, yieldedVals.append(llvm::to_vector( llvm::map_range(peeledScalarOperation->getResults(), [](OpResult opr) -> Value { return opr; }))); - rewriter.create<YieldOp>(genericOp.getLoc(), yieldedVals); + YieldOp::create(rewriter, genericOp.getLoc(), yieldedVals); } /// In the split operations, replace block arguments uses that refer to diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index ef24eb8..8309054 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -34,8 +34,8 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type, // A detensored value is converted back by creating a new tensor from its // element(s). - return builder.create<tensor::FromElementsOp>( - loc, RankedTensorType::get({}, inputType), inputs[0]); + return tensor::FromElementsOp::create( + builder, loc, RankedTensorType::get({}, inputType), inputs[0]); } namespace { @@ -147,7 +147,7 @@ public: // A tensor value is detensoried by extracting its element(s). addTargetMaterialization([](OpBuilder &builder, Type type, ValueRange inputs, Location loc) -> Value { - return builder.create<tensor::ExtractOp>(loc, inputs[0], ValueRange{}); + return tensor::ExtractOp::create(builder, loc, inputs[0], ValueRange{}); }); addSourceMaterialization(sourceMaterializationCallback); @@ -480,8 +480,8 @@ struct LinalgDetensorize Block *postEntryBlock = rewriter.splitBlock(entryBlock, entryBlock->begin()); rewriter.setInsertionPointToStart(entryBlock); - auto branch = - rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), postEntryBlock); + auto branch = cf::BranchOp::create(rewriter, rewriter.getUnknownLoc(), + postEntryBlock); if (aggressiveMode.getValue()) { AggressiveDetensoringModel costModel; diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index e0062d1..bf66ed0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -118,16 +118,17 @@ struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfterValue(op->get()); auto elemType = cast<ShapedType>(op->get().getType()).getElementType(); - auto empty = rewriter.create<tensor::EmptyOp>( - loc, tensor::getMixedSizes(rewriter, loc, op->get()), elemType); + auto empty = tensor::EmptyOp::create( + rewriter, loc, tensor::getMixedSizes(rewriter, loc, op->get()), + elemType); unsigned start = genericOp.getDpsInits().getBeginOperandIndex(); newOutputOperands[op->getOperandNumber() - start] = empty.getResult(); } - auto newOp = rewriter.create<GenericOp>( - loc, genericOp.getResultTypes(), newInputOperands, newOutputOperands, - newIndexingMaps, genericOp.getIteratorTypesArray(), + auto newOp = GenericOp::create( + rewriter, loc, genericOp.getResultTypes(), newInputOperands, + newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(), /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); OpBuilder::InsertionGuard guard(rewriter); @@ -266,8 +267,8 @@ expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, assert(rankReductionStrategy == ControlDropUnitDims::RankReductionStrategy::ReassociativeReshape && "unknown rank reduction strategy"); - return rewriter - .create<tensor::ExpandShapeOp>(loc, origResultType, result, reassociation) + return tensor::ExpandShapeOp::create(rewriter, loc, origResultType, result, + reassociation) .getResult(); } @@ -295,8 +296,8 @@ static Value collapseValue( MemRefLayoutAttrInterface layout; auto targetType = MemRefType::get(targetShape, memrefType.getElementType(), layout, memrefType.getMemorySpace()); - return rewriter.create<memref::CollapseShapeOp>(loc, targetType, operand, - reassociation); + return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand, + reassociation); } if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) { if (rankReductionStrategy == @@ -314,8 +315,8 @@ static Value collapseValue( "unknown rank reduction strategy"); auto targetType = RankedTensorType::get(targetShape, tensorType.getElementType()); - return rewriter.create<tensor::CollapseShapeOp>(loc, targetType, operand, - reassociation); + return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand, + reassociation); } llvm_unreachable("unsupported operand type"); } @@ -331,14 +332,14 @@ struct UnitExtentReplacementInfo { SmallVector<int64_t> targetShape; }; static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata( - MLIRContext *context, GenericOp genericOp, OpOperand *opOperand, + MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand, llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap, ArrayRef<AffineExpr> dimReplacements) { UnitExtentReplacementInfo info; ReassociationIndices reassociationGroup; SmallVector<AffineExpr> newIndexExprs; - AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); - ArrayRef<int64_t> operandShape = genericOp.getShape(opOperand); + AffineMap indexingMap = op.getMatchingIndexingMap(opOperand); + SmallVector<int64_t> operandShape = op.getStaticOperandShape(opOperand); ArrayRef<AffineExpr> exprs = indexingMap.getResults(); auto isUnitDim = [&](unsigned dim) { @@ -380,9 +381,16 @@ static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata( } FailureOr<DropUnitDimsResult> -linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, +linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op, + const DroppedUnitDimsBuilder &droppedUnitDimsBuilder, const ControlDropUnitDims &options) { - SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray(); + auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation()); + if (!dpsOp) { + return rewriter.notifyMatchFailure( + op, "op should implement DestinationStyleOpInterface"); + } + + SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray(); if (indexingMaps.empty()) return failure(); @@ -392,19 +400,19 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, AffineMap invertedMap = inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext())); if (!invertedMap) { - return rewriter.notifyMatchFailure(genericOp, + return rewriter.notifyMatchFailure(op, "invalid indexing maps for operation"); } SmallVector<int64_t> allShapesSizes; - for (OpOperand &opOperand : genericOp->getOpOperands()) - llvm::append_range(allShapesSizes, genericOp.getShape(&opOperand)); + for (OpOperand &opOperand : op->getOpOperands()) + llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand)); // 1a. Get the allowed list of dimensions to drop from the `options`. - SmallVector<unsigned> allowedUnitDims = options.controlFn(genericOp); + SmallVector<unsigned> allowedUnitDims = options.controlFn(op); if (allowedUnitDims.empty()) { return rewriter.notifyMatchFailure( - genericOp, "control function returns no allowed unit dims to prune"); + op, "control function returns no allowed unit dims to prune"); } llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(), allowedUnitDims.end()); @@ -417,19 +425,16 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, } } - // 2. Compute the iterator types of the modified op by dropping the one-trip + // 2. Compute the new loops of the modified op by dropping the one-trip // count loops. - SmallVector<utils::IteratorType> newIteratorTypes; llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap; SmallVector<AffineExpr> dimReplacements; unsigned newDims = 0; - for (auto [index, attr] : - llvm::enumerate(genericOp.getIteratorTypesArray())) { + for (auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) { if (unitDims.count(index)) { dimReplacements.push_back( getAffineConstantExpr(0, rewriter.getContext())); } else { - newIteratorTypes.push_back(attr); oldDimToNewDimMap[index] = newDims; dimReplacements.push_back( getAffineDimExpr(newDims, rewriter.getContext())); @@ -462,9 +467,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, } return false; }; - for (OpOperand &opOperand : genericOp->getOpOperands()) { - auto indexingMap = genericOp.getMatchingIndexingMap(&opOperand); - ArrayRef<int64_t> shape = genericOp.getShape(&opOperand); + for (OpOperand &opOperand : op->getOpOperands()) { + auto indexingMap = op.getMatchingIndexingMap(&opOperand); + SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand); if (!hasCollapsibleType(opOperand)) { AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols( dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0); @@ -474,9 +479,9 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, reassociations.push_back({}); continue; } - auto replacementInfo = dropUnitExtentFromOperandMetadata( - rewriter.getContext(), genericOp, &opOperand, oldDimToNewDimMap, - dimReplacements); + auto replacementInfo = + dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand, + oldDimToNewDimMap, dimReplacements); reassociations.push_back(replacementInfo.reassociation); newIndexingMaps.push_back(replacementInfo.indexMap); targetShapes.push_back(replacementInfo.targetShape); @@ -491,13 +496,13 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, concatAffineMaps(newIndexingMaps, rewriter.getContext()))) return failure(); - Location loc = genericOp.getLoc(); + Location loc = op.getLoc(); // 4. For each of the operands, collapse the operand to convert // from original shape to shape in the modified operation if needed, // either through use of reshapes or rank-reducing slices as // specified in `options`. SmallVector<Value> newOperands; - for (OpOperand &opOperand : genericOp->getOpOperands()) { + for (OpOperand &opOperand : op->getOpOperands()) { int64_t idx = opOperand.getOperandNumber(); if (!collapsed[idx]) { newOperands.push_back(opOperand.get()); @@ -508,31 +513,15 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, options.rankReductionStrategy)); } - // 5. Create the `linalg.generic` operation with the new operands, - // indexing maps, iterator types and result types. - ArrayRef<Value> newInputs = - ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs()); - ArrayRef<Value> newOutputs = - ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits()); - SmallVector<Type> resultTypes; - resultTypes.reserve(genericOp.getNumResults()); - for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) - resultTypes.push_back(newOutputs[i].getType()); - GenericOp replacementOp = - rewriter.create<GenericOp>(loc, resultTypes, newInputs, newOutputs, - newIndexingMaps, newIteratorTypes); - rewriter.inlineRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), - replacementOp.getRegion().begin()); - // 5a. Replace `linalg.index` operations that refer to the dropped unit - // dimensions. - replaceUnitDimIndexOps(replacementOp, unitDims, rewriter); + IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder( + loc, rewriter, op, newOperands, newIndexingMaps, unitDims); // 6. If any result type changes, insert a reshape/slice to convert from the // original type to the new type. SmallVector<Value> resultReplacements; - for (auto [index, result] : llvm::enumerate(replacementOp.getResults())) { - unsigned opOperandIndex = index + replacementOp.getNumDpsInputs(); - Value origDest = genericOp.getDpsInitOperand(index)->get(); + for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) { + unsigned opOperandIndex = index + dpsOp.getNumDpsInputs(); + Value origDest = dpsOp.getDpsInitOperand(index)->get(); if (!collapsed[opOperandIndex]) { resultReplacements.push_back(result); continue; @@ -546,6 +535,51 @@ linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, return DropUnitDimsResult{replacementOp, resultReplacements}; } +FailureOr<DropUnitDimsResult> +linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp, + const ControlDropUnitDims &options) { + + DroppedUnitDimsBuilder build = + [](Location loc, OpBuilder &b, IndexingMapOpInterface op, + ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps, + const llvm::SmallDenseSet<unsigned> &droppedDims) + -> IndexingMapOpInterface { + auto genericOp = cast<GenericOp>(op); + // Compute the iterator types of the modified op by dropping the one-trip + // count loops. + SmallVector<utils::IteratorType> newIteratorTypes; + for (auto [index, attr] : + llvm::enumerate(genericOp.getIteratorTypesArray())) { + if (!droppedDims.count(index)) + newIteratorTypes.push_back(attr); + } + + // Create the `linalg.generic` operation with the new operands, + // indexing maps, iterator types and result types. + ArrayRef<Value> newInputs = + ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs()); + ArrayRef<Value> newOutputs = + ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits()); + SmallVector<Type> resultTypes; + resultTypes.reserve(genericOp.getNumResults()); + for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults())) + resultTypes.push_back(newOutputs[i].getType()); + GenericOp replacementOp = + GenericOp::create(b, loc, resultTypes, newInputs, newOutputs, + newIndexingMaps, newIteratorTypes); + b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(), + replacementOp.getRegion().begin()); + // 5a. Replace `linalg.index` operations that refer to the dropped unit + // dimensions. + IRRewriter rewriter(b); + replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter); + + return replacementOp; + }; + + return dropUnitDims(rewriter, genericOp, build, options); +} + namespace { struct DropUnitDims : public OpRewritePattern<GenericOp> { DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {}, @@ -603,6 +637,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { } ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape(); + ArrayRef<int64_t> resultShape = padOp.getResultType().getShape(); int64_t padRank = sourceShape.size(); auto isStaticZero = [](OpFoldResult f) { @@ -613,16 +648,18 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { allowedUnitDims.end()); llvm::SmallDenseSet<unsigned> unitDims; SmallVector<int64_t> newShape; + SmallVector<int64_t> newResultShape; SmallVector<OpFoldResult> newLowPad; SmallVector<OpFoldResult> newHighPad; - for (const auto [dim, size, low, high] : - zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape, - padOp.getMixedLowPad(), padOp.getMixedHighPad())) { + for (const auto [dim, size, outSize, low, high] : zip_equal( + llvm::seq(static_cast<int64_t>(0), padRank), sourceShape, + resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) { if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) && isStaticZero(high)) { unitDims.insert(dim); } else { newShape.push_back(size); + newResultShape.push_back(outSize); newLowPad.push_back(low); newHighPad.push_back(high); } @@ -652,8 +689,10 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape, reassociationMap, options.rankReductionStrategy); + auto newResultType = RankedTensorType::get( + newResultShape, padOp.getResultType().getElementType()); auto newPadOp = rewriter.create<tensor::PadOp>( - padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad, + padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad, newHighPad, paddingVal, padOp.getNofold()); Value dest = padOp.getResult(); @@ -670,9 +709,8 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { expandedSizes.push_back(tensor::getMixedSize( rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims)); } - dest = rewriter.create<tensor::EmptyOp>( - padOp.getLoc(), expandedSizes, - padOp.getResultType().getElementType()); + dest = tensor::EmptyOp::create(rewriter, padOp.getLoc(), expandedSizes, + padOp.getResultType().getElementType()); } Value expandedValue = @@ -713,8 +751,9 @@ struct RankReducedExtractSliceOp strides)); Location loc = sliceOp.getLoc(); - Value newSlice = rewriter.create<tensor::ExtractSliceOp>( - loc, rankReducedType, sliceOp.getSource(), offsets, sizes, strides); + Value newSlice = tensor::ExtractSliceOp::create( + rewriter, loc, rankReducedType, sliceOp.getSource(), offsets, sizes, + strides); rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( sliceOp, resultType, newSlice, *reassociation); return success(); @@ -747,8 +786,8 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> { // parallel case. if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value) rewriter.setInsertionPoint(insertSliceOp->getParentOp()); - reshapedSource = rewriter.create<tensor::CollapseShapeOp>( - loc, insertSliceOp.getSource(), *reassociation); + reshapedSource = tensor::CollapseShapeOp::create( + rewriter, loc, insertSliceOp.getSource(), *reassociation); } rewriter.replaceOpWithNewOp<InsertOpTy>( insertSliceOp, reshapedSource, insertSliceOp.getDest(), @@ -898,8 +937,8 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> { /// Expand result tensor. Value expandResult(PatternRewriter &rewriter, Value result, RankedTensorType expandedType, int64_t dim) const { - return rewriter.create<tensor::ExpandShapeOp>( - result.getLoc(), expandedType, result, + return tensor::ExpandShapeOp::create( + rewriter, result.getLoc(), expandedType, result, getReassociationForReshapeAtDim(expandedType.getRank(), dim)); } @@ -934,9 +973,9 @@ struct RankReduceContractionOps : OpRewritePattern<FromOpTy> { SmallVector<Type, 1> collapsedResultTy; if (isa<RankedTensorType>(collapsedInit.getType())) collapsedResultTy.push_back(collapsedInit.getType()); - auto collapsedOp = rewriter.create<ToOpTy>( - loc, collapsedResultTy, ValueRange{collapsedLhs, collapsedRhs}, - ValueRange{collapsedInit}); + auto collapsedOp = ToOpTy::create(rewriter, loc, collapsedResultTy, + ValueRange{collapsedLhs, collapsedRhs}, + ValueRange{collapsedInit}); for (auto attr : contractionOp->getAttrs()) { if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName || attr.getName() == "indexing_maps") diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 8a5c138..3bd763e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -237,12 +237,12 @@ static void generateFusedElementwiseOpRegion( fusedIndices.reserve(numFusedOpLoops); llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops), std::back_inserter(fusedIndices), [&](uint64_t dim) { - return rewriter.create<IndexOp>(producer.getLoc(), dim); + return IndexOp::create(rewriter, producer.getLoc(), dim); }); for (IndexOp indexOp : llvm::make_early_inc_range(producerBlock.getOps<IndexOp>())) { - Value newIndex = rewriter.create<affine::AffineApplyOp>( - producer.getLoc(), + Value newIndex = affine::AffineApplyOp::create( + rewriter, producer.getLoc(), consumerToProducerLoopsMap.getSubMap(indexOp.getDim()), fusedIndices); mapper.map(indexOp.getResult(), newIndex); } @@ -328,7 +328,7 @@ static void generateFusedElementwiseOpRegion( } for (auto consumerYieldVal : consumerYieldOp.getOperands()) fusedYieldValues.push_back(mapper.lookupOrDefault(consumerYieldVal)); - rewriter.create<YieldOp>(fusedOp.getLoc(), fusedYieldValues); + YieldOp::create(rewriter, fusedOp.getLoc(), fusedYieldValues); // Sanity checks. assert(fusedBlock->getNumArguments() == fusedOp.getNumOperands() && @@ -417,8 +417,8 @@ mlir::linalg::fuseElementwiseOps(RewriterBase &rewriter, } // Generate the fused op. - auto fusedOp = rewriter.create<GenericOp>( - consumer.getLoc(), fusedResultTypes, fusedInputOperands, + auto fusedOp = GenericOp::create( + rewriter, consumer.getLoc(), fusedResultTypes, fusedInputOperands, fusedOutputOperands, rewriter.getAffineMapArrayAttr(fusedIndexMaps), consumer.getIteratorTypes(), /*doc=*/nullptr, @@ -751,9 +751,9 @@ static void updateExpandedGenericOpRegion(PatternRewriter &rewriter, expandedIndices.reserve(expandedDims.size() - 1); llvm::transform( expandedDims.drop_front(), std::back_inserter(expandedIndices), - [&](int64_t dim) { return rewriter.create<IndexOp>(loc, dim); }); + [&](int64_t dim) { return IndexOp::create(rewriter, loc, dim); }); OpFoldResult newIndex = - rewriter.create<IndexOp>(loc, expandedDims.front()).getResult(); + IndexOp::create(rewriter, loc, expandedDims.front()).getResult(); for (auto [expandedShape, expandedIndex] : llvm::zip(expandedDimsShape, expandedIndices)) { AffineExpr idx, acc, shape; @@ -797,8 +797,8 @@ static Operation *createExpandedTransposeOp(PatternRewriter &rewriter, newPerm.push_back(dim); } } - return rewriter.create<TransposeOp>(transposeOp.getLoc(), expandedInput, - output, invertPermutationVector(newPerm)); + return TransposeOp::create(rewriter, transposeOp.getLoc(), expandedInput, + output, invertPermutationVector(newPerm)); } // Create an expanded generic op. @@ -814,9 +814,9 @@ static Operation *createExpandedGenericOp( for (auto j : expansionInfo.getExpandedDims(i)) iteratorTypes[j] = type; - Operation *fused = rewriter.create<GenericOp>( - linalgOp.getLoc(), resultTypes, expandedOpOperands, outputs, - expandedOpIndexingMaps, iteratorTypes); + Operation *fused = GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes, + expandedOpOperands, outputs, + expandedOpIndexingMaps, iteratorTypes); Region &fusedRegion = fused->getRegion(0); Region &originalRegion = linalgOp->getRegion(0); @@ -934,8 +934,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, reassociation, /*isExpandingReshape=*/true))) return std::nullopt; - expandedOpOperands.push_back(rewriter.create<tensor::ExpandShapeOp>( - loc, expandedOperandType, opOperand->get(), reassociation, + expandedOpOperands.push_back(tensor::ExpandShapeOp::create( + rewriter, loc, expandedOperandType, opOperand->get(), reassociation, expandedOperandShape)); continue; } @@ -962,8 +962,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, reassociation, /*isExpandingReshape=*/true))) return std::nullopt; - outputs.push_back(rewriter.create<tensor::ExpandShapeOp>( - loc, expandedOutputType, opOperand.get(), reassociation, + outputs.push_back(tensor::ExpandShapeOp::create( + rewriter, loc, expandedOutputType, opOperand.get(), reassociation, expandedOutputShape)); } else { outputs.push_back(opOperand.get()); @@ -985,8 +985,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, Operation *reshapeOp, linalgOp.getMatchingIndexingMap( linalgOp.getDpsInitOperand(resultNumber)), expansionInfo); - resultVals.push_back(rewriter.create<tensor::CollapseShapeOp>( - linalgOp.getLoc(), opResult.getType(), + resultVals.push_back(tensor::CollapseShapeOp::create( + rewriter, linalgOp.getLoc(), opResult.getType(), fusedOp->getResult(resultNumber), reassociation)); } else { resultVals.push_back(fusedOp->getResult(resultNumber)); @@ -1087,8 +1087,8 @@ public: Location loc = padOp->getLoc(); RankedTensorType expandedPaddedType = paddedType.clone(expandedPaddedShape); - auto newPadOp = rewriter.create<tensor::PadOp>( - loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + auto newPadOp = tensor::PadOp::create( + rewriter, loc, expandedPaddedType, reshapeOp.getSrc(), newLow, newHigh, padOp.getConstantPaddingValue(), padOp.getNofold()); rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( @@ -1572,12 +1572,12 @@ static Value getCollapsedOpOperand(Location loc, LinalgOp op, // Insert a reshape to collapse the dimensions. if (isa<MemRefType>(operand.getType())) { - return builder - .create<memref::CollapseShapeOp>(loc, operand, operandReassociation) + return memref::CollapseShapeOp::create(builder, loc, operand, + operandReassociation) .getResult(); } - return builder - .create<tensor::CollapseShapeOp>(loc, operand, operandReassociation) + return tensor::CollapseShapeOp::create(builder, loc, operand, + operandReassociation) .getResult(); } @@ -1604,7 +1604,7 @@ static void generateCollapsedIndexingRegion( enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) { ReassociationIndicesRef foldedDimsRef(foldedDims.value()); Value newIndexVal = - rewriter.create<linalg::IndexOp>(loc, foldedDims.index()); + linalg::IndexOp::create(rewriter, loc, foldedDims.index()); for (auto dim : llvm::reverse(foldedDimsRef.drop_front())) { Value loopDim = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange[dim]); @@ -1688,9 +1688,10 @@ GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter, SmallVector<utils::IteratorType> iteratorTypes(getCollapsedOpIteratorTypes( origOp.getIteratorTypesArray(), collapsingInfo)); - GenericOp collapsedOp = rewriter.create<linalg::GenericOp>( - origOp.getLoc(), resultTypes, inputOperands, outputOperands, indexingMaps, - iteratorTypes, [](OpBuilder &builder, Location loc, ValueRange args) {}); + GenericOp collapsedOp = linalg::GenericOp::create( + rewriter, origOp.getLoc(), resultTypes, inputOperands, outputOperands, + indexingMaps, iteratorTypes, + [](OpBuilder &builder, Location loc, ValueRange args) {}); Block *origOpBlock = &origOp->getRegion(0).front(); Block *collapsedOpBlock = &collapsedOp->getRegion(0).front(); rewriter.mergeBlocks(origOpBlock, collapsedOpBlock, @@ -1795,12 +1796,12 @@ FailureOr<CollapseResult> mlir::linalg::collapseOpIterationDims( if (isa<MemRefType>(collapsedOpResult.getType())) { MemRefType expandShapeResultType = MemRefType::get( originalResultType.getShape(), originalResultType.getElementType()); - result = rewriter.create<memref::ExpandShapeOp>( - loc, expandShapeResultType, collapsedOpResult, reassociation, - resultShape); + result = memref::ExpandShapeOp::create( + rewriter, loc, expandShapeResultType, collapsedOpResult, + reassociation, resultShape); } else { - result = rewriter.create<tensor::ExpandShapeOp>( - loc, originalResultType, collapsedOpResult, reassociation, + result = tensor::ExpandShapeOp::create( + rewriter, loc, originalResultType, collapsedOpResult, reassociation, resultShape); } results.push_back(result); @@ -1983,8 +1984,8 @@ public: RankedTensorType collapsedPaddedType = paddedType.clone(collapsedPaddedShape); - auto newPadOp = rewriter.create<tensor::PadOp>( - loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, + auto newPadOp = tensor::PadOp::create( + rewriter, loc, collapsedPaddedType, reshapeOp.getSrc(), newLow, newHigh, padOp.getConstantPaddingValue(), padOp.getNofold()); rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( @@ -2118,17 +2119,18 @@ public: // Create a constant scalar value from the splat constant. Value scalarConstant = - rewriter.create<arith::ConstantOp>(def->getLoc(), constantAttr); + arith::ConstantOp::create(rewriter, def->getLoc(), constantAttr); SmallVector<Value> outputOperands = genericOp.getOutputs(); - auto fusedOp = rewriter.create<GenericOp>( - rewriter.getFusedLoc(fusedLocs), genericOp->getResultTypes(), - /*inputs=*/fusedOperands, - /*outputs=*/outputOperands, - rewriter.getAffineMapArrayAttr(fusedIndexMaps), - genericOp.getIteratorTypes(), - /*doc=*/nullptr, - /*library_call=*/nullptr); + auto fusedOp = + GenericOp::create(rewriter, rewriter.getFusedLoc(fusedLocs), + genericOp->getResultTypes(), + /*inputs=*/fusedOperands, + /*outputs=*/outputOperands, + rewriter.getAffineMapArrayAttr(fusedIndexMaps), + genericOp.getIteratorTypes(), + /*doc=*/nullptr, + /*library_call=*/nullptr); // Map the block argument corresponding to the replaced argument with the // scalar constant. @@ -2184,8 +2186,8 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> { modifiedOutput = true; SmallVector<OpFoldResult> mixedSizes = tensor::getMixedSizes(rewriter, loc, operandVal); - Value emptyTensor = rewriter.create<tensor::EmptyOp>( - loc, mixedSizes, operandType.getElementType()); + Value emptyTensor = tensor::EmptyOp::create( + rewriter, loc, mixedSizes, operandType.getElementType()); op->setOperand(opOperand.getOperandNumber(), emptyTensor); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index c4af09c..c523153 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -64,8 +64,8 @@ getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { continue; // Extract static / dynamic shape mix from the first operand. - res.push_back(b.create<tensor::EmptyOp>( - loc, tensor::getMixedSizes(b, loc, operands.front()), + res.push_back(tensor::EmptyOp::create( + b, loc, tensor::getMixedSizes(b, loc, operands.front()), cast<RankedTensorType>(t).getElementType())); } return res; @@ -104,7 +104,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), resultTypes, op->getAttrs()); - builder.create<linalg::YieldOp>(loc, scalarOp->getResults()); + linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp index d375878..9974ccd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -259,8 +259,8 @@ mlir::linalg::deduplicateOperandsAndRemoveDeadResults( for (Value v : newOutputOperands) if (isa<TensorType>(v.getType())) newResultTypes.push_back(v.getType()); - auto newOp = rewriter.create<GenericOp>( - loc, newResultTypes, newInputOperands, newOutputOperands, + auto newOp = GenericOp::create( + rewriter, loc, newResultTypes, newInputOperands, newOutputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), genericOp.getIteratorTypes(), genericOp.getDocAttr(), genericOp.getLibraryCallAttr(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp index 44469bc..0ca8904 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusePadOpWithLinalgProducer.cpp @@ -72,14 +72,14 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> { // Create the tensor of same size as output of the pad op. RankedTensorType padResultType = padOp.getResultType(); auto resultSizes = resultShape[0]; - auto emptyTensor = rewriter.create<tensor::EmptyOp>( - loc, resultSizes, padResultType.getElementType()); + auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, resultSizes, + padResultType.getElementType()); // Fill the tensor with the pad value. // TODO: There is an option to fill only the boundaries. For now just // filling the whole tensor. - auto fillTensor = - rewriter.create<linalg::FillOp>(loc, padValue, emptyTensor.getResult()); + auto fillTensor = linalg::FillOp::create(rewriter, loc, padValue, + emptyTensor.getResult()); // Construct a slice of the fill result that is to be replaced with the // result of the generic op. The low pad values are the offsets, the size of @@ -93,15 +93,15 @@ struct FusePadOp : OpRewritePattern<tensor::PadOp> { llvm::enumerate(cast<RankedTensorType>(source.getType()).getShape())) { if (ShapedType::isDynamic(shape.value())) { sizes.push_back( - rewriter.create<tensor::DimOp>(loc, source, shape.index()) + tensor::DimOp::create(rewriter, loc, source, shape.index()) .getResult()); } else { sizes.push_back(rewriter.getIndexAttr(shape.value())); } } SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1)); - auto slice = rewriter.create<tensor::ExtractSliceOp>( - loc, fillTensor.getResult(0), offsets, sizes, strides); + auto slice = tensor::ExtractSliceOp::create( + rewriter, loc, fillTensor.getResult(0), offsets, sizes, strides); // Clone the generic op. auto clonedOp = diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 9bc7be2..41252c6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -277,7 +277,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, // mismatches. Insert a `tensor.cast` op to propagate the transformation // invariant that types are compatible. if (consumerType != def.getType()) - def = b.create<tensor::CastOp>(fusedProducer.getLoc(), consumerType, def); + def = tensor::CastOp::create(b, fusedProducer.getLoc(), consumerType, def); consumerOpOperand.set(def); return FusionInfo{cast<LinalgOp>(producerOpResult.getOwner()), fusedProducer}; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp index 05f2157..3e31393 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -61,8 +61,9 @@ FailureOr<GenericOp> mlir::linalg::generalizeNamedOp(RewriterBase &rewriter, // All named ops have a region attached that can be inlined. assert(linalgOp->getNumRegions() == 1 && "expect named op to have one region attached"); - GenericOp genericOp = rewriter.create<GenericOp>( - linalgOp.getLoc(), resultTypes, inputs, outputs, indexingMaps, iterators); + GenericOp genericOp = + GenericOp::create(rewriter, linalgOp.getLoc(), resultTypes, inputs, + outputs, indexingMaps, iterators); rewriter.inlineRegionBefore(linalgOp->getRegion(0), genericOp.getRegion(), genericOp.getRegion().begin()); rewriter.replaceOp(linalgOp, genericOp->getResults()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index 94ed464..fd530f2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -591,8 +591,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl( // Create a packing loop that takes `hoistedPackedTensor` as iteration // argument. - auto clonedForOp = rewriter.create<scf::ForOp>( - loc, bvm.lookupOrDefault(forOp.getLowerBound()), + auto clonedForOp = scf::ForOp::create( + rewriter, loc, bvm.lookupOrDefault(forOp.getLowerBound()), bvm.lookupOrDefault(forOp.getUpperBound()), bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor); @@ -640,11 +640,11 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl( TransposeOp maybeTransposeOp; Value paddedTensor = bvm.lookup(opToHoist.getResult()); if (!transposeVector.empty()) { - Value outputTensor = rewriter.create<tensor::ExtractSliceOp>( - loc, transposedTensorType, hoistedPackedTensor, offsets, sizes, - strides); - maybeTransposeOp = rewriter.create<linalg::TransposeOp>( - loc, paddedTensor, outputTensor, transposeVector); + Value outputTensor = tensor::ExtractSliceOp::create( + rewriter, loc, transposedTensorType, hoistedPackedTensor, offsets, + sizes, strides); + maybeTransposeOp = linalg::TransposeOp::create( + rewriter, loc, paddedTensor, outputTensor, transposeVector); paddedTensor = maybeTransposeOp.getResult()[0]; } @@ -652,15 +652,16 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl( if (nPackedLoops > 0) { // Step 4. Create InsertSliceOp at the innermost loop level, inserting an // optionally transposed padded slice into the packed tensor. - Value inserted = rewriter.create<tensor::InsertSliceOp>( - loc, paddedTensor, hoistedPackedTensor, offsets, sizes, strides); + Value inserted = tensor::InsertSliceOp::create(rewriter, loc, paddedTensor, + hoistedPackedTensor, offsets, + sizes, strides); // Step 5. Iteratively pop the stack and propagate the yield. Value valueToYield = inserted; for (Value iv : llvm::reverse(clonedLoopIvs)) { auto forOp = scf::getForInductionVarOwner(iv); rewriter.setInsertionPointToEnd(&forOp.getRegion().front()); - rewriter.create<scf::YieldOp>(loc, valueToYield); + scf::YieldOp::create(rewriter, loc, valueToYield); valueToYield = forOp.getResult(0); } } @@ -712,8 +713,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl( rewriter.setInsertionPoint(outerLoop); SmallVector<Value> dynamicTensorSizes = analysis.getHoistedPackedTensorSizes(rewriter, loc); - auto emptyOp = rewriter.create<tensor::EmptyOp>( - loc, hoistedPackedTensorType.getShape(), + auto emptyOp = tensor::EmptyOp::create( + rewriter, loc, hoistedPackedTensorType.getShape(), hoistedPackedTensorType.getElementType(), dynamicTensorSizes); return buildPackingLoopNestImpl(rewriter, bvm, opToHoist, transposeVector, @@ -756,8 +757,7 @@ static bool tracesBackToExpectedValue(tensor::ExtractSliceOp extractSliceOp, Value source = extractSliceOp.getSource(); LLVM_DEBUG(DBGS() << "--with starting source: " << source << "\n"); while (source && source != expectedSource) { - auto destOp = - dyn_cast_or_null<DestinationStyleOpInterface>(source.getDefiningOp()); + auto destOp = source.getDefiningOp<DestinationStyleOpInterface>(); if (!destOp) break; LLVM_DEBUG(DBGS() << "--step dest op: " << destOp << "\n"); @@ -840,8 +840,8 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting, { OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(forOp); - extracted = rewriter.create<tensor::ExtractSliceOp>( - hoistedPackedTensor.getLoc(), hoistedPackedTensor, + extracted = tensor::ExtractSliceOp::create( + rewriter, hoistedPackedTensor.getLoc(), hoistedPackedTensor, outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(), outerSliceOp.getMixedStrides()); rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted); @@ -934,8 +934,8 @@ static Value replaceByPackingResult(RewriterBase &rewriter, // offsets = [maybe_leading_ivs, 0 .. 0]. // sizes = [1 .. 1, transposedShape] (defined above). // strides = [1 .. 1] (defined above) - return rewriter.create<tensor::ExtractSliceOp>( - loc, transposedTensorType, hoistedPackedTensor, offsets, + return tensor::ExtractSliceOp::create( + rewriter, loc, transposedTensorType, hoistedPackedTensor, offsets, packingResult.sizes, packingResult.strides); } @@ -982,10 +982,11 @@ FailureOr<Value> mlir::linalg::hoistPaddingOnTensors( OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPointAfter(newResult.getDefiningOp()); // Transpose the packed tensor back to the original storage order. - Value emptyTensor = rewriter.create<tensor::EmptyOp>( - loc, paddedTensorType.getShape(), paddedTensorType.getElementType()); - TransposeOp unTransposeOp = rewriter.create<linalg::TransposeOp>( - loc, newResult, emptyTensor, transposeVector); + Value emptyTensor = + tensor::EmptyOp::create(rewriter, loc, paddedTensorType.getShape(), + paddedTensorType.getElementType()); + TransposeOp unTransposeOp = linalg::TransposeOp::create( + rewriter, loc, newResult, emptyTensor, transposeVector); newResult = unTransposeOp.getResult()[0]; transposeOps.push_back(unTransposeOp); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index f2e51c29..58986a6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -53,9 +53,9 @@ static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, assert(index < inits.size()); inits[index] = newInitOperand; - scf::ForOp newLoop = rewriter.create<scf::ForOp>( - loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), - inits, [](OpBuilder &, Location, Value, ValueRange) {}); + scf::ForOp newLoop = scf::ForOp::create( + rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), + loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); // Generate the new yield with the replaced operand. auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator()); @@ -165,8 +165,7 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, Value source = transferRead.getBase(); // Skip view-like Ops and retrive the actual soruce Operation - while (auto srcOp = - dyn_cast_or_null<ViewLikeOpInterface>(source.getDefiningOp())) + while (auto srcOp = source.getDefiningOp<ViewLikeOpInterface>()) source = srcOp.getViewSource(); llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp index 1f3336d..39cc21d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/InlineScalarOperands.cpp @@ -60,9 +60,9 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> { Location loc = genericOp->getLoc(); SmallVector<Value> outputOperands = genericOp.getOutputs(); - auto newOp = rewriter.create<GenericOp>( - loc, genericOp->getResultTypes(), newOperands, outputOperands, - newIndexingMaps, genericOp.getIteratorTypesArray()); + auto newOp = GenericOp::create(rewriter, loc, genericOp->getResultTypes(), + newOperands, outputOperands, newIndexingMaps, + genericOp.getIteratorTypesArray()); rewriter.cloneRegionBefore(genericOp.getRegion(), newOp.getRegion(), newOp.getRegion().begin()); @@ -77,11 +77,11 @@ struct InlineScalarOperands : public OpRewritePattern<GenericOp> { SmallVector<Value> indicesValues; for (auto idx : indices) indicesValues.emplace_back( - rewriter.create<arith::ConstantIndexOp>(loc, idx)); + arith::ConstantIndexOp::create(rewriter, loc, idx)); Value scalarValue = opOperand->get(); if (isa<RankedTensorType>(scalarValue.getType())) { - scalarValue = - rewriter.create<tensor::ExtractOp>(loc, scalarValue, indicesValues); + scalarValue = tensor::ExtractOp::create(rewriter, loc, scalarValue, + indicesValues); } body->getArgument(idx).replaceAllUsesWith(scalarValue); body->eraseArgument(idx); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp index a92a0c8..96e6eee 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -88,7 +88,8 @@ mlir::linalg::interchangeGenericOp(RewriterBase &rewriter, GenericOp genericOp, allIndices.reserve(genericOp.getNumLoops()); llvm::transform(llvm::seq<uint64_t>(0, genericOp.getNumLoops()), std::back_inserter(allIndices), [&](uint64_t dim) { - return rewriter.create<IndexOp>(indexOp->getLoc(), dim); + return IndexOp::create(rewriter, indexOp->getLoc(), + dim); }); rewriter.replaceOpWithNewOp<affine::AffineApplyOp>( indexOp, permutationMap.getSubMap(indexOp.getDim()), allIndices); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index 488041a..38f1a8b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -49,7 +49,7 @@ static SmallVector<Value> makeCanonicalAffineApplies(OpBuilder &b, Location loc, auto exprMap = AffineMap::get(dims, map.getNumSymbols(), e); SmallVector<Value> operands(vals); affine::canonicalizeMapAndOperands(&exprMap, &operands); - res.push_back(b.create<affine::AffineApplyOp>(loc, exprMap, operands)); + res.push_back(affine::AffineApplyOp::create(b, loc, exprMap, operands)); } return res; } @@ -70,8 +70,9 @@ static void inlineRegionAndEmitStore(OpBuilder &b, Location loc, OpType op, Operation *terminator = block.getTerminator(); for (OpOperand &operand : terminator->getOpOperands()) { Value toStore = map.lookupOrDefault(operand.get()); - b.create<StoreOpTy>(loc, toStore, outputBuffers[operand.getOperandNumber()], - indexing[operand.getOperandNumber()]); + StoreOpTy::create(b, loc, toStore, + outputBuffers[operand.getOperandNumber()], + indexing[operand.getOperandNumber()]); } } @@ -145,7 +146,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc, auto indexing = makeCanonicalAffineApplies( b, loc, linalgOp.getMatchingIndexingMap(inputOperand), allIvsPlusDims); indexedValues.push_back( - b.create<LoadOpTy>(loc, inputOperand->get(), indexing)); + LoadOpTy::create(b, loc, inputOperand->get(), indexing)); } // 1.b. Emit load from output views. for (OpOperand &outputOperand : linalgOp.getDpsInitsMutable()) { @@ -153,7 +154,7 @@ static void emitScalarImplementation(OpBuilder &b, Location loc, b, loc, linalgOp.getMatchingIndexingMap(&outputOperand), allIvsPlusDims); indexedValues.push_back( - b.create<LoadOpTy>(loc, outputOperand.get(), indexing)); + LoadOpTy::create(b, loc, outputOperand.get(), indexing)); } // TODO: When a region inliner exists, use it. diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp index bb1e974..a2bd9d9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp @@ -59,8 +59,8 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, auto newKernelTy = RankedTensorType::get( {kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)}, kernelTy.getElementType()); - auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>( - loc, newKernelTy, kernel, collapsedKernelDims); + auto collapsedKernel = tensor::CollapseShapeOp::create( + rewriter, loc, newKernelTy, kernel, collapsedKernelDims); // Collapse init dims. SmallVector<ReassociationIndices, 4> collapsedInitDims = { @@ -70,22 +70,23 @@ matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel, RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1), initTy.getDimSize(2), initTy.getDimSize(3)}, initTy.getElementType()); - auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>( - loc, newInitTy, init, collapsedInitDims); + auto collapsedInit = tensor::CollapseShapeOp::create(rewriter, loc, newInitTy, + init, collapsedInitDims); SmallVector<NamedAttribute> preservedAttrs; Operation *newConv = TypeSwitch<Operation *, Operation *>(operation) .Case<DepthwiseConv2DNhwcHwcmOp>([&](auto op) { preservedAttrs = getPrunedAttributeList(op); - return rewriter.create<DepthwiseConv2DNhwcHwcOp>( - loc, newInitTy, ValueRange{input, collapsedKernel}, + return DepthwiseConv2DNhwcHwcOp::create( + rewriter, loc, newInitTy, ValueRange{input, collapsedKernel}, ValueRange{collapsedInit}, stride, dilation); }) .Case<DepthwiseConv2DNhwcHwcmQOp>([&](auto op) { preservedAttrs = getPrunedAttributeList(op); - return rewriter.create<DepthwiseConv2DNhwcHwcQOp>( - loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp}, + return DepthwiseConv2DNhwcHwcQOp::create( + rewriter, loc, newInitTy, + ValueRange{input, collapsedKernel, iZp, kZp}, ValueRange{collapsedInit}, stride, dilation); }) .Default([](Operation *op) { return nullptr; }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp index 2afa2f9..9d7f4e0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -81,9 +82,8 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> { ArrayRef<ReassociationIndices> reassociation) const { if (operand.getType() == newOperandType) return operand; - return rewriter - .create<tensor::ExpandShapeOp>(loc, newOperandType, operand, - reassociation) + return tensor::ExpandShapeOp::create(rewriter, loc, newOperandType, operand, + reassociation) .getResult(); } @@ -143,8 +143,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> { Type newOperandType, ArrayAttr reassociation) const { if (operand.getType() == newOperandType) return operand; - return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType, - operand, reassociation); + return tensor::CollapseShapeOp::create(rewriter, loc, newOperandType, + operand, reassociation); } /// Returns success() if it is unpacking on the innermost dimension. @@ -220,6 +220,33 @@ public: if (!isEqualConstantIntOrValue(paddingValue, constantPaddingValue)) return failure(); + // Folding is not allowed if it were to introduce artificial padding. + // Folding is also disabled in the case of dynamic dimensions and/or tile + // sizes - that is because it would be impossible to compute the padding + // size and hence to establish whether "artificial" padding would be + // created. + RankedTensorType unpackedType = packOp.getSourceType(); + SmallVector<int64_t> outerShapeWithoutTranspose = + getPackedOuterShapeWithoutTransposition(packOp); + for (auto [pos, tileSize, high] : + llvm::zip_equal(packOp.getInnerDimsPos(), packOp.getStaticInnerTiles(), + padOp.getMixedHighPad())) { + if (unpackedType.isDynamicDim(pos)) + return failure(); + if (ShapedType::isDynamic(outerShapeWithoutTranspose[pos])) + return failure(); + if (ShapedType::isDynamic(tileSize)) + return failure(); + std::optional<int64_t> cstHigh = getConstantIntValue(high); + if (!cstHigh) + return failure(); + int64_t paddingSize = outerShapeWithoutTranspose[pos] * tileSize - + unpackedType.getDimSize(pos); + // Do not fold the op if it requires artificial padding. + if (paddingSize + cstHigh.value() >= tileSize) + return failure(); + } + rewriter.replaceOpWithNewOp<PackOp>( packOp, padOp.getSource(), packOp.getDest(), packOp.getInnerDimsPos(), packOp.getMixedTiles(), constantPaddingValue, @@ -251,22 +278,13 @@ public: if (controlFn && !controlFn(&sliceOp.getSourceMutable())) return failure(); - if (sliceOp.getResultType().getRank() != unpackOp.getDestType().getRank()) { - return rewriter.notifyMatchFailure( - sliceOp, "rank-reduced folding is not supported"); - } - - // Check all offsets are zeros, and all strides are ones. - if (!areAllConstantIntValue(sliceOp.getMixedOffsets(), 0) || - !areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { - return rewriter.notifyMatchFailure( - sliceOp, "expects offsets to be 0s and strides to be 1s"); - } + if (!unpackOp.canFoldSliceOp(sliceOp)) + return failure(); // Create a new empty output tensor. Type elementType = unpackOp.getDestType().getElementType(); - Value output = rewriter.create<tensor::EmptyOp>( - sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); + Value output = tensor::EmptyOp::create( + rewriter, sliceOp.getLoc(), sliceOp.getMixedSizes(), elementType); rewriter.replaceOpWithNewOp<UnPackOp>( sliceOp, unpackOp.getSource(), output, unpackOp.getInnerDimsPos(), unpackOp.getMixedTiles(), unpackOp.getOuterDimsPerm()); @@ -529,8 +547,8 @@ public: auto elemType = cast<ShapedType>(unPackOp->getResultTypes()[0]).getElementType(); - Value output = rewriter.create<tensor::EmptyOp>( - unPackOp->getLoc(), unpackOpResultDims[0], elemType); + Value output = tensor::EmptyOp::create(rewriter, unPackOp->getLoc(), + unpackOpResultDims[0], elemType); rewriter.replaceOpWithNewOp<UnPackOp>( unPackOp, linalgOp->getOperand(0), output, newInnerDimsPosVec, diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 5eb3761..2e62523 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes, return paddingSizes; } +/// Extracts the constant multiplier from an affine expression of the form +/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an +/// AffineConstantExpr. Returns 1 if the expression is not a simple +/// multiplication of a dimension and a constant. +static int64_t extractConstantMultiplier(AffineExpr expr) { + if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) { + if (binOp.getKind() == AffineExprKind::Mul) { + auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS()); + auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS()); + if (lhsD && rhsC) { + return rhsC.getValue(); + } + auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS()); + auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS()); + if (lhsC && rhsD) { + return lhsC.getValue(); + } + } + } + return 1; +} + /// Compute the padded shape of the given value `v` of `RankedTensorType` given /// - `indexingSizes` a list of OpFoldResult. /// - an `indexingMap` that encodes how the shape of varies with increases @@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes, /// The `indexingMap` + `indexingSizes` encoding suits StructuredOps. /// The implementaiton below iteratively combines increases from contributing /// dimensions using affine.apply operations. +/// The padded shape is computed by evaluating the maximum accessed index per +/// dimension, which may involve multiplying by constant factors derived from +/// the affine indexing expressions. Currently, only a limited set of projected +/// permutation indexing maps are supported, such as +/// - affine_map<(d0, d1, d2) -> (d0, d1)> +/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)> +/// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector<OpFoldResult> linalg::computePaddedShape( @@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( /*compressDims=*/true); // If we are padding to the next multiple of, compose with ceil(sz) * sz. + OpFoldResult paddingDimOfr; if (options.padToMultipleOf) { AffineExpr d0, s0; bindDims(rewriter.getContext(), d0); bindSymbols(rewriter.getContext(), s0); AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0); AffineMap composedMap = projectedMap.compose(ceilMap); - OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( + paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, composedMap, {indexingSizes[paddingDim], paddingSize}, /*composeAffineMin=*/true); - terms.push_back(paddingDimOfr); } else { // Otherwise just set to paddingSize. - OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply( + paddingDimOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, projectedMap, paddingSize); - terms.push_back(paddingDimOfr); } + // Adjust for the maximum accessed index, which is (paddingSize - 1) * + // multiplier. + AffineExpr d0; + bindDims(rewriter.getContext(), d0); + int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0)); + AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier); + OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply( + rewriter, loc, subtractMap, {paddingDimOfr}); + terms.push_back(maxAccessIdx); + LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n"); } @@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape( AffineExpr sumExpr = dims.front(); for (unsigned i = 1; i < dims.size(); ++i) sumExpr = sumExpr + dims[i]; - OpFoldResult paddedDimOfr = - affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms); + // Add 1 to the maximum accessed index and get the final padded size. + OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply( + rewriter, loc, sumExpr + 1, terms); paddedShape[resultIndex] = paddedDimOfr; } @@ -192,11 +231,11 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, if (auto complexTy = dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) { auto complexAttr = cast<ArrayAttr>(paddingValueAttr); - paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(), - complexTy, complexAttr); + paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), + complexTy, complexAttr); } else { - paddingValue = rewriter.create<arith::ConstantOp>( - opToPad.getLoc(), cast<TypedAttr>(paddingValueAttr)); + paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(), + cast<TypedAttr>(paddingValueAttr)); } // Pad the operand to the bounding box defined by `paddedShape`. @@ -323,8 +362,8 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank(); SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); - paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>( - loc, paddedResult, offsets, reifiedResultShapes[resultNumber], + paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create( + rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber], strides)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp index dc9e11e..dd84379 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Padding.cpp @@ -219,11 +219,11 @@ static FailureOr<Value> padOperandToSmallestStaticBoundingBox( if (auto complexTy = dyn_cast<ComplexType>( getElementTypeOrSelf(opOperand->get().getType()))) { auto complexAttr = cast<ArrayAttr>(paddingAttr); - paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(), - complexTy, complexAttr); + paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), + complexTy, complexAttr); } else { - paddingValue = rewriter.create<arith::ConstantOp>( - opToPad.getLoc(), cast<TypedAttr>(paddingAttr)); + paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(), + cast<TypedAttr>(paddingAttr)); } // Computes the padded shape. @@ -313,8 +313,8 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank(); SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0)); SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1)); - paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>( - loc, paddedResult, offsets, reifiedResultShapes[resultNumber], + paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create( + rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber], strides)); } @@ -333,17 +333,16 @@ linalg::rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, for (auto it : llvm::zip(paddedSubtensorResults, opToPad.getDpsInitsMutable())) { if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp::LinalgCopy) { - replacements.push_back(rewriter - .create<linalg::CopyOp>(loc, std::get<0>(it), - std::get<1>(it).get()) + replacements.push_back(linalg::CopyOp::create(rewriter, loc, + std::get<0>(it), + std::get<1>(it).get()) .getResult(0)); } else if (options.copyBackOp == LinalgPaddingOptions::CopyBackOp:: BufferizationMaterializeInDestination) { replacements.push_back( - rewriter - .create<bufferization::MaterializeInDestinationOp>( - loc, std::get<0>(it), std::get<1>(it).get()) + bufferization::MaterializeInDestinationOp::create( + rewriter, loc, std::get<0>(it), std::get<1>(it).get()) ->getResult(0)); } else { llvm_unreachable("unsupported copy back op"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index 0433016..f05ffa8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -62,11 +62,11 @@ static Value allocBuffer(ImplicitLocOpBuilder &b, staticBufferType = MemRefType::Builder(staticBufferType).setMemorySpace(memorySpaceAttr); if (options.useAlloca) { - return b.create<memref::AllocaOp>(staticBufferType, ValueRange{}, - alignmentAttr); + return memref::AllocaOp::create(b, staticBufferType, ValueRange{}, + alignmentAttr); } - return b.create<memref::AllocOp>(staticBufferType, ValueRange{}, - alignmentAttr); + return memref::AllocOp::create(b, staticBufferType, ValueRange{}, + alignmentAttr); } // Fallback dynamic buffer. @@ -75,10 +75,10 @@ static Value allocBuffer(ImplicitLocOpBuilder &b, dynamicBufferType = MemRefType::Builder(dynamicBufferType).setMemorySpace(memorySpaceAttr); Value mul = b.createOrFold<arith::MulIOp>( - b.create<arith::ConstantIndexOp>(width), allocSize); + arith::ConstantIndexOp::create(b, width), allocSize); if (options.useAlloca) - return b.create<memref::AllocaOp>(dynamicBufferType, mul, alignmentAttr); - return b.create<memref::AllocOp>(dynamicBufferType, mul, alignmentAttr); + return memref::AllocaOp::create(b, dynamicBufferType, mul, alignmentAttr); + return memref::AllocOp::create(b, dynamicBufferType, mul, alignmentAttr); } /// Default allocation callback function. This allocates a promoted buffer when @@ -91,8 +91,8 @@ static std::optional<Value> defaultAllocBufferCallBack( std::optional<unsigned> alignment, DataLayout &layout) { ShapedType viewType = subView.getType(); ImplicitLocOpBuilder b(subView.getLoc(), builder); - auto zero = b.create<arith::ConstantIndexOp>(0); - auto one = b.create<arith::ConstantIndexOp>(1); + auto zero = arith::ConstantIndexOp::create(b, 0); + auto one = arith::ConstantIndexOp::create(b, 1); Attribute memorySpaceAttr; if (options.memorySpace.has_value()) @@ -122,8 +122,8 @@ defaultDeallocBufferCallBack(const LinalgPromotionOptions &options, OpBuilder &b, Value fullLocalView) { if (!options.useAlloca) { auto viewOp = cast<memref::ViewOp>(fullLocalView.getDefiningOp()); - b.create<memref::DeallocOp>(viewOp.getSource().getLoc(), - viewOp.getSource()); + memref::DeallocOp::create(b, viewOp.getSource().getLoc(), + viewOp.getSource()); } return success(); } @@ -210,7 +210,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( Location loc = linalgOp.getLoc(); auto defaultCopyCallBack = [loc](OpBuilder &b, Value src, Value dst) -> LogicalResult { - b.create<linalg::CopyOp>(loc, src, dst); + linalg::CopyOp::create(b, loc, src, dst); return success(); }; copyInFn = (options.copyInFn ? *(options.copyInFn) : defaultCopyCallBack); @@ -264,7 +264,7 @@ FailureOr<PromotionInfo> mlir::linalg::promoteSubviewAsNewBuffer( /*stopCondition=*/nullptr, /*closedUB=*/true); size = failed(upperBound) ? getValueOrCreateConstantIndexOp(b, loc, rangeValue.size) - : b.create<arith::ConstantIndexOp>(loc, *upperBound); + : arith::ConstantIndexOp::create(b, loc, *upperBound); } LLVM_DEBUG(llvm::dbgs() << "Extracted tightest: " << size << "\n"); fullSizes.push_back(size); @@ -309,23 +309,23 @@ promoteSubViews(ImplicitLocOpBuilder &b, Value fillVal = llvm::TypeSwitch<Type, Value>(subviewEltType) .Case([&](FloatType t) { - return b.create<arith::ConstantOp>(FloatAttr::get(t, 0.0)); + return arith::ConstantOp::create(b, FloatAttr::get(t, 0.0)); }) .Case([&](IntegerType t) { - return b.create<arith::ConstantOp>(IntegerAttr::get(t, 0)); + return arith::ConstantOp::create(b, IntegerAttr::get(t, 0)); }) .Case([&](ComplexType t) { Value tmp; if (auto et = dyn_cast<FloatType>(t.getElementType())) - tmp = b.create<arith::ConstantOp>(FloatAttr::get(et, 0.0)); + tmp = arith::ConstantOp::create(b, FloatAttr::get(et, 0.0)); else if (auto et = cast<IntegerType>(t.getElementType())) - tmp = b.create<arith::ConstantOp>(IntegerAttr::get(et, 0)); - return b.create<complex::CreateOp>(t, tmp, tmp); + tmp = arith::ConstantOp::create(b, IntegerAttr::get(et, 0)); + return complex::CreateOp::create(b, t, tmp, tmp); }) .Default([](auto) { return Value(); }); if (!fillVal) return failure(); - b.create<linalg::FillOp>(fillVal, promotionInfo->fullLocalView); + linalg::FillOp::create(b, fillVal, promotionInfo->fullLocalView); } // Copy data into the promoted buffers. Use callback if provided. @@ -458,9 +458,9 @@ static std::optional<Value> allocateSubviewGPUMemoryInAddressSpace( gpu::AddressSpaceAttr::get(builder.getContext(), addressSpace)); Value buffer; if (addressSpace == gpu::GPUDialect::getWorkgroupAddressSpace()) { - buffer = builder.create<memref::AllocOp>(funcOp.getLoc(), type); + buffer = memref::AllocOp::create(builder, funcOp.getLoc(), type); } else if (addressSpace == gpu::GPUDialect::getPrivateAddressSpace()) { - buffer = builder.create<memref::AllocaOp>(funcOp.getLoc(), type); + buffer = memref::AllocaOp::create(builder, funcOp.getLoc(), type); } else { return std::nullopt; } @@ -486,9 +486,9 @@ LogicalResult mlir::linalg::deallocateWorkgroupMemory(OpBuilder &, /// the copy operation to ensure data integrity. LogicalResult mlir::linalg::copyToWorkgroupMemory(OpBuilder &b, Value src, Value dst) { - b.create<gpu::BarrierOp>(src.getLoc()); - Operation *copyOp = b.create<memref::CopyOp>(src.getLoc(), src, dst); - b.create<gpu::BarrierOp>(copyOp->getLoc()); + gpu::BarrierOp::create(b, src.getLoc()); + Operation *copyOp = memref::CopyOp::create(b, src.getLoc(), src, dst); + gpu::BarrierOp::create(b, copyOp->getLoc()); return success(); } @@ -503,7 +503,7 @@ std::optional<Value> mlir::linalg::allocateGPUPrivateMemory( /// Normal copy to between src and dst. LogicalResult mlir::linalg::copyToGPUPrivateMemory(OpBuilder &b, Value src, Value dst) { - b.create<memref::CopyOp>(src.getLoc(), src, dst); + memref::CopyOp::create(b, src.getLoc(), src, dst); return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp index b30182d..eac0e47 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/RuntimeOpVerification.cpp @@ -38,8 +38,8 @@ struct StructuredOpInterface SmallVector<Range> loopRanges = linalgOp.createLoopRanges(builder, loc); auto [starts, ends, _] = getOffsetsSizesAndStrides(loopRanges); - auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); - auto one = builder.create<arith::ConstantIndexOp>(loc, 1); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); + auto one = arith::ConstantIndexOp::create(builder, loc, 1); // Subtract one from the loop ends before composing with the indexing map transform(ends, ends.begin(), [&](OpFoldResult end) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp index 24b8765..f277c5f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp @@ -1,4 +1,4 @@ -//===- MeshShardingInterfaceImpl.cpp --------------------------------------===// +//===- ShardingInterfaceImpl.cpp --------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,18 +6,18 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.h" +#include "mlir/Dialect/Linalg/Transforms/ShardingInterfaceImpl.h" #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" @@ -36,13 +36,13 @@ namespace mlir::linalg { -using MeshAxis = mesh::MeshAxis; -using ReductionKind = mesh::ReductionKind; -using MeshSharding = mesh::MeshSharding; -using ShardingArray = mesh::ShardingArray; -using MeshOp = mesh::MeshOp; +using GridAxis = shard::GridAxis; +using ReductionKind = shard::ReductionKind; +using Sharding = shard::Sharding; +using ShardingArray = shard::ShardingArray; +using GridOp = shard::GridOp; -// Returns the corresponding mesh reduction kind for the given arith op. +// Returns the corresponding grid reduction kind for the given arith op. static ReductionKind getReductionKind(Operation *op) { return llvm::TypeSwitch<Operation *, ReductionKind>(op) // Floating-point operations. @@ -97,18 +97,18 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) { return getReductionKind(reductionOp.value()); } -static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings, +static GridOp getGrid(Operation *op, ArrayRef<Sharding> operandShardings, + ArrayRef<Sharding> resultShardings, SymbolTableCollection &symbolTable) { - for (const MeshSharding &sharding : operandShardings) { + for (const Sharding &sharding : operandShardings) { if (sharding) { - return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); + return shard::getGrid(op, sharding.getGridAttr(), symbolTable); } } - for (const MeshSharding &sharding : resultShardings) { + for (const Sharding &sharding : resultShardings) { if (sharding) { - return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable); + return shard::getGrid(op, sharding.getGridAttr(), symbolTable); } } @@ -117,29 +117,29 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings, } // Choose the operand based on the current process index along the reduction -// mesh axes. +// grid axes. // We need to use the initial value only once to avoid including it in the // reduction multiple times. // In each process group only the leading process with linear index 0 would use // the original operand. // The other processes would use the reduction operation neutral tensor. static Value createDestinationPassingStyleInitOperand( - LinalgOp op, int operandNumber, Value spmdizedOperand, - ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp, + LinalgOp op, int operandNumber, Value partitionedOperand, + ArrayRef<GridAxis> reductionGridAxes, GridOp gridOp, ImplicitLocOpBuilder &builder) { - Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex( - meshOp.getSymName(), reductionMeshAxes, builder); - Value zero = builder.create<arith::ConstantIndexOp>(0); - Value isLeadProcess = builder.create<arith::CmpIOp>( - builder.getI1Type(), arith::CmpIPredicate::eq, + Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex( + gridOp.getSymName(), reductionGridAxes, builder); + Value zero = arith::ConstantIndexOp::create(builder, 0); + Value isLeadProcess = arith::CmpIOp::create( + builder, builder.getI1Type(), arith::CmpIPredicate::eq, processLinearIndexInReductionGroup, zero); - scf::IfOp ifOp = builder.create<scf::IfOp>(spmdizedOperand.getType(), - isLeadProcess, true, true); + scf::IfOp ifOp = scf::IfOp::create(builder, partitionedOperand.getType(), + isLeadProcess, true, true); // Then block. { OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointToEnd(&ifOp.getThenRegion().front()); - builder.create<scf::YieldOp>(spmdizedOperand); + scf::YieldOp::create(builder, partitionedOperand); } // Else block. @@ -147,7 +147,7 @@ static Value createDestinationPassingStyleInitOperand( OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointToEnd(&ifOp.getElseRegion().front()); SmallVector<OpFoldResult> shape = - tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand); + tensor::getMixedSizes(builder, builder.getLoc(), partitionedOperand); SmallVector<Operation *> combinerOps; matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps); @@ -155,85 +155,84 @@ static Value createDestinationPassingStyleInitOperand( std::optional<TypedAttr> neutralEl = arith::getNeutralElement(combinerOps[0]); - Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape, - neutralEl.value().getType()); + Value init = tensor::EmptyOp::create(builder, op.getLoc(), shape, + neutralEl.value().getType()); Value constant = - builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value()); - Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init) + arith::ConstantOp::create(builder, op.getLoc(), neutralEl.value()); + Value fill = linalg::FillOp::create(builder, op.getLoc(), constant, init) .getResult(0); - builder.create<scf::YieldOp>(fill); + scf::YieldOp::create(builder, fill); } return ifOp.getResult(0); } -// Create the DPS init operands for the spmdized Linalg op. -// Return all the new spmdized operands. +// Create the DPS init operands for the partitioned Linalg op. +// Return all the new partitioned operands. static SmallVector<Value> createDestinationPassingStyleInitOperands( - LinalgOp op, MeshOp meshOp, ArrayRef<Value> spmdizedOperands, - ArrayRef<MeshAxis> reductionMeshAxes, IRMapping &spmdizationMap, + LinalgOp op, GridOp gridOp, ArrayRef<Value> partitionedOperands, + ArrayRef<GridAxis> reductionGridAxes, IRMapping &partitionMap, ImplicitLocOpBuilder &builder) { // TODO: add support for multiple destination passing style initial value // operands. assert(op.getNumDpsInits() == 1 && "Multiple initial values not supported."); - SmallVector<Value> newOperands = llvm::to_vector(spmdizedOperands); + SmallVector<Value> newOperands = llvm::to_vector(partitionedOperands); auto operandIdx = op.getDpsInitOperand(0)->getOperandNumber(); - Value spmdizedInitOperand = - spmdizationMap.lookup(op->getOperands()[operandIdx]); + Value partitionedInitOperand = + partitionMap.lookup(op->getOperands()[operandIdx]); newOperands[operandIdx] = createDestinationPassingStyleInitOperand( - op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder); + op, 0, partitionedInitOperand, reductionGridAxes, gridOp, builder); return newOperands; } static void createAllReduceForResultsWithoutPartialShardings( - LinalgOp unshardedOp, ArrayRef<MeshAxis> opReductionMeshAxes, - ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, + LinalgOp unshardedOp, ArrayRef<GridAxis> opReductionGridAxes, + ArrayRef<Sharding> resultShardings, IRMapping &partitionMap, ImplicitLocOpBuilder &builder) { ReductionKind reductionKind = getReductionKindOfLinalgOp(unshardedOp); for (auto [unshardedLinalgOpResult, resultSharding] : llvm::zip_equal(unshardedOp->getResults(), resultShardings)) { - Value spmdizedLinalgOpResult = - spmdizationMap.lookup(unshardedLinalgOpResult); - Value reducedValue = builder.create<mesh::AllReduceOp>( - spmdizedLinalgOpResult, resultSharding.getMesh(), opReductionMeshAxes, - reductionKind); - spmdizationMap.map(unshardedLinalgOpResult, reducedValue); + Value partitionedLinalgOpResult = + partitionMap.lookup(unshardedLinalgOpResult); + Value reducedValue = shard::AllReduceOp::create( + builder, partitionedLinalgOpResult, resultSharding.getGrid(), + opReductionGridAxes, reductionKind); + partitionMap.map(unshardedLinalgOpResult, reducedValue); } } -static void spmdizeLinalgOpWithShardedReduction( - LinalgOp op, ArrayRef<Value> spmdizedOperands, - ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings, +static void partitionLinalgOpWithShardedReduction( + LinalgOp op, ArrayRef<Value> partitionedOperands, + ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings, ArrayRef<utils::IteratorType> loopIteratorTypes, - ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators, - IRMapping &spmdizationMap, SymbolTableCollection &symbolTable, + ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators, + IRMapping &partitionMap, SymbolTableCollection &symbolTable, ImplicitLocOpBuilder &builder) { - MeshOp mesh = getMesh(op, operandShardings, resultShardings, symbolTable); - SmallVector<MeshAxis> reductionMeshAxes = mesh::getReductionMeshAxes( - loopIteratorTypes, meshAxisAssignmentForLoopIterators); - SmallVector<Value> spmdizedLinalgOpOperands = - createDestinationPassingStyleInitOperands(op, mesh, spmdizedOperands, - reductionMeshAxes, - spmdizationMap, builder); - // We must not change the operand mappings of the original spmdizationMap as - // they are the mappings for the whole spmdization blob and may be used by + GridOp grid = getGrid(op, operandShardings, resultShardings, symbolTable); + SmallVector<GridAxis> reductionGridAxes = shard::getReductionGridAxes( + loopIteratorTypes, gridAxisAssignmentForLoopIterators); + SmallVector<Value> partitionedLinalgOpOperands = + createDestinationPassingStyleInitOperands(op, grid, partitionedOperands, + reductionGridAxes, partitionMap, + builder); + // We must not change the operand mappings of the original partitionMap as + // they are the mappings for the whole partition blob and may be used by // others. - IRMapping internalSpmdizationMap; - for (auto [unshardedOperand, spmdizedOperand] : - llvm::zip_equal(op->getOperands(), spmdizedLinalgOpOperands)) { - internalSpmdizationMap.map(unshardedOperand, spmdizedOperand); + IRMapping internalPartitionMap; + for (auto [unshardedOperand, partitionedOperand] : + llvm::zip_equal(op->getOperands(), partitionedLinalgOpOperands)) { + internalPartitionMap.map(unshardedOperand, partitionedOperand); } - spmdizeTriviallyShardableOperation( - *op, spmdizedLinalgOpOperands, operandShardings, resultShardings, - internalSpmdizationMap, symbolTable, builder); + partitionTriviallyShardableOperation( + *op, partitionedLinalgOpOperands, operandShardings, resultShardings, + internalPartitionMap, symbolTable, builder); for (Value result : op->getResults()) { - spmdizationMap.map(result, internalSpmdizationMap.lookup(result)); + partitionMap.map(result, internalPartitionMap.lookup(result)); } // Handle partial shardings. createAllReduceForResultsWithoutPartialShardings( - op, reductionMeshAxes, resultShardings, spmdizationMap, builder); + op, reductionGridAxes, resultShardings, partitionMap, builder); } namespace { @@ -243,7 +242,7 @@ namespace { // permutations. template <typename Op> struct StructuredOpShardingInterface - : public mesh::ShardingInterface::ExternalModel< + : public shard::ShardingInterface::ExternalModel< StructuredOpShardingInterface<Op>, Op> { SmallVector<utils::IteratorType> getLoopIteratorTypes(Operation *op) const { return llvm::cast<LinalgOp>(op).getIteratorTypesArray(); @@ -272,16 +271,16 @@ struct StructuredOpShardingInterface [](unsigned count, utils::IteratorType iter) { return count + (iter == utils::IteratorType::reduction); }); - mesh::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp); + shard::ReductionKind reductionKind = getReductionKindOfLinalgOp(linalgOp); return SmallVector<ReductionKind>(reductionItersCount, reductionKind); } - LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, - ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder) const { + LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands, + ArrayRef<Sharding> operandShardings, + ArrayRef<Sharding> resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { LinalgOp linalgOp = llvm::cast<LinalgOp>(op); SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); @@ -297,20 +296,20 @@ struct StructuredOpShardingInterface SmallVector<utils::IteratorType> loopIteratorTypes = linalgOp.getIteratorTypesArray(); - ShardingArray meshAxisAssignmentForLoopIterators = - getMeshAxisAssignmentForLoopIterators(operandShardings, resultShardings, + ShardingArray gridAxisAssignmentForLoopIterators = + getGridAxisAssignmentForLoopIterators(operandShardings, resultShardings, loopIteratorTypes, indexingMaps); - if (mesh::isAtLeastOneReductionIteratorSharded( - loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + if (shard::isAtLeastOneReductionIteratorSharded( + loopIteratorTypes, gridAxisAssignmentForLoopIterators)) { ImplicitLocOpBuilder implicitLocBuilder(op->getLoc(), builder); - spmdizeLinalgOpWithShardedReduction( - linalgOp, spmdizedOperands, operandShardings, resultShardings, - loopIteratorTypes, meshAxisAssignmentForLoopIterators, spmdizationMap, + partitionLinalgOpWithShardedReduction( + linalgOp, partitionedOperands, operandShardings, resultShardings, + loopIteratorTypes, gridAxisAssignmentForLoopIterators, partitionMap, symbolTable, implicitLocBuilder); } else { - spmdizeTriviallyShardableOperation(*op, spmdizedOperands, - operandShardings, resultShardings, - spmdizationMap, symbolTable, builder); + partitionTriviallyShardableOperation(*op, partitionedOperands, + operandShardings, resultShardings, + partitionMap, symbolTable, builder); } return success(); @@ -330,7 +329,7 @@ static void registerAll(MLIRContext *ctx) { (registerOne<OpTypes>(ctx), ...); } -void registerMeshShardingInterfaceExternalModels(DialectRegistry ®istry) { +void registerShardingInterfaceExternalModels(DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, LinalgDialect *dialect) { DialectRegistry registry; registry.insert<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp index 671dea8..76d0ba9 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -52,8 +52,8 @@ createSplitPart(RewriterBase &b, Location loc, TilingInterface op, return nullptr; SmallVector<OpFoldResult> resultStrides(resultOffsets.size(), b.getIndexAttr(1)); - Value inserted = b.create<tensor::InsertSliceOp>( - loc, result, resultOperands[index], resultOffsets, resultSizes, + Value inserted = tensor::InsertSliceOp::create( + b, loc, result, resultOperands[index], resultOffsets, resultSizes, resultStrides); results.push_back(inserted); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp index 5bfdbc6..b8f8620 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SplitReduction.cpp @@ -115,8 +115,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction( newShape, cast<RankedTensorType>(operand->get().getType()).getElementType()); - Value newInput = b.create<tensor::ExpandShapeOp>( - loc, newType, operand->get(), reassociation); + Value newInput = tensor::ExpandShapeOp::create( + b, loc, newType, operand->get(), reassociation); newInputs.push_back(newInput); } @@ -140,18 +140,18 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction( } Value emptyOrAllocTensor; if (useAlloc) { - emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>( - loc, + emptyOrAllocTensor = bufferization::AllocTensorOp::create( + b, loc, RankedTensorType::get(newOutputShape, op.getRegionOutputArgs()[0].getType()), ValueRange{}); } else { - emptyOrAllocTensor = b.create<tensor::EmptyOp>( - loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); + emptyOrAllocTensor = tensor::EmptyOp::create( + b, loc, newOutputShape, op.getRegionOutputArgs()[0].getType()); } - Value constantOp = b.create<arith::ConstantOp>(loc, *identity); + Value constantOp = arith::ConstantOp::create(b, loc, *identity); Value identityTensor = - b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor) + linalg::FillOp::create(b, op->getLoc(), constantOp, emptyOrAllocTensor) .getResult(0); newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr, @@ -168,8 +168,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction( } // Create the new op matching the original op with an extra parallel // dimension. - GenericOp genericOp = b.create<GenericOp>( - loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs, + GenericOp genericOp = GenericOp::create( + b, loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs, ValueRange({identityTensor}), newMaps, newIteratorTypes); b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(), genericOp.getRegion().begin()); @@ -191,14 +191,14 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReduction( AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext()); SmallVector<AffineMap> reductionMaps = {inputMap, outputMap}; - auto reduction = b.create<GenericOp>( - loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), + auto reduction = GenericOp::create( + b, loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}), op.getDpsInits(), reductionMaps, reductionIteratorTypes, [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) { Operation *clonedReductionOp = b.clone(*reductionOp); clonedReductionOp->setOperand(0, inputs[0]); clonedReductionOp->setOperand(1, inputs[1]); - b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); + linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0)); }); b.replaceOp(op, reduction.getResults()); @@ -318,14 +318,14 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling( Value emptyOrAllocTensor; if (useAlloc) { emptyOrAllocTensor = - b.create<bufferization::AllocTensorOp>(loc, newT, dims); + bufferization::AllocTensorOp::create(b, loc, newT, dims); } else { - emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(), - t.getElementType(), dims); + emptyOrAllocTensor = tensor::EmptyOp::create(b, loc, newT.getShape(), + t.getElementType(), dims); } - Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it)); - fillOps.push_back( - b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)); + Value constantOp = arith::ConstantOp::create(b, loc, std::get<1>(it)); + fillOps.push_back(linalg::FillOp::create(b, op->getLoc(), constantOp, + emptyOrAllocTensor)); newOutputs.push_back(fillOps.back().getResult(0)); emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp()); } @@ -354,8 +354,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling( SmallVector<Value> newInputs = op.getDpsInputs(); // Add a single shape-only tensor to carry the dimensions without resorting to // more complex inversions. - newInputs.push_back(b.create<tensor::EmptyOp>( - loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor}, + newInputs.push_back(tensor::EmptyOp::create( + b, loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor}, b.getIntegerType(1))); // Output tensors are already good to go. @@ -365,8 +365,8 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling( iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos, utils::IteratorType::parallel); GenericOp genericOp = - b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs, - newOutputs, newMaps, iteratorTypes); + GenericOp::create(b, loc, ValueRange(newOutputs).getTypes(), newInputs, + newOutputs, newMaps, iteratorTypes); b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(), genericOp.getRegion().begin()); genericOp.getRegion().front().insertArgument(reductionDimPos, @@ -396,7 +396,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling( utils::IteratorType::reduction; // clang-format off - auto reductionOp = b.create<GenericOp>( + auto reductionOp = GenericOp::create(b, loc, originalOutputType, reindexedOutput, @@ -407,7 +407,7 @@ FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling( Operation *clonedReductionOp = b.clone(*combinerOp); clonedReductionOp->setOperand(0, bbArgs[0]); clonedReductionOp->setOperand(1, bbArgs[1]); - b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); + linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0)); }); // clang-format on diff --git a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp index d35aad5..792ca3e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SwapExtractSliceWithFillPatterns.cpp @@ -29,10 +29,10 @@ struct SwapExtractSliceOfFill final if (!fillOp || !fillOp->hasOneUse()) return failure(); - auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>( - extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0], - extractOp.getMixedOffsets(), extractOp.getMixedSizes(), - extractOp.getMixedStrides()); + auto newExtractOp = tensor::ExtractSliceOp::create( + rewriter, extractOp.getLoc(), extractOp.getType(), + fillOp.getOutputs()[0], extractOp.getMixedOffsets(), + extractOp.getMixedSizes(), extractOp.getMixedStrides()); rewriter.replaceOpWithNewOp<FillOp>(extractOp, fillOp.getInputs(), ValueRange{newExtractOp.getResult()}); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 4741afe..705d6f2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -94,11 +94,11 @@ static void emitIsPositiveIndexAssertion(ImplicitLocOpBuilder &b, return; } - Value zero = b.create<arith::ConstantIndexOp>(0); - Value condition = b.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, - cast<Value>(value), zero); - b.create<cf::AssertOp>( - condition, + Value zero = arith::ConstantIndexOp::create(b, 0); + Value condition = arith::CmpIOp::create(b, arith::CmpIPredicate::sgt, + cast<Value>(value), zero); + cf::AssertOp::create( + b, condition, b.getStringAttr("expected strictly positive tile size and divisor")); } @@ -317,11 +317,12 @@ mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op, Value coveredSize = apply(s0 * s1 + s2 * s3, {spec.lowTileSize, spec.lowTripCount, spec.highTileSize, spec.highTripCount}); - Value equals = b.create<arith::CmpIOp>(arith::CmpIPredicate::eq, - coveredSize, tripCount); - b.create<cf::AssertOp>( - equals, builder.getStringAttr( - "could not compute dynamic multi-size tile shapes")); + Value equals = arith::CmpIOp::create(b, arith::CmpIPredicate::eq, + coveredSize, tripCount); + cf::AssertOp::create( + b, equals, + builder.getStringAttr( + "could not compute dynamic multi-size tile shapes")); } return spec; @@ -656,8 +657,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall( getValueOrCreateConstantIndexOp(b, loc, nonZeroNumThreads); // 2. Create the ForallOp with an empty region. - scf::ForallOp forallOp = b.create<scf::ForallOp>( - loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors, + scf::ForallOp forallOp = scf::ForallOp::create( + b, loc, getAsOpFoldResult(materializedNonZeroNumThreads), initTensors, mapping); // 3. Calculate the tile offsets and sizes for the subsequent loop that will @@ -689,8 +690,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall( sizes[reductionDim] = b.getIndexAttr(1); outOffsets[reductionDim] = forallOp.getInductionVars()[0]; // TODO: use SubsetExtractOpInterface once it is available. - tiledDpsInitOperands.push_back(b.create<tensor::ExtractSliceOp>( - loc, cast<RankedTensorType>(initOperand.getType()), + tiledDpsInitOperands.push_back(tensor::ExtractSliceOp::create( + b, loc, cast<RankedTensorType>(initOperand.getType()), destBbArgs[destNum], outOffsets, sizes, strides)); } @@ -768,8 +769,8 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall( // 6.b. Parallel insertions are inserted at the end of the combining // terminator. b.setInsertionPointToEnd(forallOp.getTerminator().getBody()); - b.create<tensor::ParallelInsertSliceOp>( - loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); + tensor::ParallelInsertSliceOp::create( + b, loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); } // 7. Merge the partial reductions. diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index 28d99b1..57b610b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -45,7 +45,7 @@ static SmallVector<Value> getIndicesForAccess(OpBuilder &b, Location loc, for (auto result : indexingMap.getResults()) { AffineMap m = AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(), result); - Value v = b.create<affine::AffineApplyOp>(loc, m, ivs); + Value v = affine::AffineApplyOp::create(b, loc, m, ivs); indices.push_back(v); } return indices; @@ -73,9 +73,9 @@ static LogicalResult inlinePayload(OpBuilder &b, LinalgOp linalgOp, OpOperand *storeInto = linalgOp.getDpsInitOperand(operand.index()); auto indices = getIndicesForAccess( b, loc, linalgOp.getMatchingIndexingMap(storeInto), ivs); - b.create<memref::StoreOp>( - loc, toStore, linalgOp.getDpsInitOperand(operand.index())->get(), - indices); + memref::StoreOp::create(b, loc, toStore, + linalgOp.getDpsInitOperand(operand.index())->get(), + indices); } return success(); } @@ -352,7 +352,7 @@ struct LinalgOpTilingInterface SmallVector<Value> indices = getIndicesForAccess( builder, linalgOpLoc, linalgOp.getMatchingIndexingMap(&operand), ivs); Value load = - builder.create<memref::LoadOp>(linalgOpLoc, operand.get(), indices); + memref::LoadOp::create(builder, linalgOpLoc, operand.get(), indices); indexedValues.push_back(load); } @@ -520,10 +520,10 @@ struct LinalgOpPartialReductionInterface Type elType = getElementTypeOrSelf(result.getType()); Value emptyTensor = - b.create<tensor::EmptyOp>(loc, partialResultShape, elType); - Value constantOp = b.create<arith::ConstantOp>(loc, *identity); + tensor::EmptyOp::create(b, loc, partialResultShape, elType); + Value constantOp = arith::ConstantOp::create(b, loc, *identity); auto identityTensor = - b.create<linalg::FillOp>(loc, constantOp, emptyTensor); + linalg::FillOp::create(b, loc, constantOp, emptyTensor); inits.push_back(identityTensor.getResult(0)); } @@ -575,9 +575,9 @@ struct LinalgOpPartialReductionInterface RankedTensorType sliceResultType = RankedTensorType::get( sliceInfo.resultShape, valueToTileType.getElementType(), valueToTileType.getEncoding()); - auto sliceOp = b.create<tensor::ExtractSliceOp>( - loc, sliceResultType, valueToTile, sliceInfo.offsets, sliceInfo.sizes, - sliceInfo.strides); + auto sliceOp = tensor::ExtractSliceOp::create( + b, loc, sliceResultType, valueToTile, sliceInfo.offsets, + sliceInfo.sizes, sliceInfo.strides); tiledInits.push_back(sliceOp.getResult()); generatedSlices.push_back(sliceOp); } @@ -604,8 +604,8 @@ struct LinalgOpPartialReductionInterface auto resultTypes = ValueRange(tiledInits).getTypes(); if (tilingStrategy == ReductionTilingStrategy::PartialReductionOuterReduction) { - auto genericOp = b.create<GenericOp>( - loc, resultTypes, tiledInputs, tiledInits, newMaps, newIteratorTypes); + auto genericOp = GenericOp::create(b, loc, resultTypes, tiledInputs, + tiledInits, newMaps, newIteratorTypes); IRMapping mapping; op->getRegion(0).cloneInto(&genericOp.getRegion(), genericOp.getRegion().begin(), mapping); @@ -649,8 +649,8 @@ struct LinalgOpPartialReductionInterface } } - auto reduction = b.create<linalg::ReduceOp>( - loc, partialResult, init, partialReductionDims, + auto reduction = linalg::ReduceOp::create( + b, loc, partialResult, init, partialReductionDims, [&linalgOp, &initIdx](OpBuilder &b, Location loc, ValueRange inputs) { // Get the combiner op. SmallVector<Operation *, 4> combinerOps; @@ -660,7 +660,7 @@ struct LinalgOpPartialReductionInterface // Combine the input at idx and output at numInits + idx. clonedReductionOp->setOperand(0, inputs[0]); clonedReductionOp->setOperand(1, inputs[1]); - b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0)); + linalg::YieldOp::create(b, loc, clonedReductionOp->getResult(0)); }); mergeOperations.push_back(reduction); @@ -791,8 +791,8 @@ struct PackOpTiling SmallVector<OpFoldResult> strides(inputRank, oneAttr); SmallVector<Value> tiledOperands; - auto sourceSlice = b.create<tensor::ExtractSliceOp>( - loc, packOp.getSource(), inputIndices, inputSizes, strides); + auto sourceSlice = tensor::ExtractSliceOp::create( + b, loc, packOp.getSource(), inputIndices, inputSizes, strides); tiledOperands.push_back(sourceSlice); SmallVector<OpFoldResult> outputOffsets, outputSizes; @@ -801,8 +801,8 @@ struct PackOpTiling return {}; strides.append(packOp.getDestRank() - inputRank, oneAttr); - auto outSlice = b.create<tensor::ExtractSliceOp>( - loc, packOp.getDest(), outputOffsets, outputSizes, strides); + auto outSlice = tensor::ExtractSliceOp::create( + b, loc, packOp.getDest(), outputOffsets, outputSizes, strides); tiledOperands.push_back(outSlice); if (auto val = packOp.getPaddingValue()) @@ -810,8 +810,8 @@ struct PackOpTiling for (auto tile : packOp.getInnerTiles()) tiledOperands.push_back(tile); - Operation *tiledPackOp = b.create<PackOp>( - loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); + Operation *tiledPackOp = PackOp::create( + b, loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); return TilingResult{ {tiledPackOp}, @@ -932,20 +932,6 @@ struct PackOpTiling continue; } - // If the dimension needs padding, it is not supported because there are - // iterations that only write padding values to the whole tile. The - // consumer fusion is driven by the source, so it is not possible to map - // an empty slice to the tile. - bool needExtraPadding = - ShapedType::isDynamic(destDimSize) || !cstInnerSize || - destDimSize * cstInnerSize.value() != srcDimSize; - // Prioritize the case that the op already says that it does not need - // padding. - if (!packOp.getPaddingValue()) - needExtraPadding = false; - if (needExtraPadding) - return failure(); - // Currently fusing `packOp` as consumer only expects perfect tiling // scenario because even if without padding semantic, the `packOp` may // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, @@ -1007,8 +993,8 @@ struct PackOpTiling SmallVector<OpFoldResult> strides(inputRank, oneAttr); SmallVector<Value> tiledOperands; - auto sourceSlice = b.create<tensor::ExtractSliceOp>( - loc, packOp.getSource(), offsets, sizes, strides); + auto sourceSlice = tensor::ExtractSliceOp::create( + b, loc, packOp.getSource(), offsets, sizes, strides); tiledOperands.push_back(sourceSlice); SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes; @@ -1023,8 +1009,8 @@ struct PackOpTiling return failure(); strides.append(packOp.getDestRank() - inputRank, oneAttr); - auto outSlice = b.create<tensor::ExtractSliceOp>( - loc, packOp.getDest(), outputOffsets, outputSizes, strides); + auto outSlice = tensor::ExtractSliceOp::create( + b, loc, packOp.getDest(), outputOffsets, outputSizes, strides); tiledOperands.push_back(outSlice); if (auto val = packOp.getPaddingValue()) @@ -1032,8 +1018,8 @@ struct PackOpTiling for (auto tile : packOp.getInnerTiles()) tiledOperands.push_back(tile); - Operation *tiledPackOp = b.create<PackOp>( - loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); + Operation *tiledPackOp = PackOp::create( + b, loc, TypeRange{outSlice.getType()}, tiledOperands, op->getAttrs()); return TilingResult{ {tiledPackOp}, @@ -1212,37 +1198,37 @@ struct UnPackOpTiling sliceSrcSizes.append(unpackOp.getMixedTiles()); sliceSrcStrides.append(numInnerTiles, oneAttr); SmallVector<Operation *> generatedSlices; - tensor::ExtractSliceOp sliceSource = b.create<tensor::ExtractSliceOp>( - loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes, + tensor::ExtractSliceOp sliceSource = tensor::ExtractSliceOp::create( + b, loc, unpackOp.getSource(), sliceSrcIndices, sliceSrcSizes, sliceSrcStrides); generatedSlices.push_back(sliceSource); SmallVector<OpFoldResult> destStrides(destRank, oneAttr); Value sliceDest; if (isPerfectTilingCase) { - auto destSliceOp = b.create<tensor::ExtractSliceOp>( - loc, unpackOp.getDest(), offsets, sizes, destStrides); + auto destSliceOp = tensor::ExtractSliceOp::create( + b, loc, unpackOp.getDest(), offsets, sizes, destStrides); sliceDest = destSliceOp; generatedSlices.push_back(destSliceOp); } else { - sliceDest = b.create<tensor::EmptyOp>( - loc, destExpandedSizes, unpackOp.getDestType().getElementType()); + sliceDest = tensor::EmptyOp::create( + b, loc, destExpandedSizes, unpackOp.getDestType().getElementType()); } SmallVector<Value> tiledOperands = {sliceSource.getResult(), sliceDest}; for (auto tile : unpackOp.getInnerTiles()) tiledOperands.push_back(tile); - Operation *tiledUnpackOp = b.create<UnPackOp>( - loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); + Operation *tiledUnpackOp = UnPackOp::create( + b, loc, TypeRange{sliceDest.getType()}, tiledOperands, op->getAttrs()); if (isPerfectTilingCase) return TilingResult{{tiledUnpackOp}, SmallVector<Value>(tiledUnpackOp->getResults()), generatedSlices}; - auto extractSlice = b.create<tensor::ExtractSliceOp>( - loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes, + auto extractSlice = tensor::ExtractSliceOp::create( + b, loc, tiledUnpackOp->getResult(0), resultOffsetsFromDest, sizes, destStrides); return TilingResult{ {tiledUnpackOp}, {extractSlice.getResult()}, generatedSlices}; @@ -1377,22 +1363,22 @@ struct UnPackOpTiling SmallVector<Value> tiledOperands; // Create slice of the dest operand. - auto extractDestSlice = b.create<tensor::ExtractSliceOp>( - loc, unPackOp.getDest(), outputOffsets, outputSizes, strides); + auto extractDestSlice = tensor::ExtractSliceOp::create( + b, loc, unPackOp.getDest(), outputOffsets, outputSizes, strides); tiledOperands.push_back(extractDestSlice); strides.append(unPackOp.getSourceRank() - outputRank, oneAttr); // Create slice of the source operand. - auto extractSourceSlice = b.create<tensor::ExtractSliceOp>( - loc, unPackOp.getSource(), offsets, sizes, strides); + auto extractSourceSlice = tensor::ExtractSliceOp::create( + b, loc, unPackOp.getSource(), offsets, sizes, strides); tiledOperands.insert(tiledOperands.begin(), extractSourceSlice); for (auto tile : unPackOp.getInnerTiles()) tiledOperands.push_back(tile); // Create tiled unpack op. Operation *tiledUnPackOp = - b.create<UnPackOp>(loc, TypeRange{extractDestSlice.getType()}, - tiledOperands, op->getAttrs()); + UnPackOp::create(b, loc, TypeRange{extractDestSlice.getType()}, + tiledOperands, op->getAttrs()); return TilingResult{{tiledUnPackOp}, SmallVector<Value>(tiledUnPackOp->getResults()), diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index eab74da..bb725f2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -269,12 +269,12 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, packingMetadata.reassociations); Value paddingValue = packOp.getPaddingValue(); if (!paddingValue) { - paddingValue = rewriter.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); + paddingValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(getElementTypeOrSelf(collapsed))); } auto padOp = - rewriter.create<tensor::PadOp>(loc, collapsed, packOp.getSource(), lows, - highs, paddingValue, /*nofold=*/false); + tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows, + highs, paddingValue, /*nofold=*/false); LLVM_DEBUG( DBGSNL(); DBGSNL(); @@ -313,8 +313,8 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, SmallVector<OpFoldResult> sizes = tensor::getMixedSizes(rewriter, loc, packOp.getDest()); - auto insertSliceOp = rewriter.create<tensor::InsertSliceOp>( - loc, /*source=*/padOp, /*dest=*/packOp.getDest(), + auto insertSliceOp = tensor::InsertSliceOp::create( + rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(), /*offsets=*/zeros, sizes, /*strides=*/ones); LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); @@ -329,15 +329,15 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, // 5. Expand from the padded result to the stripMinedShape. auto expandShapeResultType = RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape); - auto reshapeOp = rewriter.create<tensor::ExpandShapeOp>( - loc, expandShapeResultType, padOp.getResult(), + auto reshapeOp = tensor::ExpandShapeOp::create( + rewriter, loc, expandShapeResultType, padOp.getResult(), packingMetadata.reassociations); // 6. Transpose stripMinedShape to packedShape. SmallVector<int64_t> transpPerm = invertPermutationVector(packedToStripMinedShapePerm); - auto transposeOp = rewriter.create<linalg::TransposeOp>( - loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); + auto transposeOp = linalg::TransposeOp::create( + rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); DBGS() << "reshape op: " << reshapeOp; DBGSNL(); @@ -371,8 +371,8 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, SmallVector<OpFoldResult> sizes(packedRank - destShape.size(), one); sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest())); - auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( - loc, destTensorType, unPackOp.getSource(), + auto extractSliceOp = tensor::ExtractSliceOp::create( + rewriter, loc, destTensorType, unPackOp.getSource(), SmallVector<OpFoldResult>(packedRank, zero), sizes, SmallVector<OpFoldResult>(packedRank, one)); @@ -404,10 +404,11 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, SmallVector<OpFoldResult, 4> dims = tensor::getMixedSizes(rewriter, loc, unPackOp.getSource()); applyPermutationToVector(dims, packedToStripMinedShapePerm); - auto emptyOp = rewriter.create<tensor::EmptyOp>( - loc, dims, stripMinedTensorType.getElementType()); - auto transposeOp = rewriter.create<linalg::TransposeOp>( - loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm); + auto emptyOp = tensor::EmptyOp::create(rewriter, loc, dims, + stripMinedTensorType.getElementType()); + auto transposeOp = + linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp, + packedToStripMinedShapePerm); LLVM_DEBUG( DBGSNL(); DBGSNL(); @@ -426,21 +427,21 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); // 4. Collapse from the stripMinedShape to the padded result. - auto reshapeOp = rewriter.create<tensor::CollapseShapeOp>( - loc, collapsedType, transposeOp->getResult(0), + auto reshapeOp = tensor::CollapseShapeOp::create( + rewriter, loc, collapsedType, transposeOp->getResult(0), packingMetadata.reassociations); // 5. ExtractSlice. int64_t destRank = destTensorType.getRank(); - auto extractSliceOp = rewriter.create<tensor::ExtractSliceOp>( - loc, destTensorType, reshapeOp->getResult(0), + auto extractSliceOp = tensor::ExtractSliceOp::create( + rewriter, loc, destTensorType, reshapeOp->getResult(0), SmallVector<OpFoldResult>(destRank, zero), tensor::getMixedSizes(rewriter, loc, unPackOp.getDest()), SmallVector<OpFoldResult>(destRank, one)); // 6. Inject a copy to preserve DPS. - auto copyOp = rewriter.create<linalg::CopyOp>( - loc, extractSliceOp->getResult(0), unPackOp.getDest()); + auto copyOp = linalg::CopyOp::create( + rewriter, loc, extractSliceOp->getResult(0), unPackOp.getDest()); // 7. Replace unPackOp by copyOp. rewriter.replaceOp(unPackOp, copyOp->getResults()); @@ -554,16 +555,16 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, operandType.getShape(), innerPos, cast<ShapedType>(dest.getType()).getShape(), {}, innerPackSizes)) { - packOps.push_back(rewriter.create<linalg::PackOp>( - loc, operand, dest, innerPos, innerPackSizes)); + packOps.push_back(linalg::PackOp::create(rewriter, loc, operand, dest, + innerPos, innerPackSizes)); } else { // TODO: value of the padding attribute should be determined by // consumers. auto zeroAttr = rewriter.getZeroAttr(getElementTypeOrSelf(dest.getType())); - Value zero = rewriter.create<arith::ConstantOp>(loc, zeroAttr); - packOps.push_back(rewriter.create<linalg::PackOp>( - loc, operand, dest, innerPos, innerPackSizes, zero)); + Value zero = arith::ConstantOp::create(rewriter, loc, zeroAttr); + packOps.push_back(linalg::PackOp::create( + rewriter, loc, operand, dest, innerPos, innerPackSizes, zero)); } inputsAndInits.push_back(packOps.back()); } @@ -574,9 +575,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, ValueRange{inputsAndInits}.take_front(linalgOp.getNumDpsInputs()); ValueRange inits = ValueRange{inputsAndInits}.take_back(linalgOp.getNumDpsInits()); - auto packedLinalgOp = rewriter.create<linalg::GenericOp>( - linalgOp.getLoc(), inits.getTypes(), inputs, inits, indexingMaps, - iteratorTypes); + auto packedLinalgOp = + linalg::GenericOp::create(rewriter, linalgOp.getLoc(), inits.getTypes(), + inputs, inits, indexingMaps, iteratorTypes); packedLinalgOp.getRegion().takeBody(linalgOp->getRegion(0)); // Step 4. Propagate packing to all the op results. @@ -589,8 +590,8 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, continue; } // Build the symmetrical UnPackOp to the existing PackOp. - unPackOps.push_back(rewriter.create<linalg::UnPackOp>( - packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), + unPackOps.push_back(linalg::UnPackOp::create( + rewriter, packedLinalgOp->getLoc(), result, maybePackedInit.getSource(), maybePackedInit.getInnerDimsPos(), maybePackedInit.getMixedTiles())); results.push_back(unPackOps.back()); } @@ -655,7 +656,8 @@ static LinalgOp transposeOneLinalgOperandAndReplace( operands[opOperand.getOperandNumber()] = transposedValue; ValueRange operandsRef(operands); - auto transposedGenericOp = rewriter.create<linalg::GenericOp>( + auto transposedGenericOp = linalg::GenericOp::create( + rewriter, /*location=*/linalgOp->getLoc(), /*resultTensorTypes=*/ operandsRef.drop_front(linalgOp.getNumDpsInputs()).getTypes(), @@ -904,7 +906,7 @@ mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { b.setInsertionPointToStart( &op->getParentOfType<func::FuncOp>().getBody().front()); return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { - Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); + Value v = arith::ConstantIndexOp::create(b, op->getLoc(), s); return v; })); }; @@ -926,12 +928,12 @@ Value DecomposePadOpPattern::createFillOrGenerateOp( // Move the padding value defined inside the PadOp block to outside. if (padValue.getParentBlock() == &padOp.getRegion().front()) rewriter.moveOpBefore(padValue.getDefiningOp(), padOp); - return rewriter.create<FillOp>(padOp.getLoc(), padValue, dest).result(); + return FillOp::create(rewriter, padOp.getLoc(), padValue, dest).result(); } // Fill could not be optimized: Lower to tensor::GenerateOp with region. - auto generateOp = rewriter.create<tensor::GenerateOp>( - padOp.getLoc(), padOp.getResultType(), dynSizes); + auto generateOp = tensor::GenerateOp::create(rewriter, padOp.getLoc(), + padOp.getResultType(), dynSizes); // Copy region to new op. IRMapping bvm; padOp.getRegion().cloneInto(&generateOp.getRegion(), bvm); @@ -945,9 +947,9 @@ DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp, auto getIdxValue = [&](OpFoldResult ofr) { if (auto val = llvm::dyn_cast_if_present<Value>(ofr)) return val; - return rewriter - .create<arith::ConstantIndexOp>( - padOp.getLoc(), cast<IntegerAttr>(cast<Attribute>(ofr)).getInt()) + return arith::ConstantIndexOp::create( + rewriter, padOp.getLoc(), + cast<IntegerAttr>(cast<Attribute>(ofr)).getInt()) .getResult(); }; @@ -970,8 +972,9 @@ DecomposePadOpPattern::matchAndRewrite(tensor::PadOp padOp, } // Init tensor and fill it with padding. - Value emptyTensor = rewriter.create<tensor::EmptyOp>( - padOp.getLoc(), staticSizes, resultType.getElementType(), dynSizes); + Value emptyTensor = + tensor::EmptyOp::create(rewriter, padOp.getLoc(), staticSizes, + resultType.getElementType(), dynSizes); Value fill = createFillOrGenerateOp(rewriter, padOp, emptyTensor, dynSizes); // Generate a InsertSliceOp for copying the PadOp source. @@ -1222,12 +1225,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, srcPermForTranspose); - Value empty = rewriter.create<tensor::EmptyOp>( - loc, transShapeForEmptyOp, packOp.getSourceType().getElementType()); + Value empty = + tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp, + packOp.getSourceType().getElementType()); // 2.2 Create linalg.transpose - auto transposedOp = rewriter.create<linalg::TransposeOp>(loc, input, empty, - srcPermForTranspose); + auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, + srcPermForTranspose); // 3. Insert the inner tile to the destination: // %inserted_tile = tensor.insert_slice(%transposed_tile) @@ -1246,9 +1250,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( } // 4. Replace tensor.packOp with tensor.insert_slice created above - auto insert = rewriter.create<tensor::InsertSliceOp>( - loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, - writeSizes, writeStrides); + auto insert = tensor::InsertSliceOp::create( + rewriter, loc, transposedOp.getResult()[0], packOp.getDest(), + writeOffsets, writeSizes, writeStrides); rewriter.replaceOp(packOp, insert.getResult()); return success(); @@ -1313,7 +1317,7 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( // outer-untiled-dims if (ShapedType::isDynamic(srcShape[i])) { OpFoldResult dynamicDim = - rewriter.create<tensor::DimOp>(loc, source, i).getResult(); + tensor::DimOp::create(rewriter, loc, source, i).getResult(); extractSliceSizes.push_back(dynamicDim); shapeForEmptyOp.push_back(dynamicDim); } else { @@ -1340,8 +1344,8 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( readShapeForExtractSlice.append(tileShape.begin(), tileShape.end()); Type elemType = unpackOp.getSourceType().getElementType(); auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); - Value innerTile = rewriter.create<tensor::ExtractSliceOp>( - loc, readType, unpackOp.getSource(), extractSliceOffsets, + Value innerTile = tensor::ExtractSliceOp::create( + rewriter, loc, readType, unpackOp.getSource(), extractSliceOffsets, extractSliceSizes, extractSliceStrides); // 2. Transpose the tile to match the outer corresponding tile order. @@ -1352,9 +1356,9 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( applyPermutationToVector<OpFoldResult>(shapeForEmptyOp, perm); Value empty = - rewriter.create<tensor::EmptyOp>(loc, shapeForEmptyOp, elemType); + tensor::EmptyOp::create(rewriter, loc, shapeForEmptyOp, elemType); auto transposedOp = - rewriter.create<linalg::TransposeOp>(loc, innerTile, empty, perm); + linalg::TransposeOp::create(rewriter, loc, innerTile, empty, perm); // 3. Handle in-complete tiles if needed. It truncates trailing data from the // transposed tile. @@ -1369,8 +1373,9 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( tensor::getMixedSize(rewriter, loc, unpackOp.getDest(), i)); } - auto partialTile = rewriter.create<tensor::ExtractSliceOp>( - loc, transposedOp.getResult()[0], tileOffsets, tileSizes, tileStrides); + auto partialTile = + tensor::ExtractSliceOp::create(rewriter, loc, transposedOp.getResult()[0], + tileOffsets, tileSizes, tileStrides); // 4. Insert the result to the destination tensor. SmallVector<OpFoldResult> writeSizes; @@ -1382,9 +1387,9 @@ LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( else writeSizes.push_back(oneIdxAttr); } - auto insert = rewriter.create<tensor::InsertSliceOp>( - loc, partialTile, unpackOp.getDest(), writeOffsets, writeSizes, - writeStrides); + auto insert = tensor::InsertSliceOp::create(rewriter, loc, partialTile, + unpackOp.getDest(), writeOffsets, + writeSizes, writeStrides); rewriter.replaceOp(unpackOp, insert.getResult()); return success(); @@ -1491,8 +1496,8 @@ FailureOr<Conv1DOp> DownscaleSizeOneWindowed2DConvolution<Conv2DOp, Conv1DOp>:: dilations.erase(dilations.begin() + (removeH ? 0 : 1)); auto dilationsAttr = rewriter.getI64VectorAttr(dilations); - auto conv1DOp = rewriter.create<Conv1DOp>( - loc, newOutputType, ValueRange{newInput, newKernel}, + auto conv1DOp = Conv1DOp::create( + rewriter, loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}, stridesAttr, dilationsAttr); // Insert back. @@ -1578,8 +1583,8 @@ DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( dilations.erase(dilations.begin() + (removeH ? 0 : 1)); auto dilationsAttr = rewriter.getI64VectorAttr(dilations); - auto conv1DOp = rewriter.create<DepthwiseConv1DNwcWcOp>( - loc, newOutputType, ValueRange{newInput, newKernel}, + auto conv1DOp = DepthwiseConv1DNwcWcOp::create( + rewriter, loc, newOutputType, ValueRange{newInput, newKernel}, ValueRange{newOutput}, stridesAttr, dilationsAttr); // Insert back. @@ -1635,9 +1640,9 @@ DownscaleConv2DOp::returningMatchAndRewrite(Conv2DOp convOp, Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( rewriter, loc, output, newOutputType); - auto conv1DOp = rewriter.create<Conv1DOp>(loc, newOutputType, - ValueRange{newInput, newKernel}, - ValueRange{newOutput}); + auto conv1DOp = + Conv1DOp::create(rewriter, loc, newOutputType, + ValueRange{newInput, newKernel}, ValueRange{newOutput}); // Insert back. Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp index 092aecc..35453e2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeConv2D.cpp @@ -67,18 +67,17 @@ FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter, Value input; if (isTensorOp) { - input = rewriter.create<tensor::EmptyOp>(loc, newFilterShape, elementTy) + input = tensor::EmptyOp::create(rewriter, loc, newFilterShape, elementTy) .getResult(); } else { - input = rewriter - .create<memref::AllocOp>( - loc, MemRefType::get(newFilterShape, elementTy)) + input = memref::AllocOp::create(rewriter, loc, + MemRefType::get(newFilterShape, elementTy)) .getResult(); } // We can then construct the transposition on our filter. auto transpose = - rewriter.create<linalg::TransposeOp>(loc, filter, input, filterPerm); + linalg::TransposeOp::create(rewriter, loc, filter, input, filterPerm); Value newFilter; if (isTensorOp) { @@ -98,8 +97,8 @@ FailureOr<Operation *> transposeConv2DHelper(RewriterBase &rewriter, resultTy.push_back(op->getResult(0).getType()); } auto newConv = - rewriter.create<HWCFConvOp>(loc, resultTy, newInputs, op.getOutputs(), - op.getStrides(), op.getDilations()); + HWCFConvOp::create(rewriter, loc, resultTy, newInputs, op.getOutputs(), + op.getStrides(), op.getDilations()); rewriter.replaceOp(op, newConv); return newConv.getOperation(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp index 934781d..a2a4335 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp @@ -47,25 +47,25 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter, SmallVector<Value> dynamicDims; if (type.isDynamicDim(1)) - dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1)); if (type.isDynamicDim(0)) - dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0)); ArrayRef<int64_t> shape = type.getShape(); - Value empty = rewriter.create<tensor::EmptyOp>( - loc, ArrayRef<int64_t>{shape[1], shape[0]}, type.getElementType(), - dynamicDims); - auto transposeOp = rewriter.create<linalg::TransposeOp>( - loc, input, empty, ArrayRef<int64_t>{1, 0}); + Value empty = tensor::EmptyOp::create(rewriter, loc, + ArrayRef<int64_t>{shape[1], shape[0]}, + type.getElementType(), dynamicDims); + auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty, + ArrayRef<int64_t>{1, 0}); Operation *newMatmulOp; if (transposeLHS) { - newMatmulOp = rewriter.create<linalg::MatmulTransposeAOp>( - loc, matmulOp.getResultTypes(), + newMatmulOp = linalg::MatmulTransposeAOp::create( + rewriter, loc, matmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]}, matmulOp.getOutputs()); } else { - newMatmulOp = rewriter.create<linalg::MatmulTransposeBOp>( - loc, matmulOp.getResultTypes(), + newMatmulOp = linalg::MatmulTransposeBOp::create( + rewriter, loc, matmulOp.getResultTypes(), ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)}, matmulOp.getOutputs()); } @@ -102,27 +102,27 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, SmallVector<Value> dynamicDims; if (type.isDynamicDim(0)) - dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 0)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 0)); if (type.isDynamicDim(2)) - dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 2)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 2)); if (type.isDynamicDim(1)) - dynamicDims.push_back(rewriter.create<tensor::DimOp>(loc, input, 1)); + dynamicDims.push_back(tensor::DimOp::create(rewriter, loc, input, 1)); ArrayRef<int64_t> shape = type.getShape(); - Value empty = rewriter.create<tensor::EmptyOp>( - loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]}, + Value empty = tensor::EmptyOp::create( + rewriter, loc, ArrayRef<int64_t>{shape[0], shape[2], shape[1]}, type.getElementType(), dynamicDims); - auto transposeOp = rewriter.create<linalg::TransposeOp>( - loc, input, empty, ArrayRef<int64_t>{0, 2, 1}); + auto transposeOp = linalg::TransposeOp::create(rewriter, loc, input, empty, + ArrayRef<int64_t>{0, 2, 1}); Operation *newMatmulOp; if (transposeLHS) { - newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeAOp>( - loc, batchMatmulOp.getResultTypes(), + newMatmulOp = linalg::BatchMatmulTransposeAOp::create( + rewriter, loc, batchMatmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]}, batchMatmulOp.getOutputs()); } else { - newMatmulOp = rewriter.create<linalg::BatchMatmulTransposeBOp>( - loc, batchMatmulOp.getResultTypes(), + newMatmulOp = linalg::BatchMatmulTransposeBOp::create( + rewriter, loc, batchMatmulOp.getResultTypes(), ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)}, batchMatmulOp.getOutputs()); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 77c85ab..ea68b1a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -38,7 +38,8 @@ #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/InterleavedRange.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include <optional> @@ -48,9 +49,6 @@ using namespace mlir::linalg; #define DEBUG_TYPE "linalg-vectorization" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") - /// Try to vectorize `convOp` as a convolution. static FailureOr<Operation *> vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp, @@ -120,8 +118,9 @@ extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, SmallVector<int64_t> strides = {1}; for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( - loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, strides)); + result.push_back(vector::ExtractStridedSliceOp::create( + rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{w + kw}, sizes, + strides)); } } } else { @@ -131,8 +130,8 @@ extractConvInputSlices(RewriterBase &rewriter, Location loc, Value input, SmallVector<int64_t> strides = {1, 1, 1}; for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( - loc, input, + result.push_back(vector::ExtractStridedSliceOp::create( + rewriter, loc, input, /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0}, sizes, strides)); } @@ -150,8 +149,8 @@ static SmallVector<Value> extractConvFilterSlices(RewriterBase &rewriter, // Extract rhs slice of size [{c, f} for channeled convolutions and {1} for // non-chanelled convolution] @ [kw]. for (int64_t kw = 0; kw < kwSize; ++kw) { - result.push_back(rewriter.create<vector::ExtractOp>( - loc, filter, /*offsets=*/ArrayRef<int64_t>{kw})); + result.push_back(vector::ExtractOp::create( + rewriter, loc, filter, /*offsets=*/ArrayRef<int64_t>{kw})); } return result; } @@ -168,8 +167,9 @@ extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, SmallVector<int64_t> sizes = {wSizeStep}; SmallVector<int64_t> strides = {1}; for (int64_t w = 0; w < wSize; w += wSizeStep) { - result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( - loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, strides)); + result.push_back(vector::ExtractStridedSliceOp::create( + rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{w}, sizes, + strides)); } } else { // Extract res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled @@ -177,8 +177,9 @@ extractConvResultSlices(RewriterBase &rewriter, Location loc, Value res, SmallVector<int64_t> sizes = {nSize, wSizeStep, fSize}; SmallVector<int64_t> strides = {1, 1, 1}; for (int64_t w = 0; w < wSize; w += wSizeStep) { - result.push_back(rewriter.create<vector::ExtractStridedSliceOp>( - loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, strides)); + result.push_back(vector::ExtractStridedSliceOp::create( + rewriter, loc, res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, sizes, + strides)); } } return result; @@ -195,17 +196,18 @@ static Value insertConvResultSlices(RewriterBase &rewriter, Location loc, // This does not depend on kw. SmallVector<int64_t> strides = {1}; for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = rewriter.create<vector::InsertStridedSliceOp>( - loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, strides); + res = vector::InsertStridedSliceOp::create( + rewriter, loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{w}, + strides); } } else { // Write back res slice: {n, wSizeStep, f} @ [0, w, 0] for channeled // convolution. This does not depend on kw. SmallVector<int64_t> strides = {1, 1, 1}; for (int64_t w = 0; w < wSize; w += wSizeStep) { - res = rewriter.create<vector::InsertStridedSliceOp>( - loc, resVals[w], res, /*offsets=*/ArrayRef<int64_t>{0, w, 0}, - strides); + res = vector::InsertStridedSliceOp::create( + rewriter, loc, resVals[w], res, + /*offsets=*/ArrayRef<int64_t>{0, w, 0}, strides); } } return res; @@ -347,8 +349,8 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) { if (ShapedType::isStatic(iterSpaceStaticSizes[vecDim])) { // Create constant index op for static dimensions. - iterSpaceValueSizes.push_back(rewriter.create<arith::ConstantIndexOp>( - linalgOp.getLoc(), iterSpaceStaticSizes[vecDim])); + iterSpaceValueSizes.push_back(arith::ConstantIndexOp::create( + rewriter, linalgOp.getLoc(), iterSpaceStaticSizes[vecDim])); continue; } @@ -360,11 +362,12 @@ VectorizationState::precomputeIterSpaceValueSizes(RewriterBase &rewriter, operandDimPos))) return failure(); - Value dynamicDim = linalgOp.hasPureTensorSemantics() - ? (Value)rewriter.create<tensor::DimOp>( - linalgOp.getLoc(), operand, operandDimPos) - : (Value)rewriter.create<memref::DimOp>( - linalgOp.getLoc(), operand, operandDimPos); + Value dynamicDim = + linalgOp.hasPureTensorSemantics() + ? (Value)tensor::DimOp::create(rewriter, linalgOp.getLoc(), operand, + operandDimPos) + : (Value)memref::DimOp::create(rewriter, linalgOp.getLoc(), operand, + operandDimPos); iterSpaceValueSizes.push_back(dynamicDim); } @@ -398,12 +401,8 @@ LogicalResult VectorizationState::initState(RewriterBase &rewriter, scalableVecDims.append(linalgOp.getNumLoops(), false); } - LDBG("Canonical vector shape: "); - LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); - LDBG("Scalable vector dims: "); - LLVM_DEBUG(llvm::interleaveComma(scalableVecDims, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG() << "Canonical vector shape: " << llvm::interleaved(canonicalVecShape); + LDBG() << "Scalable vector dims: " << llvm::interleaved(scalableVecDims); if (ShapedType::isDynamicShape(canonicalVecShape)) return failure(); @@ -447,14 +446,14 @@ Value VectorizationState::getOrCreateMaskFor( : AffineMap::getMultiDimIdentityMap( linalgOp.getNumLoops(), rewriter.getContext()); - LDBG("Masking map: " << maskingMap << "\n"); + LDBG() << "Masking map: " << maskingMap; // Return the active mask for the masking map of this operation if it was // already created. auto activeMaskIt = activeMaskCache.find(maskingMap); if (activeMaskIt != activeMaskCache.end()) { Value mask = activeMaskIt->second; - LDBG("Reusing mask: " << mask << "\n"); + LDBG() << "Reusing mask: " << mask; return mask; } @@ -469,12 +468,10 @@ Value VectorizationState::getOrCreateMaskFor( auto maskType = getCanonicalVecType(rewriter.getI1Type(), maskingMap); auto maskShape = maskType.getShape(); - LDBG("Mask shape: "); - LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG() << "Mask shape: " << llvm::interleaved(maskShape); if (permutedStaticSizes == maskShape) { - LDBG("Masking is not needed for masking map: " << maskingMap << "\n"); + LDBG() << "Masking is not needed for masking map: " << maskingMap; activeMaskCache[maskingMap] = Value(); return Value(); } @@ -489,8 +486,9 @@ Value VectorizationState::getOrCreateMaskFor( ? true : std::get<0>(it) == std::get<1>(it); })) { - LDBG("Dynamic + static dimensions match vector sizes, masking is not " - "required.\n"); + LDBG() + << "Dynamic + static dimensions match vector sizes, masking is not " + "required."; activeMaskCache[maskingMap] = Value(); return Value(); } @@ -503,9 +501,9 @@ Value VectorizationState::getOrCreateMaskFor( "Masked 0-d vectors are not supported yet"); // Create the mask based on the dimension values. - Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(), - maskType, upperBounds); - LDBG("Creating new mask: " << mask << "\n"); + Value mask = vector::CreateMaskOp::create(rewriter, linalgOp.getLoc(), + maskType, upperBounds); + LDBG() << "Creating new mask: " << mask; activeMaskCache[maskingMap] = mask; return mask; } @@ -514,7 +512,7 @@ Operation * VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp, std::optional<AffineMap> maybeIndexingMap) { - LDBG("Trying to mask: " << *opToMask << "\n"); + LDBG() << "Trying to mask: " << *opToMask; std::optional<AffineMap> maybeMaskingMap = std::nullopt; if (maybeIndexingMap) @@ -525,7 +523,7 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap); if (!mask) { - LDBG("No mask required\n"); + LDBG() << "No mask required"; return opToMask; } @@ -539,7 +537,7 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx), maskOpTerminator); - LDBG("Masked operation: " << *maskOp << "\n"); + LDBG() << "Masked operation: " << *maskOp; return maskOp; } @@ -672,8 +670,8 @@ static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, ArrayRef<bool> dimsToMask) { auto maybeKind = getCombinerOpKind(reduceOp); assert(maybeKind && "Failed precondition: could not get reduction kind"); - return b.create<vector::MultiDimReductionOp>( - reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind); + return vector::MultiDimReductionOp::create( + b, reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind); } static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) { @@ -717,19 +715,20 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, Operation *write; if (vectorType.getRank() > 0) { AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap)); - SmallVector<Value> indices(linalgOp.getRank(outputOperand), - rewriter.create<arith::ConstantIndexOp>(loc, 0)); + SmallVector<Value> indices( + linalgOp.getRank(outputOperand), + arith::ConstantIndexOp::create(rewriter, loc, 0)); value = broadcastIfNeeded(rewriter, value, vectorType); assert(value.getType() == vectorType && "Incorrect type"); - write = rewriter.create<vector::TransferWriteOp>( - loc, value, outputOperand->get(), indices, writeMap); + write = vector::TransferWriteOp::create( + rewriter, loc, value, outputOperand->get(), indices, writeMap); } else { // 0-d case is still special: do not invert the reindexing writeMap. if (!isa<VectorType>(value.getType())) - value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value); + value = vector::BroadcastOp::create(rewriter, loc, vectorType, value); assert(value.getType() == vectorType && "Incorrect type"); - write = rewriter.create<vector::TransferWriteOp>( - loc, value, outputOperand->get(), ValueRange{}); + write = vector::TransferWriteOp::create(rewriter, loc, value, + outputOperand->get(), ValueRange{}); } write = state.maskOperation(rewriter, write, linalgOp, opOperandMap); @@ -742,7 +741,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds)); } - LDBG("vectorized op: " << *write << "\n"); + LDBG() << "vectorized op: " << *write; if (!write->getResults().empty()) return write->getResult(0); return Value(); @@ -807,7 +806,7 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, auto indexVectorType = VectorType::get({targetShape[dim]}, rewriter.getIndexType(), state.getScalableVecDims()[dim]); - auto indexSteps = rewriter.create<vector::StepOp>(loc, indexVectorType); + auto indexSteps = vector::StepOp::create(rewriter, loc, indexVectorType); // Return the one-dimensional index vector if it lives in the trailing // dimension of the iteration space since the vectorization algorithm in this // case can handle the broadcast. @@ -822,14 +821,14 @@ static VectorizationHookResult vectorizeLinalgIndex(RewriterBase &rewriter, auto permMap = AffineMap::getPermutationMap(permPattern, linalgOp.getContext()); - auto broadCastOp = rewriter.create<vector::BroadcastOp>( - loc, state.getCanonicalVecType(rewriter.getIndexType(), permMap), - indexSteps); + auto broadCastOp = vector::BroadcastOp::create( + rewriter, loc, + state.getCanonicalVecType(rewriter.getIndexType(), permMap), indexSteps); SmallVector<int64_t> transposition = llvm::to_vector<16>(llvm::seq<int64_t>(0, linalgOp.getNumLoops())); std::swap(transposition.back(), transposition[dim]); auto transposeOp = - rewriter.create<vector::TransposeOp>(loc, broadCastOp, transposition); + vector::TransposeOp::create(rewriter, loc, broadCastOp, transposition); return VectorizationHookResult{VectorizationHookStatus::NewOp, transposeOp}; } @@ -882,19 +881,19 @@ static Value calculateGatherOffset(RewriterBase &rewriter, const size_t numIndices = extractOp.getIndices().size(); for (size_t i = 1; i < numIndices; i++) { - Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i); + Value dimIdx = arith::ConstantIndexOp::create(rewriter, loc, i); auto dimSize = broadcastIfNeeded( rewriter, - rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx), + tensor::DimOp::create(rewriter, loc, extractOp.getTensor(), dimIdx), indexVecType); - offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize); + offset = arith::MulIOp::create(rewriter, loc, offset, dimSize); auto extractOpIndex = broadcastIfNeeded( rewriter, bvm.lookup(extractOp.getIndices()[i]), indexVecType); - offset = rewriter.create<arith::AddIOp>(loc, extractOpIndex, offset); + offset = arith::AddIOp::create(rewriter, loc, extractOpIndex, offset); } return offset; @@ -1084,7 +1083,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, } if (!leadingIdxsLoopInvariant) { - LDBG("Found gather load: " << extractOp); + LDBG() << "Found gather load: " << extractOp; return VectorMemoryAccessKind::Gather; } @@ -1098,7 +1097,7 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, // If the trailing index is loop invariant then this is a scalar load. if (leadingIdxsLoopInvariant && isLoopInvariantIdx(linalgOp, extractOpTrailingIdx, resType)) { - LDBG("Found scalar broadcast load: " << extractOp); + LDBG() << "Found scalar broadcast load: " << extractOp; return VectorMemoryAccessKind::ScalarBroadcast; } @@ -1116,12 +1115,12 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, isContiguousLoad &= (foundIndexOp && isRowVector); if (isContiguousLoad) { - LDBG("Found contigous load: " << extractOp); + LDBG() << "Found contigous load: " << extractOp; return VectorMemoryAccessKind::Contiguous; } // 4. Fallback case - gather load. - LDBG("Found gather load: " << extractOp); + LDBG() << "Found gather load: " << extractOp; return VectorMemoryAccessKind::Gather; } @@ -1139,18 +1138,18 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, // Compute the static loop sizes of the extract op. auto resultType = state.getCanonicalVecType(extractOp.getResult().getType()); - auto maskConstantOp = rewriter.create<arith::ConstantOp>( - loc, + auto maskConstantOp = arith::ConstantOp::create( + rewriter, loc, DenseIntElementsAttr::get(state.getCanonicalVecType(rewriter.getI1Type()), /*value=*/true)); - auto passThruConstantOp = - rewriter.create<arith::ConstantOp>(loc, rewriter.getZeroAttr(resultType)); + auto passThruConstantOp = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(resultType)); // Base indices are currently set to 0. We will need to re-visit if more // generic scenarios are to be supported. SmallVector<Value> baseIndices( extractOp.getIndices().size(), - rewriter.create<arith::ConstantIndexOp>(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); VectorMemoryAccessKind memAccessKind = getTensorExtractMemoryAccessPattern(extractOp, linalgOp, resultType); @@ -1160,12 +1159,12 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, Value offset = calculateGatherOffset(rewriter, state, extractOp, bvm); // Generate the gather load - Operation *gatherOp = rewriter.create<vector::GatherOp>( - loc, resultType, extractOp.getTensor(), baseIndices, offset, + Operation *gatherOp = vector::GatherOp::create( + rewriter, loc, resultType, extractOp.getTensor(), baseIndices, offset, maskConstantOp, passThruConstantOp); gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp); - LDBG("Vectorised as gather load: " << extractOp << "\n"); + LDBG() << "Vectorised as gather load: " << extractOp; return VectorizationHookResult{VectorizationHookStatus::NewOp, gatherOp}; } @@ -1195,13 +1194,13 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, continue; } - auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>( - loc, + auto indexAs1dVector = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(resultType.getShape().back(), rewriter.getIndexType(), resultType.getScalableDims().back()), idx); transferReadIdxs.push_back( - rewriter.create<vector::ExtractOp>(loc, indexAs1dVector, 0)); + vector::ExtractOp::create(rewriter, loc, indexAs1dVector, 0)); } // `tensor.extract_element` is always in-bounds, hence the following holds. @@ -1215,8 +1214,8 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, SmallVector<AffineExpr> exprs(dstRank, getAffineConstantExpr(0, ctx)); auto permutationMap = AffineMap::get(srcRank, 0, exprs, ctx); - auto transferReadOp = rewriter.create<vector::TransferReadOp>( - loc, resultType, extractOp.getTensor(), transferReadIdxs, + auto transferReadOp = vector::TransferReadOp::create( + rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs, /*padding=*/std::nullopt, permutationMap, inBounds); // Mask this broadcasting xfer_read here rather than relying on the generic @@ -1224,12 +1223,12 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, // valid here). SmallVector<int64_t> readMaskShape = {1}; auto readMaskType = VectorType::get(readMaskShape, rewriter.getI1Type()); - auto allTrue = rewriter.create<vector::ConstantMaskOp>( - loc, readMaskType, vector::ConstantMaskKind::AllTrue); + auto allTrue = vector::ConstantMaskOp::create( + rewriter, loc, readMaskType, vector::ConstantMaskKind::AllTrue); auto *maskedReadOp = mlir::vector::maskOperation(rewriter, transferReadOp, allTrue); - LDBG("Vectorised as scalar broadcast load: " << extractOp << "\n"); + LDBG() << "Vectorised as scalar broadcast load: " << extractOp; return VectorizationHookResult{VectorizationHookStatus::NewOp, maskedReadOp}; } @@ -1252,11 +1251,11 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, rankDiff--; } - auto transferReadOp = rewriter.create<vector::TransferReadOp>( - loc, resultType, extractOp.getTensor(), transferReadIdxs, + auto transferReadOp = vector::TransferReadOp::create( + rewriter, loc, resultType, extractOp.getTensor(), transferReadIdxs, /*padding=*/std::nullopt, permutationMap, inBounds); - LDBG("Vectorised as contiguous load: " << extractOp); + LDBG() << "Vectorised as contiguous load: " << extractOp; return VectorizationHookResult{VectorizationHookStatus::NewOp, transferReadOp}; } @@ -1304,7 +1303,7 @@ static VectorizationHookResult vectorizeOneOp(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, Operation *op, const IRMapping &bvm, ArrayRef<CustomVectorizationHook> customVectorizationHooks) { - LDBG("vectorize op " << *op << "\n"); + LDBG() << "vectorize op " << *op; // 1. Try to apply any CustomVectorizationHook. if (!customVectorizationHooks.empty()) { @@ -1419,7 +1418,7 @@ static LogicalResult vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) { - LDBG("Vectorizing operation as linalg generic\n"); + LDBG() << "Vectorizing operation as linalg generic/n"; Block *block = linalgOp.getBlock(); // 2. Values defined above the region can only be broadcast for now. Make them @@ -1434,7 +1433,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, // 3. Turn all BBArgs into vector.transfer_read / load. Location loc = linalgOp.getLoc(); - Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); for (OpOperand *opOperand : linalgOp.getOpOperandsMatchingBBargs()) { BlockArgument bbarg = linalgOp.getMatchingBlockArgument(opOperand); if (linalgOp.isScalar(opOperand)) { @@ -1464,8 +1463,8 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero); - Operation *read = rewriter.create<vector::TransferReadOp>( - loc, readType, opOperand->get(), indices, + Operation *read = vector::TransferReadOp::create( + rewriter, loc, readType, opOperand->get(), indices, /*padding=*/std::nullopt, readMap); read = state.maskOperation(rewriter, read, linalgOp, indexingMap); Value readValue = read->getResult(0); @@ -1481,11 +1480,11 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, // 3.c. Not all ops support 0-d vectors, extract the scalar for now. // TODO: remove this. if (readType.getRank() == 0) - readValue = rewriter.create<vector::ExtractOp>(loc, readValue, - ArrayRef<int64_t>()); + readValue = vector::ExtractOp::create(rewriter, loc, readValue, + ArrayRef<int64_t>()); - LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue - << "\n"); + LDBG() << "New vectorized bbarg(" << bbarg.getArgNumber() + << "): " << readValue; bvm.map(bbarg, readValue); bvm.map(opOperand->get(), readValue); } @@ -1517,13 +1516,13 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state, VectorizationHookResult result = vectorizeOneOp(rewriter, state, linalgOp, &op, bvm, hooks); if (result.status == VectorizationHookStatus::Failure) { - LDBG("failed to vectorize: " << op << "\n"); + LDBG() << "failed to vectorize: " << op; return failure(); } if (result.status == VectorizationHookStatus::NewOp) { Operation *maybeMaskedOp = state.maskOperation(rewriter, result.newOp, linalgOp); - LDBG("New vector op: " << *maybeMaskedOp << "\n"); + LDBG() << "New vector op: " << *maybeMaskedOp; bvm.map(op.getResults(), maybeMaskedOp->getResults()); } } @@ -1689,17 +1688,16 @@ createWriteOrMaskedWrite(OpBuilder &builder, Location loc, Value vecToStore, writeIndices.size() == static_cast<size_t>(destRank)) && "Invalid number of write indices!"); if (writeIndices.empty()) { - auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); writeIndices.assign(destRank, zero); } // Generate the xfer_write Op - Operation *write = - builder.create<vector::TransferWriteOp>(loc, - /*vector=*/vecToStore, - /*source=*/dest, - /*indices=*/writeIndices, - /*inBounds=*/inBoundsVal); + Operation *write = vector::TransferWriteOp::create(builder, loc, + /*vector=*/vecToStore, + /*source=*/dest, + /*indices=*/writeIndices, + /*inBounds=*/inBoundsVal); // If masking is disabled, exit. if (useInBoundsInsteadOfMasking) @@ -1774,8 +1772,9 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, Location loc = packOp.getLoc(); auto padValue = packOp.getPaddingValue(); if (!padValue) { - padValue = rewriter.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(packOp.getSourceType().getElementType())); + padValue = arith::ConstantOp::create( + rewriter, loc, + rewriter.getZeroAttr(packOp.getSourceType().getElementType())); } ReifiedRankedShapedTypeDims reifiedReturnShapes; LogicalResult status = @@ -1814,17 +1813,17 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp, auto tiledPackType = VectorType::get(getTiledPackShape(packOp, destShape), packOp.getDestType().getElementType()); auto shapeCastOp = - rewriter.create<vector::ShapeCastOp>(loc, tiledPackType, maskedRead); + vector::ShapeCastOp::create(rewriter, loc, tiledPackType, maskedRead); // Create TransposeOp. auto destPermutation = invertPermutationVector(getPackInverseDestPerm(packOp)); - auto transposeOp = rewriter.create<vector::TransposeOp>( - loc, shapeCastOp.getResult(), destPermutation); + auto transposeOp = vector::TransposeOp::create( + rewriter, loc, shapeCastOp.getResult(), destPermutation); // Create TransferWriteOp. - Value dest = rewriter.create<tensor::EmptyOp>( - loc, reifiedReturnShapes[0], + Value dest = tensor::EmptyOp::create( + rewriter, loc, reifiedReturnShapes[0], transposeOp.getResult().getType().getElementType()); Operation *write = createWriteOrMaskedWrite(rewriter, loc, transposeOp.getResult(), dest); @@ -1914,18 +1913,11 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, readVectorSizes.append(sourceShape.begin() + vectorSizes.size(), sourceShape.end()); - ReifiedRankedShapedTypeDims reifiedRetShapes; - LogicalResult status = - cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation()) - .reifyResultShapes(rewriter, reifiedRetShapes); - if (status.failed()) { - LDBG("Unable to reify result shapes of " << unpackOp); - return failure(); - } Location loc = unpackOp->getLoc(); - auto padValue = rewriter.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(unpackOp.getSourceType().getElementType())); + auto padValue = arith::ConstantOp::create( + rewriter, loc, + rewriter.getZeroAttr(unpackOp.getSourceType().getElementType())); // Read result, mask if necessary. If transferReadOp shape is not equal // to shape of source, then a mask is necessary. @@ -1943,23 +1935,17 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, RankedTensorType stripMineTensorType = RankedTensorType::get(stripMineShape, stripMineElemType); // Transpose the appropriate rows to match output. - vector::TransposeOp transposeOp = rewriter.create<vector::TransposeOp>( - loc, readResult, lastDimToInsertPosPerm); + vector::TransposeOp transposeOp = vector::TransposeOp::create( + rewriter, loc, readResult, lastDimToInsertPosPerm); // Collapse the vector to the size required by result. RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType( stripMineTensorType, packMetadata.reassociations); mlir::VectorType vecCollapsedType = VectorType::get(collapsedType.getShape(), collapsedType.getElementType()); - vector::ShapeCastOp shapeCastOp = rewriter.create<vector::ShapeCastOp>( - loc, vecCollapsedType, transposeOp->getResult(0)); - - // writeVectorSizes had to match the shapecast shape for dynamic sizes, - // otherwise the validator complains that the mask size is invalid. - SmallVector<int64_t> writeVectorSizes( - unpackOp.getDestType().hasStaticShape() - ? vectorSizes - : shapeCastOp.getResultVectorType().getShape()); + vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create( + rewriter, loc, vecCollapsedType, transposeOp->getResult(0)); + Operation *write = createWriteOrMaskedWrite( rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(), /*writeIndices=*/{}, useInBoundsInsteadOfMasking); @@ -1992,8 +1978,8 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, /*useInBoundsInsteadOfMasking=*/false); // Create Xfer write Op - Value dest = rewriter.create<tensor::EmptyOp>( - loc, reifiedReturnShapes[0], padOp.getResultType().getElementType()); + Value dest = tensor::EmptyOp::create(rewriter, loc, reifiedReturnShapes[0], + padOp.getResultType().getElementType()); Operation *write = createWriteOrMaskedWrite(rewriter, loc, maskedRead, dest); newResults.push_back(write->getResult(0)); return success(); @@ -2003,7 +1989,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, // ops that may not commute (e.g. linear reduction + non-linear instructions). static LogicalResult reductionPreconditions(LinalgOp op) { if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) { - LDBG("reduction precondition failed: no reduction iterator\n"); + LDBG() << "reduction precondition failed: no reduction iterator"; return failure(); } for (OpOperand &opOperand : op.getDpsInitsMutable()) { @@ -2013,7 +1999,7 @@ static LogicalResult reductionPreconditions(LinalgOp op) { Operation *reduceOp = matchLinalgReduction(&opOperand); if (!reduceOp || !getCombinerOpKind(reduceOp)) { - LDBG("reduction precondition failed: reduction detection failed\n"); + LDBG() << "reduction precondition failed: reduction detection failed"; return failure(); } } @@ -2024,13 +2010,13 @@ static LogicalResult vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, bool flatten1DDepthwiseConv) { if (flatten1DDepthwiseConv) { - LDBG("Vectorization of flattened convs with dynamic shapes is not " - "supported\n"); + LDBG() << "Vectorization of flattened convs with dynamic shapes is not " + "supported"; return failure(); } if (!isa<linalg::DepthwiseConv1DNwcWcOp>(conv)) { - LDBG("Not a 1D depth-wise WC conv, dynamic shapes are not supported\n"); + LDBG() << "Not a 1D depth-wise WC conv, dynamic shapes are not supported"; return failure(); } @@ -2040,8 +2026,8 @@ vectorizeDynamicConvOpPrecondition(linalg::LinalgOp conv, ArrayRef<int64_t> lhsShape = cast<ShapedType>(lhs.getType()).getShape(); auto shapeWithoutCh = lhsShape.drop_back(1); if (ShapedType::isDynamicShape(shapeWithoutCh)) { - LDBG("Dynamically-shaped op vectorization precondition failed: only " - "channel dim can be dynamic\n"); + LDBG() << "Dynamically-shaped op vectorization precondition failed: only " + "channel dim can be dynamic"; return failure(); } @@ -2064,7 +2050,7 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, op.getOperation())) return failure(); - LDBG("Dynamically-shaped op meets vectorization pre-conditions\n"); + LDBG() << "Dynamically-shaped op meets vectorization pre-conditions"; return success(); } @@ -2076,7 +2062,7 @@ vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) { return !getConstantIntValue(res).has_value(); })) { - LDBG("Inner-tiles must be constant: " << unpackOp << "\n"); + LDBG() << "Inner-tiles must be constant: " << unpackOp; return failure(); } ArrayRef<int64_t> resultShape = unpackOp.getDestType().getShape(); @@ -2116,7 +2102,7 @@ vectorizeInsertSliceOpPrecondition(tensor::InsertSliceOp sliceOp, !sourceType.hasStaticShape() && inputVectorSizes.empty(); if (!padValue && isOutOfBoundsRead) { - LDBG("Failed to get a pad value for out-of-bounds read access\n"); + LDBG() << "Failed to get a pad value for out-of-bounds read access"; return failure(); } return success(); @@ -2146,7 +2132,7 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, Operation *reduceOp = matchLinalgReduction(outOperand); auto maybeKind = getCombinerOpKind(reduceOp); if (!maybeKind) { - LDBG("Failed to determine contraction combining kind.\n"); + LDBG() << "Failed to determine contraction combining kind."; return failure(); } @@ -2156,7 +2142,7 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, AffineMap lhsMap = linalgOp.getIndexingMapsArray()[0]; AffineMap rhsMap = linalgOp.getIndexingMapsArray()[1]; if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) { - LDBG("Contractions with broadcasts are not supported.\n"); + LDBG() << "Contractions with broadcasts are not supported."; return failure(); } @@ -2191,8 +2177,8 @@ vectorizeAsLinalgContraction(RewriterBase &rewriter, VectorizationState &state, } // Create contraction. - Operation *contractOp = rewriter.create<vector::ContractionOp>( - loc, /*lhs=*/vecOperands[0], + Operation *contractOp = vector::ContractionOp::create( + rewriter, loc, /*lhs=*/vecOperands[0], /*rhs=*/vecOperands[1], /*acc=*/vecOperands[2], linalgOp.getIndexingMaps(), rewriter.getArrayAttr(iterAttrs), *maybeKind); contractOp = state.maskOperation(rewriter, contractOp, linalgOp); @@ -2348,7 +2334,7 @@ static LogicalResult vectorizeLinalgOpPrecondition( if (linalgOp.hasDynamicShape() && failed(vectorizeDynamicLinalgOpPrecondition( linalgOp, flatten1DDepthwiseConv))) { - LDBG("Dynamically-shaped op failed vectorization pre-conditions\n"); + LDBG() << "Dynamically-shaped op failed vectorization pre-conditions"; return failure(); } @@ -2390,11 +2376,11 @@ static LogicalResult vectorizeLinalgOpPrecondition( // all indexing maps are projected permutations. For convs and stencils the // logic will need to evolve. if (!allIndexingsAreProjectedPermutation(linalgOp)) { - LDBG("precondition failed: not projected permutations\n"); + LDBG() << "precondition failed: not projected permutations"; return failure(); } if (failed(reductionPreconditions(linalgOp))) { - LDBG("precondition failed: reduction preconditions\n"); + LDBG() << "precondition failed: reduction preconditions"; return failure(); } return success(); @@ -2406,7 +2392,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp, auto padValue = packOp.getPaddingValue(); Attribute cstAttr; if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) { - LDBG("pad value is not constant: " << packOp << "\n"); + LDBG() << "pad value is not constant: " << packOp; return failure(); } ArrayRef<int64_t> resultTensorShape = packOp.getDestType().getShape(); @@ -2426,7 +2412,7 @@ vectorizePackOpPrecondition(linalg::PackOp packOp, if (llvm::any_of(packOp.getInnerTiles(), [](OpFoldResult v) { return !getConstantIntValue(v).has_value(); })) { - LDBG("inner_tiles must be constant: " << packOp << "\n"); + LDBG() << "inner_tiles must be constant: " << packOp; return failure(); } @@ -2438,7 +2424,7 @@ vectorizePadOpPrecondition(tensor::PadOp padOp, ArrayRef<int64_t> inputVectorSizes) { auto padValue = padOp.getConstantPaddingValue(); if (!padValue) { - LDBG("pad value is not constant: " << padOp << "\n"); + LDBG() << "pad value is not constant: " << padOp; return failure(); } @@ -2465,7 +2451,7 @@ vectorizePadOpPrecondition(tensor::PadOp padOp, return (!pad.has_value() || pad.value() != 0) && resultTensorShape[pos] != 1; })) { - LDBG("low pad must all be zero for all non unit dims: " << padOp << "\n"); + LDBG() << "low pad must all be zero for all non unit dims: " << padOp; return failure(); } @@ -2534,13 +2520,14 @@ vectorizeScalableVectorPrecondition(Operation *op, case utils::IteratorType::reduction: { // Check 3. above is met. if (iterators.size() != inputVectorSizes.size()) { - LDBG("Non-trailing reduction dim requested for scalable " - "vectorization\n"); + LDBG() << "Non-trailing reduction dim requested for scalable " + "vectorization"; return failure(); } if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) { - LDBG("Scalable vectorization of the reduction dim in Matmul-like ops " - "is not supported\n"); + LDBG() + << "Scalable vectorization of the reduction dim in Matmul-like ops " + "is not supported"; return failure(); } break; @@ -2548,8 +2535,8 @@ vectorizeScalableVectorPrecondition(Operation *op, case utils::IteratorType::parallel: { // Check 1. and 2. above are met. if (seenNonUnitParallel) { - LDBG("Inner parallel dim not requested for scalable " - "vectorization\n"); + LDBG() << "Inner parallel dim not requested for scalable " + "vectorization"; return failure(); } break; @@ -2565,8 +2552,9 @@ vectorizeScalableVectorPrecondition(Operation *op, // * iterators = [..., parallel, reduction] // * scalable flags = [..., true, true] if (iterators.back() == utils::IteratorType::reduction) { - LDBG("Higher dim than the trailing reduction dim requested for scalable " - "vectorization\n"); + LDBG() << "Higher dim than the trailing reduction dim requested for " + "scalable " + "vectorizatio"; return failure(); } scalableFlags.pop_back(); @@ -2649,18 +2637,15 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize( ArrayRef<bool> inputScalableVecDims, bool vectorizeNDExtract, bool flatten1DDepthwiseConv, bool assumeDynamicDimsMatchVecSizes, bool createNamedContraction) { - LDBG("Attempting to vectorize:\n" << *op << "\n"); - LDBG("Input vector sizes: "); - LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); - LDBG("Input scalable vector dims: "); - LLVM_DEBUG(llvm::interleaveComma(inputScalableVecDims, llvm::dbgs())); - LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG() << "Attempting to vectorize: " << *op; + LDBG() << "Input vector sizes: " << llvm::interleaved(inputVectorSizes); + LDBG() << "Input scalable vector dims: " + << llvm::interleaved(inputScalableVecDims); if (failed(vectorizeOpPrecondition(op, inputVectorSizes, inputScalableVecDims, vectorizeNDExtract, flatten1DDepthwiseConv))) { - LDBG("Vectorization pre-conditions failed\n"); + LDBG() << "Vectorization pre-conditions failed"; return failure(); } @@ -2670,7 +2655,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize( if (failed(state.initState(rewriter, linalgOp, inputVectorSizes, inputScalableVecDims, assumeDynamicDimsMatchVecSizes))) { - LDBG("Vectorization state couldn't be initialized\n"); + LDBG() << "Vectorization state couldn't be initialized"; return failure(); } } @@ -2691,7 +2676,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize( return success(); } - LDBG("Unsupported convolution can't be vectorized.\n"); + LDBG() << "Unsupported convolution can't be vectorized."; return failure(); } @@ -2700,8 +2685,9 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize( return vectorizeAsLinalgContraction(rewriter, state, linalgOp, results); - LDBG("Vectorize generic by broadcasting to the canonical vector " - "shape\n"); + LDBG() + << "Vectorize generic by broadcasting to the canonical vector " + "shape"; // Pre-process before proceeding. convertAffineApply(rewriter, linalgOp); @@ -2732,7 +2718,7 @@ FailureOr<VectorizationResult> mlir::linalg::vectorize( .Default([](auto) { return failure(); }); if (failed(vectorizeResult)) { - LDBG("Vectorization failed\n"); + LDBG() << "Vectorization failed"; return failure(); } @@ -2756,20 +2742,21 @@ LogicalResult mlir::linalg::vectorizeCopy(RewriterBase &rewriter, auto writeType = VectorType::get(dstType.getShape(), dstElementType); Location loc = copyOp->getLoc(); - Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); SmallVector<Value> indices(srcType.getRank(), zero); - Value readValue = rewriter.create<vector::TransferReadOp>( - loc, readType, copyOp.getSource(), indices, + Value readValue = vector::TransferReadOp::create( + rewriter, loc, readType, copyOp.getSource(), indices, /*padding=*/std::nullopt, rewriter.getMultiDimIdentityMap(srcType.getRank())); if (cast<VectorType>(readValue.getType()).getRank() == 0) { + readValue = vector::ExtractOp::create(rewriter, loc, readValue, + ArrayRef<int64_t>()); readValue = - rewriter.create<vector::ExtractOp>(loc, readValue, ArrayRef<int64_t>()); - readValue = rewriter.create<vector::BroadcastOp>(loc, writeType, readValue); + vector::BroadcastOp::create(rewriter, loc, writeType, readValue); } - Operation *writeValue = rewriter.create<vector::TransferWriteOp>( - loc, readValue, copyOp.getTarget(), indices, + Operation *writeValue = vector::TransferWriteOp::create( + rewriter, loc, readValue, copyOp.getTarget(), indices, rewriter.getMultiDimIdentityMap(srcType.getRank())); rewriter.replaceOp(copyOp, writeValue->getResults()); return success(); @@ -3079,8 +3066,8 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, if (!padValue) { auto elemType = sourceType.getElementType(); - padValue = rewriter.create<arith::ConstantOp>( - sliceOp.getLoc(), elemType, rewriter.getZeroAttr(elemType)); + padValue = arith::ConstantOp::create(rewriter, sliceOp.getLoc(), elemType, + rewriter.getZeroAttr(elemType)); } // 2. Get the vector shape @@ -3111,7 +3098,7 @@ vectorizeAsInsertSliceOp(RewriterBase &rewriter, tensor::InsertSliceOp sliceOp, // Create read SmallVector<Value> readIndices( - vecType.getRank(), rewriter.create<arith::ConstantIndexOp>(loc, 0)); + vecType.getRank(), arith::ConstantIndexOp::create(rewriter, loc, 0)); Value read = mlir::vector::createReadOrMaskedRead( rewriter, loc, source, vecType.getShape(), padValue, /*useInBoundsInsteadOfMasking=*/inputVectorSizes.empty()); @@ -3198,9 +3185,10 @@ struct PadOpVectorizationWithInsertSlicePattern // Generate TransferReadOp: Read entire source tensor and add high // padding. SmallVector<Value> readIndices( - vecRank, rewriter.create<arith::ConstantIndexOp>(padOp.getLoc(), 0)); - auto read = rewriter.create<vector::TransferReadOp>( - padOp.getLoc(), vecType, padOp.getSource(), readIndices, padValue); + vecRank, arith::ConstantIndexOp::create(rewriter, padOp.getLoc(), 0)); + auto read = vector::TransferReadOp::create(rewriter, padOp.getLoc(), + vecType, padOp.getSource(), + readIndices, padValue); // Generate TransferWriteOp: Write to InsertSliceOp's dest tensor at // specified offsets. Write is fully in-bounds because a InsertSliceOp's @@ -3235,8 +3223,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, ValueRange values) { if (firstOp->getBlock() != secondOp->getBlock() || !firstOp->isBeforeInBlock(secondOp)) { - LDBG("interleavedUses precondition failed, firstOp: " - << *firstOp << ", second op: " << *secondOp << "\n"); + LDBG() << "interleavedUses precondition failed, firstOp: " << *firstOp + << ", second op: " << *secondOp; return true; } for (auto v : values) { @@ -3248,8 +3236,8 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp, if (owner->getBlock() == firstOp->getBlock() && (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner))) continue; - LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp - << ", second op: " << *secondOp << "\n"); + LDBG() << " found interleaved op " << *owner << ", firstOp: " << *firstOp + << ", second op: " << *secondOp; return true; } } @@ -3334,8 +3322,8 @@ LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite( // When forwarding to vector.transfer_read, the attribute must be reset // conservatively. auto vectorType = xferOp.getVectorType(); - Value res = rewriter.create<vector::TransferReadOp>( - xferOp.getLoc(), vectorType, in, xferOp.getIndices(), + Value res = vector::TransferReadOp::create( + rewriter, xferOp.getLoc(), vectorType, in, xferOp.getIndices(), xferOp.getPermutationMapAttr(), xferOp.getPadding(), xferOp.getMask(), rewriter.getBoolArrayAttr( SmallVector<bool>(vectorType.getRank(), false))); @@ -3393,8 +3381,8 @@ LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite( // When forwarding to vector.transfer_write, the attribute must be reset // conservatively. auto vector = xferOp.getVector(); - rewriter.create<vector::TransferWriteOp>( - xferOp.getLoc(), vector, out, xferOp.getIndices(), + vector::TransferWriteOp::create( + rewriter, xferOp.getLoc(), vector, out, xferOp.getIndices(), xferOp.getPermutationMapAttr(), xferOp.getMask(), rewriter.getBoolArrayAttr(SmallVector<bool>( dyn_cast<VectorType>(vector.getType()).getRank(), false))); @@ -3589,7 +3577,7 @@ struct Conv1DGenerator } vector::TransferWriteOp write; - Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. // When strideW == 1, we can batch the contiguous loads and avoid @@ -3608,17 +3596,17 @@ struct Conv1DGenerator SmallVector<Value> resPadding(resShape.size(), zero); // Read the whole lhs, rhs and res in one shot (with zero padding). - Value lhs = rewriter.create<vector::TransferReadOp>( - loc, lhsType, lhsShaped, lhsPadding, + Value lhs = vector::TransferReadOp::create( + rewriter, loc, lhsType, lhsShaped, lhsPadding, /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType)); // This is needed only for Conv. Value rhs = nullptr; if (oper == ConvOperationKind::Conv) - rhs = rewriter.create<vector::TransferReadOp>( - loc, rhsType, rhsShaped, rhsPadding, + rhs = vector::TransferReadOp::create( + rewriter, loc, rhsType, rhsShaped, rhsPadding, /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType)); - Value res = rewriter.create<vector::TransferReadOp>( - loc, resType, resShaped, resPadding, + Value res = vector::TransferReadOp::create( + rewriter, loc, resType, resShaped, resPadding, /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType)); // The base vectorization case for channeled convolution is input: @@ -3633,16 +3621,16 @@ struct Conv1DGenerator // To match base vectorization case, we pre-transpose current case. // ncw -> nwc static constexpr std::array<int64_t, 3> permLhs = {0, 2, 1}; - lhs = rewriter.create<vector::TransposeOp>(loc, lhs, permLhs); + lhs = vector::TransposeOp::create(rewriter, loc, lhs, permLhs); // fcw -> wcf static constexpr std::array<int64_t, 3> permRhs = {2, 1, 0}; // This is needed only for Conv. if (oper == ConvOperationKind::Conv) - rhs = rewriter.create<vector::TransposeOp>(loc, rhs, permRhs); + rhs = vector::TransposeOp::create(rewriter, loc, rhs, permRhs); // nfw -> nwf static constexpr std::array<int64_t, 3> permRes = {0, 2, 1}; - res = rewriter.create<vector::TransposeOp>(loc, res, permRes); + res = vector::TransposeOp::create(rewriter, loc, res, permRes); break; } } @@ -3707,13 +3695,13 @@ struct Conv1DGenerator case Conv1DOpOrder::Ncw: { // nwf -> nfw static constexpr std::array<int64_t, 3> perm = {0, 2, 1}; - res = rewriter.create<vector::TransposeOp>(loc, res, perm); + res = vector::TransposeOp::create(rewriter, loc, res, perm); break; } } - return rewriter - .create<vector::TransferWriteOp>(loc, res, resShaped, resPadding) + return vector::TransferWriteOp::create(rewriter, loc, res, resShaped, + resPadding) .getOperation(); } @@ -3731,16 +3719,16 @@ struct Conv1DGenerator cast<ShapedType>(val.getType()).cloneWith(std::nullopt, dstElementType); if (isa<IntegerType>(srcElementType) && isa<FloatType>(dstElementType)) { - return rewriter.create<arith::SIToFPOp>(loc, dstType, val); + return arith::SIToFPOp::create(rewriter, loc, dstType, val); } if (isa<FloatType>(srcElementType) && isa<FloatType>(dstElementType) && srcWidth < dstWidth) - return rewriter.create<arith::ExtFOp>(loc, dstType, val); + return arith::ExtFOp::create(rewriter, loc, dstType, val); if (isa<IntegerType>(srcElementType) && isa<IntegerType>(dstElementType) && srcWidth < dstWidth) - return rewriter.create<arith::ExtSIOp>(loc, dstType, val); + return arith::ExtSIOp::create(rewriter, loc, dstType, val); assert(false && "unhandled promotion case"); return nullptr; @@ -3755,8 +3743,8 @@ struct Conv1DGenerator bindDims(ctx, n, w, f, c); lhs = promote(rewriter, loc, lhs, res.getType()); rhs = promote(rewriter, loc, rhs, res.getType()); - auto contrationOp = rewriter.create<vector::ContractionOp>( - loc, lhs, rhs, res, + auto contrationOp = vector::ContractionOp::create( + rewriter, loc, lhs, rhs, res, /*indexingMaps=*/MapList{{n, w, c}, {c, f}, {n, w, f}}, /*iteratorTypes=*/ArrayRef<vector::IteratorType>{par, par, par, red}); contrationOp.setKind(reductionKind); @@ -3767,8 +3755,8 @@ struct Conv1DGenerator // convolution. Value conv1dSliceAsOuterProduct(RewriterBase &rewriter, Location loc, Value lhs, Value rhs, Value res) { - return rewriter.create<vector::OuterProductOp>( - loc, res.getType(), lhs, rhs, res, vector::CombiningKind::ADD); + return vector::OuterProductOp::create(rewriter, loc, res.getType(), lhs, + rhs, res, vector::CombiningKind::ADD); } // Create a reduction: lhs{n, w, c} -> res{n, w, c} @@ -3815,7 +3803,7 @@ struct Conv1DGenerator bindShapeDims(resShapedType, nSize, wSize); vector::TransferWriteOp write; - Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); // w is unrolled (i.e. wSizeStep == 1) iff strideW > 1. // When strideW == 1, we can batch the contiguous loads and avoid @@ -3858,29 +3846,29 @@ struct Conv1DGenerator cast<LinalgOp>(op).hasPureTensorSemantics(), opToMask, rewriter); Value maskOp = - rewriter.create<vector::CreateMaskOp>(loc, maskType, mixedDims); + vector::CreateMaskOp::create(rewriter, loc, maskType, mixedDims); return mlir::vector::maskOperation(rewriter, opToMask, maskOp); }; // Read lhs slice of size {n, w * strideW + kw * dilationW, c} @ [0, 0, // 0]. - Value lhs = rewriter.create<vector::TransferReadOp>( - loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}, + Value lhs = vector::TransferReadOp::create( + rewriter, loc, lhsType, lhsShaped, ValueRange{zero, zero, zero}, /*padding=*/arith::getZeroConstant(rewriter, loc, lhsEltType)); auto maybeMaskedLhs = maybeMaskXferOp( lhsType.getShape(), lhsType.getScalableDims(), lhs.getDefiningOp()); // Read rhs slice of size {kw, c} @ [0, 0]. - Value rhs = rewriter.create<vector::TransferReadOp>( - loc, rhsType, rhsShaped, ValueRange{zero, zero}, + Value rhs = vector::TransferReadOp::create( + rewriter, loc, rhsType, rhsShaped, ValueRange{zero, zero}, /*padding=*/arith::getZeroConstant(rewriter, loc, rhsEltType)); auto maybeMaskedRhs = maybeMaskXferOp( rhsType.getShape(), rhsType.getScalableDims(), rhs.getDefiningOp()); // Read res slice of size {n, w, c} @ [0, 0, 0]. - Value res = rewriter.create<vector::TransferReadOp>( - loc, resType, resShaped, ValueRange{zero, zero, zero}, + Value res = vector::TransferReadOp::create( + rewriter, loc, resType, resShaped, ValueRange{zero, zero, zero}, /*padding=*/arith::getZeroConstant(rewriter, loc, resEltType)); auto maybeMaskedRes = maybeMaskXferOp( resType.getShape(), resType.getScalableDims(), res.getDefiningOp()); @@ -3897,22 +3885,22 @@ struct Conv1DGenerator // @ [0, sw * w + dw * kw, 0]. for (int64_t kw = 0; kw < kwSize; ++kw) { for (int64_t w = 0; w < wSize; w += wSizeStep) { - lhsVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>( - loc, maybeMaskedLhs->getResult(0), + lhsVals.push_back(vector::ExtractStridedSliceOp::create( + rewriter, loc, maybeMaskedLhs->getResult(0), /*offsets=*/ArrayRef<int64_t>{0, w * strideW + kw * dilationW, 0}, inOutSliceSizes, inOutStrides)); } } // Extract rhs slice of size {c} @ [kw]. for (int64_t kw = 0; kw < kwSize; ++kw) { - rhsVals.push_back(rewriter.create<vector::ExtractOp>( - loc, maybeMaskedRhs->getResult(0), - /*offsets=*/ArrayRef<int64_t>{kw})); + rhsVals.push_back( + vector::ExtractOp::create(rewriter, loc, maybeMaskedRhs->getResult(0), + /*offsets=*/ArrayRef<int64_t>{kw})); } // Extract res slice: {n, wSizeStep, c} @ [0, w, 0]. for (int64_t w = 0; w < wSize; w += wSizeStep) { - resVals.push_back(rewriter.create<vector::ExtractStridedSliceOp>( - loc, maybeMaskedRes->getResult(0), + resVals.push_back(vector::ExtractStridedSliceOp::create( + rewriter, loc, maybeMaskedRes->getResult(0), /*offsets=*/ArrayRef<int64_t>{0, w, 0}, inOutSliceSizes, inOutStrides)); } @@ -3937,17 +3925,19 @@ struct Conv1DGenerator if (flatten) { // Flatten the input and output vectors (collapse the channel // dimension) - lhsVal = rewriter.create<vector::ShapeCastOp>( - loc, lhsTypeAfterFlattening, lhsVals[linearIndex(kw, w)]); - resVal = rewriter.create<vector::ShapeCastOp>( - loc, resTypeAfterFlattening, resVals[w]); + lhsVal = + vector::ShapeCastOp::create(rewriter, loc, lhsTypeAfterFlattening, + lhsVals[linearIndex(kw, w)]); + resVal = vector::ShapeCastOp::create( + rewriter, loc, resTypeAfterFlattening, resVals[w]); } resVals[w] = depthwiseConv1dSliceAsMulAcc(rewriter, loc, lhsVal, rhsVals[kw], resVal, flatten); if (flatten) { // Un-flatten the output vector (restore the channel dimension) - resVals[w] = rewriter.create<vector::ShapeCastOp>( - loc, VectorType::get(inOutSliceSizes, resEltType), resVals[w]); + resVals[w] = vector::ShapeCastOp::create( + rewriter, loc, VectorType::get(inOutSliceSizes, resEltType), + resVals[w]); } } } @@ -3965,8 +3955,8 @@ struct Conv1DGenerator // Write back res slice: {n, wSizeStep, c} @ [0, w, 0]. // This does not depend on kw. for (int64_t w = 0; w < wSize; w += wSizeStep) { - maybeMaskedRes = rewriter.create<vector::InsertStridedSliceOp>( - loc, resVals[w], maybeMaskedRes->getResult(0), + maybeMaskedRes = vector::InsertStridedSliceOp::create( + rewriter, loc, resVals[w], maybeMaskedRes->getResult(0), /*offsets=*/ArrayRef<int64_t>{0, w, 0}, /*strides=*/ArrayRef<int64_t>{1, 1, 1}); } @@ -3975,8 +3965,8 @@ struct Conv1DGenerator //===------------------------------------------------------------------===// // Write back res slice of size {n, w, c} @ [0, 0, 0]. - Operation *resOut = rewriter.create<vector::TransferWriteOp>( - loc, maybeMaskedRes->getResult(0), resShaped, + Operation *resOut = vector::TransferWriteOp::create( + rewriter, loc, maybeMaskedRes->getResult(0), resShaped, ValueRange{zero, zero, zero}); return maybeMaskXferOp(resType.getShape(), resType.getScalableDims(), resOut); @@ -4013,11 +4003,11 @@ struct Conv1DGenerator indices.push_back(j); } - rhs = rewriter.create<vector::ShuffleOp>(loc, rhs, rhs, indices); + rhs = vector::ShuffleOp::create(rewriter, loc, rhs, rhs, indices); } // Broadcast the filter to match the output vector - rhs = rewriter.create<vector::BroadcastOp>( - loc, resTy.clone(rhsTy.getElementType()), rhs); + rhs = vector::BroadcastOp::create(rewriter, loc, + resTy.clone(rhsTy.getElementType()), rhs); rhs = promote(rewriter, loc, rhs, resTy); @@ -4025,10 +4015,10 @@ struct Conv1DGenerator return nullptr; if (isa<FloatType>(resTy.getElementType())) - return rewriter.create<vector::FMAOp>(loc, lhs, rhs, res); + return vector::FMAOp::create(rewriter, loc, lhs, rhs, res); - auto mul = rewriter.create<arith::MulIOp>(loc, lhs, rhs); - return rewriter.create<arith::AddIOp>(loc, mul, res); + auto mul = arith::MulIOp::create(rewriter, loc, lhs, rhs); + return arith::AddIOp::create(rewriter, loc, mul, res); } /// Entry point for non-channeled convolution: diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index 9fd0844..b80b27f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -201,11 +201,12 @@ Value create2DTransformMatrix(OpBuilder &builder, Location loc, TransformMatrix transform, Type type) { ArrayRef<float> constVec(transform.table, transform.rows * transform.cols); - return builder.create<arith::ConstantOp>( - loc, DenseFPElementsAttr::get( - RankedTensorType::get( - SmallVector<int64_t>{transform.rows, transform.cols}, type), - constVec)); + return arith::ConstantOp::create( + builder, loc, + DenseFPElementsAttr::get( + RankedTensorType::get( + SmallVector<int64_t>{transform.rows, transform.cols}, type), + constVec)); } /// Extract height x width data from 4D tensors. @@ -233,8 +234,8 @@ Value extract2DDataFrom4D(OpBuilder &builder, Location loc, Value source, auto extractFilterType = RankedTensorType::get({extractHeight, extractWidth}, elementType); - auto extractFilterOp = builder.create<tensor::ExtractSliceOp>( - loc, extractFilterType, source, offsets, sizes, strides); + auto extractFilterOp = tensor::ExtractSliceOp::create( + builder, loc, extractFilterType, source, offsets, sizes, strides); return extractFilterOp; } @@ -267,8 +268,8 @@ Value extract2DDataFrom6D(OpBuilder &builder, Location loc, Value source, SmallVector<OpFoldResult> strides(srcSize, oneIndex); auto extractFilterType = RankedTensorType::get({height, width}, elementType); - auto extractFilterOp = builder.create<tensor::ExtractSliceOp>( - loc, extractFilterType, source, offsets, sizes, strides); + auto extractFilterOp = tensor::ExtractSliceOp::create( + builder, loc, extractFilterType, source, offsets, sizes, strides); return extractFilterOp; } @@ -293,8 +294,8 @@ Value insert2DDataTo4D(OpBuilder &builder, Location loc, Value source, retSizes[widthIdx] = builder.getIndexAttr(width); SmallVector<OpFoldResult> strides(destSize, oneIndex); - auto insertSliceOp = builder.create<tensor::InsertSliceOp>( - loc, source, dest, retOffsets, retSizes, strides); + auto insertSliceOp = tensor::InsertSliceOp::create( + builder, loc, source, dest, retOffsets, retSizes, strides); return insertSliceOp; } @@ -321,8 +322,8 @@ Value insert2DDataTo6D(OpBuilder &builder, Location loc, Value source, retSizes[widthIdx] = builder.getIndexAttr(width); SmallVector<OpFoldResult> strides(destSize, oneIndex); - auto insertSliceOp = builder.create<tensor::InsertSliceOp>( - loc, source, dest, retOffsets, retSizes, strides); + auto insertSliceOp = tensor::InsertSliceOp::create( + builder, loc, source, dest, retOffsets, retSizes, strides); return insertSliceOp; } @@ -372,7 +373,7 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, if (filterW != r && filterW != 1) return Value(); - Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); + Value zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0); auto buildBody = [&](OpBuilder &builder, Location loc, ValueRange ivs, ValueRange args) -> scf::ValueVector { Value FIter = ivs[0]; @@ -386,8 +387,8 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, int64_t retRows = 1; Value matmulRetValue = extractFilter; - Value zero = builder.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(elementType)); + Value zero = arith::ConstantOp::create(builder, loc, + rewriter.getZeroAttr(elementType)); if (leftTransform) { // Get constant transform matrix G. auto it = GMatrices.find(fmr); @@ -397,16 +398,17 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, retRows = GMatrix.rows; auto matmulType = RankedTensorType::get({retRows, filterW}, elementType); - auto empty = - builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) - .getResult(); - auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); + auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(), + elementType) + .getResult(); + auto init = + linalg::FillOp::create(builder, loc, zero, empty).getResult(0); Value G = create2DTransformMatrix(builder, loc, GMatrix, elementType); // Multiply G x g. - auto matmulOp = builder.create<linalg::MatmulOp>( - loc, matmulType, ValueRange{G, extractFilter}, ValueRange{init}); + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, + ValueRange{G, extractFilter}, + ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } @@ -419,16 +421,17 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, auto matmulType = RankedTensorType::get({retRows, GTMatrix.cols}, elementType); - auto empty = - builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) - .getResult(); - auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); + auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(), + elementType) + .getResult(); + auto init = + linalg::FillOp::create(builder, loc, zero, empty).getResult(0); Value GT = create2DTransformMatrix(builder, loc, GTMatrix, elementType); // Multiply u = (G x g) x GT. - auto matmulOp = builder.create<linalg::MatmulOp>( - loc, matmulType, ValueRange{matmulRetValue, GT}, ValueRange{init}); + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, + ValueRange{matmulRetValue, GT}, + ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } @@ -445,9 +448,9 @@ Value filterTransform(RewriterBase &rewriter, Location loc, Value filter, return {insertSliceOp}; }; - auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterF); - auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, filterC); - auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterF); + auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, filterC); + auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1); scf::LoopNest loops = scf::buildLoopNest( rewriter, loc, {zeroIdx, zeroIdx}, {fUpperBound, cUpperBound}, {oneStep, oneStep}, {retValue}, buildBody); @@ -516,10 +519,11 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, auto identityAffineMap = rewriter.getMultiDimIdentityMap(1); auto affineMap = AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); - Value heightOffset = builder.create<affine::AffineApplyOp>( - loc, leftTransform ? affineMap : identityAffineMap, tileHIter); - Value widthOffset = builder.create<affine::AffineApplyOp>( - loc, rightTransform ? affineMap : identityAffineMap, tileWIter); + Value heightOffset = affine::AffineApplyOp::create( + builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter); + Value widthOffset = affine::AffineApplyOp::create( + builder, loc, rightTransform ? affineMap : identityAffineMap, + tileWIter); // Extract (H, W) from (N, H, W, C). auto extractInput = @@ -530,8 +534,8 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, int64_t retRows = 1; int64_t retCols = 1; Value matmulRetValue = extractInput; - Value zero = builder.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(elementType)); + Value zero = arith::ConstantOp::create(builder, loc, + rewriter.getZeroAttr(elementType)); if (leftTransform) { // Get constant transform matrix BT. auto it = BTMatrices.find(fmr); @@ -541,17 +545,18 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, retRows = BTMatrix.rows; auto matmulType = RankedTensorType::get({retRows, alphaW}, elementType); - auto empty = - builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) - .getResult(); - auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); + auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(), + elementType) + .getResult(); + auto init = + linalg::FillOp::create(builder, loc, zero, empty).getResult(0); Value BT = create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type()); // Multiply BT x d. - auto matmulOp = builder.create<linalg::MatmulOp>( - loc, matmulType, ValueRange{BT, matmulRetValue}, ValueRange{init}); + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, + ValueRange{BT, matmulRetValue}, + ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } @@ -564,16 +569,17 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, retCols = BMatrix.cols; auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); - auto empty = - builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), elementType) - .getResult(); - auto init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); + auto empty = tensor::EmptyOp::create(builder, loc, matmulType.getShape(), + elementType) + .getResult(); + auto init = + linalg::FillOp::create(builder, loc, zero, empty).getResult(0); Value B = create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type()); // Multiply v = (BT x d) x B. - auto matmulOp = builder.create<linalg::MatmulOp>( - loc, matmulType, ValueRange{matmulRetValue, B}, ValueRange{init}); + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, + ValueRange{matmulRetValue, B}, + ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } @@ -586,12 +592,12 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, return {combinedVal}; }; - auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tileH); - auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW); - auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputN); - auto cUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, inputC); - auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tileH); + auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW); + auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputN); + auto cUpperBound = arith::ConstantIndexOp::create(rewriter, loc, inputC); + auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1); scf::LoopNest loops = scf::buildLoopNest( rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, {tileHBound, tileWBound, nUpperBound, cUpperBound}, @@ -629,8 +635,8 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc, {filterShape[0] * filterShape[1], filterShape[2], filterShape[3]}, filterElementType); SmallVector<ReassociationIndices> filterReassoc = {{0, 1}, {2}, {3}}; - Value collapseFilter = rewriter.create<tensor::CollapseShapeOp>( - loc, filterReassocType, transformedFilter, filterReassoc); + Value collapseFilter = tensor::CollapseShapeOp::create( + rewriter, loc, filterReassocType, transformedFilter, filterReassoc); // Convert (alphaH, alphaW, tileH, tileW, N, C) to // (alphaH x alphaW, tileH x tileW x N, C) for input. @@ -643,24 +649,23 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc, inputShape[2] * inputShape[3] * inputShape[4], inputShape[5]}, inputElementType); SmallVector<ReassociationIndices> inputReassoc = {{0, 1}, {2, 3, 4}, {5}}; - Value collapseInput = rewriter.create<tensor::CollapseShapeOp>( - loc, inputReassocType, transformedInput, inputReassoc); + Value collapseInput = tensor::CollapseShapeOp::create( + rewriter, loc, inputReassocType, transformedInput, inputReassoc); // Batched matrix multiply. auto matmulType = RankedTensorType::get( {inputShape[0] * inputShape[1], inputShape[2] * inputShape[3] * inputShape[4], filterShape[3]}, outputElementType); - Value empty = rewriter - .create<tensor::EmptyOp>(loc, matmulType.getShape(), - outputElementType) + Value empty = tensor::EmptyOp::create(rewriter, loc, matmulType.getShape(), + outputElementType) .getResult(); - Value zero = rewriter.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(outputElementType)); - Value init = rewriter.create<linalg::FillOp>(loc, zero, empty).getResult(0); + Value zero = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(outputElementType)); + Value init = linalg::FillOp::create(rewriter, loc, zero, empty).getResult(0); - auto matmulOp = rewriter.create<linalg::BatchMatmulOp>( - loc, matmulType, ValueRange({collapseInput, collapseFilter}), + auto matmulOp = linalg::BatchMatmulOp::create( + rewriter, loc, matmulType, ValueRange({collapseInput, collapseFilter}), ValueRange{init}); // The result shape of batch matmul is (alphaH x alphaW, tileH x tileW x N, F) @@ -670,8 +675,8 @@ static Value matrixMultiply(RewriterBase &rewriter, Location loc, RankedTensorType::get({inputShape[0], inputShape[1], inputShape[2], inputShape[3], inputShape[4], filterShape[3]}, outputElementType); - auto expandOutput = rewriter.create<tensor::ExpandShapeOp>( - loc, outputReassocType, matmulOp.getResult(0), outputReassoc); + auto expandOutput = tensor::ExpandShapeOp::create( + rewriter, loc, outputReassocType, matmulOp.getResult(0), outputReassoc); return expandOutput; } @@ -750,16 +755,17 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, int64_t retRows = leftTransform ? ATMatrix.rows : 1; Value matmulRetValue = extractValue; - Value zero = builder.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(elementType)); + Value zero = arith::ConstantOp::create(builder, loc, + rewriter.getZeroAttr(elementType)); auto identityAffineMap = rewriter.getMultiDimIdentityMap(1); auto affineMap = AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context); - Value heightOffset = builder.create<affine::AffineApplyOp>( - loc, leftTransform ? affineMap : identityAffineMap, tileHIter); - Value widthOffset = builder.create<affine::AffineApplyOp>( - loc, rightTransform ? affineMap : identityAffineMap, tileWIter); + Value heightOffset = affine::AffineApplyOp::create( + builder, loc, leftTransform ? affineMap : identityAffineMap, tileHIter); + Value widthOffset = affine::AffineApplyOp::create( + builder, loc, rightTransform ? affineMap : identityAffineMap, + tileWIter); Value outInitVal = extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset, @@ -771,17 +777,17 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, auto matmulType = RankedTensorType::get({retRows, valueW}, elementType); Value init = outInitVal; if (rightTransform || scalarFactor != 1) { - auto empty = builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), - elementType) + auto empty = tensor::EmptyOp::create(builder, loc, + matmulType.getShape(), elementType) .getResult(); - init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); + init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); } Value AT = create2DTransformMatrix(builder, loc, ATMatrix, elementType); // Multiply AT x m. - auto matmulOp = builder.create<linalg::MatmulOp>( - loc, matmulType, ValueRange{AT, matmulRetValue}, ValueRange{init}); + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, + ValueRange{AT, matmulRetValue}, + ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } @@ -790,47 +796,45 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, RankedTensorType::get({retRows, AMatrix.cols}, elementType); Value init = outInitVal; if (scalarFactor != 1) { - auto empty = builder - .create<tensor::EmptyOp>(loc, matmulType.getShape(), - elementType) + auto empty = tensor::EmptyOp::create(builder, loc, + matmulType.getShape(), elementType) .getResult(); - init = builder.create<linalg::FillOp>(loc, zero, empty).getResult(0); + init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); } Value A = create2DTransformMatrix(builder, loc, AMatrix, elementType); // Multiply y = (AT x m) x A. - auto matmulOp = builder.create<linalg::MatmulOp>( - loc, matmulType, ValueRange{matmulRetValue, A}, ValueRange{init}); + auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, + ValueRange{matmulRetValue, A}, + ValueRange{init}); matmulRetValue = matmulOp.getResult(0); } if (scalarFactor != 1) { // Multiply by scalar factor and add outInitVal. - Value scalarFactorValue = builder.create<arith::ConstantOp>( - loc, FloatAttr::get(elementType, scalarFactor)); + Value scalarFactorValue = arith::ConstantOp::create( + builder, loc, FloatAttr::get(elementType, scalarFactor)); auto matmulType = RankedTensorType::get({retRows, retCols}, elementType); auto identityAffineMap = rewriter.getMultiDimIdentityMap(2); SmallVector<AffineMap> affineMaps = { AffineMap::get(2, 0, context), identityAffineMap, identityAffineMap}; matmulRetValue = - rewriter - .create<linalg::GenericOp>( - loc, matmulType, - ValueRange{scalarFactorValue, matmulRetValue}, - ValueRange{outInitVal}, affineMaps, - llvm::ArrayRef<utils::IteratorType>{ - utils::IteratorType::parallel, - utils::IteratorType::parallel}, - [&](OpBuilder &nestedBuilder, Location nestedLoc, - ValueRange args) { - auto mulf = nestedBuilder.create<arith::MulFOp>( - nestedLoc, args[0], args[1]); - auto addf = nestedBuilder.create<arith::AddFOp>( - nestedLoc, mulf.getResult(), args[2]); - nestedBuilder.create<linalg::YieldOp>(nestedLoc, - addf.getResult()); - }) + linalg::GenericOp::create( + rewriter, loc, matmulType, + ValueRange{scalarFactorValue, matmulRetValue}, + ValueRange{outInitVal}, affineMaps, + llvm::ArrayRef<utils::IteratorType>{ + utils::IteratorType::parallel, utils::IteratorType::parallel}, + [&](OpBuilder &nestedBuilder, Location nestedLoc, + ValueRange args) { + auto mulf = arith::MulFOp::create(nestedBuilder, nestedLoc, + args[0], args[1]); + auto addf = arith::AddFOp::create(nestedBuilder, nestedLoc, + mulf.getResult(), args[2]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, + addf.getResult()); + }) .getResult(0); } @@ -847,12 +851,12 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value, int64_t tilwH = valueShape[2]; int64_t tileW = valueShape[3]; - auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto tileHBound = rewriter.create<arith::ConstantIndexOp>(loc, tilwH); - auto tileWBound = rewriter.create<arith::ConstantIndexOp>(loc, tileW); - auto nUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueN); - auto fUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, valueF); - auto oneStep = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto zeroIdx = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto tileHBound = arith::ConstantIndexOp::create(rewriter, loc, tilwH); + auto tileWBound = arith::ConstantIndexOp::create(rewriter, loc, tileW); + auto nUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueN); + auto fUpperBound = arith::ConstantIndexOp::create(rewriter, loc, valueF); + auto oneStep = arith::ConstantIndexOp::create(rewriter, loc, 1); scf::LoopNest loops = scf::buildLoopNest( rewriter, loc, {zeroIdx, zeroIdx, zeroIdx, zeroIdx}, {tileHBound, tileWBound, nUpperBound, fUpperBound}, @@ -867,8 +871,8 @@ static Value padToAlignedTensor(RewriterBase &rewriter, Location loc, auto valueType = cast<ShapedType>(value.getType()); Type elementType = valueType.getElementType(); auto alignedType = RankedTensorType::get(alignedShape, elementType); - Value padValue = rewriter.create<arith::ConstantOp>( - loc, elementType, rewriter.getZeroAttr(elementType)); + Value padValue = arith::ConstantOp::create(rewriter, loc, elementType, + rewriter.getZeroAttr(elementType)); return linalg::makeComposedPadHighOp(rewriter, loc, alignedType, value, padValue, false); @@ -887,8 +891,8 @@ static Value extractFromAlignedTensor(RewriterBase &rewriter, Location loc, SmallVector<OpFoldResult> sizes = getAsOpFoldResult(rewriter.getI64ArrayAttr(extractedShape)); - return rewriter.create<tensor::ExtractSliceOp>(loc, extractedType, value, - offsets, sizes, strides); + return tensor::ExtractSliceOp::create(rewriter, loc, extractedType, value, + offsets, sizes, strides); } /// Utility function to check all values in the attribute are 1. @@ -979,10 +983,10 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, int64_t tileW = llvm::divideCeilSigned(outputW, widthM); auto retType = RankedTensorType::get({alphaH, alphaW, filterC, filterF}, filterElementType); - Value retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), - filterElementType); - auto transformedFilter = rewriter.create<linalg::WinogradFilterTransformOp>( - loc, retType, filter, retValue, fmr); + Value retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(), + filterElementType); + auto transformedFilter = linalg::WinogradFilterTransformOp::create( + rewriter, loc, retType, filter, retValue, fmr); // --- Create operation for input transform --- @@ -998,10 +1002,10 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, retType = RankedTensorType::get( {alphaH, alphaW, tileH, tileW, inputN, inputC}, inputElementType); - retValue = rewriter.create<tensor::EmptyOp>(loc, retType.getShape(), - inputElementType); - auto transformedInput = rewriter.create<linalg::WinogradInputTransformOp>( - loc, retType, input, retValue, fmr); + retValue = tensor::EmptyOp::create(rewriter, loc, retType.getShape(), + inputElementType); + auto transformedInput = linalg::WinogradInputTransformOp::create( + rewriter, loc, retType, input, retValue, fmr); Type outputElementType = outputType.getElementType(); Value matmulRet = matrixMultiply(rewriter, loc, transformedFilter, @@ -1023,8 +1027,8 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, outputType = alignedOutputType; } - Value transformedOutput = rewriter.create<linalg::WinogradOutputTransformOp>( - loc, outputType, matmulRet, output, fmr); + Value transformedOutput = linalg::WinogradOutputTransformOp::create( + rewriter, loc, outputType, matmulRet, output, fmr); // When output size is not aligned with output tile size, extract the // value from the padded buffer. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 14d6200..3593b53 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -320,14 +320,14 @@ GenericOp makeMemRefCopyOp(OpBuilder &b, Location loc, Value from, Value to) { AffineMap::getMultiDimIdentityMap(memrefTypeTo.getRank(), b.getContext()); SmallVector<utils::IteratorType> iteratorTypes(memrefTypeTo.getRank(), utils::IteratorType::parallel); - return b.create<linalg::GenericOp>( - loc, + return linalg::GenericOp::create( + b, loc, /*inputs=*/from, /*outputs=*/to, /*indexingMaps=*/llvm::ArrayRef({id, id}), /*iteratorTypes=*/iteratorTypes, [](OpBuilder &b, Location loc, ValueRange args) { - b.create<linalg::YieldOp>(loc, args.front()); + linalg::YieldOp::create(b, loc, args.front()); }); } @@ -483,8 +483,8 @@ static void generateParallelLoopNest( case DistributionMethod::None: { // Generate a single parallel loop-nest operation for all outermost // parallel loops and recurse. - b.create<scf::ParallelOp>( - loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), + scf::ParallelOp::create( + b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), steps.take_front(numProcessed), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { ivStorage.append(localIvs.begin(), localIvs.end()); @@ -499,8 +499,8 @@ static void generateParallelLoopNest( case DistributionMethod::Cyclic: { // Generate a single parallel loop-nest operation for all outermost // parallel loops and recurse. - b.create<scf::ParallelOp>( - loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), + scf::ParallelOp::create( + b, loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), steps.take_front(numProcessed), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { ivStorage.append(localIvs.begin(), localIvs.end()); @@ -519,13 +519,13 @@ static void generateParallelLoopNest( for (unsigned i = 1; i < numProcessed; ++i) cond = ab._and(cond, ab.slt(lbs[i], ubs[i])); ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); - b.create<scf::IfOp>(loc, cond, [&](OpBuilder &b, Location loc) { + scf::IfOp::create(b, loc, cond, [&](OpBuilder &b, Location loc) { generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed), remainderProcInfo, bodyBuilderFn, ivStorage); - b.create<scf::YieldOp>(loc, ValueRange{}); + scf::YieldOp::create(b, loc, ValueRange{}); }); return; } @@ -595,13 +595,13 @@ static Operation *materializeTiledShape(OpBuilder &builder, Location loc, auto shapedType = dyn_cast<ShapedType>(valueToTile.getType()); auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) .Case([&](MemRefType) { - return builder.create<memref::SubViewOp>( - loc, valueToTile, sliceParams.offsets, + return memref::SubViewOp::create( + builder, loc, valueToTile, sliceParams.offsets, sliceParams.sizes, sliceParams.strides); }) .Case([&](RankedTensorType) { - return builder.create<tensor::ExtractSliceOp>( - loc, valueToTile, sliceParams.offsets, + return tensor::ExtractSliceOp::create( + builder, loc, valueToTile, sliceParams.offsets, sliceParams.sizes, sliceParams.strides); }) .Default([](ShapedType) -> Operation * { @@ -793,8 +793,8 @@ SmallVector<Value> insertSlicesBack(OpBuilder &builder, Location loc, // `tiledOperands`. Value outputTensor = operands[opOperand.getOperandNumber()]; if (auto sliceOp = outputTensor.getDefiningOp<tensor::ExtractSliceOp>()) { - Value inserted = builder.create<tensor::InsertSliceOp>( - loc, sliceOp.getSource().getType(), results[resultIdx], + Value inserted = tensor::InsertSliceOp::create( + builder, loc, sliceOp.getSource().getType(), results[resultIdx], sliceOp.getSource(), sliceOp.getOffsets(), sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp index c5643f6..dfa2e4e 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefMemorySlot.cpp @@ -85,11 +85,11 @@ Value memref::AllocaOp::getDefaultValue(const MemorySlot &slot, // TODO: support more types. return TypeSwitch<Type, Value>(slot.elemType) .Case([&](MemRefType t) { - return builder.create<memref::AllocaOp>(getLoc(), t); + return memref::AllocaOp::create(builder, getLoc(), t); }) .Default([&](Type t) { - return builder.create<arith::ConstantOp>(getLoc(), t, - builder.getZeroAttr(t)); + return arith::ConstantOp::create(builder, getLoc(), t, + builder.getZeroAttr(t)); }); } @@ -135,7 +135,7 @@ DenseMap<Attribute, MemorySlot> memref::AllocaOp::destructure( for (Attribute usedIndex : usedIndices) { Type elemType = memrefType.getTypeAtIndex(usedIndex); MemRefType elemPtr = MemRefType::get({}, elemType); - auto subAlloca = builder.create<memref::AllocaOp>(getLoc(), elemPtr); + auto subAlloca = memref::AllocaOp::create(builder, getLoc(), elemPtr); newAllocators.push_back(subAlloca); slotMap.try_emplace<MemorySlot>(usedIndex, {subAlloca.getResult(), elemType}); diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 51c8136..74b968c 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -213,9 +213,9 @@ struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> { assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. - auto newAlloc = rewriter.create<AllocLikeOp>( - alloc.getLoc(), newMemRefType, dynamicSizes, alloc.getSymbolOperands(), - alloc.getAlignmentAttr()); + auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType, + dynamicSizes, alloc.getSymbolOperands(), + alloc.getAlignmentAttr()); // Insert a cast so we have the same type as the old alloc. rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc); return success(); @@ -797,7 +797,7 @@ void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { void DimOp::build(OpBuilder &builder, OperationState &result, Value source, int64_t index) { auto loc = result.location; - Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index); + Value indexValue = arith::ConstantIndexOp::create(builder, loc, index); build(builder, result, source, indexValue); } @@ -1044,9 +1044,9 @@ struct DimOfMemRefReshape : public OpRewritePattern<DimOp> { rewriter.setInsertionPointAfter(reshape); Location loc = dim.getLoc(); Value load = - rewriter.create<LoadOp>(loc, reshape.getShape(), dim.getIndex()); + LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex()); if (load.getType() != dim.getType()) - load = rewriter.create<arith::IndexCastOp>(loc, dim.getType(), load); + load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load); rewriter.replaceOp(dim, load); return success(); } @@ -1319,8 +1319,9 @@ static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, assert(isa<Attribute>(maybeConstant) && "The constified value should be either unchanged (i.e., == result) " "or a constant"); - Value constantVal = rewriter.create<arith::ConstantIndexOp>( - loc, llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt()); + Value constantVal = arith::ConstantIndexOp::create( + rewriter, loc, + llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt()); for (Operation *op : llvm::make_early_inc_range(result.getUsers())) { // modifyOpInPlace: lambda cannot capture structured bindings in C++17 // yet. @@ -2548,8 +2549,9 @@ public: rewriter.modifyOpInPlace( op, [&]() { op.getSrcMutable().assign(cast.getSource()); }); } else { - Value newOp = rewriter.create<CollapseShapeOp>( - op->getLoc(), cast.getSource(), op.getReassociationIndices()); + Value newOp = + CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(), + op.getReassociationIndices()); rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp); } return success(); @@ -3006,15 +3008,15 @@ SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, Value offset = op.isDynamicOffset(idx) ? op.getDynamicOffset(idx) - : b.create<arith::ConstantIndexOp>(loc, op.getStaticOffset(idx)); + : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx)); Value size = op.isDynamicSize(idx) ? op.getDynamicSize(idx) - : b.create<arith::ConstantIndexOp>(loc, op.getStaticSize(idx)); + : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx)); Value stride = op.isDynamicStride(idx) ? op.getDynamicStride(idx) - : b.create<arith::ConstantIndexOp>(loc, op.getStaticStride(idx)); + : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx)); res.emplace_back(Range{offset, size, stride}); } return res; @@ -3173,8 +3175,8 @@ public: if (!resultType) return failure(); - Value newSubView = rewriter.create<SubViewOp>( - subViewOp.getLoc(), resultType, castOp.getSource(), + Value newSubView = SubViewOp::create( + rewriter, subViewOp.getLoc(), resultType, castOp.getSource(), subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(), subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(), subViewOp.getStaticStrides()); @@ -3495,9 +3497,9 @@ struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> { return failure(); // Create new ViewOp. - auto newViewOp = rewriter.create<ViewOp>( - viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), - viewOp.getByteShift(), newOperands); + auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType, + viewOp.getOperand(0), viewOp.getByteShift(), + newOperands); // Insert a cast so we have the same type as the old memref type. rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp); return success(); diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 0c03670..95eb2a9 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -155,9 +155,10 @@ transform::MemRefAllocaToGlobalOp::apply(transform::TransformRewriter &rewriter, Type resultType = alloca.getResult().getType(); OpBuilder builder(rewriter.getContext()); // TODO: Add a better builder for this. - globalOp = builder.create<memref::GlobalOp>( - loc, StringAttr::get(ctx, "alloca"), StringAttr::get(ctx, "private"), - TypeAttr::get(resultType), Attribute{}, UnitAttr{}, IntegerAttr{}); + globalOp = memref::GlobalOp::create( + builder, loc, StringAttr::get(ctx, "alloca"), + StringAttr::get(ctx, "private"), TypeAttr::get(resultType), + Attribute{}, UnitAttr{}, IntegerAttr{}); symbolTable.insert(globalOp); } diff --git a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp index c433415..75cc39e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.cpp @@ -22,11 +22,11 @@ struct DefaultAllocationInterface DefaultAllocationInterface, memref::AllocOp> { static std::optional<Operation *> buildDealloc(OpBuilder &builder, Value alloc) { - return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc) + return memref::DeallocOp::create(builder, alloc.getLoc(), alloc) .getOperation(); } static std::optional<Value> buildClone(OpBuilder &builder, Value alloc) { - return builder.create<bufferization::CloneOp>(alloc.getLoc(), alloc) + return bufferization::CloneOp::create(builder, alloc.getLoc(), alloc) .getResult(); } static ::mlir::HoistingKind getHoistingKind() { @@ -35,8 +35,9 @@ struct DefaultAllocationInterface static ::std::optional<::mlir::Operation *> buildPromotedAlloc(OpBuilder &builder, Value alloc) { Operation *definingOp = alloc.getDefiningOp(); - return builder.create<memref::AllocaOp>( - definingOp->getLoc(), cast<MemRefType>(definingOp->getResultTypes()[0]), + return memref::AllocaOp::create( + builder, definingOp->getLoc(), + cast<MemRefType>(definingOp->getResultTypes()[0]), definingOp->getOperands(), definingOp->getAttrs()); } }; @@ -52,7 +53,7 @@ struct DefaultReallocationInterface DefaultAllocationInterface, memref::ReallocOp> { static std::optional<Operation *> buildDealloc(OpBuilder &builder, Value realloc) { - return builder.create<memref::DeallocOp>(realloc.getLoc(), realloc) + return memref::DeallocOp::create(builder, realloc.getLoc(), realloc) .getOperation(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp index 7c777e8..cce80db 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp @@ -80,10 +80,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { for (auto &&[opOffset, sourceOffset, sourceStride, opSize] : llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(), sourceOp.getMixedStrides(), op.getMixedSizes())) { - // We only support static sizes. - if (isa<Value>(opSize)) { - return failure(); - } sizes.push_back(opSize); Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset), sourceOffsetAttr = @@ -124,8 +120,8 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> { } AffineMap map = AffineMap::get(0, affineApplyOperands.size(), expr); - Value result = rewriter.create<affine::AffineApplyOp>( - op.getLoc(), map, affineApplyOperands); + Value result = affine::AffineApplyOp::create(rewriter, op.getLoc(), map, + affineApplyOperands); offsets.push_back(result); } } diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index ec2bc95..556ea1a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -99,7 +99,7 @@ static Value getOffsetForBitwidth(Location loc, OpFoldResult srcIdx, affine::makeComposedFoldedAffineApply(builder, loc, offsetExpr, {srcIdx}); Value bitOffset = getValueOrCreateConstantIndexOp(builder, loc, offsetVal); IntegerType dstType = builder.getIntegerType(targetBits); - return builder.create<arith::IndexCastOp>(loc, dstType, bitOffset); + return arith::IndexCastOp::create(builder, loc, dstType, bitOffset); } /// When writing a subbyte size, masked bitwise operations are used to only @@ -112,14 +112,14 @@ static Value getSubByteWriteMask(Location loc, OpFoldResult linearizedIndices, auto dstIntegerType = builder.getIntegerType(dstBits); auto maskRightAlignedAttr = builder.getIntegerAttr(dstIntegerType, (1 << srcBits) - 1); - Value maskRightAligned = builder.create<arith::ConstantOp>( - loc, dstIntegerType, maskRightAlignedAttr); + Value maskRightAligned = arith::ConstantOp::create( + builder, loc, dstIntegerType, maskRightAlignedAttr); Value writeMaskInverse = - builder.create<arith::ShLIOp>(loc, maskRightAligned, bitwidthOffset); + arith::ShLIOp::create(builder, loc, maskRightAligned, bitwidthOffset); auto flipValAttr = builder.getIntegerAttr(dstIntegerType, -1); Value flipVal = - builder.create<arith::ConstantOp>(loc, dstIntegerType, flipValAttr); - return builder.create<arith::XOrIOp>(loc, writeMaskInverse, flipVal); + arith::ConstantOp::create(builder, loc, dstIntegerType, flipValAttr); + return arith::XOrIOp::create(builder, loc, writeMaskInverse, flipVal); } /// Returns the scaled linearized index based on the `srcBits` and `dstBits` @@ -141,7 +141,7 @@ getLinearizedSrcIndices(OpBuilder &builder, Location loc, int64_t srcBits, const SmallVector<OpFoldResult> &indices, Value memref) { auto stridedMetadata = - builder.create<memref::ExtractStridedMetadataOp>(loc, memref); + memref::ExtractStridedMetadataOp::create(builder, loc, memref); OpFoldResult linearizedIndices; std::tie(std::ignore, linearizedIndices) = memref::getLinearizedMemRefOffsetAndSize( @@ -298,16 +298,16 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { // Special case 0-rank memref loads. Value bitsLoad; if (convertedType.getRank() == 0) { - bitsLoad = rewriter.create<memref::LoadOp>(loc, adaptor.getMemref(), - ValueRange{}); + bitsLoad = memref::LoadOp::create(rewriter, loc, adaptor.getMemref(), + ValueRange{}); } else { // Linearize the indices of the original load instruction. Do not account // for the scaling yet. This will be accounted for later. OpFoldResult linearizedIndices = getLinearizedSrcIndices( rewriter, loc, srcBits, adaptor.getIndices(), op.getMemRef()); - Value newLoad = rewriter.create<memref::LoadOp>( - loc, adaptor.getMemref(), + Value newLoad = memref::LoadOp::create( + rewriter, loc, adaptor.getMemref(), getIndicesForLoadOrStore(rewriter, loc, linearizedIndices, srcBits, dstBits)); @@ -315,7 +315,7 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { // Note, currently only the big-endian is supported. Value bitwidthOffset = getOffsetForBitwidth(loc, linearizedIndices, srcBits, dstBits, rewriter); - bitsLoad = rewriter.create<arith::ShRSIOp>(loc, newLoad, bitwidthOffset); + bitsLoad = arith::ShRSIOp::create(rewriter, loc, newLoad, bitwidthOffset); } // Get the corresponding bits. If the arith computation bitwidth equals @@ -331,17 +331,17 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> { : IntegerType::get(rewriter.getContext(), resultTy.getIntOrFloatBitWidth()); if (conversionTy == convertedElementType) { - auto mask = rewriter.create<arith::ConstantOp>( - loc, convertedElementType, + auto mask = arith::ConstantOp::create( + rewriter, loc, convertedElementType, rewriter.getIntegerAttr(convertedElementType, (1 << srcBits) - 1)); - result = rewriter.create<arith::AndIOp>(loc, bitsLoad, mask); + result = arith::AndIOp::create(rewriter, loc, bitsLoad, mask); } else { - result = rewriter.create<arith::TruncIOp>(loc, conversionTy, bitsLoad); + result = arith::TruncIOp::create(rewriter, loc, conversionTy, bitsLoad); } if (conversionTy != resultTy) { - result = rewriter.create<arith::BitcastOp>(loc, resultTy, result); + result = arith::BitcastOp::create(rewriter, loc, resultTy, result); } rewriter.replaceOp(op, result); @@ -428,20 +428,20 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> { // Pad the input value with 0s on the left. Value input = adaptor.getValue(); if (!input.getType().isInteger()) { - input = rewriter.create<arith::BitcastOp>( - loc, + input = arith::BitcastOp::create( + rewriter, loc, IntegerType::get(rewriter.getContext(), input.getType().getIntOrFloatBitWidth()), input); } Value extendedInput = - rewriter.create<arith::ExtUIOp>(loc, dstIntegerType, input); + arith::ExtUIOp::create(rewriter, loc, dstIntegerType, input); // Special case 0-rank memref stores. No need for masking. if (convertedType.getRank() == 0) { - rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::assign, - extendedInput, adaptor.getMemref(), - ValueRange{}); + memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::assign, + extendedInput, adaptor.getMemref(), + ValueRange{}); rewriter.eraseOp(op); return success(); } @@ -456,16 +456,14 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> { dstBits, bitwidthOffset, rewriter); // Align the value to write with the destination bits Value alignedVal = - rewriter.create<arith::ShLIOp>(loc, extendedInput, bitwidthOffset); + arith::ShLIOp::create(rewriter, loc, extendedInput, bitwidthOffset); // Clear destination bits - rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::andi, - writeMask, adaptor.getMemref(), - storeIndices); + memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::andi, + writeMask, adaptor.getMemref(), storeIndices); // Write srcs bits to destination - rewriter.create<memref::AtomicRMWOp>(loc, arith::AtomicRMWKind::ori, - alignedVal, adaptor.getMemref(), - storeIndices); + memref::AtomicRMWOp::create(rewriter, loc, arith::AtomicRMWKind::ori, + alignedVal, adaptor.getMemref(), storeIndices); rewriter.eraseOp(op); return success(); } @@ -525,8 +523,8 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> { } // Transform the offsets, sizes and strides according to the emulation. - auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>( - loc, subViewOp.getViewSource()); + auto stridedMetadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, subViewOp.getViewSource()); OpFoldResult linearizedIndices; auto strides = stridedMetadata.getConstifiedMixedStrides(); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp index e6e4c3b0..17a148c 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandOps.cpp @@ -48,15 +48,15 @@ public: Value size; // Load dynamic sizes from the shape input, use constants for static dims. if (op.getType().isDynamicDim(i)) { - Value index = rewriter.create<arith::ConstantIndexOp>(loc, i); - size = rewriter.create<memref::LoadOp>(loc, op.getShape(), index); + Value index = arith::ConstantIndexOp::create(rewriter, loc, i); + size = memref::LoadOp::create(rewriter, loc, op.getShape(), index); if (!isa<IndexType>(size.getType())) - size = rewriter.create<arith::IndexCastOp>( - loc, rewriter.getIndexType(), size); + size = arith::IndexCastOp::create(rewriter, loc, + rewriter.getIndexType(), size); sizes[i] = size; } else { auto sizeAttr = rewriter.getIndexAttr(op.getType().getDimSize(i)); - size = rewriter.create<arith::ConstantOp>(loc, sizeAttr); + size = arith::ConstantOp::create(rewriter, loc, sizeAttr); sizes[i] = sizeAttr; } if (stride) @@ -66,10 +66,11 @@ public: if (i > 0) { if (stride) { - stride = rewriter.create<arith::MulIOp>(loc, stride, size); + stride = arith::MulIOp::create(rewriter, loc, stride, size); } else if (op.getType().isDynamicDim(i)) { - stride = rewriter.create<arith::MulIOp>( - loc, rewriter.create<arith::ConstantIndexOp>(loc, staticStride), + stride = arith::MulIOp::create( + rewriter, loc, + arith::ConstantIndexOp::create(rewriter, loc, staticStride), size); } else { staticStride *= op.getType().getDimSize(i); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp index 7475d44..01d3262 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandRealloc.cpp @@ -73,7 +73,7 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> { if (ShapedType::isDynamic(inputSize)) { Value dimZero = getValueOrCreateConstantIndexOp(rewriter, loc, rewriter.getIndexAttr(0)); - currSize = rewriter.create<memref::DimOp>(loc, op.getSource(), dimZero) + currSize = memref::DimOp::create(rewriter, loc, op.getSource(), dimZero) .getResult(); } @@ -88,10 +88,10 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> { // the old buffer is smaller than the requested size. Value lhs = getValueOrCreateConstantIndexOp(rewriter, loc, currSize); Value rhs = getValueOrCreateConstantIndexOp(rewriter, loc, targetSize); - Value cond = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, - lhs, rhs); - auto ifOp = rewriter.create<scf::IfOp>( - loc, cond, + Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult, + lhs, rhs); + auto ifOp = scf::IfOp::create( + rewriter, loc, cond, [&](OpBuilder &builder, Location loc) { // Allocate the new buffer. If it is a dynamic memref we need to pass // an additional operand for the size at runtime, otherwise the static @@ -100,25 +100,26 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> { if (op.getDynamicResultSize()) dynamicSizeOperands.push_back(op.getDynamicResultSize()); - Value newAlloc = builder.create<memref::AllocOp>( - loc, op.getResult().getType(), dynamicSizeOperands, + Value newAlloc = memref::AllocOp::create( + builder, loc, op.getResult().getType(), dynamicSizeOperands, op.getAlignmentAttr()); // Take a subview of the new (bigger) buffer such that we can copy the // old values over (the copy operation requires both operands to have // the same shape). - Value subview = builder.create<memref::SubViewOp>( - loc, newAlloc, ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)}, + Value subview = memref::SubViewOp::create( + builder, loc, newAlloc, + ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)}, ArrayRef<OpFoldResult>{currSize}, ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}); - builder.create<memref::CopyOp>(loc, op.getSource(), subview); + memref::CopyOp::create(builder, loc, op.getSource(), subview); // Insert the deallocation of the old buffer only if requested // (enabled by default). if (emitDeallocs) - builder.create<memref::DeallocOp>(loc, op.getSource()); + memref::DeallocOp::create(builder, loc, op.getSource()); - builder.create<scf::YieldOp>(loc, newAlloc); + scf::YieldOp::create(builder, loc, newAlloc); }, [&](OpBuilder &builder, Location loc) { // We need to reinterpret-cast here because either the input or output @@ -126,11 +127,12 @@ struct ExpandReallocOpPattern : public OpRewritePattern<memref::ReallocOp> { // dynamic or vice-versa. If both are static and the original buffer // is already bigger than the requested size, the cast represents a // subview operation. - Value casted = builder.create<memref::ReinterpretCastOp>( - loc, cast<MemRefType>(op.getResult().getType()), op.getSource(), - rewriter.getIndexAttr(0), ArrayRef<OpFoldResult>{targetSize}, + Value casted = memref::ReinterpretCastOp::create( + builder, loc, cast<MemRefType>(op.getResult().getType()), + op.getSource(), rewriter.getIndexAttr(0), + ArrayRef<OpFoldResult>{targetSize}, ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)}); - builder.create<scf::YieldOp>(loc, casted); + scf::YieldOp::create(builder, loc, casted); }); rewriter.replaceOp(op, ifOp.getResult(0)); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 2ba798f..9771bd2 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -66,7 +66,7 @@ resolveSubviewStridedMetadata(RewriterBase &rewriter, unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = - rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source); + memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source); auto [sourceStrides, sourceOffset] = sourceType.getStridesAndOffset(); #ifndef NDEBUG @@ -577,7 +577,7 @@ static FailureOr<StridedMetadata> resolveReshapeStridedMetadata( unsigned sourceRank = sourceType.getRank(); auto newExtractStridedMetadata = - rewriter.create<memref::ExtractStridedMetadataOp>(origLoc, source); + memref::ExtractStridedMetadataOp::create(rewriter, origLoc, source); // Collect statically known information. auto [strides, offset] = sourceType.getStridesAndOffset(); @@ -828,14 +828,14 @@ public: if (allocLikeOp.getType() == baseBufferType) results.push_back(allocLikeOp); else - results.push_back(rewriter.create<memref::ReinterpretCastOp>( - loc, baseBufferType, allocLikeOp, offset, + results.push_back(memref::ReinterpretCastOp::create( + rewriter, loc, baseBufferType, allocLikeOp, offset, /*sizes=*/ArrayRef<int64_t>(), /*strides=*/ArrayRef<int64_t>())); } // Offset. - results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset)); for (OpFoldResult size : sizes) results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, size)); @@ -900,19 +900,19 @@ public: if (getGlobalOp.getType() == baseBufferType) results.push_back(getGlobalOp); else - results.push_back(rewriter.create<memref::ReinterpretCastOp>( - loc, baseBufferType, getGlobalOp, offset, + results.push_back(memref::ReinterpretCastOp::create( + rewriter, loc, baseBufferType, getGlobalOp, offset, /*sizes=*/ArrayRef<int64_t>(), /*strides=*/ArrayRef<int64_t>())); // Offset. - results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, offset)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, offset)); for (auto size : sizes) - results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, size)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, size)); for (auto stride : strides) - results.push_back(rewriter.create<arith::ConstantIndexOp>(loc, stride)); + results.push_back(arith::ConstantIndexOp::create(rewriter, loc, stride)); rewriter.replaceOp(op, results); return success(); @@ -1008,9 +1008,8 @@ class ExtractStridedMetadataOpReinterpretCastFolder SmallVector<OpFoldResult> results; results.resize_for_overwrite(rank * 2 + 2); - auto newExtractStridedMetadata = - rewriter.create<memref::ExtractStridedMetadataOp>( - loc, reinterpretCastOp.getSource()); + auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, reinterpretCastOp.getSource()); // Register the base_buffer. results[0] = newExtractStridedMetadata.getBaseBuffer(); @@ -1082,9 +1081,8 @@ class ExtractStridedMetadataOpCastFolder SmallVector<OpFoldResult> results; results.resize_for_overwrite(rank * 2 + 2); - auto newExtractStridedMetadata = - rewriter.create<memref::ExtractStridedMetadataOp>(loc, - castOp.getSource()); + auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, castOp.getSource()); // Register the base_buffer. results[0] = newExtractStridedMetadata.getBaseBuffer(); @@ -1142,9 +1140,8 @@ class ExtractStridedMetadataOpMemorySpaceCastFolder auto memSpaceCastOp = source.getDefiningOp<memref::MemorySpaceCastOp>(); if (!memSpaceCastOp) return failure(); - auto newExtractStridedMetadata = - rewriter.create<memref::ExtractStridedMetadataOp>( - loc, memSpaceCastOp.getSource()); + auto newExtractStridedMetadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, memSpaceCastOp.getSource()); SmallVector<Value> results(newExtractStridedMetadata.getResults()); // As with most other strided metadata rewrite patterns, don't introduce // a use of the base pointer where non existed. This needs to happen here, @@ -1158,8 +1155,8 @@ class ExtractStridedMetadataOpMemorySpaceCastFolder MemRefType::Builder newTypeBuilder(baseBufferType); newTypeBuilder.setMemorySpace( memSpaceCastOp.getResult().getType().getMemorySpace()); - results[0] = rewriter.create<memref::MemorySpaceCastOp>( - loc, Type{newTypeBuilder}, baseBuffer); + results[0] = memref::MemorySpaceCastOp::create( + rewriter, loc, Type{newTypeBuilder}, baseBuffer); } else { results[0] = nullptr; } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp index 2f5c943..0946da8e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExtractAddressComputations.cpp @@ -42,8 +42,8 @@ static memref::LoadOp rebuildLoadOp(RewriterBase &rewriter, memref::LoadOp loadOp, Value srcMemRef, ArrayRef<Value> indices) { Location loc = loadOp.getLoc(); - return rewriter.create<memref::LoadOp>(loc, srcMemRef, indices, - loadOp.getNontemporal()); + return memref::LoadOp::create(rewriter, loc, srcMemRef, indices, + loadOp.getNontemporal()); } // Matches getViewSizeForEachDim specs for LoadOp. @@ -72,9 +72,8 @@ static memref::StoreOp rebuildStoreOp(RewriterBase &rewriter, memref::StoreOp storeOp, Value srcMemRef, ArrayRef<Value> indices) { Location loc = storeOp.getLoc(); - return rewriter.create<memref::StoreOp>(loc, storeOp.getValueToStore(), - srcMemRef, indices, - storeOp.getNontemporal()); + return memref::StoreOp::create(rewriter, loc, storeOp.getValueToStore(), + srcMemRef, indices, storeOp.getNontemporal()); } // Matches getViewSizeForEachDim specs for StoreOp. @@ -104,8 +103,8 @@ static nvgpu::LdMatrixOp rebuildLdMatrixOp(RewriterBase &rewriter, Value srcMemRef, ArrayRef<Value> indices) { Location loc = ldMatrixOp.getLoc(); - return rewriter.create<nvgpu::LdMatrixOp>( - loc, ldMatrixOp.getResult().getType(), srcMemRef, indices, + return nvgpu::LdMatrixOp::create( + rewriter, loc, ldMatrixOp.getResult().getType(), srcMemRef, indices, ldMatrixOp.getTranspose(), ldMatrixOp.getNumTiles()); } @@ -132,8 +131,8 @@ rebuildTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp transferReadOp, Value srcMemRef, ArrayRef<Value> indices) { Location loc = transferReadOp.getLoc(); - return rewriter.create<vector::TransferReadOp>( - loc, transferReadOp.getResult().getType(), srcMemRef, indices, + return vector::TransferReadOp::create( + rewriter, loc, transferReadOp.getResult().getType(), srcMemRef, indices, transferReadOp.getPermutationMap(), transferReadOp.getPadding(), transferReadOp.getMask(), transferReadOp.getInBoundsAttr()); } @@ -150,8 +149,8 @@ rebuildTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp transferWriteOp, Value srcMemRef, ArrayRef<Value> indices) { Location loc = transferWriteOp.getLoc(); - return rewriter.create<vector::TransferWriteOp>( - loc, transferWriteOp.getValue(), srcMemRef, indices, + return vector::TransferWriteOp::create( + rewriter, loc, transferWriteOp.getValue(), srcMemRef, indices, transferWriteOp.getPermutationMapAttr(), transferWriteOp.getMask(), transferWriteOp.getInBoundsAttr()); } @@ -182,9 +181,8 @@ static SmallVector<OpFoldResult> getGenericOpViewSizeForEachDim(RewriterBase &rewriter, LoadStoreLikeOp loadStoreLikeOp) { Location loc = loadStoreLikeOp.getLoc(); - auto extractStridedMetadataOp = - rewriter.create<memref::ExtractStridedMetadataOp>( - loc, getSrcMemRef(loadStoreLikeOp)); + auto extractStridedMetadataOp = memref::ExtractStridedMetadataOp::create( + rewriter, loc, getSrcMemRef(loadStoreLikeOp)); SmallVector<OpFoldResult> srcSizes = extractStridedMetadataOp.getConstifiedMixedSizes(); SmallVector<OpFoldResult> indices = @@ -267,12 +265,12 @@ struct LoadStoreLikeOpRewriter : public OpRewritePattern<LoadStoreLikeOp> { // apply them properly to the input indices. // Therefore the strides multipliers are simply ones. auto subview = - rewriter.create<memref::SubViewOp>(loc, /*source=*/srcMemRef, - /*offsets=*/indices, - /*sizes=*/sizes, /*strides=*/ones); + memref::SubViewOp::create(rewriter, loc, /*source=*/srcMemRef, + /*offsets=*/indices, + /*sizes=*/sizes, /*strides=*/ones); // Rewrite the load/store with the subview as the base pointer. SmallVector<Value> zeros(loadStoreRank, - rewriter.create<arith::ConstantIndexOp>(loc, 0)); + arith::ConstantIndexOp::create(rewriter, loc, 0)); LoadStoreLikeOp newLoadStore = rebuildOpFromAddressAndIndices( rewriter, loadStoreLikeOp, subview.getResult(), zeros); rewriter.replaceOp(loadStoreLikeOp, newLoadStore->getResults()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp index 76f7788..42be847 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FlattenMemRefs.cpp @@ -40,8 +40,8 @@ using namespace mlir; static Value getValueFromOpFoldResult(OpBuilder &rewriter, Location loc, OpFoldResult in) { if (Attribute offsetAttr = dyn_cast<Attribute>(in)) { - return rewriter.create<arith::ConstantIndexOp>( - loc, cast<IntegerAttr>(offsetAttr).getInt()); + return arith::ConstantIndexOp::create( + rewriter, loc, cast<IntegerAttr>(offsetAttr).getInt()); } return cast<Value>(in); } @@ -60,7 +60,7 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter, } memref::ExtractStridedMetadataOp stridedMetadata = - rewriter.create<memref::ExtractStridedMetadataOp>(loc, source); + memref::ExtractStridedMetadataOp::create(rewriter, loc, source); auto typeBit = sourceType.getElementType().getIntOrFloatBitWidth(); OpFoldResult linearizedIndices; @@ -74,8 +74,8 @@ static std::pair<Value, Value> getFlattenMemrefAndOffset(OpBuilder &rewriter, getAsOpFoldResult(indices)); return std::make_pair( - rewriter.create<memref::ReinterpretCastOp>( - loc, source, + memref::ReinterpretCastOp::create( + rewriter, loc, source, /* offset = */ linearizedInfo.linearizedOffset, /* shapes = */ ArrayRef<OpFoldResult>{linearizedInfo.linearizedSize}, @@ -111,7 +111,7 @@ template <typename T> static void castAllocResult(T oper, T newOper, Location loc, PatternRewriter &rewriter) { memref::ExtractStridedMetadataOp stridedMetadata = - rewriter.create<memref::ExtractStridedMetadataOp>(loc, oper); + memref::ExtractStridedMetadataOp::create(rewriter, loc, oper); rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>( oper, cast<MemRefType>(oper.getType()), newOper, /*offset=*/rewriter.getIndexAttr(0), @@ -125,63 +125,68 @@ static void replaceOp(T op, PatternRewriter &rewriter, Value flatMemref, Location loc = op->getLoc(); llvm::TypeSwitch<Operation *>(op.getOperation()) .template Case<memref::AllocOp>([&](auto oper) { - auto newAlloc = rewriter.create<memref::AllocOp>( - loc, cast<MemRefType>(flatMemref.getType()), + auto newAlloc = memref::AllocOp::create( + rewriter, loc, cast<MemRefType>(flatMemref.getType()), oper.getAlignmentAttr()); castAllocResult(oper, newAlloc, loc, rewriter); }) .template Case<memref::AllocaOp>([&](auto oper) { - auto newAlloca = rewriter.create<memref::AllocaOp>( - loc, cast<MemRefType>(flatMemref.getType()), + auto newAlloca = memref::AllocaOp::create( + rewriter, loc, cast<MemRefType>(flatMemref.getType()), oper.getAlignmentAttr()); castAllocResult(oper, newAlloca, loc, rewriter); }) .template Case<memref::LoadOp>([&](auto op) { - auto newLoad = rewriter.create<memref::LoadOp>( - loc, op->getResultTypes(), flatMemref, ValueRange{offset}); + auto newLoad = + memref::LoadOp::create(rewriter, loc, op->getResultTypes(), + flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) .template Case<memref::StoreOp>([&](auto op) { - auto newStore = rewriter.create<memref::StoreOp>( - loc, op->getOperands().front(), flatMemref, ValueRange{offset}); + auto newStore = + memref::StoreOp::create(rewriter, loc, op->getOperands().front(), + flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) .template Case<vector::LoadOp>([&](auto op) { - auto newLoad = rewriter.create<vector::LoadOp>( - loc, op->getResultTypes(), flatMemref, ValueRange{offset}); + auto newLoad = + vector::LoadOp::create(rewriter, loc, op->getResultTypes(), + flatMemref, ValueRange{offset}); newLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newLoad.getResult()); }) .template Case<vector::StoreOp>([&](auto op) { - auto newStore = rewriter.create<vector::StoreOp>( - loc, op->getOperands().front(), flatMemref, ValueRange{offset}); + auto newStore = + vector::StoreOp::create(rewriter, loc, op->getOperands().front(), + flatMemref, ValueRange{offset}); newStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newStore); }) .template Case<vector::MaskedLoadOp>([&](auto op) { - auto newMaskedLoad = rewriter.create<vector::MaskedLoadOp>( - loc, op.getType(), flatMemref, ValueRange{offset}, op.getMask(), - op.getPassThru()); + auto newMaskedLoad = vector::MaskedLoadOp::create( + rewriter, loc, op.getType(), flatMemref, ValueRange{offset}, + op.getMask(), op.getPassThru()); newMaskedLoad->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedLoad.getResult()); }) .template Case<vector::MaskedStoreOp>([&](auto op) { - auto newMaskedStore = rewriter.create<vector::MaskedStoreOp>( - loc, flatMemref, ValueRange{offset}, op.getMask(), + auto newMaskedStore = vector::MaskedStoreOp::create( + rewriter, loc, flatMemref, ValueRange{offset}, op.getMask(), op.getValueToStore()); newMaskedStore->setAttrs(op->getAttrs()); rewriter.replaceOp(op, newMaskedStore); }) .template Case<vector::TransferReadOp>([&](auto op) { - auto newTransferRead = rewriter.create<vector::TransferReadOp>( - loc, op.getType(), flatMemref, ValueRange{offset}, op.getPadding()); + auto newTransferRead = vector::TransferReadOp::create( + rewriter, loc, op.getType(), flatMemref, ValueRange{offset}, + op.getPadding()); rewriter.replaceOp(op, newTransferRead.getResult()); }) .template Case<vector::TransferWriteOp>([&](auto op) { - auto newTransferWrite = rewriter.create<vector::TransferWriteOp>( - loc, op.getVector(), flatMemref, ValueRange{offset}); + auto newTransferWrite = vector::TransferWriteOp::create( + rewriter, loc, op.getVector(), flatMemref, ValueRange{offset}); rewriter.replaceOp(op, newTransferWrite); }) .Default([&](auto op) { diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index 89be188..24da447 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -44,97 +44,6 @@ using namespace mlir; // Utility functions //===----------------------------------------------------------------------===// -/// Given the 'indices' of a load/store operation where the memref is a result -/// of a expand_shape op, returns the indices w.r.t to the source memref of the -/// expand_shape op. For example -/// -/// %0 = ... : memref<12x42xf32> -/// %1 = memref.expand_shape %0 [[0, 1], [2]] -/// : memref<12x42xf32> into memref<2x6x42xf32> -/// %2 = load %1[%i1, %i2, %i3] : memref<2x6x42xf32 -/// -/// could be folded into -/// -/// %2 = load %0[6 * i1 + i2, %i3] : -/// memref<12x42xf32> -static LogicalResult resolveSourceIndicesExpandShape( - Location loc, PatternRewriter &rewriter, - memref::ExpandShapeOp expandShapeOp, ValueRange indices, - SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) { - SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape(); - - // Traverse all reassociation groups to determine the appropriate indices - // corresponding to each one of them post op folding. - for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) { - assert(!group.empty() && "association indices groups cannot be empty"); - int64_t groupSize = group.size(); - if (groupSize == 1) { - sourceIndices.push_back(indices[group[0]]); - continue; - } - SmallVector<OpFoldResult> groupBasis = - llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; }); - SmallVector<Value> groupIndices = - llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; }); - Value collapsedIndex = rewriter.create<affine::AffineLinearizeIndexOp>( - loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds); - sourceIndices.push_back(collapsedIndex); - } - return success(); -} - -/// Given the 'indices' of a load/store operation where the memref is a result -/// of a collapse_shape op, returns the indices w.r.t to the source memref of -/// the collapse_shape op. For example -/// -/// %0 = ... : memref<2x6x42xf32> -/// %1 = memref.collapse_shape %0 [[0, 1], [2]] -/// : memref<2x6x42xf32> into memref<12x42xf32> -/// %2 = load %1[%i1, %i2] : memref<12x42xf32> -/// -/// could be folded into -/// -/// %2 = load %0[%i1 / 6, %i1 % 6, %i2] : -/// memref<2x6x42xf32> -static LogicalResult -resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, - memref::CollapseShapeOp collapseShapeOp, - ValueRange indices, - SmallVectorImpl<Value> &sourceIndices) { - // Note: collapse_shape requires a strided memref, we can do this. - auto metadata = rewriter.create<memref::ExtractStridedMetadataOp>( - loc, collapseShapeOp.getSrc()); - SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes(); - for (auto [index, group] : - llvm::zip(indices, collapseShapeOp.getReassociationIndices())) { - assert(!group.empty() && "association indices groups cannot be empty"); - int64_t groupSize = group.size(); - - if (groupSize == 1) { - sourceIndices.push_back(index); - continue; - } - - SmallVector<OpFoldResult> basis = - llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; }); - auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>( - loc, index, basis, /*hasOuterBound=*/true); - llvm::append_range(sourceIndices, delinearize.getResults()); - } - if (collapseShapeOp.getReassociationIndices().empty()) { - auto zeroAffineMap = rewriter.getConstantAffineMap(0); - int64_t srcRank = - cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank(); - OpFoldResult ofr = affine::makeComposedFoldedAffineApply( - rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{}); - for (int64_t i = 0; i < srcRank; i++) { - sourceIndices.push_back( - getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); - } - } - return success(); -} - /// Helpers to access the memref operand for each op. template <typename LoadOrStoreOpTy> static Value getMemRefOperand(LoadOrStoreOpTy op) { diff --git a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp index 35c661e..d5e2b97 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/IndependenceTransforms.cpp @@ -51,14 +51,13 @@ FailureOr<Value> memref::buildIndependentOp(OpBuilder &b, // Create a new memref::AllocaOp. Value newAllocaOp = - b.create<AllocaOp>(loc, newSizes, allocaOp.getType().getElementType()); + AllocaOp::create(b, loc, newSizes, allocaOp.getType().getElementType()); // Create a memref::SubViewOp. SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0)); SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1)); - return b - .create<SubViewOp>(loc, newAllocaOp, offsets, allocaOp.getMixedSizes(), - strides) + return SubViewOp::create(b, loc, newAllocaOp, offsets, + allocaOp.getMixedSizes(), strides) .getResult(); } @@ -71,11 +70,11 @@ propagateSubViewOp(RewriterBase &rewriter, MemRefType newResultType = SubViewOp::inferRankReducedResultType( op.getType().getShape(), op.getSourceType(), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); - Value newSubview = rewriter.create<SubViewOp>( - op.getLoc(), newResultType, conversionOp.getOperand(0), + Value newSubview = SubViewOp::create( + rewriter, op.getLoc(), newResultType, conversionOp.getOperand(0), op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()); - auto newConversionOp = rewriter.create<UnrealizedConversionCastOp>( - op.getLoc(), op.getType(), newSubview); + auto newConversionOp = UnrealizedConversionCastOp::create( + rewriter, op.getLoc(), op.getType(), newSubview); rewriter.replaceAllUsesWith(op.getResult(), newConversionOp->getResult(0)); return newConversionOp; } @@ -106,8 +105,8 @@ static void replaceAndPropagateMemRefType(RewriterBase &rewriter, SmallVector<UnrealizedConversionCastOp> unrealizedConversions; for (const auto &it : llvm::enumerate(llvm::zip(from->getResults(), to->getResults()))) { - unrealizedConversions.push_back(rewriter.create<UnrealizedConversionCastOp>( - to->getLoc(), std::get<0>(it.value()).getType(), + unrealizedConversions.push_back(UnrealizedConversionCastOp::create( + rewriter, to->getLoc(), std::get<0>(it.value()).getType(), std::get<1>(it.value()))); rewriter.replaceAllUsesWith(from->getResult(it.index()), unrealizedConversions.back()->getResult(0)); diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index 0a84962..5d3cec4 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -63,9 +63,10 @@ static void replaceUsesAndPropagateType(RewriterBase &rewriter, subviewUse.getType().getShape(), cast<MemRefType>(val.getType()), subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), subviewUse.getStaticStrides()); - Value newSubview = rewriter.create<memref::SubViewOp>( - subviewUse->getLoc(), newType, val, subviewUse.getMixedOffsets(), - subviewUse.getMixedSizes(), subviewUse.getMixedStrides()); + Value newSubview = memref::SubViewOp::create( + rewriter, subviewUse->getLoc(), newType, val, + subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), + subviewUse.getMixedStrides()); // Ouch recursion ... is this really necessary? replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); @@ -177,8 +178,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, Location loc = allocOp->getLoc(); OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(allocOp); - auto mbAlloc = rewriter.create<memref::AllocOp>( - loc, mbMemRefType, ValueRange{}, allocOp->getAttrs()); + auto mbAlloc = memref::AllocOp::create(rewriter, loc, mbMemRefType, + ValueRange{}, allocOp->getAttrs()); LLVM_DEBUG(DBGS() << "--multi-buffered alloc: " << mbAlloc << "\n"); // 3. Within the loop, build the modular leading index (i.e. each loop @@ -211,8 +212,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, // Strides is [1, 1 ... 1 ]. MemRefType dstMemref = memref::SubViewOp::inferRankReducedResultType( originalShape, mbMemRefType, offsets, sizes, strides); - Value subview = rewriter.create<memref::SubViewOp>(loc, dstMemref, mbAlloc, - offsets, sizes, strides); + Value subview = memref::SubViewOp::create(rewriter, loc, dstMemref, mbAlloc, + offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to @@ -224,7 +225,7 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(deallocOp); auto newDeallocOp = - rewriter.create<memref::DeallocOp>(deallocOp->getLoc(), mbAlloc); + memref::DeallocOp::create(rewriter, deallocOp->getLoc(), mbAlloc); (void)newDeallocOp; LLVM_DEBUG(DBGS() << "----Created dealloc: " << newDeallocOp << "\n"); rewriter.eraseOp(deallocOp); diff --git a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp index 4ec0432..fa7991e 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/NormalizeMemRefs.cpp @@ -276,8 +276,8 @@ void NormalizeMemRefs::updateFunctionSignature(func::FuncOp funcOp, if (!callOp) continue; Operation *newCallOp = - builder.create<func::CallOp>(userOp->getLoc(), callOp.getCalleeAttr(), - resultTypes, userOp->getOperands()); + func::CallOp::create(builder, userOp->getLoc(), callOp.getCalleeAttr(), + resultTypes, userOp->getOperands()); bool replacingMemRefUsesFailed = false; bool returnTypeChanged = false; for (unsigned resIndex : llvm::seq<unsigned>(0, userOp->getNumResults())) { diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp index 46f9d64e..d65825b 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp @@ -115,10 +115,12 @@ static LogicalResult reifyOpResultShapes(RewriterBase &rewriter, // Update the type. newRes.setType(reifiedTy); if (isa<RankedTensorType>(reifiedTy)) { - newResults.push_back(rewriter.create<tensor::CastOp>(loc, oldTy, newRes)); + newResults.push_back( + tensor::CastOp::create(rewriter, loc, oldTy, newRes)); } else { assert(isa<MemRefType>(reifiedTy) && "expected a memref type"); - newResults.push_back(rewriter.create<memref::CastOp>(loc, oldTy, newRes)); + newResults.push_back( + memref::CastOp::create(rewriter, loc, oldTy, newRes)); } } diff --git a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp index 89a3895..6a81a15 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp @@ -69,7 +69,7 @@ struct DimOfShapedTypeOpInterface : public OpRewritePattern<OpTy> { Location loc = dimOp->getLoc(); rewriter.replaceOpWithNewOp<tensor::ExtractOp>( dimOp, resultShape, - rewriter.create<arith::ConstantIndexOp>(loc, *dimIndex).getResult()); + arith::ConstantIndexOp::create(rewriter, loc, *dimIndex).getResult()); return success(); } }; diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp index d231516..d3a77c0 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp @@ -40,19 +40,18 @@ struct AssumeAlignmentOpInterface void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto assumeOp = cast<AssumeAlignmentOp>(op); - Value ptr = builder.create<ExtractAlignedPointerAsIndexOp>( - loc, assumeOp.getMemref()); - Value rest = builder.create<arith::RemUIOp>( - loc, ptr, - builder.create<arith::ConstantIndexOp>(loc, assumeOp.getAlignment())); - Value isAligned = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, rest, - builder.create<arith::ConstantIndexOp>(loc, 0)); - builder.create<cf::AssertOp>( - loc, isAligned, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "memref is not aligned to " + - std::to_string(assumeOp.getAlignment()))); + Value ptr = ExtractAlignedPointerAsIndexOp::create(builder, loc, + assumeOp.getMemref()); + Value rest = arith::RemUIOp::create( + builder, loc, ptr, + arith::ConstantIndexOp::create(builder, loc, assumeOp.getAlignment())); + Value isAligned = + arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::eq, rest, + arith::ConstantIndexOp::create(builder, loc, 0)); + cf::AssertOp::create(builder, loc, isAligned, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "memref is not aligned to " + + std::to_string(assumeOp.getAlignment()))); } }; @@ -71,15 +70,14 @@ struct CastOpInterface if (isa<UnrankedMemRefType>(srcType)) { // Check rank. - Value srcRank = builder.create<RankOp>(loc, castOp.getSource()); + Value srcRank = RankOp::create(builder, loc, castOp.getSource()); Value resultRank = - builder.create<arith::ConstantIndexOp>(loc, resultType.getRank()); - Value isSameRank = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, srcRank, resultRank); - builder.create<cf::AssertOp>( - loc, isSameRank, - RuntimeVerifiableOpInterface::generateErrorMessage(op, - "rank mismatch")); + arith::ConstantIndexOp::create(builder, loc, resultType.getRank()); + Value isSameRank = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, srcRank, resultRank); + cf::AssertOp::create(builder, loc, isSameRank, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "rank mismatch")); } // Get source offset and strides. We do not have an op to get offsets and @@ -95,8 +93,9 @@ struct CastOpInterface MemRefType::get(dynamicShape, resultType.getElementType(), stridedLayout, resultType.getMemorySpace()); Value helperCast = - builder.create<CastOp>(loc, dynStridesType, castOp.getSource()); - auto metadataOp = builder.create<ExtractStridedMetadataOp>(loc, helperCast); + CastOp::create(builder, loc, dynStridesType, castOp.getSource()); + auto metadataOp = + ExtractStridedMetadataOp::create(builder, loc, helperCast); // Check dimension sizes. for (const auto &it : llvm::enumerate(resultType.getShape())) { @@ -110,13 +109,13 @@ struct CastOpInterface continue; Value srcDimSz = - builder.create<DimOp>(loc, castOp.getSource(), it.index()); + DimOp::create(builder, loc, castOp.getSource(), it.index()); Value resultDimSz = - builder.create<arith::ConstantIndexOp>(loc, it.value()); - Value isSameSz = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); - builder.create<cf::AssertOp>( - loc, isSameSz, + arith::ConstantIndexOp::create(builder, loc, it.value()); + Value isSameSz = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, srcDimSz, resultDimSz); + cf::AssertOp::create( + builder, loc, isSameSz, RuntimeVerifiableOpInterface::generateErrorMessage( op, "size mismatch of dim " + std::to_string(it.index()))); } @@ -132,13 +131,12 @@ struct CastOpInterface // Static/dynamic offset -> dynamic offset does not need verification. Value srcOffset = metadataOp.getResult(1); Value resultOffsetVal = - builder.create<arith::ConstantIndexOp>(loc, resultOffset); - Value isSameOffset = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); - builder.create<cf::AssertOp>( - loc, isSameOffset, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "offset mismatch")); + arith::ConstantIndexOp::create(builder, loc, resultOffset); + Value isSameOffset = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, srcOffset, resultOffsetVal); + cf::AssertOp::create(builder, loc, isSameOffset, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "offset mismatch")); } // Check strides. @@ -150,11 +148,11 @@ struct CastOpInterface Value srcStride = metadataOp.getResult(2 + resultType.getRank() + it.index()); Value resultStrideVal = - builder.create<arith::ConstantIndexOp>(loc, it.value()); - Value isSameStride = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); - builder.create<cf::AssertOp>( - loc, isSameStride, + arith::ConstantIndexOp::create(builder, loc, it.value()); + Value isSameStride = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, srcStride, resultStrideVal); + cf::AssertOp::create( + builder, loc, isSameStride, RuntimeVerifiableOpInterface::generateErrorMessage( op, "stride mismatch of dim " + std::to_string(it.index()))); } @@ -186,21 +184,19 @@ struct CopyOpInterface auto getDimSize = [&](Value memRef, MemRefType type, int64_t dim) -> Value { return type.isDynamicDim(dim) - ? builder.create<DimOp>(loc, memRef, dim).getResult() - : builder - .create<arith::ConstantIndexOp>(loc, - type.getDimSize(dim)) + ? DimOp::create(builder, loc, memRef, dim).getResult() + : arith::ConstantIndexOp::create(builder, loc, + type.getDimSize(dim)) .getResult(); }; Value sourceDim = getDimSize(copyOp.getSource(), rankedSourceType, i); Value targetDim = getDimSize(copyOp.getTarget(), rankedTargetType, i); - Value sameDimSize = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, sourceDim, targetDim); - builder.create<cf::AssertOp>( - loc, sameDimSize, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "size of " + std::to_string(i) + - "-th source/target dim does not match")); + Value sameDimSize = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, sourceDim, targetDim); + cf::AssertOp::create(builder, loc, sameDimSize, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "size of " + std::to_string(i) + + "-th source/target dim does not match")); } } }; @@ -211,10 +207,11 @@ struct DimOpInterface void generateRuntimeVerification(Operation *op, OpBuilder &builder, Location loc) const { auto dimOp = cast<DimOp>(op); - Value rank = builder.create<RankOp>(loc, dimOp.getSource()); - Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); - builder.create<cf::AssertOp>( - loc, generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), + Value rank = RankOp::create(builder, loc, dimOp.getSource()); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + cf::AssertOp::create( + builder, loc, + generateInBoundsCheck(builder, loc, dimOp.getIndex(), zero, rank), RuntimeVerifiableOpInterface::generateErrorMessage( op, "index is out of bounds")); } @@ -237,7 +234,7 @@ struct LoadStoreOpInterface } auto indices = loadStoreOp.getIndices(); - auto zero = builder.create<arith::ConstantIndexOp>(loc, 0); + auto zero = arith::ConstantIndexOp::create(builder, loc, 0); Value assertCond; for (auto i : llvm::seq<int64_t>(0, rank)) { Value dimOp = builder.createOrFold<memref::DimOp>(loc, memref, i); @@ -247,10 +244,9 @@ struct LoadStoreOpInterface i > 0 ? builder.createOrFold<arith::AndIOp>(loc, assertCond, inBounds) : inBounds; } - builder.create<cf::AssertOp>( - loc, assertCond, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "out-of-bounds access")); + cf::AssertOp::create(builder, loc, assertCond, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "out-of-bounds access")); } }; @@ -265,10 +261,10 @@ struct SubViewOpInterface // For each dimension, assert that: // 0 <= offset < dim_size // 0 <= offset + (size - 1) * stride < dim_size - Value zero = builder.create<arith::ConstantIndexOp>(loc, 0); - Value one = builder.create<arith::ConstantIndexOp>(loc, 1); + Value zero = arith::ConstantIndexOp::create(builder, loc, 0); + Value one = arith::ConstantIndexOp::create(builder, loc, 1); auto metadataOp = - builder.create<ExtractStridedMetadataOp>(loc, subView.getSource()); + ExtractStridedMetadataOp::create(builder, loc, subView.getSource()); for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) { Value offset = getValueOrCreateConstantIndexOp( builder, loc, subView.getMixedOffsets()[i]); @@ -281,21 +277,21 @@ struct SubViewOpInterface Value dimSize = metadataOp.getSizes()[i]; Value offsetInBounds = generateInBoundsCheck(builder, loc, offset, zero, dimSize); - builder.create<cf::AssertOp>( - loc, offsetInBounds, + cf::AssertOp::create( + builder, loc, offsetInBounds, RuntimeVerifiableOpInterface::generateErrorMessage( op, "offset " + std::to_string(i) + " is out-of-bounds")); // Verify that slice does not run out-of-bounds. - Value sizeMinusOne = builder.create<arith::SubIOp>(loc, size, one); + Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one); Value sizeMinusOneTimesStride = - builder.create<arith::MulIOp>(loc, sizeMinusOne, stride); + arith::MulIOp::create(builder, loc, sizeMinusOne, stride); Value lastPos = - builder.create<arith::AddIOp>(loc, offset, sizeMinusOneTimesStride); + arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride); Value lastPosInBounds = generateInBoundsCheck(builder, loc, lastPos, zero, dimSize); - builder.create<cf::AssertOp>( - loc, lastPosInBounds, + cf::AssertOp::create( + builder, loc, lastPosInBounds, RuntimeVerifiableOpInterface::generateErrorMessage( op, "subview runs out-of-bounds along dimension " + std::to_string(i))); @@ -315,7 +311,7 @@ struct ExpandShapeOpInterface for (const auto &it : llvm::enumerate(expandShapeOp.getReassociationIndices())) { Value srcDimSz = - builder.create<DimOp>(loc, expandShapeOp.getSrc(), it.index()); + DimOp::create(builder, loc, expandShapeOp.getSrc(), it.index()); int64_t groupSz = 1; bool foundDynamicDim = false; for (int64_t resultDim : it.value()) { @@ -330,18 +326,17 @@ struct ExpandShapeOpInterface groupSz *= expandShapeOp.getResultType().getDimSize(resultDim); } Value staticResultDimSz = - builder.create<arith::ConstantIndexOp>(loc, groupSz); + arith::ConstantIndexOp::create(builder, loc, groupSz); // staticResultDimSz must divide srcDimSz evenly. Value mod = - builder.create<arith::RemSIOp>(loc, srcDimSz, staticResultDimSz); - Value isModZero = builder.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::eq, mod, - builder.create<arith::ConstantIndexOp>(loc, 0)); - builder.create<cf::AssertOp>( - loc, isModZero, - RuntimeVerifiableOpInterface::generateErrorMessage( - op, "static result dims in reassoc group do not " - "divide src dim evenly")); + arith::RemSIOp::create(builder, loc, srcDimSz, staticResultDimSz); + Value isModZero = arith::CmpIOp::create( + builder, loc, arith::CmpIPredicate::eq, mod, + arith::ConstantIndexOp::create(builder, loc, 0)); + cf::AssertOp::create(builder, loc, isModZero, + RuntimeVerifiableOpInterface::generateErrorMessage( + op, "static result dims in reassoc group do not " + "divide src dim evenly")); } } }; diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index a50b4cf..5af46a4 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Interfaces/ViewLikeInterface.h" #include "llvm/ADT/STLExtras.h" @@ -217,5 +218,70 @@ MemrefValue skipViewLikeOps(MemrefValue source) { return source; } +LogicalResult resolveSourceIndicesExpandShape( + Location loc, PatternRewriter &rewriter, + memref::ExpandShapeOp expandShapeOp, ValueRange indices, + SmallVectorImpl<Value> &sourceIndices, bool startsInbounds) { + SmallVector<OpFoldResult> destShape = expandShapeOp.getMixedOutputShape(); + + // Traverse all reassociation groups to determine the appropriate indices + // corresponding to each one of them post op folding. + for (ArrayRef<int64_t> group : expandShapeOp.getReassociationIndices()) { + assert(!group.empty() && "association indices groups cannot be empty"); + int64_t groupSize = group.size(); + if (groupSize == 1) { + sourceIndices.push_back(indices[group[0]]); + continue; + } + SmallVector<OpFoldResult> groupBasis = + llvm::map_to_vector(group, [&](int64_t d) { return destShape[d]; }); + SmallVector<Value> groupIndices = + llvm::map_to_vector(group, [&](int64_t d) { return indices[d]; }); + Value collapsedIndex = affine::AffineLinearizeIndexOp::create( + rewriter, loc, groupIndices, groupBasis, /*disjoint=*/startsInbounds); + sourceIndices.push_back(collapsedIndex); + } + return success(); +} + +LogicalResult +resolveSourceIndicesCollapseShape(Location loc, PatternRewriter &rewriter, + memref::CollapseShapeOp collapseShapeOp, + ValueRange indices, + SmallVectorImpl<Value> &sourceIndices) { + // Note: collapse_shape requires a strided memref, we can do this. + auto metadata = memref::ExtractStridedMetadataOp::create( + rewriter, loc, collapseShapeOp.getSrc()); + SmallVector<OpFoldResult> sourceSizes = metadata.getConstifiedMixedSizes(); + for (auto [index, group] : + llvm::zip(indices, collapseShapeOp.getReassociationIndices())) { + assert(!group.empty() && "association indices groups cannot be empty"); + int64_t groupSize = group.size(); + + if (groupSize == 1) { + sourceIndices.push_back(index); + continue; + } + + SmallVector<OpFoldResult> basis = + llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; }); + auto delinearize = affine::AffineDelinearizeIndexOp::create( + rewriter, loc, index, basis, /*hasOuterBound=*/true); + llvm::append_range(sourceIndices, delinearize.getResults()); + } + if (collapseShapeOp.getReassociationIndices().empty()) { + auto zeroAffineMap = rewriter.getConstantAffineMap(0); + int64_t srcRank = + cast<MemRefType>(collapseShapeOp.getViewSource().getType()).getRank(); + OpFoldResult ofr = affine::makeComposedFoldedAffineApply( + rewriter, loc, zeroAffineMap, ArrayRef<OpFoldResult>{}); + for (int64_t i = 0; i < srcRank; i++) { + sourceIndices.push_back( + getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); + } + } + return success(); +} + } // namespace memref } // namespace mlir diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index f5f0bfa..bc3e8b2 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -38,9 +38,6 @@ using namespace mlir::NVVM; using namespace mlir::transform; #define DEBUG_TYPE "nvgpu-transforms" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") -#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") //===----------------------------------------------------------------------===// // Apply...ConversionPatternsOp diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index e73bdd3..9d5dfc1 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -2957,6 +2957,23 @@ bool acc::LoopOp::hasDefaultGangWorkerVector() { getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static); } +acc::LoopParMode +acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) { + if (hasSeq(deviceType)) + return LoopParMode::loop_seq; + if (hasAuto(deviceType)) + return LoopParMode::loop_auto; + if (hasIndependent(deviceType)) + return LoopParMode::loop_independent; + if (hasSeq()) + return LoopParMode::loop_seq; + if (hasAuto()) + return LoopParMode::loop_auto; + assert(hasIndependent() && + "loop must have default auto, seq, or independent"); + return LoopParMode::loop_independent; +} + void acc::LoopOp::addGangOperands( MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes, llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) { diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 5d6c5499..c1c1767 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -1730,8 +1730,7 @@ static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) { if (!mapOp.getDefiningOp()) return emitError(op->getLoc(), "missing map operation"); - if (auto mapInfoOp = - mlir::dyn_cast<mlir::omp::MapInfoOp>(mapOp.getDefiningOp())) { + if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) { uint64_t mapTypeBits = mapInfoOp.getMapType(); bool to = mapTypeToBitFlag( diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index b44dbfd..c5ec0ca 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -53,7 +53,7 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) { Value ptrLike; FromPtrOp fromPtr = *this; while (fromPtr != nullptr) { - auto toPtr = dyn_cast_or_null<ToPtrOp>(fromPtr.getPtr().getDefiningOp()); + auto toPtr = fromPtr.getPtr().getDefiningOp<ToPtrOp>(); // Cannot fold if it's not a `to_ptr` op or the initial and final types are // different. if (!toPtr || toPtr.getPtr().getType() != fromPtr.getType()) @@ -64,13 +64,12 @@ OpFoldResult FromPtrOp::fold(FoldAdaptor adaptor) { ptrLike = toPtr.getPtr(); } else if (md) { // Fold if the metadata can be verified to be equal. - if (auto mdOp = dyn_cast_or_null<GetMetadataOp>(md.getDefiningOp()); + if (auto mdOp = md.getDefiningOp<GetMetadataOp>(); mdOp && mdOp.getPtr() == toPtr.getPtr()) ptrLike = toPtr.getPtr(); } // Check for a sequence of casts. - fromPtr = dyn_cast_or_null<FromPtrOp>(ptrLike ? ptrLike.getDefiningOp() - : nullptr); + fromPtr = ptrLike ? ptrLike.getDefiningOp<FromPtrOp>() : nullptr; } return ptrLike; } @@ -112,13 +111,13 @@ OpFoldResult ToPtrOp::fold(FoldAdaptor adaptor) { Value ptr; ToPtrOp toPtr = *this; while (toPtr != nullptr) { - auto fromPtr = dyn_cast_or_null<FromPtrOp>(toPtr.getPtr().getDefiningOp()); + auto fromPtr = toPtr.getPtr().getDefiningOp<FromPtrOp>(); // Cannot fold if it's not a `from_ptr` op. if (!fromPtr) return ptr; ptr = fromPtr.getPtr(); // Check for chains of casts. - toPtr = dyn_cast_or_null<ToPtrOp>(ptr.getDefiningOp()); + toPtr = ptr.getDefiningOp<ToPtrOp>(); } return ptr; } diff --git a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp index 58cd160..9e37bc5 100644 --- a/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp +++ b/mlir/lib/Dialect/Quant/Transforms/LowerQuantOps.cpp @@ -148,16 +148,14 @@ flattenUnrankedTensorAroundAxis(OpBuilder &builder, Location loc, Value input, auto axisValue = arith::ConstantIndexOp::create(builder, loc, axis); auto axisNextValue = arith::ConstantIndexOp::create(builder, loc, axis + 1); auto shapeLeft = - builder - .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType}, - inputShape, axisValue) + shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType}, + inputShape, axisValue) .getResult(0); auto sizeLeft = shape::NumElementsOp::create(builder, loc, indexType, shapeLeft); auto shapeRight = - builder - .create<shape::SplitAtOp>(loc, TypeRange{shapeType, shapeType}, - inputShape, axisNextValue) + shape::SplitAtOp::create(builder, loc, TypeRange{shapeType, shapeType}, + inputShape, axisNextValue) .getResult(1); auto sizeRight = shape::NumElementsOp::create(builder, loc, indexType, shapeRight); @@ -557,25 +555,24 @@ Value convertPerChannelRanked(OpBuilder &builder, Location loc, Operation *op, SmallVector<AffineMap> indexingMaps{ builder.getMultiDimIdentityMap(inputRank), channelAxisAffineMap, channelAxisAffineMap, builder.getMultiDimIdentityMap(inputRank)}; - auto result = builder - .create<linalg::GenericOp>( - loc, - init.getType(), // resultType - ValueRange{input, scales, zeroPoints}, // inputs - ValueRange{init}, // outputs - indexingMaps, iteratorTypes, - [&](OpBuilder &builder, Location loc, ValueRange args) { - assert(args.size() == 4); - auto input = args[0]; - auto scale = args[1]; - auto zeroPoint = args[2]; - - auto result = - convertRanked(builder, loc, op, input, {}, scale, - zeroPoint, quantizedType); - - linalg::YieldOp::create(builder, loc, result); - }) + auto result = linalg::GenericOp::create( + builder, loc, + init.getType(), // resultType + ValueRange{input, scales, zeroPoints}, // inputs + ValueRange{init}, // outputs + indexingMaps, iteratorTypes, + [&](OpBuilder &builder, Location loc, ValueRange args) { + assert(args.size() == 4); + auto input = args[0]; + auto scale = args[1]; + auto zeroPoint = args[2]; + + auto result = + convertRanked(builder, loc, op, input, {}, scale, + zeroPoint, quantizedType); + + linalg::YieldOp::create(builder, loc, result); + }) .getResult(0); return result; @@ -660,25 +657,24 @@ Value convertSubChannel(OpBuilder &builder, Location loc, Operation *op, SmallVector<AffineMap> indexingMaps{ builder.getMultiDimIdentityMap(inputRank), affineMap, affineMap, builder.getMultiDimIdentityMap(inputRank)}; - auto result = builder - .create<linalg::GenericOp>( - loc, - init.getType(), // resultType - ValueRange{input, scales, zeroPoints}, // inputs - ValueRange{init}, // outputs - indexingMaps, iteratorTypes, - [&](OpBuilder &builder, Location loc, ValueRange args) { - assert(args.size() == 4); - auto input = args[0]; - auto scale = args[1]; - auto zeroPoint = args[2]; - - auto result = - convertRanked(builder, loc, op, input, {}, scale, - zeroPoint, quantizedType); - - linalg::YieldOp::create(builder, loc, result); - }) + auto result = linalg::GenericOp::create( + builder, loc, + init.getType(), // resultType + ValueRange{input, scales, zeroPoints}, // inputs + ValueRange{init}, // outputs + indexingMaps, iteratorTypes, + [&](OpBuilder &builder, Location loc, ValueRange args) { + assert(args.size() == 4); + auto input = args[0]; + auto scale = args[1]; + auto zeroPoint = args[2]; + + auto result = + convertRanked(builder, loc, op, input, {}, scale, + zeroPoint, quantizedType); + + linalg::YieldOp::create(builder, loc, result); + }) .getResult(0); return result; diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index e282ca4..0262a1b 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -84,7 +84,7 @@ void SCFDialect::initialize() { /// Default callback for IfOp builders. Inserts a yield without arguments. void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) { - builder.create<scf::YieldOp>(loc); + scf::YieldOp::create(builder, loc); } /// Verifies that the first block of the given `region` is terminated by a @@ -137,6 +137,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, if (parser.parseOptionalArrowTypeList(result.types)) return failure(); + if (succeeded(parser.parseOptionalKeyword("no_inline"))) + result.addAttribute("no_inline", parser.getBuilder().getUnitAttr()); + // Introduce the body region and parse it. Region *body = result.addRegion(); if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) || @@ -148,8 +151,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser, void ExecuteRegionOp::print(OpAsmPrinter &p) { p.printOptionalArrowTypeList(getResultTypes()); - p << ' '; + if (getNoInline()) + p << "no_inline "; p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); @@ -184,7 +188,7 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { LogicalResult matchAndRewrite(ExecuteRegionOp op, PatternRewriter &rewriter) const override { - if (!op.getRegion().hasOneBlock()) + if (!op.getRegion().hasOneBlock() || op.getNoInline()) return failure(); replaceOpWithRegion(rewriter, op, op.getRegion()); return success(); @@ -240,13 +244,13 @@ struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> { Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator()); rewriter.setInsertionPointToEnd(prevBlock); - rewriter.create<cf::BranchOp>(op.getLoc(), &op.getRegion().front()); + cf::BranchOp::create(rewriter, op.getLoc(), &op.getRegion().front()); for (Block &blk : op.getRegion()) { if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) { rewriter.setInsertionPoint(yieldOp); - rewriter.create<cf::BranchOp>(yieldOp.getLoc(), postBlock, - yieldOp.getResults()); + cf::BranchOp::create(rewriter, yieldOp.getLoc(), postBlock, + yieldOp.getResults()); rewriter.eraseOp(yieldOp); } } @@ -556,8 +560,8 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter, rewriter.setInsertionPoint(getOperation()); auto inits = llvm::to_vector(getInitArgs()); inits.append(newInitOperands.begin(), newInitOperands.end()); - scf::ForOp newLoop = rewriter.create<scf::ForOp>( - getLoc(), getLowerBound(), getUpperBound(), getStep(), inits, + scf::ForOp newLoop = scf::ForOp::create( + rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); newLoop->setAttrs(getPrunedAttributeList(getOperation(), {})); @@ -672,8 +676,8 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) { Value dst = parallelInsertSliceOp.getDest(); Value src = parallelInsertSliceOp.getSource(); if (llvm::isa<TensorType>(src.getType())) { - results.push_back(rewriter.create<tensor::InsertSliceOp>( - forallOp.getLoc(), dst.getType(), src, dst, + results.push_back(tensor::InsertSliceOp::create( + rewriter, forallOp.getLoc(), dst.getType(), src, dst, parallelInsertSliceOp.getOffsets(), parallelInsertSliceOp.getSizes(), parallelInsertSliceOp.getStrides(), parallelInsertSliceOp.getStaticOffsets(), @@ -721,8 +725,8 @@ LoopNest mlir::scf::buildLoopNest( ValueRange currentIterArgs = iterArgs; Location currentLoc = loc; for (unsigned i = 0, e = lbs.size(); i < e; ++i) { - auto loop = builder.create<scf::ForOp>( - currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs, + auto loop = scf::ForOp::create( + builder, currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs, [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange args) { ivs.push_back(iv); @@ -741,7 +745,7 @@ LoopNest mlir::scf::buildLoopNest( // For all loops but the innermost, yield the results of the nested loop. for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) { builder.setInsertionPointToEnd(loops[i].getBody()); - builder.create<scf::YieldOp>(loc, loops[i + 1].getResults()); + scf::YieldOp::create(builder, loc, loops[i + 1].getResults()); } // In the body of the innermost loop, call the body building function if any @@ -755,7 +759,7 @@ LoopNest mlir::scf::buildLoopNest( "loop nest body must return as many values as loop has iteration " "arguments"); builder.setInsertionPointToEnd(loops.back().getBody()); - builder.create<scf::YieldOp>(loc, results); + scf::YieldOp::create(builder, loc, results); // Return the loops. ValueVector nestResults; @@ -800,8 +804,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, } // 2. Create the new forOp shell. - scf::ForOp newForOp = rewriter.create<scf::ForOp>( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + scf::ForOp newForOp = scf::ForOp::create( + rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), newIterOperands); newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.getRegion().front(); @@ -830,7 +834,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, clonedYieldOp.getOperand(yieldIdx)); SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands(); newYieldOperands[yieldIdx] = castOut; - rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands); + scf::YieldOp::create(rewriter, newForOp.getLoc(), newYieldOperands); rewriter.eraseOp(clonedYieldOp); // 6. Inject an outgoing cast op after the forOp. @@ -925,9 +929,9 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { if (!canonicalize) return failure(); - scf::ForOp newForOp = rewriter.create<scf::ForOp>( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newIterArgs); + scf::ForOp newForOp = + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), + forOp.getUpperBound(), forOp.getStep(), newIterArgs); newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.getRegion().front(); @@ -969,8 +973,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx) if (keepMask[idx]) filteredOperands.push_back(mergedTerminator.getOperand(idx)); - rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(), - filteredOperands); + scf::YieldOp::create(rewriter, mergedTerminator.getLoc(), + filteredOperands); }; rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs); @@ -1110,7 +1114,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> { op, replaceAndCastForOpIterArg( rewriter, op, iterOpOperand, incomingCast.getSource(), [](OpBuilder &b, Location loc, Type type, Value source) { - return b.create<tensor::CastOp>(loc, type, source); + return tensor::CastOp::create(b, loc, type, source); })); return success(); } @@ -1684,8 +1688,8 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> { // Step 3. Create a new scf.forall op with the new shared_outs' operands // fetched earlier - auto newForallOp = rewriter.create<scf::ForallOp>( - forallOp.getLoc(), forallOp.getMixedLowerBound(), + auto newForallOp = scf::ForallOp::create( + rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), forallOp.getMixedStep(), newOuts, forallOp.getMapping(), /*bodyBuilderFn =*/[](OpBuilder &, Location, ValueRange) {}); @@ -1781,9 +1785,9 @@ struct ForallOpSingleOrZeroIterationDimsFolder // Replace the loop by a lower-dimensional loop. ForallOp newOp; - newOp = rewriter.create<ForallOp>(loc, newMixedLowerBounds, - newMixedUpperBounds, newMixedSteps, - op.getOutputs(), std::nullopt, nullptr); + newOp = ForallOp::create(rewriter, loc, newMixedLowerBounds, + newMixedUpperBounds, newMixedSteps, + op.getOutputs(), std::nullopt, nullptr); newOp.getBodyRegion().getBlocks().clear(); // The new loop needs to keep all attributes from the old one, except for // "operandSegmentSizes" and static loop bound attributes which capture @@ -1866,16 +1870,17 @@ struct FoldTensorCastOfOutputIntoForallOp // Create new loop. Location loc = forallOp.getLoc(); - auto newForallOp = rewriter.create<ForallOp>( - loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), - forallOp.getMixedStep(), newOutputTensors, forallOp.getMapping(), + auto newForallOp = ForallOp::create( + rewriter, loc, forallOp.getMixedLowerBound(), + forallOp.getMixedUpperBound(), forallOp.getMixedStep(), + newOutputTensors, forallOp.getMapping(), [&](OpBuilder nestedBuilder, Location nestedLoc, ValueRange bbArgs) { auto castBlockArgs = llvm::to_vector(bbArgs.take_back(forallOp->getNumResults())); for (auto [index, cast] : tensorCastProducers) { Value &oldTypeBBArg = castBlockArgs[index]; - oldTypeBBArg = nestedBuilder.create<tensor::CastOp>( - nestedLoc, cast.dstType, oldTypeBBArg); + oldTypeBBArg = tensor::CastOp::create(nestedBuilder, nestedLoc, + cast.dstType, oldTypeBBArg); } // Move old body into new parallel loop. @@ -1901,8 +1906,8 @@ struct FoldTensorCastOfOutputIntoForallOp SmallVector<Value> castResults = newForallOp.getResults(); for (auto &item : tensorCastProducers) { Value &oldTypeResult = castResults[item.first]; - oldTypeResult = rewriter.create<tensor::CastOp>(loc, item.second.dstType, - oldTypeResult); + oldTypeResult = tensor::CastOp::create(rewriter, loc, item.second.dstType, + oldTypeResult); } rewriter.replaceOp(forallOp, castResults); return success(); @@ -2310,7 +2315,7 @@ struct RemoveUnusedResults : public OpRewritePattern<IfOp> { // Create a replacement operation with empty then and else regions. auto newOp = - rewriter.create<IfOp>(op.getLoc(), newTypes, op.getCondition()); + IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition()); rewriter.createBlock(&newOp.getThenRegion()); rewriter.createBlock(&newOp.getElseRegion()); @@ -2373,8 +2378,8 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> { if (nonHoistable.size() == op->getNumResults()) return failure(); - IfOp replacement = rewriter.create<IfOp>(op.getLoc(), nonHoistable, cond, - /*withElseRegion=*/false); + IfOp replacement = IfOp::create(rewriter, op.getLoc(), nonHoistable, cond, + /*withElseRegion=*/false); if (replacement.thenBlock()) rewriter.eraseBlock(replacement.thenBlock()); replacement.getThenRegion().takeBody(op.getThenRegion()); @@ -2399,8 +2404,8 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> { } else if (trueVal == falseVal) results[it.index()] = trueVal; else - results[it.index()] = rewriter.create<arith::SelectOp>( - op.getLoc(), cond, trueVal, falseVal); + results[it.index()] = arith::SelectOp::create(rewriter, op.getLoc(), + cond, trueVal, falseVal); } rewriter.setInsertionPointToEnd(replacement.thenBlock()); @@ -2489,8 +2494,8 @@ struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> { if (!trueVal && falseVal) { if (!opResult.use_empty()) { Dialect *constDialect = trueResult.getDefiningOp()->getDialect(); - Value notCond = rewriter.create<arith::XOrIOp>( - op.getLoc(), op.getCondition(), + Value notCond = arith::XOrIOp::create( + rewriter, op.getLoc(), op.getCondition(), constDialect ->materializeConstant(rewriter, rewriter.getIntegerAttr(i1Ty, 1), i1Ty, @@ -2603,8 +2608,8 @@ struct CombineIfs : public OpRewritePattern<IfOp> { SmallVector<Type> mergedTypes(prevIf.getResultTypes()); llvm::append_range(mergedTypes, nextIf.getResultTypes()); - IfOp combinedIf = rewriter.create<IfOp>( - nextIf.getLoc(), mergedTypes, prevIf.getCondition(), /*hasElse=*/false); + IfOp combinedIf = IfOp::create(rewriter, nextIf.getLoc(), mergedTypes, + prevIf.getCondition(), /*hasElse=*/false); rewriter.eraseBlock(&combinedIf.getThenRegion().back()); rewriter.inlineRegionBefore(prevIf.getThenRegion(), @@ -2619,7 +2624,7 @@ struct CombineIfs : public OpRewritePattern<IfOp> { SmallVector<Value> mergedYields(thenYield.getOperands()); llvm::append_range(mergedYields, thenYield2.getOperands()); - rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields); + YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields); rewriter.eraseOp(thenYield); rewriter.eraseOp(thenYield2); } @@ -2643,7 +2648,7 @@ struct CombineIfs : public OpRewritePattern<IfOp> { SmallVector<Value> mergedElseYields(elseYield.getOperands()); llvm::append_range(mergedElseYields, elseYield2.getOperands()); - rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields); + YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields); rewriter.eraseOp(elseYield); rewriter.eraseOp(elseYield2); } @@ -2765,9 +2770,9 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> { } Location loc = op.getLoc(); - Value newCondition = rewriter.create<arith::AndIOp>( - loc, op.getCondition(), nestedIf.getCondition()); - auto newIf = rewriter.create<IfOp>(loc, op.getResultTypes(), newCondition); + Value newCondition = arith::AndIOp::create(rewriter, loc, op.getCondition(), + nestedIf.getCondition()); + auto newIf = IfOp::create(rewriter, loc, op.getResultTypes(), newCondition); Block *newIfBlock = rewriter.createBlock(&newIf.getThenRegion()); SmallVector<Value> results; @@ -2775,8 +2780,9 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> { rewriter.setInsertionPoint(newIf); for (auto idx : elseYieldsToUpgradeToSelect) - results[idx] = rewriter.create<arith::SelectOp>( - op.getLoc(), op.getCondition(), thenYield[idx], elseYield[idx]); + results[idx] = + arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(), + thenYield[idx], elseYield[idx]); rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock); rewriter.setInsertionPointToEnd(newIf.thenBlock()); @@ -2784,7 +2790,7 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> { if (!elseYield.empty()) { rewriter.createBlock(&newIf.getElseRegion()); rewriter.setInsertionPointToEnd(newIf.elseBlock()); - rewriter.create<YieldOp>(loc, elseYield); + YieldOp::create(rewriter, loc, elseYield); } rewriter.replaceOp(op, results); return success(); @@ -3101,8 +3107,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder } // Replace the parallel loop by lower-dimensional parallel loop. auto newOp = - rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds, - newSteps, op.getInitVals(), nullptr); + ParallelOp::create(rewriter, op.getLoc(), newLowerBounds, + newUpperBounds, newSteps, op.getInitVals(), nullptr); // Erase the empty block that was inserted by the builder. rewriter.eraseBlock(newOp.getBody()); // Clone the loop body and remap the block arguments of the collapsed loops @@ -3482,8 +3488,8 @@ struct WhileConditionTruth : public OpRewritePattern<WhileOp> { if (std::get<0>(yieldedAndBlockArgs) == term.getCondition()) { if (!std::get<1>(yieldedAndBlockArgs).use_empty()) { if (!constantTrue) - constantTrue = rewriter.create<arith::ConstantOp>( - op.getLoc(), term.getCondition().getType(), + constantTrue = arith::ConstantOp::create( + rewriter, op.getLoc(), term.getCondition().getType(), rewriter.getBoolAttr(true)); rewriter.replaceAllUsesWith(std::get<1>(yieldedAndBlockArgs), @@ -3625,8 +3631,8 @@ struct RemoveLoopInvariantArgsFromBeforeBlock rewriter.replaceOpWithNewOp<YieldOp>(yieldOp, newYieldOpArgs); } - auto newWhile = - rewriter.create<WhileOp>(op.getLoc(), op.getResultTypes(), newInitArgs); + auto newWhile = WhileOp::create(rewriter, op.getLoc(), op.getResultTypes(), + newInitArgs); Block &newBeforeBlock = *rewriter.createBlock( &newWhile.getBefore(), /*insertPt*/ {}, @@ -3748,8 +3754,8 @@ struct RemoveLoopInvariantValueYielded : public OpRewritePattern<WhileOp> { newCondOpArgs); } - auto newWhile = rewriter.create<WhileOp>(op.getLoc(), newAfterBlockType, - op.getOperands()); + auto newWhile = WhileOp::create(rewriter, op.getLoc(), newAfterBlockType, + op.getOperands()); Block &newAfterBlock = *rewriter.createBlock(&newWhile.getAfter(), /*insertPt*/ {}, @@ -3855,7 +3861,7 @@ struct WhileUnusedResult : public OpRewritePattern<WhileOp> { } auto newWhile = - rewriter.create<WhileOp>(op.getLoc(), newResultTypes, op.getInits()); + WhileOp::create(rewriter, op.getLoc(), newResultTypes, op.getInits()); Block &newAfterBlock = *rewriter.createBlock( &newWhile.getAfter(), /*insertPt*/ {}, newResultTypes, newArgLocs); @@ -3984,8 +3990,8 @@ struct WhileRemoveUnusedArgs : public OpRewritePattern<WhileOp> { Location loc = op.getLoc(); auto newWhileOp = - rewriter.create<WhileOp>(loc, op.getResultTypes(), newInits, - /*beforeBody*/ nullptr, /*afterBody*/ nullptr); + WhileOp::create(rewriter, loc, op.getResultTypes(), newInits, + /*beforeBody*/ nullptr, /*afterBody*/ nullptr); Block &newBeforeBlock = *newWhileOp.getBeforeBody(); Block &newAfterBlock = *newWhileOp.getAfterBody(); @@ -4032,9 +4038,10 @@ struct WhileRemoveDuplicatedResults : public OpRewritePattern<WhileOp> { ValueRange argsRange(newArgs); Location loc = op.getLoc(); - auto newWhileOp = rewriter.create<scf::WhileOp>( - loc, argsRange.getTypes(), op.getInits(), /*beforeBody*/ nullptr, - /*afterBody*/ nullptr); + auto newWhileOp = + scf::WhileOp::create(rewriter, loc, argsRange.getTypes(), op.getInits(), + /*beforeBody*/ nullptr, + /*afterBody*/ nullptr); Block &newBeforeBlock = *newWhileOp.getBeforeBody(); Block &newAfterBlock = *newWhileOp.getAfterBody(); @@ -4128,8 +4135,8 @@ struct WhileOpAlignBeforeArgs : public OpRewritePattern<WhileOp> { for (auto &&[i, j] : llvm::enumerate(*mapping)) newResultTypes[j] = loop.getResult(i).getType(); - auto newLoop = rewriter.create<WhileOp>( - loop.getLoc(), newResultTypes, loop.getInits(), + auto newLoop = WhileOp::create( + rewriter, loop.getLoc(), newResultTypes, loop.getInits(), /*beforeBuilder=*/nullptr, /*afterBuilder=*/nullptr); auto newBefore = newLoop.getBeforeBody(); auto newAfter = newLoop.getAfterBody(); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 9a68565..aea842d 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -160,7 +160,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); scf::ExecuteRegionOp executeRegionOp = - b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes()); + scf::ExecuteRegionOp::create(b, op->getLoc(), op->getResultTypes()); { OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock()); @@ -169,7 +169,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, assert(clonedRegion.empty() && "expected empty region"); b.inlineRegionBefore(op->getRegions().front(), clonedRegion, clonedRegion.end()); - b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults()); + scf::YieldOp::create(b, op->getLoc(), clonedOp->getResults()); } b.replaceOp(op, executeRegionOp.getResults()); return executeRegionOp; diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 8509382..f8799c5 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -41,7 +41,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) { // iter_arg's layout map must be changed (see uses of `castBuffer`). assert(memref::CastOp::areCastCompatible(buffer.getType(), type) && "scf.while op bufferization: cast incompatible"); - return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult(); + return memref::CastOp::create(b, buffer.getLoc(), type, buffer).getResult(); } /// Helper function for loop bufferization. Return "true" if the given value @@ -189,7 +189,7 @@ struct ExecuteRegionOpInterface // Create new op and move over region. auto newOp = - rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes); + scf::ExecuteRegionOp::create(rewriter, op->getLoc(), newResultTypes); newOp.getRegion().takeBody(executeRegionOp.getRegion()); // Bufferize every block. @@ -203,8 +203,8 @@ struct ExecuteRegionOpInterface SmallVector<Value> newResults; for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) { if (isa<TensorType>(it.value())) { - newResults.push_back(rewriter.create<bufferization::ToTensorOp>( - executeRegionOp.getLoc(), it.value(), + newResults.push_back(bufferization::ToTensorOp::create( + rewriter, executeRegionOp.getLoc(), it.value(), newOp->getResult(it.index()))); } else { newResults.push_back(newOp->getResult(it.index())); @@ -258,9 +258,9 @@ struct IfOpInterface // Create new op. rewriter.setInsertionPoint(ifOp); - auto newIfOp = - rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(), - /*withElseRegion=*/true); + auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes, + ifOp.getCondition(), + /*withElseRegion=*/true); // Move over then/else blocks. rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock()); @@ -372,9 +372,9 @@ struct IndexSwitchOpInterface // Create new op. rewriter.setInsertionPoint(switchOp); - auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>( - switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(), - switchOp.getCases().size()); + auto newSwitchOp = scf::IndexSwitchOp::create( + rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(), + switchOp.getCases(), switchOp.getCases().size()); // Move over blocks. for (auto [src, dest] : @@ -497,10 +497,10 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, size_t idx = it.index(); Value val = it.value(); if (tensorIndices.contains(idx)) { - result.push_back(rewriter - .create<bufferization::ToTensorOp>( - val.getLoc(), oldBbArgs[idx].getType(), val) - .getResult()); + result.push_back( + bufferization::ToTensorOp::create(rewriter, val.getLoc(), + oldBbArgs[idx].getType(), val) + .getResult()); } else { result.push_back(val); } @@ -767,8 +767,8 @@ struct ForOpInterface } // Construct a new scf.for op with memref instead of tensor values. - auto newForOp = rewriter.create<scf::ForOp>( - forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + auto newForOp = scf::ForOp::create( + rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), castedInitArgs); newForOp->setAttrs(forOp->getAttrs()); Block *loopBody = newForOp.getBody(); @@ -1003,8 +1003,8 @@ struct WhileOpInterface // Construct a new scf.while op with memref instead of tensor values. ValueRange argsRangeBefore(castedInitArgs); TypeRange argsTypesBefore(argsRangeBefore); - auto newWhileOp = rewriter.create<scf::WhileOp>( - whileOp.getLoc(), argsTypesAfter, castedInitArgs); + auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(), + argsTypesAfter, castedInitArgs); // Add before/after regions to the new op. SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(), @@ -1263,8 +1263,8 @@ struct ForallOpInterface forallOp.getBody()->getArguments().drop_front(rank), buffers)) { BlockArgument bbArg = std::get<0>(it); Value buffer = std::get<1>(it); - Value bufferAsTensor = rewriter.create<ToTensorOp>( - forallOp.getLoc(), bbArg.getType(), buffer); + Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(), + bbArg.getType(), buffer); bbArg.replaceAllUsesWith(bufferAsTensor); } @@ -1272,8 +1272,8 @@ struct ForallOpInterface // introduced terminator. rewriter.setInsertionPoint(forallOp); ForallOp newForallOp; - newForallOp = rewriter.create<ForallOp>( - forallOp.getLoc(), forallOp.getMixedLowerBound(), + newForallOp = ForallOp::create( + rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), forallOp.getMixedStep(), /*outputs=*/ValueRange(), forallOp.getMapping()); diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index 3e93dc8..bee7780 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -50,19 +50,19 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> { SmallVector<Value> initArgs; initArgs.push_back(forOp.getLowerBound()); llvm::append_range(initArgs, forOp.getInitArgs()); - auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs, - forOp->getAttrs()); + auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs, + forOp->getAttrs()); // 'before' region contains the loop condition and forwarding of iteration // arguments to the 'after' region. auto *beforeBlock = rewriter.createBlock( &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs); rewriter.setInsertionPointToStart(whileOp.getBeforeBody()); - auto cmpOp = rewriter.create<arith::CmpIOp>( - whileOp.getLoc(), arith::CmpIPredicate::slt, + auto cmpOp = arith::CmpIOp::create( + rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt, beforeBlock->getArgument(0), forOp.getUpperBound()); - rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(), - beforeBlock->getArguments()); + scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(), + beforeBlock->getArguments()); // Inline for-loop body into an executeRegion operation in the "after" // region. The return type of the execRegionOp does not contain the @@ -72,8 +72,9 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> { // Add induction variable incrementation rewriter.setInsertionPointToEnd(afterBlock); - auto ivIncOp = rewriter.create<arith::AddIOp>( - whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep()); + auto ivIncOp = + arith::AddIOp::create(rewriter, whileOp.getLoc(), + afterBlock->getArgument(0), forOp.getStep()); // Rewrite uses of the for-loop block arguments to the new while-loop // "after" arguments diff --git a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp index 44e6840..b95604f 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp @@ -40,7 +40,7 @@ LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter, SmallVector<Value> steps = forallOp.getStep(rewriter); // Create empty scf.parallel op. - auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps); + auto parallelOp = scf::ParallelOp::create(rewriter, loc, lbs, ubs, steps); rewriter.eraseBlock(¶llelOp.getRegion().front()); rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(), parallelOp.getRegion().begin()); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index bcecef5..1130538 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -19,12 +19,10 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/MathExtras.h" #define DEBUG_TYPE "scf-loop-pipelining" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; using namespace mlir::scf; @@ -100,7 +98,7 @@ public: bool LoopPipelinerInternal::initializeLoopInfo( ForOp op, const PipeliningOption &options) { - LDBG("Start initializeLoopInfo"); + LDBG() << "Start initializeLoopInfo"; forOp = op; ub = forOp.getUpperBound(); lb = forOp.getLowerBound(); @@ -109,7 +107,7 @@ bool LoopPipelinerInternal::initializeLoopInfo( std::vector<std::pair<Operation *, unsigned>> schedule; options.getScheduleFn(forOp, schedule); if (schedule.empty()) { - LDBG("--empty schedule -> BAIL"); + LDBG() << "--empty schedule -> BAIL"; return false; } @@ -126,7 +124,7 @@ bool LoopPipelinerInternal::initializeLoopInfo( auto stepCst = getConstantIntValue(step); if (!upperBoundCst || !lowerBoundCst || !stepCst) { if (!options.supportDynamicLoops) { - LDBG("--dynamic loop not supported -> BAIL"); + LDBG() << "--dynamic loop not supported -> BAIL"; return false; } } else { @@ -134,21 +132,21 @@ bool LoopPipelinerInternal::initializeLoopInfo( int64_t lbImm = lowerBoundCst.value(); int64_t stepImm = stepCst.value(); if (stepImm <= 0) { - LDBG("--invalid loop step -> BAIL"); + LDBG() << "--invalid loop step -> BAIL"; return false; } int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm); if (numIteration >= maxStage) { dynamicLoop = false; } else if (!options.supportDynamicLoops) { - LDBG("--fewer loop iterations than pipeline stages -> BAIL"); + LDBG() << "--fewer loop iterations than pipeline stages -> BAIL"; return false; } } peelEpilogue = options.peelEpilogue; predicateFn = options.predicateFn; if ((!peelEpilogue || dynamicLoop) && predicateFn == nullptr) { - LDBG("--no epilogue or predicate set -> BAIL"); + LDBG() << "--no epilogue or predicate set -> BAIL"; return false; } @@ -156,13 +154,13 @@ bool LoopPipelinerInternal::initializeLoopInfo( for (Operation &op : forOp.getBody()->without_terminator()) { if (!stages.contains(&op)) { op.emitOpError("not assigned a pipeline stage"); - LDBG("--op not assigned a pipeline stage: " << op << " -> BAIL"); + LDBG() << "--op not assigned a pipeline stage: " << op << " -> BAIL"; return false; } } if (!verifySchedule()) { - LDBG("--invalid schedule: " << op << " -> BAIL"); + LDBG() << "--invalid schedule: " << op << " -> BAIL"; return false; } @@ -173,15 +171,16 @@ bool LoopPipelinerInternal::initializeLoopInfo( (void)stageNum; if (op == forOp.getBody()->getTerminator()) { op->emitError("terminator should not be assigned a stage"); - LDBG("--terminator should not be assigned stage: " << *op << " -> BAIL"); + LDBG() << "--terminator should not be assigned stage: " << *op + << " -> BAIL"; return false; } if (op->getBlock() != forOp.getBody()) { op->emitOpError("the owning Block of all operations assigned a stage " "should be the loop body block"); - LDBG("--the owning Block of all operations assigned a stage " - "should be the loop body block: " - << *op << " -> BAIL"); + LDBG() << "--the owning Block of all operations assigned a stage " + "should be the loop body block: " + << *op << " -> BAIL"; return false; } } @@ -196,8 +195,8 @@ bool LoopPipelinerInternal::initializeLoopInfo( return !def || (!stages.contains(def) && forOp->isAncestor(def)); })) { - LDBG("--only support loop carried dependency with a distance of 1 or " - "defined outside of the loop -> BAIL"); + LDBG() << "--only support loop carried dependency with a distance of 1 or " + "defined outside of the loop -> BAIL"; return false; } annotateFn = options.annotateFn; @@ -279,25 +278,25 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { if (dynamicLoop) { Type t = ub.getType(); // pred = ub > lb + (i * step) - Value iv = rewriter.create<arith::AddIOp>( - loc, lb, - rewriter.create<arith::MulIOp>( - loc, step, - rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr(t, i)))); - predicates[i] = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::slt, iv, ub); + Value iv = arith::AddIOp::create( + rewriter, loc, lb, + arith::MulIOp::create( + rewriter, loc, step, + arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(t, i)))); + predicates[i] = arith::CmpIOp::create(rewriter, loc, + arith::CmpIPredicate::slt, iv, ub); } // special handling for induction variable as the increment is implicit. // iv = lb + i * step Type t = lb.getType(); - Value iv = rewriter.create<arith::AddIOp>( - loc, lb, - rewriter.create<arith::MulIOp>( - loc, step, - rewriter.create<arith::ConstantOp>(loc, - rewriter.getIntegerAttr(t, i)))); + Value iv = arith::AddIOp::create( + rewriter, loc, lb, + arith::MulIOp::create( + rewriter, loc, step, + arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(t, i)))); setValueMapping(forOp.getInductionVar(), iv, i); for (Operation *op : opOrder) { if (stages[op] > i) @@ -332,8 +331,8 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) { Value prevValue = valueMapping [forOp.getRegionIterArgs()[operand.getOperandNumber()]] [i - stages[op]]; - source = rewriter.create<arith::SelectOp>( - loc, predicates[predicateIdx], source, prevValue); + source = arith::SelectOp::create( + rewriter, loc, predicates[predicateIdx], source, prevValue); } setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()], source, i - stages[op] + 1); @@ -444,15 +443,15 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop( Type t = ub.getType(); Location loc = forOp.getLoc(); // newUb = ub - maxStage * step - Value maxStageValue = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr(t, maxStage)); + Value maxStageValue = arith::ConstantOp::create( + rewriter, loc, rewriter.getIntegerAttr(t, maxStage)); Value maxStageByStep = - rewriter.create<arith::MulIOp>(loc, step, maxStageValue); - newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep); + arith::MulIOp::create(rewriter, loc, step, maxStageValue); + newUb = arith::SubIOp::create(rewriter, loc, ub, maxStageByStep); } auto newForOp = - rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb, - forOp.getStep(), newLoopArg); + scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), newUb, + forOp.getStep(), newLoopArg); // When there are no iter args, the loop body terminator will be created. // Since we always create it below, remove the terminator if it was created. if (!newForOp.getBody()->empty()) @@ -483,16 +482,17 @@ LogicalResult LoopPipelinerInternal::createKernel( Type t = ub.getType(); for (unsigned i = 0; i < maxStage; i++) { // c = ub - (maxStage - i) * step - Value c = rewriter.create<arith::SubIOp>( - loc, ub, - rewriter.create<arith::MulIOp>( - loc, step, - rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i))))); - - Value pred = rewriter.create<arith::CmpIOp>( - newForOp.getLoc(), arith::CmpIPredicate::slt, - newForOp.getInductionVar(), c); + Value c = arith::SubIOp::create( + rewriter, loc, ub, + arith::MulIOp::create( + rewriter, loc, step, + arith::ConstantOp::create( + rewriter, loc, + rewriter.getIntegerAttr(t, int64_t(maxStage - i))))); + + Value pred = arith::CmpIOp::create(rewriter, newForOp.getLoc(), + arith::CmpIPredicate::slt, + newForOp.getInductionVar(), c); predicates[i] = pred; } } @@ -515,13 +515,13 @@ LogicalResult LoopPipelinerInternal::createKernel( // offset = (maxStage - stages[op]) * step Type t = step.getType(); - Value offset = rewriter.create<arith::MulIOp>( - forOp.getLoc(), step, - rewriter.create<arith::ConstantOp>( - forOp.getLoc(), + Value offset = arith::MulIOp::create( + rewriter, forOp.getLoc(), step, + arith::ConstantOp::create( + rewriter, forOp.getLoc(), rewriter.getIntegerAttr(t, maxStage - stages[op]))); - Value iv = rewriter.create<arith::AddIOp>( - forOp.getLoc(), newForOp.getInductionVar(), offset); + Value iv = arith::AddIOp::create(rewriter, forOp.getLoc(), + newForOp.getInductionVar(), offset); nestedNewOp->setOperand(operand->getOperandNumber(), iv); rewriter.setInsertionPointAfter(newOp); continue; @@ -594,8 +594,8 @@ LogicalResult LoopPipelinerInternal::createKernel( auto defStage = stages.find(def); if (defStage != stages.end() && defStage->second < maxStage) { Value pred = predicates[defStage->second]; - source = rewriter.create<arith::SelectOp>( - pred.getLoc(), pred, source, + source = arith::SelectOp::create( + rewriter, pred.getLoc(), pred, source, newForOp.getBody() ->getArguments()[yieldOperand.getOperandNumber() + 1]); } @@ -638,7 +638,7 @@ LogicalResult LoopPipelinerInternal::createKernel( maxStage - defStage->second + 1); } } - rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands); + scf::YieldOp::create(rewriter, forOp.getLoc(), yieldOperands); return success(); } @@ -652,8 +652,8 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, // removed by dead code if not used. auto createConst = [&](int v) { - return rewriter.create<arith::ConstantOp>(loc, - rewriter.getIntegerAttr(t, v)); + return arith::ConstantOp::create(rewriter, loc, + rewriter.getIntegerAttr(t, v)); }; // total_iterations = cdiv(range_diff, step); @@ -661,42 +661,44 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, // - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step Value zero = createConst(0); Value one = createConst(1); - Value stepLessZero = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::slt, step, zero); - Value stepDecr = - rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1)); + Value stepLessZero = arith::CmpIOp::create( + rewriter, loc, arith::CmpIPredicate::slt, step, zero); + Value stepDecr = arith::SelectOp::create(rewriter, loc, stepLessZero, one, + createConst(-1)); - Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb); - Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step); + Value rangeDiff = arith::SubIOp::create(rewriter, loc, ub, lb); + Value rangeIncrStep = arith::AddIOp::create(rewriter, loc, rangeDiff, step); Value rangeDecr = - rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr); - Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step); + arith::AddIOp::create(rewriter, loc, rangeIncrStep, stepDecr); + Value totalIterations = + arith::DivSIOp::create(rewriter, loc, rangeDecr, step); // If total_iters < max_stage, start the epilogue at zero to match the // ramp-up in the prologue. // start_iter = max(0, total_iters - max_stage) - Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations, - createConst(maxStage)); - iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI); + Value iterI = arith::SubIOp::create(rewriter, loc, totalIterations, + createConst(maxStage)); + iterI = arith::MaxSIOp::create(rewriter, loc, zero, iterI); // Capture predicates for dynamic loops. SmallVector<Value> predicates(maxStage + 1); for (int64_t i = 1; i <= maxStage; i++) { // newLastIter = lb + step * iterI - Value newlastIter = rewriter.create<arith::AddIOp>( - loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI)); + Value newlastIter = arith::AddIOp::create( + rewriter, loc, lb, arith::MulIOp::create(rewriter, loc, step, iterI)); setValueMapping(forOp.getInductionVar(), newlastIter, i); // increment to next iterI - iterI = rewriter.create<arith::AddIOp>(loc, iterI, one); + iterI = arith::AddIOp::create(rewriter, loc, iterI, one); if (dynamicLoop) { // Disable stages when `i` is greater than total_iters. // pred = total_iters >= i - predicates[i] = rewriter.create<arith::CmpIOp>( - loc, arith::CmpIPredicate::sge, totalIterations, createConst(i)); + predicates[i] = + arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge, + totalIterations, createConst(i)); } } @@ -758,8 +760,8 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, unsigned nextVersion = currentVersion + 1; Value pred = predicates[currentVersion]; Value prevValue = valueMapping[mapVal][currentVersion]; - auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(), - prevValue); + auto selOp = arith::SelectOp::create(rewriter, loc, pred, + pair.value(), prevValue); returnValues[ri] = selOp; if (nextVersion <= maxStage) setValueMapping(mapVal, selOp, nextVersion); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index d17cd47..4752c08 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -63,13 +63,13 @@ static void specializeParallelLoopForUnrolling(ParallelOp op) { Value cond; for (auto bound : llvm::zip(op.getUpperBound(), constantIndices)) { Value constant = - b.create<arith::ConstantIndexOp>(op.getLoc(), std::get<1>(bound)); - Value cmp = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq, - std::get<0>(bound), constant); - cond = cond ? b.create<arith::AndIOp>(op.getLoc(), cond, cmp) : cmp; + arith::ConstantIndexOp::create(b, op.getLoc(), std::get<1>(bound)); + Value cmp = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::eq, + std::get<0>(bound), constant); + cond = cond ? arith::AndIOp::create(b, op.getLoc(), cond, cmp) : cmp; map.map(std::get<0>(bound), constant); } - auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true); + auto ifOp = scf::IfOp::create(b, op.getLoc(), cond, /*withElseRegion=*/true); ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); ifOp.getElseBodyBuilder().clone(*op.getOperation()); op.erase(); @@ -94,11 +94,11 @@ static void specializeForLoopForUnrolling(ForOp op) { OpBuilder b(op); IRMapping map; - Value constant = b.create<arith::ConstantIndexOp>(op.getLoc(), minConstant); - Value cond = b.create<arith::CmpIOp>(op.getLoc(), arith::CmpIPredicate::eq, - bound, constant); + Value constant = arith::ConstantIndexOp::create(b, op.getLoc(), minConstant); + Value cond = arith::CmpIOp::create(b, op.getLoc(), arith::CmpIPredicate::eq, + bound, constant); map.map(bound, constant); - auto ifOp = b.create<scf::IfOp>(op.getLoc(), cond, /*withElseRegion=*/true); + auto ifOp = scf::IfOp::create(b, op.getLoc(), cond, /*withElseRegion=*/true); ifOp.getThenBodyBuilder().clone(*op.getOperation(), map); ifOp.getElseBodyBuilder().clone(*op.getOperation()); op.erase(); diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index ad12673..694cd85 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -190,8 +190,8 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, IRRewriter b(builder); b.setInsertionPoint(secondPloop); - auto newSecondPloop = b.create<ParallelOp>( - secondPloop.getLoc(), secondPloop.getLowerBound(), + auto newSecondPloop = ParallelOp::create( + b, secondPloop.getLoc(), secondPloop.getLowerBound(), secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars); Block *newBlock = newSecondPloop.getBody(); @@ -212,7 +212,7 @@ static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop, SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end()); newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end()); - auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs); + auto newReduceOp = scf::ReduceOp::create(b, term2.getLoc(), newReduceArgs); for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>( term1.getReductions(), term2.getReductions()))) { diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp index 66f7bc2..081f5fb 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopTiling.cpp @@ -58,28 +58,28 @@ std::pair<ParallelOp, ParallelOp> mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes, bool noMinMaxBounds) { OpBuilder b(op); - auto zero = b.create<arith::ConstantIndexOp>(op.getLoc(), 0); + auto zero = arith::ConstantIndexOp::create(b, op.getLoc(), 0); SmallVector<Value, 2> tileSizeConstants; tileSizeConstants.reserve(op.getUpperBound().size()); for (size_t i = 0, end = op.getUpperBound().size(); i != end; ++i) { if (i < tileSizes.size()) tileSizeConstants.push_back( - b.create<arith::ConstantIndexOp>(op.getLoc(), tileSizes[i])); + arith::ConstantIndexOp::create(b, op.getLoc(), tileSizes[i])); else // Just pick 1 for the remaining dimensions. tileSizeConstants.push_back( - b.create<arith::ConstantIndexOp>(op.getLoc(), 1)); + arith::ConstantIndexOp::create(b, op.getLoc(), 1)); } // Create the outer loop with adjusted steps. SmallVector<Value, 2> newSteps; newSteps.reserve(op.getStep().size()); for (auto step : llvm::zip(op.getStep(), tileSizeConstants)) { - newSteps.push_back(b.create<arith::MulIOp>(op.getLoc(), std::get<0>(step), - std::get<1>(step))); + newSteps.push_back(arith::MulIOp::create(b, op.getLoc(), std::get<0>(step), + std::get<1>(step))); } - auto outerLoop = b.create<ParallelOp>(op.getLoc(), op.getLowerBound(), - op.getUpperBound(), newSteps); + auto outerLoop = ParallelOp::create(b, op.getLoc(), op.getLowerBound(), + op.getUpperBound(), newSteps); b.setInsertionPointToStart(outerLoop.getBody()); // Compute min(size, dim - offset) to avoid out-of-bounds accesses. @@ -100,11 +100,10 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes, op.getStep(), tileSizeConstants)) { // Collect the statically known loop bounds auto lowerBoundConstant = - dyn_cast_or_null<arith::ConstantIndexOp>(lowerBound.getDefiningOp()); + lowerBound.getDefiningOp<arith::ConstantIndexOp>(); auto upperBoundConstant = - dyn_cast_or_null<arith::ConstantIndexOp>(upperBound.getDefiningOp()); - auto stepConstant = - dyn_cast_or_null<arith::ConstantIndexOp>(step.getDefiningOp()); + upperBound.getDefiningOp<arith::ConstantIndexOp>(); + auto stepConstant = step.getDefiningOp<arith::ConstantIndexOp>(); auto tileSize = cast<arith::ConstantIndexOp>(tileSizeConstant.getDefiningOp()).value(); // If the loop bounds and the loop step are constant and if the number of @@ -130,45 +129,45 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes, // Otherwise, we dynamically compute the bound for // each iteration of the outer loop. newBounds.push_back( - b.create<affine::AffineMinOp>(op.getLoc(), b.getIndexType(), minMap, - ValueRange{newStep, upperBound, iv})); + affine::AffineMinOp::create(b, op.getLoc(), b.getIndexType(), minMap, + ValueRange{newStep, upperBound, iv})); } - auto innerLoop = b.create<ParallelOp>( - op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds, + auto innerLoop = ParallelOp::create( + b, op.getLoc(), SmallVector<Value, 2>(newBounds.size(), zero), newBounds, op.getStep()); if (noMinMaxBounds && needInboundCheck) { b.setInsertionPointToStart(innerLoop.getBody()); // Insert in-bound check Value inbound = - b.create<arith::ConstantIntOp>(op.getLoc(), b.getIntegerType(1), 1); + arith::ConstantIntOp::create(b, op.getLoc(), b.getIntegerType(1), 1); for (auto [outerUpperBound, outerIV, innerIV, innerStep] : llvm::zip(outerLoop.getUpperBound(), outerLoop.getInductionVars(), innerLoop.getInductionVars(), innerLoop.getStep())) { // %in_bound = %in_bound && // (%inner_iv * %inner_step + %outer_iv < %outer_upper_bound) - Value index = b.create<arith::AddIOp>( - op.getLoc(), b.create<arith::MulIOp>(op.getLoc(), innerIV, innerStep), - outerIV); - Value dimInbound = b.create<arith::CmpIOp>( - op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound); - inbound = b.create<arith::AndIOp>(op.getLoc(), inbound, dimInbound); + Value index = arith::AddIOp::create( + b, op.getLoc(), + arith::MulIOp::create(b, op.getLoc(), innerIV, innerStep), outerIV); + Value dimInbound = arith::CmpIOp::create( + b, op.getLoc(), arith::CmpIPredicate::ult, index, outerUpperBound); + inbound = arith::AndIOp::create(b, op.getLoc(), inbound, dimInbound); } - auto ifInbound = b.create<IfOp>(op.getLoc(), - /*resultTypes*/ ArrayRef<Type>{}, inbound, - /*hasElseRegion*/ false); + auto ifInbound = IfOp::create(b, op.getLoc(), + /*resultTypes*/ ArrayRef<Type>{}, inbound, + /*hasElseRegion*/ false); ifInbound.getThenRegion().takeBody(op.getRegion()); Block &thenBlock = ifInbound.getThenRegion().front(); // Replace the scf.reduce terminator with an scf.yield terminator. Operation *reduceOp = thenBlock.getTerminator(); b.setInsertionPointToEnd(&thenBlock); - b.create<scf::YieldOp>(reduceOp->getLoc()); + scf::YieldOp::create(b, reduceOp->getLoc()); reduceOp->erase(); b.setInsertionPointToStart(innerLoop.getBody()); for (const auto &ivs : llvm::enumerate(llvm::zip( innerLoop.getInductionVars(), outerLoop.getInductionVars()))) { - auto newIndex = b.create<arith::AddIOp>( - op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value())); + auto newIndex = arith::AddIOp::create( + b, op.getLoc(), std::get<0>(ivs.value()), std::get<1>(ivs.value())); thenBlock.getArgument(ivs.index()) .replaceAllUsesExcept(newIndex, newIndex); } @@ -179,8 +178,8 @@ mlir::scf::tileParallelLoop(ParallelOp op, ArrayRef<int64_t> tileSizes, for (auto ivs : llvm::zip(innerLoop.getInductionVars(), outerLoop.getInductionVars())) { Value innerIndex = std::get<0>(ivs); - auto newIndex = b.create<arith::AddIOp>(op.getLoc(), std::get<0>(ivs), - std::get<1>(ivs)); + auto newIndex = arith::AddIOp::create(b, op.getLoc(), std::get<0>(ivs), + std::get<1>(ivs)); innerIndex.replaceAllUsesExcept(newIndex, newIndex); } } diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 0932624..1b07b77 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -112,11 +112,11 @@ public: // We can not do clone as the number of result types after conversion // might be different. - ForOp newOp = rewriter.create<ForOp>( - op.getLoc(), llvm::getSingleElement(adaptor.getLowerBound()), - llvm::getSingleElement(adaptor.getUpperBound()), - llvm::getSingleElement(adaptor.getStep()), - flattenValues(adaptor.getInitArgs())); + ForOp newOp = ForOp::create(rewriter, op.getLoc(), + llvm::getSingleElement(adaptor.getLowerBound()), + llvm::getSingleElement(adaptor.getUpperBound()), + llvm::getSingleElement(adaptor.getStep()), + flattenValues(adaptor.getInitArgs())); // Reserve whatever attributes in the original op. newOp->setAttrs(op->getAttrs()); @@ -142,9 +142,9 @@ public: ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - IfOp newOp = rewriter.create<IfOp>( - op.getLoc(), dstTypes, llvm::getSingleElement(adaptor.getCondition()), - true); + IfOp newOp = + IfOp::create(rewriter, op.getLoc(), dstTypes, + llvm::getSingleElement(adaptor.getCondition()), true); newOp->setAttrs(op->getAttrs()); // We do not need the empty blocks created by rewriter. @@ -171,8 +171,8 @@ public: std::optional<WhileOp> convertSourceOp(WhileOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter, TypeRange dstTypes) const { - auto newOp = rewriter.create<WhileOp>(op.getLoc(), dstTypes, - flattenValues(adaptor.getOperands())); + auto newOp = WhileOp::create(rewriter, op.getLoc(), dstTypes, + flattenValues(adaptor.getOperands())); for (auto i : {0u, 1u}) { if (failed(rewriter.convertRegionTypes(&op.getRegion(i), *typeConverter))) diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 484b03d..c0e47ee 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -447,9 +447,9 @@ static LogicalResult generateLoopNestUsingForOp( SmallVector<Value> ivs; for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) { auto loop = - rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors, - [](OpBuilder &bodyBuilder, Location bodyLoc, - Value iv, ValueRange /*iterArgs*/) {}); + scf::ForOp::create(rewriter, loc, lb, ub, step, destinationTensors, + [](OpBuilder &bodyBuilder, Location bodyLoc, + Value iv, ValueRange /*iterArgs*/) {}); loops.push_back(loop); ivs.push_back(loop.getInductionVar()); rewriter.setInsertionPointToEnd(loop.getBody()); @@ -476,12 +476,12 @@ static LogicalResult generateLoopNestUsingForOp( resultSizes)) { SmallVector<OpFoldResult> resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); - auto insertSlice = rewriter.create<tensor::InsertSliceOp>( - loc, tiledValue, destinationTensor, resultOffset, resultSize, + auto insertSlice = tensor::InsertSliceOp::create( + rewriter, loc, tiledValue, destinationTensor, resultOffset, resultSize, resultStride); yieldedValues.push_back(insertSlice); } - rewriter.create<scf::YieldOp>(loc, yieldedValues); + scf::YieldOp::create(rewriter, loc, yieldedValues); // Add the scf.yield operations for all the outer loops. for (auto [outerLoop, innerLoop] : @@ -489,7 +489,7 @@ static LogicalResult generateLoopNestUsingForOp( MutableArrayRef(loops).drop_front())) { rewriter.setInsertionPointToEnd( cast<scf::ForOp>(outerLoop.getOperation()).getBody()); - rewriter.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop->getResults()); + scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults()); } return success(); } @@ -530,14 +530,14 @@ static LogicalResult generateLoopNestUsingForallOp( continue; nonZeroNumThreads.push_back(nt); } - forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads, - destinationTensors, mappingAttr); + forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads, + destinationTensors, mappingAttr); } else { SmallVector<OpFoldResult> lbs, ubs, steps; std::tie(lbs, ubs, steps) = getLoopBounds(rewriter, loc, loopRanges, tileSizes); - forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps, - destinationTensors, mappingAttr); + forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps, + destinationTensors, mappingAttr); } loops.push_back(forallOp); @@ -558,9 +558,9 @@ static LogicalResult generateLoopNestUsingForallOp( SmallVector<OpFoldResult> resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); - rewriter.create<tensor::ParallelInsertSliceOp>( - loc, tiledValue, destinationTensor, resultOffset, resultSize, - resultStride); + tensor::ParallelInsertSliceOp::create(rewriter, loc, tiledValue, + destinationTensor, resultOffset, + resultSize, resultStride); } return success(); } @@ -795,9 +795,9 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( auto inits = llvm::to_vector(loopOp.getInitArgs()); inits.append(newInitOperands.begin(), newInitOperands.end()); - auto newLoop = rewriter.create<scf::ForOp>( - loc, loopOp.getLowerBound(), loopOp.getUpperBound(), loopOp.getStep(), - inits, [](OpBuilder &, Location, Value, ValueRange) {}); + auto newLoop = scf::ForOp::create( + rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(), + loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); // Move the loop body to the new op. Block *loopBody = loopOp.getBody(); @@ -826,9 +826,9 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( resultSizes)) { SmallVector<OpFoldResult> resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); - Value insert = rewriter.create<tensor::InsertSliceOp>( - yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, resultSize, - resultStride); + Value insert = tensor::InsertSliceOp::create( + rewriter, yieldOp->getLoc(), tiledValue, regionIterArg, resultOffset, + resultSize, resultStride); newYieldValues.push_back(insert); } @@ -848,8 +848,8 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( rewriter.setInsertionPoint(loopOp); auto inits = llvm::to_vector(loopOp.getOutputs()); inits.append(newInitOperands.begin(), newInitOperands.end()); - auto newLoop = rewriter.create<scf::ForallOp>( - loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), + auto newLoop = scf::ForallOp::create( + rewriter, loc, loopOp.getMixedLowerBound(), loopOp.getMixedUpperBound(), loopOp.getMixedStep(), inits, loopOp.getMapping(), [](OpBuilder &, Location, ValueRange) {}); @@ -881,9 +881,9 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForallOp>( tiledValues, regionIterArgs, resultOffsets, resultSizes)) { SmallVector<OpFoldResult> resultStride(resultOffset.size(), rewriter.getIndexAttr(1)); - rewriter.create<tensor::ParallelInsertSliceOp>( - terminator.getLoc(), tiledValue, iterArg, resultOffset, resultSize, - resultStride); + tensor::ParallelInsertSliceOp::create(rewriter, terminator.getLoc(), + tiledValue, iterArg, resultOffset, + resultSize, resultStride); } rewriter.replaceOp(loopOp, @@ -932,9 +932,9 @@ static LogicalResult addInitOperandsToLoopNest( // Create a new loop with the new init values for this loop. SmallVector<Value> newInits = llvm::to_vector(forLoop.getInitArgs()); newInits.append(newInitValues.begin(), newInitValues.end()); - auto newLoop = rewriter.create<scf::ForOp>( - forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), - forLoop.getStep(), newInits, + auto newLoop = scf::ForOp::create( + rewriter, forLoop.getLoc(), forLoop.getLowerBound(), + forLoop.getUpperBound(), forLoop.getStep(), newInits, [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); // Merge the body of the new loop with the body of the old loops. @@ -1416,8 +1416,8 @@ FailureOr<SmallVector<Operation *>> mlir::scf::yieldReplacementForFusedProducer( rewriter.setInsertionPoint(tiledDestStyleOp); for (const auto &&[index, newRegionArg] : llvm::enumerate(newRegionIterArgs)) { - auto destSlice = rewriter.create<tensor::ExtractSliceOp>( - loc, newRegionArg, offsetList[index], sizesList[index], + auto destSlice = tensor::ExtractSliceOp::create( + rewriter, loc, newRegionArg, offsetList[index], sizesList[index], SmallVector<OpFoldResult>(offsetList[index].size(), rewriter.getIndexAttr(1))); generatedSlices.push_back(destSlice); @@ -2089,8 +2089,8 @@ cloneAsInsertSlice<tensor::InsertSliceOp>(RewriterBase &rewriter, template <> tensor::InsertSliceOp cloneAsInsertSlice<tensor::ParallelInsertSliceOp>( RewriterBase &rewriter, tensor::ParallelInsertSliceOp insertSliceOp) { - return rewriter.create<tensor::InsertSliceOp>( - insertSliceOp->getLoc(), insertSliceOp.getSource(), + return tensor::InsertSliceOp::create( + rewriter, insertSliceOp->getLoc(), insertSliceOp.getSource(), insertSliceOp.getDest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); } @@ -2311,8 +2311,9 @@ mlir::scf::tileAndFuseConsumerOfSlices( rewriter.setInsertionPoint(tiledDestStyleOp); for (const auto &&[index, newRegionArg] : llvm::enumerate(newRegionIterArgs)) { - auto destSlice = rewriter.create<tensor::ExtractSliceOp>( - loc, newRegionArg, resultOffsets[index], resultSizes[index], + auto destSlice = tensor::ExtractSliceOp::create( + rewriter, loc, newRegionArg, resultOffsets[index], + resultSizes[index], SmallVector<OpFoldResult>(resultOffsets[index].size(), rewriter.getIndexAttr(1))); // Make a copy of index to avoid a capturing structured binding, which @@ -2388,8 +2389,8 @@ mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); Value strideVal = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); - auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, - strideVal, ValueRange{}); + auto loop = scf::ForOp::create(rewriter, op.getLoc(), offsetVal, sizeVal, + strideVal, ValueRange{}); loops.push_back(loop); ivs.push_back(loop.getInductionVar()); rewriter.setInsertionPoint(loop.getBody()->getTerminator()); diff --git a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp index 7e9a4d7..ec1044a 100644 --- a/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/UpliftWhileToFor.cpp @@ -189,7 +189,7 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, // dummy builder instead. auto emptyBuilder = [](OpBuilder &, Location, Value, ValueRange) {}; auto newLoop = - rewriter.create<scf::ForOp>(loc, lb, ub, step, newArgs, emptyBuilder); + scf::ForOp::create(rewriter, loc, lb, ub, step, newArgs, emptyBuilder); Block *newBody = newLoop.getBody(); @@ -236,18 +236,18 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, rewriter.setInsertionPointAfter(newLoop); Value one; if (isa<IndexType>(step.getType())) { - one = rewriter.create<arith::ConstantIndexOp>(loc, 1); + one = arith::ConstantIndexOp::create(rewriter, loc, 1); } else { - one = rewriter.create<arith::ConstantIntOp>(loc, step.getType(), 1); + one = arith::ConstantIntOp::create(rewriter, loc, step.getType(), 1); } - Value stepDec = rewriter.create<arith::SubIOp>(loc, step, one); - Value len = rewriter.create<arith::SubIOp>(loc, ub, lb); - len = rewriter.create<arith::AddIOp>(loc, len, stepDec); - len = rewriter.create<arith::DivSIOp>(loc, len, step); - len = rewriter.create<arith::SubIOp>(loc, len, one); - Value res = rewriter.create<arith::MulIOp>(loc, len, step); - res = rewriter.create<arith::AddIOp>(loc, lb, res); + Value stepDec = arith::SubIOp::create(rewriter, loc, step, one); + Value len = arith::SubIOp::create(rewriter, loc, ub, lb); + len = arith::AddIOp::create(rewriter, loc, len, stepDec); + len = arith::DivSIOp::create(rewriter, loc, len, step); + len = arith::SubIOp::create(rewriter, loc, len, one); + Value res = arith::MulIOp::create(rewriter, loc, len, step); + res = arith::AddIOp::create(rewriter, loc, lb, res); // Reconstruct `scf.while` results, inserting final induction var value // into proper place. diff --git a/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp b/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp index f829208..db504fe 100644 --- a/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/WrapInZeroTripCheck.cpp @@ -96,8 +96,8 @@ FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck( condOp.getArgs(), [&](Value arg) { return mapper.lookupOrDefault(arg); }); // Create rotated while loop. - auto newLoopOp = rewriter.create<scf::WhileOp>( - whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs, + auto newLoopOp = scf::WhileOp::create( + rewriter, whileOp.getLoc(), whileOp.getResultTypes(), clonedCondArgs, [&](OpBuilder &builder, Location loc, ValueRange args) { // Rotate and move the loop body into before block. auto newBlock = builder.getBlock(); @@ -109,21 +109,21 @@ FailureOr<scf::WhileOp> mlir::scf::wrapWhileLoopInZeroTripCheck( }, [&](OpBuilder &builder, Location loc, ValueRange args) { // Pass through values. - builder.create<scf::YieldOp>(loc, args); + scf::YieldOp::create(builder, loc, args); }); // Create zero-trip-check and move the while loop in. - auto ifOp = rewriter.create<scf::IfOp>( - whileOp.getLoc(), clonedCondition, + auto ifOp = scf::IfOp::create( + rewriter, whileOp.getLoc(), clonedCondition, [&](OpBuilder &builder, Location loc) { // Then runs the while loop. rewriter.moveOpBefore(newLoopOp, builder.getInsertionBlock(), builder.getInsertionPoint()); - builder.create<scf::YieldOp>(loc, newLoopOp.getResults()); + scf::YieldOp::create(builder, loc, newLoopOp.getResults()); }, [&](OpBuilder &builder, Location loc) { // Else returns the results from precondition. - builder.create<scf::YieldOp>(loc, clonedCondArgs); + scf::YieldOp::create(builder, loc, clonedCondArgs); }); rewriter.replaceOp(whileOp, ifOp); diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 062268a..5731795 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -24,14 +24,12 @@ #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cstdint> using namespace mlir; #define DEBUG_TYPE "scf-utils" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields( RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest, @@ -149,7 +147,7 @@ FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter, FunctionType::get(rewriter.getContext(), outlinedFuncArgTypes, originalTerminator->getOperandTypes()); auto outlinedFunc = - rewriter.create<func::FuncOp>(loc, funcName, outlinedFuncType); + func::FuncOp::create(rewriter, loc, funcName, outlinedFuncType); Block *outlinedFuncBody = outlinedFunc.addEntryBlock(); // Merge blocks while replacing the original block operands. @@ -164,8 +162,8 @@ FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter, outlinedFuncBlockArgs.take_front(numOriginalBlockArguments)); // Explicitly set up a new ReturnOp terminator. rewriter.setInsertionPointToEnd(outlinedFuncBody); - rewriter.create<func::ReturnOp>(loc, originalTerminator->getResultTypes(), - originalTerminator->getOperands()); + func::ReturnOp::create(rewriter, loc, originalTerminator->getResultTypes(), + originalTerminator->getOperands()); } // Reconstruct the block that was deleted and add a @@ -181,7 +179,7 @@ FailureOr<func::FuncOp> mlir::outlineSingleBlockRegion(RewriterBase &rewriter, SmallVector<Value> callValues; llvm::append_range(callValues, newBlock->getArguments()); llvm::append_range(callValues, outlinedValues); - auto call = rewriter.create<func::CallOp>(loc, outlinedFunc, callValues); + auto call = func::CallOp::create(rewriter, loc, outlinedFunc, callValues); if (callOp) *callOp = call; @@ -270,12 +268,12 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, assert(dividend.getType().isIntOrIndex() && "expected integer or index-typed value"); - Value divisorMinusOneCst = builder.create<arith::ConstantOp>( - loc, builder.getIntegerAttr(dividend.getType(), divisor - 1)); - Value divisorCst = builder.create<arith::ConstantOp>( - loc, builder.getIntegerAttr(dividend.getType(), divisor)); - Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOneCst); - return builder.create<arith::DivUIOp>(loc, sum, divisorCst); + Value divisorMinusOneCst = arith::ConstantOp::create( + builder, loc, builder.getIntegerAttr(dividend.getType(), divisor - 1)); + Value divisorCst = arith::ConstantOp::create( + builder, loc, builder.getIntegerAttr(dividend.getType(), divisor)); + Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOneCst); + return arith::DivUIOp::create(builder, loc, sum, divisorCst); } // Build the IR that performs ceil division of a positive value by another @@ -286,11 +284,11 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, Value divisor) { assert(dividend.getType().isIntOrIndex() && "expected integer or index-typed value"); - Value cstOne = builder.create<arith::ConstantOp>( - loc, builder.getOneAttr(dividend.getType())); - Value divisorMinusOne = builder.create<arith::SubIOp>(loc, divisor, cstOne); - Value sum = builder.create<arith::AddIOp>(loc, dividend, divisorMinusOne); - return builder.create<arith::DivUIOp>(loc, sum, divisor); + Value cstOne = arith::ConstantOp::create( + builder, loc, builder.getOneAttr(dividend.getType())); + Value divisorMinusOne = arith::SubIOp::create(builder, loc, divisor, cstOne); + Value sum = arith::AddIOp::create(builder, loc, dividend, divisorMinusOne); + return arith::DivUIOp::create(builder, loc, sum, divisor); } /// Returns the trip count of `forOp` if its' low bound, high bound and step are @@ -400,18 +398,20 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor( // Create constant for 'upperBoundUnrolled' and set epilogue loop flag. generateEpilogueLoop = upperBoundUnrolledCst < ubCst; if (generateEpilogueLoop) - upperBoundUnrolled = boundsBuilder.create<arith::ConstantOp>( - loc, boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(), - upperBoundUnrolledCst)); + upperBoundUnrolled = arith::ConstantOp::create( + boundsBuilder, loc, + boundsBuilder.getIntegerAttr(forOp.getUpperBound().getType(), + upperBoundUnrolledCst)); else upperBoundUnrolled = forOp.getUpperBound(); // Create constant for 'stepUnrolled'. - stepUnrolled = stepCst == stepUnrolledCst - ? step - : boundsBuilder.create<arith::ConstantOp>( - loc, boundsBuilder.getIntegerAttr( - step.getType(), stepUnrolledCst)); + stepUnrolled = + stepCst == stepUnrolledCst + ? step + : arith::ConstantOp::create(boundsBuilder, loc, + boundsBuilder.getIntegerAttr( + step.getType(), stepUnrolledCst)); } else { // Dynamic loop bounds computation. // TODO: Add dynamic asserts for negative lb/ub/step, or @@ -419,22 +419,23 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor( auto lowerBound = forOp.getLowerBound(); auto upperBound = forOp.getUpperBound(); Value diff = - boundsBuilder.create<arith::SubIOp>(loc, upperBound, lowerBound); + arith::SubIOp::create(boundsBuilder, loc, upperBound, lowerBound); Value tripCount = ceilDivPositive(boundsBuilder, loc, diff, step); - Value unrollFactorCst = boundsBuilder.create<arith::ConstantOp>( - loc, boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor)); + Value unrollFactorCst = arith::ConstantOp::create( + boundsBuilder, loc, + boundsBuilder.getIntegerAttr(tripCount.getType(), unrollFactor)); Value tripCountRem = - boundsBuilder.create<arith::RemSIOp>(loc, tripCount, unrollFactorCst); + arith::RemSIOp::create(boundsBuilder, loc, tripCount, unrollFactorCst); // Compute tripCountEvenMultiple = tripCount - (tripCount % unrollFactor) Value tripCountEvenMultiple = - boundsBuilder.create<arith::SubIOp>(loc, tripCount, tripCountRem); + arith::SubIOp::create(boundsBuilder, loc, tripCount, tripCountRem); // Compute upperBoundUnrolled = lowerBound + tripCountEvenMultiple * step - upperBoundUnrolled = boundsBuilder.create<arith::AddIOp>( - loc, lowerBound, - boundsBuilder.create<arith::MulIOp>(loc, tripCountEvenMultiple, step)); + upperBoundUnrolled = arith::AddIOp::create( + boundsBuilder, loc, lowerBound, + arith::MulIOp::create(boundsBuilder, loc, tripCountEvenMultiple, step)); // Scale 'step' by 'unrollFactor'. stepUnrolled = - boundsBuilder.create<arith::MulIOp>(loc, step, unrollFactorCst); + arith::MulIOp::create(boundsBuilder, loc, step, unrollFactorCst); } UnrolledLoopInfo resultLoops; @@ -470,11 +471,11 @@ FailureOr<UnrolledLoopInfo> mlir::loopUnrollByFactor( forOp.getBody(), forOp.getInductionVar(), unrollFactor, [&](unsigned i, Value iv, OpBuilder b) { // iv' = iv + step * i; - auto stride = b.create<arith::MulIOp>( - loc, step, - b.create<arith::ConstantOp>(loc, - b.getIntegerAttr(iv.getType(), i))); - return b.create<arith::AddIOp>(loc, iv, stride); + auto stride = arith::MulIOp::create( + b, loc, step, + arith::ConstantOp::create(b, loc, + b.getIntegerAttr(iv.getType(), i))); + return arith::AddIOp::create(b, loc, iv, stride); }, annotateFn, iterArgs, yieldedValues); // Promote the loop body up if this has turned into a single iteration loop. @@ -522,13 +523,13 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp, // If any control operand of any inner loop of `forOp` is defined within // `forOp`, no unroll jam. if (!areInnerBoundsInvariant(forOp)) { - LDBG("failed to unroll and jam: inner bounds are not invariant"); + LDBG() << "failed to unroll and jam: inner bounds are not invariant"; return failure(); } // Currently, for operations with results are not supported. if (forOp->getNumResults() > 0) { - LDBG("failed to unroll and jam: unsupported loop with results"); + LDBG() << "failed to unroll and jam: unsupported loop with results"; return failure(); } @@ -537,16 +538,17 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp, std::optional<uint64_t> tripCount = getConstantTripCount(forOp); if (!tripCount.has_value()) { // If the trip count is dynamic, do not unroll & jam. - LDBG("failed to unroll and jam: trip count could not be determined"); + LDBG() << "failed to unroll and jam: trip count could not be determined"; return failure(); } if (unrollJamFactor > *tripCount) { - LDBG("unroll and jam factor is greater than trip count, set factor to trip " - "count"); + LDBG() << "unroll and jam factor is greater than trip count, set factor to " + "trip " + "count"; unrollJamFactor = *tripCount; } else if (*tripCount % unrollJamFactor != 0) { - LDBG("failed to unroll and jam: unsupported trip count that is not a " - "multiple of unroll jam factor"); + LDBG() << "failed to unroll and jam: unsupported trip count that is not a " + "multiple of unroll jam factor"; return failure(); } @@ -777,13 +779,13 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc, if (!isStepOne) { Value origStepValue = getValueOrCreateConstantIntOp(rewriter, loc, origStep); - scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue); + scaled = arith::MulIOp::create(rewriter, loc, normalizedIv, origStepValue); preserve.insert(scaled.getDefiningOp()); } denormalizedIv = scaled; if (!isZeroBased) { Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb); - denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue); + denormalizedIv = arith::AddIOp::create(rewriter, loc, scaled, origLbValue); preserve.insert(denormalizedIv.getDefiningOp()); } @@ -819,15 +821,14 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc, if (vOne && vOne.value() == 1) continue; if (productOf) - productOf = - rewriter.create<arith::MulIOp>(loc, productOf.value(), v).getResult(); + productOf = arith::MulIOp::create(rewriter, loc, productOf.value(), v) + .getResult(); else productOf = v; } if (!productOf) { - productOf = rewriter - .create<arith::ConstantOp>( - loc, rewriter.getOneAttr(getType(values.front()))) + productOf = arith::ConstantOp::create( + rewriter, loc, rewriter.getOneAttr(getType(values.front()))) .getResult(); } return productOf.value(); @@ -846,9 +847,8 @@ delinearizeInductionVariable(RewriterBase &rewriter, Location loc, Value linearizedIv, ArrayRef<Value> ubs) { if (linearizedIv.getType().isIndex()) { - Operation *delinearizedOp = - rewriter.create<affine::AffineDelinearizeIndexOp>(loc, linearizedIv, - ubs); + Operation *delinearizedOp = affine::AffineDelinearizeIndexOp::create( + rewriter, loc, linearizedIv, ubs); auto resultVals = llvm::map_to_vector( delinearizedOp->getResults(), [](OpResult r) -> Value { return r; }); return {resultVals, SmallPtrSet<Operation *, 2>{delinearizedOp}}; @@ -870,8 +870,8 @@ delinearizeInductionVariable(RewriterBase &rewriter, Location loc, if (!isUbOne.test(index)) { break; } - delinearizedIvs[index] = rewriter.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(ub.getType())); + delinearizedIvs[index] = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(ub.getType())); numLeadingOneUbs++; } @@ -879,17 +879,17 @@ delinearizeInductionVariable(RewriterBase &rewriter, Location loc, for (unsigned i = numLeadingOneUbs, e = ubs.size(); i < e; ++i) { unsigned idx = ubs.size() - (i - numLeadingOneUbs) - 1; if (i != numLeadingOneUbs && !isUbOne.test(idx + 1)) { - previous = rewriter.create<arith::DivSIOp>(loc, previous, ubs[idx + 1]); + previous = arith::DivSIOp::create(rewriter, loc, previous, ubs[idx + 1]); preservedUsers.insert(previous.getDefiningOp()); } Value iv = previous; if (i != e - 1) { if (!isUbOne.test(idx)) { - iv = rewriter.create<arith::RemSIOp>(loc, previous, ubs[idx]); + iv = arith::RemSIOp::create(rewriter, loc, previous, ubs[idx]); preservedUsers.insert(iv.getDefiningOp()); } else { - iv = rewriter.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(ubs[idx].getType())); + iv = arith::ConstantOp::create( + rewriter, loc, rewriter.getZeroAttr(ubs[idx].getType())); } } delinearizedIvs[idx] = iv; @@ -1089,13 +1089,13 @@ void mlir::collapseParallelLoops( // Combine iteration spaces. SmallVector<Value, 3> lowerBounds, upperBounds, steps; - auto cst0 = rewriter.create<arith::ConstantIndexOp>(loc, 0); - auto cst1 = rewriter.create<arith::ConstantIndexOp>(loc, 1); + auto cst0 = arith::ConstantIndexOp::create(rewriter, loc, 0); + auto cst1 = arith::ConstantIndexOp::create(rewriter, loc, 1); for (auto &sortedDimension : sortedDimensions) { - Value newUpperBound = rewriter.create<arith::ConstantIndexOp>(loc, 1); + Value newUpperBound = arith::ConstantIndexOp::create(rewriter, loc, 1); for (auto idx : sortedDimension) { - newUpperBound = rewriter.create<arith::MulIOp>( - loc, newUpperBound, normalizedUpperBounds[idx]); + newUpperBound = arith::MulIOp::create(rewriter, loc, newUpperBound, + normalizedUpperBounds[idx]); } lowerBounds.push_back(cst0); steps.push_back(cst1); @@ -1108,8 +1108,8 @@ void mlir::collapseParallelLoops( // value. The remainders then determine based on that range, which iteration // of the original induction value this represents. This is a normalized value // that is un-normalized already by the previous logic. - auto newPloop = rewriter.create<scf::ParallelOp>( - loc, lowerBounds, upperBounds, steps, + auto newPloop = scf::ParallelOp::create( + rewriter, loc, lowerBounds, upperBounds, steps, [&](OpBuilder &insideBuilder, Location, ValueRange ploopIVs) { for (unsigned i = 0, e = combinedDimensions.size(); i < e; ++i) { Value previous = ploopIVs[i]; @@ -1119,15 +1119,15 @@ void mlir::collapseParallelLoops( unsigned idx = combinedDimensions[i][j]; // Determine the current induction value's current loop iteration - Value iv = insideBuilder.create<arith::RemSIOp>( - loc, previous, normalizedUpperBounds[idx]); + Value iv = arith::RemSIOp::create(insideBuilder, loc, previous, + normalizedUpperBounds[idx]); replaceAllUsesInRegionWith(loops.getBody()->getArgument(idx), iv, loops.getRegion()); // Remove the effect of the current induction value to prepare for // the next value. - previous = insideBuilder.create<arith::DivSIOp>( - loc, previous, normalizedUpperBounds[idx]); + previous = arith::DivSIOp::create(insideBuilder, loc, previous, + normalizedUpperBounds[idx]); } // The final induction value is just the remaining value. @@ -1237,7 +1237,7 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor, auto iv = forOp.getInductionVar(); OpBuilder b(forOp); - forOp.setStep(b.create<arith::MulIOp>(forOp.getLoc(), originalStep, factor)); + forOp.setStep(arith::MulIOp::create(b, forOp.getLoc(), originalStep, factor)); Loops innerLoops; for (auto t : targets) { @@ -1247,12 +1247,12 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor, // Insert newForOp before the terminator of `t`. auto b = OpBuilder::atBlockTerminator((t.getBody())); - Value stepped = b.create<arith::AddIOp>(t.getLoc(), iv, forOp.getStep()); + Value stepped = arith::AddIOp::create(b, t.getLoc(), iv, forOp.getStep()); Value ub = - b.create<arith::MinSIOp>(t.getLoc(), forOp.getUpperBound(), stepped); + arith::MinSIOp::create(b, t.getLoc(), forOp.getUpperBound(), stepped); // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses. - auto newForOp = b.create<scf::ForOp>(t.getLoc(), iv, ub, originalStep); + auto newForOp = scf::ForOp::create(b, t.getLoc(), iv, ub, originalStep); newForOp.getBody()->getOperations().splice( newForOp.getBody()->getOperations().begin(), t.getBody()->getOperations(), begin, std::next(begin, nOps - 1)); @@ -1339,8 +1339,8 @@ TileLoops mlir::extractFixedOuterLoops(scf::ForOp rootForOp, auto forOp = forOps[i]; OpBuilder builder(forOp); auto loc = forOp.getLoc(); - Value diff = builder.create<arith::SubIOp>(loc, forOp.getUpperBound(), - forOp.getLowerBound()); + Value diff = arith::SubIOp::create(builder, loc, forOp.getUpperBound(), + forOp.getLowerBound()); Value numIterations = ceilDivPositive(builder, loc, diff, forOp.getStep()); Value iterationsPerBlock = ceilDivPositive(builder, loc, numIterations, sizes[i]); @@ -1372,9 +1372,10 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, // Create a new scf.forall op after the source loop. rewriter.setInsertionPointAfter(source); - scf::ForallOp fusedLoop = rewriter.create<scf::ForallOp>( - source.getLoc(), source.getMixedLowerBound(), source.getMixedUpperBound(), - source.getMixedStep(), fusedOuts, source.getMapping()); + scf::ForallOp fusedLoop = scf::ForallOp::create( + rewriter, source.getLoc(), source.getMixedLowerBound(), + source.getMixedUpperBound(), source.getMixedStep(), fusedOuts, + source.getMapping()); // Map control operands. IRMapping mapping; @@ -1425,8 +1426,8 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, // Create a new scf.for op after the source loop (with scf.yield terminator // (without arguments) only in case its init_args is empty). rewriter.setInsertionPointAfter(source); - scf::ForOp fusedLoop = rewriter.create<scf::ForOp>( - source.getLoc(), source.getLowerBound(), source.getUpperBound(), + scf::ForOp fusedLoop = scf::ForOp::create( + rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(), source.getStep(), fusedInitArgs); // Map original induction variables and operands to those of the fused loop. @@ -1452,7 +1453,7 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, for (Value operand : source.getBody()->getTerminator()->getOperands()) yieldResults.push_back(mapping.lookupOrDefault(operand)); if (!yieldResults.empty()) - rewriter.create<scf::YieldOp>(source.getLoc(), yieldResults); + scf::YieldOp::create(rewriter, source.getLoc(), yieldResults); // Replace old loops by substituting their uses by results of the fused loop. rewriter.replaceOp(target, fusedLoop.getResults().take_front(numTargetOuts)); @@ -1483,8 +1484,8 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter, // Use the normalized builder since the lower bounds are always 0 and the // steps are always 1. - auto normalizedForallOp = rewriter.create<scf::ForallOp>( - loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(), + auto normalizedForallOp = scf::ForallOp::create( + rewriter, loc, newUbs, forallOp.getOutputs(), forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {}); rewriter.inlineRegionBefore(forallOp.getBodyRegion(), diff --git a/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp b/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp index 66eed86..48c0b1e 100644 --- a/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp +++ b/mlir/lib/Dialect/SMT/IR/SMTDialect.cpp @@ -30,14 +30,14 @@ Operation *SMTDialect::materializeConstant(OpBuilder &builder, Attribute value, if (auto attrValue = dyn_cast<BitVectorAttr>(value)) { assert(bvType == attrValue.getType() && "attribute and desired result types have to match"); - return builder.create<BVConstantOp>(loc, attrValue); + return BVConstantOp::create(builder, loc, attrValue); } } // BoolType constants can materialize into smt.constant if (auto boolType = dyn_cast<BoolType>(type)) { if (auto attrValue = dyn_cast<BoolAttr>(value)) - return builder.create<BoolConstantOp>(loc, attrValue); + return BoolConstantOp::create(builder, loc, attrValue); } return nullptr; diff --git a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp index 8977a3a..c517ef2 100644 --- a/mlir/lib/Dialect/SMT/IR/SMTOps.cpp +++ b/mlir/lib/Dialect/SMT/IR/SMTOps.cpp @@ -405,7 +405,7 @@ static void buildQuantifier( SmallVector<Location>(boundVarTypes.size(), odsState.location)); Value returnVal = bodyBuilder(odsBuilder, odsState.location, block->getArguments()); - odsBuilder.create<smt::YieldOp>(odsState.location, returnVal); + smt::YieldOp::create(odsBuilder, odsState.location, returnVal); } if (patternBuilder) { Region *region = odsState.addRegion(); @@ -416,7 +416,7 @@ static void buildQuantifier( SmallVector<Location>(boundVarTypes.size(), odsState.location)); ValueRange returnVals = patternBuilder(odsBuilder, odsState.location, block->getArguments()); - odsBuilder.create<smt::YieldOp>(odsState.location, returnVals); + smt::YieldOp::create(odsBuilder, odsState.location, returnVals); } } diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp index c9a8e97..fcf1526 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp @@ -92,11 +92,13 @@ struct SPIRVInlinerInterface : public DialectInlinerInterface { /// as necessary. void handleTerminator(Operation *op, Block *newDest) const final { if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) { - OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest); + auto builder = OpBuilder(op); + spirv::BranchOp::create(builder, op->getLoc(), newDest); op->erase(); } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) { - OpBuilder(op).create<spirv::BranchOp>(retValOp->getLoc(), newDest, - retValOp->getOperands()); + auto builder = OpBuilder(op); + spirv::BranchOp::create(builder, retValOp->getLoc(), newDest, + retValOp->getOperands()); op->erase(); } } @@ -665,19 +667,17 @@ static ParseResult parseStructMemberDecorations( // Parse member decoration value if it exists. if (succeeded(parser.parseOptionalEqual())) { - auto memberDecorationValue = - parseAndVerifyInteger<uint32_t>(dialect, parser); - - if (!memberDecorationValue) + Attribute memberDecorationValue; + if (failed(parser.parseAttribute(memberDecorationValue))) return failure(); memberDecorationInfo.emplace_back( - static_cast<uint32_t>(memberTypes.size() - 1), 1, - memberDecoration.value(), memberDecorationValue.value()); + static_cast<uint32_t>(memberTypes.size() - 1), + memberDecoration.value(), memberDecorationValue); } else { memberDecorationInfo.emplace_back( - static_cast<uint32_t>(memberTypes.size() - 1), 0, - memberDecoration.value(), 0); + static_cast<uint32_t>(memberTypes.size() - 1), + memberDecoration.value(), UnitAttr::get(dialect.getContext())); } return success(); }; @@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations( // `!spirv.struct<` (id `,`)? // `(` // (spirv-type (`[` struct-member-decoration `]`)?)* -// `)>` +// `)` +// (`,` struct-decoration)? +// `>` static Type parseStructType(SPIRVDialect const &dialect, DialectAsmParser &parser) { // TODO: This function is quite lengthy. Break it down into smaller chunks. @@ -767,17 +769,48 @@ static Type parseStructType(SPIRVDialect const &dialect, return Type(); } - if (failed(parser.parseRParen()) || failed(parser.parseGreater())) + if (failed(parser.parseRParen())) + return Type(); + + SmallVector<StructType::StructDecorationInfo, 1> structDecorationInfo; + + auto parseStructDecoration = [&]() { + std::optional<spirv::Decoration> decoration = + parseAndVerify<spirv::Decoration>(dialect, parser); + if (!decoration) + return failure(); + + // Parse decoration value if it exists. + if (succeeded(parser.parseOptionalEqual())) { + Attribute decorationValue; + if (failed(parser.parseAttribute(decorationValue))) + return failure(); + + structDecorationInfo.emplace_back(decoration.value(), decorationValue); + } else { + structDecorationInfo.emplace_back(decoration.value(), + UnitAttr::get(dialect.getContext())); + } + return success(); + }; + + while (succeeded(parser.parseOptionalComma())) + if (failed(parseStructDecoration())) + return Type(); + + if (failed(parser.parseGreater())) return Type(); if (!identifier.empty()) { if (failed(idStructTy.trySetBody(memberTypes, offsetInfo, - memberDecorationInfo))) + memberDecorationInfo, + structDecorationInfo))) return Type(); return idStructTy; } - return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); + return StructType::get(memberTypes, offsetInfo, memberDecorationInfo, + structDecorationInfo); } // spirv-type ::= array-type @@ -882,8 +915,9 @@ static void print(StructType type, DialectAsmPrinter &os) { } auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) { os << stringifyDecoration(decoration.decoration); - if (decoration.hasValue) { - os << "=" << decoration.decorationValue; + if (decoration.hasValue()) { + os << "="; + os.printAttributeWithoutType(decoration.decorationValue); } }; llvm::interleaveComma(decorations, os, eachFn); @@ -892,7 +926,23 @@ static void print(StructType type, DialectAsmPrinter &os) { }; llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os, printMember); - os << ")>"; + os << ")"; + + SmallVector<spirv::StructType::StructDecorationInfo, 1> decorations; + type.getStructDecorations(decorations); + if (!decorations.empty()) { + os << ", "; + auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) { + os << stringifyDecoration(decoration.decoration); + if (decoration.hasValue()) { + os << "="; + os.printAttributeWithoutType(decoration.decorationValue); + } + }; + llvm::interleaveComma(decorations, os, eachFn); + } + + os << ">"; } static void print(CooperativeMatrixType type, DialectAsmPrinter &os) { diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index 46739bc..ddb3426 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -835,12 +835,14 @@ void SampledImageType::getCapabilities( /// - for literal structs: /// - a list of member types; /// - a list of member offset info; -/// - a list of member decoration info. +/// - a list of member decoration info; +/// - a list of struct decoration info. /// /// Identified structures only have a mutable component consisting of: /// - a list of member types; /// - a list of member offset info; -/// - a list of member decoration info. +/// - a list of member decoration info; +/// - a list of struct decoration info. struct spirv::detail::StructTypeStorage : public TypeStorage { /// Construct a storage object for an identified struct type. A struct type /// associated with such storage must call StructType::trySetBody(...) later @@ -848,6 +850,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage(StringRef identifier) : memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr), numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr), + numStructDecorations(0), structDecorationsInfo(nullptr), identifier(identifier) {} /// Construct a storage object for a literal struct type. A struct type @@ -855,10 +858,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { StructTypeStorage( unsigned numMembers, Type const *memberTypes, StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations, - StructType::MemberDecorationInfo const *memberDecorationsInfo) + StructType::MemberDecorationInfo const *memberDecorationsInfo, + unsigned numStructDecorations, + StructType::StructDecorationInfo const *structDecorationsInfo) : memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo), numMembers(numMembers), numMemberDecorations(numMemberDecorations), - memberDecorationsInfo(memberDecorationsInfo) {} + memberDecorationsInfo(memberDecorationsInfo), + numStructDecorations(numStructDecorations), + structDecorationsInfo(structDecorationsInfo) {} /// A storage key is divided into 2 parts: /// - for identified structs: @@ -867,16 +874,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { /// - an ArrayRef<Type> for member types; /// - an ArrayRef<StructType::OffsetInfo> for member offset info; /// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration + /// info; + /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration /// info. /// /// An identified struct type is uniqued only by the first part (field 0) /// of the key. /// - /// A literal struct type is uniqued only by the second part (fields 1, 2, and - /// 3) of the key. The identifier field (field 0) must be empty. + /// A literal struct type is uniqued only by the second part (fields 1, 2, 3 + /// and 4) of the key. The identifier field (field 0) must be empty. using KeyTy = std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>, - ArrayRef<StructType::MemberDecorationInfo>>; + ArrayRef<StructType::MemberDecorationInfo>, + ArrayRef<StructType::StructDecorationInfo>>; /// For identified structs, return true if the given key contains the same /// identifier. @@ -890,7 +900,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { } return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(), - getMemberDecorationsInfo()); + getMemberDecorationsInfo(), getStructDecorationsInfo()); } /// If the given key contains a non-empty identifier, this method constructs @@ -937,9 +947,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { memberDecorationList = allocator.copyInto(keyMemberDecorations).data(); } - return new (allocator.allocate<StructTypeStorage>()) - StructTypeStorage(keyTypes.size(), typesList, offsetInfoList, - numMemberDecorations, memberDecorationList); + const StructType::StructDecorationInfo *structDecorationList = nullptr; + unsigned numStructDecorations = 0; + if (!std::get<4>(key).empty()) { + auto keyStructDecorations = std::get<4>(key); + numStructDecorations = keyStructDecorations.size(); + structDecorationList = allocator.copyInto(keyStructDecorations).data(); + } + + return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage( + keyTypes.size(), typesList, offsetInfoList, numMemberDecorations, + memberDecorationList, numStructDecorations, structDecorationList); } ArrayRef<Type> getMemberTypes() const { @@ -961,6 +979,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { return {}; } + ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const { + if (structDecorationsInfo) + return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo, + numStructDecorations); + return {}; + } + StringRef getIdentifier() const { return identifier; } bool isIdentified() const { return !identifier.empty(); } @@ -973,17 +998,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { /// - If called for an identified struct whose body was set before (through a /// call to this method) but with different contents from the passed /// arguments. - LogicalResult mutate( - TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes, - ArrayRef<StructType::OffsetInfo> structOffsetInfo, - ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) { + LogicalResult + mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes, + ArrayRef<StructType::OffsetInfo> structOffsetInfo, + ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo, + ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) { if (!isIdentified()) return failure(); if (memberTypesAndIsBodySet.getInt() && (getMemberTypes() != structMemberTypes || getOffsetInfo() != structOffsetInfo || - getMemberDecorationsInfo() != structMemberDecorationInfo)) + getMemberDecorationsInfo() != structMemberDecorationInfo || + getStructDecorationsInfo() != structDecorationInfo)) return failure(); memberTypesAndIsBodySet.setInt(true); @@ -1007,6 +1034,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { allocator.copyInto(structMemberDecorationInfo).data(); } + if (!structDecorationInfo.empty()) { + numStructDecorations = structDecorationInfo.size(); + structDecorationsInfo = allocator.copyInto(structDecorationInfo).data(); + } + return success(); } @@ -1015,21 +1047,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage { unsigned numMembers; unsigned numMemberDecorations; StructType::MemberDecorationInfo const *memberDecorationsInfo; + unsigned numStructDecorations; + StructType::StructDecorationInfo const *structDecorationsInfo; StringRef identifier; }; StructType StructType::get(ArrayRef<Type> memberTypes, ArrayRef<StructType::OffsetInfo> offsetInfo, - ArrayRef<StructType::MemberDecorationInfo> memberDecorations) { + ArrayRef<StructType::MemberDecorationInfo> memberDecorations, + ArrayRef<StructType::StructDecorationInfo> structDecorations) { assert(!memberTypes.empty() && "Struct needs at least one member type"); // Sort the decorations. - SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations( + SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations( memberDecorations); - llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end()); + llvm::array_pod_sort(sortedMemberDecorations.begin(), + sortedMemberDecorations.end()); + SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations( + structDecorations); + llvm::array_pod_sort(sortedStructDecorations.begin(), + sortedStructDecorations.end()); + return Base::get(memberTypes.vec().front().getContext(), /*identifier=*/StringRef(), memberTypes, offsetInfo, - sortedDecorations); + sortedMemberDecorations, sortedStructDecorations); } StructType StructType::getIdentified(MLIRContext *context, @@ -1039,18 +1080,21 @@ StructType StructType::getIdentified(MLIRContext *context, return Base::get(context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()); + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()); } StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) { StructType newStructType = Base::get( context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()); + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()); // Set an empty body in case this is a identified struct. if (newStructType.isIdentified() && failed(newStructType.trySetBody( ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(), - ArrayRef<StructType::MemberDecorationInfo>()))) + ArrayRef<StructType::MemberDecorationInfo>(), + ArrayRef<StructType::StructDecorationInfo>()))) return StructType(); return newStructType; @@ -1074,6 +1118,15 @@ TypeRange StructType::getElementTypes() const { bool StructType::hasOffset() const { return getImpl()->offsetInfo; } +bool StructType::hasDecoration(spirv::Decoration decoration) const { + for (StructType::StructDecorationInfo info : + getImpl()->getStructDecorationsInfo()) + if (info.decoration == decoration) + return true; + + return false; +} + uint64_t StructType::getMemberOffset(unsigned index) const { assert(getNumElements() > index && "member index out of range"); return getImpl()->offsetInfo[index]; @@ -1105,11 +1158,21 @@ void StructType::getMemberDecorations( } } +void StructType::getStructDecorations( + SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations) + const { + structDecorations.clear(); + auto implDecorations = getImpl()->getStructDecorationsInfo(); + structDecorations.append(implDecorations.begin(), implDecorations.end()); +} + LogicalResult StructType::trySetBody(ArrayRef<Type> memberTypes, ArrayRef<OffsetInfo> offsetInfo, - ArrayRef<MemberDecorationInfo> memberDecorations) { - return Base::mutate(memberTypes, offsetInfo, memberDecorations); + ArrayRef<MemberDecorationInfo> memberDecorations, + ArrayRef<StructDecorationInfo> structDecorations) { + return Base::mutate(memberTypes, offsetInfo, memberDecorations, + structDecorations); } void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, @@ -1131,6 +1194,11 @@ llvm::hash_code spirv::hash_value( memberDecorationInfo.decoration); } +llvm::hash_code spirv::hash_value( + const StructType::StructDecorationInfo &structDecorationInfo) { + return llvm::hash_value(structDecorationInfo.decoration); +} + //===----------------------------------------------------------------------===// // MatrixType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 81365b4..3911ec0 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -58,7 +58,17 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp, spirv::PointerType::get(spirv::StructType::get(varType), *storageClass); } auto varPtrType = cast<spirv::PointerType>(varType); - auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType()); + Type pointeeType = varPtrType.getPointeeType(); + + // Images are an opaque type and so we can just return a pointer to an image. + // Note that currently only sampled images are supported in the SPIR-V + // lowering. + if (isa<spirv::SampledImageType>(pointeeType)) + return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType, + varName, abiInfo.getDescriptorSet(), + abiInfo.getBinding()); + + auto varPointeeType = cast<spirv::StructType>(pointeeType); // Set the offset information. varPointeeType = diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp index 35ec019..8f4c4cc 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) { return bitWidth / 8; } + // Handle 8-bit floats. + if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) { + auto bitWidth = type.getIntOrFloatBitWidth(); + if (bitWidth == 8) + return bitWidth / 8; + return std::nullopt; + } + if (auto complexType = dyn_cast<ComplexType>(type)) { auto elementSize = getTypeNumBytes(options, complexType.getElementType()); if (!elementSize) @@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options, type.getSignedness()); } +/// Converts 8-bit float types to integer types with the same bit width. +/// Returns a nullptr for unsupported 8-bit float types. +static Type convert8BitFloatType(const SPIRVConversionOptions &options, + FloatType type) { + if (!options.emulateUnsupportedFloatTypes) + return nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, + Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, + Float8E8M0FNUType>(type)) + return IntegerType::get(type.getContext(), type.getWidth()); + LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n"); + return nullptr; +} + +/// Returns a type with the same shape but with any 8-bit float element type +/// converted to the same bit width integer type. This is a noop when the +/// element type is not the 8-bit float type or emulation flag is set to false. +static ShapedType +convertShaped8BitFloatType(ShapedType type, + const SPIRVConversionOptions &options) { + if (!options.emulateUnsupportedFloatTypes) + return type; + Type srcElementType = type.getElementType(); + Type convertedElementType = nullptr; + // F8 types are converted to integer types with the same bit width. + if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType, + Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type, + Float8E8M0FNUType>(srcElementType)) + convertedElementType = IntegerType::get( + type.getContext(), srcElementType.getIntOrFloatBitWidth()); + + if (!convertedElementType) + return type; + + return type.clone(convertedElementType); +} + /// Returns a type with the same shape but with any index element type converted /// to the matching integer type. This is a noop when the element type is not /// the index type. @@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv, const SPIRVConversionOptions &options, VectorType type, std::optional<spirv::StorageClass> storageClass = {}) { type = cast<VectorType>(convertIndexElementType(type, options)); + type = cast<VectorType>(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); if (!scalarType) { // If this is not a spec allowed scalar type, try to handle sub-byte integer @@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv, } type = cast<TensorType>(convertIndexElementType(type, options)); + type = cast<TensorType>(convertShaped8BitFloatType(type, options)); auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType()); if (!scalarType) { LLVM_DEBUG(llvm::dbgs() @@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv, } else if (auto indexType = dyn_cast<IndexType>(elementType)) { type = cast<MemRefType>(convertIndexElementType(type, options)); arrayElemType = type.getElementType(); + } else if (auto floatType = dyn_cast<FloatType>(elementType)) { + // Hnadle 8 bit float types. + type = cast<MemRefType>(convertShaped8BitFloatType(type, options)); + arrayElemType = type.getElementType(); } else { LLVM_DEBUG( llvm::dbgs() @@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr, addConversion([this](FloatType floatType) -> std::optional<Type> { if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType)) return convertScalarType(this->targetEnv, this->options, scalarType); + if (floatType.getWidth() == 8) + return convert8BitFloatType(this->options, floatType); return Type(); }); diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp index 6a9b951..a53d0a7 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp @@ -174,6 +174,21 @@ void UpdateVCEPass::runOnOperation() { if (walkResult.wasInterrupted()) return signalPassFailure(); + // Update min version requirement for capabilities after deducing them. + for (spirv::Capability cap : deducedCapabilities) { + if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) { + deducedVersion = std::max(deducedVersion, *minVersion); + if (deducedVersion > allowedVersion) { + module.emitError("Capability '") + << spirv::stringifyCapability(cap) << "' requires min version " + << spirv::stringifyVersion(deducedVersion) + << " but target environment allows up to " + << spirv::stringifyVersion(allowedVersion); + return signalPassFailure(); + } + } + } + // TODO: verify that the deduced version is consistent with // SPIR-V ops' maximal version requirements. diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 7805599..5ba8289 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -150,17 +150,17 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto poison = dyn_cast<ub::PoisonAttr>(value)) - return builder.create<ub::PoisonOp>(loc, type, poison); + return ub::PoisonOp::create(builder, loc, type, poison); if (llvm::isa<ShapeType>(type) || isExtentTensorType(type)) - return builder.create<ConstShapeOp>( - loc, type, llvm::cast<DenseIntElementsAttr>(value)); + return ConstShapeOp::create(builder, loc, type, + llvm::cast<DenseIntElementsAttr>(value)); if (llvm::isa<SizeType>(type)) - return builder.create<ConstSizeOp>(loc, type, - llvm::cast<IntegerAttr>(value)); + return ConstSizeOp::create(builder, loc, type, + llvm::cast<IntegerAttr>(value)); if (llvm::isa<WitnessType>(type)) - return builder.create<ConstWitnessOp>(loc, type, - llvm::cast<BoolAttr>(value)); + return ConstWitnessOp::create(builder, loc, type, + llvm::cast<BoolAttr>(value)); return arith::ConstantOp::materialize(builder, value, type, loc); } @@ -315,8 +315,8 @@ struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> { auto newYieldOp = rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands); rewriter.setInsertionPoint(op); - auto newOp = rewriter.create<AssumingOp>( - op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); + auto newOp = AssumingOp::create( + rewriter, op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness()); newOp.getDoRegion().takeBody(op.getDoRegion()); // Use the new results to replace the previously used ones. @@ -384,7 +384,7 @@ void AssumingOp::build( // Build body. SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location); - builder.create<AssumingYieldOp>(result.location, yieldValues); + AssumingYieldOp::create(builder, result.location, yieldValues); SmallVector<Type, 2> assumingTypes; for (Value v : yieldValues) @@ -735,13 +735,13 @@ struct BroadcastForwardSingleOperandPattern if (replacement.getType() != op.getType()) { auto loc = op.getLoc(); if (llvm::isa<ShapeType>(op.getType())) { - replacement = rewriter.create<FromExtentTensorOp>(loc, replacement); + replacement = FromExtentTensorOp::create(rewriter, loc, replacement); } else { assert(!llvm::isa<ShapeType>(op.getType()) && !llvm::isa<ShapeType>(replacement.getType()) && "expect extent tensor cast"); replacement = - rewriter.create<tensor::CastOp>(loc, op.getType(), replacement); + tensor::CastOp::create(rewriter, loc, op.getType(), replacement); } } @@ -779,9 +779,9 @@ struct BroadcastFoldConstantOperandsPattern auto foldedConstantOperandsTy = RankedTensorType::get( {static_cast<int64_t>(foldedConstantShape.size())}, rewriter.getIndexType()); - newShapeOperands.push_back(rewriter.create<ConstShapeOp>( - op.getLoc(), foldedConstantOperandsTy, - rewriter.getIndexTensorAttr(foldedConstantShape))); + newShapeOperands.push_back( + ConstShapeOp::create(rewriter, op.getLoc(), foldedConstantOperandsTy, + rewriter.getIndexTensorAttr(foldedConstantShape))); rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), newShapeOperands); return success(); @@ -844,9 +844,9 @@ struct BroadcastConcretizeResultTypePattern } } - auto newOp = rewriter.create<BroadcastOp>( - op.getLoc(), getExtentTensorType(getContext(), maxRank), - op.getShapes()); + auto newOp = BroadcastOp::create(rewriter, op.getLoc(), + getExtentTensorType(getContext(), maxRank), + op.getShapes()); rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp); return success(); } @@ -1353,11 +1353,11 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, auto loc = result.location; auto dimAttr = builder.getIndexAttr(dim); if (llvm::isa<ShapeType>(shape.getType())) { - Value dim = builder.create<ConstSizeOp>(loc, dimAttr); + Value dim = ConstSizeOp::create(builder, loc, dimAttr); build(builder, result, builder.getType<SizeType>(), shape, dim); } else { - Value dim = - builder.create<arith::ConstantOp>(loc, builder.getIndexType(), dimAttr); + Value dim = arith::ConstantOp::create(builder, loc, builder.getIndexType(), + dimAttr); build(builder, result, builder.getIndexType(), shape, dim); } } @@ -1702,13 +1702,12 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> { return failure(); Location loc = op.getLoc(); Value constShape = - rewriter - .create<ConstShapeOp>(loc, - rewriter.getIndexTensorAttr(type.getShape())) + ConstShapeOp::create(rewriter, loc, + rewriter.getIndexTensorAttr(type.getShape())) .getResult(); if (constShape.getType() != op.getResult().getType()) - constShape = rewriter.create<tensor::CastOp>( - loc, op.getResult().getType(), constShape); + constShape = tensor::CastOp::create(rewriter, loc, + op.getResult().getType(), constShape); rewriter.replaceOp(op, constShape); return success(); } @@ -1750,10 +1749,11 @@ struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> { if (opTensorTy != shapeTensorTy) { if (opTensorTy.getElementType() == shapeTensorTy.getElementType()) - shape = rewriter.create<tensor::CastOp>(op.getLoc(), opTensorTy, shape); - else if (!isExtentTensorType(shapeTensorTy)) shape = - rewriter.create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape); + tensor::CastOp::create(rewriter, op.getLoc(), opTensorTy, shape); + else if (!isExtentTensorType(shapeTensorTy)) + shape = arith::IndexCastOp::create(rewriter, op.getLoc(), opTensorTy, + shape); } rewriter.replaceOp(op, shape); diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp index e405475..f6bc225 100644 --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -55,8 +55,8 @@ struct AssumingOpInterface // Create new op and move over region. TypeRange newResultTypes(yieldOp.getOperands()); - auto newOp = rewriter.create<shape::AssumingOp>( - op->getLoc(), newResultTypes, assumingOp.getWitness()); + auto newOp = shape::AssumingOp::create( + rewriter, op->getLoc(), newResultTypes, assumingOp.getWitness()); newOp.getDoRegion().takeBody(assumingOp.getRegion()); // Update all uses of the old op. @@ -64,8 +64,9 @@ struct AssumingOpInterface SmallVector<Value> newResults; for (const auto &it : llvm::enumerate(assumingOp->getResultTypes())) { if (isa<TensorType>(it.value())) { - newResults.push_back(rewriter.create<bufferization::ToTensorOp>( - assumingOp.getLoc(), it.value(), newOp->getResult(it.index()))); + newResults.push_back(bufferization::ToTensorOp::create( + rewriter, assumingOp.getLoc(), it.value(), + newOp->getResult(it.index()))); } else { newResults.push_back(newOp->getResult(it.index())); } diff --git a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp index 0fe1072..b636797 100644 --- a/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/OutlineShapeComputation.cpp @@ -66,7 +66,7 @@ createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster, cluster.empty() ? b.getFunctionType(shape.getType(), shape.getType()) : b.getFunctionType(ValueRange(inputs).getTypes(), shape.getType()); - shape::FuncOp fnOp = b.create<shape::FuncOp>(loc, fnName, fnType); + shape::FuncOp fnOp = shape::FuncOp::create(b, loc, fnName, fnType); Block *block = fnOp.addEntryBlock(); b.setInsertionPointToEnd(block); IRMapping bvm; @@ -82,7 +82,7 @@ createFuncFromCluster(OpBuilder &b, const SmallVector<Operation *, 8> &cluster, llvm::SmallVector<Value, 4> fnReturns; fnReturns.push_back(bvm.lookupOrDefault(shape)); - b.create<shape::ReturnOp>(loc, fnReturns); + shape::ReturnOp::create(b, loc, fnReturns); fnOp.setPrivate(); return std::make_pair(fnOp, inputs); } @@ -184,7 +184,7 @@ class TensorDimOpRewriter : public OpRewritePattern<tensor::DimOp> { LogicalResult matchAndRewrite(tensor::DimOp op, PatternRewriter &rewriter) const override { auto shapeOf = - rewriter.create<shape::ShapeOfOp>(op.getLoc(), op.getSource()); + shape::ShapeOfOp::create(rewriter, op.getLoc(), op.getSource()); rewriter.replaceOpWithNewOp<shape::GetExtentOp>(op, op.getType(), shapeOf, op.getIndex()); return success(); diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp index d83ceab..3c363f3 100644 --- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp @@ -43,14 +43,14 @@ NumElementsOpConverter::matchAndRewrite(NumElementsOp op, ->materializeConstant(rewriter, rewriter.getIndexAttr(1), valueType, loc) ->getResult(0); - ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.getShape(), init); + ReduceOp reduce = ReduceOp::create(rewriter, loc, op.getShape(), init); // Generate reduce operator. Block *body = reduce.getBody(); OpBuilder b = OpBuilder::atBlockEnd(body); - Value product = b.create<MulOp>(loc, valueType, body->getArgument(1), - body->getArgument(2)); - b.create<shape::YieldOp>(loc, product); + Value product = MulOp::create(b, loc, valueType, body->getArgument(1), + body->getArgument(2)); + shape::YieldOp::create(b, loc, product); rewriter.replaceOp(op, reduce.getResult()); return success(); diff --git a/mlir/lib/Dialect/Mesh/CMakeLists.txt b/mlir/lib/Dialect/Shard/CMakeLists.txt index fa8842f..fa8842f 100644 --- a/mlir/lib/Dialect/Mesh/CMakeLists.txt +++ b/mlir/lib/Dialect/Shard/CMakeLists.txt diff --git a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt b/mlir/lib/Dialect/Shard/IR/CMakeLists.txt index 3fea4d6..70c604988 100644 --- a/mlir/lib/Dialect/Mesh/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Shard/IR/CMakeLists.txt @@ -1,11 +1,11 @@ -add_mlir_dialect_library(MLIRMeshDialect - MeshOps.cpp +add_mlir_dialect_library(MLIRShardDialect + ShardOps.cpp ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard DEPENDS - MLIRMeshIncGen + MLIRShardIncGen LINK_LIBS PUBLIC MLIRArithDialect diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp index 28608cb..08fccfa 100644 --- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp +++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp @@ -1,4 +1,4 @@ -//===- MeshOps.cpp - Mesh Dialect Operations ------------------------------===// +//===- ShardOps.cpp - Shard Dialect Operations ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" @@ -37,13 +37,12 @@ #include <optional> #include <utility> -#define DEBUG_TYPE "mesh-ops" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") +#define DEBUG_TYPE "shard-ops" using namespace mlir; -using namespace mlir::mesh; +using namespace mlir::shard; -#include "mlir/Dialect/Mesh/IR/MeshDialect.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardDialect.cpp.inc" namespace { @@ -74,11 +73,10 @@ static DimensionSize operator*(DimensionSize lhs, DimensionSize rhs) { return lhs.value() * rhs.value(); } -SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b, - const Location &loc, - llvm::ArrayRef<int64_t> statics, - ValueRange dynamics, - Type type) { +SmallVector<Value> +mlir::shard::getMixedAsValues(OpBuilder b, const Location &loc, + llvm::ArrayRef<int64_t> statics, + ValueRange dynamics, Type type) { SmallVector<Value> values; auto dyn = dynamics.begin(); Type i64 = b.getI64Type(); @@ -102,7 +100,7 @@ SmallVector<Value> mlir::mesh::getMixedAsValues(OpBuilder b, //===----------------------------------------------------------------------===// namespace { -struct MeshInlinerInterface : public DialectInlinerInterface { +struct ShardInlinerinterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; // Currently no restrictions are encoded for inlining. bool isLegalToInline(Operation *, Operation *, bool) const final { @@ -118,44 +116,45 @@ struct MeshInlinerInterface : public DialectInlinerInterface { } // namespace //===----------------------------------------------------------------------===// -// Mesh dialect +// Shard dialect //===----------------------------------------------------------------------===// -void MeshDialect::initialize() { +void ShardDialect::initialize() { addOperations< #define GET_OP_LIST -#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc" >(); addAttributes< #define GET_ATTRDEF_LIST -#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc" >(); addTypes< #define GET_TYPEDEF_LIST -#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc" >(); - addInterface<MeshInlinerInterface>(); + addInterface<ShardInlinerinterface>(); } -Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value, - Type type, Location loc) { +Operation *ShardDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { return arith::ConstantOp::materialize(builder, value, type, loc); } //===----------------------------------------------------------------------===// -// Mesh utilities +// Shard utilities //===----------------------------------------------------------------------===// -static FailureOr<MeshOp> getMeshAndVerify(Operation *op, - FlatSymbolRefAttr meshSymbol, +static FailureOr<GridOp> getGridAndVerify(Operation *op, + FlatSymbolRefAttr gridSymbol, SymbolTableCollection &symbolTable) { - mesh::MeshOp mesh = getMeshOrNull(op, meshSymbol, symbolTable); - if (!mesh) { - return op->emitError() << "Undefined required mesh symbol \"" - << meshSymbol.getValue() << "\"."; + shard::GridOp grid = getGridOrNull(op, gridSymbol, symbolTable); + if (!grid) { + return op->emitError() << "Undefined required grid symbol \"" + << gridSymbol.getValue() << "\"."; } - return mesh; + return grid; } template <typename It> @@ -175,20 +174,20 @@ bool isUnique(It begin, It end) { return true; } -static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes, - MeshOp mesh) { - SmallVector<MeshAxis> sorted = llvm::to_vector(axes); +static LogicalResult verifyGridAxes(Location loc, ArrayRef<GridAxis> axes, + GridOp grid) { + SmallVector<GridAxis> sorted = llvm::to_vector(axes); llvm::sort(sorted); if (!isUnique(sorted.begin(), sorted.end())) { - return emitError(loc) << "Mesh axes contains duplicate elements."; + return emitError(loc) << "Grid axes contains duplicate elements."; } - MeshAxis rank = mesh.getRank(); + GridAxis rank = grid.getRank(); for (auto axis : axes) { if (axis >= rank || axis < 0) { return emitError(loc) - << "0-based mesh axis index " << axis - << " is out of bounds. The referenced mesh \"" << mesh.getSymName() + << "0-based grid axis index " << axis + << " is out of bounds. The referenced grid \"" << grid.getSymName() << "\" is of rank " << rank << "."; } } @@ -197,22 +196,22 @@ static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes, } template <typename Op> -static FailureOr<MeshOp> -getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) { - auto mesh = - ::getMeshAndVerify(op.getOperation(), op.getMeshAttr(), symbolTable); - if (failed(mesh)) { +static FailureOr<GridOp> +getGridAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) { + auto grid = + ::getGridAndVerify(op.getOperation(), op.getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } - if (failed(verifyMeshAxes(op.getLoc(), op.getMeshAxes(), mesh.value()))) { + if (failed(verifyGridAxes(op.getLoc(), op.getGridAxes(), grid.value()))) { return failure(); } - return mesh; + return grid; } -template <typename InShape, typename MeshShape, typename SplitAxes, +template <typename InShape, typename GridShape, typename SplitAxes, typename OutShape> -static void shardShape(const InShape &inShape, const MeshShape &meshShape, +static void shardShape(const InShape &inShape, const GridShape &gridShape, const SplitAxes &splitAxes, OutShape &outShape, ArrayRef<int64_t> shardedDimsOffsets = {}, ArrayRef<int64_t> haloSizes = {}) { @@ -226,7 +225,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape, llvm::adl_begin(outShape)); if (!shardedDimsOffsets.empty()) { - auto isDynShape = ShapedType::isDynamicShape(meshShape); + auto isDynShape = ShapedType::isDynamicShape(gridShape); uint64_t pos = 1; for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { if (!innerSplitAxes.empty()) { @@ -238,7 +237,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape, // non-uniform offs in shardedDimsOffsets. uint64_t numShards = 0; for (auto i : innerSplitAxes.asArrayRef()) { - numShards += meshShape[i]; + numShards += gridShape[i]; } for (size_t i = 1; i < numShards; ++i) { if (shardedDimsOffsets[pos + i] - shardedDimsOffsets[pos + i - 1] != @@ -256,7 +255,7 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape, for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) { outShape[tensorAxis] = shardDimension( inShape[tensorAxis], - collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), meshShape)); + collectiveProcessGroupSize(innerSplitAxes.asArrayRef(), gridShape)); } if (!haloSizes.empty()) { @@ -279,25 +278,25 @@ static void shardShape(const InShape &inShape, const MeshShape &meshShape, } } -ShapedType mesh::shardShapedType(ShapedType shape, MeshOp mesh, - MeshSharding sharding) { +ShapedType shard::shardShapedType(ShapedType shape, GridOp grid, + Sharding sharding) { using Dim = std::decay_t<decltype(shape.getDimSize(0))>; SmallVector<Dim> resShapeArr(shape.getShape().size()); - shardShape(shape.getShape(), mesh.getShape(), sharding.getSplitAxes(), + shardShape(shape.getShape(), grid.getShape(), sharding.getSplitAxes(), resShapeArr, sharding.getStaticShardedDimsOffsets(), sharding.getStaticHaloSizes()); return shape.clone(resShapeArr); } -Type mesh::shardType(Type type, MeshOp mesh, MeshSharding sharding) { +Type shard::shardType(Type type, GridOp grid, Sharding sharding) { RankedTensorType rankedTensorType = dyn_cast<RankedTensorType>(type); if (rankedTensorType && !rankedTensorType.getShape().empty()) { - return shardShapedType(rankedTensorType, mesh, sharding); + return shardShapedType(rankedTensorType, grid, sharding); } return type; } -static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, +static void maybeInsertTargetShardingAnnotationImpl(Sharding sharding, Value &operandValue, Operation *operandOp, OpBuilder &builder, @@ -336,9 +335,9 @@ static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding, newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2); } -void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, - OpResult result, - OpBuilder &builder) { +void mlir::shard::maybeInsertTargetShardingAnnotation(Sharding sharding, + OpResult result, + OpBuilder &builder) { ShardOp newShardOp; SmallVector<std::pair<Value, Operation *>> uses; for (auto &use : result.getUses()) { @@ -350,9 +349,9 @@ void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding, } } -void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, - OpOperand &operand, - OpBuilder &builder) { +void mlir::shard::maybeInsertSourceShardingAnnotation(Sharding sharding, + OpOperand &operand, + OpBuilder &builder) { OpBuilder::InsertionGuard insertionGuard(builder); Value operandValue = operand.get(); Operation *operandSrcOp = operandValue.getDefiningOp(); @@ -404,18 +403,18 @@ void mlir::mesh::maybeInsertSourceShardingAnnotation(MeshSharding sharding, } //===----------------------------------------------------------------------===// -// mesh.mesh op +// shard.grid op //===----------------------------------------------------------------------===// -LogicalResult MeshOp::verify() { +LogicalResult GridOp::verify() { int64_t rank = getRank(); if (rank <= 0) - return emitOpError("rank of mesh is expected to be a positive integer"); + return emitOpError("rank of grid is expected to be a positive integer"); for (int64_t dimSize : getShape()) { if (dimSize < 0 && ShapedType::isStatic(dimSize)) - return emitOpError("dimension size of a mesh is expected to be " + return emitOpError("dimension size of a grid is expected to be " "non-negative or dynamic"); } @@ -423,21 +422,21 @@ LogicalResult MeshOp::verify() { } //===----------------------------------------------------------------------===// -// mesh.mesh_shape op +// shard.grid_shape op //===----------------------------------------------------------------------===// LogicalResult -MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { +GridShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } - if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { + if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) { return failure(); } size_t expectedResultsCount = - getAxes().empty() ? mesh->getRank() : getAxes().size(); + getAxes().empty() ? grid->getRank() : getAxes().size(); if (getResult().size() != expectedResultsCount) { return emitError() << "Unexpected number of results " << getResult().size() << ". Expected " << expectedResultsCount << "."; @@ -446,53 +445,53 @@ MeshShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } -void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, - MeshOp mesh) { - build(odsBuilder, odsState, mesh, SmallVector<MeshAxis>()); +void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + GridOp grid) { + build(odsBuilder, odsState, grid, SmallVector<GridAxis>()); } -void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, - MeshOp mesh, ArrayRef<MeshAxis> axes) { +void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + GridOp grid, ArrayRef<GridAxis> axes) { build(odsBuilder, odsState, - SmallVector<Type>(axes.empty() ? mesh.getRank() : axes.size(), + SmallVector<Type>(axes.empty() ? grid.getRank() : axes.size(), odsBuilder.getIndexType()), - mesh.getSymName(), MeshAxesAttr::get(odsBuilder.getContext(), axes)); + grid.getSymName(), GridAxesAttr::get(odsBuilder.getContext(), axes)); } -void MeshShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef mesh, ArrayRef<MeshAxis> axes) { +void GridShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState, + StringRef grid, ArrayRef<GridAxis> axes) { assert(!axes.empty()); build(odsBuilder, odsState, - SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, - MeshAxesAttr::get(odsBuilder.getContext(), axes)); + SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid, + GridAxesAttr::get(odsBuilder.getContext(), axes)); } -void MeshShapeOp::getAsmResultNames( +void GridShapeOp::getAsmResultNames( function_ref<void(Value, StringRef)> setNameFn) { - setNameFn(getResults()[0], "mesh_shape"); + setNameFn(getResults()[0], "grid_shape"); } //===----------------------------------------------------------------------===// -// mesh.sharding +// shard.sharding //===----------------------------------------------------------------------===// void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - FlatSymbolRefAttr mesh, - ArrayRef<MeshAxesAttr> split_axes, + FlatSymbolRefAttr grid, + ArrayRef<GridAxesAttr> split_axes, ArrayRef<int64_t> static_halos, ArrayRef<int64_t> static_offsets) { return build( - b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), + b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes), ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); } void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - llvm::StringRef mesh, ArrayRef<MeshAxesAttr> split_axes, + llvm::StringRef grid, ArrayRef<GridAxesAttr> split_axes, ArrayRef<int64_t> static_halos, ArrayRef<int64_t> static_offsets) { - return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), mesh), - MeshAxesArrayAttr::get(b.getContext(), split_axes), + return build(b, odsState, FlatSymbolRefAttr::get(b.getContext(), grid), + GridAxesArrayAttr::get(b.getContext(), split_axes), ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_halos), {}, ::mlir::DenseI64ArrayAttr::get(b.getContext(), static_offsets), {}); @@ -500,7 +499,7 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, void ShardingOp::build( ::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - FlatSymbolRefAttr mesh, ArrayRef<MeshAxesAttr> split_axes, + FlatSymbolRefAttr grid, ArrayRef<GridAxesAttr> split_axes, ::mlir::ArrayRef<::mlir::OpFoldResult> halo_sizes, ::mlir::ArrayRef<::mlir::OpFoldResult> sharded_dims_offsets) { mlir::SmallVector<int64_t> staticHalos, staticDims; @@ -508,16 +507,16 @@ void ShardingOp::build( dispatchIndexOpFoldResults(halo_sizes, dynamicHalos, staticHalos); dispatchIndexOpFoldResults(sharded_dims_offsets, dynamicDims, staticDims); return build( - b, odsState, mesh, MeshAxesArrayAttr::get(b.getContext(), split_axes), + b, odsState, grid, GridAxesArrayAttr::get(b.getContext(), split_axes), ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticHalos), dynamicHalos, ::mlir::DenseI64ArrayAttr::get(b.getContext(), staticDims), dynamicDims); } void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, - mlir::mesh::MeshSharding from) { + mlir::shard::Sharding from) { - build(b, odsState, ShardingType::get(b.getContext()), from.getMeshAttr(), - MeshAxesArrayAttr::get(b.getContext(), from.getSplitAxes()), + build(b, odsState, ShardingType::get(b.getContext()), from.getGridAttr(), + GridAxesArrayAttr::get(b.getContext(), from.getSplitAxes()), from.getStaticShardedDimsOffsets().empty() ? DenseI64ArrayAttr() : b.getDenseI64ArrayAttr(from.getStaticShardedDimsOffsets()), @@ -529,21 +528,21 @@ void ShardingOp::build(::mlir::OpBuilder &b, ::mlir::OperationState &odsState, } LogicalResult ShardingOp::verify() { - llvm::SmallSet<MeshAxis, 4> visitedAxes; + llvm::SmallSet<GridAxis, 4> visitedAxes; - auto checkMeshAxis = [&](ArrayRef<MeshAxis> axesArray) -> LogicalResult { - for (MeshAxis axis : axesArray) { + auto checkGridAxis = [&](ArrayRef<GridAxis> axesArray) -> LogicalResult { + for (GridAxis axis : axesArray) { if (axis < 0) - return emitError() << "mesh axis is expected to be non-negative"; + return emitError() << "grid axis is expected to be non-negative"; if (!visitedAxes.insert(axis).second) - return emitError() << "mesh axis duplicated"; + return emitError() << "grid axis duplicated"; } return success(); }; for (auto subAxes : getSplitAxes().getAxes()) { - ArrayRef<MeshAxis> subAxesArray = subAxes.asArrayRef(); - if (failed(checkMeshAxis(subAxesArray))) + ArrayRef<GridAxis> subAxesArray = subAxes.asArrayRef(); + if (failed(checkGridAxis(subAxesArray))) return failure(); } @@ -572,26 +571,26 @@ void ShardingOp::getAsmResultNames( } LogicalResult ShardingOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { + auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } - if (mlir::ShapedType::isDynamicShape(mesh->getShape()) && + if (mlir::ShapedType::isDynamicShape(grid->getShape()) && getStaticShardedDimsOffsets().size() > 0) { return emitError() << "sharded dims offsets are not allowed for " - "devices meshes with dynamic shape."; + "device grids with dynamic shape."; } auto shardedDimsOffsets = getStaticShardedDimsOffsets(); if (!shardedDimsOffsets.empty()) { - auto meshShape = mesh.value().getShape(); - assert(ShapedType::isStaticShape(meshShape)); + auto gridShape = grid.value().getShape(); + assert(ShapedType::isStaticShape(gridShape)); uint64_t pos = 0; for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(getSplitAxes())) { if (!innerSplitAxes.empty()) { int64_t numShards = 0, off = 0; for (auto i : innerSplitAxes.asArrayRef()) { - numShards += meshShape[i]; + numShards += gridShape[i]; } for (int64_t i = 0; i <= numShards; ++i) { if (shardedDimsOffsets.size() <= pos + i) { @@ -684,11 +683,11 @@ void ShardingOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, } //===----------------------------------------------------------------------===// -// MeshSharding +// Sharding //===----------------------------------------------------------------------===// -bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const { - if (getMesh() != rhs.getMesh()) { +bool Sharding::equalSplitAxes(const Sharding &rhs) const { + if (getGrid() != rhs.getGrid()) { return false; } @@ -701,16 +700,16 @@ bool MeshSharding::equalSplitAxes(const MeshSharding &rhs) const { } return llvm::all_of(llvm::drop_begin(getSplitAxes(), minSize), - std::mem_fn(&MeshAxesAttr::empty)) && + std::mem_fn(&GridAxesAttr::empty)) && llvm::all_of(llvm::drop_begin(rhs.getSplitAxes(), minSize), - std::mem_fn(&MeshAxesAttr::empty)); + std::mem_fn(&GridAxesAttr::empty)); } -bool MeshSharding::equalHaloAndShardSizes(const MeshSharding &rhs) const { +bool Sharding::equalHaloAndShardSizes(const Sharding &rhs) const { return equalShardSizes(rhs) && equalHaloSizes(rhs); } -bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const { +bool Sharding::equalShardSizes(const Sharding &rhs) const { if (rhs.getStaticShardedDimsOffsets().size() != getStaticShardedDimsOffsets().size() || !llvm::equal(getStaticShardedDimsOffsets(), @@ -726,7 +725,7 @@ bool MeshSharding::equalShardSizes(const MeshSharding &rhs) const { return true; } -bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const { +bool Sharding::equalHaloSizes(const Sharding &rhs) const { if (rhs.getStaticHaloSizes().size() != getStaticHaloSizes().size() || !llvm::equal(getStaticHaloSizes(), rhs.getStaticHaloSizes())) { return false; @@ -738,45 +737,43 @@ bool MeshSharding::equalHaloSizes(const MeshSharding &rhs) const { return true; } -bool MeshSharding::operator==(Value rhs) const { +bool Sharding::operator==(Value rhs) const { return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs); } -bool MeshSharding::operator!=(Value rhs) const { return !(*this == rhs); } +bool Sharding::operator!=(Value rhs) const { return !(*this == rhs); } -bool MeshSharding::operator==(const MeshSharding &rhs) const { +bool Sharding::operator==(const Sharding &rhs) const { return equalSplitAxes(rhs) && equalHaloAndShardSizes(rhs); } -bool MeshSharding::operator!=(const MeshSharding &rhs) const { - return !(*this == rhs); -} +bool Sharding::operator!=(const Sharding &rhs) const { return !(*this == rhs); } -MeshSharding::MeshSharding(::mlir::FlatSymbolRefAttr mesh_) : mesh(mesh_) {} +Sharding::Sharding(::mlir::FlatSymbolRefAttr grid_) : grid(grid_) {} -MeshSharding::MeshSharding(Value rhs) { - auto shardingOp = mlir::dyn_cast<ShardingOp>(rhs.getDefiningOp()); +Sharding::Sharding(Value rhs) { + auto shardingOp = rhs.getDefiningOp<ShardingOp>(); assert(shardingOp && "expected sharding op"); auto splitAxes = shardingOp.getSplitAxes().getAxes(); // If splitAxes are empty, use "empty" constructor. if (splitAxes.empty()) { - *this = MeshSharding(shardingOp.getMeshAttr()); + *this = Sharding(shardingOp.getGridAttr()); return; } *this = - get(shardingOp.getMeshAttr(), splitAxes, shardingOp.getStaticHaloSizes(), + get(shardingOp.getGridAttr(), splitAxes, shardingOp.getStaticHaloSizes(), shardingOp.getStaticShardedDimsOffsets(), SmallVector<Value>(shardingOp.getDynamicHaloSizes()), SmallVector<Value>(shardingOp.getDynamicShardedDimsOffsets())); } -MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, - ArrayRef<MeshAxesAttr> split_axes_, - ArrayRef<int64_t> static_halo_sizes_, - ArrayRef<int64_t> static_sharded_dims_offsets_, - ArrayRef<Value> dynamic_halo_sizes_, - ArrayRef<Value> dynamic_sharded_dims_offsets_) { - MeshSharding res(mesh_); +Sharding Sharding::get(::mlir::FlatSymbolRefAttr grid_, + ArrayRef<GridAxesAttr> split_axes_, + ArrayRef<int64_t> static_halo_sizes_, + ArrayRef<int64_t> static_sharded_dims_offsets_, + ArrayRef<Value> dynamic_halo_sizes_, + ArrayRef<Value> dynamic_sharded_dims_offsets_) { + Sharding res(grid_); if (split_axes_.empty()) { return res; } @@ -784,7 +781,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, res.split_axes.resize(split_axes_.size()); for (auto [i, axis] : llvm::enumerate(split_axes_)) { res.split_axes[i] = - MeshAxesAttr::get(mesh_.getContext(), axis.asArrayRef()); + GridAxesAttr::get(grid_.getContext(), axis.asArrayRef()); } auto clone = [](const auto src, auto &dst) { @@ -801,7 +798,7 @@ MeshSharding MeshSharding::get(::mlir::FlatSymbolRefAttr mesh_, } //===----------------------------------------------------------------------===// -// mesh.shard_shape +// shard.shard_shape //===----------------------------------------------------------------------===// void ShardShapeOp::getAsmResultNames( @@ -820,7 +817,7 @@ void ShardShapeOp::build(::mlir::OpBuilder &odsBuilder, } //===----------------------------------------------------------------------===// -// mesh.shard op +// shard.shard op //===----------------------------------------------------------------------===// void ShardOp::getAsmResultNames( @@ -850,10 +847,10 @@ public: if (!otherOp || !otherOp->isBeforeInBlock(op)) { return failure(); } - // Create a MeshSharding object for the current and the other ShardOp + // Create a Sharding object for the current and the other ShardOp // If the two are equal replace current op with the other op. - MeshSharding currentSharding(op.getSharding()); - MeshSharding otherSharding(otherOp.getSharding()); + Sharding currentSharding(op.getSharding()); + Sharding otherSharding(otherOp.getSharding()); if (currentSharding == otherSharding) { b.replaceAllUsesWith(op.getResult(), otherOp.getResult()); b.eraseOp(op.getOperation()); @@ -876,21 +873,21 @@ void ShardOp::getCanonicalizationPatterns(mlir::RewritePatternSet &results, } //===----------------------------------------------------------------------===// -// mesh.process_multi_index op +// shard.process_multi_index op //===----------------------------------------------------------------------===// LogicalResult ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { + auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } - if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) { + if (failed(verifyGridAxes(getLoc(), getAxes(), grid.value()))) { return failure(); } size_t expectedResultsCount = - getAxes().empty() ? mesh->getRank() : getAxes().size(); + getAxes().empty() ? grid->getRank() : getAxes().size(); if (getResult().size() != expectedResultsCount) { return emitError() << "Unexpected number of results " << getResult().size() << ". Expected " << expectedResultsCount << "."; @@ -900,17 +897,17 @@ ProcessMultiIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, - MeshOp mesh) { + GridOp grid) { build(odsBuilder, odsState, - SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()), - mesh.getSymName(), ArrayRef<MeshAxis>()); + SmallVector<Type>(grid.getRank(), odsBuilder.getIndexType()), + grid.getSymName(), ArrayRef<GridAxis>()); } void ProcessMultiIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState, - StringRef mesh, ArrayRef<MeshAxis> axes) { + StringRef grid, ArrayRef<GridAxis> axes) { build(odsBuilder, odsState, - SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh, - MeshAxesAttr::get(odsBuilder.getContext(), axes)); + SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), grid, + GridAxesAttr::get(odsBuilder.getContext(), axes)); } void ProcessMultiIndexOp::getAsmResultNames( @@ -919,21 +916,21 @@ void ProcessMultiIndexOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.process_linear_index op +// shard.process_linear_index op //===----------------------------------------------------------------------===// LogicalResult ProcessLinearIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { + auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } return success(); } void ProcessLinearIndexOp::build(OpBuilder &odsBuilder, - OperationState &odsState, MeshOp mesh) { - build(odsBuilder, odsState, mesh.getSymName()); + OperationState &odsState, GridOp grid) { + build(odsBuilder, odsState, grid.getSymName()); } void ProcessLinearIndexOp::getAsmResultNames( @@ -942,13 +939,13 @@ void ProcessLinearIndexOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.neighbors_linear_indices op +// shard.neighbors_linear_indices op //===----------------------------------------------------------------------===// LogicalResult NeighborsLinearIndicesOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = ::getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { + auto grid = ::getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } return success(); @@ -967,12 +964,12 @@ void NeighborsLinearIndicesOp::getAsmResultNames( namespace { template <typename Op> -struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> { +struct EmptyGridAxesCanonicalizationPattern : OpRewritePattern<Op> { using OpRewritePattern<Op>::OpRewritePattern; LogicalResult matchAndRewrite(Op op, PatternRewriter &rewriter) const override { - auto meshAxes = op.getMeshAxes(); - if (!meshAxes.empty()) { + auto gridAxes = op.getGridAxes(); + if (!gridAxes.empty()) { return failure(); } if (op.getInput().getType() != op.getResult().getType()) { @@ -990,24 +987,24 @@ struct EmptyMeshAxesCanonicalizationPattern : OpRewritePattern<Op> { static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName, ArrayRef<int64_t> device, Operation::operand_range deviceDynamic, - ArrayRef<MeshAxis> meshAxes, - ArrayRef<int64_t> meshShape) { - if (device.size() != meshAxes.size()) { + ArrayRef<GridAxis> gridAxes, + ArrayRef<int64_t> gridShape) { + if (device.size() != gridAxes.size()) { return emitError(loc) << "In-group device \"" << deviceName << "\" has unexpected multi-index size " - << device.size() << ". Expected " << meshAxes.size() + << device.size() << ". Expected " << gridAxes.size() << "."; } for (size_t i = 0; i < device.size(); ++i) { if (ShapedType::isStatic(device[i]) && - ShapedType::isStatic(meshShape[meshAxes[i]]) && - meshShape[meshAxes[i]] <= device[i]) { + ShapedType::isStatic(gridShape[gridAxes[i]]) && + gridShape[gridAxes[i]] <= device[i]) { return emitError(loc) << "Out of bounds coordinate " << i << " for in-group device \"" << deviceName << "\"." << " Got " << device[i] << ", but expected value in the range [0, " - << (meshShape[meshAxes[i]] - 1) << "]."; + << (gridShape[gridAxes[i]] - 1) << "]."; } } return success(); @@ -1043,7 +1040,7 @@ static LogicalResult verifyDimensionCompatibility(Location loc, static LogicalResult verifyGatherOperandAndResultShape( Value operand, Value result, int64_t gatherAxis, - ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { + ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) { auto resultRank = cast<ShapedType>(result.getType()).getRank(); if (gatherAxis < 0 || gatherAxis >= resultRank) { return emitError(result.getLoc()) @@ -1054,7 +1051,7 @@ static LogicalResult verifyGatherOperandAndResultShape( ShapedType operandType = cast<ShapedType>(operand.getType()); ShapedType resultType = cast<ShapedType>(result.getType()); auto deviceGroupSize = - DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); + DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape)); for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { auto operandDimSize = DimensionSize(operandType.getDimSize(axis)); auto resultDimSize = DimensionSize(resultType.getDimSize(axis)); @@ -1070,7 +1067,7 @@ static LogicalResult verifyGatherOperandAndResultShape( static LogicalResult verifyAllToAllOperandAndResultShape( Value operand, Value result, int64_t splitAxis, int64_t concatAxis, - ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { + ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) { ShapedType operandType = cast<ShapedType>(operand.getType()); ShapedType resultType = cast<ShapedType>(result.getType()); for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { @@ -1088,7 +1085,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape( } auto deviceGroupSize = - DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); + DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape)); auto operandConcatDimSize = DimensionSize(operandType.getDimSize(concatAxis)); auto operandSplitDimSize = DimensionSize(operandType.getDimSize(splitAxis)); DimensionSize expectedResultConcatDimSize = @@ -1115,7 +1112,7 @@ static LogicalResult verifyAllToAllOperandAndResultShape( static LogicalResult verifyScatterOrSliceOperandAndResultShape( Value operand, Value result, int64_t tensorAxis, - ArrayRef<MeshAxis> meshAxes, ArrayRef<int64_t> meshShape) { + ArrayRef<GridAxis> gridAxes, ArrayRef<int64_t> gridShape) { ShapedType operandType = cast<ShapedType>(operand.getType()); ShapedType resultType = cast<ShapedType>(result.getType()); for (int64_t axis = 0; axis < operandType.getRank(); ++axis) { @@ -1129,7 +1126,7 @@ static LogicalResult verifyScatterOrSliceOperandAndResultShape( } auto deviceGroupSize = - DimensionSize(collectiveProcessGroupSize(meshAxes, meshShape)); + DimensionSize(collectiveProcessGroupSize(gridAxes, gridShape)); auto operandScatterDimSize = DimensionSize(operandType.getDimSize(tensorAxis)); if (!operandScatterDimSize.isDynamic() && !deviceGroupSize.isDynamic() && @@ -1151,8 +1148,8 @@ static LogicalResult verifyScatterOrSliceOperandAndResultShape( return success(); } -static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, - ArrayRef<MeshAxis> meshAxes, +static RankedTensorType sliceResultType(Type operandType, GridOp grid, + ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) { RankedTensorType operandRankedTensorType = cast<RankedTensorType>(operandType); @@ -1163,29 +1160,29 @@ static RankedTensorType sliceResultType(Type operandType, MeshOp mesh, resultShape[sliceAxis] = operandSliceAxisSize / - DimensionSize(collectiveProcessGroupSize(meshAxes, mesh)); + DimensionSize(collectiveProcessGroupSize(gridAxes, grid)); return operandRankedTensorType.clone(resultShape); } //===----------------------------------------------------------------------===// -// mesh.all_gather op +// shard.all_gather op //===----------------------------------------------------------------------===// LogicalResult AllGatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } auto gatherAxis = getGatherAxis().getSExtValue(); return verifyGatherOperandAndResultShape(getOperand(), getResult(), - gatherAxis, getMeshAxes(), - mesh.value().getShape()); + gatherAxis, getGridAxes(), + grid.value().getShape()); } void AllGatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<EmptyMeshAxesCanonicalizationPattern<AllGatherOp>>(context); + patterns.add<EmptyGridAxesCanonicalizationPattern<AllGatherOp>>(context); } void AllGatherOp::getAsmResultNames( @@ -1194,23 +1191,23 @@ void AllGatherOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.all_reduce op +// shard.all_reduce op //===----------------------------------------------------------------------===// LogicalResult AllReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - return getMeshAndVerifyAxes(*this, symbolTable); + return getGridAndVerifyAxes(*this, symbolTable); } void AllReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<EmptyMeshAxesCanonicalizationPattern<AllReduceOp>>(context); + patterns.add<EmptyGridAxesCanonicalizationPattern<AllReduceOp>>(context); } void AllReduceOp::build(OpBuilder &odsBuilder, OperationState &odsState, - Value input, StringRef mesh, - ArrayRef<MeshAxis> meshAxes, ReductionKind reduction) { - build(odsBuilder, odsState, input.getType(), mesh, meshAxes, input, + Value input, StringRef grid, + ArrayRef<GridAxis> gridAxes, ReductionKind reduction) { + build(odsBuilder, odsState, input.getType(), grid, gridAxes, input, reduction); } @@ -1220,36 +1217,36 @@ void AllReduceOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.all_slice op +// shard.all_slice op //===----------------------------------------------------------------------===// LogicalResult AllSliceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } return verifyScatterOrSliceOperandAndResultShape( - getOperand(), getResult(), getSliceAxis().getSExtValue(), getMeshAxes(), - mesh.value().getShape()); + getOperand(), getResult(), getSliceAxis().getSExtValue(), getGridAxes(), + grid.value().getShape()); } void AllSliceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<EmptyMeshAxesCanonicalizationPattern<AllSliceOp>>(context); + patterns.add<EmptyGridAxesCanonicalizationPattern<AllSliceOp>>(context); } void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, - Value input, MeshOp mesh, ArrayRef<MeshAxis> meshAxes, + Value input, GridOp grid, ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) { - Type resultType = sliceResultType(input.getType(), mesh, meshAxes, sliceAxis); - build(odsBuilder, odsState, resultType, input, mesh.getSymName(), meshAxes, + Type resultType = sliceResultType(input.getType(), grid, gridAxes, sliceAxis); + build(odsBuilder, odsState, resultType, input, grid.getSymName(), gridAxes, sliceAxis); } void AllSliceOp::build(OpBuilder &odsBuilder, OperationState &odsState, - Type resultType, Value input, StringRef mesh, - ArrayRef<MeshAxis> meshAxes, int64_t sliceAxis) { - build(odsBuilder, odsState, resultType, mesh, meshAxes, input, + Type resultType, Value input, StringRef grid, + ArrayRef<GridAxis> gridAxes, int64_t sliceAxis) { + build(odsBuilder, odsState, resultType, grid, gridAxes, input, APInt(sizeof(sliceAxis) * CHAR_BIT, sliceAxis)); } @@ -1259,23 +1256,23 @@ void AllSliceOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.all_to_all op +// shard.all_to_all op //===----------------------------------------------------------------------===// LogicalResult AllToAllOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } return verifyAllToAllOperandAndResultShape( getOperand(), getResult(), getSplitAxis().getSExtValue(), - getConcatAxis().getSExtValue(), getMeshAxes(), mesh.value().getShape()); + getConcatAxis().getSExtValue(), getGridAxes(), grid.value().getShape()); } void AllToAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<EmptyMeshAxesCanonicalizationPattern<AllToAllOp>>(context); + patterns.add<EmptyGridAxesCanonicalizationPattern<AllToAllOp>>(context); } void AllToAllOp::getAsmResultNames( @@ -1284,18 +1281,18 @@ void AllToAllOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.broadcast op +// shard.broadcast op //===----------------------------------------------------------------------===// LogicalResult BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), - getRootDynamic(), getMeshAxes(), - mesh.value().getShape()))) { + getRootDynamic(), getGridAxes(), + grid.value().getShape()))) { return failure(); } @@ -1304,7 +1301,7 @@ BroadcastOp::verifySymbolUses(SymbolTableCollection &symbolTable) { void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<EmptyMeshAxesCanonicalizationPattern<BroadcastOp>>(context); + patterns.add<EmptyGridAxesCanonicalizationPattern<BroadcastOp>>(context); } void BroadcastOp::getAsmResultNames( @@ -1313,29 +1310,29 @@ void BroadcastOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.gather op +// shard.gather op //===----------------------------------------------------------------------===// LogicalResult GatherOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), - getRootDynamic(), getMeshAxes(), - mesh.value().getShape()))) { + getRootDynamic(), getGridAxes(), + grid.value().getShape()))) { return failure(); } auto gatherAxis = getGatherAxis().getSExtValue(); return verifyGatherOperandAndResultShape(getInput(), getResult(), gatherAxis, - getMeshAxes(), - mesh.value().getShape()); + getGridAxes(), + grid.value().getShape()); } void GatherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<EmptyMeshAxesCanonicalizationPattern<GatherOp>>(context); + patterns.add<EmptyGridAxesCanonicalizationPattern<GatherOp>>(context); } void GatherOp::getAsmResultNames( @@ -1344,18 +1341,18 @@ void GatherOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.recv op +// shard.recv op //===----------------------------------------------------------------------===// LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (getSource() && failed(verifyInGroupDevice(getLoc(), getSourceAttrName(), getSource().value(), getSourceDynamic(), - getMeshAxes(), mesh.value().getShape()))) { + getGridAxes(), grid.value().getShape()))) { return failure(); } return success(); @@ -1363,7 +1360,7 @@ LogicalResult RecvOp::verifySymbolUses(SymbolTableCollection &symbolTable) { void RecvOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<EmptyMeshAxesCanonicalizationPattern<RecvOp>>(context); + patterns.add<EmptyGridAxesCanonicalizationPattern<RecvOp>>(context); } void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { @@ -1371,17 +1368,17 @@ void RecvOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { } //===----------------------------------------------------------------------===// -// mesh.reduce op +// shard.reduce op //===----------------------------------------------------------------------===// LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), - getRootDynamic(), getMeshAxes(), - mesh.value().getShape()))) { + getRootDynamic(), getGridAxes(), + grid.value().getShape()))) { return failure(); } @@ -1390,7 +1387,7 @@ LogicalResult ReduceOp::verifySymbolUses(SymbolTableCollection &symbolTable) { void ReduceOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceOp>>(context); + patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceOp>>(context); } void ReduceOp::getAsmResultNames( @@ -1399,24 +1396,24 @@ void ReduceOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.reduce_scatter op +// shard.reduce_scatter op //===----------------------------------------------------------------------===// LogicalResult ReduceScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } return verifyScatterOrSliceOperandAndResultShape( - getOperand(), getResult(), getScatterAxis().getSExtValue(), getMeshAxes(), - mesh.value().getShape()); + getOperand(), getResult(), getScatterAxis().getSExtValue(), getGridAxes(), + grid.value().getShape()); } void ReduceScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<EmptyMeshAxesCanonicalizationPattern<ReduceScatterOp>>(context); + patterns.add<EmptyGridAxesCanonicalizationPattern<ReduceScatterOp>>(context); } void ReduceScatterOp::getAsmResultNames( @@ -1425,29 +1422,29 @@ void ReduceScatterOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.scatter op +// shard.scatter op //===----------------------------------------------------------------------===// LogicalResult ScatterOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (failed(verifyInGroupDevice(getLoc(), getRootAttrName(), getRoot(), - getRootDynamic(), getMeshAxes(), - mesh.value().getShape()))) { + getRootDynamic(), getGridAxes(), + grid.value().getShape()))) { return failure(); } auto scatterAxis = getScatterAxis().getSExtValue(); return verifyScatterOrSliceOperandAndResultShape(getInput(), getResult(), - scatterAxis, getMeshAxes(), - mesh.value().getShape()); + scatterAxis, getGridAxes(), + grid.value().getShape()); } void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<EmptyMeshAxesCanonicalizationPattern<ScatterOp>>(context); + patterns.add<EmptyGridAxesCanonicalizationPattern<ScatterOp>>(context); } void ScatterOp::getAsmResultNames( @@ -1456,17 +1453,17 @@ void ScatterOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.send op +// shard.send op //===----------------------------------------------------------------------===// LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } if (failed(verifyInGroupDevice(getLoc(), getDestinationAttrName(), getDestination(), getDestinationDynamic(), - getMeshAxes(), mesh.value().getShape()))) { + getGridAxes(), grid.value().getShape()))) { return failure(); } return success(); @@ -1474,7 +1471,7 @@ LogicalResult SendOp::verifySymbolUses(SymbolTableCollection &symbolTable) { void SendOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add<EmptyMeshAxesCanonicalizationPattern<SendOp>>(context); + patterns.add<EmptyGridAxesCanonicalizationPattern<SendOp>>(context); } void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { @@ -1482,20 +1479,20 @@ void SendOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { } //===----------------------------------------------------------------------===// -// mesh.shift op +// shard.shift op //===----------------------------------------------------------------------===// LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerifyAxes(*this, symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerifyAxes(*this, symbolTable); + if (failed(grid)) { return failure(); } - auto meshAxes = getMeshAxes(); + auto gridAxes = getGridAxes(); auto shiftAxis = getShiftAxis().getZExtValue(); - if (!llvm::is_contained(meshAxes, shiftAxis)) { + if (!llvm::is_contained(gridAxes, shiftAxis)) { return emitError() << "Invalid shift axis " << shiftAxis - << ". It must be one of the grouping mesh axes."; + << ". It must be one of the grouping grid axes."; } return success(); @@ -1504,7 +1501,7 @@ LogicalResult ShiftOp::verifySymbolUses(SymbolTableCollection &symbolTable) { void ShiftOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { // TODO: remove op when offset is 0 or if it is a rotate with and - // offset % shift_axis_mesh_dim_size == 0. + // offset % shift_axis_grid_dim_size == 0. } void ShiftOp::getAsmResultNames( @@ -1513,13 +1510,13 @@ void ShiftOp::getAsmResultNames( } //===----------------------------------------------------------------------===// -// mesh.update_halo op +// shard.update_halo op //===----------------------------------------------------------------------===// LogicalResult UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto mesh = getMeshAndVerify(getOperation(), getMeshAttr(), symbolTable); - if (failed(mesh)) { + auto grid = getGridAndVerify(getOperation(), getGridAttr(), symbolTable); + if (failed(grid)) { return failure(); } @@ -1531,12 +1528,12 @@ UpdateHaloOp::verifySymbolUses(SymbolTableCollection &symbolTable) { //===----------------------------------------------------------------------===// #define GET_OP_CLASSES -#include "mlir/Dialect/Mesh/IR/MeshOps.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardOps.cpp.inc" #define GET_ATTRDEF_CLASSES -#include "mlir/Dialect/Mesh/IR/MeshAttributes.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardAttributes.cpp.inc" #define GET_TYPEDEF_CLASSES -#include "mlir/Dialect/Mesh/IR/MeshTypes.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardTypes.cpp.inc" -#include "mlir/Dialect/Mesh/IR/MeshEnums.cpp.inc" +#include "mlir/Dialect/Shard/IR/ShardEnums.cpp.inc" diff --git a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt index afe76b5..01e8e56 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/CMakeLists.txt +++ b/mlir/lib/Dialect/Shard/Interfaces/CMakeLists.txt @@ -2,7 +2,7 @@ add_mlir_library(MLIRShardingInterface ShardingInterface.cpp ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard DEPENDS MLIRShardingInterfaceIncGen @@ -10,7 +10,7 @@ add_mlir_library(MLIRShardingInterface LINK_LIBS PUBLIC MLIRDialectUtils MLIRIR - MLIRMeshDialect + MLIRShardDialect MLIRTensorDialect MLIRSupport ) diff --git a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp index 6b3d49e..d4e7618 100644 --- a/mlir/lib/Dialect/Mesh/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/IRMapping.h" #include "mlir/Support/LLVM.h" @@ -24,9 +24,9 @@ #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; -using namespace mlir::mesh; +using namespace mlir::shard; -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.cpp.inc" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.cpp.inc" //===----------------------------------------------------------------------===// // common util functions @@ -93,40 +93,39 @@ checkOperandAffineExpr(AffineExpr expr, unsigned numDims) { } template <typename T> -SmallVector<MeshAxesAttr> +SmallVector<GridAxesAttr> fromArrayOfVector(MLIRContext *ctxt, const SmallVector<SmallVector<T>> &vec) { - SmallVector<MeshAxesAttr> res; + SmallVector<GridAxesAttr> res; for (const auto &v : vec) { - res.emplace_back(MeshAxesAttr::get(ctxt, v)); + res.emplace_back(GridAxesAttr::get(ctxt, v)); } return res; } //===----------------------------------------------------------------------===// -// mesh::getMeshSharding +// shard::getSharding //===----------------------------------------------------------------------===// -FailureOr<std::pair<bool, MeshSharding>> -mesh::getMeshSharding(OpResult result) { +FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpResult result) { Value val = cast<Value>(result); bool anyShardedForDef = llvm::any_of(val.getUsers(), [](Operation *user) { - auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); + auto shardOp = llvm::dyn_cast<shard::ShardOp>(user); if (!shardOp) return false; return !shardOp.getAnnotateForUsers(); }); if (anyShardedForDef) { - // expected to have exact one use if it has a use of `mesh.shard` without + // expected to have exact one use if it has a use of `shard.shard` without // unit attr annotate_for_users if (!val.hasOneUse()) return failure(); - auto shardOp = llvm::cast<mesh::ShardOp>(*val.getUsers().begin()); - return std::make_pair(false, MeshSharding(shardOp.getSharding())); + auto shardOp = llvm::cast<shard::ShardOp>(*val.getUsers().begin()); + return std::make_pair(false, Sharding(shardOp.getSharding())); } bool anyShardedForUsers = llvm::any_of(val.getUsers(), [](Operation *user) { - auto shardOp = llvm::dyn_cast<mesh::ShardOp>(user); + auto shardOp = llvm::dyn_cast<shard::ShardOp>(user); if (!shardOp) return false; return shardOp.getAnnotateForUsers(); @@ -138,24 +137,23 @@ mesh::getMeshSharding(OpResult result) { if (shardOp) shardOps.push_back(shardOp); } - MeshSharding shardForDef = shardOps[0].getSharding(); + Sharding shardForDef = shardOps[0].getSharding(); for (size_t i = 1; i < shardOps.size(); ++i) { - // TODO: Deduce a reasonable mesh sharding attr for def when they are + // TODO: Deduce a reasonable grid sharding attr for def when they are // different assert(shardForDef == shardOps[i].getSharding() && - "only support all shard ops have the same mesh sharding attr"); + "only support all shard ops have the same grid sharding attr"); } return std::make_pair(true, shardForDef); } return failure(); } -FailureOr<std::pair<bool, MeshSharding>> -mesh::getMeshSharding(OpOperand &opOperand) { +FailureOr<std::pair<bool, Sharding>> shard::getSharding(OpOperand &opOperand) { Value val = opOperand.get(); if (ShardOp shardOp = val.getDefiningOp<ShardOp>()) return std::make_pair(shardOp.getAnnotateForUsers(), - MeshSharding(shardOp.getSharding())); + Sharding(shardOp.getSharding())); return failure(); } @@ -164,7 +162,7 @@ mesh::getMeshSharding(OpOperand &opOperand) { // ShardingInterface::verifyShardingInterfaceImpl //===----------------------------------------------------------------------===// -LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { +LogicalResult shard::ShardingInterface::verifyShardingInterfaceImpl() { Operation *op = getOperation(); // check operands and results type @@ -201,7 +199,7 @@ LogicalResult mesh::ShardingInterface::verifyShardingInterfaceImpl() { // ShardingInterface::printLoopTypesAndIndexingMaps //===----------------------------------------------------------------------===// -void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { +void shard::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { os << "print loop types and indexing maps for: \n"; getOperation()->print(os); os << "\n"; @@ -222,15 +220,15 @@ void mesh::ShardingInterface::printLoopTypesAndIndexingMaps(raw_ostream &os) { namespace { -// Update the given `shardingOption` according to `meshAxes` and `loopIdx` +// Update the given `shardingOption` according to `gridAxes` and `loopIdx` static LogicalResult fillShardingOption(Operation *op, ShardingOption &shardingOption, - FlatSymbolRefAttr mesh, - ArrayRef<MeshAxis> meshAxes, + FlatSymbolRefAttr grid, + ArrayRef<GridAxis> gridAxes, unsigned loopIdx) { - if ((shardingOption.mesh && mesh && shardingOption.mesh != mesh) || + if ((shardingOption.grid && grid && shardingOption.grid != grid) || (!shardingOption.shardingArray[loopIdx].empty() && - shardingOption.shardingArray[loopIdx] != meshAxes)) { + shardingOption.shardingArray[loopIdx] != gridAxes)) { LLVM_DEBUG(DBGS() << "sharding option conflicts on loop iterator " << loopIdx << "\n"); return failure(); @@ -239,28 +237,28 @@ static LogicalResult fillShardingOption(Operation *op, if (i == loopIdx) continue; - for (MeshAxis axis : meshAxes) { + for (GridAxis axis : gridAxes) { if (llvm::is_contained(shardingOption.shardingArray[i], axis)) { - LLVM_DEBUG(DBGS() << "sharding option conflicts because mesh axes " + LLVM_DEBUG(DBGS() << "sharding option conflicts because grid axes " << axis << " duplicate"); return failure(); } } } - if (mesh) - shardingOption.mesh = mesh; + if (grid) + shardingOption.grid = grid; if (shardingOption.shardingArray[loopIdx].empty()) - shardingOption.shardingArray[loopIdx].append(meshAxes.begin(), - meshAxes.end()); + shardingOption.shardingArray[loopIdx].append(gridAxes.begin(), + gridAxes.end()); return success(); } } // namespace FailureOr<ShardingOption> -mesh::detail::defaultGetShardingOption(Operation *op, - ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings) { +shard::detail::defaultGetShardingOption(Operation *op, + ArrayRef<Sharding> operandShardings, + ArrayRef<Sharding> resultShardings) { ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); ShardingOption shardingOption; @@ -276,25 +274,25 @@ mesh::detail::defaultGetShardingOption(Operation *op, // 1. Fill sharding option based on op results for (auto shardingIt : llvm::enumerate(resultShardings)) { - MeshSharding shardAttr = shardingIt.value(); + Sharding shardAttr = shardingIt.value(); if (!shardAttr) continue; AffineMap map = maps[numOperands + shardingIt.index()]; anyShardingInResultsOrOperands = true; if (shardAttr.getSplitAxes().empty() || map.getResults().empty()) { - shardingOption.mesh = shardAttr.getMeshAttr(); + shardingOption.grid = shardAttr.getGridAttr(); } else { // Handle the split axes: calculate the corresponding loop index for each // split axes sub-array, and then store the sub-array to // shardingOption[index] for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { AffineExpr expr = std::get<0>(it); - ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); + ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef(); auto dim = cast<AffineDimExpr>(expr); unsigned index = dim.getPosition(); visitedLoopIndices.insert(index); if (failed(fillShardingOption(op, shardingOption, - shardAttr.getMeshAttr(), axes, index))) + shardAttr.getGridAttr(), axes, index))) return failure(); } } @@ -302,7 +300,7 @@ mesh::detail::defaultGetShardingOption(Operation *op, // 2. Fill sharding option based on operands for (auto shardingIt : llvm::enumerate(operandShardings)) { - MeshSharding shardAttr = shardingIt.value(); + Sharding shardAttr = shardingIt.value(); if (!shardAttr) continue; @@ -316,7 +314,7 @@ mesh::detail::defaultGetShardingOption(Operation *op, // then the operands with multiple loop indices. for (auto it : llvm::zip(map.getResults(), shardAttr.getSplitAxes())) { AffineExpr expr = std::get<0>(it); - ArrayRef<MeshAxis> axes = std::get<1>(it).asArrayRef(); + ArrayRef<GridAxis> axes = std::get<1>(it).asArrayRef(); FailureOr<llvm::SmallSet<unsigned, 2>> loopIndices = checkOperandAffineExpr(expr, numDims); if (failed(loopIndices)) @@ -329,7 +327,7 @@ mesh::detail::defaultGetShardingOption(Operation *op, unsigned loopIdx = *loopIndices->begin(); visitedLoopIndices.insert(loopIdx); if (failed(fillShardingOption(op, shardingOption, - shardAttr.getMeshAttr(), axes, loopIdx))) + shardAttr.getGridAttr(), axes, loopIdx))) return failure(); } // If multiple loop indices correspond to a dimension of an operand, it is @@ -361,11 +359,11 @@ mesh::detail::defaultGetShardingOption(Operation *op, } // Get the sharding attributed for the given result and sharding option. -MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, - AffineMap map, - ArrayRef<utils::IteratorType> loopTypes) { +static Sharding getSharding(OpResult result, + const ShardingOption &shardingOption, AffineMap map, + ArrayRef<utils::IteratorType> loopTypes) { auto resultType = cast<RankedTensorType>(result.getType()); - SmallVector<SmallVector<MeshAxis>> splitAxes(resultType.getRank()); + SmallVector<SmallVector<GridAxis>> splitAxes(resultType.getRank()); // process the split axes for (auto it : llvm::enumerate(map.getResults())) { @@ -379,25 +377,25 @@ MeshSharding getSharding(OpResult result, const ShardingOption &shardingOption, } removeTrailingEmptySubArray(splitAxes); - return MeshSharding::get(shardingOption.mesh, - fromArrayOfVector(result.getContext(), splitAxes)); + return Sharding::get(shardingOption.grid, + fromArrayOfVector(result.getContext(), splitAxes)); } -static FailureOr<MeshSharding> getSharding(OpOperand &opOperand, - const ShardingOption &shardingOption, - AffineMap map) { +static FailureOr<Sharding> getSharding(OpOperand &opOperand, + const ShardingOption &shardingOption, + AffineMap map) { Value operandValue = opOperand.get(); auto operandType = dyn_cast<RankedTensorType>(operandValue.getType()); if (!operandType) { if (operandValue.getType().isIntOrIndexOrFloat()) - return MeshSharding(); + return Sharding(); return failure(); } // 0d tensors cannot be sharded and must get replicated if (operandType.getRank() == 0) { - return MeshSharding(shardingOption.mesh); + return Sharding(shardingOption.grid); } - SmallVector<SmallVector<MeshAxis>> splitAxes(operandType.getRank()); + SmallVector<SmallVector<GridAxis>> splitAxes(operandType.getRank()); unsigned numDims = map.getNumDims(); for (auto it : llvm::enumerate(map.getResults())) { int64_t idx = it.index(); @@ -422,15 +420,14 @@ static FailureOr<MeshSharding> getSharding(OpOperand &opOperand, } removeTrailingEmptySubArray(splitAxes); - return MeshSharding::get( - shardingOption.mesh, + return Sharding::get( + shardingOption.grid, fromArrayOfVector(opOperand.get().getContext(), splitAxes)); } -FailureOr<std::vector<MeshSharding>> -mesh::detail::defaultGetShardingAnnotations( +FailureOr<std::vector<Sharding>> shard::detail::defaultGetShardingAnnotations( Operation *op, const ShardingOption &shardingOption) { - std::vector<MeshSharding> res; + std::vector<Sharding> res; ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); SmallVector<utils::IteratorType> loopTypes = @@ -439,7 +436,7 @@ mesh::detail::defaultGetShardingAnnotations( unsigned numOperands = op->getNumOperands(); for (OpOperand &opOperand : op->getOpOperands()) { - FailureOr<MeshSharding> shardingAttr = getSharding( + FailureOr<Sharding> shardingAttr = ::getSharding( opOperand, shardingOption, maps[opOperand.getOperandNumber()]); if (failed(shardingAttr)) return failure(); @@ -447,9 +444,9 @@ mesh::detail::defaultGetShardingAnnotations( } for (OpResult result : op->getResults()) { - res.push_back(getSharding(result, shardingOption, - maps[numOperands + result.getResultNumber()], - loopTypes)); + res.push_back(::getSharding(result, shardingOption, + maps[numOperands + result.getResultNumber()], + loopTypes)); } return res; @@ -459,26 +456,25 @@ mesh::detail::defaultGetShardingAnnotations( // detail::defaultAddShardingAnnotations //===----------------------------------------------------------------------===// -// To add a `mesh.shard` op for the given result, based on the details provided +// To add a `shard.shard` op for the given result, based on the details provided // in `shardingOption`, `map`, and `loopTypes`. static LogicalResult addShardOp(OpBuilder &b, OpResult result, const ShardingOption &shardingOption, AffineMap map, ArrayRef<utils::IteratorType> loopTypes) { - MeshSharding sharding = getSharding(result, shardingOption, map, loopTypes); + Sharding sharding = getSharding(result, shardingOption, map, loopTypes); maybeInsertTargetShardingAnnotation(sharding, result, b); return success(); } -// To add a `mesh.shard` op for the given operand, based on the details provided -// in `shardingOption`, `map`, and `loopTypes`. +// To add a `shard.shard` op for the given operand, based on the details +// provided in `shardingOption`, `map`, and `loopTypes`. static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, const ShardingOption &shardingOption, AffineMap map) { - FailureOr<MeshSharding> sharding = - getSharding(opOperand, shardingOption, map); + FailureOr<Sharding> sharding = getSharding(opOperand, shardingOption, map); if (failed(sharding)) { return failure(); } @@ -488,9 +484,9 @@ static LogicalResult addShardOp(OpBuilder &b, OpOperand &opOperand, return success(); } -LogicalResult mesh::detail::defaultAddShardingAnnotations( +LogicalResult shard::detail::defaultAddShardingAnnotations( Operation *op, OpBuilder &b, const ShardingOption &shardingOption) { - assert(!shardingOption.empty && shardingOption.mesh); + assert(!shardingOption.empty && shardingOption.grid); ShardingInterface shardingOp = llvm::cast<ShardingInterface>(op); SmallVector<utils::IteratorType> loopTypes = @@ -498,7 +494,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations( SmallVector<AffineMap> maps = shardingOp.getIndexingMaps(); unsigned numOperands = op->getNumOperands(); - // 1. add mesh.shard ops for all op results + // 1. add shard.shard ops for all op results for (OpResult result : op->getResults()) { if (failed(addShardOp(b, result, shardingOption, maps[numOperands + result.getResultNumber()], @@ -506,7 +502,7 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations( return failure(); } - // 2. add mesh.shard ops for all operands + // 2. add shard.shard ops for all operands for (OpOperand &opOperand : op->getOpOperands()) { if (failed(addShardOp(b, opOperand, shardingOption, maps[opOperand.getOperandNumber()]))) @@ -517,9 +513,8 @@ LogicalResult mesh::detail::defaultAddShardingAnnotations( } #ifndef NDEBUG -static bool -isValueCompatibleWithFullReplicationSharding(Value value, - MeshSharding sharding) { +static bool isValueCompatibleWithFullReplicationSharding(Value value, + Sharding sharding) { if (isa<RankedTensorType>(value.getType())) { return isFullReplication(sharding); } @@ -527,60 +522,59 @@ isValueCompatibleWithFullReplicationSharding(Value value, return !sharding; } -template <typename ValueRange, typename MeshShardingRage> +template <typename ValueRange, typename ShardingRage> static bool areValuesCompatibleWithFullReplicationShardings(ValueRange &&values, - MeshShardingRage &&shardings) { + ShardingRage &&shardings) { if (std::size(values) != std::size(shardings)) { return false; } - return llvm::all_of( - llvm::zip_equal(std::forward<ValueRange>(values), - std::forward<MeshShardingRage>(shardings)), - [](auto valueAndSharding) { - return isValueCompatibleWithFullReplicationSharding( - std::get<0>(valueAndSharding), std::get<1>(valueAndSharding)); - }); + return llvm::all_of(llvm::zip_equal(std::forward<ValueRange>(values), + std::forward<ShardingRage>(shardings)), + [](auto valueAndSharding) { + return isValueCompatibleWithFullReplicationSharding( + std::get<0>(valueAndSharding), + std::get<1>(valueAndSharding)); + }); } #endif // NDEBUG -void mesh::spmdizeFullyReplicatedOperation( - Operation &op, ArrayRef<Value> spmdizedOperands, - ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, OpBuilder &builder) { - assert(spmdizedOperands.size() == operandShardings.size()); +void shard::partitionFullyReplicatedOperation( + Operation &op, ArrayRef<Value> partitionedOperands, + ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings, + IRMapping &partitionMap, SymbolTableCollection &symbolTable, + OpBuilder &builder) { + assert(partitionedOperands.size() == operandShardings.size()); assert(areValuesCompatibleWithFullReplicationShardings(op.getOperands(), operandShardings)); assert(areValuesCompatibleWithFullReplicationShardings(op.getResults(), resultShardings)); // `clone` will populate the mapping of old to new results. - builder.clone(op, spmdizationMap); + builder.clone(op, partitionMap); } -static void updateMeshAxisAssignmentForLoopIterators( - ArrayRef<MeshAxis> meshAxesAssignmentForTensorAxis, AffineExpr indexingExpr, - SmallVector<std::optional<SmallVector<MeshAxis>>> - &meshAxesAssignmentForLoopIterators) { +static void updateGridAxisAssignmentForLoopIterators( + ArrayRef<GridAxis> gridAxesAssignmentForTensorAxis, AffineExpr indexingExpr, + SmallVector<std::optional<SmallVector<GridAxis>>> + &gridAxesAssignmentForLoopIterators) { AffineDimExpr affineDimExpr = cast<AffineDimExpr>(indexingExpr); unsigned loopIteratorIdx = affineDimExpr.getPosition(); - if (meshAxesAssignmentForLoopIterators[loopIteratorIdx]) { - assert(llvm::equal(meshAxesAssignmentForTensorAxis, - *meshAxesAssignmentForLoopIterators[loopIteratorIdx])); + if (gridAxesAssignmentForLoopIterators[loopIteratorIdx]) { + assert(llvm::equal(gridAxesAssignmentForTensorAxis, + *gridAxesAssignmentForLoopIterators[loopIteratorIdx])); } else { - meshAxesAssignmentForLoopIterators[loopIteratorIdx] = - llvm::to_vector(meshAxesAssignmentForTensorAxis); + gridAxesAssignmentForLoopIterators[loopIteratorIdx] = + llvm::to_vector(gridAxesAssignmentForTensorAxis); } } -ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( - ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings, +ShardingArray shard::getGridAxisAssignmentForLoopIterators( + ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings, ArrayRef<utils::IteratorType> loopIteratorTypes, ArrayRef<AffineMap> indexingMaps) { - SmallVector<std::optional<SmallVector<MeshAxis>>> - meshAxisAssignmentForLoopIterators(loopIteratorTypes.size()); - std::vector<MeshSharding> operatorAndResultShardings; + SmallVector<std::optional<SmallVector<GridAxis>>> + gridAxisAssignmentForLoopIterators(loopIteratorTypes.size()); + std::vector<Sharding> operatorAndResultShardings; operatorAndResultShardings.reserve(operandShardings.size() + resultShardings.size()); llvm::append_range(operatorAndResultShardings, operandShardings); @@ -589,69 +583,69 @@ ShardingArray mesh::getMeshAxisAssignmentForLoopIterators( if (!sharding) { continue; } - for (auto [meshAxesAssignmentForTensorAxis, indexingExpr] : + for (auto [gridAxesAssignmentForTensorAxis, indexingExpr] : llvm::zip(sharding.getSplitAxes(), affineMap.getResults())) { - updateMeshAxisAssignmentForLoopIterators( - meshAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, - meshAxisAssignmentForLoopIterators); + updateGridAxisAssignmentForLoopIterators( + gridAxesAssignmentForTensorAxis.asArrayRef(), indexingExpr, + gridAxisAssignmentForLoopIterators); } // Missing trailing split axes means replication on those tensor dimensions. for (unsigned i = sharding.getSplitAxes().size(); i < affineMap.getNumResults(); ++i) { - updateMeshAxisAssignmentForLoopIterators( - {}, affineMap.getResults()[i], meshAxisAssignmentForLoopIterators); + updateGridAxisAssignmentForLoopIterators( + {}, affineMap.getResults()[i], gridAxisAssignmentForLoopIterators); } } ShardingArray res; - llvm::transform(meshAxisAssignmentForLoopIterators, std::back_inserter(res), - [](std::optional<SmallVector<MeshAxis>> &axes) { + llvm::transform(gridAxisAssignmentForLoopIterators, std::back_inserter(res), + [](std::optional<SmallVector<GridAxis>> &axes) { if (!axes) { - return SmallVector<MeshAxis>(); + return SmallVector<GridAxis>(); }; return std::move(*axes); }); return res; } -bool mesh::isAtLeastOneReductionIteratorSharded( +bool shard::isAtLeastOneReductionIteratorSharded( ArrayRef<utils::IteratorType> loopIteratorTypes, - ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { - for (auto [loopIteratorType, meshAxisAssignment] : - llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) { + for (auto [loopIteratorType, gridAxisAssignment] : + llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) { if (loopIteratorType == utils::IteratorType::reduction && - !meshAxisAssignment.empty()) { + !gridAxisAssignment.empty()) { return true; } } return false; } -SmallVector<MeshAxis> mesh::getReductionMeshAxes( +SmallVector<GridAxis> shard::getReductionGridAxes( ArrayRef<utils::IteratorType> loopIteratorTypes, - ArrayRef<SmallVector<MeshAxis>> meshAxisAssignmentForLoopIterators) { - SmallVector<MeshAxis> meshAxes; - for (auto [loopIteratorType, meshAxisAssignment] : - llvm::zip_equal(loopIteratorTypes, meshAxisAssignmentForLoopIterators)) { + ArrayRef<SmallVector<GridAxis>> gridAxisAssignmentForLoopIterators) { + SmallVector<GridAxis> gridAxes; + for (auto [loopIteratorType, gridAxisAssignment] : + llvm::zip_equal(loopIteratorTypes, gridAxisAssignmentForLoopIterators)) { if (loopIteratorType == utils::IteratorType::reduction) { - llvm::append_range(meshAxes, meshAxisAssignment); + llvm::append_range(gridAxes, gridAxisAssignment); } } - return meshAxes; + return gridAxes; } -void mesh::spmdizeTriviallyShardableOperation( - Operation &op, ArrayRef<Value> spmdizedOperands, - ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, OpBuilder &builder) { +void shard::partitionTriviallyShardableOperation( + Operation &op, ArrayRef<Value> partitionedOperands, + ArrayRef<Sharding> operandShardings, ArrayRef<Sharding> resultShardings, + IRMapping &partitionMap, SymbolTableCollection &symbolTable, + OpBuilder &builder) { // `clone` will populate the mapping of old to new results. - Operation *newOp = builder.clone(op, spmdizationMap); + Operation *newOp = builder.clone(op, partitionMap); // Set the result types to the sharded counterparts. for (auto [oldResult, newResult, sharding] : llvm::zip_equal(op.getResults(), newOp->getResults(), resultShardings)) { newResult.setType(shardType( newResult.getType(), - getMeshOrNull(&op, sharding.getMeshAttr(), symbolTable), sharding)); + getGridOrNull(&op, sharding.getGridAttr(), symbolTable), sharding)); } } diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt index 381bc9a..a884764 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Shard/Transforms/CMakeLists.txt @@ -1,14 +1,14 @@ -add_mlir_dialect_library(MLIRMeshTransforms +add_mlir_dialect_library(MLIRShardTransforms Simplifications.cpp ShardingPropagation.cpp - Spmdization.cpp + Partition.cpp Transforms.cpp ADDITIONAL_HEADER_DIRS - ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shard DEPENDS - MLIRMeshPassIncGen + MLIRShardPassIncGen MLIRShardingInterface LINK_LIBS PUBLIC @@ -21,7 +21,7 @@ add_mlir_dialect_library(MLIRMeshTransforms MLIRFuncDialect MLIRFunctionInterfaces MLIRIR - MLIRMeshDialect + MLIRShardDialect MLIRPass MLIRSupport MLIRTensorDialect diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index c6e76ec..3e3d476 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -1,4 +1,4 @@ -//===- Spmdization.cpp --------------------------------------------- C++ --===// +//===- Partition.cpp --------------------------------------------- C++ --===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Transforms/Spmdization.h" +#include "mlir/Dialect/Shard/Transforms/Partition.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" @@ -33,7 +33,7 @@ #include <optional> #include <tuple> -namespace mlir::mesh { +namespace mlir::shard { template <typename SourceAxes, typename TargetAxes> static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, @@ -43,52 +43,49 @@ static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, }); } -static MeshSharding targetShardingInSplitLastAxis(MLIRContext *ctx, - MeshSharding sourceSharding, - int64_t splitTensorAxis, - MeshAxis splitMeshAxis) { - SmallVector<MeshAxesAttr> targetShardingSplitAxes = +static Sharding targetShardingInSplitLastAxis(MLIRContext *ctx, + Sharding sourceSharding, + int64_t splitTensorAxis, + GridAxis splitGridAxis) { + SmallVector<GridAxesAttr> targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= splitTensorAxis) { - targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); + targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {})); } auto targetSplitAxes = llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); - targetSplitAxes.push_back(splitMeshAxis); + targetSplitAxes.push_back(splitGridAxis); targetShardingSplitAxes[splitTensorAxis] = - MeshAxesAttr::get(ctx, targetSplitAxes); - return MeshSharding::get(sourceSharding.getMeshAttr(), - targetShardingSplitAxes); + GridAxesAttr::get(ctx, targetSplitAxes); + return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes); } -// Split a replicated tensor along a mesh axis. +// Split a replicated tensor along a grid axis. // E.g. [[0, 1]] -> [[0, 1, 2]]. -// Returns the spmdized target value with its sharding. -static std::tuple<TypedValue<ShapedType>, MeshSharding> +// Returns the partitioned target value with its sharding. +static std::tuple<TypedValue<ShapedType>, Sharding> splitLastAxisInResharding(ImplicitLocOpBuilder &builder, - MeshSharding sourceSharding, - TypedValue<ShapedType> sourceShard, MeshOp mesh, - int64_t splitTensorAxis, MeshAxis splitMeshAxis) { + Sharding sourceSharding, + TypedValue<ShapedType> sourceShard, GridOp grid, + int64_t splitTensorAxis, GridAxis splitGridAxis) { TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( - builder - .create<AllSliceOp>(sourceShard, mesh, - ArrayRef<MeshAxis>(splitMeshAxis), - splitTensorAxis) + AllSliceOp::create(builder, sourceShard, grid, + ArrayRef<GridAxis>(splitGridAxis), splitTensorAxis) .getResult()); - MeshSharding targetSharding = targetShardingInSplitLastAxis( - builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); + Sharding targetSharding = targetShardingInSplitLastAxis( + builder.getContext(), sourceSharding, splitTensorAxis, splitGridAxis); return {targetShard, targetSharding}; } // Detect if the resharding is of type e.g. // [[0, 1]] -> [[0, 1, 2]]. -// If detected, returns the corresponding tensor axis mesh axis pair. +// If detected, returns the corresponding tensor axis grid axis pair. // Does not detect insertions like // [[0, 1]] -> [[0, 2, 1]]. -static std::optional<std::tuple<int64_t, MeshAxis>> -detectSplitLastAxisInResharding(MeshSharding sourceSharding, - MeshSharding targetSharding) { +static std::optional<std::tuple<int64_t, GridAxis>> +detectSplitLastAxisInResharding(Sharding sourceSharding, + Sharding targetSharding) { for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size(); ++tensorAxis) { if (sourceSharding.getSplitAxes().size() > tensorAxis) { @@ -118,16 +115,15 @@ detectSplitLastAxisInResharding(MeshSharding sourceSharding, return std::nullopt; } -static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> -trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, - MeshSharding targetSharding, +static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>> +trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, Sharding targetSharding, TypedValue<ShapedType> sourceShard) { if (auto detectRes = detectSplitLastAxisInResharding(sourceSharding, targetSharding)) { - auto [tensorAxis, meshAxis] = detectRes.value(); - return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh, - tensorAxis, meshAxis); + auto [tensorAxis, gridAxis] = detectRes.value(); + return splitLastAxisInResharding(builder, sourceSharding, sourceShard, grid, + tensorAxis, gridAxis); } return std::nullopt; @@ -135,10 +131,10 @@ trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, // Detect if the resharding is of type e.g. // [[0, 1, 2]] -> [[0, 1]]. -// If detected, returns the corresponding tensor axis mesh axis pair. -static std::optional<std::tuple<int64_t, MeshAxis>> -detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, - MeshSharding targetSharding) { +// If detected, returns the corresponding tensor axis grid axis pair. +static std::optional<std::tuple<int64_t, GridAxis>> +detectUnsplitLastAxisInResharding(Sharding sourceSharding, + Sharding targetSharding) { for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size(); ++tensorAxis) { if (targetSharding.getSplitAxes().size() > tensorAxis) { @@ -165,10 +161,10 @@ detectUnsplitLastAxisInResharding(MeshSharding sourceSharding, return std::nullopt; } -static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, - MeshSharding sourceSharding, - int64_t splitTensorAxis) { - SmallVector<MeshAxesAttr> targetShardingSplitAxes = +static Sharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, + Sharding sourceSharding, + int64_t splitTensorAxis) { + SmallVector<GridAxesAttr> targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); assert(static_cast<int64_t>(targetShardingSplitAxes.size()) > splitTensorAxis); @@ -177,9 +173,8 @@ static MeshSharding targetShardingInUnsplitLastAxis(MLIRContext *ctx, targetSplitAxes.pop_back(); targetShardingSplitAxes[splitTensorAxis] = - MeshAxesAttr::get(ctx, targetSplitAxes); - return MeshSharding::get(sourceSharding.getMeshAttr(), - targetShardingSplitAxes); + GridAxesAttr::get(ctx, targetSplitAxes); + return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes); } static ShapedType allGatherResultShapeInUnsplitLastAxis( @@ -190,45 +185,42 @@ static ShapedType allGatherResultShapeInUnsplitLastAxis( return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); } -static std::tuple<TypedValue<ShapedType>, MeshSharding> -unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, - MeshSharding sourceSharding, - ShapedType sourceUnshardedShape, - TypedValue<ShapedType> sourceShard, MeshOp mesh, - int64_t splitTensorAxis, MeshAxis splitMeshAxis) { +static std::tuple<TypedValue<ShapedType>, Sharding> unsplitLastAxisInResharding( + ImplicitLocOpBuilder &builder, Sharding sourceSharding, + ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard, + GridOp grid, int64_t splitTensorAxis, GridAxis splitGridAxis) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); - MeshSharding targetSharding = + Sharding targetSharding = targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitTensorAxis); ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( - sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis); + sourceShard.getType(), grid.getShape()[splitGridAxis], splitTensorAxis); Value allGatherResult = AllGatherOp::create( builder, RankedTensorType::get(allGatherResultShape.getShape(), allGatherResultShape.getElementType()), - mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard, + grid.getSymName(), SmallVector<GridAxis>({splitGridAxis}), sourceShard, APInt(64, splitTensorAxis)); ShapedType targetShape = - shardShapedType(sourceUnshardedShape, mesh, targetSharding); + shardShapedType(sourceUnshardedShape, grid, targetSharding); TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( tensor::CastOp::create(builder, targetShape, allGatherResult) .getResult()); return {targetShard, targetSharding}; } -static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> -tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, - MeshSharding targetSharding, +static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>> +tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard) { if (auto detectRes = detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) { - auto [tensorAxis, meshAxis] = detectRes.value(); + auto [tensorAxis, gridAxis] = detectRes.value(); return unsplitLastAxisInResharding(builder, sourceSharding, - sourceUnshardedShape, sourceShard, mesh, - tensorAxis, meshAxis); + sourceUnshardedShape, sourceShard, grid, + tensorAxis, gridAxis); } return std::nullopt; @@ -238,10 +230,10 @@ tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, // [[0, 1], [2]] -> [[0], [1, 2]]. // Only moving the last axis counts. // If detected, returns the corresponding (source_tensor_axis, -// target_tensor_axis, mesh_axis) tuple. -static std::optional<std::tuple<int64_t, int64_t, MeshAxis>> -detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, - MeshSharding targetSharding) { +// target_tensor_axis, grid_axis) tuple. +static std::optional<std::tuple<int64_t, int64_t, GridAxis>> +detectMoveLastSplitAxisInResharding(Sharding sourceSharding, + Sharding targetSharding) { for (size_t sourceTensorAxis = 0; sourceTensorAxis < sourceSharding.getSplitAxes().size(); ++sourceTensorAxis) { @@ -281,33 +273,32 @@ detectMoveLastSplitAxisInResharding(MeshSharding sourceSharding, return std::nullopt; } -static MeshSharding targetShardingInMoveLastAxis(MLIRContext *ctx, - MeshSharding sourceSharding, - int64_t sourceTensorAxis, - int64_t targetTensorAxis) { - SmallVector<MeshAxesAttr> targetShardingSplitAxes = +static Sharding targetShardingInMoveLastAxis(MLIRContext *ctx, + Sharding sourceSharding, + int64_t sourceTensorAxis, + int64_t targetTensorAxis) { + SmallVector<GridAxesAttr> targetShardingSplitAxes = llvm::to_vector(sourceSharding.getSplitAxes()); while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= targetTensorAxis) { - targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); + targetShardingSplitAxes.push_back(GridAxesAttr::get(ctx, {})); } auto sourceSplitAxes = llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef()); assert(!sourceSplitAxes.empty()); - auto meshAxis = sourceSplitAxes.back(); + auto gridAxis = sourceSplitAxes.back(); sourceSplitAxes.pop_back(); targetShardingSplitAxes[sourceTensorAxis] = - MeshAxesAttr::get(ctx, sourceSplitAxes); + GridAxesAttr::get(ctx, sourceSplitAxes); auto targetSplitAxes = llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef()); - targetSplitAxes.push_back(meshAxis); + targetSplitAxes.push_back(gridAxis); targetShardingSplitAxes[targetTensorAxis] = - MeshAxesAttr::get(ctx, targetSplitAxes); + GridAxesAttr::get(ctx, targetSplitAxes); - return MeshSharding::get(sourceSharding.getMeshAttr(), - targetShardingSplitAxes); + return Sharding::get(sourceSharding.getGridAttr(), targetShardingSplitAxes); } static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, @@ -322,46 +313,46 @@ static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); } -static std::tuple<TypedValue<ShapedType>, MeshSharding> -moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, +static std::tuple<TypedValue<ShapedType>, Sharding> +moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard, int64_t sourceTensorAxis, - int64_t targetTensorAxis, MeshAxis meshAxis) { + int64_t targetTensorAxis, GridAxis gridAxis) { MLIRContext *ctx = builder.getContext(); builder.setInsertionPointAfterValue(sourceShard); - MeshSharding targetSharding = targetShardingInMoveLastAxis( + Sharding targetSharding = targetShardingInMoveLastAxis( ctx, sourceSharding, sourceTensorAxis, targetTensorAxis); ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis( - sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis, + sourceShard.getType(), grid.getShape()[gridAxis], sourceTensorAxis, targetTensorAxis); Value allToAllResult = AllToAllOp::create( builder, RankedTensorType::get(allToAllResultShape.getShape(), allToAllResultShape.getElementType()), - mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard, + grid.getSymName(), SmallVector<GridAxis>({gridAxis}), sourceShard, APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); ShapedType targetShape = - shardShapedType(sourceUnshardedShape, mesh, targetSharding); + shardShapedType(sourceUnshardedShape, grid, targetSharding); TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( tensor::CastOp::create(builder, targetShape, allToAllResult).getResult()); return {targetShard, targetSharding}; } -static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> -tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, - MeshSharding targetSharding, +static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>> +tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, + Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard) { if (auto detectRes = detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) { - auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value(); + auto [sourceTensorAxis, targetTensorAxis, gridAxis] = detectRes.value(); return moveLastSplitAxisInResharding( - builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard, - sourceTensorAxis, targetTensorAxis, meshAxis); + builder, grid, sourceSharding, sourceUnshardedShape, sourceShard, + sourceTensorAxis, targetTensorAxis, gridAxis); } return std::nullopt; @@ -371,10 +362,9 @@ tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, // needed. A changed halo sizes requires copying the "core" of the source tensor // into the "core" of the destination tensor followed by an update halo // operation. -static std::optional<std::tuple<TypedValue<ShapedType>, MeshSharding>> -tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, - MeshSharding targetSharding, +static std::optional<std::tuple<TypedValue<ShapedType>, Sharding>> +tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, Sharding targetSharding, ShapedType sourceUnshardedShape, TypedValue<ShapedType> sourceShard) { // Currently handles only cases where halo sizes differ but everything else @@ -392,7 +382,7 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, assert(((srcHaloSizes.empty() || ShapedType::isStaticShape(srcHaloSizes)) && ShapedType::isStaticShape(tgtHaloSizes) && sourceShard.getType().hasStaticShape()) && - "dynamic shapes/halos are not supported yet for mesh-spmdization"); + "dynamic shapes/halos are not supported yet for shard-partition"); auto rank = sourceShard.getType().getRank(); auto splitAxes = sourceSharding.getSplitAxes(); SmallVector<int64_t> srcCoreOffs(rank, 0), tgtCoreOffs(rank, 0), @@ -428,56 +418,55 @@ tryUpdateHaloInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, // Finally update the halo. auto updateHaloResult = - builder - .create<UpdateHaloOp>( - sourceShard.getLoc(), - RankedTensorType::get(outShape, - sourceShard.getType().getElementType()), - initOprnd, mesh.getSymName(), - MeshAxesArrayAttr::get(builder.getContext(), - sourceSharding.getSplitAxes()), - targetSharding.getDynamicHaloSizes(), - targetSharding.getStaticHaloSizes()) + UpdateHaloOp::create( + builder, sourceShard.getLoc(), + RankedTensorType::get(outShape, + sourceShard.getType().getElementType()), + initOprnd, grid.getSymName(), + GridAxesArrayAttr::get(builder.getContext(), + sourceSharding.getSplitAxes()), + targetSharding.getDynamicHaloSizes(), + targetSharding.getStaticHaloSizes()) .getResult(); return std::make_tuple(cast<TypedValue<ShapedType>>(updateHaloResult), targetSharding); } -// Handles only resharding on a 1D mesh. +// Handles only resharding on a 1D shard. // Currently the sharded tensor axes must be exactly divisible by the single -// mesh axis size. +// grid axis size. static TypedValue<ShapedType> -reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, MeshSharding targetSharding, +reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, Sharding targetSharding, TypedValue<ShapedType> sourceUnshardedValue, TypedValue<ShapedType> sourceShard) { assert(sourceShard.getType() == - shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding)); + shardShapedType(sourceUnshardedValue.getType(), grid, sourceSharding)); [[maybe_unused]] ShapedType targetShardType = - shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding); + shardShapedType(sourceUnshardedValue.getType(), grid, targetSharding); assert(sourceShard.getType().getRank() == targetShardType.getRank()); - assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported."); + assert(grid.getRank() == 1 && "Only 1D grides are currently supported."); if (sourceSharding == targetSharding) { return sourceShard; } TypedValue<ShapedType> targetShard; - MeshSharding actualTargetSharding; + Sharding actualTargetSharding; if (sourceSharding.getStaticShardedDimsOffsets().empty() && targetSharding.getStaticShardedDimsOffsets().empty() && sourceSharding.getStaticHaloSizes().empty() && targetSharding.getStaticHaloSizes().empty()) { if (auto tryRes = tryMoveLastSplitAxisInResharding( - builder, mesh, sourceSharding, targetSharding, + builder, grid, sourceSharding, targetSharding, sourceUnshardedValue.getType(), sourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } else if (auto tryRes = - trySplitLastAxisInResharding(builder, mesh, sourceSharding, + trySplitLastAxisInResharding(builder, grid, sourceSharding, targetSharding, sourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } else if (auto tryRes = tryUnsplitLastAxisInResharding( - builder, mesh, sourceSharding, targetSharding, + builder, grid, sourceSharding, targetSharding, sourceUnshardedValue.getType(), sourceShard)) { std::tie(targetShard, actualTargetSharding) = tryRes.value(); } @@ -488,9 +477,8 @@ reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, return targetShard; } -TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, - MeshSharding sourceSharding, - MeshSharding targetSharding, +TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, GridOp grid, + Sharding sourceSharding, Sharding targetSharding, TypedValue<ShapedType> sourceUnshardedValue, TypedValue<ShapedType> sourceShard) { // If source and destination sharding are the same, no need to do anything. @@ -500,28 +488,28 @@ TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, } // Tries to handle the case where the resharding is needed because the halo - // sizes are different. Supports arbitrary mesh dimensionality. + // sizes are different. Supports arbitrary grid dimensionality. if (auto tryRes = tryUpdateHaloInResharding( - builder, mesh, sourceSharding, targetSharding, + builder, grid, sourceSharding, targetSharding, sourceUnshardedValue.getType(), sourceShard)) { return std::get<0>(tryRes.value()); // targetShard } - // Resort to handling only 1D meshes since the general case is complicated if + // Resort to handling only 1D grids since the general case is complicated if // it needs to be communication efficient in terms of minimizing the data // transfered between devices. - return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding, + return reshardOn1DGrid(builder, grid, sourceSharding, targetSharding, sourceUnshardedValue, sourceShard); } -TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, +TypedValue<ShapedType> reshard(OpBuilder &builder, GridOp grid, ShardOp source, ShardOp target, TypedValue<ShapedType> sourceShardValue) { assert(source.getResult() == target.getSrc()); auto sourceSharding = source.getSharding(); auto targetSharding = target.getSharding(); ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); - return reshard(implicitLocOpBuilder, mesh, sourceSharding, targetSharding, + return reshard(implicitLocOpBuilder, grid, sourceSharding, targetSharding, cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue); } @@ -530,21 +518,21 @@ TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, ShardOp target, TypedValue<ShapedType> sourceShardValue, SymbolTableCollection &symbolTableCollection) { - MeshOp srcMesh = getMesh(source, symbolTableCollection); - assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection)); - return reshard(builder, srcMesh, source, target, sourceShardValue); + GridOp srcGrid = getGrid(source, symbolTableCollection); + assert(srcGrid && srcGrid == getGrid(target, symbolTableCollection)); + return reshard(builder, srcGrid, source, target, sourceShardValue); } void reshardingRegisterDependentDialects(DialectRegistry ®istry) { - registry.insert<mesh::MeshDialect, tensor::TensorDialect>(); + registry.insert<shard::ShardDialect, tensor::TensorDialect>(); } -#define GEN_PASS_DEF_SPMDIZATION -#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" +#define GEN_PASS_DEF_PARTITION +#include "mlir/Dialect/Shard/Transforms/Passes.h.inc" using UnshardedToShardedValueMap = DenseMap<Value, Value>; -// Get the types of block arguments for an spmdized block. +// Get the types of block arguments for an partitioned block. // Reads the sharding annotations of the arguments to deduce the sharded types. // Types that are not ranked tensors are left unchanged. SmallVector<Type> @@ -563,35 +551,36 @@ shardedBlockArgumentTypes(Block &block, Operation *useOp = *rankedTensorArg.getUsers().begin(); ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp); assert(shardOp); - MeshOp mesh = getMesh(shardOp, symbolTableCollection); - return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh, + GridOp grid = getGrid(shardOp, symbolTableCollection); + return cast<Type>(shardShapedType(rankedTensorArg.getType(), grid, shardOp.getSharding())); }); return res; } -static LogicalResult spmdizeOperation( - Operation &op, ArrayRef<Value> spmdizedOperands, - ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { +static LogicalResult +partitionOperation(Operation &op, ArrayRef<Value> partitionedOperands, + ArrayRef<Sharding> operandShardings, + ArrayRef<Sharding> resultShardings, IRMapping &partitionMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op); if (!shardingInterface) { // If there is no sharding interface we are conservative and assume that // the op should be fully replicated no all devices. - spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings, - resultShardings, spmdizationMap, - symbolTableCollection, builder); + partitionFullyReplicatedOperation(op, partitionedOperands, operandShardings, + resultShardings, partitionMap, + symbolTableCollection, builder); } else { - if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings, - resultShardings, spmdizationMap, - symbolTableCollection, builder))) { + if (failed(shardingInterface.partition( + partitionedOperands, operandShardings, resultShardings, + partitionMap, symbolTableCollection, builder))) { return failure(); } } - assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) { - return spmdizationMap.contains(result); + assert(llvm::all_of(op.getResults(), [&partitionMap](OpResult result) { + return partitionMap.contains(result); })); return success(); @@ -599,88 +588,87 @@ static LogicalResult spmdizeOperation( // Retrieve the sharding annotations for the operands of the given operation. // If the type is not a ranked tensor it is not require to have an annotation. -static std::vector<MeshSharding> getOperandShardings(Operation &op) { - std::vector<MeshSharding> res; +static std::vector<Sharding> getOperandShardings(Operation &op) { + std::vector<Sharding> res; res.reserve(op.getNumOperands()); llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) { TypedValue<RankedTensorType> rankedTensor = dyn_cast<TypedValue<RankedTensorType>>(operand); if (!rankedTensor || rankedTensor.getType().getRank() == 0) { - return MeshSharding(); + return Sharding(); } Operation *definingOp = operand.getDefiningOp(); assert(definingOp); ShardOp shardOp = llvm::cast<ShardOp>(definingOp); - return MeshSharding(shardOp.getSharding()); + return Sharding(shardOp.getSharding()); }); return res; } // Retrieve the sharding annotations for the results of the given operation. // If the type is not a ranked tensor it is not require to have an annotation. -static std::vector<MeshSharding> getResultShardings(Operation &op) { - std::vector<MeshSharding> res; +static std::vector<Sharding> getResultShardings(Operation &op) { + std::vector<Sharding> res; res.reserve(op.getNumResults()); llvm::transform( op.getResults(), std::back_inserter(res), [&op](OpResult result) { if (!result.hasOneUse() || result.use_empty()) { - return MeshSharding(); + return Sharding(); } TypedValue<RankedTensorType> rankedTensor = dyn_cast<TypedValue<RankedTensorType>>(result); if (!rankedTensor) { - return MeshSharding(); + return Sharding(); } Operation *userOp = *result.getUsers().begin(); ShardOp shardOp = llvm::dyn_cast<ShardOp>(userOp); if (shardOp) { - return MeshSharding(shardOp.getSharding()); + return Sharding(shardOp.getSharding()); } if (rankedTensor.getType().getRank() == 0) { // This is a 0d tensor result without explicit sharding. - // Find mesh symbol from operands, if any. - // Shardings without mesh are not always fully supported yet. + // Find grid symbol from operands, if any. + // Shardings without grid are not always fully supported yet. for (auto operand : op.getOperands()) { if (auto sharding = operand.getDefiningOp<ShardingOp>()) { - return MeshSharding(sharding.getMeshAttr()); + return Sharding(sharding.getGridAttr()); } } } - return MeshSharding(); + return Sharding(); }); return res; } static LogicalResult -spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTableCollection, - OpBuilder &builder) { - Value targetSpmdValue; +partitionOperation(ShardOp shardOp, IRMapping &partitionMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { + Value targetPartitionValue; // Check if 2 shard ops are chained. If not there is no need for resharding // as the source and target shared the same sharding. - ShardOp srcShardOp = - dyn_cast_or_null<ShardOp>(shardOp.getSrc().getDefiningOp()); + ShardOp srcShardOp = shardOp.getSrc().getDefiningOp<ShardOp>(); if (!srcShardOp) { - targetSpmdValue = spmdizationMap.lookup(shardOp.getSrc()); + targetPartitionValue = partitionMap.lookup(shardOp.getSrc()); } else { // Insert resharding. - TypedValue<ShapedType> srcSpmdValue = - cast<TypedValue<ShapedType>>(spmdizationMap.lookup(srcShardOp)); - targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, - symbolTableCollection); + TypedValue<ShapedType> srcPartitionValue = + cast<TypedValue<ShapedType>>(partitionMap.lookup(srcShardOp)); + targetPartitionValue = reshard(builder, srcShardOp, shardOp, + srcPartitionValue, symbolTableCollection); } - assert(!spmdizationMap.contains(shardOp.getResult())); - spmdizationMap.map(shardOp.getResult(), targetSpmdValue); + assert(!partitionMap.contains(shardOp.getResult())); + partitionMap.map(shardOp.getResult(), targetPartitionValue); return success(); } static LogicalResult -spmdizeOperation(Operation &op, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTableCollection, - OpBuilder &builder) { +partitionOperation(Operation &op, IRMapping &partitionMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { if (isa<ShardingOp>(op)) { return success(); } @@ -690,30 +678,31 @@ spmdizeOperation(Operation &op, IRMapping &spmdizationMap, return op.emitError("expected a shard op as source of get_sharding"); } auto newSharding = builder.clone(*shardOp.getSharding().getDefiningOp()); - spmdizationMap.map(op.getResult(0), newSharding->getResult(0)); + partitionMap.map(op.getResult(0), newSharding->getResult(0)); return success(); } ShardOp shardOp = llvm::dyn_cast<ShardOp>(op); if (shardOp) { - return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, - builder); + return partitionOperation(shardOp, partitionMap, symbolTableCollection, + builder); } - SmallVector<Value> spmdizedOperands; - llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands), - [&spmdizationMap](Value operand) { - assert(spmdizationMap.contains(operand)); - return spmdizationMap.lookup(operand); + SmallVector<Value> partitionedOperands; + llvm::transform(op.getOperands(), std::back_inserter(partitionedOperands), + [&partitionMap](Value operand) { + assert(partitionMap.contains(operand)); + return partitionMap.lookup(operand); }); - return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op), - getResultShardings(op), spmdizationMap, - symbolTableCollection, builder); + return partitionOperation(op, partitionedOperands, getOperandShardings(op), + getResultShardings(op), partitionMap, + symbolTableCollection, builder); } -static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTableCollection, - OpBuilder &builder) { +static LogicalResult +partitionBlock(Block &block, IRMapping &partitionMap, + SymbolTableCollection &symbolTableCollection, + OpBuilder &builder) { SmallVector<Location> argLocations; llvm::transform(block.getArguments(), std::back_inserter(argLocations), @@ -721,16 +710,16 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, Block *newBlock = builder.createBlock( block.getParent(), {}, shardedBlockArgumentTypes(block, symbolTableCollection), argLocations); - for (auto [unshardedBlockArg, spmdizedBlockArg] : + for (auto [unshardedBlockArg, partitionedBlockArg] : llvm::zip(block.getArguments(), newBlock->getArguments())) { - spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg); + partitionMap.map(unshardedBlockArg, partitionedBlockArg); } OpBuilder::InsertionGuard insertionGuard(builder); builder.setInsertionPointToEnd(newBlock); for (Operation &op : block.getOperations()) { - if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection, - builder))) { + if (failed(partitionOperation(op, partitionMap, symbolTableCollection, + builder))) { return failure(); } } @@ -739,8 +728,8 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, } static LogicalResult -spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, - SymbolTableCollection &symbolTableCollection) { +partitionFuncOp(FunctionOpInterface op, IRMapping &partitionMap, + SymbolTableCollection &symbolTableCollection) { OpBuilder builder(op.getFunctionBody()); // Snapshot the original blocks to not mess up the iteration when adding new @@ -754,8 +743,8 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, } for (Block *block : originalBlocks) { - if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection, - builder))) { + if (failed(partitionBlock(*block, partitionMap, symbolTableCollection, + builder))) { return failure(); } } @@ -788,22 +777,22 @@ spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, namespace { -struct Spmdization : public impl::SpmdizationBase<Spmdization> { +struct Partition : public impl::PartitionBase<Partition> { void runOnOperation() override { - IRMapping spmdizationMap; + IRMapping partitionMap; SymbolTableCollection symbolTableCollection; - if (failed(spmdizeFuncOp(getOperation(), spmdizationMap, - symbolTableCollection))) { + if (failed(partitionFuncOp(getOperation(), partitionMap, + symbolTableCollection))) { return signalPassFailure(); } } void getDependentDialects(DialectRegistry ®istry) const override { reshardingRegisterDependentDialects(registry); - registry.insert<mesh::MeshDialect>(); + registry.insert<shard::ShardDialect>(); } }; } // namespace -} // namespace mlir::mesh +} // namespace mlir::shard diff --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp index 09c754d..a647128c 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp @@ -6,11 +6,11 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Transforms/Passes.h" +#include "mlir/Dialect/Shard/Transforms/Passes.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "llvm/ADT/STLExtras.h" @@ -21,17 +21,17 @@ #include <vector> namespace mlir { -namespace mesh { +namespace shard { #define GEN_PASS_DEF_SHARDINGPROPAGATION -#include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" -} // namespace mesh +#include "mlir/Dialect/Shard/Transforms/Passes.h.inc" +} // namespace shard } // namespace mlir #define DEBUG_TYPE "sharding-propagation" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") using namespace mlir; -using namespace mlir::mesh; +using namespace mlir::shard; enum class ReshardingRquirementKind { NO_RESHARDING = 0, @@ -68,7 +68,7 @@ static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, [[maybe_unused]] static llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, const ShardingOption &v) { - return stream << "{empty = " << v.empty << ", mesh" << v.mesh + return stream << "{empty = " << v.empty << ", grid" << v.grid << ", shardingArray = " << v.shardingArray << "}"; } @@ -105,15 +105,15 @@ operator<<(llvm::raw_ostream &stream, ReshardingRquirementKind v) { // specific shardings. For example, mustShardings = [shard0, None] and // optionalShardings = [None, shard1], the result will be [[shard0, shard1], // [shard0, None]] -static SmallVector<std::vector<MeshSharding>> -getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings, - ArrayRef<MeshSharding> optionalShardings) { - SmallVector<std::vector<MeshSharding>> allShardingAttrs; - std::vector<MeshSharding> curShardingAttrs; +static SmallVector<std::vector<Sharding>> +getOrderedPossibleShardingAttrs(ArrayRef<Sharding> mustShardings, + ArrayRef<Sharding> optionalShardings) { + SmallVector<std::vector<Sharding>> allShardingAttrs; + std::vector<Sharding> curShardingAttrs; std::function<void(size_t)> dfsCreateShardingAttrs = [&](size_t i) { if (i == mustShardings.size()) { - allShardingAttrs.push_back(std::vector<MeshSharding>(curShardingAttrs)); + allShardingAttrs.push_back(std::vector<Sharding>(curShardingAttrs)); return; } @@ -147,14 +147,14 @@ getOrderedPossibleShardingAttrs(ArrayRef<MeshSharding> mustShardings, // 1. No resharding is required (all existing annotations are compatible). // 2. No resharding for operands/results that have annotation specifically // targeting this operation. This means -// * operands that are the result of `mesh.shard` ops marked with +// * operands that are the result of `shard.shard` ops marked with // `annotate_for_users`. -// * results that are annotated with `mesh.shard` ops without +// * results that are annotated with `shard.shard` ops without // `annotate_for_users`. // 3. All other cases. Resharding is required for operands/results with // annotation targeting explicitly this operation. ReshardingRquirementKind getReshardingRquirementKind( - Operation *op, const std::vector<MeshSharding> &operandAndResultShardings) { + Operation *op, const std::vector<Sharding> &operandAndResultShardings) { ReshardingRquirementKind res = ReshardingRquirementKind::NO_RESHARDING; size_t operandsCount = op->getOperands().size(); @@ -167,7 +167,7 @@ ReshardingRquirementKind getReshardingRquirementKind( for (auto [operand, sharding] : llvm::zip_equal(op->getOperands(), operandShardings)) { - ShardOp shardOp = llvm::dyn_cast_or_null<ShardOp>(operand.getDefiningOp()); + ShardOp shardOp = operand.getDefiningOp<ShardOp>(); if (!shardOp) { continue; } @@ -213,14 +213,13 @@ ReshardingRquirementKind getReshardingRquirementKind( // 3. Resharding of existing explicit sharding annotations for this op. static FailureOr<ShardingOption> selectShardingOption( ShardingInterface shardingOp, - ArrayRef<std::vector<MeshSharding>> possibleOperandShardingAttrs, - ArrayRef<std::vector<MeshSharding>> possibleResultShardingAttrs) { + ArrayRef<std::vector<Sharding>> possibleOperandShardingAttrs, + ArrayRef<std::vector<Sharding>> possibleResultShardingAttrs) { SmallVector<std::tuple<ShardingOption, ReshardingRquirementKind>> shardingOptionsAndReshardingRequirements; - for (ArrayRef<MeshSharding> resultShardings : possibleResultShardingAttrs) { - for (ArrayRef<MeshSharding> operandShardings : - possibleOperandShardingAttrs) { + for (ArrayRef<Sharding> resultShardings : possibleResultShardingAttrs) { + for (ArrayRef<Sharding> operandShardings : possibleOperandShardingAttrs) { FailureOr<ShardingOption> shardingOption = shardingOp.getShardingOption(operandShardings, resultShardings); if (failed(shardingOption) || shardingOption->empty) { @@ -231,7 +230,7 @@ static FailureOr<ShardingOption> selectShardingOption( // They may be missing some annotations. // Whatever is returned by getShardingAnnotations is exactly what the op // needs. - FailureOr<std::vector<MeshSharding>> operandAndResultShardings = + FailureOr<std::vector<Sharding>> operandAndResultShardings = shardingOp.getShardingAnnotations(*shardingOption); if (failed(operandAndResultShardings)) { return failure(); @@ -276,13 +275,13 @@ static FailureOr<ShardingOption> selectShardingOption( // For each operation that implements the ShardingInterface, infer the sharding // option of the operation from its operands and/or results using the // `getShardingOption` method. If the inferred sharding option is not empty, add -// a `mesh.shard` operation for all remaining operands and results that do not +// a `shard.shard` operation for all remaining operands and results that do not // have sharding annotations. static LogicalResult visitOp(Operation *op, OpBuilder &builder) { ShardingInterface shardingOp = llvm::dyn_cast<ShardingInterface>(op); if (op->hasTrait<OpTrait::IsTerminator>() || (op->hasTrait<OpTrait::ConstantLike>() && !shardingOp) || - llvm::isa<mesh::ShardOp, mesh::ShardingOp, mesh::GetShardingOp>(op)) + llvm::isa<shard::ShardOp, shard::ShardingOp, shard::GetShardingOp>(op)) return success(); if (!shardingOp) { @@ -290,14 +289,13 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { return failure(); } - // collect MeshSharding from results - std::vector<MeshSharding> allowConflictsResultShardings; + // collect Sharding from results + std::vector<Sharding> allowConflictsResultShardings; allowConflictsResultShardings.resize(op->getNumResults()); - std::vector<MeshSharding> resultMustShardings; + std::vector<Sharding> resultMustShardings; resultMustShardings.resize(op->getNumResults()); for (OpResult result : op->getResults()) { - FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr = - getMeshSharding(result); + FailureOr<std::pair<bool, Sharding>> maybeShardAttr = getSharding(result); if (failed(maybeShardAttr)) continue; if (!maybeShardAttr->first) @@ -307,14 +305,14 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { maybeShardAttr->second; } - // collect MeshSharding from operands - std::vector<MeshSharding> allowConflictsOperandShardings; + // collect Sharding from operands + std::vector<Sharding> allowConflictsOperandShardings; allowConflictsOperandShardings.resize(op->getNumOperands()); - std::vector<MeshSharding> operandMustShardings; + std::vector<Sharding> operandMustShardings; operandMustShardings.resize(op->getNumOperands()); for (OpOperand &opOperand : op->getOpOperands()) { - FailureOr<std::pair<bool, MeshSharding>> maybeShardAttr = - getMeshSharding(opOperand); + FailureOr<std::pair<bool, Sharding>> maybeShardAttr = + getSharding(opOperand); if (failed(maybeShardAttr)) continue; @@ -327,10 +325,10 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { } // try to get the sharding option - SmallVector<std::vector<MeshSharding>> possibleOperandShardingAttrs = + SmallVector<std::vector<Sharding>> possibleOperandShardingAttrs = getOrderedPossibleShardingAttrs(operandMustShardings, allowConflictsOperandShardings); - SmallVector<std::vector<MeshSharding>> possibleResultShardingAttrs = + SmallVector<std::vector<Sharding>> possibleResultShardingAttrs = getOrderedPossibleShardingAttrs(resultMustShardings, allowConflictsResultShardings); FailureOr<ShardingOption> shardingOption = selectShardingOption( @@ -358,7 +356,7 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) { // ShardingPropagation //===----------------------------------------------------------------------===// struct ShardingPropagation - : public mesh::impl::ShardingPropagationBase<ShardingPropagation> { + : public shard::impl::ShardingPropagationBase<ShardingPropagation> { using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase; @@ -376,8 +374,7 @@ struct ShardingPropagation LLVM_DEBUG( DBGS() << "print all the ops' iterator types and indexing maps in the " "block.\n"; - for (Operation &op - : block.getOperations()) { + for (Operation &op : block.getOperations()) { if (auto shardingOp = llvm::dyn_cast<ShardingInterface>(&op)) shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs()); }); diff --git a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp b/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp index 1315502..a17671e 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Simplifications.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Simplifications.cpp @@ -1,4 +1,4 @@ -//===- Simplifications.cpp - Mesh Simplifications ---------------*- C++ -*-===// +//===- Simplifications.cpp - Shard Simplifications -_------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Transforms/Simplifications.h" +#include "mlir/Dialect/Shard/Transforms/Simplifications.h" #include "TransformsDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" @@ -18,7 +18,7 @@ #include <numeric> namespace mlir { -namespace mesh { +namespace shard { void populateSimplificationPatterns( RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { @@ -52,53 +52,53 @@ namespace { // DialectFoldInterface, because it needs a SymbolTableCollection to cache the // symbol tables. // We can't use DialectFoldInterface since the cache may be invalidated by some -// pass changing the referenced MeshOp ops. -struct MeshShapeFolder - : OpRewritePatternWithSymbolTableCollection<MeshShapeOp> { +// pass changing the referenced GridOp ops. +struct GridShapeFolder + : OpRewritePatternWithSymbolTableCollection<GridShapeOp> { using OpRewritePatternWithSymbolTableCollection:: OpRewritePatternWithSymbolTableCollection; - LogicalResult matchAndRewrite(MeshShapeOp op, + LogicalResult matchAndRewrite(GridShapeOp op, PatternRewriter &rewriter) const override { ImplicitLocOpBuilder builder(op->getLoc(), rewriter); - MeshOp mesh = symbolTableCollection.lookupNearestSymbolFrom<mesh::MeshOp>( - op.getOperation(), op.getMeshAttr()); - if (!mesh) { + GridOp grid = symbolTableCollection.lookupNearestSymbolFrom<shard::GridOp>( + op.getOperation(), op.getGridAttr()); + if (!grid) { return failure(); } - ArrayRef<MeshAxis> opMeshAxes = op.getAxes(); - SmallVector<MeshAxis> opAxesIota; - if (opMeshAxes.empty()) { - opAxesIota.resize(mesh.getRank()); + ArrayRef<GridAxis> opGridAxes = op.getAxes(); + SmallVector<GridAxis> opAxesIota; + if (opGridAxes.empty()) { + opAxesIota.resize(grid.getRank()); std::iota(opAxesIota.begin(), opAxesIota.end(), 0); - opMeshAxes = opAxesIota; + opGridAxes = opAxesIota; } - if (llvm::all_of(opMeshAxes, [&mesh](MeshAxis axis) { - return ShapedType::isDynamic(mesh.getShape()[axis]); + if (llvm::all_of(opGridAxes, [&grid](GridAxis axis) { + return ShapedType::isDynamic(grid.getShape()[axis]); })) { - // All mesh dimensions are dynamic. Nothing to fold. + // All grid dimensions are dynamic. Nothing to fold. return failure(); } SmallVector<Value> newResults(op->getResults().size()); - SmallVector<MeshAxis> newShapeOpMeshAxes; + SmallVector<GridAxis> newShapeOpGridAxes; SmallVector<size_t> newToOldResultsIndexMap; - for (size_t i = 0; i < opMeshAxes.size(); ++i) { - auto meshAxisSize = mesh.getShape()[opMeshAxes[i]]; - if (ShapedType::isDynamic(meshAxisSize)) { + for (size_t i = 0; i < opGridAxes.size(); ++i) { + auto gridAxisSize = grid.getShape()[opGridAxes[i]]; + if (ShapedType::isDynamic(gridAxisSize)) { newToOldResultsIndexMap.push_back(i); - newShapeOpMeshAxes.push_back(opMeshAxes[i]); + newShapeOpGridAxes.push_back(opGridAxes[i]); } else { - // Fold static mesh axes. + // Fold static grid axes. newResults[i] = arith::ConstantOp::create( - builder, builder.getIndexAttr(meshAxisSize)); + builder, builder.getIndexAttr(gridAxisSize)); } } - // Leave only the dynamic mesh axes to be queried. - if (!newShapeOpMeshAxes.empty()) { - MeshShapeOp newShapeOp = - MeshShapeOp::create(builder, mesh.getSymName(), newShapeOpMeshAxes); + // Leave only the dynamic grid axes to be queried. + if (!newShapeOpGridAxes.empty()) { + GridShapeOp newShapeOp = + GridShapeOp::create(builder, grid.getSymName(), newShapeOpGridAxes); for (size_t i = 0; i < newShapeOp->getResults().size(); ++i) { newResults[newToOldResultsIndexMap[i]] = newShapeOp->getResults()[i]; } @@ -113,8 +113,8 @@ struct MeshShapeFolder void populateFoldingPatterns(RewritePatternSet &patterns, SymbolTableCollection &symbolTableCollection) { - patterns.add<MeshShapeFolder>(symbolTableCollection, patterns.getContext()); + patterns.add<GridShapeFolder>(symbolTableCollection, patterns.getContext()); } -} // namespace mesh +} // namespace shard } // namespace mlir diff --git a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp index 1bde1af..772e66f 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Transforms/Transforms.h" +#include "mlir/Dialect/Shard/Transforms/Transforms.h" #include "TransformsDetail.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/Utils.h" @@ -14,8 +14,8 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" -#include "mlir/Dialect/Mesh/IR/MeshDialect.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinTypes.h" @@ -28,12 +28,12 @@ #include <iterator> #include <numeric> -namespace mlir::mesh { +namespace mlir::shard { namespace { -/// Lower `mesh.process_multi_index` into expression using -/// `mesh.process_linear_index` and `mesh.mesh_shape`. +/// Lower `shard.process_multi_index` into expression using +/// `shard.process_linear_index` and `shard.grid_shape`. struct ProcessMultiIndexOpLowering : OpRewritePatternWithSymbolTableCollection<ProcessMultiIndexOp> { using OpRewritePatternWithSymbolTableCollection:: @@ -41,30 +41,30 @@ struct ProcessMultiIndexOpLowering LogicalResult matchAndRewrite(ProcessMultiIndexOp op, PatternRewriter &rewriter) const override { - MeshOp mesh = getMesh(op, symbolTableCollection); - if (!mesh) { + GridOp grid = getGrid(op, symbolTableCollection); + if (!grid) { return failure(); } ImplicitLocOpBuilder builder(op->getLoc(), rewriter); builder.setInsertionPointAfter(op.getOperation()); - Value linearIndex = ProcessLinearIndexOp::create(builder, mesh); - ValueRange meshShape = MeshShapeOp::create(builder, mesh).getResults(); + Value linearIndex = ProcessLinearIndexOp::create(builder, grid); + ValueRange gridShape = GridShapeOp::create(builder, grid).getResults(); SmallVector<Value> completeMultiIndex = affine::AffineDelinearizeIndexOp::create(builder, linearIndex, - meshShape) + gridShape) .getMultiIndex(); SmallVector<Value> multiIndex; - ArrayRef<MeshAxis> opMeshAxes = op.getAxes(); - SmallVector<MeshAxis> opAxesIota; - if (opMeshAxes.empty()) { - opAxesIota.resize(mesh.getRank()); + ArrayRef<GridAxis> opGridAxes = op.getAxes(); + SmallVector<GridAxis> opAxesIota; + if (opGridAxes.empty()) { + opAxesIota.resize(grid.getRank()); std::iota(opAxesIota.begin(), opAxesIota.end(), 0); - opMeshAxes = opAxesIota; + opGridAxes = opAxesIota; } - llvm::transform(opMeshAxes, std::back_inserter(multiIndex), - [&completeMultiIndex](MeshAxis meshAxis) { - return completeMultiIndex[meshAxis]; + llvm::transform(opGridAxes, std::back_inserter(multiIndex), + [&completeMultiIndex](GridAxis gridAxis) { + return completeMultiIndex[gridAxis]; }); rewriter.replaceAllUsesWith(op.getResults(), multiIndex); return success(); @@ -86,15 +86,15 @@ struct AllSliceOpLowering // axis. // The slice axis is split into equisized parts with count // the number of processes in the collective process group induced by - // the mesh axes. + // the grid axes. // The part for each process is determined by the corresponding // linear-index in the process group. // // There are no collectives that require communication. // Each process operates on its local tensor. - MeshOp mesh = getMesh(op, symbolTableCollection); - if (!mesh) { + GridOp grid = getGrid(op, symbolTableCollection); + if (!grid) { return failure(); } @@ -104,15 +104,15 @@ struct AllSliceOpLowering Value zero = arith::ConstantOp::create(builder, builder.getIndexAttr(0)); Operation::result_range processInGroupMultiIndex = - ProcessMultiIndexOp::create(builder, mesh.getSymName(), - op.getMeshAxes()) + ProcessMultiIndexOp::create(builder, grid.getSymName(), + op.getGridAxes()) .getResults(); Operation::result_range processGroupShape = - MeshShapeOp::create(builder, mesh.getSymName(), op.getMeshAxes()) + GridShapeOp::create(builder, grid.getSymName(), op.getGridAxes()) .getResult(); Value processGroupSize = - createCollectiveProcessGroupSize(mesh, op.getMeshAxes(), builder); + createCollectiveProcessGroupSize(grid, op.getGridAxes(), builder); int64_t sliceAxis = op.getSliceAxis().getSExtValue(); Value operandSliceAxisSize = @@ -125,7 +125,7 @@ struct AllSliceOpLowering cf::AssertOp::create(builder, isTargetShapeExactlyDivisible, "Slicing a tensor with axis size that is " "not exactly divisible by the " - "mesh process group size is not supported."); + "grid process group size is not supported."); Value resultSliceAxisSize = arith::DivUIOp::create(builder, operandSliceAxisSize, processGroupSize); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( @@ -172,7 +172,7 @@ void populateProcessMultiIndexOpLoweringPatterns( } void registerProcessMultiIndexOpLoweringDialects(DialectRegistry ®istry) { - registry.insert<affine::AffineDialect, mesh::MeshDialect>(); + registry.insert<affine::AffineDialect, shard::ShardDialect>(); } void populateAllSliceOpLoweringPatterns( @@ -183,7 +183,7 @@ void populateAllSliceOpLoweringPatterns( void registerAllSliceOpLoweringDialects(DialectRegistry ®istry) { registry.insert<affine::AffineDialect, arith::ArithDialect, - cf::ControlFlowDialect, mesh::MeshDialect, + cf::ControlFlowDialect, shard::ShardDialect, tensor::TensorDialect>(); } @@ -199,21 +199,21 @@ void registerAllOpLoweringDialects(DialectRegistry ®istry) { } TypedValue<IndexType> -createCollectiveProcessGroupSize(MeshOp mesh, ArrayRef<MeshAxis> axes, +createCollectiveProcessGroupSize(GridOp grid, ArrayRef<GridAxis> axes, ImplicitLocOpBuilder &builder) { - Operation::result_range meshShape = - mesh::MeshShapeOp::create(builder, mesh, axes).getResults(); + Operation::result_range gridShape = + GridShapeOp::create(builder, grid, axes).getResults(); return cast<TypedValue<IndexType>>(arith::createProduct( - builder, builder.getLoc(), llvm::to_vector_of<Value>(meshShape), + builder, builder.getLoc(), llvm::to_vector_of<Value>(gridShape), builder.getIndexType())); } TypedValue<IndexType> -createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex, - ArrayRef<MeshAxis> meshAxes, +createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex, + ArrayRef<GridAxis> gridAxes, ImplicitLocOpBuilder &builder) { Operation::result_range processGroupShape = - MeshShapeOp::create(builder, mesh, meshAxes).getResult(); + GridShapeOp::create(builder, grid, gridAxes).getResult(); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( llvm::to_vector_of<OpFoldResult>(processInGroupMultiIndex), llvm::to_vector_of<OpFoldResult>(processGroupShape), builder); @@ -225,11 +225,11 @@ createProcessLinearIndex(StringRef mesh, ValueRange processInGroupMultiIndex, return cast<TypedValue<IndexType>>(res); } -TypedValue<IndexType> createProcessLinearIndex(StringRef mesh, - ArrayRef<MeshAxis> meshAxes, +TypedValue<IndexType> createProcessLinearIndex(StringRef grid, + ArrayRef<GridAxis> gridAxes, ImplicitLocOpBuilder &builder) { return createProcessLinearIndex( - mesh, ProcessMultiIndexOp::create(builder, mesh, meshAxes).getResults(), - meshAxes, builder); + grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(), + gridAxes, builder); } -} // namespace mlir::mesh +} // namespace mlir::shard diff --git a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h b/mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h index 3e3f584..60c9828 100644 --- a/mlir/lib/Dialect/Mesh/Transforms/TransformsDetail.h +++ b/mlir/lib/Dialect/Shard/Transforms/TransformsDetail.h @@ -6,14 +6,14 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H -#define MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H +#ifndef MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H +#define MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H #include "mlir/IR/PatternMatch.h" #include "mlir/IR/SymbolTable.h" namespace mlir { -namespace mesh { +namespace shard { template <typename Op> struct OpRewritePatternWithSymbolTableCollection : OpRewritePattern<Op> { @@ -29,7 +29,7 @@ protected: SymbolTableCollection &symbolTableCollection; }; -} // namespace mesh +} // namespace shard } // namespace mlir -#endif // MLIR_DIALECT_MESH_TRANSFORMS_TRANSFORMSDETAIL_H +#endif // MLIR_DIALECT_SHARD_TRANSFORMS_TRANSFORMSDETAIL_H diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index 0262319..3b4140e 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -931,10 +931,9 @@ createQuickSort(OpBuilder &builder, ModuleOp module, func::FuncOp func, FlatSymbolRefAttr partitionFunc = getMangledSortHelperFunc( builder, func, {IndexType::get(context)}, kPartitionFuncNamePrefix, xPerm, ny, args.drop_back(nTrailingP), createPartitionFunc); - Value p = builder - .create<func::CallOp>(loc, partitionFunc, - TypeRange{IndexType::get(context)}, - args.drop_back(nTrailingP)) + Value p = func::CallOp::create(builder, loc, partitionFunc, + TypeRange{IndexType::get(context)}, + args.drop_back(nTrailingP)) .getResult(0); Value lenLow = arith::SubIOp::create(builder, loc, p, lo); @@ -1028,9 +1027,8 @@ static void createSortStableFunc(OpBuilder &builder, ModuleOp module, FlatSymbolRefAttr searchFunc = getMangledSortHelperFunc( builder, func, {IndexType::get(context)}, kBinarySearchFuncNamePrefix, xPerm, ny, operands, createBinarySearchFunc); - Value p = builder - .create<func::CallOp>(loc, searchFunc, TypeRange{c1.getType()}, - operands) + Value p = func::CallOp::create(builder, loc, searchFunc, + TypeRange{c1.getType()}, operands) .getResult(0); // Move the value at data[i] to a temporary location. @@ -1317,7 +1315,7 @@ public: Value n = op.getN() ? op.getN() : constantIndex(rewriter, loc, 1); Value newSize = arith::AddIOp::create(rewriter, loc, size, n); - auto nValue = dyn_cast_or_null<arith::ConstantIndexOp>(n.getDefiningOp()); + auto nValue = n.getDefiningOp<arith::ConstantIndexOp>(); bool nIsOne = (nValue && nValue.value() == 1); if (!op.getInbounds()) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp index a317abd..0bd1d34 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp @@ -98,10 +98,10 @@ static Value genLaunchGPUFunc(OpBuilder &builder, gpu::GPUFuncOp gpuFunc, Value numT = constantIndex(builder, loc, numThreads); gpu::KernelDim3 gridSize = {one, one, one}; gpu::KernelDim3 blckSize = {numT, one, one}; - return builder - .create<gpu::LaunchFuncOp>(loc, gpuFunc, gridSize, blckSize, - /*dynSharedMemSz*/ none, args, - builder.getType<gpu::AsyncTokenType>(), tokens) + return gpu::LaunchFuncOp::create(builder, loc, gpuFunc, gridSize, blckSize, + /*dynSharedMemSz*/ none, args, + builder.getType<gpu::AsyncTokenType>(), + tokens) .getAsyncToken(); } @@ -1168,7 +1168,7 @@ struct ForallRewriter : public OpRewritePattern<scf::ParallelOp> { using OpRewritePattern<scf::ParallelOp>::OpRewritePattern; ForallRewriter(MLIRContext *context, unsigned nT) - : OpRewritePattern(context), numThreads(nT){}; + : OpRewritePattern(context), numThreads(nT) {}; LogicalResult matchAndRewrite(scf::ParallelOp forallOp, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index dfb1274..9cd4896 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -443,8 +443,8 @@ mlir::SparseIterationTypeConverter::SparseIterationTypeConverter() { addSourceMaterialization([](OpBuilder &builder, IterSpaceType spTp, ValueRange inputs, Location loc) -> Value { - return builder - .create<UnrealizedConversionCastOp>(loc, TypeRange(spTp), inputs) + return UnrealizedConversionCastOp::create(builder, loc, TypeRange(spTp), + inputs) .getResult(0); }); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp index 70795e2..7a26cd3 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -412,13 +412,13 @@ static Value genSliceToSize(OpBuilder &builder, Location loc, Value mem, if (memTp.getRank() > 1) return mem; // Truncate linear memrefs to given size. - return builder - .create<memref::SubViewOp>( - loc, MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()), - mem, ValueRange{}, ValueRange{sz}, ValueRange{}, - ArrayRef<int64_t>{0}, // static offset - ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size - ArrayRef<int64_t>{1}) // static stride + return memref::SubViewOp::create( + builder, loc, + MemRefType::get({ShapedType::kDynamic}, memTp.getElementType()), + mem, ValueRange{}, ValueRange{sz}, ValueRange{}, + ArrayRef<int64_t>{0}, // static offset + ArrayRef<int64_t>{ShapedType::kDynamic}, // dynamic size + ArrayRef<int64_t>{1}) // static stride .getResult(); } @@ -449,7 +449,7 @@ class SparseInsertGenerator public: SparseInsertGenerator(TensorType rtp, TypeRange retTypes, ValueRange params, bool genCall) - : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp){}; + : FuncCallOrInlineGenerator(retTypes, params, genCall), rtp(rtp) {}; /// Generates code along an insertion path without the need for a "cursor". /// This current insertion strategy comes at the expense of some testing diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp index b444ac5..79f4e7f 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -904,9 +904,8 @@ public: dstTp->withoutDimToLvl(), !srcTp->isAllOrdered() || !srcTp->isIdentity() || !dstTp->isIdentity()); SmallVector<Value> dynSizes; - Value buffer = rewriter - .create<AllocTensorOp>(loc, bufferTp, dynSizes, Value(), - nnz, Attribute()) + Value buffer = AllocTensorOp::create(rewriter, loc, bufferTp, dynSizes, + Value(), nnz, Attribute()) .getResult(); // Convert src coordinates to dst coordinates by first collapsing it to 1D @@ -1013,9 +1012,8 @@ public: !srcTp.isAllOrdered() || !srcTp.isIdentity() || !dstTp.isIdentity()); Value buffer = - rewriter - .create<AllocTensorOp>(loc, bufferTp, dstDynSizes, Value(), - /*sizeHint=*/nnz, Attribute()) + AllocTensorOp::create(rewriter, loc, bufferTp, dstDynSizes, Value(), + /*sizeHint=*/nnz, Attribute()) .getResult(); // Implement the sparse2sparse reshape as follows: diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 0e96b59..869d27a 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -115,8 +115,7 @@ public: bufferization::BufferizationState bufferizationState; - if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()), - updatedOptions, + if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions, bufferizationState))) return failure(); diff --git a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp index 0421a6c..0784615 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp +++ b/mlir/lib/Dialect/Tensor/Extensions/AllExtensions.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tensor/Extensions/AllExtensions.h" -#include "mlir/Dialect/Tensor/Extensions/MeshShardingExtensions.h" +#include "mlir/Dialect/Tensor/Extensions/ShardingExtensions.h" using namespace mlir; diff --git a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt index dba5933..8f0b7da 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Extensions/CMakeLists.txt @@ -1,10 +1,10 @@ set(LLVM_OPTIONAL_SOURCES AllExtensions.cpp - MeshShardingExtensions.cpp + ShardingExtensions.cpp ) -add_mlir_extension_library(MLIRTensorMeshShardingExtensions - MeshShardingExtensions.cpp +add_mlir_extension_library(MLIRTensorShardingExtensions + ShardingExtensions.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions @@ -22,5 +22,5 @@ add_mlir_extension_library(MLIRTensorAllExtensions ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Extensions LINK_LIBS PUBLIC - MLIRTensorMeshShardingExtensions + MLIRTensorShardingExtensions )
\ No newline at end of file diff --git a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp index 7e4a5ac..ca7287c 100644 --- a/mlir/lib/Dialect/Tensor/Extensions/MeshShardingExtensions.cpp +++ b/mlir/lib/Dialect/Tensor/Extensions/ShardingExtensions.cpp @@ -6,15 +6,15 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/ShardingInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/DialectRegistry.h" using namespace mlir; using namespace mlir::tensor; -using namespace mlir::mesh; +using namespace mlir::shard; namespace { @@ -40,20 +40,20 @@ struct CreatorOpShardingInterface {AffineMap::getMultiDimIdentityMap(type.getRank(), ctx)}); } - LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, - ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder) const { + LogicalResult partition(Operation *op, ArrayRef<Value> partitionedOperands, + ArrayRef<Sharding> operandShardings, + ArrayRef<Sharding> resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { assert(resultShardings.size() == 1); auto resType = cast<RankedTensorType>(op->getResult(0).getType()); - mlir::mesh::MeshOp mesh; + mlir::shard::GridOp grid; ShapedType shardType; if (resType.getRank() > 0) { - mesh = mesh::getMesh(op, resultShardings[0].getMeshAttr(), symbolTable); + grid = shard::getGrid(op, resultShardings[0].getGridAttr(), symbolTable); shardType = - cast<ShapedType>(mesh::shardType(resType, mesh, resultShardings[0])); + cast<ShapedType>(shard::shardType(resType, grid, resultShardings[0])); } else { shardType = resType; } @@ -67,7 +67,7 @@ struct CreatorOpShardingInterface auto oldType = cast<ShapedType>(resType); assert(oldType.getRank() == shardType.getRank()); int currOldOprndNum = -1; - mesh::ShardShapeOp shapeForDevice; + shard::ShardShapeOp shapeForDevice; ValueRange device; Operation *newSharding = nullptr; for (auto i = 0; i < oldType.getRank(); ++i) { @@ -76,23 +76,23 @@ struct CreatorOpShardingInterface newSharding = ShardingOp::create(builder, op->getLoc(), resultShardings[0]); device = - mesh::ProcessMultiIndexOp::create(builder, op->getLoc(), mesh) + shard::ProcessMultiIndexOp::create(builder, op->getLoc(), grid) .getResults(); - shapeForDevice = mesh::ShardShapeOp::create( - builder, op->getLoc(), oldType.getShape(), spmdizedOperands, + shapeForDevice = shard::ShardShapeOp::create( + builder, op->getLoc(), oldType.getShape(), partitionedOperands, newSharding->getResult(0), device); } newOperands.emplace_back(shapeForDevice.getResult()[i]); } else if (oldType.isDynamicDim(i)) { assert(shardType.isDynamicDim(i)); - newOperands.emplace_back(spmdizedOperands[++currOldOprndNum]); + newOperands.emplace_back(partitionedOperands[++currOldOprndNum]); } } newOp = OpTy::create(builder, op->getLoc(), shardType, newOperands); - spmdizationMap.map(op->getResult(0), newOp->getResult(0)); + partitionMap.map(op->getResult(0), newOp->getResult(0)); } else { // `clone` will populate the mapping of old to new results. - newOp = builder.clone(*op, spmdizationMap); + newOp = builder.clone(*op, partitionMap); } newOp->getResult(0).setType(shardType); diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index bc11e56..c3356c1 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -784,8 +784,8 @@ struct PadOpInterface auto toValue = [&](OpFoldResult ofr) { if (auto value = dyn_cast<Value>(ofr)) return value; - return rewriter - .create<arith::ConstantIndexOp>(loc, *getConstantIntValue(ofr)) + return arith::ConstantIndexOp::create(rewriter, loc, + *getConstantIntValue(ofr)) .getResult(); }; @@ -919,9 +919,8 @@ struct ReshapeOpInterface auto memrefType = MemRefType::get( srcType.getShape(), srcType.getElementType(), AffineMap(), cast<BaseMemRefType>(srcBuffer->getType()).getMemorySpace()); - srcBuffer = rewriter - .create<bufferization::ToBufferOp>( - op->getLoc(), memrefType, *tensorAlloc) + srcBuffer = bufferization::ToBufferOp::create(rewriter, op->getLoc(), + memrefType, *tensorAlloc) .getResult(); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp index 43d9d70..9fd27d3 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/IndependenceTransforms.cpp @@ -130,8 +130,7 @@ FailureOr<Value> tensor::buildIndependentOp(OpBuilder &b, // Create a tensor::ExtractSliceOp. SmallVector<OpFoldResult> offsets(newSizes.size(), b.getIndexAttr(0)); SmallVector<OpFoldResult> strides(newSizes.size(), b.getIndexAttr(1)); - return b - .create<ExtractSliceOp>(loc, newEmptyOp, offsets, emptyOp.getMixedSizes(), - strides) + return ExtractSliceOp::create(b, loc, newEmptyOp, offsets, + emptyOp.getMixedSizes(), strides) .getResult(); } diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index e0af2f7..2ec23e1 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -385,10 +385,9 @@ struct BubbleUpExpandShapeThroughExtractSlice return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); }); OpFoldResult collapsedOffset = - rewriter - .create<affine::AffineLinearizeIndexOp>(loc, offsetVals, - reassocGroupSizes, - /*disjoint=*/true) + affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals, + reassocGroupSizes, + /*disjoint=*/true) .getResult(); collapsedOffsets.push_back(collapsedOffset); collapsedSizes.push_back(collapsedSize); diff --git a/mlir/lib/Dialect/Tosa/CMakeLists.txt b/mlir/lib/Dialect/Tosa/CMakeLists.txt index b1fac8c..c6a438d 100644 --- a/mlir/lib/Dialect/Tosa/CMakeLists.txt +++ b/mlir/lib/Dialect/Tosa/CMakeLists.txt @@ -36,7 +36,7 @@ add_mlir_dialect_library(MLIRTosaShardingInterfaceImpl LINK_LIBS PUBLIC MLIRIR - MLIRMeshDialect + MLIRShardDialect MLIRShardingInterface MLIRSupport MLIRTosaDialect diff --git a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp index d3a5f44..45994a7 100644 --- a/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tosa/IR/ShardingInterfaceImpl.cpp @@ -7,9 +7,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h" -#include "mlir/Dialect/Mesh/IR/MeshOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Shard/IR/ShardOps.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/DialectRegistry.h" @@ -19,7 +19,7 @@ using namespace mlir; using namespace mlir::tosa; -using namespace mlir::mesh; +using namespace mlir::shard; namespace { @@ -87,15 +87,15 @@ struct NegateOpSharding return maps; } - LogicalResult spmdize(Operation *op, ArrayRef<Value> spmdizedOperands, - ArrayRef<MeshSharding> operandShardings, - ArrayRef<MeshSharding> resultShardings, - IRMapping &spmdizationMap, - SymbolTableCollection &symbolTable, - OpBuilder &builder) const { - spmdizeTriviallyShardableOperation(*op, spmdizedOperands, operandShardings, - resultShardings, spmdizationMap, - symbolTable, builder); + LogicalResult partition(Operation *op, ArrayRef<Value> partitiondOperands, + ArrayRef<Sharding> operandShardings, + ArrayRef<Sharding> resultShardings, + IRMapping &partitionMap, + SymbolTableCollection &symbolTable, + OpBuilder &builder) const { + partitionTriviallyShardableOperation(*op, partitiondOperands, + operandShardings, resultShardings, + partitionMap, symbolTable, builder); return success(); } }; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 606626d..6d2cbb5 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -554,7 +554,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> { Value input = op.getInput(); // Check the input to the CLAMP op is itself a CLAMP. - auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp()); + auto clampOp = input.getDefiningOp<tosa::ClampOp>(); if (!clampOp) return failure(); @@ -707,9 +707,8 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> { auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes); replaceWithSlice = - rewriter - .create<tosa::SliceOp>(sliceOp.getLoc(), sliceOp.getType(), - input, start_op, size_op) + tosa::SliceOp::create(rewriter, sliceOp.getLoc(), sliceOp.getType(), + input, start_op, size_op) .getResult(); break; } @@ -1302,9 +1301,11 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { auto intVal = operand.getSplatValue<APInt>(); auto bitwidth = outETy.getIntOrFloatBitWidth(); - if (trunc) { + // i1 types are boolean in TOSA + if (outETy.isInteger(1)) { + intVal = APInt(bitwidth, intVal.isZero() ? 0 : 1); + } else if (trunc) { intVal = intVal.trunc(bitwidth); - // i1 types are boolean in TOSA } else if (unsignIn || inIntType.isInteger(1)) { intVal = intVal.zext(bitwidth); } else { @@ -1634,7 +1635,7 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { for (Value operand : getOperands()) { concatOperands.emplace_back(operand); - auto producer = dyn_cast_or_null<ConcatOp>(operand.getDefiningOp()); + auto producer = operand.getDefiningOp<ConcatOp>(); if (!producer) continue; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 648e508a9..3cafb19 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -13,8 +13,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" @@ -166,7 +166,7 @@ void TosaDialect::initialize() { >(); addInterfaces<TosaDialectBytecodeInterface, TosaInlinerInterface>(); declarePromisedInterfaces< - mesh::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp, + shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp, @@ -3647,6 +3647,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() { return std::nullopt; } +static void printInitializationList(OpAsmPrinter &parser, + Block::BlockArgListType blocksArgs, + ValueRange initializers, + StringRef prefix = "") { + assert(blocksArgs.size() == initializers.size() && + "expected same length of arguments and initializers"); + if (initializers.empty()) + return; + + parser << prefix << '('; + llvm::interleaveComma( + llvm::zip(blocksArgs, initializers), parser, + [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); + parser << ")"; +} + // parse and print of IfOp refer to the implementation of SCF dialect. ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions for 'then'. @@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); - auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand cond; - // Create a i1 tensor type for the boolean condition. - Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); - if (parser.parseOperand(cond) || - parser.resolveOperand(cond, i1Type, result.operands)) + + if (parser.parseOperand(cond)) return failure(); - // Parse optional results type list. - if (parser.parseOptionalArrowTypeList(result.types)) + + SmallVector<OpAsmParser::Argument, 4> regionArgs; + SmallVector<OpAsmParser::UnresolvedOperand, 4> operands; + + // Parse the optional block arguments + OptionalParseResult listResult = + parser.parseOptionalAssignmentList(regionArgs, operands); + if (listResult.has_value() && failed(listResult.value())) return failure(); + + // Parse a colon. + if (failed(parser.parseColon())) + return parser.emitError(parser.getCurrentLocation(), + "expected type for condition operand"); + + // Parse the type of the condition operand + Type condType; + if (failed(parser.parseType(condType))) + return parser.emitError(parser.getCurrentLocation(), + "expected type for condition operand"); + + // Resolve operand with provided type + if (failed(parser.resolveOperand(cond, condType, result.operands))) + return failure(); + + // Parse optional block arg types + if (listResult.has_value()) { + FunctionType functionType; + + if (failed(parser.parseType(functionType))) + return parser.emitError(parser.getCurrentLocation()) + << "expected list of types for block arguments " + << "followed by arrow type and list of return types"; + + result.addTypes(functionType.getResults()); + + if (functionType.getNumInputs() != operands.size()) { + return parser.emitError(parser.getCurrentLocation()) + << "expected as many input types as operands " + << "(expected " << operands.size() << " got " + << functionType.getNumInputs() << ")"; + } + + // Resolve input operands. + if (failed(parser.resolveOperands(operands, functionType.getInputs(), + parser.getCurrentLocation(), + result.operands))) + return failure(); + } else { + // Parse optional results type list. + if (parser.parseOptionalArrowTypeList(result.types)) + return failure(); + } + // Parse the 'then' region. if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); @@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { } void IfOp::print(OpAsmPrinter &p) { - bool printBlockTerminators = false; - p << " " << getCondition(); - if (!getResults().empty()) { - p << " -> (" << getResultTypes() << ")"; - // Print yield explicitly if the op defines values. - printBlockTerminators = true; + + printInitializationList(p, getThenGraph().front().getArguments(), + getInputList(), " "); + p << " : "; + p << getCondition().getType(); + + if (!getInputList().empty()) { + p << " ("; + llvm::interleaveComma(getInputList().getTypes(), p); + p << ")"; } - p << ' '; - p.printRegion(getThenGraph(), - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + p.printArrowTypeList(getResultTypes()); + p << " "; + + p.printRegion(getThenGraph()); // Print the 'else' regions if it exists and has a block. auto &elseRegion = getElseGraph(); if (!elseRegion.empty()) { p << " else "; - p.printRegion(elseRegion, - /*printEntryBlockArgs=*/false, - /*printBlockTerminators=*/printBlockTerminators); + p.printRegion(elseRegion); } p.printOptionalAttrDict((*this)->getAttrs()); @@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { parser.parseOptionalAttrDictWithKeyword(result.attributes)); } -static void printInitializationList(OpAsmPrinter &parser, - Block::BlockArgListType blocksArgs, - ValueRange initializers, - StringRef prefix = "") { - assert(blocksArgs.size() == initializers.size() && - "expected same length of arguments and initializers"); - if (initializers.empty()) - return; - - parser << prefix << '('; - llvm::interleaveComma( - llvm::zip(blocksArgs, initializers), parser, - [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); - parser << ")"; -} - void WhileOp::print(OpAsmPrinter &parser) { printInitializationList(parser, getCondGraph().front().getArguments(), getInputList(), " "); diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 9474299..0bec0da 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -81,9 +81,8 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { dyn_cast<RankedTensorType>(input.getType()).getElementType()); auto revisedInputShapeValue = getTosaConstShape(rewriter, op.getLoc(), revisedInputShape); - input = rewriter - .create<tosa::ReshapeOp>(op.getLoc(), inputType, input, - revisedInputShapeValue) + input = tosa::ReshapeOp::create(rewriter, op.getLoc(), inputType, input, + revisedInputShapeValue) .getResult(); Type resultETy = resultType.getElementType(); @@ -162,9 +161,8 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> { shiftType, rewriter.getIntegerAttr(shiftElementType, 0)); Value constZero = tosa::ConstOp::create(rewriter, op.getLoc(), shiftType, shiftZeroAttr); - Value mulValue = rewriter - .create<tosa::MulOp>(op.getLoc(), mulShapeType, input, - weight, constZero) + Value mulValue = tosa::MulOp::create(rewriter, op.getLoc(), mulShapeType, + input, weight, constZero) .getResult(); // Reshape output to [N, H, W, C * M]. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp index 88b0f36..9543fa1 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp @@ -464,9 +464,12 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) { CheckCondition condition = CheckCondition::invalid; const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition); const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition); + if (failed(maybeProfDef) && failed(maybeExtDef)) + return success(); - if (!failed(maybeProfDef) && !failed(maybeExtDef) && - !maybeProfDef.value().size() && !maybeExtDef.value().size()) { + const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) || + (succeeded(maybeExtDef) && !maybeExtDef->empty()); + if (!hasEntry) { std::string message; llvm::raw_string_ostream os(message); os << "illegal: operation operand/result data types did not align with any " diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index 32b5fb6..8ec7765 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) { // }) // // Simplified: - // %0 = tosa.cond_if %arg2 { - // tosa.yield %arg0 + // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) { + // ^bb0(%arg3, %arg4): + // tosa.yield %arg3 // } else { - // tosa.yield %arg1 + // ^bb0(%arg3, %arg4): + // tosa.yield %arg4 // } - // - // Unfortunately, the simplified syntax does not encapsulate values - // used in then/else regions (see 'simplified' example above), so it - // must be rewritten to use the generic syntax in order to be conformant - // to the specification. + return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) || failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else")); } diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index 4662836..14a4fdf 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -16,15 +16,13 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/iterator.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/InterleavedRange.h" #define DEBUG_TYPE "transform-dialect" -#define DEBUG_TYPE_FULL "transform-dialect-full" #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << (X)) -#define FULL_LDBG(X) DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, (DBGS() << (X))) +#define FULL_LDBG() LDBG(4) using namespace mlir; @@ -486,24 +484,20 @@ void transform::TransformState::recordOpHandleInvalidationOne( newlyInvalidated.count(otherHandle)) return; - FULL_LDBG("--recordOpHandleInvalidationOne\n"); - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { - (DBGS() << "--ancestors: " - << llvm::interleaved(llvm::make_pointee_range(potentialAncestors)) - << "\n"); - }); + FULL_LDBG() << "--recordOpHandleInvalidationOne"; + FULL_LDBG() << "--ancestors: " + << llvm::interleaved( + llvm::make_pointee_range(potentialAncestors)); Operation *owner = consumingHandle.getOwner(); unsigned operandNo = consumingHandle.getOperandNumber(); for (Operation *ancestor : potentialAncestors) { // clang-format off - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, - { (DBGS() << "----handle one ancestor: " << *ancestor << "\n"); }); - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, - { (DBGS() << "----of payload with name: " - << payloadOp->getName().getIdentifier() << "\n"); }); - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, - { (DBGS() << "----of payload: " << *payloadOp << "\n"); }); + FULL_LDBG() << "----handle one ancestor: " << *ancestor;; + + FULL_LDBG() << "----of payload with name: " + << payloadOp->getName().getIdentifier(); + FULL_LDBG() << "----of payload: " << *payloadOp; // clang-format on if (!ancestor->isAncestor(payloadOp)) continue; @@ -609,10 +603,8 @@ void transform::TransformState::recordOpHandleInvalidation( transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { if (potentialAncestors.empty()) { - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { - (DBGS() << "----recording invalidation for empty handle: " << handle.get() - << "\n"); - }); + FULL_LDBG() << "----recording invalidation for empty handle: " + << handle.get(); Operation *owner = handle.getOwner(); unsigned operandNo = handle.getOperandNumber(); @@ -709,7 +701,7 @@ void transform::TransformState::recordValueHandleInvalidation( LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( transform::TransformOpInterface transform, transform::TransformState::InvalidatedHandleMap &newlyInvalidated) const { - FULL_LDBG("--Start checkAndRecordHandleInvalidation\n"); + FULL_LDBG() << "--Start checkAndRecordHandleInvalidation"; auto memoryEffectsIface = cast<MemoryEffectOpInterface>(transform.getOperation()); SmallVector<MemoryEffects::EffectInstance> effects; @@ -717,9 +709,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( transform::TransformMappingResource::get(), effects); for (OpOperand &target : transform->getOpOperands()) { - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { - (DBGS() << "----iterate on handle: " << target.get() << "\n"); - }); + FULL_LDBG() << "----iterate on handle: " << target.get(); // If the operand uses an invalidated handle, report it. If the operation // allows handles to point to repeated payload operations, only report // pre-existing invalidation errors. Otherwise, also report invalidations @@ -727,14 +717,14 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( auto it = invalidatedHandles.find(target.get()); auto nit = newlyInvalidated.find(target.get()); if (it != invalidatedHandles.end()) { - FULL_LDBG("--End checkAndRecordHandleInvalidation, found already " - "invalidated -> FAILURE\n"); + FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found already " + "invalidated -> FAILURE"; return it->getSecond()(transform->getLoc()), failure(); } if (!transform.allowsRepeatedHandleOperands() && nit != newlyInvalidated.end()) { - FULL_LDBG("--End checkAndRecordHandleInvalidation, found newly " - "invalidated (by this op) -> FAILURE\n"); + FULL_LDBG() << "--End checkAndRecordHandleInvalidation, found newly " + "invalidated (by this op) -> FAILURE"; return nit->getSecond()(transform->getLoc()), failure(); } @@ -745,27 +735,28 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidationImpl( effect.getValue() == target.get(); }; if (llvm::any_of(effects, consumesTarget)) { - FULL_LDBG("----found consume effect\n"); + FULL_LDBG() << "----found consume effect"; if (llvm::isa<transform::TransformHandleTypeInterface>( target.get().getType())) { - FULL_LDBG("----recordOpHandleInvalidation\n"); + FULL_LDBG() << "----recordOpHandleInvalidation"; SmallVector<Operation *> payloadOps = llvm::to_vector(getPayloadOps(target.get())); recordOpHandleInvalidation(target, payloadOps, nullptr, newlyInvalidated); } else if (llvm::isa<transform::TransformValueHandleTypeInterface>( target.get().getType())) { - FULL_LDBG("----recordValueHandleInvalidation\n"); + FULL_LDBG() << "----recordValueHandleInvalidation"; recordValueHandleInvalidation(target, newlyInvalidated); } else { - FULL_LDBG("----not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); + FULL_LDBG() + << "----not a TransformHandle -> SKIP AND DROP ON THE FLOOR"; } } else { - FULL_LDBG("----no consume effect -> SKIP\n"); + FULL_LDBG() << "----no consume effect -> SKIP"; } } - FULL_LDBG("--End checkAndRecordHandleInvalidation -> SUCCESS\n"); + FULL_LDBG() << "--End checkAndRecordHandleInvalidation -> SUCCESS"; return success(); } @@ -818,18 +809,14 @@ void transform::TransformState::compactOpHandles() { DiagnosedSilenceableFailure transform::TransformState::applyTransform(TransformOpInterface transform) { - LLVM_DEBUG({ - DBGS() << "applying: "; - transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n"; - }); - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, - DBGS() << "Top-level payload before application:\n" - << *getTopLevel() << "\n"); + LDBG() << "applying: " + << OpWithFlags(transform, OpPrintingFlags().skipRegions()); + FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel(); auto printOnFailureRAII = llvm::make_scope_exit([this] { (void)this; - LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print( - llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm());); + LDBG() << "Failing Top-level payload:\n" + << OpWithFlags(getTopLevel(), + OpPrintingFlags().printGenericOpForm()); }); // Set current transform op. @@ -837,47 +824,45 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { // Expensive checks to detect invalid transform IR. if (options.getExpensiveChecksEnabled()) { - FULL_LDBG("ExpensiveChecksEnabled\n"); + FULL_LDBG() << "ExpensiveChecksEnabled"; if (failed(checkAndRecordHandleInvalidation(transform))) return DiagnosedSilenceableFailure::definiteFailure(); for (OpOperand &operand : transform->getOpOperands()) { - DEBUG_WITH_TYPE(DEBUG_TYPE_FULL, { - (DBGS() << "iterate on handle: " << operand.get() << "\n"); - }); + FULL_LDBG() << "iterate on handle: " << operand.get(); if (!isHandleConsumed(operand.get(), transform)) { - FULL_LDBG("--handle not consumed -> SKIP\n"); + FULL_LDBG() << "--handle not consumed -> SKIP"; continue; } if (transform.allowsRepeatedHandleOperands()) { - FULL_LDBG("--op allows repeated handles -> SKIP\n"); + FULL_LDBG() << "--op allows repeated handles -> SKIP"; continue; } - FULL_LDBG("--handle is consumed\n"); + FULL_LDBG() << "--handle is consumed"; Type operandType = operand.get().getType(); if (llvm::isa<TransformHandleTypeInterface>(operandType)) { - FULL_LDBG("--checkRepeatedConsumptionInOperand for Operation*\n"); + FULL_LDBG() << "--checkRepeatedConsumptionInOperand for Operation*"; DiagnosedSilenceableFailure check = checkRepeatedConsumptionInOperand<Operation *>( getPayloadOpsView(operand.get()), transform, operand.getOperandNumber()); if (!check.succeeded()) { - FULL_LDBG("----FAILED\n"); + FULL_LDBG() << "----FAILED"; return check; } } else if (llvm::isa<TransformValueHandleTypeInterface>(operandType)) { - FULL_LDBG("--checkRepeatedConsumptionInOperand For Value\n"); + FULL_LDBG() << "--checkRepeatedConsumptionInOperand For Value"; DiagnosedSilenceableFailure check = checkRepeatedConsumptionInOperand<Value>( getPayloadValuesView(operand.get()), transform, operand.getOperandNumber()); if (!check.succeeded()) { - FULL_LDBG("----FAILED\n"); + FULL_LDBG() << "----FAILED"; return check; } } else { - FULL_LDBG("--not a TransformHandle -> SKIP AND DROP ON THE FLOOR\n"); + FULL_LDBG() << "--not a TransformHandle -> SKIP AND DROP ON THE FLOOR"; } } } @@ -999,8 +984,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { printOnFailureRAII.release(); DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, { - DBGS() << "Top-level payload:\n"; - getTopLevel()->print(llvm::dbgs()); + LDBG() << "Top-level payload:\n" << *getTopLevel(); }); return result; } @@ -1277,7 +1261,7 @@ void transform::TrackingListener::notifyMatchFailure( LLVM_DEBUG({ Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); - DBGS() << "Match Failure : " << diag.str() << "\n"; + LDBG() << "Match Failure : " << diag.str(); }); } diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt index d464230..0248896 100644 --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -26,6 +26,7 @@ add_mlir_dialect_library(MLIRVectorDialect MLIRMemRefDialect MLIRSideEffectInterfaces MLIRTensorDialect + MLIRUBDialect MLIRValueBoundsOpInterface MLIRVectorInterfaces ) diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8c97aed..86fbb76 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -372,9 +372,8 @@ SmallVector<Value> vector::getAsValues(OpBuilder &builder, Location loc, llvm::transform(foldResults, std::back_inserter(values), [&](OpFoldResult foldResult) { if (auto attr = dyn_cast<Attribute>(foldResult)) - return builder - .create<arith::ConstantIndexOp>( - loc, cast<IntegerAttr>(attr).getInt()) + return arith::ConstantIndexOp::create( + builder, loc, cast<IntegerAttr>(attr).getInt()) .getResult(); return cast<Value>(foldResult); @@ -1259,63 +1258,6 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, CanonicalizeContractAdd<arith::AddFOp>>(context); } -//===----------------------------------------------------------------------===// -// ExtractElementOp -//===----------------------------------------------------------------------===// - -void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges.front()); -} - -void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, - Value source) { - result.addOperands({source}); - result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType()); -} - -LogicalResult vector::ExtractElementOp::verify() { - VectorType vectorType = getSourceVectorType(); - if (vectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (vectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here now. - if (!adaptor.getPosition()) - return {}; - - // Fold extractelement (splat X) -> X. - if (auto splat = getVector().getDefiningOp<vector::SplatOp>()) - return splat.getInput(); - - // Fold extractelement(broadcast(X)) -> X. - if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>()) - if (!llvm::isa<VectorType>(broadcast.getSource().getType())) - return broadcast.getSource(); - - auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!pos || !src) - return {}; - - auto srcElements = src.getValues<Attribute>(); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= srcElements.size()) - return {}; - - return srcElements[posIdx]; -} - // Returns `true` if `index` is either within [0, maxIndex) or equal to // `poisonValue`. static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, @@ -2591,8 +2533,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> { llvm::enumerate(fromElements.getElements())) { // Check that the element is from a vector.extract operation. - auto extractOp = - dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp()); + auto extractOp = element.getDefiningOp<vector::ExtractOp>(); if (!extractOp) { return rewriter.notifyMatchFailure(fromElements, "element not from vector.extract"); @@ -3186,60 +3127,6 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// -// InsertElementOp -//===----------------------------------------------------------------------===// - -void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1])); -} - -void InsertElementOp::build(OpBuilder &builder, OperationState &result, - Value source, Value dest) { - build(builder, result, source, dest, {}); -} - -LogicalResult InsertElementOp::verify() { - auto dstVectorType = getDestVectorType(); - if (dstVectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (dstVectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here. - if (!adaptor.getPosition()) - return {}; - - auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource()); - auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!src || !dst || !pos) - return {}; - - if (src.getType() != getDestVectorType().getElementType()) - return {}; - - auto dstElements = dst.getValues<Attribute>(); - - SmallVector<Attribute> results(dstElements); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= results.size()) - return {}; - results[posIdx] = src; - - return DenseElementsAttr::get(getDestVectorType(), results); -} - -//===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// @@ -6429,6 +6316,11 @@ std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() { return llvm::to_vector<4>(getResultVectorType().getShape()); } +void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges.front()); +} + namespace { // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. @@ -7311,6 +7203,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, } //===----------------------------------------------------------------------===// +// StepOp +//===----------------------------------------------------------------------===// + +void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, + SetIntRangeFn setResultRanges) { + auto resultType = cast<VectorType>(getType()); + if (resultType.isScalable()) { + return; + } + unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType); + APInt zero(bitwidth, 0); + APInt high(bitwidth, resultType.getDimSize(0) - 1); + ConstantIntRanges result = {zero, high, zero, high}; + setResultRanges(getResult(), result); +} + +//===----------------------------------------------------------------------===// // Vector Masking Utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp index cb8e566..dedc3b3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp @@ -28,7 +28,10 @@ using namespace mlir; using namespace mlir::vector; namespace { -/// Progressive lowering of BroadcastOp. + +/// Convert a vector.broadcast with a vector operand to a lower rank +/// vector.broadcast. vector.broadcast with a scalar operand is expected to be +/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly. class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> { public: using OpRewritePattern::OpRewritePattern; @@ -40,20 +43,23 @@ public: VectorType srcType = dyn_cast<VectorType>(op.getSourceType()); Type eltType = dstType.getElementType(); - // Scalar to any vector can use splat. - if (!srcType) { - rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource()); - return success(); - } + // A broadcast from a scalar is considered to be in the lowered form. + if (!srcType) + return rewriter.notifyMatchFailure( + op, "broadcast from scalar already in lowered form"); // Determine rank of source and destination. int64_t srcRank = srcType.getRank(); int64_t dstRank = dstType.getRank(); - // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat. + // Here we are broadcasting to a rank-1 vector. Ensure that the source is a + // scalar. if (srcRank <= 1 && dstRank == 1) { - Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource()); - rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext); + SmallVector<int64_t> fullRankPosition(srcRank, 0); + Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(), + fullRankPosition); + assert(!isa<VectorType>(ext.getType()) && "expected scalar"); + rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index 2484670..e062f55 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -248,11 +248,10 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { scf::YieldOp::create(b, loc, result); }; - result = - rewriter - .create<scf::IfOp>(loc, condition, /*thenBuilder=*/loadBuilder, + result = scf::IfOp::create(rewriter, loc, condition, + /*thenBuilder=*/loadBuilder, /*elseBuilder=*/passThruBuilder) - .getResult(0); + .getResult(0); } rewriter.replaceOp(op, result); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp index e910932..2cf8f0b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp @@ -142,8 +142,8 @@ struct TransferReadPermutationLowering // Transpose result of transfer_read. SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); - return rewriter - .create<vector::TransposeOp>(op.getLoc(), newRead, transposePerm) + return vector::TransposeOp::create(rewriter, op.getLoc(), newRead, + transposePerm) .getResult(); } }; @@ -371,8 +371,8 @@ struct TransferOpReduceRank rewriter, op.getLoc(), newReadType, op.getBase(), op.getIndices(), AffineMapAttr::get(newMap), op.getPadding(), op.getMask(), newInBoundsAttr); - return rewriter - .create<vector::BroadcastOp>(op.getLoc(), originalVecType, newRead) + return vector::BroadcastOp::create(rewriter, op.getLoc(), originalVecType, + newRead) .getVector(); } }; @@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering read, "vector type is not rank 1, can't create masked load, needs " "VectorToSCF"); - Value fill = vector::SplatOp::create( + Value fill = vector::BroadcastOp::create( rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding()); res = vector::MaskedLoadOp::create( rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(), diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 58e94ea..bb0f339 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -451,10 +451,9 @@ struct WarpOpTransferWrite : public WarpDistributionPattern { } SmallVector<Value> delinearized; if (map.getNumResults() > 1) { - delinearized = rewriter - .create<mlir::affine::AffineDelinearizeIndexOp>( - newWarpOp.getLoc(), newWarpOp.getLaneid(), - delinearizedIdSizes) + delinearized = mlir::affine::AffineDelinearizeIndexOp::create( + rewriter, newWarpOp.getLoc(), newWarpOp.getLaneid(), + delinearizedIdSizes) .getResults(); } else { // If there is only one map result, we can elide the delinearization @@ -1538,19 +1537,18 @@ struct WarpOpInsertScalar : public WarpDistributionPattern { arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane); Value newResult = - rewriter - .create<scf::IfOp>( - loc, isInsertingLane, - /*thenBuilder=*/ - [&](OpBuilder &builder, Location loc) { - Value newInsert = vector::InsertOp::create( - builder, loc, newSource, distributedVec, newPos); - scf::YieldOp::create(builder, loc, newInsert); - }, - /*elseBuilder=*/ - [&](OpBuilder &builder, Location loc) { - scf::YieldOp::create(builder, loc, distributedVec); - }) + scf::IfOp::create( + rewriter, loc, isInsertingLane, + /*thenBuilder=*/ + [&](OpBuilder &builder, Location loc) { + Value newInsert = vector::InsertOp::create( + builder, loc, newSource, distributedVec, newPos); + scf::YieldOp::create(builder, loc, newInsert); + }, + /*elseBuilder=*/ + [&](OpBuilder &builder, Location loc) { + scf::YieldOp::create(builder, loc, distributedVec); + }) .getResult(0); rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber), newResult); return success(); @@ -1661,10 +1659,9 @@ struct WarpOpInsert : public WarpDistributionPattern { auto nonInsertingBuilder = [&](OpBuilder &builder, Location loc) { scf::YieldOp::create(builder, loc, distributedDest); }; - newResult = rewriter - .create<scf::IfOp>(loc, isInsertingLane, - /*thenBuilder=*/insertingBuilder, - /*elseBuilder=*/nonInsertingBuilder) + newResult = scf::IfOp::create(rewriter, loc, isInsertingLane, + /*thenBuilder=*/insertingBuilder, + /*elseBuilder=*/nonInsertingBuilder) .getResult(0); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index 73388a5..9889d7f2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -466,9 +466,9 @@ mlir::vector::castAwayContractionLeadingOneDim(vector::ContractionOp contractOp, newOp = mlir::vector::maskOperation(rewriter, newOp, newMask); } - return rewriter - .create<vector::BroadcastOp>(loc, contractOp->getResultTypes()[0], - newOp->getResults()[0]) + return vector::BroadcastOp::create(rewriter, loc, + contractOp->getResultTypes()[0], + newOp->getResults()[0]) .getResult(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp index e6bb96f..f78e579 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp @@ -32,7 +32,7 @@ #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include <cstdint> @@ -41,9 +41,6 @@ using namespace mlir; #define DEBUG_TYPE "vector-narrow-type-emulation" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using VectorValue = TypedValue<VectorType>; using MemRefValue = TypedValue<MemRefType>; @@ -135,17 +132,16 @@ static FailureOr<Operation *> getCompressedMaskOp(OpBuilder &rewriter, return vector::CreateMaskOp::create(rewriter, loc, newMaskType, newMaskOperands); }) - .Case<vector::ConstantMaskOp>( - [&](auto constantMaskOp) -> std::optional<Operation *> { - // Take the shape of mask, compress its trailing dimension: - SmallVector<int64_t> maskDimSizes( - constantMaskOp.getMaskDimSizes()); - int64_t &maskIndex = maskDimSizes.back(); - maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex, - numSrcElemsPerDest); - return vector::ConstantMaskOp::create( - rewriter, loc, newMaskType, maskDimSizes); - }) + .Case<vector::ConstantMaskOp>([&](auto constantMaskOp) + -> std::optional<Operation *> { + // Take the shape of mask, compress its trailing dimension: + SmallVector<int64_t> maskDimSizes(constantMaskOp.getMaskDimSizes()); + int64_t &maskIndex = maskDimSizes.back(); + maskIndex = llvm::divideCeil(numFrontPadElems + maskIndex, + numSrcElemsPerDest); + return vector::ConstantMaskOp::create(rewriter, loc, newMaskType, + maskDimSizes); + }) .Case<arith::ConstantOp>([&](auto constantOp) -> std::optional<Operation *> { // TODO: Support multiple dimensions. @@ -232,9 +228,8 @@ static Value staticallyExtractSubvector(OpBuilder &rewriter, Location loc, auto resultVectorType = VectorType::get({numElemsToExtract}, vectorType.getElementType()); - return rewriter - .create<vector::ExtractStridedSliceOp>(loc, resultVectorType, src, - offsets, sizes, strides) + return vector::ExtractStridedSliceOp::create(rewriter, loc, resultVectorType, + src, offsets, sizes, strides) ->getResult(0); } @@ -1526,11 +1521,11 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType, "requires -D non-scalable vector type"); int64_t sourceBitWidth = sourceVectorType.getElementTypeBitWidth(); int64_t mostMinorSourceDim = sourceVectorType.getShape().back(); - LDBG("sourceVectorType: " << sourceVectorType); + LDBG() << "sourceVectorType: " << sourceVectorType; int64_t targetBitWidth = targetVectorType.getElementTypeBitWidth(); int64_t mostMinorTargetDim = targetVectorType.getShape().back(); - LDBG("targetVectorType: " << targetVectorType); + LDBG() << "targetVectorType: " << targetVectorType; int64_t bitwidth = targetBitWidth * mostMinorTargetDim; (void)mostMinorSourceDim; @@ -1555,7 +1550,7 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType, BitCastRewriter::BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType) : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) { - LDBG("\n" << enumerator.sourceElementRanges); + LDBG() << "\n" << enumerator.sourceElementRanges; } /// Verify that the precondition type meets the common preconditions for any diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp index 72352d7..cbb9d4b 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp @@ -303,7 +303,7 @@ public: // Extract/insert on a lower ranked extract strided slice op. Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = SplatOp::create(rewriter, loc, dstType, zero); + Value res = BroadcastOp::create(rewriter, loc, dstType, zero); for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e; off += stride, ++idx) { Value one = ExtractOp::create(rewriter, loc, op.getVector(), off); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index 2676d25..c707f38 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -25,12 +25,10 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "vector-transfer-opt" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") - using namespace mlir; /// Return the ancestor op in the region or nullptr if the region is not @@ -88,8 +86,7 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) { /// transfer_write is dead if all reads that can be reached from the potentially /// dead transfer_write are dominated by the overwriting transfer_write. void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { - LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation() - << "\n"); + LDBG() << "Candidate for dead store: " << *write.getOperation(); llvm::SmallVector<Operation *, 8> blockingAccesses; Operation *firstOverwriteCandidate = nullptr; Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase())); @@ -150,13 +147,12 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { !isReachable(writeAncestor, accessAncestor)) continue; if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) { - LLVM_DEBUG(DBGS() << "Store may not be dead due to op: " - << *accessAncestor << "\n"); + LDBG() << "Store may not be dead due to op: " << *accessAncestor; return; } } - LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation() - << " overwritten by: " << *firstOverwriteCandidate << "\n"); + LDBG() << "Found dead store: " << *write.getOperation() + << " overwritten by: " << *firstOverwriteCandidate; opToErase.push_back(write.getOperation()); } @@ -174,8 +170,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { if (read.hasOutOfBoundsDim()) return; - LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation() - << "\n"); + LDBG() << "Candidate for Forwarding: " << *read.getOperation(); SmallVector<Operation *, 8> blockingWrites; vector::TransferWriteOp lastwrite = nullptr; Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase())); @@ -230,14 +225,13 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor)) continue; if (!postDominators.postDominates(lastwrite, write)) { - LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: " - << *write << "\n"); + LDBG() << "Fail to do write to read forwarding due to op: " << *write; return; } } - LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation() - << " to: " << *read.getOperation() << "\n"); + LDBG() << "Forward value from " << *lastwrite.getOperation() + << " to: " << *read.getOperation(); read.replaceAllUsesWith(lastwrite.getVector()); opToErase.push_back(read.getOperation()); } @@ -330,8 +324,8 @@ createMaskDropNonScalableUnitDims(PatternRewriter &rewriter, Location loc, } reducedOperands.push_back(operand); } - return rewriter - .create<vector::CreateMaskOp>(loc, reducedType, reducedOperands) + return vector::CreateMaskOp::create(rewriter, loc, reducedType, + reducedOperands) .getResult(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp index 05b0074..5e12dc4 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -348,24 +348,23 @@ getLocationToWriteFullVec(RewriterBase &b, vector::TransferWriteOp xferOp, Location loc = xferOp.getLoc(); Value zero = arith::ConstantIndexOp::create(b, loc, 0); Value memref = xferOp.getBase(); - return b - .create<scf::IfOp>( - loc, inBoundsCond, - [&](OpBuilder &b, Location loc) { - Value res = - castToCompatibleMemRefType(b, memref, compatibleMemRefType); - scf::ValueVector viewAndIndices{res}; - llvm::append_range(viewAndIndices, xferOp.getIndices()); - scf::YieldOp::create(b, loc, viewAndIndices); - }, - [&](OpBuilder &b, Location loc) { - Value casted = - castToCompatibleMemRefType(b, alloc, compatibleMemRefType); - scf::ValueVector viewAndIndices{casted}; - viewAndIndices.insert(viewAndIndices.end(), - xferOp.getTransferRank(), zero); - scf::YieldOp::create(b, loc, viewAndIndices); - }) + return scf::IfOp::create( + b, loc, inBoundsCond, + [&](OpBuilder &b, Location loc) { + Value res = + castToCompatibleMemRefType(b, memref, compatibleMemRefType); + scf::ValueVector viewAndIndices{res}; + llvm::append_range(viewAndIndices, xferOp.getIndices()); + scf::YieldOp::create(b, loc, viewAndIndices); + }, + [&](OpBuilder &b, Location loc) { + Value casted = + castToCompatibleMemRefType(b, alloc, compatibleMemRefType); + scf::ValueVector viewAndIndices{casted}; + viewAndIndices.insert(viewAndIndices.end(), + xferOp.getTransferRank(), zero); + scf::YieldOp::create(b, loc, viewAndIndices); + }) ->getResults(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 73ca327..2269a40 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -410,9 +410,8 @@ FailureOr<Value> combineContractAndBroadcast(vector::ContractionOp contractOp, oldMaskType.getScalableDims().drop_front(unusedDimsBitVector.count()); VectorType maskOpType = VectorType::get(newShape, rewriter.getI1Type(), newShapeScalableDims); - mask = rewriter - .create<vector::ShapeCastOp>(contractOp.getLoc(), maskOpType, - maskingOp.getMask()) + mask = vector::ShapeCastOp::create(rewriter, contractOp.getLoc(), + maskOpType, maskingOp.getMask()) .getResult(); } @@ -940,7 +939,7 @@ public: Value zero = arith::ConstantOp::create(rewriter, loc, elemType, rewriter.getZeroAttr(elemType)); - Value res = SplatOp::create(rewriter, loc, castDstType, zero); + Value res = BroadcastOp::create(rewriter, loc, castDstType, zero); SmallVector<int64_t> sliceShape = {castDstLastDim}; SmallVector<int64_t> strides = {1}; @@ -966,6 +965,45 @@ private: std::function<bool(BitCastOp)> controlFn; }; +static bool haveSameShapeAndScaling(Type t, Type u) { + auto tVec = dyn_cast<VectorType>(t); + auto uVec = dyn_cast<VectorType>(u); + if (!tVec) { + return !uVec; + } + if (!uVec) { + return false; + } + return tVec.getShape() == uVec.getShape() && + tVec.getScalableDims() == uVec.getScalableDims(); +} + +/// If `type` is shaped, clone it with `newElementType`. Otherwise, +/// return `newElementType`. +static Type cloneOrReplace(Type type, Type newElementType) { + if (auto shapedType = dyn_cast<ShapedType>(type)) { + return shapedType.clone(newElementType); + } + return newElementType; +} + +/// If `value` is the result of a splat or broadcast operation, return the input +/// of the splat/broadcast operation. +static Value getBroadcastLikeSource(Value value) { + + Operation *op = value.getDefiningOp(); + if (!op) + return {}; + + if (auto broadcast = dyn_cast<vector::BroadcastOp>(op)) + return broadcast.getSource(); + + if (auto splat = dyn_cast<vector::SplatOp>(op)) + return splat.getInput(); + + return {}; +} + /// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex: /// /// Example: @@ -989,16 +1027,14 @@ struct ReorderElementwiseOpsOnBroadcast final PatternRewriter &rewriter) const override { if (op->getNumResults() != 1) return failure(); - if (!llvm::isa<ShapedType>(op->getResults()[0].getType())) + auto resultType = dyn_cast<VectorType>(op->getResult(0).getType()); + if (!resultType) return failure(); if (!OpTrait::hasElementwiseMappableTraits(op)) return rewriter.notifyMatchFailure( op, "Op doesn't have ElementwiseMappableTraits"); if (op->getNumOperands() == 0) return failure(); - if (op->getResults()[0].getType() != op->getOperand(0).getType()) - return rewriter.notifyMatchFailure(op, - "result and operand type mismatch"); if (isa<vector::FMAOp>(op)) { return rewriter.notifyMatchFailure( op, @@ -1006,45 +1042,71 @@ struct ReorderElementwiseOpsOnBroadcast final "might be a scalar"); } - // Get the type of the lhs operand - auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp(); - if (!lhsBcastOrSplat || - !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat)) + Type resultElemType = resultType.getElementType(); + + // Get the type of the first non-constant operand + Value splatSource; + for (Value operand : op->getOperands()) { + Operation *definingOp = operand.getDefiningOp(); + if (!definingOp) + return failure(); + if (definingOp->hasTrait<OpTrait::ConstantLike>()) + continue; + splatSource = getBroadcastLikeSource(operand); + break; + } + if (!splatSource) return failure(); - auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType(); + Type unbroadcastResultType = + cloneOrReplace(splatSource.getType(), resultElemType); - // Make sure that all operands are broadcast from identical types: + // Make sure that all operands are broadcast from identically-shaped types: // * scalar (`vector.broadcast` + `vector.splat`), or // * vector (`vector.broadcast`). // Otherwise the re-ordering wouldn't be safe. - if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) { - auto bcast = val.getDefiningOp<vector::BroadcastOp>(); - if (bcast) - return (bcast.getOperand().getType() == lhsBcastOrSplatType); - auto splat = val.getDefiningOp<vector::SplatOp>(); - if (splat) - return (splat.getOperand().getType() == lhsBcastOrSplatType); - return false; + if (!llvm::all_of(op->getOperands(), [splatSource](Value val) { + if (auto source = getBroadcastLikeSource(val)) + return haveSameShapeAndScaling(source.getType(), + splatSource.getType()); + SplatElementsAttr splatConst; + return matchPattern(val, m_Constant(&splatConst)); })) { - return failure(); + return rewriter.notifyMatchFailure( + op, + "not all operands are constants or broadcasts from the same type"); } // Collect the source values before broadcasting SmallVector<Value> srcValues; srcValues.reserve(op->getNumOperands()); for (Value operand : op->getOperands()) { - srcValues.push_back(operand.getDefiningOp()->getOperand(0)); + SplatElementsAttr splatConst; + if (matchPattern(operand, m_Constant(&splatConst))) { + Attribute newConst; + Type elementType = getElementTypeOrSelf(operand.getType()); + Type newType = cloneOrReplace(unbroadcastResultType, elementType); + if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) { + newConst = splatConst.resizeSplat(newTypeShaped); + } else { + newConst = splatConst.getSplatValue<Attribute>(); + } + Operation *newConstOp = + operand.getDefiningOp()->getDialect()->materializeConstant( + rewriter, newConst, newType, operand.getLoc()); + srcValues.push_back(newConstOp->getResult(0)); + } else { + srcValues.push_back(operand.getDefiningOp()->getOperand(0)); + } } // Create the "elementwise" Op Operation *elementwiseOp = rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues, - lhsBcastOrSplatType, op->getAttrs()); + unbroadcastResultType, op->getAttrs()); // Replace the original Op with the elementwise Op - auto vectorType = op->getResultTypes()[0]; rewriter.replaceOpWithNewOp<vector::BroadcastOp>( - op, vectorType, elementwiseOp->getResults()); + op, resultType, elementwiseOp->getResults()); return success(); } @@ -1240,15 +1302,17 @@ public: return rewriter.notifyMatchFailure( op, "only 1-element vectors are supported"); - Operation *splat = op.getValueToStore().getDefiningOp(); - if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat)) - return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast"); + Value toStore = op.getValueToStore(); + Value source = getBroadcastLikeSource(toStore); + if (!source) + return rewriter.notifyMatchFailure( + op, "value to store is not from a broadcast"); // Checking for single use so we can remove splat. + Operation *splat = toStore.getDefiningOp(); if (!splat->hasOneUse()) return rewriter.notifyMatchFailure(op, "expected single op use"); - Value source = splat->getOperand(0); Value base = op.getBase(); ValueRange indices = op.getIndices(); @@ -1298,13 +1362,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op, // Add in an offset if requested. if (off) { Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off); - Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o); + Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o); indices = arith::AddIOp::create(rewriter, loc, ov, indices); } // Construct the vector comparison. Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b); Value bounds = - vector::SplatOp::create(rewriter, loc, indices.getType(), bound); + vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound); return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt, indices, bounds); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index fceba65..501abec 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -16,13 +16,11 @@ #include "mlir/Interfaces/VectorInterfaces.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #include <optional> #define DEBUG_TYPE "vector-unroll" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; using namespace mlir::vector; @@ -90,10 +88,9 @@ static Operation *cloneOpWithOperandsAndTypes(OpBuilder &builder, Location loc, /// std::nullopt if the op shouldn't be or cannot be unrolled. static std::optional<SmallVector<int64_t>> getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { - LDBG(""); - LDBG("Get unroll shape for op " << op->getName().getStringRef()); + LDBG() << "Get unroll shape for op " << op->getName().getStringRef(); if (options.filterConstraint && failed(options.filterConstraint(op))) { - LDBG("--no filter constraint -> BAIL"); + LDBG() << "--no filter constraint -> BAIL"; return std::nullopt; } assert(options.nativeShape && @@ -101,33 +98,33 @@ getTargetShape(const vector::UnrollVectorOptions &options, Operation *op) { "shape call back function to be set"); auto unrollableVectorOp = dyn_cast<VectorUnrollOpInterface>(op); if (!unrollableVectorOp) { - LDBG("--not an unrollable op -> BAIL"); + LDBG() << "--not an unrollable op -> BAIL"; return std::nullopt; } auto maybeUnrollShape = unrollableVectorOp.getShapeForUnroll(); if (!maybeUnrollShape) { - LDBG("--could not get shape of op " << *op << " -> BAIL"); + LDBG() << "--could not get shape of op " << *op << " -> BAIL"; return std::nullopt; } - LDBG("--vector op shape: " << llvm::interleaved(*maybeUnrollShape)); + LDBG() << "--vector op shape: " << llvm::interleaved(*maybeUnrollShape); std::optional<SmallVector<int64_t>> targetShape = options.nativeShape(op); if (!targetShape) { - LDBG("--no unrolling target shape defined " << *op << "-> SKIP"); + LDBG() << "--no unrolling target shape defined " << *op << "-> SKIP"; return std::nullopt; } - LDBG("--target shape: " << llvm::interleaved(*targetShape)); + LDBG() << "--target shape: " << llvm::interleaved(*targetShape); auto maybeShapeRatio = computeShapeRatio(*maybeUnrollShape, *targetShape); if (!maybeShapeRatio) { - LDBG("--could not compute integral shape ratio -> BAIL"); + LDBG() << "--could not compute integral shape ratio -> BAIL"; return std::nullopt; } if (llvm::all_of(*maybeShapeRatio, [](int64_t v) { return v == 1; })) { - LDBG("--no unrolling needed -> SKIP"); + LDBG() << "--no unrolling needed -> SKIP"; return std::nullopt; } - LDBG("--found an integral shape ratio to unroll to -> SUCCESS"); + LDBG() << "--found an integral shape ratio to unroll to -> SUCCESS"; return targetShape; } @@ -169,7 +166,7 @@ struct UnrollTransferReadPattern auto sourceVectorType = readOp.getVectorType(); SmallVector<int64_t> strides(targetShape->size(), 1); Location loc = readOp.getLoc(); - ArrayRef<int64_t> originalSize = readOp.getVectorType().getShape(); + ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); // Prepare the result vector; Value result = @@ -225,6 +222,14 @@ struct UnrollTransferWritePattern SmallVector<int64_t> strides(targetShape->size(), 1); Location loc = writeOp.getLoc(); ArrayRef<int64_t> originalSize = sourceVectorType.getShape(); + // Bail-out if rank(source) != rank(target). The main limitation here is the + // fact that `ExtractStridedSlice` requires the rank for the input and + // output to match. If needed, we can relax this later. + if (originalSize.size() != targetShape->size()) + return rewriter.notifyMatchFailure( + writeOp, + "expected source input vector rank to match target shape rank"); + SmallVector<Value> originalIndices(writeOp.getIndices().begin(), writeOp.getIndices().end()); SmallVector<int64_t> loopOrder = diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index c045063..10ed2bc 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -27,13 +27,11 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #define DEBUG_TYPE "vector-utils" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") - using namespace mlir; /// Helper function that creates a memref::DimOp or tensor::DimOp depending on @@ -369,14 +367,14 @@ Value vector::createReadOrMaskedRead(OpBuilder &builder, Location loc, LogicalResult vector::isValidMaskedInputVector(ArrayRef<int64_t> shape, ArrayRef<int64_t> inputVectorSizes) { - LDBG("Iteration space static sizes:" << llvm::interleaved(shape)); + LDBG() << "Iteration space static sizes:" << llvm::interleaved(shape); if (inputVectorSizes.size() != shape.size()) { - LDBG("Input vector sizes don't match the number of loops"); + LDBG() << "Input vector sizes don't match the number of loops"; return failure(); } if (ShapedType::isDynamicShape(inputVectorSizes)) { - LDBG("Input vector sizes can't have dynamic dimensions"); + LDBG() << "Input vector sizes can't have dynamic dimensions"; return failure(); } if (!llvm::all_of(llvm::zip(shape, inputVectorSizes), @@ -386,8 +384,9 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape, return ShapedType::isDynamic(staticSize) || staticSize <= inputSize; })) { - LDBG("Input vector sizes must be greater than or equal to iteration space " - "static sizes"); + LDBG() << "Input vector sizes must be greater than or equal to iteration " + "space " + "static sizes"; return failure(); } return success(); diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index 704deea..33450f3 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -110,6 +110,34 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, return success(); } +static LogicalResult +isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, + int64_t chunkSize, + function_ref<InFlightDiagnostic()> emitError) { + + if (!valueTy) + return emitError() << "Expecting a vector type result."; + + auto maskShape = getShapeOf(maskTy); + auto valueShape = getShapeOf(valueTy); + + // a valid shape for SIMT case + if (valueTy.getRank() == 1) { + if (valueTy.getNumElements() != chunkSize) + return emitError() << "value elements must match chunk size " << chunkSize + << " for SIMT code."; + return success(); + } + + llvm::SmallVector<int64_t> expectedMaskShape(valueShape); + if (chunkSize > 1) + expectedMaskShape.pop_back(); + if (expectedMaskShape != maskShape) + return emitError() << "Mask should match value except the chunk size dim."; + + return success(); +} + //===----------------------------------------------------------------------===// // XeGPU_CreateNdDescOp //===----------------------------------------------------------------------===// @@ -644,9 +672,14 @@ LogicalResult CreateDescOp::verify() { //===----------------------------------------------------------------------===// LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); - if (!tdescTy.isScattered()) + + if (tdescTy && !tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -659,6 +692,13 @@ LogicalResult PrefetchOp::verify() { return success(); } +void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_LoadGatherOp //===----------------------------------------------------------------------===// @@ -667,6 +707,13 @@ LogicalResult LoadGatherOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc."); + + if (!tdescTy && getRankOf(getSource()) > 1) + return emitOpError( + "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -676,8 +723,27 @@ LogicalResult LoadGatherOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + auto srcTy = getSourceType(); + uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); + auto memTy = dyn_cast<MemRefType>(srcTy); + + if (memTy && (valueTy.getElementType() != memTy.getElementType())) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); +} + +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, valueType, source, Value(), mask, IntegerAttr(), + l1_hint, l2_hint, l3_hint); } //===----------------------------------------------------------------------===// @@ -688,6 +754,13 @@ LogicalResult StoreScatterOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc.\n"); + + if (!tdescTy && getRankOf(getDest()) > 1) + return emitOpError( + "Expecting the dest is a 1D memref or pointer (uint64_t)."); + if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -697,8 +770,28 @@ LogicalResult StoreScatterOp::verify() { if (!isWriteHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); - return isValidGatherScatterParams(maskTy, valueTy, tdescTy, - [&]() { return emitOpError(); }); + if (tdescTy) + return isValidGatherScatterParams(maskTy, valueTy, tdescTy, + [&]() { return emitOpError(); }); + + auto destTy = getDestType(); + uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); + auto memTy = dyn_cast<MemRefType>(destTy); + + if (memTy && (valueTy.getElementType() != memTy.getElementType())) + return emitError() << "Value should have the same element type as MemRef."; + + return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + [&]() { return emitOpError(); }); +} + +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value dest, Value mask, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint, + l2_hint, l3_hint); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index 4656f11..d82c541 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -17,6 +17,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" namespace mlir { namespace xegpu { @@ -26,8 +27,6 @@ namespace xegpu { } // namespace mlir #define DEBUG_TYPE "xegpu-blocking" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; @@ -53,7 +52,7 @@ resolveUnrealizedConversionCastOp(UnrealizedConversionCastOp castOp) { // We only interest in the case where all inputs and outputs have the // identical VectorTypes if (!hasIdenticalVectorTypes(inputs) || !hasIdenticalVectorTypes(outputs)) { - LDBG("skip unrealized conversion cast op not emulating pack/unpack."); + LDBG() << "skip unrealized conversion cast op not emulating pack/unpack."; return; } @@ -149,7 +148,7 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { if (auto type = dyn_cast<ShapedType>(value.getType())) return llvm::to_vector(type.getShape()); } - LDBG("failed to getTileShape for: " << value); + LDBG() << "failed to getTileShape for: " << value; return std::nullopt; } @@ -214,7 +213,7 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const { return layout && layout.isWgLayout(); }); if (hasWgLayoutOperands || hasWgLayoutResults) { - LDBG("skip unrolling for op with workgroup level layout: " << *op); + LDBG() << "skip unrolling for op with workgroup level layout: " << *op; return false; } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp index a6208b4..c793b71 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp @@ -17,7 +17,7 @@ #include "mlir/Dialect/XeGPU/Transforms/Transforms.h" #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" namespace mlir { namespace xegpu { @@ -27,8 +27,6 @@ namespace xegpu { } // namespace mlir #define DEBUG_TYPE "xegpu-unroll" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; @@ -44,11 +42,10 @@ protected: /// Return the target shape for the given `op`. Return std::nullopt if the /// op shouldn't be or cannot be unrolled. std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const { - LDBG(""); - LDBG("Get unroll shape for: " << *op); + LDBG() << "Get unroll shape for: " << *op; if (options.filterConstraint && failed(options.filterConstraint(op))) { - LDBG("--no filter constraint -> BAIL"); + LDBG() << "--no filter constraint -> BAIL"; return std::nullopt; } @@ -484,7 +481,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> { VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); @@ -546,7 +544,8 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> { Location loc = op.getLoc(); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); @@ -575,7 +574,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> { VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType()); xegpu::TensorDescType tdescTy = op.getTensorDescType(); - if (!tdescTy.isScattered()) + // TODO: handle the unstructure source case (!tdesTy) + if (!tdescTy || op.getOffsets()) return failure(); std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 229a289..850f70c 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -207,7 +207,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { // Subtract startOfRange from the original subgroup id to get the adjusted // sg id Value startOfRangeVal = - rewriter.create<arith::ConstantIndexOp>(loc, startOfRange); + arith::ConstantIndexOp::create(rewriter, loc, startOfRange); adjustedSgId = rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal); } @@ -431,8 +431,8 @@ struct WgToSgVectorBroadcastOp SmallVector<Value> newBroadcastOps; for (auto operand : adaptor.getOperands().front()) { - auto newBroadcast = rewriter.create<vector::BroadcastOp>( - op.getLoc(), newResultType, operand); + auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), + newResultType, operand); xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData()); newBroadcastOps.push_back(newBroadcast.getResult()); @@ -563,8 +563,8 @@ struct WgToSgConvertLayoutOp if (input && target) { // keep the ConvertLayoutOp for rest fields, e.g., inst_data. for (auto [i, src] : llvm::enumerate(adaptor.getSource())) { - auto newOp = rewriter.create<xegpu::ConvertLayoutOp>( - op.getLoc(), src.getType(), src, input, target); + auto newOp = xegpu::ConvertLayoutOp::create( + rewriter, op.getLoc(), src.getType(), src, input, target); newOps[i] = newOp; } } diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp index 0652202..e55a666 100644 --- a/mlir/lib/IR/AffineExpr.cpp +++ b/mlir/lib/IR/AffineExpr.cpp @@ -8,7 +8,6 @@ #include <cmath> #include <cstdint> -#include <limits> #include <utility> #include "AffineExprDetail.h" @@ -16,7 +15,6 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/IntegerSet.h" -#include "mlir/Support/TypeID.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/MathExtras.h" #include <numeric> diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index f95ad29..de52fbd 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -40,7 +40,7 @@ #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/Endian.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Regex.h" @@ -2070,9 +2070,8 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op, return failure(); }); if (failed(verify(op))) { - LLVM_DEBUG(llvm::dbgs() - << DEBUG_TYPE << ": '" << op->getName() - << "' failed to verify and will be printed in generic form\n"); + LDBG() << op->getName() + << "' failed to verify and will be printed in generic form"; printerFlags.printGenericOpForm(); } diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp index 3e33795..776b5c6 100644 --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -821,15 +821,7 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler( for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i) (void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1)); - // Register a handler to verify the diagnostics. - setHandler([&](Diagnostic &diag) { - // Process the main diagnostics. - process(diag); - - // Process each of the notes. - for (auto ¬e : diag.getNotes()) - process(note); - }); + registerInContext(ctx); } SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler( @@ -862,6 +854,17 @@ LogicalResult SourceMgrDiagnosticVerifierHandler::verify() { return impl->status; } +void SourceMgrDiagnosticVerifierHandler::registerInContext(MLIRContext *ctx) { + ctx->getDiagEngine().registerHandler([&](Diagnostic &diag) { + // Process the main diagnostics. + process(diag); + + // Process each of the notes. + for (auto ¬e : diag.getNotes()) + process(note); + }); +} + /// Process a single diagnostic. void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) { return process(diag.getLocation(), diag.str(), diag.getSeverity()); diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp index f897546..23e70c6 100644 --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -18,13 +18,9 @@ #include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Casting.h" #include "llvm/Support/TrailingObjects.h" #include <cassert> -#include <iterator> -#include <memory> -#include <optional> #include <tuple> #include <utility> diff --git a/mlir/lib/IR/PDL/PDLPatternMatch.cpp b/mlir/lib/IR/PDL/PDLPatternMatch.cpp index 28b39dd..62a71aa 100644 --- a/mlir/lib/IR/PDL/PDLPatternMatch.cpp +++ b/mlir/lib/IR/PDL/PDLPatternMatch.cpp @@ -7,10 +7,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/Iterators.h" #include "mlir/IR/PatternMatch.h" -#include "mlir/IR/RegionKindInterface.h" #include "llvm/Support/InterleavedRange.h" using namespace mlir; diff --git a/mlir/lib/IR/PatternLoggingListener.cpp b/mlir/lib/IR/PatternLoggingListener.cpp index ce2123a..0db13ab 100644 --- a/mlir/lib/IR/PatternLoggingListener.cpp +++ b/mlir/lib/IR/PatternLoggingListener.cpp @@ -1,50 +1,48 @@ #include "mlir/IR/PatternMatch.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "pattern-logging-listener" -#define DBGS() (llvm::dbgs() << "[" << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; void RewriterBase::PatternLoggingListener::notifyOperationInserted( Operation *op, InsertPoint previous) { - LDBG(patternName << " | notifyOperationInserted" - << " | " << op->getName()); + LDBG() << patternName << " | notifyOperationInserted" + << " | " << op->getName(); ForwardingListener::notifyOperationInserted(op, previous); } void RewriterBase::PatternLoggingListener::notifyOperationModified( Operation *op) { - LDBG(patternName << " | notifyOperationModified" - << " | " << op->getName()); + LDBG() << patternName << " | notifyOperationModified" + << " | " << op->getName(); ForwardingListener::notifyOperationModified(op); } void RewriterBase::PatternLoggingListener::notifyOperationReplaced( Operation *op, Operation *newOp) { - LDBG(patternName << " | notifyOperationReplaced (with op)" - << " | " << op->getName() << " | " << newOp->getName()); + LDBG() << patternName << " | notifyOperationReplaced (with op)" + << " | " << op->getName() << " | " << newOp->getName(); ForwardingListener::notifyOperationReplaced(op, newOp); } void RewriterBase::PatternLoggingListener::notifyOperationReplaced( Operation *op, ValueRange replacement) { - LDBG(patternName << " | notifyOperationReplaced (with values)" - << " | " << op->getName()); + LDBG() << patternName << " | notifyOperationReplaced (with values)" + << " | " << op->getName(); ForwardingListener::notifyOperationReplaced(op, replacement); } void RewriterBase::PatternLoggingListener::notifyOperationErased( Operation *op) { - LDBG(patternName << " | notifyOperationErased" - << " | " << op->getName()); + LDBG() << patternName << " | notifyOperationErased" + << " | " << op->getName(); ForwardingListener::notifyOperationErased(op); } void RewriterBase::PatternLoggingListener::notifyPatternBegin( const Pattern &pattern, Operation *op) { - LDBG(patternName << " | notifyPatternBegin" - << " | " << op->getName()); + LDBG() << patternName << " | notifyPatternBegin" + << " | " << op->getName(); ForwardingListener::notifyPatternBegin(pattern, op); } diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 1e60848..9332f55 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -7,8 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/PatternMatch.h" -#include "mlir/Config/mlir-config.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Iterators.h" #include "mlir/IR/RegionKindInterface.h" #include "llvm/ADT/SmallPtrSet.h" @@ -158,6 +156,11 @@ void RewriterBase::eraseOp(Operation *op) { assert(op->use_empty() && "expected 'op' to have no uses"); auto *rewriteListener = dyn_cast_if_present<Listener>(listener); + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + // Fast path: If no listener is attached, the op can be dropped in one go. if (!rewriteListener) { op->erase(); @@ -322,6 +325,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. assert(source->empty() && "expected 'source' to be empty"); eraseBlock(source); diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index 07c311b..87b4799 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -10,7 +10,6 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/OpImplementation.h" #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringSwitch.h" #include <optional> diff --git a/mlir/lib/IR/Value.cpp b/mlir/lib/IR/Value.cpp index 7b3a946..fa550e4 100644 --- a/mlir/lib/IR/Value.cpp +++ b/mlir/lib/IR/Value.cpp @@ -8,9 +8,7 @@ #include "mlir/IR/Value.h" #include "mlir/IR/Block.h" -#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Operation.h" -#include "llvm/ADT/SmallPtrSet.h" using namespace mlir; using namespace mlir::detail; diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index e9b5e92..310680b 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -17,14 +17,32 @@ using namespace mlir; +static std::pair<int64_t, int64_t> +getLineAndColStart(const llvm::SourceMgr &sourceMgr) { + unsigned lastFileID = sourceMgr.getNumBuffers(); + if (lastFileID == 1) + return {0, 0}; + + auto bufferID = sourceMgr.getMainFileID(); + const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID); + const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID); + // Exclude same start. + if (main->getBufferStart() < last->getBufferStart() && + main->getBufferEnd() >= last->getBufferEnd()) { + return sourceMgr.getLineAndColumn( + llvm::SMLoc::getFromPointer(last->getBufferStart()), bufferID); + } + return {0, 0}; +} + LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr, Block *block, const ParserConfig &config, LocationAttr *sourceFileLoc) { const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID()); if (sourceFileLoc) { - *sourceFileLoc = FileLineColLoc::get(config.getContext(), - sourceBuf->getBufferIdentifier(), - /*line=*/0, /*column=*/0); + auto [line, column] = getLineAndColStart(sourceMgr); + *sourceFileLoc = FileLineColLoc::get( + config.getContext(), sourceBuf->getBufferIdentifier(), line, column); } if (isBytecode(*sourceBuf)) return readBytecodeFile(*sourceBuf, block, config); @@ -37,9 +55,9 @@ mlir::parseSourceFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, const auto *sourceBuf = sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()); if (sourceFileLoc) { - *sourceFileLoc = FileLineColLoc::get(config.getContext(), - sourceBuf->getBufferIdentifier(), - /*line=*/0, /*column=*/0); + auto [line, column] = getLineAndColStart(*sourceMgr); + *sourceFileLoc = FileLineColLoc::get( + config.getContext(), sourceBuf->getBufferIdentifier(), line, column); } if (isBytecode(*sourceBuf)) return readBytecodeFile(sourceMgr, block, config); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 0db9808..7094c8e 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -901,7 +901,7 @@ LogicalResult PassManager::run(Operation *op) { if (failed(initialize(context, impl->initializationGeneration + 1))) return failure(); initializationKey = newInitKey; - pipelineKey = pipelineInitializationKey; + pipelineInitializationKey = pipelineKey; } // Construct a top level analysis manager for the pipeline. diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp index 7c294f0..bc766d4 100644 --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -10,7 +10,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Format.h" diff --git a/mlir/lib/Query/Matcher/MatchersInternal.cpp b/mlir/lib/Query/Matcher/MatchersInternal.cpp index 01f412a..21524f0 100644 --- a/mlir/lib/Query/Matcher/MatchersInternal.cpp +++ b/mlir/lib/Query/Matcher/MatchersInternal.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Query/Matcher/MatchersInternal.h" -#include "llvm/ADT/SetVector.h" namespace mlir::query::matcher { diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp new file mode 100644 index 0000000..7a345ed --- /dev/null +++ b/mlir/lib/RegisterAllDialects.cpp @@ -0,0 +1,207 @@ +//===- RegisterAllDialects.cpp - MLIR Dialects Registration -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a helper to trigger the registration of all dialects and +// passes to the system. +// +//===----------------------------------------------------------------------===// + +#include "mlir/InitAllDialects.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h" +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmSME/IR/ArmSME.h" +#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h" +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" +#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/MPI/IR/MPI.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h" +#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h" +#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Dialect/Ptr/IR/PtrDialect.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h" +#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" +#include "mlir/Dialect/Shape/IR/Shape.h" +#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Shard/IR/ShardDialect.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h" +#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h" +#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" +#include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h" +#include "mlir/Dialect/X86Vector/X86VectorDialect.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/IR/Dialect.h" +#include "mlir/Interfaces/CastInterfaces.h" +#include "mlir/Target/LLVM/NVVM/Target.h" +#include "mlir/Target/LLVM/ROCDL/Target.h" +#include "mlir/Target/SPIRV/Target.h" + +/// Add all the MLIR dialects to the provided registry. +void mlir::registerAllDialects(DialectRegistry ®istry) { + // clang-format off + registry.insert<acc::OpenACCDialect, + affine::AffineDialect, + amdgpu::AMDGPUDialect, + amx::AMXDialect, + arith::ArithDialect, + arm_neon::ArmNeonDialect, + arm_sme::ArmSMEDialect, + arm_sve::ArmSVEDialect, + async::AsyncDialect, + bufferization::BufferizationDialect, + cf::ControlFlowDialect, + complex::ComplexDialect, + DLTIDialect, + emitc::EmitCDialect, + func::FuncDialect, + gpu::GPUDialect, + index::IndexDialect, + irdl::IRDLDialect, + linalg::LinalgDialect, + LLVM::LLVMDialect, + math::MathDialect, + memref::MemRefDialect, + shard::ShardDialect, + ml_program::MLProgramDialect, + mpi::MPIDialect, + nvgpu::NVGPUDialect, + NVVM::NVVMDialect, + omp::OpenMPDialect, + pdl::PDLDialect, + pdl_interp::PDLInterpDialect, + ptr::PtrDialect, + quant::QuantDialect, + ROCDL::ROCDLDialect, + scf::SCFDialect, + shape::ShapeDialect, + smt::SMTDialect, + sparse_tensor::SparseTensorDialect, + spirv::SPIRVDialect, + tensor::TensorDialect, + tosa::TosaDialect, + transform::TransformDialect, + ub::UBDialect, + vector::VectorDialect, + x86vector::X86VectorDialect, + xegpu::XeGPUDialect, + xevm::XeVMDialect>(); + // clang-format on + + // Register all external models. + affine::registerValueBoundsOpInterfaceExternalModels(registry); + arith::registerBufferDeallocationOpInterfaceExternalModels(registry); + arith::registerBufferizableOpInterfaceExternalModels(registry); + arith::registerBufferViewFlowOpInterfaceExternalModels(registry); + arith::registerShardingInterfaceExternalModels(registry); + arith::registerValueBoundsOpInterfaceExternalModels(registry); + bufferization::func_ext::registerBufferizableOpInterfaceExternalModels( + registry); + builtin::registerCastOpInterfaceExternalModels(registry); + cf::registerBufferizableOpInterfaceExternalModels(registry); + cf::registerBufferDeallocationOpInterfaceExternalModels(registry); + gpu::registerBufferDeallocationOpInterfaceExternalModels(registry); + gpu::registerValueBoundsOpInterfaceExternalModels(registry); + LLVM::registerInlinerInterface(registry); + NVVM::registerInlinerInterface(registry); + linalg::registerAllDialectInterfaceImplementations(registry); + linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + memref::registerAllocationOpInterfaceExternalModels(registry); + memref::registerBufferViewFlowOpInterfaceExternalModels(registry); + memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + memref::registerValueBoundsOpInterfaceExternalModels(registry); + memref::registerMemorySlotExternalModels(registry); + ml_program::registerBufferizableOpInterfaceExternalModels(registry); + scf::registerBufferDeallocationOpInterfaceExternalModels(registry); + scf::registerBufferizableOpInterfaceExternalModels(registry); + scf::registerValueBoundsOpInterfaceExternalModels(registry); + shape::registerBufferizableOpInterfaceExternalModels(registry); + sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerBufferizableOpInterfaceExternalModels(registry); + tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry); + tensor::registerInferTypeOpInterfaceExternalModels(registry); + tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry); + tensor::registerSubsetOpInterfaceExternalModels(registry); + tensor::registerTilingInterfaceExternalModels(registry); + tensor::registerValueBoundsOpInterfaceExternalModels(registry); + tosa::registerShardingInterfaceExternalModels(registry); + vector::registerBufferizableOpInterfaceExternalModels(registry); + vector::registerSubsetOpInterfaceExternalModels(registry); + vector::registerValueBoundsOpInterfaceExternalModels(registry); + NVVM::registerNVVMTargetInterfaceExternalModels(registry); + ROCDL::registerROCDLTargetInterfaceExternalModels(registry); + spirv::registerSPIRVTargetInterfaceExternalModels(registry); +} + +/// Append all the MLIR dialects to the registry contained in the given context. +void mlir::registerAllDialects(MLIRContext &context) { + DialectRegistry registry; + registerAllDialects(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp new file mode 100644 index 0000000..8f7c67c --- /dev/null +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -0,0 +1,115 @@ +//===- RegisterAllExtensions.cpp - MLIR Extension Registration --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a helper to trigger the registration of all dialect +// extensions to the system. +// +//===----------------------------------------------------------------------===// + +#include "mlir/InitAllExtensions.h" + +#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h" +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" +#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/GPUCommon/GPUToLLVM.h" +#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h" +#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h" +#include "mlir/Conversion/MathToLLVM/MathToLLVM.h" +#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" +#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" +#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" +#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" +#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" +#include "mlir/Conversion/UBToLLVM/UBToLLVM.h" +#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" +#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h" +#include "mlir/Dialect/AMX/Transforms.h" +#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h" +#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h" +#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h" +#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h" +#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h" +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h" +#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h" +#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h" +#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" +#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h" +#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h" +#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" +#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h" +#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h" +#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h" +#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h" +#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" +#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h" +#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" + +/// This function may be called to register all MLIR dialect extensions with the +/// provided registry. +/// If you're building a compiler, you generally shouldn't use this: you would +/// individually register the specific extensions that are useful for the +/// pipelines and transformations you are using. +void mlir::registerAllExtensions(DialectRegistry ®istry) { + // Register all conversions to LLVM extensions. + registerConvertArithToEmitCInterface(registry); + arith::registerConvertArithToLLVMInterface(registry); + registerConvertComplexToLLVMInterface(registry); + cf::registerConvertControlFlowToLLVMInterface(registry); + func::registerAllExtensions(registry); + tensor::registerAllExtensions(registry); + registerConvertFuncToEmitCInterface(registry); + registerConvertFuncToLLVMInterface(registry); + index::registerConvertIndexToLLVMInterface(registry); + registerConvertMathToLLVMInterface(registry); + mpi::registerConvertMPIToLLVMInterface(registry); + registerConvertMemRefToEmitCInterface(registry); + registerConvertMemRefToLLVMInterface(registry); + registerConvertNVVMToLLVMInterface(registry); + registerConvertOpenMPToLLVMInterface(registry); + registerConvertSCFToEmitCInterface(registry); + ub::registerConvertUBToLLVMInterface(registry); + registerConvertAMXToLLVMInterface(registry); + gpu::registerConvertGpuToLLVMInterface(registry); + NVVM::registerConvertGpuToNVVMInterface(registry); + vector::registerConvertVectorToLLVMInterface(registry); + registerConvertXeVMToLLVMInterface(registry); + + // Register all transform dialect extensions. + affine::registerTransformDialectExtension(registry); + bufferization::registerTransformDialectExtension(registry); + dlti::registerTransformDialectExtension(registry); + func::registerTransformDialectExtension(registry); + gpu::registerTransformDialectExtension(registry); + linalg::registerTransformDialectExtension(registry); + memref::registerTransformDialectExtension(registry); + nvgpu::registerTransformDialectExtension(registry); + scf::registerTransformDialectExtension(registry); + sparse_tensor::registerTransformDialectExtension(registry); + tensor::registerTransformDialectExtension(registry); + transform::registerDebugExtension(registry); + transform::registerIRDLExtension(registry); + transform::registerLoopExtension(registry); + transform::registerPDLExtension(registry); + transform::registerTuneExtension(registry); + vector::registerTransformDialectExtension(registry); + arm_neon::registerTransformDialectExtension(registry); + arm_sve::registerTransformDialectExtension(registry); + + // Translation extensions need to be registered by calling + // `registerAllToLLVMIRTranslations` (see All.h). +} diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp new file mode 100644 index 0000000..1ed3a37 --- /dev/null +++ b/mlir/lib/RegisterAllPasses.cpp @@ -0,0 +1,99 @@ +//===- RegisterAllPasses.cpp - MLIR Registration ----------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines a helper to trigger the registration of all passes to the +// system. +// +//===----------------------------------------------------------------------===// + +#include "mlir/InitAllPasses.h" + +#include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" +#include "mlir/Dialect/Affine/Passes.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/ArmSME/Transforms/Passes.h" +#include "mlir/Dialect/ArmSVE/Transforms/Passes.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Dialect/Bufferization/Pipelines/Passes.h" +#include "mlir/Dialect/Bufferization/Transforms/Passes.h" +#include "mlir/Dialect/EmitC/Transforms/Passes.h" +#include "mlir/Dialect/Func/Transforms/Passes.h" +#include "mlir/Dialect/GPU/Pipelines/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Dialect/LLVMIR/Transforms/Passes.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/MLProgram/Transforms/Passes.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/Dialect/MemRef/Transforms/Passes.h" +#include "mlir/Dialect/NVGPU/Transforms/Passes.h" +#include "mlir/Dialect/OpenACC/Transforms/Passes.h" +#include "mlir/Dialect/Quant/Transforms/Passes.h" +#include "mlir/Dialect/SCF/Transforms/Passes.h" +#include "mlir/Dialect/SPIRV/Transforms/Passes.h" +#include "mlir/Dialect/Shape/Transforms/Passes.h" +#include "mlir/Dialect/Shard/Transforms/Passes.h" +#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/Dialect/Tosa/Transforms/Passes.h" +#include "mlir/Dialect/Transform/Transforms/Passes.h" +#include "mlir/Dialect/Vector/Transforms/Passes.h" +#include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Transforms/Passes.h" + +// This function may be called to register the MLIR passes with the +// global registry. +// If you're building a compiler, you likely don't need this: you would build a +// pipeline programmatically without the need to register with the global +// registry, since it would already be calling the creation routine of the +// individual passes. +// The global registry is interesting to interact with the command-line tools. +void mlir::registerAllPasses() { + // General passes + registerTransformsPasses(); + + // Conversion passes + registerConversionPasses(); + + // Dialect passes + acc::registerOpenACCPasses(); + affine::registerAffinePasses(); + amdgpu::registerAMDGPUPasses(); + registerAsyncPasses(); + arith::registerArithPasses(); + bufferization::registerBufferizationPasses(); + func::registerFuncPasses(); + registerGPUPasses(); + registerLinalgPasses(); + registerNVGPUPasses(); + registerSparseTensorPasses(); + LLVM::registerLLVMPasses(); + math::registerMathPasses(); + memref::registerMemRefPasses(); + shard::registerShardPasses(); + ml_program::registerMLProgramPasses(); + quant::registerQuantPasses(); + registerSCFPasses(); + registerShapePasses(); + spirv::registerSPIRVPasses(); + tensor::registerTensorPasses(); + tosa::registerTosaOptPasses(); + transform::registerTransformPasses(); + vector::registerVectorPasses(); + arm_sme::registerArmSMEPasses(); + arm_sve::registerArmSVEPasses(); + emitc::registerEmitCPasses(); + xegpu::registerXeGPUPasses(); + + // Dialect pipelines + bufferization::registerBufferizationPipelines(); + sparse_tensor::registerSparseTensorPipelines(); + tosa::registerTosaToLinalgPipelines(); + gpu::registerGPUToNVVMPipeline(); +} diff --git a/mlir/lib/Support/ToolUtilities.cpp b/mlir/lib/Support/ToolUtilities.cpp index 748f928..2cf33eb 100644 --- a/mlir/lib/Support/ToolUtilities.cpp +++ b/mlir/lib/Support/ToolUtilities.cpp @@ -14,6 +14,8 @@ #include "mlir/Support/LLVM.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/raw_ostream.h" +#include <string> +#include <utility> using namespace mlir; @@ -22,18 +24,18 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, ChunkBufferHandler processChunkBuffer, raw_ostream &os, llvm::StringRef inputSplitMarker, llvm::StringRef outputSplitMarker) { + llvm::MemoryBufferRef originalBufferRef = originalBuffer->getMemBufferRef(); // If splitting is disabled, we process the full input buffer. if (inputSplitMarker.empty()) - return processChunkBuffer(std::move(originalBuffer), os); + return processChunkBuffer(std::move(originalBuffer), originalBufferRef, os); const int inputSplitMarkerLen = inputSplitMarker.size(); - auto *origMemBuffer = originalBuffer.get(); SmallVector<StringRef, 8> rawSourceBuffers; const int checkLen = 2; // Split dropping the last checkLen chars to enable flagging near misses. - origMemBuffer->getBuffer().split(rawSourceBuffers, - inputSplitMarker.drop_back(checkLen)); + originalBufferRef.getBuffer().split(rawSourceBuffers, + inputSplitMarker.drop_back(checkLen)); if (rawSourceBuffers.empty()) return success(); @@ -79,11 +81,17 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, auto interleaveFn = [&](StringRef subBuffer) { auto splitLoc = SMLoc::getFromPointer(subBuffer.data()); unsigned splitLine = fileSourceMgr.getLineAndColumn(splitLoc).first; - auto subMemBuffer = llvm::MemoryBuffer::getMemBufferCopy( - subBuffer, Twine("within split at ") + - origMemBuffer->getBufferIdentifier() + ":" + - Twine(splitLine) + " offset "); - if (failed(processChunkBuffer(std::move(subMemBuffer), os))) + std::string name((Twine("within split at ") + + originalBufferRef.getBufferIdentifier() + ":" + + Twine(splitLine) + " offset ") + .str()); + // Use MemoryBufferRef to avoid copying the buffer & keep at same location + // relative to the original buffer. + auto subMemBuffer = + llvm::MemoryBuffer::getMemBuffer(llvm::MemoryBufferRef(subBuffer, name), + /*RequiresNullTerminator=*/false); + if (failed( + processChunkBuffer(std::move(subMemBuffer), originalBufferRef, os))) hadFailure = true; }; llvm::interleave(sourceBuffers, os, interleaveFn, @@ -92,3 +100,16 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, // If any fails, then return a failure of the tool. return failure(hadFailure); } + +LogicalResult +mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer, + NoSourceChunkBufferHandler processChunkBuffer, + raw_ostream &os, llvm::StringRef inputSplitMarker, + llvm::StringRef outputSplitMarker) { + auto process = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer, + const llvm::MemoryBufferRef &, raw_ostream &os) { + return processChunkBuffer(std::move(chunkBuffer), os); + }; + return splitAndProcessBuffer(std::move(originalBuffer), process, os, + inputSplitMarker, outputSplitMarker); +} diff --git a/mlir/lib/Support/TypeID.cpp b/mlir/lib/Support/TypeID.cpp index 01ad910..304253c 100644 --- a/mlir/lib/Support/TypeID.cpp +++ b/mlir/lib/Support/TypeID.cpp @@ -27,9 +27,6 @@ namespace { struct ImplicitTypeIDRegistry { /// Lookup or insert a TypeID for the given type name. TypeID lookupOrInsert(StringRef typeName) { - LLVM_DEBUG(llvm::dbgs() << "ImplicitTypeIDRegistry::lookupOrInsert(" - << typeName << ")\n"); - // Perform a heuristic check to see if this type is in an anonymous // namespace. String equality is not valid for anonymous types, so we try to // abort whenever we see them. diff --git a/mlir/lib/TableGen/Successor.cpp b/mlir/lib/TableGen/Successor.cpp index ce0aafb..cd0677d 100644 --- a/mlir/lib/TableGen/Successor.cpp +++ b/mlir/lib/TableGen/Successor.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "mlir/TableGen/Successor.h" -#include "llvm/ADT/TypeSwitch.h" #include "llvm/TableGen/Record.h" using namespace mlir; diff --git a/mlir/lib/TableGen/Type.cpp b/mlir/lib/TableGen/Type.cpp index 4f74056..b31377e 100644 --- a/mlir/lib/TableGen/Type.cpp +++ b/mlir/lib/TableGen/Type.cpp @@ -12,7 +12,6 @@ #include "mlir/TableGen/Type.h" #include "mlir/TableGen/Dialect.h" -#include "llvm/ADT/Twine.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/TableGen/Record.h" diff --git a/mlir/lib/Target/Cpp/TranslateRegistration.cpp b/mlir/lib/Target/Cpp/TranslateRegistration.cpp index 2108ffd..7dae03e 100644 --- a/mlir/lib/Target/Cpp/TranslateRegistration.cpp +++ b/mlir/lib/Target/Cpp/TranslateRegistration.cpp @@ -9,8 +9,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Dialect.h" #include "mlir/Target/Cpp/CppEmitter.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "llvm/Support/CommandLine.h" diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index a393d88..dcd2e11 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -17,15 +17,12 @@ #include "mlir/Support/IndentedOstream.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/Cpp/CppEmitter.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/StringExtras.h" -#include "llvm/ADT/StringMap.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" #include <stack> -#include <utility> #define DEBUG_TYPE "translate-to-cpp" @@ -903,8 +900,7 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) { // inlined, and as such should be wrapped in parentheses in order to guarantee // its precedence and associativity. auto requiresParentheses = [&](Value value) { - auto expressionOp = - dyn_cast_if_present<ExpressionOp>(value.getDefiningOp()); + auto expressionOp = value.getDefiningOp<ExpressionOp>(); if (!expressionOp) return false; return shouldBeInlined(expressionOp); @@ -1545,7 +1541,7 @@ LogicalResult CppEmitter::emitOperand(Value value) { return success(); } - auto expressionOp = dyn_cast_if_present<ExpressionOp>(value.getDefiningOp()); + auto expressionOp = value.getDefiningOp<ExpressionOp>(); if (expressionOp && shouldBeInlined(expressionOp)) return emitExpression(expressionOp); diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt index 83fbf7a..f6e44c6 100644 --- a/mlir/lib/Target/LLVM/CMakeLists.txt +++ b/mlir/lib/Target/LLVM/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_library(MLIRTargetLLVM intrinsics_gen LINK_COMPONENTS + BitWriter Core IPO IRReader @@ -59,7 +60,7 @@ if ("NVPTX" IN_LIST LLVM_TARGETS_TO_BUILD) # See: https://gitlab.kitware.com/cmake/cmake/-/issues/24858 # TODO: Bump the MLIR CMake version to 3.26.4 and switch to # ${CUDAToolkit_LIBRARY_ROOT} - if(NOT DEFINED ${CUDAToolkit_LIBRARY_ROOT}) + if(NOT DEFINED CUDAToolkit_LIBRARY_ROOT) get_filename_component(MLIR_CUDAToolkit_ROOT ${CUDAToolkit_BIN_DIR} DIRECTORY ABSOLUTE) else() diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt index af22a7f..9ea5c683 100644 --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -60,6 +60,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration MLIRROCDLToLLVMIRTranslation MLIRSPIRVToLLVMIRTranslation MLIRVCIXToLLVMIRTranslation + MLIRXeVMToLLVMIRTranslation ) add_mlir_translation_library(MLIRTargetLLVMIRImport diff --git a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp index 75170bf..8e6f5c7 100644 --- a/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp +++ b/mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/Target/LLVMIR/Dialect/All.h" #include "mlir/Target/LLVMIR/Export.h" #include "mlir/Tools/mlir-translate/Translation.h" diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt index f030fa7..86c731a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -10,3 +10,4 @@ add_subdirectory(OpenMP) add_subdirectory(ROCDL) add_subdirectory(SPIRV) add_subdirectory(VCIX) +add_subdirectory(XeVM) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index ff34a08..0f675a0 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -13,6 +13,7 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Operation.h" +#include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" @@ -136,46 +137,6 @@ convertOperandBundles(OperandRangeRange bundleOperands, return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation); } -static LogicalResult -convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray, - ArrayAttr resAttrsArray, llvm::CallBase *call, - LLVM::ModuleTranslation &moduleTranslation) { - if (argAttrsArray) { - for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) { - if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr); - !argAttrs.empty()) { - FailureOr<llvm::AttrBuilder> attrBuilder = - moduleTranslation.convertParameterAttrs(loc, argAttrs); - if (failed(attrBuilder)) - return failure(); - call->addParamAttrs(argIdx, *attrBuilder); - } - } - } - - if (resAttrsArray && resAttrsArray.size() > 0) { - if (resAttrsArray.size() != 1) - return mlir::emitError(loc, "llvm.func cannot have multiple results"); - if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]); - !resAttrs.empty()) { - FailureOr<llvm::AttrBuilder> attrBuilder = - moduleTranslation.convertParameterAttrs(loc, resAttrs); - if (failed(attrBuilder)) - return failure(); - call->addRetAttrs(*attrBuilder); - } - } - return success(); -} - -static LogicalResult -convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call, - LLVM::ModuleTranslation &moduleTranslation) { - return convertParameterAndResultAttrs( - callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call, - moduleTranslation); -} - /// Builder for LLVM_CallIntrinsicOp static LogicalResult convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, @@ -243,9 +204,7 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder, convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(), moduleTranslation)); - if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(), - op.getResAttrsAttr(), inst, - moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(op, inst))) return failure(); if (op.getNumResults() == 1) @@ -455,7 +414,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, if (callOp.getInlineHintAttr()) call->addFnAttr(llvm::Attribute::InlineHint); - if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(callOp, call))) return failure(); if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) { @@ -569,8 +528,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, operandsRef.drop_front(), opBundles); } result->setCallingConv(convertCConvToLLVM(invOp.getCConv())); - if (failed( - convertParameterAndResultAttrs(invOp, result, moduleTranslation))) + if (failed(moduleTranslation.convertArgAndResultAttrs(invOp, result))) return failure(); moduleTranslation.mapBranch(invOp, result); // InvokeOp can only have 0 or 1 result diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp index ad01a64..55e73e8 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp @@ -13,7 +13,6 @@ #include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Target/LLVMIR/ModuleImport.h" - #include "llvm/IR/ConstantRange.h" using namespace mlir; diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp index d162afd..97c6b4e 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenACC/OpenACCToLLVMIRTranslation.cpp @@ -151,8 +151,7 @@ processDataOperands(llvm::IRBuilderBase &builder, // Copyin operands are handled as `to` call. llvm::SmallVector<mlir::Value> create, copyin; for (mlir::Value dataOp : op.getDataClauseOperands()) { - if (auto createOp = - mlir::dyn_cast_or_null<acc::CreateOp>(dataOp.getDefiningOp())) { + if (auto createOp = dataOp.getDefiningOp<acc::CreateOp>()) { create.push_back(createOp.getVarPtr()); } else if (auto copyinOp = mlir::dyn_cast_or_null<acc::CopyinOp>( dataOp.getDefiningOp())) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index da39b19..49e1e55 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -16,15 +16,12 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/OpenMP/OpenMPInterfaces.h" -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/LLVMIR/Dialect/OpenMPCommon.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" -#include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Frontend/OpenMP/OMPConstants.h" @@ -39,7 +36,6 @@ #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/ModuleUtils.h" -#include <any> #include <cstdint> #include <iterator> #include <numeric> @@ -3541,8 +3537,7 @@ getDeclareTargetRefPtrSuffix(LLVM::GlobalOp globalOp, } static bool isDeclareTargetLink(mlir::Value value) { - if (auto addressOfOp = - llvm::dyn_cast_if_present<LLVM::AddressOfOp>(value.getDefiningOp())) { + if (auto addressOfOp = value.getDefiningOp<LLVM::AddressOfOp>()) { auto modOp = addressOfOp->getParentOfType<mlir::ModuleOp>(); Operation *gOp = modOp.lookupSymbol(addressOfOp.getGlobalName()); if (auto declareTargetGlobal = @@ -3882,29 +3877,28 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo, llvm::SmallVector<size_t> indices(indexAttr.size()); std::iota(indices.begin(), indices.end(), 0); - llvm::sort(indices.begin(), indices.end(), - [&](const size_t a, const size_t b) { - auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); - auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); - for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) { - int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt(); - int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt(); + llvm::sort(indices, [&](const size_t a, const size_t b) { + auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]); + auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]); + for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) { + int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt(); + int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt(); - if (aIndex == bIndex) - continue; + if (aIndex == bIndex) + continue; - if (aIndex < bIndex) - return first; + if (aIndex < bIndex) + return first; - if (aIndex > bIndex) - return !first; - } + if (aIndex > bIndex) + return !first; + } - // Iterated the up until the end of the smallest member and - // they were found to be equal up to that point, so select - // the member with the lowest index count, so the "parent" - return memberIndicesA.size() < memberIndicesB.size(); - }); + // Iterated the up until the end of the smallest member and + // they were found to be equal up to that point, so select + // the member with the lowest index count, so the "parent" + return memberIndicesA.size() < memberIndicesB.size(); + }); return llvm::cast<omp::MapInfoOp>( mapInfo.getMembers()[indices.front()].getDefiningOp()); @@ -4502,8 +4496,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, ifCond = moduleTranslation.lookupValue(ifVar); if (auto devId = dataOp.getDevice()) - if (auto constOp = - dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp())) + if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>()) if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) deviceID = intAttr.getInt(); @@ -4520,8 +4513,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, ifCond = moduleTranslation.lookupValue(ifVar); if (auto devId = enterDataOp.getDevice()) - if (auto constOp = - dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp())) + if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>()) if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) deviceID = intAttr.getInt(); RTLFn = @@ -4540,8 +4532,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, ifCond = moduleTranslation.lookupValue(ifVar); if (auto devId = exitDataOp.getDevice()) - if (auto constOp = - dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp())) + if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>()) if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) deviceID = intAttr.getInt(); @@ -4560,8 +4551,7 @@ convertOmpTargetData(Operation *op, llvm::IRBuilderBase &builder, ifCond = moduleTranslation.lookupValue(ifVar); if (auto devId = updateDataOp.getDevice()) - if (auto constOp = - dyn_cast<LLVM::ConstantOp>(devId.getDefiningOp())) + if (auto constOp = devId.getDefiningOp<LLVM::ConstantOp>()) if (auto intAttr = dyn_cast<IntegerAttr>(constOp.getValue())) deviceID = intAttr.getInt(); @@ -5202,8 +5192,7 @@ static std::optional<int64_t> extractConstInteger(Value value) { if (!value) return std::nullopt; - if (auto constOp = - dyn_cast_if_present<LLVM::ConstantOp>(value.getDefiningOp())) + if (auto constOp = value.getDefiningOp<LLVM::ConstantOp>()) if (auto constAttr = dyn_cast<IntegerAttr>(constOp.getValue())) return constAttr.getInt(); diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt new file mode 100644 index 0000000..6308d7e --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt @@ -0,0 +1,21 @@ +set(LLVM_OPTIONAL_SOURCES + XeVMToLLVMIRTranslation.cpp +) + +add_mlir_translation_library(MLIRXeVMToLLVMIRTranslation + XeVMToLLVMIRTranslation.cpp + + DEPENDS + MLIRXeVMConversionsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRDialectUtils + MLIRIR + MLIRLLVMDialect + MLIRXeVMDialect + MLIRSupport + MLIRTargetLLVMIRExport +) diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp new file mode 100644 index 0000000..73b166d --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp @@ -0,0 +1,103 @@ +//===-- XeVMToLLVMIRTranslation.cpp - Translate XeVM to LLVM IR -*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR XeVM dialect and +// LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" + +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::LLVM; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the XeVM dialect to LLVM IR. +class XeVMDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Attaches module-level metadata for functions marked as kernels. + LogicalResult + amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final { + StringRef attrName = attribute.getName().getValue(); + if (attrName == mlir::xevm::XeVMDialect::getCacheControlsAttrName()) { + auto cacheControlsArray = dyn_cast<ArrayAttr>(attribute.getValue()); + if (cacheControlsArray.size() != 2) { + return op->emitOpError( + "Expected both L1 and L3 cache control attributes!"); + } + if (instructions.size() != 1) { + return op->emitOpError("Expecting a single instruction"); + } + return handleDecorationCacheControl(instructions.front(), + cacheControlsArray.getValue()); + } + auto func = dyn_cast<LLVM::LLVMFuncOp>(op); + if (!func) + return failure(); + + return success(); + } + +private: + static LogicalResult handleDecorationCacheControl(llvm::Instruction *inst, + ArrayRef<Attribute> attrs) { + SmallVector<llvm::Metadata *> decorations; + llvm::LLVMContext &ctx = inst->getContext(); + llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx); + llvm::transform( + attrs, std::back_inserter(decorations), + [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * { + auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue(); + std::array<llvm::Metadata *, 4> metadata; + llvm::transform( + valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) { + return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get( + i32Ty, cast<IntegerAttr>(valueAttr).getValue())); + }); + return llvm::MDNode::get(ctx, metadata); + }); + constexpr llvm::StringLiteral decorationCacheControlMDName = + "spirv.DecorationCacheControlINTEL"; + inst->setMetadata(decorationCacheControlMDName, + llvm::MDNode::get(ctx, decorations)); + return success(); + } +}; +} // namespace + +void mlir::registerXeVMDialectTranslation(::mlir::DialectRegistry ®istry) { + registry.insert<xevm::XeVMDialect>(); + registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) { + dialect->addInterfaces<XeVMDialectLLVMIRTranslationInterface>(); + }); +} + +void mlir::registerXeVMDialectTranslation(::mlir::MLIRContext &context) { + DialectRegistry registry; + registerXeVMDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp index 580afdd..cb1f234 100644 --- a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp +++ b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp @@ -33,7 +33,9 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic( SmallVector<Value> mlirOperands; SmallVector<NamedAttribute> mlirAttrs; if (failed(moduleImport.convertIntrinsicArguments( - llvmOperands, llvmOpBundles, false, {}, {}, mlirOperands, mlirAttrs))) + llvmOperands, llvmOpBundles, /*requiresOpBundles=*/false, + /*immArgPositions=*/{}, /*immArgAttrNames=*/{}, mlirOperands, + mlirAttrs))) return failure(); Type resultType = moduleImport.convertType(inst->getType()); @@ -44,11 +46,7 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic( ValueRange{mlirOperands}, FastmathFlagsAttr{}); moduleImport.setFastmathFlagsAttr(inst, op); - - ArrayAttr argsAttr, resAttr; - moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder); - op.setArgAttrsAttr(argsAttr); - op.setResAttrsAttr(resAttr); + moduleImport.convertArgAndResultAttrs(inst, op); // Update importer tracking of results. unsigned numRes = op.getNumResults(); diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 94db7f8..6325480 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -30,6 +30,7 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/IR/Comdat.h" #include "llvm/IR/Constants.h" @@ -142,6 +143,7 @@ static LogicalResult convertInstructionImpl(OpBuilder &odsBuilder, // TODO: Implement the `convertInstruction` hooks in the // `LLVMDialectLLVMIRImportInterface` and move the following include there. #include "mlir/Dialect/LLVMIR/LLVMOpFromLLVMIRConversions.inc" + return failure(); } @@ -1062,6 +1064,18 @@ void ModuleImport::convertTargetTriple() { builder.getStringAttr(llvmModule->getTargetTriple().str())); } +void ModuleImport::convertModuleLevelAsm() { + llvm::StringRef asmStr = llvmModule->getModuleInlineAsm(); + llvm::SmallVector<mlir::Attribute> asmArrayAttr; + + for (llvm::StringRef line : llvm::split(asmStr, '\n')) + if (!line.empty()) + asmArrayAttr.push_back(builder.getStringAttr(line)); + + mlirModule->setAttr(LLVM::LLVMDialect::getModuleLevelAsmAttrName(), + builder.getArrayAttr(asmArrayAttr)); +} + LogicalResult ModuleImport::convertFunctions() { for (llvm::Function &func : llvmModule->functions()) if (failed(processFunction(&func))) @@ -1626,12 +1640,11 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) { // Convert dso_local_equivalent. if (auto *dsoLocalEquivalent = dyn_cast<llvm::DSOLocalEquivalent>(constant)) { Type type = convertType(dsoLocalEquivalent->getType()); - return builder - .create<DSOLocalEquivalentOp>( - loc, type, - FlatSymbolRefAttr::get( - builder.getContext(), - dsoLocalEquivalent->getGlobalValue()->getName())) + return DSOLocalEquivalentOp::create( + builder, loc, type, + FlatSymbolRefAttr::get( + builder.getContext(), + dsoLocalEquivalent->getGlobalValue()->getName())) .getResult(); } @@ -1736,9 +1749,9 @@ FailureOr<Value> ModuleImport::convertConstant(llvm::Constant *constant) { FlatSymbolRefAttr::get(context, blockAddr->getFunction()->getName()); auto blockTag = BlockTagAttr::get(context, blockAddr->getBasicBlock()->getNumber()); - return builder - .create<BlockAddressOp>(loc, convertType(blockAddr->getType()), - BlockAddressAttr::get(context, fnSym, blockTag)) + return BlockAddressOp::create( + builder, loc, convertType(blockAddr->getType()), + BlockAddressAttr::get(context, fnSym, blockTag)) .getRes(); } @@ -2228,17 +2241,16 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { if (!resultTy) return failure(); ArrayAttr operandAttrs = convertAsmInlineOperandAttrs(*callInst); - return builder - .create<InlineAsmOp>( - loc, resultTy, *operands, - builder.getStringAttr(asmI->getAsmString()), - builder.getStringAttr(asmI->getConstraintString()), - asmI->hasSideEffects(), asmI->isAlignStack(), - convertTailCallKindFromLLVM(callInst->getTailCallKind()), - AsmDialectAttr::get( - mlirModule.getContext(), - convertAsmDialectFromLLVM(asmI->getDialect())), - operandAttrs) + return InlineAsmOp::create( + builder, loc, resultTy, *operands, + builder.getStringAttr(asmI->getAsmString()), + builder.getStringAttr(asmI->getConstraintString()), + asmI->hasSideEffects(), asmI->isAlignStack(), + convertTailCallKindFromLLVM(callInst->getTailCallKind()), + AsmDialectAttr::get( + mlirModule.getContext(), + convertAsmDialectFromLLVM(asmI->getDialect())), + operandAttrs) .getOperation(); } bool isIncompatibleCall; @@ -2268,7 +2280,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // Handle parameter and result attributes unless it's an incompatible // call. if (!isIncompatibleCall) - convertParameterAttributes(callInst, callOp, builder); + convertArgAndResultAttrs(callInst, callOp); return callOp.getOperation(); }(); @@ -2365,7 +2377,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) { // Handle parameter and result attributes unless it's an incompatible // invoke. if (!isIncompatibleInvoke) - convertParameterAttributes(invokeInst, invokeOp, builder); + convertArgAndResultAttrs(invokeInst, invokeOp); if (!invokeInst->getType()->isVoidTy()) mapValue(inst, invokeOp.getResults().front()); @@ -2731,11 +2743,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func, } DictionaryAttr -ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, - OpBuilder &builder) { +ModuleImport::convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet) { SmallVector<NamedAttribute> paramAttrs; for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) { - auto llvmAttr = llvmParamAttrs.getAttribute(llvmKind); + auto llvmAttr = llvmAttrSet.getAttribute(llvmKind); // Skip attributes that are not attached. if (!llvmAttr.isValid()) continue; @@ -2770,13 +2781,12 @@ ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs, return builder.getDictionaryAttr(paramAttrs); } -void ModuleImport::convertParameterAttributes(llvm::Function *func, - LLVMFuncOp funcOp, - OpBuilder &builder) { +void ModuleImport::convertArgAndResultAttrs(llvm::Function *func, + LLVMFuncOp funcOp) { auto llvmAttrs = func->getAttributes(); for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) { llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(i); - funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder)); + funcOp.setArgAttrs(i, convertArgOrResultAttrSet(llvmArgAttrs)); } // Convert the result attributes and attach them wrapped in an ArrayAttribute // to the funcOp. @@ -2784,17 +2794,23 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func, if (!llvmResAttr.hasAttributes()) return; funcOp.setResAttrsAttr( - builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder))); + builder.getArrayAttr({convertArgOrResultAttrSet(llvmResAttr)})); } -void ModuleImport::convertParameterAttributes(llvm::CallBase *call, - ArrayAttr &argsAttr, - ArrayAttr &resAttr, - OpBuilder &builder) { +void ModuleImport::convertArgAndResultAttrs( + llvm::CallBase *call, ArgAndResultAttrsOpInterface attrsOp, + ArrayRef<unsigned> immArgPositions) { + // Compute the set of immediate argument positions. + llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(), + immArgPositions.end()); + // Convert the argument attributes and filter out immediate arguments. llvm::AttributeList llvmAttrs = call->getAttributes(); SmallVector<llvm::AttributeSet> llvmArgAttrsSet; bool anyArgAttrs = false; for (size_t i = 0, e = call->arg_size(); i < e; ++i) { + // Skip immediate arguments. + if (immArgPositionsSet.contains(i)) + continue; llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i)); if (llvmArgAttrsSet.back().hasAttributes()) anyArgAttrs = true; @@ -2808,24 +2824,16 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call, if (anyArgAttrs) { SmallVector<DictionaryAttr> argAttrs; for (auto &llvmArgAttrs : llvmArgAttrsSet) - argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder)); - argsAttr = getArrayAttr(argAttrs); + argAttrs.emplace_back(convertArgOrResultAttrSet(llvmArgAttrs)); + attrsOp.setArgAttrsAttr(getArrayAttr(argAttrs)); } + // Convert the result attributes. llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs(); if (!llvmResAttr.hasAttributes()) return; - DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder); - resAttr = getArrayAttr({resAttrs}); -} - -void ModuleImport::convertParameterAttributes(llvm::CallBase *call, - CallOpInterface callOp, - OpBuilder &builder) { - ArrayAttr argsAttr, resAttr; - convertParameterAttributes(call, argsAttr, resAttr, builder); - callOp.setArgAttrsAttr(argsAttr); - callOp.setResAttrsAttr(resAttr); + DictionaryAttr resAttrs = convertArgOrResultAttrSet(llvmResAttr); + attrsOp.setResAttrsAttr(getArrayAttr({resAttrs})); } template <typename Op> @@ -2893,7 +2901,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) { builder, loc, func->getName(), functionType, convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv); - convertParameterAttributes(func, funcOp, builder); + convertArgAndResultAttrs(func, funcOp); if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func)) funcOp.setPersonalityAttr(personality); @@ -3200,5 +3208,6 @@ OwningOpRef<ModuleOp> mlir::translateLLVMIRToModule( if (failed(moduleImport.convertIFuncs())) return {}; moduleImport.convertTargetTriple(); + moduleImport.convertModuleLevelAsm(); return module; } diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index b997e55..b3a06e2 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1758,6 +1758,48 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx, return attrBuilder; } +LogicalResult ModuleTranslation::convertArgAndResultAttrs( + ArgAndResultAttrsOpInterface attrsOp, llvm::CallBase *call, + ArrayRef<unsigned> immArgPositions) { + // Convert the argument attributes. + if (ArrayAttr argAttrsArray = attrsOp.getArgAttrsAttr()) { + unsigned argAttrIdx = 0; + llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(), + immArgPositions.end()); + for (unsigned argIdx : llvm::seq<unsigned>(call->arg_size())) { + if (argAttrIdx >= argAttrsArray.size()) + break; + // Skip immediate arguments (they have no entries in argAttrsArray). + if (immArgPositionsSet.contains(argIdx)) + continue; + // Skip empty argument attributes. + auto argAttrs = cast<DictionaryAttr>(argAttrsArray[argAttrIdx++]); + if (argAttrs.empty()) + continue; + // Convert and add attributes to the call instruction. + FailureOr<llvm::AttrBuilder> attrBuilder = + convertParameterAttrs(attrsOp->getLoc(), argAttrs); + if (failed(attrBuilder)) + return failure(); + call->addParamAttrs(argIdx, *attrBuilder); + } + } + + // Convert the result attributes. + if (ArrayAttr resAttrsArray = attrsOp.getResAttrsAttr()) { + if (!resAttrsArray.empty()) { + auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]); + FailureOr<llvm::AttrBuilder> attrBuilder = + convertParameterAttrs(attrsOp->getLoc(), resAttrs); + if (failed(attrBuilder)) + return failure(); + call->addRetAttrs(*attrBuilder); + } + } + + return success(); +} + FailureOr<llvm::AttrBuilder> ModuleTranslation::convertParameterAttrs(Location loc, DictionaryAttr paramAttrs) { @@ -2276,6 +2318,25 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext, llvmModule->setTargetTriple( llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue())); + if (auto asmAttr = m->getDiscardableAttr( + LLVM::LLVMDialect::getModuleLevelAsmAttrName())) { + auto asmArrayAttr = dyn_cast<ArrayAttr>(asmAttr); + if (!asmArrayAttr) { + m->emitError("expected an array attribute for a module level asm"); + return nullptr; + } + + for (Attribute elt : asmArrayAttr) { + auto asmStrAttr = dyn_cast<StringAttr>(elt); + if (!asmStrAttr) { + m->emitError( + "expected a string attribute for each entry of a module level asm"); + return nullptr; + } + llvmModule->appendModuleInlineAsm(asmStrAttr.getValue()); + } + } + return llvmModule; } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 88799a5..88931b5 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) { return emitError(unknownLoc, "OpDecoration with ") << decorationName << "needs a single target <id>"; } - // Block decoration does not affect spirv.struct type, but is still stored - // for verification. - // TODO: Update StructType to contain this information since - // it is needed for many validation rules. decorations[words[0]].set(symbol, opBuilder.getUnitAttr()); break; case spirv::Decoration::Location: @@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) { if (failed(structType.trySetBody( deferredStructIt->memberTypes, deferredStructIt->offsetInfo, - deferredStructIt->memberDecorationsInfo))) + deferredStructIt->memberDecorationsInfo, + deferredStructIt->structDecorationsInfo))) return failure(); deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt); @@ -1188,13 +1185,14 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { } offsetInfo[memberIndex] = memberDecoration.second[0]; } else { + auto intType = mlir::IntegerType::get(context, 32); if (!memberDecoration.second.empty()) { - memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/1, - memberDecoration.first, - memberDecoration.second[0]); + memberDecorationsInfo.emplace_back( + memberIndex, memberDecoration.first, + IntegerAttr::get(intType, memberDecoration.second[0])); } else { - memberDecorationsInfo.emplace_back(memberIndex, /*hasValue=*/0, - memberDecoration.first, 0); + memberDecorationsInfo.emplace_back( + memberIndex, memberDecoration.first, UnitAttr::get(context)); } } } @@ -1202,24 +1200,37 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) { } } + SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo; + if (decorations.count(operands[0])) { + NamedAttrList &allDecorations = decorations[operands[0]]; + for (NamedAttribute &decorationAttr : allDecorations) { + std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration( + llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true)); + assert(decoration.has_value()); + structDecorationsInfo.emplace_back(decoration.value(), + decorationAttr.getValue()); + } + } + uint32_t structID = operands[0]; std::string structIdentifier = nameMap.lookup(structID).str(); if (structIdentifier.empty()) { assert(unresolvedMemberTypes.empty() && "didn't expect unresolved member types"); - typeMap[structID] = - spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo); + typeMap[structID] = spirv::StructType::get( + memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo); } else { auto structTy = spirv::StructType::getIdentified(context, structIdentifier); typeMap[structID] = structTy; if (!unresolvedMemberTypes.empty()) - deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes, - memberTypes, offsetInfo, - memberDecorationsInfo}); + deferredStructTypesInfos.push_back( + {structTy, unresolvedMemberTypes, memberTypes, offsetInfo, + memberDecorationsInfo, structDecorationsInfo}); else if (failed(structTy.trySetBody(memberTypes, offsetInfo, - memberDecorationsInfo))) + memberDecorationsInfo, + structDecorationsInfo))) return failure(); } diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index 20482bd..db1cc3f 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -95,6 +95,7 @@ struct DeferredStructTypeInfo { SmallVector<Type, 4> memberTypes; SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo; SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo; + SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo; }; /// A struct that collects the info needed to materialize/emit a diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 3400fcf..737f296 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -19,7 +19,6 @@ #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/bit.h" @@ -319,6 +318,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID, case spirv::Decoration::RestrictPointer: case spirv::Decoration::NoContraction: case spirv::Decoration::Constant: + case spirv::Decoration::Block: // For unit attributes and decoration attributes, the args list // has no values so we do nothing. if (isa<UnitAttr, DecorationAttr>(attr)) @@ -406,8 +406,9 @@ LogicalResult Serializer::processMemberDecoration( SmallVector<uint32_t, 4> args( {structID, memberDecoration.memberIndex, static_cast<uint32_t>(memberDecoration.decoration)}); - if (memberDecoration.hasValue) { - args.push_back(memberDecoration.decorationValue); + if (memberDecoration.hasValue()) { + args.push_back( + cast<IntegerAttr>(memberDecoration.decorationValue).getInt()); } encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args); return success(); @@ -446,6 +447,19 @@ LogicalResult Serializer::processType(Location loc, Type type, LogicalResult Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, SetVector<StringRef> &serializationCtx) { + + // Map unsigned integer types to singless integer types. + // This is needed otherwise the generated spirv assembly will contain + // twice a type declaration (like OpTypeInt 32 0) which is no permitted and + // such module fails validation. Indeed at MLIR level the two types are + // different and lookup in the cache below misses. + // Note: This conversion needs to happen here before the type is looked up in + // the cache. + if (type.isUnsignedInteger()) { + type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(), + IntegerType::SignednessSemantics::Signless); + } + typeID = getTypeID(type); if (typeID) return success(); @@ -617,11 +631,16 @@ LogicalResult Serializer::prepareBasicType( operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); operands.push_back(pointeeTypeID); + // TODO: Now struct decorations are supported this code may not be + // necessary. However, it is left to support backwards compatibility. + // Ideally, Block decorations should be inserted when converting to SPIR-V. if (isInterfaceStructPtrType(ptrType)) { - if (failed(emitDecoration(getTypeID(pointeeStruct), - spirv::Decoration::Block))) - return emitError(loc, "cannot decorate ") - << pointeeStruct << " with Block decoration"; + auto structType = cast<spirv::StructType>(ptrType.getPointeeType()); + if (!structType.hasDecoration(spirv::Decoration::Block)) + if (failed(emitDecoration(getTypeID(pointeeStruct), + spirv::Decoration::Block))) + return emitError(loc, "cannot decorate ") + << pointeeStruct << " with Block decoration"; } return success(); @@ -666,10 +685,12 @@ LogicalResult Serializer::prepareBasicType( } operands.push_back(elementTypeID); if (hasOffset) { + auto intType = IntegerType::get(structType.getContext(), 32); // Decorate each struct member with an offset spirv::StructType::MemberDecorationInfo offsetDecoration{ - elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, - static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; + elementIndex, spirv::Decoration::Offset, + IntegerAttr::get(intType, + structType.getMemberOffset(elementIndex))}; if (failed(processMemberDecoration(resultID, offsetDecoration))) { return emitError(loc, "cannot decorate ") << elementIndex << "-th member of " << structType @@ -689,6 +710,20 @@ LogicalResult Serializer::prepareBasicType( } } + SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations; + structType.getStructDecorations(structDecorations); + + for (spirv::StructType::StructDecorationInfo &structDecoration : + structDecorations) { + if (failed(processDecorationAttr(loc, resultID, + structDecoration.decoration, + structDecoration.decorationValue))) { + return emitError(loc, "cannot decorate struct ") + << structType << " with " + << stringifyDecoration(structDecoration.decoration); + } + } + typeEnum = spirv::Opcode::OpTypeStruct; if (structType.isIdentified()) @@ -923,6 +958,25 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType, } else { return 0; } + } else if (isa<spirv::TensorArmType>(constType)) { + numberOfConstituents = shapedType.getNumElements(); + operands.reserve(numberOfConstituents + 2); + for (int i = 0; i < numberOfConstituents; ++i) { + uint32_t elementID = 0; + if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) { + elementID = + elementType.isInteger(1) + ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i]) + : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]); + } + if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) { + elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]); + } + if (!elementID) { + return 0; + } + operands.push_back(elementID); + } } else { operands.reserve(numberOfConstituents + 2); for (int i = 0; i < numberOfConstituents; ++i) { diff --git a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp index 04f02f2..e2c987a 100644 --- a/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp +++ b/mlir/lib/Tools/PDLL/AST/NodePrinter.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Tools/PDLL/AST/Context.h" #include "mlir/Tools/PDLL/AST/Nodes.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/TypeSwitch.h" diff --git a/mlir/lib/Tools/PDLL/ODS/Operation.cpp b/mlir/lib/Tools/PDLL/ODS/Operation.cpp index 7e708be..b836ece 100644 --- a/mlir/lib/Tools/PDLL/ODS/Operation.cpp +++ b/mlir/lib/Tools/PDLL/ODS/Operation.cpp @@ -7,8 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Tools/PDLL/ODS/Operation.h" -#include "mlir/Support/IndentedOstream.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::pdll::ods; diff --git a/mlir/lib/Tools/lsp-server-support/Protocol.cpp b/mlir/lib/Tools/lsp-server-support/Protocol.cpp index 33cdd28..9828704 100644 --- a/mlir/lib/Tools/lsp-server-support/Protocol.cpp +++ b/mlir/lib/Tools/lsp-server-support/Protocol.cpp @@ -284,11 +284,11 @@ bool mlir::lsp::fromJSON(const llvm::json::Value &value, if (codeAction->getObject("codeActionLiteralSupport")) result.codeActionStructure = true; } - if (auto *window = textDocument->getObject("window")) { - if (std::optional<bool> workDoneProgressSupport = - window->getBoolean("workDoneProgress")) - result.workDoneProgress = *workDoneProgressSupport; - } + } + if (auto *window = o->getObject("window")) { + if (std::optional<bool> workDoneProgressSupport = + window->getBoolean("workDoneProgress")) + result.workDoneProgress = *workDoneProgressSupport; } return true; } diff --git a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp index 2504123..9b937db 100644 --- a/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/LSPServer.cpp @@ -11,8 +11,6 @@ #include "Protocol.h" #include "mlir/Tools/lsp-server-support/Logging.h" #include "mlir/Tools/lsp-server-support/Transport.h" -#include "llvm/ADT/FunctionExtras.h" -#include "llvm/ADT/StringMap.h" #include <optional> #define DEBUG_TYPE "mlir-lsp-server" diff --git a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp index b1bbf98..f1dc326 100644 --- a/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/MlirLspServerMain.cpp @@ -9,7 +9,6 @@ #include "mlir/Tools/mlir-lsp-server/MlirLspServerMain.h" #include "LSPServer.h" #include "MLIRServer.h" -#include "mlir/IR/Dialect.h" #include "mlir/Tools/lsp-server-support/Logging.h" #include "mlir/Tools/lsp-server-support/Transport.h" #include "llvm/Support/CommandLine.h" diff --git a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp index 4ba76fb..a56e9a1 100644 --- a/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp +++ b/mlir/lib/Tools/mlir-lsp-server/Protocol.cpp @@ -11,14 +11,7 @@ //===----------------------------------------------------------------------===// #include "Protocol.h" -#include "llvm/ADT/Hashing.h" -#include "llvm/ADT/SmallString.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/Format.h" -#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/JSON.h" -#include "llvm/Support/Path.h" -#include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace mlir::lsp; diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp index 8f78590..bdcdaa4 100644 --- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp +++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp @@ -508,13 +508,20 @@ performActions(raw_ostream &os, /// Parses the memory buffer. If successfully, run a series of passes against /// it and print the result. -static LogicalResult processBuffer(raw_ostream &os, - std::unique_ptr<MemoryBuffer> ownedBuffer, - const MlirOptMainConfig &config, - DialectRegistry ®istry, - llvm::ThreadPoolInterface *threadPool) { +static LogicalResult +processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer, + llvm::MemoryBufferRef sourceBuffer, + const MlirOptMainConfig &config, DialectRegistry ®istry, + SourceMgrDiagnosticVerifierHandler *verifyHandler, + llvm::ThreadPoolInterface *threadPool) { // Tell sourceMgr about this buffer, which is what the parser will pick up. auto sourceMgr = std::make_shared<SourceMgr>(); + // Add the original buffer to the source manager to use for determining + // locations. + sourceMgr->AddNewSourceBuffer( + llvm::MemoryBuffer::getMemBuffer(sourceBuffer, + /*RequiresNullTerminator=*/false), + SMLoc()); sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc()); // Create a context just for the current buffer. Disable threading on creation @@ -522,6 +529,8 @@ static LogicalResult processBuffer(raw_ostream &os, MLIRContext context(registry, MLIRContext::Threading::DISABLED); if (threadPool) context.setThreadPool(*threadPool); + if (verifyHandler) + verifyHandler->registerInContext(&context); StringRef irdlFile = config.getIrdlFile(); if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context))) @@ -545,17 +554,12 @@ static LogicalResult processBuffer(raw_ostream &os, return performActions(os, sourceMgr, &context, config); } - SourceMgrDiagnosticVerifierHandler sourceMgrHandler( - *sourceMgr, &context, config.verifyDiagnosticsLevel()); - // Do any processing requested by command line flags. We don't care whether // these actions succeed or fail, we only care what diagnostics they produce // and whether they match our expectations. (void)performActions(os, sourceMgr, &context, config); - // Verify the diagnostic handler to make sure that each of the diagnostics - // matched. - return sourceMgrHandler.verify(); + return success(); } std::pair<std::string, std::string> @@ -624,14 +628,31 @@ LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream, if (threadPoolCtx.isMultithreadingEnabled()) threadPool = &threadPoolCtx.getThreadPool(); + SourceMgr sourceMgr; + sourceMgr.AddNewSourceBuffer( + llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(), + /*RequiresNullTerminator=*/false), + SMLoc()); + // Note: this creates a verifier handler independent of the the flag set, as + // internally if the flag is not set, a new scoped diagnostic handler is + // created which would intercept the diagnostics and verify them. + SourceMgrDiagnosticVerifierHandler sourceMgrHandler( + sourceMgr, &threadPoolCtx, config.verifyDiagnosticsLevel()); auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer, - raw_ostream &os) { - return processBuffer(os, std::move(chunkBuffer), config, registry, - threadPool); + llvm::MemoryBufferRef sourceBuffer, raw_ostream &os) { + return processBuffer( + os, std::move(chunkBuffer), sourceBuffer, config, registry, + config.shouldVerifyDiagnostics() ? &sourceMgrHandler : nullptr, + threadPool); }; - return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream, - config.inputSplitMarker(), - config.outputSplitMarker()); + LogicalResult status = splitAndProcessBuffer( + llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(), + /*RequiresNullTerminator=*/false), + chunkFn, outputStream, config.inputSplitMarker(), + config.outputSplitMarker()); + if (config.shouldVerifyDiagnostics() && failed(sourceMgrHandler.verify())) + status = failure(); + return status; } LogicalResult mlir::MlirOptMain(int argc, char **argv, diff --git a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp index 97b8288..685e794 100644 --- a/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp +++ b/mlir/lib/Tools/mlir-tblgen/MlirTblgenMain.cpp @@ -15,7 +15,6 @@ #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/GenNameParser.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/InitLLVM.h" #include "llvm/Support/Signals.h" #include "llvm/TableGen/Error.h" diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp index f2a81cc..e1c8afb 100644 --- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp +++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp @@ -8,9 +8,6 @@ #include "mlir/Tools/mlir-translate/MlirTranslateMain.h" #include "mlir/IR/AsmState.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/Verifier.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/FileUtilities.h" #include "mlir/Support/Timing.h" @@ -138,6 +135,13 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv, // Processes the memory buffer with a new MLIRContext. auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer, raw_ostream &os) { + // Many of the translations expect a null-terminated buffer while splitting + // the buffer does not guarantee null-termination. Make a copy of the buffer + // to ensure null-termination. + if (!ownedBuffer->getBuffer().ends_with('\0')) { + ownedBuffer = llvm::MemoryBuffer::getMemBufferCopy( + ownedBuffer->getBuffer(), ownedBuffer->getBufferIdentifier()); + } // Temporary buffers for chained translation processing. std::string dataIn; std::string dataOut; diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 3a8088b..058039e 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -37,5 +37,4 @@ add_mlir_library(MLIRTransforms MLIRSideEffectInterfaces MLIRSupport MLIRTransformUtils - MLIRUBDialect ) diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 4d09c5f..09e5a02 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -19,7 +19,6 @@ #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseMapInfo.h" -#include "llvm/ADT/Hashing.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/RecyclingAllocator.h" diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp index 4b0ac28..7a99fe8 100644 --- a/mlir/lib/Transforms/Canonicalizer.cpp +++ b/mlir/lib/Transforms/Canonicalizer.cpp @@ -13,7 +13,6 @@ #include "mlir/Transforms/Passes.h" -#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" diff --git a/mlir/lib/Transforms/OpStats.cpp b/mlir/lib/Transforms/OpStats.cpp index 0dc3fe9..9ebf310 100644 --- a/mlir/lib/Transforms/OpStats.cpp +++ b/mlir/lib/Transforms/OpStats.cpp @@ -8,10 +8,8 @@ #include "mlir/Transforms/Passes.h" -#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/Support/Format.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 608bdcb..4ccb83f 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -36,6 +36,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" @@ -51,6 +52,7 @@ #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #include <cstddef> #include <memory> @@ -58,8 +60,6 @@ #include <vector> #define DEBUG_TYPE "remove-dead-values" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") namespace mlir { #define GEN_PASS_DEF_REMOVEDEADVALUES @@ -119,21 +119,21 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet, RunLivenessAnalysis &la) { for (Value value : values) { if (nonLiveSet.contains(value)) { - LDBG("Value " << value << " is already marked non-live (dead)"); + LDBG() << "Value " << value << " is already marked non-live (dead)"; continue; } const Liveness *liveness = la.getLiveness(value); if (!liveness) { - LDBG("Value " << value - << " has no liveness info, conservatively considered live"); + LDBG() << "Value " << value + << " has no liveness info, conservatively considered live"; return true; } if (liveness->isLive) { - LDBG("Value " << value << " is live according to liveness analysis"); + LDBG() << "Value " << value << " is live according to liveness analysis"; return true; } else { - LDBG("Value " << value << " is dead according to liveness analysis"); + LDBG() << "Value " << value << " is dead according to liveness analysis"; } } return false; @@ -148,8 +148,8 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet, for (auto [index, value] : llvm::enumerate(values)) { if (nonLiveSet.contains(value)) { lives.reset(index); - LDBG("Value " << value << " is already marked non-live (dead) at index " - << index); + LDBG() << "Value " << value + << " is already marked non-live (dead) at index " << index; continue; } @@ -161,17 +161,17 @@ static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet, // (because they weren't erased) and also their liveness is null because // liveness analysis ran before their creation. if (!liveness) { - LDBG("Value " << value << " at index " << index - << " has no liveness info, conservatively considered live"); + LDBG() << "Value " << value << " at index " << index + << " has no liveness info, conservatively considered live"; continue; } if (!liveness->isLive) { lives.reset(index); - LDBG("Value " << value << " at index " << index - << " is dead according to liveness analysis"); + LDBG() << "Value " << value << " at index " << index + << " is dead according to liveness analysis"; } else { - LDBG("Value " << value << " at index " << index - << " is live according to liveness analysis"); + LDBG() << "Value " << value << " at index " << index + << " is live according to liveness analysis"; } } @@ -187,8 +187,8 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range, if (!nonLive[index]) continue; nonLiveSet.insert(result); - LDBG("Marking value " << result << " as non-live (dead) at index " - << index); + LDBG() << "Marking value " << result << " as non-live (dead) at index " + << index; } } @@ -258,16 +258,18 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG("Processing simple op: " << *op); + LDBG() << "Processing simple op: " << *op; if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { - LDBG("Simple op is not memory effect free or has live results, skipping: " - << *op); + LDBG() + << "Simple op is not memory effect free or has live results, skipping: " + << *op; return; } - LDBG("Simple op has all dead results and is memory effect free, scheduling " - "for removal: " - << *op); + LDBG() + << "Simple op has all dead results and is memory effect free, scheduling " + "for removal: " + << *op; cl.operations.push_back(op); collectNonLiveValues(nonLiveSet, op->getResults(), BitVector(op->getNumResults(), true)); @@ -286,10 +288,10 @@ static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, static void processFuncOp(FunctionOpInterface funcOp, Operation *module, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG("Processing function op: " << funcOp.getOperation()->getName()); + LDBG() << "Processing function op: " << funcOp.getOperation()->getName(); if (funcOp.isPublic() || funcOp.isExternal()) { - LDBG("Function is public or external, skipping: " - << funcOp.getOperation()->getName()); + LDBG() << "Function is public or external, skipping: " + << funcOp.getOperation()->getName(); return; } @@ -345,8 +347,6 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, // since it forwards only to non-live value(s) (%1#1). Operation *lastReturnOp = funcOp.back().getTerminator(); size_t numReturns = lastReturnOp->getNumOperands(); - if (numReturns == 0) - return; BitVector nonLiveRets(numReturns, true); for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); @@ -368,6 +368,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, cl.functions.push_back({funcOp, nonLiveArgs, nonLiveRets}); // Do (5) and (6). + if (numReturns == 0) + return; for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); assert(isa<CallOpInterface>(callOp) && "expected a call-like user"); @@ -409,9 +411,8 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LLVM_DEBUG(DBGS() << "Processing region branch op: "; regionBranchOp->print( - llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n"); + LDBG() << "Processing region branch op: " + << OpWithFlags(regionBranchOp, OpPrintingFlags().skipRegions()); // Mark live results of `regionBranchOp` in `liveResults`. auto markLiveResults = [&](BitVector &liveResults) { liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la); @@ -697,7 +698,7 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG("Processing branch op: " << *branchOp); + LDBG() << "Processing branch op: " << *branchOp; unsigned numSuccessors = branchOp->getNumSuccessors(); for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index d224f73..08803e0 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -14,8 +14,10 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Iterators.h" +#include "mlir/IR/Operation.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/FormatVariadic.h" @@ -130,11 +132,6 @@ struct ConversionValueMapping { /// value. ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; - /// Lookup the given value within the map, or return an empty vector if the - /// value is not mapped. If it is mapped, this follows the same behavior - /// as `lookupOrDefault`. - ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; - template <typename T> struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; @@ -237,15 +234,6 @@ ConversionValueMapping::lookupOrDefault(Value from, return !desiredValue.empty() ? std::move(desiredValue) : std::move(current); } -ValueVector ConversionValueMapping::lookupOrNull(Value from, - TypeRange desiredTypes) const { - ValueVector result = lookupOrDefault(from, desiredTypes); - if (result == ValueVector{from} || - (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) - return {}; - return result; -} - //===----------------------------------------------------------------------===// // Rewriter and Translation State //===----------------------------------------------------------------------===// @@ -521,9 +509,11 @@ private: class MoveBlockRewrite : public BlockRewrite { public: MoveBlockRewrite(ConversionPatternRewriterImpl &rewriterImpl, Block *block, - Region *region, Block *insertBeforeBlock) - : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), region(region), - insertBeforeBlock(insertBeforeBlock) {} + Region *previousRegion, Region::iterator previousIt) + : BlockRewrite(Kind::MoveBlock, rewriterImpl, block), + region(previousRegion), + insertBeforeBlock(previousIt == previousRegion->end() ? nullptr + : &*previousIt) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::MoveBlock; @@ -630,9 +620,12 @@ protected: class MoveOperationRewrite : public OperationRewrite { public: MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Operation *op, Block *block, Operation *insertBeforeOp) - : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block), - insertBeforeOp(insertBeforeOp) {} + Operation *op, OpBuilder::InsertPoint previous) + : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), + block(previous.getBlock()), + insertBeforeOp(previous.getPoint() == previous.getBlock()->end() + ? nullptr + : &*previous.getPoint()) {} static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() == Kind::MoveOperation; @@ -926,6 +919,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Return "true" if the given operation was replaced or erased. bool wasOpReplaced(Operation *op) const; + /// Lookup the most recently mapped values with the desired types in the + /// mapping. + /// + /// Special cases: + /// - If the desired type range is empty, simply return the most recently + /// mapped values. + /// - If there is no mapping to the desired types, also return the most + /// recently mapped values. + /// - If there is no mapping for the given values at all, return the given + /// value. + ValueVector lookupOrDefault(Value from, TypeRange desiredTypes = {}) const; + + /// Lookup the given value within the map, or return an empty vector if the + /// value is not mapped. If it is mapped, this follows the same behavior + /// as `lookupOrDefault`. + ValueVector lookupOrNull(Value from, TypeRange desiredTypes = {}) const; + //===--------------------------------------------------------------------===// // IR Rewrites / Type Conversion //===--------------------------------------------------------------------===// @@ -1248,6 +1258,22 @@ void ConversionPatternRewriterImpl::applyRewrites() { // State Management //===----------------------------------------------------------------------===// +ValueVector +ConversionPatternRewriterImpl::lookupOrDefault(Value from, + TypeRange desiredTypes) const { + return mapping.lookupOrDefault(from, desiredTypes); +} + +ValueVector +ConversionPatternRewriterImpl::lookupOrNull(Value from, + TypeRange desiredTypes) const { + ValueVector result = lookupOrDefault(from, desiredTypes); + if (result == ValueVector{from} || + (!desiredTypes.empty() && TypeRange(ValueRange(result)) != desiredTypes)) + return {}; + return result; +} + RewriterState ConversionPatternRewriterImpl::getCurrentState() { return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size()); } @@ -1295,7 +1321,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // The current pattern does not have a type converter. I.e., it does not // distinguish between legal and illegal types. For each operand, simply // pass through the most recently mapped values. - remapped.push_back(mapping.lookupOrDefault(operand)); + remapped.push_back(lookupOrDefault(operand)); continue; } @@ -1314,7 +1340,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( continue; } - ValueVector repl = mapping.lookupOrDefault(operand, legalTypes); + ValueVector repl = lookupOrDefault(operand, legalTypes); if (!repl.empty() && TypeRange(ValueRange(repl)) == legalTypes) { // Mapped values have the correct type or there is an existing // materialization. Or the operand is not mapped at all and has the @@ -1324,7 +1350,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( } // Create a materialization for the most recently mapped values. - repl = mapping.lookupOrDefault(operand); + repl = lookupOrDefault(operand); ValueRange castValues = buildUnresolvedMaterialization( MaterializationKind::Target, computeInsertPoint(repl), operandLoc, /*valuesToMap=*/repl, /*inputs=*/repl, /*outputTypes=*/legalTypes, @@ -1519,7 +1545,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // Try to find a replacement value with the same type in the conversion value // mapping. This includes cached materializations. We try to reuse those // instead of generating duplicate IR. - ValueVector repl = mapping.lookupOrNull(value, value.getType()); + ValueVector repl = lookupOrNull(value, value.getType()); if (!repl.empty()) return repl.front(); @@ -1535,7 +1561,7 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( // No replacement value was found. Get the latest replacement value // (regardless of the type) and build a source materialization to the // original type. - repl = mapping.lookupOrNull(value); + repl = lookupOrNull(value); if (repl.empty()) { // No replacement value is registered in the mapping. This means that the // value is dropped and no longer needed. (If the value were still needed, @@ -1568,23 +1594,30 @@ Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( void ConversionPatternRewriterImpl::notifyOperationInserted( Operation *op, OpBuilder::InsertPoint previous) { + // If no previous insertion point is provided, the op used to be detached. + bool wasDetached = !previous.isSet(); LLVM_DEBUG({ - logger.startLine() << "** Insert : '" << op->getName() << "'(" << op - << ")\n"; + logger.startLine() << "** Insert : '" << op->getName() << "' (" << op + << ")"; + if (wasDetached) + logger.getOStream() << " (was detached)"; + logger.getOStream() << "\n"; }); assert(!wasOpReplaced(op->getParentOp()) && "attempting to insert into a block within a replaced/erased op"); - if (!previous.isSet()) { - // This is a newly created op. + if (wasDetached) { + // If the op was detached, it is most likely a newly created op. + // TODO: If the same op is inserted multiple times from a detached state, + // the rollback mechanism may erase the same op multiple times. This is a + // bug in the rollback-based dialect conversion driver. appendRewrite<CreateOperationRewrite>(op); patternNewOps.insert(op); return; } - Operation *prevOp = previous.getPoint() == previous.getBlock()->end() - ? nullptr - : &*previous.getPoint(); - appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp); + + // The op was moved from one place to another. + appendRewrite<MoveOperationRewrite>(op, previous); } void ConversionPatternRewriterImpl::replaceOp( @@ -1649,29 +1682,40 @@ void ConversionPatternRewriterImpl::eraseBlock(Block *block) { void ConversionPatternRewriterImpl::notifyBlockInserted( Block *block, Region *previous, Region::iterator previousIt) { - assert(!wasOpReplaced(block->getParentOp()) && - "attempting to insert into a region within a replaced/erased op"); + // If no previous insertion point is provided, the block used to be detached. + bool wasDetached = !previous; + Operation *newParentOp = block->getParentOp(); LLVM_DEBUG( { - Operation *parent = block->getParentOp(); + Operation *parent = newParentOp; if (parent) { logger.startLine() << "** Insert Block into : '" << parent->getName() - << "'(" << parent << ")\n"; + << "' (" << parent << ")"; } else { logger.startLine() - << "** Insert Block into detached Region (nullptr parent op)'\n"; + << "** Insert Block into detached Region (nullptr parent op)"; } + if (wasDetached) + logger.getOStream() << " (was detached)"; + logger.getOStream() << "\n"; }); + assert(!wasOpReplaced(newParentOp) && + "attempting to insert into a region within a replaced/erased op"); + (void)newParentOp; patternInsertedBlocks.insert(block); - if (!previous) { - // This is a newly created block. + if (wasDetached) { + // If the block was detached, it is most likely a newly created block. + // TODO: If the same block is inserted multiple times from a detached state, + // the rollback mechanism may erase the same block multiple times. This is a + // bug in the rollback-based dialect conversion driver. appendRewrite<CreateBlockRewrite>(block); return; } - Block *prevBlock = previousIt == previous->end() ? nullptr : &*previousIt; - appendRewrite<MoveBlockRewrite>(block, previous, prevBlock); + + // The block was moved from one place to another. + appendRewrite<MoveBlockRewrite>(block, previous, previousIt); } void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, @@ -1716,6 +1760,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> newVals = llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> { return v ? SmallVector<Value>{v} : SmallVector<Value>(); @@ -1731,6 +1781,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple( impl->logger.startLine() << "** Replace : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + impl->replaceOp(op, std::move(newValues)); } @@ -1739,6 +1795,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) { impl->logger.startLine() << "** Erase : '" << op->getName() << "'(" << op << ")\n"; }); + + // If the current insertion point is before the erased operation, we adjust + // the insertion point to be after the operation. + if (getInsertionPoint() == op->getIterator()) + setInsertionPointAfter(op); + SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {}); impl->replaceOp(op, std::move(nullRepls)); } @@ -1845,6 +1907,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, moveOpBefore(&source->front(), dest, before); } + // If the current insertion point is within the source block, adjust the + // insertion point to the destination block. + if (getInsertionBlock() == source) + setInsertionPoint(dest, getInsertionPoint()); + // Erase the source block. eraseBlock(source); } @@ -1976,6 +2043,7 @@ private: /// Legalize the resultant IR after successfully applying the given pattern. LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, + const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks); @@ -2092,8 +2160,9 @@ OperationLegalizer::legalize(Operation *op, // If the operation has no regions, just print it here. if (!isIgnored && op->getNumRegions() == 0) { - op->print(logger.startLine(), OpPrintingFlags().printGenericOpForm()); - logger.getOStream() << "\n\n"; + logger.startLine() << OpWithFlags(op, + OpPrintingFlags().printGenericOpForm()) + << "\n"; } }); @@ -2172,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op, rewriterImpl.logger.startLine() << "* Fold {\n"; rewriterImpl.logger.indent(); }); - (void)rewriterImpl; + + // Clear pattern state, so that the next pattern application starts with a + // clean slate. (The op/block sets are populated by listener notifications.) + auto cleanup = llvm::make_scope_exit([&]() { + rewriterImpl.patternNewOps.clear(); + rewriterImpl.patternModifiedOps.clear(); + rewriterImpl.patternInsertedBlocks.clear(); + }); + + // Upon failure, undo all changes made by the folder. + RewriterState curState = rewriterImpl.getCurrentState(); // Try to fold the operation. StringRef opName = op->getName().getStringRef(); SmallVector<Value, 2> replacementValues; SmallVector<Operation *, 2> newOps; rewriter.setInsertionPoint(op); + rewriter.startOpModification(op); if (failed(rewriter.tryFold(op, replacementValues, &newOps))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold")); + rewriter.cancelOpModification(op); return failure(); } + rewriter.finalizeOpModification(op); // An empty list of replacement values indicates that the fold was in-place. // As the operation changed, a new legalization needs to be attempted. if (replacementValues.empty()) return legalize(op, rewriter); + // Insert a replacement for 'op' with the folded replacement values. + rewriter.replaceOp(op, replacementValues); + // Recursively legalize any new constant operations. for (Operation *newOp : newOps) { if (failed(legalize(newOp, rewriter))) { @@ -2201,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op, "op '" + opName + "' folder rollback of IR modifications requested"); } - // Legalization failed: erase all materialized constants. - for (Operation *op : newOps) - rewriter.eraseOp(op); + rewriterImpl.resetState( + curState, std::string(op->getName().getStringRef()) + " folder"); return failure(); } } - // Insert a replacement for 'op' with the folded replacement values. - rewriter.replaceOp(op, replacementValues); - LLVM_DEBUG(logSuccess(rewriterImpl.logger, "")); return success(); } @@ -2220,6 +2301,32 @@ OperationLegalizer::legalizeWithPattern(Operation *op, ConversionPatternRewriter &rewriter) { auto &rewriterImpl = rewriter.getImpl(); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + Operation *checkOp; + std::optional<OperationFingerPrint> topLevelFingerPrint; + if (!rewriterImpl.config.allowPatternRollback) { + // The op may be getting erased, so we have to check the parent op. + // (In rare cases, a pattern may even erase the parent op, which will cause + // a crash here. Expensive checks are "best effort".) Skip the check if the + // op does not have a parent op. + if ((checkOp = op->getParentOp())) { + if (!op->getContext()->isMultithreadingEnabled()) { + topLevelFingerPrint = OperationFingerPrint(checkOp); + } else { + // Another thread may be modifying a sibling operation. Therefore, the + // fingerprinting mechanism of the parent op works only in + // single-threaded mode. + LLVM_DEBUG({ + rewriterImpl.logger.startLine() + << "WARNING: Multi-threadeding is enabled. Some dialect " + "conversion expensive checks are skipped in multithreading " + "mode!\n"; + }); + } + } + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern &pattern) { bool canApply = canApplyPattern(op, pattern, rewriter); @@ -2232,6 +2339,17 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + if (!rewriterImpl.config.allowPatternRollback) { + // Returning "failure" after modifying IR is not allowed. + if (checkOp) { + OperationFingerPrint fingerPrintAfterPattern(checkOp); + if (fingerPrintAfterPattern != *topLevelFingerPrint) + llvm::report_fatal_error("pattern '" + pattern.getDebugName() + + "' returned failure but IR did change"); + } + } +#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); rewriterImpl.patternInsertedBlocks.clear(); @@ -2260,7 +2378,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, moveAndReset(rewriterImpl.patternModifiedOps); SetVector<Block *> insertedBlocks = moveAndReset(rewriterImpl.patternInsertedBlocks); - auto result = legalizePatternResult(op, pattern, rewriter, newOps, + auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps, modifiedOps, insertedBlocks); appliedPatterns.erase(&pattern); if (failed(result)) { @@ -2303,7 +2421,7 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, LogicalResult OperationLegalizer::legalizePatternResult( Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, - const SetVector<Operation *> &newOps, + const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks) { auto &impl = rewriter.getImpl(); @@ -2319,7 +2437,8 @@ LogicalResult OperationLegalizer::legalizePatternResult( return hasRewrite<ModifyOperationRewrite>(newRewrites, op); }; if (!replacedRoot() && !updatedRootInPlace()) - llvm::report_fatal_error("expected pattern to replace the root operation"); + llvm::report_fatal_error( + "expected pattern to replace the root operation or modify it in place"); #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Legalize each of the actions registered during application. diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index b82d850..607b86c 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp index b639e87f..26c965c 100644 --- a/mlir/lib/Transforms/Utils/Inliner.cpp +++ b/mlir/lib/Transforms/Utils/Inliner.cpp @@ -21,7 +21,7 @@ #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "inlining" @@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks, // InlinerInterfaceImpl //===----------------------------------------------------------------------===// -#ifndef NDEBUG static std::string getNodeName(CallOpInterface op) { if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee())) return debugString(op); return "_unnamed_callee_"; } -#endif /// Return true if the specified `inlineHistoryID` indicates an inline history /// that already includes `node`. @@ -614,10 +612,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{}); LLVM_DEBUG({ - llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n"; + LDBG() << "* Inliner: Initial calls in SCC are: {"; for (unsigned i = 0, e = calls.size(); i < e; ++i) - llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n"; - llvm::dbgs() << "}\n"; + LDBG() << " " << i << ". " << calls[i].call << ","; + LDBG() << "}"; }); // Try to inline each of the call operations. Don't cache the end iterator @@ -635,9 +633,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, CallOpInterface call = it.call; LLVM_DEBUG({ if (doInline) - llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Inlining call: " << i << ". " << call; else - llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n"; + LDBG() << "* Not inlining call: " << i << ". " << call; }); if (!doInline) continue; @@ -654,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, cast<CallableOpInterface>(targetRegion->getParentOp()), targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace); if (failed(inlineResult)) { - LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n"); + LDBG() << "** Failed to inline"; continue; } inlinedAnyCalls = true; @@ -667,19 +665,16 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface, auto historyToString = [](InlineHistoryT h) { return h.has_value() ? std::to_string(*h) : "root"; }; - (void)historyToString; - LLVM_DEBUG(llvm::dbgs() - << "* new inlineHistory entry: " << newInlineHistoryID << ". [" - << getNodeName(call) << ", " << historyToString(inlineHistoryID) - << "]\n"); + LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". [" + << getNodeName(call) << ", " << historyToString(inlineHistoryID) + << "]"; for (unsigned k = prevSize; k != calls.size(); ++k) { callHistory.push_back(newInlineHistoryID); - LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call - << "}\n with historyID = " << newInlineHistoryID - << ", added due to inlining of\n call {" << call - << "}\n with historyID = " - << historyToString(inlineHistoryID) << "\n"); + LDBG() << "* new call " << k << " {" << calls[k].call + << "}\n with historyID = " << newInlineHistoryID + << ", added due to inlining of\n call {" << call + << "}\n with historyID = " << historyToString(inlineHistoryID); } // If the inlining was successful, Merge the new uses into the source node. |