diff options
Diffstat (limited to 'mlir/lib')
271 files changed, 13316 insertions, 3761 deletions
diff --git a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp index 6cece46..8062b474 100644 --- a/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp +++ b/mlir/lib/Analysis/AliasAnalysis/LocalAliasAnalysis.cpp @@ -127,9 +127,12 @@ static void collectUnderlyingAddressValues(OpResult result, unsigned maxDepth, Operation *op = result.getOwner(); // If this is a view, unwrap to the source. - if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op)) - return collectUnderlyingAddressValues(view.getViewSource(), maxDepth, - visited, output); + if (ViewLikeOpInterface view = dyn_cast<ViewLikeOpInterface>(op)) { + if (result == view.getViewDest()) { + return collectUnderlyingAddressValues(view.getViewSource(), maxDepth, + visited, output); + } + } // Check to see if we can reason about the control flow of this op. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { return collectUnderlyingAddressValues(branch, /*region=*/nullptr, result, diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index 10874fd..9424eff 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/SymbolTable.h" #include "mlir/IR/Value.h" #include "mlir/IR/ValueRange.h" @@ -78,9 +79,17 @@ void Executable::onUpdate(DataFlowSolver *solver) const { void PredecessorState::print(raw_ostream &os) const { if (allPredecessorsKnown()) os << "(all) "; - os << "predecessors:\n"; - for (Operation *op : getKnownPredecessors()) - os << " " << *op << "\n"; + os << "predecessors:"; + if (getKnownPredecessors().empty()) + os << " (none)"; + else + os << "\n"; + llvm::interleave( + getKnownPredecessors(), os, + [&](Operation *op) { + os << " " << OpWithFlags(op, OpPrintingFlags().skipRegions()); + }, + "\n"); } ChangeResult PredecessorState::join(Operation *predecessor) { @@ -127,7 +136,7 @@ DeadCodeAnalysis::DeadCodeAnalysis(DataFlowSolver &solver) LogicalResult DeadCodeAnalysis::initialize(Operation *top) { LDBG() << "Initializing DeadCodeAnalysis for top-level op: " - << top->getName(); + << OpWithFlags(top, OpPrintingFlags().skipRegions()); // Mark the top-level blocks as executable. for (Region ®ion : top->getRegions()) { if (region.empty()) @@ -135,7 +144,8 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) { auto *state = getOrCreate<Executable>(getProgramPointBefore(®ion.front())); propagateIfChanged(state, state->setToLive()); - LDBG() << "Marked entry block live for region in op: " << top->getName(); + LDBG() << "Marked entry block live for region in op: " + << OpWithFlags(top, OpPrintingFlags().skipRegions()); } // Mark as overdefined the predecessors of symbol callables with potentially @@ -147,17 +157,19 @@ LogicalResult DeadCodeAnalysis::initialize(Operation *top) { void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { LDBG() << "[init] Entering initializeSymbolCallables for top-level op: " - << top->getName(); + << OpWithFlags(top, OpPrintingFlags().skipRegions()); analysisScope = top; auto walkFn = [&](Operation *symTable, bool allUsesVisible) { - LDBG() << "[init] Processing symbol table op: " << symTable->getName(); + LDBG() << "[init] Processing symbol table op: " + << OpWithFlags(symTable, OpPrintingFlags().skipRegions()); Region &symbolTableRegion = symTable->getRegion(0); Block *symbolTableBlock = &symbolTableRegion.front(); bool foundSymbolCallable = false; for (auto callable : symbolTableBlock->getOps<CallableOpInterface>()) { LDBG() << "[init] Found CallableOpInterface: " - << callable.getOperation()->getName(); + << OpWithFlags(callable.getOperation(), + OpPrintingFlags().skipRegions()); Region *callableRegion = callable.getCallableRegion(); if (!callableRegion) continue; @@ -172,7 +184,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { getOrCreate<PredecessorState>(getProgramPointAfter(callable)); propagateIfChanged(state, state->setHasUnknownPredecessors()); LDBG() << "[init] Marked callable as having unknown predecessors: " - << callable.getOperation()->getName(); + << OpWithFlags(callable.getOperation(), + OpPrintingFlags().skipRegions()); } foundSymbolCallable = true; } @@ -195,7 +208,8 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { propagateIfChanged(state, state->setHasUnknownPredecessors()); LDBG() << "[init] Marked nested callable as " "having unknown predecessors: " - << callable.getOperation()->getName(); + << OpWithFlags(callable.getOperation(), + OpPrintingFlags().skipRegions()); }); } @@ -211,13 +225,13 @@ void DeadCodeAnalysis::initializeSymbolCallables(Operation *top) { propagateIfChanged(state, state->setHasUnknownPredecessors()); LDBG() << "[init] Found non-call use for symbol, " "marked as having unknown predecessors: " - << symbol->getName(); + << OpWithFlags(symbol, OpPrintingFlags().skipRegions()); } }; SymbolTable::walkSymbolTables(top, /*allSymUsesVisible=*/!top->getBlock(), walkFn); LDBG() << "[init] Finished initializeSymbolCallables for top-level op: " - << top->getName(); + << OpWithFlags(top, OpPrintingFlags().skipRegions()); } /// Returns true if the operation is a returning terminator in region @@ -229,12 +243,13 @@ static bool isRegionOrCallableReturn(Operation *op) { } LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { - LDBG() << "[init] Entering initializeRecursively for op: " << op->getName() - << " at " << op; + LDBG() << "[init] Entering initializeRecursively for op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // Initialize the analysis by visiting every op with control-flow semantics. if (op->getNumRegions() || op->getNumSuccessors() || isRegionOrCallableReturn(op) || isa<CallOpInterface>(op)) { - LDBG() << "[init] Visiting op with control-flow semantics: " << *op; + LDBG() << "[init] Visiting op with control-flow semantics: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // When the liveness of the parent block changes, make sure to // re-invoke the analysis on the op. if (op->getBlock()) @@ -246,16 +261,17 @@ LogicalResult DeadCodeAnalysis::initializeRecursively(Operation *op) { } // Recurse on nested operations. for (Region ®ion : op->getRegions()) { - LDBG() << "[init] Recursing into region of op: " << op->getName(); + LDBG() << "[init] Recursing into region of op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); for (Operation &nestedOp : region.getOps()) { - LDBG() << "[init] Recursing into nested op: " << nestedOp.getName() - << " at " << &nestedOp; + LDBG() << "[init] Recursing into nested op: " + << OpWithFlags(&nestedOp, OpPrintingFlags().skipRegions()); if (failed(initializeRecursively(&nestedOp))) return failure(); } } - LDBG() << "[init] Finished initializeRecursively for op: " << op->getName() - << " at " << op; + LDBG() << "[init] Finished initializeRecursively for op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); return success(); } @@ -269,35 +285,40 @@ void DeadCodeAnalysis::markEdgeLive(Block *from, Block *to) { } void DeadCodeAnalysis::markEntryBlocksLive(Operation *op) { - LDBG() << "Marking entry blocks live for op: " << op->getName(); + LDBG() << "Marking entry blocks live for op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); for (Region ®ion : op->getRegions()) { if (region.empty()) continue; auto *state = getOrCreate<Executable>(getProgramPointBefore(®ion.front())); propagateIfChanged(state, state->setToLive()); - LDBG() << "Marked entry block live for region in op: " << op->getName(); + LDBG() << "Marked entry block live for region in op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); } } LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { - LDBG() << "Visiting program point: " << point << " " << *point; + LDBG() << "Visiting program point: " << *point; if (point->isBlockStart()) return success(); Operation *op = point->getPrevOp(); - LDBG() << "Visiting operation: " << *op; + LDBG() << "Visiting operation: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // If the parent block is not executable, there is nothing to do. if (op->getBlock() != nullptr && !getOrCreate<Executable>(getProgramPointBefore(op->getBlock())) ->isLive()) { - LDBG() << "Parent block not live, skipping op: " << *op; + LDBG() << "Parent block not live, skipping op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); return success(); } // We have a live call op. Add this as a live predecessor of the callee. if (auto call = dyn_cast<CallOpInterface>(op)) { - LDBG() << "Visiting call operation: " << *op; + LDBG() << "Visiting call operation: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); visitCallOperation(call); } @@ -305,12 +326,14 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { if (op->getNumRegions()) { // Check if we can reason about the region control-flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { - LDBG() << "Visiting region branch operation: " << *op; + LDBG() << "Visiting region branch operation: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); visitRegionBranchOperation(branch); // Check if this is a callable operation. } else if (auto callable = dyn_cast<CallableOpInterface>(op)) { - LDBG() << "Visiting callable operation: " << *op; + LDBG() << "Visiting callable operation: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); const auto *callsites = getOrCreateFor<PredecessorState>( getProgramPointAfter(op), getProgramPointAfter(callable)); @@ -322,19 +345,22 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { // Otherwise, conservatively mark all entry blocks as executable. } else { - LDBG() << "Marking all entry blocks live for op: " << *op; + LDBG() << "Marking all entry blocks live for op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); markEntryBlocksLive(op); } } if (isRegionOrCallableReturn(op)) { if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) { - LDBG() << "Visiting region terminator: " << *op; + LDBG() << "Visiting region terminator: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // Visit the exiting terminator of a region. visitRegionTerminator(op, branch); } else if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) { - LDBG() << "Visiting callable terminator: " << *op; + LDBG() << "Visiting callable terminator: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // Visit the exiting terminator of a callable. visitCallableTerminator(op, callable); } @@ -343,12 +369,14 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { if (op->getNumSuccessors()) { // Check if we can reason about the control-flow. if (auto branch = dyn_cast<BranchOpInterface>(op)) { - LDBG() << "Visiting branch operation: " << *op; + LDBG() << "Visiting branch operation: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); visitBranchOperation(branch); // Otherwise, conservatively mark all successors as exectuable. } else { - LDBG() << "Marking all successors live for op: " << *op; + LDBG() << "Marking all successors live for op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); for (Block *successor : op->getSuccessors()) markEdgeLive(op->getBlock(), successor); } @@ -358,7 +386,8 @@ LogicalResult DeadCodeAnalysis::visit(ProgramPoint *point) { } void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { - LDBG() << "visitCallOperation: " << call.getOperation()->getName(); + LDBG() << "visitCallOperation: " + << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions()); Operation *callableOp = call.resolveCallableInTable(&symbolTable); // A call to a externally-defined callable has unknown predecessors. @@ -382,14 +411,14 @@ void DeadCodeAnalysis::visitCallOperation(CallOpInterface call) { getOrCreate<PredecessorState>(getProgramPointAfter(callableOp)); propagateIfChanged(callsites, callsites->join(call)); LDBG() << "Added callsite as predecessor for callable: " - << callableOp->getName(); + << OpWithFlags(callableOp, OpPrintingFlags().skipRegions()); } else { // Mark this call op's predecessors as overdefined. auto *predecessors = getOrCreate<PredecessorState>(getProgramPointAfter(call)); propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors()); LDBG() << "Marked call op's predecessors as unknown for: " - << call.getOperation()->getName(); + << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions()); } } @@ -421,7 +450,8 @@ DeadCodeAnalysis::getOperandValues(Operation *op) { } void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { - LDBG() << "visitBranchOperation: " << branch.getOperation()->getName(); + LDBG() << "visitBranchOperation: " + << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions()); // Try to deduce a single successor for the branch. std::optional<SmallVector<Attribute>> operands = getOperandValues(branch); if (!operands) @@ -440,7 +470,8 @@ void DeadCodeAnalysis::visitBranchOperation(BranchOpInterface branch) { void DeadCodeAnalysis::visitRegionBranchOperation( RegionBranchOpInterface branch) { - LDBG() << "visitRegionBranchOperation: " << branch.getOperation()->getName(); + LDBG() << "visitRegionBranchOperation: " + << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions()); // Try to deduce which regions are executable. std::optional<SmallVector<Attribute>> operands = getOperandValues(branch); if (!operands) @@ -517,14 +548,14 @@ void DeadCodeAnalysis::visitCallableTerminator(Operation *op, if (canResolve) { propagateIfChanged(predecessors, predecessors->join(op)); LDBG() << "Added callable terminator as predecessor for callsite: " - << predecessor->getName(); + << OpWithFlags(predecessor, OpPrintingFlags().skipRegions()); } else { // If the terminator is not a return-like, then conservatively assume we // can't resolve the predecessor. propagateIfChanged(predecessors, predecessors->setHasUnknownPredecessors()); LDBG() << "Could not resolve callable terminator for callsite: " - << predecessor->getName(); + << OpWithFlags(predecessor, OpPrintingFlags().skipRegions()); } } } diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index c7a950d..e79f6a8 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -19,6 +19,8 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" @@ -28,6 +30,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #include <optional> #include <utility> @@ -87,7 +90,8 @@ LogicalResult IntegerRangeAnalysis::visitOperation( return success(); } - LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); + LDBG() << "Inferring ranges for " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); auto argRanges = llvm::map_to_vector( operands, [](const IntegerValueRangeLattice *lattice) { return lattice->getValue(); @@ -99,7 +103,7 @@ LogicalResult IntegerRangeAnalysis::visitOperation( return; assert(llvm::is_contained(op->getResults(), result)); - LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + LDBG() << "Inferred range " << attrs; IntegerValueRangeLattice *lattice = results[result.getResultNumber()]; IntegerValueRange oldRange = lattice->getValue(); @@ -114,7 +118,7 @@ LogicalResult IntegerRangeAnalysis::visitOperation( }); if (isYieldedResult && !oldRange.isUninitialized() && !(lattice->getValue() == oldRange)) { - LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + LDBG() << "Loop variant loop result detected"; changed |= lattice->join(IntegerValueRange::getMaxRange(v)); } propagateIfChanged(lattice, changed); @@ -128,7 +132,8 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( Operation *op, const RegionSuccessor &successor, ArrayRef<IntegerValueRangeLattice *> argLattices, unsigned firstIndex) { if (auto inferrable = dyn_cast<InferIntRangeInterface>(op)) { - LLVM_DEBUG(llvm::dbgs() << "Inferring ranges for " << *op << "\n"); + LDBG() << "Inferring ranges for " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); auto argRanges = llvm::map_to_vector(op->getOperands(), [&](Value value) { return getLatticeElementFor(getProgramPointAfter(op), value)->getValue(); @@ -141,7 +146,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( if (!llvm::is_contained(successor.getSuccessor()->getArguments(), arg)) return; - LLVM_DEBUG(llvm::dbgs() << "Inferred range " << attrs << "\n"); + LDBG() << "Inferred range " << attrs; IntegerValueRangeLattice *lattice = argLattices[arg.getArgNumber()]; IntegerValueRange oldRange = lattice->getValue(); @@ -156,7 +161,7 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( }); if (isYieldedValue && !oldRange.isUninitialized() && !(lattice->getValue() == oldRange)) { - LLVM_DEBUG(llvm::dbgs() << "Loop variant loop result detected\n"); + LDBG() << "Loop variant loop result detected"; changed |= lattice->join(IntegerValueRange::getMaxRange(v)); } propagateIfChanged(lattice, changed); diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp index 509f520..65df355 100644 --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -294,7 +294,34 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) { solver.load<LivenessAnalysis>(symbolTable); LDBG() << "Initializing and running solver"; (void)solver.initializeAndRun(op); - LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName(); + LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName() + << " check on unreachable code now:"; + // The framework doesn't visit operations in dead blocks, so we need to + // explicitly mark them as dead. + op->walk([&](Operation *op) { + if (op->getNumResults() == 0) + return; + for (auto result : llvm::enumerate(op->getResults())) { + if (getLiveness(result.value())) + continue; + LDBG() << "Result: " << result.index() << " of " + << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " has no liveness info (unreachable), mark dead"; + solver.getOrCreateState<Liveness>(result.value()); + } + for (auto ®ion : op->getRegions()) { + for (auto &block : region) { + for (auto blockArg : llvm::enumerate(block.getArguments())) { + if (getLiveness(blockArg.value())) + continue; + LDBG() << "Block argument: " << blockArg.index() << " of " + << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " has no liveness info, mark dead"; + solver.getOrCreateState<Liveness>(blockArg.value()); + } + } + } + }); } const Liveness *RunLivenessAnalysis::getLiveness(Value val) { diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp index e625f62..13a3e14 100644 --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -19,12 +19,15 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #include <optional> using namespace mlir; using namespace mlir::dataflow; +#define DEBUG_TYPE "dataflow" + //===----------------------------------------------------------------------===// // AbstractSparseLattice //===----------------------------------------------------------------------===// @@ -64,22 +67,36 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) { LogicalResult AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) { + LDBG() << "Initializing recursively for operation: " << op->getName(); + // Initialize the analysis by visiting every owner of an SSA value (all // operations and blocks). - if (failed(visitOperation(op))) + if (failed(visitOperation(op))) { + LDBG() << "Failed to visit operation: " << op->getName(); return failure(); + } for (Region ®ion : op->getRegions()) { + LDBG() << "Processing region with " << region.getBlocks().size() + << " blocks"; for (Block &block : region) { + LDBG() << "Processing block with " << block.getNumArguments() + << " arguments"; getOrCreate<Executable>(getProgramPointBefore(&block)) ->blockContentSubscribe(this); visitBlock(&block); - for (Operation &op : block) - if (failed(initializeRecursively(&op))) + for (Operation &op : block) { + LDBG() << "Recursively initializing nested operation: " << op.getName(); + if (failed(initializeRecursively(&op))) { + LDBG() << "Failed to initialize nested operation: " << op.getName(); return failure(); + } + } } } + LDBG() << "Successfully completed recursive initialization for operation: " + << op->getName(); return success(); } @@ -409,11 +426,20 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) { LogicalResult AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { + LDBG() << "Visiting operation: " << op->getName() << " with " + << op->getNumOperands() << " operands and " << op->getNumResults() + << " results"; + // If we're in a dead block, bail out. if (op->getBlock() != nullptr && - !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive()) + !getOrCreate<Executable>(getProgramPointBefore(op->getBlock())) + ->isLive()) { + LDBG() << "Operation is in dead block, bailing out"; return success(); + } + LDBG() << "Creating lattice elements for " << op->getNumOperands() + << " operands and " << op->getNumResults() << " results"; SmallVector<AbstractSparseLattice *> operandLattices = getLatticeElements(op->getOperands()); SmallVector<const AbstractSparseLattice *> resultLattices = @@ -422,11 +448,15 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // Block arguments of region branch operations flow back into the operands // of the parent op if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { + LDBG() << "Processing RegionBranchOpInterface operation"; visitRegionSuccessors(branch, operandLattices); return success(); } if (auto branch = dyn_cast<BranchOpInterface>(op)) { + LDBG() << "Processing BranchOpInterface operation with " + << op->getNumSuccessors() << " successors"; + // Block arguments of successor blocks flow back into our operands. // We remember all operands not forwarded to any block in a BitVector. @@ -463,6 +493,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // For function calls, connect the arguments of the entry blocks to the // operands of the call op that are forwarded to these arguments. if (auto call = dyn_cast<CallOpInterface>(op)) { + LDBG() << "Processing CallOpInterface operation"; Operation *callableOp = call.resolveCallableInTable(&symbolTable); if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) { // Not all operands of a call op forward to arguments. Such operands are @@ -513,6 +544,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { // of this op itself and the operands of the terminators of the regions of // this op. if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) { + LDBG() << "Processing RegionBranchTerminatorOpInterface operation"; if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) { visitRegionSuccessorsFromTerminator(terminator, branch); return success(); @@ -520,12 +552,16 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) { } if (op->hasTrait<OpTrait::ReturnLike>()) { + LDBG() << "Processing ReturnLike operation"; // Going backwards, the operands of the return are derived from the // results of all CallOps calling this CallableOp. - if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) + if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) { + LDBG() << "Callable parent found, visiting callable operation"; return visitCallableOperation(op, callable, operandLattices); + } } + LDBG() << "Using default visitOperationImpl for operation: " << op->getName(); return visitOperationImpl(op, operandLattices, resultLattices); } diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp index 16f7033..7e1b405 100644 --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -45,7 +45,7 @@ void AnalysisState::addDependency(ProgramPoint *dependent, DATAFLOW_DEBUG({ if (inserted) { LDBG() << "Creating dependency between " << debugName << " of " << anchor - << "\nand " << debugName << " on " << dependent; + << "\nand " << debugName << " on " << *dependent; } }); } @@ -62,11 +62,12 @@ void ProgramPoint::print(raw_ostream &os) const { return; } if (!isBlockStart()) { - os << "<after operation>:"; - return getPrevOp()->print(os, OpPrintingFlags().skipRegions()); + os << "<after operation>:" + << OpWithFlags(getPrevOp(), OpPrintingFlags().skipRegions()); + return; } - os << "<before operation>:"; - return getNextOp()->print(os, OpPrintingFlags().skipRegions()); + os << "<before operation>:" + << OpWithFlags(getNextOp(), OpPrintingFlags().skipRegions()); } //===----------------------------------------------------------------------===// @@ -78,8 +79,8 @@ void LatticeAnchor::print(raw_ostream &os) const { os << "<NULL POINT>"; return; } - if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this)) - return LatticeAnchor->print(os); + if (auto *latticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this)) + return latticeAnchor->print(os); if (auto value = llvm::dyn_cast<Value>(*this)) { return value.print(os, OpPrintingFlags().skipRegions()); } @@ -88,8 +89,8 @@ void LatticeAnchor::print(raw_ostream &os) const { } Location LatticeAnchor::getLoc() const { - if (auto *LatticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this)) - return LatticeAnchor->getLoc(); + if (auto *latticeAnchor = llvm::dyn_cast<GenericLatticeAnchor *>(*this)) + return latticeAnchor->getLoc(); if (auto value = llvm::dyn_cast<Value>(*this)) return value.getLoc(); @@ -128,7 +129,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) { worklist.pop(); DATAFLOW_DEBUG(LDBG() << "Invoking '" << analysis->debugName - << "' on: " << point); + << "' on: " << *point); if (failed(analysis->visit(point))) return failure(); } diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp index f4b02b4..30ce1fb 100644 --- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp +++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp @@ -60,7 +60,7 @@ private: AffineExpr localExpr) override { SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr); // Update localVarCst. - localVarCst.addLocalFloorDiv(dividend, divisor); + (void)localVarCst.addLocalFloorDiv(dividend, divisor); } LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs, diff --git a/mlir/lib/Analysis/Presburger/Barvinok.cpp b/mlir/lib/Analysis/Presburger/Barvinok.cpp index 4546e49..75d592e 100644 --- a/mlir/lib/Analysis/Presburger/Barvinok.cpp +++ b/mlir/lib/Analysis/Presburger/Barvinok.cpp @@ -554,7 +554,7 @@ QuasiPolynomial mlir::presburger::detail::getCoefficientInRationalFunction( /// t^num / \prod_j (1 - t^dens[j]). /// v represents the affine functions whose floors are multiplied by the /// generators, and ds represents the list of generators. -std::pair<QuasiPolynomial, std::vector<Fraction>> +static std::pair<QuasiPolynomial, std::vector<Fraction>> substituteMuInTerm(unsigned numParams, const ParamPoint &v, const std::vector<Point> &ds, const Point &mu) { unsigned numDims = mu.size(); @@ -606,8 +606,8 @@ substituteMuInTerm(unsigned numParams, const ParamPoint &v, /// Here, sign = ± 1, /// num is a QuasiPolynomial, and /// each dens[j] is a Fraction. -void normalizeDenominatorExponents(int &sign, QuasiPolynomial &num, - std::vector<Fraction> &dens) { +static void normalizeDenominatorExponents(int &sign, QuasiPolynomial &num, + std::vector<Fraction> &dens) { // We track the number of exponents that are negative in the // denominator, and convert them to their absolute values. unsigned numNegExps = 0; @@ -634,8 +634,8 @@ void normalizeDenominatorExponents(int &sign, QuasiPolynomial &num, /// Compute the binomial coefficients nCi for 0 ≤ i ≤ r, /// where n is a QuasiPolynomial. -std::vector<QuasiPolynomial> getBinomialCoefficients(const QuasiPolynomial &n, - unsigned r) { +static std::vector<QuasiPolynomial> +getBinomialCoefficients(const QuasiPolynomial &n, unsigned r) { unsigned numParams = n.getNumInputs(); std::vector<QuasiPolynomial> coefficients; coefficients.reserve(r + 1); @@ -651,8 +651,8 @@ std::vector<QuasiPolynomial> getBinomialCoefficients(const QuasiPolynomial &n, /// Compute the binomial coefficients nCi for 0 ≤ i ≤ r, /// where n is a QuasiPolynomial. -std::vector<Fraction> getBinomialCoefficients(const Fraction &n, - const Fraction &r) { +static std::vector<Fraction> getBinomialCoefficients(const Fraction &n, + const Fraction &r) { std::vector<Fraction> coefficients; coefficients.reserve((int64_t)floor(r)); coefficients.emplace_back(1); diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 5c4d4d1..0dcdd5b 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -1500,12 +1500,13 @@ void IntegerRelation::addBound(BoundType type, ArrayRef<DynamicAPInt> expr, /// respect to a positive constant 'divisor'. Two constraints are added to the /// system to capture equivalence with the floordiv. /// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1. -void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend, - const DynamicAPInt &divisor) { +/// Returns the column position of the new local variable. +unsigned IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend, + const DynamicAPInt &divisor) { assert(dividend.size() == getNumCols() && "incorrect dividend size"); assert(divisor > 0 && "positive divisor expected"); - appendVar(VarKind::Local); + unsigned newVar = appendVar(VarKind::Local); SmallVector<DynamicAPInt, 8> dividendCopy(dividend); dividendCopy.insert(dividendCopy.end() - 1, DynamicAPInt(0)); @@ -1513,6 +1514,28 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend, getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2)); addInequality( getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2)); + return newVar; +} + +unsigned IntegerRelation::addLocalModulo(ArrayRef<DynamicAPInt> exprs, + const DynamicAPInt &modulus) { + assert(exprs.size() == getNumCols() && "incorrect exprs size"); + assert(modulus > 0 && "positive modulus expected"); + + /// Add a local variable for q = expr floordiv modulus + addLocalFloorDiv(exprs, modulus); + + /// Add a local var to represent the result + auto resultIndex = appendVar(VarKind::Local); + + SmallVector<DynamicAPInt, 8> exprsCopy(exprs); + /// Insert the two new locals before the constant + /// Add locals that correspond to `q` and `result` to compute + /// 0 = (expr - modulus * q) - result + exprsCopy.insert(exprsCopy.end() - 1, + {DynamicAPInt(-modulus), DynamicAPInt(-1)}); + addEquality(exprsCopy); + return resultIndex; } int IntegerRelation::findEqualityToConstant(unsigned pos, bool symbolic) const { diff --git a/mlir/lib/Analysis/Presburger/Matrix.cpp b/mlir/lib/Analysis/Presburger/Matrix.cpp index 9fc6205..bb60564 100644 --- a/mlir/lib/Analysis/Presburger/Matrix.cpp +++ b/mlir/lib/Analysis/Presburger/Matrix.cpp @@ -402,10 +402,10 @@ void Matrix<T>::print(raw_ostream &os) const { for (unsigned row = 0; row < nRows; ++row) for (unsigned column = 0; column < nColumns; ++column) updatePrintMetrics<T>(at(row, column), ptm); - unsigned MIN_SPACING = 1; + unsigned minSpacing = 1; for (unsigned row = 0; row < nRows; ++row) { for (unsigned column = 0; column < nColumns; ++column) { - printWithPrintMetrics<T>(os, at(row, column), MIN_SPACING, ptm); + printWithPrintMetrics<T>(os, at(row, column), minSpacing, ptm); } os << "\n"; } @@ -721,7 +721,7 @@ FracMatrix FracMatrix::gramSchmidt() const { // Otherwise, we swap b_k and b_{k-1} and decrement k. // // We repeat this until k = n and return. -void FracMatrix::LLL(Fraction delta) { +void FracMatrix::LLL(const Fraction &delta) { DynamicAPInt nearest; Fraction mu; diff --git a/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp b/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp index 84d885f..4e374d0 100644 --- a/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp +++ b/mlir/lib/Analysis/Presburger/QuasiPolynomial.cpp @@ -112,14 +112,14 @@ QuasiPolynomial QuasiPolynomial::simplify() { // A term is zero if its coefficient is zero, or if (coefficients[i] == Fraction(0, 1)) continue; - bool product_is_zero = + bool productIsZero = // if any of the affine functions in the product - llvm::any_of(affine[i], [](const SmallVector<Fraction> &affine_ij) { + llvm::any_of(affine[i], [](const SmallVector<Fraction> &affineIj) { // has all its coefficients as zero. - return llvm::all_of(affine_ij, + return llvm::all_of(affineIj, [](const Fraction &f) { return f == 0; }); }); - if (product_is_zero) + if (productIsZero) continue; // Now, we know the term is nonzero. diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp index 08290db..a1cbe29 100644 --- a/mlir/lib/Analysis/Presburger/Simplex.cpp +++ b/mlir/lib/Analysis/Presburger/Simplex.cpp @@ -433,7 +433,7 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) { normalizeDiv(divCoeffs, divDenom); domainSimplex.addDivisionVariable(divCoeffs, divDenom); - domainPoly.addLocalFloorDiv(divCoeffs, divDenom); + (void)domainPoly.addLocalFloorDiv(divCoeffs, divDenom); // Update `this` to account for the additional symbol we just added. appendSymbol(); @@ -1663,7 +1663,7 @@ public: /// First pushes a snapshot for the current simplex state to the stack so /// that this can be rolled back later. void addEqualityForDirection(ArrayRef<DynamicAPInt> dir) { - assert(llvm::any_of(dir, [](const DynamicAPInt &x) { return x != 0; }) && + assert(llvm::any_of(dir, [](const DynamicAPInt &X) { return X != 0; }) && "Direction passed is the zero vector!"); snapshotStack.emplace_back(simplex.getSnapshot()); simplex.addEquality(getCoeffsForDirection(dir)); @@ -2156,10 +2156,10 @@ void SimplexBase::print(raw_ostream &os) const { for (unsigned row = 0, numRows = getNumRows(); row < numRows; ++row) for (unsigned col = 0, numCols = getNumColumns(); col < numCols; ++col) updatePrintMetrics<DynamicAPInt>(tableau(row, col), ptm); - unsigned MIN_SPACING = 1; + unsigned minSpacing = 1; for (unsigned row = 0, numRows = getNumRows(); row < numRows; ++row) { for (unsigned col = 0, numCols = getNumColumns(); col < numCols; ++col) { - printWithPrintMetrics<DynamicAPInt>(os, tableau(row, col), MIN_SPACING, + printWithPrintMetrics<DynamicAPInt>(os, tableau(row, col), minSpacing, ptm); } os << '\n'; diff --git a/mlir/lib/Analysis/TopologicalSortUtils.cpp b/mlir/lib/Analysis/TopologicalSortUtils.cpp index a2fd149..99546e7 100644 --- a/mlir/lib/Analysis/TopologicalSortUtils.cpp +++ b/mlir/lib/Analysis/TopologicalSortUtils.cpp @@ -101,12 +101,7 @@ bool mlir::sortTopologically( bool mlir::sortTopologically( Block *block, function_ref<bool(Value, Operation *)> isOperandReady) { - if (block->empty()) - return true; - if (block->back().hasTrait<OpTrait::IsTerminator>()) - return sortTopologically(block, block->without_terminator(), - isOperandReady); - return sortTopologically(block, *block, isOperandReady); + return sortTopologically(block, block->without_terminator(), isOperandReady); } bool mlir::computeTopologicalSorting( diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp index e5045cf..a21176fff 100644 --- a/mlir/lib/Bindings/Python/DialectGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectGPU.cpp @@ -9,8 +9,8 @@ #include "mlir-c/Dialect/GPU.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace nanobind::literals; @@ -34,7 +34,7 @@ NB_MODULE(_mlirDialectsGPU, m) { mlirGPUAsyncTokenType.def_classmethod( "get", - [](nb::object cls, MlirContext ctx) { + [](const nb::object &cls, MlirContext ctx) { return cls(mlirGPUAsyncTokenTypeGet(ctx)); }, "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"), @@ -47,8 +47,9 @@ NB_MODULE(_mlirDialectsGPU, m) { mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr) .def_classmethod( "get", - [](nb::object cls, MlirAttribute target, uint32_t format, - nb::bytes object, std::optional<MlirAttribute> mlirObjectProps, + [](const nb::object &cls, MlirAttribute target, uint32_t format, + const nb::bytes &object, + std::optional<MlirAttribute> mlirObjectProps, std::optional<MlirAttribute> mlirKernelsAttr) { MlirStringRef objectStrRef = mlirStringRefCreate( static_cast<char *>(const_cast<void *>(object.data())), diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index f211e76..ee106c0 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -12,8 +12,8 @@ #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Diagnostics.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; @@ -24,7 +24,7 @@ using namespace mlir; using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; -void populateDialectLLVMSubmodule(const nanobind::module_ &m) { +static void populateDialectLLVMSubmodule(const nanobind::module_ &m) { //===--------------------------------------------------------------------===// // StructType @@ -35,8 +35,8 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) { llvmStructType.def_classmethod( "get_literal", - [](nb::object cls, const std::vector<MlirType> &elements, bool packed, - MlirLocation loc) { + [](const nb::object &cls, const std::vector<MlirType> &elements, + bool packed, MlirLocation loc) { CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc)); MlirType type = mlirLLVMStructTypeLiteralGetChecked( @@ -51,7 +51,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) { llvmStructType.def_classmethod( "get_identified", - [](nb::object cls, const std::string &name, MlirContext context) { + [](const nb::object &cls, const std::string &name, MlirContext context) { return cls(mlirLLVMStructTypeIdentifiedGet( context, mlirStringRefCreate(name.data(), name.size()))); }, @@ -59,7 +59,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) { llvmStructType.def_classmethod( "get_opaque", - [](nb::object cls, const std::string &name, MlirContext context) { + [](const nb::object &cls, const std::string &name, MlirContext context) { return cls(mlirLLVMStructTypeOpaqueGet( context, mlirStringRefCreate(name.data(), name.size()))); }, @@ -79,7 +79,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) { llvmStructType.def_classmethod( "new_identified", - [](nb::object cls, const std::string &name, + [](const nb::object &cls, const std::string &name, const std::vector<MlirType> &elements, bool packed, MlirContext ctx) { return cls(mlirLLVMStructTypeIdentifiedNewGet( ctx, mlirStringRefCreate(name.data(), name.length()), @@ -123,7 +123,7 @@ void populateDialectLLVMSubmodule(const nanobind::module_ &m) { mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType) .def_classmethod( "get", - [](nb::object cls, std::optional<unsigned> addressSpace, + [](const nb::object &cls, std::optional<unsigned> addressSpace, MlirContext context) { CollectDiagnosticsToStringScope scope(context); MlirType type = mlirLLVMPointerTypeGet( diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp index a0d6a4b..bb3f519c 100644 --- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp @@ -8,8 +8,8 @@ #include "mlir-c/Dialect/NVGPU.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace llvm; @@ -23,8 +23,8 @@ static void populateDialectNVGPUSubmodule(const nb::module_ &m) { nvgpuTensorMapDescriptorType.def_classmethod( "get", - [](nb::object cls, MlirType tensorMemrefType, int swizzle, int l2promo, - int oobFill, int interleave, MlirContext ctx) { + [](const nb::object &cls, MlirType tensorMemrefType, int swizzle, + int l2promo, int oobFill, int interleave, MlirContext ctx) { return cls(mlirNVGPUTensorMapDescriptorTypeGet( ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave)); }, diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp index bcc6ff4..2acedbc 100644 --- a/mlir/lib/Bindings/Python/DialectPDL.cpp +++ b/mlir/lib/Bindings/Python/DialectPDL.cpp @@ -8,8 +8,8 @@ #include "mlir-c/Dialect/PDL.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace llvm; @@ -17,7 +17,7 @@ using namespace mlir; using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; -void populateDialectPDLSubmodule(const nanobind::module_ &m) { +static void populateDialectPDLSubmodule(const nanobind::module_ &m) { //===-------------------------------------------------------------------===// // PDLType //===-------------------------------------------------------------------===// @@ -32,7 +32,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) { mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType); attributeType.def_classmethod( "get", - [](nb::object cls, MlirContext ctx) { + [](const nb::object &cls, MlirContext ctx) { return cls(mlirPDLAttributeTypeGet(ctx)); }, "Get an instance of AttributeType in given context.", nb::arg("cls"), @@ -46,7 +46,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) { mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType); operationType.def_classmethod( "get", - [](nb::object cls, MlirContext ctx) { + [](const nb::object &cls, MlirContext ctx) { return cls(mlirPDLOperationTypeGet(ctx)); }, "Get an instance of OperationType in given context.", nb::arg("cls"), @@ -59,7 +59,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) { auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType); rangeType.def_classmethod( "get", - [](nb::object cls, MlirType elementType) { + [](const nb::object &cls, MlirType elementType) { return cls(mlirPDLRangeTypeGet(elementType)); }, "Gets an instance of RangeType in the same context as the provided " @@ -77,7 +77,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) { auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType); typeType.def_classmethod( "get", - [](nb::object cls, MlirContext ctx) { + [](const nb::object &cls, MlirContext ctx) { return cls(mlirPDLTypeTypeGet(ctx)); }, "Get an instance of TypeType in given context.", nb::arg("cls"), @@ -90,7 +90,7 @@ void populateDialectPDLSubmodule(const nanobind::module_ &m) { auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType); valueType.def_classmethod( "get", - [](nb::object cls, MlirContext ctx) { + [](const nb::object &cls, MlirContext ctx) { return cls(mlirPDLValueTypeGet(ctx)); }, "Get an instance of TypeType in given context.", nb::arg("cls"), diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp index 55571cd..a5220fc 100644 --- a/mlir/lib/Bindings/Python/DialectQuant.cpp +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -165,7 +165,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) { quantizedType.get_class()); anyQuantizedType.def_classmethod( "get", - [](nb::object cls, unsigned flags, MlirType storageType, + [](const nb::object &cls, unsigned flags, MlirType storageType, MlirType expressedType, int64_t storageTypeMin, int64_t storageTypeMax) { return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, @@ -186,7 +186,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) { quantizedType.get_class()); uniformQuantizedType.def_classmethod( "get", - [](nb::object cls, unsigned flags, MlirType storageType, + [](const nb::object &cls, unsigned flags, MlirType storageType, MlirType expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) { return cls(mlirUniformQuantizedTypeGet(flags, storageType, @@ -221,7 +221,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) { quantizedType.get_class()); uniformQuantizedPerAxisType.def_classmethod( "get", - [](nb::object cls, unsigned flags, MlirType storageType, + [](const nb::object &cls, unsigned flags, MlirType storageType, MlirType expressedType, std::vector<double> scales, std::vector<int64_t> zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax) { @@ -293,7 +293,7 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) { mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class()); uniformQuantizedSubChannelType.def_classmethod( "get", - [](nb::object cls, unsigned flags, MlirType storageType, + [](const nb::object &cls, unsigned flags, MlirType storageType, MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints, std::vector<int32_t> quantizedDimensions, std::vector<int64_t> blockSizes, int64_t storageTypeMin, @@ -367,7 +367,8 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) { quantizedType.get_class()); calibratedQuantizedType.def_classmethod( "get", - [](nb::object cls, MlirType expressedType, double min, double max) { + [](const nb::object &cls, MlirType expressedType, double min, + double max) { return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); }, "Gets an instance of CalibratedQuantizedType in the same context as the " diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp index 4e76477..cab4219 100644 --- a/mlir/lib/Bindings/Python/DialectSMT.cpp +++ b/mlir/lib/Bindings/Python/DialectSMT.cpp @@ -24,7 +24,7 @@ using namespace mlir; using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; -void populateDialectSMTSubmodule(nanobind::module_ &m) { +static void populateDialectSMTSubmodule(nanobind::module_ &m) { auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool) .def_classmethod( diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 97cebcc..9d7dc11 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -12,8 +12,8 @@ #include "mlir-c/AffineMap.h" #include "mlir-c/Dialect/SparseTensor.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace llvm; @@ -38,7 +38,8 @@ static void populateDialectSparseTensorSubmodule(const nb::module_ &m) { mlirAttributeIsASparseTensorEncodingAttr) .def_classmethod( "get", - [](nb::object cls, std::vector<MlirSparseTensorLevelType> lvlTypes, + [](const nb::object &cls, + std::vector<MlirSparseTensorLevelType> lvlTypes, std::optional<MlirAffineMap> dimToLvl, std::optional<MlirAffineMap> lvlToDim, int posWidth, int crdWidth, std::optional<MlirAttribute> explicitVal, @@ -58,7 +59,7 @@ static void populateDialectSparseTensorSubmodule(const nb::module_ &m) { "Gets a sparse_tensor.encoding from parameters.") .def_classmethod( "build_level_type", - [](nb::object cls, MlirSparseTensorLevelFormat lvlFmt, + [](const nb::object &cls, MlirSparseTensorLevelFormat lvlFmt, const std::vector<MlirSparseTensorLevelPropertyNondefault> &properties, unsigned n, unsigned m) { diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp index 59a030a..1a62b06 100644 --- a/mlir/lib/Bindings/Python/DialectTransform.cpp +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -11,15 +11,15 @@ #include "mlir-c/Dialect/Transform.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace mlir; using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; -void populateDialectTransformSubmodule(const nb::module_ &m) { +static void populateDialectTransformSubmodule(const nb::module_ &m) { //===-------------------------------------------------------------------===// // AnyOpType //===-------------------------------------------------------------------===// @@ -29,7 +29,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) { mlirTransformAnyOpTypeGetTypeID); anyOpType.def_classmethod( "get", - [](nb::object cls, MlirContext ctx) { + [](const nb::object &cls, MlirContext ctx) { return cls(mlirTransformAnyOpTypeGet(ctx)); }, "Get an instance of AnyOpType in the given context.", nb::arg("cls"), @@ -44,7 +44,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) { mlirTransformAnyParamTypeGetTypeID); anyParamType.def_classmethod( "get", - [](nb::object cls, MlirContext ctx) { + [](const nb::object &cls, MlirContext ctx) { return cls(mlirTransformAnyParamTypeGet(ctx)); }, "Get an instance of AnyParamType in the given context.", nb::arg("cls"), @@ -59,7 +59,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) { mlirTransformAnyValueTypeGetTypeID); anyValueType.def_classmethod( "get", - [](nb::object cls, MlirContext ctx) { + [](const nb::object &cls, MlirContext ctx) { return cls(mlirTransformAnyValueTypeGet(ctx)); }, "Get an instance of AnyValueType in the given context.", nb::arg("cls"), @@ -74,7 +74,8 @@ void populateDialectTransformSubmodule(const nb::module_ &m) { mlirTransformOperationTypeGetTypeID); operationType.def_classmethod( "get", - [](nb::object cls, const std::string &operationName, MlirContext ctx) { + [](const nb::object &cls, const std::string &operationName, + MlirContext ctx) { MlirStringRef cOperationName = mlirStringRefCreate(operationName.data(), operationName.size()); return cls(mlirTransformOperationTypeGet(ctx, cOperationName)); @@ -101,7 +102,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) { mlirTransformParamTypeGetTypeID); paramType.def_classmethod( "get", - [](nb::object cls, MlirType type, MlirContext ctx) { + [](const nb::object &cls, MlirType type, MlirContext ctx) { return cls(mlirTransformParamTypeGet(ctx, type)); }, "Get an instance of ParamType for the given type in the given context.", diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 81dada3..8bb493e 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir-c/ExecutionEngine.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace mlir; @@ -45,7 +45,7 @@ public: referencedObjects.push_back(obj); } - static nb::object createFromCapsule(nb::object capsule) { + static nb::object createFromCapsule(const nb::object &capsule) { MlirExecutionEngine rawPm = mlirPythonCapsuleToExecutionEngine(capsule.ptr()); if (mlirExecutionEngineIsNull(rawPm)) @@ -113,7 +113,7 @@ NB_MODULE(_mlirExecutionEngine, m) { .def( "raw_register_runtime", [](PyExecutionEngine &executionEngine, const std::string &name, - nb::object callbackObj) { + const nb::object &callbackObj) { executionEngine.addReferencedObject(callbackObj); uintptr_t rawSym = nb::cast<uintptr_t>(nb::getattr(callbackObj, "value")); @@ -125,6 +125,17 @@ NB_MODULE(_mlirExecutionEngine, m) { nb::arg("name"), nb::arg("callback"), "Register `callback` as the runtime symbol `name`.") .def( + "initialize", + [](PyExecutionEngine &executionEngine) { + mlirExecutionEngineInitialize(executionEngine.get()); + }, + "Initialize the ExecutionEngine. Global constructors specified by " + "`llvm.mlir.global_ctors` will be run. One common scenario is that " + "kernel binary compiled from `gpu.module` gets loaded during " + "initialization. Make sure all symbols are resolvable before " + "initialization by calling `register_runtime` or including " + "shared libraries.") + .def( "dump_to_object_file", [](PyExecutionEngine &executionEngine, const std::string &fileName) { mlirExecutionEngineDumpToObjectFile( diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 826a34a..71a051c 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -10,15 +10,19 @@ #define MLIR_BINDINGS_PYTHON_GLOBALS_H #include <optional> +#include <regex> #include <string> +#include <unordered_set> #include <vector> #include "NanobindUtils.h" #include "mlir-c/IR.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/Regex.h" namespace mlir { namespace python { @@ -114,6 +118,39 @@ public: std::optional<nanobind::object> lookupOperationClass(llvm::StringRef operationName); + class TracebackLoc { + public: + bool locTracebacksEnabled(); + + void setLocTracebacksEnabled(bool value); + + size_t locTracebackFramesLimit(); + + void setLocTracebackFramesLimit(size_t value); + + void registerTracebackFileInclusion(const std::string &file); + + void registerTracebackFileExclusion(const std::string &file); + + bool isUserTracebackFilename(llvm::StringRef file); + + static constexpr size_t kMaxFrames = 512; + + private: + nanobind::ft_mutex mutex; + bool locTracebackEnabled_ = false; + size_t locTracebackFramesLimit_ = 10; + std::unordered_set<std::string> userTracebackIncludeFiles; + std::unordered_set<std::string> userTracebackExcludeFiles; + std::regex userTracebackIncludeRegex; + bool rebuildUserTracebackIncludeRegex = false; + std::regex userTracebackExcludeRegex; + bool rebuildUserTracebackExcludeRegex = false; + llvm::StringMap<bool> isUserTracebackFilenameCache; + }; + + TracebackLoc &getTracebackLoc() { return tracebackLoc; } + private: static PyGlobals *instance; @@ -134,6 +171,8 @@ private: /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModules; + + TracebackLoc tracebackLoc; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 50f2a4f..a6499c9 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -17,9 +17,9 @@ #include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir-c/IntegerSet.h" #include "mlir/Bindings/Python/Nanobind.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir/Support/LLVM.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SmallVector.h" @@ -64,7 +64,7 @@ static void pyListToVector(const nb::list &list, } template <typename PermutationTy> -static bool isPermutation(std::vector<PermutationTy> permutation) { +static bool isPermutation(const std::vector<PermutationTy> &permutation) { llvm::SmallVector<bool, 8> seen(permutation.size(), false); for (auto val : permutation) { if (val < permutation.size()) { @@ -366,7 +366,7 @@ nb::object PyAffineExpr::getCapsule() { return nb::steal<nb::object>(mlirPythonAffineExprToCapsule(*this)); } -PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) { +PyAffineExpr PyAffineExpr::createFromCapsule(const nb::object &capsule) { MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); if (mlirAffineExprIsNull(rawAffineExpr)) throw nb::python_error(); @@ -424,7 +424,7 @@ nb::object PyAffineMap::getCapsule() { return nb::steal<nb::object>(mlirPythonAffineMapToCapsule(*this)); } -PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) { +PyAffineMap PyAffineMap::createFromCapsule(const nb::object &capsule) { MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); if (mlirAffineMapIsNull(rawAffineMap)) throw nb::python_error(); @@ -500,7 +500,7 @@ nb::object PyIntegerSet::getCapsule() { return nb::steal<nb::object>(mlirPythonIntegerSetToCapsule(*this)); } -PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) { +PyIntegerSet PyIntegerSet::createFromCapsule(const nb::object &capsule) { MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); if (mlirIntegerSetIsNull(rawIntegerSet)) throw nb::python_error(); @@ -708,7 +708,8 @@ void mlir::python::populateIRAffine(nb::module_ &m) { return static_cast<size_t>(llvm::hash_value(self.get().ptr)); }) .def_static("compress_unused_symbols", - [](nb::list affineMaps, DefaultingPyMlirContext context) { + [](const nb::list &affineMaps, + DefaultingPyMlirContext context) { SmallVector<MlirAffineMap> maps; pyListToVector<PyAffineMap, MlirAffineMap>( affineMaps, maps, "attempting to create an AffineMap"); @@ -734,7 +735,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { kDumpDocstring) .def_static( "get", - [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs, + [](intptr_t dimCount, intptr_t symbolCount, const nb::list &exprs, DefaultingPyMlirContext context) { SmallVector<MlirAffineExpr> affineExprs; pyListToVector<PyAffineExpr, MlirAffineExpr>( @@ -869,7 +870,8 @@ void mlir::python::populateIRAffine(nb::module_ &m) { .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) .def("__eq__", [](PyIntegerSet &self, PyIntegerSet &other) { return self == other; }) - .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; }) + .def("__eq__", + [](PyIntegerSet &self, const nb::object &other) { return false; }) .def("__str__", [](PyIntegerSet &self) { PyPrintAccumulator printAccum; @@ -898,7 +900,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { kDumpDocstring) .def_static( "get", - [](intptr_t numDims, intptr_t numSymbols, nb::list exprs, + [](intptr_t numDims, intptr_t numSymbols, const nb::list &exprs, std::vector<bool> eqFlags, DefaultingPyMlirContext context) { if (exprs.size() != eqFlags.size()) throw nb::value_error( @@ -934,8 +936,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) { nb::arg("context").none() = nb::none()) .def( "get_replaced", - [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs, - intptr_t numResultDims, intptr_t numResultSymbols) { + [](PyIntegerSet &self, const nb::list &dimExprs, + const nb::list &symbolExprs, intptr_t numResultDims, + intptr_t numResultSymbols) { if (static_cast<intptr_t>(dimExprs.size()) != mlirIntegerSetGetNumDims(self)) throw nb::value_error( diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index db84ee1..f2eafa7 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -505,7 +505,7 @@ public: static void bindDerived(ClassTy &c) { c.def_static( "get", - [](nb::list attributes, DefaultingPyMlirContext context) { + [](const nb::list &attributes, DefaultingPyMlirContext context) { SmallVector<MlirAttribute> mlirAttributes; mlirAttributes.reserve(nb::len(attributes)); for (auto attribute : attributes) { @@ -530,7 +530,7 @@ public: .def("__iter__", [](const PyArrayAttribute &arr) { return PyArrayAttributeIterator(arr); }); - c.def("__add__", [](PyArrayAttribute arr, nb::list extras) { + c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) { std::vector<MlirAttribute> attributes; intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); attributes.reserve(numOldElements + nb::len(extras)); @@ -708,7 +708,7 @@ public: static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string value, DefaultingPyMlirContext context) { + [](const std::string &value, DefaultingPyMlirContext context) { MlirAttribute attr = mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); return PyFlatSymbolRefAttribute(context->getRef(), attr); @@ -736,8 +736,8 @@ public: static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string dialectNamespace, nb_buffer buffer, PyType &type, - DefaultingPyMlirContext context) { + [](const std::string &dialectNamespace, const nb_buffer &buffer, + PyType &type, DefaultingPyMlirContext context) { const nb_buffer_info bufferInfo = buffer.request(); intptr_t bufferSize = bufferInfo.size; MlirAttribute attr = mlirOpaqueAttrGet( @@ -775,7 +775,7 @@ public: static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string value, DefaultingPyMlirContext context) { + [](const std::string &value, DefaultingPyMlirContext context) { MlirAttribute attr = mlirStringAttrGet(context->get(), toMlirStringRef(value)); return PyStringAttribute(context->getRef(), attr); @@ -784,7 +784,7 @@ public: "Gets a uniqued string attribute"); c.def_static( "get", - [](nb::bytes value, DefaultingPyMlirContext context) { + [](const nb::bytes &value, DefaultingPyMlirContext context) { MlirAttribute attr = mlirStringAttrGet(context->get(), toMlirStringRef(value)); return PyStringAttribute(context->getRef(), attr); @@ -793,7 +793,7 @@ public: "Gets a uniqued string attribute"); c.def_static( "get_typed", - [](PyType &type, std::string value) { + [](PyType &type, const std::string &value) { MlirAttribute attr = mlirStringAttrTypedGet(type, toMlirStringRef(value)); return PyStringAttribute(type.getContext(), attr); @@ -826,7 +826,7 @@ public: using PyConcreteAttribute::PyConcreteAttribute; static PyDenseElementsAttribute - getFromList(nb::list attributes, std::optional<PyType> explicitType, + getFromList(const nb::list &attributes, std::optional<PyType> explicitType, DefaultingPyMlirContext contextWrapper) { const size_t numAttributes = nb::len(attributes); if (numAttributes == 0) @@ -878,8 +878,8 @@ public: } static PyDenseElementsAttribute - getFromBuffer(nb_buffer array, bool signless, - std::optional<PyType> explicitType, + getFromBuffer(const nb_buffer &array, bool signless, + const std::optional<PyType> &explicitType, std::optional<std::vector<int64_t>> explicitShape, DefaultingPyMlirContext contextWrapper) { // Request a contiguous view. In exotic cases, this will cause a copy. @@ -894,8 +894,8 @@ public: auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); MlirContext context = contextWrapper->get(); - MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, - explicitShape, context); + MlirAttribute attr = getAttributeFromBuffer( + view, signless, explicitType, std::move(explicitShape), context); if (mlirAttributeIsNull(attr)) { throw std::invalid_argument( "DenseElementsAttr could not be constructed from the given buffer. " @@ -1092,16 +1092,16 @@ private: "when the type is not a shaped type."); } return *bulkLoadElementType; - } else { - MlirAttribute encodingAttr = mlirAttributeGetNull(); - return mlirRankedTensorTypeGet(shape.size(), shape.data(), - *bulkLoadElementType, encodingAttr); } + MlirAttribute encodingAttr = mlirAttributeGetNull(); + return mlirRankedTensorTypeGet(shape.size(), shape.data(), + *bulkLoadElementType, encodingAttr); } static MlirAttribute getAttributeFromBuffer( Py_buffer &view, bool signless, std::optional<PyType> explicitType, - std::optional<std::vector<int64_t>> explicitShape, MlirContext &context) { + const std::optional<std::vector<int64_t>> &explicitShape, + MlirContext &context) { // Detect format codes that are suitable for bulk loading. This includes // all byte aligned integer and floating point types up to 8 bytes. // Notably, this excludes exotics types which do not have a direct @@ -1125,7 +1125,7 @@ private: bulkLoadElementType = mlirF16TypeGet(context); } else if (format == "?") { // i1 - // The i1 type needs to be bit-packed, so we will handle it seperately + // The i1 type needs to be bit-packed, so we will handle it separately return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, context); } else if (isSignedIntegerFormat(format)) { @@ -1205,8 +1205,8 @@ private: packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little"); nb_buffer_info pythonBuffer = nb::cast<nb_buffer>(packedBooleans).request(); - MlirType bitpackedType = - getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); + MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1), + std::move(explicitShape), view); assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8"); // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of // packedBooleans, hence the MlirAttribute will remain valid even when @@ -1443,9 +1443,9 @@ public: using PyConcreteAttribute::PyConcreteAttribute; static PyDenseResourceElementsAttribute - getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type, - std::optional<size_t> alignment, bool isMutable, - DefaultingPyMlirContext contextWrapper) { + getFromBuffer(const nb_buffer &buffer, const std::string &name, + const PyType &type, std::optional<size_t> alignment, + bool isMutable, DefaultingPyMlirContext contextWrapper) { if (!mlirTypeIsAShaped(type)) { throw std::invalid_argument( "Constructing a DenseResourceElementsAttr requires a ShapedType."); @@ -1534,7 +1534,7 @@ public: c.def("__len__", &PyDictAttribute::dunderLen); c.def_static( "get", - [](nb::dict attributes, DefaultingPyMlirContext context) { + [](const nb::dict &attributes, DefaultingPyMlirContext context) { SmallVector<MlirNamedAttribute> mlirNamedAttributes; mlirNamedAttributes.reserve(attributes.size()); for (std::pair<nb::handle, nb::handle> it : attributes) { @@ -1618,7 +1618,7 @@ public: static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyType value, DefaultingPyMlirContext context) { + [](const PyType &value, DefaultingPyMlirContext context) { MlirAttribute attr = mlirTypeAttrGet(value.get()); return PyTypeAttribute(context->getRef(), attr); }, @@ -1663,7 +1663,7 @@ public: static void bindDerived(ClassTy &c) { c.def_static( "get", - [](int64_t offset, const std::vector<int64_t> strides, + [](int64_t offset, const std::vector<int64_t> &strides, DefaultingPyMlirContext ctx) { MlirAttribute attr = mlirStridedLayoutAttrGet( ctx->get(), offset, strides.size(), strides.data()); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 5feed95..2df2a73 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -20,11 +20,8 @@ #include "nanobind/nanobind.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" #include <optional> -#include <system_error> -#include <utility> namespace nb = nanobind; using namespace nb::literals; @@ -70,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails. See also: https://mlir.llvm.org/docs/LangRef/ )"; +static const char kModuleCAPICreate[] = + R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr). +Note this returns a new object BUT _clear_mlir_module(module) must be called to +prevent double-frees (of the underlying mlir::Module). +)"; + static const char kOperationCreateDocstring[] = R"(Creates a new operation. @@ -199,7 +202,7 @@ operations. /// Helper for creating an @classmethod. template <class Func, typename... Args> -nb::object classmethod(Func f, Args... args) { +static nb::object classmethod(Func f, Args... args) { nb::object cf = nb::cpp_function(f, args...); return nb::borrow<nb::object>((PyClassMethod_New(cf.ptr()))); } @@ -705,84 +708,6 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } -size_t PyMlirContext::getLiveOperationCount() { - nb::ft_lock_guard lock(liveOperationsMutex); - return liveOperations.size(); -} - -std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() { - std::vector<PyOperation *> liveObjects; - nb::ft_lock_guard lock(liveOperationsMutex); - for (auto &entry : liveOperations) - liveObjects.push_back(entry.second.second); - return liveObjects; -} - -size_t PyMlirContext::clearLiveOperations() { - - LiveOperationMap operations; - { - nb::ft_lock_guard lock(liveOperationsMutex); - std::swap(operations, liveOperations); - } - for (auto &op : operations) - op.second.second->setInvalid(); - size_t numInvalidated = operations.size(); - return numInvalidated; -} - -void PyMlirContext::clearOperation(MlirOperation op) { - PyOperation *py_op; - { - nb::ft_lock_guard lock(liveOperationsMutex); - auto it = liveOperations.find(op.ptr); - if (it == liveOperations.end()) { - return; - } - py_op = it->second.second; - liveOperations.erase(it); - } - py_op->setInvalid(); -} - -void PyMlirContext::clearOperationsInside(PyOperationBase &op) { - typedef struct { - PyOperation &rootOp; - bool rootSeen; - } callBackData; - callBackData data{op.getOperation(), false}; - // Mark all ops below the op that the passmanager will be rooted - // at (but not op itself - note the preorder) as invalid. - MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, - void *userData) { - callBackData *data = static_cast<callBackData *>(userData); - if (LLVM_LIKELY(data->rootSeen)) - data->rootOp.getOperation().getContext()->clearOperation(op); - else - data->rootSeen = true; - return MlirWalkResult::MlirWalkResultAdvance; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - static_cast<void *>(&data), MlirWalkPreOrder); -} -void PyMlirContext::clearOperationsInside(MlirOperation op) { - PyOperationRef opRef = PyOperation::forOperation(getRef(), op); - clearOperationsInside(opRef->getOperation()); -} - -void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { - MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, - void *userData) { - PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData); - contextRef->clearOperation(op); - return MlirWalkResult::MlirWalkResultAdvance; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - &op.getOperation().getContext(), MlirWalkPreOrder); -} - -size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } - nb::object PyMlirContext::contextEnter(nb::object context) { return PyThreadContextEntry::pushContext(context); } @@ -1154,38 +1079,23 @@ PyLocation &DefaultingPyLocation::resolve() { PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) : BaseContextObject(std::move(contextRef)), module(module) {} -PyModule::~PyModule() { - nb::gil_scoped_acquire acquire; - auto &liveModules = getContext()->liveModules; - assert(liveModules.count(module.ptr) == 1 && - "destroying module not in live map"); - liveModules.erase(module.ptr); - mlirModuleDestroy(module); -} +PyModule::~PyModule() { mlirModuleDestroy(module); } PyModuleRef PyModule::forModule(MlirModule module) { MlirContext context = mlirModuleGetContext(module); PyMlirContextRef contextRef = PyMlirContext::forContext(context); - nb::gil_scoped_acquire acquire; - auto &liveModules = contextRef->liveModules; - auto it = liveModules.find(module.ptr); - if (it == liveModules.end()) { - // Create. - PyModule *unownedModule = new PyModule(std::move(contextRef), module); - // Note that the default return value policy on cast is automatic_reference, - // which does not take ownership (delete will not be called). - // Just be explicit. - nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); - unownedModule->handle = pyRef; - liveModules[module.ptr] = - std::make_pair(unownedModule->handle, unownedModule); - return PyModuleRef(unownedModule, std::move(pyRef)); - } - // Use existing. - PyModule *existing = it->second.second; - nb::object pyRef = nb::borrow<nb::object>(it->second.first); - return PyModuleRef(existing, std::move(pyRef)); + // Create. + PyModule *unownedModule = new PyModule(std::move(contextRef), module); + // Note that the default return value policy on cast is `automatic_reference`, + // which means "does not take ownership, does not call delete/dtor". + // We use `take_ownership`, which means "Python will call the C++ destructor + // and delete operator when the Python wrapper is garbage collected", because + // MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse + // etc). + nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); + unownedModule->handle = pyRef; + return PyModuleRef(unownedModule, std::move(pyRef)); } nb::object PyModule::createFromCapsule(nb::object capsule) { @@ -1210,15 +1120,11 @@ PyOperation::~PyOperation() { // If the operation has already been invalidated there is nothing to do. if (!valid) return; - - // Otherwise, invalidate the operation and remove it from live map when it is - // attached. - if (isAttached()) { - getContext()->clearOperation(*this); - } else { - // And destroy it when it is detached, i.e. owned by Python, in which case - // all nested operations must be invalidated at removed from the live map as - // well. + // Otherwise, invalidate the operation when it is attached. + if (isAttached()) + setInvalid(); + else { + // And destroy it when it is detached, i.e. owned by Python. erase(); } } @@ -1255,35 +1161,15 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - nb::ft_lock_guard lock(contextRef->liveOperationsMutex); - auto &liveOperations = contextRef->liveOperations; - auto it = liveOperations.find(operation.ptr); - if (it == liveOperations.end()) { - // Create. - PyOperationRef result = createInstance(std::move(contextRef), operation, - std::move(parentKeepAlive)); - liveOperations[operation.ptr] = - std::make_pair(result.getObject(), result.get()); - return result; - } - // Use existing. - PyOperation *existing = it->second.second; - nb::object pyRef = nb::borrow<nb::object>(it->second.first); - return PyOperationRef(existing, std::move(pyRef)); + return createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); } PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - nb::ft_lock_guard lock(contextRef->liveOperationsMutex); - auto &liveOperations = contextRef->liveOperations; - assert(liveOperations.count(operation.ptr) == 0 && - "cannot create detached operation that already exists"); - (void)liveOperations; PyOperationRef created = createInstance(std::move(contextRef), operation, std::move(parentKeepAlive)); - liveOperations[operation.ptr] = - std::make_pair(created.getObject(), created.get()); created->attached = false; return created; } @@ -1523,7 +1409,7 @@ nb::object PyOperation::create(std::string_view name, llvm::ArrayRef<MlirValue> operands, std::optional<nb::dict> attributes, std::optional<std::vector<PyBlock *>> successors, - int regions, DefaultingPyLocation location, + int regions, PyLocation &location, const nb::object &maybeIp, bool inferType) { llvm::SmallVector<MlirType, 4> mlirResults; llvm::SmallVector<MlirBlock, 4> mlirSuccessors; @@ -1627,7 +1513,7 @@ nb::object PyOperation::create(std::string_view name, if (!operation.ptr) throw nb::value_error("Operation creation failed"); PyOperationRef created = - PyOperation::createDetached(location->getContext(), operation); + PyOperation::createDetached(location.getContext(), operation); maybeInsertOperation(created, maybeIp); return created.getObject(); @@ -1655,7 +1541,7 @@ nb::object PyOperation::createOpView() { void PyOperation::erase() { checkValid(); - getContext()->clearOperationAndInside(*this); + setInvalid(); mlirOperationDestroy(operation); } @@ -1937,9 +1823,9 @@ nb::object PyOpView::buildGeneric( std::optional<nb::list> resultTypeList, nb::list operandList, std::optional<nb::dict> attributes, std::optional<std::vector<PyBlock *>> successors, - std::optional<int> regions, DefaultingPyLocation location, + std::optional<int> regions, PyLocation &location, const nb::object &maybeIp) { - PyMlirContextRef context = location->getContext(); + PyMlirContextRef context = location.getContext(); // Class level operation construction metadata. // Operand and result segment specs are either none, which does no @@ -2108,7 +1994,7 @@ nb::object PyOpView::buildGeneric( // Delegate to create. return PyOperation::create(name, /*results=*/std::move(resultTypes), - /*operands=*/std::move(operands), + /*operands=*/operands, /*attributes=*/std::move(attributes), /*successors=*/std::move(successors), /*regions=*/*regions, location, maybeIp, @@ -2789,6 +2675,156 @@ private: PyOperationRef operation; }; +// see +// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h + +#ifndef _Py_CAST +#define _Py_CAST(type, expr) ((type)(expr)) +#endif + +// Static inline functions should use _Py_NULL rather than using directly NULL +// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer, +// _Py_NULL is defined as nullptr. +#ifndef _Py_NULL +#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \ + (defined(__cplusplus) && __cplusplus >= 201103) +#define _Py_NULL nullptr +#else +#define _Py_NULL NULL +#endif +#endif + +// Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 + +// bpo-42262 added Py_XNewRef() +#if !defined(Py_XNewRef) +[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) { + Py_XINCREF(obj); + return obj; +} +#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj)) +#endif + +// bpo-42262 added Py_NewRef() +#if !defined(Py_NewRef) +[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) { + Py_INCREF(obj); + return obj; +} +#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj)) +#endif + +#endif // Python 3.10.0a3 + +// Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) + +// bpo-40429 added PyThreadState_GetFrame() +PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) { + assert(tstate != _Py_NULL && "expected tstate != _Py_NULL"); + return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame)); +} + +// bpo-40421 added PyFrame_GetBack() +PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) { + assert(frame != _Py_NULL && "expected frame != _Py_NULL"); + return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back)); +} + +// bpo-40421 added PyFrame_GetCode() +PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) { + assert(frame != _Py_NULL && "expected frame != _Py_NULL"); + assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL"); + return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code)); +} + +#endif // Python 3.9.0b1 + +MlirLocation tracebackToLocation(MlirContext ctx) { + size_t framesLimit = + PyGlobals::get().getTracebackLoc().locTracebackFramesLimit(); + // Use a thread_local here to avoid requiring a large amount of space. + thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames> + frames; + size_t count = 0; + + nb::gil_scoped_acquire acquire; + PyThreadState *tstate = PyThreadState_GET(); + PyFrameObject *next; + PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate); + // In the increment expression: + // 1. get the next prev frame; + // 2. decrement the ref count on the current frame (in order that it can get + // gc'd, along with any objects in its closure and etc); + // 3. set current = next. + for (; pyFrame != nullptr && count < framesLimit; + next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) { + PyCodeObject *code = PyFrame_GetCode(pyFrame); + auto fileNameStr = + nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename)); + llvm::StringRef fileName(fileNameStr); + if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName)) + continue; + + // co_qualname and PyCode_Addr2Location added in py3.11 +#if PY_VERSION_HEX < 0x030B00F0 + std::string name = + nb::cast<std::string>(nb::borrow<nb::str>(code->co_name)); + llvm::StringRef funcName(name); + int startLine = PyFrame_GetLineNumber(pyFrame); + MlirLocation loc = + mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0); +#else + std::string name = + nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname)); + llvm::StringRef funcName(name); + int startLine, startCol, endLine, endCol; + int lasti = PyFrame_GetLasti(pyFrame); + if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine, + &endCol)) { + throw nb::python_error(); + } + MlirLocation loc = mlirLocationFileLineColRangeGet( + ctx, wrap(fileName), startLine, startCol, endLine, endCol); +#endif + + frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc); + ++count; + } + // When the loop breaks (after the last iter), current frame (if non-null) + // is leaked without this. + Py_XDECREF(pyFrame); + + if (count == 0) + return mlirLocationUnknownGet(ctx); + + MlirLocation callee = frames[0]; + assert(!mlirLocationIsNull(callee) && "expected non-null callee location"); + if (count == 1) + return callee; + + MlirLocation caller = frames[count - 1]; + assert(!mlirLocationIsNull(caller) && "expected non-null caller location"); + for (int i = count - 2; i >= 1; i--) + caller = mlirLocationCallSiteGet(frames[i], caller); + + return mlirLocationCallSiteGet(callee, caller); +} + +PyLocation +maybeGetTracebackLocation(const std::optional<PyLocation> &location) { + if (location.has_value()) + return location.value(); + if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled()) + return DefaultingPyLocation::resolve(); + + PyMlirContext &ctx = DefaultingPyMlirContext::resolve(); + MlirLocation mlirLoc = tracebackToLocation(ctx.get()); + PyMlirContextRef ref = PyMlirContext::forContext(ctx.get()); + return {ref, mlirLoc}; +} + } // namespace //------------------------------------------------------------------------------ @@ -2876,14 +2912,6 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyMlirContextRef ref = PyMlirContext::forContext(self.get()); return ref.releaseObject(); }) - .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) - .def("_get_live_operation_objects", - &PyMlirContext::getLiveOperationObjects) - .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) - .def("_clear_live_operations_inside", - nb::overload_cast<MlirOperation>( - &PyMlirContext::clearOperationsInside)) - .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) .def("__enter__", &PyMlirContext::contextEnter) @@ -3052,10 +3080,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) .def_prop_ro_static( "current", - [](nb::object & /*class*/) { + [](nb::object & /*class*/) -> std::optional<PyLocation *> { auto *loc = PyThreadContextEntry::getDefaultLocation(); if (!loc) - throw nb::value_error("No current Location"); + return std::nullopt; return loc; }, "Gets the Location bound to the current thread or raises ValueError") @@ -3201,7 +3229,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable()) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule, + kModuleCAPICreate) + .def("_clear_mlir_module", &PyModule::clearMlirModule) .def_static( "parse", [](const std::string &moduleAsm, DefaultingPyMlirContext context) { @@ -3240,8 +3270,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { kModuleParseDocstring) .def_static( "create", - [](DefaultingPyLocation loc) { - MlirModule module = mlirModuleCreateEmpty(loc); + [](const std::optional<PyLocation> &loc) { + PyLocation pyLoc = maybeGetTracebackLocation(loc); + MlirModule module = mlirModuleCreateEmpty(pyLoc.get()); return PyModule::forModule(module).releaseObject(); }, nb::arg("loc").none() = nb::none(), "Creates an empty module") @@ -3280,7 +3311,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Defer to the operation's __str__. return self.attr("operation").attr("__str__")(); }, - kOperationStrDunderDocstring); + kOperationStrDunderDocstring) + .def( + "__eq__", + [](PyModule &self, PyModule &other) { + return mlirModuleEqual(self.get(), other.get()); + }, + "other"_a); //---------------------------------------------------------------------------- // Mapping of Operation. @@ -3292,7 +3329,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { }) .def("__eq__", [](PyOperationBase &self, PyOperationBase &other) { - return &self.getOperation() == &other.getOperation(); + return mlirOperationEqual(self.getOperation().get(), + other.getOperation().get()); }) .def("__eq__", [](PyOperationBase &self, nb::object other) { return false; }) @@ -3442,6 +3480,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { return operation.createOpView(); }, "Detaches the operation from its parent block.") + .def_prop_ro( + "attached", + [](PyOperationBase &self) { + PyOperation &operation = self.getOperation(); + operation.checkValid(); + return operation.isAttached(); + }, + "Reports if the operation is attached to its parent block.") .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) .def("walk", &PyOperationBase::walk, nb::arg("callback"), nb::arg("walk_order") = MlirWalkPostOrder); @@ -3454,8 +3500,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { std::optional<std::vector<PyValue *>> operands, std::optional<nb::dict> attributes, std::optional<std::vector<PyBlock *>> successors, int regions, - DefaultingPyLocation location, const nb::object &maybeIp, - bool inferType) { + const std::optional<PyLocation> &location, + const nb::object &maybeIp, bool inferType) { // Unpack/validate operands. llvm::SmallVector<MlirValue, 4> mlirOperands; if (operands) { @@ -3467,8 +3513,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { } } + PyLocation pyLoc = maybeGetTracebackLocation(location); return PyOperation::create(name, results, mlirOperands, attributes, - successors, regions, location, maybeIp, + successors, regions, pyLoc, maybeIp, inferType); }, nb::arg("name"), nb::arg("results").none() = nb::none(), @@ -3498,7 +3545,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); }, - "Returns the list of Operation successors."); + "Returns the list of Operation successors.") + .def("_set_invalid", &PyOperation::setInvalid, + "Invalidate the operation."); auto opViewClass = nb::class_<PyOpView, PyOperationBase>(m, "OpView") @@ -3512,12 +3561,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { std::optional<nb::list> resultTypeList, nb::list operandList, std::optional<nb::dict> attributes, std::optional<std::vector<PyBlock *>> successors, - std::optional<int> regions, DefaultingPyLocation location, + std::optional<int> regions, + const std::optional<PyLocation> &location, const nb::object &maybeIp) { + PyLocation pyLoc = maybeGetTracebackLocation(location); new (self) PyOpView(PyOpView::buildGeneric( name, opRegionSpec, operandSegmentSpecObj, resultSegmentSpecObj, resultTypeList, operandList, - attributes, successors, regions, location, maybeIp)); + attributes, successors, regions, pyLoc, maybeIp)); }, nb::arg("name"), nb::arg("opRegionSpec"), nb::arg("operandSegmentSpecObj").none() = nb::none(), @@ -3540,7 +3591,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); }, - "Returns the list of Operation successors."); + "Returns the list of Operation successors.") + .def( + "_set_invalid", + [](PyOpView &self) { self.getOperation().setInvalid(); }, + "Invalidate the operation."); opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); @@ -3551,17 +3606,18 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](nb::handle cls, std::optional<nb::list> resultTypeList, nb::list operandList, std::optional<nb::dict> attributes, std::optional<std::vector<PyBlock *>> successors, - std::optional<int> regions, DefaultingPyLocation location, + std::optional<int> regions, std::optional<PyLocation> location, const nb::object &maybeIp) { std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME")); std::tuple<int, bool> opRegionSpec = nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS")); nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS"); nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS"); + PyLocation pyLoc = maybeGetTracebackLocation(location); return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec, resultSegmentSpec, resultTypeList, operandList, attributes, successors, - regions, location, maybeIp); + regions, pyLoc, maybeIp); }, nb::arg("cls"), nb::arg("results").none() = nb::none(), nb::arg("operands").none() = nb::none(), diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index e600f1b..0de2f17 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -13,9 +13,9 @@ #include "Globals.h" #include "NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Nanobind.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. namespace nb = nanobind; using namespace mlir; @@ -197,3 +197,71 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) { // Not found and loading did not yield a registration. return std::nullopt; } + +bool PyGlobals::TracebackLoc::locTracebacksEnabled() { + nanobind::ft_lock_guard lock(mutex); + return locTracebackEnabled_; +} + +void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) { + nanobind::ft_lock_guard lock(mutex); + locTracebackEnabled_ = value; +} + +size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() { + nanobind::ft_lock_guard lock(mutex); + return locTracebackFramesLimit_; +} + +void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) { + nanobind::ft_lock_guard lock(mutex); + locTracebackFramesLimit_ = std::min(value, kMaxFrames); +} + +void PyGlobals::TracebackLoc::registerTracebackFileInclusion( + const std::string &file) { + nanobind::ft_lock_guard lock(mutex); + auto reg = "^" + llvm::Regex::escape(file); + if (userTracebackIncludeFiles.insert(reg).second) + rebuildUserTracebackIncludeRegex = true; + if (userTracebackExcludeFiles.count(reg)) { + if (userTracebackExcludeFiles.erase(reg)) + rebuildUserTracebackExcludeRegex = true; + } +} + +void PyGlobals::TracebackLoc::registerTracebackFileExclusion( + const std::string &file) { + nanobind::ft_lock_guard lock(mutex); + auto reg = "^" + llvm::Regex::escape(file); + if (userTracebackExcludeFiles.insert(reg).second) + rebuildUserTracebackExcludeRegex = true; + if (userTracebackIncludeFiles.count(reg)) { + if (userTracebackIncludeFiles.erase(reg)) + rebuildUserTracebackIncludeRegex = true; + } +} + +bool PyGlobals::TracebackLoc::isUserTracebackFilename( + const llvm::StringRef file) { + nanobind::ft_lock_guard lock(mutex); + if (rebuildUserTracebackIncludeRegex) { + userTracebackIncludeRegex.assign( + llvm::join(userTracebackIncludeFiles, "|")); + rebuildUserTracebackIncludeRegex = false; + isUserTracebackFilenameCache.clear(); + } + if (rebuildUserTracebackExcludeRegex) { + userTracebackExcludeRegex.assign( + llvm::join(userTracebackExcludeFiles, "|")); + rebuildUserTracebackExcludeRegex = false; + isUserTracebackFilenameCache.clear(); + } + if (!isUserTracebackFilenameCache.contains(file)) { + std::string fileStr = file.str(); + bool include = std::regex_search(fileStr, userTracebackIncludeRegex); + bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex); + isUserTracebackFilenameCache[file] = include || !exclude; + } + return isUserTracebackFilenameCache[file]; +} diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 9c22dea..0cc0459 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -192,16 +192,6 @@ public: PyMlirContext(const PyMlirContext &) = delete; PyMlirContext(PyMlirContext &&) = delete; - /// For the case of a python __init__ (nanobind::init) method, pybind11 is - /// quite strict about needing to return a pointer that is not yet associated - /// to an nanobind::object. Since the forContext() method acts like a pool, - /// possibly returning a recycled context, it does not satisfy this need. The - /// usual way in python to accomplish such a thing is to override __new__, but - /// that is also not supported by pybind11. Instead, we use this entry - /// point which always constructs a fresh context (which cannot alias an - /// existing one because it is fresh). - static PyMlirContext *createNewContextForInit(); - /// Returns a context reference for the singleton PyMlirContext wrapper for /// the given context. static PyMlirContextRef forContext(MlirContext context); @@ -228,40 +218,6 @@ public: /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); - /// Get a list of Python objects which are still in the live context map. - std::vector<PyOperation *> getLiveOperationObjects(); - - /// Gets the count of live operations associated with this context. - /// Used for testing. - size_t getLiveOperationCount(); - - /// Clears the live operations map, returning the number of entries which were - /// invalidated. To be used as a safety mechanism so that API end-users can't - /// corrupt by holding references they shouldn't have accessed in the first - /// place. - size_t clearLiveOperations(); - - /// Removes an operation from the live operations map and sets it invalid. - /// This is useful for when some non-bindings code destroys the operation and - /// the bindings need to made aware. For example, in the case when pass - /// manager is run. - /// - /// Note that this does *NOT* clear the nested operations. - void clearOperation(MlirOperation op); - - /// Clears all operations nested inside the given op using - /// `clearOperation(MlirOperation)`. - void clearOperationsInside(PyOperationBase &op); - void clearOperationsInside(MlirOperation op); - - /// Clears the operaiton _and_ all operations inside using - /// `clearOperation(MlirOperation)`. - void clearOperationAndInside(PyOperationBase &op); - - /// Gets the count of live modules associated with this context. - /// Used for testing. - size_t getLiveModuleCount(); - /// Enter and exit the context manager. static nanobind::object contextEnter(nanobind::object context); void contextExit(const nanobind::object &excType, @@ -288,25 +244,6 @@ private: static nanobind::ft_mutex live_contexts_mutex; static LiveContextMap &getLiveContexts(); - // Interns all live modules associated with this context. Modules tracked - // in this map are valid. When a module is invalidated, it is removed - // from this map, and while it still exists as an instance, any - // attempt to access it will raise an error. - using LiveModuleMap = - llvm::DenseMap<const void *, std::pair<nanobind::handle, PyModule *>>; - LiveModuleMap liveModules; - - // Interns all live operations associated with this context. Operations - // tracked in this map are valid. When an operation is invalidated, it is - // removed from this map, and while it still exists as an instance, any - // attempt to access it will raise an error. - using LiveOperationMap = - llvm::DenseMap<void *, std::pair<nanobind::handle, PyOperation *>>; - nanobind::ft_mutex liveOperationsMutex; - - // Guarded by liveOperationsMutex in free-threading mode. - LiveOperationMap liveOperations; - bool emitErrorDiagnostics = false; MlirContext context; @@ -558,8 +495,8 @@ class PyModule; using PyModuleRef = PyObjectRef<PyModule>; class PyModule : public BaseContextObject { public: - /// Returns a PyModule reference for the given MlirModule. This may return - /// a pre-existing or new object. + /// Returns a PyModule reference for the given MlirModule. This always returns + /// a new object. static PyModuleRef forModule(MlirModule module); PyModule(PyModule &) = delete; PyModule(PyMlirContext &&) = delete; @@ -580,11 +517,12 @@ public: nanobind::object getCapsule(); /// Creates a PyModule from the MlirModule wrapped by a capsule. - /// Note that PyModule instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirModule - /// is taken by calling this function. + /// Note this returns a new object BUT clearMlirModule() must be called to + /// prevent double-frees (of the underlying mlir::Module). static nanobind::object createFromCapsule(nanobind::object capsule); + void clearMlirModule() { module = {nullptr}; } + private: PyModule(PyMlirContextRef contextRef, MlirModule module); MlirModule module; @@ -722,8 +660,7 @@ public: llvm::ArrayRef<MlirValue> operands, std::optional<nanobind::dict> attributes, std::optional<std::vector<PyBlock *>> successors, int regions, - DefaultingPyLocation location, const nanobind::object &ip, - bool inferType); + PyLocation &location, const nanobind::object &ip, bool inferType); /// Creates an OpView suitable for this operation. nanobind::object createOpView(); @@ -781,7 +718,7 @@ public: nanobind::list operandList, std::optional<nanobind::dict> attributes, std::optional<std::vector<PyBlock *>> successors, - std::optional<int> regions, DefaultingPyLocation location, + std::optional<int> regions, PyLocation &location, const nanobind::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its @@ -1227,7 +1164,7 @@ public: /// Note that PyAffineExpr instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr /// is taken by calling this function. - static PyAffineExpr createFromCapsule(nanobind::object capsule); + static PyAffineExpr createFromCapsule(const nanobind::object &capsule); PyAffineExpr add(const PyAffineExpr &other) const; PyAffineExpr mul(const PyAffineExpr &other) const; @@ -1254,7 +1191,7 @@ public: /// Note that PyAffineMap instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineMap /// is taken by calling this function. - static PyAffineMap createFromCapsule(nanobind::object capsule); + static PyAffineMap createFromCapsule(const nanobind::object &capsule); private: MlirAffineMap affineMap; @@ -1274,7 +1211,7 @@ public: /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. /// Note that PyIntegerSet instances may be uniqued, so the returned object /// may be a pre-existing object. Integer sets are owned by the context. - static PyIntegerSet createFromCapsule(nanobind::object capsule); + static PyIntegerSet createFromCapsule(const nanobind::object &capsule); private: MlirIntegerSet integerSet; diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index b11e3f7..a9b1259 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -963,7 +963,7 @@ public: static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string dialectNamespace, std::string typeData, + [](const std::string &dialectNamespace, const std::string &typeData, DefaultingPyMlirContext context) { MlirType type = mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace), diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 6f49431..278847e 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// - #include "Globals.h" #include "IRModule.h" #include "NanobindUtils.h" @@ -44,7 +43,27 @@ NB_MODULE(_mlir, m) { .def("_register_operation_impl", &PyGlobals::registerOperationImpl, "operation_name"_a, "operation_class"_a, nb::kw_only(), "replace"_a = false, - "Testing hook for directly registering an operation"); + "Testing hook for directly registering an operation") + .def("loc_tracebacks_enabled", + [](PyGlobals &self) { + return self.getTracebackLoc().locTracebacksEnabled(); + }) + .def("set_loc_tracebacks_enabled", + [](PyGlobals &self, bool enabled) { + self.getTracebackLoc().setLocTracebacksEnabled(enabled); + }) + .def("set_loc_tracebacks_frame_limit", + [](PyGlobals &self, int n) { + self.getTracebackLoc().setLocTracebackFramesLimit(n); + }) + .def("register_traceback_file_inclusion", + [](PyGlobals &self, const std::string &filename) { + self.getTracebackLoc().registerTracebackFileInclusion(filename); + }) + .def("register_traceback_file_exclusion", + [](PyGlobals &self, const std::string &filename) { + self.getTracebackLoc().registerTracebackFileExclusion(filename); + }); // Aside from making the globals accessible to python, having python manage // it is necessary to make sure it is destroyed (and releases its python diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 20017e2..88e28dc 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -39,7 +39,7 @@ public: return nb::steal<nb::object>(mlirPythonPassManagerToCapsule(get())); } - static nb::object createFromCapsule(nb::object capsule) { + static nb::object createFromCapsule(const nb::object &capsule) { MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); if (mlirPassManagerIsNull(rawPm)) throw nb::python_error(); @@ -159,11 +159,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "ValueError if the pipeline can't be parsed.") .def( "run", - [](PyPassManager &passManager, PyOperationBase &op, - bool invalidateOps) { - if (invalidateOps) { - op.getOperation().getContext()->clearOperationsInside(op); - } + [](PyPassManager &passManager, PyOperationBase &op) { // Actually run the pass manager. PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRunOnOp( @@ -172,7 +168,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { throw MLIRError("Failure while executing pass pipeline", errors.take()); }, - "operation"_a, "invalidate_ops"_a = true, + "operation"_a, "Run the pass manager on the provided operation, raising an " "MLIRError on failure.") .def( diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp index 3ba42be..3edcb09 100644 --- a/mlir/lib/Bindings/Python/RegisterEverything.cpp +++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir-c/RegisterEverything.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" NB_MODULE(_mlirRegisterEverything, m) { m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration"; diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index f9b0fed..920bca8 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) { // root. This is awkward, but we don't have access to PyMlirContext // object here otherwise. nb::object obj = nb::cast(payloadRoot); - obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); MlirLogicalResult result = mlirTransformApplyNamedSequence( payloadRoot, transformRoot, transformModule, options.options); diff --git a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp index 6ebeac5..eacb936 100644 --- a/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp +++ b/mlir/lib/Bytecode/Writer/BytecodeWriter.cpp @@ -19,6 +19,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/Endian.h" #include "llvm/Support/raw_ostream.h" #include <optional> @@ -150,8 +151,7 @@ public: /// Backpatch a byte in the result buffer at the given offset. void patchByte(uint64_t offset, uint8_t value, StringLiteral desc) { - LLVM_DEBUG(llvm::dbgs() << "patchByte(" << offset << ',' << uint64_t(value) - << ")\t" << desc << '\n'); + LDBG() << "patchByte(" << offset << ',' << uint64_t(value) << ")\t" << desc; assert(offset < size() && offset >= prevResultSize && "cannot patch previously emitted data"); currentResult[offset - prevResultSize] = value; @@ -160,8 +160,7 @@ public: /// Emit the provided blob of data, which is owned by the caller and is /// guaranteed to not die before the end of the bytecode process. void emitOwnedBlob(ArrayRef<uint8_t> data, StringLiteral desc) { - LLVM_DEBUG(llvm::dbgs() - << "emitOwnedBlob(" << data.size() << "b)\t" << desc << '\n'); + LDBG() << "emitOwnedBlob(" << data.size() << "b)\t" << desc; // Push the current buffer before adding the provided data. appendResult(std::move(currentResult)); appendOwnedResult(data); @@ -209,15 +208,13 @@ public: /// Emit a single byte. template <typename T> void emitByte(T byte, StringLiteral desc) { - LLVM_DEBUG(llvm::dbgs() - << "emitByte(" << uint64_t(byte) << ")\t" << desc << '\n'); + LDBG() << "emitByte(" << uint64_t(byte) << ")\t" << desc; currentResult.push_back(static_cast<uint8_t>(byte)); } /// Emit a range of bytes. void emitBytes(ArrayRef<uint8_t> bytes, StringLiteral desc) { - LLVM_DEBUG(llvm::dbgs() - << "emitBytes(" << bytes.size() << "b)\t" << desc << '\n'); + LDBG() << "emitBytes(" << bytes.size() << "b)\t" << desc; llvm::append_range(currentResult, bytes); } @@ -229,7 +226,7 @@ public: /// additional bytes, provide the value of the integer encoded in /// little-endian order. void emitVarInt(uint64_t value, StringLiteral desc) { - LLVM_DEBUG(llvm::dbgs() << "emitVarInt(" << value << ")\t" << desc << '\n'); + LDBG() << "emitVarInt(" << value << ")\t" << desc; // In the most common case, the value can be represented in a single byte. // Given how hot this case is, explicitly handle that here. diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 306cebd..2dbb993 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, return wrap(jitOrError->release()); } +extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) { + unwrap(jit)->initialize(); +} + extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) { delete (unwrap(jit)); } @@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, void *sym) { unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { llvm::orc::SymbolMap symbolMap; - symbolMap[interner(unwrap(name))] = - { llvm::orc::ExecutorAddr::fromPtr(sym), - llvm::JITSymbolFlags::Exported }; + symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym), + llvm::JITSymbolFlags::Exported}; return symbolMap; }); } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 9d8554a..f5f4ed3 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -465,10 +465,6 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType))); } -MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) { - return wrap(llvm::cast<UnrankedTensorType>(unwrap(type)).getElementType()); -} - //===----------------------------------------------------------------------===// // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 8491553..c7069f0 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -465,6 +465,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) { return wrap(dyn_cast<ModuleOp>(unwrap(op))); } +bool mlirModuleEqual(MlirModule lhs, MlirModule rhs) { + return unwrap(lhs) == unwrap(rhs); +} + //===----------------------------------------------------------------------===// // Operation state API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt index 191b5ab6..91ed05f 100644 --- a/mlir/lib/CMakeLists.txt +++ b/mlir/lib/CMakeLists.txt @@ -13,6 +13,7 @@ add_subdirectory(Parser) add_subdirectory(Pass) add_subdirectory(Query) add_subdirectory(Reducer) +add_subdirectory(Remark) add_subdirectory(Rewrite) add_subdirectory(Support) add_subdirectory(TableGen) diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 64720bf..203790e 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" @@ -1876,6 +1877,54 @@ struct AMDGPUSwizzleBitModeLowering } }; +struct AMDGPUPermlaneLowering : public ConvertOpToLLVMPattern<PermlaneSwapOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + AMDGPUPermlaneLowering(const LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern<PermlaneSwapOp>(converter), chipset(chipset) {} + Chipset chipset; + + LogicalResult + matchAndRewrite(PermlaneSwapOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset < kGfx950) + return op->emitOpError("permlane_swap is only supported on gfx950+"); + + Location loc = op.getLoc(); + Type i32 = rewriter.getI32Type(); + Value src = adaptor.getSrc(); + unsigned rowLength = op.getRowLength(); + bool fi = op.getFetchInactive(); + bool boundctrl = op.getBoundCtrl(); + + SmallVector<Value> decomposed = + LLVM::decomposeValue(rewriter, loc, src, i32); + + SmallVector<Value> permuted; + for (Value v : decomposed) { + Value res; + Type i32pair = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), {v.getType(), v.getType()}); + + if (rowLength == 16) + res = ROCDL::Permlane16SwapOp::create(rewriter, loc, i32pair, v, v, fi, + boundctrl); + else if (rowLength == 32) + res = ROCDL::Permlane32SwapOp::create(rewriter, loc, i32pair, v, v, fi, + boundctrl); + else + llvm_unreachable("unsupported row length"); + + Value vdstNew = LLVM::ExtractValueOp::create(rewriter, loc, res, {0}); + permuted.emplace_back(vdstNew); + } + + Value result = LLVM::composeValue(rewriter, loc, permuted, src.getType()); + rewriter.replaceOp(op, result); + return success(); + } +}; + struct ConvertAMDGPUToROCDLPass : public impl::ConvertAMDGPUToROCDLPassBase<ConvertAMDGPUToROCDLPass> { using Base::Base; @@ -1944,6 +1993,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, WMMAOpLowering, ExtPackedFp8OpLowering, ScaledExtPackedOpLowering, PackedScaledTruncOpLowering, PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, GatherToLDSOpLowering, - TransposeLoadOpLowering>(converter, chipset); + TransposeLoadOpLowering, AMDGPUPermlaneLowering>(converter, chipset); patterns.add<AMDGPUSwizzleBitModeLowering>(converter); } diff --git a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp index 515fe5c..b68933d 100644 --- a/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp +++ b/mlir/lib/Conversion/ArithToEmitC/ArithToEmitC.cpp @@ -610,16 +610,19 @@ public: ? rewriter.getIntegerAttr(arithmeticType, 0) : rewriter.getIndexAttr(0))); - emitc::ExpressionOp ternary = emitc::ExpressionOp::create( - rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false); - Block &bodyBlock = ternary.getBodyRegion().emplaceBlock(); + emitc::ExpressionOp ternary = + emitc::ExpressionOp::create(rewriter, op.getLoc(), arithmeticType, + ValueRange({lhs, rhs, excessCheck, poison}), + /*do_not_inline=*/false); + Block &bodyBlock = ternary.createBody(); auto currentPoint = rewriter.getInsertionPoint(); rewriter.setInsertionPointToStart(&bodyBlock); Value arithmeticResult = - EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs); - Value resultOrPoison = - emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType, - excessCheck, arithmeticResult, poison); + EmitCOp::create(rewriter, op.getLoc(), arithmeticType, + bodyBlock.getArgument(0), bodyBlock.getArgument(1)); + Value resultOrPoison = emitc::ConditionalOp::create( + rewriter, op.getLoc(), arithmeticType, bodyBlock.getArgument(2), + arithmeticResult, bodyBlock.getArgument(3)); emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison); rewriter.setInsertionPoint(op->getBlock(), currentPoint); diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 18e857c..cb0c829 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> { ConversionPatternRewriter &rewriter) const override; }; +struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + using Adaptor = + typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor; + + LogicalResult + matchAndRewrite(arith::SelectOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + } // namespace //===----------------------------------------------------------------------===// @@ -480,6 +490,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor, } //===----------------------------------------------------------------------===// +// SelectOpOneToNLowering +//===----------------------------------------------------------------------===// + +/// Pattern for arith.select where the true/false values lower to multiple +/// SSA values (1:N conversion). This pattern generates multiple arith.select +/// than can be lowered by the 1:1 arith.select pattern. +LogicalResult SelectOpOneToNLowering::matchAndRewrite( + arith::SelectOp op, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // In case of a 1:1 conversion, the 1:1 pattern will match. + if (llvm::hasSingleElement(adaptor.getTrueValue())) + return rewriter.notifyMatchFailure( + op, "not a 1:N conversion, 1:1 pattern will match"); + if (!op.getCondition().getType().isInteger(1)) + return rewriter.notifyMatchFailure(op, + "non-i1 conditions are not supported"); + SmallVector<Value> results; + for (auto [trueValue, falseValue] : + llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue())) + results.push_back(arith::SelectOp::create( + rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue)); + rewriter.replaceOpWithMultiple(op, {results}); + return success(); +} + +//===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns( RemSIOpLowering, RemUIOpLowering, SelectOpLowering, + SelectOpOneToNLowering, ShLIOpLowering, ShRSIOpLowering, ShRUIOpLowering, diff --git a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp index e28d5122..c69ede9 100644 --- a/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp +++ b/mlir/lib/Conversion/ArmSMEToSCF/ArmSMEToSCF.cpp @@ -333,7 +333,7 @@ struct TileLoadOpWithMaskAndPadNonZeroConversion auto loadSlice = vector::MaskedLoadOp::create(rewriter, loc, tileSliceType, tileLoadOp.getBase(), memrefIndices, maskOp1D, - /*passthru=*/pad1DOp); + /*passthrough=*/pad1DOp); // Create 'arm_sme.insert_tile_slice' to insert slice into tile. auto insertSlice = arm_sme::InsertTileSliceOp::create( diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt index 785cb82..71986f8 100644 --- a/mlir/lib/Conversion/CMakeLists.txt +++ b/mlir/lib/Conversion/CMakeLists.txt @@ -50,6 +50,7 @@ add_subdirectory(NVVMToLLVM) add_subdirectory(OpenACCToSCF) add_subdirectory(OpenMPToLLVM) add_subdirectory(PDLToPDLInterp) +add_subdirectory(PtrToLLVM) add_subdirectory(ReconcileUnrealizedCasts) add_subdirectory(SCFToControlFlow) add_subdirectory(SCFToEmitC) @@ -68,6 +69,7 @@ add_subdirectory(TosaToSCF) add_subdirectory(TosaToTensor) add_subdirectory(UBToLLVM) add_subdirectory(UBToSPIRV) +add_subdirectory(VectorToAMX) add_subdirectory(VectorToArmSME) add_subdirectory(VectorToGPU) add_subdirectory(VectorToLLVM) @@ -75,3 +77,4 @@ add_subdirectory(VectorToSCF) add_subdirectory(VectorToSPIRV) add_subdirectory(VectorToXeGPU) add_subdirectory(XeVMToLLVM) +add_subdirectory(XeGPUToXeVM) diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp index 35ad99c..7a3a7fd 100644 --- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp +++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp @@ -56,22 +56,30 @@ struct ComplexOpToROCDLLibraryCalls : public OpRewritePattern<Op> { private: std::string funcName; }; + +// Rewrite complex.pow(z, w) -> complex.exp(w * complex.log(z)) +struct PowOpToROCDLLibraryCalls : public OpRewritePattern<complex::PowOp> { + using OpRewritePattern<complex::PowOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(complex::PowOp op, + PatternRewriter &rewriter) const final { + Location loc = op.getLoc(); + Value logBase = rewriter.create<complex::LogOp>(loc, op.getLhs()); + Value mul = rewriter.create<complex::MulOp>(loc, op.getRhs(), logBase); + Value exp = rewriter.create<complex::ExpOp>(loc, mul); + rewriter.replaceOp(op, exp); + return success(); + } +}; } // namespace void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( RewritePatternSet &patterns) { + patterns.add<PowOpToROCDLLibraryCalls>(patterns.getContext()); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float32Type>>( patterns.getContext(), "__ocml_cabs_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>( patterns.getContext(), "__ocml_cabs_f64"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>( - patterns.getContext(), "__ocml_carg_f32"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>( - patterns.getContext(), "__ocml_carg_f64"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>( - patterns.getContext(), "__ocml_conj_f32"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>( - patterns.getContext(), "__ocml_conj_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>( patterns.getContext(), "__ocml_ccos_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>( @@ -84,10 +92,6 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns( patterns.getContext(), "__ocml_clog_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>( patterns.getContext(), "__ocml_clog_f64"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>( - patterns.getContext(), "__ocml_cpow_f32"); - patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>( - patterns.getContext(), "__ocml_cpow_f64"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>( patterns.getContext(), "__ocml_csin_f32"); patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>( @@ -122,10 +126,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect<func::FuncDialect>(); - target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp, - complex::CosOp, complex::ExpOp, complex::LogOp, - complex::PowOp, complex::SinOp, complex::SqrtOp, - complex::TanOp, complex::TanhOp>(); + target.addLegalOp<complex::MulOp>(); + target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp, + complex::LogOp, complex::PowOp, complex::SinOp, + complex::SqrtOp, complex::TanOp, complex::TanhOp>(); if (failed(applyPartialConversion(op, target, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp index ff6d369..798d8b0 100644 --- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp +++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp @@ -125,22 +125,33 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter, return rewriter.applySignatureConversion(block, *conversion, converter); } +/// Flatten the given value ranges into a single vector of values. +static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) { + SmallVector<Value> result; + for (const ValueRange &vals : values) + llvm::append_range(result, vals); + return result; +} + /// Convert the destination block signature (if necessary) and lower the branch /// op to llvm.br. struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern; + using Adaptor = + typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor; LogicalResult - matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor, + matchAndRewrite(cf::BranchOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands()); FailureOr<Block *> convertedBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(), - TypeRange(adaptor.getOperands())); + TypeRange(ValueRange(flattenedAdaptor))); if (failed(convertedBlock)) return failure(); DictionaryAttr attrs = op->getAttrDictionary(); Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>( - op, adaptor.getOperands(), *convertedBlock); + op, flattenedAdaptor, *convertedBlock); // TODO: We should not just forward all attributes like that. But there are // existing Flang tests that depend on this behavior. newOp->setAttrs(attrs); @@ -152,29 +163,37 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> { /// branch op to llvm.cond_br. struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> { using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern; + using Adaptor = + typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor; LogicalResult - matchAndRewrite(cf::CondBranchOp op, - typename cf::CondBranchOp::Adaptor adaptor, + matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector<Value> flattenedAdaptorTrue = + flattenValues(adaptor.getTrueDestOperands()); + SmallVector<Value> flattenedAdaptorFalse = + flattenValues(adaptor.getFalseDestOperands()); + if (!llvm::hasSingleElement(adaptor.getCondition())) + return rewriter.notifyMatchFailure(op, + "expected single element condition"); FailureOr<Block *> convertedTrueBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(), - TypeRange(adaptor.getTrueDestOperands())); + TypeRange(ValueRange(flattenedAdaptorTrue))); if (failed(convertedTrueBlock)) return failure(); FailureOr<Block *> convertedFalseBlock = getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(), - TypeRange(adaptor.getFalseDestOperands())); + TypeRange(ValueRange(flattenedAdaptorFalse))); if (failed(convertedFalseBlock)) return failure(); - DictionaryAttr attrs = op->getAttrDictionary(); + DictionaryAttr attrs = op->getDiscardableAttrDictionary(); auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>( - op, adaptor.getCondition(), adaptor.getTrueDestOperands(), - adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(), + op, llvm::getSingleElement(adaptor.getCondition()), + flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(), *convertedTrueBlock, *convertedFalseBlock); // TODO: We should not just forward all attributes like that. But there are // existing Flang tests that depend on this behavior. - newOp->setAttrs(attrs); + newOp->setDiscardableAttrs(attrs); return success(); } }; diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp index ed5d6d4..764ad2e 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/Support/DebugLog.h" #include <memory> #define DEBUG_TYPE "convert-to-llvm" @@ -31,7 +32,8 @@ namespace { class ConvertToLLVMPassInterface { public: ConvertToLLVMPassInterface(MLIRContext *context, - ArrayRef<std::string> filterDialects); + ArrayRef<std::string> filterDialects, + bool allowPatternRollback = true); virtual ~ConvertToLLVMPassInterface() = default; /// Get the dependent dialects used by `convert-to-llvm`. @@ -60,6 +62,9 @@ protected: MLIRContext *context; /// List of dialects names to use as filters. ArrayRef<std::string> filterDialects; + /// An experimental flag to disallow pattern rollback. This is more efficient + /// but not supported by all lowering patterns. + bool allowPatternRollback; }; /// This DialectExtension can be attached to the context, which will invoke the @@ -75,13 +80,13 @@ public: void apply(MLIRContext *context, MutableArrayRef<Dialect *> dialects) const final { - LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM extension load\n"); + LDBG() << "Convert to LLVM extension load"; for (Dialect *dialect : dialects) { auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); if (!iface) continue; - LLVM_DEBUG(llvm::dbgs() << "Convert to LLVM found dialect interface for " - << dialect->getNamespace() << "\n"); + LDBG() << "Convert to LLVM found dialect interface for " + << dialect->getNamespace(); iface->loadDependentDialects(context); } } @@ -128,7 +133,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface { /// Apply the conversion driver. LogicalResult transform(Operation *op, AnalysisManager manager) const final { - if (failed(applyPartialConversion(op, *target, *patterns))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, *target, *patterns, config))) return failure(); return success(); } @@ -179,7 +186,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface { patterns); // Apply the conversion. - if (failed(applyPartialConversion(op, target, std::move(patterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, target, std::move(patterns), config))) return failure(); return success(); } @@ -206,9 +215,11 @@ public: std::shared_ptr<ConvertToLLVMPassInterface> impl; // Choose the pass implementation. if (useDynamic) - impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects); + impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects, + allowPatternRollback); else - impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects); + impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects, + allowPatternRollback); if (failed(impl->initialize())) return failure(); this->impl = impl; @@ -228,8 +239,10 @@ public: //===----------------------------------------------------------------------===// ConvertToLLVMPassInterface::ConvertToLLVMPassInterface( - MLIRContext *context, ArrayRef<std::string> filterDialects) - : context(context), filterDialects(filterDialects) {} + MLIRContext *context, ArrayRef<std::string> filterDialects, + bool allowPatternRollback) + : context(context), filterDialects(filterDialects), + allowPatternRollback(allowPatternRollback) {} void ConvertToLLVMPassInterface::getDependentDialects( DialectRegistry ®istry) { diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 67bb1c1..42c76ed 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -527,19 +527,21 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern; using Super = CallOpInterfaceLowering<CallOpType>; using Base = ConvertOpToLLVMPattern<CallOpType>; + using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor; - LogicalResult matchAndRewriteImpl(CallOpType callOp, - typename CallOpType::Adaptor adaptor, + LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor, ConversionPatternRewriter &rewriter, bool useBarePtrCallConv = false) const { // Pack the result types into a struct. Type packedResult = nullptr; + SmallVector<SmallVector<Type>> groupedResultTypes; unsigned numResults = callOp.getNumResults(); auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); - + int64_t numConvertedTypes = 0; if (numResults != 0) { if (!(packedResult = this->getTypeConverter()->packFunctionResults( - resultTypes, useBarePtrCallConv))) + resultTypes, useBarePtrCallConv, &groupedResultTypes, + &numConvertedTypes))) return failure(); } @@ -565,34 +567,64 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> { static_cast<int32_t>(promoted.size()), 0}; newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({}); - SmallVector<Value, 4> results; - if (numResults < 2) { - // If < 2 results, packing did not do anything and we can just return. - results.append(newOp.result_begin(), newOp.result_end()); - } else { - // Otherwise, it had been converted to an operation producing a structure. - // Extract individual results from the structure and return them as list. - results.reserve(numResults); - for (unsigned i = 0; i < numResults; ++i) { - results.push_back(LLVM::ExtractValueOp::create( - rewriter, callOp.getLoc(), newOp->getResult(0), i)); + // Helper function that extracts an individual result from the return value + // of the new call op. llvm.call ops support only 0 or 1 result. In case of + // 2 or more results, the results are packed into a structure. + // + // The new call op may have more than 2 results because: + // a. The original call op has more than 2 results. + // b. An original op result type-converted to more than 1 result. + auto getUnpackedResult = [&](unsigned i) -> Value { + assert(numConvertedTypes > 0 && "convert op has no results"); + if (numConvertedTypes == 1) { + assert(i == 0 && "out of bounds: converted op has only one result"); + return newOp->getResult(0); } + // Results have been converted to a structure. Extract individual results + // from the structure. + return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(), + newOp->getResult(0), i); + }; + + // Group the results into a vector of vectors, such that it is clear which + // original op result is replaced with which range of values. (In case of a + // 1:N conversion, there can be multiple replacements for a single result.) + SmallVector<SmallVector<Value>> results; + results.reserve(numResults); + unsigned counter = 0; + for (unsigned i = 0; i < numResults; ++i) { + SmallVector<Value> &group = results.emplace_back(); + for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j) + group.push_back(getUnpackedResult(counter++)); } - if (useBarePtrCallConv) { - // For the bare-ptr calling convention, promote memref results to - // descriptors. - assert(results.size() == resultTypes.size() && - "The number of arguments and types doesn't match"); - this->getTypeConverter()->promoteBarePtrsToDescriptors( - rewriter, callOp.getLoc(), resultTypes, results); - } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(), - resultTypes, results, - /*toDynamic=*/false))) { - return failure(); + // Special handling for MemRef types. + for (unsigned i = 0; i < numResults; ++i) { + Type origType = resultTypes[i]; + auto memrefType = dyn_cast<MemRefType>(origType); + auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType); + if (useBarePtrCallConv && memrefType) { + // For the bare-ptr calling convention, promote memref results to + // descriptors. + assert(results[i].size() == 1 && "expected one converted result"); + results[i].front() = MemRefDescriptor::fromStaticShape( + rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType, + results[i].front()); + } + if (unrankedMemrefType) { + assert(!useBarePtrCallConv && "unranked memref is not supported in the " + "bare-ptr calling convention"); + assert(results[i].size() == 1 && "expected one converted result"); + Value desc = this->copyUnrankedDescriptor( + rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(), + /*toDynamic=*/false); + if (!desc) + return failure(); + results[i].front() = desc; + } } - rewriter.replaceOp(callOp, results); + rewriter.replaceOpWithMultiple(callOp, results); return success(); } }; @@ -606,7 +638,7 @@ public: symbolTables(symbolTables) {} LogicalResult - matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor, + matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { bool useBarePtrCallConv = false; if (getTypeConverter()->getOptions().useBarePtrCallConv) { @@ -636,7 +668,7 @@ struct CallIndirectOpLowering using Super::Super; LogicalResult - matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor, + matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter); } @@ -679,41 +711,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> { using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor, + matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - unsigned numArguments = op.getNumOperands(); SmallVector<Value, 4> updatedOperands; auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>(); bool useBarePtrCallConv = shouldUseBarePtrCallConv(funcOp, this->getTypeConverter()); - if (useBarePtrCallConv) { - // For the bare-ptr calling convention, extract the aligned pointer to - // be returned from the memref descriptor. - for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { - Type oldTy = std::get<0>(it).getType(); - Value newOperand = std::get<1>(it); - if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr( - cast<BaseMemRefType>(oldTy))) { - MemRefDescriptor memrefDesc(newOperand); - newOperand = memrefDesc.allocatedPtr(rewriter, loc); - } else if (isa<UnrankedMemRefType>(oldTy)) { + + for (auto [oldOperand, newOperands] : + llvm::zip_equal(op->getOperands(), adaptor.getOperands())) { + Type oldTy = oldOperand.getType(); + if (auto memRefType = dyn_cast<MemRefType>(oldTy)) { + assert(newOperands.size() == 1 && "expected one converted result"); + if (useBarePtrCallConv && + getTypeConverter()->canConvertToBarePtr(memRefType)) { + // For the bare-ptr calling convention, extract the aligned pointer to + // be returned from the memref descriptor. + MemRefDescriptor memrefDesc(newOperands.front()); + updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc)); + continue; + } + } else if (auto unrankedMemRefType = + dyn_cast<UnrankedMemRefType>(oldTy)) { + assert(newOperands.size() == 1 && "expected one converted result"); + if (useBarePtrCallConv) { // Unranked memref is not supported in the bare pointer calling // convention. return failure(); } - updatedOperands.push_back(newOperand); + Value updatedDesc = + copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType, + newOperands.front(), /*toDynamic=*/true); + if (!updatedDesc) + return failure(); + updatedOperands.push_back(updatedDesc); + continue; } - } else { - updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); - (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), - updatedOperands, - /*toDynamic=*/true); + + llvm::append_range(updatedOperands, newOperands); } // If ReturnOp has 0 or 1 operand, create it and return immediately. - if (numArguments <= 1) { + if (updatedOperands.size() <= 1) { rewriter.replaceOpWithNewOp<LLVM::ReturnOp>( op, TypeRange(), updatedOperands, op->getAttrs()); return success(); diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp index 3cfbd89..e516118 100644 --- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp +++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp @@ -532,6 +532,9 @@ void GpuToLLVMConversionPass::runOnOperation() { // Vector transfer ops with rank > 1 should be lowered with VectorToSCF. vector::populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); + // Transform N-D vector.from_elements to 1-D vector.from_elements before + // conversion. + vector::populateVectorFromElementsLoweringPatterns(patterns); if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 317bfc2..93e370d 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -27,6 +27,7 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -369,6 +370,9 @@ struct LowerGpuOpsToNVVMOpsPass final { RewritePatternSet patterns(m.getContext()); populateGpuRewritePatterns(patterns); + // Transform N-D vector.from_elements to 1-D vector.from_elements before + // conversion. + vector::populateVectorFromElementsLoweringPatterns(patterns); if (failed(applyPatternsGreedily(m, std::move(patterns)))) return signalPassFailure(); } @@ -394,7 +398,7 @@ struct LowerGpuOpsToNVVMOpsPass final if (!allowedDialectsSet.empty() && !allowed) continue; - auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); + auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect); if (!iface) { // Error out if dialect was explicily specified but doesn't implement // conversion interface. diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp index d22364e..8994905 100644 --- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp +++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp @@ -79,17 +79,30 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) { return canBeBare; } -static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc, - const unsigned indexBitwidth) { +static Value getLaneId(RewriterBase &rewriter, Location loc) { auto int32Type = IntegerType::get(rewriter.getContext(), 32); Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32); - Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type, - ValueRange{minus1, zero}); - Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type, - ValueRange{minus1, mbcntLo}); + NamedAttribute noundef = rewriter.getNamedAttr( + LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr()); + NamedAttribute lowRange = rewriter.getNamedAttr( + LLVM::LLVMDialect::getRangeAttrName(), + LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32), + APInt(32, 32))); + NamedAttribute highRange = rewriter.getNamedAttr( + LLVM::LLVMDialect::getRangeAttrName(), + LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32), + APInt(32, 64))); + Value mbcntLo = ROCDL::MbcntLoOp::create( + rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{}, + /*res_attrs=*/ + rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange}))); + Value laneId = ROCDL::MbcntHiOp::create( + rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{}, + rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange}))); return laneId; } + static constexpr StringLiteral amdgcnDataLayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32" "-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:" @@ -104,18 +117,16 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> { LogicalResult matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto loc = op->getLoc(); + Location loc = op.getLoc(); MLIRContext *context = rewriter.getContext(); - // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0) - // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo) - - Type intTy = IntegerType::get(context, 32); - Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32); - Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32); - Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, intTy, - ValueRange{minus1, zero}); - Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy, - ValueRange{minus1, mbcntLo}); + // convert to: + // %mlo = call noundef range(i32 0, 32) + // @llvm.amdgcn.mbcnt.lo(-1, 0) + // followed by: + // %lid = call noundef range(i32 0, 64) + // @llvm.amdgcn.mbcnt.hi(-1, %mlo) + + Value laneId = getLaneId(rewriter, loc); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); @@ -160,6 +171,38 @@ struct GPUSubgroupSizeOpToROCDL : ConvertOpToLLVMPattern<gpu::SubgroupSizeOp> { const amdgpu::Chipset chipset; }; +static bool isSupportedReadLaneType(Type type) { + // read(first)lane also supports some vector types, but limit it for scalars + // for now. + return type.isInteger(16) || type.isInteger(32) || type.isInteger(64) || + isa<Float16Type, BFloat16Type, Float32Type, Float64Type, + LLVM::LLVMPointerType>(type); +} + +struct GPUSubgroupBroadcastOpToROCDL + : public ConvertOpToLLVMPattern<gpu::SubgroupBroadcastOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(gpu::SubgroupBroadcastOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value src = adaptor.getSrc(); + if (!isSupportedReadLaneType(src.getType())) + return rewriter.notifyMatchFailure(op, "unsupported readlane type"); + + if (adaptor.getBroadcastType() == gpu::BroadcastType::specific_lane) { + rewriter.replaceOpWithNewOp<ROCDL::ReadlaneOp>(op, src.getType(), src, + adaptor.getLane()); + } else { // first_active_lane or any_lane + // any_lane is lowered to readfirstlane too, to force value into scalar + // register. + rewriter.replaceOpWithNewOp<ROCDL::ReadfirstlaneOp>(op, src.getType(), + src); + } + return success(); + } +}; + struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { using ConvertOpToLLVMPattern<gpu::ShuffleOp>::ConvertOpToLLVMPattern; @@ -185,8 +228,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> { Location loc = op->getLoc(); Value initShflValue = adaptor.getValue(); - const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); - Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth); + Value srcLaneId = getLaneId(rewriter, loc); auto int32Type = IntegerType::get(rewriter.getContext(), 32); Value width = adaptor.getWidth(); @@ -317,7 +359,7 @@ struct LowerGpuOpsToROCDLOpsPass final { RewritePatternSet patterns(ctx); populateGpuRewritePatterns(patterns); - populateGpuPromoteShuffleToAMDGPUPatterns(patterns); + populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset); (void)applyPatternsGreedily(m, std::move(patterns)); } @@ -453,7 +495,8 @@ void mlir::populateGpuToROCDLConversionPatterns( // TODO: Add alignment for workgroup memory patterns.add<GPUDynamicSharedMemoryOpLowering>(converter); - patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter); + patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL, + GPUSubgroupBroadcastOpToROCDL>(converter); patterns.add<GPUSubgroupSizeOpToROCDL>(converter, chipset); populateMathToROCDLConversionPatterns(converter, patterns); diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp index fce7a3f..522e914 100644 --- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp +++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp @@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc, results.push_back(d.memRefDescPtr(builder, loc)); } -void UnrankedMemRefDescriptor::computeSizes( +Value UnrankedMemRefDescriptor::computeSize( OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter, - ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces, - SmallVectorImpl<Value> &sizes) { - if (values.empty()) - return; - assert(values.size() == addressSpaces.size() && - "must provide address space for each descriptor"); + UnrankedMemRefDescriptor desc, unsigned addressSpace) { // Cache the index type. Type indexType = typeConverter.getIndexType(); @@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes( builder, loc, indexType, llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8)); - sizes.reserve(sizes.size() + values.size()); - for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) { - // Emit IR computing the memory necessary to store the descriptor. This - // assumes the descriptor to be - // { type*, type*, index, index[rank], index[rank] } - // and densely packed, so the total size is - // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). - // TODO: consider including the actual size (including eventual padding due - // to data layout) into the unranked descriptor. - Value pointerSize = createIndexAttrConstant( - builder, loc, indexType, - llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8)); - Value doublePointerSize = - LLVM::MulOp::create(builder, loc, indexType, two, pointerSize); - - // (1 + 2 * rank) * sizeof(index) - Value rank = desc.rank(builder, loc); - Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank); - Value doubleRankIncremented = - LLVM::AddOp::create(builder, loc, indexType, doubleRank, one); - Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType, - doubleRankIncremented, indexSize); - - // Total allocation size. - Value allocationSize = LLVM::AddOp::create( - builder, loc, indexType, doublePointerSize, rankIndexSize); - sizes.push_back(allocationSize); - } + // Emit IR computing the memory necessary to store the descriptor. This + // assumes the descriptor to be + // { type*, type*, index, index[rank], index[rank] } + // and densely packed, so the total size is + // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index). + // TODO: consider including the actual size (including eventual padding due + // to data layout) into the unranked descriptor. + Value pointerSize = createIndexAttrConstant( + builder, loc, indexType, + llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8)); + Value doublePointerSize = + LLVM::MulOp::create(builder, loc, indexType, two, pointerSize); + + // (1 + 2 * rank) * sizeof(index) + Value rank = desc.rank(builder, loc); + Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank); + Value doubleRankIncremented = + LLVM::AddOp::create(builder, loc, indexType, doubleRank, one); + Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType, + doubleRankIncremented, indexSize); + + // Total allocation size. + Value allocationSize = LLVM::AddOp::create(builder, loc, indexType, + doublePointerSize, rankIndexSize); + return allocationSize; } Value UnrankedMemRefDescriptor::allocatedPtr( diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp index 2568044..48a0319 100644 --- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp +++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp @@ -216,34 +216,14 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor( return memRefDescriptor; } -LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( - OpBuilder &builder, Location loc, TypeRange origTypes, - SmallVectorImpl<Value> &operands, bool toDynamic) const { - assert(origTypes.size() == operands.size() && - "expected as may original types as operands"); - - // Find operands of unranked memref type and store them. - SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs; - SmallVector<unsigned> unrankedAddressSpaces; - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) { - unrankedMemrefs.emplace_back(operands[i]); - FailureOr<unsigned> addressSpace = - getTypeConverter()->getMemRefAddressSpace(memRefType); - if (failed(addressSpace)) - return failure(); - unrankedAddressSpaces.emplace_back(*addressSpace); - } - } - - if (unrankedMemrefs.empty()) - return success(); - - // Compute allocation sizes. - SmallVector<Value> sizes; - UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(), - unrankedMemrefs, unrankedAddressSpaces, - sizes); +Value ConvertToLLVMPattern::copyUnrankedDescriptor( + OpBuilder &builder, Location loc, UnrankedMemRefType memRefType, + Value operand, bool toDynamic) const { + // Convert memory space. + FailureOr<unsigned> addressSpace = + getTypeConverter()->getMemRefAddressSpace(memRefType); + if (failed(addressSpace)) + return {}; // Get frequently used types. Type indexType = getTypeConverter()->getIndexType(); @@ -254,52 +234,61 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( if (toDynamic) { mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType); if (failed(mallocFunc)) - return failure(); + return {}; } if (!toDynamic) { freeFunc = LLVM::lookupOrCreateFreeFn(builder, module); if (failed(freeFunc)) - return failure(); + return {}; } - unsigned unrankedMemrefPos = 0; - for (unsigned i = 0, e = operands.size(); i < e; ++i) { - Type type = origTypes[i]; - if (!isa<UnrankedMemRefType>(type)) - continue; - Value allocationSize = sizes[unrankedMemrefPos++]; - UnrankedMemRefDescriptor desc(operands[i]); - - // Allocate memory, copy, and free the source if necessary. - Value memory = - toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(), - allocationSize) - .getResult() - : LLVM::AllocaOp::create(builder, loc, getPtrType(), - IntegerType::get(getContext(), 8), - allocationSize, - /*alignment=*/0); - Value source = desc.memRefDescPtr(builder, loc); - LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); - if (!toDynamic) - LLVM::CallOp::create(builder, loc, freeFunc.value(), source); - - // Create a new descriptor. The same descriptor can be returned multiple - // times, attempting to modify its pointer can lead to memory leaks - // (allocated twice and overwritten) or double frees (the caller does not - // know if the descriptor points to the same memory). - Type descriptorType = getTypeConverter()->convertType(type); - if (!descriptorType) - return failure(); - auto updatedDesc = - UnrankedMemRefDescriptor::poison(builder, loc, descriptorType); - Value rank = desc.rank(builder, loc); - updatedDesc.setRank(builder, loc, rank); - updatedDesc.setMemRefDescPtr(builder, loc, memory); + UnrankedMemRefDescriptor desc(operand); + Value allocationSize = UnrankedMemRefDescriptor::computeSize( + builder, loc, *getTypeConverter(), desc, *addressSpace); + + // Allocate memory, copy, and free the source if necessary. + Value memory = toDynamic + ? LLVM::CallOp::create(builder, loc, mallocFunc.value(), + allocationSize) + .getResult() + : LLVM::AllocaOp::create(builder, loc, getPtrType(), + IntegerType::get(getContext(), 8), + allocationSize, + /*alignment=*/0); + Value source = desc.memRefDescPtr(builder, loc); + LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false); + if (!toDynamic) + LLVM::CallOp::create(builder, loc, freeFunc.value(), source); + + // Create a new descriptor. The same descriptor can be returned multiple + // times, attempting to modify its pointer can lead to memory leaks + // (allocated twice and overwritten) or double frees (the caller does not + // know if the descriptor points to the same memory). + Type descriptorType = getTypeConverter()->convertType(memRefType); + if (!descriptorType) + return {}; + auto updatedDesc = + UnrankedMemRefDescriptor::poison(builder, loc, descriptorType); + Value rank = desc.rank(builder, loc); + updatedDesc.setRank(builder, loc, rank); + updatedDesc.setMemRefDescPtr(builder, loc, memory); + return updatedDesc; +} - operands[i] = updatedDesc; +LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors( + OpBuilder &builder, Location loc, TypeRange origTypes, + SmallVectorImpl<Value> &operands, bool toDynamic) const { + assert(origTypes.size() == operands.size() && + "expected as may original types as operands"); + for (unsigned i = 0, e = operands.size(); i < e; ++i) { + if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) { + Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType, + operands[i], toDynamic); + if (!updatedDesc) + return failure(); + operands[i] = updatedDesc; + } } - return success(); } diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp index 1a9bf56..cb9dea1 100644 --- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp +++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp @@ -365,6 +365,7 @@ Type LLVMTypeConverter::convertFunctionSignatureImpl( useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv; auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter : structFuncArgTypeConverter; + // Convert argument types one by one and check for errors. for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) { SmallVector<Type, 8> converted; @@ -658,27 +659,19 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const { /// UnrankedMemRefType, are converted following the specific rules for the /// calling convention. Calling convention independent types are converted /// following the default LLVM type conversions. -Type LLVMTypeConverter::convertCallingConventionType( - Type type, bool useBarePtrCallConv) const { - if (useBarePtrCallConv) - if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) - return convertMemRefToBarePtr(memrefTy); - - return convertType(type); -} +LogicalResult LLVMTypeConverter::convertCallingConventionType( + Type type, SmallVectorImpl<Type> &result, bool useBarePtrCallConv) const { + if (useBarePtrCallConv) { + if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) { + Type converted = convertMemRefToBarePtr(memrefTy); + if (!converted) + return failure(); + result.push_back(converted); + return success(); + } + } -/// Promote the bare pointers in 'values' that resulted from memrefs to -/// descriptors. 'stdTypes' holds they types of 'values' before the conversion -/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type). -void LLVMTypeConverter::promoteBarePtrsToDescriptors( - ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes, - SmallVectorImpl<Value> &values) const { - assert(stdTypes.size() == values.size() && - "The number of types and values doesn't match"); - for (unsigned i = 0, end = values.size(); i < end; ++i) - if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i])) - values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this, - memrefTy, values[i]); + return convertType(type, result); } /// Convert a non-empty list of types of values produced by an operation into an @@ -706,23 +699,35 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) const { /// LLVM-compatible type. In particular, if more than one value is returned, /// create an LLVM dialect structure type with elements that correspond to each /// of the types converted with `convertCallingConventionType`. -Type LLVMTypeConverter::packFunctionResults(TypeRange types, - bool useBarePtrCallConv) const { +Type LLVMTypeConverter::packFunctionResults( + TypeRange types, bool useBarePtrCallConv, + SmallVector<SmallVector<Type>> *groupedTypes, + int64_t *numConvertedTypes) const { assert(!types.empty() && "expected non-empty list of type"); + assert((!groupedTypes || groupedTypes->empty()) && + "expected groupedTypes to be empty"); useBarePtrCallConv |= options.useBarePtrCallConv; - if (types.size() == 1) - return convertCallingConventionType(types.front(), useBarePtrCallConv); - SmallVector<Type> resultTypes; resultTypes.reserve(types.size()); + size_t sizeBefore = 0; for (auto t : types) { - auto converted = convertCallingConventionType(t, useBarePtrCallConv); - if (!converted || !LLVM::isCompatibleType(converted)) + if (failed( + convertCallingConventionType(t, resultTypes, useBarePtrCallConv))) return {}; - resultTypes.push_back(converted); + if (groupedTypes) { + SmallVector<Type> &group = groupedTypes->emplace_back(); + llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore)); + } + sizeBefore = resultTypes.size(); } + if (numConvertedTypes) + *numConvertedTypes = resultTypes.size(); + if (resultTypes.size() == 1) + return resultTypes.front(); + if (resultTypes.empty()) + return {}; return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes); } @@ -740,40 +745,50 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand, return allocated; } -SmallVector<Value, 4> -LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands, - ValueRange operands, OpBuilder &builder, - bool useBarePtrCallConv) const { +SmallVector<Value, 4> LLVMTypeConverter::promoteOperands( + Location loc, ValueRange opOperands, ValueRange adaptorOperands, + OpBuilder &builder, bool useBarePtrCallConv) const { + SmallVector<ValueRange> ranges; + for (size_t i = 0, e = adaptorOperands.size(); i < e; i++) + ranges.push_back(adaptorOperands.slice(i, 1)); + return promoteOperands(loc, opOperands, ranges, builder, useBarePtrCallConv); +} + +SmallVector<Value, 4> LLVMTypeConverter::promoteOperands( + Location loc, ValueRange opOperands, ArrayRef<ValueRange> adaptorOperands, + OpBuilder &builder, bool useBarePtrCallConv) const { SmallVector<Value, 4> promotedOperands; - promotedOperands.reserve(operands.size()); + promotedOperands.reserve(adaptorOperands.size()); useBarePtrCallConv |= options.useBarePtrCallConv; - for (auto it : llvm::zip(opOperands, operands)) { - auto operand = std::get<0>(it); - auto llvmOperand = std::get<1>(it); - + for (auto [operand, llvmOperand] : + llvm::zip_equal(opOperands, adaptorOperands)) { if (useBarePtrCallConv) { // For the bare-ptr calling convention, we only have to extract the // aligned pointer of a memref. if (isa<MemRefType>(operand.getType())) { - MemRefDescriptor desc(llvmOperand); - llvmOperand = desc.alignedPtr(builder, loc); + assert(llvmOperand.size() == 1 && "Expected a single operand"); + MemRefDescriptor desc(llvmOperand.front()); + promotedOperands.push_back(desc.alignedPtr(builder, loc)); + continue; } else if (isa<UnrankedMemRefType>(operand.getType())) { llvm_unreachable("Unranked memrefs are not supported"); } } else { if (isa<UnrankedMemRefType>(operand.getType())) { - UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand, + assert(llvmOperand.size() == 1 && "Expected a single operand"); + UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand.front(), promotedOperands); continue; } if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) { - MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType, + assert(llvmOperand.size() == 1 && "Expected a single operand"); + MemRefDescriptor::unpack(builder, loc, llvmOperand.front(), memrefType, promotedOperands); continue; } } - promotedOperands.push_back(llvmOperand); + llvm::append_range(promotedOperands, llvmOperand); } return promotedOperands; } @@ -802,11 +817,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, result.append(converted.begin(), converted.end()); return success(); } - auto converted = converter.convertType(type); - if (!converted) - return failure(); - result.push_back(converted); - return success(); + return converter.convertType(type, result); } /// Callback to convert function argument types. It converts MemRef function @@ -814,11 +825,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, LogicalResult mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type, SmallVectorImpl<Type> &result) { - auto llvmTy = converter.convertCallingConventionType( - type, /*useBarePointerCallConv=*/true); - if (!llvmTy) - return failure(); - - result.push_back(llvmTy); - return success(); + return converter.convertCallingConventionType( + type, result, + /*useBarePointerCallConv=*/true); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index 6bd0e2d..2b7bdc9 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -17,11 +17,13 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" #include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" #include <cstdint> +#include <numeric> using namespace mlir; @@ -97,6 +99,48 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType, + OpBuilder &builder) { + assert(isMemRefTypeLegalForEmitC(memrefType) && + "incompatible memref type for EmitC conversion"); + emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create( + builder, loc, emitc::SizeTType::get(builder.getContext()), + builder.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(builder.getContext(), + {TypeAttr::get(memrefType.getElementType())})); + + IndexType indexType = builder.getIndexType(); + int64_t numElements = std::accumulate(memrefType.getShape().begin(), + memrefType.getShape().end(), int64_t{1}, + std::multiplies<int64_t>()); + emitc::ConstantOp numElementsValue = emitc::ConstantOp::create( + builder, loc, indexType, builder.getIndexAttr(numElements)); + + Type sizeTType = emitc::SizeTType::get(builder.getContext()); + emitc::MulOp totalSizeBytes = emitc::MulOp::create( + builder, loc, sizeTType, elementSize.getResult(0), numElementsValue); + + return totalSizeBytes.getResult(); +} + +static emitc::ApplyOp +createPointerFromEmitcArray(Location loc, OpBuilder &builder, + TypedValue<emitc::ArrayType> arrayValue) { + + emitc::ConstantOp zeroIndex = emitc::ConstantOp::create( + builder, loc, builder.getIndexType(), builder.getIndexAttr(0)); + + emitc::ArrayType arrayType = arrayValue.getType(); + llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex); + emitc::SubscriptOp subPtr = + emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices)); + emitc::ApplyOp ptr = emitc::ApplyOp::create( + builder, loc, emitc::PointerType::get(arrayType.getElementType()), + builder.getStringAttr("&"), subPtr); + + return ptr; +} + struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -112,19 +156,21 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); Type elementType = memrefType.getElementType(); IndexType indexType = rewriter.getIndexType(); - emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>( - loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, + emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create( + rewriter, loc, sizeTType, rewriter.getStringAttr("sizeof"), + ValueRange{}, ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); int64_t numElements = 1; for (int64_t dimSize : memrefType.getShape()) { numElements *= dimSize; } - Value numElementsValue = rewriter.create<emitc::ConstantOp>( - loc, indexType, rewriter.getIndexAttr(numElements)); + Value numElementsValue = emitc::ConstantOp::create( + rewriter, loc, indexType, rewriter.getIndexAttr(numElements)); - Value totalSizeBytes = rewriter.create<emitc::MulOp>( - loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); + Value totalSizeBytes = + emitc::MulOp::create(rewriter, loc, sizeTType, + sizeofElementOp.getResult(0), numElementsValue); emitc::CallOpaqueOp allocCall; StringAttr allocFunctionName; @@ -132,8 +178,8 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { SmallVector<Value, 2> argsVec; if (allocOp.getAlignment()) { allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); - alignmentValue = rewriter.create<emitc::ConstantOp>( - loc, sizeTType, + alignmentValue = emitc::ConstantOp::create( + rewriter, loc, sizeTType, rewriter.getIntegerAttr(indexType, allocOp.getAlignment().value_or(0))); argsVec.push_back(alignmentValue); @@ -144,21 +190,62 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { argsVec.push_back(totalSizeBytes); ValueRange args(argsVec); - allocCall = rewriter.create<emitc::CallOpaqueOp>( - loc, + allocCall = emitc::CallOpaqueOp::create( + rewriter, loc, emitc::PointerType::get( emitc::OpaqueType::get(rewriter.getContext(), "void")), allocFunctionName, args); emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); - emitc::CastOp castOp = rewriter.create<emitc::CastOp>( - loc, targetPointerType, allocCall.getResult(0)); + emitc::CastOp castOp = emitc::CastOp::create( + rewriter, loc, targetPointerType, allocCall.getResult(0)); rewriter.replaceOp(allocOp, castOp); return success(); } }; +struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = copyOp.getLoc(); + MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType()); + MemRefType targetMemrefType = + cast<MemRefType>(copyOp.getTarget().getType()); + + if (!isMemRefTypeLegalForEmitC(srcMemrefType)) + return rewriter.notifyMatchFailure( + loc, "incompatible source memref type for EmitC conversion"); + + if (!isMemRefTypeLegalForEmitC(targetMemrefType)) + return rewriter.notifyMatchFailure( + loc, "incompatible target memref type for EmitC conversion"); + + auto srcArrayValue = + cast<TypedValue<emitc::ArrayType>>(operands.getSource()); + emitc::ApplyOp srcPtr = + createPointerFromEmitcArray(loc, rewriter, srcArrayValue); + + auto targetArrayValue = + cast<TypedValue<emitc::ArrayType>>(operands.getTarget()); + emitc::ApplyOp targetPtr = + createPointerFromEmitcArray(loc, rewriter, targetArrayValue); + + emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create( + rewriter, loc, TypeRange{}, "memcpy", + ValueRange{ + targetPtr.getResult(), srcPtr.getResult(), + calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)}); + + rewriter.replaceOp(copyOp, memCpyCall.getResults()); + + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; @@ -320,6 +407,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, - ConvertLoad, ConvertStore>(converter, patterns.getContext()); + patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal, + ConvertGetGlobal, ConvertLoad, ConvertStore>( + converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index e78dd76..a073a9a 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -18,6 +18,8 @@ #include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/StringRef.h" namespace mlir { #define GEN_PASS_DEF_CONVERTMEMREFTOEMITC @@ -27,6 +29,15 @@ namespace mlir { using namespace mlir; namespace { + +emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module, + StringRef headerName) { + StringAttr includeAttr = builder.getStringAttr(headerName); + return emitc::IncludeOp::create( + builder, module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); +} + struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { using Base::Base; @@ -55,34 +66,29 @@ struct ConvertMemRefToEmitCPass return signalPassFailure(); mlir::ModuleOp module = getOperation(); + llvm::SmallSet<StringRef, 4> existingHeaders; + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + module.walk([&](mlir::emitc::IncludeOp includeOp) { + if (includeOp.getIsStandardInclude()) + existingHeaders.insert(includeOp.getInclude()); + }); + module.walk([&](mlir::emitc::CallOpaqueOp callOp) { - if (callOp.getCallee() != alignedAllocFunctionName && - callOp.getCallee() != mallocFunctionName) { + StringRef expectedHeader; + if (callOp.getCallee() == alignedAllocFunctionName || + callOp.getCallee() == mallocFunctionName) + expectedHeader = options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader; + else if (callOp.getCallee() == memcpyFunctionName) + expectedHeader = + options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader; + else return mlir::WalkResult::advance(); + if (!existingHeaders.contains(expectedHeader)) { + addStandardHeader(builder, module, expectedHeader); + existingHeaders.insert(expectedHeader); } - - for (auto &op : *module.getBody()) { - emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op); - if (!includeOp) { - continue; - } - if (includeOp.getIsStandardInclude() && - ((options.lowerToCpp && - includeOp.getInclude() == cppStandardLibraryHeader) || - (!options.lowerToCpp && - includeOp.getInclude() == cStandardLibraryHeader))) { - return mlir::WalkResult::interrupt(); - } - } - - mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); - StringAttr includeAttr = - builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader - : cStandardLibraryHeader); - builder.create<mlir::emitc::IncludeOp>( - module.getLoc(), includeAttr, - /*is_standard_include=*/builder.getUnitAttr()); - return mlir::WalkResult::interrupt(); + return mlir::WalkResult::advance(); }); } }; diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp index d6bdd34..262e0e7 100644 --- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp +++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp @@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering auto result = UnrankedMemRefDescriptor::poison( rewriter, loc, typeConverter->convertType(resultTypeU)); result.setRank(rewriter, loc, rank); - SmallVector<Value, 1> sizes; - UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), - result, resultAddrSpace, sizes); - Value resultUnderlyingSize = sizes.front(); + Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize( + rewriter, loc, *getTypeConverter(), result, resultAddrSpace); Value resultUnderlyingDesc = LLVM::AllocaOp::create(rewriter, loc, getPtrType(), rewriter.getI8Type(), resultUnderlyingSize); @@ -1530,12 +1528,11 @@ private: auto targetDesc = UnrankedMemRefDescriptor::poison( rewriter, loc, typeConverter->convertType(targetType)); targetDesc.setRank(rewriter, loc, resultRank); - SmallVector<Value, 4> sizes; - UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(), - targetDesc, addressSpace, sizes); + Value allocationSize = UnrankedMemRefDescriptor::computeSize( + rewriter, loc, *getTypeConverter(), targetDesc, addressSpace); Value underlyingDescPtr = LLVM::AllocaOp::create( rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8), - sizes.front()); + allocationSize); targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr); // Extract pointers and offset from the source memref. @@ -1872,6 +1869,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) { return LLVM::AtomicBinOp::umin; case arith::AtomicRMWKind::ori: return LLVM::AtomicBinOp::_or; + case arith::AtomicRMWKind::xori: + return LLVM::AtomicBinOp::_xor; case arith::AtomicRMWKind::andi: return LLVM::AtomicBinOp::_and; default: diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp index 2549a9c..37d12ba 100644 --- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp +++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp @@ -283,11 +283,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> { Value srcPtr = getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType, adaptor.getSrcMemref(), adaptor.getIndices()); + auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8); Value ldMatrixResult = NVVM::LdMatrixOp::create( b, ldMatrixResultType, srcPtr, /*num=*/op.getNumTiles(), /*layout=*/op.getTranspose() ? NVVM::MMALayout::col - : NVVM::MMALayout::row); + : NVVM::MMALayout::row, + /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16); // The ldmatrix operation returns either a single i32 value or a struct of // i32 values. Here we unpack those values and cast them back to their @@ -394,11 +396,6 @@ struct ConvertNVGPUToNVVMPass : public impl::ConvertNVGPUToNVVMPassBase<ConvertNVGPUToNVVMPass> { using Base::Base; - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect, - arith::ArithDialect>(); - } - void runOnOperation() override { LowerToLLVMOptions options(&getContext()); RewritePatternSet patterns(&getContext()); @@ -1029,8 +1026,10 @@ struct NVGPUTmaAsyncStoreOpLowering coords[index] = truncToI32(b, value); } + // TODO: Enhance the NVGPU Op for other modes too rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>( - op, adaptor.getTensorMapDescriptor(), dest, coords, + op, adaptor.getTensorMapDescriptor(), dest, coords, Value{}, + NVVM::TMAStoreMode::TILE, // default is TILE mode adaptor.getPredicate()); return success(); } @@ -1104,12 +1103,10 @@ struct NVGPUGenerateWarpgroupDescriptorLowering // // [0,14) start_address dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit); - LDBG() << "Generating warpgroup.descriptor: " - << "leading_off:" << leadDimVal << "\t" - << "stride_off :" << strideDimVal << "\t" - << "base_offset:" << offsetVal << "\t" - << "layout_type:" << swizzle << " (" - << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) + LDBG() << "Generating warpgroup.descriptor: " << "leading_off:" + << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t" + << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle + << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind) << ")\n start_addr : " << baseAddr; rewriter.replaceOp(op, dsc); @@ -1399,14 +1396,12 @@ struct NVGPUWarpgroupMmaOpLowering /// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix /// descriptors and arranges them based on induction variables: i, j, and k. Value generateWgmma(int i, int j, int k, Value matrixC) { - LDBG() << "\t wgmma." - << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A[" - << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM - << "][" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "] * " - << " B[" << (iterationK * wgmmaK) << ":" - << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN - << "])"; + LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK + << "(A[" << (iterationM * wgmmaM) << ":" + << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK) + << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B[" + << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK) + << "][" << 0 << ":" << wgmmaN << "])"; Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k); Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k); @@ -1700,8 +1695,10 @@ struct NVGPUTmaPrefetchOpLowering LogicalResult matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>( - op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate()); + rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>( + op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr, + adaptor.getTensorMapDescriptor(), adaptor.getPredicate(), + /* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext())); return success(); } }; diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp index 91788f9..314cbed 100644 --- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp +++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp @@ -26,6 +26,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "llvm/Support/DebugLog.h" +#include "llvm/Support/LogicalResult.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "nvvm-to-llvm" @@ -57,12 +58,13 @@ struct PtxLowering SmallVector<std::pair<Value, PTXRegisterMod>> asmValues; LDBG() << op.getPtx(); - PtxBuilder generator(op, rewriter); - op.getAsmValues(rewriter, asmValues); + bool needsManualMapping = op.getAsmValues(rewriter, asmValues); + PtxBuilder generator(op, rewriter, needsManualMapping); for (auto &[asmValue, modifier] : asmValues) { - LDBG() << asmValue << "\t Modifier : " << &modifier; - generator.insertValue(asmValue, modifier); + LDBG() << asmValue << "\t Modifier : " << modifier; + if (failed(generator.insertValue(asmValue, modifier))) + return failure(); } generator.buildAndReplaceOp(); diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp index 5bd1d49..d57926ec 100644 --- a/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp +++ b/mlir/lib/Conversion/PDLToPDLInterp/PredicateTree.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <queue> #define DEBUG_TYPE "pdl-predicate-tree" @@ -544,7 +545,7 @@ static void visitUpward(std::vector<PositionalPredicate> &predList, Value value = opIndex.parent; TypeSwitch<Operation *>(value.getDefiningOp()) .Case<pdl::OperationOp>([&](auto operationOp) { - LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + LDBG() << " * Value: " << value; // Get users and iterate over them. Position *usersPos = builder.getUsers(pos, /*useRepresentative=*/true); @@ -618,19 +619,15 @@ static Value buildPredicateList(pdl::PatternOp pattern, RootOrderingGraph graph; ParentMaps parentMaps; buildCostGraph(roots, graph, parentMaps); - LLVM_DEBUG({ - llvm::dbgs() << "Graph:\n"; - for (auto &target : graph) { - llvm::dbgs() << " * " << target.first.getLoc() << " " << target.first - << "\n"; - for (auto &source : target.second) { - RootOrderingEntry &entry = source.second; - llvm::dbgs() << " <- " << source.first << ": " << entry.cost.first - << ":" << entry.cost.second << " via " - << entry.connector.getLoc() << "\n"; - } + LDBG() << "Graph:"; + for (auto &target : graph) { + LDBG() << " * " << target.first.getLoc() << " " << target.first; + for (auto &source : target.second) { + RootOrderingEntry &entry = source.second; + LDBG() << " <- " << source.first << ": " << entry.cost.first << ":" + << entry.cost.second << " via " << entry.connector.getLoc(); } - }); + } // Solve the optimal branching problem for each candidate root, or use the // provided one. @@ -638,11 +635,11 @@ static Value buildPredicateList(pdl::PatternOp pattern, OptimalBranching::EdgeList bestEdges; if (!bestRoot) { unsigned bestCost = 0; - LLVM_DEBUG(llvm::dbgs() << "Candidate roots:\n"); + LDBG() << "Candidate roots:"; for (Value root : roots) { OptimalBranching solver(graph, root); unsigned cost = solver.solve(); - LLVM_DEBUG(llvm::dbgs() << " * " << root << ": " << cost << "\n"); + LDBG() << " * " << root << ": " << cost; if (!bestRoot || bestCost > cost) { bestCost = cost; bestRoot = root; @@ -656,18 +653,15 @@ static Value buildPredicateList(pdl::PatternOp pattern, } // Print the best solution. - LLVM_DEBUG({ - llvm::dbgs() << "Best tree:\n"; - for (const std::pair<Value, Value> &edge : bestEdges) { - llvm::dbgs() << " * " << edge.first; - if (edge.second) - llvm::dbgs() << " <- " << edge.second; - llvm::dbgs() << "\n"; - } - }); + LDBG() << "Best tree:"; + for (const std::pair<Value, Value> &edge : bestEdges) { + if (edge.second) + LDBG() << " * " << edge.first << " <- " << edge.second; + else + LDBG() << " * " << edge.first; + } - LLVM_DEBUG(llvm::dbgs() << "Calling key getTreePredicates:\n"); - LLVM_DEBUG(llvm::dbgs() << " * Value: " << bestRoot << "\n"); + LDBG() << "Calling key getTreePredicates (Value: " << bestRoot << ")"; // The best root is the starting point for the traversal. Get the tree // predicates for the DAG rooted at bestRoot. @@ -691,7 +685,7 @@ static Value buildPredicateList(pdl::PatternOp pattern, // Determine the connector. Value connector = graph[target][source].connector; assert(connector && "invalid edge"); - LLVM_DEBUG(llvm::dbgs() << " * Connector: " << connector.getLoc() << "\n"); + LDBG() << " * Connector: " << connector.getLoc(); DenseMap<Value, OpIndex> parentMap = parentMaps.lookup(target); Position *pos = valueToPosition.lookup(connector); assert(pos && "connector has not been traversed yet"); @@ -806,9 +800,9 @@ static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { /// Get or insert a child matcher for the given parent switch node, given a /// predicate and parent pattern. -std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node, - OrderedPredicate *predicate, - pdl::PatternOp pattern) { +static std::unique_ptr<MatcherNode> & +getOrCreateChild(SwitchNode *node, OrderedPredicate *predicate, + pdl::PatternOp pattern) { assert(isSamePredicate(node, predicate) && "expected matcher to equal the given predicate"); diff --git a/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt new file mode 100644 index 0000000..2d416be --- /dev/null +++ b/mlir/lib/Conversion/PtrToLLVM/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_conversion_library(MLIRPtrToLLVM + PtrToLLVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/PtrToLLVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRPtrDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + ) diff --git a/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp new file mode 100644 index 0000000..a0758aa --- /dev/null +++ b/mlir/lib/Conversion/PtrToLLVM/PtrToLLVM.cpp @@ -0,0 +1,440 @@ +//===- PtrToLLVM.cpp - Ptr to LLVM dialect conversion ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h" + +#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h" +#include "mlir/Conversion/LLVMCommon/ConversionTarget.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Conversion/LLVMCommon/TypeConverter.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/IR/TypeUtilities.h" +#include <type_traits> + +using namespace mlir; + +namespace { +//===----------------------------------------------------------------------===// +// FromPtrOpConversion +//===----------------------------------------------------------------------===// +struct FromPtrOpConversion : public ConvertOpToLLVMPattern<ptr::FromPtrOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ptr::FromPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +//===----------------------------------------------------------------------===// +// GetMetadataOpConversion +//===----------------------------------------------------------------------===// +struct GetMetadataOpConversion + : public ConvertOpToLLVMPattern<ptr::GetMetadataOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ptr::GetMetadataOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +//===----------------------------------------------------------------------===// +// PtrAddOpConversion +//===----------------------------------------------------------------------===// +struct PtrAddOpConversion : public ConvertOpToLLVMPattern<ptr::PtrAddOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +//===----------------------------------------------------------------------===// +// ToPtrOpConversion +//===----------------------------------------------------------------------===// +struct ToPtrOpConversion : public ConvertOpToLLVMPattern<ptr::ToPtrOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; + +//===----------------------------------------------------------------------===// +// TypeOffsetOpConversion +//===----------------------------------------------------------------------===// +struct TypeOffsetOpConversion + : public ConvertOpToLLVMPattern<ptr::TypeOffsetOp> { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(ptr::TypeOffsetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Internal functions +//===----------------------------------------------------------------------===// + +// Function to create an LLVM struct type representing a memref metadata. +static FailureOr<LLVM::LLVMStructType> +createMemRefMetadataType(MemRefType type, + const LLVMTypeConverter &typeConverter) { + MLIRContext *context = type.getContext(); + // Get the address space. + FailureOr<unsigned> addressSpace = typeConverter.getMemRefAddressSpace(type); + if (failed(addressSpace)) + return failure(); + + // Get pointer type (using address space 0 by default) + auto ptrType = LLVM::LLVMPointerType::get(context, *addressSpace); + + // Get the strides offsets and shape. + SmallVector<int64_t> strides; + int64_t offset; + if (failed(type.getStridesAndOffset(strides, offset))) + return failure(); + ArrayRef<int64_t> shape = type.getShape(); + + // Use index type from the type converter for the descriptor elements + Type indexType = typeConverter.getIndexType(); + + // For a ranked memref, the descriptor contains: + // 1. The pointer to the allocated data + // 2. The pointer to the aligned data + // 3. The dynamic offset? + // 4. The dynamic sizes? + // 5. The dynamic strides? + SmallVector<Type, 5> elements; + + // Allocated pointer. + elements.push_back(ptrType); + + // Potentially add the dynamic offset. + if (offset == ShapedType::kDynamic) + elements.push_back(indexType); + + // Potentially add the dynamic sizes. + for (int64_t dim : shape) { + if (dim == ShapedType::kDynamic) + elements.push_back(indexType); + } + + // Potentially add the dynamic strides. + for (int64_t stride : strides) { + if (stride == ShapedType::kDynamic) + elements.push_back(indexType); + } + return LLVM::LLVMStructType::getLiteral(context, elements); +} + +//===----------------------------------------------------------------------===// +// FromPtrOpConversion +//===----------------------------------------------------------------------===// + +LogicalResult FromPtrOpConversion::matchAndRewrite( + ptr::FromPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get the target memref type + auto mTy = dyn_cast<MemRefType>(op.getResult().getType()); + if (!mTy) + return rewriter.notifyMatchFailure(op, "Expected memref result type"); + + if (!op.getMetadata() && op.getType().hasPtrMetadata()) { + return rewriter.notifyMatchFailure( + op, "Can convert only memrefs with metadata"); + } + + // Convert the result type + Type descriptorTy = getTypeConverter()->convertType(mTy); + if (!descriptorTy) + return rewriter.notifyMatchFailure(op, "Failed to convert result type"); + + // Get the strides, offsets and shape. + SmallVector<int64_t> strides; + int64_t offset; + if (failed(mTy.getStridesAndOffset(strides, offset))) { + return rewriter.notifyMatchFailure(op, + "Failed to get the strides and offset"); + } + ArrayRef<int64_t> shape = mTy.getShape(); + + // Create a new memref descriptor + Location loc = op.getLoc(); + auto desc = MemRefDescriptor::poison(rewriter, loc, descriptorTy); + + // Set the allocated and aligned pointers. + desc.setAllocatedPtr( + rewriter, loc, + rewriter.create<LLVM::ExtractValueOp>(loc, adaptor.getMetadata(), 0)); + desc.setAlignedPtr(rewriter, loc, adaptor.getPtr()); + + // Extract metadata from the passed struct. + unsigned fieldIdx = 1; + + // Set dynamic offset if needed. + if (offset == ShapedType::kDynamic) { + Value offsetValue = rewriter.create<LLVM::ExtractValueOp>( + loc, adaptor.getMetadata(), fieldIdx++); + desc.setOffset(rewriter, loc, offsetValue); + } else { + desc.setConstantOffset(rewriter, loc, offset); + } + + // Set dynamic sizes if needed. + for (auto [i, dim] : llvm::enumerate(shape)) { + if (dim == ShapedType::kDynamic) { + Value sizeValue = rewriter.create<LLVM::ExtractValueOp>( + loc, adaptor.getMetadata(), fieldIdx++); + desc.setSize(rewriter, loc, i, sizeValue); + } else { + desc.setConstantSize(rewriter, loc, i, dim); + } + } + + // Set dynamic strides if needed. + for (auto [i, stride] : llvm::enumerate(strides)) { + if (stride == ShapedType::kDynamic) { + Value strideValue = rewriter.create<LLVM::ExtractValueOp>( + loc, adaptor.getMetadata(), fieldIdx++); + desc.setStride(rewriter, loc, i, strideValue); + } else { + desc.setConstantStride(rewriter, loc, i, stride); + } + } + + rewriter.replaceOp(op, static_cast<Value>(desc)); + return success(); +} + +//===----------------------------------------------------------------------===// +// GetMetadataOpConversion +//===----------------------------------------------------------------------===// + +LogicalResult GetMetadataOpConversion::matchAndRewrite( + ptr::GetMetadataOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto mTy = dyn_cast<MemRefType>(op.getPtr().getType()); + if (!mTy) + return rewriter.notifyMatchFailure(op, "Only memref metadata is supported"); + + // Get the metadata type. + FailureOr<LLVM::LLVMStructType> mdTy = + createMemRefMetadataType(mTy, *getTypeConverter()); + if (failed(mdTy)) { + return rewriter.notifyMatchFailure(op, + "Failed to create the metadata type"); + } + + // Get the memref descriptor. + MemRefDescriptor descriptor(adaptor.getPtr()); + + // Get the strides offsets and shape. + SmallVector<int64_t> strides; + int64_t offset; + if (failed(mTy.getStridesAndOffset(strides, offset))) { + return rewriter.notifyMatchFailure(op, + "Failed to get the strides and offset"); + } + ArrayRef<int64_t> shape = mTy.getShape(); + + // Create a new LLVM struct to hold the metadata + Location loc = op.getLoc(); + Value sV = rewriter.create<LLVM::UndefOp>(loc, *mdTy); + + // First element is the allocated pointer. + sV = rewriter.create<LLVM::InsertValueOp>( + loc, sV, descriptor.allocatedPtr(rewriter, loc), 0); + + // Track the current field index. + unsigned fieldIdx = 1; + + // Add dynamic offset if needed. + if (offset == ShapedType::kDynamic) { + sV = rewriter.create<LLVM::InsertValueOp>( + loc, sV, descriptor.offset(rewriter, loc), fieldIdx++); + } + + // Add dynamic sizes if needed. + for (auto [i, dim] : llvm::enumerate(shape)) { + if (dim != ShapedType::kDynamic) + continue; + sV = rewriter.create<LLVM::InsertValueOp>( + loc, sV, descriptor.size(rewriter, loc, i), fieldIdx++); + } + + // Add dynamic strides if needed + for (auto [i, stride] : llvm::enumerate(strides)) { + if (stride != ShapedType::kDynamic) + continue; + sV = rewriter.create<LLVM::InsertValueOp>( + loc, sV, descriptor.stride(rewriter, loc, i), fieldIdx++); + } + rewriter.replaceOp(op, sV); + return success(); +} + +//===----------------------------------------------------------------------===// +// PtrAddOpConversion +//===----------------------------------------------------------------------===// + +LogicalResult +PtrAddOpConversion::matchAndRewrite(ptr::PtrAddOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get and check the base. + Value base = adaptor.getBase(); + if (!isa<LLVM::LLVMPointerType>(base.getType())) + return rewriter.notifyMatchFailure(op, "Incompatible pointer type"); + + // Get the offset. + Value offset = adaptor.getOffset(); + + // Ptr assumes the offset is in bytes. + Type elementType = IntegerType::get(rewriter.getContext(), 8); + + // Convert the `ptradd` flags. + LLVM::GEPNoWrapFlags flags; + switch (op.getFlags()) { + case ptr::PtrAddFlags::none: + flags = LLVM::GEPNoWrapFlags::none; + break; + case ptr::PtrAddFlags::nusw: + flags = LLVM::GEPNoWrapFlags::nusw; + break; + case ptr::PtrAddFlags::nuw: + flags = LLVM::GEPNoWrapFlags::nuw; + break; + case ptr::PtrAddFlags::inbounds: + flags = LLVM::GEPNoWrapFlags::inbounds; + break; + } + + // Create the GEP operation with appropriate arguments + rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, base.getType(), elementType, + base, ValueRange{offset}, flags); + return success(); +} + +//===----------------------------------------------------------------------===// +// ToPtrOpConversion +//===----------------------------------------------------------------------===// + +LogicalResult +ToPtrOpConversion::matchAndRewrite(ptr::ToPtrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Bail if it's not a memref. + if (!isa<MemRefType>(op.getPtr().getType())) + return rewriter.notifyMatchFailure(op, "Expected a memref input"); + + // Extract the aligned pointer from the memref descriptor. + rewriter.replaceOp( + op, MemRefDescriptor(adaptor.getPtr()).alignedPtr(rewriter, op.getLoc())); + return success(); +} + +//===----------------------------------------------------------------------===// +// TypeOffsetOpConversion +//===----------------------------------------------------------------------===// + +LogicalResult TypeOffsetOpConversion::matchAndRewrite( + ptr::TypeOffsetOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Convert the type attribute. + Type type = getTypeConverter()->convertType(op.getElementType()); + if (!type) + return rewriter.notifyMatchFailure(op, "Couldn't convert the type"); + + // Convert the result type. + Type rTy = getTypeConverter()->convertType(op.getResult().getType()); + if (!rTy) + return rewriter.notifyMatchFailure(op, "Couldn't convert the result type"); + + // TODO: Use MLIR's data layout. We don't use it because overall support is + // still flaky. + + // Create an LLVM pointer type for the GEP operation. + auto ptrTy = LLVM::LLVMPointerType::get(getContext()); + + // Create a GEP operation to compute the offset of the type. + auto offset = + LLVM::GEPOp::create(rewriter, op.getLoc(), ptrTy, type, + LLVM::ZeroOp::create(rewriter, op.getLoc(), ptrTy), + ArrayRef<LLVM::GEPArg>({LLVM::GEPArg(1)})); + + // Replace the original op with a PtrToIntOp using the computed offset. + rewriter.replaceOpWithNewOp<LLVM::PtrToIntOp>(op, rTy, offset.getRes()); + return success(); +} + +//===----------------------------------------------------------------------===// +// ConvertToLLVMPatternInterface implementation +//===----------------------------------------------------------------------===// + +namespace { +/// Implement the interface to convert Ptr to LLVM. +struct PtrToLLVMDialectInterface : public ConvertToLLVMPatternInterface { + using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface; + void loadDependentDialects(MLIRContext *context) const final { + context->loadDialect<LLVM::LLVMDialect>(); + } + + /// Hook for derived dialect interface to provide conversion patterns + /// and mark dialect legal for the conversion target. + void populateConvertToLLVMConversionPatterns( + ConversionTarget &target, LLVMTypeConverter &converter, + RewritePatternSet &patterns) const final { + ptr::populatePtrToLLVMConversionPatterns(converter, patterns); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// API +//===----------------------------------------------------------------------===// + +void mlir::ptr::populatePtrToLLVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { + // Add address space conversions. + converter.addTypeAttributeConversion( + [&](PtrLikeTypeInterface type, ptr::GenericSpaceAttr memorySpace) + -> TypeConverter::AttributeConversionResult { + if (type.getMemorySpace() != memorySpace) + return TypeConverter::AttributeConversionResult::na(); + return IntegerAttr::get(IntegerType::get(type.getContext(), 32), 0); + }); + + // Add type conversions. + converter.addConversion([&](ptr::PtrType type) -> Type { + std::optional<Attribute> maybeAttr = + converter.convertTypeAttribute(type, type.getMemorySpace()); + auto memSpace = + maybeAttr ? dyn_cast_or_null<IntegerAttr>(*maybeAttr) : IntegerAttr(); + if (!memSpace) + return {}; + return LLVM::LLVMPointerType::get(type.getContext(), + memSpace.getValue().getSExtValue()); + }); + + // Convert ptr metadata of memref type. + converter.addConversion([&](ptr::PtrMetadataType type) -> Type { + auto mTy = dyn_cast<MemRefType>(type.getType()); + if (!mTy) + return {}; + FailureOr<LLVM::LLVMStructType> res = + createMemRefMetadataType(mTy, converter); + return failed(res) ? Type() : res.value(); + }); + + // Add conversion patterns. + patterns.add<FromPtrOpConversion, GetMetadataOpConversion, PtrAddOpConversion, + ToPtrOpConversion, TypeOffsetOpConversion>(converter); +} + +void mlir::ptr::registerConvertPtrToLLVMInterface(DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) { + dialect->addInterfaces<PtrToLLVMDialectInterface>(); + }); +} diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index ba448e4..37cfc9f 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -382,8 +382,11 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, // With the body block done, we can fill in the condition block. rewriter.setInsertionPointToEnd(conditionBlock); - auto comparison = arith::CmpIOp::create( - rewriter, loc, arith::CmpIPredicate::slt, iv, upperBound); + arith::CmpIPredicate predicate = forOp.getUnsignedCmp() + ? arith::CmpIPredicate::ult + : arith::CmpIPredicate::slt; + auto comparison = + arith::CmpIOp::create(rewriter, loc, predicate, iv, upperBound); cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock, ArrayRef<Value>(), endBlock, ArrayRef<Value>()); diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 84cbd86..1f239aa 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -154,6 +154,10 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = forOp.getLoc(); + if (forOp.getUnsignedCmp()) + return rewriter.notifyMatchFailure(forOp, + "unsigned loops are not supported"); + // Create an emitc::variable op for each result. These variables will be // assigned to by emitc::assign ops within the loop body. SmallVector<Value> resultVariables; diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp index badd2f6..7d0a236 100644 --- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp +++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp @@ -27,7 +27,7 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/RegionUtils.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <optional> #define DEBUG_TYPE "loops-to-gpu" @@ -134,7 +134,7 @@ static LogicalResult checkAffineLoopNestMappable(AffineForOp forOp, unsigned numBlockDims, unsigned numThreadDims) { if (numBlockDims < 1 || numThreadDims < 1) { - LLVM_DEBUG(llvm::dbgs() << "nothing to map"); + LDBG() << "nothing to map"; return success(); } diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 34f372a..c4a9fc2 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -22,7 +22,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" namespace mlir { #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS @@ -538,15 +538,18 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /// Applies the conversion patterns in the given function. static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) { - ConversionTarget target(*module.getContext()); - target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(); - target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect, - memref::MemRefDialect>(); - RewritePatternSet patterns(module.getContext()); patterns.add<ParallelOpLowering>(module.getContext(), numThreads); FrozenRewritePatternSet frozen(std::move(patterns)); - return applyPartialConversion(module, target, frozen); + walkAndApplyPatterns(module, frozen); + auto status = module.walk([](Operation *op) { + if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) { + op->emitError("unconverted operation found"); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return failure(status.wasInterrupted()); } /// A pass converting SCF operations to OpenMP operations. diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp index dc92367f..55ed31e 100644 --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -178,8 +178,14 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> { // Generate the rest of the loop header. rewriter.setInsertionPointToEnd(header); auto *mergeBlock = loopOp.getMergeBlock(); - auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(), - newIndVar, adaptor.getUpperBound()); + Value cmpOp; + if (forOp.getUnsignedCmp()) { + cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.getI1Type(), + newIndVar, adaptor.getUpperBound()); + } else { + cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(), + newIndVar, adaptor.getUpperBound()); + } spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index fa9e544..398ab88 100644 --- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -301,7 +301,7 @@ public: ConversionPatternRewriter &rewriter) const override { // Create mpi::CommRankOp Location loc = op.getLoc(); - auto ctx = op.getContext(); + auto *ctx = op.getContext(); Value commWorld = mpi::CommWorldOp::create(rewriter, loc, mpi::CommType::get(ctx)); auto rank = mpi::CommRankOp::create( @@ -520,7 +520,7 @@ struct ConvertShardShapeOp : public OpConversionPattern<ShardShapeOp> { }; static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) { - auto ctx = kind.getContext(); + auto *ctx = kind.getContext(); auto getReductionOp = [ctx](mpi::MPI_ReductionOpEnum redOp) { return mpi::MPI_ReductionOpEnumAttr::get(ctx, redOp); }; diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp index 044b725..e568660 100644 --- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp +++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp @@ -64,8 +64,9 @@ public: LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, PatternRewriter &rewriter) const final { - StringRef roundingMode = op.getRoundingMode(); - if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") { + RoundingMode roundingMode = op.getRoundingMode(); + if (roundingMode != RoundingMode::DOUBLE_ROUND && + roundingMode != RoundingMode::SINGLE_ROUND) { return failure(); } @@ -100,7 +101,7 @@ public: multiply64 = arith::AddIOp::create(rewriter, loc, multiply64, round); // Apply double rounding if necessary. - if (op.getRoundingMode() == "DOUBLE_ROUND") { + if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) { int64_t roundInt = 1 << 30; Value roundUp = getConstantValue(loc, i64Ty, roundInt, rewriter); Value roundDown = getConstantValue(loc, i64Ty, -roundInt, rewriter); @@ -129,8 +130,9 @@ public: LogicalResult matchAndRewrite(tosa::ApplyScaleOp op, PatternRewriter &rewriter) const final { - StringRef roundingMode = op.getRoundingMode(); - if (roundingMode != "DOUBLE_ROUND" && roundingMode != "SINGLE_ROUND") { + RoundingMode roundingMode = op.getRoundingMode(); + if (roundingMode != RoundingMode::DOUBLE_ROUND && + roundingMode != RoundingMode::SINGLE_ROUND) { return failure(); } @@ -179,7 +181,7 @@ public: arith::SelectOp::create(rewriter, loc, shiftOver32, shiftHighR, zero32); // Conditionally perform our double round. - if (op.getRoundingMode() == "DOUBLE_ROUND") { + if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND) { Value negOne32 = getConstantValue(loc, i32Ty, -1, rewriter); Value valuePositive = arith::CmpIOp::create( rewriter, loc, arith::CmpIPredicate::sge, value32, zero32); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 0e3de06..d0a431b 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -65,7 +65,7 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter, return result; auto nanMode = op.getNanMode(); - if (nanMode == "PROPAGATE") + if (nanMode == NanPropagationMode::PROPAGATE) return result; // Unordered comparison of NaN against itself will always return true. @@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp( if (isa<tosa::MulOp>(op)) { auto shiftVal = cast<tosa::MulOp>(op).getShift(); DenseElementsAttr shiftElem; - if (!matchPattern(shiftVal, m_Constant(&shiftElem))) { - (void)rewriter.notifyMatchFailure(op, "shift value of mul not found"); - return nullptr; - } - - int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt(); + bool shiftIsConstant = true; + int32_t shift = 0; + if (matchPattern(shiftVal, m_Constant(&shiftElem))) + shift = shiftElem.getValues<IntegerAttr>()[0].getInt(); + else + shiftIsConstant = false; if (isa<FloatType>(elementTy)) { if (shift != 0) { @@ -147,23 +147,26 @@ static Value createLinalgBodyCalculationForElementwiseOp( Value a = args[0]; Value b = args[1]; - if (shift > 0) { - auto shiftConst = - arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8); + if (shift > 0 || !shiftIsConstant) { + Value shiftConst; + if (shiftIsConstant) + shiftConst = + rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8); + if (!a.getType().isInteger(32)) a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a); if (!b.getType().isInteger(32)) b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b); - auto result = tosa::ApplyScaleOp::create( - rewriter, loc, rewriter.getI32Type(), a, b, shiftConst, - rewriter.getStringAttr("SINGLE_ROUND")); - - if (elementTy.isInteger(32)) - return result; + auto shiftAmount = shiftIsConstant ? shiftConst : args[2]; + auto roundingAttr = RoundingModeAttr::get(rewriter.getContext(), + RoundingMode::SINGLE_ROUND); + auto result = + tosa::ApplyScaleOp::create(rewriter, loc, rewriter.getI32Type(), a, + b, shiftAmount, roundingAttr); - return arith::TruncIOp::create(rewriter, loc, elementTy, result); + return result; } int aWidth = a.getType().getIntOrFloatBitWidth(); @@ -464,7 +467,7 @@ static Value createLinalgBodyCalculationForElementwiseOp( // In the case of "PROPAGATE" semantics no compare and selection is // required. - if (nanMode == "PROPAGATE") + if (nanMode == NanPropagationMode::PROPAGATE) return result; // In the case of "IGNORE" semantics materialize a comparison @@ -918,6 +921,18 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, if (operands.size() == 1) return operands; + // No need to broadcast for static shape + bool hasDynamic = false; + for (auto op : operands) { + const auto tType = dyn_cast<RankedTensorType>(op.getType()); + if (tType && !tType.hasStaticShape()) { + hasDynamic = true; + break; + } + } + if (!hasDynamic) + return operands; + // Broadcast dynamic dimensions operand by operand return llvm::map_to_vector(operands, [&](Value operand) { return broadcastDynamicDimensions(rewriter, loc, indexPool, operand, @@ -990,8 +1005,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, static ValueRange getBroadcastableOperands(Operation *operation, ValueRange operands) { // Shift cannot broadcast - if (isa<tosa::MulOp>(operation)) - return operands.take_front(2); + if (isa<tosa::MulOp>(operation)) { + DenseElementsAttr shiftElems; + // Shift cannot broadcast when it is constant + if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems))) + return operands.take_front(2); + else + return operands.take_front(3); + } // Input1_zp and output_zp cannot broadcast if (isa<tosa::NegateOp>(operation)) return operands.take_front(1); @@ -1173,7 +1194,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, if constexpr (std::is_same_v<OpTy, tosa::ReduceMinOp> || std::is_same_v<OpTy, tosa::ReduceMaxOp>) { // NaN propagation has no meaning for non floating point types. - if (isa<FloatType>(elementTy) && op.getNanMode() == "IGNORE") { + if (isa<FloatType>(elementTy) && + op.getNanMode() == NanPropagationMode::IGNORE) { isNanIgnoreMode = true; // Because the TOSA spec requires the result be NaN iff all elements in // the reduction are NaN we can't simply perform a compare and select. @@ -1336,11 +1358,11 @@ public: unsigned rank = inputTy.getRank(); // This is an illegal configuration. terminate and log an error - if (op.getRoundingMode() == "INEXACT_ROUND") + if (op.getRoundingMode() == RoundingMode::INEXACT_ROUND) return rewriter.notifyMatchFailure( op, "tosa.rescale with rounding mode = 'INEXACT_ROUND' is not " "currently supported"); - if (op.getRoundingMode() == "DOUBLE_ROUND" && !op.getScale32()) + if (op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !op.getScale32()) return rewriter.notifyMatchFailure( op, "tosa.rescale requires scale32 for double_round to be true"); @@ -1386,11 +1408,10 @@ public: // is ever true. bool doubleRound = - op.getRoundingMode() == "DOUBLE_ROUND" && + op.getRoundingMode() == RoundingMode::DOUBLE_ROUND && llvm::any_of(shiftValues, [](int32_t v) { return v > 31; }); - StringAttr roundingMode = doubleRound - ? rewriter.getStringAttr("DOUBLE_ROUND") - : rewriter.getStringAttr("SINGLE_ROUND"); + RoundingMode roundingMode = + doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND; SmallVector<AffineMap> indexingMaps = { rewriter.getMultiDimIdentityMap(rank)}; @@ -1573,7 +1594,7 @@ public: auto input = op.getInput(); auto inputTy = cast<RankedTensorType>(input.getType()); auto resultTy = cast<RankedTensorType>(op.getType()); - const bool isBilinear = op.getMode() == "BILINEAR"; + const bool isBilinear = op.getMode() == ResizeMode::BILINEAR; auto inputH = inputTy.getDimSize(1); auto inputW = inputTy.getDimSize(2); @@ -1584,8 +1605,8 @@ public: return rewriter.notifyMatchFailure( op, "tosa.resize is not a pure 1x1->1x1 image operation"); - // TODO(suderman): These string values should be declared the TOSA dialect. - if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR") + if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR && + op.getMode() != ResizeMode::BILINEAR) return rewriter.notifyMatchFailure( op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR"); @@ -1785,7 +1806,8 @@ public: return rewriter.notifyMatchFailure( op, "unable to get dynamic dimensions of tosa.resize"); - if (op.getMode() != "NEAREST_NEIGHBOR" && op.getMode() != "BILINEAR") + if (op.getMode() != ResizeMode::NEAREST_NEIGHBOR && + op.getMode() != ResizeMode::BILINEAR) return rewriter.notifyMatchFailure( op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR"); @@ -1890,7 +1912,7 @@ public: getIndexAndDeltaInt(ix, dx, inX, xScaleN, xScaleD, xOffset, imageW, b); } - if (op.getMode() == "NEAREST_NEIGHBOR") { + if (op.getMode() == ResizeMode::NEAREST_NEIGHBOR) { auto one = arith::ConstantOp::create(b, b.getI32IntegerAttr(1)); auto getNearestIndexAndClamp = [&](Value val, Value dval, Value scale, @@ -1926,7 +1948,7 @@ public: linalg::YieldOp::create(b, result); } else { // The mode here must be BILINEAR. - assert(op.getMode() == "BILINEAR"); + assert(op.getMode() == ResizeMode::BILINEAR); auto oneVal = arith::ConstantOp::create(b, b.getI32IntegerAttr(1)); @@ -2291,7 +2313,7 @@ public: Value predicate; if (isa<FloatType>(inElementTy)) { - if (argmaxOp.getNanMode() == "IGNORE") { + if (argmaxOp.getNanMode() == NanPropagationMode::IGNORE) { // Only update index & max value for non NaN values. If all // values are NaNs, the initial index will be return which is 0. predicate = arith::CmpFOp::create(rewriter, nestedLoc, diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 12d85ca..6f28849 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -803,7 +803,7 @@ public: dilationAttr); rewriter.setInsertionPointAfter(op); - StringRef nanMode = op.getNanMode(); + NanPropagationMode nanMode = op.getNanMode(); rewriter.replaceOp(op, resultOp); // NaN propagation has no meaning for non floating point types. @@ -817,7 +817,7 @@ public: // we've already produced a named op we will just take its body and modify // it to include the appropriate checks. If the current value is NaN the // old value of pool will be taken otherwise we use the result. - if (nanMode == "IGNORE") { + if (nanMode == NanPropagationMode::IGNORE) { auto genericOp = linalg::GenericOp::create( rewriter, loc, resultOp.getType(0), resultOp.getInputs(), resultOp.getOutputs(), resultOp.getIndexingMapsArray(), @@ -1040,11 +1040,13 @@ public: rewriter, loc, rewriter.getI8IntegerAttr(30)); Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8); - auto scaled = - tosa::ApplyScaleOp::create( - rewriter, loc, rewriter.getI32Type(), poolVal, multiplier, - shift, rewriter.getStringAttr("SINGLE_ROUND")) - .getResult(); + auto roundingAttr = RoundingModeAttr::get( + rewriter.getContext(), RoundingMode::SINGLE_ROUND); + + auto scaled = tosa::ApplyScaleOp::create( + rewriter, loc, rewriter.getI32Type(), poolVal, + multiplier, shift, roundingAttr) + .getResult(); // If we have quantization information we need to apply output // zeropoint. diff --git a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt new file mode 100644 index 0000000..2d4b2b6 --- /dev/null +++ b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt @@ -0,0 +1,19 @@ +add_mlir_conversion_library(MLIRVectorToAMX + VectorToAMX.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMX + + DEPENDS + MLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRAMXDialect + MLIRAffineUtils + MLIRArithDialect + MLIRLinalgUtils + MLIRMemRefDialect + MLIRSCFDialect + MLIRTransforms + MLIRVectorDialect + ) diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp new file mode 100644 index 0000000..7b9ed1d --- /dev/null +++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp @@ -0,0 +1,429 @@ +//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/VectorToAMX/VectorToAMX.h" + +#include "mlir/Dialect/AMX/AMXDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "llvm/Support/DebugLog.h" + +#include <numeric> + +namespace mlir { +#define GEN_PASS_DEF_CONVERTVECTORTOAMX +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +#define DEBUG_TYPE "vector-to-amx" + +namespace { + +/// Return true if vector shape is compatible with AMX tiles. +/// The validation accounts for VNNI packing. +static bool verifyAmxShape(VectorType vec) { + // Check overall shape: + // - 2D for plain layout input or output + // - 3D for VNNI packed input + if (vec.getRank() != 2 && vec.getRank() != 3) + return false; + + ArrayRef<int64_t> shape = vec.getShape(); + int64_t rows = shape[0]; + int64_t cols = shape[1]; + unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth(); + + // 3D shape indicates VNNI packed layout. + if (vec.getRank() == 3) { + int64_t vnniFactor = 32 / elemBitWidth; + if (shape.back() != vnniFactor) { + LDBG() << "invalid VNNI packing factor"; + return false; + } + cols *= vnniFactor; + } + + // AMX tile supports up to 16 rows of 64 bytes each. + constexpr unsigned maxRows = 16; + constexpr unsigned maxBitsPerRow = 64 * 8; + return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow; +} + +/// Check if contraction operands are in AMX-compatible packed VNNI layout. +static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter, + vector::ContractionOp contractOp) { + VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType()); + if (!accType || accType.getRank() != 2) + return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector"); + + // Expect 3D inputs for VNNI packed data. + VectorType lhsType = contractOp.getLhs().getType(); + VectorType rhsType = contractOp.getRhs().getType(); + if (lhsType.getRank() != 3 || rhsType.getRank() != 3) + return rewriter.notifyMatchFailure(contractOp, + "Expects lhs and rhs 3D vectors"); + + // Check if shapes are compatible with AMX tile. + if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) || + !verifyAmxShape(accType)) + return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape"); + + // Validate affine maps. + // + // Iterators can be ordered arbitrarily. Indexing map positions are based on + // operands' target shapes. + // The matrix layouts must match the following: + // - matrix A - [M]x[K/vnniFactor]x[vnniFactor] + // - matrix B - [K/vnniFactor]x[N]x[vnniFactor] + // - matrix C - [M]x[N] + SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray(); + AffineMap mapA = indexingMaps[0]; + AffineMap mapB = indexingMaps[1]; + if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 || + mapB.getNumResults() != 3) + return rewriter.notifyMatchFailure(contractOp, + "Invalid input indexing maps"); + FailureOr<linalg::ContractionDimensions> dims = + linalg::inferContractionDims(indexingMaps); + if (failed(dims)) + return rewriter.notifyMatchFailure(contractOp, + "Failed to infer contraction dims"); + // Two reduction dimensions are expected: + // - one for the K dimension + // - one for the VNNI factor + if (dims->k.size() != 2) + return rewriter.notifyMatchFailure(contractOp, + "Expected two reduction dims"); + assert(dims->m.size() == 1 && dims->n.size() == 1 && + "Invalid parallel contraction dims"); + + SmallVector<vector::IteratorType> iteratorTypes = + contractOp.getIteratorTypesArray(); + // Check VNNI dim maps - the innermost dim for A and B inputs. + auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2)); + auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2)); + if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB || + iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction) + return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map"); + // Check K dim maps - non-transposed row-major layout. + auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1)); + auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0)); + if (!redDimA || !redDimB || redDimA != redDimB || + iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction) + return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map"); + // Check M and N dim maps - map to non-transposed output. + AffineMap mapC = indexingMaps[2]; + auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0)); + auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1)); + if (!mDimC || !nDimC) + return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps"); + auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0)); + if (!parallelDimA || + iteratorTypes[parallelDimA.getPosition()] != + vector::IteratorType::parallel || + parallelDimA != mDimC) + return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map"); + auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1)); + if (!parallelDimB || + iteratorTypes[parallelDimB.getPosition()] != + vector::IteratorType::parallel || + parallelDimB != nDimC) + return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map"); + + return success(); +} + +/// Validate contraction operands for AMX lowering. +static LogicalResult validateOperands(PatternRewriter &rewriter, + vector::ContractionOp contractOp) { + VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType()); + if (!accType) + return rewriter.notifyMatchFailure(contractOp, "Expects vector acc"); + + // Check if operand types are compatible with AMX compute ops. + bool validElemTypes = false; + Type lhsElemType = contractOp.getLhs().getType().getElementType(); + Type rhsElemType = contractOp.getRhs().getType().getElementType(); + Type accElemType = accType.getElementType(); + if (accElemType.isInteger(32)) { + validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8); + } else if (accElemType.isF32()) { + validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) || + (lhsElemType.isBF16() && rhsElemType.isBF16()); + } + if (!validElemTypes) + return rewriter.notifyMatchFailure(contractOp, + "Invalid combination of operand types"); + + if (failed(isAmxVnniLayout(rewriter, contractOp))) + return failure(); + + return success(); +} + +/// Collapse the two innermost dimensions together. +static TypedValue<MemRefType> collapseLastDim(PatternRewriter &rewriter, + TypedValue<MemRefType> memref) { + int64_t rank = memref.getType().getRank(); + SmallVector<ReassociationIndices> reassocIndices; + for (auto i : llvm::seq<int64_t>(0, rank - 2)) + reassocIndices.push_back({i}); + reassocIndices.push_back({rank - 2, rank - 1}); + return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref, + reassocIndices); +} + +/// Attempt to create an AMX tile load/store operation equivalent to the given +/// vector transfer `xfer` op. +/// This approach allows to skip longer route through registers and a temporary +/// buffer otherwise required to move data to/from an AMX tile. +static Operation * +loadStoreFromTransfer(PatternRewriter &rewriter, + VectorTransferOpInterface xferOp, bool isPacked, + TypedValue<amx::TileType> tileToStore = nullptr) { + if (!xferOp || !isa<vector::TransferReadOp, vector::TransferWriteOp>(xferOp)) + return nullptr; + if (xferOp.hasOutOfBoundsDim() || + !xferOp.getPermutationMap().isMinorIdentity()) + return nullptr; + + // Extra checks in case of a write op. + // Stores must not be packed. + if (isa<vector::TransferWriteOp>(xferOp) && + (!tileToStore || isPacked || + tileToStore.getType().getShape() != xferOp.getVectorType().getShape())) + return nullptr; + + // Check for a memref source buffer. + // AMX data transfer requires at least 2D shape to correctly + // infer stride between rows. + Value base = xferOp.getBase(); + auto memTy = dyn_cast<MemRefType>(base.getType()); + int64_t memRank = memTy.getRank(); + if (!memTy || memRank < 2) + return nullptr; + + // Check that the source buffer has enough contiguous elements to load whole + // AMX tile row. + // + // To ensure correctness, the validation is conservative and expects the + // buffer's innermost dimensions to be statically known, equal to or larger + // than the vector row length, and equal to the VNNI dimension if applicable. + // + // This check could be relaxed to accept more arbitrarily shaped buffers as + // long as there are enough contiguous elements to load a whole row. + if (!memTy.areTrailingDimsContiguous(isPacked ? 2 : 1)) + return nullptr; + VectorType vecTy = xferOp.getVectorType(); + ArrayRef<int64_t> vecShape = vecTy.getShape(); + ArrayRef<int64_t> memShape = memTy.getShape(); + if (memShape.back() == ShapedType::kDynamic || + memShape.back() < vecShape.back()) + return nullptr; + if (isPacked && + (memShape.back() != vecShape.back() || + memShape[memShape.size() - 2] == ShapedType::kDynamic || + memShape[memShape.size() - 2] < vecShape[vecShape.size() - 2])) + return nullptr; + + // Load values directly from the buffer to an AMX tile. + PatternRewriter::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(xferOp); + Location loc = xferOp.getLoc(); + + // Create a subview of the source buffer based on the transfer op to resolve + // offsets. + SmallVector<OpFoldResult> strides(memRank, rewriter.getIndexAttr(1)); + int64_t vecRank = vecTy.getRank(); + assert(memRank >= vecRank && + "Expects buffer to be the same or greater rank than vector"); + SmallVector<int64_t> shape(memRank - vecRank, 1); + shape.append(vecShape.begin(), vecShape.end()); + TypedValue<MemRefType> src = + memref::SubViewOp::create( + rewriter, loc, base, getAsOpFoldResult(xferOp.getIndices()), + getAsOpFoldResult(rewriter.getI64ArrayAttr(shape)), strides) + .getResult(); + + // Collapse the VNNI dimension in case of packing. + if (isPacked) + src = collapseLastDim(rewriter, src); + int64_t rows = vecShape[0]; + int64_t cols = std::accumulate(vecShape.begin() + 1, vecShape.end(), 1, + std::multiplies<int64_t>()); + auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); + + Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); + SmallVector<Value> tileIndicides(src.getType().getRank(), zeroIndex); + + Operation *amxTileOp = nullptr; + if (isa<vector::TransferReadOp>(xferOp)) { + amxTileOp = + amx::TileLoadOp::create(rewriter, loc, tileType, src, tileIndicides); + } else if (isa<vector::TransferWriteOp>(xferOp)) { + amxTileOp = amx::TileStoreOp::create(rewriter, loc, src, tileIndicides, + tileToStore); + } else { + llvm_unreachable("unsupported vector transfer op"); + } + + return amxTileOp; +} + +/// Attempt to create an AMX tile load operation equivalent to the given +/// vector transfer `readOp`. +/// Returns loaded AMX tile if successful. +static FailureOr<TypedValue<amx::TileType>> +loadFromTransfer(PatternRewriter &rewriter, vector::TransferReadOp readOp, + bool isPacked) { + amx::TileLoadOp loadOp = dyn_cast_if_present<amx::TileLoadOp>( + loadStoreFromTransfer(rewriter, readOp, isPacked)); + if (!loadOp) + return failure(); + return loadOp.getRes(); +} + +/// Attempt to create an AMX tile store operation equivalent to the given +/// vector transfer `writeOp`. +static LogicalResult storeFromTransfer(PatternRewriter &rewriter, + vector::TransferWriteOp writeOp, + TypedValue<amx::TileType> tileToStore) { + return success(loadStoreFromTransfer(rewriter, writeOp, /*isPacked=*/false, + tileToStore)); +} + +/// Load vector values to an AMX tile. +static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter, + TypedValue<VectorType> vec) { + Location loc = vec.getLoc(); + + VectorType vecTy = vec.getType(); + bool isPacked = vecTy.getRank() == 3; + + // Try to load tile directly from vector producer's buffer. + auto readOp = vec.getDefiningOp<vector::TransferReadOp>(); + FailureOr<TypedValue<amx::TileType>> tile = + loadFromTransfer(rewriter, readOp, isPacked); + if (succeeded(tile)) + return *tile; + + // Transfer the vector to a tile through an intermediate buffer. + Value buf = memref::AllocaOp::create( + rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType())); + Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); + SmallVector<Value> indices(vecTy.getRank(), zeroIndex); + vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices); + + // Collapse the VNNI dimension in case of packing. + if (isPacked) + buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf)); + + ArrayRef<int64_t> shape = vecTy.getShape(); + int64_t rows = shape[0]; + int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1, + std::multiplies<int64_t>()); + auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType()); + + return amx::TileLoadOp::create(rewriter, loc, tileType, buf, + {zeroIndex, zeroIndex}); +} + +/// Store an AMX tile in a vector. +static TypedValue<VectorType> storeTile(PatternRewriter &rewriter, + TypedValue<amx::TileType> tile) { + Location loc = tile.getLoc(); + + // Transfer the tile to a vector through an intermediate buffer. + amx::TileType tileTy = tile.getType(); + Value buf = memref::AllocaOp::create( + rewriter, loc, + MemRefType::get(tileTy.getShape(), tileTy.getElementType())); + Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0); + SmallVector<Value> indices(2, zeroIndex); + amx::TileStoreOp::create(rewriter, loc, buf, indices, tile); + + auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType()); + return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {}); +} + +struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::ContractionOp contractOp, + PatternRewriter &rewriter) const override { + Location loc = contractOp.getLoc(); + + if (contractOp.getKind() != vector::CombiningKind::ADD) + return rewriter.notifyMatchFailure(contractOp, + "Expects add combining kind"); + if (failed(validateOperands(rewriter, contractOp))) + return failure(); + + TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs()); + TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs()); + auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc()); + assert(acc && "Invalid accumulator type"); + TypedValue<amx::TileType> accTile = loadTile(rewriter, acc); + + TypedValue<amx::TileType> tileMul; + if (acc.getType().getElementType().isFloat()) { + tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(), + lhsTile, rhsTile, accTile); + } else { + tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(), + lhsTile, rhsTile, accTile); + } + + // If the contraction result is only written back to memory, try to replace + // the vector op with an AMX store directly. + Value res = contractOp.getResult(); + if (res.hasOneUse()) { + auto writeOp = dyn_cast<vector::TransferWriteOp>(*res.getUsers().begin()); + LogicalResult storeRes = storeFromTransfer(rewriter, writeOp, tileMul); + if (succeeded(storeRes)) { + rewriter.eraseOp(writeOp); + rewriter.eraseOp(contractOp); + return success(); + } + } + + // Load the result back into a vector. + Value newResult = storeTile(rewriter, tileMul); + rewriter.replaceOp(contractOp, newResult); + + return success(); + } +}; + +struct ConvertVectorToAMXPass + : public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> { + void runOnOperation() override { + MLIRContext &ctx = getContext(); + RewritePatternSet patterns(&ctx); + populateVectorToAMXConversionPatterns(patterns); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) { + patterns.add<ContractionToAMX>(patterns.getContext()); +} diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index f9e2a01..1ff7d5d 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -306,11 +306,11 @@ public: // Resolve address. Value ptr = getStridedElementPtr(rewriter, loc, memRefType, - adaptor.getBase(), adaptor.getIndices()); + adaptor.getBase(), adaptor.getOffsets()); Value base = adaptor.getBase(); Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, - base, ptr, adaptor.getIndexVec(), vType); + base, ptr, adaptor.getIndices(), vType); // Replace with the gather intrinsic. rewriter.replaceOpWithNewOp<LLVM::masked_gather>( @@ -362,10 +362,10 @@ public: // Resolve address. Value ptr = getStridedElementPtr(rewriter, loc, memRefType, - adaptor.getBase(), adaptor.getIndices()); + adaptor.getBase(), adaptor.getOffsets()); Value ptrs = getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType, - adaptor.getBase(), ptr, adaptor.getIndexVec(), vType); + adaptor.getBase(), ptr, adaptor.getIndices(), vType); // Replace with the scatter intrinsic. rewriter.replaceOpWithNewOp<LLVM::masked_scatter>( @@ -1891,15 +1891,21 @@ struct VectorFromElementsLowering ConversionPatternRewriter &rewriter) const override { Location loc = fromElementsOp.getLoc(); VectorType vectorType = fromElementsOp.getType(); - // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>. - // Such ops should be handled in the same way as vector.insert. + // Only support 1-D vectors. Multi-dimensional vectors should have been + // transformed to 1-D vectors by the vector-to-vector transformations before + // this. if (vectorType.getRank() > 1) return rewriter.notifyMatchFailure(fromElementsOp, "rank > 1 vectors are not supported"); Type llvmType = typeConverter->convertType(vectorType); + Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType()); Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType); - for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) - result = vector::InsertOp::create(rewriter, loc, val, result, idx); + for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) { + auto constIdx = + LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx); + result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result, + val, constIdx); + } rewriter.replaceOp(fromElementsOp, result); return success(); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index cf10869..9852df6 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorStepLoweringPatterns(patterns); populateVectorRankReducingFMAPattern(patterns); populateVectorGatherLoweringPatterns(patterns); + populateVectorFromElementsLoweringPatterns(patterns); if (armI8MM) { if (armNeon) arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index a4be7d4..036cbad 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -743,6 +743,22 @@ struct VectorLoadOpConverter final auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass); + std::optional<uint64_t> alignment = loadOp.getAlignment(); + if (alignment > std::numeric_limits<uint32_t>::max()) { + return rewriter.notifyMatchFailure(loadOp, + "invalid alignment requirement"); + } + + auto memoryAccess = spirv::MemoryAccess::None; + spirv::MemoryAccessAttr memoryAccessAttr; + IntegerAttr alignmentAttr; + if (alignment.has_value()) { + memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; + memoryAccessAttr = + spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess); + alignmentAttr = rewriter.getI32IntegerAttr(alignment.value()); + } + // For single element vectors, we don't need to bitcast the access chain to // the original vector type. Both is going to be the same, a pointer // to a scalar. @@ -753,7 +769,8 @@ struct VectorLoadOpConverter final accessChain); rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType, - castedAccessChain); + castedAccessChain, + memoryAccessAttr, alignmentAttr); return success(); } @@ -782,6 +799,12 @@ struct VectorStoreOpConverter final return rewriter.notifyMatchFailure( storeOp, "failed to get memref element pointer"); + std::optional<uint64_t> alignment = storeOp.getAlignment(); + if (alignment > std::numeric_limits<uint32_t>::max()) { + return rewriter.notifyMatchFailure(storeOp, + "invalid alignment requirement"); + } + spirv::StorageClass storageClass = attr.getValue(); auto vectorType = storeOp.getVectorType(); auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass); @@ -795,8 +818,19 @@ struct VectorStoreOpConverter final : spirv::BitcastOp::create(rewriter, loc, vectorPtrType, accessChain); - rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain, - adaptor.getValueToStore()); + auto memoryAccess = spirv::MemoryAccess::None; + spirv::MemoryAccessAttr memoryAccessAttr; + IntegerAttr alignmentAttr; + if (alignment.has_value()) { + memoryAccess = memoryAccess | spirv::MemoryAccess::Aligned; + memoryAccessAttr = + spirv::MemoryAccessAttr::get(rewriter.getContext(), memoryAccess); + alignmentAttr = rewriter.getI32IntegerAttr(alignment.value()); + } + + rewriter.replaceOpWithNewOp<spirv::StoreOp>( + storeOp, castedAccessChain, adaptor.getValueToStore(), memoryAccessAttr, + alignmentAttr); return success(); } diff --git a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt index 567083d..e9ad67c5 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt @@ -13,4 +13,5 @@ add_mlir_conversion_library(MLIRVectorToXeGPU MLIRTransforms MLIRVectorDialect MLIRXeGPUDialect + MLIRXeGPUUtils ) diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp index 8010755..819c2e5 100644 --- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp +++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp @@ -14,9 +14,11 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" @@ -68,11 +70,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter, if (!srcTy) return rewriter.notifyMatchFailure(xferOp, "Expects memref source"); - // Perform common data transfer checks. - VectorType vecTy = xferOp.getVectorType(); - if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy))) - return failure(); - // Validate further transfer op semantics. SmallVector<int64_t> strides; int64_t offset; @@ -80,6 +77,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter, return rewriter.notifyMatchFailure( xferOp, "Buffer must be contiguous in the innermost dimension"); + VectorType vecTy = xferOp.getVectorType(); unsigned vecRank = vecTy.getRank(); if (xferOp.hasOutOfBoundsDim() && vecRank < 2) return rewriter.notifyMatchFailure( @@ -155,6 +153,277 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc, return ndDesc; } +// Adjusts the strides of a memref according to a given permutation map for +// vector operations. +// +// This function updates the innermost strides in the `strides` array to +// reflect the permutation specified by `permMap`. The permutation is computed +// using the inverse and broadcasting-aware version of the permutation map, +// and is applied to the relevant strides. This ensures that memory accesses +// are consistent with the logical permutation of vector elements. +// +// Example: +// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`. +// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1, +// 0]), then after calling this function, the last two strides will be +// swapped: +// Original strides: [s0, s1, s2, s3] +// After permutation: [s0, s1, s3, s2] +// +static void adjustStridesForPermutation(AffineMap permMap, + SmallVectorImpl<Value> &strides) { + + AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap); + SmallVector<unsigned> perms; + invMap.isPermutationOfMinorIdentityWithBroadcasting(perms); + SmallVector<int64_t> perms64(perms.begin(), perms.end()); + strides = applyPermutation(strides, perms64); +} + +// Computes memory strides for vector transfer operations, handling both +// static and dynamic memrefs while applying permutation transformations +// for XeGPU lowering. +static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp, + PatternRewriter &rewriter) { + SmallVector<Value> strides; + Value baseMemref = xferOp.getBase(); + AffineMap permMap = xferOp.getPermutationMap(); + MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType()); + + Location loc = xferOp.getLoc(); + if (memrefType.hasStaticShape()) { + int64_t offset; + SmallVector<int64_t> intStrides; + if (failed(memrefType.getStridesAndOffset(intStrides, offset))) + return {}; + // Wrap static strides as MLIR values + for (int64_t s : intStrides) + strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s)); + } else { + // For dynamic shape memref, use memref.extract_strided_metadata to get + // stride values + unsigned rank = memrefType.getRank(); + Type indexType = rewriter.getIndexType(); + + // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1, + // size0, size1, ..., sizeN-1] + SmallVector<Type> resultTypes; + resultTypes.push_back(MemRefType::get( + {}, memrefType.getElementType())); // base memref (unranked) + resultTypes.push_back(indexType); // offset + + for (unsigned i = 0; i < rank; ++i) + resultTypes.push_back(indexType); // strides + + for (unsigned i = 0; i < rank; ++i) + resultTypes.push_back(indexType); // sizes + + auto meta = memref::ExtractStridedMetadataOp::create( + rewriter, loc, resultTypes, baseMemref); + strides.append(meta.getStrides().begin(), meta.getStrides().end()); + } + // Adjust strides according to the permutation map (e.g., for transpose) + adjustStridesForPermutation(permMap, strides); + return strides; +} + +// This function compute the vectors of localOffsets for scattered load/stores. +// It is used in the lowering of vector.transfer_read/write to +// load_gather/store_scatter Example: +// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0], +// %cst {in_bounds = [true, true, true, true]}>} : +// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16> +// +// %6 = vector.step: vector<4xindex> +// %7 = vector.step: vector<2xindex> +// %8 = vector.step: vector<6xindex> +// %9 = vector.step: vector<32xindex> +// %10 = arith.mul %6, 384 +// %11 = arith.mul %7, 192 +// %12 = arith.mul %8, 32 +// %13 = arith.mul %9, 1 +// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16> +// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16> +// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16> +// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16> +// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex> +// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex> +// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex> +// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex> +// %22 = arith.add %18, %19 +// %23 = arith.add %20, %21 +// %local_offsets = arith.add %22, %23 +// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map +// %offsets = orig_offset + local_offsets +static Value computeOffsets(VectorTransferOpInterface xferOp, + PatternRewriter &rewriter, + ArrayRef<Value> strides) { + Location loc = xferOp.getLoc(); + VectorType vectorType = xferOp.getVectorType(); + SmallVector<Value> indices(xferOp.getIndices().begin(), + xferOp.getIndices().end()); + ArrayRef<int64_t> vectorShape = vectorType.getShape(); + + // Create vector.step operations for each dimension + SmallVector<Value> stepVectors; + llvm::map_to_vector(vectorShape, [&](int64_t dim) { + auto stepType = VectorType::get({dim}, rewriter.getIndexType()); + auto stepOp = vector::StepOp::create(rewriter, loc, stepType); + stepVectors.push_back(stepOp); + return stepOp; + }); + + // Multiply step vectors by corresponding strides + size_t memrefRank = strides.size(); + size_t vectorRank = vectorShape.size(); + SmallVector<Value> strideMultiplied; + for (size_t i = 0; i < vectorRank; ++i) { + size_t memrefDim = memrefRank - vectorRank + i; + Value strideValue = strides[memrefDim]; + auto mulType = dyn_cast<VectorType>(stepVectors[i].getType()); + auto bcastOp = + vector::BroadcastOp::create(rewriter, loc, mulType, strideValue); + auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp); + strideMultiplied.push_back(mulOp); + } + + // Shape cast each multiplied vector to add singleton dimensions + SmallVector<Value> shapeCasted; + for (size_t i = 0; i < vectorRank; ++i) { + SmallVector<int64_t> newShape(vectorRank, 1); + newShape[i] = vectorShape[i]; + auto newType = VectorType::get(newShape, rewriter.getIndexType()); + auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType, + strideMultiplied[i]); + shapeCasted.push_back(castOp); + } + + // Broadcast each shape-casted vector to full vector shape + SmallVector<Value> broadcasted; + auto fullIndexVectorType = + VectorType::get(vectorShape, rewriter.getIndexType()); + for (Value shapeCastVal : shapeCasted) { + auto broadcastOp = vector::BroadcastOp::create( + rewriter, loc, fullIndexVectorType, shapeCastVal); + broadcasted.push_back(broadcastOp); + } + + // Add all broadcasted vectors together to compute local offsets + Value localOffsets = broadcasted[0]; + for (size_t i = 1; i < broadcasted.size(); ++i) + localOffsets = + arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]); + + // Compute base offset from transfer read indices + Value baseOffset = nullptr; + if (!indices.empty()) { + baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0); + for (size_t i = 0; i < indices.size(); ++i) { + Value strideVal = strides[i]; + Value offsetContrib = + arith::MulIOp::create(rewriter, loc, indices[i], strideVal); + baseOffset = + arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib); + } + // Broadcast base offset to match vector shape + Value bcastBase = vector::BroadcastOp::create( + rewriter, loc, fullIndexVectorType, baseOffset); + localOffsets = + arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets); + } + return localOffsets; +} + +// Collapse memref shape to 1D +static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp, + PatternRewriter &rewriter) { + Location loc = xferOp.getLoc(); + + Value baseMemref = xferOp.getBase(); + MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType()); + Type elementType = memrefType.getElementType(); + + // Compute the total number of elements in the memref + MemRefType flatMemrefType; + if (memrefType.hasStaticShape()) { + auto totalElements = memrefType.getNumElements(); + flatMemrefType = MemRefType::get({totalElements}, elementType); + } else { + flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType); + } + + SmallVector<ReassociationIndices> reassociation; + ReassociationIndices allDims = + llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank())); + reassociation.push_back(allDims); + + auto collapseOp = memref::CollapseShapeOp::create( + rewriter, loc, flatMemrefType, baseMemref, reassociation); + return collapseOp; +} + +static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp, + PatternRewriter &rewriter) { + + Location loc = readOp.getLoc(); + VectorType vectorType = readOp.getVectorType(); + ArrayRef<int64_t> vectorShape = vectorType.getShape(); + auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType()); + if (!memrefType) + return rewriter.notifyMatchFailure(readOp, "Expected memref source"); + + SmallVector<Value> strides = computeStrides(readOp, rewriter); + if (strides.empty()) + return rewriter.notifyMatchFailure(readOp, "Failed to compute strides"); + + Value localOffsets = computeOffsets(readOp, rewriter, strides); + + Value flatMemref = collapseMemrefTo1D(readOp, rewriter); + + Value mask = vector::ConstantMaskOp::create( + rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), + vectorShape); + auto gatherOp = xegpu::LoadGatherOp::create( + rewriter, loc, vectorType, flatMemref, localOffsets, mask, + /*chunk_size=*/IntegerAttr{}, + /*l1_hint=*/xegpu::CachePolicyAttr{}, + /*l2_hint=*/xegpu::CachePolicyAttr{}, + /*l3_hint=*/xegpu::CachePolicyAttr{}); + + rewriter.replaceOp(readOp, gatherOp.getResult()); + return success(); +} + +static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp, + PatternRewriter &rewriter) { + + Location loc = writeOp.getLoc(); + VectorType vectorType = writeOp.getVectorType(); + ArrayRef<int64_t> vectorShape = vectorType.getShape(); + + auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType()); + if (!memrefType) + return rewriter.notifyMatchFailure(writeOp, "Expected memref source"); + + SmallVector<Value> strides = computeStrides(writeOp, rewriter); + + Value localOffsets = computeOffsets(writeOp, rewriter, strides); + + Value flatMemref = collapseMemrefTo1D(writeOp, rewriter); + + Value mask = vector::ConstantMaskOp::create( + rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()), + vectorShape); + xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref, + localOffsets, mask, + /*chunk_size=*/IntegerAttr{}, + /*l1_hint=*/xegpu::CachePolicyAttr{}, + /*l2_hint=*/xegpu::CachePolicyAttr{}, + /*l3_hint=*/xegpu::CachePolicyAttr{}); + rewriter.eraseOp(writeOp); + return success(); +} + struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern; @@ -165,6 +434,22 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { if (failed(transferPreconditions(rewriter, readOp))) return failure(); + // TODO:This check needs to be replaced with proper uArch capability check + auto chip = xegpu::getChipStr(readOp); + if (chip != "pvc" && chip != "bmg") { + // lower to scattered load Op if the target HW doesn't have 2d block load + // support + // TODO: add support for OutOfBound access + if (readOp.hasOutOfBoundsDim()) + return failure(); + return lowerToScatteredLoadOp(readOp, rewriter); + } + + // Perform common data transfer checks. + VectorType vecTy = readOp.getVectorType(); + if (failed(storeLoadPreconditions(rewriter, readOp, vecTy))) + return failure(); + bool isOutOfBounds = readOp.hasOutOfBoundsDim(); if (isOutOfBounds && !isZeroConstant(readOp.getPadding())) return rewriter.notifyMatchFailure( @@ -173,7 +458,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> { AffineMap readMap = readOp.getPermutationMap(); bool isTransposeLoad = !readMap.isMinorIdentity(); - VectorType vecTy = readOp.getVectorType(); Type elementType = vecTy.getElementType(); unsigned minTransposeBitWidth = 32; if (isTransposeLoad && @@ -221,11 +505,26 @@ struct TransferWriteLowering if (failed(transferPreconditions(rewriter, writeOp))) return failure(); + // TODO:This check needs to be replaced with proper uArch capability check + auto chip = xegpu::getChipStr(writeOp); + if (chip != "pvc" && chip != "bmg") { + // lower to scattered store Op if the target HW doesn't have 2d block + // store support + // TODO: add support for OutOfBound access + if (writeOp.hasOutOfBoundsDim()) + return failure(); + return lowerToScatteredStoreOp(writeOp, rewriter); + } + + // Perform common data transfer checks. + VectorType vecTy = writeOp.getVectorType(); + if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy))) + return failure(); + AffineMap map = writeOp.getPermutationMap(); if (!map.isMinorIdentity()) return rewriter.notifyMatchFailure(writeOp, "Expects identity map"); - VectorType vecTy = writeOp.getVectorType(); auto descType = xegpu::TensorDescType::get( vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(), diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt new file mode 100644 index 0000000..84b2580 --- /dev/null +++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt @@ -0,0 +1,27 @@ +add_mlir_conversion_library(MLIRXeGPUToXeVM + XeGPUToXeVM.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeGPUToXeVM + + DEPENDS + MLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRFuncDialect + MLIRGPUDialect + MLIRLLVMCommonConversion + MLIRLLVMDialect + MLIRXeVMDialect + MLIRVectorDialect + MLIRArithDialect + MLIRIndexDialect + MLIRSCFDialect + MLIRXeGPUDialect + MLIRPass + MLIRTransforms + MLIRSCFTransforms +) diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp new file mode 100644 index 0000000..a7f2dc2 --- /dev/null +++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp @@ -0,0 +1,1026 @@ +//===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" + +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Transforms/Patterns.h" +#include "mlir/Dialect/XeGPU/IR/XeGPU.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LLVM.h" +#include "llvm/Support/FormatVariadic.h" + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Types.h" + +#include "llvm/ADT/TypeSwitch.h" + +#include <numeric> + +namespace mlir { +#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS +#include "mlir/Conversion/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { + +// TODO: Below are uArch dependent values, should move away from hardcoding +static constexpr int32_t systolicDepth{8}; +static constexpr int32_t executionSize{16}; + +// Offsets to individual fields of the 8xi32 layout nd tensor descriptor. +enum class NdTdescOffset : uint32_t { + BasePtr = 0, // Base pointer (i64) + BaseShapeW = 2, // Base shape width (i32) + BaseShapeH = 3, // Base shape height (i32) + TensorOffsetW = 4, // Tensor offset W (i32) + TensorOffsetH = 5 // Tensor offset H (i32) +}; + +static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { + switch (xeGpuMemspace) { + case xegpu::MemorySpace::Global: + return static_cast<int>(xevm::AddrSpace::GLOBAL); + case xegpu::MemorySpace::SLM: + return static_cast<int>(xevm::AddrSpace::SHARED); + } +} + +// Get same bitwidth flat vector type of new element type. +static VectorType encodeVectorTypeTo(VectorType currentVecType, + Type toElemType) { + auto elemType = currentVecType.getElementType(); + auto currentBitWidth = elemType.getIntOrFloatBitWidth(); + auto newBitWidth = toElemType.getIntOrFloatBitWidth(); + const int size = + currentVecType.getNumElements() * currentBitWidth / newBitWidth; + return VectorType::get(size, toElemType); +} + +static xevm::LoadCacheControl +translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint, + std::optional<xegpu::CachePolicy> L3hint) { + auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED); + auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED); + switch (L1hintVal) { + case xegpu::CachePolicy::CACHED: + if (L3hintVal == xegpu::CachePolicy::CACHED) + return xevm::LoadCacheControl::L1C_L2UC_L3C; + else if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::LoadCacheControl::L1C_L2UC_L3UC; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::UNCACHED: + if (L3hintVal == xegpu::CachePolicy::CACHED) + return xevm::LoadCacheControl::L1UC_L2UC_L3C; + else if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::LoadCacheControl::L1UC_L2UC_L3UC; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::STREAMING: + if (L3hintVal == xegpu::CachePolicy::CACHED) + return xevm::LoadCacheControl::L1S_L2UC_L3C; + else if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::LoadCacheControl::L1S_L2UC_L3UC; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::READ_INVALIDATE: + return xevm::LoadCacheControl::INVALIDATE_READ; + default: + llvm_unreachable("Unsupported cache control."); + } +} + +static xevm::StoreCacheControl +translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint, + std::optional<xegpu::CachePolicy> L3hint) { + auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED); + auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED); + switch (L1hintVal) { + case xegpu::CachePolicy::UNCACHED: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1UC_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1UC_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::STREAMING: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1S_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1S_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::WRITE_BACK: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1WB_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1WB_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + case xegpu::CachePolicy::WRITE_THROUGH: + if (L3hintVal == xegpu::CachePolicy::UNCACHED) + return xevm::StoreCacheControl::L1WT_L2UC_L3UC; + else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK) + return xevm::StoreCacheControl::L1WT_L2UC_L3WB; + else + llvm_unreachable("Unsupported cache control."); + default: + llvm_unreachable("Unsupported cache control."); + } +} + +class CreateNdDescToXeVMPattern + : public OpConversionPattern<xegpu::CreateNdDescOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateNdDescOp op, + xegpu::CreateNdDescOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto source = op.getSource(); + // Op is lowered to a code sequence that populates payload. + // Payload is a 8xi32 vector. Offset to individual fields are defined in + // NdTdescOffset enum. + Type payloadElemTy = rewriter.getI32Type(); + VectorType payloadTy = VectorType::get(8, payloadElemTy); + Type i64Ty = rewriter.getI64Type(); + // 4xi64 view is used for inserting the base pointer. + VectorType payloadI64Ty = VectorType::get(4, i64Ty); + // Initialize payload to zero. + Value payload = arith::ConstantOp::create( + rewriter, loc, + DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0))); + + Value baseAddr; + Value baseShapeW; + Value baseShapeH; + Value offsetW; + Value offsetH; + + // Source can be a memref or a pointer (ui64, ui32, i64 or i32). + SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes(); + SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets(); + // Descriptor shape is expected to be 2D. + int64_t rank = mixedSizes.size(); + if (rank != 2) + return rewriter.notifyMatchFailure(op, "Expected 2D shape."); + auto sourceTy = source.getType(); + auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy); + // If source is a memref, we need to extract the aligned pointer as index. + // Pointer type is passed as i32 or i64 by type converter. + if (sourceMemrefTy) { + if (!sourceMemrefTy.hasStaticShape()) { + return rewriter.notifyMatchFailure(op, "Expected static memref shape."); + } + baseAddr = + memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source); + } else { + baseAddr = adaptor.getSource(); + } + // Utility for creating offset values from op fold result. + auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec, + unsigned idx) -> Value { + Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]); + val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val); + return val; + }; + // Offsets can be either 2D or not provided (0 is used). + if (mixedOffsets.size() == 2) { + offsetW = createOffset(mixedOffsets, 1); + offsetH = createOffset(mixedOffsets, 0); + } else if (mixedOffsets.size() == 0) { + offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); + offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0); + } else { + return rewriter.notifyMatchFailure(op, + "Expected 2D offsets or no offsets."); + } + // Get shape values from op fold results. + baseShapeW = createOffset(mixedSizes, 1); + baseShapeH = createOffset(mixedSizes, 0); + if (sourceMemrefTy) { + // Cast index to i64. + baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr); + } else if (baseAddr.getType() != i64Ty) { + // Pointer type may be i32. Cast to i64 if needed. + baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr); + } + // Populate payload. + Value payLoadAsI64 = + vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload); + payLoadAsI64 = + vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64, + static_cast<int>(NdTdescOffset::BasePtr)); + payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64); + payload = + vector::InsertOp::create(rewriter, loc, baseShapeW, payload, + static_cast<int>(NdTdescOffset::BaseShapeW)); + payload = + vector::InsertOp::create(rewriter, loc, baseShapeH, payload, + static_cast<int>(NdTdescOffset::BaseShapeH)); + payload = vector::InsertOp::create( + rewriter, loc, offsetW, payload, + static_cast<int>(NdTdescOffset::TensorOffsetW)); + payload = vector::InsertOp::create( + rewriter, loc, offsetH, payload, + static_cast<int>(NdTdescOffset::TensorOffsetH)); + rewriter.replaceOp(op, payload); + return success(); + } +}; + +class UpdateNdOffsetToXeVMPattern + : public OpConversionPattern<xegpu::UpdateNdOffsetOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::UpdateNdOffsetOp op, + xegpu::UpdateNdOffsetOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto mixedOffsets = op.getMixedOffsets(); + // Only 2D offsets are supported for now. + if (mixedOffsets.size() != 2) + return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); + auto payload = adaptor.getTensorDesc(); + // Utility for updating payload offset values from op fold result. + auto updateOffset = [&](unsigned idx, int payloadPos) -> Value { + Value offset = + getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]); + offset = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offset); + Value oldOffset = + vector::ExtractOp::create(rewriter, loc, payload, payloadPos); + Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset); + return vector::InsertOp::create(rewriter, loc, newOffset, payload, + payloadPos); + }; + // Update offsets in the payload. + payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH)); + payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW)); + rewriter.replaceOp(op, payload); + return success(); + } +}; + +template < + typename OpType, + typename = std::enable_if_t<llvm::is_one_of< + OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>> +class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + + auto tdesc = adaptor.getTensorDesc(); + auto tdescTy = op.getTensorDescType(); + if (tdescTy.getRank() != 2) + return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor."); + auto elemType = tdescTy.getElementType(); + auto elemBitSize = elemType.getIntOrFloatBitWidth(); + if (elemBitSize % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + + VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type()); + Value payLoadAsI64 = + vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc); + Value basePtr = vector::ExtractOp::create( + rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr)); + Value baseShapeW = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW)); + Value baseShapeH = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH)); + // Offsets provided in two ways: + // 1. Offsets are extracted from the tensor descriptor. + // 2. (Mixed) offsets which are provided by the op. + Value offsetW; + Value offsetH; + auto mixedOffsets = op.getMixedOffsets(); + int64_t opOffsetsSize = mixedOffsets.size(); + if (opOffsetsSize != 0 && opOffsetsSize != 2) + return rewriter.notifyMatchFailure(op, + "Expected 2D offsets or no offsets."); + if (opOffsetsSize) { + // If mixed offsets are provided by the op convert them to i32. + offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]); + offsetW = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetW); + offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]); + offsetH = getValueOrCreateCastToIndexLike(rewriter, loc, + rewriter.getI32Type(), offsetH); + } else { + // If offsets are not available, we need to extract them from the tensor + // descriptor. + offsetW = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW)); + offsetH = vector::ExtractOp::create( + rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH)); + } + // Get address space from tensor descriptor memory space. + auto ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + // Convert base pointer (i64) to LLVM pointer type. + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr); + // Compute element byte size and surface width in bytes. + Value elemByteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI32Type(), elemBitSize / 8); + Value surfaceW = + arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize); + + // Get tile sizes and vblocks from the tensor descriptor type. + auto tileW = tdescTy.getDimSize(1); + auto tileH = tdescTy.getDimSize(0); + int32_t vblocks = tdescTy.getArrayLength(); + if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) { + Value src = adaptor.getValue(); + // If store value is a scalar, get value from op instead of adaptor. + // Adaptor might have optimized away single element vector + if (src.getType().isIntOrFloat()) { + src = op.getValue(); + } + VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); + if (!srcVecTy) + return rewriter.notifyMatchFailure( + op, "Expected store value to be a vector type."); + // Get flat vector type of integer type with matching element bit size. + VectorType newSrcVecTy = + encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize)); + if (srcVecTy != newSrcVecTy) + src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src); + auto storeCacheControl = + translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + xevm::BlockStore2dOp::create( + rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, + offsetH, elemBitSize, tileW, tileH, src, + xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl)); + rewriter.eraseOp(op); + } else { + auto loadCacheControl = + translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()); + if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) { + xevm::BlockPrefetch2dOp::create( + rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW, + offsetH, elemBitSize, tileW, tileH, vblocks, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + rewriter.eraseOp(op); + } else { + VectorType dstVecTy = cast<VectorType>(op.getValue().getType()); + const bool vnni = op.getPacked().value_or(false); + auto transposeValue = op.getTranspose(); + bool transpose = + transposeValue.has_value() && transposeValue.value()[0] == 1; + VectorType loadedTy = encodeVectorTypeTo( + dstVecTy, vnni ? rewriter.getI32Type() + : rewriter.getIntegerType(elemBitSize)); + + Value resultFlatVec = xevm::BlockLoad2dOp::create( + rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH, + surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks, + transpose, vnni, + xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl)); + resultFlatVec = vector::BitCastOp::create( + rewriter, loc, + encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()), + resultFlatVec); + rewriter.replaceOp(op, resultFlatVec); + } + } + return success(); + } +}; + +// Add a builder that creates +// offset * elemByteSize + baseAddr +static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, + Value baseAddr, Value offset, int64_t elemByteSize) { + Value byteSize = arith::ConstantIntOp::create( + rewriter, loc, rewriter.getI64Type(), elemByteSize); + Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); + Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); + return newAddr; +} + +class CreateDescToXeVMPattern + : public OpConversionPattern<xegpu::CreateDescOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto eTy = op.getTensorDescType().getElementType(); + auto eBw = eTy.getIntOrFloatBitWidth(); + if (eBw % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + auto loc = op.getLoc(); + // Offsets are provided as scalar i64 by type converter. + auto offsets = adaptor.getOffsets(); + // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32). + // But type converter will convert them to integer types. + Value addr = adaptor.getSource(); + // ui32 or i32 are passed as i32 so they need to be casted to i64. + if (addr.getType() != rewriter.getI64Type()) + addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr); + auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8); + rewriter.replaceOp(op, laneAddr); + return success(); + } +}; + +class UpdateOffsetToXeVMPattern + : public OpConversionPattern<xegpu::UpdateOffsetOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::UpdateOffsetOp op, + xegpu::UpdateOffsetOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto eTy = op.getTensorDescType().getElementType(); + auto eBw = eTy.getIntOrFloatBitWidth(); + if (eBw % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + auto loc = op.getLoc(); + // Scatter descriptor is provided as scalar i64 by type converter. + // Offsets are provided as scalar i64 by type converter. + Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(), + adaptor.getOffsets(), eBw / 8); + rewriter.replaceOp(op, newOffset); + return success(); + } +}; + +template <typename OpType, + typename = std::enable_if_t<llvm::is_one_of< + OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>> +class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto tdescTy = op.getTensorDescType(); + Value basePtrI64; + // Load result or Store valye Type can be vector or scalar. + Type valOrResTy; + if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) + valOrResTy = op.getResult().getType(); + else + valOrResTy = adaptor.getValue().getType(); + VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy); + bool hasScalarVal = !valOrResVecTy; + int64_t elemBitWidth = + hasScalarVal ? valOrResTy.getIntOrFloatBitWidth() + : valOrResVecTy.getElementType().getIntOrFloatBitWidth(); + // Element type must be multiple of 8 bits. + if (elemBitWidth % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + int64_t elemByteSize = elemBitWidth / 8; + // Default memory space is global. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); + // If tensor descriptor is available, we use its memory space. + if (tdescTy) + ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + // Base pointer can come from source (load) or dest (store). + // If they are memrefs, we use their memory space. + if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) { + basePtrI64 = adaptor.getSource(); + if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) { + auto addrSpace = memRefTy.getMemorySpaceAsInt(); + if (addrSpace != 0) + ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); + } + } else { + basePtrI64 = adaptor.getDest(); + if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) { + auto addrSpace = memRefTy.getMemorySpaceAsInt(); + if (addrSpace != 0) + ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); + } + } + // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed. + if (basePtrI64.getType() != rewriter.getI64Type()) { + basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), + basePtrI64); + } + Value offsets = adaptor.getOffsets(); + Value mask = adaptor.getMask(); + if (offsets) { + if (dyn_cast<VectorType>(offsets.getType())) { + // Offset needs be scalar. Single element vector is converted to scalar + // by type converter. + return rewriter.notifyMatchFailure(op, + "Expected offsets to be a scalar."); + } else { + // If offsets are provided, we add them to the base pointer. + // Offsets are in number of elements, we need to multiply by + // element byte size. + basePtrI64 = + addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); + } + } + // Convert base pointer (i64) to LLVM pointer type. + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + + Value maskForLane; + VectorType maskVecTy = dyn_cast<VectorType>(mask.getType()); + if (maskVecTy) { + // Mask needs be scalar. Single element vector is converted to scalar by + // type converter. + return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar."); + } else + maskForLane = mask; + if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) { + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy}, + maskForLane, true, true); + // If mask is true,- then clause - load from memory and yield. + rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front()); + if (!hasScalarVal) + valOrResTy = VectorType::get({valOrResVecTy.getNumElements()}, + valOrResVecTy.getElementType()); + Value loaded = + LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM); + // Set cache control attribute on the load operation. + loaded.getDefiningOp()->setAttr( + "cache_control", xevm::LoadCacheControlAttr::get( + ctxt, translateLoadXeGPUCacheHint( + op.getL1Hint(), op.getL3Hint()))); + scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); + rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front()); + // If mask is false - else clause -yield a vector of zeros. + auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType(); + TypedAttr eVal; + if (eTy.isFloat()) + eVal = FloatAttr::get(eTy, 0.0); + else + eVal = IntegerAttr::get(eTy, 0); + if (hasScalarVal) + loaded = arith::ConstantOp::create(rewriter, loc, eVal); + else + loaded = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal)); + scf::YieldOp::create(rewriter, loc, ValueRange{loaded}); + rewriter.replaceOp(op, ifOp.getResult(0)); + } else { + // If mask is true, perform the store. + scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false); + auto body = ifOp.getBody(); + rewriter.setInsertionPointToStart(body); + auto storeOp = + LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM); + // Set cache control attribute on the store operation. + storeOp.getOperation()->setAttr( + "cache_control", xevm::StoreCacheControlAttr::get( + ctxt, translateStoreXeGPUCacheHint( + op.getL1Hint(), op.getL3Hint()))); + rewriter.eraseOp(op); + } + return success(); + } +}; + +class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto tdescTy = op.getTensorDescType(); + Value basePtrI64 = adaptor.getSource(); + // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed. + if (basePtrI64.getType() != rewriter.getI64Type()) + basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), + basePtrI64); + Value offsets = adaptor.getOffsets(); + if (offsets) { + VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType()); + if (offsetsVecTy) { + // Offset needs be scalar. + return rewriter.notifyMatchFailure(op, + "Expected offsets to be a scalar."); + } else { + int64_t elemBitWidth{0}; + int64_t elemByteSize; + // Element byte size can come from three sources: + if (tdescTy) { + // If tensor descriptor is available, we use its element type to + // determine element byte size. + elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth(); + } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) { + // If memref is available, we use its element type to + // determine element byte size. + elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth(); + } else { + // Otherwise, we use the provided offset byte alignment. + elemByteSize = *op.getOffsetAlignByte(); + } + if (elemBitWidth != 0) { + if (elemBitWidth % 8 != 0) + return rewriter.notifyMatchFailure( + op, "Expected element type bit width to be multiple of 8."); + elemByteSize = elemBitWidth / 8; + } + basePtrI64 = + addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize); + } + } + // Default memory space is global. + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global)); + // If tensor descriptor is available, we use its memory space. + if (tdescTy) + ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace())); + // If source is a memref, we use its memory space. + if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) { + auto addrSpace = memRefTy.getMemorySpaceAsInt(); + if (addrSpace != 0) + ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace); + } + // Convert base pointer (i64) to LLVM pointer type. + Value ptrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + // Create the prefetch op with cache control attribute. + xevm::PrefetchOp::create( + rewriter, loc, ptrLLVM, + xevm::LoadCacheControlAttr::get( + ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint()))); + rewriter.eraseOp(op); + return success(); + } +}; + +class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + xevm::MemScope memScope{xevm::MemScope::WORKGROUP}; + switch (op.getFenceScope()) { + case xegpu::FenceScope::Workgroup: + memScope = xevm::MemScope::WORKGROUP; + break; + case xegpu::FenceScope::GPU: + memScope = xevm::MemScope::DEVICE; + break; + } + xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL}; + switch (op.getMemoryKind()) { + case xegpu::MemorySpace::Global: + addrSpace = xevm::AddrSpace::GLOBAL; + break; + case xegpu::MemorySpace::SLM: + addrSpace = xevm::AddrSpace::SHARED; + break; + } + xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace); + rewriter.eraseOp(op); + return success(); + } +}; + +class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto aTy = cast<VectorType>(op.getLhs().getType()); + auto bTy = cast<VectorType>(op.getRhs().getType()); + auto resultType = cast<VectorType>(op.getResultType()); + + auto encodePrecision = [&](Type type) -> xevm::ElemType { + if (type == rewriter.getBF16Type()) + return xevm::ElemType::BF16; + else if (type == rewriter.getF16Type()) + return xevm::ElemType::F16; + else if (type == rewriter.getTF32Type()) + return xevm::ElemType::TF32; + else if (type.isInteger(8)) { + if (type.isUnsignedInteger()) + return xevm::ElemType::U8; + return xevm::ElemType::S8; + } else if (type == rewriter.getF32Type()) + return xevm::ElemType::F32; + else if (type.isInteger(32)) + return xevm::ElemType::S32; + llvm_unreachable("add more support for ElemType"); + }; + xevm::ElemType precATy = encodePrecision(aTy.getElementType()); + xevm::ElemType precBTy = encodePrecision(bTy.getElementType()); + Value c = op.getAcc(); + if (!c) { + auto elementTy = resultType.getElementType(); + Attribute initValueAttr; + if (isa<FloatType>(elementTy)) + initValueAttr = FloatAttr::get(elementTy, 0.0); + else + initValueAttr = IntegerAttr::get(elementTy, 0); + c = arith::ConstantOp::create( + rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr)); + } + + Value aVec = op.getLhs(); + Value bVec = op.getRhs(); + auto cvecty = cast<VectorType>(c.getType()); + xevm::ElemType precCTy = encodePrecision(cvecty.getElementType()); + xevm::ElemType precDTy = encodePrecision(resultType.getElementType()); + VectorType cNty = + VectorType::get(cvecty.getNumElements(), cvecty.getElementType()); + if (cvecty != cNty) + c = vector::ShapeCastOp::create(rewriter, loc, cNty, c); + Value dpasRes = xevm::MMAOp::create( + rewriter, loc, cNty, aVec, bVec, c, + xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize, + systolicDepth * + getNumOperandsPerDword(precATy)), + xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy)); + if (cvecty != cNty) + dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes); + rewriter.replaceOp(op, dpasRes); + return success(); + } + +private: + static unsigned getNumOperandsPerDword(xevm::ElemType pTy) { + switch (pTy) { + case xevm::ElemType::TF32: + return 1; + case xevm::ElemType::BF16: + case xevm::ElemType::F16: + return 2; + case xevm::ElemType::U8: + case xevm::ElemType::S8: + return 4; + default: + llvm_unreachable("unsupported xevm::ElemType"); + } + } +}; + +static std::optional<LLVM::AtomicBinOp> +matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) { + switch (arithKind) { + case arith::AtomicRMWKind::addf: + return LLVM::AtomicBinOp::fadd; + case arith::AtomicRMWKind::addi: + return LLVM::AtomicBinOp::add; + case arith::AtomicRMWKind::assign: + return LLVM::AtomicBinOp::xchg; + case arith::AtomicRMWKind::maximumf: + return LLVM::AtomicBinOp::fmax; + case arith::AtomicRMWKind::maxs: + return LLVM::AtomicBinOp::max; + case arith::AtomicRMWKind::maxu: + return LLVM::AtomicBinOp::umax; + case arith::AtomicRMWKind::minimumf: + return LLVM::AtomicBinOp::fmin; + case arith::AtomicRMWKind::mins: + return LLVM::AtomicBinOp::min; + case arith::AtomicRMWKind::minu: + return LLVM::AtomicBinOp::umin; + case arith::AtomicRMWKind::ori: + return LLVM::AtomicBinOp::_or; + case arith::AtomicRMWKind::andi: + return LLVM::AtomicBinOp::_and; + default: + return std::nullopt; + } +} + +class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto ctxt = rewriter.getContext(); + auto tdesc = op.getTensorDesc().getType(); + auto ptrTypeLLVM = LLVM::LLVMPointerType::get( + ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace())); + Value basePtrI64 = arith::IndexCastOp::create( + rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc()); + Value basePtrLLVM = + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); + VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType()); + VectorType srcOrDstFlatVecTy = VectorType::get( + srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType()); + Value srcFlatVec = vector::ShapeCastOp::create( + rewriter, loc, srcOrDstFlatVecTy, op.getValue()); + auto atomicKind = matchSimpleAtomicOp(op.getKind()); + assert(atomicKind.has_value()); + Value resVec = srcFlatVec; + for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) { + auto val = vector::ExtractOp::create(rewriter, loc, resVec, i); + Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), + rewriter.getIndexAttr(i)); + Value currPtr = + LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM, + srcOrDstVecTy.getElementType(), basePtrLLVM, idx); + Value newVal = + LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr, + val, LLVM::AtomicOrdering::seq_cst); + resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i); + } + rewriter.replaceOp(op, resVec); + return success(); + } +}; + +//===----------------------------------------------------------------------===// +// Pass Definition +//===----------------------------------------------------------------------===// + +struct ConvertXeGPUToXeVMPass + : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> { + using Base::Base; + + void runOnOperation() override { + LLVMTypeConverter typeConverter(&getContext()); + typeConverter.addConversion([&](VectorType type) -> Type { + unsigned rank = type.getRank(); + auto elemType = type.getElementType(); + // If the element type is index, convert it to i64. + if (llvm::isa<IndexType>(elemType)) + elemType = IntegerType::get(&getContext(), 64); + // If the vector is a scalar or has a single element, return the element + if (rank < 1 || type.getNumElements() == 1) + return elemType; + // Otherwise, convert the vector to a flat vector type. + int64_t sum = + std::accumulate(type.getShape().begin(), type.getShape().end(), + int64_t{1}, std::multiplies<int64_t>()); + return VectorType::get(sum, elemType); + }); + typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type { + if (type.isScattered()) + return IntegerType::get(&getContext(), 64); + auto i32Type = IntegerType::get(&getContext(), 32); + return VectorType::get(8, i32Type); + }); + typeConverter.addConversion([&](MemRefType type) -> Type { + // Convert MemRefType to i64 type. + return IntegerType::get(&getContext(), 64); + }); + + // LLVM type converter puts unrealized casts for the following cases: + // add materialization casts to handle them. + + // Materialization to convert memref to i64 + auto memrefMaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) { + + Value addr = + memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input); + return arith::IndexCastUIOp::create(builder, loc, type, addr) + .getResult(); + } + return {}; + }; + + // Materialization to convert ui64 to i64 + auto ui64MaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType() == builder.getIntegerType(64, false)) { + Value cast = + index::CastUOp::create(builder, loc, builder.getIndexType(), input) + .getResult(); + return arith::IndexCastUIOp::create(builder, loc, type, cast) + .getResult(); + } + return {}; + }; + + // Materialization to convert ui32 to i32 + auto ui32MaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (input.getType() == builder.getIntegerType(32, false)) { + Value cast = + index::CastUOp::create(builder, loc, builder.getIndexType(), input) + .getResult(); + return arith::IndexCastUIOp::create(builder, loc, type, cast) + .getResult(); + } + return {}; + }; + + // Materialization to convert + // - single element 1D vector to scalar + // - bitcast vector of same rank + // - shape vector of different rank but same element type + auto vectorMaterializationCast = [](OpBuilder &builder, Type type, + ValueRange inputs, + Location loc) -> Value { + if (inputs.size() != 1) + return {}; + auto input = inputs.front(); + if (auto vecTy = dyn_cast<VectorType>(input.getType())) { + if (vecTy.getNumElements() == 1) { + // If the vector has a single element, return the element type. + Value cast = + vector::ExtractOp::create(builder, loc, input, 0).getResult(); + if (vecTy.getElementType() == builder.getIndexType()) + cast = arith::IndexCastUIOp::create(builder, loc, type, cast) + .getResult(); + return cast; + } else if (auto targetVecTy = dyn_cast<VectorType>(type)) { + // If the target type is a vector of same rank, + // bitcast to the target type. + if (targetVecTy.getRank() == vecTy.getRank()) + return vector::BitCastOp::create(builder, loc, targetVecTy, input) + .getResult(); + else if (targetVecTy.getElementType() == vecTy.getElementType()) { + // If the target type is a vector of different rank but same element + // type, reshape to the target type. + return vector::ShapeCastOp::create(builder, loc, targetVecTy, input) + .getResult(); + } + } + } + return {}; + }; + typeConverter.addSourceMaterialization(memrefMaterializationCast); + typeConverter.addSourceMaterialization(ui64MaterializationCast); + typeConverter.addSourceMaterialization(ui32MaterializationCast); + typeConverter.addSourceMaterialization(vectorMaterializationCast); + typeConverter.addTargetMaterialization(memrefMaterializationCast); + typeConverter.addTargetMaterialization(ui32MaterializationCast); + typeConverter.addTargetMaterialization(ui64MaterializationCast); + typeConverter.addTargetMaterialization(vectorMaterializationCast); + ConversionTarget target(getContext()); + target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect, + vector::VectorDialect, arith::ArithDialect, + memref::MemRefDialect, gpu::GPUDialect, + index::IndexDialect>(); + target.addIllegalDialect<xegpu::XeGPUDialect>(); + + RewritePatternSet patterns(&getContext()); + populateXeGPUToXeVMConversionPatterns(typeConverter, patterns); + scf::populateSCFStructuralTypeConversionsAndLegality(typeConverter, + patterns, target); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + signalPassFailure(); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Pattern Population +//===----------------------------------------------------------------------===// +void mlir::populateXeGPUToXeVMConversionPatterns( + const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) { + patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern, + LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>, + LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>, + LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>( + typeConverter, patterns.getContext()); + patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern, + AtomicRMWToXeVMPattern, PrefetchToXeVMPattern, + LoadStoreToXeVMPattern<xegpu::LoadGatherOp>, + LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>( + typeConverter, patterns.getContext()); + patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter, + patterns.getContext()); +} diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index d7ffdcb..11a40d6 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -511,6 +511,18 @@ LogicalResult DPPOp::verify() { } //===----------------------------------------------------------------------===// +// PermlaneSwapOp +//===----------------------------------------------------------------------===// +LogicalResult PermlaneSwapOp::verify() { + unsigned rowLength = getRowLength(); + + if (rowLength != 16 && rowLength != 32) + return emitOpError("row_length attribute must either be 16 or 32."); + + return success(); +} + +//===----------------------------------------------------------------------===// // GatherToLDSOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt index 729e3da..d35853b 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt @@ -5,7 +5,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms ResolveStridedMetadata.cpp ADDITIONAL_HEADER_DIRS - {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms DEPENDS MLIRAMDGPUTransformsIncGen diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index 6f3110c..68990ef 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -271,7 +271,9 @@ Type amx::TileType::parse(AsmParser &parser) { if (parser.parseGreater()) return nullptr; - return TileType::get(shape, elementType); + return TileType::getChecked( + [&] { return parser.emitError(parser.getNameLoc()); }, shape, + elementType); } void amx::TileType::print(AsmPrinter &os) const { diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp index 86edc2b..b405ec2 100644 --- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp @@ -93,13 +93,13 @@ FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) { int64_t lb = forOp.getConstantLowerBound(); dividend[pos] = 1; dividend.back() -= lb; - addLocalFloorDiv(dividend, step); + unsigned qPos = addLocalFloorDiv(dividend, step); // Second constraint: (iv - lb) - step * q = 0. SmallVector<int64_t, 8> eq(getNumCols(), 0); eq[pos] = 1; eq.back() -= lb; // For the local var just added above. - eq[getNumCols() - 2] = -step; + eq[qPos] = -step; addEquality(eq); } } diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp index 2f85e0b..166d39e 100644 --- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp @@ -21,6 +21,7 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <numeric> #include <optional> @@ -548,19 +549,19 @@ bool mlir::affine::isTilingValid(ArrayRef<AffineForOp> loops) { // Check whether there is any negative direction vector in the // dependence components found above, which means that dependence is // violated by the default hyper-rect tiling method. - LLVM_DEBUG(llvm::dbgs() << "Checking whether tiling legality violated " - "for dependence at depth: " - << Twine(d) << " between:\n";); - LLVM_DEBUG(srcAccess.opInst->dump()); - LLVM_DEBUG(dstAccess.opInst->dump()); + LDBG() << "Checking whether tiling legality violated " + << "for dependence at depth: " << Twine(d) << " between:" + << OpWithFlags(srcAccess.opInst, OpPrintingFlags().skipRegions()) + << "\nand:\n" + << OpWithFlags(dstAccess.opInst, + OpPrintingFlags().skipRegions()); for (const DependenceComponent &depComp : depComps) { if (depComp.lb.has_value() && depComp.ub.has_value() && *depComp.lb < *depComp.ub && *depComp.ub < 0) { - LLVM_DEBUG(llvm::dbgs() - << "Dependence component lb = " << Twine(*depComp.lb) - << " ub = " << Twine(*depComp.ub) - << " is negative at depth: " << Twine(d) - << " and thus violates the legality rule.\n"); + LDBG() << "Dependence component lb = " << Twine(*depComp.lb) + << " ub = " << Twine(*depComp.ub) + << " is negative at depth: " << Twine(d) + << " and thus violates the legality rule."; return false; } } diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp index a89c1ae..99ea20b 100644 --- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/IntegerSet.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <optional> @@ -241,7 +242,7 @@ addNodeToMDG(Operation *nodeOp, MemRefDependenceGraph &mdg, } bool MemRefDependenceGraph::init() { - LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); + LDBG() << "--- Initializing MDG ---"; // Map from a memref to the set of ids of the nodes that have ops accessing // the memref. DenseMap<Value, SetVector<unsigned>> memrefAccesses; @@ -288,8 +289,7 @@ bool MemRefDependenceGraph::init() { // Return false if non-handled/unknown region-holding ops are found. We // won't know what such ops do or what its regions mean; for e.g., it may // not be an imperative op. - LLVM_DEBUG(llvm::dbgs() - << "MDG init failed; unknown region-holding op found!\n"); + LDBG() << "MDG init failed; unknown region-holding op found!"; return false; } // We aren't creating nodes for memory-effect free ops either with no @@ -297,7 +297,7 @@ bool MemRefDependenceGraph::init() { // interface. } - LLVM_DEBUG(llvm::dbgs() << "Created " << nodes.size() << " nodes\n"); + LDBG() << "Created " << nodes.size() << " nodes"; // Add dependence edges between nodes which produce SSA values and their // users. Load ops can be considered as the ones producing SSA values. @@ -556,9 +556,8 @@ MemRefDependenceGraph::getFusedLoopNestInsertionPoint(unsigned srcId, gatherDefiningNodes(dstId, definingNodes); if (llvm::any_of(definingNodes, [&](unsigned id) { return hasDependencePath(srcId, id); })) { - LLVM_DEBUG(llvm::dbgs() - << "Can't fuse: a defining op with a user in the dst " - "loop has dependence from the src loop\n"); + LDBG() << "Can't fuse: a defining op with a user in the dst " + << "loop has dependence from the src loop"; return nullptr; } @@ -957,20 +956,20 @@ std::optional<bool> ComputationSliceState::isSliceValid() const { FlatAffineValueConstraints srcConstraints; // TODO: Store the source's domain to avoid computation at each depth. if (failed(getSourceAsConstraints(srcConstraints))) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute source's domain\n"); + LDBG() << "Unable to compute source's domain"; return std::nullopt; } // As the set difference utility currently cannot handle symbols in its // operands, validity of the slice cannot be determined. if (srcConstraints.getNumSymbolVars() > 0) { - LLVM_DEBUG(llvm::dbgs() << "Cannot handle symbols in source domain\n"); + LDBG() << "Cannot handle symbols in source domain"; return std::nullopt; } // TODO: Handle local vars in the source domains while using the 'projectOut' // utility below. Currently, aligning is not done assuming that there will be // no local vars in the source domain. if (srcConstraints.getNumLocalVars() != 0) { - LLVM_DEBUG(llvm::dbgs() << "Cannot handle locals in source domain\n"); + LDBG() << "Cannot handle locals in source domain"; return std::nullopt; } @@ -978,7 +977,7 @@ std::optional<bool> ComputationSliceState::isSliceValid() const { // fusion succeeds. FlatAffineValueConstraints sliceConstraints; if (failed(getAsConstraints(&sliceConstraints))) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute slice's domain\n"); + LDBG() << "Unable to compute slice's domain"; return std::nullopt; } @@ -987,11 +986,11 @@ std::optional<bool> ComputationSliceState::isSliceValid() const { sliceConstraints.projectOut(ivs.size(), sliceConstraints.getNumVars() - ivs.size()); - LLVM_DEBUG(llvm::dbgs() << "Domain of the source of the slice:\n"); - LLVM_DEBUG(srcConstraints.dump()); - LLVM_DEBUG(llvm::dbgs() << "Domain of the slice if this fusion succeeds " - "(expressed in terms of its source's IVs):\n"); - LLVM_DEBUG(sliceConstraints.dump()); + LDBG() << "Domain of the source of the slice:\n" + << "Source constraints:" << srcConstraints + << "\nDomain of the slice if this fusion succeeds " + << "(expressed in terms of its source's IVs):\n" + << "Slice constraints:" << sliceConstraints; // TODO: Store 'srcSet' to avoid recalculating for each depth. PresburgerSet srcSet(srcConstraints); @@ -999,7 +998,7 @@ std::optional<bool> ComputationSliceState::isSliceValid() const { PresburgerSet diffSet = sliceSet.subtract(srcSet); if (!diffSet.isIntegerEmpty()) { - LLVM_DEBUG(llvm::dbgs() << "Incorrect slice\n"); + LDBG() << "Incorrect slice"; return false; } return true; @@ -1172,8 +1171,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, unsigned rank = access.getRank(); - LLVM_DEBUG(llvm::dbgs() << "MemRefRegion::compute: " << *op - << "\ndepth: " << loopDepth << "\n";); + LDBG() << "MemRefRegion::compute: " << *op << " depth: " << loopDepth; // 0-d memrefs. if (rank == 0) { @@ -1236,7 +1234,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, if (auto constVal = getConstantIntValue(symbol)) cst.addBound(BoundType::EQ, symbol, constVal.value()); } else { - LLVM_DEBUG(llvm::dbgs() << "unknown affine dimensional value"); + LDBG() << "unknown affine dimensional value"; return failure(); } } @@ -1260,7 +1258,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, // Add access function equalities to connect loop IVs to data dimensions. if (failed(cst.composeMap(&accessValueMap))) { op->emitError("getMemRefRegion: compose affine map failed"); - LLVM_DEBUG(accessValueMap.getAffineMap().dump()); + LDBG() << "Access map: " << accessValueMap.getAffineMap(); return failure(); } @@ -1317,8 +1315,7 @@ LogicalResult MemRefRegion::compute(Operation *op, unsigned loopDepth, } cst.removeTrivialRedundancy(); - LLVM_DEBUG(llvm::dbgs() << "Memory region:\n"); - LLVM_DEBUG(cst.dump()); + LDBG() << "Memory region: " << cst; return success(); } @@ -1346,14 +1343,14 @@ std::optional<int64_t> MemRefRegion::getRegionSize() { auto memRefType = cast<MemRefType>(memref.getType()); if (!memRefType.getLayout().isIdentity()) { - LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); + LDBG() << "Non-identity layout map not yet supported"; return false; } // Compute the extents of the buffer. std::optional<int64_t> numElements = getConstantBoundingSizeAndShape(); if (!numElements) { - LLVM_DEBUG(llvm::dbgs() << "Dynamic shapes not yet supported\n"); + LDBG() << "Dynamic shapes not yet supported"; return std::nullopt; } auto eltSize = getMemRefIntOrFloatEltSizeInBytes(memRefType); @@ -1397,8 +1394,7 @@ LogicalResult mlir::affine::boundCheckLoadOrStoreOp(LoadOrStoreOp loadOrStoreOp, /*addMemRefDimBounds=*/false))) return success(); - LLVM_DEBUG(llvm::dbgs() << "Memory region"); - LLVM_DEBUG(region.getConstraints()->dump()); + LDBG() << "Memory region: " << region.getConstraints(); bool outOfBounds = false; unsigned rank = loadOrStoreOp.getMemRefType().getRank(); @@ -1558,7 +1554,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, // Check if 'loopDepth' exceeds nesting depth of src/dst ops. if ((!isBackwardSlice && loopDepth > getNestingDepth(a)) || (isBackwardSlice && loopDepth > getNestingDepth(b))) { - LLVM_DEBUG(llvm::dbgs() << "Invalid loop depth\n"); + LDBG() << "Invalid loop depth"; return SliceComputationResult::GenericFailure; } @@ -1571,7 +1567,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, &dependenceConstraints, /*dependenceComponents=*/nullptr, /*allowRAR=*/readReadAccesses); if (result.value == DependenceResult::Failure) { - LLVM_DEBUG(llvm::dbgs() << "Dependence check failed\n"); + LDBG() << "Dependence check failed"; return SliceComputationResult::GenericFailure; } if (result.value == DependenceResult::NoDependence) @@ -1586,8 +1582,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, if (sliceUnionCst.getNumDimAndSymbolVars() == 0) { // Initialize 'sliceUnionCst' with the bounds computed in previous step. if (failed(tmpSliceState.getAsConstraints(&sliceUnionCst))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute slice bound constraints\n"); + LDBG() << "Unable to compute slice bound constraints"; return SliceComputationResult::GenericFailure; } assert(sliceUnionCst.getNumDimAndSymbolVars() > 0); @@ -1597,8 +1592,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, // Compute constraints for 'tmpSliceState' in 'tmpSliceCst'. FlatAffineValueConstraints tmpSliceCst; if (failed(tmpSliceState.getAsConstraints(&tmpSliceCst))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute slice bound constraints\n"); + LDBG() << "Unable to compute slice bound constraints"; return SliceComputationResult::GenericFailure; } @@ -1630,8 +1624,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, if (sliceUnionCst.getNumLocalVars() > 0 || tmpSliceCst.getNumLocalVars() > 0 || failed(sliceUnionCst.unionBoundingBox(tmpSliceCst))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute union bounding box of slice bounds\n"); + LDBG() << "Unable to compute union bounding box of slice bounds"; return SliceComputationResult::GenericFailure; } } @@ -1639,7 +1632,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, // Empty union. if (sliceUnionCst.getNumDimAndSymbolVars() == 0) { - LLVM_DEBUG(llvm::dbgs() << "empty slice union - unexpected\n"); + LDBG() << "empty slice union - unexpected"; return SliceComputationResult::GenericFailure; } @@ -1652,7 +1645,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, unsigned innermostCommonLoopDepth = getInnermostCommonLoopDepth(ops, &surroundingLoops); if (loopDepth > innermostCommonLoopDepth) { - LLVM_DEBUG(llvm::dbgs() << "Exceeds max loop depth\n"); + LDBG() << "Exceeds max loop depth"; return SliceComputationResult::GenericFailure; } @@ -1696,7 +1689,7 @@ mlir::affine::computeSliceUnion(ArrayRef<Operation *> opsA, // that the slice is valid, otherwise return appropriate failure status. std::optional<bool> isSliceValid = sliceUnion->isSliceValid(); if (!isSliceValid) { - LLVM_DEBUG(llvm::dbgs() << "Cannot determine if the slice is valid\n"); + LDBG() << "Cannot determine if the slice is valid"; return SliceComputationResult::GenericFailure; } if (!*isSliceValid) @@ -2050,7 +2043,8 @@ static std::optional<int64_t> getMemoryFootprintBytes(Block &block, if (failed( region->compute(opInst, /*loopDepth=*/getNestingDepth(&*block.begin())))) { - LLVM_DEBUG(opInst->emitError("error obtaining memory region")); + LDBG() << "Error obtaining memory region"; + opInst->emitError("error obtaining memory region"); return failure(); } @@ -2058,9 +2052,11 @@ static std::optional<int64_t> getMemoryFootprintBytes(Block &block, if (inserted) { it->second = std::move(region); } else if (failed(it->second->unionBoundingBox(*region))) { - LLVM_DEBUG(opInst->emitWarning( + LDBG() << "getMemoryFootprintBytes: unable to perform a union on a " + "memory region"; + opInst->emitWarning( "getMemoryFootprintBytes: unable to perform a union on a memory " - "region")); + "region"); return failure(); } return WalkResult::advance(); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 22608a1..7e5ce26 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -427,6 +427,21 @@ bool mlir::affine::isValidSymbol(Value value) { return false; } +/// A utility function to check if a value is defined at the top level of +/// `region` or is an argument of `region` or is defined above the region. +static bool isTopLevelValueOrAbove(Value value, Region *region) { + Region *parentRegion = value.getParentRegion(); + do { + if (parentRegion == region) + return true; + Operation *regionOp = region->getParentOp(); + if (regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) + break; + region = region->getParentOp()->getParentRegion(); + } while (region); + return false; +} + /// A value can be used as a symbol for `region` iff it meets one of the /// following conditions: /// *) It is a constant. @@ -445,19 +460,12 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) { return false; // A top-level value is a valid symbol. - if (region && ::isTopLevelValue(value, region)) + if (region && isTopLevelValueOrAbove(value, region)) return true; auto *defOp = value.getDefiningOp(); - if (!defOp) { - // A block argument that is not a top-level value is a valid symbol if it - // dominates region's parent op. - Operation *regionOp = region ? region->getParentOp() : nullptr; - if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) - if (auto *parentOpRegion = region->getParentOp()->getParentRegion()) - return isValidSymbol(value, parentOpRegion); + if (!defOp) return false; - } // Constant operation is ok. Attribute operandCst; @@ -475,12 +483,6 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) { if (auto dimOp = dyn_cast<ShapedDimOpInterface>(defOp)) return isDimOpValidSymbol(dimOp, region); - // Check for values dominating `region`'s parent op. - Operation *regionOp = region ? region->getParentOp() : nullptr; - if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>()) - if (auto *parentRegion = region->getParentOp()->getParentRegion()) - return isValidSymbol(value, parentRegion); - return false; } diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp index 6c9adff..ff0157e 100644 --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <iomanip> #include <optional> @@ -95,8 +96,8 @@ static bool canRemoveSrcNodeAfterFusion( // Otherwise, the src loop can't be removed. if (fusedLoopInsPoint != depNodeOp && !fusedLoopInsPoint->isBeforeInBlock(depNodeOp)) { - LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: dst loop doesn't " - "dominate dependence\n"); + LDBG() << "Src loop can't be removed: dst loop doesn't " + << "dominate dependence"; return false; } @@ -109,14 +110,13 @@ static bool canRemoveSrcNodeAfterFusion( if (hasOutDepsAfterFusion || !escapingMemRefs.empty()) { std::optional<bool> isMaximal = fusionSlice.isMaximal(); if (!isMaximal) { - LLVM_DEBUG(llvm::dbgs() << "Src loop can't be removed: can't determine " - "if fusion is maximal\n"); + LDBG() << "Src loop can't be removed: can't determine " + << "if fusion is maximal"; return false; } if (!*isMaximal) { - LLVM_DEBUG(llvm::dbgs() - << "Src loop can't be removed: fusion is not maximal\n"); + LDBG() << "Src loop can't be removed: fusion is not maximal"; return false; } } @@ -190,7 +190,8 @@ static bool isEscapingMemref(Value memref, Block *block) { // Check if this is defined to be an alias of another memref. if (auto viewOp = dyn_cast<mlir::ViewLikeOpInterface>(defOp)) - if (isEscapingMemref(viewOp.getViewSource(), block)) + if (memref == viewOp.getViewDest() && + isEscapingMemref(viewOp.getViewSource(), block)) return true; // Any op besides allocating ops wouldn't guarantee alias freedom @@ -279,19 +280,19 @@ static std::optional<double> getAdditionalComputeFraction( AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth, ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost, int64_t &fusedLoopNestComputeCost) { - LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";); + LDBG() << "Determining additional compute fraction..."; // Compute cost of sliced and unsliced src loop nest. // Walk src loop nest and collect stats. LoopNestStats srcLoopNestStats; if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) { - LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n"); + LDBG() << "Failed to get source loop nest stats."; return std::nullopt; } // Compute cost of dst loop nest. LoopNestStats dstLoopNestStats; if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) { - LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n"); + LDBG() << "Failed to get destination loop nest stats."; return std::nullopt; } @@ -304,14 +305,14 @@ static std::optional<double> getAdditionalComputeFraction( const ComputationSliceState &slice = depthSliceUnions[depth - 1]; // Skip slice union if it wasn't computed for this depth. if (slice.isEmpty()) { - LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n"); + LDBG() << "Slice wasn't computed."; return std::nullopt; } if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp, dstLoopNestStats, slice, &fusedLoopNestComputeCost)) { - LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n"); + LDBG() << "Unable to compute fusion compute cost"; return std::nullopt; } @@ -348,9 +349,8 @@ static Value createPrivateMemRef(AffineForOp forOp, MemRefAccess bM(cast<AffineWriteOpInterface>(b)); return aM == bM; })) { - LLVM_DEBUG(llvm::dbgs() - << "Private memref creation unsupported for multiple producer " - "stores with different access functions.\n"); + LDBG() << "Private memref creation unsupported for multiple producer " + << "stores with different access functions."; return nullptr; } @@ -455,8 +455,7 @@ static Value createPrivateMemRef(AffineForOp forOp, assert(succeeded(res) && "replaceAllMemrefUsesWith should always succeed here"); (void)res; - LLVM_DEBUG(llvm::dbgs() << "Created private memref of type: " << newMemRefType - << '\n'); + LDBG() << "Created private memref of type: " << newMemRefType; return newMemRef; } @@ -505,15 +504,12 @@ static bool isFusionProfitable(AffineForOp srcForOp, unsigned maxLegalFusionDepth, unsigned *dstLoopDepth, double computeToleranceThreshold) { - LLVM_DEBUG({ - llvm::dbgs() - << "Checking whether fusion is profitable between source nest:\n"; - llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n"; - llvm::dbgs() << dstForOp << "\n"; - }); + LDBG() << "Checking whether fusion is profitable between source nest:"; + LDBG() << ' ' << srcForOp << " and destination nest:"; + LDBG() << dstForOp; if (maxLegalFusionDepth == 0) { - LLVM_DEBUG(llvm::dbgs() << "Can't fuse: maxLegalFusionDepth is 0\n"); + LDBG() << "Can't fuse: maxLegalFusionDepth is 0"; return false; } @@ -537,8 +533,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, // TODO: Suppport multiple producer stores in profitability // analysis. if (producerStores.size() > 1) { - LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not " - "supported for multiple producer store case.\n"); + LDBG() << "Limited profitability analysis. Not " + << "supported for multiple producer store case."; int64_t sliceCost; int64_t fusedLoopNestComputeCost; // We will still fuse if fusion obeys the specified compute @@ -547,12 +543,11 @@ static bool isFusionProfitable(AffineForOp srcForOp, srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost, fusedLoopNestComputeCost); if (!fraction || fraction > computeToleranceThreshold) { - LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds " - "compute tolerance. Not fusing.\n"); + LDBG() << "Additional computation exceeds " + << "compute tolerance. Not fusing."; return false; } - LLVM_DEBUG(llvm::dbgs() - << "Considering fusion profitable at max legal depth.\n"); + LDBG() << "Considering fusion profitable at max legal depth."; return true; } @@ -574,8 +569,7 @@ static bool isFusionProfitable(AffineForOp srcForOp, // Compute src loop nest write region size. MemRefRegion srcWriteRegion(srcStoreOp->getLoc()); if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) { - LLVM_DEBUG(llvm::dbgs() - << "Unable to compute MemRefRegion for source operation\n"); + LDBG() << "Unable to compute MemRefRegion for source operation"; return false; } @@ -609,8 +603,7 @@ static bool isFusionProfitable(AffineForOp srcForOp, getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions, sliceCost, fusedLoopNestComputeCost); if (!mayAdditionalComputeFraction) { - LLVM_DEBUG(llvm::dbgs() - << "Can't determine additional compute fraction.\n"); + LDBG() << "Can't determine additional compute fraction."; continue; } double additionalComputeFraction = *mayAdditionalComputeFraction; @@ -620,9 +613,7 @@ static bool isFusionProfitable(AffineForOp srcForOp, // depth 'i'. MemRefRegion sliceWriteRegion(srcStoreOp->getLoc()); if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to compute slice write region at loopDepth: " << i - << "\n"); + LDBG() << "Failed to compute slice write region at loopDepth: " << i; continue; } @@ -630,9 +621,7 @@ static bool isFusionProfitable(AffineForOp srcForOp, sliceWriteRegion.getRegionSize(); if (!maybeSliceWriteRegionSizeBytes.has_value() || *maybeSliceWriteRegionSizeBytes == 0) { - LLVM_DEBUG(llvm::dbgs() - << "Failed to get slice write region size at loopDepth: " << i - << "\n"); + LDBG() << "Failed to get slice write region size at loopDepth: " << i; continue; } int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes; @@ -649,9 +638,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, << " storage reduction factor: " << storageReduction << "x\n" << " fused nest cost: " << fusedLoopNestComputeCost << "\n" << " src write region size: " << srcWriteRegionSizeBytes << "\n" - << " slice write region size: " << sliceWriteRegionSizeBytes - << "\n"; - llvm::dbgs() << msg.str(); + << " slice write region size: " << sliceWriteRegionSizeBytes; + LDBG() << msg.str(); }); // TODO: This is a placeholder cost model. @@ -670,28 +658,24 @@ static bool isFusionProfitable(AffineForOp srcForOp, // A simple cost model: fuse if it reduces the memory footprint. if (!bestDstLoopDepth) { - LLVM_DEBUG( - llvm::dbgs() - << "All fusion choices involve more than the threshold amount of " - "redundant computation; NOT fusing.\n"); + LDBG() << "All fusion choices involve more than the threshold amount of " + << "redundant computation; NOT fusing."; return false; } if (!bestDstLoopDepth) { - LLVM_DEBUG(llvm::dbgs() << "no fusion depth could be evaluated.\n"); + LDBG() << "no fusion depth could be evaluated."; return false; } // Set dstLoopDepth based on best values from search. *dstLoopDepth = *bestDstLoopDepth; - LLVM_DEBUG( - llvm::dbgs() << " LoopFusion fusion stats:" - << "\n best loop depth: " << bestDstLoopDepth - << "\n src loop nest compute cost: " << srcLoopNestCost - << "\n dst loop nest compute cost: " << dstLoopNestCost - << "\n fused loop nest compute cost: " - << minFusedLoopNestComputeCost << "\n"); + LDBG() << " LoopFusion fusion stats:"; + LDBG() << " best loop depth: " << bestDstLoopDepth; + LDBG() << " src loop nest compute cost: " << srcLoopNestCost; + LDBG() << " dst loop nest compute cost: " << dstLoopNestCost; + LDBG() << " fused loop nest compute cost: " << minFusedLoopNestComputeCost; auto dstMemSize = getMemoryFootprintBytes(dstForOp); auto srcMemSize = getMemoryFootprintBytes(srcForOp); @@ -699,8 +683,7 @@ static bool isFusionProfitable(AffineForOp srcForOp, std::optional<double> storageReduction; if (!dstMemSize || !srcMemSize) { - LLVM_DEBUG(llvm::dbgs() - << " fusion memory benefit cannot be evaluated; NOT fusing.\n"); + LDBG() << " fusion memory benefit cannot be evaluated; NOT fusing."; return false; } @@ -710,13 +693,13 @@ static bool isFusionProfitable(AffineForOp srcForOp, assert(sliceMemEstimate && "expected value"); auto fusedMem = dstMemSizeVal + *sliceMemEstimate; - LLVM_DEBUG(llvm::dbgs() << " src mem: " << srcMemSizeVal << "\n" - << " dst mem: " << dstMemSizeVal << "\n" - << " fused mem: " << fusedMem << "\n" - << " slice mem: " << sliceMemEstimate << "\n"); + LDBG() << " src mem: " << srcMemSizeVal; + LDBG() << " dst mem: " << dstMemSizeVal; + LDBG() << " fused mem: " << fusedMem; + LDBG() << " slice mem: " << sliceMemEstimate; if (static_cast<long>(fusedMem) > srcMemSizeVal + dstMemSizeVal) { - LLVM_DEBUG(llvm::dbgs() << "Fusion is not profitable; NOT fusing.\n"); + LDBG() << "Fusion is not profitable; NOT fusing."; return false; } storageReduction = @@ -734,8 +717,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, << std::setprecision(2) << additionalComputeFraction << "% redundant computation and a "; msg << (storageReduction ? std::to_string(*storageReduction) : "<unknown>"); - msg << "% storage reduction.\n"; - llvm::dbgs() << msg.str(); + msg << "% storage reduction."; + LDBG() << msg.str(); }); return true; @@ -895,7 +878,7 @@ public: /// No fusion is performed when producers with a user count greater than /// `maxSrcUserCount` for any of the memrefs involved. void performFusionsIntoDest(unsigned dstId, unsigned maxSrcUserCount) { - LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); + LDBG() << "Evaluating dst loop " << dstId; // Skip if this node was removed (fused into another node). if (mdg->nodes.count(dstId) == 0) return; @@ -909,7 +892,7 @@ public: if (dstNode->op->getNumResults() > 0) return; - LLVM_DEBUG(llvm::dbgs() << "Evaluating dst loop " << dstId << "\n"); + LDBG() << "Evaluating dst loop " << dstId; // Sink sequential loops in 'dstNode' (and thus raise parallel loops) // while preserving relative order. This can increase the maximum loop @@ -936,18 +919,14 @@ public: auto *srcNode = mdg->getNode(srcId); auto srcAffineForOp = cast<AffineForOp>(srcNode->op); - LLVM_DEBUG(llvm::dbgs() - << "Trying to fuse producer loop nest " << srcId - << " with consumer loop nest " << dstId << "\n"); - LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: " - << computeToleranceThreshold << '\n'); - LLVM_DEBUG(llvm::dbgs() - << "Producer loop nest:\n" - << *srcNode->op << "\n and consumer loop nest:\n" - << *dstNode->op << '\n'); + LDBG() << "Trying to fuse producer loop nest " << srcId + << " with consumer loop nest " << dstId; + LDBG() << "Compute tolerance threshold: " << computeToleranceThreshold; + LDBG() << "Producer loop nest:"; + LDBG() << *srcNode->op << " and consumer loop nest:"; + LDBG() << *dstNode->op; - LLVM_DEBUG(llvm::dbgs() << "Evaluating src loop " << srcId - << " for dst loop " << dstId << "\n"); + LDBG() << "Evaluating src loop " << srcId << " for dst loop " << dstId; // Skip if 'srcNode' is a loop nest returning values. // TODO: support loop nests that return values. @@ -1018,19 +997,16 @@ public: &depthSliceUnions[i - 1], strategy); if (result.value == FusionResult::Success) { maxLegalFusionDepth = i; - LLVM_DEBUG(llvm::dbgs() - << "Found valid slice for depth: " << i << '\n'); + LDBG() << "Found valid slice for depth: " << i; } } if (maxLegalFusionDepth == 0) { - LLVM_DEBUG(llvm::dbgs() - << "Can't fuse: fusion is not legal at any depth\n"); + LDBG() << "Can't fuse: fusion is not legal at any depth"; continue; } - LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: " - << maxLegalFusionDepth << '\n'); + LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth; double computeToleranceThresholdToUse = computeToleranceThreshold; @@ -1040,7 +1016,7 @@ public: // producer-consumer memref access for example). Check this and allow // fusion accordingly. if (hasCyclicDependence(srcAffineForOp)) { - LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n"); + LDBG() << "Source nest has a cyclic dependence."; // Maximal fusion does not check for compute tolerance threshold; so // perform the maximal fusion only when the redundanation computation // is zero. @@ -1053,18 +1029,15 @@ public: srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost, fusedLoopNestComputeCost); if (!fraction || fraction > 0) { - LLVM_DEBUG( - llvm::dbgs() - << "Can't perform maximal fusion with a cyclic dependence " - "and non-zero additional compute.\n"); + LDBG() << "Can't perform maximal fusion with a cyclic dependence " + << "and non-zero additional compute."; return; } } else { // Set redundant computation tolerance to zero regardless of what // the user specified. Without this, fusion would be invalid. - LLVM_DEBUG(llvm::dbgs() - << "Setting compute tolerance to zero since " - "source has a cylic dependence.\n"); + LDBG() << "Setting compute tolerance to zero since " + << "source has a cylic dependence."; computeToleranceThresholdToUse = 0; } } @@ -1107,8 +1080,7 @@ public: if (canCreatePrivateMemRef(memref, srcEscapingMemRefs, srcId, dstId, removeSrcNode)) { // Create a private version of this memref. - LLVM_DEBUG(llvm::dbgs() - << "Creating private memref for " << memref << '\n'); + LDBG() << "Creating private memref for " << memref; // Create a private version of this memref. privateMemrefs.insert(memref); } @@ -1118,10 +1090,9 @@ public: fuseLoops(srcAffineForOp, dstAffineForOp, bestSlice); dstNodeChanged = true; - LLVM_DEBUG(llvm::dbgs() - << "Fused src loop " << srcId << " into dst loop " << dstId - << " at depth " << bestDstLoopDepth << ":\n" - << dstAffineForOp << "\n"); + LDBG() << "Fused src loop " << srcId << " into dst loop " << dstId + << " at depth " << bestDstLoopDepth << ":"; + LDBG() << dstAffineForOp; // Move 'dstAffineForOp' before 'insertPointInst' if needed. if (fusedLoopInsPoint != dstAffineForOp) @@ -1179,8 +1150,7 @@ public: dstLoopCollector.memrefFrees); if (removeSrcNode) { - LLVM_DEBUG(llvm::dbgs() - << "Removing src loop " << srcId << " after fusion\n"); + LDBG() << "Removing src loop " << srcId << " after fusion"; // srcNode is no longer valid after it is removed from mdg. srcAffineForOp.erase(); mdg->removeNode(srcId); @@ -1195,7 +1165,7 @@ public: /// user count greater than `maxSrcUserCount` for any of the memrefs involved /// are encountered. void fuseProducerConsumerNodes(unsigned maxSrcUserCount) { - LLVM_DEBUG(llvm::dbgs() << "--- Producer/Consumer Fusion ---\n"); + LDBG() << "--- Producer/Consumer Fusion ---"; init(); while (!worklist.empty()) { unsigned dstId = worklist.back(); @@ -1207,7 +1177,7 @@ public: // Visits each node in the graph, and for each node, attempts to fuse it with // its sibling nodes (nodes which share a parent, but no dependence edges). void fuseSiblingNodes() { - LLVM_DEBUG(llvm::dbgs() << "--- Sibling Fusion ---\n"); + LDBG() << "--- Sibling Fusion ---"; init(); while (!worklist.empty()) { unsigned dstId = worklist.back(); @@ -1289,8 +1259,7 @@ public: maxLegalFusionDepth = i; } - LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: " - << maxLegalFusionDepth << '\n'); + LDBG() << "Max legal depth for fusion: " << maxLegalFusionDepth; // Skip if fusion is not feasible at any loop depths. if (maxLegalFusionDepth == 0) @@ -1304,7 +1273,7 @@ public: // producer-consumer memref access for example). Check this and allow // fusion accordingly. if (hasCyclicDependence(sibAffineForOp)) { - LLVM_DEBUG(llvm::dbgs() << "Source nest has a cyclic dependence.\n"); + LDBG() << "Source nest has a cyclic dependence."; // Maximal fusion does not check for compute tolerance threshold; so // perform the maximal fusion only when the redundanation computation is // zero. @@ -1316,17 +1285,15 @@ public: sibAffineForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost, fusedLoopNestComputeCost); if (!fraction || fraction > 0) { - LLVM_DEBUG( - llvm::dbgs() - << "Can't perform maximal fusion with a cyclic dependence " - "and non-zero additional compute.\n"); + LDBG() << "Can't perform maximal fusion with a cyclic dependence " + << "and non-zero additional compute."; return; } } else { // Set redundant computation tolerance to zero regardless of what the // user specified. Without this, fusion would be invalid. - LLVM_DEBUG(llvm::dbgs() << "Setting compute tolerance to zero since " - "source has a cyclic dependence.\n"); + LDBG() << "Setting compute tolerance to zero since " + << "source has a cyclic dependence."; computeToleranceThresholdToUse = 0.0; } } @@ -1356,8 +1323,7 @@ public: // slice is used in the destination. auto isMaximal = bestSlice.isMaximal(); if (!isMaximal.value_or(false)) { - LLVM_DEBUG(llvm::dbgs() - << "Slice isn't maximal; not performing sibling fusion.\n"); + LDBG() << "Slice isn't maximal; not performing sibling fusion."; continue; } @@ -1374,10 +1340,9 @@ public: if (insertPointInst != dstForInst) dstForInst->moveBefore(insertPointInst); - LLVM_DEBUG(llvm::dbgs() - << "Fused sibling nest " << sibId << " into destination nest " - << dstNode->id << " at depth " << bestDstLoopDepth << ":\n" - << dstAffineForOp << "\n"); + LDBG() << "Fused sibling nest " << sibId << " into destination nest " + << dstNode->id << " at depth " << bestDstLoopDepth << ":"; + LDBG() << dstAffineForOp; // Update data dependence graph state post fusion. updateStateAfterSiblingFusion(sibNode, dstNode); @@ -1555,7 +1520,7 @@ public: void LoopFusion::runOnBlock(Block *block) { MemRefDependenceGraph g(*block); if (!g.init()) { - LLVM_DEBUG(llvm::dbgs() << "MDG init failed\n"); + LDBG() << "MDG init failed"; return; } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp index 41cd739..c6abb0d 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <optional> @@ -251,20 +252,20 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp, FusionStrategy fusionStrategy) { // Return 'failure' if 'dstLoopDepth == 0'. if (dstLoopDepth == 0) { - LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests at depth 0\n"); + LDBG() << "Cannot fuse loop nests at depth 0"; return FusionResult::FailPrecondition; } // Return 'failure' if 'srcForOp' and 'dstForOp' are not in the same block. auto *block = srcForOp->getBlock(); if (block != dstForOp->getBlock()) { - LLVM_DEBUG(llvm::dbgs() << "Cannot fuse loop nests in different blocks\n"); + LDBG() << "Cannot fuse loop nests in different blocks"; return FusionResult::FailPrecondition; } // Return 'failure' if no valid insertion point for fused loop nest in 'block' // exists which would preserve dependences. if (!getFusedLoopNestInsertionPoint(srcForOp, dstForOp)) { - LLVM_DEBUG(llvm::dbgs() << "Fusion would violate dependences in block\n"); + LDBG() << "Fusion would violate dependences in block"; return FusionResult::FailBlockDependence; } @@ -277,14 +278,14 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp, // Gather all load and store from 'forOpA' which precedes 'forOpB' in 'block'. SmallVector<Operation *, 4> opsA; if (!gatherLoadsAndStores(forOpA, opsA)) { - LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); + LDBG() << "Fusing loops with affine.if unsupported"; return FusionResult::FailPrecondition; } // Gather all load and store from 'forOpB' which succeeds 'forOpA' in 'block'. SmallVector<Operation *, 4> opsB; if (!gatherLoadsAndStores(forOpB, opsB)) { - LLVM_DEBUG(llvm::dbgs() << "Fusing loops with affine.if unsupported\n"); + LDBG() << "Fusing loops with affine.if unsupported"; return FusionResult::FailPrecondition; } @@ -296,7 +297,7 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp, // TODO: 'getMaxLoopDepth' does not support forward slice fusion. assert(isSrcForOpBeforeDstForOp && "Unexpected forward slice fusion"); if (getMaxLoopDepth(opsA, opsB) < dstLoopDepth) { - LLVM_DEBUG(llvm::dbgs() << "Fusion would violate loop dependences\n"); + LDBG() << "Fusion would violate loop dependences"; return FusionResult::FailFusionDependence; } } @@ -339,12 +340,12 @@ FusionResult mlir::affine::canFuseLoops(AffineForOp srcForOp, strategyOpsA, opsB, dstLoopDepth, numCommonLoops, isSrcForOpBeforeDstForOp, srcSlice); if (sliceComputationResult.value == SliceComputationResult::GenericFailure) { - LLVM_DEBUG(llvm::dbgs() << "computeSliceUnion failed\n"); + LDBG() << "computeSliceUnion failed"; return FusionResult::FailPrecondition; } if (sliceComputationResult.value == SliceComputationResult::IncorrectSliceFailure) { - LLVM_DEBUG(llvm::dbgs() << "Incorrect slice computation\n"); + LDBG() << "Incorrect slice computation"; return FusionResult::FailIncorrectSlice; } @@ -477,7 +478,7 @@ bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot, auto *parentForOp = forOp->getParentOp(); if (forOp != forOpRoot) { if (!isa<AffineForOp>(parentForOp)) { - LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n"); + LDBG() << "Expected parent AffineForOp"; return WalkResult::interrupt(); } // Add mapping to 'forOp' from its parent AffineForOp. @@ -498,7 +499,7 @@ bool mlir::affine::getLoopNestStats(AffineForOp forOpRoot, std::optional<uint64_t> maybeConstTripCount = getConstantTripCount(forOp); if (!maybeConstTripCount) { // Currently only constant trip count loop nests are supported. - LLVM_DEBUG(llvm::dbgs() << "Non-constant trip count unsupported\n"); + LDBG() << "Non-constant trip count unsupported"; return WalkResult::interrupt(); } diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp index 2de057d..cd216ef 100644 --- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp @@ -21,9 +21,11 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/IntegerSet.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/MapVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <optional> @@ -365,12 +367,11 @@ checkIfHyperRectangular(MutableArrayRef<AffineForOp> input) { if (input.size() <= 1) return success(); if (failed(getIndexSet(ops, &cst))) { - LLVM_DEBUG(llvm::dbgs() << "Index set computation failed!\n"); + LDBG() << "Index set computation failed!"; return failure(); } if (!cst.isHyperRectangular(0, input.size())) { - LLVM_DEBUG(llvm::dbgs() - << "Non-hyperrectangular nests not supported for tiling!\n"); + LDBG() << "Non-hyperrectangular nests not supported for tiling!"; return failure(); } return success(); @@ -385,14 +386,13 @@ static LogicalResult performPreTilingChecks(MutableArrayRef<AffineForOp> input, if (llvm::any_of(input, [](AffineForOp op) { return op.getNumResults() > 0; })) { - LLVM_DEBUG(llvm::dbgs() - << "Cannot tile nest where a loop has yield values\n"); + LDBG() << "Cannot tile nest where a loop has yield values"; return failure(); } // Check if the supplied `for` ops are all successively nested. if (!isPerfectlyNested(input)) { - LLVM_DEBUG(llvm::dbgs() << "input loops not perfectly nested"); + LDBG() << "input loops not perfectly nested"; return failure(); } @@ -1098,7 +1098,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp, // If the trip count is lower than the unroll jam factor, no unroll jam. if (mayBeConstantTripCount && *mayBeConstantTripCount < unrollJamFactor) { - LLVM_DEBUG(llvm::dbgs() << "[failed] trip count < unroll-jam factor\n"); + LDBG() << "[failed] trip count < unroll-jam factor"; return failure(); } @@ -1339,6 +1339,15 @@ bool mlir::affine::isValidLoopInterchangePermutation( unsigned maxLoopDepth = loops.size(); if (maxLoopDepth == 1) return true; + + // We cannot guarantee the validity of the interchange if the loops have + // iter_args, since the dependence analysis does not take them into account. + // Conservatively return false in such cases. + if (llvm::any_of(loops, [](AffineForOp loop) { + return loop.getNumIterOperands() > 0; + })) + return false; + // Gather dependence components for dependences between all ops in loop nest // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth]. std::vector<SmallVector<DependenceComponent, 2>> depCompsVec; @@ -1766,9 +1775,7 @@ findHighestBlockForPlacement(const MemRefRegion ®ion, Block &block, // We can't hoist past the definition of the memref being copied. Value memref = region.memref; if (!memref.getParentRegion()->isAncestor(enclosingOp->getParentRegion())) { - LLVM_DEBUG( - llvm::dbgs() - << "memref definition will end up not dominating hoist location\n"); + LDBG() << "memref definition will end up not dominating hoist location"; break; } @@ -1977,7 +1984,7 @@ static LogicalResult generateCopy( auto memRefType = cast<MemRefType>(memref.getType()); if (!memRefType.getLayout().isIdentity()) { - LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n"); + LDBG() << "Non-identity layout map not yet supported"; return failure(); } @@ -1989,7 +1996,7 @@ static LogicalResult generateCopy( unsigned rank = memRefType.getRank(); if (rank == 0) { - LLVM_DEBUG(llvm::dbgs() << "Non-zero ranked memrefs supported\n"); + LDBG() << "Non-zero ranked memrefs supported"; return failure(); } @@ -2001,19 +2008,18 @@ static LogicalResult generateCopy( std::optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(&fastBufferShape, &lbs); if (!numElements) { - LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n"); + LDBG() << "Non-constant region size not supported"; return failure(); } if (llvm::any_of(lbs, [](AffineMap lb) { return lb.getNumResults() > 1; })) { // This can be supported in the future if needed. - LLVM_DEBUG(llvm::dbgs() - << "Max lower bound for memref region start not supported\n"); + LDBG() << "Max lower bound for memref region start not supported"; return failure(); } if (*numElements == 0) { - LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n"); + LDBG() << "Nothing to copy"; return success(); } @@ -2021,9 +2027,8 @@ static LogicalResult generateCopy( for (unsigned i = 0; i < rank; ++i) { region.getLowerAndUpperBound(i, lbMaps[i], ubMaps[i]); if (lbMaps[i].getNumResults() == 0 || ubMaps[i].getNumResults() == 0) { - LLVM_DEBUG(llvm::dbgs() - << "Missing lower or upper bound for region along dimension: " - << i << '\n'); + LDBG() << "Missing lower or upper bound for region along dimension: " + << i; return failure(); } } @@ -2122,7 +2127,7 @@ static LogicalResult generateCopy( // TODO: use all stride levels once DmaStartOp is extended for // multi-level strides. if (dmaStrideInfos.size() > 1) { - LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n"); + LDBG() << "Only up to one level of stride supported"; return failure(); } @@ -2309,10 +2314,11 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, // surrounding the this block range. unsigned copyDepth = getNestingDepth(&*begin); - LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth - << "\n"); - LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n"); - LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n"); + LDBG() << "Generating copies at depth " << copyDepth; + LDBG() << "from begin: " + << OpWithFlags(&*begin, OpPrintingFlags().skipRegions()); + LDBG() << "to inclusive end: " + << OpWithFlags(&*std::prev(end), OpPrintingFlags().skipRegions()); // List of memory regions to copy for. We need a map vector to have a // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here @@ -2349,8 +2355,8 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, return; if (!memref.getParentRegion()->isAncestor(block->getParent())) { - LLVM_DEBUG(llvm::dbgs() << "memref definition is inside of the depth at " - "which copy-in/copy-out would happen\n"); + LDBG() << "memref definition is inside of the depth at " + << "which copy-in/copy-out would happen"; return; } @@ -2358,12 +2364,10 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, auto region = std::make_unique<MemRefRegion>(opInst->getLoc()); if (failed(region->compute(opInst, copyDepth, /*sliceState=*/nullptr, /*addMemRefDimBounds=*/false))) { - LLVM_DEBUG(llvm::dbgs() - << "Error obtaining memory region: semi-affine maps?\n"); - LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n"); + LDBG() << "Error obtaining memory region: semi-affine maps?"; + LDBG() << "over-approximating to the entire memref"; if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) { - LLVM_DEBUG( - opInst->emitError("non-constant memref sizes not yet supported")); + LDBG() << "non-constant memref sizes not yet supported"; error = true; return; } @@ -2392,13 +2396,11 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, // Perform a union with the existing region. if (failed(it->second->unionBoundingBox(*region))) { - LLVM_DEBUG(llvm::dbgs() - << "Memory region bounding box failed; " - "over-approximating to the entire memref\n"); + LDBG() << "Memory region bounding box failed; " + << "over-approximating to the entire memref"; // If the union fails, we will overapproximate. if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) { - LLVM_DEBUG(opInst->emitError( - "non-constant memref sizes not yet supported")); + LDBG() << "non-constant memref sizes not yet supported"; error = true; return true; } @@ -2428,8 +2430,7 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, }); if (error) { - LLVM_DEBUG(begin->emitError( - "copy generation failed for one or more memref's in this block\n")); + LDBG() << "copy generation failed for one or more memref's in this block"; return failure(); } @@ -2466,8 +2467,7 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end, processRegions(writeRegions); if (!ret) { - LLVM_DEBUG(begin->emitError( - "copy generation failed for one or more memref's in this block\n")); + LDBG() << "copy generation failed for one or more memref's in this block"; return failure(); } @@ -2608,7 +2608,7 @@ static AffineIfOp createSeparationCondition(MutableArrayRef<AffineForOp> loops, /*boundFloorDivisor=*/nullptr, /*ub=*/nullptr, &fullTileLbPos, &fullTileUbPos)) { - LLVM_DEBUG(llvm::dbgs() << "Can't get constant diff pair for a loop\n"); + LDBG() << "Can't get constant diff pair for a loop"; return nullptr; } @@ -2667,8 +2667,7 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest, for (auto loop : inputNest) { // TODO: straightforward to generalize to a non-unit stride. if (loop.getStepAsInt() != 1) { - LLVM_DEBUG(llvm::dbgs() - << "[tile separation] non-unit stride not implemented\n"); + LDBG() << "[tile separation] non-unit stride not implemented"; return failure(); } SmallVector<Operation *, 1> loopOp{loop.getOperation()}; @@ -2682,8 +2681,8 @@ createFullTiles(MutableArrayRef<AffineForOp> inputNest, /*boundFloorDivisor=*/nullptr, /*ub=*/nullptr, &lbPos, &ubPos) || lbPos == ubPos) { - LLVM_DEBUG(llvm::dbgs() << "[tile separation] Can't get constant diff / " - "equalities not yet handled\n"); + LDBG() << "[tile separation] Can't get constant diff / " + << "equalities not yet handled"; return failure(); } @@ -2741,8 +2740,8 @@ mlir::affine::separateFullTiles(MutableArrayRef<AffineForOp> inputNest, AffineIfOp ifOp = createSeparationCondition(inputNest, b); if (!ifOp) { fullTileLoops.front().erase(); - LLVM_DEBUG(llvm::dbgs() << "All tiles are full tiles, or failure creating " - "separation condition\n"); + LDBG() << "All tiles are full tiles, or failure creating " + << "separation condition"; return failure(); } diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 488c3c3..7d4d818 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType, case AtomicRMWKind::addi: case AtomicRMWKind::maxu: case AtomicRMWKind::ori: + case AtomicRMWKind::xori: return builder.getZeroAttr(resultType); case AtomicRMWKind::andi: return builder.getIntegerAttr( @@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) { // Integer operations. .Case([](arith::AddIOp op) { return AtomicRMWKind::addi; }) .Case([](arith::OrIOp op) { return AtomicRMWKind::ori; }) - .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; }) + .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; }) .Case([](arith::AndIOp op) { return AtomicRMWKind::andi; }) .Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; }) .Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; }) @@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, return arith::OrIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::andi: return arith::AndIOp::create(builder, loc, lhs, rhs); + case AtomicRMWKind::xori: + return arith::XOrIOp::create(builder, loc, lhs, rhs); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported"); diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 93682a9..4780dbb 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -12,7 +12,7 @@ add_mlir_dialect_library(MLIRArithTransforms UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS - {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms DEPENDS MLIRArithTransformsIncGen diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp index 1aa8064..35365f2 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp @@ -158,13 +158,11 @@ protected: PatternRewriter &rewriter) { // Check iterator types for matrix multiplication. SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray(); - if (!((itTypes.size() == 3 && - (itTypes[0] == vector::IteratorType::parallel && - itTypes[1] == vector::IteratorType::parallel && - itTypes[2] == vector::IteratorType::reduction)) || - (itTypes.size() == 2 && - (itTypes[0] == vector::IteratorType::parallel && - itTypes[1] == vector::IteratorType::reduction)))) + if ((itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel || + itTypes[1] != vector::IteratorType::parallel || + itTypes[2] != vector::IteratorType::reduction) && + (itTypes.size() != 2 || itTypes[0] != vector::IteratorType::parallel || + itTypes[1] != vector::IteratorType::reduction)) return rewriter.notifyMatchFailure( op, "iterator types do not correspond to matrix multiplication"); diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp index 35b0bd1..6cb2a56 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractToSVEPatterns.cpp @@ -183,9 +183,9 @@ protected: Value acc; // Conventional names for matrix dimensions. - int64_t M = 0; - int64_t N = 0; - int64_t K = 0; + int64_t m = 0; + int64_t n = 0; + int64_t k = 0; // Create the matrix mulitply and accumulate operation according to // `mmlaOp`. @@ -286,41 +286,41 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, // Single-dimension vector type for the entire RHS tile. - auto flatRhsTileType = VectorType::get(/*shape=*/K * N, operandEltType, + auto flatRhsTileType = VectorType::get(/*shape=*/k * n, operandEltType, /*scalableDims=*/{true}); // Vector type having the same number of elements as a row in the // accumulator/output tile and the same element type. - auto accRowTy = VectorType::get(/*shape=*/N, resultEltType, + auto accRowTy = VectorType::get(/*shape=*/n, resultEltType, /*scalableDims=*/{true}); // Vector type having twice the number of elements as a row in the // accumulator/output tile the same element type. - auto accRowX2Ty = VectorType::get(/*shape=*/2 * N, resultEltType, + auto accRowX2Ty = VectorType::get(/*shape=*/2 * n, resultEltType, /*scalableDims=*/{true}); // Vector type having half the number of elements as a row in the // accumulator/output tile and an integer element type with twice the bit // width. - auto accRow64Ty = VectorType::get(/*shape=*/N / 2, rewriter.getI64Type(), + auto accRow64Ty = VectorType::get(/*shape=*/n / 2, rewriter.getI64Type(), /*scalableDims=*/{true}); // Vector type having the same the number of elements as a row in the // accumulator/output tile and an integer element type with twice the bit // width. - auto accRowX264Ty = VectorType::get(/*shape=*/N, rewriter.getI64Type(), + auto accRowX264Ty = VectorType::get(/*shape=*/n, rewriter.getI64Type(), /*scalableDims=*/{true}); Location loc = op.getLoc(); // Extract LHS sub-tiles with logical shape <2xK>. SmallVector<Value> lhsTile; - for (int64_t i = 0; i < M; i += 2) { + for (int64_t i = 0; i < m; i += 2) { // Extract two consecutive rows of the LHS tile. auto r0 = vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i}); auto r1 = vector::ExtractOp::create(rewriter, loc, lhs, ArrayRef<int64_t>{i + 1}); // Concatenate to obtain a 2 x K x <input-type> flattened sub-tile. - SmallVector<int64_t> shuffleIdx(2 * K); + SmallVector<int64_t> shuffleIdx(2 * k); std::iota(shuffleIdx.begin(), shuffleIdx.end(), 0); auto t = vector::ShuffleOp::create(rewriter, loc, r0, r1, shuffleIdx); // Turn it into a scalable vector. @@ -337,13 +337,13 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, // Extract the RHS sub-tiles with logical shape <Kx[2]>. SmallVector<Value> rhsTile; - for (int64_t j = 0; j < N; j += 2) + for (int64_t j = 0; j < n; j += 2) rhsTile.push_back(vector::ScalableExtractOp::create( - rewriter, loc, flatRhsType, rhs, j * K)); + rewriter, loc, flatRhsType, rhs, j * k)); // Extract and pack the ACC sub-tiles. SmallVector<Value> accTile; - for (int64_t i = 0; i < M; i += 2) { + for (int64_t i = 0; i < m; i += 2) { // Extract two consecutive rows of the accumulator tile. auto r0 = vector::ExtractOp::create(rewriter, loc, op.getAcc(), ArrayRef<int64_t>{i}); @@ -370,28 +370,28 @@ Value VectorContractRewriter::lower(vector::ContractionOp op, vector::BitCastOp::create(rewriter, loc, accRowX2Ty, intrI64); } // Extract ACC sub-tiles. - for (int64_t j = 0; j < N; j += 2) + for (int64_t j = 0; j < n; j += 2) accTile.push_back(vector::ScalableExtractOp::create( rewriter, loc, flatAccType, accTileVec, j * 2)); } // Emit sub-tile matrix multiplications. SmallVector<Value> outTile; - for (int64_t i = 0; i < M / 2; ++i) - for (int64_t j = 0; j < N / 2; ++j) { - Value mmla = createMMLA(rewriter, loc, accTile[i * N / 2 + j], lhsTile[i], + for (int64_t i = 0; i < m / 2; ++i) + for (int64_t j = 0; j < n / 2; ++j) { + Value mmla = createMMLA(rewriter, loc, accTile[i * n / 2 + j], lhsTile[i], rhsTile[j]); outTile.push_back(mmla); } // Unpack the OUT sub-tiles and insert into the result. Value result = ub::PoisonOp::create(rewriter, loc, op.getResultType()); - for (int64_t i = 0; i < M / 2; ++i) { + for (int64_t i = 0; i < m / 2; ++i) { // Collect a number of sub-tiles in a row. Value row = ub::PoisonOp::create(rewriter, loc, accRowX2Ty); - for (int64_t j = 0; j < N / 2; ++j) + for (int64_t j = 0; j < n / 2; ++j) row = vector::ScalableInsertOp::create( - rewriter, loc, outTile[i * N / 2 + j], row, j * 4); + rewriter, loc, outTile[i * n / 2 + j], row, j * 4); // Unpack the row to obtain two rows of the output. If we have the out // sub-tiles transposed we obtain two consecutive output rows by @@ -432,9 +432,9 @@ public: VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); - M = lhsType.getDimSize(0); - N = rhsType.getDimSize(0); - K = rhsType.getDimSize(1); + m = lhsType.getDimSize(0); + n = rhsType.getDimSize(0); + k = rhsType.getDimSize(1); // Check the operands have the expected shape: // * for LHS: fixed vector MxK @@ -442,8 +442,8 @@ public: // * K == 8 // * M and N even and at least 2 if (lhsType.isScalable() || !rhsType.getScalableDims()[0] || - rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 8 || - M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 || + rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 8 || + m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 || !rhsType.getScalableDims()[0]) return rewriter.notifyMatchFailure(op, "non-matching operand shape"); @@ -504,9 +504,9 @@ public: VectorType lhsType = op.getLhsType(); VectorType rhsType = op.getRhsType(); - M = lhsType.getDimSize(0); - N = rhsType.getDimSize(0); - K = rhsType.getDimSize(1); + m = lhsType.getDimSize(0); + n = rhsType.getDimSize(0); + k = rhsType.getDimSize(1); // Check the operands have the expected shape: // * for LHS: fixed vector MxK @@ -514,8 +514,8 @@ public: // * K == 4 // * M and N even and at least 2 if (lhsType.isScalable() || !rhsType.getScalableDims()[0] || - rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != K || K != 4 || - M < 2 || M % 2 != 0 || N < 2 || N % 2 != 0 || + rhsType.getScalableDims()[1] || lhsType.getDimSize(1) != k || k != 4 || + m < 2 || m % 2 != 0 || n < 2 || n % 2 != 0 || !rhsType.getScalableDims()[0]) return rewriter.notifyMatchFailure(op, "non-matching operand shape"); diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp index ddc64ea..91e37dd 100644 --- a/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncRuntimeRefCounting.cpp @@ -248,7 +248,7 @@ LogicalResult AsyncRuntimeRefCountingPass::addDropRefAfterLastUse(Value value) { Region *definingRegion = value.getParentRegion(); // Last users of the `value` inside all blocks where the value dies. - llvm::SmallSet<Operation *, 4> lastUsers; + llvm::SmallPtrSet<Operation *, 4> lastUsers; // Find blocks in the `definingRegion` that have users of the `value` (if // there are multiple users in the block, which one will be selected is diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index f1f12f4..56ff212 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -463,8 +463,12 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> { // which otherwise could prevent removal of unnecessary allocs. Value canonicalSource = source; while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>( - canonicalSource.getDefiningOp())) + canonicalSource.getDefiningOp())) { + if (canonicalSource != iface.getViewDest()) { + break; + } canonicalSource = iface.getViewSource(); + } std::optional<Operation *> maybeCloneDeallocOp = memref::findDealloc(cloneOp.getOutput()); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp index 8916526..a465c95 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferDeallocationSimplification.cpp @@ -37,8 +37,12 @@ using namespace mlir::bufferization; /// Given a memref value, return the "base" value by skipping over all /// ViewLikeOpInterface ops (if any) in the reverse use-def chain. static Value getViewBase(Value value) { - while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) + while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) { + if (value != viewLikeOp.getViewDest()) { + break; + } value = viewLikeOp.getViewSource(); + } return value; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp index 8f983ab..0b2e080 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferViewFlowAnalysis.cpp @@ -121,7 +121,7 @@ void BufferViewFlowAnalysis::build(Operation *op) { // Add additional dependencies created by view changes to the alias list. if (auto viewInterface = dyn_cast<ViewLikeOpInterface>(op)) { registerDependencies(viewInterface.getViewSource(), - viewInterface->getResult(0)); + viewInterface.getViewDest()); return WalkResult::advance(); } @@ -231,8 +231,12 @@ static bool isFunctionArgument(Value v) { /// Given a memref value, return the "base" value by skipping over all /// ViewLikeOpInterface ops (if any) in the reverse use-def chain. static Value getViewBase(Value value) { - while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) + while (auto viewLikeOp = value.getDefiningOp<ViewLikeOpInterface>()) { + if (value != viewLikeOp.getViewDest()) { + break; + } value = viewLikeOp.getViewSource(); + } return value; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 91f6f25..68ef519 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -20,6 +20,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/PassManager.h" +#include "llvm/Support/DebugLog.h" #include <optional> namespace mlir { @@ -328,20 +329,16 @@ LogicalResult bufferization::bufferizeOp(Operation *op, "blocks"); // Bufferize the op. - LLVM_DEBUG(llvm::dbgs() - << "//===-------------------------------------------===//\n" - << "IR after bufferizing: " << nextOp->getName() << "\n"); + LDBG(3) << "//===-------------------------------------------===//\n" + << "IR after bufferizing: " << nextOp->getName(); rewriter.setInsertionPoint(nextOp); if (failed( bufferizableOp.bufferize(rewriter, options, bufferizationState))) { - LLVM_DEBUG(llvm::dbgs() - << "failed to bufferize\n" - << "//===-------------------------------------------===//\n"); + LDBG(2) << "failed to bufferize\n" + << "//===-------------------------------------------===//"; return nextOp->emitError("failed to bufferize op"); } - LLVM_DEBUG(llvm::dbgs() - << *op - << "\n//===-------------------------------------------===//\n"); + LDBG(3) << *op << "\n//===-------------------------------------------===//"; } // Return early if the top-level op is entirely gone. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index a8e8353..fb7f2bb 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -56,6 +56,7 @@ #include "mlir/Interfaces/SubsetOpInterface.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" +#include "llvm/Support/DebugLog.h" MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::bufferization::OneShotAnalysisState) @@ -616,13 +617,10 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, if (getParallelRegion(def.getParentRegion(), options) != getParallelRegion(uConflictingWrite->getOwner()->getParentRegion(), options)) { - LLVM_DEBUG( - llvm::dbgs() - << "\n- bufferizes out-of-place due to parallel region:\n"); - LLVM_DEBUG(llvm::dbgs() - << " unConflictingWrite = operand " - << uConflictingWrite->getOperandNumber() << " of " - << *uConflictingWrite->getOwner() << "\n"); + LDBG() << "\n- bufferizes out-of-place due to parallel region:\n" + << " unConflictingWrite = operand " + << uConflictingWrite->getOperandNumber() << " of " + << *uConflictingWrite->getOwner(); return true; } } @@ -631,9 +629,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, for (OpOperand *uRead : usesRead) { Operation *readingOp = uRead->getOwner(); - LLVM_DEBUG(llvm::dbgs() << "\n- check conflict:\n"); - LLVM_DEBUG(llvm::dbgs() << " uRead = operand " << uRead->getOperandNumber() - << " of " << *readingOp << "\n"); + LDBG() << "\n- check conflict:\n" + << " uRead = operand " << uRead->getOperandNumber() << " of " + << *readingOp; // Find the definition of uRead by following the SSA use-def chain. // E.g.: @@ -648,23 +646,22 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, const SetVector<Value> &definitions = state.findDefinitionsCached(uRead); if (definitions.empty()) { // Fast path: No conflict if there are no definitions. - LLVM_DEBUG(llvm::dbgs() - << " no conflict: read value has no definitions\n"); + LDBG() << " no conflict: read value has no definitions"; continue; } // Look for conflicting memory writes. Potential conflicts are writes to an // alias that have been decided to bufferize inplace. for (OpOperand *uConflictingWrite : usesWrite) { - LLVM_DEBUG(llvm::dbgs() << " unConflictingWrite = operand " - << uConflictingWrite->getOperandNumber() << " of " - << *uConflictingWrite->getOwner() << "\n"); + LDBG() << " unConflictingWrite = operand " + << uConflictingWrite->getOperandNumber() << " of " + << *uConflictingWrite->getOwner(); // Check if op dominance can be used to rule out read-after-write // conflicts. bool useDominance = canUseOpDominance(uRead, uConflictingWrite, definitions, state); - LLVM_DEBUG(llvm::dbgs() << "\n- useDominance = " << useDominance << "\n"); + LDBG() << "\n- useDominance = " << useDominance; // Throughout this loop, check for multiple requirements that have to be // met for uConflictingWrite to be an actual conflict. @@ -680,8 +677,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // inside a loop), there may be no meaningful `happensBefore` // relationship. if (happensBefore(readingOp, conflictingWritingOp, domInfo)) { - LLVM_DEBUG(llvm::dbgs() - << " no conflict: read happens before write\n"); + LDBG() << " no conflict: read happens before write"; continue; } @@ -693,8 +689,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // Note: If the op is executed multiple times (e.g., because it is // inside a loop), it may be conflicting with itself. if (uConflictingWrite == uRead) { - LLVM_DEBUG(llvm::dbgs() - << " no conflict: read and write are same use\n"); + LDBG() << " no conflict: read and write are same use"; continue; } @@ -705,8 +700,8 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // multiple times. if (state.insideMutuallyExclusiveRegions(readingOp, conflictingWritingOp)) { - LLVM_DEBUG(llvm::dbgs() << " no conflict: read and write are in " - "mutually exclusive regions\n"); + LDBG() << " no conflict: read and write are in " + "mutually exclusive regions"; continue; } @@ -721,9 +716,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, state, uRead, uConflictingWrite->get()) || hasEquivalentValueInReverseUseDefChain( state, uConflictingWrite, uRead->get())) { - LLVM_DEBUG( - llvm::dbgs() - << " no conflict: op bufferizes to element-wise access\n"); + LDBG() << " no conflict: op bufferizes to element-wise access"; continue; } } @@ -733,15 +726,14 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // No conflict if the operands are non-conflicting subsets. if (areNonConflictingSubsets(uRead, uConflictingWrite, state)) { - LLVM_DEBUG(llvm::dbgs() << " no conflict: non-conflicting subsets\n"); + LDBG() << " no conflict: non-conflicting subsets"; continue; } // No conflict if the op interface says so. if (auto bufferizableOp = options.dynCastBufferizableOp(readingOp)) { if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) { - LLVM_DEBUG(llvm::dbgs() - << " no conflict: op interace of reading op says 'no'\n"); + LDBG() << " no conflict: op interace of reading op says 'no'"; continue; } } @@ -751,9 +743,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, options.dynCastBufferizableOp(conflictingWritingOp)) { if (bufferizableOp.isNotConflicting(uRead, uConflictingWrite, state)) { - LLVM_DEBUG( - llvm::dbgs() - << " no conflict: op interace of writing op says 'no'\n"); + LDBG() << " no conflict: op interace of writing op says 'no'"; continue; } } @@ -761,29 +751,26 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // Check all possible definitions. for (Value definition : definitions) { - LLVM_DEBUG(llvm::dbgs() << " * definition = " << definition << "\n"); + LDBG() << " * definition = " << definition; // No conflict if the conflicting write happens before the definition. if (Operation *defOp = definition.getDefiningOp()) { if (happensBefore(conflictingWritingOp, defOp, domInfo)) { // conflictingWritingOp happens before defOp. No conflict. - LLVM_DEBUG(llvm::dbgs() - << " no conflict: write happens before definition\n"); + LDBG() << " no conflict: write happens before definition"; continue; } // No conflict if conflictingWritingOp is contained in defOp. if (defOp->isProperAncestor(conflictingWritingOp)) { - LLVM_DEBUG( - llvm::dbgs() - << " no conflict: write is contained in definition\n"); + LDBG() << " no conflict: write is contained in definition"; continue; } } else { auto bbArg = cast<BlockArgument>(definition); Block *block = bbArg.getOwner(); if (!block->findAncestorOpInBlock(*conflictingWritingOp)) { - LLVM_DEBUG(llvm::dbgs() << " no conflict: definition is bbArg " - "and write happens outside of block\n"); + LDBG() << " no conflict: definition is bbArg " + "and write happens outside of block"; // conflictingWritingOp happens outside of the block. No // conflict. continue; @@ -795,8 +782,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, AliasingValueList aliases = state.getAliasingValues(*uConflictingWrite); if (aliases.getNumAliases() == 1 && aliases.getAliases()[0].value == definition) { - LLVM_DEBUG(llvm::dbgs() - << " no conflict: definition and write are same\n"); + LDBG() << " no conflict: definition and write are same"; continue; } @@ -804,7 +790,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, if (options.printConflicts) annotateConflict(uRead, uConflictingWrite, definition); - LLVM_DEBUG(llvm::dbgs() << " => RaW CONFLICT FOUND\n"); + LDBG() << " => RaW CONFLICT FOUND"; return true; } } @@ -958,7 +944,7 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand, for (AliasingValue alias : state.getAliasingValues(operand)) state.applyOnAliases(alias.value, checkReadOnly); if (foundReadOnly) { - LLVM_DEBUG(llvm::dbgs() << "=> NOT WRITABLE\n"); + LDBG() << "=> NOT WRITABLE"; return true; } @@ -987,10 +973,9 @@ void OneShotAnalysisState::resetCache() { static LogicalResult bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, const DominanceInfo &domInfo) { - LLVM_DEBUG( - llvm::dbgs() << "//===-------------------------------------------===//\n" - << "Analyzing operand #" << operand.getOperandNumber() - << " of " << *operand.getOwner() << "\n"); + LDBG() << "//===-------------------------------------------===//\n" + << "Analyzing operand #" << operand.getOperandNumber() << " of " + << *operand.getOwner(); bool foundInterference = wouldCreateWriteToNonWritableBuffer(operand, state) || @@ -1001,8 +986,7 @@ bufferizableInPlaceAnalysisImpl(OpOperand &operand, OneShotAnalysisState &state, else state.bufferizeInPlace(operand); - LLVM_DEBUG(llvm::dbgs() - << "//===-------------------------------------------===//\n"); + LDBG() << "//===-------------------------------------------===//"; return success(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp index 725fa24..b593cca 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp @@ -51,14 +51,8 @@ static bool isMemref(Value v) { return isa<BaseMemRefType>(v.getType()); } /// Return "true" if the given op is guaranteed to have neither "Allocate" nor /// "Free" side effects. static bool hasNeitherAllocateNorFreeSideEffect(Operation *op) { - if (isa<MemoryEffectOpInterface>(op)) - return !hasEffect<MemoryEffects::Allocate>(op) && - !hasEffect<MemoryEffects::Free>(op); - // If the op does not implement the MemoryEffectOpInterface but has has - // recursive memory effects, then this op in isolation (without its body) does - // not have any side effects. All the ops inside the regions of this op will - // be processed separately. - return op->hasTrait<OpTrait::HasRecursiveMemoryEffects>(); + return !mightHaveEffect<MemoryEffects::Allocate>(op) && + !mightHaveEffect<MemoryEffects::Free>(op); } /// Return "true" if the given op has buffer semantics. I.e., it has buffer @@ -517,9 +511,7 @@ LogicalResult BufferDeallocation::verifyOperationPreconditions(Operation *op) { // MemoryEffectOpInterface. They usually do not have side effects apart // from the callee, which will be analyzed separately. (This is similar to // "recursive memory effects".) - if (!isa<MemoryEffectOpInterface>(op) && - !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>() && - !isa<CallOpInterface>(op)) + if (hasUnknownEffects(op) && !isa<CallOpInterface>(op)) return op->emitError( "ops with unknown memory side effects are not supported"); diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt index 37b4cfc..47740d3 100644 --- a/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ControlFlow/Transforms/CMakeLists.txt @@ -3,7 +3,7 @@ add_mlir_dialect_library(MLIRControlFlowTransforms BufferizableOpInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS - {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/ControlFlow/Transforms LINK_LIBS PUBLIC MLIRBufferizationDialect diff --git a/mlir/lib/Dialect/DLTI/Traits.cpp b/mlir/lib/Dialect/DLTI/Traits.cpp index 34f2dd5..3f6dd29 100644 --- a/mlir/lib/Dialect/DLTI/Traits.cpp +++ b/mlir/lib/Dialect/DLTI/Traits.cpp @@ -24,7 +24,7 @@ LogicalResult mlir::impl::verifyHasDefaultDLTIDataLayoutTrait(Operation *op) { } DataLayoutSpecInterface mlir::impl::getDataLayoutSpec(Operation *op) { - return op->getAttrOfType<DataLayoutSpecAttr>( + return op->getAttrOfType<DataLayoutSpecInterface>( DLTIDialect::kDataLayoutAttrName); } diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index e79da92..00ce3b5 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -131,6 +131,12 @@ bool mlir::emitc::isPointerWideType(Type type) { type); } +bool mlir::emitc::isFundamentalType(Type type) { + return llvm::isa<IndexType>(type) || isPointerWideType(type) || + isSupportedIntegerType(type) || isSupportedFloatType(type) || + isa<emitc::PointerType>(type); +} + /// Check that the type of the initial value is compatible with the operations /// result type. static LogicalResult verifyInitializationAttribute(Operation *op, @@ -375,6 +381,52 @@ OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); } // ExpressionOp //===----------------------------------------------------------------------===// +ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) { + SmallVector<OpAsmParser::UnresolvedOperand> operands; + if (parser.parseOperandList(operands)) + return parser.emitError(parser.getCurrentLocation()) << "expected operands"; + if (succeeded(parser.parseOptionalKeyword("noinline"))) + result.addAttribute(ExpressionOp::getDoNotInlineAttrName(result.name), + parser.getBuilder().getUnitAttr()); + Type type; + if (parser.parseColonType(type)) + return parser.emitError(parser.getCurrentLocation(), + "expected function type"); + auto fnType = llvm::dyn_cast<FunctionType>(type); + if (!fnType) + return parser.emitError(parser.getCurrentLocation(), + "expected function type"); + if (parser.resolveOperands(operands, fnType.getInputs(), + parser.getCurrentLocation(), result.operands)) + return failure(); + if (fnType.getNumResults() != 1) + return parser.emitError(parser.getCurrentLocation(), + "expected single return type"); + result.addTypes(fnType.getResults()); + Region *body = result.addRegion(); + SmallVector<OpAsmParser::Argument> argsInfo; + for (auto [unresolvedOperand, operandType] : + llvm::zip(operands, fnType.getInputs())) { + OpAsmParser::Argument argInfo; + argInfo.ssaName = unresolvedOperand; + argInfo.type = operandType; + argsInfo.push_back(argInfo); + } + if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true)) + return failure(); + return success(); +} + +void emitc::ExpressionOp::print(OpAsmPrinter &p) { + p << ' '; + p.printOperands(getDefs()); + p << " : "; + p.printFunctionalType(getOperation()); + p.shadowRegionArgs(getRegion(), getDefs()); + p << ' '; + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); +} + Operation *ExpressionOp::getRootOp() { auto yieldOp = cast<YieldOp>(getBody()->getTerminator()); Value yieldedValue = yieldOp.getResult(); @@ -1395,6 +1447,7 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) { //===----------------------------------------------------------------------===// // FieldOp //===----------------------------------------------------------------------===// + static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op, TypeAttr type, Attribute initialValue) { @@ -1452,6 +1505,15 @@ LogicalResult FieldOp::verify() { //===----------------------------------------------------------------------===// // GetFieldOp //===----------------------------------------------------------------------===// + +LogicalResult GetFieldOp::verify() { + auto parentClassOp = getOperation()->getParentOfType<emitc::ClassOp>(); + if (!parentClassOp.getOperation()) + return emitOpError(" must be nested within an emitc.class operation"); + + return success(); +} + LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) { mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr(); FieldOp fieldOp = diff --git a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp index 3f0690c..f8469b8 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/Transforms.cpp @@ -9,7 +9,9 @@ #include "mlir/Dialect/EmitC/Transforms/Transforms.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/Location.h" #include "mlir/IR/PatternMatch.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { namespace emitc { @@ -24,20 +26,24 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder) { Location loc = op->getLoc(); builder.setInsertionPointAfter(op); - auto expressionOp = emitc::ExpressionOp::create(builder, loc, resultType); + auto expressionOp = + emitc::ExpressionOp::create(builder, loc, resultType, op->getOperands()); // Replace all op's uses with the new expression's result. result.replaceAllUsesWith(expressionOp.getResult()); - // Create an op to yield op's value. - Region ®ion = expressionOp.getRegion(); - Block &block = region.emplaceBlock(); + Block &block = expressionOp.createBody(); + IRMapping mapper; + for (auto [operand, arg] : + llvm::zip(expressionOp.getOperands(), block.getArguments())) + mapper.map(operand, arg); builder.setInsertionPointToEnd(&block); - auto yieldOp = emitc::YieldOp::create(builder, loc, result); - // Move op into the new expression. - op->moveBefore(yieldOp); + Operation *rootOp = builder.clone(*op, mapper); + op->erase(); + // Create an op to yield op's value. + emitc::YieldOp::create(builder, loc, rootOp->getResults()[0]); return expressionOp; } @@ -53,51 +59,93 @@ struct FoldExpressionOp : public OpRewritePattern<ExpressionOp> { using OpRewritePattern<ExpressionOp>::OpRewritePattern; LogicalResult matchAndRewrite(ExpressionOp expressionOp, PatternRewriter &rewriter) const override { - bool anythingFolded = false; - for (Operation &op : llvm::make_early_inc_range( - expressionOp.getBody()->without_terminator())) { - // Don't fold expressions whose result value has its address taken. - auto applyOp = dyn_cast<emitc::ApplyOp>(op); - if (applyOp && applyOp.getApplicableOperator() == "&") - continue; - - for (Value operand : op.getOperands()) { - auto usedExpression = operand.getDefiningOp<ExpressionOp>(); - if (!usedExpression) - continue; - - // Don't fold expressions with multiple users: assume any - // re-materialization was done separately. - if (!usedExpression.getResult().hasOneUse()) - continue; - - // Don't fold expressions with side effects. - if (usedExpression.hasSideEffects()) - continue; - - // Fold the used expression into this expression by cloning all - // instructions in the used expression just before the operation using - // its value. - rewriter.setInsertionPoint(&op); - IRMapping mapper; - for (Operation &opToClone : - usedExpression.getBody()->without_terminator()) { - Operation *clone = rewriter.clone(opToClone, mapper); - mapper.map(&opToClone, clone); - } - - Operation *expressionRoot = usedExpression.getRootOp(); - Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot); - assert(clonedExpressionRootOp && - "Expected cloned expression root to be in mapper"); - assert(clonedExpressionRootOp->getNumResults() == 1 && - "Expected cloned root to have a single result"); - - rewriter.replaceOp(usedExpression, clonedExpressionRootOp); - anythingFolded = true; - } + Block *expressionBody = expressionOp.getBody(); + ExpressionOp usedExpression; + SetVector<Value> foldedOperands; + + auto takesItsOperandsAddress = [](Operation *user) { + auto applyOp = dyn_cast<emitc::ApplyOp>(user); + return applyOp && applyOp.getApplicableOperator() == "&"; + }; + + // Select as expression to fold the first operand expression that + // - doesn't have its result value's address taken, + // - has a single user: assume any re-materialization was done separately, + // - has no side effects, + // and save all other operands to be used later as operands in the folded + // expression. + for (auto [operand, arg] : llvm::zip(expressionOp.getOperands(), + expressionBody->getArguments())) { + ExpressionOp operandExpression = operand.getDefiningOp<ExpressionOp>(); + if (usedExpression || !operandExpression || + llvm::any_of(arg.getUsers(), takesItsOperandsAddress) || + !operandExpression.getResult().hasOneUse() || + operandExpression.hasSideEffects()) + foldedOperands.insert(operand); + else + usedExpression = operandExpression; } - return anythingFolded ? success() : failure(); + + // If no operand expression was selected, bail out. + if (!usedExpression) + return failure(); + + // Collect additional operands from the folded expression. + for (Value operand : usedExpression.getOperands()) + foldedOperands.insert(operand); + + // Create a new expression to hold the folding result. + rewriter.setInsertionPointAfter(expressionOp); + auto foldedExpression = emitc::ExpressionOp::create( + rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(), + foldedOperands.getArrayRef(), expressionOp.getDoNotInline()); + Block &foldedExpressionBody = foldedExpression.createBody(); + + // Map each operand of the new expression to its matching block argument. + IRMapping mapper; + for (auto [operand, arg] : llvm::zip(foldedExpression.getOperands(), + foldedExpressionBody.getArguments())) + mapper.map(operand, arg); + + // Prepare to fold the used expression and the matched expression into the + // newly created folded expression. + auto foldExpression = [&rewriter, &mapper](ExpressionOp expressionToFold, + bool withTerminator) { + Block *expressionToFoldBody = expressionToFold.getBody(); + for (auto [operand, arg] : + llvm::zip(expressionToFold.getOperands(), + expressionToFoldBody->getArguments())) { + mapper.map(arg, mapper.lookup(operand)); + } + + for (Operation &opToClone : expressionToFoldBody->without_terminator()) + rewriter.clone(opToClone, mapper); + + if (withTerminator) + rewriter.clone(*expressionToFoldBody->getTerminator(), mapper); + }; + rewriter.setInsertionPointToStart(&foldedExpressionBody); + + // First, fold the used expression into the new expression and map its + // result to the clone of its root operation within the new expression. + foldExpression(usedExpression, /*withTerminator=*/false); + Operation *expressionRoot = usedExpression.getRootOp(); + Operation *clonedExpressionRootOp = mapper.lookup(expressionRoot); + assert(clonedExpressionRootOp && + "Expected cloned expression root to be in mapper"); + assert(clonedExpressionRootOp->getNumResults() == 1 && + "Expected cloned root to have a single result"); + mapper.map(usedExpression.getResult(), + clonedExpressionRootOp->getResults()[0]); + + // Now fold the matched expression into the new expression. + foldExpression(expressionOp, /*withTerminator=*/true); + + // Complete the rewrite. + rewriter.replaceOp(expressionOp, foldedExpression); + rewriter.eraseOp(usedExpression); + + return success(); } }; diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp index c55e26e..06d7e07 100644 --- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp +++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp @@ -64,8 +64,8 @@ public: TypeAttr typeAttr = TypeAttr::get(val.getType()); fields.push_back({fieldName, typeAttr}); - FieldOp fieldop = rewriter.create<emitc::FieldOp>( - funcOp->getLoc(), fieldName, typeAttr, nullptr); + FieldOp fieldop = emitc::FieldOp::create(rewriter, funcOp->getLoc(), + fieldName, typeAttr, nullptr); if (argAttrs && idx < argAttrs->size()) { fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx)); diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 2503ccb..b87b4f4 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -2486,8 +2486,7 @@ LogicalResult WarpExecuteOnLane0Op::verify() { if (getArgs().size() != getWarpRegion().getNumArguments()) return emitOpError( "expected same number op arguments and block arguments."); - auto yield = - cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = getTerminator(); if (yield.getNumOperands() != getNumResults()) return emitOpError( "expected same number of yield operands and return values."); @@ -2511,6 +2510,50 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) { verifyDistributedType(lhs, rhs, getWarpSize(), getOperation())); } +gpu::YieldOp WarpExecuteOnLane0Op::getTerminator() { + return cast<gpu::YieldOp>(getBody()->getTerminator()); +} + +//===----------------------------------------------------------------------===// +// GPU_SubgroupBroadcastOp +//===----------------------------------------------------------------------===// + +void gpu::SubgroupBroadcastOp::inferResultRanges( + ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRange) { + setResultRange(getResult(), argRanges.front()); +} + +Speculation::Speculatability gpu::SubgroupBroadcastOp::getSpeculatability() { + switch (getBroadcastType()) { + case BroadcastType::first_active_lane: + // Cannot speculate first_lane broadcast, because speculating it across + // control flow can change the active lanes. + return Speculation::NotSpeculatable; + case BroadcastType::any_lane: + LLVM_FALLTHROUGH; + case BroadcastType::specific_lane: + // Speculation should be safe as long as we inside structured control flow. + return Speculation::Speculatable; + } +} + +LogicalResult gpu::SubgroupBroadcastOp::verify() { + switch (getBroadcastType()) { + case BroadcastType::first_active_lane: + LLVM_FALLTHROUGH; + case BroadcastType::any_lane: + if (getLane()) + return emitOpError() + << "lane can only be specified for `specific_lane` broadcast"; + return success(); + case BroadcastType::specific_lane: + if (!getLane()) + return emitOpError() + << "lane must be specified for `specific_lane` broadcast"; + return success(); + } +} + //===----------------------------------------------------------------------===// // GPU KernelMetadataAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 21cb2f6..c766539 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/GPU/TransformOps/Utils.h" @@ -43,6 +44,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/LogicalResult.h" +#include <optional> #include <type_traits> using namespace mlir; @@ -170,7 +172,16 @@ void ApplyGPURewritePatternsOp::populatePatterns(RewritePatternSet &patterns) { void transform::ApplyGPUPromoteShuffleToAMDGPUPatternsOp::populatePatterns( RewritePatternSet &patterns) { - populateGpuPromoteShuffleToAMDGPUPatterns(patterns); + std::optional<StringRef> chipsetName = getChipset(); + std::optional<amdgpu::Chipset> maybeChipset; + if (chipsetName) { + FailureOr<amdgpu::Chipset> parsedChipset = + amdgpu::Chipset::parse(*chipsetName); + assert(llvm::succeeded(parsedChipset) && "expected valid chipset"); + maybeChipset = parsedChipset; + } + + populateGpuPromoteShuffleToAMDGPUPatterns(patterns, maybeChipset); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp index 9bf11c7..d2c2138 100644 --- a/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/EliminateBarriers.cpp @@ -25,6 +25,7 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_GPUELIMINATEBARRIERS @@ -37,9 +38,6 @@ using namespace mlir::gpu; #define DEBUG_TYPE "gpu-erase-barriers" #define DEBUG_TYPE_ALIAS "gpu-erase-barries-alias" -#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") - // The functions below provide interface-like verification, but are too specific // to barrier elimination to become interfaces. @@ -424,27 +422,18 @@ static bool maybeCaptured(Value v) { /// everything. This seems sufficient to achieve barrier removal in structured /// control flow, more complex cases would require a proper dataflow analysis. static bool mayAlias(Value first, Value second) { - DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, { - DBGS_ALIAS() << "checking aliasing between "; - DBGS_ALIAS() << first << "\n"; - DBGS_ALIAS() << " and "; - DBGS_ALIAS() << second << "\n"; - }); + LDBG(DEBUG_TYPE_ALIAS, 1) + << "checking aliasing between " << first << " and " << second; first = getBase(first); second = getBase(second); - DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, { - DBGS_ALIAS() << "base "; - DBGS_ALIAS() << first << "\n"; - DBGS_ALIAS() << " and "; - DBGS_ALIAS() << second << "\n"; - }); + LDBG(DEBUG_TYPE_ALIAS, 1) << "base " << first << " and " << second; // Values derived from the same base memref do alias (unless we do a more // advanced analysis to prove non-overlapping accesses). if (first == second) { - DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> do alias!\n"); + LDBG(DEBUG_TYPE_ALIAS, 1) << "-> do alias!"; return true; } @@ -493,7 +482,7 @@ static bool mayAlias(Value first, Value second) { return false; // Otherwise, conservatively assume aliasing. - DEBUG_WITH_TYPE(DEBUG_TYPE_ALIAS, DBGS_ALIAS() << "-> may alias!\n"); + LDBG(DEBUG_TYPE_ALIAS, 1) << "-> may alias!"; return true; } @@ -567,20 +556,16 @@ haveConflictingEffects(ArrayRef<MemoryEffects::EffectInstance> beforeEffects, continue; // Other kinds of effects create a conflict, e.g. read-after-write. - LLVM_DEBUG( - DBGS() << "found a conflict between (before): " << before.getValue() - << " read:" << isa<MemoryEffects::Read>(before.getEffect()) - << " write:" << isa<MemoryEffects::Write>(before.getEffect()) - << " alloc:" - << isa<MemoryEffects::Allocate>(before.getEffect()) << " free:" - << isa<MemoryEffects::Free>(before.getEffect()) << "\n"); - LLVM_DEBUG( - DBGS() << "and (after): " << after.getValue() - << " read:" << isa<MemoryEffects::Read>(after.getEffect()) - << " write:" << isa<MemoryEffects::Write>(after.getEffect()) - << " alloc:" << isa<MemoryEffects::Allocate>(after.getEffect()) - << " free:" << isa<MemoryEffects::Free>(after.getEffect()) - << "\n"); + LDBG() << "found a conflict between (before): " << before.getValue() + << " read:" << isa<MemoryEffects::Read>(before.getEffect()) + << " write:" << isa<MemoryEffects::Write>(before.getEffect()) + << " alloc:" << isa<MemoryEffects::Allocate>(before.getEffect()) + << " free:" << isa<MemoryEffects::Free>(before.getEffect()); + LDBG() << "and (after): " << after.getValue() + << " read:" << isa<MemoryEffects::Read>(after.getEffect()) + << " write:" << isa<MemoryEffects::Write>(after.getEffect()) + << " alloc:" << isa<MemoryEffects::Allocate>(after.getEffect()) + << " free:" << isa<MemoryEffects::Free>(after.getEffect()); return true; } } @@ -595,8 +580,8 @@ public: LogicalResult matchAndRewrite(BarrierOp barrier, PatternRewriter &rewriter) const override { - LLVM_DEBUG(DBGS() << "checking the necessity of: " << barrier << " " - << barrier.getLoc() << "\n"); + LDBG() << "checking the necessity of: " << barrier << " " + << barrier.getLoc(); SmallVector<MemoryEffects::EffectInstance> beforeEffects; getEffectsBefore(barrier, beforeEffects, /*stopAtBarrier=*/true); @@ -605,14 +590,12 @@ public: getEffectsAfter(barrier, afterEffects, /*stopAtBarrier=*/true); if (!haveConflictingEffects(beforeEffects, afterEffects)) { - LLVM_DEBUG(DBGS() << "the surrounding barriers are sufficient, removing " - << barrier << "\n"); + LDBG() << "the surrounding barriers are sufficient, removing " << barrier; rewriter.eraseOp(barrier); return success(); } - LLVM_DEBUG(DBGS() << "barrier is necessary: " << barrier << " " - << barrier.getLoc() << "\n"); + LDBG() << "barrier is necessary: " << barrier << " " << barrier.getLoc(); return failure(); } }; diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp index d4978ca..97adad6 100644 --- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp @@ -431,8 +431,7 @@ private: if (std::optional<SymbolTable::UseRange> symbolUses = SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) { for (SymbolTable::SymbolUse symbolUse : *symbolUses) { - StringRef symbolName = - cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue(); + StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference(); if (symbolTable.lookup(symbolName)) continue; diff --git a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp index 18c69f5..67cef8a 100644 --- a/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/PromoteShuffleToAMDGPU.cpp @@ -11,16 +11,21 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/PatternMatch.h" +#include <optional> using namespace mlir; namespace { + +constexpr amdgpu::Chipset kGfx950 = amdgpu::Chipset(9, 5, 0); + /// Try to promote `gpu.shuffle` to `amdgpu.swizzle_bitmode`, width must be 64 /// and offset must be a constant integer in the range [0, 31]. struct PromoteShuffleToSwizzlePattern @@ -56,9 +61,48 @@ struct PromoteShuffleToSwizzlePattern return success(); } }; + +/// Try to promote `gpu.shuffle` to `amdgpu.permlane_swap`, width must be 64 +/// and offset must be a constant integer in the set {16, 32}. +struct PromoteShuffleToPermlanePattern + : public OpRewritePattern<gpu::ShuffleOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(gpu::ShuffleOp op, + PatternRewriter &rewriter) const override { + if (op.getMode() != gpu::ShuffleMode::XOR) + return rewriter.notifyMatchFailure(op, + "only xor shuffle mode is supported"); + + if (!isConstantIntValue(op.getWidth(), 64)) + return rewriter.notifyMatchFailure(op, + "only 64 width shuffle is supported"); + + std::optional<int64_t> offset = getConstantIntValue(op.getOffset()); + if (!offset) + return rewriter.notifyMatchFailure(op, + "offset must be a constant integer"); + + int64_t offsetValue = *offset; + if (offsetValue != 16 && offsetValue != 32) + return rewriter.notifyMatchFailure(op, "offset must be either 15 or 31"); + + Location loc = op.getLoc(); + Value res = amdgpu::PermlaneSwapOp::create( + rewriter, loc, op.getResult(0).getType(), op.getValue(), offsetValue); + Value valid = arith::ConstantIntOp::create(rewriter, loc, 1, /*width*/ 1); + rewriter.replaceOp(op, {res, valid}); + return success(); + } +}; + } // namespace void mlir::populateGpuPromoteShuffleToAMDGPUPatterns( - RewritePatternSet &patterns) { - patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext()); + RewritePatternSet &patterns, std::optional<amdgpu::Chipset> maybeChipset) { + patterns.add<PromoteShuffleToSwizzlePattern>(patterns.getContext(), + /*benefit*/ 1); + if (maybeChipset && *maybeChipset >= kGfx950) + patterns.add<PromoteShuffleToPermlanePattern>(patterns.getContext(), + /*benefit*/ 2); } diff --git a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp index e9cf493..6da76e9 100644 --- a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" +#include "mlir/Target/LLVM/XeVM/Target.h" #include "llvm/Support/Regex.h" namespace mlir { diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp index 384d1a0..88f531f 100644 --- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp +++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Value.h" +#include "llvm/ADT/DenseMap.h" #include <numeric> @@ -55,28 +56,30 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns( SmallVector<size_t> &indices) const { SmallVector<Type> types(warpOp.getResultTypes().begin(), warpOp.getResultTypes().end()); - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); - llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(), - yield.getOperands().end()); + gpu::YieldOp yield = warpOp.getTerminator(); + SmallVector<Value> yieldValues(yield.getOperands().begin(), + yield.getOperands().end()); + llvm::SmallDenseMap<Value, unsigned> indexLookup; + // Record the value -> first index mapping for faster lookup. + for (auto [i, v] : llvm::enumerate(yieldValues)) { + if (!indexLookup.count(v)) + indexLookup[v] = i; + } + for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) { - if (yieldValues.insert(value)) { + // If the value already exists in the yield, don't create a new output. + if (indexLookup.count(value)) { + indices.push_back(indexLookup[value]); + } else { + // If the value is new, add it to the yield and to the types. + yieldValues.push_back(value); types.push_back(type); indices.push_back(yieldValues.size() - 1); - } else { - // If the value already exit the region don't create a new output. - for (auto [idx, yieldOperand] : - llvm::enumerate(yieldValues.getArrayRef())) { - if (yieldOperand == value) { - indices.push_back(idx); - break; - } - } } } - yieldValues.insert_range(newYieldedValues); + WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns( - rewriter, warpOp, yieldValues.getArrayRef(), types); + rewriter, warpOp, yieldValues, types); rewriter.replaceOp(warpOp, newWarpOp.getResults().take_front(warpOp.getNumResults())); return newWarpOp; @@ -85,8 +88,7 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns( OpOperand *WarpDistributionPattern::getWarpResult( WarpExecuteOnLane0Op warpOp, llvm::function_ref<bool(Operation *)> fn) const { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); for (OpOperand &yieldOperand : yield->getOpOperands()) { Value yieldValues = yieldOperand.get(); Operation *definedOp = yieldValues.getDefiningOp(); diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index ff55f17..ec581ac 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -32,6 +32,7 @@ add_mlir_dialect_library(MLIRLLVMDialect MLIRInferTypeOpInterface MLIRIR MLIRMemorySlotInterfaces + MLIRPtrMemorySpaceInterfaces MLIRSideEffectInterfaces MLIRSupport ) diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp index 894de44..7220e10 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp @@ -12,10 +12,20 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/BasicPtxBuilderInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/MLIRContext.h" + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LogicalResult.h" +#include "llvm/Support/Regex.h" #define DEBUG_TYPE "ptx-builder" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") //===----------------------------------------------------------------------===// // BasicPtxBuilderInterface @@ -28,50 +38,122 @@ using namespace NVVM; static constexpr int64_t kSharedMemorySpace = 3; -static char getRegisterType(Type type) { - if (type.isInteger(1)) - return 'b'; - if (type.isInteger(16)) - return 'h'; - if (type.isInteger(32)) - return 'r'; - if (type.isInteger(64)) - return 'l'; - if (type.isF32()) - return 'f'; - if (type.isF64()) - return 'd'; - if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) { - // Shared address spaces is addressed with 32-bit pointers. - if (ptr.getAddressSpace() == kSharedMemorySpace) { +static FailureOr<char> getRegisterType(Type type, Location loc) { + MLIRContext *ctx = type.getContext(); + auto i16 = IntegerType::get(ctx, 16); + auto i32 = IntegerType::get(ctx, 32); + auto f32 = Float32Type::get(ctx); + + auto getRegisterTypeForScalar = [&](Type type) -> FailureOr<char> { + if (type.isInteger(1)) + return 'b'; + if (type.isInteger(16)) + return 'h'; + if (type.isInteger(32)) return 'r'; + if (type.isInteger(64)) + return 'l'; + if (type.isF32()) + return 'f'; + if (type.isF64()) + return 'd'; + if (auto ptr = dyn_cast<LLVM::LLVMPointerType>(type)) { + // Shared address spaces is addressed with 32-bit pointers. + if (ptr.getAddressSpace() == kSharedMemorySpace) { + return 'r'; + } + return 'l'; } - return 'l'; + // register type for struct is not supported. + mlir::emitError( + loc, "The register type could not be deduced from MLIR type. The ") + << type + << " is not supported. Supported types are:" + "i1, i16, i32, i64, f32, f64," + "pointers.\nPlease use llvm.bitcast if you have different type. " + "\nSee the constraints from here: " + "https://docs.nvidia.com/cuda/inline-ptx-assembly/" + "index.html#constraints"; + return failure(); + }; + + // Packed registers + if (auto v = dyn_cast<VectorType>(type)) { + assert(v.getNumDynamicDims() == 0 && "Dynamic vectors are not supported"); + + int64_t lanes = v.getNumElements(); + Type elem = v.getElementType(); + + // Case 1. Single vector + if (lanes <= 1) + return getRegisterTypeForScalar(elem); + + // Case 2. Packed registers + Type widened = elem; + switch (lanes) { + + case 2: + if (elem.isF16() || elem.isBF16()) // vector<2xf16> + widened = f32; + else if (elem.isFloat(8)) // vector<2xf8> + widened = i16; + break; + case 4: + if (elem.isInteger(8)) // vector<i8x4> + widened = i32; + else if (elem.isFloat(8)) // vector<f8x4> + widened = f32; + else if (elem.isFloat(4)) // vector<f4x4> + widened = i16; + break; + // Other packing is not supported + default: + break; + } + return getRegisterTypeForScalar(widened); } - // register type for struct is not supported. - llvm_unreachable("The register type could not deduced from MLIR type"); - return '?'; + + return getRegisterTypeForScalar(type); } -static char getRegisterType(Value v) { +static FailureOr<char> getRegisterType(Value v, Location loc) { if (v.getDefiningOp<LLVM::ConstantOp>()) return 'n'; - return getRegisterType(v.getType()); + return getRegisterType(v.getType(), loc); } -void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { - LLVM_DEBUG(DBGS() << v << "\t Modifier : " << &itype << "\n"); +/// Extract every element of a struct value. +static SmallVector<Value> extractStructElements(PatternRewriter &rewriter, + Location loc, Value structVal) { + auto structTy = dyn_cast<LLVM::LLVMStructType>(structVal.getType()); + assert(structTy && "expected LLVM struct"); + + SmallVector<Value> elems; + for (unsigned i : llvm::seq<unsigned>(0, structTy.getBody().size())) + elems.push_back(rewriter.create<LLVM::ExtractValueOp>(loc, structVal, i)); + + return elems; +} + +LogicalResult PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { + LDBG() << v << "\t Modifier : " << itype << "\n"; + registerModifiers.push_back(itype); + + Location loc = interfaceOp->getLoc(); auto getModifier = [&]() -> const char * { - if (itype == PTXRegisterMod::ReadWrite) { - assert(false && "Read-Write modifier is not supported. Try setting the " - "same value as Write and Read separately."); - return "+"; - } - if (itype == PTXRegisterMod::Write) { + switch (itype) { + case PTXRegisterMod::Read: + return ""; + case PTXRegisterMod::Write: return "="; + case PTXRegisterMod::ReadWrite: + // "Read-Write modifier is not actually supported + // Interface will change it to "=" later and add integer mapping + return "+"; } - return ""; + llvm_unreachable("Unknown PTX register modifier"); }; + auto addValue = [&](Value v) { if (itype == PTXRegisterMod::Read) { ptxOperands.push_back(v); @@ -90,35 +172,273 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) { } for (auto [idx, t] : llvm::enumerate(stype.getBody())) { if (itype != PTXRegisterMod::Write) { - Value extractValue = LLVM::ExtractValueOp::create( - rewriter, interfaceOp->getLoc(), v, idx); + Value extractValue = + LLVM::ExtractValueOp::create(rewriter, loc, v, idx); addValue(extractValue); } if (itype == PTXRegisterMod::ReadWrite) { ss << idx << ","; } else { - ss << getModifier() << getRegisterType(t) << ","; + FailureOr<char> regType = getRegisterType(t, loc); + if (failed(regType)) + return rewriter.notifyMatchFailure(loc, + "failed to get register type"); + ss << getModifier() << regType.value() << ","; } } - return; + return success(); } // Handle Scalars addValue(v); - ss << getModifier() << getRegisterType(v) << ","; + FailureOr<char> regType = getRegisterType(v, loc); + if (failed(regType)) + return rewriter.notifyMatchFailure(loc, "failed to get register type"); + ss << getModifier() << regType.value() << ","; + return success(); +} + +/// Check if the operation needs to pack and unpack results. +static bool +needsPackUnpack(BasicPtxBuilderInterface interfaceOp, + bool needsManualRegisterMapping, + SmallVectorImpl<PTXRegisterMod> ®isterModifiers) { + if (needsManualRegisterMapping) + return false; + const unsigned writeOnlyVals = interfaceOp->getNumResults(); + const unsigned readWriteVals = + llvm::count_if(registerModifiers, [](PTXRegisterMod m) { + return m == PTXRegisterMod::ReadWrite; + }); + return (writeOnlyVals + readWriteVals) > 1; +} + +/// Pack the result types of the interface operation. +/// If the operation has multiple results, it packs them into a struct +/// type. Otherwise, it returns the original result types. +static SmallVector<Type> +packResultTypes(BasicPtxBuilderInterface interfaceOp, + bool needsManualRegisterMapping, + SmallVectorImpl<PTXRegisterMod> ®isterModifiers, + SmallVectorImpl<Value> &ptxOperands) { + MLIRContext *ctx = interfaceOp->getContext(); + TypeRange resultRange = interfaceOp->getResultTypes(); + + if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping, + registerModifiers)) { + // Single value path: + if (interfaceOp->getResults().size() == 1) + return SmallVector<Type>{resultRange.front()}; + + // No declared results: if there is an RW, forward its type. + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) + if (m == PTXRegisterMod::ReadWrite) + return SmallVector<Type>{v.getType()}; + } + + SmallVector<Type> packed; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) + if (m == PTXRegisterMod::ReadWrite) + packed.push_back(v.getType()); + for (Type t : resultRange) + packed.push_back(t); + + if (packed.empty()) + return {}; + + auto sTy = LLVM::LLVMStructType::getLiteral(ctx, packed, /*isPacked=*/false); + return SmallVector<Type>{sTy}; +} + +/// Canonicalize the register constraints: +/// - Turn every "+X" into "=X" +/// - Append (at the very end) the 0-based indices of tokens that were "+X" +/// Examples: +/// "+f,+f,+r,=r,=r,r,r" -> "=f,=f,=r,=r,=r,r,r,0,1,2" +/// "+f,+f,+r,=r,=r" -> "=f,=f,=r,=r,=r,0,1,2" +static std::string canonicalizeRegisterConstraints(llvm::StringRef csv) { + SmallVector<llvm::StringRef> toks; + SmallVector<std::string> out; + SmallVector<unsigned> plusIdx; + + csv.split(toks, ','); + out.reserve(toks.size() + 8); + + for (unsigned i = 0, e = toks.size(); i < e; ++i) { + StringRef t = toks[i].trim(); + if (t.consume_front("+")) { + plusIdx.push_back(i); + out.push_back(("=" + t).str()); + } else { + out.push_back(t.str()); + } + } + + // Append indices of original "+X" tokens. + for (unsigned idx : plusIdx) + out.push_back(std::to_string(idx)); + + // Join back to CSV. + std::string result; + result.reserve(csv.size() + plusIdx.size() * 2); + llvm::raw_string_ostream os(result); + for (size_t i = 0; i < out.size(); ++i) { + if (i) + os << ','; + os << out[i]; + } + return os.str(); +} + +constexpr llvm::StringLiteral kReadWritePrefix{"rw"}; +constexpr llvm::StringLiteral kWriteOnlyPrefix{"w"}; +constexpr llvm::StringLiteral kReadOnlyPrefix{"r"}; + +/// Returns a regex that matches {$rwN}, {$wN}, {$rN} +static llvm::Regex getPredicateMappingRegex() { + llvm::Regex rx(llvm::formatv(R"(\{\$({0}|{1}|{2})([0-9]+)\})", + kReadWritePrefix, kWriteOnlyPrefix, + kReadOnlyPrefix) + .str()); + return rx; +} + +void mlir::NVVM::countPlaceholderNumbers( + StringRef ptxCode, llvm::SmallDenseSet<unsigned int> &seenRW, + llvm::SmallDenseSet<unsigned int> &seenW, + llvm::SmallDenseSet<unsigned int> &seenR, + llvm::SmallVectorImpl<unsigned int> &rwNums, + llvm::SmallVectorImpl<unsigned int> &wNums, + llvm::SmallVectorImpl<unsigned int> &rNums) { + + llvm::Regex rx = getPredicateMappingRegex(); + StringRef rest = ptxCode; + + SmallVector<StringRef, 3> m; // 0: full, 1: kind, 2: number + while (!rest.empty() && rx.match(rest, &m)) { + unsigned num = 0; + (void)m[2].getAsInteger(10, num); + // Insert it into the vector only the first time we see this number + if (m[1].equals_insensitive(kReadWritePrefix)) { + if (seenRW.insert(num).second) + rwNums.push_back(num); + } else if (m[1].equals_insensitive(kWriteOnlyPrefix)) { + if (seenW.insert(num).second) + wNums.push_back(num); + } else { + if (seenR.insert(num).second) + rNums.push_back(num); + } + + const size_t advance = (size_t)(m[0].data() - rest.data()) + m[0].size(); + rest = rest.drop_front(advance); + } +} + +/// Rewrites `{$rwN}`, `{$wN}`, and `{$rN}` placeholders in `ptxCode` into +/// compact `$K` indices: +/// - All `rw*` first (sorted by N), +/// - Then `w*`, +/// - Then `r*`. +/// If there a predicate, it comes always in the end. +/// Each number is assigned once; duplicates are ignored. +/// +/// Example Input: +/// "{ +/// reg .pred p; +/// setp.ge.s32 p, {$r0}, {$r1};" +/// selp.s32 {$rw0}, {$r0}, {$r1}, p; +/// selp.s32 {$rw1}, {$r0}, {$r1}, p; +/// selp.s32 {$w0}, {$r0}, {$r1}, p; +/// selp.s32 {$w1}, {$r0}, {$r1}, p; +/// }\n" +/// Example Output: +/// "{ +/// reg .pred p; +/// setp.ge.s32 p, $4, $5;" +/// selp.s32 $0, $4, $5, p; +/// selp.s32 $1, $4, $5, p; +/// selp.s32 $2, $4, $5, p; +/// selp.s32 $3, $4, $5, p; +/// }\n" +static std::string rewriteAsmPlaceholders(llvm::StringRef ptxCode) { + llvm::SmallDenseSet<unsigned> seenRW, seenW, seenR; + llvm::SmallVector<unsigned> rwNums, wNums, rNums; + + // Step 1. Count Register Placeholder numbers + countPlaceholderNumbers(ptxCode, seenRW, seenW, seenR, rwNums, wNums, rNums); + + // Step 2. Sort the Register Placeholder numbers + llvm::sort(rwNums); + llvm::sort(wNums); + llvm::sort(rNums); + + // Step 3. Create mapping from original to new IDs + llvm::DenseMap<unsigned, unsigned> rwMap, wMap, rMap; + unsigned nextId = 0; + for (unsigned n : rwNums) + rwMap[n] = nextId++; + for (unsigned n : wNums) + wMap[n] = nextId++; + for (unsigned n : rNums) + rMap[n] = nextId++; + + // Step 4. Rewrite the PTX code with new IDs + std::string out; + out.reserve(ptxCode.size()); + size_t prev = 0; + StringRef rest = ptxCode; + SmallVector<StringRef, 3> matches; + llvm::Regex rx = getPredicateMappingRegex(); + while (!rest.empty() && rx.match(rest, &matches)) { + // Compute absolute match bounds in the original buffer. + size_t absStart = (size_t)(matches[0].data() - ptxCode.data()); + size_t absEnd = absStart + matches[0].size(); + + // Emit text before the match. + out.append(ptxCode.data() + prev, ptxCode.data() + absStart); + + // Emit compact $K + unsigned num = 0; + (void)matches[2].getAsInteger(10, num); + unsigned id = 0; + if (matches[1].equals_insensitive(kReadWritePrefix)) + id = rwMap.lookup(num); + else if (matches[1].equals_insensitive(kWriteOnlyPrefix)) + id = wMap.lookup(num); + else + id = rMap.lookup(num); + + out.push_back('$'); + out += std::to_string(id); + + prev = absEnd; + + const size_t advance = + (size_t)(matches[0].data() - rest.data()) + matches[0].size(); + rest = rest.drop_front(advance); + } + + // Step 5. Tail. + out.append(ptxCode.data() + prev, ptxCode.data() + ptxCode.size()); + return out; } LLVM::InlineAsmOp PtxBuilder::build() { auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(), LLVM::AsmDialect::AD_ATT); - auto resultTypes = interfaceOp->getResultTypes(); + SmallVector<Type> resultTypes = packResultTypes( + interfaceOp, needsManualRegisterMapping, registerModifiers, ptxOperands); // Remove the last comma from the constraints string. if (!registerConstraints.empty() && registerConstraints[registerConstraints.size() - 1] == ',') registerConstraints.pop_back(); + registerConstraints = canonicalizeRegisterConstraints(registerConstraints); std::string ptxInstruction = interfaceOp.getPtx(); + if (!needsManualRegisterMapping) + ptxInstruction = rewriteAsmPlaceholders(ptxInstruction); // Add the predicate to the asm string. if (interfaceOp.getPredicate().has_value() && @@ -136,7 +456,7 @@ LLVM::InlineAsmOp PtxBuilder::build() { rewriter, interfaceOp->getLoc(), /*result types=*/resultTypes, /*operands=*/ptxOperands, - /*asm_string=*/llvm::StringRef(ptxInstruction), + /*asm_string=*/ptxInstruction, /*constraints=*/registerConstraints.data(), /*has_side_effects=*/interfaceOp.hasSideEffect(), /*is_align_stack=*/false, LLVM::TailCallKind::None, @@ -146,10 +466,89 @@ LLVM::InlineAsmOp PtxBuilder::build() { void PtxBuilder::buildAndReplaceOp() { LLVM::InlineAsmOp inlineAsmOp = build(); - LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n"); - if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) { - rewriter.replaceOp(interfaceOp, inlineAsmOp); - } else { + LDBG() << "\n Generated PTX \n\t" << inlineAsmOp; + + // Case 0: no result at all → just erase wrapper op. + if (!hasResult) { rewriter.eraseOp(interfaceOp); + return; + } + + if (needsManualRegisterMapping) { + rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults()); + return; + } + + // Case 1: Simple path, return single scalar + if (!needsPackUnpack(interfaceOp, needsManualRegisterMapping, + registerModifiers)) { + if (inlineAsmOp->getNumResults() > 0) { + rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults()); + } else { + // RW-only case with no declared results: forward the RW value. + SmallVector<Value> results; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) + if (m == PTXRegisterMod::ReadWrite) { + results.push_back(v); + break; + } + rewriter.replaceOp(interfaceOp, results); + } + return; + } + + const bool hasRW = llvm::any_of(registerModifiers, [](PTXRegisterMod m) { + return m == PTXRegisterMod::ReadWrite; + }); + + // All multi-value paths produce a single struct result we need to unpack. + assert(LLVM::LLVMStructType::classof(inlineAsmOp.getResultTypes().front()) && + "expected struct return for multi-result inline asm"); + Value structVal = inlineAsmOp.getResult(0); + SmallVector<Value> unpacked = + extractStructElements(rewriter, interfaceOp->getLoc(), structVal); + + // Case 2: only declared results (no RW): replace the op with all unpacked. + if (!hasRW && interfaceOp->getResults().size() > 0) { + rewriter.replaceOp(interfaceOp, unpacked); + return; + } + + // Case 3: RW-only (no declared results): update RW uses and erase wrapper. + if (hasRW && interfaceOp->getResults().size() == 0) { + unsigned idx = 0; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) { + if (m != PTXRegisterMod::ReadWrite) + continue; + Value repl = unpacked[idx++]; + v.replaceUsesWithIf(repl, [&](OpOperand &use) { + Operation *owner = use.getOwner(); + return owner != interfaceOp && owner != inlineAsmOp; + }); + } + rewriter.eraseOp(interfaceOp); + return; + } + + // Case 4: mixed (RW + declared results). + { + // First rewrite RW operands in place. + unsigned idx = 0; + for (auto [m, v] : llvm::zip(registerModifiers, ptxOperands)) { + if (m != PTXRegisterMod::ReadWrite) + continue; + Value repl = unpacked[idx++]; + v.replaceUsesWithIf(repl, [&](OpOperand &use) { + Operation *owner = use.getOwner(); + return owner != interfaceOp && owner != inlineAsmOp; + }); + } + // The remaining unpacked values correspond to the declared results. + SmallVector<Value> tail; + tail.reserve(unpacked.size() - idx); + for (unsigned i = idx, e = unpacked.size(); i < e; ++i) + tail.push_back(unpacked[i]); + + rewriter.replaceOp(interfaceOp, tail); } } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index 1e02bfe..e268e8f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -12,6 +12,8 @@ #include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "mlir/Dialect/Ptr/IR/PtrEnums.h" #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/Interfaces/FunctionInterfaces.h" @@ -51,6 +53,87 @@ void LLVMDialect::registerAttributes() { } //===----------------------------------------------------------------------===// +// AddressSpaceAttr +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is an LLVM type that can be loaded or stored. +static bool isValidLoadStoreImpl(Type type, ptr::AtomicOrdering ordering, + std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) { + if (!isLoadableType(type)) { + if (emitError) + emitError() << "type must be LLVM type with size, but got " << type; + return false; + } + if (ordering == ptr::AtomicOrdering::not_atomic) + return true; + + // To check atomic validity we need a datalayout. + if (!dataLayout) { + if (emitError) + emitError() << "expected a valid data layout"; + return false; + } + if (!isTypeCompatibleWithAtomicOp(type, *dataLayout)) { + if (emitError) + emitError() << "unsupported type " << type << " for atomic access"; + return false; + } + return true; +} + +bool AddressSpaceAttr::isValidLoad( + Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) const { + return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError); +} + +bool AddressSpaceAttr::isValidStore( + Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) const { + return isValidLoadStoreImpl(type, ordering, alignment, dataLayout, emitError); +} + +bool AddressSpaceAttr::isValidAtomicOp( + ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering, + std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) const { + // TODO: update this method once `ptr.atomic_rmw` is implemented. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + +bool AddressSpaceAttr::isValidAtomicXchg( + Type type, ptr::AtomicOrdering successOrdering, + ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) const { + // TODO: update this method once `ptr.atomic_cmpxchg` is implemented. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + +bool AddressSpaceAttr::isValidAddrSpaceCast( + Type tgt, Type src, function_ref<InFlightDiagnostic()> emitError) const { + // TODO: update this method once the `ptr.addrspace_cast` op is added to the + // dialect. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + +bool AddressSpaceAttr::isValidPtrIntCast( + Type intLikeTy, Type ptrLikeTy, + function_ref<InFlightDiagnostic()> emitError) const { + // TODO: update this method once the int-cast ops are added to the `ptr` + // dialect. + assert(false && "unimplemented, see TODO in the source."); + return false; +} + +//===----------------------------------------------------------------------===// // AliasScopeAttr //===----------------------------------------------------------------------===// @@ -374,6 +457,43 @@ TargetFeaturesAttr TargetFeaturesAttr::featuresAt(Operation *op) { getAttributeName()); } +FailureOr<Attribute> TargetFeaturesAttr::query(DataLayoutEntryKey key) { + auto stringKey = dyn_cast<StringAttr>(key); + if (!stringKey) + return failure(); + + if (contains(stringKey)) + return UnitAttr::get(getContext()); + + if (contains((std::string("+") + stringKey.strref()).str())) + return BoolAttr::get(getContext(), true); + + if (contains((std::string("-") + stringKey.strref()).str())) + return BoolAttr::get(getContext(), false); + + return failure(); +} + +//===----------------------------------------------------------------------===// +// TargetAttr +//===----------------------------------------------------------------------===// + +FailureOr<::mlir::Attribute> TargetAttr::query(DataLayoutEntryKey key) { + if (auto stringAttrKey = dyn_cast<StringAttr>(key)) { + if (stringAttrKey.getValue() == "triple") + return getTriple(); + if (stringAttrKey.getValue() == "chip") + return getChip(); + if (stringAttrKey.getValue() == "features" && getFeatures()) + return getFeatures(); + } + return failure(); +} + +//===----------------------------------------------------------------------===// +// ModuleFlagAttr +//===----------------------------------------------------------------------===// + LogicalResult ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError, LLVM::ModFlagBehavior flagBehavior, StringAttr key, diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 422039f..ef27070 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -141,6 +141,38 @@ static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) { return success(); } +static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder, + bool isExpandLoad, + uint64_t alignment = 1) { + // From + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // + // The pointer alignment defaults to 1. + if (alignment == 1) { + return nullptr; + } + + auto emptyDictAttr = builder.getDictionaryAttr({}); + auto alignmentAttr = builder.getI64IntegerAttr(alignment); + auto namedAttr = + builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr); + SmallVector<mlir::NamedAttribute> attrs = {namedAttr}; + auto alignDictAttr = builder.getDictionaryAttr(attrs); + // From + // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics + // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics + // + // The align parameter attribute can be provided for [expandload]'s first + // argument. The align parameter attribute can be provided for + // [compressstore]'s second argument. + int pos = isExpandLoad ? 0 : 1; + return pos == 0 ? builder.getArrayAttr( + {alignDictAttr, emptyDictAttr, emptyDictAttr}) + : builder.getArrayAttr( + {emptyDictAttr, alignDictAttr, emptyDictAttr}); +} + //===----------------------------------------------------------------------===// // Operand bundle helpers. //===----------------------------------------------------------------------===// @@ -821,8 +853,8 @@ void LoadOp::getEffects( /// Returns true if the given type is supported by atomic operations. All /// integer, float, and pointer types with a power-of-two bitsize and a minimal /// size of 8 bits are supported. -static bool isTypeCompatibleWithAtomicOp(Type type, - const DataLayout &dataLayout) { +bool LLVM::isTypeCompatibleWithAtomicOp(Type type, + const DataLayout &dataLayout) { if (!isa<IntegerType, LLVMPointerType>(type)) if (!isCompatibleFloatingPointType(type)) return false; @@ -836,8 +868,9 @@ static bool isTypeCompatibleWithAtomicOp(Type type, /// Verifies the attributes and the type of atomic memory access operations. template <typename OpTy> -LogicalResult verifyAtomicMemOp(OpTy memOp, Type valueType, - ArrayRef<AtomicOrdering> unsupportedOrderings) { +static LogicalResult +verifyAtomicMemOp(OpTy memOp, Type valueType, + ArrayRef<AtomicOrdering> unsupportedOrderings) { if (memOp.getOrdering() != AtomicOrdering::not_atomic) { DataLayout dataLayout = DataLayout::closest(memOp); if (!isTypeCompatibleWithAtomicOp(valueType, dataLayout)) @@ -1087,7 +1120,7 @@ static LogicalResult verifyCallOpDebugInfo(CallOp callOp, LLVMFuncOp callee) { /// Verify that the parameter and return types of the variadic callee type match /// the `callOp` argument and result types. template <typename OpTy> -LogicalResult verifyCallOpVarCalleeType(OpTy callOp) { +static LogicalResult verifyCallOpVarCalleeType(OpTy callOp) { std::optional<LLVMFunctionType> varCalleeType = callOp.getVarCalleeType(); if (!varCalleeType) return success(); @@ -2500,7 +2533,7 @@ LogicalResult GlobalOp::verifyRegions() { // LLVM::GlobalCtorsOp //===----------------------------------------------------------------------===// -LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) { +static LogicalResult checkGlobalXtorData(Operation *op, ArrayAttr data) { if (data.empty()) return success(); @@ -4117,6 +4150,32 @@ LogicalResult LLVM::masked_scatter::verify() { } //===----------------------------------------------------------------------===// +// masked_expandload (intrinsic) +//===----------------------------------------------------------------------===// + +void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state, + mlir::TypeRange resTys, Value ptr, + Value mask, Value passthru, + uint64_t align) { + ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align); + build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs, + /*res_attrs=*/nullptr); +} + +//===----------------------------------------------------------------------===// +// masked_compressstore (intrinsic) +//===----------------------------------------------------------------------===// + +void LLVM::masked_compressstore::build(OpBuilder &builder, + OperationState &state, Value value, + Value ptr, Value mask, uint64_t align) { + ArrayAttr argAttrs = + getLLVMAlignParamForCompressExpand(builder, false, align); + build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs, + /*res_attrs=*/nullptr); +} + +//===----------------------------------------------------------------------===// // InlineAsmOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index e7d5dad..ef38027 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -19,6 +19,7 @@ #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/DebugLog.h" #define DEBUG_TYPE "sroa" @@ -734,9 +735,8 @@ static std::optional<uint64_t> gepToByteOffset(const DataLayout &dataLayout, return false; }) .Default([&](Type type) { - LLVM_DEBUG(llvm::dbgs() - << "[sroa] Unsupported type for offset computations" - << type << "\n"); + LDBG() << "[sroa] Unsupported type for offset computations" + << type; return true; }); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp index 78b4411..297640c 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -24,7 +24,9 @@ using namespace mlir::LLVM; /// prints it as usual. static void dispatchPrint(AsmPrinter &printer, Type type) { if (isCompatibleType(type) && - !llvm::isa<IntegerType, FloatType, VectorType>(type)) + !(llvm::isa<IntegerType, FloatType, VectorType>(type) || + (llvm::isa<PtrLikeTypeInterface>(type) && + !llvm::isa<LLVMPointerType>(type)))) return mlir::LLVM::detail::printType(type, printer); printer.printType(type); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index fee2d3e..2dd0132 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -13,6 +13,7 @@ #include "TypeDetail.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinTypes.h" @@ -701,6 +702,17 @@ const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const { // Utility functions. //===----------------------------------------------------------------------===// +/// Check whether type is a compatible ptr type. These are pointer-like types +/// with no element type, no metadata, and using the LLVM AddressSpaceAttr +/// memory space. +static bool isCompatiblePtrType(Type type) { + auto ptrTy = dyn_cast<PtrLikeTypeInterface>(type); + if (!ptrTy) + return false; + return !ptrTy.hasPtrMetadata() && ptrTy.getElementType() == nullptr && + isa<AddressSpaceAttr>(ptrTy.getMemorySpace()); +} + bool mlir::LLVM::isCompatibleOuterType(Type type) { // clang-format off if (llvm::isa< @@ -734,7 +746,7 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) { if (auto vecType = llvm::dyn_cast<VectorType>(type)) return vecType.getRank() == 1; - return false; + return isCompatiblePtrType(type); } static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) { @@ -784,6 +796,8 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) { LLVMX86AMXType >([](Type) { return true; }) // clang-format on + .Case<PtrLikeTypeInterface>( + [](Type type) { return isCompatiblePtrType(type); }) .Default([](Type) { return false; }); if (!result) @@ -805,6 +819,18 @@ bool mlir::LLVM::isCompatibleType(Type type) { return LLVMDialect::isCompatibleType(type); } +bool mlir::LLVM::isLoadableType(Type type) { + return /*LLVM_PrimitiveType*/ ( + LLVM::isCompatibleOuterType(type) && + !isa<LLVM::LLVMVoidType, LLVM::LLVMFunctionType>(type)) && + /*LLVM_OpaqueStruct*/ + !(isa<LLVM::LLVMStructType>(type) && + cast<LLVM::LLVMStructType>(type).isOpaque()) && + /*LLVM_AnyTargetExt*/ + !(isa<LLVM::LLVMTargetExtType>(type) && + !cast<LLVM::LLVMTargetExtType>(type).supportsMemOps()); +} + bool mlir::LLVM::isCompatibleFloatingPointType(Type type) { return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type, Float80Type, Float128Type, LLVMPPCFP128Type>(type); @@ -818,7 +844,8 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) { if (auto intType = llvm::dyn_cast<IntegerType>(elementType)) return intType.isSignless(); return llvm::isa<BFloat16Type, Float16Type, Float32Type, Float64Type, - Float80Type, Float128Type, LLVMPointerType>(elementType); + Float80Type, Float128Type, LLVMPointerType>(elementType) || + isCompatiblePtrType(elementType); } return false; } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 7ad429e..77ec1eb 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -33,6 +33,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/NVPTXAddrSpace.h" #include "llvm/Support/raw_ostream.h" #include <cassert> #include <optional> @@ -50,7 +51,6 @@ using namespace NVVM; // This verifier is shared among the following Ops: // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) -// CpAsyncBulkTensorPrefetchOp (TMA Prefetch) // CpAsyncBulkTensorReduceOp (TMA Store-Reduce) static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims, bool isIm2Col, @@ -82,8 +82,27 @@ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() { } LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() { - if (getCoordinates().size() > 5) - return emitError("Maximum 5 coordinates and dimension is supported."); + TMAStoreMode mode = getMode(); + // We lower through inline-ptx when getPredicate() is true. + // a) Only TILE mode is supported + // b) Cache-hint is not supported + if (getPredicate()) { + if (mode != TMAStoreMode::TILE) + return emitError("Inline-ptx lowering supported only for Tile mode."); + if (getL2CacheHint()) + return emitError("Inline-ptx lowering unsupported with L2 cache-hint."); + } + + size_t dims = getCoordinates().size(); + switch (mode) { + case TMAStoreMode::TILE: + return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc()); + case TMAStoreMode::IM2COL: + return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc()); + case TMAStoreMode::TILE_SCATTER4: + if (dims != 5) + return emitError("Scatter4 mode expects 5 coordinates"); + } return success(); } @@ -98,17 +117,59 @@ LogicalResult CpAsyncOp::verify() { return success(); } +// This verify params can be shared across TMA Load and Prefetch Ops. +static LogicalResult verifyTMALoadParams(size_t tensorDims, size_t numIm2colOff, + TMALoadMode mode, Location loc) { + if (tensorDims < 1 || tensorDims > 5) + return emitError(loc, "expects coordinates between 1 to 5 dimension"); + + auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col, + size_t expectedIm2colOff) -> LogicalResult { + if (isIm2col && (tensorDims < 3)) + return emitError(loc) + << "to use " << stringifyEnum(mode) + << " mode, the tensor has to be at least 3-dimensional"; + + if (numIm2colOff != expectedIm2colOff) + return emitError(loc) << " im2col offsets expected " << expectedIm2colOff + << " (provided " << numIm2colOff << ")"; + + return success(); + }; + + switch (mode) { + case TMALoadMode::TILE: + return checkTMALoadParams(mode, false, 0); + case TMALoadMode::IM2COL: + return checkTMALoadParams(mode, true, tensorDims - 2); + case TMALoadMode::IM2COL_W: + case TMALoadMode::IM2COL_W_128: + return checkTMALoadParams(mode, true, 2); + case TMALoadMode::TILE_GATHER4: + return (tensorDims == 5) + ? checkTMALoadParams(mode, false, 0) + : emitError(loc, "Gather4 mode expects 5 coordinates"); + } + return success(); +} + LogicalResult CpAsyncBulkTensorPrefetchOp::verify() { - size_t numIm2ColOffsets = getIm2colOffsets().size(); - bool isIm2Col = numIm2ColOffsets > 0; - return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, - numIm2ColOffsets, getLoc()); + return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(), + getMode(), getLoc()); } LogicalResult CpAsyncBulkTensorReduceOp::verify() { - bool isIm2Col = (getMode() == TMAStoreMode::IM2COL); - return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0, - getLoc()); + TMAStoreMode mode = getMode(); + size_t dims = getCoordinates().size(); + switch (mode) { + case TMAStoreMode::TILE: + return cpAsyncBulkTensorCommonVerifier(dims, false, 0, getLoc()); + case TMAStoreMode::IM2COL: + return cpAsyncBulkTensorCommonVerifier(dims, true, 0, getLoc()); + case TMAStoreMode::TILE_SCATTER4: + return emitError("Scatter mode unsupported for CpAsyncBulkTensorReduceOp"); + } + return success(); } LogicalResult ConvertFloatToTF32Op::verify() { @@ -811,24 +872,58 @@ LogicalResult NVVM::WMMAMmaOp::verify() { } LogicalResult NVVM::LdMatrixOp::verify() { - unsigned addressSpace = - llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace(); - if (addressSpace != NVVM::kSharedMemorySpace) - return emitOpError("expected source pointer in memory space 3"); - - if (getNum() != 1 && getNum() != 2 && getNum() != 4) - return emitOpError("expected num attribute to be 1, 2 or 4"); + uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN(); + if (m == 8 && n == 8) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 " + "matrix"); + } + if (getEltType() != LdStMatrixEltType::B16) { + return emitOpError("expected element type to be b16 for 8x8 matrix"); + } + } else if (m == 8 && n == 16) { + if (num != 1 && num != 2 && num != 4) { + return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 " + "matrix"); + } + if (getLayout() != MMALayout::row) { + return emitOpError("expected layout to be row for 8x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 8x16 matrix"); + } + } else if (m == 16 && n == 16) { + if (num != 1 && num != 2) { + return emitOpError("expected num attribute to be 1 or 2 for 16x16 " + "matrix"); + } + if (getLayout() != MMALayout::col) { + return emitOpError("expected layout to be col for 16x16 matrix"); + } + if (getEltType() != LdStMatrixEltType::B8 && + getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 && + getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) { + return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or " + "b8x16.b6x16_p32 for 16x16 matrix"); + } + } else { + return emitOpError("expected shape to be 8x8, 8x16 or 16x16"); + } Type i32 = IntegerType::get(getContext(), 32); - if (getNum() == 1 && getType() != i32) + uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num); + if (numElements == 1 && getType() != i32) return emitOpError("expected destination type is i32"); - if (getNum() == 2 || getNum() == 4) { + if (numElements == 2 || numElements == 4) { Type dstType = LLVM::LLVMStructType::getLiteral( - getContext(), SmallVector<Type>(getNum(), i32)); + getContext(), SmallVector<Type>(numElements, i32)); if (getType() != dstType) return emitOpError("expected destination type is a structure of ") - << getNum() << " elements of type i32"; + << numElements << " elements of type i32"; } + return success(); } @@ -1089,7 +1184,7 @@ std::string NVVM::WgmmaMmaAsyncOp::getPtx() { return ptx; } -void NVVM::WgmmaMmaAsyncOp::getAsmValues( +bool NVVM::WgmmaMmaAsyncOp::getAsmValues( RewriterBase &rewriter, llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> &asmValues) { @@ -1120,7 +1215,9 @@ void NVVM::WgmmaMmaAsyncOp::getAsmValues( {makeConstantI32(rewriter, 1 - static_cast<int>(getLayoutB())), mlir::NVVM::PTXRegisterMod::Read}); } + return true; // Has manual mapping } + LogicalResult NVVM::FenceProxyOp::verify() { if (getKind() == NVVM::ProxyKind::TENSORMAP) return emitOpError() << "tensormap proxy is not a supported proxy kind"; @@ -1236,30 +1333,70 @@ LogicalResult NVVM::PrefetchOp::verify() { unsigned addressSpace = llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace(); std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority(); + std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel(); - if (getUniform()) { - if (getCacheLevel() != CacheLevel::L1) - return emitOpError("unsupported cache level, the only supported uniform " - "cache level is L1"); + if (getTensormap() && cacheLevel) + return emitOpError("cannot specify both tensormap and cache level"); - if (addressSpace != MemSpace::kGenericMemorySpace) + if (getTensormap()) { + if (addressSpace != MemSpace::kGenericMemorySpace && + addressSpace != MemSpace::kConstantMemorySpace) { return emitOpError( - "prefetch to uniform cache requires a generic pointer"); - } + "prefetch tensormap requires a generic or constant pointer"); + } - if (evictPriority) { - if (getCacheLevel() != CacheLevel::L2) + if (evictPriority) { return emitOpError( - "cache eviction priority supported only for cache level L2"); - - if (addressSpace != MemSpace::kGlobalMemorySpace) - return emitOpError("cache eviction priority requires a global pointer"); + "prefetch tensormap does not support eviction priority"); + } - if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal && - *evictPriority != NVVM::CacheEvictionPriority::EvictLast) + if (getInParamSpace() && addressSpace != MemSpace::kGenericMemorySpace) { return emitOpError( - "unsupported cache eviction priority, only evict_last and " - "evict_normal are supported"); + "in_param_space can only be specified for a generic pointer"); + } + + } else if (cacheLevel) { + if (addressSpace != MemSpace::kGenericMemorySpace && + addressSpace != MemSpace::kGlobalMemorySpace && + addressSpace != MemSpace::kLocalMemorySpace) { + return emitOpError("prefetch to cache level requires a generic, global, " + "or local pointer"); + } + + if (getUniform()) { + if (*cacheLevel != CacheLevel::L1) { + return emitOpError( + "unsupported cache level, the only supported uniform " + "cache level is L1"); + } + + if (addressSpace != MemSpace::kGenericMemorySpace) { + return emitOpError( + "prefetch to uniform cache requires a generic pointer"); + } + } + + if (evictPriority) { + if (*cacheLevel != CacheLevel::L2) + return emitOpError( + "cache eviction priority supported only for cache level L2"); + + if (addressSpace != MemSpace::kGlobalMemorySpace) + return emitOpError("cache eviction priority requires a global pointer"); + + if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal && + *evictPriority != NVVM::CacheEvictionPriority::EvictLast) + return emitOpError( + "unsupported cache eviction priority, only evict_last and " + "evict_normal are supported"); + } + + if (getPredicate()) + return emitOpError("predicate supported only on prefetch tensormap"); + + } else { + return emitOpError( + "requires specification of either cache level or tensormap"); } return success(); @@ -1399,28 +1536,102 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( return {id, std::move(args)}; } -llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, - bool isIm2Col) { - switch (tensorDims) { - case 1: - return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d; - case 2: - return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d; - case 3: - return isIm2Col - ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d - : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d; - case 4: - return isIm2Col - ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d - : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d; - case 5: - return isIm2Col - ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d - : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d; - default: - llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp."); - } +mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op); + llvm::SmallVector<llvm::Value *> args; + + // Fill the Intrinsic Args + args.push_back(mt.lookupValue(thisOp.getTmaDescriptor())); + + for (auto v : thisOp.getCoordinates()) + args.push_back(mt.lookupValue(v)); + for (auto v : thisOp.getIm2colOffsets()) + args.push_back(mt.lookupValue(v)); + + mlir::Value cacheHint = thisOp.getL2CacheHint(); + const bool hasCacheHint = static_cast<bool>(cacheHint); + llvm::Value *i64Unused = + llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); + args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); + args.push_back(builder.getInt1(hasCacheHint)); + + const unsigned NI = llvm::Intrinsic::not_intrinsic; + static constexpr llvm::Intrinsic::ID IDTable[][6] = { + {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d}, + {NI, NI, NI, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d}, + {NI, NI, NI, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d}, + {NI, NI, NI, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d}, + {NI, NI, NI, NI, NI, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d}}; + + static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1, + "TMALoadModes must match number of rows in IDTable"); + size_t mode = static_cast<size_t>(thisOp.getMode()); + size_t dim = thisOp.getCoordinates().size(); + llvm::Intrinsic::ID id = IDTable[mode][dim]; + if (id == llvm::Intrinsic::not_intrinsic) + llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp."); + + return {id, std::move(args)}; +} + +mlir::NVVM::IDArgPair +CpAsyncBulkTensorSharedCTAToGlobalOp::getIntrinsicIDAndArgs( + Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { + auto thisOp = cast<NVVM::CpAsyncBulkTensorSharedCTAToGlobalOp>(op); + llvm::SmallVector<llvm::Value *> args; + + // Fill the Intrinsic Args + args.push_back(mt.lookupValue(thisOp.getSrcMem())); + args.push_back(mt.lookupValue(thisOp.getTmaDescriptor())); + + for (auto v : thisOp.getCoordinates()) + args.push_back(mt.lookupValue(v)); + + mlir::Value cacheHint = thisOp.getL2CacheHint(); + const bool hasCacheHint = static_cast<bool>(cacheHint); + llvm::Value *i64Unused = + llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); + args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); + args.push_back(builder.getInt1(hasCacheHint)); + + const unsigned NI = llvm::Intrinsic::not_intrinsic; + static constexpr llvm::Intrinsic::ID IDTable[][6] = { + {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_1d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_2d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_5d}, + {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_3d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_4d, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_im2col_5d}, + {NI, NI, NI, NI, NI, + llvm::Intrinsic::nvvm_cp_async_bulk_tensor_s2g_tile_scatter4_2d}}; + + static_assert(getMaxEnumValForTMAStoreMode() == std::size(IDTable) - 1, + "TMAStoreModes must match number of rows in IDTable"); + size_t mode = static_cast<size_t>(thisOp.getMode()); + size_t dim = thisOp.getCoordinates().size(); + llvm::Intrinsic::ID id = IDTable[mode][dim]; + if (id == llvm::Intrinsic::not_intrinsic) + llvm_unreachable( + "Invalid intrinsic for CpAsyncBulkTensorSharedCTAToGlobalOp."); + + return {id, std::move(args)}; } #define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \ @@ -1794,26 +2005,47 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs( return {ids[type], args}; } -llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) { +static llvm::Value *getParamCastedAddr(llvm::Value *addr, + llvm::IRBuilderBase &builder) { + return builder.CreateAddrSpaceCast( + addr, + llvm::PointerType::get(builder.getContext(), + llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM)); +} + +NVVM::IDArgPair +PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op, + LLVM::ModuleTranslation &mt, + llvm::IRBuilderBase &builder) { using MemSpace = NVVM::NVVMMemorySpace; using CacheLevel = NVVM::PrefetchCacheLevel; - NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel(); + std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel(); std::optional<NVVM::CacheEvictionPriority> evictPriority = op.getEvictPriority(); unsigned addressSpace = llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType()) .getAddressSpace(); - if (op.getUniform() && cacheLevel == CacheLevel::L1) - return llvm::Intrinsic::nvvm_prefetchu_L1; + llvm::SmallVector<llvm::Value *> args; + llvm::Value *addr = mt.lookupValue(op.getAddr()); + args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder) + : addr); - if (evictPriority && cacheLevel == CacheLevel::L2) { + if (op.getTensormap()) + return {llvm::Intrinsic::nvvm_prefetch_tensormap, args}; + + assert(cacheLevel && "expected cache level for non-tensormap prefetch"); + + if (op.getUniform() && *cacheLevel == CacheLevel::L1) + return {llvm::Intrinsic::nvvm_prefetchu_L1, args}; + + if (evictPriority && *cacheLevel == CacheLevel::L2) { switch (*evictPriority) { case NVVM::CacheEvictionPriority::EvictLast: - return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last; + return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args}; case NVVM::CacheEvictionPriority::EvictNormal: - return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal; + return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args}; default: llvm_unreachable("Invalid cache eviction priority"); } @@ -1821,21 +2053,41 @@ llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) { switch (addressSpace) { case MemSpace::kGenericMemorySpace: - return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1 - : llvm::Intrinsic::nvvm_prefetch_L2; + return *cacheLevel == CacheLevel::L1 + ? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args}) + : NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args}); case MemSpace::kGlobalMemorySpace: - return cacheLevel == CacheLevel::L1 - ? llvm::Intrinsic::nvvm_prefetch_global_L1 - : llvm::Intrinsic::nvvm_prefetch_global_L2; + return *cacheLevel == CacheLevel::L1 + ? NVVM::IDArgPair( + {llvm::Intrinsic::nvvm_prefetch_global_L1, args}) + : NVVM::IDArgPair( + {llvm::Intrinsic::nvvm_prefetch_global_L2, args}); case MemSpace::kLocalMemorySpace: - return cacheLevel == CacheLevel::L1 - ? llvm::Intrinsic::nvvm_prefetch_local_L1 - : llvm::Intrinsic::nvvm_prefetch_local_L2; + return *cacheLevel == CacheLevel::L1 + ? NVVM::IDArgPair( + {llvm::Intrinsic::nvvm_prefetch_local_L1, args}) + : NVVM::IDArgPair( + {llvm::Intrinsic::nvvm_prefetch_local_L2, args}); default: llvm_unreachable("Invalid pointer address space"); } } +bool NVVM::InlinePtxOp::getAsmValues( + RewriterBase &rewriter, + llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>> + &asmValues) { + for (auto arg : getReadWriteArgs()) + asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::ReadWrite}); + for (auto arg : getResults()) + asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Write}); + for (auto arg : getReadOnlyArgs()) + asmValues.push_back({arg, mlir::NVVM::PTXRegisterMod::Read}); + if (getPredicate()) + asmValues.push_back({getPredicate(), mlir::NVVM::PTXRegisterMod::Read}); + return false; // No manual mapping needed +} + //===----------------------------------------------------------------------===// // NVVMDialect initialization, type parsing, and registration. //===----------------------------------------------------------------------===// @@ -1874,19 +2126,31 @@ LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, attrName == NVVMDialect::getReqntidAttrName() || attrName == NVVMDialect::getClusterDimAttrName()) { auto values = llvm::dyn_cast<DenseI32ArrayAttr>(attr.getValue()); - if (!values || values.empty() || values.size() > 3) + if (!values || values.empty() || values.size() > 3) { return op->emitError() << "'" << attrName << "' attribute must be integer array with maximum 3 index"; + } } // If minctasm / maxnreg / cluster_max_blocks exist, it must be an integer // attribute if (attrName == NVVMDialect::getMinctasmAttrName() || attrName == NVVMDialect::getMaxnregAttrName() || attrName == NVVMDialect::getClusterMaxBlocksAttrName()) { - if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) + if (!llvm::dyn_cast<IntegerAttr>(attr.getValue())) { return op->emitError() << "'" << attrName << "' attribute must be integer constant"; + } + } + // blocksareclusters must be used along with reqntid and cluster_dim + if (attrName == NVVMDialect::getBlocksAreClustersAttrName()) { + if (!op->hasAttr(NVVMDialect::getReqntidAttrName()) || + !op->hasAttr(NVVMDialect::getClusterDimAttrName())) { + return op->emitError() + << "'" << attrName << "' attribute must be used along with " + << "'" << NVVMDialect::getReqntidAttrName() << "' and " + << "'" << NVVMDialect::getClusterDimAttrName() << "'"; + } } return success(); diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp index 8317b67..23b4130 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/DIExpressionRewriter.cpp @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/LLVMIR/Transforms/DIExpressionRewriter.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" using namespace mlir; using namespace LLVM; @@ -63,9 +63,8 @@ DIExpressionRewriter::simplify(DIExpressionAttr expr, } if (maxNumRewrites && numRewrites >= *maxNumRewrites) { - LLVM_DEBUG(llvm::dbgs() - << "LLVMDIExpressionSimplifier exceeded max num rewrites (" - << maxNumRewrites << ")\n"); + LDBG() << "LLVMDIExpressionSimplifier exceeded max num rewrites (" + << maxNumRewrites << ")"; // Skip rewriting the rest. result.append(inputs.begin(), inputs.end()); } diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp index 18f85b6..4ea2ac9 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp +++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp @@ -235,8 +235,10 @@ getUnderlyingObjectSet(Value pointerValue) { WalkContinuation walkResult = walkSlice(pointerValue, [&](Value val) { // Attempt to advance to the source of the underlying view-like operation. // Examples of view-like operations include GEPOp and AddrSpaceCastOp. - if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>()) - return WalkContinuation::advanceTo(viewOp.getViewSource()); + if (auto viewOp = val.getDefiningOp<ViewLikeOpInterface>()) { + if (val == viewOp.getViewDest()) + return WalkContinuation::advanceTo(viewOp.getViewSource()); + } // Attempt to advance to control flow predecessors. std::optional<SmallVector<Value>> controlFlowPredecessors = diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 34c63d3..578931e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -194,9 +194,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state, ArrayRef<AffineMap> indexingMaps) { // Initialize indexingMaps attribute, for MatmulOp. SmallVector<Attribute, 3> indexingMapsAttrVal; - indexingMapsAttrVal = llvm::map_to_vector( - MatmulOp::getDefaultIndexingMaps(b.getContext()), - [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); }); + indexingMapsAttrVal = + llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + }); state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal)); return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs, attributes, regionBuilder); @@ -1569,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -// Retrieve the operation from the body, if it is the only one (except -// yield) and if it gets the same amount of arguments as the body does. -// If initFirst flag is enabled, we check that init takes the first position in -// operands of payload. -static Operation *findPayloadOp(Block *body, bool initFirst = false) { +static bool canUseShortForm(Block *body, bool initFirst = false) { + // Check if the body can be printed in short form. The following 4 conditions + // must be satisfied: + + // 1) The body must contain exactly 2 operations: the payload op and a yield. if (body->getOperations().size() != 2) - return nullptr; + return false; Operation &payload = body->getOperations().front(); - assert(isa<YieldOp>(body->getOperations().back())); + // 2) The payload op must have the same number of operands as the number of + // block arguments. if (payload.getNumOperands() == 0 || payload.getNumOperands() != body->getNumArguments()) - return nullptr; + return false; + + // 3) If `initFirst` is true (e.g., for reduction ops), the init block + // must be the first operand of the payload op, otherwise, the operands + // must match the block arguments in order. if (initFirst) { // check init if (payload.getOperands().back() != body->getArgument(0)) - return nullptr; + return false; // check rest for (const auto &[operand, bbArg] : llvm::zip(payload.getOperands(), body->getArguments().drop_front())) { if (bbArg != operand) - return nullptr; + return false; } } else { for (const auto &[operand, bbArg] : llvm::zip(payload.getOperands(), body->getArguments())) { if (bbArg != operand) - return nullptr; + return false; } } - return &payload; + + // 4) The `yield` operand must be the result of the payload op. + auto yieldOp = cast<YieldOp>(body->getTerminator()); + return yieldOp.getNumOperands() == 1 && + yieldOp.getOperand(0).getDefiningOp() && + yieldOp.getOperand(0).getDefiningOp() == &payload; } -void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { +static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { SmallVector<StringRef> elidedAttrs; std::string attrToElide; p << " { " << payloadOp->getName().getStringRef(); @@ -1621,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) { void MapOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - Operation *payloadOp = findPayloadOp(mapper); - if (payloadOp) { - printShortForm(p, payloadOp); + bool useShortForm = canUseShortForm(mapper); + if (useShortForm) { + printShortForm(p, &mapper->getOperations().front()); } printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); p.printOptionalAttrDict((*this)->getAttrs()); - if (!payloadOp) { + if (!useShortForm) { // Print region if the payload op was not detected. p.increaseIndent(); p.printNewline(); @@ -1828,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName, void ReduceOp::print(OpAsmPrinter &p) { Block *mapper = getBody(); - Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true); - if (payloadOp) { - printShortForm(p, payloadOp); + bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true); + if (useShortForm) { + printShortForm(p, &mapper->getOperations().front()); } printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits()); printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions()); p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()}); - if (!payloadOp) { + if (!useShortForm) { // Print region if the payload op was not detected. p.increaseIndent(); p.printNewline(); @@ -3749,6 +3760,25 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) { // MatMulOp //===----------------------------------------------------------------------===// +static FailureOr<SmallVector<SmallVector<int64_t>>> +getAffineResultPositions(ArrayAttr maps) { + SmallVector<SmallVector<int64_t>> positions; + for (auto map : maps) { + AffineMapAttr attr = dyn_cast<AffineMapAttr>(map); + if (!attr) + return failure(); + SmallVector<int64_t> pos; + for (auto result : attr.getAffineMap().getResults()) { + auto dim = dyn_cast<AffineDimExpr>(result); + if (!dim) + return failure(); + pos.push_back(dim.getPosition()); + } + positions.push_back(pos); + } + return positions; +} + /// Returns a list of AffineMap with the typical matmul indexing charactristic. SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) { AffineExpr d0, d1, d2; @@ -3760,6 +3790,20 @@ SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) { return indexingMaps; } +bool MatmulOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 2} && + (*positions)[1] == SmallVector<int64_t>{2, 1} && + (*positions)[2] == SmallVector<int64_t>{0, 1}; +} + SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() { return SmallVector<utils::IteratorType>{utils::IteratorType::parallel, utils::IteratorType::parallel, @@ -3836,7 +3880,7 @@ bool MatmulOp::isValidLhsRhsBroadcastMap(AffineMap bcastMap) { return expr.isFunctionOfDim(bcastMap.getNumDims() - 1); } -FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) { +static FailureOr<ArrayAttr> parseIndexingMapsAttr(OpAsmParser &parser) { if (parser.parseOptionalKeyword("indexing_maps")) return ArrayAttr{ nullptr}; // Success in case indexing_maps was not provided. @@ -3912,6 +3956,380 @@ Speculation::Speculatability MatmulOp::getSpeculatability() { return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation())); } +SmallVector<AffineMap> +MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2); + AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context); + AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context); + AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool MatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{2, 0} && + (*positions)[1] == SmallVector<int64_t>{2, 1} && + (*positions)[2] == SmallVector<int64_t>{0, 1}; +} + +void linalg::MatmulTransposeAOp::build(OpBuilder &builder, + OperationState &result, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeAOp +MatmulTransposeAOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeAOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeAOp +MatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeAOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeAOp +MatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool MatmulTransposeAOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::MatmulOp>(op) && + MatmulTransposeAOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + +SmallVector<AffineMap> +MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2); + AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context); + AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context); + AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool MatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 2} && + (*positions)[1] == SmallVector<int64_t>{1, 2} && + (*positions)[2] == SmallVector<int64_t>{0, 1}; +} + +void linalg::MatmulTransposeBOp::build(OpBuilder &builder, + OperationState &result, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeBOp +MatmulTransposeBOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeBOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeBOp +MatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::MatmulTransposeBOp::build(OpBuilder &builder, + OperationState &result, + TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder)); +} + +MatmulTransposeBOp +MatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool MatmulTransposeBOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::MatmulOp>(op) && + MatmulTransposeBOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + +SmallVector<AffineMap> +BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2, d3; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2, d3); + AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context); + AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context); + AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool BatchMatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} && + (*positions)[1] == SmallVector<int64_t>{0, 3, 2} && + (*positions)[2] == SmallVector<int64_t>{0, 1, 2}; +} + +void linalg::BatchMatmulTransposeAOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange outputs, ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeAOp +BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeAOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeAOp +BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeAOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeAOp +BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool BatchMatmulTransposeAOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::BatchMatmulOp>(op) && + BatchMatmulTransposeAOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + +SmallVector<AffineMap> +BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) { + AffineExpr d0, d1, d2, d3; + MLIRContext *context = builder.getContext(); + bindDims(context, d0, d1, d2, d3); + AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context); + AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context); + AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context); + return {mapLHS, mapRHS, mapOut}; +} + +bool BatchMatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} && + (*positions)[1] == SmallVector<int64_t>{0, 2, 3} && + (*positions)[2] == SmallVector<int64_t>{0, 1, 2}; +} + +void linalg::BatchMatmulTransposeBOp::build( + OpBuilder &builder, OperationState &result, ValueRange inputs, + ValueRange outputs, ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeBOp +BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeBOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeBOp +BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, attributes); + auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +void linalg::BatchMatmulTransposeBOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + result.addAttribute("cast", cast); + buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes, + BatchMatmulOp::getRegionBuilder(), + getDefaultIndexingMaps(builder)); +} + +BatchMatmulTransposeBOp +BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location, + TypeRange resultTensorTypes, ValueRange inputs, + ValueRange outputs, Attribute cast, + ArrayRef<NamedAttribute> attributes) { + OperationState state(location, getOperationName()); + build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes); + auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state)); + assert(res && "builder didn't return the right type"); + return res; +} + +bool BatchMatmulTransposeBOp::classof(Operation *op) { + return dyn_cast_or_null<linalg::BatchMatmulOp>(op) && + BatchMatmulTransposeBOp::isDefaultIndexingMaps( + op->getAttr("indexing_maps")); +} + //===----------------------------------------------------------------------===// // ContractOp //===----------------------------------------------------------------------===// @@ -4120,6 +4538,20 @@ BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) { return indexingMaps; } +bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} && + (*positions)[1] == SmallVector<int64_t>{0, 3, 2} && + (*positions)[2] == SmallVector<int64_t>{0, 1, 2}; +} + SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() { return SmallVector<utils::IteratorType>{ utils::IteratorType::parallel, utils::IteratorType::parallel, @@ -5042,7 +5474,7 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, /// Returns true if the tiles and the tiled dims are constant. template <typename OpTy> -bool areTilesAndTiledDimsAllConstant(OpTy op) { +static bool areTilesAndTiledDimsAllConstant(OpTy op) { static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, "applies to only pack or unpack operations"); ShapedType packedType = (std::is_same<OpTy, PackOp>::value) @@ -5345,11 +5777,18 @@ ArrayRef<int64_t> UnPackOp::getAllOuterDims() { SmallVector<int64_t> UnPackOp::getTiledOuterDims() { auto innerDimsPos = getInnerDimsPos(); - auto packedShape = getSourceType().getShape(); + SmallVector<int64_t> outerDims(getAllOuterDims()); SmallVector<int64_t> res; + // Recover the original order of the outer dims. + SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm()); + invertPermutationVector(outerDimPermInv); + if (!outerDimPermInv.empty()) + applyPermutationToVector(outerDims, outerDimPermInv); + + // Collect the outer dims corresponding to the tilled inner dims. for (auto index : innerDimsPos) - res.push_back(packedShape[index]); + res.push_back(outerDims[index]); return res; } @@ -5646,6 +6085,19 @@ BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) { return indexingMaps; } +bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) { + ArrayAttr maps = dyn_cast<ArrayAttr>(attr); + if (!maps) + return false; + if (maps.size() != 3) + return false; + auto positions = getAffineResultPositions(maps); + if (failed(positions)) + return false; + return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} && + (*positions)[1] == SmallVector<int64_t>{0, 3, 2} && + (*positions)[2] == SmallVector<int64_t>{1, 2}; +} unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; } std::string BatchReduceMatmulOp::getLibraryCallName() { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 639e0fe..f0c1f44 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -70,12 +70,7 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); // We want to discourage direct use of PatternRewriter in APIs but In this // very specific case, an IRRewriter is not enough. - struct TrivialPatternRewriter : public PatternRewriter { - public: - explicit TrivialPatternRewriter(MLIRContext *context) - : PatternRewriter(context) {} - }; - TrivialPatternRewriter rewriter(operation->getContext()); + PatternRewriter rewriter(operation->getContext()); rewriter.setInsertionPoint(operation); auto result = pattern.returningMatchAndRewrite(op, rewriter); if (failed(result)) diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp index 3908d73..6912da3f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp @@ -55,8 +55,8 @@ static bool validateFullTilesOnDims(linalg::LinalgOp linalgOp, // Skip the batch dimension if present. // Offset all dimensions accordingly. SmallVector<int64_t, 3> offsetDims(dims); - for (size_t i = 0; i < offsetDims.size(); i++) - offsetDims[i] += batchDimsOffset; + for (int64_t &offsetDim : offsetDims) + offsetDim += batchDimsOffset; auto tileOp = cast<TilingInterface>(linalgOp.getOperation()); OpBuilder builder(tileOp); @@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns( RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) { patterns.add<BlockPackMatmul<linalg::GenericOp>, BlockPackMatmul<linalg::MatmulOp>, - BlockPackMatmul<linalg::BatchMatmulOp>, - BlockPackMatmul<linalg::MatmulTransposeAOp>, - BlockPackMatmul<linalg::BatchMatmulTransposeAOp>, - BlockPackMatmul<linalg::MatmulTransposeBOp>, - BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>( - patterns.getContext(), controlFn); + BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(), + controlFn); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt index 6ec2e9fd..fb39e186 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -26,7 +26,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms MorphOps.cpp TransposeMatmul.cpp ShardingInterfaceImpl.cpp - NamedOpConversions.cpp + SimplifyDepthwiseConv.cpp NamedToElementwise.cpp BlockPackMatmul.cpp PackAndUnpackPatterns.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp index d1eb270..108abe8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertConv2DToImg2Col.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/IR/AffineExpr.h" @@ -50,28 +51,71 @@ static Value createMul(Location loc, Value x, Value y, Type accType, return arith::MulFOp::create(builder, loc, xConvert, yConvert); } -// Delinearizes the given composite `index` by the basis specified in `factors`. -static SmallVector<Value> unrollIndex(OpBuilder &b, Location loc, Value index, - ArrayRef<int64_t> factors) { - assert(!factors.empty() && "empty factor list"); - SmallVector<Value> basis; - for (int64_t f : factors) - basis.push_back(arith::ConstantOp::create(b, loc, b.getIndexAttr(f))); - FailureOr<SmallVector<Value>> multiIndex = - affine::delinearizeIndex(b, loc, index, basis); - assert(!failed(multiIndex) && "Failed to linearize img2col index"); - return *multiIndex; +// Generate the affine expression to compute the convolved index +// for the input as `oIndex * stride + fIndex`, +// where oIndex: output iterator; fIndex: filter iterator. +static AffineExpr getConvolvedExpr(OpBuilder &b, int64_t stride, + bool useSymbols = true) { + AffineExpr oExpr, fExpr; + if (useSymbols) + bindSymbols(b.getContext(), oExpr, fExpr); + else + bindDims(b.getContext(), oExpr, fExpr); + return AffineExpr(stride * oExpr + fExpr); } -// Given indices corresponding to iterators in the output (oIndex) and filter -// (fIndex) for a convolution, compute the convolved index for the -// input as `oIndex * stride + fIndex`. -static Value getConvolvedIndex(OpBuilder &b, Location loc, Value oIndex, - Value fIndex, int64_t stride) { - AffineExpr oExpr, fExpr; - bindSymbols(b.getContext(), oExpr, fExpr); - AffineMap convMap = AffineMap::get(0, 2, stride * oExpr + fExpr); - return affine::makeComposedAffineApply(b, loc, convMap, {oIndex, fIndex}); +// Stores the affine expressions to map the iteration space of the im2col matrix +// to the corresponding indices of the output and filter matrices +struct Im2ColToOperandsExprs { + AffineExpr fhIndex; + AffineExpr fwIndex; + AffineExpr icIndex; + AffineExpr ohIndex; + AffineExpr owIndex; +}; + +// Stores the affine expressions to map the iteration space of the im2col matrix +// to the input matrix indices +struct Im2ColToInputDimsExprs { + AffineExpr bIndex; + AffineExpr hIndex; + AffineExpr wIndex; + AffineExpr cIndex; +}; + +/// Construct the affine expressions that map the indices of the im2col matrix +/// to the corresponding input tensor indices for a 2D convolution with the the +/// provided strides. +/// +/// @param exprs Affine expressions for output and filter indices. +/// @param strides [height, width] stride values for the convolution. +/// @param rewriter Pattern rewriter. +/// @return Affine expressions mapping im2col matrix indices to input +/// offsets. +static Im2ColToInputDimsExprs +getIm2ColInputExpressions(Im2ColToOperandsExprs exprs, + ArrayRef<int64_t> strides, RewriterBase &rewriter) { + // maps the iteration space of the im2col matrix to (output_y, filter_y) + auto hIndicesMap = AffineMap::inferFromExprList( + {ArrayRef{exprs.ohIndex, exprs.fhIndex}}, rewriter.getContext())[0]; + // maps the iteration space of the im2col matrix to (output_x, filter_x) + auto wIndicesMap = AffineMap::inferFromExprList( + {ArrayRef{exprs.owIndex, exprs.fwIndex}}, rewriter.getContext())[0]; + // Compute the input indexing map, to map the indices of the im2col matrix to + // the original input offsets. Each element of the im2col matrix corresponds + // to a pair of (out_element, filter_element). First, we build the expressions + // to compute the input (ix, iy) indices from [out_x/y, filter_x/y] pairs; + // then we compose them with the maps that map the im2col matrix elements to + // the (out_element, filter_element) pairs. + auto bIndexExpr = rewriter.getAffineDimExpr(0U); + auto hIndexExpr = getConvolvedExpr(rewriter, strides[0], + /*useSymbols*/ false); + hIndexExpr = hIndexExpr.compose(hIndicesMap); + auto wIndexExpr = getConvolvedExpr(rewriter, strides[1], + /*useSymbols*/ false); + wIndexExpr = wIndexExpr.compose(wIndicesMap); + auto cIndexExpr = exprs.icIndex; + return {bIndexExpr, hIndexExpr, wIndexExpr, cIndexExpr}; } FailureOr<std::pair<Operation *, Operation *>> @@ -135,44 +179,37 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcHwcfOp convOp) { auto reduction = utils::IteratorType::reduction; SmallVector<utils::IteratorType> img2colIterators(nloops, parallel); + // Given an index of the im2col matrix, retrieve the corresponding indices of + // the output and filter matrices + auto mIndicesExprs = + delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1}); + auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U), + ArrayRef<int64_t>{fw * ic, ic, 1}); + Im2ColToOperandsExprs i2cToOperExprs; + i2cToOperExprs.fhIndex = kIndicesExprs[0]; + i2cToOperExprs.fwIndex = kIndicesExprs[1]; + i2cToOperExprs.icIndex = kIndicesExprs[2]; + i2cToOperExprs.ohIndex = mIndicesExprs[0]; + i2cToOperExprs.owIndex = mIndicesExprs[1]; + + // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] + Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions( + i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()), + rewriter); + auto inMap = + AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex, + inExprs.wIndex, inExprs.cIndex}}, + rewriter.getContext())[0]; + SmallVector<AffineMap> img2colIndexingMaps = { - AffineMap::getMultiDimIdentityMap(nloops, context)}; + inMap, AffineMap::getMultiDimIdentityMap(nloops, context)}; auto img2ColTensor = linalg::GenericOp::create( rewriter, loc, colTensor.getType(), - /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); - Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); - Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); - - // Recover the original iteration indices from the problem/input sizes. - SmallVector<Value> mIndices = unrollIndex( - nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow}); - auto ohIndex = mIndices[0]; - auto owIndex = mIndices[1]; - - SmallVector<Value> kIndices = unrollIndex( - nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic}); - auto fhIndex = kIndices[0]; - auto fwIndex = kIndices[1]; - auto icIndex = kIndices[2]; - - // Extract the input element corresponding to the expanded indices. - Value hIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, - convOp.getStrides().getValues<int64_t>()[0]); - Value wIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, - convOp.getStrides().getValues<int64_t>()[1]); - - // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] - SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex}; - Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, - extractionIndices); - linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); // Because the filter does not share the same batch dimension, @@ -421,44 +458,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNchwFchwOp convOp) { auto reduction = utils::IteratorType::reduction; SmallVector<utils::IteratorType, 3> img2colIterators(nloops, parallel); - SmallVector<AffineMap, 4> img2colIndexingMaps = { - AffineMap::getMultiDimIdentityMap(nloops, context)}; + // Recover the original iteration indices from the problem/input sizes: + // given an index of the im2col matrix, retrieve the corresponding indices of + // the output and filter matrices + auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(1U), + ArrayRef<int64_t>{fh * fw, fw, 1}); + auto mIndicesExprs = + delinearize(rewriter.getAffineDimExpr(2U), ArrayRef<int64_t>{ow, 1}); + Im2ColToOperandsExprs i2cToOperExprs; + i2cToOperExprs.icIndex = kIndicesExprs[0]; + i2cToOperExprs.fhIndex = kIndicesExprs[1]; + i2cToOperExprs.fwIndex = kIndicesExprs[2]; + i2cToOperExprs.ohIndex = mIndicesExprs[0]; + i2cToOperExprs.owIndex = mIndicesExprs[1]; + Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions( + i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()), + rewriter); + auto inMap = + AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.cIndex, + inExprs.hIndex, inExprs.wIndex}}, + rewriter.getContext())[0]; + // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw] + SmallVector<AffineMap> img2colIndexingMaps = { + inMap, AffineMap::getMultiDimIdentityMap(nloops, context)}; auto img2ColTensor = linalg::GenericOp::create( rewriter, loc, colTensor.getType(), - /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); - Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); - Value nIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); - - // Recover the original iteration indices from the problem/input sizes. - SmallVector<Value> kIndices = unrollIndex( - nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{ic, fh, fw}); - auto icIndex = kIndices[0]; - auto fhIndex = kIndices[1]; - auto fwIndex = kIndices[2]; - - SmallVector<Value> nIndices = unrollIndex( - nestedBuilder, nestedLoc, nIndex, ArrayRef<int64_t>{oh, ow}); - auto ohIndex = nIndices[0]; - auto owIndex = nIndices[1]; - - // Extract the input element corresponding to the expanded indices. - Value hIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, - convOp.getStrides().getValues<int64_t>()[0]); - Value wIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, - convOp.getStrides().getValues<int64_t>()[1]); - - // im2col[n, ic*fh*fw, oh*ow] = input[n, ic, sh*oh + fh, sw*ow + fw] - SmallVector<Value> extractionIndices{bIndex, icIndex, hIndex, wIndex}; - Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, - extractionIndices); - linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); // Because the filter does not share the same batch dimension, @@ -545,6 +574,7 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { Value reshapedOutput = tensor::CollapseShapeOp::create( rewriter, loc, reshapedOutputType, output, outputReassocIndices); + // Shape of the Toeplitz matrix produced by Im2col. SmallVector<int64_t> colTensorShape = {n, oh * ow, fh * fw * ic}; Value colTensor = tensor::EmptyOp::create(rewriter, loc, colTensorShape, inputType.getElementType()); @@ -556,44 +586,36 @@ rewriteInIm2Col(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp) { auto reduction = utils::IteratorType::reduction; SmallVector<utils::IteratorType> img2colIterators(nloops, parallel); + // Given an index of the im2col matrix, retrieve the corresponding indices of + // the output and filter matrices + auto mIndicesExprs = + delinearize(rewriter.getAffineDimExpr(1U), ArrayRef<int64_t>{ow, 1}); + auto kIndicesExprs = delinearize(rewriter.getAffineDimExpr(2U), + ArrayRef<int64_t>{fw * ic, ic, 1}); + Im2ColToOperandsExprs i2cToOperExprs; + i2cToOperExprs.fhIndex = kIndicesExprs[0]; + i2cToOperExprs.fwIndex = kIndicesExprs[1]; + i2cToOperExprs.icIndex = kIndicesExprs[2]; + i2cToOperExprs.ohIndex = mIndicesExprs[0]; + i2cToOperExprs.owIndex = mIndicesExprs[1]; + + // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] + Im2ColToInputDimsExprs inExprs = getIm2ColInputExpressions( + i2cToOperExprs, llvm::to_vector(convOp.getStrides().getValues<int64_t>()), + rewriter); + auto inMap = + AffineMap::inferFromExprList({ArrayRef{inExprs.bIndex, inExprs.hIndex, + inExprs.wIndex, inExprs.cIndex}}, + rewriter.getContext())[0]; SmallVector<AffineMap> img2colIndexingMaps = { - AffineMap::getMultiDimIdentityMap(nloops, context)}; + inMap, AffineMap::getMultiDimIdentityMap(nloops, context)}; auto img2ColTensor = linalg::GenericOp::create( rewriter, loc, colTensor.getType(), - /*inputs=*/ValueRange{}, /*outputs=*/colTensor, img2colIndexingMaps, + /*inputs=*/input, /*outputs=*/colTensor, img2colIndexingMaps, img2colIterators, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { - // Get the iterators named based on the matmul (batch, m, k). - Value bIndex = linalg::IndexOp::create(nestedBuilder, loc, 0); - Value mIndex = linalg::IndexOp::create(nestedBuilder, loc, 1); - Value kIndex = linalg::IndexOp::create(nestedBuilder, loc, 2); - - // Recover the original iteration indices from the problem/input sizes. - SmallVector<Value> mIndices = unrollIndex( - nestedBuilder, nestedLoc, mIndex, ArrayRef<int64_t>{oh, ow}); - auto ohIndex = mIndices[0]; - auto owIndex = mIndices[1]; - - SmallVector<Value> kIndices = unrollIndex( - nestedBuilder, nestedLoc, kIndex, ArrayRef<int64_t>{fh, fw, ic}); - auto fhIndex = kIndices[0]; - auto fwIndex = kIndices[1]; - auto icIndex = kIndices[2]; - - // Extract the input element corresponding to the expanded indices. - Value hIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, ohIndex, fhIndex, - convOp.getStrides().getValues<int64_t>()[0]); - Value wIndex = - getConvolvedIndex(nestedBuilder, nestedLoc, owIndex, fwIndex, - convOp.getStrides().getValues<int64_t>()[1]); - - // im2col[n, oh*ow, fh*fw*ic] = input[n, sh*oh + fh, sw*ow + fw, ic] - SmallVector<Value> extractionIndices{bIndex, hIndex, wIndex, icIndex}; - Value inputVal = tensor::ExtractOp::create(nestedBuilder, loc, input, - extractionIndices); - linalg::YieldOp::create(nestedBuilder, nestedLoc, inputVal); + linalg::YieldOp::create(nestedBuilder, nestedLoc, args[0]); }); // Because we didn't transpose the filters we don't actually have a batched diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 76ddee4..2ff7f46 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -75,7 +75,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, // layout for best compatibility. Value toBuffer = bufferization::ToBufferOp::create( b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), - tensorSource, /*readOnly=*/true); + tensorSource, /*read_only=*/true); memref::CopyOp::create(b, loc, toBuffer, memrefDest); } break; case linalg::BufferizeToAllocationOptions::MemcpyOp::LinalgCopy: { @@ -84,7 +84,7 @@ static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, // layout for best compatibility. Value toBuffer = bufferization::ToBufferOp::create( b, loc, bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType), - tensorSource, /*readOnly=*/true); + tensorSource, /*read_only=*/true); linalg::CopyOp::create(b, loc, toBuffer, memrefDest); } break; }; diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 0a9c176..40085a2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -6,10 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/UB/IR/UBOps.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/Dominance.h" #include "llvm/ADT/SetOperations.h" @@ -1236,6 +1238,272 @@ private: ControlPropagationFn controlFn; }; +// This struct contains infomation about extract_slice dims. +struct SliceDimInfo { + OpFoldResult offset; + OpFoldResult sliceSize; + OpFoldResult outputSize; +}; + +/// Return the first input extract slice operand, if present, for the current +/// generic op. +static FailureOr<OpOperand *> getSliceOperand(GenericOp genericOp) { + OpOperand *sliceOperand = nullptr; + for (auto operand : genericOp.getDpsInputOperands()) { + auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>(); + if (!extractOp) + continue; + sliceOperand = operand; + break; + } + if (!sliceOperand) { + return failure(); + } + return sliceOperand; +} + +// Return a map of dims that have partial slices on them so that other operands +// can use this information. Also return a bool mentioning if a reduction dim +// has a non full slice as that can be used to fold the original extract slice. +static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>> +getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) { + tensor::ExtractSliceOp producerSliceOp = + sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); + assert(producerSliceOp && "expect a valid ExtractSliceOp"); + llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap; + SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets(); + SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes(); + + SmallVector<OpFoldResult> shape = getAsIndexOpFoldResult( + genericOp.getContext(), producerSliceOp.getSourceType().getShape()); + + for (auto [idx, expr] : llvm::enumerate( + genericOp.getMatchingIndexingMap(sliceOperand).getResults())) { + // If we have a full slice in a dimension then we dont need to add it to + // the partial slice map. + if (isConstantIntValue(offsets[idx], 0) && + isEqualConstantIntOrValue(sizes[idx], shape[idx])) { + continue; + } + // We only support partial slices of AffineDimExprs so bail-out if thats not + // the case. + if (!isa<AffineDimExpr>(expr)) { + return failure(); + } + SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]}; + int64_t dimPos = cast<AffineDimExpr>(expr).getPosition(); + partialSliceDimMap[dimPos] = sliceDimInfo; + } + // Next check if the dims with partial slice info are used in non + // AffineDimExpr in other operands and if they are then bail-out. + for (OpOperand &operand : genericOp->getOpOperands()) { + if (operand == *sliceOperand) { + continue; + } + AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand); + if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) { + if (isa<AffineDimExpr>(expr)) { + return false; + } + WalkResult status = expr.walk([&](AffineExpr expr) { + if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) { + if (partialSliceDimMap.contains(dimExpr.getPosition())) { + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + if (status.wasInterrupted()) { + return true; + } + return false; + })) { + return failure(); + } + } + return partialSliceDimMap; +} + +static FailureOr<std::tuple<GenericOp, Value>> +pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter, + GenericOp genericOp, + ControlPropagationFn controlFn) { + if (genericOp.getNumResults() != 1) + return rewriter.notifyMatchFailure( + genericOp, "propagation through multi-result generic is unsupported."); + if (hasGatherSemantics(genericOp)) + return rewriter.notifyMatchFailure( + genericOp, + "propagation through generic with gather semantics is unsupported."); + // Collect the sliced operand, if present. + auto maybeSliceOperand = getSliceOperand(genericOp); + if (failed(maybeSliceOperand)) + return failure(); + OpOperand *sliceOperand = *maybeSliceOperand; + unsigned OperandIndex = sliceOperand->getOperandNumber(); + + if (!controlFn(sliceOperand)) + return failure(); + + tensor::ExtractSliceOp producerSliceOp = + sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>(); + assert(producerSliceOp && "expect a valid ExtractSliceOp"); + + if (producerSliceOp.getSource().getType().getRank() != + producerSliceOp.getResult().getType().getRank()) { + return rewriter.notifyMatchFailure( + genericOp, + "propagation of rank-reducing extract slice is unsupported."); + } + + SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides(); + if (!areAllConstantIntValue(strides, 1)) + return rewriter.notifyMatchFailure( + genericOp, "propagation of strided extract slice is unsupported."); + + // check if we can support the propagation of this extractSlice + // through the generic op and if so return the dimensions that + + auto maybePartialSliceDimMap = + getPartialSliceDimInfo(genericOp, sliceOperand); + + if (failed(maybePartialSliceDimMap)) { + return failure(); + } + + auto partialSliceDimMap = *maybePartialSliceDimMap; + + SmallVector<utils::IteratorType> iterators = + genericOp.getIteratorTypesArray(); + bool hasPartialReductionDimSlice = + llvm::any_of(partialSliceDimMap, [&](const auto &slice) { + int64_t sliceDim = slice.first; + return iterators[sliceDim] == utils::IteratorType::reduction; + }); + + // Store the padding information as (dimPos, lowPad, highPad, PaddedShape). + Location loc = genericOp->getLoc(); + AffineExpr dim0, dim1; + bindDims(rewriter.getContext(), dim0, dim1); + auto subMap = AffineMap::get(2, 0, {dim0 - dim1}); + auto sub = [&](OpFoldResult v1, OpFoldResult v2) { + return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap, + {v1, v2}); + }; + + MLIRContext *ctx = genericOp.getContext(); + SmallVector<Value> paddedInputs; + for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) { + if (idx == OperandIndex && !hasPartialReductionDimSlice) { + paddedInputs.push_back(producerSliceOp.getSource()); + continue; + } + AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand); + SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) { + if (!isa<AffineDimExpr>(expr)) { + continue; + } + AffineDimExpr dimExpr = cast<AffineDimExpr>(expr); + if (!partialSliceDimMap.contains(dimExpr.getPosition())) { + continue; + } + SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()]; + operandLowPads[idx] = sliceDimInfo.offset; + operandHighPads[idx] = + sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); + } + auto paddingValue = ub::PoisonOp::create( + rewriter, loc, getElementTypeOrSelf(operand->get().getType())); + auto paddedOperand = tensor::PadOp::create( + rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads, + paddingValue, /*nofold=*/false); + paddedInputs.push_back(paddedOperand); + } + AffineMap outputIndexingMap = + genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0)); + + auto outputShapeType = + llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType()); + SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector( + outputShapeType.getShape(), + [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); }); + SmallVector<OpFoldResult> newSizes = OutputShape; + SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 0)); + SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(), + getAsIndexOpFoldResult(ctx, 1)); + for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) { + if (!isa<AffineDimExpr>(expr)) { + continue; + } + AffineDimExpr dimExpr = cast<AffineDimExpr>(expr); + if (!partialSliceDimMap.contains(dimExpr.getPosition())) { + continue; + } + SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()]; + outputLowPads[idx] = sliceDimInfo.offset; + outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset), + sliceDimInfo.sliceSize); + OutputShape[idx] = sliceDimInfo.outputSize; + newSizes[idx] = sliceDimInfo.sliceSize; + } + Value newPadOutput; + auto outputElType = + getElementTypeOrSelf(genericOp.getDpsInits()[0].getType()); + if (isGenericOutsNotUsed(genericOp)) { + newPadOutput = + tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType); + } else { + auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType); + newPadOutput = tensor::PadOp::create( + rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads, + outputHighPads, paddingValue, /*nofold=*/false); + } + + auto newGenericOp = linalg::GenericOp::create( + rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput}, + genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(), + /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp)); + rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), + newGenericOp.getRegion().begin()); + + auto extractOp = tensor::ExtractSliceOp::create( + rewriter, loc, + newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)), + outputLowPads, newSizes, newStrides); + Value extractRes = extractOp.getResult(); + + return std::make_tuple(newGenericOp, extractRes); +} + +class PushDownExtractSliceOpThroughGenericOp final + : public OpRewritePattern<GenericOp> { +public: + PushDownExtractSliceOpThroughGenericOp(MLIRContext *context, + ControlPropagationFn fun) + : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + auto genericAndRepl = + pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn); + if (failed(genericAndRepl)) + return failure(); + rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl)); + return success(); + } + +private: + ControlPropagationFn controlFn; +}; + } // namespace void mlir::linalg::populateDataLayoutPropagationPatterns( @@ -1247,3 +1515,10 @@ void mlir::linalg::populateDataLayoutPropagationPatterns( PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>( patterns.getContext(), controlPackUnPackPropagation); } + +void mlir::linalg::populateExtractSliceSinkingPatterns( + RewritePatternSet &patterns, + const ControlPropagationFn &controlPackUnPackPropagation) { + patterns.insert<PushDownExtractSliceOpThroughGenericOp>( + patterns.getContext(), controlPackUnPackPropagation); +} diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index bf66ed0..22690da 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -691,9 +691,9 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> { auto newResultType = RankedTensorType::get( newResultShape, padOp.getResultType().getElementType()); - auto newPadOp = rewriter.create<tensor::PadOp>( - padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad, - newHighPad, paddingVal, padOp.getNofold()); + auto newPadOp = tensor::PadOp::create( + rewriter, padOp.getLoc(), /*result=*/newResultType, collapsedSource, + newLowPad, newHighPad, paddingVal, padOp.getNofold()); Value dest = padOp.getResult(); if (options.rankReductionStrategy == @@ -1052,12 +1052,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> { static bool constexpr reduceLeft = (std::is_same_v<FromOpTy, BatchMatmulOp> && std::is_same_v<ToOpTy, BatchVecmatOp>) || - (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> && - std::is_same_v<ToOpTy, BatchVecmatOp>) || (std::is_same_v<FromOpTy, MatmulOp> && std::is_same_v<ToOpTy, VecmatOp>) || - (std::is_same_v<FromOpTy, MatmulTransposeAOp> && - std::is_same_v<ToOpTy, VecmatOp>) || (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>); /// Look for non-batch spatial dims to collapse. @@ -1113,27 +1109,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns( MLIRContext *context = patterns.getContext(); // Unbatching patterns for unit batch size patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context); - patterns - .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>( - context); - patterns - .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>( - context); patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context); patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context); // Non-batch rank 1 reducing patterns patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context); patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context); - patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context); - patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context); // Batch rank 1 reducing patterns patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context); patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context); - patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>( - context); - patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>( - context); // Non-batch rank 0 reducing patterns patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp index c523153..baf4083 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -20,13 +20,26 @@ namespace mlir { using namespace mlir; +static inline bool isScalarLike(Type t) { + return isa<IntegerType, FloatType, IndexType, ComplexType>(t); +} + static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { if (!OpTrait::hasElementwiseMappableTraits(op)) return false; - // TODO: The conversion pattern can be made to work for `any_of` here, but - // it's more complex as it requires tracking which operands are scalars. - return llvm::all_of(op->getOperandTypes(), llvm::IsaPred<RankedTensorType>); + auto types = op->getOperandTypes(); + + // We want at least one ranked tensor. + bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred<RankedTensorType>); + + // No invalid operands (i.e., every operand is a ranked tensor or + // scalar-like). + bool noneInvalid = llvm::none_of(types, [](Type t) { + return !(isa<RankedTensorType>(t) || isScalarLike(t)); + }); + + return anyRankedTensor && noneInvalid; } /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over @@ -81,13 +94,41 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { return rewriter.notifyMatchFailure( op, "requires elementwise op on ranked tensors"); - auto rank = cast<RankedTensorType>(op->getResult(0).getType()).getRank(); - SmallVector<AffineMap, 3> indexingMaps( - op->getNumResults() + op->getNumOperands(), - rewriter.getMultiDimIdentityMap(rank)); - SmallVector<utils::IteratorType, 6> iteratorTypes( + auto resTy = cast<RankedTensorType>(op->getResult(0).getType()); + auto rank = resTy.getRank(); + + // Maps: identity for tensors (rank > 0), scalar map for scalars. + AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, + /*results=*/{}, rewriter.getContext()); + AffineMap idMap = rewriter.getMultiDimIdentityMap(rank); + + // Match phase. + SmallVector<bool> isScalarOperand; + isScalarOperand.reserve(op->getNumOperands()); + for (Type ty : op->getOperandTypes()) { + if (isScalarLike(ty)) + isScalarOperand.push_back(true); + else if (auto rt = dyn_cast<RankedTensorType>(ty)) + isScalarOperand.push_back(false); + else + return rewriter.notifyMatchFailure( + op, + "unsupported operand type (expected scalar-like or ranked tensor)"); + } + + // Create indexing maps. + SmallVector<AffineMap> indexingMaps; + indexingMaps.reserve(op->getNumOperands() + op->getNumResults()); + + for (bool isScalar : isScalarOperand) + indexingMaps.push_back(isScalar ? scalarMap : idMap); + + indexingMaps.append(op->getNumResults(), idMap); + + SmallVector<utils::IteratorType> iteratorTypes( rank, utils::IteratorType::parallel); - auto outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); + SmallVector<Value> outputs = + getOrCreateOperandsMatchingResultTypes(rewriter, op); rewriter.replaceOpWithNewOp<linalg::GenericOp>( op, /*resultTensorTypes=*/op->getResultTypes(), /*inputs=*/op->getOperands(), @@ -96,14 +137,14 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { /*iteratorTypes=*/iteratorTypes, /*bodyBuilder=*/ [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { - auto resultTypes = llvm::to_vector<6>( + SmallVector<Type> resultEltTys = llvm::to_vector<6>( llvm::map_range(op->getResultTypes(), [](Type type) { return cast<TensorType>(type).getElementType(); })); - auto *scalarOp = + Operation *scalarOp = builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), - resultTypes, op->getAttrs()); + resultEltTys, op->getAttrs()); linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index fd530f2..9436f1c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -594,7 +594,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl( auto clonedForOp = scf::ForOp::create( rewriter, loc, bvm.lookupOrDefault(forOp.getLowerBound()), bvm.lookupOrDefault(forOp.getUpperBound()), - bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor); + bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); // Map the induction var, region args and results to the `clonedForOp`. bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp index 58986a6..36434cf 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp @@ -55,7 +55,8 @@ static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter, scf::ForOp newLoop = scf::ForOp::create( rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), - loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); + loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}, + loop.getUnsignedCmp()); // Generate the new yield with the replaced operand. auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator()); @@ -165,8 +166,12 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, Value source = transferRead.getBase(); // Skip view-like Ops and retrive the actual soruce Operation - while (auto srcOp = source.getDefiningOp<ViewLikeOpInterface>()) - source = srcOp.getViewSource(); + while (auto viewLike = source.getDefiningOp<ViewLikeOpInterface>()) { + if (viewLike.getViewDest() != source) { + break; + } + source = viewLike.getViewSource(); + } llvm::SmallVector<Operation *, 32> users(source.getUsers().begin(), source.getUsers().end()); @@ -177,7 +182,8 @@ static bool noAliasingUseInLoop(vector::TransferReadOp transferRead, if (!processed.insert(user).second) continue; if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) { - users.append(viewLike->getUsers().begin(), viewLike->getUsers().end()); + Value viewDest = viewLike.getViewDest(); + users.append(viewDest.getUsers().begin(), viewDest.getUsers().end()); continue; } if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 3d12bc3..8942670 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -263,11 +263,11 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, paddingValue, /*nofold=*/false, dynDims); } -FailureOr<TilingInterface> -linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, - const PadTilingInterfaceOptions &constOptions, - SmallVector<tensor::PadOp> &padOps, - PadSizeComputationFunction computePaddingSizeFun) { +FailureOr<TilingInterface> linalg::rewriteAsPaddedOp( + RewriterBase &rewriter, TilingInterface opToPad, + const PadTilingInterfaceOptions &constOptions, + SmallVector<tensor::PadOp> &padOps, + const PadSizeComputationFunction &computePaddingSizeFun) { LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); Location loc = opToPad.getLoc(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp index a2bd9d9..27ccf3c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/NamedOpConversions.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/SimplifyDepthwiseConv.cpp @@ -21,7 +21,7 @@ #include "llvm/ADT/TypeSwitch.h" namespace mlir { -#define GEN_PASS_DEF_LINALGNAMEDOPCONVERSIONPASS +#define GEN_PASS_DEF_SIMPLIFYDEPTHWISECONVPASS #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir @@ -143,23 +143,22 @@ struct SimplifyDepthwiseConvQOp } }; -struct LinalgNamedOpConversionPass - : public impl::LinalgNamedOpConversionPassBase< - LinalgNamedOpConversionPass> { - using impl::LinalgNamedOpConversionPassBase< - LinalgNamedOpConversionPass>::LinalgNamedOpConversionPassBase; +struct SimplifyDepthwiseConvPass + : public impl::SimplifyDepthwiseConvPassBase<SimplifyDepthwiseConvPass> { + using impl::SimplifyDepthwiseConvPassBase< + SimplifyDepthwiseConvPass>::SimplifyDepthwiseConvPassBase; void runOnOperation() override { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); - populateLinalgNamedOpConversionPatterns(patterns); + populateSimplifyDepthwiseConvPatterns(patterns); if (failed(applyPatternsGreedily(op, std::move(patterns)))) return signalPassFailure(); } }; } // namespace -void mlir::linalg::populateLinalgNamedOpConversionPatterns( +void mlir::linalg::populateSimplifyDepthwiseConvPatterns( RewritePatternSet &patterns) { patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>( patterns.getContext()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 455e1a6..35ba4f15 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter, /// Codegen the different matmul variants. if (numOfBatchDims) { - if (a == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter, - genericOp); - if (b == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter, - genericOp); return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp); } - - if (a == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp); - if (b == IndexMatchResult::Transposed) - return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp); return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index bb725f2..e9a8b25 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -29,6 +29,7 @@ #include "mlir/Support/LLVM.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/InterleavedRange.h" #include "llvm/Support/raw_ostream.h" #include <utility> @@ -38,9 +39,6 @@ using namespace mlir; using namespace mlir::linalg; -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") -#define DBGSNL() (llvm::dbgs() << "\n") - //===----------------------------------------------------------------------===// // Transformations exposed as functional-style API calls. //===----------------------------------------------------------------------===// @@ -91,11 +89,11 @@ static bool hasAtMostOneResultFunctionOfDim(AffineMap map, int64_t dim) { } return true; } +#endif // NDEBUG static std::string stringifyReassocIndices(ReassociationIndicesRef ri) { return llvm::interleaved(ri, ", ", /*Prefix=*/"|", /*Suffix=*/""); } -#endif // NDEBUG /// Return the index of the first result of `map` that is a function of /// AffineDimExpr(dim), std::nullopt otherwise. @@ -276,23 +274,18 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, tensor::PadOp::create(rewriter, loc, collapsed, packOp.getSource(), lows, highs, paddingValue, /*nofold=*/false); - LLVM_DEBUG( - DBGSNL(); DBGSNL(); - DBGS() << "insertPositions: " - << llvm::interleaved(packingMetadata.insertPositions); - DBGSNL(); DBGS() << "outerPositions: " - << llvm::interleaved(packingMetadata.outerPositions); - DBGSNL(); DBGS() << "packedShape: " - << llvm::interleaved(packedTensorType.getShape()); - DBGSNL(); DBGS() << "packedToStripMinedShapePerm: " - << llvm::interleaved(packedToStripMinedShapePerm); - DBGSNL(); - DBGS() << "reassociations: " - << llvm::interleaved(llvm::map_range( - packingMetadata.reassociations, stringifyReassocIndices)); - DBGSNL(); - DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); - DBGSNL(); DBGS() << "collapsed type: " << collapsed; DBGSNL();); + LDBG() << "insertPositions: " + << llvm::interleaved(packingMetadata.insertPositions); + LDBG() << "outerPositions: " + << llvm::interleaved(packingMetadata.outerPositions); + LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape()); + LDBG() << "packedToStripMinedShapePerm: " + << llvm::interleaved(packedToStripMinedShapePerm); + LDBG() << "reassociations: " + << llvm::interleaved(llvm::map_range(packingMetadata.reassociations, + stringifyReassocIndices)); + LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); + LDBG() << "collapsed type: " << collapsed; if (lowerPadLikeWithInsertSlice && packOp.isLikePad()) { // Pack ops which operate as simple pads may not produce legal @@ -317,7 +310,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, rewriter, loc, /*source=*/padOp, /*dest=*/packOp.getDest(), /*offsets=*/zeros, sizes, /*strides=*/ones); - LLVM_DEBUG(DBGS() << "insert_slice op: " << insertSliceOp; DBGSNL();); + LDBG() << "insert_slice op: " << insertSliceOp; rewriter.replaceOp(packOp, insertSliceOp->getResults()); @@ -339,10 +332,9 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter, auto transposeOp = linalg::TransposeOp::create( rewriter, loc, reshapeOp.getResult(), packOp.getDest(), transpPerm); - LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); - DBGS() << "reshape op: " << reshapeOp; DBGSNL(); - DBGS() << "transpPerm: " << llvm::interleaved(transpPerm); - DBGSNL(); DBGS() << "transpose op: " << transposeOp; DBGSNL();); + LDBG() << "reshape op: " << reshapeOp; + LDBG() << "transpPerm: " << llvm::interleaved(transpPerm); + LDBG() << "transpose op: " << transposeOp; // 7. Replace packOp by transposeOp. rewriter.replaceOp(packOp, transposeOp->getResults()); @@ -410,21 +402,16 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, linalg::TransposeOp::create(rewriter, loc, unPackOp.getSource(), emptyOp, packedToStripMinedShapePerm); - LLVM_DEBUG( - DBGSNL(); DBGSNL(); - DBGS() << "insertPositions: " - << llvm::interleaved(packingMetadata.insertPositions); - DBGSNL(); DBGS() << "packedShape: " - << llvm::interleaved(packedTensorType.getShape()); - DBGSNL(); DBGS() << "packedToStripMinedShapePerm: " - << llvm::interleaved(packedToStripMinedShapePerm); - DBGSNL(); - DBGS() << "reassociations: " - << llvm::interleaved(llvm::map_range( - packingMetadata.reassociations, stringifyReassocIndices)); - DBGSNL(); - DBGS() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); - DBGSNL(); DBGS() << "collapsed type: " << collapsedType; DBGSNL();); + LDBG() << "insertPositions: " + << llvm::interleaved(packingMetadata.insertPositions); + LDBG() << "packedShape: " << llvm::interleaved(packedTensorType.getShape()); + LDBG() << "packedToStripMinedShapePerm: " + << llvm::interleaved(packedToStripMinedShapePerm); + LDBG() << "reassociations: " + << llvm::interleaved(llvm::map_range(packingMetadata.reassociations, + stringifyReassocIndices)); + LDBG() << "stripMinedShape: " << llvm::interleaved(stripMinedShape); + LDBG() << "collapsed type: " << collapsedType; // 4. Collapse from the stripMinedShape to the padded result. auto reshapeOp = tensor::CollapseShapeOp::create( @@ -486,10 +473,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray(); SmallVector<utils::IteratorType> iteratorTypes = linalgOp.getIteratorTypesArray(); - LLVM_DEBUG(DBGS() << "Start packing: " << linalgOp << "\n" - << "maps: " << llvm::interleaved(indexingMaps) << "\n" - << "iterators: " << llvm::interleaved(iteratorTypes) - << "\n"); + LDBG() << "Start packing: " << linalgOp; + LDBG() << "maps: " << llvm::interleaved(indexingMaps); + LDBG() << "iterators: " << llvm::interleaved(iteratorTypes); SmallVector<linalg::PackOp> packOps; SmallVector<linalg::UnPackOp> unPackOps; @@ -511,14 +497,11 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, packedOperandsDims.packedDimForEachOperand = *maybePackedDimForEachOperand; listOfPackedOperandsDim.pushBack(std::move(packedOperandsDims)); - LLVM_DEBUG( - DBGS() << "++++ After pack size #" << i << ": " << packedSizes[i] - << "\n" - << "maps: " << llvm::interleaved(indexingMaps) << "\n" - << "iterators: " << llvm::interleaved(iteratorTypes) << "\n" - << "packedDimForEachOperand: " - << llvm::interleaved(packedOperandsDims.packedDimForEachOperand) - << "\n"); + LDBG() << "++++ After pack size #" << i << ": " << packedSizes[i]; + LDBG() << "maps: " << llvm::interleaved(indexingMaps); + LDBG() << "iterators: " << llvm::interleaved(iteratorTypes); + LDBG() << "packedDimForEachOperand: " + << llvm::interleaved(packedOperandsDims.packedDimForEachOperand); } // Step 2. Propagate packing to all LinalgOp operands. @@ -534,10 +517,9 @@ FailureOr<PackResult> linalg::pack(RewriterBase &rewriter, listOfPackedOperandsDim.extractPackedDimsForOperand(pos); SmallVector<OpFoldResult> innerPackSizes = listOfPackedOperandsDim.extractPackSizesForOperand(pos); - LLVM_DEBUG(DBGS() << "operand: " << operand << "\n" - << "innerPos: " << llvm::interleaved(innerPos) << "\n" - << "innerPackSizes: " - << llvm::interleaved(innerPackSizes) << "\n"); + LDBG() << "operand: " << operand; + LDBG() << "innerPos: " << llvm::interleaved(innerPos); + LDBG() << "innerPackSizes: " << llvm::interleaved(innerPackSizes); if (innerPackSizes.empty()) { inputsAndInits.push_back(operand); continue; @@ -776,8 +758,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, int64_t numLoops = linalgOp.getNumLoops(); if (numLoops <= 2) { - LLVM_DEBUG(DBGS() << "need 3+ loops to find a matmul to pack, got " - << numLoops << "\nin: " << linalgOp << "\n"); + LDBG() << "need 3+ loops to find a matmul to pack, got " << numLoops + << " in: " << linalgOp; return rewriter.notifyMatchFailure( linalgOp, "need 3+ loops to find a matmul to pack"); } @@ -801,8 +783,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, FailureOr<ContractionDimensions> maybeDimensions = inferContractionDims(linalgOp); if (failed(maybeDimensions)) { - LLVM_DEBUG(DBGS() << "couldn't infer matmul iterators in: " << linalgOp - << "\n"); + LDBG() << "couldn't infer matmul iterators in: " << linalgOp; return rewriter.notifyMatchFailure(linalgOp, "couldn't infer matmul iterators"); } @@ -814,10 +795,8 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, // to plug a heuristic. int64_t mPos = maybeDimensions->m.back(), nPos = maybeDimensions->n.back(), kPos = maybeDimensions->k.back(); - LLVM_DEBUG(DBGSNL(); DBGSNL(); DBGSNL(); - DBGS() << "Start packing generic op greedily with (m@" << mPos - << ", n@" << nPos << ", k@" << kPos << "): " << linalgOp - << "\n";); + LDBG() << "Start packing generic op greedily with (m@" << mPos << ", n@" + << nPos << ", k@" << kPos << "): " << linalgOp; // 2.a. Rewrite as a generic. auto genericOp = dyn_cast<GenericOp>(linalgOp.getOperation()); @@ -833,14 +812,14 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, // not change the indexings of any operand. SmallVector<int64_t> permutation = computePermutationVector(numLoops, {mPos, nPos, kPos}, mmnnkkPos); - LLVM_DEBUG(DBGS() << "perm: " << llvm::interleaved(permutation) << "\n"); + LDBG() << "perm: " << llvm::interleaved(permutation); // Sign .. unsigned pollution. SmallVector<unsigned> unsignedPerm(permutation.begin(), permutation.end()); FailureOr<GenericOp> interchangeResult = interchangeGenericOp(rewriter, genericOp, unsignedPerm); assert(succeeded(interchangeResult) && "unexpected failure interchanging op"); genericOp = *interchangeResult; - LLVM_DEBUG(DBGS() << "Generalized Op to pack: " << genericOp << "\n";); + LDBG() << "Generalized Op to pack: " << genericOp; // At this point, the op iterators are normalized to {leading, k, m, n}. // The layouts induced by packing will always be: @@ -862,12 +841,11 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, // Add leading zeros to match numLoops, we only pack the last 3 dimensions // post interchange. - LLVM_DEBUG(DBGS() << "paddedSizesNextMultipleOf: " - << llvm::interleaved(paddedSizesNextMultipleOf) << "\n" - << "loopRanges: " - << llvm::interleaved(llvm::map_range( - loopRanges, [](Range r) { return r.size; })) - << "\n"); + LDBG() << "paddedSizesNextMultipleOf: " + << llvm::interleaved(paddedSizesNextMultipleOf); + LDBG() << "loopRanges: " + << llvm::interleaved( + llvm::map_range(loopRanges, [](Range r) { return r.size; })); SmallVector<OpFoldResult> adjustedPackedSizes(numLoops - packedSizes.size(), rewriter.getIndexAttr(0)); for (int64_t i = 0, e = numPackedDims; i < e; ++i) { @@ -883,8 +861,7 @@ linalg::packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, {loopRanges[adjustedPackedSizes.size()].size, rewriter.getIndexAttr(paddedSizesNextMultipleOf[i])})); } - LLVM_DEBUG(DBGS() << "adjustedPackedSizes: " - << llvm::interleaved(adjustedPackedSizes) << "\n"); + LDBG() << "adjustedPackedSizes: " << llvm::interleaved(adjustedPackedSizes); // TODO: If we wanted to give the genericOp a name after packing, after // calling `pack` would be a good time. One would still need to check that @@ -1214,9 +1191,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( } srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end()); - LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n" - << "perm: " << llvm::interleaved(srcPermForTranspose) - << "\n"); + LDBG() << "Pack permutation: " << packOp; + LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose); // 2.1 Create tensor.empty (init value for TransposeOp) SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles, diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp index a2a4335..2650488 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp @@ -59,12 +59,12 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter, ArrayRef<int64_t>{1, 0}); Operation *newMatmulOp; if (transposeLHS) { - newMatmulOp = linalg::MatmulTransposeAOp::create( + newMatmulOp = MatmulTransposeAOp::create( rewriter, loc, matmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]}, matmulOp.getOutputs()); } else { - newMatmulOp = linalg::MatmulTransposeBOp::create( + newMatmulOp = MatmulTransposeBOp::create( rewriter, loc, matmulOp.getResultTypes(), ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)}, matmulOp.getOutputs()); @@ -116,12 +116,12 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter, ArrayRef<int64_t>{0, 2, 1}); Operation *newMatmulOp; if (transposeLHS) { - newMatmulOp = linalg::BatchMatmulTransposeAOp::create( + newMatmulOp = BatchMatmulTransposeAOp::create( rewriter, loc, batchMatmulOp.getResultTypes(), ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]}, batchMatmulOp.getOutputs()); } else { - newMatmulOp = linalg::BatchMatmulTransposeBOp::create( + newMatmulOp = BatchMatmulTransposeBOp::create( rewriter, loc, batchMatmulOp.getResultTypes(), ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)}, batchMatmulOp.getOutputs()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index cf65e67..406f05c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -2563,7 +2563,7 @@ vectorizeScalableVectorPrecondition(Operation *op, "vectorization"; return failure(); } - if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) { + if (isa<linalg::MatmulOp>(op)) { LDBG() << "Scalable vectorization of the reduction dim in Matmul-like ops " "is not supported"; @@ -2604,17 +2604,12 @@ vectorizeScalableVectorPrecondition(Operation *op, return failure(); } - // Check to not let go the matmul with extended semantic, through this - // transform. - if (linalgOp.hasUserDefinedMaps()) - return failure(); - // Cond 4: Only the following ops are supported in the // presence of scalable vectors return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) || - isa<linalg::MatmulTransposeAOp>(op) || isa<linalg::DepthwiseConv1DNwcWcOp>(op) || isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) || + isa<linalg::BatchMmt4DOp>(op) || hasReductionIterator(linalgOp)); } diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index e1c0c24..d37a056 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -1,6 +1,6 @@ add_mlir_dialect_library(MLIRMathTransforms AlgebraicSimplification.cpp - ExpandPatterns.cpp + ExpandOps.cpp ExtendToSupportedTypes.cpp PolynomialApproximation.cpp UpliftToFMA.cpp diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp index 4a40a30..cd68039 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp @@ -13,14 +13,18 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; +namespace mlir::math { +#define GEN_PASS_DEF_MATHEXPANDOPSPASS +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +} // namespace mlir::math + /// Create a float constant. static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b) { @@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op, return success(); } -void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { - patterns.add(convertCtlzOp); -} - -void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) { - patterns.add(convertSinhOp); -} - -void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) { - patterns.add(convertCoshOp); -} - -void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { - patterns.add(convertTanOp); -} - -void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { - patterns.add(convertTanhOp); -} - -void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) { - patterns.add(convertAsinhOp); -} - -void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) { - patterns.add(convertAcoshOp); -} - -void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) { - patterns.add(convertAtanhOp); -} - -void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { - patterns.add(convertFmaFOp); -} - -void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { - patterns.add(convertCeilOp); -} - -void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { - patterns.add(convertExp2fOp); -} - -void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { - patterns.add(convertPowfOp); -} - -void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { - patterns.add(convertFPowIOp); -} - -void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { - patterns.add(convertRoundOp); +// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf` +static LogicalResult convertClampfOp(math::ClampFOp op, + PatternRewriter &rewriter) { + auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(), + op.getMin(), op.getFastmath()); + rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(), + op.getFastmath()); + return success(); } -void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) { - patterns.add(convertRoundEvenOp); +void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns, + ArrayRef<StringRef> opMnemonics) { + auto filter = [&](StringRef name) { + // This should be a static assert and `consume_front` take a twine, but none + // is currently possible. TODO: augment `StringRef::consume_front` and make + // `getDialectNamespace` use `std::string_view`. + assert("math" == MathDialect::getDialectNamespace()); + name.consume_front("math."); + return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0); + }; + if (filter(CountLeadingZerosOp::getOperationName())) + patterns.add(convertCtlzOp); + if (filter(SinhOp::getOperationName())) + patterns.add(convertSinhOp); + if (filter(CoshOp::getOperationName())) + patterns.add(convertCoshOp); + if (filter(TanOp::getOperationName())) + patterns.add(convertTanOp); + if (filter(TanhOp::getOperationName())) + patterns.add(convertTanhOp); + if (filter(AsinhOp::getOperationName())) + patterns.add(convertAsinhOp); + if (filter(AcoshOp::getOperationName())) + patterns.add(convertAcoshOp); + if (filter(AtanhOp::getOperationName())) + patterns.add(convertAtanhOp); + if (filter(FmaOp::getOperationName())) + patterns.add(convertFmaFOp); + if (filter(CeilOp::getOperationName())) + patterns.add(convertCeilOp); + if (filter(Exp2Op::getOperationName())) + patterns.add(convertExp2fOp); + if (filter(PowFOp::getOperationName())) + patterns.add(convertPowfOp); + if (filter(FPowIOp::getOperationName())) + patterns.add(convertFPowIOp); + if (filter(RoundOp::getOperationName())) + patterns.add(convertRoundOp); + if (filter(RoundEvenOp::getOperationName())) + patterns.add(convertRoundEvenOp); + if (filter(RsqrtOp::getOperationName())) + patterns.add(convertRsqrtOp); + if (filter(ClampFOp::getOperationName())) + patterns.add(convertClampfOp); } -void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) { - patterns.add(convertRsqrtOp); -} +//===----------------------------------------------------------------------===// +// MathExpandOpsPass pass +//===----------------------------------------------------------------------===// +namespace { +struct MathExpandOpsPass final + : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> { + using MathExpandOpsPassBase::MathExpandOpsPassBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + SmallVector<StringRef> mnemonics = + llvm::to_vector_of<StringRef>(opMnemonics); + math::populateExpansionPatterns(patterns, mnemonics); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 74b968c..b59d73d 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() { case arith::AtomicRMWKind::minu: case arith::AtomicRMWKind::muli: case arith::AtomicRMWKind::ori: + case arith::AtomicRMWKind::xori: case arith::AtomicRMWKind::andi: if (!llvm::isa<IntegerType>(getValue().getType())) return emitOpError() << "with kind '" diff --git a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp index bbb269b..1939195 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.cpp @@ -21,9 +21,9 @@ namespace { struct ReallocOpInterface : public BufferViewFlowOpInterface::ExternalModel<ReallocOpInterface, ReallocOp> { - void - populateDependencies(Operation *op, - RegisterDependenciesFn registerDependenciesFn) const { + void populateDependencies( + Operation *op, + const RegisterDependenciesFn ®isterDependenciesFn) const { auto reallocOp = cast<ReallocOp>(op); // memref.realloc may return the source operand. registerDependenciesFn(reallocOp.getSource(), reallocOp.getResult()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 9771bd2..d35566a 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -959,7 +959,7 @@ class RewriteExtractAlignedPointerAsIndexOfViewLikeOp PatternRewriter &rewriter) const override { auto viewLikeOp = extractOp.getSource().getDefiningOp<ViewLikeOpInterface>(); - if (!viewLikeOp) + if (!viewLikeOp || extractOp.getSource() != viewLikeOp.getViewDest()) return rewriter.notifyMatchFailure(extractOp, "not a ViewLike source"); rewriter.modifyOpInPlace(extractOp, [&]() { extractOp.getSourceMutable().assign(viewLikeOp.getViewSource()); diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp index 5d3cec4..860384f 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp @@ -43,50 +43,34 @@ static bool overrideBuffer(Operation *op, Value buffer) { /// propagate the type change and erase old subview ops. static void replaceUsesAndPropagateType(RewriterBase &rewriter, Operation *oldOp, Value val) { - SmallVector<Operation *> opsToDelete; - SmallVector<OpOperand *> operandsToReplace; - - // Save the operand to replace / delete later (avoid iterator invalidation). - // TODO: can we use an early_inc iterator? - for (OpOperand &use : oldOp->getUses()) { - // Non-subview ops will be replaced by `val`. - auto subviewUse = dyn_cast<memref::SubViewOp>(use.getOwner()); - if (!subviewUse) { - operandsToReplace.push_back(&use); + // Iterate with early_inc to erase current user inside the loop. + for (OpOperand &use : llvm::make_early_inc_range(oldOp->getUses())) { + Operation *user = use.getOwner(); + if (auto subviewUse = dyn_cast<memref::SubViewOp>(user)) { + // `subview(old_op)` is replaced by a new `subview(val)`. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(subviewUse); + MemRefType newType = memref::SubViewOp::inferRankReducedResultType( + subviewUse.getType().getShape(), cast<MemRefType>(val.getType()), + subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), + subviewUse.getStaticStrides()); + Value newSubview = memref::SubViewOp::create( + rewriter, subviewUse->getLoc(), newType, val, + subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), + subviewUse.getMixedStrides()); + + // Ouch recursion ... is this really necessary? + replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); + + // Safe to erase. + rewriter.eraseOp(subviewUse); continue; } - - // `subview(old_op)` is replaced by a new `subview(val)`. - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(subviewUse); - MemRefType newType = memref::SubViewOp::inferRankReducedResultType( - subviewUse.getType().getShape(), cast<MemRefType>(val.getType()), - subviewUse.getStaticOffsets(), subviewUse.getStaticSizes(), - subviewUse.getStaticStrides()); - Value newSubview = memref::SubViewOp::create( - rewriter, subviewUse->getLoc(), newType, val, - subviewUse.getMixedOffsets(), subviewUse.getMixedSizes(), - subviewUse.getMixedStrides()); - - // Ouch recursion ... is this really necessary? - replaceUsesAndPropagateType(rewriter, subviewUse, newSubview); - - opsToDelete.push_back(use.getOwner()); + // Non-subview: replace with new value. + rewriter.startOpModification(user); + use.set(val); + rewriter.finalizeOpModification(user); } - - // Perform late replacement. - // TODO: can we use an early_inc iterator? - for (OpOperand *operand : operandsToReplace) { - Operation *op = operand->getOwner(); - rewriter.startOpModification(op); - operand->set(val); - rewriter.finalizeOpModification(op); - } - - // Perform late op erasure. - // TODO: can we use an early_inc iterator? - for (Operation *op : opsToDelete) - rewriter.eraseOp(op); } // Transformation to do multi-buffering/array expansion to remove dependencies @@ -216,8 +200,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp, offsets, sizes, strides); LLVM_DEBUG(DBGS() << "--multi-buffered slice: " << subview << "\n"); - // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need to - // handle dealloc uses separately.. + // 5. Due to the recursive nature of replaceUsesAndPropagateType , we need + // to handle dealloc uses separately.. for (OpOperand &use : llvm::make_early_inc_range(allocOp->getUses())) { auto deallocOp = dyn_cast<memref::DeallocOp>(use.getOwner()); if (!deallocOp) diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 5af46a4..3de9c38 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -210,8 +210,10 @@ MemrefValue skipFullyAliasingOperations(MemrefValue source) { MemrefValue skipViewLikeOps(MemrefValue source) { while (auto op = source.getDefiningOp()) { if (auto viewLike = dyn_cast<ViewLikeOpInterface>(op)) { - source = cast<MemrefValue>(viewLike.getViewSource()); - continue; + if (source == viewLike.getViewDest()) { + source = cast<MemrefValue>(viewLike.getViewSource()); + continue; + } } return source; } diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 34c95e3..8474244 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -422,6 +422,12 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref( << descMemref << " != " << dstMemref; } + int lastDimBytes = + descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8; + if (lastDimBytes % 16 != 0) { + return op->emitError() << "the bytes in the last dimension of the tensor " + "map must be a multiple of 16"; + } return std::nullopt; } diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 485bb73..ded4c7a 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -173,9 +173,7 @@ void OpenACCDialect::initialize() { //===----------------------------------------------------------------------===// static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) { - if (arrayAttr && *arrayAttr && arrayAttr->size() > 0) - return true; - return false; + return arrayAttr && *arrayAttr && arrayAttr->size() > 0; } static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr, @@ -1390,6 +1388,36 @@ void acc::ParallelOp::addPrivatization(MLIRContext *context, setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } +void acc::ParallelOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + getFirstprivateOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getFirstprivatizationRecipesAttr()) + llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + +void acc::ParallelOp::addReduction(MLIRContext *context, + mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + getReductionOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getReductionRecipesAttr()) + llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + static ParseResult parseNumGangs( mlir::OpAsmParser &parser, llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands, @@ -2041,6 +2069,36 @@ void acc::SerialOp::addPrivatization(MLIRContext *context, setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } +void acc::SerialOp::addFirstPrivatization( + MLIRContext *context, mlir::acc::FirstprivateOp op, + mlir::acc::FirstprivateRecipeOp recipe) { + getFirstprivateOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getFirstprivatizationRecipesAttr()) + llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + +void acc::SerialOp::addReduction(MLIRContext *context, + mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + getReductionOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getReductionRecipesAttr()) + llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + //===----------------------------------------------------------------------===// // KernelsOp //===----------------------------------------------------------------------===// @@ -3059,6 +3117,20 @@ void acc::LoopOp::addPrivatization(MLIRContext *context, setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes)); } +void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op, + mlir::acc::ReductionRecipeOp recipe) { + getReductionOperandsMutable().append(op.getResult()); + + llvm::SmallVector<mlir::Attribute> recipes; + + if (getReductionRecipesAttr()) + llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes)); + + recipes.push_back( + mlir::SymbolRefAttr::get(context, recipe.getSymName().str())); + setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes)); +} + //===----------------------------------------------------------------------===// // DataOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index c1c1767..6e43f28 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3874,6 +3874,159 @@ LogicalResult AllocateDirOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// TargetAllocMemOp +//===----------------------------------------------------------------------===// + +mlir::Type omp::TargetAllocMemOp::getAllocatedType() { + return getInTypeAttr().getValue(); +} + +/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype, +/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )? +/// attr-dict-without-keyword +static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + auto &builder = parser.getBuilder(); + bool hasOperands = false; + std::int32_t typeparamsSize = 0; + + // Parse device number as a new operand + mlir::OpAsmParser::UnresolvedOperand deviceOperand; + mlir::Type deviceType; + if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType)) + return mlir::failure(); + if (parser.resolveOperand(deviceOperand, deviceType, result.operands)) + return mlir::failure(); + if (parser.parseComma()) + return mlir::failure(); + + mlir::Type intype; + if (parser.parseType(intype)) + return mlir::failure(); + result.addAttribute("in_type", mlir::TypeAttr::get(intype)); + llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands; + llvm::SmallVector<mlir::Type> typeVec; + if (!parser.parseOptionalLParen()) { + // parse the LEN params of the derived type. (<params> : <types>) + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) || + parser.parseColonTypeList(typeVec) || parser.parseRParen()) + return mlir::failure(); + typeparamsSize = operands.size(); + hasOperands = true; + } + std::int32_t shapeSize = 0; + if (!parser.parseOptionalComma()) { + // parse size to scale by, vector of n dimensions of type index + if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None)) + return mlir::failure(); + shapeSize = operands.size() - typeparamsSize; + auto idxTy = builder.getIndexType(); + for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i) + typeVec.push_back(idxTy); + hasOperands = true; + } + if (hasOperands && + parser.resolveOperands(operands, typeVec, parser.getNameLoc(), + result.operands)) + return mlir::failure(); + + mlir::Type restype = builder.getIntegerType(64); + if (!restype) { + parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype; + return mlir::failure(); + } + llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize}; + result.addAttribute("operandSegmentSizes", + builder.getDenseI32ArrayAttr(segmentSizes)); + if (parser.parseOptionalAttrDict(result.attributes) || + parser.addTypeToList(restype, result.types)) + return mlir::failure(); + return mlir::success(); +} + +mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseTargetAllocMemOp(parser, result); +} + +void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) { + p << " "; + p.printOperand(getDevice()); + p << " : "; + p << getDevice().getType(); + p << ", "; + p << getInType(); + if (!getTypeparams().empty()) { + p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')'; + } + for (auto sh : getShape()) { + p << ", "; + p.printOperand(sh); + } + p.printOptionalAttrDict((*this)->getAttrs(), + {"in_type", "operandSegmentSizes"}); +} + +llvm::LogicalResult omp::TargetAllocMemOp::verify() { + mlir::Type outType = getType(); + if (!mlir::dyn_cast<IntegerType>(outType)) + return emitOpError("must be a integer type"); + return mlir::success(); +} + +//===----------------------------------------------------------------------===// +// WorkdistributeOp +//===----------------------------------------------------------------------===// + +LogicalResult WorkdistributeOp::verify() { + // Check that region exists and is not empty + Region ®ion = getRegion(); + if (region.empty()) + return emitOpError("region cannot be empty"); + // Verify single entry point. + Block &entryBlock = region.front(); + if (entryBlock.empty()) + return emitOpError("region must contain a structured block"); + // Verify single exit point. + bool hasTerminator = false; + for (Block &block : region) { + if (isa<TerminatorOp>(block.back())) { + if (hasTerminator) { + return emitOpError("region must have exactly one terminator"); + } + hasTerminator = true; + } + } + if (!hasTerminator) { + return emitOpError("region must be terminated with omp.terminator"); + } + auto walkResult = region.walk([&](Operation *op) -> WalkResult { + // No implicit barrier at end + if (isa<BarrierOp>(op)) { + return emitOpError( + "explicit barriers are not allowed in workdistribute region"); + } + // Check for invalid nested constructs + if (isa<ParallelOp>(op)) { + return emitOpError( + "nested parallel constructs not allowed in workdistribute"); + } + if (isa<TeamsOp>(op)) { + return emitOpError( + "nested teams constructs not allowed in workdistribute"); + } + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return failure(); + + Operation *parentOp = (*this)->getParentOp(); + if (!llvm::dyn_cast<TeamsOp>(parentOp)) + return emitOpError("workdistribute must be nested under teams"); + return success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt index 497468b..bd1e655 100644 --- a/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Ptr/IR/CMakeLists.txt @@ -1,3 +1,22 @@ +set(LLVM_OPTIONAL_SOURCES + MemorySpaceInterfaces.cpp + PtrAttrs.cpp + PtrTypes.cpp + PtrDialect.cpp +) + +add_mlir_dialect_library( + MLIRPtrMemorySpaceInterfaces + MemorySpaceInterfaces.cpp + + DEPENDS + MLIRPtrOpsEnumsGen + MLIRPtrMemorySpaceInterfacesIncGen + LINK_LIBS + PUBLIC + MLIRIR +) + add_mlir_dialect_library( MLIRPtrDialect PtrAttrs.cpp @@ -15,4 +34,5 @@ add_mlir_dialect_library( MLIRDataLayoutInterfaces MLIRMemorySlotInterfaces MLIRViewLikeInterface + MLIRPtrMemorySpaceInterfaces ) diff --git a/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp new file mode 100644 index 0000000..059e67f --- /dev/null +++ b/mlir/lib/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp @@ -0,0 +1,15 @@ +//===-- MemorySpaceInterfaces.cpp - ptr memory space interfaces -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the ptr dialect memory space interfaces. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.h" + +#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc" diff --git a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp index 772d25d..ac3bcd6 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrAttrs.cpp @@ -22,26 +22,30 @@ constexpr const static unsigned kBitsInByte = 8; //===----------------------------------------------------------------------===// bool GenericSpaceAttr::isValidLoad( - Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment, + Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, function_ref<InFlightDiagnostic()> emitError) const { return true; } bool GenericSpaceAttr::isValidStore( - Type type, ptr::AtomicOrdering ordering, IntegerAttr alignment, + Type type, ptr::AtomicOrdering ordering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, function_ref<InFlightDiagnostic()> emitError) const { return true; } bool GenericSpaceAttr::isValidAtomicOp( ptr::AtomicBinOp op, Type type, ptr::AtomicOrdering ordering, - IntegerAttr alignment, function_ref<InFlightDiagnostic()> emitError) const { + std::optional<int64_t> alignment, const ::mlir::DataLayout *dataLayout, + function_ref<InFlightDiagnostic()> emitError) const { return true; } bool GenericSpaceAttr::isValidAtomicXchg( Type type, ptr::AtomicOrdering successOrdering, - ptr::AtomicOrdering failureOrdering, IntegerAttr alignment, + ptr::AtomicOrdering failureOrdering, std::optional<int64_t> alignment, + const ::mlir::DataLayout *dataLayout, function_ref<InFlightDiagnostic()> emitError) const { return true; } diff --git a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp index c5ec0ca..d5976b9 100644 --- a/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp +++ b/mlir/lib/Dialect/Ptr/IR/PtrDialect.cpp @@ -85,6 +85,124 @@ LogicalResult FromPtrOp::verify() { } //===----------------------------------------------------------------------===// +// LoadOp +//===----------------------------------------------------------------------===// + +/// Verifies the attributes and the type of atomic memory access operations. +template <typename OpTy> +static LogicalResult +verifyAtomicMemOp(OpTy memOp, ArrayRef<AtomicOrdering> unsupportedOrderings) { + if (memOp.getOrdering() != AtomicOrdering::not_atomic) { + if (llvm::is_contained(unsupportedOrderings, memOp.getOrdering())) + return memOp.emitOpError("unsupported ordering '") + << stringifyAtomicOrdering(memOp.getOrdering()) << "'"; + if (!memOp.getAlignment()) + return memOp.emitOpError("expected alignment for atomic access"); + return success(); + } + if (memOp.getSyncscope()) { + return memOp.emitOpError( + "expected syncscope to be null for non-atomic access"); + } + return success(); +} + +/// Verifies that the alignment attribute is a power of 2 if present. +static LogicalResult +verifyAlignment(std::optional<int64_t> alignment, + function_ref<InFlightDiagnostic()> emitError) { + if (!alignment) + return success(); + if (alignment.value() <= 0) + return emitError() << "alignment must be positive"; + if (!llvm::isPowerOf2_64(alignment.value())) + return emitError() << "alignment must be a power of 2"; + return success(); +} + +void LoadOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Read::get(), &getPtrMutable()); + // Volatile operations can have target-specific read-write effects on + // memory besides the one referred to by the pointer operand. + // Similarly, atomic operations that are monotonic or stricter cause + // synchronization that from a language point-of-view, are arbitrary + // read-writes into memory. + if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic && + getOrdering() != AtomicOrdering::unordered)) { + effects.emplace_back(MemoryEffects::Write::get()); + effects.emplace_back(MemoryEffects::Read::get()); + } +} + +LogicalResult LoadOp::verify() { + auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; + MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); + if (!ms.isValidLoad(getResult().getType(), getOrdering(), getAlignment(), + &dataLayout, emitDiag)) + return failure(); + if (failed(verifyAlignment(getAlignment(), emitDiag))) + return failure(); + return verifyAtomicMemOp(*this, + {AtomicOrdering::release, AtomicOrdering::acq_rel}); +} + +void LoadOp::build(OpBuilder &builder, OperationState &state, Type type, + Value addr, unsigned alignment, bool isVolatile, + bool isNonTemporal, bool isInvariant, bool isInvariantGroup, + AtomicOrdering ordering, StringRef syncscope) { + build(builder, state, type, addr, + alignment ? std::optional<int64_t>(alignment) : std::nullopt, + isVolatile, isNonTemporal, isInvariant, isInvariantGroup, ordering, + syncscope.empty() ? nullptr : builder.getStringAttr(syncscope)); +} + +//===----------------------------------------------------------------------===// +// StoreOp +//===----------------------------------------------------------------------===// + +void StoreOp::getEffects( + SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> + &effects) { + effects.emplace_back(MemoryEffects::Write::get(), &getPtrMutable()); + // Volatile operations can have target-specific read-write effects on + // memory besides the one referred to by the pointer operand. + // Similarly, atomic operations that are monotonic or stricter cause + // synchronization that from a language point-of-view, are arbitrary + // read-writes into memory. + if (getVolatile_() || (getOrdering() != AtomicOrdering::not_atomic && + getOrdering() != AtomicOrdering::unordered)) { + effects.emplace_back(MemoryEffects::Write::get()); + effects.emplace_back(MemoryEffects::Read::get()); + } +} + +LogicalResult StoreOp::verify() { + auto emitDiag = [&]() -> InFlightDiagnostic { return emitError(); }; + MemorySpaceAttrInterface ms = getPtr().getType().getMemorySpace(); + DataLayout dataLayout = DataLayout::closest(*this); + if (!ms.isValidStore(getValue().getType(), getOrdering(), getAlignment(), + &dataLayout, emitDiag)) + return failure(); + if (failed(verifyAlignment(getAlignment(), emitDiag))) + return failure(); + return verifyAtomicMemOp(*this, + {AtomicOrdering::acquire, AtomicOrdering::acq_rel}); +} + +void StoreOp::build(OpBuilder &builder, OperationState &state, Value value, + Value addr, unsigned alignment, bool isVolatile, + bool isNonTemporal, bool isInvariantGroup, + AtomicOrdering ordering, StringRef syncscope) { + build(builder, state, value, addr, + alignment ? std::optional<int64_t>(alignment) : std::nullopt, + isVolatile, isNonTemporal, isInvariantGroup, ordering, + syncscope.empty() ? nullptr : builder.getStringAttr(syncscope)); +} + +//===----------------------------------------------------------------------===// // PtrAddOp //===----------------------------------------------------------------------===// @@ -152,10 +270,6 @@ llvm::TypeSize TypeOffsetOp::getTypeSize(std::optional<DataLayout> layout) { #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Ptr/IR/PtrOpsAttrs.cpp.inc" -#include "mlir/Dialect/Ptr/IR/MemorySpaceInterfaces.cpp.inc" - -#include "mlir/Dialect/Ptr/IR/MemorySpaceAttrInterfaces.cpp.inc" - #include "mlir/Dialect/Ptr/IR/PtrOpsEnums.cpp.inc" #define GET_TYPEDEF_CLASSES diff --git a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt index 825d119..deb7109 100644 --- a/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Quant/Transforms/CMakeLists.txt @@ -4,7 +4,7 @@ add_mlir_dialect_library(MLIRQuantTransforms StripFuncQuantTypes.cpp ADDITIONAL_HEADER_DIRS - {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Quant/Transforms DEPENDS MLIRQuantTransformsIncGen diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 0262a1b..84f9777 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -157,8 +157,7 @@ void ExecuteRegionOp::print(OpAsmPrinter &p) { p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); - - p.printOptionalAttrDict((*this)->getAttrs()); + p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"no_inline"}); } LogicalResult ExecuteRegionOp::verify() { @@ -318,9 +317,12 @@ void ConditionOp::getSuccessorRegions( void ForOp::build(OpBuilder &builder, OperationState &result, Value lb, Value ub, Value step, ValueRange initArgs, - BodyBuilderFn bodyBuilder) { + BodyBuilderFn bodyBuilder, bool unsignedCmp) { OpBuilder::InsertionGuard guard(builder); + if (unsignedCmp) + result.addAttribute(getUnsignedCmpAttrName(result.name), + builder.getUnitAttr()); result.addOperands({lb, ub, step}); result.addOperands(initArgs); for (Value v : initArgs) @@ -450,6 +452,9 @@ static void printInitializationList(OpAsmPrinter &p, } void ForOp::print(OpAsmPrinter &p) { + if (getUnsignedCmp()) + p << " unsigned"; + p << " " << getInductionVar() << " = " << getLowerBound() << " to " << getUpperBound() << " step " << getStep(); @@ -462,7 +467,8 @@ void ForOp::print(OpAsmPrinter &p) { p.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/!getInitArgs().empty()); - p.printOptionalAttrDict((*this)->getAttrs()); + p.printOptionalAttrDict((*this)->getAttrs(), + /*elidedAttrs=*/getUnsignedCmpAttrName().strref()); } ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { @@ -472,6 +478,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::Argument inductionVariable; OpAsmParser::UnresolvedOperand lb, ub, step; + if (succeeded(parser.parseOptionalKeyword("unsigned"))) + result.addAttribute(getUnsignedCmpAttrName(result.name), + builder.getUnitAttr()); + // Parse the induction variable followed by '='. if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() || // Parse loop bounds. @@ -562,7 +572,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter, inits.append(newInitOperands.begin(), newInitOperands.end()); scf::ForOp newLoop = scf::ForOp::create( rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits, - [](OpBuilder &, Location, Value, ValueRange) {}); + [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp()); newLoop->setAttrs(getPrunedAttributeList(getOperation(), {})); // Generate the new yield values and append them to the scf.yield operation. @@ -806,7 +816,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, // 2. Create the new forOp shell. scf::ForOp newForOp = scf::ForOp::create( rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newIterOperands); + forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr, + forOp.getUnsignedCmp()); newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.getRegion().front(); SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(), @@ -931,7 +942,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { scf::ForOp newForOp = scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), - forOp.getUpperBound(), forOp.getStep(), newIterArgs); + forOp.getUpperBound(), forOp.getStep(), newIterArgs, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); newForOp->setAttrs(forOp->getAttrs()); Block &newBlock = newForOp.getRegion().front(); @@ -989,12 +1001,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { /// Util function that tries to compute a constant diff between u and l. /// Returns std::nullopt when the difference between two AffineValueMap is /// dynamic. -static std::optional<int64_t> computeConstDiff(Value l, Value u) { +static std::optional<APInt> computeConstDiff(Value l, Value u) { IntegerAttr clb, cub; if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) { llvm::APInt lbValue = clb.getValue(); llvm::APInt ubValue = cub.getValue(); - return (ubValue - lbValue).getSExtValue(); + return ubValue - lbValue; } // Else a simple pattern match for x + c or c + x @@ -1003,7 +1015,7 @@ static std::optional<int64_t> computeConstDiff(Value l, Value u) { u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) || matchPattern( u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l)))) - return diff.getSExtValue(); + return diff; return std::nullopt; } @@ -1022,13 +1034,15 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> { return success(); } - std::optional<int64_t> diff = + std::optional<APInt> diff = computeConstDiff(op.getLowerBound(), op.getUpperBound()); if (!diff) return failure(); // If the loop is known to have 0 iterations, remove it. - if (*diff <= 0) { + bool zeroOrLessIterations = + diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative()); + if (zeroOrLessIterations) { rewriter.replaceOp(op, op.getInitArgs()); return success(); } @@ -3384,9 +3398,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) { if (functionType.getNumInputs() != operands.size()) { return parser.emitError(typeLoc) - << "expected as many input types as operands " - << "(expected " << operands.size() << " got " - << functionType.getNumInputs() << ")"; + << "expected as many input types as operands " << "(expected " + << operands.size() << " got " << functionType.getNumInputs() << ")"; } // Resolve input operands. @@ -4222,14 +4235,15 @@ LogicalResult scf::IndexSwitchOp::verify() { << "see yield operation here"; } for (auto [idx, result, operand] : - llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(), - yield.getOperandTypes())) { - if (result == operand) + llvm::enumerate(getResultTypes(), yield.getOperands())) { + if (!operand) + return yield.emitOpError() << "operand " << idx << " is null\n"; + if (result == operand.getType()) continue; return (emitOpError("expected result #") << idx << " of each region to be " << result) .attachNote(yield.getLoc()) - << name << " returns " << operand << " here"; + << name << " returns " << operand.getType() << " here"; } return success(); }; diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index f8799c5..fb179e6 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -769,7 +769,8 @@ struct ForOpInterface // Construct a new scf.for op with memref instead of tensor values. auto newForOp = scf::ForOp::create( rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), castedInitArgs); + forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr, + forOp.getUnsignedCmp()); newForOp->setAttrs(forOp->getAttrs()); Block *loopBody = newForOp.getBody(); diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp index bee7780..ae52af5 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp @@ -58,9 +58,12 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> { auto *beforeBlock = rewriter.createBlock( &whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs); rewriter.setInsertionPointToStart(whileOp.getBeforeBody()); - auto cmpOp = arith::CmpIOp::create( - rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt, - beforeBlock->getArgument(0), forOp.getUpperBound()); + arith::CmpIPredicate predicate = forOp.getUnsignedCmp() + ? arith::CmpIPredicate::ult + : arith::CmpIPredicate::slt; + auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), predicate, + beforeBlock->getArgument(0), + forOp.getUpperBound()); scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(), beforeBlock->getArguments()); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp index 1130538..7e7fba4 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp @@ -791,6 +791,11 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp, bool *modifiedIR) { if (modifiedIR) *modifiedIR = false; + + // TODO: Add support for unsigned loops. + if (forOp.getUnsignedCmp()) + return failure(); + LoopPipelinerInternal pipeliner; if (!pipeliner.initializeLoopInfo(forOp, options)) return failure(); diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp index 4752c08..f1203b2 100644 --- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp @@ -256,6 +256,10 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> { LogicalResult matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const override { + if (forOp.getUnsignedCmp()) + return rewriter.notifyMatchFailure(forOp, + "unsigned loops are not supported"); + // Do not peel already peeled loops. if (forOp->hasAttr(kPeeledLoopLabel)) return failure(); diff --git a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp index 694cd85..4ea8321 100644 --- a/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp @@ -269,10 +269,10 @@ namespace { struct ParallelLoopFusion : public impl::SCFParallelLoopFusionBase<ParallelLoopFusion> { void runOnOperation() override { - auto &AA = getAnalysis<AliasAnalysis>(); + auto &aa = getAnalysis<AliasAnalysis>(); auto mayAlias = [&](Value val1, Value val2) -> bool { - return !AA.alias(val1, val2).isNo(); + return !aa.alias(val1, val2).isNo(); }; getOperation()->walk([&](Operation *child) { diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp index 1b07b77..072bc50 100644 --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -52,8 +52,8 @@ public: SmallVector<unsigned> offsets; offsets.push_back(0); // Do the type conversion and record the offsets. - for (Type type : op.getResultTypes()) { - if (failed(typeConverter->convertTypes(type, dstTypes))) + for (Value v : op.getResults()) { + if (failed(typeConverter->convertType(v, dstTypes))) return rewriter.notifyMatchFailure(op, "could not convert result type"); offsets.push_back(dstTypes.size()); } @@ -116,7 +116,8 @@ public: llvm::getSingleElement(adaptor.getLowerBound()), llvm::getSingleElement(adaptor.getUpperBound()), llvm::getSingleElement(adaptor.getStep()), - flattenValues(adaptor.getInitArgs())); + flattenValues(adaptor.getInitArgs()), + /*bodyBuilder=*/nullptr, op.getUnsignedCmp()); // Reserve whatever attributes in the original op. newOp->setAttrs(op->getAttrs()); @@ -126,7 +127,6 @@ public: // Inline the type converted region from the original operation. rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), newOp.getRegion().end()); - return newOp; } }; @@ -225,15 +225,14 @@ void mlir::scf::populateSCFStructuralTypeConversions( void mlir::scf::populateSCFStructuralTypeConversionTarget( const TypeConverter &typeConverter, ConversionTarget &target) { - target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) { - return typeConverter.isLegal(op->getResultTypes()); - }); + target.addDynamicallyLegalOp<ForOp, IfOp>( + [&](Operation *op) { return typeConverter.isLegal(op->getResults()); }); target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) { // We only have conversions for a subset of ops that use scf.yield // terminators. if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp())) return true; - return typeConverter.isLegal(op.getOperandTypes()); + return typeConverter.isLegal(op.getOperands()); }); target.addDynamicallyLegalOp<WhileOp, ConditionOp>( [&](Operation *op) { return typeConverter.isLegal(op); }); diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index c0e47ee..834c021 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -797,7 +797,8 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>( inits.append(newInitOperands.begin(), newInitOperands.end()); auto newLoop = scf::ForOp::create( rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(), - loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}); + loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {}, + loopOp.getUnsignedCmp()); // Move the loop body to the new op. Block *loopBody = loopOp.getBody(); @@ -935,7 +936,8 @@ static LogicalResult addInitOperandsToLoopNest( auto newLoop = scf::ForOp::create( rewriter, forLoop.getLoc(), forLoop.getLowerBound(), forLoop.getUpperBound(), forLoop.getStep(), newInits, - [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}); + [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {}, + forLoop.getUnsignedCmp()); // Merge the body of the new loop with the body of the old loops. SmallVector<Value> sourceBlockArgs; @@ -1914,63 +1916,6 @@ static FailureOr<OpOperand *> getConsumerFromLoopUses(RewriterBase &rewriter, return failure(); } -/// Check that the loop is perfectly nested. -/// The loops are expected to be ordered from outer most to inner most. -/// For example: -/// ``` -/// %0 = scf.for() -/// %1 = scf.for() -/// %2 = scf.for() -/// %3 = ... -/// yield %3 -/// yield %2 -/// yield %1 -/// ``` -/// Here loops should be [%0, %1]. -static bool -isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops) { - assert(!loops.empty() && "unexpected empty loop nest"); - if (loops.size() == 1) { - return isa_and_nonnull<scf::ForOp>(loops.front().getOperation()); - } - for (auto [outerLoop, innerLoop] : - llvm::zip_equal(loops.drop_back(), loops.drop_front())) { - auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation()); - auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation()); - if (!outerFor || !innerFor) { - return false; - } - auto outerBBArgs = outerFor.getRegionIterArgs(); - auto innerIterArgs = innerFor.getInitArgs(); - if (outerBBArgs.size() != innerIterArgs.size()) { - return false; - } - - for (auto [outerBBArg, innerIterArg] : - llvm::zip_equal(outerBBArgs, innerIterArgs)) { - if (!llvm::hasSingleElement(outerBBArg.getUses()) || - innerIterArg != outerBBArg) { - return false; - } - } - - ValueRange outerYields = - cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands(); - ValueRange innerResults = innerFor.getResults(); - if (outerYields.size() != innerResults.size()) { - return false; - } - for (auto [outerYield, innerResult] : - llvm::zip_equal(outerYields, innerResults)) { - if (!llvm::hasSingleElement(innerResult.getUses()) || - outerYield != innerResult) { - return false; - } - } - } - return true; -} - /// Fetch the untiled consumer of the outermost scf.for's result which is /// yielded by a tensor.insert_slice from the innermost scf.for. This function /// makes the following assumptions : diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 5731795..684dff8 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1233,6 +1233,7 @@ static void getPerfectlyNestedLoopsImpl( static Loops stripmineSink(scf::ForOp forOp, Value factor, ArrayRef<scf::ForOp> targets) { + assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported"); auto originalStep = forOp.getStep(); auto iv = forOp.getInductionVar(); @@ -1241,6 +1242,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor, Loops innerLoops; for (auto t : targets) { + assert(!t.getUnsignedCmp() && "unsigned loops are not supported"); + // Save information for splicing ops out of t when done auto begin = t.getBody()->begin(); auto nOps = t.getBody()->getOperations().size(); @@ -1415,6 +1418,8 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target, scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, RewriterBase &rewriter) { + assert(source.getUnsignedCmp() == target.getUnsignedCmp() && + "incompatible signedness"); unsigned numTargetOuts = target.getNumResults(); unsigned numSourceOuts = source.getNumResults(); @@ -1428,7 +1433,8 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target, rewriter.setInsertionPointAfter(source); scf::ForOp fusedLoop = scf::ForOp::create( rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(), - source.getStep(), fusedInitArgs); + source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr, + source.getUnsignedCmp()); // Map original induction variables and operands to those of the fused loop. IRMapping mapping; @@ -1506,3 +1512,41 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter, rewriter.replaceOp(forallOp, normalizedForallOp); return normalizedForallOp; } + +bool mlir::isPerfectlyNestedForLoops( + MutableArrayRef<LoopLikeOpInterface> loops) { + assert(!loops.empty() && "unexpected empty loop nest"); + if (loops.size() == 1) + return isa_and_nonnull<scf::ForOp>(loops.front().getOperation()); + for (auto [outerLoop, innerLoop] : + llvm::zip_equal(loops.drop_back(), loops.drop_front())) { + auto outerFor = dyn_cast_or_null<scf::ForOp>(outerLoop.getOperation()); + auto innerFor = dyn_cast_or_null<scf::ForOp>(innerLoop.getOperation()); + if (!outerFor || !innerFor) + return false; + auto outerBBArgs = outerFor.getRegionIterArgs(); + auto innerIterArgs = innerFor.getInitArgs(); + if (outerBBArgs.size() != innerIterArgs.size()) + return false; + + for (auto [outerBBArg, innerIterArg] : + llvm::zip_equal(outerBBArgs, innerIterArgs)) { + if (!llvm::hasSingleElement(outerBBArg.getUses()) || + innerIterArg != outerBBArg) + return false; + } + + ValueRange outerYields = + cast<scf::YieldOp>(outerFor.getBody()->getTerminator())->getOperands(); + ValueRange innerResults = innerFor.getResults(); + if (outerYields.size() != innerResults.size()) + return false; + for (auto [outerYield, innerResult] : + llvm::zip_equal(outerYields, innerResults)) { + if (!llvm::hasSingleElement(innerResult.getUses()) || + outerYield != innerResult) + return false; + } + } + return true; +} diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp index ddb3426..369b953 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp @@ -1322,7 +1322,7 @@ struct spirv::detail::TensorArmTypeStorage final : TypeStorage { } TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType) - : shape(std::move(shape)), elementType(std::move(elementType)) {} + : shape(shape), elementType(elementType) {} ArrayRef<int64_t> shape; Type elementType; diff --git a/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp index d4e7618..7a05dfe 100644 --- a/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp +++ b/mlir/lib/Dialect/Shard/Interfaces/ShardingInterface.cpp @@ -513,8 +513,9 @@ LogicalResult shard::detail::defaultAddShardingAnnotations( } #ifndef NDEBUG -static bool isValueCompatibleWithFullReplicationSharding(Value value, - Sharding sharding) { +static bool +isValueCompatibleWithFullReplicationSharding(Value value, + const Sharding &sharding) { if (isa<RankedTensorType>(value.getType())) { return isFullReplication(sharding); } diff --git a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp index 3e3d476..5dc61a2 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Partition.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Partition.cpp @@ -477,10 +477,10 @@ reshardOn1DGrid(ImplicitLocOpBuilder &builder, GridOp grid, return targetShard; } -TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, GridOp grid, - Sharding sourceSharding, Sharding targetSharding, - TypedValue<ShapedType> sourceUnshardedValue, - TypedValue<ShapedType> sourceShard) { +static TypedValue<ShapedType> +reshard(ImplicitLocOpBuilder &builder, GridOp grid, Sharding sourceSharding, + Sharding targetSharding, TypedValue<ShapedType> sourceUnshardedValue, + TypedValue<ShapedType> sourceShard) { // If source and destination sharding are the same, no need to do anything. if (sourceSharding == targetSharding || (isFullReplication(sourceSharding) && isFullReplication(targetSharding))) { @@ -535,7 +535,7 @@ using UnshardedToShardedValueMap = DenseMap<Value, Value>; // Get the types of block arguments for an partitioned block. // Reads the sharding annotations of the arguments to deduce the sharded types. // Types that are not ranked tensors are left unchanged. -SmallVector<Type> +static SmallVector<Type> shardedBlockArgumentTypes(Block &block, SymbolTableCollection &symbolTableCollection) { SmallVector<Type> res; diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp index 56b435c..9694a40 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/DimLvlMapParser.cpp @@ -231,7 +231,9 @@ ParseResult DimLvlMapParser::parseLvlSpecList() { const auto loc = parser.getCurrentLocation(); const auto res = parser.parseCommaSeparatedList( mlir::OpAsmParser::Delimiter::Paren, - [=]() -> ParseResult { return parseLvlSpec(requireLvlVarBinding); }, + [this, requireLvlVarBinding]() -> ParseResult { + return parseLvlSpec(requireLvlVarBinding); + }, " in level-specifier list"); FAILURE_IF_FAILED(res) const auto specLvlRank = lvlSpecs.size(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp index 9e2e6ab..a1711a6 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/Detail/Var.cpp @@ -156,13 +156,14 @@ minSMLoc(AsmParser &parser, llvm::SMLoc sm1, llvm::SMLoc sm2) { return pair1 <= pair2 ? sm1 : sm2; } -bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, StringRef name) { +static bool isInternalConsistent(VarEnv const &env, VarInfo::ID id, + StringRef name) { const auto &var = env.access(id); return (var.getName() == name && var.getID() == id); } -bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, llvm::SMLoc loc, - VarKind vk) { +static bool isUsageConsistent(VarEnv const &env, VarInfo::ID id, + llvm::SMLoc loc, VarKind vk) { const auto &var = env.access(id); return var.getKind() == vk; } diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp index 3b97786..dabbea1 100644 --- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp +++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp @@ -71,7 +71,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm, pm.addPass(createLowerAffinePass()); pm.addPass( createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions())); - pm.addPass(createFinalizeMemRefToLLVMConversionPass()); pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass()); pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass()); pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass()); @@ -79,12 +78,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm, pm.addPass(createConvertComplexToLibm()); pm.addPass( createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions())); - pm.addPass(createConvertComplexToLLVMPass()); - pm.addPass( - createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions())); - pm.addPass(createConvertFuncToLLVMPass()); - pm.addPass(createArithToLLVMConversionPass()); - pm.addPass(createConvertControlFlowToLLVMPass()); // Finalize GPU code generation. if (gpuCodegen) { @@ -99,8 +92,8 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm, pm.addPass(createGpuModuleToBinaryPass(gpuModuleToBinaryPassOptions)); } - // Convert poison values. - pm.addPass(createUBToLLVMConversionPass()); + // Convert to LLVM. + pm.addPass(createConvertToLLVMPass()); // Ensure all casts are realized. pm.addPass(createReconcileUnrealizedCastsPass()); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp index 3b4140e..ae7eef2 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseBufferRewriting.cpp @@ -1219,8 +1219,9 @@ static void createQuickSortFunc(OpBuilder &builder, ModuleOp module, /// Implements the rewriting for operator sort and sort_coo. template <typename OpTy> -LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, AffineMap xPerm, - uint64_t ny, PatternRewriter &rewriter) { +static LogicalResult matchAndRewriteSortOp(OpTy op, ValueRange xys, + AffineMap xPerm, uint64_t ny, + PatternRewriter &rewriter) { Location loc = op.getLoc(); SmallVector<Value> operands{constantIndex(rewriter, loc, 0), op.getN()}; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 134aef3..0e88d31d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -730,9 +730,9 @@ public: {tensor, lvlCoords, values, filled, added, count}, EmitCInterface::On); Operation *parent = getTop(op); + rewriter.setInsertionPointAfter(parent); rewriter.replaceOp(op, adaptor.getTensor()); // Deallocate the buffers on exit of the loop nest. - rewriter.setInsertionPointAfter(parent); memref::DeallocOp::create(rewriter, loc, values); memref::DeallocOp::create(rewriter, loc, filled); memref::DeallocOp::create(rewriter, loc, added); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp index 4464450..febec6d 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp @@ -533,8 +533,10 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl, VectorType vtp = vectorType(vl, init.getType()); Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0), forOp.getRegionIterArg(0), init, vtp); - forOpNew = scf::ForOp::create(rewriter, loc, forOp.getLowerBound(), - forOp.getUpperBound(), step, vinit); + forOpNew = + scf::ForOp::create(rewriter, loc, forOp.getLowerBound(), + forOp.getUpperBound(), step, vinit, + /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp()); forOpNew->setAttr( LoopEmitter::getLoopEmitterLoopAttrName(), forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName())); @@ -605,8 +607,8 @@ public: ForOpRewriter(MLIRContext *context, unsigned vectorLength, bool enableVLAVectorization, bool enableSIMDIndex32) - : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization, - enableSIMDIndex32} {} + : OpRewritePattern(context), + vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {} LogicalResult matchAndRewrite(scf::ForOp op, PatternRewriter &rewriter) const override { diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 7d4b112..68584ec 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3200,20 +3200,6 @@ void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) { setNameFn(getResult(), "padded"); } -// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it -// supports optional types. -void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand, - Type typeToInfer, Type typeToInferFrom) {} - -ParseResult -parseInferType(OpAsmParser &parser, - std::optional<OpAsmParser::UnresolvedOperand> optOperand, - Type &typeToInfer, Type typeToInferFrom) { - if (optOperand) - typeToInfer = typeToInferFrom; - return success(); -} - LogicalResult PadOp::verify() { auto sourceType = llvm::cast<RankedTensorType>(getSource().getType()); auto resultType = llvm::cast<RankedTensorType>(getResult().getType()); @@ -4059,7 +4045,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// // Common Canonicalizers and Folders. //===----------------------------------------------------------------------===// -bool foldTensorCastPrecondition(DestinationStyleOpInterface op) { +static bool foldTensorCastPrecondition(DestinationStyleOpInterface op) { // 1. InsertSliceOp has its own logic about folding tensor.cast ops. // 2. Exclude DPS ops that are also LoopLike from this interface as they // might need special handling of attached regions. diff --git a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp index 2ec23e1..dfce835 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp @@ -327,172 +327,31 @@ struct BubbleUpExpandShapeThroughExtractSlice PatternRewriter &rewriter) const override { auto expandShapeOp = sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); + if (!expandShapeOp) { + return rewriter.notifyMatchFailure( + sliceOp, "tensor.extract_slice source not produced by expand_shape"); + } + SmallVector<ReassociationIndices> reassociation = + expandShapeOp.getReassociationIndices(); - if (checkPreconditionForBubbleUpExtractSlice(sliceOp, expandShapeOp, - rewriter) - .failed()) + SmallVector<OpFoldResult> offsets, sizes, strides; + if (failed(getCollapsedExtractSliceInfo(rewriter, sliceOp, reassociation, + offsets, sizes, strides))) return failure(); - // The tensor.extract_slice before applying the pattern works on the result - // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) - // referring to the state before applying the pattern are named with the - // prefix "expanded", and ones referring to the state after applying the - // pattern are named with the prefix "collapsed". - SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets(); - SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); - SmallVector<OpFoldResult> expandedShape = - getMixedValues(expandShapeOp.getStaticOutputShape(), - expandShapeOp.getOutputShape(), rewriter); - - // Helper variables and function for accumulating the size values. - Location loc = expandShapeOp->getLoc(); - AffineExpr d0, d1, d2; - bindDims(rewriter.getContext(), d0, d1, d2); - // Multiply two integers. - auto mul = [&](OpFoldResult v1, OpFoldResult v2) { - auto mulMap = AffineMap::get(2, 0, {d0 * d1}); - return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap, - {v1, v2}); - }; - - // Compute new offsets, sizes, and strides for tensor.extract_slice. - // The new tensor.extract_slice will work on a tensor that has has a rank of - // ReassociationIndices.size(). In the loop a single offset, size, and - // stride value is computed per reassociation group. - SmallVector<OpFoldResult> collapsedOffsets, collapsedSizes, - collapsedStrides; - for (const ReassociationIndices &indices : - expandShapeOp.getReassociationIndices()) { - // collapsedSize will hold the size of the single dim that represents the - // reassociation group in the non expanded tensor. - OpFoldResult collapsedSize = rewriter.getIndexAttr(1); - // The reassocGroupSizes and reassocGroupOffsets are used to create an - // affine.linearize_index op to linearize the single offset value required - // for this reassociation group. - SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets; - - for (long expandedDim : indices) { - // reassocGroupSizes and reassocGroupOffsets can be obtained directly - // from the expanded state, but the collapsed size requires calculation - // as it did not previously exist. - reassocGroupSizes.push_back(expandedShape[expandedDim]); - reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); - collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); - } - - SmallVector<Value> offsetVals = - llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { - return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); - }); - OpFoldResult collapsedOffset = - affine::AffineLinearizeIndexOp::create(rewriter, loc, offsetVals, - reassocGroupSizes, - /*disjoint=*/true) - .getResult(); - collapsedOffsets.push_back(collapsedOffset); - collapsedSizes.push_back(collapsedSize); - - // Only unit stride is supported. - collapsedStrides.push_back(rewriter.getIndexAttr(1)); - } - // The shape of the result can be obtained from the sizes passed in. - SmallVector<Value> dynDims; - SmallVector<int64_t> shape; - dispatchIndexOpFoldResults(expandedSizes, dynDims, shape); - RankedTensorType resultType = RankedTensorType::get( - shape, expandShapeOp.getResultType().getElementType()); + SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); + RankedTensorType resultType = sliceOp.getResultType(); // Create a new ExtractSliceOp and ExpandShapeOp. + Location loc = sliceOp.getLoc(); Value newSliceOp = tensor::ExtractSliceOp::create( - rewriter, loc, expandShapeOp.getSrc(), collapsedOffsets, collapsedSizes, - collapsedStrides); + rewriter, loc, expandShapeOp.getSrc(), offsets, sizes, strides); rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>( sliceOp, resultType, newSliceOp, expandShapeOp.getReassociationIndices(), expandedSizes); return success(); } - - // Helper function to check if all the required conditions for the - // tensor.extract_slice to be bubbled up through the tensor.expand_shape are - // met. - LogicalResult - checkPreconditionForBubbleUpExtractSlice(tensor::ExtractSliceOp sliceOp, - tensor::ExpandShapeOp expandShapeOp, - PatternRewriter &rewriter) const { - - if (!expandShapeOp) { - return rewriter.notifyMatchFailure( - sliceOp, "tensor.extract_slice source not produced by expand_shape"); - } - - if (!sliceOp.hasUnitStride()) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: non-unit stride. Only contiguous slices can " - "be supported in this transformation."); - } - - SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); - SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); - - if (static_cast<size_t>(sliceOp.getResultType().getRank()) != - sizes.size()) { - return rewriter.notifyMatchFailure(sliceOp, - "unimplemented: rank reducing slice"); - } - - SmallVector<OpFoldResult> outputShape = - getMixedValues(expandShapeOp.getStaticOutputShape(), - expandShapeOp.getOutputShape(), rewriter); - - std::function<bool(OpFoldResult, OpFoldResult, OpFoldResult)> - isZeroOffsetAndFullSize = - [](OpFoldResult offset, OpFoldResult sliceSize, OpFoldResult size) { - if (!isZeroInteger(offset)) - return false; - FailureOr<bool> maybeEqual = - ValueBoundsConstraintSet::areEqual(sliceSize, size); - return llvm::succeeded(maybeEqual) && maybeEqual.value(); - }; - - // Check that the slice is contiguous within each reassociation group. - // The slice is contiguous only if after the first dimension where a non - // unit slice is taken, the slice size on all subsequent dimensions of the - // group is equal to the entire size of the dimension. - // Examples of contiguous slices: - // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] - // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] - // Examples of non contiguous slices: - // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] - // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] - for (const ReassociationIndices &indices : - expandShapeOp.getReassociationIndices()) { - int64_t i = 0; - int64_t e = indices.size(); - // Find the first expanded dim after the first dim with non-unit extracted - // size. - for (; i < e; ++i) { - if (!isOneInteger(sizes[indices[i]])) { - // +1 to skip the first non-unit size dim. - i++; - break; - } - } - - // Verify that all subsequent dimensions extract the full size of the - // source tensor. - for (; i < e; ++i) { - int64_t expandedDim = indices[i]; - if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], - outputShape[expandedDim])) { - return rewriter.notifyMatchFailure( - sliceOp, "Not a contiguous slice of the expanded tensor."); - } - } - } - - return success(); - } }; /// Converts `tensor.extract_slice(tensor.collapse_shape)` to @@ -582,170 +441,281 @@ struct BubbleUpCollapseShapeThroughExtractSlice "tensor.extract_slice source not produced by tensor.collapse_shape"); } - if (!sliceOp.hasUnitStride()) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: non-unit stride. Only contiguous slices can " - "be supported in this transformation."); - } + SmallVector<OpFoldResult> offsets, sizes, strides; + if (failed(getExpandedExtractSliceInfo( + rewriter, sliceOp, collapseShapeOp.getReassociationIndices(), + collapseShapeOp.getSrcType().getShape(), offsets, sizes, strides))) + return failure(); - // The tensor.extract_slice before applying the pattern works on the result - // of the tensor.collapse_shape, so variables (i.e. inputs for - // ExtractSliceOp) referring to the state before applying the pattern are - // named with the prefix "collapsed", and ones referring to the state after - // applying the pattern are named with the prefix "expanded". - SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets(); - SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes(); - - if (static_cast<size_t>(sliceOp.getResultType().getRank()) != - collapsedSizes.size()) { - return rewriter.notifyMatchFailure(sliceOp, - "unimplemented: rank reducing slice"); - } + Value newSliceOp = tensor::ExtractSliceOp::create( + rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), offsets, + sizes, strides); + rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( + sliceOp, sliceOp.getResultType(), newSliceOp, + collapseShapeOp.getReassociationIndices()); - ArrayRef<int64_t> srcShape = collapseShapeOp.getSrcType().getShape(); - SmallVector<ReassociationIndices, 4> reassociationIndices = - collapseShapeOp.getReassociationIndices(); - - // Compute new offsets, sizes, and strides for tensor.extract_slice. - // The new tensor.extract_slice will work on a tensor that has has a rank - // equal to the rank of the src of the collapse_shape. In each iteration of - // the loop, the offsets and sizes will be computed per reassociation group. - SmallVector<OpFoldResult> expandedOffsets, expandedSizes; - SmallVector<OpFoldResult> expandedStrides(srcShape.size(), - rewriter.getIndexAttr(1)); - - for (auto [collapsedSize, collapsedOffset, reassocIndices] : - llvm::zip_equal(collapsedSizes, collapsedOffsets, - collapseShapeOp.getReassociationIndices())) { - // CASE #1 - size and/or offset are dynamic. - // In this case, the slice can be represented as a contiguous slice only - // if there is a single dimension in the reassociation group that has a - // size not equal to 1. - if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) { - int nonUnitSizeCount = 0; - for (int64_t expandedShapeIdx : reassocIndices) { - if (srcShape[expandedShapeIdx] != 1) { - nonUnitSizeCount++; - expandedSizes.push_back(collapsedSize); - expandedOffsets.push_back(collapsedOffset); - continue; - } - - expandedSizes.push_back(rewriter.getIndexAttr(1)); - expandedOffsets.push_back(rewriter.getIndexAttr(0)); - } + return success(); + } +}; - if (nonUnitSizeCount != 1) { - return rewriter.notifyMatchFailure( - sliceOp, - "unsupported: slice cannot be verified to be contiguous"); - } - continue; - } +} // namespace - // CASE #2 = size and offset are static. - // Verify that the slice can be represented as a contiguous slice of the - // src of the collapse_shape. - // Checking this is done on order of most internal dimensions first, - // so traversal is done in reverse order of the reassociation group. - // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, - // ...,An] then we first find the size and offset for n...k+1 then for k - // and then for k-1...0. - - // currentCollapsedsize and currentCollapsedOffset are initialized with - // the original collapsed size and offset and divided by the expanded - // shape size in each dimension as we go along the reassociation group. - // In essence we are spreading the original collapsed size and offset over - // the various expanded slice dimensions. - // The variables are used both to check the validity of the slice and to - // compute the expanded sizes and offsets. - int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); - int64_t currentCollapsedOffset = - getConstantIntValue(collapsedOffset).value(); - - SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets; - - ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), - reassocIndices.rend()); - int64_t idx = 0; - int64_t reassocGroupSize = reassocIndices.size(); - - // First handle the trailing dimensions where the slice size should be - // equal to the tensor shape and the offset should be 0 (n...k+1). - for (; idx < reassocGroupSize; ++idx) { - int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - - if (currentCollapsedsize < expandedShapeSize) - break; - - // We need to make sure that the slice size can be set to the shape size - // and the offset to 0. - if ((currentCollapsedsize % expandedShapeSize) != 0 || - (currentCollapsedOffset % expandedShapeSize) != 0) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: cannot be extracted as a contiguous slice " - "of the src of the collapse_shape"); - } +LogicalResult mlir::tensor::getCollapsedExtractSliceInfo( + OpBuilder &b, tensor::ExtractSliceOp sliceOp, + ArrayRef<ReassociationIndices> reassociation, + SmallVectorImpl<OpFoldResult> &collapsedOffsets, + SmallVectorImpl<OpFoldResult> &collapsedSizes, + SmallVectorImpl<OpFoldResult> &collapsedStrides) { + if (!sliceOp.hasUnitStride()) { + return failure(); + } + + SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets(); + SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes(); - groupExpandedSizes.push_back(rewriter.getIndexAttr(expandedShapeSize)); - groupExpandedOffsets.push_back(rewriter.getIndexAttr(0)); + if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) { + return failure(); + } - currentCollapsedsize /= expandedShapeSize; - currentCollapsedOffset /= expandedShapeSize; + auto isZeroOffsetAndFullSize = [&](OpFoldResult offset, + OpFoldResult sliceSize, int64_t inputDim) { + if (!isZeroInteger(offset)) + return false; + ValueBoundsConstraintSet::Variable inputSize(sliceOp.getSource(), inputDim); + FailureOr<bool> maybeEqual = + ValueBoundsConstraintSet::areEqual(sliceSize, inputSize); + return llvm::succeeded(maybeEqual) && maybeEqual.value(); + }; + + // Check that the slice is contiguous within each reassociation group. + // The slice is contiguous only if after the first dimension where a non + // unit slice is taken, the slice size on all subsequent dimensions of the + // group is equal to the entire size of the dimension. + // Examples of contiguous slices: + // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 1, 10] + // full sizes: [5, 10] slice offsets: [3, 0] slice sizes: [2, 10] + // Examples of non contiguous slices: + // full sizes: [8, 8, 10] slice offsets: [0, 0, 0] slice sizes: [1, 2, 5] + // full sizes: [5, 10] slice offsets: [0, 4] slice sizes: [2, 5] + for (const ReassociationIndices &indices : reassociation) { + int64_t i = 0; + int64_t e = indices.size(); + // Find the first expanded dim after the first dim with non-unit extracted + // size. + for (; i < e; ++i) { + if (!isOneInteger(sizes[indices[i]])) { + // +1 to skip the first non-unit size dim. + i++; + break; } + } + + // Verify that all subsequent dimensions extract the full size of the + // source tensor. + for (; i < e; ++i) { + int64_t expandedDim = indices[i]; + if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim], + expandedDim)) { + return failure(); + } + } + } + + // The tensor.extract_slice before applying the pattern works on the result + // of the tensor.expand_shape, so variables (i.e. inputs for ExtractSliceOp) + // referring to the state before applying the pattern are named with the + // prefix "expanded", and ones referring to the state after applying the + // pattern are named with the prefix "collapsed". + Location loc = sliceOp.getLoc(); + SmallVector<OpFoldResult> expandedOffsets = sliceOp.getMixedOffsets(); + SmallVector<OpFoldResult> expandedSizes = sliceOp.getMixedSizes(); + SmallVector<OpFoldResult> expandedShape = + getMixedSizes(b, loc, sliceOp.getSource()); + + // Helper variables and function for accumulating the size values. + AffineExpr d0, d1, d2; + bindDims(b.getContext(), d0, d1, d2); + // Multiply two integers. + auto mul = [&](OpFoldResult v1, OpFoldResult v2) { + auto mulMap = AffineMap::get(2, 0, {d0 * d1}); + return affine::makeComposedFoldedAffineApply(b, loc, mulMap, {v1, v2}); + }; + + // Compute new offsets, sizes, and strides for tensor.extract_slice. + // The new tensor.extract_slice will work on a tensor that has has a rank of + // ReassociationIndices.size(). In the loop a single offset, size, and + // stride value is computed per reassociation group. + for (const ReassociationIndices &indices : reassociation) { + // collapsedSize will hold the size of the single dim that represents the + // reassociation group in the non expanded tensor. + OpFoldResult collapsedSize = b.getIndexAttr(1); + // The reassocGroupSizes and reassocGroupOffsets are used to create an + // affine.linearize_index op to linearize the single offset value required + // for this reassociation group. + SmallVector<OpFoldResult> reassocGroupSizes, reassocGroupOffsets; + + for (long expandedDim : indices) { + // reassocGroupSizes and reassocGroupOffsets can be obtained directly + // from the expanded state, but the collapsed size requires calculation + // as it did not previously exist. + reassocGroupSizes.push_back(expandedShape[expandedDim]); + reassocGroupOffsets.push_back(expandedOffsets[expandedDim]); + collapsedSize = mul(collapsedSize, expandedSizes[expandedDim]); + } + + SmallVector<Value> offsetVals = + llvm::map_to_vector(reassocGroupOffsets, [&](OpFoldResult ofr) { + return getValueOrCreateConstantIndexOp(b, loc, ofr); + }); + OpFoldResult collapsedOffset = affine::AffineLinearizeIndexOp::create( + b, loc, offsetVals, reassocGroupSizes, + /*disjoint=*/true) + .getResult(); + collapsedOffsets.push_back(collapsedOffset); + collapsedSizes.push_back(collapsedSize); + + // Only unit stride is supported. + collapsedStrides.push_back(b.getIndexAttr(1)); + } + return success(); +} + +LogicalResult mlir::tensor::getExpandedExtractSliceInfo( + OpBuilder &b, tensor::ExtractSliceOp sliceOp, + ArrayRef<ReassociationIndices> reassociation, + ArrayRef<int64_t> expandedShape, + SmallVectorImpl<OpFoldResult> &expandedOffsets, + SmallVectorImpl<OpFoldResult> &expandedSizes, + SmallVectorImpl<OpFoldResult> &expandedStrides) { + if (!sliceOp.hasUnitStride()) { + return failure(); + } + + // The tensor.extract_slice before applying the pattern works on the result + // of the tensor.collapse_shape, so variables (i.e. inputs for + // ExtractSliceOp) referring to the state before applying the pattern are + // named with the prefix "collapsed", and ones referring to the state after + // applying the pattern are named with the prefix "expanded". + SmallVector<OpFoldResult> collapsedOffsets = sliceOp.getMixedOffsets(); + SmallVector<OpFoldResult> collapsedSizes = sliceOp.getMixedSizes(); + if (static_cast<size_t>(sliceOp.getResultType().getRank()) != + collapsedSizes.size()) { + return failure(); + } - // Now handle the first dim where slicing occurs on (k). - if (idx < reassocGroupSize) { - int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; - // We need to make sure that the slice size in this dim + offset will - // not exceed the shape size. - if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { - return rewriter.notifyMatchFailure( - sliceOp, "unsupported: slice cannot be extracted as a contiguous " - "slice of the src of the collapse_shape"); + // Compute new offsets, sizes, and strides for tensor.extract_slice. + // The new tensor.extract_slice will work on a tensor that has has a rank + // equal to the rank of the src of the collapse_shape. In each iteration of + // the loop, the offsets and sizes will be computed per reassociation group. + expandedStrides.resize(expandedShape.size(), b.getIndexAttr(1)); + for (auto [collapsedSize, collapsedOffset, reassocIndices] : + llvm::zip_equal(collapsedSizes, collapsedOffsets, reassociation)) { + // CASE #1 - size and/or offset are dynamic. + // In this case, the slice can be represented as a contiguous slice only + // if there is a single dimension in the reassociation group that has a + // size not equal to 1. + if (isa<Value>(collapsedSize) || isa<Value>(collapsedOffset)) { + int nonUnitSizeCount = 0; + for (int64_t expandedShapeIdx : reassocIndices) { + if (expandedShape[expandedShapeIdx] != 1) { + nonUnitSizeCount++; + expandedSizes.push_back(collapsedSize); + expandedOffsets.push_back(collapsedOffset); + continue; } - groupExpandedSizes.push_back( - rewriter.getIndexAttr(currentCollapsedsize)); - groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); + expandedSizes.push_back(b.getIndexAttr(1)); + expandedOffsets.push_back(b.getIndexAttr(0)); + } - currentCollapsedOffset /= expandedShapeSize; + if (nonUnitSizeCount != 1) { + return failure(); } + continue; + } - // Now handle the leading dimensions where the slice size is equal to 1 - // (k-1...0). - // The size for these dimensions must be 1 because of how we constructed - // the slice size of the expanded shape. We spread the original collapsed - // size over the expanded shape sizes until we reached dimension k where - // the remaining size was smaller than the expanded shape size, and spread - // the remaining size on it. So, now we are left with only 1s. - for (idx++; idx < reassocGroupSize; ++idx) { - int64_t expandedShapeSize = srcShape[reversedReassocIndices[idx]]; - int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; - groupExpandedSizes.push_back(rewriter.getIndexAttr(1)); - groupExpandedOffsets.push_back(rewriter.getIndexAttr(offsetInDim)); - currentCollapsedOffset /= expandedShapeSize; + // CASE #2 = size and offset are static. + // Verify that the slice can be represented as a contiguous slice of the + // src of the collapse_shape. + // Checking this is done on order of most internal dimensions first, + // so traversal is done in reverse order of the reassociation group. + // If the expected slice shape is [1, 1, ..., 1, Sk, Ak + 1, Ak + 2, + // ...,An] then we first find the size and offset for n...k+1 then for k + // and then for k-1...0. + + // currentCollapsedsize and currentCollapsedOffset are initialized with + // the original collapsed size and offset and divided by the expanded + // shape size in each dimension as we go along the reassociation group. + // In essence we are spreading the original collapsed size and offset over + // the various expanded slice dimensions. + // The variables are used both to check the validity of the slice and to + // compute the expanded sizes and offsets. + int64_t currentCollapsedsize = getConstantIntValue(collapsedSize).value(); + int64_t currentCollapsedOffset = + getConstantIntValue(collapsedOffset).value(); + SmallVector<OpFoldResult> groupExpandedSizes, groupExpandedOffsets; + ReassociationIndices reversedReassocIndices(reassocIndices.rbegin(), + reassocIndices.rend()); + int64_t idx = 0; + int64_t reassocGroupSize = reassocIndices.size(); + + // First handle the trailing dimensions where the slice size should be + // equal to the tensor shape and the offset should be 0 (n...k+1). + for (; idx < reassocGroupSize; ++idx) { + int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; + + if (currentCollapsedsize < expandedShapeSize) + break; + + // We need to make sure that the slice size can be set to the shape size + // and the offset to 0. + if ((currentCollapsedsize % expandedShapeSize) != 0 || + (currentCollapsedOffset % expandedShapeSize) != 0) { + return failure(); } - expandedSizes.append(groupExpandedSizes.rbegin(), - groupExpandedSizes.rend()); - expandedOffsets.append(groupExpandedOffsets.rbegin(), - groupExpandedOffsets.rend()); + groupExpandedSizes.push_back(b.getIndexAttr(expandedShapeSize)); + groupExpandedOffsets.push_back(b.getIndexAttr(0)); + + currentCollapsedsize /= expandedShapeSize; + currentCollapsedOffset /= expandedShapeSize; } - Value newSliceOp = tensor::ExtractSliceOp::create( - rewriter, collapseShapeOp->getLoc(), collapseShapeOp.getSrc(), - expandedOffsets, expandedSizes, expandedStrides); - rewriter.replaceOpWithNewOp<tensor::CollapseShapeOp>( - sliceOp, sliceOp.getResultType(), newSliceOp, - collapseShapeOp.getReassociationIndices()); + // Now handle the first dim where slicing occurs on (k). + if (idx < reassocGroupSize) { + int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; + int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; + // We need to make sure that the slice size in this dim + offset will + // not exceed the shape size. + if ((currentCollapsedsize + offsetInDim) >= expandedShapeSize) { + return failure(); + } + groupExpandedSizes.push_back(b.getIndexAttr(currentCollapsedsize)); + groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim)); + currentCollapsedOffset /= expandedShapeSize; + } - return success(); + // Now handle the leading dimensions where the slice size is equal to 1 + // (k-1...0). + // The size for these dimensions must be 1 because of how we constructed + // the slice size of the expanded shape. We spread the original collapsed + // size over the expanded shape sizes until we reached dimension k where + // the remaining size was smaller than the expanded shape size, and spread + // the remaining size on it. So, now we are left with only 1s. + for (idx++; idx < reassocGroupSize; ++idx) { + int64_t expandedShapeSize = expandedShape[reversedReassocIndices[idx]]; + int64_t offsetInDim = currentCollapsedOffset % expandedShapeSize; + groupExpandedSizes.push_back(b.getIndexAttr(1)); + groupExpandedOffsets.push_back(b.getIndexAttr(offsetInDim)); + currentCollapsedOffset /= expandedShapeSize; + } + expandedSizes.append(groupExpandedSizes.rbegin(), + groupExpandedSizes.rend()); + expandedOffsets.append(groupExpandedOffsets.rbegin(), + groupExpandedOffsets.rend()); } -}; - -} // namespace + return success(); +} void mlir::tensor::populateReassociativeReshapeFoldingPatterns( RewritePatternSet &patterns) { diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index e3cba388..8d63646 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -122,8 +122,9 @@ struct PoolPadFoldAdaptor<tosa::MaxPool2dOp> { const APFloat lowestVal = APFloat::getLargest(padConstVal.getSemantics(), true); return padConstVal == lowestVal; - } else if (auto padConstIntAttr = - mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) { + } + if (auto padConstIntAttr = + mlir::dyn_cast<DenseIntElementsAttr>(padConstAttr)) { const APInt padConstVal = *padConstIntAttr.begin(); const unsigned int bitWidth = padConstVal.getBitWidth(); const APInt lowestVal = @@ -555,7 +556,8 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> { // Check we have a valid NaN propagation combination. const auto opNanMode = op.getNanMode(); const auto clampNanMode = clampOp.getNanMode(); - if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE") + if (opNanMode == NanPropagationMode::IGNORE && + clampNanMode == NanPropagationMode::PROPAGATE) return failure(); auto maxValAttr = op.getMaxValAttr(); @@ -636,10 +638,16 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> { } } + auto newMode = (opNanMode != clampNanMode) + ? tosa::NanPropagationMode::IGNORE + : opNanMode; + + auto newModeAttr = + NanPropagationModeAttr::get(rewriter.getContext(), newMode); + rewriter.replaceOpWithNewOp<tosa::ClampOp>( op, op.getType(), clampOp.getInput(), newMinValAttr, newMaxValAttr, - rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE" - : opNanMode)); + newModeAttr); return success(); } }; @@ -1120,13 +1128,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { } if (rhsTy == resultTy) { - if (isSplatZero(resultETy, lhsAttr)) + if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape()) + // constant values can only be resized if resulting type is static return lhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, lhsAttr, shift)) return rhs; } if (lhsTy == resultTy) { - if (isSplatZero(resultETy, rhsAttr)) + if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape()) return rhsAttr.resizeSplat(resultTy); if (isSplatOne(resultETy, rhsAttr, shift)) return lhs; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 3cafb19..bd7aee5 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -270,6 +270,244 @@ void mlir::tosa::printVariableOpTypeOrInitialValue( } } +namespace { + +// parse attributes with special handling for tosa enum attributes +template <typename EnumType> +ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser, + NamedAttrList &outAttrs) { + llvm::StringRef name; + if (parser.parseOptionalKeyword(&name) || parser.parseEqual()) + return failure(); + + // special handling: rounding_mode accepts a *bare* RoundingMode enum + // keyword. + llvm::StringRef kw; + if constexpr (std::is_same_v<EnumType, tosa::RoundingMode>) { + if (name == "rounding_mode" && + succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeRoundingMode(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid rounding_mode value: " << kw; + auto attr = RoundingModeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + // special handling: mode accepts a *bare* ResizeMode enum keyword. + if constexpr (std::is_same_v<EnumType, tosa::ResizeMode>) { + if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeResizeMode(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid resize mode value: " << kw; + auto attr = ResizeModeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + // special handling: nan_mode accepts a *bare* NanPropagationMode enum + // keyword. + if constexpr (std::is_same_v<EnumType, tosa::NanPropagationMode>) { + if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeNanPropagationMode(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid nan_mode value: " << kw; + auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + + // Default path: parse any normal attribute literal, including fully qualified + // enum keyword + Attribute attr; + return parser.parseAttribute(attr, name, outAttrs); +} + +template <typename EnumType> +ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) { + // parse operands + SmallVector<OpAsmParser::UnresolvedOperand, 5> operands; + if (parser.parseCommaSeparatedList( + [&]() { return parser.parseOperand(operands.emplace_back()); })) + return failure(); + + // Parse { attr-dict } with special handling for enum bare token + NamedAttrList attrs; + if (succeeded(parser.parseOptionalLBrace()) && + failed(parser.parseOptionalRBrace())) { + do { + if (parseAttrEntryWithEnumHandling<EnumType>(parser, attrs)) + return failure(); + } while (succeeded(parser.parseOptionalComma())); + if (parser.parseRBrace()) + return failure(); + } + + FunctionType fnTy; + if (parser.parseColonType(fnTy)) + return failure(); + + // Resolve operands and types + if (failed(parser.resolveOperands(operands, fnTy.getInputs(), + parser.getCurrentLocation(), + result.operands))) + return failure(); + + result.addTypes(fnTy.getResult(0)); + result.addAttributes(attrs); + + return success(); +} + +void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) { + parser << namedAttr.getName().strref() << " = "; + auto attr = namedAttr.getValue(); + if (auto roundingModeAttr = dyn_cast<tosa::RoundingModeAttr>(attr)) { + parser << roundingModeAttr.getValue(); + } else if (auto resizeModeAttr = dyn_cast<tosa::ResizeModeAttr>(attr)) { + parser << resizeModeAttr.getValue(); + } else if (auto nanPropagationModeAttr = + dyn_cast<tosa::NanPropagationModeAttr>(attr)) { + parser << nanPropagationModeAttr.getValue(); + } else { + parser.printAttribute(attr); + } +} + +// print with special handling for default valued NanPropagationMode attribute +void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) { + parser << " "; + parser.printOperands(op->getOperands()); + + NamedAttrList toPrint(op->getAttrs()); + // remove default NanPropagate attribute + const auto kDefaultNanValue = NanPropagationMode::PROPAGATE; + for (auto attr : op->getAttrs()) { + if (auto nanAttr = dyn_cast<NanPropagationModeAttr>(attr.getValue())) { + if (nanAttr.getValue() == kDefaultNanValue) { + // elide from toPrint + toPrint.erase(attr.getName()); + break; + } + } + } + + if (!toPrint.empty()) { + parser << " {"; + llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) { + printNamedAttr(parser, namedAttr); + }); + parser << "}"; + } + + parser << " : "; + parser.printFunctionalType(op); +} + +// print with special handling for enums: RoundingMode, ResizeMode +void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) { + parser << " "; + parser.printOperands(op->getOperands()); + + if (!op->getAttrs().empty()) { + parser << " {"; + llvm::interleaveComma(op->getAttrs(), parser, + [&](const NamedAttribute namedAttr) { + printNamedAttr(parser, namedAttr); + }); + parser << "}"; + } + + parser << " : "; + parser.printFunctionalType(op); +} + +} // namespace + +ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::RoundingMode>(parser, result); +} + +void RescaleOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::RoundingMode>(parser, result); +} + +void ApplyScaleOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::ResizeMode>(parser, result); +} + +void ResizeOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void ArgMaxOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void MaxPool2dOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void ClampOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void MaximumOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void MinimumOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void ReduceMaxOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + +ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseWithEnumHandling<tosa::NanPropagationMode>(parser, result); +} + +void ReduceMinOp::print(OpAsmPrinter &parser) { + printWithNanPropagationHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp index c7b9534..790bbf7 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp @@ -508,14 +508,15 @@ private: bool attributeCheckRescale(Operation *op) { if (auto rescale = dyn_cast<tosa::RescaleOp>(op)) { - if (rescale.getRoundingMode() == "DOUBLE_ROUND" && + if (rescale.getRoundingMode() == RoundingMode::DOUBLE_ROUND && !targetEnv.allows(Extension::doubleround)) { op->emitOpError() << "failed attribute check: rounding_mode = DOUBLE_ROUND " << "requires extension [doubleround]"; return false; - } else if (rescale.getRoundingMode() == "INEXACT_ROUND" && - !targetEnv.allows(Extension::inexactround)) { + } + if (rescale.getRoundingMode() == RoundingMode::INEXACT_ROUND && + !targetEnv.allows(Extension::inexactround)) { op->emitOpError() << "failed attribute check: rounding_mode = INEXACT_ROUND " << "requires extension [inexactround]"; @@ -1122,7 +1123,7 @@ bool checkErrorIfRescale(Operation *op) { } // ERROR_IF(!scale32 && (rounding_mode == DOUBLE_ROUND)) - if (!scale32 && roundingMode == "DOUBLE_ROUND") { + if (!scale32 && roundingMode == RoundingMode::DOUBLE_ROUND) { op->emitOpError() << "DOUBLE_ROUND is only allowed with scale32=true."; return false; } @@ -1307,7 +1308,8 @@ bool TosaValidation::isValidElementType(Type type, const bool allowUnsigned) { if (isa<FloatType>(type)) { return isa<Float32Type, Float16Type, BFloat16Type, Float8E4M3FNType, Float8E5M2Type>(type); - } else if (auto intTy = dyn_cast<IntegerType>(type)) { + } + if (auto intTy = dyn_cast<IntegerType>(type)) { if (intTy.isSignless()) { switch (intTy.getWidth()) { case 1: diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 9266a63..48df1a0 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -37,16 +37,13 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/InterleavedRange.h" #include <optional> #define DEBUG_TYPE "transform-dialect" -#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ") - #define DEBUG_TYPE_MATCHER "transform-matcher" -#define DBGS_MATCHER() (llvm::dbgs() << "[" DEBUG_TYPE_MATCHER "] ") -#define DEBUG_MATCHER(x) DEBUG_WITH_TYPE(DEBUG_TYPE_MATCHER, x) using namespace mlir; @@ -182,8 +179,7 @@ transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, DiagnosedSilenceableFailure result = state.applyTransform(cast<TransformOpInterface>(transform)); if (result.isSilenceableFailure()) { - LLVM_DEBUG(DBGS() << "alternative failed: " << result.getMessage() - << "\n"); + LDBG() << "alternative failed: " << result.getMessage(); failed = true; break; } @@ -1155,12 +1151,10 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, std::optional<DiagnosedSilenceableFailure> maybeFailure; for (Operation *root : state.getPayloadOps(getRoot())) { WalkResult walkResult = root->walk([&](Operation *op) { - DEBUG_MATCHER({ - DBGS_MATCHER() << "matching "; - op->print(llvm::dbgs(), - OpPrintingFlags().assumeVerified().skipRegions()); - llvm::dbgs() << " @" << op << "\n"; - }); + LDBG(1, DEBUG_TYPE_MATCHER) + << "matching " + << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions()) + << " @" << op; // Try matching. SmallVector<SmallVector<MappedValue>> mappings; @@ -1172,8 +1166,8 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, if (diag.isDefiniteFailure()) return WalkResult::interrupt(); if (diag.isSilenceableFailure()) { - DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() - << " failed: " << diag.getMessage()); + LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName() + << " failed: " << diag.getMessage(); return WalkResult::advance(); } @@ -1304,12 +1298,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, if (!getRestrictRoot() && op == root) return WalkResult::advance(); - DEBUG_MATCHER({ - DBGS_MATCHER() << "matching "; - op->print(llvm::dbgs(), - OpPrintingFlags().assumeVerified().skipRegions()); - llvm::dbgs() << " @" << op << "\n"; - }); + LDBG(1, DEBUG_TYPE_MATCHER) + << "matching " + << OpWithFlags(op, OpPrintingFlags().assumeVerified().skipRegions()) + << " @" << op; firstMatchArgument.clear(); firstMatchArgument.push_back(op); @@ -1322,8 +1314,8 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, if (diag.isDefiniteFailure()) return WalkResult::interrupt(); if (diag.isSilenceableFailure()) { - DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() - << " failed: " << diag.getMessage()); + LDBG(1, DEBUG_TYPE_MATCHER) << "matcher " << matcher.getName() + << " failed: " << diag.getMessage(); continue; } @@ -2173,10 +2165,10 @@ DiagnosedSilenceableFailure transform::MatchOperationEmptyOp::matchOperation( ::std::optional<::mlir::Operation *> maybeCurrent, transform::TransformResults &results, transform::TransformState &state) { if (!maybeCurrent.has_value()) { - DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp success\n"; }); + LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp success"; return DiagnosedSilenceableFailure::success(); } - DEBUG_MATCHER({ DBGS_MATCHER() << "MatchOperationEmptyOp failure\n"; }); + LDBG(1, DEBUG_TYPE_MATCHER) << "MatchOperationEmptyOp failure"; return emitSilenceableError() << "operation is not empty"; } diff --git a/mlir/lib/Dialect/Transform/IR/Utils.cpp b/mlir/lib/Dialect/Transform/IR/Utils.cpp index d666390..773eb13 100644 --- a/mlir/lib/Dialect/Transform/IR/Utils.cpp +++ b/mlir/lib/Dialect/Transform/IR/Utils.cpp @@ -11,6 +11,7 @@ #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" using namespace mlir; @@ -90,7 +91,7 @@ transform::detail::mergeSymbolsInto(Operation *target, // // Rename private symbols in both ops in order to resolve conflicts that can // be resolved that way. - LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n"); + LDBG() << "renaming private symbols to resolve conflicts:"; // TODO: Do we *actually* need to test in both directions? for (auto &&[symbolTable, otherSymbolTable] : llvm::zip( SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable}, @@ -102,7 +103,7 @@ transform::detail::mergeSymbolsInto(Operation *target, if (!symbolOp) continue; StringAttr name = symbolOp.getNameAttr(); - LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n"); + LDBG() << " found @" << name.getValue(); // Check if there is a colliding op in the other module. auto collidingOp = @@ -110,7 +111,7 @@ transform::detail::mergeSymbolsInto(Operation *target, if (!collidingOp) continue; - LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue()); + LDBG() << " collision found for @" << name.getValue(); // Collisions are fine if both opt are functions and can be merged. if (auto funcOp = dyn_cast<FunctionOpInterface>(op), @@ -119,13 +120,12 @@ transform::detail::mergeSymbolsInto(Operation *target, funcOp && collidingFuncOp) { if (canMergeInto(funcOp, collidingFuncOp) || canMergeInto(collidingFuncOp, funcOp)) { - LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and " - "will be merged\n"); + LDBG() << " but both ops are functions and will be merged"; continue; } // If they can't be merged, proceed like any other collision. - LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions"); + LDBG() << " and both ops are function definitions"; } // Collision can be resolved by renaming if one of the ops is private. @@ -133,7 +133,7 @@ transform::detail::mergeSymbolsInto(Operation *target, [&](SymbolOpInterface op, SymbolOpInterface otherOp, SymbolTable &symbolTable, SymbolTable &otherSymbolTable) -> InFlightDiagnostic { - LLVM_DEBUG(llvm::dbgs() << ", renaming\n"); + LDBG() << ", renaming"; FailureOr<StringAttr> maybeNewName = symbolTable.renameToUnique(op, {&otherSymbolTable}); if (failed(maybeNewName)) { @@ -142,8 +142,7 @@ transform::detail::mergeSymbolsInto(Operation *target, << "attempted renaming due to collision with this op"; return diag; } - LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue() - << "\n"); + LDBG() << " renamed to @" << maybeNewName->getValue(); return InFlightDiagnostic(); }; @@ -161,7 +160,7 @@ transform::detail::mergeSymbolsInto(Operation *target, return diag; continue; } - LLVM_DEBUG(llvm::dbgs() << ", emitting error\n"); + LDBG() << ", emitting error"; InFlightDiagnostic diag = symbolOp.emitError() << "doubly defined symbol @" << name.getValue(); diag.attachNote(collidingOp->getLoc()) << "previously defined here"; @@ -179,7 +178,7 @@ transform::detail::mergeSymbolsInto(Operation *target, // Step 2: // // Move all ops from `other` into target and merge public symbols. - LLVM_DEBUG(DBGS() << "moving all symbols into target\n"); + LDBG() << "moving all symbols into target"; { SmallVector<SymbolOpInterface> opsToMove; for (Operation &op : other->getRegion(0).front()) { @@ -193,13 +192,13 @@ transform::detail::mergeSymbolsInto(Operation *target, targetSymbolTable.lookup(op.getNameAttr())); // Move op even if we get a collision. - LLVM_DEBUG(DBGS() << " moving @" << op.getName()); + LDBG() << " moving @" << op.getName(); op->moveBefore(&target->getRegion(0).front(), target->getRegion(0).front().end()); // If there is no collision, we are done. if (!collidingOp) { - LLVM_DEBUG(llvm::dbgs() << " without collision\n"); + LDBG() << " without collision"; continue; } @@ -217,9 +216,9 @@ transform::detail::mergeSymbolsInto(Operation *target, } assert(canMergeInto(funcOp, collidingFuncOp)); - LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at " - << collidingFuncOp.getLoc() << ":\n" - << collidingFuncOp << "\n"); + LDBG() << " with collision, trying to keep op at " + << collidingFuncOp.getLoc() << ":\n" + << collidingFuncOp; // Update symbol table. This works with or without the previous `swap`. targetSymbolTable.remove(funcOp); @@ -239,6 +238,6 @@ transform::detail::mergeSymbolsInto(Operation *target, return target->emitError() << "failed to verify target op after merging symbols"; - LLVM_DEBUG(DBGS() << "done merging ops\n"); + LDBG() << "done merging ops"; return InFlightDiagnostic(); } diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index 14a4fdf..4f4620a 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -312,7 +312,7 @@ LogicalResult transform::TransformState::setParams(Value value, } template <typename Mapping, typename Key, typename Mapped> -void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) { +static void dropMappingEntry(Mapping &mapping, Key key, Mapped mapped) { auto it = mapping.find(key); if (it == mapping.end()) return; @@ -771,7 +771,7 @@ LogicalResult transform::TransformState::checkAndRecordHandleInvalidation( } template <typename T> -DiagnosedSilenceableFailure +static DiagnosedSilenceableFailure checkRepeatedConsumptionInOperand(ArrayRef<T> payload, transform::TransformOpInterface transform, unsigned operandNumber) { diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp index 41955c8..3ced1a6 100644 --- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp @@ -100,12 +100,7 @@ LogicalResult PatternApplicatorExtension::findAllMatches( PatternApplicator applicator(it->second); // We want to discourage direct use of PatternRewriter in APIs but In this // very specific case, an IRRewriter is not enough. - struct TrivialPatternRewriter : public PatternRewriter { - public: - explicit TrivialPatternRewriter(MLIRContext *context) - : PatternRewriter(context) {} - }; - TrivialPatternRewriter rewriter(root->getContext()); + PatternRewriter rewriter(root->getContext()); applicator.applyDefaultCostModel(); root->walk([&](Operation *op) { if (succeeded(applicator.matchAndRewrite(op, rewriter))) diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp index e6ef028..34385d7 100644 --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -276,7 +276,7 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub, if (!ubConstant) return std::nullopt; std::optional<int64_t> stepConstant = getConstantIntValue(step); - if (!stepConstant) + if (!stepConstant || *stepConstant == 0) return std::nullopt; return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index cb4783d..9b2a455 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2402,6 +2402,16 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor, return foldToElementsFromElements(*this, results); } +LogicalResult +ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, + ToElementsOp::Adaptor adaptor, + SmallVectorImpl<Type> &inferredReturnTypes) { + auto vecType = cast<VectorType>(adaptor.getSource().getType()); + Type elType = vecType.getElementType(); + inferredReturnTypes.append(vecType.getNumElements(), elType); + return success(); +} + //===----------------------------------------------------------------------===// // FromElementsOp //===----------------------------------------------------------------------===// @@ -2456,8 +2466,12 @@ static OpFoldResult foldFromElementsToConstant(FromElementsOp fromElementsOp, if (llvm::any_of(elements, [](Attribute attr) { return !attr; })) return {}; + // DenseElementsAttr only supports int/index/float/complex types. auto destVecType = fromElementsOp.getDest().getType(); auto destEltType = destVecType.getElementType(); + if (!destEltType.isIntOrIndexOrFloat() && !isa<ComplexType>(destEltType)) + return {}; + // Constant attributes might have a different type than the return type. // Convert them before creating the dense elements attribute. auto convertedElements = llvm::map_to_vector(elements, [&](Attribute attr) { @@ -2768,8 +2782,8 @@ BroadcastableToResult mlir::vector::isBroadcastableTo( Type srcType, VectorType dstVectorType, std::pair<VectorDim, VectorDim> *mismatchingDims) { // Broadcast scalar to vector of the same element type. - if (srcType.isIntOrIndexOrFloat() && dstVectorType && - getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) + if (isa<VectorElementTypeInterface>(srcType) && dstVectorType && + srcType == getElementTypeOrSelf(dstVectorType)) return BroadcastableToResult::Success; // From now on, only vectors broadcast. VectorType srcVectorType = llvm::dyn_cast<VectorType>(srcType); @@ -3276,6 +3290,18 @@ LogicalResult InsertOp::verify() { return success(); } +// Calculate the linearized position of the continuous chunk of elements to +// insert, based on the shape of the value to insert and the positions to insert +// at. +static int64_t calculateInsertPosition(VectorType destTy, + ArrayRef<int64_t> positions) { + llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0); + assert(positions.size() <= completePositions.size() && + "positions size must be less than or equal to destTy rank"); + copy(positions, completePositions.begin()); + return linearize(completePositions, computeStrides(destTy.getShape())); +} + namespace { // If insertOp is only inserting unit dimensions it can be transformed to a @@ -3313,6 +3339,132 @@ public: return success(); } }; + +/// Pattern to optimize a chain of insertions. +/// +/// This pattern identifies chains of vector.insert operations that: +/// 1. Only insert values at static positions. +/// 2. Completely initialize all elements in the resulting vector. +/// 3. All intermediate insert operations have only one use. +/// +/// When these conditions are met, the entire chain can be replaced with a +/// single vector.from_elements operation. +/// +/// To keep this pattern simple, and avoid spending too much time on matching +/// fragmented insert chains, this pattern only considers the last insert op in +/// the chain. +/// +/// Example transformation: +/// %poison = ub.poison : vector<2xi32> +/// %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32> +/// %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32> +/// -> +/// %result = vector.from_elements %c1, %c2 : vector<2xi32> +class InsertChainFullyInitialized final : public OpRewritePattern<InsertOp> { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InsertOp op, + PatternRewriter &rewriter) const override { + + VectorType destTy = op.getDestVectorType(); + if (destTy.isScalable()) + return failure(); + // Ensure this is the trailing vector.insert op in a chain of inserts. + for (Operation *user : op.getResult().getUsers()) + if (auto insertOp = dyn_cast<InsertOp>(user)) + if (insertOp.getDest() == op.getResult()) + return failure(); + + InsertOp currentOp = op; + SmallVector<InsertOp> chainInsertOps; + while (currentOp) { + // Check cond 1: Dynamic position is not supported. + if (currentOp.hasDynamicPosition()) + return failure(); + + chainInsertOps.push_back(currentOp); + currentOp = currentOp.getDest().getDefiningOp<InsertOp>(); + // Check cond 3: Intermediate inserts have only one use to avoid an + // explosion of vectors. + if (currentOp && !currentOp->hasOneUse()) + return failure(); + } + + int64_t vectorSize = destTy.getNumElements(); + int64_t initializedCount = 0; + SmallVector<bool> initializedDestIdxs(vectorSize, false); + SmallVector<int64_t> pendingInsertPos; + SmallVector<int64_t> pendingInsertSize; + SmallVector<Value> pendingInsertValues; + + for (auto insertOp : chainInsertOps) { + // This pattern can do nothing with poison index. + if (is_contained(insertOp.getStaticPosition(), InsertOp::kPoisonIndex)) + return failure(); + + // Calculate the linearized position for inserting elements. + int64_t insertBeginPosition = + calculateInsertPosition(destTy, insertOp.getStaticPosition()); + + // The valueToStore operand may be a vector or a scalar. Need to handle + // both cases. + int64_t insertSize = 1; + if (auto srcVectorType = + llvm::dyn_cast<VectorType>(insertOp.getValueToStoreType())) + insertSize = srcVectorType.getNumElements(); + + assert(insertBeginPosition + insertSize <= vectorSize && + "insert would overflow the vector"); + + for (auto index : llvm::seq<int64_t>(insertBeginPosition, + insertBeginPosition + insertSize)) { + if (initializedDestIdxs[index]) + continue; + initializedDestIdxs[index] = true; + ++initializedCount; + } + + // Defer the creation of ops before we can make sure the pattern can + // succeed. + pendingInsertPos.push_back(insertBeginPosition); + pendingInsertSize.push_back(insertSize); + pendingInsertValues.push_back(insertOp.getValueToStore()); + + if (initializedCount == vectorSize) + break; + } + + // Check cond 2: all positions must be initialized. + if (initializedCount != vectorSize) + return failure(); + + SmallVector<Value> elements(vectorSize); + for (auto [insertBeginPosition, insertSize, valueToStore] : + llvm::reverse(llvm::zip(pendingInsertPos, pendingInsertSize, + pendingInsertValues))) { + auto srcVectorType = llvm::dyn_cast<VectorType>(valueToStore.getType()); + + if (!srcVectorType) { + elements[insertBeginPosition] = valueToStore; + continue; + } + + SmallVector<Type> elementToInsertTypes(insertSize, + srcVectorType.getElementType()); + // Get all elements from the vector in row-major order. + auto elementsToInsert = rewriter.create<vector::ToElementsOp>( + op.getLoc(), elementToInsertTypes, valueToStore); + for (int64_t linearIdx = 0; linearIdx < insertSize; linearIdx++) { + elements[insertBeginPosition + linearIdx] = + elementsToInsert.getResult(linearIdx); + } + } + + rewriter.replaceOpWithNewOp<vector::FromElementsOp>(op, destTy, elements); + return success(); + } +}; + } // namespace static Attribute @@ -3339,13 +3491,9 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr, !insertOp->hasOneUse()) return {}; - // Calculate the linearized position of the continuous chunk of elements to - // insert. - llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0); - copy(insertOp.getStaticPosition(), completePositions.begin()); + // Calculate the linearized position for inserting elements. int64_t insertBeginPosition = - linearize(completePositions, computeStrides(destTy.getShape())); - + calculateInsertPosition(destTy, insertOp.getStaticPosition()); SmallVector<Attribute> insertedValues; Type destEltType = destTy.getElementType(); @@ -3381,7 +3529,8 @@ static Value foldInsertUseChain(InsertOp insertOp) { void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context); + results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat, + InsertChainFullyInitialized>(context); } OpFoldResult InsertOp::fold(FoldAdaptor adaptor) { @@ -5637,7 +5786,7 @@ LogicalResult GatherOp::verify() { if (resVType.getElementType() != baseType.getElementType()) return emitOpError("base and result element type should match"); - if (llvm::size(getIndices()) != baseType.getRank()) + if (llvm::size(getOffsets()) != baseType.getRank()) return emitOpError("requires ") << baseType.getRank() << " indices"; if (resVType.getShape() != indVType.getShape()) return emitOpError("expected result dim to match indices dim"); @@ -5709,11 +5858,11 @@ public: if (!isa<MemRefType>(op.getBase().getType())) return rewriter.notifyMatchFailure(op, "base must be of memref type"); - if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) + if (failed(isZeroBasedContiguousSeq(op.getIndices()))) return failure(); rewriter.replaceOpWithNewOp<MaskedLoadOp>(op, op.getType(), op.getBase(), - op.getIndices(), op.getMask(), + op.getOffsets(), op.getMask(), op.getPassThru()); return success(); } @@ -5737,7 +5886,7 @@ LogicalResult ScatterOp::verify() { if (valueVType.getElementType() != memType.getElementType()) return emitOpError("base and valueToStore element type should match"); - if (llvm::size(getIndices()) != memType.getRank()) + if (llvm::size(getOffsets()) != memType.getRank()) return emitOpError("requires ") << memType.getRank() << " indices"; if (valueVType.getShape() != indVType.getShape()) return emitOpError("expected valueToStore dim to match indices dim"); @@ -5772,11 +5921,11 @@ public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ScatterOp op, PatternRewriter &rewriter) const override { - if (failed(isZeroBasedContiguousSeq(op.getIndexVec()))) + if (failed(isZeroBasedContiguousSeq(op.getIndices()))) return failure(); rewriter.replaceOpWithNewOp<MaskedStoreOp>( - op, op.getBase(), op.getIndices(), op.getMask(), op.getValueToStore()); + op, op.getBase(), op.getOffsets(), op.getMask(), op.getValueToStore()); return success(); } }; diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 2d5cc07..fe066dc 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns( vector::populateVectorGatherLoweringPatterns(patterns); } +void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorFromElementsLoweringPatterns(patterns); +} + void transform::ApplyLowerScanPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorScanLoweringPatterns(patterns); diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 6619619..546099c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -162,7 +162,7 @@ struct GatherOpInterface return failure(); replaceOpWithNewBufferizedOp<vector::GatherOp>( rewriter, gatherOp, gatherOp.getVectorType(), *buffer, - gatherOp.getIndices(), gatherOp.getIndexVec(), gatherOp.getMask(), + gatherOp.getOffsets(), gatherOp.getIndices(), gatherOp.getMask(), gatherOp.getPassThru()); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt index 9e287fc..acbf2b7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms LowerVectorBitCast.cpp LowerVectorBroadcast.cpp LowerVectorContract.cpp + LowerVectorFromElements.cpp LowerVectorGather.cpp LowerVectorInterleave.cpp LowerVectorMask.cpp diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp new file mode 100644 index 0000000..c22fd54 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp @@ -0,0 +1,65 @@ +//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements target-independent rewrites and utilities to lower the +// 'vector.from_elements' operation. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" + +#define DEBUG_TYPE "lower-vector-from-elements" + +using namespace mlir; + +namespace { + +/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the +/// outermost dimension. For example: +/// ``` +/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32> +/// +/// ==> +/// +/// %0 = ub.poison : vector<2x3xf32> +/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32> +/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32> +/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32> +/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32> +/// ``` +/// +/// When applied exhaustively, this will produce a sequence of 1-d from_elements +/// ops. +struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::FromElementsOp op, + PatternRewriter &rewriter) const override { + ValueRange allElements = op.getElements(); + + auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + size_t subTyNumElements = subTy.getNumElements(); + assert((index + 1) * subTyNumElements <= allElements.size() && + "out of bounds"); + ValueRange subElements = + allElements.slice(index * subTyNumElements, subTyNumElements); + return vector::FromElementsOp::create(rewriter, loc, subTy, subElements); + }; + + return unrollVectorOp(op, rewriter, unrollFromElementsFn); + } +}; + +} // namespace + +void mlir::vector::populateVectorFromElementsLoweringPatterns( + RewritePatternSet &patterns, PatternBenefit benefit) { + patterns.add<UnrollFromElements>(patterns.getContext(), benefit); +} diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp index e062f55..9830189 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp @@ -54,27 +54,13 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> { LogicalResult matchAndRewrite(vector::GatherOp op, PatternRewriter &rewriter) const override { - VectorType resultTy = op.getType(); - if (resultTy.getRank() < 2) - return rewriter.notifyMatchFailure(op, "already 1-D"); - - // Unrolling doesn't take vscale into account. Pattern is disabled for - // vectors with leading scalable dim(s). - if (resultTy.getScalableDims().front()) - return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); - - Location loc = op.getLoc(); - Value indexVec = op.getIndexVec(); + Value indexVec = op.getIndices(); Value maskVec = op.getMask(); Value passThruVec = op.getPassThru(); - Value result = arith::ConstantOp::create(rewriter, loc, resultTy, - rewriter.getZeroAttr(resultTy)); - - VectorType subTy = VectorType::Builder(resultTy).dropDim(0); - - for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { - int64_t thisIdx[1] = {i}; + auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc, + VectorType subTy, int64_t index) { + int64_t thisIdx[1] = {index}; Value indexSubVec = vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx); @@ -82,15 +68,12 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> { vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx); Value passThruSubVec = vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx); - Value subGather = vector::GatherOp::create( - rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec, - maskSubVec, passThruSubVec); - result = - vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx); - } + return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(), + op.getOffsets(), indexSubVec, maskSubVec, + passThruSubVec); + }; - rewriter.replaceOp(op, result); - return success(); + return unrollVectorOp(op, rewriter, unrollGatherFn); } }; @@ -158,18 +141,18 @@ struct RemoveStrideFromGatherSource : OpRewritePattern<vector::GatherOp> { // 2. Generate new gather indices that will model the // strided access. IntegerAttr stride = rewriter.getIndexAttr(srcTrailingDim); - VectorType vType = op.getIndexVec().getType(); + VectorType vType = op.getIndices().getType(); Value mulCst = arith::ConstantOp::create( rewriter, op.getLoc(), vType, DenseElementsAttr::get(vType, stride)); Value newIdxs = - arith::MulIOp::create(rewriter, op.getLoc(), op.getIndexVec(), mulCst); + arith::MulIOp::create(rewriter, op.getLoc(), op.getIndices(), mulCst); // 3. Create an updated gather op with the collapsed input memref and the // updated indices. Value newGather = vector::GatherOp::create( rewriter, op.getLoc(), op.getResult().getType(), collapsed, - op.getIndices(), newIdxs, op.getMask(), op.getPassThru()); + op.getOffsets(), newIdxs, op.getMask(), op.getPassThru()); rewriter.replaceOp(op, newGather); return success(); @@ -212,8 +195,8 @@ struct Gather1DToConditionalLoads : OpRewritePattern<vector::GatherOp> { Value indexVec = rewriter.createOrFold<arith::IndexCastOp>( loc, op.getIndexVectorType().clone(rewriter.getIndexType()), - op.getIndexVec()); - auto baseOffsets = llvm::to_vector(op.getIndices()); + op.getIndices()); + auto baseOffsets = llvm::to_vector(op.getOffsets()); Value lastBaseOffset = baseOffsets.back(); Value result = op.getPassThru(); diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp index 45ef7f0..5617b06 100644 --- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp @@ -269,7 +269,7 @@ public: // Replace the `vector.mask` operation. rewriter.replaceOpWithNewOp<GatherOp>( maskingOp.getOperation(), gatherOp.getVectorType(), gatherOp.getBase(), - gatherOp.getIndices(), gatherOp.getIndexVec(), maskingOp.getMask(), + gatherOp.getOffsets(), gatherOp.getIndices(), maskingOp.getMask(), passthru); return success(); } diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index bb0f339..c84eb2c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -528,8 +528,7 @@ struct WarpOpTransferWrite : public WarpDistributionPattern { LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); Operation *lastNode = yield->getPrevNode(); auto writeOp = dyn_cast_or_null<vector::TransferWriteOp>(lastNode); if (!writeOp) @@ -706,6 +705,52 @@ struct WarpOpConstant : public WarpDistributionPattern { } }; +/// Sink out step op feeding into a warp op yield. +/// Vector step op is treated similar to arith.constant, apart from +/// the result that represents a sequence [0, vec_size). +/// Due to the to vec_size == warp_size limitation, +/// we can simply wrap the lane id into a vector (i.e., broadcast). +/// Supporting vec_size != warp_size may involve preserving the step +/// result and using additional arith ops (the exact details are TBD). +/// ``` +/// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xindex>) { +/// ... +/// %cst = vector.step : vector<32xindex> +/// gpu.yield %cst : vector<1xindex> +/// } +/// ``` +/// To +/// ``` +/// gpu.warp_execute_on_lane_0(%arg0) { +/// ... +/// } +/// %lane_id_vec = vector.broadcast %arg0 : index to vector<1xindex> +struct WarpOpStep final : public WarpDistributionPattern { + using Base::Base; + LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, + PatternRewriter &rewriter) const override { + OpOperand *yieldOperand = + getWarpResult(warpOp, llvm::IsaPred<vector::StepOp>); + if (!yieldOperand) + return failure(); + const unsigned operandIdx = yieldOperand->getOperandNumber(); + auto stepOp = yieldOperand->get().getDefiningOp<vector::StepOp>(); + VectorType resTy = stepOp.getResult().getType(); + if (resTy.getNumElements() != static_cast<int64_t>(warpOp.getWarpSize())) + return rewriter.notifyMatchFailure( + warpOp, + llvm::formatv("Expected result size ({0}) to be of warp size ({1})", + resTy.getNumElements(), warpOp.getWarpSize())); + VectorType newVecTy = + cast<VectorType>(warpOp.getResult(operandIdx).getType()); + rewriter.setInsertionPointAfter(warpOp); + Value laneIdVec = vector::BroadcastOp::create(rewriter, warpOp.getLoc(), + newVecTy, warpOp.getLaneid()); + rewriter.replaceAllUsesWith(warpOp.getResult(operandIdx), laneIdVec); + return success(); + } +}; + /// Sink out transfer_read op feeding into a warp op yield. /// ``` /// %0 = gpu.warp_execute_on_lane_0(%arg0) -> (vector<1xf32>) { @@ -846,8 +891,7 @@ struct WarpOpDeadResult : public WarpDistributionPattern { newYieldValues.reserve(warpOp->getNumResults()); DenseMap<Value, int64_t> dedupYieldOperandPositionMap; DenseMap<OpResult, int64_t> dedupResultPositionMap; - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); // Some values may be yielded multiple times and correspond to multiple // results. Deduplicating occurs by taking each result with its matching @@ -901,8 +945,7 @@ struct WarpOpForwardOperand : public WarpDistributionPattern { using Base::Base; LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); Value valForwarded; unsigned resultIndex; for (OpOperand &operand : yield->getOpOperands()) { @@ -1708,8 +1751,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern { : WarpDistributionPattern(ctx, b), distributionMapFn(std::move(fn)) {} LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto warpOpYield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp warpOpYield = warpOp.getTerminator(); // Only pick up `ForOp` if it is the last op in the region. Operation *lastNode = warpOpYield->getPrevNode(); auto forOp = dyn_cast_or_null<scf::ForOp>(lastNode); @@ -1826,7 +1868,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern { rewriter.setInsertionPointAfter(newWarpOp); auto newForOp = scf::ForOp::create( rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), - forOp.getStep(), newForOpOperands); + forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr, + forOp.getUnsignedCmp()); // Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the // newly created `ForOp`. This `WarpOp` will contain all ops that were // contained within the original `ForOp` body. @@ -2019,7 +2062,7 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns( .add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast, WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant, WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask, - WarpOpExtractStridedSlice, WarpOpInsertStridedSlice>( + WarpOpExtractStridedSlice, WarpOpInsertStridedSlice, WarpOpStep>( patterns.getContext(), benefit); patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn, benefit); diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp index 491b448..7dde631 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp @@ -762,6 +762,42 @@ struct LinearizeVectorStore final } }; +/// This pattern linearizes `vector.from_elements` operations by converting +/// the result type to a 1-D vector while preserving all element values. +/// The transformation creates a linearized `vector.from_elements` followed by +/// a `vector.shape_cast` to restore the original multidimensional shape. +/// +/// Example: +/// +/// %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32> +/// +/// is converted to: +/// +/// %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32> +/// %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32> +/// +struct LinearizeVectorFromElements final + : public OpConversionPattern<vector::FromElementsOp> { + using OpConversionPattern::OpConversionPattern; + LinearizeVectorFromElements(const TypeConverter &typeConverter, + MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit) {} + LogicalResult + matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + VectorType dstTy = + getTypeConverter()->convertType<VectorType>(fromElementsOp.getType()); + assert(dstTy && "vector type destination expected."); + + OperandRange elements = fromElementsOp.getElements(); + assert(elements.size() == static_cast<size_t>(dstTy.getNumElements()) && + "expected same number of elements"); + rewriter.replaceOpWithNewOp<vector::FromElementsOp>(fromElementsOp, dstTy, + elements); + return success(); + } +}; + } // namespace /// This method defines the set of operations that are linearizable, and hence @@ -854,7 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns( patterns .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast, LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad, - LinearizeVectorStore>(typeConverter, patterns.getContext()); + LinearizeVectorStore, LinearizeVectorFromElements>( + typeConverter, patterns.getContext()); } void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns( diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp index c707f38..369857f 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp @@ -98,8 +98,9 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) { // If the user has already been processed skip. if (!processed.insert(user).second) continue; - if (isa<ViewLikeOpInterface>(user)) { - users.append(user->getUsers().begin(), user->getUsers().end()); + if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) { + Value viewDest = viewLike.getViewDest(); + users.append(viewDest.getUsers().begin(), viewDest.getUsers().end()); continue; } if (isMemoryEffectFree(user)) @@ -182,8 +183,9 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) { // If the user has already been processed skip. if (!processed.insert(user).second) continue; - if (isa<ViewLikeOpInterface>(user)) { - users.append(user->getUsers().begin(), user->getUsers().end()); + if (auto viewLike = dyn_cast<ViewLikeOpInterface>(user)) { + Value viewDest = viewLike.getViewDest(); + users.append(viewDest.getUsers().begin(), viewDest.getUsers().end()); continue; } if (isMemoryEffectFree(user) || isa<vector::TransferReadOp>(user)) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index 2269a40..dbb5eb3 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -600,7 +600,7 @@ struct BubbleDownVectorBitCastForExtract // Get the first element of the mixed position as integer. auto mixedPos = extractOp.getMixedPosition(); - if (mixedPos.size() > 0 && !isa<Attribute>(mixedPos[0])) + if (!mixedPos.empty() && !isa<Attribute>(mixedPos[0])) return failure(); uint64_t index = cast<IntegerAttr>(cast<Attribute>(mixedPos[0])).getInt(); @@ -2274,7 +2274,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> { LogicalResult matchAndRewrite(MulOpType mulOp, PatternRewriter &rewriter) const override { - auto resType = llvm::cast<VectorType>(mulOp.getResult().getType()); + auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType()); if (!resType) return failure(); if (resType.getRank() != 2) diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp index 501abec..e8ecb0c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp @@ -640,7 +640,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> { // decomposed shape from each of the index, mask, and pass-through // vectors. Value indexSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( - loc, gatherOp.getIndexVec(), elementOffsets, *targetShape, strides); + loc, gatherOp.getIndices(), elementOffsets, *targetShape, strides); Value maskSubVec = rewriter.createOrFold<vector::ExtractStridedSliceOp>( loc, gatherOp.getMask(), elementOffsets, *targetShape, strides); Value passThruSubVec = @@ -648,7 +648,7 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> { loc, gatherOp.getPassThru(), elementOffsets, *targetShape, strides); auto slicedGather = vector::GatherOp::create( - rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getIndices(), + rewriter, loc, targetType, gatherOp.getBase(), gatherOp.getOffsets(), indexSubVec, maskSubVec, passThruSubVec); result = rewriter.createOrFold<vector::InsertStridedSliceOp>( diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 6e2fa35..841e138 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -392,3 +392,29 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape, } return success(); } + +LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter, + vector::UnrollVectorOpFn unrollFn) { + assert(op->getNumResults() == 1 && "expected single result"); + assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type"); + VectorType resultTy = cast<VectorType>(op->getResult(0).getType()); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(op, "already 1-D"); + + // Unrolling doesn't take vscale into account. Pattern is disabled for + // vectors with leading scalable dim(s). + if (resultTy.getScalableDims().front()) + return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim"); + + Location loc = op->getLoc(); + Value result = ub::PoisonOp::create(rewriter, loc, resultTy); + VectorType subTy = VectorType::Builder(resultTy).dropDim(0); + + for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) { + Value subVector = unrollFn(rewriter, loc, subTy, i); + result = vector::InsertOp::create(rewriter, loc, subVector, result, i); + } + + rewriter.replaceOp(op, result); + return success(); +} diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt index 7c6a4f3..7869a28 100644 --- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt @@ -17,6 +17,8 @@ add_mlir_dialect_library(MLIRXeGPUDialect MLIRAffineUtils MLIRArithUtils MLIRDialectUtils + MLIRGPUDialect + MLIRXeVMDialect MLIRIR MLIRViewLikeInterface MLIRVectorDialect diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp index d997296..7f3be7f 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp @@ -67,7 +67,7 @@ genOffsetsComputingInsts(OpBuilder &builder, Location loc, StaticTileOffsetRange(sizePerWg, distUnit)) { SmallVector<Value> base = llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value { - return builder.create<arith::ConstantIndexOp>(loc, d); + return arith::ConstantIndexOp::create(builder, loc, d); }); SmallVector<Value> adds = llvm::map_to_vector( @@ -80,7 +80,7 @@ genOffsetsComputingInsts(OpBuilder &builder, Location loc, llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value { return builder.createOrFold<index::RemUOp>( loc, std::get<0>(t), - builder.create<arith::ConstantIndexOp>(loc, std::get<1>(t))); + arith::ConstantIndexOp::create(builder, loc, std::get<1>(t))); }); offsets.push_back(mods); @@ -91,7 +91,7 @@ genOffsetsComputingInsts(OpBuilder &builder, Location loc, // Checks if the given shape can be evenly distributed based on the layout // and data factors provided by the LayoutAttr. bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, - xegpu::LayoutAttr attr) { + xegpu::DistributeLayoutAttr attr) { assert(attr && "Layout attribute is missing."); // Checks whether the given shape can be evenly distributed using the @@ -104,52 +104,51 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, // smaller than `layout[i] * data[i]`, allowing multiple compute units to // share the data. auto tryDistribute = [&](llvm::ArrayRef<int64_t> shape, - DenseI32ArrayAttr layout, DenseI32ArrayAttr data, + SmallVector<int64_t> layout, + SmallVector<int64_t> data, bool rr = true) -> optional<SmallVector<int64_t>> { llvm::SmallVector<int64_t> newShape(shape); - if (layout) { - auto vec = llvm::to_vector_of<int64_t>(layout.asArrayRef()); - if (vec.size() != shape.size()) + if (layout.size()) { + if (layout.size() != shape.size()) return std::nullopt; - auto ratio = computeShapeRatio(shape, vec); + auto ratio = computeShapeRatio(shape, layout); if (!ratio.has_value()) return std::nullopt; newShape = ratio.value(); } - if (data) { - auto vec = llvm::to_vector_of<int64_t>(data.asArrayRef()); - if (vec.size() != shape.size()) + if (data.size()) { + if (data.size() != shape.size()) return std::nullopt; - auto ratio = computeShapeRatio(newShape, vec); + auto ratio = computeShapeRatio(newShape, data); if (!ratio.has_value() && rr) - ratio = computeShapeRatio(vec, newShape); + ratio = computeShapeRatio(data, newShape); if (!ratio.has_value()) return std::nullopt; // if data is not null, we always return it for next phase. - newShape = vec; + newShape = data; } return newShape; }; // check the sgLayout and sgData auto maybeSgShape = - tryDistribute(shape, attr.getSgLayout(), attr.getSgData()); + tryDistribute(shape, attr.getSgLayoutAsInt(), attr.getSgDataAsInt()); if (!maybeSgShape) return false; auto sgShape = maybeSgShape.value(); // check InstData, it neither have layout nor need round-robin auto maybeInstShape = - tryDistribute(sgShape, nullptr, attr.getInstData(), false); + tryDistribute(sgShape, {}, attr.getInstDataAsInt(), false); if (!maybeInstShape) return false; auto instShape = maybeInstShape.value(); // check LaneLayout and LaneData - auto maybeLaneShape = - tryDistribute(instShape, attr.getLaneLayout(), attr.getLaneData(), false); + auto maybeLaneShape = tryDistribute(instShape, attr.getLaneLayoutAsInt(), + attr.getLaneDataAsInt(), false); return maybeLaneShape.has_value(); } @@ -271,7 +270,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId) { // delinearizeSubgroupId is only available for // workgroup-level layout attribute - if (!isWgLayout()) + if (!isForWorkgroup()) return failure(); // TODO: handle order attribute @@ -283,29 +282,30 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, if (!hasDefaultOrder()) return mlir::emitError(loc, "order attribute is currently not supported."); - auto dims = llvm::map_to_vector(*getSgLayoutAsInt(), [&](int64_t d) -> Value { + auto dims = llvm::map_to_vector(getSgLayoutAsInt(), [&](int64_t d) -> Value { return builder.createOrFold<arith::ConstantIndexOp>(loc, d); }); return affine::delinearizeIndex(builder, loc, linearId, dims); } -/// Implements LayoutTrait::getOffsets to generate instructions for -/// computing multi-dimensional offsets when distributed by LayoutAttr. +/// Implements DistributeLayoutAttr::getOffsets to generate +/// instructions for computing multi-dimensional offsets when distributed by +/// LayoutAttr. FailureOr<SmallVector<SmallVector<Value>>> LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape) { - if (!isWgLayout()) + if (!isForWorkgroup()) return failure(); - SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value(); - SmallVector<int64_t> sgShape; - if (auto maybeSgShape = getSgDataAsInt()) - sgShape = maybeSgShape.value(); - else if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); - else - return failure(); + SmallVector<int64_t> sgLayout = getSgLayoutAsInt(); + SmallVector<int64_t> sgShape = getSgDataAsInt(); + if (sgShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, sgLayout)) + sgShape = derivedShape.value(); + else + return failure(); + } // delinearize Ids auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); @@ -322,7 +322,7 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, //===----------------------------------------------------------------------===// LogicalResult SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError, - xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) { + xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) { if (!parent || !dims) return emitError() << "expected parent layout and dims attribute"; @@ -340,7 +340,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError, } SliceAttr SliceAttr::flatten() const { - xegpu::LayoutTrait parent = getParent(); + xegpu::DistributeLayoutAttr parent = getParent(); SmallVector<DenseI64ArrayAttr> slicedDims({getDims()}); while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) { @@ -375,23 +375,24 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc, return parent.delinearizeSubgroupId(builder, loc, linearId); } -/// Implements LayoutTrait::getOffsets to generate instructions for -/// computing multi-dimensional offsets when distributed by SliceAttr. +/// Implements DistributeLayoutAttr::getOffsets to generate +/// instructions for computing multi-dimensional offsets when distributed by +/// SliceAttr. FailureOr<SmallVector<SmallVector<Value>>> SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape) { assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape."); - if (!isWgLayout()) + if (!isForWorkgroup()) return failure(); - SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value(); - SmallVector<int64_t> sgShape; - if (auto maybeSgShape = getSgDataAsInt()) - sgShape = maybeSgShape.value(); - else if (auto derivedShape = computeShapeRatio(shape, sgLayout)) - sgShape = derivedShape.value(); - else - return failure(); + SmallVector<int64_t> sgLayout = getSgLayoutAsInt(); + SmallVector<int64_t> sgShape = getSgDataAsInt(); + if (sgShape.empty()) { + if (auto derivedShape = computeShapeRatio(shape, sgLayout)) + sgShape = derivedShape.value(); + else + return failure(); + } // delinearize Ids auto maybeIds = delinearizeSubgroupId(builder, loc, linearId); @@ -427,7 +428,7 @@ RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError, // XeGPU_TensorDescType //===----------------------------------------------------------------------===// -mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { +mlir::Type TensorDescType::parse(AsmParser &parser) { llvm::SmallVector<int64_t> shape; mlir::Type elementType; mlir::FailureOr<mlir::Attribute> encoding; @@ -477,7 +478,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) { layout.value_or(mlir::Attribute())); } -void TensorDescType::print(::mlir::AsmPrinter &printer) const { +void TensorDescType::print(AsmPrinter &printer) const { printer << "<"; auto shape = getShape(); @@ -522,10 +523,10 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape, return Base::get(context, shape, elementType, attr, layout); } -LogicalResult TensorDescType::verify( - llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, - llvm::ArrayRef<int64_t> shape, mlir::Type elementType, - mlir::Attribute encoding, mlir::Attribute layout) { +LogicalResult +TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError, + llvm::ArrayRef<int64_t> shape, mlir::Type elementType, + mlir::Attribute encoding, mlir::Attribute layout) { size_t rank = shape.size(); if (rank == 0) @@ -591,6 +592,119 @@ LogicalResult TensorDescType::verify( return success(); } +//===----------------------------------------------------------------------===// +// XeGPU_MemDescType +//===----------------------------------------------------------------------===// +mlir::Type MemDescType::parse(AsmParser &parser) { + llvm::SmallVector<int64_t> shape; + mlir::Type elementType; + mlir::FailureOr<MemLayoutAttr> layout; + + // Parse literal '<' + if (parser.parseLess()) + return {}; + + auto shapeLoc = parser.getCurrentLocation(); + if (mlir::failed(parser.parseDimensionList(shape, false, true))) { + parser.emitError(shapeLoc, "failed to parse parameter 'shape'"); + return {}; + } + + auto elemTypeLoc = parser.getCurrentLocation(); + if (mlir::failed(parser.parseType(elementType))) { + parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'"); + return {}; + } + + // parse optional attributes + if (mlir::succeeded(parser.parseOptionalComma())) { + MemLayoutAttr attr; + ParseResult res = parser.parseAttribute(attr); + if (mlir::failed(res)) + return {}; + layout = attr; + } + + // Parse literal '>' + if (parser.parseGreater()) + return {}; + + MLIRContext *ctxt = parser.getContext(); + return MemDescType::getChecked( + [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape, + elementType, layout.value_or(MemLayoutAttr())); +} + +void MemDescType::print(AsmPrinter &printer) const { + printer << "<"; + + printer.printDimensionList(getShape()); + printer << 'x'; + printer << getElementType(); + + if (auto layout = getMemLayout()) + printer << ", " << layout; + + printer << ">"; +} + +//===----------------------------------------------------------------------===// +// XeGPU_MemDescType +//===----------------------------------------------------------------------===// + +Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) { + + auto context = parser.getContext(); + llvm::SMLoc loc = parser.getCurrentLocation(); + + llvm::SmallDenseSet<StringRef> seenKeys; + SmallVector<NamedAttribute> attributes; + + auto parseElt = [&]() -> ParseResult { + StringRef nameId; + if (failed(parser.parseKeyword(&nameId))) + return parser.emitError(loc, "expected valid attribute name"); + + if (!seenKeys.insert(nameId).second) + return parser.emitError(loc, "duplicate key '") + << nameId << " in mem layout attribute"; + + if (failed(parser.parseEqual())) + return failure(); + + Attribute attr; + if (failed(parser.parseAttribute(attr))) + return failure(); + attributes.emplace_back(nameId, attr); + return success(); + }; + + // Parse literal '<' + if (parser.parseLess()) + return {}; + + if (failed(parser.parseCommaSeparatedList(parseElt))) + return {}; + + // Parse literal '>' + if (parser.parseGreater()) + return {}; + + return parser.getChecked<MemLayoutAttr>( + loc, context, DictionaryAttr::get(context, attributes)); +} + +void MemLayoutAttr::print(AsmPrinter &printer) const { + printer << "<"; + ArrayRef<NamedAttribute> attrs = getAttrs().getValue(); + for (size_t i = 0; i < attrs.size(); i++) { + printer << attrs[i].getName().str() << " = " << attrs[i].getValue(); + if (i < attrs.size() - 1) + printer << ", "; + } + printer << ">"; +} + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp index fc11fa8..aca6654 100644 --- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp +++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" @@ -21,6 +23,17 @@ namespace mlir { namespace xegpu { +bool isSharedMemory(const MemRefType &memrefTy) { + Attribute attr = memrefTy.getMemorySpace(); + if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr)) + return intAttr.getInt() == 3; + if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr)) + return memrefSpace.getValue() == MemorySpace::SLM; + if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr)) + return xevmSpace.getValue() == xevm::AddrSpace::SHARED; + return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr); +} + template <typename T> static std::string makeString(T array, bool breakline = false) { std::string buf; @@ -45,13 +58,6 @@ static SmallVector<int64_t> getShapeOf(Type type) { return shape; } -static int64_t getRankOf(Value val) { - auto type = val.getType(); - if (auto ty = llvm::dyn_cast<ShapedType>(type)) - return ty.getRank(); - return 0; -} - static bool isReadHintOrNone(const CachePolicyAttr &attr) { if (!attr) return true; @@ -76,13 +82,18 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, if (!tdescTy.isScattered()) return emitError() << "Expects a scattered TensorDesc."; - if (!valueTy) - return emitError() << "Expecting a vector type result."; + auto chunkSize = tdescTy.getChunkSizeAsInt(); + if (!valueTy) { + if (chunkSize > 1) + return emitError() << "Expecting chunk size == 1 for scalar result"; + if (dyn_cast<VectorType>(maskTy)) + return emitError() << "Expecting a vector type result."; + return success(); + } auto maskShape = getShapeOf(maskTy); auto valueShape = getShapeOf(valueTy); auto tdescShape = getShapeOf(tdescTy); - auto chunkSize = tdescTy.getChunkSizeAsInt(); if (valueTy.getElementType() != tdescTy.getElementType()) return emitError() @@ -111,25 +122,49 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy, } static LogicalResult -isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy, - int64_t chunkSize, +isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy, + VectorType valueTy, int64_t chunkSize, function_ref<InFlightDiagnostic()> emitError) { - if (!valueTy) - return emitError() << "Expecting a vector type result."; + auto maskVecTy = dyn_cast<VectorType>(maskTy); + auto offsetsVecTy = dyn_cast<VectorType>(offsetsTy); + if (!valueTy) { + if (chunkSize > 1) + return emitError() << "Expecting chunk size == 1 for scalar result"; + if (maskVecTy || offsetsVecTy) + return emitError() << "Expecting scalar mask and offsets."; + else if (maskVecTy && offsetsVecTy) + return emitError() << "Expecting a vector type result."; + return success(); + } + auto valueSize = valueTy.getNumElements(); + // SIMT mode with scalar mask and offsets. + if (!maskVecTy && !offsetsVecTy) { + if (valueSize != chunkSize) + return emitError() << "value elements must match chunk size " + << chunkSize; + return success(); + } auto maskShape = getShapeOf(maskTy); auto valueShape = getShapeOf(valueTy); - // a valid shape for SIMT case - if (valueTy.getRank() == 1) { - if (valueTy.getNumElements() != chunkSize) - return emitError() << "value elements must match chunk size " << chunkSize - << " for SIMT code."; - return success(); + if (!maskVecTy) + return emitError() << "Expecting a vector type mask."; + int64_t maskSize = maskVecTy.getNumElements(); + + if (chunkSize > 1) { + if ((valueTy.getRank() == 1) && (valueSize != chunkSize)) + return emitError() << "value elements must match chunk size " + << chunkSize; + } else { + if (valueSize != maskSize) + return emitError() + << "Mask should match value except the chunk size dim."; } - llvm::SmallVector<int64_t> expectedMaskShape(valueShape); + if (maskSize == 1) + return success(); if (chunkSize > 1) expectedMaskShape.pop_back(); if (expectedMaskShape != maskShape) @@ -156,41 +191,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, } void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, - Type tdesc, TypedValue<MemRefType> source, + Type tdesc, Value source, llvm::ArrayRef<OpFoldResult> shape, llvm::ArrayRef<OpFoldResult> strides) { - assert(shape.size() && strides.size() && shape.size() == strides.size() && - "Shape and strides must be present and of equal size for ui64 " - "initialization."); + Type srcTy = source.getType(); + assert((isa<IntegerType, MemRefType>(srcTy)) && + "Source has to be either int or memref."); - llvm::SmallVector<int64_t> staticShape; - llvm::SmallVector<int64_t> staticStrides; llvm::SmallVector<Value> dynamicShape; llvm::SmallVector<Value> dynamicStrides; - dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); - dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); - - auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); - auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); - - build(builder, state, tdesc, source, ValueRange({}), dynamicShape, - dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, - staticStridesAttr); -} - -void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, - Type tdesc, TypedValue<IntegerType> source, - llvm::ArrayRef<OpFoldResult> shape, - llvm::ArrayRef<OpFoldResult> strides) { - assert(shape.size() && strides.size() && shape.size() == strides.size() && - "Shape and strides must be present and of equal size for ui64 " - "initialization."); - llvm::SmallVector<int64_t> staticShape; llvm::SmallVector<int64_t> staticStrides; - llvm::SmallVector<Value> dynamicShape; - llvm::SmallVector<Value> dynamicStrides; dispatchIndexOpFoldResults(shape, dynamicShape, staticShape); dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides); @@ -198,6 +210,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape); auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides); + if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) { + auto memrefShape = memrefTy.getShape(); + auto [memrefStrides, _] = memrefTy.getStridesAndOffset(); + + // if shape and strides are from Memref, we don't need attributes for them + // to keep the IR print clean. + if (staticShape == memrefShape && staticStrides == memrefStrides) { + staticShapeAttr = DenseI64ArrayAttr(); + staticStridesAttr = DenseI64ArrayAttr(); + } + } + build(builder, state, tdesc, source, ValueRange({}), dynamicShape, dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr, staticStridesAttr); @@ -265,8 +289,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state, } LogicalResult CreateNdDescOp::verify() { - auto rank = (int64_t)getMixedOffsets().size(); - bool invalidRank = false; + size_t rank = getMixedSizes().size(); + bool invalidRank = rank != getMixedStrides().size(); bool invalidElemTy = false; // Memory space of created TensorDesc should match with the source. @@ -280,31 +304,28 @@ LogicalResult CreateNdDescOp::verify() { << " Source: " << srcMemorySpace << ", TensorDesc: " << tdescMemorySpace; + if (size_t offsetRank = getMixedOffsets().size()) + invalidRank |= (offsetRank != rank); + // check source type matches the rank if it is a memref. // It also should have the same ElementType as TensorDesc. - auto memrefTy = dyn_cast<MemRefType>(getSourceType()); - if (memrefTy) { - invalidRank |= (memrefTy.getRank() != rank); + if (auto memrefTy = dyn_cast<MemRefType>(getSourceType())) invalidElemTy |= memrefTy.getElementType() != getElementType(); - } if (llvm::isa<IntegerType>(getSourceType())) { // strides and shape must present for integer source. if (getMixedStrides().empty() || getMixedSizes().empty()) - return emitOpError("Expecting strides and shape to be present for " + return emitOpError("expecting strides and shape to be present for " "integer source."); } - // mismatches among shape, strides, and offsets are - // already handeled by OffsetSizeAndStrideOpInterface. - // So they are not check here. if (invalidRank) return emitOpError( "Expecting the rank of shape, strides, offsets, and source (if source " "is a memref) should match with each other."); // check result TensorDesc rank - if (getType().getRank() > rank) + if (getType().getRank() > (int64_t)rank) return emitOpError( "Expecting the TensorDesc rank is not greater than the " "ranks of shape, strides, offsets or the memref source."); @@ -360,13 +381,10 @@ ParseResult parseOptionalDynamicIndexList( void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op, OperandRange values, DenseI64ArrayAttr integers) { - - if (!integers) + if (!integers || integers.empty()) return; - - return printDynamicIndexList(printer, op, values, integers, - /*scalableFlags=*/{}, {}, - AsmParser::Delimiter::Square); + printDynamicIndexList(printer, op, values, integers, + /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square); } //===----------------------------------------------------------------------===// // XeGPU_PrefetchNdOp @@ -381,6 +399,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, l1_hint, l2_hint, l3_hint); } +void PrefetchNdOp::build(OpBuilder &builder, OperationState &state, + Value tensorDesc, ArrayRef<OpFoldResult> offsets, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector<Value> dynamicOffsets; + SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint, + l2_hint, l3_hint); +} + LogicalResult PrefetchNdOp::verify() { auto tdescTy = getTensorDescType(); if (tdescTy.isScattered()) @@ -423,6 +456,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, l3_hint); } +void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType, + Value tensorDesc, ArrayRef<OpFoldResult> offsets, + UnitAttr packed, DenseI64ArrayAttr transpose, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector<Value> dynamicOffsets; + SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr, + packed, transpose, l1_hint, l2_hint, l3_hint); +} + LogicalResult LoadNdOp::verify() { auto tdescTy = getTensorDescType(); auto valueTy = getType(); @@ -529,6 +578,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint); } +void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value, + Value tensorDesc, ArrayRef<OpFoldResult> offsets, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + SmallVector<Value> dynamicOffsets; + SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + + build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr, + l1_hint, l2_hint, l3_hint); +} + LogicalResult StoreNdOp::verify() { auto dstTy = getTensorDescType(); // Tile auto valTy = getValueType(); // Vector @@ -635,10 +699,6 @@ void CreateDescOp::build(OpBuilder &builder, OperationState &state, LogicalResult CreateDescOp::verify() { auto tdescTy = getTensorDescType(); - if (getRankOf(getSource()) > 1) - return emitOpError( - "Expecting the source is a 1D memref or pointer (uint64_t)."); - if (!tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc.\n"); @@ -673,12 +733,14 @@ LogicalResult CreateDescOp::verify() { LogicalResult PrefetchOp::verify() { auto tdescTy = getTensorDescType(); - if (tdescTy && !tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); + if (!tdescTy && !getOffsets()) + return emitOpError("Expects offsets."); - if (!tdescTy && getRankOf(getSource()) > 1) - return emitOpError( - "Expecting the source is a 1D memref or pointer (uint64_t)."); + if (tdescTy && getOffsets()) + return emitOpError("offsets not allowed."); + + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc."); if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -689,6 +751,13 @@ LogicalResult PrefetchOp::verify() { if (!isReadHintOrNone(getL3HintAttr())) return emitOpError("invalid l3_hint: ") << getL3HintAttr(); + auto srcTy = getSourceType(); + if (srcTy.isInteger() && !getOffsetAlignByteAttr()) + return emitOpError("offset_align_byte is required with integer source."); + + if (getOffsetAlignByteAttr() && !srcTy.isInteger()) + return emitOpError("offset_align_byte only allowed with integer source."); + return success(); } @@ -696,7 +765,8 @@ void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source, xegpu::CachePolicyAttr l1_hint, xegpu::CachePolicyAttr l2_hint, xegpu::CachePolicyAttr l3_hint) { - build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint); + build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint, + IntegerAttr{}); } //===----------------------------------------------------------------------===// @@ -707,13 +777,15 @@ LogicalResult LoadGatherOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); + if (!tdescTy && !getOffsets()) + return emitOpError("Expects offsets."); + + if (tdescTy && getOffsets()) + return emitOpError("offsets not allowed."); + if (tdescTy && !tdescTy.isScattered()) return emitOpError("Expects a scattered TensorDesc."); - if (!tdescTy && getRankOf(getSource()) > 1) - return emitOpError( - "Expecting the source is a 1D memref or pointer (uint64_t)."); - if (!isReadHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -730,10 +802,11 @@ LogicalResult LoadGatherOp::verify() { uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); auto memTy = dyn_cast<MemRefType>(srcTy); - if (memTy && (valueTy.getElementType() != memTy.getElementType())) + if (memTy && (getElementType() != memTy.getElementType())) return emitError() << "Value should have the same element type as MemRef."; - return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + auto offsetsTy = getOffsets().getType(); + return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize, [&]() { return emitOpError(); }); } @@ -746,6 +819,22 @@ void LoadGatherOp::build(OpBuilder &builder, OperationState &state, l1_hint, l2_hint, l3_hint); } +void LoadGatherOp::build(OpBuilder &builder, OperationState &state, + Type valueType, Value source, + ArrayRef<OpFoldResult> offsets, Value mask, + IntegerAttr chunk_size, xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + auto loc = source.getLoc(); + int64_t size = static_cast<int64_t>(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + build(builder, state, valueType, source, offset, mask, chunk_size, l1_hint, + l2_hint, l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_StoreScatterOp //===----------------------------------------------------------------------===// @@ -754,12 +843,14 @@ LogicalResult StoreScatterOp::verify() { auto maskTy = getMaskType(); auto valueTy = getValueType(); - if (tdescTy && !tdescTy.isScattered()) - return emitOpError("Expects a scattered TensorDesc.\n"); + if (!tdescTy && !getOffsets()) + return emitOpError("Expects offsets."); - if (!tdescTy && getRankOf(getDest()) > 1) - return emitOpError( - "Expecting the dest is a 1D memref or pointer (uint64_t)."); + if (tdescTy && getOffsets()) + return emitOpError("offsets not allowed."); + + if (tdescTy && !tdescTy.isScattered()) + return emitOpError("Expects a scattered TensorDesc."); if (!isWriteHintOrNone(getL1HintAttr())) return emitOpError("invalid l1_hint: ") << getL1HintAttr(); @@ -778,10 +869,11 @@ LogicalResult StoreScatterOp::verify() { uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1)); auto memTy = dyn_cast<MemRefType>(destTy); - if (memTy && (valueTy.getElementType() != memTy.getElementType())) + if (memTy && (getElementType() != memTy.getElementType())) return emitError() << "Value should have the same element type as MemRef."; - return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize, + auto offsetsTy = getOffsets().getType(); + return isValidGatherScatterBufferParams(offsetsTy, maskTy, valueTy, chunkSize, [&]() { return emitOpError(); }); } @@ -794,6 +886,24 @@ void StoreScatterOp::build(OpBuilder &builder, OperationState &state, l2_hint, l3_hint); } +void StoreScatterOp::build(OpBuilder &builder, OperationState &state, + Value value, Value dest, + ArrayRef<OpFoldResult> offsets, Value mask, + IntegerAttr chunk_size, + xegpu::CachePolicyAttr l1_hint, + xegpu::CachePolicyAttr l2_hint, + xegpu::CachePolicyAttr l3_hint) { + auto loc = dest.getLoc(); + int64_t size = static_cast<int64_t>(offsets.size()); + auto type = VectorType::get(size, builder.getIndexType()); + auto values = getValueOrCreateConstantIndexOp(builder, loc, offsets); + auto offset = vector::FromElementsOp::create(builder, loc, type, values); + + // Call the correct builder overload that does not expect result types. + build(builder, state, value, dest, offset, mask, chunk_size, l1_hint, l2_hint, + l3_hint); +} + //===----------------------------------------------------------------------===// // XeGPU_UpdateOffsetOp //===----------------------------------------------------------------------===// @@ -888,8 +998,8 @@ LogicalResult ConvertLayoutOp::verify() { // both input and target layouts should be WgLayout or SgLayout at the same // time. - if ((!srcLayout.isWgLayout() || !resLayout.isWgLayout()) && - (!srcLayout.isSgLayout() || !resLayout.isSgLayout())) + if ((!srcLayout.isForWorkgroup() || !resLayout.isForWorkgroup()) && + (!srcLayout.isForSubgroup() || !resLayout.isForSubgroup())) return emitOpError("expected input layout and target layout be WgLayout or " "SgLayout at the same time."); @@ -928,6 +1038,101 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, patterns.add<FoldConvertLayoutOp>(context); } +//===----------------------------------------------------------------------===// +// XeGPU_LoadMatrixOp +//===----------------------------------------------------------------------===// +void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res, + TypedValue<MemDescType> memDesc, + llvm::ArrayRef<OpFoldResult> offsets, + DistributeLayoutAttr layout) { + llvm::SmallVector<Value> dynamicOffsets; + llvm::SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr, + layout); +} + +LogicalResult LoadMatrixOp::verify() { + VectorType resTy = getRes().getType(); + MemDescType mdescTy = getMemDesc().getType(); + + if (mdescTy.getRank() != 2) + return emitOpError("mem_desc must be 2D."); + + ArrayRef<int64_t> valueShape = resTy.getShape(); + ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed mem_desc shape."); + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_StoreMatrixOp +//===----------------------------------------------------------------------===// +void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data, + TypedValue<MemDescType> memDesc, + llvm::ArrayRef<OpFoldResult> offsets, + DistributeLayoutAttr layout) { + llvm::SmallVector<Value> dynamicOffsets; + llvm::SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr, + layout); +} + +LogicalResult StoreMatrixOp::verify() { + VectorType dataTy = getData().getType(); + MemDescType mdescTy = getMemDesc().getType(); + + if (mdescTy.getRank() != 2) + return emitOpError("mem_desc must be 2D."); + + ArrayRef<int64_t> dataShape = dataTy.getShape(); + ArrayRef<int64_t> mdescShape = mdescTy.getShape(); + if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("data shape must not exceed mem_desc shape."); + + return success(); +} + +//===----------------------------------------------------------------------===// +// XeGPU_MemDescSubviewOp +//===----------------------------------------------------------------------===// + +void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state, + Type resTy, Value src, + llvm::ArrayRef<OpFoldResult> offsets) { + llvm::SmallVector<Value> dynamicOffsets; + llvm::SmallVector<int64_t> staticOffsets; + dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets); + auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets); + build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr); +} + +LogicalResult MemDescSubviewOp::verify() { + MemDescType srcTy = getSrc().getType(); + MemDescType resTy = getRes().getType(); + ArrayRef<int64_t> srcShape = srcTy.getShape(); + ArrayRef<int64_t> resShape = resTy.getShape(); + + if (srcTy.getRank() < resTy.getRank()) + return emitOpError("result rank must not exceed source rank."); + + if (llvm::any_of( + llvm::zip_equal(resShape, srcShape.take_back(resShape.size())), + [](auto p) { return std::get<0>(p) > std::get<1>(p); })) + return emitOpError("result shape must not exceed source shape."); + + if (srcTy.getStrides() != resTy.getStrides()) + return emitOpError("result must inherit the source strides."); + + return success(); +} + } // namespace xegpu } // namespace mlir diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp index d82c541..9ee002e 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp @@ -84,9 +84,10 @@ struct ConvertLayoutOpPattern using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, PatternRewriter &rewriter) const override { - xegpu::LayoutAttr input_layout = op.getInputLayoutAttr(); - xegpu::LayoutAttr target_layout = op.getTargetLayoutAttr(); - if (!input_layout.getInstData() || !target_layout.getInstData()) + xegpu::DistributeLayoutAttr input_layout = op.getInputLayoutAttr(); + xegpu::DistributeLayoutAttr target_layout = op.getTargetLayoutAttr(); + if (input_layout.getInstDataAsInt().empty() || + target_layout.getInstDataAsInt().empty()) return rewriter.notifyMatchFailure(op, "Not a target ConvertLayoutOp."); input_layout = input_layout.dropInstData(); @@ -140,10 +141,11 @@ XeGPUBlockingPass::getTileShape(const T &operandOrResult) const { else value = (Value)operandOrResult; - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operandOrResult); - if (layout && layout.isSgLayout()) { - if (auto inst_data = layout.getInstData()) - return llvm::to_vector_of<int64_t>(inst_data.asArrayRef()); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(operandOrResult); + if (layout && layout.isForSubgroup()) { + if (!layout.getInstDataAsInt().empty()) + return layout.getInstDataAsInt(); if (auto type = dyn_cast<ShapedType>(value.getType())) return llvm::to_vector(type.getShape()); @@ -204,13 +206,15 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const { // skip the op if any of its operands or results has workgroup level layouts bool hasWgLayoutOperands = llvm::any_of(op->getOpOperands(), [](OpOperand &opr) { - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(opr); - return layout && layout.isWgLayout(); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(opr); + return layout && layout.isForWorkgroup(); }); bool hasWgLayoutResults = llvm::any_of(op->getOpResults(), [](OpResult result) { - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(result); - return layout && layout.isWgLayout(); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(result); + return layout && layout.isForWorkgroup(); }); if (hasWgLayoutOperands || hasWgLayoutResults) { LDBG() << "skip unrolling for op with workgroup level layout: " << *op; @@ -220,8 +224,8 @@ bool XeGPUBlockingPass::needsUnroll(Operation *op) const { auto isUnrollable = [](Value value, ArrayRef<int64_t> tileShape) { Type valTy = value.getType(); if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(valTy)) { - xegpu::LayoutAttr layout = tdescTy.getLayoutAttr(); - return layout && layout.getInstData(); + xegpu::DistributeLayoutAttr layout = tdescTy.getLayoutAttr(); + return layout && !layout.getInstDataAsInt().empty(); } auto shapedType = dyn_cast<ShapedType>(valTy); return shapedType && !llvm::equal(tileShape, shapedType.getShape()); @@ -247,7 +251,8 @@ void XeGPUBlockingPass::runOnOperation() { // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr. // This ensures that the LayoutAttr remains accessible even if the defining // operation is replaced. - xegpu::setLayoutAttrs(op, [](Value v) { return xegpu::getLayoutAttr(v); }); + xegpu::setDistributeLayoutAttrs( + op, [](Value v) { return xegpu::getDistributeLayoutAttr(v); }); auto getTileShapeAndCount = [](llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) { @@ -272,7 +277,7 @@ void XeGPUBlockingPass::runOnOperation() { auto layout = llvm::dyn_cast_if_present<xegpu::LayoutAttr>(type.getEncoding()); - if (layout && layout.isWgLayout()) + if (layout && layout.isForWorkgroup()) return failure(); int count; @@ -289,7 +294,7 @@ void XeGPUBlockingPass::runOnOperation() { ArrayRef<int64_t> shape = type.getShape(); xegpu::LayoutAttr layout = type.getLayoutAttr(); - if (layout && layout.isWgLayout()) + if (layout && layout.isForWorkgroup()) return failure(); int count; @@ -377,7 +382,7 @@ void XeGPUBlockingPass::runOnOperation() { if (auto layout = op->getAttrOfType<xegpu::LayoutAttr>(name)) { op->removeAttr(name); if (!isa<LoopLikeOpInterface>(op)) - xegpu::setLayoutAttr(result, layout.dropInstData()); + xegpu::setDistributeLayoutAttr(result, layout.dropInstData()); } } diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp index bef8804..5cb47b2 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp @@ -718,7 +718,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op, } // If the result is a vector type, add a temporary layout attribute to the // op. - xegpu::setLayoutAttr(result, layout); + xegpu::setDistributeLayoutAttr(result, layout); } return success(); } @@ -800,7 +800,7 @@ updateControlFlowOps(mlir::OpBuilder &builder, // If the type is a vector type and this region argument is an OpResult, // set the layout attribute on the OpResult. if (auto result = dyn_cast<OpResult>(successorInput)) - xegpu::setLayoutAttr(result, successorOperandLayout); + xegpu::setDistributeLayoutAttr(result, successorOperandLayout); } } return success(); diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp index 2088c3c..dddb5ea 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp @@ -336,8 +336,7 @@ struct StoreNdDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); Operation *lastNode = yield->getPrevNode(); auto storeOp = dyn_cast_or_null<xegpu::StoreNdOp>(lastNode); if (!storeOp) @@ -449,8 +448,7 @@ struct LoadNdDistribution final : public gpu::WarpDistributionPattern { // Make sure the same load op is the last operation in the warp op body. // This ensure that load op is not sinked earlier violating any barrier // synchronizations. - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); return yield->getPrevNode() == op; }); @@ -752,8 +750,7 @@ struct PrefetchNdDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); Operation *lastNode = yield->getPrevNode(); auto prefetchOp = dyn_cast_or_null<xegpu::PrefetchNdOp>(lastNode); if (!prefetchOp) @@ -794,8 +791,7 @@ struct GpuBarrierDistribution final : public gpu::WarpDistributionPattern { using gpu::WarpDistributionPattern::WarpDistributionPattern; LogicalResult matchAndRewrite(gpu::WarpExecuteOnLane0Op warpOp, PatternRewriter &rewriter) const override { - auto yield = cast<gpu::YieldOp>( - warpOp.getBodyRegion().getBlocks().begin()->getTerminator()); + gpu::YieldOp yield = warpOp.getTerminator(); Operation *lastNode = yield->getPrevNode(); // The last node must be a gpu::BarrierOp. auto barrierOp = dyn_cast_or_null<gpu::BarrierOp>(lastNode); @@ -841,14 +837,15 @@ void XeGPUSubgroupDistributePass::runOnOperation() { if (!isa<VectorType>(operand.get().getType())) continue; - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(operand); + auto layout = + xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand); if (!layout) { op->emitError("Could not find layout attribute for operand ") << operand.getOperandNumber() << " of operation " << op->getName(); signalPassFailure(); return; } - xegpu::setLayoutAttr(operand, layout); + xegpu::setDistributeLayoutAttr(operand, layout); } }); // Step 2: Move all operations of a GPU function inside @@ -882,7 +879,8 @@ void XeGPUSubgroupDistributePass::runOnOperation() { if (vecRank == 0) return AffineMap::get(val.getContext()); // Get the layout of the vector type. - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(val); + // TODO: support more layout types + auto layout = xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(val); // If no layout is specified, assume the inner most dimension is distributed // for now. if (!layout) diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 4a5525c..9f627c7 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -34,38 +34,29 @@ using namespace mlir; namespace { -// Check if there is sg id range attached to the scf.if op. -static bool isSgIdRangeSpecified(Operation *op, int64_t &startOfRange, - int64_t &endOfRange) { - Operation *parent = op->getParentOp(); - // Find the outermost scf::IfOp with xegpu.sg_id_range. +// Retrieve the RangeAttr if it is specified. +static xegpu::RangeAttr getRangeSpecAttr(Operation *op) { + Operation *parent = op->getParentOfType<scf::IfOp>(); while (parent) { - if (auto ifOp = dyn_cast<scf::IfOp>(parent)) { - if (auto attr = llvm::dyn_cast_or_null<xegpu::RangeAttr>( - ifOp->getAttr("sg_id_range"))) { - startOfRange = attr.getStart().getInt(); - endOfRange = attr.getEnd().getInt(); - break; - } - } - parent = parent->getParentOp(); + if (auto attr = llvm::dyn_cast_if_present<xegpu::RangeAttr>( + parent->getAttr("sg_id_range"))) + return attr; + parent = parent->getParentOfType<scf::IfOp>(); } - // Return false if startOfRange is 0 - return (startOfRange > 0 && endOfRange > startOfRange); + return {}; } static std::pair<SmallVector<int64_t>, int> -getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) { +getSgShapeAndCount(ArrayRef<int64_t> shape, + xegpu::DistributeLayoutAttr layout) { int count = 1; SmallVector<int64_t> sgShape(shape); - - if (layout && layout.isWgLayout()) { - DenseI32ArrayAttr sgLayoutAttr = layout.getSgLayout(); - auto sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); - if (DenseI32ArrayAttr sgDataAttr = layout.getSgData()) - sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef()); - else - sgShape = computeShapeRatio(shape, sgLayout).value_or(sgShape); + if (layout && layout.isForWorkgroup()) { + SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt(); + if (!layout.getSgDataAsInt().empty()) + sgShape = layout.getSgDataAsInt(); + else if (auto maybeDerivedSgData = computeShapeRatio(shape, sgLayout)) + sgShape = *maybeDerivedSgData; SmallVector<int64_t> distUnit = computeElementwiseMul(sgLayout, sgShape); // Clamp distUnit to the original shape to handle cases where data is // shared among subgroups, which may cause distUnit to exceed the original @@ -77,6 +68,67 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) { return std::make_pair(sgShape, count); } +/// Utility helper for deriving a list of offsets for each sub-TensorDescs +/// or sub-MemDescs to be accessed by current subgroup (sgId) based on the +/// associated distribute layout attribute, the shape, subgroup id and the +/// original offsets of the op +template < + typename OpType, + typename = std::enable_if_t<llvm::is_one_of< + OpType, xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp, + xegpu::PrefetchNdOp, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>> +static LogicalResult +genOffsetsList(ConversionPatternRewriter &rewriter, OpType op, + SmallVector<SmallVector<OpFoldResult>> &offsetsList) { + Location loc = op.getLoc(); + SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets(); + // not applicable to ops without offsets operands. + if (origOffsets.empty()) + return failure(); + + // not applicable to ops without workgroup layout attributes + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + Value sgId = rewriter.create<gpu::SubgroupIdOp>(loc, /*upper_bound=*/nullptr); + + // verify and adjust the sgId if the range specifier is present + xegpu::RangeAttr sgIdRange = getRangeSpecAttr(op); + if (sgIdRange) { + int64_t startOfRange = sgIdRange.getStart().getInt(); + int64_t endOfRange = sgIdRange.getEnd().getInt(); + // verify the RangeAttr against the layout attribute + if (layout.getNumSubgroups() != endOfRange - startOfRange) + return rewriter.notifyMatchFailure( + op, "sg_layout size must match the sg_id_range"); + // adjust the sgId if necessary + if (startOfRange > 0) { + Value startOfRangeVal = + rewriter.create<arith::ConstantIndexOp>(loc, startOfRange); + sgId = rewriter.create<index::SubOp>(loc, sgId, startOfRangeVal); + } + } + + // Compute the list of subgroup-relative offsets for sub-tensors or sub-memory + // descriptors to be accessed, based on the layout information. + ArrayRef<int64_t> wgShape = op.getDataShape(); + auto maybeDescOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape); + if (failed(maybeDescOffsets)) + return failure(); + + // Compute the final global offsets for each accessed sub-tensor + // or sub-memory descriptor. + for (const auto &sgOffsets : *maybeDescOffsets) { + SmallVector<OpFoldResult> newOffsets = xegpu::addWithRightAligned( + rewriter, loc, getAsOpFoldResult(sgOffsets), origOffsets); + offsetsList.push_back(std::move(newOffsets)); + } + + // callback(offsetsList); + return success(); +} + /// This pattern transforms the CreateNdDescOp to create a subgroup descriptor /// from a workgroup descriptor. It replaces the offsets and sizes with /// appropriate values for the subgroup. @@ -128,72 +180,72 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> { LogicalResult matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + MLIRContext *ctx = op.getContext(); + xegpu::TensorDescType tdescTy = op.getType(); + ArrayRef<int64_t> wgShape = tdescTy.getShape(); + Type elemTy = tdescTy.getElementType(); + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + auto newTdescTy = + xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), + layout.dropSgLayoutAndData()); + + SmallVector<Value> newOps; + for (auto offsets : offsetsList) { + auto newOp = xegpu::CreateNdDescOp::create( + rewriter, op.getLoc(), newTdescTy, op.getSource(), offsets, + op.getMixedSizes(), op.getMixedStrides()); + + newOps.push_back(newOp); + } + rewriter.replaceOpWithMultiple(op, {newOps}); + + return success(); + } +}; + +// This pattern transforms the CreateNdDescOp without offsets to create a +// subgroup descriptor from a workgroup descriptor +struct WgToSgCreateNdOpNoOffset + : public OpConversionPattern<xegpu::CreateNdDescOp> { + using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + // Check no offsets are specified. + if (!op.getMixedOffsets().empty()) + return failure(); + Location loc = op.getLoc(); MLIRContext *ctx = op.getContext(); xegpu::TensorDescType tdescTy = op.getType(); auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout()); - if (!layout) + if (!layout || !layout.isForWorkgroup()) return failure(); + Type elemTy = tdescTy.getElementType(); ArrayRef<int64_t> wgShape = tdescTy.getShape(); - // sgLayout must be present for workgroup-level distribution. - SmallVector<int64_t> sgLayout; - if (auto sgLayoutAttr = layout.getSgLayout()) - sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); - else - return rewriter.notifyMatchFailure( - op, "sgLayout attribute is required in layout"); - - // Get the subgroup ID - Value linearSgId = - gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); - - int64_t startOfRange = -1, endOfRange = -1; - bool sgIdRangeSpecified = - isSgIdRangeSpecified(op, startOfRange, endOfRange); - - if (sgIdRangeSpecified) { - int64_t sgCount = endOfRange - startOfRange; - if (computeProduct(sgLayout) != sgCount) - return rewriter.notifyMatchFailure( - op, "sg_layout size must match the sg_id_range"); - // Subtract startOfRange from the original subgroup id to get - // the adjusted sg id - Value startOfRangeVal = - rewriter.create<arith::ConstantIndexOp>(loc, startOfRange); - linearSgId = - rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal); - } - auto maybeTdescOffsets = - layout.getOffsets(rewriter, loc, linearSgId, wgShape); - if (failed(maybeTdescOffsets)) - return failure(); - - SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + SmallVector<int64_t> sgShape; + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); xegpu::TensorDescType newTdescTy = xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(), layout.dropSgLayoutAndData()); - SmallVector<Value> newCreateNdOps; - SmallVector<OpFoldResult> wgOffsets = op.getMixedOffsets(); - - for (auto tdescOffsets : *maybeTdescOffsets) { - SmallVector<OpFoldResult> sgOffsets; - size_t rank = tdescOffsets.size(); - for (size_t i = 0; i < rank; i++) { - size_t idx = wgOffsets.size() - rank + i; - Value add = rewriter.createOrFold<index::AddOp>( - loc, tdescOffsets[i], - getValueOrCreateConstantIndexOp(rewriter, loc, wgOffsets[idx])); - sgOffsets.push_back(add); - } + SmallVector<Value> newCreateNdOps(count); + std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() { + return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy, + op.getSource(), op.getMixedSizes(), + op.getMixedStrides()); + }); - auto newOp = xegpu::CreateNdDescOp::create( - rewriter, loc, newTdescTy, op.getSource(), sgOffsets, - op.getMixedSizes(), op.getMixedStrides()); - newCreateNdOps.push_back(newOp); - } rewriter.replaceOpWithMultiple(op, {newCreateNdOps}); return success(); } @@ -205,12 +257,10 @@ struct WgToSgLoadNdOp : public OpConversionPattern<xegpu::LoadNdOp> { LogicalResult matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SmallVector<Value> newLoadOps; - - int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); - if ((offsetSize != 0) || op.getConstOffsetsAttr()) + if (!op.getMixedOffsets().empty()) return failure(); + SmallVector<Value> newLoadOps; for (auto src : adaptor.getTensorDesc()) { xegpu::TensorDescType tdescTy = dyn_cast<xegpu::TensorDescType>(src.getType()); @@ -233,9 +283,7 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> { LogicalResult matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size()); - if ((offsetSize != 0) || op.getConstOffsetsAttr()) + if (!op.getMixedOffsets().empty()) return failure(); for (auto [v, t] : llvm::zip(adaptor.getValue(), adaptor.getTensorDesc())) @@ -247,6 +295,84 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> { } }; +// This pattern transforms the LoadNdOp with explicit offsets to load +// subgroup data. +struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> { + using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + SmallVector<Value> newOps; + for (auto [tdesc, offsets] : + llvm::zip(adaptor.getTensorDesc(), offsetsList)) { + auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType()); + VectorType newResTy = + VectorType::get(tdescTy.getShape(), tdescTy.getElementType()); + auto newOp = xegpu::LoadNdOp::create( + rewriter, op.getLoc(), newResTy, tdesc, offsets, + /*packed = */ nullptr, /*transpose = */ nullptr, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + newOps.push_back(newOp); + } + rewriter.replaceOpWithMultiple(op, {newOps}); + + return success(); + } +}; + +// This pattern transforms the StoreNdOp with explicit offsets to store +// subgroup data. +struct WgToSgStoreNdOpWithOffset + : public OpConversionPattern<xegpu::StoreNdOp> { + using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + for (auto [v, tdesc, offsets] : + llvm::zip(adaptor.getValue(), adaptor.getTensorDesc(), offsetsList)) { + rewriter.create<xegpu::StoreNdOp>(op.getLoc(), v, tdesc, offsets, + op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + } + rewriter.eraseOp(op); + + return success(); + } +}; + +// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch +// subgroup data. +struct WgToSgPrefetchNdOpWithOffset + : public OpConversionPattern<xegpu::PrefetchNdOp> { + using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + for (auto [tdesc, offsets] : + llvm::zip(adaptor.getTensorDesc(), offsetsList)) { + rewriter.create<xegpu::PrefetchNdOp>( + op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(), + op.getL3HintAttr()); + } + rewriter.eraseOp(op); + + return success(); + } +}; + /// This pattern transforms the UpdateNdOffsetOp to update the offsets of a /// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the /// offsets of the new subgroup src tensor descriptors. @@ -280,7 +406,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> { if (resultTy.getRank() != 2) return failure(); - auto originalLayout = xegpu::getLayoutAttr(op.getResult()); + auto originalLayout = xegpu::getDistributeLayoutAttr(op.getResult()); if (!originalLayout) return failure(); @@ -303,8 +429,8 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> { VectorType resTy = VectorType::get({aVecShape[0], bVecShape[1]}, resultTy.getElementType()); tmpC = xegpu::DpasOp::create(rewriter, loc, resTy, operands); - xegpu::setLayoutAttr(cast<OpResult>(tmpC), - originalLayout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(cast<OpResult>(tmpC), + originalLayout.dropSgLayoutAndData()); newDpasOps.push_back(tmpC); } @@ -344,8 +470,9 @@ struct WgToSgVectorBroadcastOp VectorType resultType = op.getResult().getType(); ArrayRef<int64_t> wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); - if (!layout || !layout.getSgLayout()) + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) return failure(); // TODO: Currently only supports cases where the source and result ranks @@ -360,10 +487,8 @@ struct WgToSgVectorBroadcastOp VectorType::get(sgShape, resultType.getElementType()); // Check if the output layout is distributable - SmallVector<int64_t> sgLayout; - if (auto sgLayoutAttr = layout.getSgLayout()) - sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef()); - else + SmallVector<int64_t> sgLayout = layout.getSgLayoutAsInt(); + if (sgLayout.empty()) return failure(); if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout)) @@ -382,8 +507,8 @@ struct WgToSgVectorBroadcastOp for (auto operand : adaptor.getOperands().front()) { auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(), newResultType, operand); - xegpu::setLayoutAttr(newBroadcast->getResult(0), - layout.dropSgLayoutAndData()); + xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0), + layout.dropSgLayoutAndData()); newBroadcastOps.push_back(newBroadcast.getResult()); } @@ -409,8 +534,9 @@ struct WgToSgElementwiseOp : public ConversionPattern { ArrayRef<int64_t> wgShape = resultType.getShape(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); - if (!layout || !layout.getSgLayout()) + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op->getResult(0)); + if (!layout || !layout.isForWorkgroup()) return failure(); SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; @@ -475,8 +601,8 @@ struct WgToSgElementwiseOp : public ConversionPattern { // is lowered to: // #a = #xegpu.layout<inst_data = [16, 16]> // #b = #xegpu.layout<inst_data = [8, 16]> -// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32> -// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32> +// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32> +// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32> // xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32> // clang-format on struct WgToSgConvertLayoutOp @@ -485,10 +611,12 @@ struct WgToSgConvertLayoutOp LogicalResult matchAndRewrite(xegpu::ConvertLayoutOp op, OneToNOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - xegpu::LayoutAttr input = op.getInputLayout(); - xegpu::LayoutAttr target = op.getTargetLayout(); + // TODO: currently, we only support LayoutAttr + auto input = dyn_cast<xegpu::LayoutAttr>(op.getInputLayout()); + auto target = dyn_cast<xegpu::LayoutAttr>(op.getTargetLayout()); - if (!input || !target || !input.isWgLayout() || !target.isWgLayout()) + if (!input || !target || !input.isForWorkgroup() || + !target.isForWorkgroup()) return rewriter.notifyMatchFailure( op, "Input and target layouts must have subgroup layout"); @@ -598,16 +726,213 @@ struct UnrealizedConversionCastOpPattern } }; +// This pattern distributes arith.constant op into subgroup-level constants +struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> { + using OpConversionPattern<arith::ConstantOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue()); + auto vecType = dyn_cast<VectorType>(op.getType()); + if (!vecAttr || !vecAttr.isSplat() || !vecType) + return failure(); + + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + ArrayRef<int64_t> wgShape = vecType.getShape(); + SmallVector<int64_t> sgShape; + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); + + // Current limitation: constant of vector with single value. + // TODO: support more complex cases, e.g., vector with multiple values. + Attribute singleVal = vecAttr.getSplatValue<Attribute>(); + + auto newType = VectorType::get(sgShape, vecType.getElementType()); + auto sgAttr = DenseElementsAttr::get(newType, singleVal); + auto cstOp = + arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr); + if (auto newLayout = layout.dropSgLayoutAndData()) + xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout); + SmallVector<Value> newConsts(count, cstOp); + + rewriter.replaceOpWithMultiple(op, {newConsts}); + return success(); + } +}; + +// This pattern transforms the LoadGatherOp with explicit offsets to load +// subgroup data +struct WgToSgLoadGatherOpWithOffset + : public OpConversionPattern<xegpu::LoadGatherOp> { + using OpConversionPattern<xegpu::LoadGatherOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadGatherOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getOffsets()) + return failure(); + + Location loc = op.getLoc(); + VectorType resultType = dyn_cast<VectorType>(op.getResult().getType()); + if (!resultType) + return failure(); + ArrayRef<int64_t> wgShape = resultType.getShape(); + + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + + // The offsets need to be distributed + auto offsetsVecType = + dyn_cast<VectorType>(adaptor.getOffsets().front().getType()); + auto maskVecType = + dyn_cast<VectorType>(adaptor.getMask().front().getType()); + if (!offsetsVecType || !maskVecType || + offsetsVecType.getShape() != maskVecType.getShape()) { + return rewriter.notifyMatchFailure(op, + "offsets have not been distributed"); + } + + SmallVector<Value> newLoadOps; + auto chunkSizeAttr = + rewriter.getI64IntegerAttr(op.getChunkSize().value_or(1)); + VectorType newTy = VectorType::get(sgShape, resultType.getElementType()); + for (auto [offsets, mask] : + llvm::zip(adaptor.getOffsets(), adaptor.getMask())) { + auto newLoadOp = rewriter.create<xegpu::LoadGatherOp>( + loc, newTy, op.getSource(), offsets, mask, chunkSizeAttr, + op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr()); + xegpu::setDistributeLayoutAttr(newLoadOp->getResult(0), + layout.dropSgLayoutAndData()); + newLoadOps.push_back(newLoadOp); + } + rewriter.replaceOpWithMultiple(op, {newLoadOps}); + return success(); + } +}; + +// This pattern transforms the StoreScatterOp with explicit offsets to store +// subgroup data +struct WgToSgStoreScatterOpWithOffset + : public OpConversionPattern<xegpu::StoreScatterOp> { + using OpConversionPattern<xegpu::StoreScatterOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreScatterOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (!op.getOffsets()) + return failure(); + + Location loc = op.getLoc(); + VectorType valueType = dyn_cast<VectorType>(op.getValue().getType()); + if (!valueType) + return failure(); + + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getValue()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + // The offsets need to be distributed + auto offsetsVecType = + dyn_cast<VectorType>(adaptor.getOffsets().front().getType()); + auto maskVecType = + dyn_cast<VectorType>(adaptor.getMask().front().getType()); + if (!offsetsVecType || !maskVecType || + offsetsVecType.getShape() != maskVecType.getShape()) { + return rewriter.notifyMatchFailure(op, + "offsets have not been distributed"); + } + + auto chunkSizeOpt = op.getChunkSize(); + int64_t chunkSize = chunkSizeOpt ? static_cast<int64_t>(*chunkSizeOpt) : 1; + auto chunkSizeAttr = rewriter.getI64IntegerAttr(chunkSize); + for (auto [val, offs, mask] : llvm::zip( + adaptor.getValue(), adaptor.getOffsets(), adaptor.getMask())) { + rewriter.create<xegpu::StoreScatterOp>( + loc, val, op.getDest(), offs, mask, chunkSizeAttr, op.getL1HintAttr(), + op.getL2HintAttr(), op.getL3HintAttr()); + // Update the layout attribute to drop sg_layout and sg_data. + if (auto newLayout = layout.dropSgLayoutAndData()) + op->setAttr("layout", newLayout); + } + rewriter.eraseOp(op); + return success(); + } +}; + +struct WgToSgLoadMatrixOp : public OpConversionPattern<xegpu::LoadMatrixOp> { + using OpConversionPattern<xegpu::LoadMatrixOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::LoadMatrixOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + ArrayRef<int64_t> wgShape = op.getDataShape(); + VectorType valueTy = op.getRes().getType(); + Type elemTy = valueTy.getElementType(); + + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType newResTy = VectorType::get(sgShape, elemTy); + SmallVector<Value> newOps; + for (auto offsets : offsetsList) { + auto newOp = rewriter.create<xegpu::LoadMatrixOp>( + op.getLoc(), newResTy, op.getMemDesc(), offsets, + layout.dropSgLayoutAndData()); + newOps.push_back(newOp); + } + rewriter.replaceOpWithMultiple(op, {newOps}); + + return success(); + } +}; + +struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> { + using OpConversionPattern<xegpu::StoreMatrixOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(xegpu::StoreMatrixOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + SmallVector<SmallVector<OpFoldResult>> offsetsList; + if (failed(genOffsetsList(rewriter, op, offsetsList))) + return failure(); + + xegpu::DistributeLayoutAttr layout = op.getLayoutAttr(); + for (auto [v, offsets] : llvm::zip(adaptor.getData(), offsetsList)) + rewriter.create<xegpu::StoreMatrixOp>(op.getLoc(), v, op.getMemDesc(), + offsets, + layout.dropSgLayoutAndData()); + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace namespace mlir { namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { - patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp, - WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, - UnrealizedConversionCastOpPattern, WgToSgElementwiseOp, - WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>( - patterns.getContext()); + patterns + .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp, + WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset, + WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp, + WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern, + WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp, + WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, + WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, + WgToSgStoreMatrixOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -697,8 +1022,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return xegpu::TensorDescType(); }; - auto isLegal = [&](xegpu::LayoutAttr layout) -> bool { - return !layout || !layout.isWgLayout(); + auto isLegal = [&](xegpu::DistributeLayoutAttr layout) -> bool { + return !layout || !layout.isForWorkgroup(); }; target.addDynamicallyLegalOp<xegpu::CreateNdDescOp, xegpu::LoadNdOp, @@ -710,13 +1035,46 @@ void XeGPUWgToSgDistributePass::runOnOperation() { }); target.addDynamicallyLegalOp<xegpu::DpasOp>([=](xegpu::DpasOp op) -> bool { - auto layout = xegpu::getLayoutAttr(op.getResult()); + auto layout = xegpu::getDistributeLayoutAttr(op.getResult()); return isLegal(layout); }); + target.addDynamicallyLegalOp<xegpu::LoadMatrixOp>( + [=](xegpu::LoadMatrixOp op) -> bool { + return isLegal(op.getLayoutAttr()); + }); + + target.addDynamicallyLegalOp<xegpu::StoreMatrixOp>( + [=](xegpu::StoreMatrixOp op) -> bool { + return isLegal(op.getLayoutAttr()); + }); + + target.addDynamicallyLegalOp<arith::ConstantOp>( + [=](arith::ConstantOp op) -> bool { + auto vecType = dyn_cast<VectorType>(op.getType()); + if (!vecType) + return true; + return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); + }); + + target.addDynamicallyLegalOp<xegpu::LoadGatherOp>( + [=](xegpu::LoadGatherOp op) -> bool { + auto layout = xegpu::getDistributeLayoutAttr(op.getResult()); + return isLegal(layout); + }); + + target.addDynamicallyLegalOp<xegpu::StoreScatterOp>( + [=](xegpu::StoreScatterOp op) -> bool { + // Check if the layout attribute is present on the result. + auto layout = op->getAttrOfType<xegpu::LayoutAttr>("layout"); + if (!layout) + return true; + return isLegal(layout); + }); + target.addDynamicallyLegalOp<vector::BroadcastOp>( [=](vector::BroadcastOp op) -> bool { - return isLegal(xegpu::getLayoutAttr(op.getResult())); + return isLegal(xegpu::getDistributeLayoutAttr(op.getResult())); }); target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>( @@ -744,7 +1102,8 @@ void XeGPUWgToSgDistributePass::runOnOperation() { } } - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0)); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op->getResult(0)); return isLegal(layout); }); diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt index 98e84a4..d9bf4a1 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt @@ -7,5 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUUtils LINK_LIBS PUBLIC MLIRIR MLIRSCFTransforms + MLIRGPUDialect + MLIRXeVMDialect MLIRXeGPUDialect ) diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp index 2cf21fb..cac1ffe 100644 --- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp +++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp @@ -11,6 +11,9 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/XeGPU/IR/XeGPU.h" @@ -38,7 +41,7 @@ mlir::xegpu::getDistributedVectorType(xegpu::TensorDescType tdescTy) { auto layout = llvm::dyn_cast_if_present<LayoutAttr>(tdescTy.getLayout()); // It only works for subgroup level layout, which only has lane_layout // and lane_data, and is to distribute a SIMD code into SIMT code. - if (!layout || !layout.isSgLayout()) + if (!layout || !layout.isForSubgroup()) return failure(); SmallVector<int64_t> laneData(layout.getLaneData().asArrayRef()); @@ -111,7 +114,7 @@ std::string xegpu::getLayoutName(const OpResult result) { return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str(); } -xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { +xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) { if (!value) return nullptr; @@ -129,11 +132,11 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { // for LoadNdOp, the layout is stored in the tensor descriptor if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp)) - return getLayoutAttr(loadNd.getTensorDesc()); + return getDistributeLayoutAttr(loadNd.getTensorDesc()); std::string layoutName = getLayoutName(result); if (defOp->hasAttr(layoutName)) - return defOp->getAttrOfType<xegpu::LayoutAttr>(layoutName); + return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); } if (auto arg = dyn_cast<BlockArgument>(value)) { @@ -141,49 +144,51 @@ xegpu::LayoutAttr xegpu::getLayoutAttr(const Value value) { if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) { OpOperand *tiedInit = loop.getTiedLoopInit(arg); if (tiedInit) - return getLayoutAttr(tiedInit->get()); + return getDistributeLayoutAttr(tiedInit->get()); } } return nullptr; } -xegpu::LayoutAttr xegpu::getLayoutAttr(const OpOperand &opr) { +xegpu::DistributeLayoutAttr +xegpu::getDistributeLayoutAttr(const OpOperand &opr) { Operation *op = opr.getOwner(); std::string layoutName = xegpu::getLayoutName(opr); if (op->hasAttr(layoutName)) - return op->getAttrOfType<xegpu::LayoutAttr>(layoutName); - return getLayoutAttr(opr.get()); + return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName); + return getDistributeLayoutAttr(opr.get()); } template <typename T, typename> -void xegpu::setLayoutAttr(const T &operandOrResult, const LayoutAttr layout) { +void xegpu::setDistributeLayoutAttr(const T &operandOrResult, + const DistributeLayoutAttr layout) { Operation *owner = operandOrResult.getOwner(); std::string name = xegpu::getLayoutName(operandOrResult); - if (layout && !owner->hasAttrOfType<LayoutAttr>(name)) + if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name)) owner->setAttr(name, layout); } // Explicit instantiation for OpResult -template void -xegpu::setLayoutAttr<mlir::OpResult>(const mlir::OpResult &result, - const mlir::xegpu::LayoutAttr layout); +template void xegpu::setDistributeLayoutAttr<mlir::OpResult>( + const mlir::OpResult &result, + const mlir::xegpu::DistributeLayoutAttr layout); // Explicit instantiation for OpOperand -template void -xegpu::setLayoutAttr<mlir::OpOperand>(const mlir::OpOperand &operand, - const mlir::xegpu::LayoutAttr layout); +template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>( + const mlir::OpOperand &operand, + const mlir::xegpu::DistributeLayoutAttr layout); -void xegpu::setLayoutAttrs(Operation *op, - function_ref<LayoutAttr(Value)> getLayoutImpl) { +void xegpu::setDistributeLayoutAttrs( + Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) { op->walk([&](Operation *nestOp) { for (OpOperand &opr : nestOp->getOpOperands()) { auto layout = getLayoutImpl(opr.get()); - setLayoutAttr(opr, layout); + setDistributeLayoutAttr(opr, layout); } for (OpResult result : nestOp->getOpResults()) { auto layout = getLayoutImpl(result); - setLayoutAttr(result, layout); + setDistributeLayoutAttr(result, layout); } }); } @@ -192,7 +197,7 @@ template <typename T, typename> void xegpu::removeLayoutAttr(const T &operandOrResult) { Operation *owner = operandOrResult.getOwner(); std::string name = xegpu::getLayoutName(operandOrResult); - if (owner->hasAttrOfType<LayoutAttr>(name)) + if (owner->hasAttrOfType<DistributeLayoutAttr>(name)) owner->removeAttr(name); } @@ -303,7 +308,8 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( if (!inputTy || !resultTy) return WalkResult::skip(); - xegpu::LayoutAttr layout = xegpu::getLayoutAttr(input); + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(input); if (!layout) return WalkResult::skip(); @@ -341,7 +347,7 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( } { // perform the conversion from RankedTensorType to VectorType based on the - // LayoutAttr + // DistributeLayoutAttr // Handle the UnrealizedConversionCastOp introduced by the first step. // For vector->RankedTensorType, it will simply forward the inputs. @@ -404,3 +410,49 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType( (void)mlir::applyPartialConversion(op, target, std::move(patterns)); } } + +std::optional<std::string> xegpu::getChipStr(Operation *op) { + auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>(); + + if (!gpuModuleOp) + return std::nullopt; + + auto targetAttrs = gpuModuleOp.getTargets(); + if (targetAttrs) { + for (auto &attr : *targetAttrs) { + auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr); + if (xevmAttr) + return xevmAttr.getChip().str(); + } + } + + return std::nullopt; +} + +/// Generates element-wise addition ops of two arrays with automatic alignment. +/// When the input arrays have different sizes, the shorter array is +/// right-aligned with the longer array, and the unmatched leading elements from +/// the longer array are preserved unchanged. This is commonly used for offset +/// computation where higher-dimensional offsets need to be added to +/// lower-dimensional adjustments. +/// +/// Example: +/// lhs = [l1, l2, l3], rhs = [r1, r2] +/// Result: [11, l2+r1, l3+r2] +SmallVector<OpFoldResult> +xegpu::addWithRightAligned(OpBuilder &builder, Location loc, + ArrayRef<OpFoldResult> lhs, + ArrayRef<OpFoldResult> rhs) { + // ensure a is longer than b + ArrayRef<OpFoldResult> a = lhs.size() >= rhs.size() ? lhs : rhs; + ArrayRef<OpFoldResult> b = lhs.size() >= rhs.size() ? rhs : lhs; + SmallVector<OpFoldResult> results(a.take_front(a.size() - b.size())); + a = a.slice(a.size() - b.size()); + for (auto [l, r] : llvm::zip(a, b)) { + auto lval = getValueOrCreateConstantIndexOp(builder, loc, l); + auto rval = getValueOrCreateConstantIndexOp(builder, loc, r); + results.push_back(builder.createOrFold<index::AddOp>(loc, lval, rval)); + } + return results; + return {}; +} diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index f704fbf..52162a4 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) { } // Compilation is lazy and it doesn't populate object cache unless requested. // In case object dump is requested before cache is populated, we need to - // force compilation manually. + // force compilation manually. if (cache->isEmpty()) { for (std::string &functionName : functionNames) { auto result = lookupPacked(functionName); @@ -400,13 +400,6 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options, return symbolMap; }; engine->registerSymbols(runtimeSymbolMap); - - // Execute the global constructors from the module being processed. - // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a - // crash for AArch64 see related issue #71963. - if (!engine->jit->getTargetTriple().isAArch64()) - cantFail(engine->jit->initialize(engine->jit->getMainJITDylib())); - return std::move(engine); } @@ -442,6 +435,7 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const { Error ExecutionEngine::invokePacked(StringRef name, MutableArrayRef<void *> args) { + initialize(); auto expectedFPtr = lookupPacked(name); if (!expectedFPtr) return expectedFPtr.takeError(); @@ -451,3 +445,13 @@ Error ExecutionEngine::invokePacked(StringRef name, return Error::success(); } + +void ExecutionEngine::initialize() { + if (isInitialized) + return; + // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a + // crash for AArch64 see related issue #71963. + if (!jit->getTargetTriple().isAArch64()) + cantFail(jit->initialize(jit->getMainJITDylib())); + isInitialized = true; +} diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp index 2107df3..0ada4cc 100644 --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -202,6 +202,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint, auto engine = std::move(*expectedEngine); + engine->initialize(); + auto expectedFPtr = engine->lookupPacked(entryPoint); if (!expectedFPtr) return expectedFPtr.takeError(); diff --git a/mlir/lib/ExecutionEngine/VulkanRuntime.cpp b/mlir/lib/ExecutionEngine/VulkanRuntime.cpp index 9f653b2..9452a56 100644 --- a/mlir/lib/ExecutionEngine/VulkanRuntime.cpp +++ b/mlir/lib/ExecutionEngine/VulkanRuntime.cpp @@ -20,7 +20,7 @@ #include <iomanip> #include <iostream> -inline void emitVulkanError(const char *api, VkResult error) { +static inline void emitVulkanError(const char *api, VkResult error) { std::cerr << " failed with error code " << error << " when executing " << api; } diff --git a/mlir/lib/IR/Block.cpp b/mlir/lib/IR/Block.cpp index 57825d9..27b47e2 100644 --- a/mlir/lib/IR/Block.cpp +++ b/mlir/lib/IR/Block.cpp @@ -251,6 +251,16 @@ bool Block::mightHaveTerminator() { return !empty() && back().mightHaveTrait<OpTrait::IsTerminator>(); } +iterator_range<Block::iterator> Block::without_terminator_impl() { + // Note: When the op is unregistered, we do not know for sure if the last + // op is a terminator. In that case, we include it in `without_terminator`, + // but that decision is somewhat arbitrary. + if (!back().hasTrait<OpTrait::IsTerminator>()) + return {begin(), end()}; + auto endIt = --end(); + return {begin(), endIt}; +} + // Indexed successor access. unsigned Block::getNumSuccessors() { return empty() ? 0 : back().getNumSuccessors(); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index fd898b7..6f880f8 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/Types.h" #include "llvm/ADT/APSInt.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/Endian.h" #include <optional> @@ -1119,9 +1120,8 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, auto denseEltBitWidth = getDenseElementBitWidth(type); auto dataSize = static_cast<size_t>(dataEltSize * CHAR_BIT); if (denseEltBitWidth != dataSize) { - LLVM_DEBUG(llvm::dbgs() << "expected dense element bit width " - << denseEltBitWidth << " to match data size " - << dataSize << " for type " << type << "\n"); + LDBG() << "expected dense element bit width " << denseEltBitWidth + << " to match data size " << dataSize << " for type " << type; return false; } @@ -1129,9 +1129,7 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, if (!isInt) { bool valid = llvm::isa<FloatType>(type); if (!valid) - LLVM_DEBUG(llvm::dbgs() - << "expected float type when isInt is false, but found " - << type << "\n"); + LDBG() << "expected float type when isInt is false, but found " << type; return valid; } if (type.isIndex()) @@ -1139,9 +1137,7 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, auto intType = llvm::dyn_cast<IntegerType>(type); if (!intType) { - LLVM_DEBUG(llvm::dbgs() - << "expected integer type when isInt is true, but found " << type - << "\n"); + LDBG() << "expected integer type when isInt is true, but found " << type; return false; } @@ -1151,8 +1147,7 @@ static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, bool valid = intType.isSigned() == isSigned; if (!valid) - LLVM_DEBUG(llvm::dbgs() << "expected signedness " << isSigned - << " to match type " << type << "\n"); + LDBG() << "expected signedness " << isSigned << " to match type " << type; return valid; } diff --git a/mlir/lib/IR/CMakeLists.txt b/mlir/lib/IR/CMakeLists.txt index 3ef69ce..d95bdc9 100644 --- a/mlir/lib/IR/CMakeLists.txt +++ b/mlir/lib/IR/CMakeLists.txt @@ -33,6 +33,7 @@ add_mlir_library(MLIRIR PatternMatch.cpp Region.cpp RegionKindInterface.cpp + Remarks.cpp SymbolTable.cpp TensorEncoding.cpp Types.cpp diff --git a/mlir/lib/IR/Dialect.cpp b/mlir/lib/IR/Dialect.cpp index f84fe89..952619b 100644 --- a/mlir/lib/IR/Dialect.cpp +++ b/mlir/lib/IR/Dialect.cpp @@ -19,7 +19,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/Twine.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/Regex.h" #include <memory> @@ -104,14 +104,8 @@ void Dialect::addInterface(std::unique_ptr<DialectInterface> interface) { auto it = registeredInterfaces.try_emplace(interface->getID(), std::move(interface)); - (void)it; - LLVM_DEBUG({ - if (!it.second) { - llvm::dbgs() << "[" DEBUG_TYPE - "] repeated interface registration for dialect " - << getNamespace(); - } - }); + if (!it.second) + LDBG() << "repeated interface registration for dialect " << getNamespace(); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp index 23e70c6..662681e 100644 --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -52,8 +52,8 @@ struct FileLineColRangeAttrStorage final FileLineColRangeAttrStorage::totalSizeToAlloc<unsigned>(locEnc - 1); auto *rawMem = allocator.allocate(byteSize, alignof(FileLineColRangeAttrStorage)); - auto *result = ::new (rawMem) FileLineColRangeAttrStorage( - std::move(std::get<0>(tblgenKey)), locEnc - 1); + auto *result = ::new (rawMem) + FileLineColRangeAttrStorage(std::get<0>(tblgenKey), locEnc - 1); if (numInArray > 0) { ArrayRef<unsigned> elements = std::get<1>(tblgenKey); result->startLine = elements[0]; diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 2d5381d..1fa04ed 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -25,12 +25,13 @@ #include "mlir/IR/Location.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/Remarks.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ManagedStatic.h" #include "llvm/Support/Mutex.h" #include "llvm/Support/RWMutex.h" @@ -134,6 +135,11 @@ public: DiagnosticEngine diagEngine; //===--------------------------------------------------------------------===// + // Remark + //===--------------------------------------------------------------------===// + std::unique_ptr<remark::detail::RemarkEngine> remarkEngine; + + //===--------------------------------------------------------------------===// // Options //===--------------------------------------------------------------------===// @@ -388,6 +394,19 @@ bool MLIRContext::hasActionHandler() { return (bool)getImpl().actionHandler; } DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; } //===----------------------------------------------------------------------===// +// Remark Handlers +//===----------------------------------------------------------------------===// + +void MLIRContext::setRemarkEngine( + std::unique_ptr<remark::detail::RemarkEngine> engine) { + getImpl().remarkEngine = std::move(engine); +} + +remark::detail::RemarkEngine *MLIRContext::getRemarkEngine() { + return getImpl().remarkEngine.get(); +} + +//===----------------------------------------------------------------------===// // Dialect and Operation Registration //===----------------------------------------------------------------------===// @@ -455,8 +474,7 @@ MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID, auto dialectIt = impl.loadedDialects.try_emplace(dialectNamespace, nullptr); if (dialectIt.second) { - LLVM_DEBUG(llvm::dbgs() - << "Load new dialect in Context " << dialectNamespace << "\n"); + LDBG() << "Load new dialect in Context " << dialectNamespace; #ifndef NDEBUG if (impl.multiThreadedExecutionContext != 0) llvm::report_fatal_error( @@ -525,8 +543,7 @@ DynamicDialect *MLIRContext::getOrLoadDynamicDialect( "' has already been registered"); } - LLVM_DEBUG(llvm::dbgs() << "Load new dynamic dialect in Context " - << dialectNamespace << "\n"); + LDBG() << "Load new dynamic dialect in Context " << dialectNamespace; #ifndef NDEBUG if (impl.multiThreadedExecutionContext != 0) llvm::report_fatal_error( @@ -1192,11 +1209,10 @@ willBeValidAffineMap(unsigned dimCount, unsigned symbolCount, getMaxDimAndSymbol(ArrayRef<ArrayRef<AffineExpr>>(results), maxDimPosition, maxSymbolPosition); if ((maxDimPosition >= dimCount) || (maxSymbolPosition >= symbolCount)) { - LLVM_DEBUG( - llvm::dbgs() + LDBG() << "maximum dimensional identifier position in result expression must " "be less than `dimCount` and maximum symbolic identifier position " - "in result expression must be less than `symbolCount`\n"); + "in result expression must be less than `symbolCount`"; return false; } return true; diff --git a/mlir/lib/IR/ODSSupport.cpp b/mlir/lib/IR/ODSSupport.cpp index 69b4a56..f2665d2 100644 --- a/mlir/lib/IR/ODSSupport.cpp +++ b/mlir/lib/IR/ODSSupport.cpp @@ -112,7 +112,7 @@ Attribute mlir::convertToAttribute(MLIRContext *ctx, bool storage) { } template <typename DenseArrayTy, typename T> -LogicalResult +static LogicalResult convertDenseArrayFromAttr(MutableArrayRef<T> storage, Attribute attr, function_ref<InFlightDiagnostic()> emitError, StringRef denseArrayTyStr) { @@ -143,7 +143,7 @@ mlir::convertFromAttribute(MutableArrayRef<int32_t> storage, Attribute attr, } template <typename DenseArrayTy, typename T> -LogicalResult +static LogicalResult convertDenseArrayFromAttr(SmallVectorImpl<T> &storage, Attribute attr, function_ref<InFlightDiagnostic()> emitError, StringRef denseArrayTyStr) { diff --git a/mlir/lib/IR/Remarks.cpp b/mlir/lib/IR/Remarks.cpp new file mode 100644 index 0000000..78c9644 --- /dev/null +++ b/mlir/lib/IR/Remarks.cpp @@ -0,0 +1,279 @@ +//===- Remarks.cpp - MLIR Remarks -----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Remarks.h" + +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Value.h" + +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" + +using namespace mlir::remark::detail; + +//------------------------------------------------------------------------------ +// Remark +//------------------------------------------------------------------------------ + +Remark::Arg::Arg(llvm::StringRef k, Value v) : key(k) { + llvm::raw_string_ostream os(val); + os << v; +} + +Remark::Arg::Arg(llvm::StringRef k, Type t) : key(k) { + llvm::raw_string_ostream os(val); + os << t; +} + +void Remark::insert(llvm::StringRef s) { args.emplace_back(s); } +void Remark::insert(Arg a) { args.push_back(std::move(a)); } + +// Simple helper to print key=val list (sorted). +static void printArgs(llvm::raw_ostream &os, llvm::ArrayRef<Remark::Arg> args) { + if (args.empty()) + return; + + llvm::SmallVector<Remark::Arg, 8> sorted(args.begin(), args.end()); + llvm::sort(sorted, [](const Remark::Arg &a, const Remark::Arg &b) { + return a.key < b.key; + }); + + for (size_t i = 0; i < sorted.size(); ++i) { + const auto &a = sorted[i]; + os << a.key << "="; + + llvm::StringRef val(a.val); + bool needsQuote = val.contains(' ') || val.contains(',') || + val.contains('{') || val.contains('}'); + if (needsQuote) + os << '"' << val << '"'; + else + os << val; + + if (i + 1 < sorted.size()) + os << ", "; + } +} + +/// Print the remark to the given output stream. +/// Example output: +// clang-format off +/// [Missed] Category: Loop | Pass:Unroller | Function=main | Reason="tripCount=4 < threshold=256" +/// [Failure] LoopOptimizer | Reason="failed due to unsupported pattern" +// clang-format on +void Remark::print(llvm::raw_ostream &os, bool printLocation) const { + // Header: [Type] pass:remarkName + StringRef type = getRemarkTypeString(); + StringRef categoryName = getFullCategoryName(); + StringRef name = remarkName; + + os << '[' << type << "] "; + os << name << " | "; + if (!categoryName.empty()) + os << "Category:" << categoryName << " | "; + if (!functionName.empty()) + os << "Function=" << getFunction() << " | "; + + if (printLocation) { + if (auto flc = mlir::dyn_cast<mlir::FileLineColLoc>(getLocation())) + os << " @" << flc.getFilename() << ":" << flc.getLine() << ":" + << flc.getColumn(); + } + + printArgs(os, getArgs()); +} + +std::string Remark::getMsg() const { + std::string s; + llvm::raw_string_ostream os(s); + print(os); + os.flush(); + return s; +} + +llvm::StringRef Remark::getRemarkTypeString() const { + switch (remarkKind) { + case RemarkKind::RemarkUnknown: + return "Unknown"; + case RemarkKind::RemarkPassed: + return "Passed"; + case RemarkKind::RemarkMissed: + return "Missed"; + case RemarkKind::RemarkFailure: + return "Failure"; + case RemarkKind::RemarkAnalysis: + return "Analysis"; + } + llvm_unreachable("Unknown remark kind"); +} + +llvm::remarks::Type Remark::getRemarkType() const { + switch (remarkKind) { + case RemarkKind::RemarkUnknown: + return llvm::remarks::Type::Unknown; + case RemarkKind::RemarkPassed: + return llvm::remarks::Type::Passed; + case RemarkKind::RemarkMissed: + return llvm::remarks::Type::Missed; + case RemarkKind::RemarkFailure: + return llvm::remarks::Type::Failure; + case RemarkKind::RemarkAnalysis: + return llvm::remarks::Type::Analysis; + } + llvm_unreachable("Unknown remark kind"); +} + +llvm::remarks::Remark Remark::generateRemark() const { + auto locLambda = [&]() -> llvm::remarks::RemarkLocation { + if (auto flc = dyn_cast<FileLineColLoc>(getLocation())) + return {flc.getFilename(), flc.getLine(), flc.getColumn()}; + return {"<unknown file>", 0, 0}; + }; + + llvm::remarks::Remark r; // The result. + r.RemarkType = getRemarkType(); + r.RemarkName = getRemarkName(); + // MLIR does not use passes; instead, it has categories and sub-categories. + r.PassName = getFullCategoryName(); + r.FunctionName = getFunction(); + r.Loc = locLambda(); + for (const Remark::Arg &arg : getArgs()) { + r.Args.emplace_back(); + r.Args.back().Key = arg.key; + r.Args.back().Val = arg.val; + } + return r; +} + +//===----------------------------------------------------------------------===// +// InFlightRemark +//===----------------------------------------------------------------------===// + +InFlightRemark::~InFlightRemark() { + if (remark && owner) + owner->report(std::move(*remark)); + owner = nullptr; +} + +//===----------------------------------------------------------------------===// +// Remark Engine +//===----------------------------------------------------------------------===// + +template <typename RemarkT, typename... Args> +InFlightRemark RemarkEngine::makeRemark(Args &&...args) { + static_assert(std::is_base_of_v<Remark, RemarkT>, + "RemarkT must derive from Remark"); + return InFlightRemark(*this, + std::make_unique<RemarkT>(std::forward<Args>(args)...)); +} + +template <typename RemarkT> +InFlightRemark +RemarkEngine::emitIfEnabled(Location loc, RemarkOpts opts, + bool (RemarkEngine::*isEnabled)(StringRef) const) { + return (this->*isEnabled)(opts.categoryName) ? makeRemark<RemarkT>(loc, opts) + : InFlightRemark{}; +} + +bool RemarkEngine::isMissedOptRemarkEnabled(StringRef categoryName) const { + return missFilter && missFilter->match(categoryName); +} + +bool RemarkEngine::isPassedOptRemarkEnabled(StringRef categoryName) const { + return passedFilter && passedFilter->match(categoryName); +} + +bool RemarkEngine::isAnalysisOptRemarkEnabled(StringRef categoryName) const { + return analysisFilter && analysisFilter->match(categoryName); +} + +bool RemarkEngine::isFailedOptRemarkEnabled(StringRef categoryName) const { + return failedFilter && failedFilter->match(categoryName); +} + +InFlightRemark RemarkEngine::emitOptimizationRemark(Location loc, + RemarkOpts opts) { + return emitIfEnabled<OptRemarkPass>(loc, opts, + &RemarkEngine::isPassedOptRemarkEnabled); +} + +InFlightRemark RemarkEngine::emitOptimizationRemarkMiss(Location loc, + RemarkOpts opts) { + return emitIfEnabled<OptRemarkMissed>( + loc, opts, &RemarkEngine::isMissedOptRemarkEnabled); +} + +InFlightRemark RemarkEngine::emitOptimizationRemarkFailure(Location loc, + RemarkOpts opts) { + return emitIfEnabled<OptRemarkFailure>( + loc, opts, &RemarkEngine::isFailedOptRemarkEnabled); +} + +InFlightRemark RemarkEngine::emitOptimizationRemarkAnalysis(Location loc, + RemarkOpts opts) { + return emitIfEnabled<OptRemarkAnalysis>( + loc, opts, &RemarkEngine::isAnalysisOptRemarkEnabled); +} + +//===----------------------------------------------------------------------===// +// RemarkEngine +//===----------------------------------------------------------------------===// + +void RemarkEngine::report(const Remark &&remark) { + // Stream the remark + if (remarkStreamer) + remarkStreamer->streamOptimizationRemark(remark); + + // Print using MLIR's diagnostic + if (printAsEmitRemarks) + emitRemark(remark.getLocation(), remark.getMsg()); +} + +RemarkEngine::~RemarkEngine() { + if (remarkStreamer) + remarkStreamer->finalize(); +} + +llvm::LogicalResult +RemarkEngine::initialize(std::unique_ptr<MLIRRemarkStreamerBase> streamer, + std::string *errMsg) { + // If you need to validate categories/filters, do so here and set errMsg. + remarkStreamer = std::move(streamer); + return success(); +} + +RemarkEngine::RemarkEngine(bool printAsEmitRemarks, + const RemarkCategories &cats) + : printAsEmitRemarks(printAsEmitRemarks) { + if (cats.passed) + passedFilter = llvm::Regex(cats.passed.value()); + if (cats.missed) + missFilter = llvm::Regex(cats.missed.value()); + if (cats.analysis) + analysisFilter = llvm::Regex(cats.analysis.value()); + if (cats.failed) + failedFilter = llvm::Regex(cats.failed.value()); +} + +llvm::LogicalResult mlir::remark::enableOptimizationRemarks( + MLIRContext &ctx, + std::unique_ptr<remark::detail::MLIRRemarkStreamerBase> streamer, + const remark::RemarkCategories &cats, bool printAsEmitRemarks) { + auto engine = + std::make_unique<remark::detail::RemarkEngine>(printAsEmitRemarks, cats); + + std::string errMsg; + if (failed(engine->initialize(std::move(streamer), &errMsg))) { + llvm::report_fatal_error( + llvm::Twine("Failed to initialize remark engine. Error: ") + errMsg); + } + ctx.setRemarkEngine(std::move(engine)); + + return success(); +} diff --git a/mlir/lib/Interfaces/SideEffectInterfaces.cpp b/mlir/lib/Interfaces/SideEffectInterfaces.cpp index 266f6db..b5a6888 100644 --- a/mlir/lib/Interfaces/SideEffectInterfaces.cpp +++ b/mlir/lib/Interfaces/SideEffectInterfaces.cpp @@ -304,6 +304,11 @@ template bool mlir::hasEffect<BlockArgument, MemoryEffects::Write, MemoryEffects::Free>( Operation *, BlockArgument); +bool mlir::hasUnknownEffects(Operation *op) { + return !isa<MemoryEffectOpInterface>(op) && + !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>(); +} + bool mlir::wouldOpBeTriviallyDead(Operation *op) { if (op->mightHaveTrait<OpTrait::IsTerminator>()) return false; diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp index 2f47939..af4ea5a 100644 --- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp +++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp @@ -290,8 +290,7 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs, DivisionFixupFn fixup) { const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(), &rhsMax = rhs.umax(); - - if (!rhsMin.isZero()) { + if (!rhsMin.isZero() && !rhsMax.isZero()) { auto udiv = [&fixup](const APInt &a, const APInt &b) -> std::optional<APInt> { return fixup(a, b, a.udiv(b)); diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index c9481fb..caa9091 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -813,7 +813,7 @@ FailureOr<bool> ValueBoundsConstraintSet::strongCompare(const Variable &lhs, return false; // Keep processing as long as the strong relation cannot be proven. FailureOr<bool> ordered = cstr.strongComparePos(lhsPos, cmp, rhsPos); - return failed(ordered) ? true : false; + return failed(ordered); }; ValueBoundsConstraintSet cstr(lhs.getContext(), stopCondition); lhsPos = cstr.populateConstraints(lhs.map, lhs.mapOperands); diff --git a/mlir/lib/Query/Matcher/VariantValue.cpp b/mlir/lib/Query/Matcher/VariantValue.cpp index 98ed4cc..ec18c48 100644 --- a/mlir/lib/Query/Matcher/VariantValue.cpp +++ b/mlir/lib/Query/Matcher/VariantValue.cpp @@ -35,7 +35,7 @@ public: std::optional<DynMatcher> getDynMatcher() const override { std::vector<DynMatcher> dynMatchers; - for (auto variantMatcher : args) { + for (const auto &variantMatcher : args) { std::optional<DynMatcher> dynMatcher = variantMatcher.getDynMatcher(); if (dynMatcher) dynMatchers.push_back(dynMatcher.value()); @@ -66,8 +66,7 @@ VariantMatcher VariantMatcher::SingleMatcher(DynMatcher matcher) { VariantMatcher VariantMatcher::VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, ArrayRef<VariantMatcher> args) { - return VariantMatcher( - std::make_shared<VariadicOpPayload>(varOp, std::move(args))); + return VariantMatcher(std::make_shared<VariadicOpPayload>(varOp, args)); } std::optional<DynMatcher> VariantMatcher::MatcherOps::constructVariadicOperator( diff --git a/mlir/lib/Query/Query.cpp b/mlir/lib/Query/Query.cpp index 03e4177..375e820 100644 --- a/mlir/lib/Query/Query.cpp +++ b/mlir/lib/Query/Query.cpp @@ -141,7 +141,7 @@ LogicalResult MatchQuery::run(llvm::raw_ostream &os, QuerySession &qs) const { os << "\n"; for (auto &results : matches) { os << "Match #" << ++matchCount << ":\n\n"; - for (auto op : results.matchedOps) { + for (Operation *op : results.matchedOps) { if (op == results.rootOp) { finder.printMatch(os, qs, op, "root"); } else { diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp index 950b85e2..258fed1 100644 --- a/mlir/lib/RegisterAllDialects.cpp +++ b/mlir/lib/RegisterAllDialects.cpp @@ -102,6 +102,7 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Target/LLVM/NVVM/Target.h" #include "mlir/Target/LLVM/ROCDL/Target.h" +#include "mlir/Target/LLVM/XeVM/Target.h" #include "mlir/Target/SPIRV/Target.h" /// Add all the MLIR dialects to the provided registry. @@ -199,6 +200,7 @@ void mlir::registerAllDialects(DialectRegistry ®istry) { NVVM::registerNVVMTargetInterfaceExternalModels(registry); ROCDL::registerROCDLTargetInterfaceExternalModels(registry); spirv::registerSPIRVTargetInterfaceExternalModels(registry); + xevm::registerXeVMTargetInterfaceExternalModels(registry); } /// Append all the MLIR dialects to the registry contained in the given context. diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp index 8f7c67c..69a85db 100644 --- a/mlir/lib/RegisterAllExtensions.cpp +++ b/mlir/lib/RegisterAllExtensions.cpp @@ -28,6 +28,7 @@ #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" #include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h" #include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h" +#include "mlir/Conversion/PtrToLLVM/PtrToLLVM.h" #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" @@ -58,6 +59,7 @@ #include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" /// This function may be called to register all MLIR dialect extensions with the /// provided registry. @@ -80,6 +82,7 @@ void mlir::registerAllExtensions(DialectRegistry ®istry) { registerConvertMemRefToEmitCInterface(registry); registerConvertMemRefToLLVMInterface(registry); registerConvertNVVMToLLVMInterface(registry); + ptr::registerConvertPtrToLLVMInterface(registry); registerConvertOpenMPToLLVMInterface(registry); registerConvertSCFToEmitCInterface(registry); ub::registerConvertUBToLLVMInterface(registry); diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp index 1ed3a37..c67b242 100644 --- a/mlir/lib/RegisterAllPasses.cpp +++ b/mlir/lib/RegisterAllPasses.cpp @@ -45,6 +45,7 @@ #include "mlir/Dialect/Transform/Transforms/Passes.h" #include "mlir/Dialect/Vector/Transforms/Passes.h" #include "mlir/Dialect/XeGPU/Transforms/Passes.h" +#include "mlir/Target/LLVMIR/Transforms/Passes.h" #include "mlir/Transforms/Passes.h" // This function may be called to register the MLIR passes with the @@ -74,6 +75,7 @@ void mlir::registerAllPasses() { registerNVGPUPasses(); registerSparseTensorPasses(); LLVM::registerLLVMPasses(); + LLVM::registerTargetLLVMIRTransformsPasses(); math::registerMathPasses(); memref::registerMemRefPasses(); shard::registerShardPasses(); diff --git a/mlir/lib/Remark/CMakeLists.txt b/mlir/lib/Remark/CMakeLists.txt new file mode 100644 index 0000000..920a95d --- /dev/null +++ b/mlir/lib/Remark/CMakeLists.txt @@ -0,0 +1,14 @@ +add_mlir_library(MLIRRemarkStreamer + RemarkStreamer.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Remark + + LINK_LIBS PUBLIC + MLIRIR + + LINK_COMPONENTS + Remarks + Core + BitstreamReader + ) diff --git a/mlir/lib/Remark/RemarkStreamer.cpp b/mlir/lib/Remark/RemarkStreamer.cpp new file mode 100644 index 0000000..8e3544f --- /dev/null +++ b/mlir/lib/Remark/RemarkStreamer.cpp @@ -0,0 +1,69 @@ +#include "mlir/Remark/RemarkStreamer.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Remarks.h" + +#include "llvm/Remarks/RemarkSerializer.h" +#include "llvm/Remarks/RemarkStreamer.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/ToolOutputFile.h" + +namespace mlir::remark::detail { + +FailureOr<std::unique_ptr<MLIRRemarkStreamerBase>> +LLVMRemarkStreamer::createToFile(llvm::StringRef path, + llvm::remarks::Format fmt) { + std::error_code ec; + // Use error_code ctor; YAML is text. (Bitstream also works fine here.) + auto f = + std::make_unique<llvm::ToolOutputFile>(path, ec, llvm::sys::fs::OF_Text); + if (ec) + return failure(); + + auto serOr = llvm::remarks::createRemarkSerializer( + fmt, llvm::remarks::SerializerMode::Separate, f->os()); + if (!serOr) { + llvm::consumeError(serOr.takeError()); + return failure(); + } + + auto rs = + std::make_unique<llvm::remarks::RemarkStreamer>(std::move(*serOr), path); + + auto impl = std::unique_ptr<LLVMRemarkStreamer>(new LLVMRemarkStreamer()); + impl->remarkStreamer = std::move(rs); + impl->file = std::move(f); + return std::unique_ptr<MLIRRemarkStreamerBase>(std::move(impl)); +} + +void LLVMRemarkStreamer::streamOptimizationRemark(const Remark &remark) { + if (!remarkStreamer->matchesFilter(remark.getCategoryName())) + return; + + // First, convert the diagnostic to a remark. + llvm::remarks::Remark r = remark.generateRemark(); + // Then, emit the remark through the serializer. + remarkStreamer->getSerializer().emit(r); +} + +LLVMRemarkStreamer::~LLVMRemarkStreamer() { + if (file && remarkStreamer) + file->keep(); +} +} // namespace mlir::remark::detail + +namespace mlir::remark { +LogicalResult enableOptimizationRemarksWithLLVMStreamer( + MLIRContext &ctx, StringRef path, llvm::remarks::Format fmt, + const RemarkCategories &cat, bool printAsEmitRemarks) { + + FailureOr<std::unique_ptr<detail::MLIRRemarkStreamerBase>> sOr = + detail::LLVMRemarkStreamer::createToFile(path, fmt); + if (failed(sOr)) + return failure(); + + return remark::enableOptimizationRemarks(ctx, std::move(*sOr), cat, + printAsEmitRemarks); +} + +} // namespace mlir::remark diff --git a/mlir/lib/Rewrite/ByteCode.cpp b/mlir/lib/Rewrite/ByteCode.cpp index ae3f22d..5cbea5d 100644 --- a/mlir/lib/Rewrite/ByteCode.cpp +++ b/mlir/lib/Rewrite/ByteCode.cpp @@ -20,8 +20,10 @@ #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/Format.h" #include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/InterleavedRange.h" #include <numeric> #include <optional> @@ -707,10 +709,8 @@ void Generator::allocateMemoryIndices(pdl_interp::FuncOp matcherFunc, } // Print the index usage and ensure that we did not run out of index space. - LLVM_DEBUG({ - llvm::dbgs() << "Allocated " << allocatedIndices.size() << " indices " - << "(down from initial " << valueDefRanges.size() << ").\n"; - }); + LDBG() << "Allocated " << allocatedIndices.size() << " indices " + << "(down from initial " << valueDefRanges.size() << ")."; assert(allocatedIndices.size() <= std::numeric_limits<ByteCodeField>::max() && "Ran out of memory for allocated indices"); @@ -736,6 +736,7 @@ void Generator::generate(Region *region, ByteCodeWriter &writer) { } void Generator::generate(Operation *op, ByteCodeWriter &writer) { + LDBG() << "Generating bytecode for operation: " << op->getName(); LLVM_DEBUG({ // The following list must contain all the operations that do not // produce any bytecode. @@ -1275,12 +1276,8 @@ private: /// Handle a switch operation with the provided value and cases. template <typename T, typename RangeT, typename Comparator = std::equal_to<T>> void handleSwitch(const T &value, RangeT &&cases, Comparator cmp = {}) { - LLVM_DEBUG({ - llvm::dbgs() << " * Value: " << value << "\n" - << " * Cases: "; - llvm::interleaveComma(cases, llvm::dbgs()); - llvm::dbgs() << "\n"; - }); + LDBG() << "Switch operation:\n * Value: " << value + << "\n * Cases: " << llvm::interleaved(cases); // Check to see if the attribute value is within the case list. Jump to // the correct successor index based on the result. @@ -1424,38 +1421,27 @@ private: } // namespace void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { - LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n"); + LDBG() << "Executing ApplyConstraint:"; ByteCodeField fun_idx = read(); SmallVector<PDLValue, 16> args; readList<PDLValue>(args); - LLVM_DEBUG({ - llvm::dbgs() << " * Arguments: "; - llvm::interleaveComma(args, llvm::dbgs()); - llvm::dbgs() << "\n"; - }); + LDBG() << " * Arguments: " << llvm::interleaved(args); ByteCodeField isNegated = read(); - LLVM_DEBUG({ - llvm::dbgs() << " * isNegated: " << isNegated << "\n"; - llvm::interleaveComma(args, llvm::dbgs()); - }); + LDBG() << " * isNegated: " << isNegated; ByteCodeField numResults = read(); const PDLRewriteFunction &constraintFn = constraintFunctions[fun_idx]; ByteCodeRewriteResultList results(numResults); LogicalResult rewriteResult = constraintFn(rewriter, results, args); [[maybe_unused]] ArrayRef<PDLValue> constraintResults = results.getResults(); - LLVM_DEBUG({ - if (succeeded(rewriteResult)) { - llvm::dbgs() << " * Constraint succeeded\n"; - llvm::dbgs() << " * Results: "; - llvm::interleaveComma(constraintResults, llvm::dbgs()); - llvm::dbgs() << "\n"; - } else { - llvm::dbgs() << " * Constraint failed\n"; - } - }); + if (succeeded(rewriteResult)) { + LDBG() << " * Constraint succeeded, results: " + << llvm::interleaved(constraintResults); + } else { + LDBG() << " * Constraint failed"; + } assert((failed(rewriteResult) || constraintResults.size() == numResults) && "native PDL rewrite function succeeded but returned " "unexpected number of results"); @@ -1466,15 +1452,12 @@ void ByteCodeExecutor::executeApplyConstraint(PatternRewriter &rewriter) { } LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { - LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n"); + LDBG() << "Executing ApplyRewrite:"; const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()]; SmallVector<PDLValue, 16> args; readList<PDLValue>(args); - LLVM_DEBUG({ - llvm::dbgs() << " * Arguments: "; - llvm::interleaveComma(args, llvm::dbgs()); - }); + LDBG() << " * Arguments: " << llvm::interleaved(args); // Execute the rewrite function. ByteCodeField numResults = read(); @@ -1487,7 +1470,7 @@ LogicalResult ByteCodeExecutor::executeApplyRewrite(PatternRewriter &rewriter) { processNativeFunResults(results, numResults, rewriteResult); if (failed(rewriteResult)) { - LLVM_DEBUG(llvm::dbgs() << " - Failed"); + LDBG() << " - Failed"; return failure(); } return success(); @@ -1516,7 +1499,7 @@ void ByteCodeExecutor::processNativeFunResults( PDLValue::Kind resultKind = read<PDLValue::Kind>(); (void)resultKind; PDLValue result = results.getResults()[resultIdx]; - LLVM_DEBUG(llvm::dbgs() << " * Result: " << result << "\n"); + LDBG() << " * Result: " << result; assert(result.getKind() == resultKind && "native PDL rewrite function returned an unexpected type of " "result"); @@ -1544,16 +1527,16 @@ void ByteCodeExecutor::processNativeFunResults( } void ByteCodeExecutor::executeAreEqual() { - LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); + LDBG() << "Executing AreEqual:"; const void *lhs = read<const void *>(); const void *rhs = read<const void *>(); - LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n"); + LDBG() << " * " << lhs << " == " << rhs; selectJump(lhs == rhs); } void ByteCodeExecutor::executeAreRangesEqual() { - LLVM_DEBUG(llvm::dbgs() << "Executing AreRangesEqual:\n"); + LDBG() << "Executing AreRangesEqual:"; PDLValue::Kind valueKind = read<PDLValue::Kind>(); const void *lhs = read<const void *>(); const void *rhs = read<const void *>(); @@ -1562,14 +1545,14 @@ void ByteCodeExecutor::executeAreRangesEqual() { case PDLValue::Kind::TypeRange: { const TypeRange *lhsRange = reinterpret_cast<const TypeRange *>(lhs); const TypeRange *rhsRange = reinterpret_cast<const TypeRange *>(rhs); - LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); + LDBG() << " * " << lhs << " == " << rhs; selectJump(*lhsRange == *rhsRange); break; } case PDLValue::Kind::ValueRange: { const auto *lhsRange = reinterpret_cast<const ValueRange *>(lhs); const auto *rhsRange = reinterpret_cast<const ValueRange *>(rhs); - LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); + LDBG() << " * " << lhs << " == " << rhs; selectJump(*lhsRange == *rhsRange); break; } @@ -1579,20 +1562,19 @@ void ByteCodeExecutor::executeAreRangesEqual() { } void ByteCodeExecutor::executeBranch() { - LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n"); + LDBG() << "Executing Branch"; curCodeIt = &code[read<ByteCodeAddr>()]; } void ByteCodeExecutor::executeCheckOperandCount() { - LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n"); + LDBG() << "Executing CheckOperandCount:"; Operation *op = read<Operation *>(); uint32_t expectedCount = read<uint32_t>(); bool compareAtLeast = read(); - LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumOperands() << "\n" - << " * Expected: " << expectedCount << "\n" - << " * Comparator: " - << (compareAtLeast ? ">=" : "==") << "\n"); + LDBG() << " * Found: " << op->getNumOperands() + << "\n * Expected: " << expectedCount + << "\n * Comparator: " << (compareAtLeast ? ">=" : "=="); if (compareAtLeast) selectJump(op->getNumOperands() >= expectedCount); else @@ -1600,25 +1582,24 @@ void ByteCodeExecutor::executeCheckOperandCount() { } void ByteCodeExecutor::executeCheckOperationName() { - LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n"); + LDBG() << "Executing CheckOperationName:"; Operation *op = read<Operation *>(); OperationName expectedName = read<OperationName>(); - LLVM_DEBUG(llvm::dbgs() << " * Found: \"" << op->getName() << "\"\n" - << " * Expected: \"" << expectedName << "\"\n"); + LDBG() << " * Found: \"" << op->getName() << "\"\n * Expected: \"" + << expectedName << "\""; selectJump(op->getName() == expectedName); } void ByteCodeExecutor::executeCheckResultCount() { - LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n"); + LDBG() << "Executing CheckResultCount:"; Operation *op = read<Operation *>(); uint32_t expectedCount = read<uint32_t>(); bool compareAtLeast = read(); - LLVM_DEBUG(llvm::dbgs() << " * Found: " << op->getNumResults() << "\n" - << " * Expected: " << expectedCount << "\n" - << " * Comparator: " - << (compareAtLeast ? ">=" : "==") << "\n"); + LDBG() << " * Found: " << op->getNumResults() + << "\n * Expected: " << expectedCount + << "\n * Comparator: " << (compareAtLeast ? ">=" : "=="); if (compareAtLeast) selectJump(op->getNumResults() >= expectedCount); else @@ -1626,36 +1607,35 @@ void ByteCodeExecutor::executeCheckResultCount() { } void ByteCodeExecutor::executeCheckTypes() { - LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n"); + LDBG() << "Executing AreEqual:"; TypeRange *lhs = read<TypeRange *>(); Attribute rhs = read<Attribute>(); - LLVM_DEBUG(llvm::dbgs() << " * " << lhs << " == " << rhs << "\n\n"); + LDBG() << " * " << lhs << " == " << rhs; selectJump(*lhs == cast<ArrayAttr>(rhs).getAsValueRange<TypeAttr>()); } void ByteCodeExecutor::executeContinue() { ByteCodeField level = read(); - LLVM_DEBUG(llvm::dbgs() << "Executing Continue\n" - << " * Level: " << level << "\n"); + LDBG() << "Executing Continue\n * Level: " << level; ++loopIndex[level]; popCodeIt(); } void ByteCodeExecutor::executeCreateConstantTypeRange() { - LLVM_DEBUG(llvm::dbgs() << "Executing CreateConstantTypeRange:\n"); + LDBG() << "Executing CreateConstantTypeRange:"; unsigned memIndex = read(); unsigned rangeIndex = read(); ArrayAttr typesAttr = cast<ArrayAttr>(read<Attribute>()); - LLVM_DEBUG(llvm::dbgs() << " * Types: " << typesAttr << "\n\n"); + LDBG() << " * Types: " << typesAttr; assignRangeToMemory(typesAttr.getAsValueRange<TypeAttr>(), memIndex, rangeIndex); } void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, Location mainRewriteLoc) { - LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n"); + LDBG() << "Executing CreateOperation:"; unsigned memIndex = read(); OperationState state(mainRewriteLoc, read<OperationName>()); @@ -1696,45 +1676,37 @@ void ByteCodeExecutor::executeCreateOperation(PatternRewriter &rewriter, Operation *resultOp = rewriter.create(state); memory[memIndex] = resultOp; - LLVM_DEBUG({ - llvm::dbgs() << " * Attributes: " - << state.attributes.getDictionary(state.getContext()) - << "\n * Operands: "; - llvm::interleaveComma(state.operands, llvm::dbgs()); - llvm::dbgs() << "\n * Result Types: "; - llvm::interleaveComma(state.types, llvm::dbgs()); - llvm::dbgs() << "\n * Result: " << *resultOp << "\n"; - }); + LDBG() << " * Attributes: " + << state.attributes.getDictionary(state.getContext()) + << "\n * Operands: " << llvm::interleaved(state.operands) + << "\n * Result Types: " << llvm::interleaved(state.types) + << "\n * Result: " << *resultOp; } template <typename T> void ByteCodeExecutor::executeDynamicCreateRange(StringRef type) { - LLVM_DEBUG(llvm::dbgs() << "Executing CreateDynamic" << type << "Range:\n"); + LDBG() << "Executing CreateDynamic" << type << "Range:"; unsigned memIndex = read(); unsigned rangeIndex = read(); SmallVector<T> values; readList(values); - LLVM_DEBUG({ - llvm::dbgs() << "\n * " << type << "s: "; - llvm::interleaveComma(values, llvm::dbgs()); - llvm::dbgs() << "\n"; - }); + LDBG() << " * " << type << "s: " << llvm::interleaved(values); assignRangeToMemory(values, memIndex, rangeIndex); } void ByteCodeExecutor::executeEraseOp(PatternRewriter &rewriter) { - LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n"); + LDBG() << "Executing EraseOp:"; Operation *op = read<Operation *>(); - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); + LDBG() << " * Operation: " << *op; rewriter.eraseOp(op); } template <typename T, typename Range, PDLValue::Kind kind> void ByteCodeExecutor::executeExtract() { - LLVM_DEBUG(llvm::dbgs() << "Executing Extract" << kind << ":\n"); + LDBG() << "Executing Extract" << kind << ":"; Range *range = read<Range *>(); unsigned index = read<uint32_t>(); unsigned memIndex = read(); @@ -1745,18 +1717,16 @@ void ByteCodeExecutor::executeExtract() { } T result = index < range->size() ? (*range)[index] : T(); - LLVM_DEBUG(llvm::dbgs() << " * " << kind << "s(" << range->size() << ")\n" - << " * Index: " << index << "\n" - << " * Result: " << result << "\n"); + LDBG() << " * " << kind << "s(" << range->size() << ")"; + LDBG() << " * Index: " << index; + LDBG() << " * Result: " << result; storeToMemory(memIndex, result); } -void ByteCodeExecutor::executeFinalize() { - LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n"); -} +void ByteCodeExecutor::executeFinalize() { LDBG() << "Executing Finalize"; } void ByteCodeExecutor::executeForEach() { - LLVM_DEBUG(llvm::dbgs() << "Executing ForEach:\n"); + LDBG() << "Executing ForEach:"; const ByteCodeField *prevCodeIt = getPrevCodeIt(); unsigned rangeIndex = read(); unsigned memIndex = read(); @@ -1768,12 +1738,12 @@ void ByteCodeExecutor::executeForEach() { ArrayRef<Operation *> array = opRangeMemory[rangeIndex]; assert(index <= array.size() && "iterated past the end"); if (index < array.size()) { - LLVM_DEBUG(llvm::dbgs() << " * Result: " << array[index] << "\n"); + LDBG() << " * Result: " << array[index]; value = array[index]; break; } - LLVM_DEBUG(llvm::dbgs() << " * Done\n"); + LDBG() << " * Done"; index = 0; selectJump(size_t(0)); return; @@ -1791,49 +1761,47 @@ void ByteCodeExecutor::executeForEach() { } void ByteCodeExecutor::executeGetAttribute() { - LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n"); + LDBG() << "Executing GetAttribute:"; unsigned memIndex = read(); Operation *op = read<Operation *>(); StringAttr attrName = read<StringAttr>(); Attribute attr = op->getAttr(attrName); - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" - << " * Attribute: " << attrName << "\n" - << " * Result: " << attr << "\n"); + LDBG() << " * Operation: " << *op << "\n * Attribute: " << attrName + << "\n * Result: " << attr; memory[memIndex] = attr.getAsOpaquePointer(); } void ByteCodeExecutor::executeGetAttributeType() { - LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n"); + LDBG() << "Executing GetAttributeType:"; unsigned memIndex = read(); Attribute attr = read<Attribute>(); Type type; if (auto typedAttr = dyn_cast<TypedAttr>(attr)) type = typedAttr.getType(); - LLVM_DEBUG(llvm::dbgs() << " * Attribute: " << attr << "\n" - << " * Result: " << type << "\n"); + LDBG() << " * Attribute: " << attr << "\n * Result: " << type; memory[memIndex] = type.getAsOpaquePointer(); } void ByteCodeExecutor::executeGetDefiningOp() { - LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n"); + LDBG() << "Executing GetDefiningOp:"; unsigned memIndex = read(); Operation *op = nullptr; if (read<PDLValue::Kind>() == PDLValue::Kind::Value) { Value value = read<Value>(); if (value) op = value.getDefiningOp(); - LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + LDBG() << " * Value: " << value; } else { ValueRange *values = read<ValueRange *>(); if (values && !values->empty()) { op = values->front().getDefiningOp(); } - LLVM_DEBUG(llvm::dbgs() << " * Values: " << values << "\n"); + LDBG() << " * Values: " << values; } - LLVM_DEBUG(llvm::dbgs() << " * Result: " << op << "\n"); + LDBG() << " * Result: " << op; memory[memIndex] = op; } @@ -1843,9 +1811,8 @@ void ByteCodeExecutor::executeGetOperand(unsigned index) { Value operand = index < op->getNumOperands() ? op->getOperand(index) : Value(); - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" - << " * Index: " << index << "\n" - << " * Result: " << operand << "\n"); + LDBG() << " * Operation: " << *op << "\n * Index: " << index + << "\n * Result: " << operand; memory[memIndex] = operand.getAsOpaquePointer(); } @@ -1860,13 +1827,12 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index, // Check for the sentinel index that signals that all values should be // returned. if (index == std::numeric_limits<uint32_t>::max()) { - LLVM_DEBUG(llvm::dbgs() << " * Getting all values\n"); + LDBG() << " * Getting all values"; // `values` is already the full value range. // Otherwise, check to see if this operation uses AttrSizedSegments. } else if (op->hasTrait<AttrSizedSegmentsT>()) { - LLVM_DEBUG(llvm::dbgs() - << " * Extracting values from `" << attrSizedSegments << "`\n"); + LDBG() << " * Extracting values from `" << attrSizedSegments << "`"; auto segmentAttr = op->getAttrOfType<DenseI32ArrayAttr>(attrSizedSegments); if (!segmentAttr || segmentAttr.asArrayRef().size() <= index) @@ -1877,16 +1843,15 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index, std::accumulate(segments.begin(), segments.begin() + index, 0); values = values.slice(startIndex, *std::next(segments.begin(), index)); - LLVM_DEBUG(llvm::dbgs() << " * Extracting range[" << startIndex << ", " - << *std::next(segments.begin(), index) << "]\n"); + LDBG() << " * Extracting range[" << startIndex << ", " + << *std::next(segments.begin(), index) << "]"; // Otherwise, assume this is the last operand group of the operation. // FIXME: We currently don't support operations with // SameVariadicOperandSize/SameVariadicResultSize here given that we don't // have a way to detect it's presence. } else if (values.size() >= index) { - LLVM_DEBUG(llvm::dbgs() - << " * Treating values as trailing variadic range\n"); + LDBG() << " * Treating values as trailing variadic range"; values = values.drop_front(index); // If we couldn't detect a way to compute the values, bail out. @@ -1905,7 +1870,7 @@ executeGetOperandsResults(RangeT values, Operation *op, unsigned index, } void ByteCodeExecutor::executeGetOperands() { - LLVM_DEBUG(llvm::dbgs() << "Executing GetOperands:\n"); + LDBG() << "Executing GetOperands:"; unsigned index = read<uint32_t>(); Operation *op = read<Operation *>(); ByteCodeField rangeIndex = read(); @@ -1914,7 +1879,7 @@ void ByteCodeExecutor::executeGetOperands() { op->getOperands(), op, index, rangeIndex, "operandSegmentSizes", valueRangeMemory); if (!result) - LLVM_DEBUG(llvm::dbgs() << " * Invalid operand range\n"); + LDBG() << " * Invalid operand range"; memory[read()] = result; } @@ -1924,14 +1889,13 @@ void ByteCodeExecutor::executeGetResult(unsigned index) { OpResult result = index < op->getNumResults() ? op->getResult(index) : OpResult(); - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n" - << " * Index: " << index << "\n" - << " * Result: " << result << "\n"); + LDBG() << " * Operation: " << *op << "\n * Index: " << index + << "\n * Result: " << result; memory[memIndex] = result.getAsOpaquePointer(); } void ByteCodeExecutor::executeGetResults() { - LLVM_DEBUG(llvm::dbgs() << "Executing GetResults:\n"); + LDBG() << "Executing GetResults:"; unsigned index = read<uint32_t>(); Operation *op = read<Operation *>(); ByteCodeField rangeIndex = read(); @@ -1940,12 +1904,12 @@ void ByteCodeExecutor::executeGetResults() { op->getResults(), op, index, rangeIndex, "resultSegmentSizes", valueRangeMemory); if (!result) - LLVM_DEBUG(llvm::dbgs() << " * Invalid result range\n"); + LDBG() << " * Invalid result range"; memory[read()] = result; } void ByteCodeExecutor::executeGetUsers() { - LLVM_DEBUG(llvm::dbgs() << "Executing GetUsers:\n"); + LDBG() << "Executing GetUsers:"; unsigned memIndex = read(); unsigned rangeIndex = read(); OwningOpRange &range = opRangeMemory[rangeIndex]; @@ -1957,7 +1921,7 @@ void ByteCodeExecutor::executeGetUsers() { Value value = read<Value>(); if (!value) return; - LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + LDBG() << " * Value: " << value; // Extract the users of a single value. range = OwningOpRange(std::distance(value.user_begin(), value.user_end())); @@ -1967,11 +1931,8 @@ void ByteCodeExecutor::executeGetUsers() { ValueRange *values = read<ValueRange *>(); if (!values) return; - LLVM_DEBUG({ - llvm::dbgs() << " * Values (" << values->size() << "): "; - llvm::interleaveComma(*values, llvm::dbgs()); - llvm::dbgs() << "\n"; - }); + LDBG() << " * Values (" << values->size() + << "): " << llvm::interleaved(*values); // Extract all the users of a range of values. SmallVector<Operation *> users; @@ -1981,54 +1942,49 @@ void ByteCodeExecutor::executeGetUsers() { llvm::copy(users, range.begin()); } - LLVM_DEBUG(llvm::dbgs() << " * Result: " << range.size() << " operations\n"); + LDBG() << " * Result: " << range.size() << " operations"; } void ByteCodeExecutor::executeGetValueType() { - LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n"); + LDBG() << "Executing GetValueType:"; unsigned memIndex = read(); Value value = read<Value>(); Type type = value ? value.getType() : Type(); - LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n" - << " * Result: " << type << "\n"); + LDBG() << " * Value: " << value << "\n * Result: " << type; memory[memIndex] = type.getAsOpaquePointer(); } void ByteCodeExecutor::executeGetValueRangeTypes() { - LLVM_DEBUG(llvm::dbgs() << "Executing GetValueRangeTypes:\n"); + LDBG() << "Executing GetValueRangeTypes:"; unsigned memIndex = read(); unsigned rangeIndex = read(); ValueRange *values = read<ValueRange *>(); if (!values) { - LLVM_DEBUG(llvm::dbgs() << " * Values: <NULL>\n\n"); + LDBG() << " * Values: <NULL>"; memory[memIndex] = nullptr; return; } - LLVM_DEBUG({ - llvm::dbgs() << " * Values (" << values->size() << "): "; - llvm::interleaveComma(*values, llvm::dbgs()); - llvm::dbgs() << "\n * Result: "; - llvm::interleaveComma(values->getType(), llvm::dbgs()); - llvm::dbgs() << "\n"; - }); + LDBG() << " * Values (" << values->size() + << "): " << llvm::interleaved(*values) + << "\n * Result: " << llvm::interleaved(values->getType()); typeRangeMemory[rangeIndex] = values->getType(); memory[memIndex] = &typeRangeMemory[rangeIndex]; } void ByteCodeExecutor::executeIsNotNull() { - LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n"); + LDBG() << "Executing IsNotNull:"; const void *value = read<const void *>(); - LLVM_DEBUG(llvm::dbgs() << " * Value: " << value << "\n"); + LDBG() << " * Value: " << value; selectJump(value != nullptr); } void ByteCodeExecutor::executeRecordMatch( PatternRewriter &rewriter, SmallVectorImpl<PDLByteCode::MatchResult> &matches) { - LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n"); + LDBG() << "Executing RecordMatch:"; unsigned patternIndex = read(); PatternBenefit benefit = currentPatternBenefits[patternIndex]; const ByteCodeField *dest = &code[read<ByteCodeAddr>()]; @@ -2036,7 +1992,7 @@ void ByteCodeExecutor::executeRecordMatch( // If the benefit of the pattern is impossible, skip the processing of the // rest of the pattern. if (benefit.isImpossibleToMatch()) { - LLVM_DEBUG(llvm::dbgs() << " * Benefit: Impossible To Match\n"); + LDBG() << " * Benefit: Impossible To Match"; curCodeIt = dest; return; } @@ -2052,8 +2008,8 @@ void ByteCodeExecutor::executeRecordMatch( matchLocs.push_back(read<Operation *>()->getLoc()); Location matchLoc = rewriter.getFusedLoc(matchLocs); - LLVM_DEBUG(llvm::dbgs() << " * Benefit: " << benefit.getBenefit() << "\n" - << " * Location: " << matchLoc << "\n"); + LDBG() << " * Benefit: " << benefit.getBenefit(); + LDBG() << " * Location: " << matchLoc; matches.emplace_back(matchLoc, patterns[patternIndex], benefit); PDLByteCode::MatchResult &match = matches.back(); @@ -2083,38 +2039,34 @@ void ByteCodeExecutor::executeRecordMatch( } void ByteCodeExecutor::executeReplaceOp(PatternRewriter &rewriter) { - LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n"); + LDBG() << "Executing ReplaceOp:"; Operation *op = read<Operation *>(); SmallVector<Value, 16> args; readList(args); - LLVM_DEBUG({ - llvm::dbgs() << " * Operation: " << *op << "\n" - << " * Values: "; - llvm::interleaveComma(args, llvm::dbgs()); - llvm::dbgs() << "\n"; - }); + LDBG() << " * Operation: " << *op + << "\n * Values: " << llvm::interleaved(args); rewriter.replaceOp(op, args); } void ByteCodeExecutor::executeSwitchAttribute() { - LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n"); + LDBG() << "Executing SwitchAttribute:"; Attribute value = read<Attribute>(); ArrayAttr cases = read<ArrayAttr>(); handleSwitch(value, cases); } void ByteCodeExecutor::executeSwitchOperandCount() { - LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n"); + LDBG() << "Executing SwitchOperandCount:"; Operation *op = read<Operation *>(); auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); + LDBG() << " * Operation: " << *op; handleSwitch(op->getNumOperands(), cases); } void ByteCodeExecutor::executeSwitchOperationName() { - LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n"); + LDBG() << "Executing SwitchOperationName:"; OperationName value = read<Operation *>()->getName(); size_t caseCount = read(); @@ -2123,13 +2075,11 @@ void ByteCodeExecutor::executeSwitchOperationName() { // switch so that we can display all of the possible values. LLVM_DEBUG({ const ByteCodeField *prevCodeIt = curCodeIt; - llvm::dbgs() << " * Value: " << value << "\n" - << " * Cases: "; - llvm::interleaveComma( - llvm::map_range(llvm::seq<size_t>(0, caseCount), - [&](size_t) { return read<OperationName>(); }), - llvm::dbgs()); - llvm::dbgs() << "\n"; + LDBG() << " * Value: " << value << "\n * Cases: " + << llvm::interleaved( + llvm::map_range(llvm::seq<size_t>(0, caseCount), [&](size_t) { + return read<OperationName>(); + })); curCodeIt = prevCodeIt; }); @@ -2144,27 +2094,27 @@ void ByteCodeExecutor::executeSwitchOperationName() { } void ByteCodeExecutor::executeSwitchResultCount() { - LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n"); + LDBG() << "Executing SwitchResultCount:"; Operation *op = read<Operation *>(); auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>(); - LLVM_DEBUG(llvm::dbgs() << " * Operation: " << *op << "\n"); + LDBG() << " * Operation: " << *op; handleSwitch(op->getNumResults(), cases); } void ByteCodeExecutor::executeSwitchType() { - LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n"); + LDBG() << "Executing SwitchType:"; Type value = read<Type>(); auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>(); handleSwitch(value, cases); } void ByteCodeExecutor::executeSwitchTypes() { - LLVM_DEBUG(llvm::dbgs() << "Executing SwitchTypes:\n"); + LDBG() << "Executing SwitchTypes:"; TypeRange *value = read<TypeRange *>(); auto cases = read<ArrayAttr>().getAsRange<ArrayAttr>(); if (!value) { - LLVM_DEBUG(llvm::dbgs() << "Types: <NULL>\n"); + LDBG() << "Types: <NULL>"; return selectJump(size_t(0)); } handleSwitch(*value, cases, [](ArrayAttr caseValue, const TypeRange &value) { @@ -2178,7 +2128,7 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter, std::optional<Location> mainRewriteLoc) { while (true) { // Print the location of the operation being executed. - LLVM_DEBUG(llvm::dbgs() << readInline<Location>() << "\n"); + LDBG() << readInline<Location>(); OpCode opCode = static_cast<OpCode>(read()); switch (opCode) { @@ -2239,7 +2189,7 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter, break; case Finalize: executeFinalize(); - LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG() << ""; return success(); case ForEach: executeForEach(); @@ -2258,12 +2208,12 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter, case GetOperand2: case GetOperand3: { unsigned index = opCode - GetOperand0; - LLVM_DEBUG(llvm::dbgs() << "Executing GetOperand" << index << ":\n"); + LDBG() << "Executing GetOperand" << index << ":"; executeGetOperand(index); break; } case GetOperandN: - LLVM_DEBUG(llvm::dbgs() << "Executing GetOperandN:\n"); + LDBG() << "Executing GetOperandN:"; executeGetOperand(read<uint32_t>()); break; case GetOperands: @@ -2274,12 +2224,12 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter, case GetResult2: case GetResult3: { unsigned index = opCode - GetResult0; - LLVM_DEBUG(llvm::dbgs() << "Executing GetResult" << index << ":\n"); + LDBG() << "Executing GetResult" << index << ":"; executeGetResult(index); break; } case GetResultN: - LLVM_DEBUG(llvm::dbgs() << "Executing GetResultN:\n"); + LDBG() << "Executing GetResultN:"; executeGetResult(read<uint32_t>()); break; case GetResults: @@ -2324,7 +2274,7 @@ ByteCodeExecutor::execute(PatternRewriter &rewriter, executeSwitchTypes(); break; } - LLVM_DEBUG(llvm::dbgs() << "\n"); + LDBG() << ""; } } @@ -2383,7 +2333,7 @@ LogicalResult PDLByteCode::rewrite(PatternRewriter &rewriter, // bug in the user code (i.e. failable rewrites should not be used with // pattern rewriters that don't support it). if (failed(result) && !rewriter.canRecoverFromRewriteFailure()) { - LLVM_DEBUG(llvm::dbgs() << " and rollback is not supported - aborting"); + LDBG() << " and rollback is not supported - aborting"; llvm::report_fatal_error( "Native PDL Rewrite failed, but the pattern " "rewriter doesn't support recovery. Failable pattern rewrites should " diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp index e13bcff..23ae95a 100644 --- a/mlir/lib/Rewrite/PatternApplicator.cpp +++ b/mlir/lib/Rewrite/PatternApplicator.cpp @@ -37,9 +37,9 @@ PatternApplicator::~PatternApplicator() = default; #ifndef NDEBUG /// Log a message for a pattern that is impossible to match. static void logImpossibleToMatch(const Pattern &pattern) { - llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind() - << "' because it is impossible to match or cannot lead " - "to legal IR (by cost model)\n"; + LDBG() << "Ignoring pattern '" << pattern.getRootKind() + << "' because it is impossible to match or cannot lead " + "to legal IR (by cost model)"; } /// Log IR after pattern application. diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt index 6eb0abc..f0c3ac4 100644 --- a/mlir/lib/Target/CMakeLists.txt +++ b/mlir/lib/Target/CMakeLists.txt @@ -4,3 +4,4 @@ add_subdirectory(SPIRV) add_subdirectory(LLVMIR) add_subdirectory(LLVM) add_subdirectory(SMTLIB) +add_subdirectory(Wasm) diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index 8e83e45..570f38c 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -14,6 +14,7 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/IR/Value.h" #include "mlir/Support/IndentedOstream.h" #include "mlir/Support/LLVM.h" #include "mlir/Target/Cpp/CppEmitter.h" @@ -35,7 +36,7 @@ using llvm::formatv; /// on each element doesn't return a string. template <typename ForwardIterator, typename UnaryFunctor, typename NullaryFunctor> -inline LogicalResult +static inline LogicalResult interleaveWithError(ForwardIterator begin, ForwardIterator end, UnaryFunctor eachFn, NullaryFunctor betweenFn) { if (begin == end) @@ -52,16 +53,16 @@ interleaveWithError(ForwardIterator begin, ForwardIterator end, } template <typename Container, typename UnaryFunctor, typename NullaryFunctor> -inline LogicalResult interleaveWithError(const Container &c, - UnaryFunctor eachFn, - NullaryFunctor betweenFn) { +static inline LogicalResult interleaveWithError(const Container &c, + UnaryFunctor eachFn, + NullaryFunctor betweenFn) { return interleaveWithError(c.begin(), c.end(), eachFn, betweenFn); } template <typename Container, typename UnaryFunctor> -inline LogicalResult interleaveCommaWithError(const Container &c, - raw_ostream &os, - UnaryFunctor eachFn) { +static inline LogicalResult interleaveCommaWithError(const Container &c, + raw_ostream &os, + UnaryFunctor eachFn) { return interleaveWithError(c.begin(), c.end(), eachFn, [&]() { os << ", "; }); } @@ -364,9 +365,10 @@ static bool shouldBeInlined(ExpressionOp expressionOp) { if (hasDeferredEmission(user)) return false; - // Do not inline expressions used by ops with the CExpressionInterface. If - // this was intended, the user could have been merged into the expression op. - return !isa<emitc::CExpressionInterface>(*user); + // Do not inline expressions used by other expressions or by ops with the + // CExpressionInterface. If this was intended, the user could have been merged + // into the expression op. + return !isa<emitc::ExpressionOp, emitc::CExpressionInterface>(*user); } static LogicalResult printConstantOp(CppEmitter &emitter, Operation *operation, @@ -749,11 +751,7 @@ static LogicalResult printOperation(CppEmitter &emitter, if (t.getType().isIndex()) { int64_t idx = t.getInt(); Value operand = op.getOperand(idx); - if (!emitter.hasValueInScope(operand)) - return op.emitOpError("operand ") - << idx << "'s value not defined in scope"; - os << emitter.getOrCreateName(operand); - return success(); + return emitter.emitOperand(operand); } } if (failed(emitter.emitAttribute(op.getLoc(), attr))) @@ -782,9 +780,7 @@ static LogicalResult printOperation(CppEmitter &emitter, if (failed(emitter.emitAssignPrefix(op))) return failure(); os << applyOp.getApplicableOperator(); - os << emitter.getOrCreateName(applyOp.getOperand()); - - return success(); + return emitter.emitOperand(applyOp.getOperand()); } static LogicalResult printOperation(CppEmitter &emitter, @@ -1447,7 +1443,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { } if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) { if (auto iType = dyn_cast<IntegerType>( - cast<TensorType>(dense.getType()).getElementType())) { + cast<ShapedType>(dense.getType()).getElementType())) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, shouldMapToUnsigned(iType.getSignedness())); @@ -1456,7 +1452,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) { return success(); } if (auto iType = dyn_cast<IndexType>( - cast<TensorType>(dense.getType()).getElementType())) { + cast<ShapedType>(dense.getType()).getElementType())) { os << '{'; interleaveComma(dense, os, [&](const APInt &val) { printInt(val, false); }); @@ -1538,6 +1534,20 @@ LogicalResult CppEmitter::emitOperand(Value value) { if (expressionOp && shouldBeInlined(expressionOp)) return emitExpression(expressionOp); + if (BlockArgument arg = dyn_cast<BlockArgument>(value)) { + // If this operand is a block argument of an expression, emit instead the + // matching expression parameter. + Operation *argOp = arg.getParentBlock()->getParentOp(); + if (auto expressionOp = dyn_cast<ExpressionOp>(argOp)) { + // This scenario is only expected when one of the operations within the + // expression being emitted references one of the expression's block + // arguments. + assert(expressionOp == emittedExpression && + "Expected expression being emitted"); + value = expressionOp->getOperand(arg.getArgNumber()); + } + } + os << getOrCreateName(value); return success(); } @@ -1793,7 +1803,7 @@ LogicalResult CppEmitter::emitType(Location loc, Type type) { case 16: { if (llvm::isa<Float16Type>(type)) return (os << "_Float16"), success(); - else if (llvm::isa<BFloat16Type>(type)) + if (llvm::isa<BFloat16Type>(type)) return (os << "__bf16"), success(); else return emitError(loc, "cannot emit float type ") << type; diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt index f6e44c6..9a0e4d4 100644 --- a/mlir/lib/Target/LLVM/CMakeLists.txt +++ b/mlir/lib/Target/LLVM/CMakeLists.txt @@ -210,3 +210,27 @@ if(MLIR_ENABLE_ROCM_CONVERSIONS) ) endif() +if ("SPIRV" IN_LIST LLVM_TARGETS_TO_BUILD) + set(SPIRV_LIBS + SPIRVCodeGen + SPIRVDesc + SPIRVInfo + ) +endif() + +add_mlir_dialect_library(MLIRXeVMTarget + XeVM/Target.cpp + + OBJECT + + LINK_COMPONENTS + ${SPIRV_LIBS} + + LINK_LIBS PUBLIC + MLIRIR + MLIRExecutionEngineUtils + MLIRSupport + MLIRGPUDialect + MLIRTargetLLVM + MLIRXeVMToLLVMIRTranslation +) diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp index 55c8a64..8760ea8 100644 --- a/mlir/lib/Target/LLVM/NVVM/Target.cpp +++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp @@ -24,9 +24,11 @@ #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Export.h" +#include "llvm/Support/InterleavedRange.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Config/Targets.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/FileUtilities.h" #include "llvm/Support/FormatVariadic.h" @@ -265,6 +267,8 @@ NVPTXSerializer::NVPTXSerializer(Operation &module, NVVMTargetAttr target, std::optional<NVPTXSerializer::TmpFile> NVPTXSerializer::createTemp(StringRef name, StringRef suffix) { llvm::SmallString<128> filename; + if (name.size() > 80) + name = name.substr(0, 80); std::error_code ec = llvm::sys::fs::createTemporaryFile(name, suffix, filename); if (ec) { @@ -452,17 +456,11 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) { // Dump tool invocation commands. #define DEBUG_TYPE "serialize-to-binary" - LLVM_DEBUG({ - llvm::dbgs() << "Tool invocation for module: " - << getOperation().getNameAttr() << "\n"; - llvm::dbgs() << "ptxas executable:" << ptxasCompiler.value() << "\n"; - llvm::interleave(ptxasArgs, llvm::dbgs(), " "); - llvm::dbgs() << "\n"; - if (createFatbin) { - llvm::interleave(fatbinArgs, llvm::dbgs(), " "); - llvm::dbgs() << "\n"; - } - }); + LDBG() << "Tool invocation for module: " << getOperation().getNameAttr() + << "\nptxas executable:" << ptxasCompiler.value() + << "\nptxas args: " << llvm::interleaved(ptxasArgs, " "); + if (createFatbin) + LDBG() << "fatbin args: " << llvm::interleaved(fatbinArgs, " "); #undef DEBUG_TYPE // Helper function for printing tool error logs. @@ -507,7 +505,7 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) { llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer = llvm::MemoryBuffer::getFile(logFile->first); if (logBuffer && !(*logBuffer)->getBuffer().empty()) { - llvm::dbgs() << "Output:\n" << (*logBuffer)->getBuffer() << "\n"; + LDBG() << "Output:\n" << (*logBuffer)->getBuffer(); llvm::dbgs().flush(); } }); @@ -529,7 +527,7 @@ NVPTXSerializer::compileToBinary(const std::string &ptxCode) { llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> logBuffer = llvm::MemoryBuffer::getFile(logFile->first); if (logBuffer && !(*logBuffer)->getBuffer().empty()) { - llvm::dbgs() << "Output:\n" << (*logBuffer)->getBuffer() << "\n"; + LDBG() << "Output:\n" << (*logBuffer)->getBuffer(); llvm::dbgs().flush(); } }); @@ -629,12 +627,11 @@ NVPTXSerializer::compileToBinaryNVPTX(const std::string &ptxCode) { SmallVector<char> log(logSize + 1, 0); RETURN_ON_NVPTXCOMPILER_ERROR( nvPTXCompilerGetInfoLog(compiler, log.data())); - llvm::dbgs() << "NVPTX compiler invocation for module: " - << getOperation().getNameAttr() << "\n"; - llvm::dbgs() << "Arguments: "; - llvm::interleave(cmdOpts.second, llvm::dbgs(), " "); - llvm::dbgs() << "\nOutput\n" << log.data() << "\n"; - llvm::dbgs().flush(); + LDBG() << "NVPTX compiler invocation for module: " + << getOperation().getNameAttr() + << "\nArguments: " << llvm::interleaved(cmdOpts.second, " ") + << "\nOutput\n" + << log.data(); } }); #undef DEBUG_TYPE @@ -678,10 +675,8 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) { // Return LLVM IR if the compilation target is `offload`. #define DEBUG_TYPE "serialize-to-llvm" LLVM_DEBUG({ - llvm::dbgs() << "LLVM IR for module: " << getOperation().getNameAttr() - << "\n"; - llvm::dbgs() << llvmModule << "\n"; - llvm::dbgs().flush(); + LDBG() << "LLVM IR for module: " << getOperation().getNameAttr(); + LDBG() << llvmModule; }); #undef DEBUG_TYPE if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Offload) @@ -716,11 +711,8 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) { isaCallback(serializedISA.value()); #define DEBUG_TYPE "serialize-to-isa" - LLVM_DEBUG({ - llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n"; - llvm::dbgs() << *serializedISA << "\n"; - llvm::dbgs().flush(); - }); + LDBG() << "PTX for module: " << getOperation().getNameAttr() << "\n" + << *serializedISA; #undef DEBUG_TYPE // Return PTX if the compilation target is `assembly`. diff --git a/mlir/lib/Target/LLVM/XeVM/Target.cpp b/mlir/lib/Target/LLVM/XeVM/Target.cpp new file mode 100644 index 0000000..1e6784a2 --- /dev/null +++ b/mlir/lib/Target/LLVM/XeVM/Target.cpp @@ -0,0 +1,418 @@ +//===- Target.cpp - MLIR LLVM XeVM target compilation -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This files defines XeVM target related functions including registration +// calls for the `#xevm.target` compilation attribute. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVM/XeVM/Target.h" + +#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/DialectResourceBlobManager.h" +#include "mlir/Target/LLVM/XeVM/Utils.h" +#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Export.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/Target/TargetMachine.h" + +#include "llvm/Bitcode/BitcodeWriter.h" +#include "llvm/Config/Targets.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/FileUtilities.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" +#include "llvm/Support/Process.h" +#include "llvm/Support/Program.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/raw_ostream.h" + +#include <cstdint> +#include <cstdlib> + +using namespace mlir; +using namespace mlir::xevm; + +namespace { +// XeVM implementation of the gpu:TargetAttrInterface. +class XeVMTargetAttrImpl + : public gpu::TargetAttrInterface::FallbackModel<XeVMTargetAttrImpl> { +public: + std::optional<SmallVector<char, 0>> + serializeToObject(Attribute attribute, Operation *module, + const gpu::TargetOptions &options) const; + + Attribute createObject(Attribute attribute, Operation *module, + const SmallVector<char, 0> &object, + const gpu::TargetOptions &options) const; +}; +} // namespace + +void mlir::xevm::registerXeVMTargetInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) { + XeVMTargetAttr::attachInterface<XeVMTargetAttrImpl>(*ctx); + }); +} + +void mlir::xevm::registerXeVMTargetInterfaceExternalModels( + MLIRContext &context) { + DialectRegistry registry; + registerXeVMTargetInterfaceExternalModels(registry); + context.appendDialectRegistry(registry); +} + +SerializeGPUModuleBase::SerializeGPUModuleBase( + Operation &module, XeVMTargetAttr xeTarget, + const gpu::TargetOptions &targetOptions) + : ModuleToObject(module, xeTarget.getTriple(), "", {}, xeTarget.getO()), + xeTarget(xeTarget), librariesToLink(targetOptions.getLibrariesToLink()), + targetOptions(targetOptions) { + if (xeTarget.getLinkFiles()) + librariesToLink.append(xeTarget.getLinkFiles().begin(), + xeTarget.getLinkFiles().end()); +} + +XeVMTargetAttr SerializeGPUModuleBase::getTarget() const { return xeTarget; } + +std::optional<SmallVector<std::unique_ptr<llvm::Module>>> +SerializeGPUModuleBase::loadBitcodeFiles(llvm::Module &module) { + if (librariesToLink.empty()) + return SmallVector<std::unique_ptr<llvm::Module>>(); + SmallVector<std::unique_ptr<llvm::Module>> bcFiles; + if (failed(loadBitcodeFilesFromList(module.getContext(), librariesToLink, + bcFiles))) + return std::nullopt; + return std::move(bcFiles); +} + +gpu::GPUModuleOp SerializeGPUModuleBase::getGPUModuleOp() { + return dyn_cast<gpu::GPUModuleOp>(&SerializeGPUModuleBase::getOperation()); +} + +// There is 1 way to finalize IL to native code: IGC +// There are 2 ways to access IGC: AOT (ocloc) and JIT (L0 runtime). +// - L0 runtime consumes IL and is external to MLIR codebase (rt wrappers). +// - `ocloc` tool can be "queried" from within MLIR. +std::optional<SmallVector<char, 0>> +SerializeGPUModuleBase::compileToBinary(const std::string &asmStr, + StringRef inputFormat) { + using TmpFile = std::pair<llvm::SmallString<128>, llvm::FileRemover>; + // Find the `ocloc` tool. + std::optional<std::string> oclocCompiler = findTool("ocloc"); + if (!oclocCompiler) + return std::nullopt; + Location loc = getGPUModuleOp().getLoc(); + std::string basename = llvm::formatv( + "mlir-{0}-{1}-{2}", getGPUModuleOp().getNameAttr().getValue(), + getTarget().getTriple(), getTarget().getChip()); + + auto createTemp = [&](StringRef name, + StringRef suffix) -> std::optional<TmpFile> { + llvm::SmallString<128> filePath; + if (auto ec = llvm::sys::fs::createTemporaryFile(name, suffix, filePath)) { + getGPUModuleOp().emitError() + << "Couldn't create the temp file: `" << filePath + << "`, error message: " << ec.message(); + return std::nullopt; + } + return TmpFile(filePath, llvm::FileRemover(filePath.c_str())); + }; + // Create temp file + std::optional<TmpFile> asmFile = createTemp(basename, "asm"); + std::optional<TmpFile> binFile = createTemp(basename, ""); + std::optional<TmpFile> logFile = createTemp(basename, "log"); + if (!logFile || !asmFile || !binFile) + return std::nullopt; + // Dump the assembly to a temp file + std::error_code ec; + { + llvm::raw_fd_ostream asmStream(asmFile->first, ec); + if (ec) { + emitError(loc) << "Couldn't open the file: `" << asmFile->first + << "`, error message: " << ec.message(); + return std::nullopt; + } + asmStream << asmStr; + if (asmStream.has_error()) { + emitError(loc) << "An error occurred while writing the assembly to: `" + << asmFile->first << "`."; + return std::nullopt; + } + asmStream.flush(); + } + // Set cmd options + std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts = + targetOptions.tokenizeCmdOptions(); + // Example: --gpu-module-to-binary="opts='opt1 opt2'" + const std::string cmdOptsStr = "\"" + llvm::join(cmdOpts.second, " ") + "\""; + SmallVector<StringRef, 12> oclocArgs( + {"ocloc", "compile", "-file", asmFile->first, inputFormat, "-device", + getTarget().getChip(), "-output", binFile->first, "-output_no_suffix", + "-options", cmdOptsStr}); + +// Dump tool invocation commands. +#define DEBUG_TYPE "serialize-to-binary" + LLVM_DEBUG({ + llvm::dbgs() << "Tool invocation for module: " + << getGPUModuleOp().getNameAttr() << "\n"; + llvm::interleave(oclocArgs, llvm::dbgs(), " "); + llvm::dbgs() << "\n"; + }); +#undef DEBUG_TYPE + // Helper function for printing tool error logs. + std::string message; + auto emitLogError = + [&](StringRef toolName) -> std::optional<SmallVector<char, 0>> { + if (message.empty()) { + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> toolStderr = + llvm::MemoryBuffer::getFile(logFile->first); + if (toolStderr) + emitError(loc) << toolName << " invocation failed. Log:\n" + << toolStderr->get()->getBuffer(); + else + emitError(loc) << toolName << " invocation failed."; + return std::nullopt; + } + emitError(loc) << toolName + << " invocation failed, error message: " << message; + return std::nullopt; + }; + std::optional<StringRef> redirects[] = { + std::nullopt, + logFile->first, + logFile->first, + }; + // Invoke ocloc. + if (llvm::sys::ExecuteAndWait(oclocCompiler.value(), oclocArgs, std::nullopt, + redirects, 0, 0, &message)) + return emitLogError("`ocloc`"); + binFile->first.append(".bin"); + llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> binaryBuffer = + llvm::MemoryBuffer::getFile(binFile->first); + if (!binaryBuffer) { + emitError(loc) << "Couldn't open the file: `" << binFile->first + << "`, error message: " << binaryBuffer.getError().message(); + return std::nullopt; + } + StringRef bin = (*binaryBuffer)->getBuffer(); + return SmallVector<char, 0>(bin.begin(), bin.end()); +} + +std::optional<std::string> SerializeGPUModuleBase::findTool(StringRef tool) { + // 1. Check the toolkit path given in the command line. + StringRef pathRef = targetOptions.getToolkitPath(); + SmallVector<char, 256> path; + if (!pathRef.empty()) { + path.insert(path.begin(), pathRef.begin(), pathRef.end()); + llvm::sys::path::append(path, "bin", tool); + if (llvm::sys::fs::can_execute(path)) + return StringRef(path.data(), path.size()).str(); + } + // 2. Check PATH. + if (std::optional<std::string> toolPath = + llvm::sys::Process::FindInEnvPath("PATH", tool)) + return *toolPath; + + getGPUModuleOp().emitError() + << "Couldn't find the `" << tool + << "` binary. Please specify the toolkit " + "path via GpuModuleToBinaryPass or add the compiler to $PATH`."; + return std::nullopt; +} + +namespace { +class SPIRVSerializer : public SerializeGPUModuleBase { +public: + SPIRVSerializer(Operation &module, XeVMTargetAttr xeTarget, + const gpu::TargetOptions &targetOptions) + : SerializeGPUModuleBase(module, xeTarget, targetOptions) {} + + static void init(); + + /// Serializes the LLVM module to an object format, depending on the + /// compilation target selected in target options. + std::optional<SmallVector<char, 0>> + moduleToObject(llvm::Module &llvmModule) override; + +private: + /// Translates the LLVM module to SPIR-V binary using LLVM's + /// SPIR-V target. + std::optional<std::string> + translateToSPIRVBinary(llvm::Module &llvmModule, + llvm::TargetMachine &targetMachine); +}; +} // namespace + +void SPIRVSerializer::init() { + static llvm::once_flag initializeBackendOnce; + llvm::call_once(initializeBackendOnce, []() { +#if LLVM_HAS_SPIRV_TARGET + LLVMInitializeSPIRVTarget(); + LLVMInitializeSPIRVTargetInfo(); + LLVMInitializeSPIRVTargetMC(); + LLVMInitializeSPIRVAsmPrinter(); +#endif + }); +} + +std::optional<SmallVector<char, 0>> +SPIRVSerializer::moduleToObject(llvm::Module &llvmModule) { +#define DEBUG_TYPE "serialize-to-llvm" + LLVM_DEBUG({ + llvm::dbgs() << "LLVM IR for module: " << getGPUModuleOp().getNameAttr() + << "\n"; + llvm::dbgs() << llvmModule << "\n"; + llvm::dbgs().flush(); + }); +#undef DEBUG_TYPE + + // Return LLVM IR if the compilation target is `offload`. + if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Offload) + return SerializeGPUModuleBase::moduleToObject(llvmModule); + +#if !LLVM_HAS_SPIRV_TARGET + getGPUModuleOp()->emitError("The `SPIRV` target was not built. Please enable " + "it when building LLVM."); + return std::nullopt; +#endif // LLVM_HAS_SPIRV_TARGET + + std::optional<llvm::TargetMachine *> targetMachine = + getOrCreateTargetMachine(); + if (!targetMachine) { + getGPUModuleOp().emitError() << "Target Machine unavailable for triple " + << triple << ", can't optimize with LLVM\n"; + return std::nullopt; + } + + // Return SPIRV if the compilation target is `assembly`. + if (targetOptions.getCompilationTarget() == + gpu::CompilationTarget::Assembly) { + std::optional<std::string> serializedISA = + translateToISA(llvmModule, **targetMachine); + if (!serializedISA) { + getGPUModuleOp().emitError() << "Failed translating the module to ISA." + << triple << ", can't compile with LLVM\n"; + return std::nullopt; + } + +#define DEBUG_TYPE "serialize-to-isa" + LLVM_DEBUG({ + llvm::dbgs() << "SPIR-V for module: " << getGPUModuleOp().getNameAttr() + << "\n"; + llvm::dbgs() << *serializedISA << "\n"; + llvm::dbgs().flush(); + }); +#undef DEBUG_TYPE + + // Make sure to include the null terminator. + StringRef bin(serializedISA->c_str(), serializedISA->size() + 1); + return SmallVector<char, 0>(bin.begin(), bin.end()); + } + + // Level zero runtime is set up to accept SPIR-V binary + // translateToSPIRVBinary translates the LLVM module to SPIR-V binary + // using LLVM's SPIRV target. + // compileToBinary can be used in the future if level zero runtime + // implementation switches to native XeVM binary format. + std::optional<std::string> serializedSPIRVBinary = + translateToSPIRVBinary(llvmModule, **targetMachine); + if (!serializedSPIRVBinary) { + getGPUModuleOp().emitError() << "Failed translating the module to Binary."; + return std::nullopt; + } + if (serializedSPIRVBinary->size() % 4) { + getGPUModuleOp().emitError() << "SPIRV code size must be a multiple of 4."; + return std::nullopt; + } + StringRef bin(serializedSPIRVBinary->c_str(), serializedSPIRVBinary->size()); + return SmallVector<char, 0>(bin.begin(), bin.end()); +} + +std::optional<std::string> +SPIRVSerializer::translateToSPIRVBinary(llvm::Module &llvmModule, + llvm::TargetMachine &targetMachine) { + std::string targetISA; + llvm::raw_string_ostream stream(targetISA); + + { // Drop pstream after this to prevent the ISA from being stuck buffering + llvm::buffer_ostream pstream(stream); + llvm::legacy::PassManager codegenPasses; + if (targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr, + llvm::CodeGenFileType::ObjectFile)) + return std::nullopt; + + codegenPasses.run(llvmModule); + } + return targetISA; +} + +std::optional<SmallVector<char, 0>> +XeVMTargetAttrImpl::serializeToObject(Attribute attribute, Operation *module, + const gpu::TargetOptions &options) const { + if (!module) + return std::nullopt; + auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module); + if (!gpuMod) { + module->emitError("expected to be a gpu.module op"); + return std::nullopt; + } + auto xeTarget = cast<XeVMTargetAttr>(attribute); + if (xeTarget.getTriple().starts_with("spirv")) { + gpuMod.walk([&](LLVM::LLVMFuncOp funcOp) { + if (funcOp->hasAttr(gpu::GPUDialect::getKernelFuncAttrName())) { + funcOp.setIntelReqdSubGroupSize(16); + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + SPIRVSerializer serializer(*module, cast<XeVMTargetAttr>(attribute), + options); + serializer.init(); + +#if !LLVM_HAS_SPIRV_TARGET + module->emitError("Cannot run `TargetRegistry::lookupTarget()` for SPIRV " + "without having the target built."); +#endif + + return serializer.run(); + } + module->emitError("Unsupported XeVM target triple: ") << xeTarget.getTriple(); + return std::nullopt; +} + +Attribute +XeVMTargetAttrImpl::createObject(Attribute attribute, Operation *module, + const SmallVector<char, 0> &object, + const gpu::TargetOptions &options) const { + Builder builder(attribute.getContext()); + gpu::CompilationTarget format = options.getCompilationTarget(); + auto xeTarget = cast<XeVMTargetAttr>(attribute); + SmallVector<NamedAttribute, 2> properties; + if (format == gpu::CompilationTarget::Assembly) + properties.push_back( + builder.getNamedAttr("O", builder.getI32IntegerAttr(xeTarget.getO()))); + + DictionaryAttr objectProps; + if (!properties.empty()) + objectProps = builder.getDictionaryAttr(properties); + + return builder.getAttr<gpu::ObjectAttr>( + attribute, format, + builder.getStringAttr(StringRef(object.data(), object.size())), + objectProps, /*kernels=*/nullptr); +} diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt index 9ea5c683..a73a78d 100644 --- a/mlir/lib/Target/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Dialect) +add_subdirectory(Transforms) set(LLVM_OPTIONAL_SOURCES ConvertFromLLVMIR.cpp @@ -58,6 +59,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration MLIROpenACCToLLVMIRTranslation MLIROpenMPToLLVMIRTranslation MLIRROCDLToLLVMIRTranslation + MLIRPtrToLLVMIRTranslation MLIRSPIRVToLLVMIRTranslation MLIRVCIXToLLVMIRTranslation MLIRXeVMToLLVMIRTranslation diff --git a/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp b/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp index fbad5c2..8bd07cd 100644 --- a/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp +++ b/mlir/lib/Target/LLVMIR/DataLayoutImporter.cpp @@ -6,13 +6,14 @@ // //===----------------------------------------------------------------------===// -#include "DataLayoutImporter.h" +#include "mlir/Target/LLVMIR/DataLayoutImporter.h" #include "mlir/Dialect/DLTI/DLTI.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" #include "mlir/Target/LLVMIR/Import.h" + #include "llvm/IR/DataLayout.h" using namespace mlir; @@ -274,101 +275,88 @@ DataLayoutImporter::tryToEmplaceLegalIntWidthsEntry(StringRef token) { return success(); } -void DataLayoutImporter::translateDataLayout( - const llvm::DataLayout &llvmDataLayout) { - dataLayout = {}; - - // Transform the data layout to its string representation and append the - // default data layout string specified in the language reference - // (https://llvm.org/docs/LangRef.html#data-layout). The translation then - // parses the string and ignores the default value if a specific kind occurs - // in both strings. Additionally, the following default values exist: - // - non-default address space pointer specifications default to the default - // address space pointer specification - // - the alloca address space defaults to the default address space. - layoutStr = llvmDataLayout.getStringRepresentation(); - if (!layoutStr.empty()) - layoutStr += "-"; - layoutStr += kDefaultDataLayout; - StringRef layout(layoutStr); +DataLayoutSpecInterface DataLayoutImporter::dataLayoutSpecFromDataLayoutStr() { + if (!dataLayoutStr.empty()) + dataLayoutStr += "-"; + dataLayoutStr += kDefaultDataLayout; // Split the data layout string into tokens separated by a dash. SmallVector<StringRef> tokens; - layout.split(tokens, '-'); + StringRef(dataLayoutStr).split(tokens, '-'); for (StringRef token : tokens) { lastToken = token; FailureOr<StringRef> prefix = tryToParseAlphaPrefix(token); if (failed(prefix)) - return; + return {}; // Parse the endianness. if (*prefix == "e") { if (failed(tryToEmplaceEndiannessEntry( DLTIDialect::kDataLayoutEndiannessLittle, token))) - return; + return {}; continue; } if (*prefix == "E") { if (failed(tryToEmplaceEndiannessEntry( DLTIDialect::kDataLayoutEndiannessBig, token))) - return; + return {}; continue; } // Parse the program address space. if (*prefix == "P") { if (failed(tryToEmplaceAddrSpaceEntry( token, DLTIDialect::kDataLayoutProgramMemorySpaceKey))) - return; + return {}; continue; } // Parse the mangling mode. if (*prefix == "m") { if (failed(tryToEmplaceManglingModeEntry( token, DLTIDialect::kDataLayoutManglingModeKey))) - return; + return {}; continue; } // Parse the global address space. if (*prefix == "G") { if (failed(tryToEmplaceAddrSpaceEntry( token, DLTIDialect::kDataLayoutGlobalMemorySpaceKey))) - return; + return {}; continue; } // Parse the alloca address space. if (*prefix == "A") { if (failed(tryToEmplaceAddrSpaceEntry( token, DLTIDialect::kDataLayoutAllocaMemorySpaceKey))) - return; + return {}; continue; } // Parse the stack alignment. if (*prefix == "S") { if (failed(tryToEmplaceStackAlignmentEntry(token))) - return; + return {}; continue; } // Parse integer alignment specifications. if (*prefix == "i") { FailureOr<uint64_t> width = tryToParseInt(token); if (failed(width)) - return; + return {}; Type type = IntegerType::get(context, *width); if (failed(tryToEmplaceAlignmentEntry(type, token))) - return; + return {}; continue; } // Parse float alignment specifications. if (*prefix == "f") { FailureOr<uint64_t> width = tryToParseInt(token); if (failed(width)) - return; + return {}; Type type = getFloatType(context, *width); if (failed(tryToEmplaceAlignmentEntry(type, token))) - return; + return {}; continue; } // Parse pointer alignment specifications. @@ -376,17 +364,17 @@ void DataLayoutImporter::translateDataLayout( FailureOr<uint64_t> space = token.starts_with(":") ? 0 : tryToParseInt(token); if (failed(space)) - return; + return {}; auto type = LLVMPointerType::get(context, *space); if (failed(tryToEmplacePointerAlignmentEntry(type, token))) - return; + return {}; continue; } // Parse native integer widths specifications. if (*prefix == "n") { if (failed(tryToEmplaceLegalIntWidthsEntry(token))) - return; + return {}; continue; } // Parse function pointer alignment specifications. @@ -394,7 +382,7 @@ void DataLayoutImporter::translateDataLayout( if (prefix->starts_with("F")) { StringRef nextPrefix = prefix->drop_front(1); if (failed(tryToEmplaceFunctionPointerAlignmentEntry(nextPrefix, token))) - return; + return {}; continue; } @@ -409,11 +397,12 @@ void DataLayoutImporter::translateDataLayout( entries.push_back(it.second); for (const auto &it : keyEntries) entries.push_back(it.second); - dataLayout = DataLayoutSpecAttr::get(context, entries); + return DataLayoutSpecAttr::get(context, entries); } DataLayoutSpecInterface mlir::translateDataLayout(const llvm::DataLayout &dataLayout, MLIRContext *context) { - return DataLayoutImporter(context, dataLayout).getDataLayout(); + return DataLayoutImporter(context, dataLayout.getStringRepresentation()) + .getDataLayoutSpec(); } diff --git a/mlir/lib/Target/LLVMIR/DataLayoutImporter.h b/mlir/lib/Target/LLVMIR/DataLayoutImporter.h deleted file mode 100644 index 88ceaf1..0000000 --- a/mlir/lib/Target/LLVMIR/DataLayoutImporter.h +++ /dev/null @@ -1,132 +0,0 @@ -//===- DataLayoutImporter.h - LLVM to MLIR data layout conversion -*- C++ -*-=// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements the translation between the LLVMIR data layout and the -// corresponding MLIR representation. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_LIB_TARGET_LLVMIR_DATALAYOUTIMPORTER_H_ -#define MLIR_LIB_TARGET_LLVMIR_DATALAYOUTIMPORTER_H_ - -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/Interfaces/DataLayoutInterfaces.h" -#include "llvm/ADT/MapVector.h" - -namespace llvm { -class StringRef; -class DataLayout; -} // namespace llvm - -namespace mlir { -class FloatType; -class MLIRContext; -class Operation; - -namespace LLVM { -class LLVMFuncOp; - -namespace detail { - -/// Returns a supported MLIR floating point type of the given bit width or -/// null if the bit width is not supported. -FloatType getFloatType(MLIRContext *context, unsigned width); - -/// Helper class that translates an LLVM data layout to an MLIR data layout -/// specification. Only integer, float, pointer, alloca memory space, stack -/// alignment, and endianness entries are translated. The class also returns all -/// entries from the default data layout specification found in the language -/// reference (https://llvm.org/docs/LangRef.html#data-layout) if they are not -/// overwritten by the provided data layout. -class DataLayoutImporter { -public: - DataLayoutImporter(MLIRContext *context, - const llvm::DataLayout &llvmDataLayout) - : context(context) { - translateDataLayout(llvmDataLayout); - } - - /// Returns the MLIR data layout specification translated from the LLVM - /// data layout. - DataLayoutSpecInterface getDataLayout() const { return dataLayout; } - - /// Returns the last data layout token that has been processed before - /// the data layout translation failed. - StringRef getLastToken() const { return lastToken; } - - /// Returns the data layout tokens that have not been handled during the - /// data layout translation. - ArrayRef<StringRef> getUnhandledTokens() const { return unhandledTokens; } - -private: - /// Translates the LLVM `dataLayout` to an MLIR data layout specification. - void translateDataLayout(const llvm::DataLayout &llvmDataLayout); - - /// Tries to parse the letter only prefix that identifies the specification - /// and removes the consumed characters from the beginning of the string. - FailureOr<StringRef> tryToParseAlphaPrefix(StringRef &token) const; - - /// Tries to parse an integer parameter and removes the integer from the - /// beginning of the string. - FailureOr<uint64_t> tryToParseInt(StringRef &token) const; - - /// Tries to parse an integer parameter array. - FailureOr<SmallVector<uint64_t>> tryToParseIntList(StringRef token) const; - - /// Tries to parse the parameters of a type alignment entry. - FailureOr<DenseIntElementsAttr> tryToParseAlignment(StringRef token) const; - - /// Tries to parse the parameters of a pointer alignment entry. - FailureOr<DenseIntElementsAttr> - tryToParsePointerAlignment(StringRef token) const; - - /// Adds a type alignment entry if there is none yet. - LogicalResult tryToEmplaceAlignmentEntry(Type type, StringRef token); - - /// Adds a pointer alignment entry if there is none yet. - LogicalResult tryToEmplacePointerAlignmentEntry(LLVMPointerType type, - StringRef token); - - /// Adds an endianness entry if there is none yet. - LogicalResult tryToEmplaceEndiannessEntry(StringRef endianness, - StringRef token); - - /// Adds an alloca address space entry if there is none yet. - LogicalResult tryToEmplaceAddrSpaceEntry(StringRef token, - llvm::StringLiteral spaceKey); - - /// Adds an mangling mode entry if there is none yet. - LogicalResult tryToEmplaceManglingModeEntry(StringRef token, - llvm::StringLiteral manglingKey); - - /// Adds a stack alignment entry if there is none yet. - LogicalResult tryToEmplaceStackAlignmentEntry(StringRef token); - - /// Adds a function pointer alignment entry if there is none yet. - LogicalResult - tryToEmplaceFunctionPointerAlignmentEntry(StringRef fnPtrAlignEntry, - StringRef token); - - /// Adds legal int widths entry if there is none yet. - LogicalResult tryToEmplaceLegalIntWidthsEntry(StringRef token); - - std::string layoutStr = {}; - StringRef lastToken = {}; - SmallVector<StringRef> unhandledTokens; - llvm::MapVector<StringAttr, DataLayoutEntryInterface> keyEntries; - llvm::MapVector<TypeAttr, DataLayoutEntryInterface> typeEntries; - MLIRContext *context; - DataLayoutSpecInterface dataLayout; -}; - -} // namespace detail -} // namespace LLVM -} // namespace mlir - -#endif // MLIR_LIB_TARGET_LLVMIR_DATALAYOUTIMPORTER_H_ diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt index 86c731a..a102c43 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt +++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt @@ -8,6 +8,7 @@ add_subdirectory(NVVM) add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(ROCDL) +add_subdirectory(Ptr) add_subdirectory(SPIRV) add_subdirectory(VCIX) add_subdirectory(XeVM) diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp index 0f675a0..fd8463a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp @@ -18,6 +18,7 @@ #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/DIBuilder.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instructions.h" @@ -358,6 +359,17 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder, } } +static llvm::DILocalScope * +getLocalScopeFromLoc(llvm::IRBuilderBase &builder, Location loc, + LLVM::ModuleTranslation &moduleTranslation) { + if (auto scopeLoc = + loc->findInstanceOf<FusedLocWith<LLVM::DILocalScopeAttr>>()) + if (auto *localScope = llvm::dyn_cast<llvm::DILocalScope>( + moduleTranslation.translateDebugInfo(scopeLoc.getMetadata()))) + return localScope; + return builder.GetInsertBlock()->getParent()->getSubprogram(); +} + static LogicalResult convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp index 90462d1..7f69af14 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp @@ -135,33 +135,83 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) { llvm_unreachable("unsupported vote kind"); } -/// Return the intrinsic ID associated with ldmatrix for the given paramters. -static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout, - int32_t num) { - if (layout == NVVM::MMALayout::row) { +static llvm::Intrinsic::ID +getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num, + NVVM::LdStMatrixShapeAttr shape, + NVVM::LdStMatrixEltType eltType) { + if (shape.getM() == 8 && shape.getN() == 8) { switch (num) { case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16; + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16; + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16; - default: - llvm_unreachable("unsupported number of matrix"); + return (layout == NVVM::MMALayout::row) + ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16 + : llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; } - - } else { - switch (num) { - case 1: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16; - case 2: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16; - case 4: - return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16; - default: - llvm_unreachable("unsupported number of matrix"); + } else if (shape.getM() == 8 && shape.getN() == 16) { + if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32; + case 4: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64; + case 4: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64; + } + } + } else if (shape.getM() == 16 && shape.getN() == 16) { + if (eltType == NVVM::LdStMatrixEltType::B8) { + switch (num) { + case 1: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8; + case 2: + return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32; + } + } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) { + switch (num) { + case 1: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64; + case 2: + return llvm::Intrinsic:: + nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64; + } } } + llvm_unreachable("unknown ldmatrix kind"); } /// Return the intrinsic ID associated with stmatrix for the given paramters. @@ -418,7 +468,11 @@ public: } else if (attribute.getName() == NVVM::NVVMDialect::getKernelFuncAttrName()) { llvmFunc->setCallingConv(llvm::CallingConv::PTX_Kernel); + } else if (attribute.getName() == + NVVM::NVVMDialect::getBlocksAreClustersAttrName()) { + llvmFunc->addFnAttr("nvvm.blocksareclusters"); } + return success(); } @@ -429,51 +483,10 @@ public: llvm::LLVMContext &llvmContext = moduleTranslation.getLLVMContext(); llvm::Function *llvmFunc = moduleTranslation.lookupFunction(funcOp.getName()); - llvm::NamedMDNode *nvvmAnnotations = - moduleTranslation.getOrInsertNamedModuleMetadata("nvvm.annotations"); if (attribute.getName() == NVVM::NVVMDialect::getGridConstantAttrName()) { - llvm::MDNode *gridConstantMetaData = nullptr; - - // Check if a 'grid_constant' metadata node exists for the given function - for (llvm::MDNode *opnd : llvm::reverse(nvvmAnnotations->operands())) { - if (opnd->getNumOperands() == 3 && - opnd->getOperand(0) == llvm::ValueAsMetadata::get(llvmFunc) && - opnd->getOperand(1) == - llvm::MDString::get(llvmContext, "grid_constant")) { - gridConstantMetaData = opnd; - break; - } - } - - // 'grid_constant' is a function-level meta data node with a list of - // integers, where each integer n denotes that the nth parameter has the - // grid_constant annotation (numbering from 1). This requires aggregating - // the indices of the individual parameters that have this attribute. - llvm::Type *i32 = llvm::IntegerType::get(llvmContext, 32); - if (gridConstantMetaData == nullptr) { - // Create a new 'grid_constant' metadata node - SmallVector<llvm::Metadata *> gridConstMetadata = { - llvm::ValueAsMetadata::getConstant( - llvm::ConstantInt::get(i32, argIdx + 1))}; - llvm::Metadata *llvmMetadata[] = { - llvm::ValueAsMetadata::get(llvmFunc), - llvm::MDString::get(llvmContext, "grid_constant"), - llvm::MDNode::get(llvmContext, gridConstMetadata)}; - llvm::MDNode *llvmMetadataNode = - llvm::MDNode::get(llvmContext, llvmMetadata); - nvvmAnnotations->addOperand(llvmMetadataNode); - } else { - // Append argIdx + 1 to the 'grid_constant' argument list - if (auto argList = - dyn_cast<llvm::MDTuple>(gridConstantMetaData->getOperand(2))) { - llvm::TempMDTuple clonedArgList = argList->clone(); - clonedArgList->push_back((llvm::ValueAsMetadata::getConstant( - llvm::ConstantInt::get(i32, argIdx + 1)))); - gridConstantMetaData->replaceOperandWith( - 2, llvm::MDNode::replaceWithUniqued(std::move(clonedArgList))); - } - } + llvmFunc->addParamAttr( + argIdx, llvm::Attribute::get(llvmContext, "nvvm.grid_constant")); } return success(); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 2cdd502..8a1b554 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -2893,6 +2893,12 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, alignment = builder.getInt64(intAttr.getInt()); assert(ty->isPointerTy() && "Invalid type for aligned variable"); assert(alignment && "Invalid alignment value"); + + // Check if the alignment value is not a power of 2. If so, skip emitting + // alignment. + if (!intAttr.getValue().isPowerOf2()) + continue; + auto curInsert = builder.saveIP(); builder.SetInsertPoint(sourceBlock); llvmVal = builder.CreateLoad(ty, llvmVal); @@ -4356,9 +4362,11 @@ createAlteredByCaptureMap(MapInfoData &mapData, if (!isPtrTy) { auto curInsert = builder.saveIP(); + llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation(); builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation)); auto *memTempAlloc = builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted"); + builder.SetCurrentDebugLocation(DbgLoc); builder.restoreIP(curInsert); builder.CreateStore(newV, memTempAlloc); @@ -5865,6 +5873,10 @@ static bool isTargetDeviceOp(Operation *op) { if (mlir::isa<omp::ThreadprivateOp>(op)) return true; + if (mlir::isa<omp::TargetAllocMemOp>(op) || + mlir::isa<omp::TargetFreeMemOp>(op)) + return true; + if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) if (auto declareTargetIface = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( @@ -5877,6 +5889,85 @@ static bool isTargetDeviceOp(Operation *op) { return false; } +static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *i64Ty = builder.getInt64Ty(); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *returnType = builder.getPtrTy(0); + llvm::FunctionType *fnType = + llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false); + llvm::Function *func = cast<llvm::Function>( + llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst); + if (!allocMemOp) + return failure(); + + // Get "omp_target_alloc" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = allocMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the allocation size. + llvm::DataLayout dataLayout = llvmModule->getDataLayout(); + mlir::Type heapTy = allocMemOp.getAllocatedType(); + llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy); + llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy); + llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue()); + for (auto typeParam : allocMemOp.getTypeparams()) + allocSize = + builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam)); + // Create call to "omp_target_alloc" with the args as translated llvm values. + llvm::CallInst *call = + builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum}); + llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty()); + + // Map the result + moduleTranslation.mapValue(allocMemOp.getResult(), resultI64); + return success(); +} + +static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder, + llvm::Module *llvmModule) { + llvm::Type *ptrTy = builder.getPtrTy(0); + llvm::Type *i32Ty = builder.getInt32Ty(); + llvm::Type *voidTy = builder.getVoidTy(); + llvm::FunctionType *fnType = + llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false); + llvm::Function *func = dyn_cast<llvm::Function>( + llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee()); + return func; +} + +static LogicalResult +convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) { + auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst); + if (!freeMemOp) + return failure(); + + // Get "omp_target_free" function + llvm::Module *llvmModule = moduleTranslation.getLLVMModule(); + llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule); + // Get the corresponding device value in llvm + mlir::Value deviceNum = freeMemOp.getDevice(); + llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum); + // Get the corresponding heapref value in llvm + mlir::Value heapref = freeMemOp.getHeapref(); + llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref); + // Convert heapref int to ptr and call "omp_target_free" + llvm::Value *intToPtr = + builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0)); + builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum}); + return success(); +} + /// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including /// OpenMP runtime calls). static LogicalResult @@ -6051,6 +6142,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder, // the omp.canonical_loop. return applyUnrollHeuristic(op, builder, moduleTranslation); }) + .Case([&](omp::TargetAllocMemOp) { + return convertTargetAllocMemOp(*op, builder, moduleTranslation); + }) + .Case([&](omp::TargetFreeMemOp) { + return convertTargetFreeMemOp(*op, builder, moduleTranslation); + }) .Default([&](Operation *inst) { return inst->emitError() << "not yet implemented: " << inst->getName(); @@ -6287,9 +6384,8 @@ LogicalResult OpenMPDialectLLVMIRTranslationInterface::convertOperation( if (ompBuilder->Config.isTargetDevice()) { if (isTargetDeviceOp(op)) { return convertTargetDeviceOp(op, builder, moduleTranslation); - } else { - return convertTargetOpsInNest(op, builder, moduleTranslation); } + return convertTargetOpsInNest(op, builder, moduleTranslation); } return convertHostOrTargetOperation(op, builder, moduleTranslation); } diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt new file mode 100644 index 0000000..f94410d --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_translation_library(MLIRPtrToLLVMIRTranslation + PtrToLLVMIRTranslation.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPtrDialect + MLIRSupport + MLIRTargetLLVMIRExport + ) diff --git a/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp new file mode 100644 index 0000000..7b89ec8 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.cpp @@ -0,0 +1,66 @@ +//===- PtrToLLVMIRTranslation.cpp - Translate `ptr` to LLVM IR ------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a translation between the MLIR `ptr` dialect and +// LLVM IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Dialect/Ptr/PtrToLLVMIRTranslation.h" +#include "mlir/Dialect/Ptr/IR/PtrOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +using namespace mlir; +using namespace mlir::ptr; + +namespace { +/// Implementation of the dialect interface that converts operations belonging +/// to the `ptr` dialect to LLVM IR. +class PtrDialectLLVMIRTranslationInterface + : public LLVMTranslationDialectInterface { +public: + using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface; + + /// Translates the given operation to LLVM IR using the provided IR builder + /// and saving the state in `moduleTranslation`. + LogicalResult + convertOperation(Operation *op, llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation) const final { + // Translation for ptr dialect operations to LLVM IR is currently + // unimplemented. + return op->emitError("Translation for ptr dialect operations to LLVM IR is " + "not implemented."); + } + + /// Attaches module-level metadata for functions marked as kernels. + LogicalResult + amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions, + NamedAttribute attribute, + LLVM::ModuleTranslation &moduleTranslation) const final { + // Translation for ptr dialect operations to LLVM IR is currently + // unimplemented. + return op->emitError("Translation for ptr dialect operations to LLVM IR is " + "not implemented."); + } +}; +} // namespace + +void mlir::registerPtrDialectTranslation(DialectRegistry ®istry) { + registry.insert<ptr::PtrDialect>(); + registry.addExtension(+[](MLIRContext *ctx, ptr::PtrDialect *dialect) { + dialect->addInterfaces<PtrDialectLLVMIRTranslationInterface>(); + }); +} + +void mlir::registerPtrDialectTranslation(MLIRContext &context) { + DialectRegistry registry; + registerPtrDialectTranslation(registry); + context.appendDialectRegistry(registry); +} diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp index 6325480..7a888bb3 100644 --- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp @@ -16,7 +16,6 @@ #include "mlir/Target/LLVMIR/Import.h" #include "AttrKindDetail.h" -#include "DataLayoutImporter.h" #include "DebugImporter.h" #include "LoopAnnotationImporter.h" @@ -25,6 +24,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/Matchers.h" #include "mlir/Interfaces/DataLayoutInterfaces.h" +#include "mlir/Target/LLVMIR/DataLayoutImporter.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "llvm/ADT/DepthFirstIterator.h" @@ -1045,8 +1045,9 @@ LogicalResult ModuleImport::convertIFuncs() { LogicalResult ModuleImport::convertDataLayout() { Location loc = mlirModule.getLoc(); - DataLayoutImporter dataLayoutImporter(context, llvmModule->getDataLayout()); - if (!dataLayoutImporter.getDataLayout()) + DataLayoutImporter dataLayoutImporter( + context, llvmModule->getDataLayout().getStringRepresentation()); + if (!dataLayoutImporter.getDataLayoutSpec()) return emitError(loc, "cannot translate data layout: ") << dataLayoutImporter.getLastToken(); @@ -1054,7 +1055,7 @@ LogicalResult ModuleImport::convertDataLayout() { emitWarning(loc, "unhandled data layout token: ") << token; mlirModule->setAttr(DLTIDialect::kDataLayoutAttrName, - dataLayoutImporter.getDataLayout()); + dataLayoutImporter.getDataLayoutSpec()); return success(); } @@ -1408,6 +1409,67 @@ LogicalResult ModuleImport::convertIFunc(llvm::GlobalIFunc *ifunc) { return success(); } +/// Converts LLVM string, integer, and enum attributes into MLIR attributes, +/// skipping those in `attributesToSkip` and emitting a warning at `loc` for +/// any other unsupported attributes. +static ArrayAttr +convertLLVMAttributesToMLIR(Location loc, MLIRContext *context, + llvm::AttributeSet attributes, + ArrayRef<StringLiteral> attributesToSkip = {}) { + SmallVector<Attribute> mlirAttributes; + for (llvm::Attribute attr : attributes) { + StringRef attrName; + if (attr.isStringAttribute()) + attrName = attr.getKindAsString(); + else + attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum()); + if (llvm::is_contained(attributesToSkip, attrName)) + continue; + + auto keyAttr = StringAttr::get(context, attrName); + if (attr.isStringAttribute()) { + StringRef val = attr.getValueAsString(); + if (val.empty()) { + // For string attributes without values, add only the attribute name. + mlirAttributes.push_back(keyAttr); + continue; + } + // For string attributes with a value, create a [name, value] pair. + mlirAttributes.push_back( + ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)})); + continue; + } + if (attr.isIntAttribute()) { + // For integer attributes, convert the value to a string and create a + // [name, value] pair. + auto val = std::to_string(attr.getValueAsInt()); + mlirAttributes.push_back( + ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)})); + continue; + } + if (attr.isEnumAttribute()) { + // For enum attributes, add only the attribute name. + mlirAttributes.push_back(keyAttr); + continue; + } + + emitWarning(loc) + << "'" << attrName + << "' attribute is invalid on current operation, skipping it"; + } + return ArrayAttr::get(context, mlirAttributes); +} + +/// Converts LLVM attributes from `globalVar` into MLIR attributes and adds them +/// to `globalOp` as target-specific attributes. +static void processTargetSpecificAttrs(llvm::GlobalVariable *globalVar, + GlobalOp globalOp) { + ArrayAttr targetSpecificAttrs = convertLLVMAttributesToMLIR( + globalOp.getLoc(), globalOp.getContext(), globalVar->getAttributes()); + if (!targetSpecificAttrs.empty()) + globalOp.setTargetSpecificAttrsAttr(targetSpecificAttrs); +} + LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) { // Insert the global after the last one or at the start of the module. OpBuilder::InsertionGuard guard = setGlobalInsertionPoint(); @@ -1473,6 +1535,8 @@ LogicalResult ModuleImport::convertGlobal(llvm::GlobalVariable *globalVar) { if (globalVar->hasComdat()) globalOp.setComdatAttr(comdatMapping.lookup(globalVar->getComdat())); + processTargetSpecificAttrs(globalVar, globalOp); + return success(); } @@ -2525,7 +2589,7 @@ static void processMemoryEffects(llvm::Function *func, LLVMFuncOp funcOp) { // List of LLVM IR attributes that map to an explicit attribute on the MLIR // LLVMFuncOp. -static constexpr std::array kExplicitAttributes{ +static constexpr std::array kExplicitLLVMFuncOpAttributes{ StringLiteral("aarch64_in_za"), StringLiteral("aarch64_inout_za"), StringLiteral("aarch64_new_za"), @@ -2535,7 +2599,6 @@ static constexpr std::array kExplicitAttributes{ StringLiteral("aarch64_pstate_sm_compatible"), StringLiteral("aarch64_pstate_sm_enabled"), StringLiteral("alwaysinline"), - StringLiteral("approx-func-fp-math"), StringLiteral("convergent"), StringLiteral("denormal-fp-math"), StringLiteral("denormal-fp-math-f32"), @@ -2543,6 +2606,7 @@ static constexpr std::array kExplicitAttributes{ StringLiteral("frame-pointer"), StringLiteral("instrument-function-entry"), StringLiteral("instrument-function-exit"), + StringLiteral("memory"), StringLiteral("no-infs-fp-math"), StringLiteral("no-nans-fp-math"), StringLiteral("no-signed-zeros-fp-math"), @@ -2557,61 +2621,17 @@ static constexpr std::array kExplicitAttributes{ StringLiteral("willreturn"), }; +/// Converts LLVM attributes from `func` into MLIR attributes and adds them +/// to `funcOp` as passthrough attributes, skipping those listed in +/// `kExplicitLLVMFuncAttributes`. static void processPassthroughAttrs(llvm::Function *func, LLVMFuncOp funcOp) { - MLIRContext *context = funcOp.getContext(); - SmallVector<Attribute> passthroughs; llvm::AttributeSet funcAttrs = func->getAttributes().getAttributes( llvm::AttributeList::AttrIndex::FunctionIndex); - for (llvm::Attribute attr : funcAttrs) { - // Skip the memory attribute since the LLVMFuncOp has an explicit memory - // attribute. - if (attr.hasAttribute(llvm::Attribute::Memory)) - continue; - - // Skip invalid type attributes. - if (attr.isTypeAttribute()) { - emitWarning(funcOp.getLoc(), - "type attributes on a function are invalid, skipping it"); - continue; - } - - StringRef attrName; - if (attr.isStringAttribute()) - attrName = attr.getKindAsString(); - else - attrName = llvm::Attribute::getNameFromAttrKind(attr.getKindAsEnum()); - auto keyAttr = StringAttr::get(context, attrName); - - // Skip attributes that map to an explicit attribute on the LLVMFuncOp. - if (llvm::is_contained(kExplicitAttributes, attrName)) - continue; - - if (attr.isStringAttribute()) { - StringRef val = attr.getValueAsString(); - if (val.empty()) { - passthroughs.push_back(keyAttr); - continue; - } - passthroughs.push_back( - ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)})); - continue; - } - if (attr.isIntAttribute()) { - auto val = std::to_string(attr.getValueAsInt()); - passthroughs.push_back( - ArrayAttr::get(context, {keyAttr, StringAttr::get(context, val)})); - continue; - } - if (attr.isEnumAttribute()) { - passthroughs.push_back(keyAttr); - continue; - } - - llvm_unreachable("unexpected attribute kind"); - } - - if (!passthroughs.empty()) - funcOp.setPassthroughAttr(ArrayAttr::get(context, passthroughs)); + ArrayAttr passthroughAttr = + convertLLVMAttributesToMLIR(funcOp.getLoc(), funcOp.getContext(), + funcAttrs, kExplicitLLVMFuncOpAttributes); + if (!passthroughAttr.empty()) + funcOp.setPassthroughAttr(passthroughAttr); } void ModuleImport::processFunctionAttributes(llvm::Function *func, @@ -2703,10 +2723,6 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func, attr.isStringAttribute()) funcOp.setNoNansFpMath(attr.getValueAsBool()); - if (llvm::Attribute attr = func->getFnAttribute("approx-func-fp-math"); - attr.isStringAttribute()) - funcOp.setApproxFuncFpMath(attr.getValueAsBool()); - if (llvm::Attribute attr = func->getFnAttribute("instrument-function-entry"); attr.isStringAttribute()) funcOp.setInstrumentFunctionEntry( diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index b3a06e2..97253591 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -1081,6 +1081,83 @@ static void addRuntimePreemptionSpecifier(bool dsoLocalRequested, gv->setDSOLocal(true); } +/// Attempts to translate an MLIR attribute identified by `key`, optionally with +/// the given `value`, into an LLVM IR attribute. Reports errors at `loc` if +/// any. If the attribute name corresponds to a known LLVM IR attribute kind, +/// creates the LLVM attribute of that kind; otherwise, keeps it as a string +/// attribute. Performs additional checks for attributes known to have or not +/// have a value in order to avoid assertions inside LLVM upon construction. +static FailureOr<llvm::Attribute> +convertMLIRAttributeToLLVM(Location loc, llvm::LLVMContext &ctx, StringRef key, + StringRef value = StringRef()) { + auto kind = llvm::Attribute::getAttrKindFromName(key); + if (kind == llvm::Attribute::None) + return llvm::Attribute::get(ctx, key, value); + + if (llvm::Attribute::isIntAttrKind(kind)) { + if (value.empty()) + return emitError(loc) << "LLVM attribute '" << key << "' expects a value"; + + int64_t result; + if (!value.getAsInteger(/*Radix=*/0, result)) + return llvm::Attribute::get(ctx, kind, result); + return llvm::Attribute::get(ctx, key, value); + } + + if (!value.empty()) + return emitError(loc) << "LLVM attribute '" << key + << "' does not expect a value, found '" << value + << "'"; + + return llvm::Attribute::get(ctx, kind); +} + +/// Converts the MLIR attributes listed in the given array attribute into LLVM +/// attributes. Returns an `AttrBuilder` containing the converted attributes. +/// Reports error to `loc` if any and returns immediately. Expects `arrayAttr` +/// to contain either string attributes, treated as value-less LLVM attributes, +/// or array attributes containing two string attributes, with the first string +/// being the name of the corresponding LLVM attribute and the second string +/// beings its value. Note that even integer attributes are expected to have +/// their values expressed as strings. +static FailureOr<llvm::AttrBuilder> +convertMLIRAttributesToLLVM(Location loc, llvm::LLVMContext &ctx, + ArrayAttr arrayAttr, StringRef arrayAttrName) { + llvm::AttrBuilder attrBuilder(ctx); + if (!arrayAttr) + return attrBuilder; + + for (Attribute attr : arrayAttr) { + if (auto stringAttr = dyn_cast<StringAttr>(attr)) { + FailureOr<llvm::Attribute> llvmAttr = + convertMLIRAttributeToLLVM(loc, ctx, stringAttr.getValue()); + if (failed(llvmAttr)) + return failure(); + attrBuilder.addAttribute(*llvmAttr); + continue; + } + + auto arrayAttr = dyn_cast<ArrayAttr>(attr); + if (!arrayAttr || arrayAttr.size() != 2) + return emitError(loc) << "expected '" << arrayAttrName + << "' to contain string or array attributes"; + + auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]); + auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]); + if (!keyAttr || !valueAttr) + return emitError(loc) << "expected arrays within '" << arrayAttrName + << "' to contain two strings"; + + FailureOr<llvm::Attribute> llvmAttr = convertMLIRAttributeToLLVM( + loc, ctx, keyAttr.getValue(), valueAttr.getValue()); + if (failed(llvmAttr)) + return failure(); + attrBuilder.addAttribute(*llvmAttr); + } + + return attrBuilder; +} + LogicalResult ModuleTranslation::convertGlobalsAndAliases() { // Mapping from compile unit to its respective set of global variables. DenseMap<llvm::DICompileUnit *, SmallVector<llvm::Metadata *>> allGVars; @@ -1191,6 +1268,15 @@ LogicalResult ModuleTranslation::convertGlobalsAndAliases() { } } } + + // Forward the target-specific attributes to LLVM. + FailureOr<llvm::AttrBuilder> convertedTargetSpecificAttrs = + convertMLIRAttributesToLLVM(op.getLoc(), var->getContext(), + op.getTargetSpecificAttrsAttr(), + op.getTargetSpecificAttrsAttrName()); + if (failed(convertedTargetSpecificAttrs)) + return failure(); + var->addAttributes(*convertedTargetSpecificAttrs); } // Create all llvm::GlobalAlias @@ -1381,44 +1467,6 @@ LogicalResult ModuleTranslation::convertGlobalsAndAliases() { return success(); } -/// Attempts to add an attribute identified by `key`, optionally with the given -/// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the -/// attribute has a kind known to LLVM IR, create the attribute of this kind, -/// otherwise keep it as a string attribute. Performs additional checks for -/// attributes known to have or not have a value in order to avoid assertions -/// inside LLVM upon construction. -static LogicalResult checkedAddLLVMFnAttribute(Location loc, - llvm::Function *llvmFunc, - StringRef key, - StringRef value = StringRef()) { - auto kind = llvm::Attribute::getAttrKindFromName(key); - if (kind == llvm::Attribute::None) { - llvmFunc->addFnAttr(key, value); - return success(); - } - - if (llvm::Attribute::isIntAttrKind(kind)) { - if (value.empty()) - return emitError(loc) << "LLVM attribute '" << key << "' expects a value"; - - int64_t result; - if (!value.getAsInteger(/*Radix=*/0, result)) - llvmFunc->addFnAttr( - llvm::Attribute::get(llvmFunc->getContext(), kind, result)); - else - llvmFunc->addFnAttr(key, value); - return success(); - } - - if (!value.empty()) - return emitError(loc) << "LLVM attribute '" << key - << "' does not expect a value, found '" << value - << "'"; - - llvmFunc->addFnAttr(kind); - return success(); -} - /// Return a representation of `value` as metadata. static llvm::Metadata *convertIntegerToMetadata(llvm::LLVMContext &context, const llvm::APInt &value) { @@ -1454,45 +1502,6 @@ static llvm::MDNode *convertIntegerArrayToMDNode(llvm::LLVMContext &context, return llvm::MDNode::get(context, mdValues); } -/// Attaches the attributes listed in the given array attribute to `llvmFunc`. -/// Reports error to `loc` if any and returns immediately. Expects `attributes` -/// to be an array attribute containing either string attributes, treated as -/// value-less LLVM attributes, or array attributes containing two string -/// attributes, with the first string being the name of the corresponding LLVM -/// attribute and the second string beings its value. Note that even integer -/// attributes are expected to have their values expressed as strings. -static LogicalResult -forwardPassthroughAttributes(Location loc, std::optional<ArrayAttr> attributes, - llvm::Function *llvmFunc) { - if (!attributes) - return success(); - - for (Attribute attr : *attributes) { - if (auto stringAttr = dyn_cast<StringAttr>(attr)) { - if (failed( - checkedAddLLVMFnAttribute(loc, llvmFunc, stringAttr.getValue()))) - return failure(); - continue; - } - - auto arrayAttr = dyn_cast<ArrayAttr>(attr); - if (!arrayAttr || arrayAttr.size() != 2) - return emitError(loc) - << "expected 'passthrough' to contain string or array attributes"; - - auto keyAttr = dyn_cast<StringAttr>(arrayAttr[0]); - auto valueAttr = dyn_cast<StringAttr>(arrayAttr[1]); - if (!keyAttr || !valueAttr) - return emitError(loc) - << "expected arrays within 'passthrough' to contain two strings"; - - if (failed(checkedAddLLVMFnAttribute(loc, llvmFunc, keyAttr.getValue(), - valueAttr.getValue()))) - return failure(); - } - return success(); -} - LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { // Clear the block, branch value mappings, they are only relevant within one // function. @@ -1561,10 +1570,6 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) { if (auto noNansFpMath = func.getNoNansFpMath()) llvmFunc->addFnAttr("no-nans-fp-math", llvm::toStringRef(*noNansFpMath)); - if (auto approxFuncFpMath = func.getApproxFuncFpMath()) - llvmFunc->addFnAttr("approx-func-fp-math", - llvm::toStringRef(*approxFuncFpMath)); - if (auto noSignedZerosFpMath = func.getNoSignedZerosFpMath()) llvmFunc->addFnAttr("no-signed-zeros-fp-math", llvm::toStringRef(*noSignedZerosFpMath)); @@ -1864,9 +1869,13 @@ LogicalResult ModuleTranslation::convertFunctionSignatures() { } // Forward the pass-through attributes to LLVM. - if (failed(forwardPassthroughAttributes( - function.getLoc(), function.getPassthrough(), llvmFunc))) + FailureOr<llvm::AttrBuilder> convertedPassthroughAttrs = + convertMLIRAttributesToLLVM(function.getLoc(), llvmFunc->getContext(), + function.getPassthroughAttr(), + function.getPassthroughAttrName()); + if (failed(convertedPassthroughAttrs)) return failure(); + llvmFunc->addFnAttrs(*convertedPassthroughAttrs); // Convert visibility attribute. llvmFunc->setVisibility(convertVisibilityToLLVM(function.getVisibility_())); @@ -2407,11 +2416,6 @@ mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext, if (failed(translator.convertUnresolvedBlockAddress())) return nullptr; - // Once we've finished constructing elements in the module, we should convert - // it to use the debug info format desired by LLVM. - // See https://llvm.org/docs/RemoveDIsDebugInfo.html - translator.llvmModule->convertToNewDbgValues(); - // Add the necessary debug info module flags, if they were not encoded in MLIR // beforehand. translator.debugTranslation->addModuleFlagsIfNotPresent(); diff --git a/mlir/lib/Target/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Transforms/CMakeLists.txt new file mode 100644 index 0000000..044da1c --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Transforms/CMakeLists.txt @@ -0,0 +1,23 @@ +add_mlir_dialect_library(MLIRTargetLLVMIRTransforms + TargetToDataLayout.cpp + TargetToTargetFeatures.cpp + TargetUtils.cpp + + DEPENDS + MLIRTargetLLVMIRTransformsIncGen + + LINK_COMPONENTS + MC + Target + TargetParser + AllTargetsAsmParsers + AllTargetsCodeGens + AllTargetsDescs + AllTargetsInfos + + LINK_LIBS PUBLIC + MLIRDLTIDialect + MLIRLLVMDialect + MLIRPass + MLIRTargetLLVMIRImport + ) diff --git a/mlir/lib/Target/LLVMIR/Transforms/TargetToDataLayout.cpp b/mlir/lib/Target/LLVMIR/Transforms/TargetToDataLayout.cpp new file mode 100644 index 0000000..c0f9ceb --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Transforms/TargetToDataLayout.cpp @@ -0,0 +1,62 @@ +//===- TargetToDataLayout.cpp - extract data layout from TargetMachine ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Transforms/Passes.h" +#include "mlir/Target/LLVMIR/Transforms/TargetUtils.h" + +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Target/LLVMIR/Import.h" + +namespace mlir { +namespace LLVM { +#define GEN_PASS_DEF_LLVMTARGETTODATALAYOUT +#include "mlir/Target/LLVMIR/Transforms/Passes.h.inc" +} // namespace LLVM +} // namespace mlir + +using namespace mlir; + +struct TargetToDataLayoutPass + : public LLVM::impl::LLVMTargetToDataLayoutBase<TargetToDataLayoutPass> { + using LLVM::impl::LLVMTargetToDataLayoutBase< + TargetToDataLayoutPass>::LLVMTargetToDataLayoutBase; + + void runOnOperation() override { + Operation *op = getOperation(); + + if (initializeLLVMTargets) + LLVM::detail::initializeBackendsOnce(); + + auto targetAttr = op->getAttrOfType<LLVM::TargetAttrInterface>( + LLVM::LLVMDialect::getTargetAttrName()); + if (!targetAttr) { + op->emitError() + << "no TargetAttrInterface-implementing attribute at key \"" + << LLVM::LLVMDialect::getTargetAttrName() << "\""; + return signalPassFailure(); + } + + FailureOr<llvm::DataLayout> dataLayout = + LLVM::detail::getDataLayout(targetAttr); + if (failed(dataLayout)) { + op->emitError() << "failed to obtain llvm::DataLayout for " << targetAttr; + return signalPassFailure(); + } + + DataLayoutSpecInterface dataLayoutSpec = + mlir::translateDataLayout(dataLayout.value(), &getContext()); + + if (auto existingDlSpec = op->getAttrOfType<DataLayoutSpecInterface>( + DLTIDialect::kDataLayoutAttrName)) { + dataLayoutSpec = existingDlSpec.combineWith({dataLayoutSpec}); + } + + op->setAttr(DLTIDialect::kDataLayoutAttrName, dataLayoutSpec); + } +}; diff --git a/mlir/lib/Target/LLVMIR/Transforms/TargetToTargetFeatures.cpp b/mlir/lib/Target/LLVMIR/Transforms/TargetToTargetFeatures.cpp new file mode 100644 index 0000000..4a1ca46 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Transforms/TargetToTargetFeatures.cpp @@ -0,0 +1,78 @@ +//===- TargetToTargetFeatures.cpp - extract features from TargetMachine ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Transforms/Passes.h" +#include "mlir/Target/LLVMIR/Transforms/TargetUtils.h" + +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Target/LLVMIR/Import.h" + +#include "llvm/MC/MCSubtargetInfo.h" + +namespace mlir { +namespace LLVM { +#define GEN_PASS_DEF_LLVMTARGETTOTARGETFEATURES +#include "mlir/Target/LLVMIR/Transforms/Passes.h.inc" +} // namespace LLVM +} // namespace mlir + +using namespace mlir; + +struct TargetToTargetFeaturesPass + : public LLVM::impl::LLVMTargetToTargetFeaturesBase< + TargetToTargetFeaturesPass> { + using LLVM::impl::LLVMTargetToTargetFeaturesBase< + TargetToTargetFeaturesPass>::LLVMTargetToTargetFeaturesBase; + + void runOnOperation() override { + Operation *op = getOperation(); + + if (initializeLLVMTargets) + LLVM::detail::initializeBackendsOnce(); + + auto targetAttr = op->getAttrOfType<LLVM::TargetAttr>( + LLVM::LLVMDialect::getTargetAttrName()); + if (!targetAttr) { + op->emitError() << "no LLVM::TargetAttr attribute at key \"" + << LLVM::LLVMDialect::getTargetAttrName() << "\""; + return signalPassFailure(); + } + + FailureOr<std::unique_ptr<llvm::TargetMachine>> targetMachine = + LLVM::detail::getTargetMachine(targetAttr); + if (failed(targetMachine)) { + op->emitError() << "failed to obtain llvm::TargetMachine for " + << targetAttr; + return signalPassFailure(); + } + + llvm::MCSubtargetInfo const *subTargetInfo = + (*targetMachine)->getMCSubtargetInfo(); + + const std::vector<llvm::SubtargetFeatureKV> enabledFeatures = + subTargetInfo->getEnabledProcessorFeatures(); + + auto plussedFeatures = llvm::to_vector( + llvm::map_range(enabledFeatures, [](llvm::SubtargetFeatureKV feature) { + return std::string("+") + feature.Key; + })); + + auto plussedFeaturesRefs = llvm::to_vector(llvm::map_range( + plussedFeatures, [](auto &it) { return StringRef(it.c_str()); })); + + auto fullTargetFeaturesAttr = + LLVM::TargetFeaturesAttr::get(&getContext(), plussedFeaturesRefs); + + auto updatedTargetAttr = + LLVM::TargetAttr::get(&getContext(), targetAttr.getTriple(), + targetAttr.getChip(), fullTargetFeaturesAttr); + + op->setAttr(LLVM::LLVMDialect::getTargetAttrName(), updatedTargetAttr); + } +}; diff --git a/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp b/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp new file mode 100644 index 0000000..f1d3622 --- /dev/null +++ b/mlir/lib/Target/LLVMIR/Transforms/TargetUtils.cpp @@ -0,0 +1,71 @@ +//===- TargetUtils.cpp - utils for obtaining generic target backend info --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Target/LLVMIR/Transforms/Passes.h" + +#include "mlir/Dialect/DLTI/DLTI.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Target/LLVMIR/Import.h" + +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Target/TargetMachine.h" + +#define DEBUG_TYPE "mlir-llvm-target-utils" + +namespace mlir { +namespace LLVM { +namespace detail { +void initializeBackendsOnce() { + static const auto initOnce = [] { + // Ensure that the targets, that LLVM has been configured to support, + // are loaded into the TargetRegistry. + llvm::InitializeAllTargets(); + llvm::InitializeAllTargetMCs(); + return true; + }(); + (void)initOnce; // Dummy usage. +} + +FailureOr<std::unique_ptr<llvm::TargetMachine>> +getTargetMachine(mlir::LLVM::TargetAttrInterface attr) { + StringRef triple = attr.getTriple(); + StringRef chipAKAcpu = attr.getChip(); + // NB: `TargetAttrInterface::getFeatures()` is coarsely typed to work around + // cyclic dependency issue in tablegen files. + auto featuresAttr = + llvm::cast_if_present<LLVM::TargetFeaturesAttr>(attr.getFeatures()); + std::string features = featuresAttr ? featuresAttr.getFeaturesString() : ""; + + std::string error; + const llvm::Target *target = + llvm::TargetRegistry::lookupTarget(triple, error); + if (!target || !error.empty()) { + LDBG() << "Looking up target '" << triple << "' failed: " << error << "\n"; + return failure(); + } + + return std::unique_ptr<llvm::TargetMachine>(target->createTargetMachine( + llvm::Triple(triple), chipAKAcpu, features, {}, {})); +} + +FailureOr<llvm::DataLayout> +getDataLayout(mlir::LLVM::TargetAttrInterface attr) { + FailureOr<std::unique_ptr<llvm::TargetMachine>> targetMachine = + getTargetMachine(attr); + if (failed(targetMachine)) { + LDBG() << "Failed to retrieve the target machine for data layout.\n"; + return failure(); + } + return (targetMachine.value())->createDataLayout(); +} + +} // namespace detail +} // namespace LLVM +} // namespace mlir diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp index e4ba478..ddd5946 100644 --- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Target/LLVMIR/TypeToLLVM.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/BuiltinTypes.h" @@ -71,7 +72,7 @@ public: }) .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType, LLVM::LLVMPointerType, LLVM::LLVMStructType, VectorType, - LLVM::LLVMTargetExtType>( + LLVM::LLVMTargetExtType, PtrLikeTypeInterface>( [this](auto type) { return this->translate(type); }) .Default([](Type t) -> llvm::Type * { llvm_unreachable("unknown LLVM dialect type"); @@ -149,6 +150,14 @@ private: type.getIntParams()); } + /// Translates the given ptr type. + llvm::Type *translate(PtrLikeTypeInterface type) { + auto memSpace = dyn_cast<LLVM::AddressSpaceAttr>(type.getMemorySpace()); + assert(memSpace && "expected pointer with the LLVM address space"); + assert(!type.hasPtrMetadata() && "expected pointer without metadata"); + return llvm::PointerType::get(context, memSpace.getAddressSpace()); + } + /// Translates a list of types. void translateTypes(ArrayRef<Type> types, SmallVectorImpl<llvm::Type *> &result) { diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index d8c54ec..3625dd2 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -229,7 +229,7 @@ spirv::Deserializer::processMemoryModel(ArrayRef<uint32_t> operands) { } template <typename AttrTy, typename EnumAttrTy, typename EnumTy> -LogicalResult deserializeCacheControlDecoration( +static LogicalResult deserializeCacheControlDecoration( Location loc, OpBuilder &opBuilder, DenseMap<uint32_t, NamedAttrList> &decorations, ArrayRef<uint32_t> words, StringAttr symbol, StringRef decorationName, StringRef cacheControlKind) { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 7c007de..7fc7795 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" +#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -112,7 +113,9 @@ LogicalResult Serializer::serialize() { // TODO: handle the other sections processCapability(); - processExtension(); + if (failed(processExtension())) { + return failure(); + } processMemoryModel(); processDebugInfo(); @@ -204,13 +207,24 @@ void Serializer::processDebugInfo() { // TODO: Encode more debug instructions. } -void Serializer::processExtension() { +LogicalResult Serializer::processExtension() { llvm::SmallVector<uint32_t, 16> extName; - for (spirv::Extension ext : module.getVceTriple()->getExtensions()) { + llvm::SmallSet<Extension, 4> deducedExts( + llvm::from_range, module.getVceTriple()->getExtensions()); + auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info; + if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) { + TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(module); + if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt)) + return module.emitError( + "SPV_KHR_non_semantic_info extension not available"); + deducedExts.insert(nonSemanticInfoExt); + } + for (spirv::Extension ext : deducedExts) { extName.clear(); spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext)); encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName); } + return success(); } void Serializer::processMemoryModel() { diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h index 7047869..fb2cecd 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -102,7 +102,7 @@ private: void processDebugInfo(); - void processExtension(); + LogicalResult processExtension(); void processMemoryModel(); diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp index ac338d55..796354e 100644 --- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp +++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp @@ -21,8 +21,11 @@ #include "mlir/Target/SPIRV/Serialization.h" #include "mlir/Tools/mlir-translate/Translation.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/Path.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ToolOutputFile.h" using namespace mlir; @@ -76,24 +79,66 @@ void registerFromSPIRVTranslation() { // Serialization registration //===----------------------------------------------------------------------===// -static LogicalResult serializeModule(spirv::ModuleOp module, - raw_ostream &output) { +static LogicalResult +serializeModule(spirv::ModuleOp moduleOp, raw_ostream &output, + const spirv::SerializationOptions &options) { SmallVector<uint32_t, 0> binary; - if (failed(spirv::serialize(module, binary))) + if (failed(spirv::serialize(moduleOp, binary))) return failure(); - output.write(reinterpret_cast<char *>(binary.data()), - binary.size() * sizeof(uint32_t)); + size_t sizeInBytes = binary.size() * sizeof(uint32_t); + + output.write(reinterpret_cast<char *>(binary.data()), sizeInBytes); + + if (options.saveModuleForValidation) { + size_t dirSeparator = + options.validationFilePrefix.find(llvm::sys::path::get_separator()); + // If file prefix includes directory check if that directory exists. + if (dirSeparator != std::string::npos) { + llvm::StringRef parentDir = + llvm::sys::path::parent_path(options.validationFilePrefix); + if (!llvm::sys::fs::is_directory(parentDir)) + return moduleOp.emitError( + "validation prefix directory does not exist\n"); + } + + SmallString<128> filename; + int fd = 0; + + std::error_code errorCode = llvm::sys::fs::createUniqueFile( + options.validationFilePrefix + "%%%%%%.spv", fd, filename); + if (errorCode) + return moduleOp.emitError("error creating validation output file: ") + << errorCode.message() << "\n"; + + llvm::raw_fd_ostream validationOutput(fd, /*shouldClose=*/true); + validationOutput.write(reinterpret_cast<char *>(binary.data()), + sizeInBytes); + validationOutput.flush(); + } return mlir::success(); } namespace mlir { void registerToSPIRVTranslation() { + static llvm::cl::opt<std::string> validationFilesPrefix( + "spirv-save-validation-files-with-prefix", + llvm::cl::desc( + "When non-empty string is passed each serialized SPIR-V module is " + "saved to an additional file that starts with the given prefix. This " + "is used to generate separate binaries for validation, where " + "`--split-input-file` normally combines all outputs into one. The " + "one combined output (`-o`) is still written. Created files need to " + "be removed manually once processed."), + llvm::cl::init("")); + TranslateFromMLIRRegistration toBinary( "serialize-spirv", "serialize SPIR-V dialect", - [](spirv::ModuleOp module, raw_ostream &output) { - return serializeModule(module, output); + [](spirv::ModuleOp moduleOp, raw_ostream &output) { + return serializeModule(moduleOp, output, + {true, false, !validationFilesPrefix.empty(), + validationFilesPrefix}); }, [](DialectRegistry ®istry) { registry.insert<spirv::SPIRVDialect>(); diff --git a/mlir/lib/Target/Wasm/CMakeLists.txt b/mlir/lib/Target/Wasm/CMakeLists.txt new file mode 100644 index 0000000..890fc0ec --- /dev/null +++ b/mlir/lib/Target/Wasm/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_translation_library(MLIRTargetWasmImport + TranslateRegistration.cpp + TranslateFromWasm.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/Target/Wasm + + LINK_LIBS PUBLIC + MLIRWasmSSADialect + MLIRIR + MLIRSupport + MLIRTranslateLib +) diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp new file mode 100644 index 0000000..6afbe05 --- /dev/null +++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp @@ -0,0 +1,1522 @@ +//===- TranslateFromWasm.cpp - Translating to WasmSSA dialect -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the WebAssembly importer. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Target/Wasm/WasmBinaryEncoding.h" +#include "mlir/Target/Wasm/WasmImporter.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/Endian.h" +#include "llvm/Support/FormatVariadic.h" +#include "llvm/Support/LEB128.h" +#include "llvm/Support/LogicalResult.h" + +#include <cstddef> +#include <cstdint> +#include <variant> + +#define DEBUG_TYPE "wasm-translate" + +static_assert(CHAR_BIT == 8, + "This code expects std::byte to be exactly 8 bits"); + +using namespace mlir; +using namespace mlir::wasm; +using namespace mlir::wasmssa; + +namespace { +using section_id_t = uint8_t; +enum struct WasmSectionType : section_id_t { + CUSTOM = 0, + TYPE = 1, + IMPORT = 2, + FUNCTION = 3, + TABLE = 4, + MEMORY = 5, + GLOBAL = 6, + EXPORT = 7, + START = 8, + ELEMENT = 9, + CODE = 10, + DATA = 11, + DATACOUNT = 12 +}; + +constexpr section_id_t highestWasmSectionID{ + static_cast<section_id_t>(WasmSectionType::DATACOUNT)}; + +#define APPLY_WASM_SEC_TRANSFORM \ + WASM_SEC_TRANSFORM(CUSTOM) \ + WASM_SEC_TRANSFORM(TYPE) \ + WASM_SEC_TRANSFORM(IMPORT) \ + WASM_SEC_TRANSFORM(FUNCTION) \ + WASM_SEC_TRANSFORM(TABLE) \ + WASM_SEC_TRANSFORM(MEMORY) \ + WASM_SEC_TRANSFORM(GLOBAL) \ + WASM_SEC_TRANSFORM(EXPORT) \ + WASM_SEC_TRANSFORM(START) \ + WASM_SEC_TRANSFORM(ELEMENT) \ + WASM_SEC_TRANSFORM(CODE) \ + WASM_SEC_TRANSFORM(DATA) \ + WASM_SEC_TRANSFORM(DATACOUNT) + +template <WasmSectionType> +constexpr const char *wasmSectionName = ""; + +#define WASM_SEC_TRANSFORM(section) \ + template <> \ + [[maybe_unused]] constexpr const char \ + *wasmSectionName<WasmSectionType::section> = #section; +APPLY_WASM_SEC_TRANSFORM +#undef WASM_SEC_TRANSFORM + +constexpr bool sectionShouldBeUnique(WasmSectionType secType) { + return secType != WasmSectionType::CUSTOM; +} + +template <std::byte... Bytes> +struct ByteSequence {}; + +/// Template class for representing a byte sequence of only one byte +template <std::byte Byte> +struct UniqueByte : ByteSequence<Byte> {}; + +[[maybe_unused]] constexpr ByteSequence< + WasmBinaryEncoding::Type::i32, WasmBinaryEncoding::Type::i64, + WasmBinaryEncoding::Type::f32, WasmBinaryEncoding::Type::f64, + WasmBinaryEncoding::Type::v128> valueTypesEncodings{}; + +template <std::byte... allowedFlags> +constexpr bool isValueOneOf(std::byte value, + ByteSequence<allowedFlags...> = {}) { + return ((value == allowedFlags) | ... | false); +} + +template <std::byte... flags> +constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) { + return !isValueOneOf<flags...>(value); +} + +struct GlobalTypeRecord { + Type type; + bool isMutable; +}; + +struct TypeIdxRecord { + size_t id; +}; + +struct SymbolRefContainer { + FlatSymbolRefAttr symbol; +}; + +struct GlobalSymbolRefContainer : SymbolRefContainer { + Type globalType; +}; + +struct FunctionSymbolRefContainer : SymbolRefContainer { + FunctionType functionType; +}; + +using ImportDesc = + std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>; + +using parsed_inst_t = FailureOr<SmallVector<Value>>; + +struct WasmModuleSymbolTables { + SmallVector<FunctionSymbolRefContainer> funcSymbols; + SmallVector<GlobalSymbolRefContainer> globalSymbols; + SmallVector<SymbolRefContainer> memSymbols; + SmallVector<SymbolRefContainer> tableSymbols; + SmallVector<FunctionType> moduleFuncTypes; + + std::string getNewSymbolName(StringRef prefix, size_t id) const { + return (prefix + Twine{id}).str(); + } + + std::string getNewFuncSymbolName() const { + size_t id = funcSymbols.size(); + return getNewSymbolName("func_", id); + } + + std::string getNewGlobalSymbolName() const { + size_t id = globalSymbols.size(); + return getNewSymbolName("global_", id); + } + + std::string getNewMemorySymbolName() const { + size_t id = memSymbols.size(); + return getNewSymbolName("mem_", id); + } + + std::string getNewTableSymbolName() const { + size_t id = tableSymbols.size(); + return getNewSymbolName("table_", id); + } +}; + +class ParserHead; + +/// Wrapper around SmallVector to only allow access as push and pop on the +/// stack. Makes sure that there are no "free accesses" on the stack to preserve +/// its state. +class ValueStack { +private: + struct LabelLevel { + size_t stackIdx; + LabelLevelOpInterface levelOp; + }; + +public: + bool empty() const { return values.empty(); } + + size_t size() const { return values.size(); } + + /// Pops values from the stack because they are being used in an operation. + /// @param operandTypes The list of expected types of the operation, used + /// to know how many values to pop and check if the types match the + /// expectation. + /// @param opLoc Location of the caller, used to report accurately the + /// location + /// if an error occurs. + /// @return Failure or the vector of popped values. + FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes, + Location *opLoc); + + /// Push the results of an operation to the stack so they can be used in a + /// following operation. + /// @param results The list of results of the operation + /// @param opLoc Location of the caller, used to report accurately the + /// location + /// if an error occurs. + LogicalResult pushResults(ValueRange results, Location *opLoc); + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// A simple dump function for debugging. + /// Writes output to llvm::dbgs(). + LLVM_DUMP_METHOD void dump() const; +#endif + +private: + SmallVector<Value> values; +}; + +using local_val_t = TypedValue<wasmssa::LocalRefType>; + +class ExpressionParser { +public: + using locals_t = SmallVector<local_val_t>; + ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols, + ArrayRef<local_val_t> initLocal) + : parser{parser}, symbols{symbols}, locals{initLocal} {} + +private: + template <std::byte opCode> + inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder); + + template <typename valueT> + parsed_inst_t + parseConstInst(OpBuilder &builder, + std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr); + + /// Construct an operation with \p numOperands operands and a single result. + /// Each operand must have the same type. Suitable for e.g. binops, unary + /// ops, etc. + /// + /// \p opcode - The WASM opcode to build. + /// \p valueType - The operand and result type for the built instruction. + /// \p numOperands - The number of operands for the built operation. + /// + /// \returns The parsed instruction result, or failure. + template <typename opcode, typename valueType, unsigned int numOperands> + inline parsed_inst_t + buildNumericOp(OpBuilder &builder, + std::enable_if_t<std::is_arithmetic_v<valueType>> * = nullptr); + + /// This function generates a dispatch tree to associate an opcode with a + /// parser. Parsers are registered by specialising the + /// `parseSpecificInstruction` function for the op code to handle. + /// + /// The dispatcher is generated by recursively creating all possible patterns + /// for an opcode and calling the relevant parser on the leaf. + /// + /// @tparam patternBitSize is the first bit for which the pattern is not fixed + /// + /// @tparam highBitPattern is the fixed pattern that this instance handles for + /// the 8-patternBitSize bits + template <size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}> + inline parsed_inst_t dispatchToInstParser(std::byte opCode, + OpBuilder &builder) { + static_assert(patternBitSize <= 8, + "PatternBitSize is outside of range of opcode space! " + "(expected at most 8 bits)"); + if constexpr (patternBitSize < 8) { + constexpr std::byte bitSelect{1 << (7 - patternBitSize)}; + constexpr std::byte nextHighBitPatternStem = highBitPattern << 1; + constexpr size_t nextPatternBitSize = patternBitSize + 1; + if ((opCode & bitSelect) != std::byte{0}) + return dispatchToInstParser<nextPatternBitSize, + nextHighBitPatternStem | std::byte{1}>( + opCode, builder); + return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>( + opCode, builder); + } else { + return parseSpecificInstruction<highBitPattern>(builder); + } + } + + struct ParseResultWithInfo { + SmallVector<Value> opResults; + std::byte endingByte; + }; + +public: + template <std::byte ParseEndByte = WasmBinaryEncoding::endByte> + parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {}); + + template <std::byte... ExpressionParseEnd> + FailureOr<ParseResultWithInfo> + parse(OpBuilder &builder, + ByteSequence<ExpressionParseEnd...> parsingEndFilters); + + FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) { + return valueStack.popOperands(operandTypes, ¤tOpLoc.value()); + } + + LogicalResult pushResults(ValueRange results) { + return valueStack.pushResults(results, ¤tOpLoc.value()); + } + + /// The local.set and local.tee operations behave similarly and only differ + /// on their return value. This function factorizes the behavior of the two + /// operations in one place. + template <typename OpToCreate> + parsed_inst_t parseSetOrTee(OpBuilder &); + +private: + std::optional<Location> currentOpLoc; + ParserHead &parser; + WasmModuleSymbolTables const &symbols; + locals_t locals; + ValueStack valueStack; +}; + +class ParserHead { +public: + ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {} + ParserHead(ParserHead &&) = default; + +private: + ParserHead(ParserHead const &other) = default; + +public: + auto getLocation() const { + return FileLineColLoc::get(locName, 0, anchorOffset + offset); + } + + FailureOr<StringRef> consumeNBytes(size_t nBytes) { + LDBG() << "Consume " << nBytes << " bytes"; + LDBG() << " Bytes remaining: " << size(); + LDBG() << " Current offset: " << offset; + if (nBytes > size()) + return emitError(getLocation(), "trying to extract ") + << nBytes << "bytes when only " << size() << "are available"; + + StringRef res = head.slice(offset, offset + nBytes); + offset += nBytes; + LDBG() << " Updated offset (+" << nBytes << "): " << offset; + return res; + } + + FailureOr<std::byte> consumeByte() { + FailureOr<StringRef> res = consumeNBytes(1); + if (failed(res)) + return failure(); + return std::byte{*res->bytes_begin()}; + } + + template <typename T> + FailureOr<T> parseLiteral(); + + FailureOr<uint32_t> parseVectorSize(); + +private: + // TODO: This is equivalent to parseLiteral<uint32_t> and could be removed + // if parseLiteral specialization were moved here, but default GCC on Ubuntu + // 22.04 has bug with template specialization in class declaration + inline FailureOr<uint32_t> parseUI32(); + inline FailureOr<int64_t> parseI64(); + +public: + FailureOr<StringRef> parseName() { + FailureOr<uint32_t> size = parseVectorSize(); + if (failed(size)) + return failure(); + + return consumeNBytes(*size); + } + + FailureOr<WasmSectionType> parseWasmSectionType() { + FailureOr<std::byte> id = consumeByte(); + if (failed(id)) + return failure(); + if (std::to_integer<unsigned>(*id) > highestWasmSectionID) + return emitError(getLocation(), "invalid section ID: ") + << static_cast<int>(*id); + return static_cast<WasmSectionType>(*id); + } + + FailureOr<LimitType> parseLimit(MLIRContext *ctx) { + using WasmLimits = WasmBinaryEncoding::LimitHeader; + FileLineColLoc limitLocation = getLocation(); + FailureOr<std::byte> limitHeader = consumeByte(); + if (failed(limitHeader)) + return failure(); + + if (isNotIn<WasmLimits::bothLimits, WasmLimits::lowLimitOnly>(*limitHeader)) + return emitError(limitLocation, "invalid limit header: ") + << static_cast<int>(*limitHeader); + FailureOr<uint32_t> minParse = parseUI32(); + if (failed(minParse)) + return failure(); + std::optional<uint32_t> max{std::nullopt}; + if (*limitHeader == WasmLimits::bothLimits) { + FailureOr<uint32_t> maxParse = parseUI32(); + if (failed(maxParse)) + return failure(); + max = *maxParse; + } + return LimitType::get(ctx, *minParse, max); + } + + FailureOr<Type> parseValueType(MLIRContext *ctx) { + FileLineColLoc typeLoc = getLocation(); + FailureOr<std::byte> typeEncoding = consumeByte(); + if (failed(typeEncoding)) + return failure(); + switch (*typeEncoding) { + case WasmBinaryEncoding::Type::i32: + return IntegerType::get(ctx, 32); + case WasmBinaryEncoding::Type::i64: + return IntegerType::get(ctx, 64); + case WasmBinaryEncoding::Type::f32: + return Float32Type::get(ctx); + case WasmBinaryEncoding::Type::f64: + return Float64Type::get(ctx); + case WasmBinaryEncoding::Type::v128: + return IntegerType::get(ctx, 128); + case WasmBinaryEncoding::Type::funcRef: + return wasmssa::FuncRefType::get(ctx); + case WasmBinaryEncoding::Type::externRef: + return wasmssa::ExternRefType::get(ctx); + default: + return emitError(typeLoc, "invalid value type encoding: ") + << static_cast<int>(*typeEncoding); + } + } + + FailureOr<GlobalTypeRecord> parseGlobalType(MLIRContext *ctx) { + using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability; + FailureOr<Type> typeParsed = parseValueType(ctx); + if (failed(typeParsed)) + return failure(); + FileLineColLoc mutLoc = getLocation(); + FailureOr<std::byte> mutSpec = consumeByte(); + if (failed(mutSpec)) + return failure(); + if (isNotIn<WasmGlobalMut::isConst, WasmGlobalMut::isMutable>(*mutSpec)) + return emitError(mutLoc, "invalid global mutability specifier: ") + << static_cast<int>(*mutSpec); + return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable}; + } + + FailureOr<TupleType> parseResultType(MLIRContext *ctx) { + FailureOr<uint32_t> nParamsParsed = parseVectorSize(); + if (failed(nParamsParsed)) + return failure(); + uint32_t nParams = *nParamsParsed; + SmallVector<Type> res{}; + res.reserve(nParams); + for (size_t i = 0; i < nParams; ++i) { + FailureOr<Type> parsedType = parseValueType(ctx); + if (failed(parsedType)) + return failure(); + res.push_back(*parsedType); + } + return TupleType::get(ctx, res); + } + + FailureOr<FunctionType> parseFunctionType(MLIRContext *ctx) { + FileLineColLoc typeLoc = getLocation(); + FailureOr<std::byte> funcTypeHeader = consumeByte(); + if (failed(funcTypeHeader)) + return failure(); + if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType) + return emitError(typeLoc, "invalid function type header byte. Expecting ") + << std::to_integer<unsigned>(WasmBinaryEncoding::Type::funcType) + << " got " << std::to_integer<unsigned>(*funcTypeHeader); + FailureOr<TupleType> inputTypes = parseResultType(ctx); + if (failed(inputTypes)) + return failure(); + + FailureOr<TupleType> resTypes = parseResultType(ctx); + if (failed(resTypes)) + return failure(); + + return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes()); + } + + FailureOr<TypeIdxRecord> parseTypeIndex() { + FailureOr<uint32_t> res = parseUI32(); + if (failed(res)) + return failure(); + return TypeIdxRecord{*res}; + } + + FailureOr<TableType> parseTableType(MLIRContext *ctx) { + FailureOr<Type> elmTypeParse = parseValueType(ctx); + if (failed(elmTypeParse)) + return failure(); + if (!isWasmRefType(*elmTypeParse)) + return emitError(getLocation(), "invalid element type for table"); + FailureOr<LimitType> limitParse = parseLimit(ctx); + if (failed(limitParse)) + return failure(); + return TableType::get(ctx, *elmTypeParse, *limitParse); + } + + FailureOr<ImportDesc> parseImportDesc(MLIRContext *ctx) { + FileLineColLoc importLoc = getLocation(); + FailureOr<std::byte> importType = consumeByte(); + auto packager = [](auto parseResult) -> FailureOr<ImportDesc> { + if (failed(parseResult)) + return failure(); + return {*parseResult}; + }; + if (failed(importType)) + return failure(); + switch (*importType) { + case WasmBinaryEncoding::Import::typeID: + return packager(parseTypeIndex()); + case WasmBinaryEncoding::Import::tableType: + return packager(parseTableType(ctx)); + case WasmBinaryEncoding::Import::memType: + return packager(parseLimit(ctx)); + case WasmBinaryEncoding::Import::globalType: + return packager(parseGlobalType(ctx)); + default: + return emitError(importLoc, "invalid import type descriptor: ") + << static_cast<int>(*importType); + } + } + + parsed_inst_t parseExpression(OpBuilder &builder, + WasmModuleSymbolTables const &symbols, + ArrayRef<local_val_t> locals = {}) { + auto eParser = ExpressionParser{*this, symbols, locals}; + return eParser.parse(builder); + } + + LogicalResult parseCodeFor(FuncOp func, + WasmModuleSymbolTables const &symbols) { + SmallVector<local_val_t> locals{}; + // Populating locals with function argument + Block &block = func.getBody().front(); + // Delete temporary return argument which was only created for IR validity + assert(func.getBody().getBlocks().size() == 1 && + "Function should only have its default created block at this point"); + assert(block.getOperations().size() == 1 && + "Only the placeholder return op should be present at this point"); + auto returnOp = cast<ReturnOp>(&block.back()); + assert(returnOp); + + FailureOr<uint32_t> codeSizeInBytes = parseUI32(); + if (failed(codeSizeInBytes)) + return failure(); + FailureOr<StringRef> codeContent = consumeNBytes(*codeSizeInBytes); + if (failed(codeContent)) + return failure(); + auto name = StringAttr::get(func->getContext(), + locName.str() + "::" + func.getSymName()); + auto cParser = ParserHead{*codeContent, name}; + FailureOr<uint32_t> localVecSize = cParser.parseVectorSize(); + if (failed(localVecSize)) + return failure(); + OpBuilder builder{&func.getBody().front().back()}; + for (auto arg : block.getArguments()) + locals.push_back(cast<TypedValue<LocalRefType>>(arg)); + // Declare the local ops + uint32_t nVarVec = *localVecSize; + for (size_t i = 0; i < nVarVec; ++i) { + FileLineColLoc varLoc = cParser.getLocation(); + FailureOr<uint32_t> nSubVar = cParser.parseUI32(); + if (failed(nSubVar)) + return failure(); + FailureOr<Type> varT = cParser.parseValueType(func->getContext()); + if (failed(varT)) + return failure(); + for (size_t j = 0; j < *nSubVar; ++j) { + auto local = builder.create<LocalOp>(varLoc, *varT); + locals.push_back(local.getResult()); + } + } + parsed_inst_t res = cParser.parseExpression(builder, symbols, locals); + if (failed(res)) + return failure(); + if (!cParser.end()) + return emitError(cParser.getLocation(), + "unparsed garbage remaining at end of code block"); + builder.create<ReturnOp>(func->getLoc(), *res); + returnOp->erase(); + return success(); + } + + bool end() const { return curHead().empty(); } + + ParserHead copy() const { return *this; } + +private: + StringRef curHead() const { return head.drop_front(offset); } + + FailureOr<std::byte> peek() const { + if (end()) + return emitError( + getLocation(), + "trying to peek at next byte, but input stream is empty"); + return static_cast<std::byte>(curHead().front()); + } + + size_t size() const { return head.size() - offset; } + + StringRef head; + StringAttr locName; + unsigned anchorOffset{0}; + unsigned offset{0}; +}; + +template <> +FailureOr<float> ParserHead::parseLiteral<float>() { + FailureOr<StringRef> bytes = consumeNBytes(4); + if (failed(bytes)) + return failure(); + return llvm::support::endian::read<float>(bytes->bytes_begin(), + llvm::endianness::little); +} + +template <> +FailureOr<double> ParserHead::parseLiteral<double>() { + FailureOr<StringRef> bytes = consumeNBytes(8); + if (failed(bytes)) + return failure(); + return llvm::support::endian::read<double>(bytes->bytes_begin(), + llvm::endianness::little); +} + +template <> +FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() { + char const *error = nullptr; + uint32_t res{0}; + unsigned encodingSize{0}; + StringRef src = curHead(); + uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + + if (std::isgreater(decoded, std::numeric_limits<uint32_t>::max())) + return emitError(getLocation()) << "literal does not fit on 32 bits"; + + res = static_cast<uint32_t>(decoded); + offset += encodingSize; + return res; +} + +template <> +FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() { + char const *error = nullptr; + int32_t res{0}; + unsigned encodingSize{0}; + StringRef src = curHead(); + int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) || + std::isgreater(std::numeric_limits<int32_t>::min(), decoded)) + return emitError(getLocation()) << "literal does not fit on 32 bits"; + + res = static_cast<int32_t>(decoded); + offset += encodingSize; + return res; +} + +template <> +FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() { + char const *error = nullptr; + unsigned encodingSize{0}; + StringRef src = curHead(); + int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize, + src.bytes_end(), &error); + if (error) + return emitError(getLocation(), error); + + offset += encodingSize; + return res; +} + +FailureOr<uint32_t> ParserHead::parseVectorSize() { + return parseLiteral<uint32_t>(); +} + +inline FailureOr<uint32_t> ParserHead::parseUI32() { + return parseLiteral<uint32_t>(); +} + +inline FailureOr<int64_t> ParserHead::parseI64() { + return parseLiteral<int64_t>(); +} + +template <std::byte opCode> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) { + return emitError(*currentOpLoc, "unknown instruction opcode: ") + << static_cast<int>(opCode); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void ValueStack::dump() const { + llvm::dbgs() << "================= Wasm ValueStack =======================\n"; + llvm::dbgs() << "size: " << size() << "\n"; + llvm::dbgs() << "<Top>" + << "\n"; + // Stack is pushed to via push_back. Therefore the top of the stack is the + // end of the vector. Iterate in reverse so that the first thing we print + // is the top of the stack. + size_t stackSize = size(); + for (size_t idx = 0; idx < stackSize; idx++) { + size_t actualIdx = stackSize - 1 - idx; + llvm::dbgs() << " "; + values[actualIdx].dump(); + } + llvm::dbgs() << "<Bottom>" + << "\n"; + llvm::dbgs() << "=========================================================\n"; +} +#endif + +parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) { + LDBG() << "Popping from ValueStack\n" + << " Elements(s) to pop: " << operandTypes.size() << "\n" + << " Current stack size: " << values.size(); + if (operandTypes.size() > values.size()) + return emitError(*opLoc, + "stack doesn't contain enough values. trying to get ") + << operandTypes.size() << " operands on a stack containing only " + << values.size() << " values."; + size_t stackIdxOffset = values.size() - operandTypes.size(); + SmallVector<Value> res{}; + res.reserve(operandTypes.size()); + for (size_t i{0}; i < operandTypes.size(); ++i) { + Value operand = values[i + stackIdxOffset]; + Type stackType = operand.getType(); + if (stackType != operandTypes[i]) + return emitError(*opLoc, "invalid operand type on stack. expecting ") + << operandTypes[i] << ", value on stack is of type " << stackType + << "."; + LDBG() << " POP: " << operand; + res.push_back(operand); + } + values.resize(values.size() - operandTypes.size()); + LDBG() << " Updated stack size: " << values.size(); + return res; +} + +LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) { + LDBG() << "Pushing to ValueStack\n" + << " Elements(s) to push: " << results.size() << "\n" + << " Current stack size: " << values.size(); + for (Value val : results) { + if (!isWasmValueType(val.getType())) + return emitError(*opLoc, "invalid value type on stack: ") + << val.getType(); + LDBG() << " PUSH: " << val; + values.push_back(val); + } + + LDBG() << " Updated stack size: " << values.size(); + return success(); +} + +template <std::byte EndParseByte> +parsed_inst_t ExpressionParser::parse(OpBuilder &builder, + UniqueByte<EndParseByte> endByte) { + auto res = parse(builder, ByteSequence<EndParseByte>{}); + if (failed(res)) + return failure(); + return res->opResults; +} + +template <std::byte... ExpressionParseEnd> +FailureOr<ExpressionParser::ParseResultWithInfo> +ExpressionParser::parse(OpBuilder &builder, + ByteSequence<ExpressionParseEnd...> parsingEndFilters) { + SmallVector<Value> res; + for (;;) { + currentOpLoc = parser.getLocation(); + FailureOr<std::byte> opCode = parser.consumeByte(); + if (failed(opCode)) + return failure(); + if (isValueOneOf(*opCode, parsingEndFilters)) + return {{res, *opCode}}; + parsed_inst_t resParsed; + resParsed = dispatchToInstParser(*opCode, builder); + if (failed(resParsed)) + return failure(); + std::swap(res, *resParsed); + if (failed(pushResults(res))) + return failure(); + } +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::localGet>(OpBuilder &builder) { + FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>(); + Location instLoc = *currentOpLoc; + if (failed(id)) + return failure(); + if (*id >= locals.size()) + return emitError(instLoc, "invalid local index. function has ") + << locals.size() << " accessible locals, received index " << *id; + return {{builder.create<LocalGetOp>(instLoc, locals[*id]).getResult()}}; +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::globalGet>(OpBuilder &builder) { + FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>(); + Location instLoc = *currentOpLoc; + if (failed(id)) + return failure(); + if (*id >= symbols.globalSymbols.size()) + return emitError(instLoc, "invalid global index. function has ") + << symbols.globalSymbols.size() + << " accessible globals, received index " << *id; + GlobalSymbolRefContainer globalVar = symbols.globalSymbols[*id]; + auto globalOp = builder.create<GlobalGetOp>(instLoc, globalVar.globalType, + globalVar.symbol); + + return {{globalOp.getResult()}}; +} + +template <typename OpToCreate> +parsed_inst_t ExpressionParser::parseSetOrTee(OpBuilder &builder) { + FailureOr<uint32_t> id = parser.parseLiteral<uint32_t>(); + if (failed(id)) + return failure(); + if (*id >= locals.size()) + return emitError(*currentOpLoc, "invalid local index. function has ") + << locals.size() << " accessible locals, received index " << *id; + if (valueStack.empty()) + return emitError( + *currentOpLoc, + "invalid stack access, trying to access a value on an empty stack."); + + parsed_inst_t poppedOp = popOperands(locals[*id].getType().getElementType()); + if (failed(poppedOp)) + return failure(); + return { + builder.create<OpToCreate>(*currentOpLoc, locals[*id], poppedOp->front()) + ->getResults()}; +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::localSet>(OpBuilder &builder) { + return parseSetOrTee<LocalSetOp>(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::localTee>(OpBuilder &builder) { + return parseSetOrTee<LocalTeeOp>(builder); +} + +template <typename T> +inline Type buildLiteralType(OpBuilder &); + +template <> +inline Type buildLiteralType<int32_t>(OpBuilder &builder) { + return builder.getI32Type(); +} + +template <> +inline Type buildLiteralType<int64_t>(OpBuilder &builder) { + return builder.getI64Type(); +} + +template <> +[[maybe_unused]] inline Type buildLiteralType<uint32_t>(OpBuilder &builder) { + return builder.getI32Type(); +} + +template <> +[[maybe_unused]] inline Type buildLiteralType<uint64_t>(OpBuilder &builder) { + return builder.getI64Type(); +} + +template <> +inline Type buildLiteralType<float>(OpBuilder &builder) { + return builder.getF32Type(); +} + +template <> +inline Type buildLiteralType<double>(OpBuilder &builder) { + return builder.getF64Type(); +} + +template <typename ValT, + typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>> +struct AttrHolder; + +template <typename ValT> +struct AttrHolder<ValT, std::enable_if_t<std::is_integral_v<ValT>>> { + using type = IntegerAttr; +}; + +template <typename ValT> +struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> { + using type = FloatAttr; +}; + +template <typename ValT> +using attr_holder_t = typename AttrHolder<ValT>::type; + +template <typename ValT, + typename EnableT = std::enable_if_t<std::is_arithmetic_v<ValT>>> +attr_holder_t<ValT> buildLiteralAttr(OpBuilder &builder, ValT val) { + return attr_holder_t<ValT>::get(buildLiteralType<ValT>(builder), val); +} + +template <typename valueT> +parsed_inst_t ExpressionParser::parseConstInst( + OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueT>> *) { + auto parsedConstant = parser.parseLiteral<valueT>(); + if (failed(parsedConstant)) + return failure(); + auto constOp = + ConstOp::create(builder, *currentOpLoc, + buildLiteralAttr<valueT>(builder, *parsedConstant)); + return {{constOp.getResult()}}; +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) { + return parseConstInst<int32_t>(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) { + return parseConstInst<int64_t>(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) { + return parseConstInst<float>(builder); +} + +template <> +inline parsed_inst_t ExpressionParser::parseSpecificInstruction< + WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) { + return parseConstInst<double>(builder); +} + +template <typename opcode, typename valueType, unsigned int numOperands> +inline parsed_inst_t ExpressionParser::buildNumericOp( + OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueType>> *) { + auto ty = buildLiteralType<valueType>(builder); + LDBG() << "*** buildNumericOp: numOperands = " << numOperands + << ", type = " << ty << " ***"; + auto tysToPop = SmallVector<Type, numOperands>(); + tysToPop.resize(numOperands); + std::fill(tysToPop.begin(), tysToPop.end(), ty); + auto operands = popOperands(tysToPop); + if (failed(operands)) + return failure(); + auto op = builder.create<opcode>(*currentOpLoc, *operands).getResult(); + LDBG() << "Built operation: " << op; + return {{op}}; +} + +// Convenience macro for generating numerical operations. +#define BUILD_NUMERIC_OP(OP_NAME, N_ARGS, PREFIX, SUFFIX, TYPE) \ + template <> \ + inline parsed_inst_t ExpressionParser::parseSpecificInstruction< \ + WasmBinaryEncoding::OpCode::PREFIX##SUFFIX>(OpBuilder & builder) { \ + return buildNumericOp<OP_NAME, TYPE, N_ARGS>(builder); \ + } + +// Macro to define binops that only support integer types. +#define BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \ + BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I32, int32_t) \ + BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, I64, int64_t) + +// Macro to define binops that only support floating point types. +#define BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX) \ + BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F32, float) \ + BUILD_NUMERIC_OP(OP_NAME, 2, PREFIX, F64, double) + +// Macro to define binops that support both floating point and integer types. +#define BUILD_NUMERIC_BINOP_INTFP(OP_NAME, PREFIX) \ + BUILD_NUMERIC_BINOP_INT(OP_NAME, PREFIX) \ + BUILD_NUMERIC_BINOP_FP(OP_NAME, PREFIX) + +// Macro to implement unary ops that only support integers. +#define BUILD_NUMERIC_UNARY_OP_INT(OP_NAME, PREFIX) \ + BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I32, int32_t) \ + BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, I64, int64_t) + +// Macro to implement unary ops that support integer and floating point types. +#define BUILD_NUMERIC_UNARY_OP_FP(OP_NAME, PREFIX) \ + BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F32, float) \ + BUILD_NUMERIC_OP(OP_NAME, 1, PREFIX, F64, double) + +BUILD_NUMERIC_BINOP_FP(CopySignOp, copysign) +BUILD_NUMERIC_BINOP_FP(DivOp, div) +BUILD_NUMERIC_BINOP_FP(MaxOp, max) +BUILD_NUMERIC_BINOP_FP(MinOp, min) +BUILD_NUMERIC_BINOP_INT(AndOp, and) +BUILD_NUMERIC_BINOP_INT(DivSIOp, divS) +BUILD_NUMERIC_BINOP_INT(DivUIOp, divU) +BUILD_NUMERIC_BINOP_INT(OrOp, or) +BUILD_NUMERIC_BINOP_INT(RemSIOp, remS) +BUILD_NUMERIC_BINOP_INT(RemUIOp, remU) +BUILD_NUMERIC_BINOP_INT(RotlOp, rotl) +BUILD_NUMERIC_BINOP_INT(RotrOp, rotr) +BUILD_NUMERIC_BINOP_INT(ShLOp, shl) +BUILD_NUMERIC_BINOP_INT(ShRSOp, shrS) +BUILD_NUMERIC_BINOP_INT(ShRUOp, shrU) +BUILD_NUMERIC_BINOP_INT(XOrOp, xor) +BUILD_NUMERIC_BINOP_INTFP(AddOp, add) +BUILD_NUMERIC_BINOP_INTFP(MulOp, mul) +BUILD_NUMERIC_BINOP_INTFP(SubOp, sub) +BUILD_NUMERIC_UNARY_OP_FP(AbsOp, abs) +BUILD_NUMERIC_UNARY_OP_FP(CeilOp, ceil) +BUILD_NUMERIC_UNARY_OP_FP(FloorOp, floor) +BUILD_NUMERIC_UNARY_OP_FP(NegOp, neg) +BUILD_NUMERIC_UNARY_OP_FP(SqrtOp, sqrt) +BUILD_NUMERIC_UNARY_OP_FP(TruncOp, trunc) +BUILD_NUMERIC_UNARY_OP_INT(ClzOp, clz) +BUILD_NUMERIC_UNARY_OP_INT(CtzOp, ctz) +BUILD_NUMERIC_UNARY_OP_INT(PopCntOp, popcnt) + +// Don't need these anymore so let's undef them. +#undef BUILD_NUMERIC_BINOP_FP +#undef BUILD_NUMERIC_BINOP_INT +#undef BUILD_NUMERIC_BINOP_INTFP +#undef BUILD_NUMERIC_UNARY_OP_FP +#undef BUILD_NUMERIC_UNARY_OP_INT +#undef BUILD_NUMERIC_OP +#undef BUILD_NUMERIC_CAST_OP + +class WasmBinaryParser { +private: + struct SectionRegistry { + using section_location_t = StringRef; + + std::array<SmallVector<section_location_t>, highestWasmSectionID + 1> + registry; + + template <WasmSectionType SecType> + std::conditional_t<sectionShouldBeUnique(SecType), + std::optional<section_location_t>, + ArrayRef<section_location_t>> + getContentForSection() const { + constexpr auto idx = static_cast<size_t>(SecType); + if constexpr (sectionShouldBeUnique(SecType)) { + return registry[idx].empty() ? std::nullopt + : std::make_optional(registry[idx][0]); + } else { + return registry[idx]; + } + } + + bool hasSection(WasmSectionType secType) const { + return !registry[static_cast<size_t>(secType)].empty(); + } + + /// + /// @returns success if registration valid, failure in case registration + /// can't be done (if another section of same type already exist and this + /// section type should only be present once) + /// + LogicalResult registerSection(WasmSectionType secType, + section_location_t location, Location loc) { + if (sectionShouldBeUnique(secType) && hasSection(secType)) + return emitError(loc, + "trying to add a second instance of unique section"); + + registry[static_cast<size_t>(secType)].push_back(location); + emitRemark(loc, "Adding section with section ID ") + << static_cast<uint8_t>(secType); + return success(); + } + + LogicalResult populateFromBody(ParserHead ph) { + while (!ph.end()) { + FileLineColLoc sectionLoc = ph.getLocation(); + FailureOr<WasmSectionType> secType = ph.parseWasmSectionType(); + if (failed(secType)) + return failure(); + + FailureOr<uint32_t> secSizeParsed = ph.parseLiteral<uint32_t>(); + if (failed(secSizeParsed)) + return failure(); + + uint32_t secSize = *secSizeParsed; + FailureOr<StringRef> sectionContent = ph.consumeNBytes(secSize); + if (failed(sectionContent)) + return failure(); + + LogicalResult registration = + registerSection(*secType, *sectionContent, sectionLoc); + + if (failed(registration)) + return failure(); + } + return success(); + } + }; + + auto getLocation(int offset = 0) const { + return FileLineColLoc::get(srcName, 0, offset); + } + + template <WasmSectionType> + LogicalResult parseSectionItem(ParserHead &, size_t); + + template <WasmSectionType section> + LogicalResult parseSection() { + auto secName = std::string{wasmSectionName<section>}; + auto sectionNameAttr = + StringAttr::get(ctx, srcName.strref() + ":" + secName + "-SECTION"); + unsigned offset = 0; + auto getLocation = [sectionNameAttr, &offset]() { + return FileLineColLoc::get(sectionNameAttr, 0, offset); + }; + auto secContent = registry.getContentForSection<section>(); + if (!secContent) { + LDBG() << secName << " section is not present in file."; + return success(); + } + + auto secSrc = secContent.value(); + ParserHead ph{secSrc, sectionNameAttr}; + FailureOr<uint32_t> nElemsParsed = ph.parseVectorSize(); + if (failed(nElemsParsed)) + return failure(); + uint32_t nElems = *nElemsParsed; + LDBG() << "starting to parse " << nElems << " items for section " + << secName; + for (size_t i = 0; i < nElems; ++i) { + if (failed(parseSectionItem<section>(ph, i))) + return failure(); + } + + if (!ph.end()) + return emitError(getLocation(), "unparsed garbage at end of section ") + << secName; + return success(); + } + + /// Handles the registration of a function import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, TypeIdxRecord tid) { + using llvm::Twine; + if (tid.id >= symbols.moduleFuncTypes.size()) + return emitError(loc, "invalid type id: ") + << tid.id << ". Only " << symbols.moduleFuncTypes.size() + << " type registration."; + FunctionType type = symbols.moduleFuncTypes[tid.id]; + std::string symbol = symbols.getNewFuncSymbolName(); + auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName, + importName, type); + symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type}); + return funcOp.verify(); + } + + /// Handles the registration of a memory import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, LimitType limitType) { + std::string symbol = symbols.getNewMemorySymbolName(); + auto memOp = MemImportOp::create(builder, loc, symbol, moduleName, + importName, limitType); + symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)}); + return memOp.verify(); + } + + /// Handles the registration of a table import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, TableType tableType) { + std::string symbol = symbols.getNewTableSymbolName(); + auto tableOp = TableImportOp::create(builder, loc, symbol, moduleName, + importName, tableType); + symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)}); + return tableOp.verify(); + } + + /// Handles the registration of a global variable import + LogicalResult visitImport(Location loc, StringRef moduleName, + StringRef importName, GlobalTypeRecord globalType) { + std::string symbol = symbols.getNewGlobalSymbolName(); + auto giOp = + GlobalImportOp::create(builder, loc, symbol, moduleName, importName, + globalType.type, globalType.isMutable); + symbols.globalSymbols.push_back( + {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()}); + return giOp.verify(); + } + + // Detect occurence of errors + LogicalResult peekDiag(Diagnostic &diag) { + if (diag.getSeverity() == DiagnosticSeverity::Error) + isValid = false; + return failure(); + } + +public: + WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx) + : builder{ctx}, ctx{ctx} { + ctx->getDiagEngine().registerHandler( + [this](Diagnostic &diag) { return peekDiag(diag); }); + ctx->loadAllAvailableDialects(); + if (sourceMgr.getNumBuffers() != 1) { + emitError(UnknownLoc::get(ctx), "one source file should be provided"); + return; + } + uint32_t sourceBufId = sourceMgr.getMainFileID(); + StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer(); + srcName = StringAttr::get( + ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier()); + + auto parser = ParserHead{source, srcName}; + auto const wasmHeader = StringRef{"\0asm", 4}; + FileLineColLoc magicLoc = parser.getLocation(); + FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size()); + if (failed(magic) || magic->compare(wasmHeader)) { + emitError(magicLoc, "source file does not contain valid Wasm header."); + return; + } + auto const expectedVersionString = StringRef{"\1\0\0\0", 4}; + FileLineColLoc versionLoc = parser.getLocation(); + FailureOr<StringRef> version = + parser.consumeNBytes(expectedVersionString.size()); + if (failed(version)) + return; + if (version->compare(expectedVersionString)) { + emitError(versionLoc, + "unsupported Wasm version. only version 1 is supported"); + return; + } + LogicalResult fillRegistry = registry.populateFromBody(parser.copy()); + if (failed(fillRegistry)) + return; + + mOp = ModuleOp::create(builder, getLocation()); + builder.setInsertionPointToStart(&mOp.getBodyRegion().front()); + LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>(); + if (failed(parsingTypes)) + return; + + LogicalResult parsingImports = parseSection<WasmSectionType::IMPORT>(); + if (failed(parsingImports)) + return; + + firstInternalFuncID = symbols.funcSymbols.size(); + + LogicalResult parsingFunctions = parseSection<WasmSectionType::FUNCTION>(); + if (failed(parsingFunctions)) + return; + + LogicalResult parsingTables = parseSection<WasmSectionType::TABLE>(); + if (failed(parsingTables)) + return; + + LogicalResult parsingMems = parseSection<WasmSectionType::MEMORY>(); + if (failed(parsingMems)) + return; + + LogicalResult parsingGlobals = parseSection<WasmSectionType::GLOBAL>(); + if (failed(parsingGlobals)) + return; + + LogicalResult parsingCode = parseSection<WasmSectionType::CODE>(); + if (failed(parsingCode)) + return; + + LogicalResult parsingExports = parseSection<WasmSectionType::EXPORT>(); + if (failed(parsingExports)) + return; + + // Copy over sizes of containers into statistics. + LDBG() << "WASM Imports:" + << "\n" + << " - Num functions: " << symbols.funcSymbols.size() << "\n" + << " - Num globals: " << symbols.globalSymbols.size() << "\n" + << " - Num memories: " << symbols.memSymbols.size() << "\n" + << " - Num tables: " << symbols.tableSymbols.size(); + } + + ModuleOp getModule() { + if (isValid) + return mOp; + if (mOp) + mOp.erase(); + return ModuleOp{}; + } + +private: + mlir::StringAttr srcName; + OpBuilder builder; + WasmModuleSymbolTables symbols; + MLIRContext *ctx; + ModuleOp mOp; + SectionRegistry registry; + size_t firstInternalFuncID{0}; + bool isValid{true}; +}; + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph, + size_t) { + FileLineColLoc importLoc = ph.getLocation(); + auto moduleName = ph.parseName(); + if (failed(moduleName)) + return failure(); + + auto importName = ph.parseName(); + if (failed(importName)) + return failure(); + + FailureOr<ImportDesc> import = ph.parseImportDesc(ctx); + if (failed(import)) + return failure(); + + return std::visit( + [this, importLoc, &moduleName, &importName](auto import) { + return visitImport(importLoc, *moduleName, *importName, import); + }, + *import); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph, + size_t) { + FileLineColLoc exportLoc = ph.getLocation(); + + auto exportName = ph.parseName(); + if (failed(exportName)) + return failure(); + + FailureOr<std::byte> opcode = ph.consumeByte(); + if (failed(opcode)) + return failure(); + + FailureOr<uint32_t> idx = ph.parseLiteral<uint32_t>(); + if (failed(idx)) + return failure(); + + using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>, + SmallVector<GlobalSymbolRefContainer>, + SmallVector<FunctionSymbolRefContainer>>; + + SymbolRefDesc currentSymbolList; + std::string symbolType = ""; + switch (*opcode) { + case WasmBinaryEncoding::Export::function: + symbolType = "function"; + currentSymbolList = symbols.funcSymbols; + break; + case WasmBinaryEncoding::Export::table: + symbolType = "table"; + currentSymbolList = symbols.tableSymbols; + break; + case WasmBinaryEncoding::Export::memory: + symbolType = "memory"; + currentSymbolList = symbols.memSymbols; + break; + case WasmBinaryEncoding::Export::global: + symbolType = "global"; + currentSymbolList = symbols.globalSymbols; + break; + default: + return emitError(exportLoc, "invalid value for export type: ") + << std::to_integer<unsigned>(*opcode); + } + + auto currentSymbol = std::visit( + [&](const auto &list) -> FailureOr<FlatSymbolRefAttr> { + if (*idx > list.size()) { + emitError( + exportLoc, + llvm::formatv( + "trying to export {0} {1} which is undefined in this scope", + symbolType, *idx)); + return failure(); + } + return list[*idx].symbol; + }, + currentSymbolList); + + if (failed(currentSymbol)) + return failure(); + + Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol); + SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public); + StringAttr symName = SymbolTable::getSymbolName(op); + return SymbolTable{mOp}.rename(symName, *exportName); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph, + size_t) { + FileLineColLoc opLocation = ph.getLocation(); + FailureOr<TableType> tableType = ph.parseTableType(ctx); + if (failed(tableType)) + return failure(); + LDBG() << " Parsed table description: " << *tableType; + StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName()); + auto tableOp = + TableOp::create(builder, opLocation, symbol.strref(), *tableType); + symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)}); + return success(); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph, + size_t) { + FileLineColLoc opLoc = ph.getLocation(); + auto typeIdxParsed = ph.parseLiteral<uint32_t>(); + if (failed(typeIdxParsed)) + return failure(); + uint32_t typeIdx = *typeIdxParsed; + if (typeIdx >= symbols.moduleFuncTypes.size()) + return emitError(getLocation(), "invalid type index: ") << typeIdx; + std::string symbol = symbols.getNewFuncSymbolName(); + auto funcOp = + FuncOp::create(builder, opLoc, symbol, symbols.moduleFuncTypes[typeIdx]); + Block *block = funcOp.addEntryBlock(); + OpBuilder::InsertionGuard guard{builder}; + builder.setInsertionPointToEnd(block); + ReturnOp::create(builder, opLoc); + symbols.funcSymbols.push_back( + {{FlatSymbolRefAttr::get(funcOp.getSymNameAttr())}, + symbols.moduleFuncTypes[typeIdx]}); + return funcOp.verify(); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph, + size_t) { + FailureOr<FunctionType> funcType = ph.parseFunctionType(ctx); + if (failed(funcType)) + return failure(); + LDBG() << "Parsed function type " << *funcType; + symbols.moduleFuncTypes.push_back(*funcType); + return success(); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph, + size_t) { + FileLineColLoc opLocation = ph.getLocation(); + FailureOr<LimitType> memory = ph.parseLimit(ctx); + if (failed(memory)) + return failure(); + + LDBG() << " Registering memory " << *memory; + std::string symbol = symbols.getNewMemorySymbolName(); + auto memOp = MemOp::create(builder, opLocation, symbol, *memory); + symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)}); + return success(); +} + +template <> +LogicalResult +WasmBinaryParser::parseSectionItem<WasmSectionType::GLOBAL>(ParserHead &ph, + size_t) { + FileLineColLoc globalLocation = ph.getLocation(); + auto globalTypeParsed = ph.parseGlobalType(ctx); + if (failed(globalTypeParsed)) + return failure(); + + GlobalTypeRecord globalType = *globalTypeParsed; + auto symbol = builder.getStringAttr(symbols.getNewGlobalSymbolName()); + auto globalOp = builder.create<wasmssa::GlobalOp>( + globalLocation, symbol, globalType.type, globalType.isMutable); + symbols.globalSymbols.push_back( + {{FlatSymbolRefAttr::get(globalOp)}, globalOp.getType()}); + OpBuilder::InsertionGuard guard{builder}; + Block *block = builder.createBlock(&globalOp.getInitializer()); + builder.setInsertionPointToStart(block); + parsed_inst_t expr = ph.parseExpression(builder, symbols); + if (failed(expr)) + return failure(); + if (block->empty()) + return emitError(globalLocation, "global with empty initializer"); + if (expr->size() != 1 && (*expr)[0].getType() != globalType.type) + return emitError( + globalLocation, + "initializer result type does not match global declaration type"); + builder.create<ReturnOp>(globalLocation, *expr); + return success(); +} + +template <> +LogicalResult WasmBinaryParser::parseSectionItem<WasmSectionType::CODE>( + ParserHead &ph, size_t innerFunctionId) { + unsigned long funcId = innerFunctionId + firstInternalFuncID; + FunctionSymbolRefContainer symRef = symbols.funcSymbols[funcId]; + auto funcOp = + dyn_cast<FuncOp>(SymbolTable::lookupSymbolIn(mOp, symRef.symbol)); + assert(funcOp); + if (failed(ph.parseCodeFor(funcOp, symbols))) + return failure(); + return success(); +} +} // namespace + +namespace mlir::wasm { +OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source, + MLIRContext *context) { + WasmBinaryParser wBN{source, context}; + ModuleOp mOp = wBN.getModule(); + if (mOp) + return {mOp}; + + return {nullptr}; +} +} // namespace mlir::wasm diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp new file mode 100644 index 0000000..03b9784 --- /dev/null +++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp @@ -0,0 +1,28 @@ +//===- TranslateRegistration.cpp - Register translation -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/IR/OwningOpRef.h" +#include "mlir/Target/Wasm/WasmImporter.h" +#include "mlir/Tools/mlir-translate/Translation.h" + +using namespace mlir; + +namespace mlir { +void registerFromWasmTranslation() { + TranslateToMLIRRegistration registration{ + "import-wasm", "Translate WASM to MLIR", + [](llvm::SourceMgr &sourceMgr, + MLIRContext *context) -> OwningOpRef<Operation *> { + return wasm::importWebAssemblyToModule(sourceMgr, context); + }, + [](DialectRegistry ®istry) { + registry.insert<wasmssa::WasmSSADialect>(); + }}; +} +} // namespace mlir diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index 51e702a..c883baa 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -147,8 +147,9 @@ private: std::string docStr; { llvm::raw_string_ostream docOS(docStr); + std::string tmpDocStr = doc.str(); raw_indented_ostream(docOS).printReindented( - StringRef(docStr).rtrim(" \t")); + StringRef(tmpDocStr).rtrim(" \t")); } return docStr; } diff --git a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp index 9950050..6945c09 100644 --- a/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp +++ b/mlir/lib/Tools/mlir-query/MlirQueryMain.cpp @@ -21,6 +21,7 @@ #include "llvm/LineEditor/LineEditor.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/InitLLVM.h" +#include "llvm/Support/Process.h" #include "llvm/Support/SourceMgr.h" //===----------------------------------------------------------------------===// @@ -43,7 +44,7 @@ mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context, llvm::cl::value_desc("command"), llvm::cl::cat(mlirQueryCategory)); static llvm::cl::opt<std::string> inputFilename( - llvm::cl::Positional, llvm::cl::desc("<input file>"), + llvm::cl::Positional, llvm::cl::desc("<input file>"), llvm::cl::init("-"), llvm::cl::cat(mlirQueryCategory)); static llvm::cl::opt<bool> noImplicitModule{ @@ -68,6 +69,14 @@ mlir::mlirQueryMain(int argc, char **argv, MLIRContext &context, return mlir::success(); } + // When reading from stdin and the input is a tty, it is often a user mistake + // and the process "appears to be stuck". Print a message to let the user + // know! + if (inputFilename == "-" && + llvm::sys::Process::FileDescriptorIsDisplayed(fileno(stdin))) + llvm::errs() << "(processing input from stdin now, hit ctrl-c/ctrl-d to " + "interrupt)\n"; + // Set up the input file. std::string errorMessage; auto file = openInputFile(inputFilename, &errorMessage); diff --git a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp index e89d392..34459b8 100644 --- a/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp +++ b/mlir/lib/Tools/mlir-reduce/MlirReduceMain.cpp @@ -26,9 +26,9 @@ using namespace mlir; // Parse and verify the input MLIR file. Returns null on error. -OwningOpRef<Operation *> loadModule(MLIRContext &context, - StringRef inputFilename, - bool insertImplictModule) { +static OwningOpRef<Operation *> loadModule(MLIRContext &context, + StringRef inputFilename, + bool insertImplictModule) { // Set up the input file. std::string errorMessage; auto file = openInputFile(inputFilename, &errorMessage); @@ -65,6 +65,11 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv, "Disable implicit addition of a top-level module op during parsing"), llvm::cl::init(false)}; + static llvm::cl::opt<bool> allowUnregisteredDialects( + "allow-unregistered-dialect", + llvm::cl::desc("Allow operation with no registered dialects"), + llvm::cl::init(false)); + llvm::cl::HideUnrelatedOptions(mlirReduceCategory); llvm::InitLLVM y(argc, argv); @@ -79,6 +84,8 @@ LogicalResult mlir::mlirReduceMain(int argc, char **argv, llvm::cl::PrintHelpMessage(); return success(); } + if (allowUnregisteredDialects) + context.allowUnregisteredDialects(); std::string errorMessage; diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 09e5a02..8eaac30 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -177,11 +177,10 @@ void CSEDriver::replaceUsesAndDelete(ScopedMapTy &knownValues, Operation *op, bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp, Operation *toOp) { assert(fromOp->getBlock() == toOp->getBlock()); - assert( - isa<MemoryEffectOpInterface>(fromOp) && - cast<MemoryEffectOpInterface>(fromOp).hasEffect<MemoryEffects::Read>() && - isa<MemoryEffectOpInterface>(toOp) && - cast<MemoryEffectOpInterface>(toOp).hasEffect<MemoryEffects::Read>()); + assert(hasEffect<MemoryEffects::Read>(fromOp) && + "expected read effect on fromOp"); + assert(hasEffect<MemoryEffects::Read>(toOp) && + "expected read effect on toOp"); Operation *nextOp = fromOp->getNextNode(); auto result = memEffectsCache.try_emplace(fromOp, std::make_pair(fromOp, nullptr)); @@ -245,11 +244,10 @@ LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, // Some simple use case of operation with memory side-effect are dealt with // here. Operations with no side-effect are done after. if (!isMemoryEffectFree(op)) { - auto memEffects = dyn_cast<MemoryEffectOpInterface>(op); // TODO: Only basic use case for operations with MemoryEffects::Read can be // eleminated now. More work needs to be done for more complicated patterns // and other side-effects. - if (!memEffects || !memEffects.onlyHasEffect<MemoryEffects::Read>()) + if (!hasSingleEffect<MemoryEffects::Read>(op)) return failure(); // Look for an existing definition for the operation. diff --git a/mlir/lib/Transforms/InlinerPass.cpp b/mlir/lib/Transforms/InlinerPass.cpp index 703e517..77a9e6c 100644 --- a/mlir/lib/Transforms/InlinerPass.cpp +++ b/mlir/lib/Transforms/InlinerPass.cpp @@ -18,6 +18,7 @@ #include "mlir/Analysis/CallGraph.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Inliner.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_INLINER @@ -120,8 +121,8 @@ static bool isProfitableToInline(const Inliner::ResolvedCall &resolvedCall, return true; unsigned ratio = countOps(calleeRegion) * 100 / callerOps; - LLVM_DEBUG(llvm::dbgs() << "Callee / caller operation ratio (max: " - << inliningThreshold << "%): " << ratio << "%\n"); + LDBG() << "Callee / caller operation ratio (max: " << inliningThreshold + << "%): " << ratio << "%"; return ratio <= inliningThreshold; } @@ -138,7 +139,7 @@ void InlinerPass::runOnOperation() { } // By default, assume that any inlining is profitable. - auto profitabilityCb = [=](const Inliner::ResolvedCall &call) { + auto profitabilityCb = [this](const Inliner::ResolvedCall &call) { return isProfitableToInline(call, inliningThreshold); }; diff --git a/mlir/lib/Transforms/Mem2Reg.cpp b/mlir/lib/Transforms/Mem2Reg.cpp index cf039c3..d36a3c1 100644 --- a/mlir/lib/Transforms/Mem2Reg.cpp +++ b/mlir/lib/Transforms/Mem2Reg.cpp @@ -19,6 +19,7 @@ #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/GenericIteratedDominanceFrontier.h" namespace mlir { @@ -632,8 +633,7 @@ MemorySlotPromoter::promoteSlot() { } } - LLVM_DEBUG(llvm::dbgs() << "[mem2reg] Promoted memory slot: " << slot.ptr - << "\n"); + LDBG() << "Promoted memory slot: " << slot.ptr; if (statistics.promotedAmount) (*statistics.promotedAmount)++; diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 4ccb83f..0e84b6d 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -258,18 +258,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { static void processSimpleOp(Operation *op, RunLivenessAnalysis &la, DenseSet<Value> &nonLiveSet, RDVFinalCleanupList &cl) { - LDBG() << "Processing simple op: " << *op; if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) { - LDBG() - << "Simple op is not memory effect free or has live results, skipping: " - << *op; + LDBG() << "Simple op is not memory effect free or has live results, " + "preserving it: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); return; } LDBG() << "Simple op has all dead results and is memory effect free, scheduling " "for removal: " - << *op; + << OpWithFlags(op, OpPrintingFlags().skipRegions()); cl.operations.push_back(op); collectNonLiveValues(nonLiveSet, op->getResults(), BitVector(op->getNumResults(), true)); @@ -345,8 +344,7 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module, // being returned, in order to optimize our IR. So, this demonstrates how we // can make our optimization strong by even removing a live return value (%0), // since it forwards only to non-live value(s) (%1#1). - Operation *lastReturnOp = funcOp.back().getTerminator(); - size_t numReturns = lastReturnOp->getNumOperands(); + size_t numReturns = funcOp.getNumResults(); BitVector nonLiveRets(numReturns, true); for (SymbolTable::SymbolUse use : uses) { Operation *callOp = use.getUser(); @@ -728,19 +726,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, /// Removes dead values collected in RDVFinalCleanupList. /// To be run once when all dead values have been collected. static void cleanUpDeadVals(RDVFinalCleanupList &list) { + LDBG() << "Starting cleanup of dead values..."; + // 1. Operations + LDBG() << "Cleaning up " << list.operations.size() << " operations"; for (auto &op : list.operations) { + LDBG() << "Erasing operation: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); op->dropAllUses(); op->erase(); } // 2. Values + LDBG() << "Cleaning up " << list.values.size() << " values"; for (auto &v : list.values) { + LDBG() << "Dropping all uses of value: " << v; v.dropAllUses(); } // 3. Functions + LDBG() << "Cleaning up " << list.functions.size() << " functions"; for (auto &f : list.functions) { + LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName(); + LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments"; + LDBG() << " Erasing " << f.nonLiveRets.count() + << " non-live return values"; // Some functions may not allow erasing arguments or results. These calls // return failure in such cases without modifying the function, so it's okay // to proceed. @@ -749,44 +759,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { } // 4. Operands + LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperationToCleanup &o : list.operands) { - if (o.op->getNumOperands() > 0) + if (o.op->getNumOperands() > 0) { + LDBG() << "Erasing " << o.nonLive.count() + << " non-live operands from operation: " + << OpWithFlags(o.op, OpPrintingFlags().skipRegions()); o.op->eraseOperands(o.nonLive); + } } // 5. Results + LDBG() << "Cleaning up " << list.results.size() << " result lists"; for (auto &r : list.results) { + LDBG() << "Erasing " << r.nonLive.count() + << " non-live results from operation: " + << OpWithFlags(r.op, OpPrintingFlags().skipRegions()); dropUsesAndEraseResults(r.op, r.nonLive); } // 6. Blocks + LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists"; for (auto &b : list.blocks) { // blocks that are accessed via multiple codepaths processed once if (b.b->getNumArguments() != b.nonLiveArgs.size()) continue; + LDBG() << "Erasing " << b.nonLiveArgs.count() + << " non-live arguments from block: " << b.b; // it iterates backwards because erase invalidates all successor indexes for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { if (!b.nonLiveArgs[i]) continue; + LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i); b.b->getArgument(i).dropAllUses(); b.b->eraseArgument(i); } } // 7. Successor Operands + LDBG() << "Cleaning up " << list.successorOperands.size() + << " successor operand lists"; for (auto &op : list.successorOperands) { SuccessorOperands successorOperands = op.branch.getSuccessorOperands(op.successorIndex); // blocks that are accessed via multiple codepaths processed once if (successorOperands.size() != op.nonLiveOperands.size()) continue; + LDBG() << "Erasing " << op.nonLiveOperands.count() + << " non-live successor operands from successor " + << op.successorIndex << " of branch: " + << OpWithFlags(op.branch, OpPrintingFlags().skipRegions()); // it iterates backwards because erase invalidates all successor indexes for (int i = successorOperands.size() - 1; i >= 0; --i) { if (!op.nonLiveOperands[i]) continue; + LDBG() << " Erasing successor operand " << i << ": " + << successorOperands[i]; successorOperands.erase(i); } } + + LDBG() << "Finished cleanup of dead values"; } struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> { diff --git a/mlir/lib/Transforms/SROA.cpp b/mlir/lib/Transforms/SROA.cpp index 67f536a..859c030 100644 --- a/mlir/lib/Transforms/SROA.cpp +++ b/mlir/lib/Transforms/SROA.cpp @@ -12,6 +12,7 @@ #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/Interfaces/MemorySlotInterfaces.h" #include "mlir/Transforms/Passes.h" +#include "llvm/Support/DebugLog.h" namespace mlir { #define GEN_PASS_DEF_SROA @@ -180,8 +181,7 @@ static void destructureSlot( assert(slot.ptr.use_empty() && "after destructuring, the original slot " "pointer should no longer be used"); - LLVM_DEBUG(llvm::dbgs() << "[sroa] Destructured memory slot: " << slot.ptr - << "\n"); + LDBG() << "Destructured memory slot: " << slot.ptr; if (statistics.destructuredAmount) (*statistics.destructuredAmount)++; diff --git a/mlir/lib/Transforms/SymbolDCE.cpp b/mlir/lib/Transforms/SymbolDCE.cpp index 0a925c4..87885be 100644 --- a/mlir/lib/Transforms/SymbolDCE.cpp +++ b/mlir/lib/Transforms/SymbolDCE.cpp @@ -13,8 +13,11 @@ #include "mlir/Transforms/Passes.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/SymbolTable.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" +#include "llvm/Support/InterleavedRange.h" namespace mlir { #define GEN_PASS_DEF_SYMBOLDCE @@ -87,8 +90,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, SymbolTableCollection &symbolTable, bool symbolTableIsHidden, DenseSet<Operation *> &liveSymbols) { - LLVM_DEBUG(llvm::dbgs() << "computeLiveness: " << symbolTableOp->getName() - << "\n"); + LDBG() << "computeLiveness: " + << OpWithFlags(symbolTableOp, OpPrintingFlags().skipRegions()); // A worklist of live operations to propagate uses from. SmallVector<Operation *, 16> worklist; @@ -116,7 +119,8 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, // consideration. while (!worklist.empty()) { Operation *op = worklist.pop_back_val(); - LLVM_DEBUG(llvm::dbgs() << "processing: " << op->getName() << "\n"); + LDBG() << "processing: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // If this is a symbol table, recursively compute its liveness. if (op->hasTrait<OpTrait::SymbolTable>()) { @@ -124,13 +128,14 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, // symbol, or if it is a private symbol. SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op); bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate(); - LLVM_DEBUG(llvm::dbgs() << "\tsymbol table: " << op->getName() - << " is hidden: " << symIsHidden << "\n"); + LDBG() << "\tsymbol table: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " is hidden: " << symIsHidden; if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols))) return failure(); } else { - LLVM_DEBUG(llvm::dbgs() - << "\tnon-symbol table: " << op->getName() << "\n"); + LDBG() << "\tnon-symbol table: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // If the op is not a symbol table, then, unless op itself is dead which // would be handled by DCE, we need to check all the regions and blocks // within the op to find the uses (e.g., consider visibility within op as @@ -160,20 +165,17 @@ LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp, } SmallVector<Operation *, 4> resolvedSymbols; - LLVM_DEBUG(llvm::dbgs() << "uses of " << op->getName() << "\n"); + LDBG() << "uses of " << OpWithFlags(op, OpPrintingFlags().skipRegions()); for (const SymbolTable::SymbolUse &use : *uses) { - LLVM_DEBUG(llvm::dbgs() << "\tuse: " << use.getUser() << "\n"); + LDBG() << "\tuse: " << use.getUser(); // Lookup the symbols referenced by this use. resolvedSymbols.clear(); if (failed(symbolTable.lookupSymbolIn(parentOp, use.getSymbolRef(), resolvedSymbols))) // Ignore references to unknown symbols. continue; - LLVM_DEBUG({ - llvm::dbgs() << "\t\tresolved symbols: "; - llvm::interleaveComma(resolvedSymbols, llvm::dbgs()); - llvm::dbgs() << "\n"; - }); + LDBG() << "\t\tresolved symbols: " + << llvm::interleaved(resolvedSymbols, ", "); // Mark each of the resolved symbols as live. for (Operation *resolvedSymbol : resolvedSymbols) diff --git a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp index cfd568f..19cf464 100644 --- a/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp +++ b/mlir/lib/Transforms/Utils/ControlFlowSinkUtils.cpp @@ -21,7 +21,10 @@ #include "mlir/Transforms/ControlFlowSinkUtils.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "llvm/Support/DebugLog.h" #include <vector> #define DEBUG_TYPE "cf-sink" @@ -84,13 +87,15 @@ bool Sinker::allUsersDominatedBy(Operation *op, Region *region) { void Sinker::tryToSinkPredecessors(Operation *user, Region *region, std::vector<Operation *> &stack) { - LLVM_DEBUG(user->print(llvm::dbgs() << "\nContained op:\n")); + LDBG() << "Contained op: " + << OpWithFlags(user, OpPrintingFlags().skipRegions()); for (Value value : user->getOperands()) { Operation *op = value.getDefiningOp(); // Ignore block arguments and ops that are already inside the region. if (!op || op->getParentRegion() == region) continue; - LLVM_DEBUG(op->print(llvm::dbgs() << "\nTry to sink:\n")); + LDBG() << "Try to sink:\n" + << OpWithFlags(op, OpPrintingFlags().skipRegions()); // If the op's users are all in the region and it can be moved, then do so. if (allUsersDominatedBy(op, region) && shouldMoveIntoRegion(op, region)) { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 0c26b4e..5ba109d 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -182,15 +182,24 @@ private: /// conversions.) static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__"; +/// Return the operation that defines all values in the vector. Return nullptr +/// if the values are not defined by the same operation. +static Operation *getCommonDefiningOp(const ValueVector &values) { + assert(!values.empty() && "expected non-empty value vector"); + Operation *op = values.front().getDefiningOp(); + for (Value v : llvm::drop_begin(values)) { + if (v.getDefiningOp() != op) + return nullptr; + } + return op; +} + /// A vector of values is a pure type conversion if all values are defined by /// the same operation and the operation has the `kPureTypeConversionMarker` /// attribute. static bool isPureTypeConversion(const ValueVector &values) { assert(!values.empty() && "expected non-empty value vector"); - Operation *op = values.front().getDefiningOp(); - for (Value v : llvm::drop_begin(values)) - if (v.getDefiningOp() != op) - return false; + Operation *op = getCommonDefiningOp(values); return op && op->hasAttr(kPureTypeConversionMarker); } @@ -839,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) { namespace mlir { namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { - explicit ConversionPatternRewriterImpl(MLIRContext *ctx, + explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter, const ConversionConfig &config) - : context(ctx), config(config) {} + : rewriter(rewriter), config(config), + notifyingRewriter(rewriter.getContext(), config.listener) {} //===--------------------------------------------------------------------===// // State Management @@ -863,6 +873,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// failure. template <typename RewriteTy, typename... Args> void appendRewrite(Args &&...args) { + assert(config.allowPatternRollback && "appending rewrites is not allowed"); rewrites.push_back( std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...)); } @@ -877,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// is the tag used when describing a value within a diagnostic, e.g. /// "operand". LogicalResult remapValues(StringRef valueDiagTag, - std::optional<Location> inputLoc, - PatternRewriter &rewriter, ValueRange values, + std::optional<Location> inputLoc, ValueRange values, SmallVector<ValueVector> &remapped); /// Return "true" if the given operation is ignored, and does not need to be @@ -889,15 +899,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { bool wasOpReplaced(Operation *op) const; /// Lookup the most recently mapped values with the desired types in the - /// mapping. - /// - /// Special cases: - /// - If the desired type range is empty, simply return the most recently - /// mapped values. - /// - If there is no mapping to the desired types, also return the most - /// recently mapped values. - /// - If there is no mapping for the given values at all, return the given - /// value. + /// mapping, taking into account only replacements. Perform a best-effort + /// search for existing materializations with the desired types. /// /// If `skipPureTypeConversions` is "true", materializations that are pure /// type conversions are not considered. @@ -915,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Convert the types of block arguments within the given region. FailureOr<Block *> - convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region, - const TypeConverter &converter, + convertRegionTypes(Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion); /// Apply the given signature conversion on the given block. The new block @@ -926,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// translate between the origin argument types and those specified in the /// signature conversion. Block *applySignatureConversion( - ConversionPatternRewriter &rewriter, Block *block, - const TypeConverter *converter, + Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion); /// Replace the results of the given operation with the given values and @@ -976,7 +977,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, - UnrealizedConversionCastOp *castOp = nullptr, bool isPureTypeConversion = true); /// Find a replacement value for the given SSA value in the conversion value @@ -1058,14 +1058,17 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { // State //===--------------------------------------------------------------------===// - /// MLIR context. - MLIRContext *context; + /// The rewriter that is used to perform the conversion. + ConversionPatternRewriter &rewriter; // Mapping between replaced values that differ in type. This happens when // replacing a value with one of a different type. ConversionValueMapping mapping; /// Ordered list of block operations (creations, splits, motions). + /// This vector is maintained only if `allowPatternRollback` is set to + /// "true". Otherwise, all IR rewrites are materialized immediately and no + /// bookkeeping is needed. SmallVector<std::unique_ptr<IRRewrite>> rewrites; /// A set of operations that should no longer be considered for legalization. @@ -1089,6 +1092,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// by the current pattern. SetVector<Block *> patternInsertedBlocks; + /// A list of unresolved materializations that were created by the current + /// pattern. + DenseSet<UnrealizedConversionCastOp> patternMaterializations; + /// A mapping for looking up metadata of unresolved materializations. DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo> unresolvedMaterializations; @@ -1104,15 +1111,37 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Dialect conversion configuration. const ConversionConfig &config; + /// A set of erased operations. This set is utilized only if + /// `allowPatternRollback` is set to "false". Conceptually, this set is + /// similar to `replacedOps` (which is maintained when the flag is set to + /// "true"). However, erasing from a DenseSet is more efficient than erasing + /// from a SetVector. + DenseSet<Operation *> erasedOps; + + /// A set of erased blocks. This set is utilized only if + /// `allowPatternRollback` is set to "false". + DenseSet<Block *> erasedBlocks; + + /// A rewriter that notifies the listener (if any) about all IR + /// modifications. This rewriter is utilized only if `allowPatternRollback` + /// is set to "false". If the flag is set to "true", the listener is notified + /// with a separate mechanism (e.g., in `IRRewrite::commit`). + IRRewriter notifyingRewriter; + #ifndef NDEBUG + /// A set of replaced block arguments. This set is for debugging purposes + /// only and it is maintained only if `allowPatternRollback` is set to + /// "true". + DenseSet<BlockArgument> replacedArgs; + /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra /// verification. SmallPtrSet<Operation *, 1> pendingRootUpdates; /// A raw output stream used to prefix the debug log. - llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + "] ").str(), - llvm::dbgs(), /*HasPendingNewline=*/false}; + llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(), + llvm::dbgs()}; /// A logger used to emit diagnostics during the conversion process. llvm::ScopedPrinter logger{os}; @@ -1140,11 +1169,8 @@ void BlockTypeConversionRewrite::rollback() { getNewBlock()->replaceAllUsesWith(getOrigBlock()); } -void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); - if (!repl) - return; - +static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg, + Value repl) { if (isa<BlockArgument>(repl)) { rewriter.replaceAllUsesWith(arg, repl); return; @@ -1161,6 +1187,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { }); } +void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { + Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); + if (!repl) + return; + performReplaceBlockArg(rewriter, arg, repl); +} + void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { @@ -1223,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() { } void ConversionPatternRewriterImpl::applyRewrites() { - // Commit all rewrites. - IRRewriter rewriter(context, config.listener); + // Commit all rewrites. Use a new rewriter, so the modifications are not + // tracked for rollback purposes etc. + IRRewriter irRewriter(rewriter.getContext(), config.listener); // Note: New rewrites may be added during the "commit" phase and the // `rewrites` vector may reallocate. for (size_t i = 0; i < rewrites.size(); ++i) - rewrites[i]->commit(rewriter); + rewrites[i]->commit(irRewriter); // Clean up all rewrites. SingleEraseRewriter eraseRewriter( - context, /*opErasedCallback=*/[&](Operation *op) { + rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) { if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) unresolvedMaterializations.erase(castOp); }); @@ -1246,6 +1280,30 @@ void ConversionPatternRewriterImpl::applyRewrites() { ValueVector ConversionPatternRewriterImpl::lookupOrDefault( Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const { + // Helper function that looks up a single value. + auto lookup = [&](const ValueVector &values) -> ValueVector { + assert(!values.empty() && "expected non-empty value vector"); + + // If the pattern rollback is enabled, use the mapping to look up the + // values. + if (config.allowPatternRollback) + return mapping.lookup(values); + + // Otherwise, look up values by examining the IR. All replacements have + // already been materialized in IR. + Operation *op = getCommonDefiningOp(values); + if (!op) + return {}; + auto castOp = dyn_cast<UnrealizedConversionCastOp>(op); + if (!castOp) + return {}; + if (!this->unresolvedMaterializations.contains(castOp)) + return {}; + if (castOp.getOutputs() != values) + return {}; + return castOp.getInputs(); + }; + // Helper function that looks up each value in `values` individually and then // composes the results. If that fails, it tries to look up the entire vector // at once. @@ -1253,7 +1311,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault( // If possible, replace each value with (one or multiple) mapped values. ValueVector next; for (Value v : values) { - ValueVector r = mapping.lookup({v}); + ValueVector r = lookup({v}); if (!r.empty()) { llvm::append_range(next, r); } else { @@ -1273,7 +1331,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault( // be stored (and looked up) in the mapping. But for performance reasons, // we choose to reuse existing IR (when possible) instead of creating it // multiple times. - ValueVector r = mapping.lookup(values); + ValueVector r = lookup(values); if (r.empty()) { // No mapping found: The lookup stops here. return {}; @@ -1347,21 +1405,13 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state, void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep, StringRef patternName) { for (auto &rewrite : - llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) { - if (!config.allowPatternRollback && - !isa<UnresolvedMaterializationRewrite>(rewrite)) { - // Unresolved materializations can always be rolled back (erased). - llvm::report_fatal_error("pattern '" + patternName + - "' rollback of IR modifications requested"); - } + llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) rewrite->rollback(); - } rewrites.resize(numRewritesToKeep); } LogicalResult ConversionPatternRewriterImpl::remapValues( - StringRef valueDiagTag, std::optional<Location> inputLoc, - PatternRewriter &rewriter, ValueRange values, + StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values, SmallVector<ValueVector> &remapped) { remapped.reserve(llvm::size(values)); @@ -1383,7 +1433,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( // If there is no legal conversion, fail to match this pattern. SmallVector<Type, 1> legalTypes; - if (failed(currentTypeConverter->convertType(origType, legalTypes))) { + if (failed(currentTypeConverter->convertType(operand, legalTypes))) { notifyMatchFailure(operandLoc, [=](Diagnostic &diag) { diag << "unable to convert type for " << valueDiagTag << " #" << it.index() << ", type was " << origType; @@ -1419,12 +1469,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { // Check to see if this operation is ignored or was replaced. - return replacedOps.count(op) || ignoredOps.count(op); + return wasOpReplaced(op) || ignoredOps.count(op); } bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { // Check to see if this operation was replaced. - return replacedOps.count(op); + return replacedOps.count(op) || erasedOps.count(op); } //===----------------------------------------------------------------------===// @@ -1432,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { //===----------------------------------------------------------------------===// FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( - ConversionPatternRewriter &rewriter, Region *region, - const TypeConverter &converter, + Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion) { regionToConverter[region] = &converter; if (region->empty()) @@ -1448,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes( if (!conversion) return failure(); // Convert the block with the computed signature. - applySignatureConversion(rewriter, &block, &converter, *conversion); + applySignatureConversion(&block, &converter, *conversion); } // Convert the entry block. If an entry signature conversion was provided, // use that one. Otherwise, compute the signature with the type converter. if (entryConversion) - return applySignatureConversion(rewriter, ®ion->front(), &converter, + return applySignatureConversion(®ion->front(), &converter, *entryConversion); std::optional<TypeConverter::SignatureConversion> conversion = converter.convertBlockSignature(®ion->front()); if (!conversion) return failure(); - return applySignatureConversion(rewriter, ®ion->front(), &converter, - *conversion); + return applySignatureConversion(®ion->front(), &converter, *conversion); } Block *ConversionPatternRewriterImpl::applySignatureConversion( - ConversionPatternRewriter &rewriter, Block *block, - const TypeConverter *converter, + Block *block, const TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion) { #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // A block cannot be converted multiple times. @@ -1508,7 +1555,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // a bit more efficient, so we try to do that when possible. bool fastPath = !config.listener; if (fastPath) { - appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end()); + if (config.allowPatternRollback) + appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end()); newBlock->getOperations().splice(newBlock->end(), block->getOperations()); } else { while (!block->empty()) @@ -1534,7 +1582,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( origArg.getLoc(), /*valuesToMap=*/{}, /*inputs=*/ValueRange(), /*outputTypes=*/origArgType, /*originalType=*/Type(), converter, - /*castOp=*/nullptr, /*isPureTypeConversion=*/false) + /*isPureTypeConversion=*/false) .front(); replaceUsesOfBlockArgument(origArg, mat, converter); continue; @@ -1556,7 +1604,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( replaceUsesOfBlockArgument(origArg, replArgs, converter); } - appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock); + if (config.allowPatternRollback) + appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.) @@ -1575,7 +1624,7 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc, ValueVector valuesToMap, ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter, - UnrealizedConversionCastOp *castOp, bool isPureTypeConversion) { + bool isPureTypeConversion) { assert((!originalType || kind == MaterializationKind::Target) && "original type is valid only for target materializations"); assert(TypeRange(inputs) != outputTypes && @@ -1585,23 +1634,35 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( // tracking the materialization like we do for other operations. OpBuilder builder(outputTypes.front().getContext()); builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); - auto convertOp = + UnrealizedConversionCastOp convertOp = UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs); + if (config.attachDebugMaterializationKind) { + StringRef kindStr = + kind == MaterializationKind::Source ? "source" : "target"; + convertOp->setAttr("__kind__", builder.getStringAttr(kindStr)); + } if (isPureTypeConversion) convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr()); - if (!valuesToMap.empty()) - mapping.map(valuesToMap, convertOp.getResults()); - if (castOp) - *castOp = convertOp; + + // Register the materialization. unresolvedMaterializations[convertOp] = UnresolvedMaterializationInfo(converter, kind, originalType); - appendRewrite<UnresolvedMaterializationRewrite>(convertOp, - std::move(valuesToMap)); + if (config.allowPatternRollback) { + if (!valuesToMap.empty()) + mapping.map(valuesToMap, convertOp.getResults()); + appendRewrite<UnresolvedMaterializationRewrite>(convertOp, + std::move(valuesToMap)); + } else { + patternMaterializations.insert(convertOp); + } return convertOp.getResults(); } Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( Value value, const TypeConverter *converter) { + assert(config.allowPatternRollback && + "this code path is valid only in rollback mode"); + // Try to find a replacement value with the same type in the conversion value // mapping. This includes cached materializations. We try to reuse those // instead of generating duplicate IR. @@ -1663,26 +1724,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(op->getParentOp()) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) && "attempting to insert into a block within a replaced/erased op"); + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyOperationInserted(op, previous); + if (wasDetached) { - // If the op was detached, it is most likely a newly created op. - // TODO: If the same op is inserted multiple times from a detached state, - // the rollback mechanism may erase the same op multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite<CreateOperationRewrite>(op); + // If the op was detached, it is most likely a newly created op. Add it the + // set of newly created ops, so that it will be legalized. If this op is + // not a newly created op, it will be legalized a second time, which is + // inefficient but harmless. patternNewOps.insert(op); + + if (config.allowPatternRollback) { + // TODO: If the same op is inserted multiple times from a detached + // state, the rollback mechanism may erase the same op multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite<CreateOperationRewrite>(op); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased operations that must be kept up to date. + erasedOps.erase(op); + } return; } // The op was moved from one place to another. - appendRewrite<MoveOperationRewrite>(op, previous); + if (config.allowPatternRollback) + appendRewrite<MoveOperationRewrite>(op, previous); +} + +/// Given that `fromRange` is about to be replaced with `toRange`, compute +/// replacement values with the types of `fromRange`. +static SmallVector<Value> +getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, + const SmallVector<SmallVector<Value>> &toRange, + const TypeConverter *converter) { + assert(!impl.config.allowPatternRollback && + "this code path is valid only in 'no rollback' mode"); + SmallVector<Value> repls; + for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) { + if (from.use_empty()) { + // The replaced value is dead. No replacement value is needed. + repls.push_back(Value()); + continue; + } + + if (to.empty()) { + // The replaced value is dropped. Materialize a replacement value "out of + // thin air". + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(from), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/ValueRange(), + /*outputTypes=*/from.getType(), /*originalType=*/Type(), + converter)[0]; + repls.push_back(srcMat); + continue; + } + + if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) { + // The replacement value already has the correct type. Use it directly. + repls.push_back(to[0]); + continue; + } + + // The replacement value has the wrong type. Build a source materialization + // to the original type. + // TODO: This is a bit inefficient. We should try to reuse existing + // materializations if possible. This would require an extension of the + // `lookupOrDefault` API. + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(to), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(), + /*originalType=*/Type(), converter)[0]; + repls.push_back(srcMat); + } + + return repls; } void ConversionPatternRewriterImpl::replaceOp( Operation *op, SmallVector<SmallVector<Value>> &&newValues) { - assert(newValues.size() == op->getNumResults()); + assert(newValues.size() == op->getNumResults() && + "incorrect number of replacement values"); + + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + SmallVector<Value> repls = getReplacementValues( + *this, op->getResults(), newValues, currentTypeConverter); + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + op->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + op->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Replace the op with the replacement values and notify the listener. + notifyingRewriter.replaceOp(op, repls); + return; + } + assert(!ignoredOps.contains(op) && "operation was already replaced"); // Check if replaced op is an unresolved materialization, i.e., an @@ -1704,8 +1858,7 @@ void ConversionPatternRewriterImpl::replaceOp( MaterializationKind::Source, computeInsertPoint(result), result.getLoc(), /*valuesToMap=*/{result}, /*inputs=*/ValueRange(), /*outputTypes=*/result.getType(), /*originalType=*/Type(), - currentTypeConverter, /*castOp=*/nullptr, - /*isPureTypeConversion=*/false); + currentTypeConverter, /*isPureTypeConversion=*/false); continue; } @@ -1722,11 +1875,59 @@ void ConversionPatternRewriterImpl::replaceOp( void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument( BlockArgument from, ValueRange to, const TypeConverter *converter) { + if (!config.allowPatternRollback) { + SmallVector<Value> toConv = llvm::to_vector(to); + SmallVector<Value> repls = + getReplacementValues(*this, from, {toConv}, converter); + IRRewriter r(from.getContext()); + Value repl = repls.front(); + if (!repl) + return; + + performReplaceBlockArg(r, from, repl); + return; + } + +#ifndef NDEBUG + // Make sure that a block argument is not replaced multiple times. In + // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current + // uses of the given block argument, but also all future uses that may be + // introduced by future pattern applications. Therefore, it does not make + // sense to call `replaceUsesOfBlockArgument` multiple times with the same + // block argument. Doing so would overwrite the mapping and mess with the + // internal state of the dialect conversion driver. + assert(!replacedArgs.contains(from) && + "attempting to replace a block argument that was already replaced"); + replacedArgs.insert(from); +#endif // NDEBUG + appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter); mapping.map(from, to); } void ConversionPatternRewriterImpl::eraseBlock(Block *block) { + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + block->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + block->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Erase the block and notify the listener. + notifyingRewriter.eraseBlock(block); + return; + } + assert(!wasOpReplaced(block->getParentOp()) && "attempting to erase a block within a replaced/erased op"); appendRewrite<EraseBlockRewrite>(block); @@ -1760,23 +1961,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(newParentOp) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) && "attempting to insert into a region within a replaced/erased op"); (void)newParentOp; + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyBlockInserted(block, previous, previousIt); + patternInsertedBlocks.insert(block); if (wasDetached) { // If the block was detached, it is most likely a newly created block. - // TODO: If the same block is inserted multiple times from a detached state, - // the rollback mechanism may erase the same block multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite<CreateBlockRewrite>(block); + if (config.allowPatternRollback) { + // TODO: If the same block is inserted multiple times from a detached + // state, the rollback mechanism may erase the same block multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite<CreateBlockRewrite>(block); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased blocks that must be kept up to date. + erasedBlocks.erase(block); + } return; } // The block was moved from one place to another. - appendRewrite<MoveBlockRewrite>(block, previous, previousIt); + if (config.allowPatternRollback) + appendRewrite<MoveBlockRewrite>(block, previous, previousIt); } void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, @@ -1803,7 +2018,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure( ConversionPatternRewriter::ConversionPatternRewriter( MLIRContext *ctx, const ConversionConfig &config) : PatternRewriter(ctx), - impl(new detail::ConversionPatternRewriterImpl(ctx, config)) { + impl(new detail::ConversionPatternRewriterImpl(*this, config)) { setListener(impl.get()); } @@ -1880,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion( assert(!impl->wasOpReplaced(block->getParentOp()) && "attempting to apply a signature conversion to a block within a " "replaced/erased op"); - return impl->applySignatureConversion(*this, block, converter, conversion); + return impl->applySignatureConversion(block, converter, conversion); } FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( @@ -1889,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( assert(!impl->wasOpReplaced(region->getParentOp()) && "attempting to apply a signature conversion to a block within a " "replaced/erased op"); - return impl->convertRegionTypes(*this, region, converter, entryConversion); + return impl->convertRegionTypes(region, converter, entryConversion); } void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, @@ -1908,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, Value ConversionPatternRewriter::getRemappedValue(Value key) { SmallVector<ValueVector> remappedValues; - if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key, + if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key, remappedValues))) return nullptr; assert(remappedValues.front().size() == 1 && "1:N conversion not supported"); @@ -1921,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys, if (keys.empty()) return success(); SmallVector<ValueVector> remapped; - if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys, + if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys, remapped))) return failure(); for (const auto &values : remapped) { @@ -1956,7 +2171,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, // a bit more efficient, so we try to do that when possible. bool fastPath = !getConfig().listener; - if (fastPath) + if (fastPath && impl->config.allowPatternRollback) impl->inlineBlockBefore(source, dest, before); // Replace all uses of block arguments. @@ -1982,6 +2197,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, } void ConversionPatternRewriter::startOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + // Pattern rollback is not allowed: no extra bookkeeping is needed. + PatternRewriter::startOpModification(op); + return; + } assert(!impl->wasOpReplaced(op) && "attempting to modify a replaced/erased op"); #ifndef NDEBUG @@ -1991,20 +2211,29 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { } void ConversionPatternRewriter::finalizeOpModification(Operation *op) { - assert(!impl->wasOpReplaced(op) && - "attempting to modify a replaced/erased op"); - PatternRewriter::finalizeOpModification(op); impl->patternModifiedOps.insert(op); + if (!impl->config.allowPatternRollback) { + PatternRewriter::finalizeOpModification(op); + if (getConfig().listener) + getConfig().listener->notifyOperationModified(op); + return; + } // There is nothing to do here, we only need to track the operation at the // start of the update. #ifndef NDEBUG + assert(!impl->wasOpReplaced(op) && + "attempting to modify a replaced/erased op"); assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); #endif } void ConversionPatternRewriter::cancelOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + PatternRewriter::cancelOpModification(op); + return; + } #ifndef NDEBUG assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); @@ -2029,17 +2258,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { // ConversionPattern //===----------------------------------------------------------------------===// -SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands( +FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands( ArrayRef<ValueRange> operands) const { SmallVector<Value> oneToOneOperands; oneToOneOperands.reserve(operands.size()); for (ValueRange operand : operands) { if (operand.size() != 1) - llvm::report_fatal_error("pattern '" + getDebugName() + - "' does not support 1:N conversion"); + return failure(); + oneToOneOperands.push_back(operand.front()); } - return oneToOneOperands; + return std::move(oneToOneOperands); } LogicalResult @@ -2054,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op, // Remap the operands of the operation. SmallVector<ValueVector> remapped; - if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, + if (failed(rewriterImpl.remapValues("operand", op->getLoc(), op->getOperands(), remapped))) { return failure(); } @@ -2076,7 +2305,8 @@ class OperationLegalizer { public: using LegalizationAction = ConversionTarget::LegalizationAction; - OperationLegalizer(const ConversionTarget &targetInfo, + OperationLegalizer(ConversionPatternRewriter &rewriter, + const ConversionTarget &targetInfo, const FrozenRewritePatternSet &patterns); /// Returns true if the given operation is known to be illegal on the target. @@ -2084,29 +2314,25 @@ public: /// Attempt to legalize the given operation. Returns success if the operation /// was legalized, failure otherwise. - LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); + LogicalResult legalize(Operation *op); /// Returns the conversion target in use by the legalizer. const ConversionTarget &getTarget() { return target; } private: /// Attempt to legalize the given operation by folding it. - LogicalResult legalizeWithFold(Operation *op, - ConversionPatternRewriter &rewriter); + LogicalResult legalizeWithFold(Operation *op); /// Attempt to legalize the given operation by applying a pattern. Returns /// success if the operation was legalized, failure otherwise. - LogicalResult legalizeWithPattern(Operation *op, - ConversionPatternRewriter &rewriter); + LogicalResult legalizeWithPattern(Operation *op); /// Return true if the given pattern may be applied to the given operation, /// false otherwise. - bool canApplyPattern(Operation *op, const Pattern &pattern, - ConversionPatternRewriter &rewriter); + bool canApplyPattern(Operation *op, const Pattern &pattern); /// Legalize the resultant IR after successfully applying the given pattern. LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern, - ConversionPatternRewriter &rewriter, const RewriterState &curState, const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, @@ -2115,18 +2341,12 @@ private: /// Legalizes the actions registered during the execution of a pattern. LogicalResult legalizePatternBlockRewrites(Operation *op, - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, const SetVector<Block *> &insertedBlocks, const SetVector<Operation *> &newOps); LogicalResult - legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, - const SetVector<Operation *> &newOps); + legalizePatternCreatedOperations(const SetVector<Operation *> &newOps); LogicalResult - legalizePatternRootUpdates(ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, - const SetVector<Operation *> &modifiedOps); + legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps); //===--------------------------------------------------------------------===// // Cost Model @@ -2169,6 +2389,9 @@ private: /// The current set of patterns that have been applied. SmallPtrSet<const Pattern *, 8> appliedPatterns; + /// The rewriter to use when converting operations. + ConversionPatternRewriter &rewriter; + /// The legalization information provided by the target. const ConversionTarget ⌖ @@ -2177,9 +2400,10 @@ private: }; } // namespace -OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo, +OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter, + const ConversionTarget &targetInfo, const FrozenRewritePatternSet &patterns) - : target(targetInfo), applicator(patterns) { + : rewriter(rewriter), target(targetInfo), applicator(patterns) { // The set of patterns that can be applied to illegal operations to transform // them into legal ones. DenseMap<OperationName, LegalizationPatterns> legalizerPatterns; @@ -2193,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const { return target.isIllegal(op); } -LogicalResult -OperationLegalizer::legalize(Operation *op, - ConversionPatternRewriter &rewriter) { +LogicalResult OperationLegalizer::legalize(Operation *op) { #ifndef NDEBUG const char *logLineComment = "//===-------------------------------------------===//\n"; @@ -2257,19 +2479,21 @@ OperationLegalizer::legalize(Operation *op, return success(); } - // If the operation isn't legal, try to fold it in-place. - // TODO: Should we always try to do this, even if the op is - // already legal? - if (succeeded(legalizeWithFold(op, rewriter))) { - LLVM_DEBUG({ - logSuccess(logger, "operation was folded"); - logger.startLine() << logLineComment; - }); - return success(); + // If the operation is not legal, try to fold it in-place if the folding mode + // is 'BeforePatterns'. 'Never' will skip this. + const ConversionConfig &config = rewriter.getConfig(); + if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) { + if (succeeded(legalizeWithFold(op))) { + LLVM_DEBUG({ + logSuccess(logger, "operation was folded"); + logger.startLine() << logLineComment; + }); + return success(); + } } // Otherwise, we need to apply a legalization pattern to this operation. - if (succeeded(legalizeWithPattern(op, rewriter))) { + if (succeeded(legalizeWithPattern(op))) { LLVM_DEBUG({ logSuccess(logger, ""); logger.startLine() << logLineComment; @@ -2277,6 +2501,18 @@ OperationLegalizer::legalize(Operation *op, return success(); } + // If the operation can't be legalized via patterns, try to fold it in-place + // if the folding mode is 'AfterPatterns'. + if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) { + if (succeeded(legalizeWithFold(op))) { + LLVM_DEBUG({ + logSuccess(logger, "operation was folded"); + logger.startLine() << logLineComment; + }); + return success(); + } + } + LLVM_DEBUG({ logFailure(logger, "no matched legalization pattern"); logger.startLine() << logLineComment; @@ -2293,9 +2529,7 @@ static T moveAndReset(T &obj) { return result; } -LogicalResult -OperationLegalizer::legalizeWithFold(Operation *op, - ConversionPatternRewriter &rewriter) { +LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) { auto &rewriterImpl = rewriter.getImpl(); LLVM_DEBUG({ rewriterImpl.logger.startLine() << "* Fold {\n"; @@ -2329,14 +2563,14 @@ OperationLegalizer::legalizeWithFold(Operation *op, // An empty list of replacement values indicates that the fold was in-place. // As the operation changed, a new legalization needs to be attempted. if (replacementValues.empty()) - return legalize(op, rewriter); + return legalize(op); // Insert a replacement for 'op' with the folded replacement values. rewriter.replaceOp(op, replacementValues); // Recursively legalize any new constant operations. for (Operation *newOp : newOps) { - if (failed(legalize(newOp, rewriter))) { + if (failed(legalize(newOp))) { LLVM_DEBUG(logFailure(rewriterImpl.logger, "failed to legalize generated constant '{0}'", newOp->getName())); @@ -2381,9 +2615,7 @@ reportNewIrLegalizationFatalError(const Pattern &pattern, llvm::join(insertedBlockNames, ", ") + "}"); } -LogicalResult -OperationLegalizer::legalizeWithPattern(Operation *op, - ConversionPatternRewriter &rewriter) { +LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) { auto &rewriterImpl = rewriter.getImpl(); const ConversionConfig &config = rewriter.getConfig(); @@ -2415,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // Functor that returns if the given pattern may be applied. auto canApply = [&](const Pattern &pattern) { - bool canApply = canApplyPattern(op, pattern, rewriter); + bool canApply = canApplyPattern(op, pattern); if (canApply && config.listener) config.listener->notifyPatternBegin(pattern, op); return canApply; @@ -2425,17 +2657,23 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); -#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (!rewriterImpl.config.allowPatternRollback) { - // Returning "failure" after modifying IR is not allowed. + // Erase all unresolved materializations. + for (auto op : rewriterImpl.patternMaterializations) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + rewriterImpl.patternMaterializations.clear(); +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Expensive pattern check that can detect API violations. if (checkOp) { OperationFingerPrint fingerPrintAfterPattern(checkOp); if (fingerPrintAfterPattern != *topLevelFingerPrint) llvm::report_fatal_error("pattern '" + pattern.getDebugName() + "' returned failure but IR did change"); } - } #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + } rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); rewriterImpl.patternInsertedBlocks.clear(); @@ -2459,12 +2697,22 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // successfully applied. auto onSuccess = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); + if (!rewriterImpl.config.allowPatternRollback) { + // Eagerly erase unused materializations. + for (auto op : rewriterImpl.patternMaterializations) { + if (op->use_empty()) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + } + rewriterImpl.patternMaterializations.clear(); + } SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps); SetVector<Operation *> modifiedOps = moveAndReset(rewriterImpl.patternModifiedOps); SetVector<Block *> insertedBlocks = moveAndReset(rewriterImpl.patternInsertedBlocks); - auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps, + auto result = legalizePatternResult(op, pattern, curState, newOps, modifiedOps, insertedBlocks); appliedPatterns.erase(&pattern); if (failed(result)) { @@ -2483,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op, onSuccess); } -bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, - ConversionPatternRewriter &rewriter) { +bool OperationLegalizer::canApplyPattern(Operation *op, + const Pattern &pattern) { LLVM_DEBUG({ auto &os = rewriter.getImpl().logger; os.getOStream() << "\n"; @@ -2506,11 +2754,11 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern, } LogicalResult OperationLegalizer::legalizePatternResult( - Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter, - const RewriterState &curState, const SetVector<Operation *> &newOps, + Operation *op, const Pattern &pattern, const RewriterState &curState, + const SetVector<Operation *> &newOps, const SetVector<Operation *> &modifiedOps, const SetVector<Block *> &insertedBlocks) { - auto &impl = rewriter.getImpl(); + [[maybe_unused]] auto &impl = rewriter.getImpl(); assert(impl.pendingRootUpdates.empty() && "dangling root updates"); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS @@ -2528,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult( #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Legalize each of the actions registered during application. - if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks, - newOps)) || - failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) || - failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) { + if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) || + failed(legalizePatternRootUpdates(modifiedOps)) || + failed(legalizePatternCreatedOperations(newOps))) { return failure(); } @@ -2540,15 +2787,17 @@ LogicalResult OperationLegalizer::legalizePatternResult( } LogicalResult OperationLegalizer::legalizePatternBlockRewrites( - Operation *op, ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &impl, - const SetVector<Block *> &insertedBlocks, + Operation *op, const SetVector<Block *> &insertedBlocks, const SetVector<Operation *> &newOps) { + ConversionPatternRewriterImpl &impl = rewriter.getImpl(); SmallPtrSet<Operation *, 16> alreadyLegalized; // If the pattern moved or created any blocks, make sure the types of block // arguments get legalized. for (Block *block : insertedBlocks) { + if (impl.erasedBlocks.contains(block)) + continue; + // Only check blocks outside of the current operation. Operation *parentOp = block->getParentOp(); if (!parentOp || parentOp == op || block->getNumArguments() == 0) @@ -2564,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( "block")); return failure(); } - impl.applySignatureConversion(rewriter, block, converter, *conversion); + impl.applySignatureConversion(block, converter, *conversion); continue; } @@ -2573,7 +2822,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // operation, and blocks in regions created by this pattern will already be // legalized later on. if (!newOps.count(parentOp) && alreadyLegalized.insert(parentOp).second) { - if (failed(legalize(parentOp, rewriter))) { + if (failed(legalize(parentOp))) { LLVM_DEBUG(logFailure( impl.logger, "operation '{0}'({1}) became illegal after rewrite", parentOp->getName(), parentOp)); @@ -2585,11 +2834,10 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( } LogicalResult OperationLegalizer::legalizePatternCreatedOperations( - ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, const SetVector<Operation *> &newOps) { for (Operation *op : newOps) { - if (failed(legalize(op, rewriter))) { - LLVM_DEBUG(logFailure(impl.logger, + if (failed(legalize(op))) { + LLVM_DEBUG(logFailure(rewriter.getImpl().logger, "failed to legalize generated operation '{0}'({1})", op->getName(), op)); return failure(); @@ -2599,13 +2847,13 @@ LogicalResult OperationLegalizer::legalizePatternCreatedOperations( } LogicalResult OperationLegalizer::legalizePatternRootUpdates( - ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &impl, const SetVector<Operation *> &modifiedOps) { for (Operation *op : modifiedOps) { - if (failed(legalize(op, rewriter))) { - LLVM_DEBUG(logFailure( - impl.logger, "failed to legalize operation updated in-place '{0}'", - op->getName())); + if (failed(legalize(op))) { + LLVM_DEBUG( + logFailure(rewriter.getImpl().logger, + "failed to legalize operation updated in-place '{0}'", + op->getName())); return failure(); } } @@ -2825,21 +3073,22 @@ namespace mlir { // rewrite patterns. The conversion behaves differently depending on the // conversion mode. struct OperationConverter { - explicit OperationConverter(const ConversionTarget &target, + explicit OperationConverter(MLIRContext *ctx, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, const ConversionConfig &config, OpConversionMode mode) - : config(config), opLegalizer(target, patterns), mode(mode) {} + : rewriter(ctx, config), opLegalizer(rewriter, target, patterns), + mode(mode) {} /// Converts the given operations to the conversion target. LogicalResult convertOperations(ArrayRef<Operation *> ops); private: /// Converts an operation with the given rewriter. - LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); + LogicalResult convert(Operation *op); - /// Dialect conversion configuration. - ConversionConfig config; + /// The rewriter to use when converting operations. + ConversionPatternRewriter rewriter; /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; @@ -2849,10 +3098,11 @@ private: }; } // namespace mlir -LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, - Operation *op) { +LogicalResult OperationConverter::convert(Operation *op) { + const ConversionConfig &config = rewriter.getConfig(); + // Legalize the given operation. - if (failed(opLegalizer.legalize(op, rewriter))) { + if (failed(opLegalizer.legalize(op))) { // Handle the case of a failed conversion for each of the different modes. // Full conversions expect all operations to be converted. if (mode == OpConversionMode::Full) @@ -2928,7 +3178,6 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter, } LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { - assert(!ops.empty() && "expected at least one operation"); const ConversionTarget &target = opLegalizer.getTarget(); // Compute the set of operations and blocks to convert. @@ -2947,11 +3196,10 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { } // Convert each operation and discard rewrites on failure. - ConversionPatternRewriter rewriter(ops.front()->getContext(), config); ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); for (auto *op : toConvert) { - if (failed(convert(rewriter, op))) { + if (failed(convert(op))) { // Dialect conversion failed. if (rewriterImpl.config.allowPatternRollback) { // Rollback is allowed: restore the original IR. @@ -2986,13 +3234,16 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) { castOp->removeAttr(kPureTypeConversionMarker); // Try to legalize all unresolved materializations. - if (config.buildMaterializations) { - IRRewriter rewriter(rewriterImpl.context, config.listener); + if (rewriter.getConfig().buildMaterializations) { + // Use a new rewriter, so the modifications are not tracked for rollback + // purposes etc. + IRRewriter irRewriter(rewriterImpl.rewriter.getContext(), + rewriter.getConfig().listener); for (UnrealizedConversionCastOp castOp : remainingCastOps) { auto it = materializations.find(castOp); assert(it != materializations.end() && "inconsistent state"); - if (failed( - legalizeUnresolvedMaterialization(rewriter, castOp, it->second))) + if (failed(legalizeUnresolvedMaterialization(irRewriter, castOp, + it->second))) return failure(); } } @@ -3159,6 +3410,27 @@ LogicalResult TypeConverter::convertType(Type t, return failure(); } +LogicalResult TypeConverter::convertType(Value v, + SmallVectorImpl<Type> &results) const { + assert(v && "expected non-null value"); + + // If this type converter does not have context-aware type conversions, call + // the type-based overload, which has caching. + if (!hasContextAwareTypeConversions) + return convertType(v.getType(), results); + + // Walk the added converters in reverse order to apply the most recently + // registered first. + for (const ConversionCallbackFn &converter : llvm::reverse(conversions)) { + if (std::optional<LogicalResult> result = converter(v, results)) { + if (!succeeded(*result)) + return failure(); + return success(); + } + } + return failure(); +} + Type TypeConverter::convertType(Type t) const { // Use the multi-type result version to convert the type. SmallVector<Type, 1> results; @@ -3169,6 +3441,16 @@ Type TypeConverter::convertType(Type t) const { return results.size() == 1 ? results.front() : nullptr; } +Type TypeConverter::convertType(Value v) const { + // Use the multi-type result version to convert the type. + SmallVector<Type, 1> results; + if (failed(convertType(v, results))) + return nullptr; + + // Check to ensure that only one type was produced. + return results.size() == 1 ? results.front() : nullptr; +} + LogicalResult TypeConverter::convertTypes(TypeRange types, SmallVectorImpl<Type> &results) const { @@ -3178,21 +3460,38 @@ TypeConverter::convertTypes(TypeRange types, return success(); } +LogicalResult +TypeConverter::convertTypes(ValueRange values, + SmallVectorImpl<Type> &results) const { + for (Value value : values) + if (failed(convertType(value, results))) + return failure(); + return success(); +} + bool TypeConverter::isLegal(Type type) const { return convertType(type) == type; } + +bool TypeConverter::isLegal(Value value) const { + return convertType(value) == value.getType(); +} + bool TypeConverter::isLegal(Operation *op) const { - return isLegal(op->getOperandTypes()) && isLegal(op->getResultTypes()); + return isLegal(op->getOperands()) && isLegal(op->getResults()); } bool TypeConverter::isLegal(Region *region) const { - return llvm::all_of(*region, [this](Block &block) { - return isLegal(block.getArgumentTypes()); - }); + return llvm::all_of( + *region, [this](Block &block) { return isLegal(block.getArguments()); }); } bool TypeConverter::isSignatureLegal(FunctionType ty) const { - return isLegal(llvm::concat<const Type>(ty.getInputs(), ty.getResults())); + if (!isLegal(ty.getInputs())) + return false; + if (!isLegal(ty.getResults())) + return false; + return true; } LogicalResult @@ -3220,6 +3519,31 @@ TypeConverter::convertSignatureArgs(TypeRange types, return failure(); return success(); } +LogicalResult +TypeConverter::convertSignatureArg(unsigned inputNo, Value value, + SignatureConversion &result) const { + // Try to convert the given input type. + SmallVector<Type, 1> convertedTypes; + if (failed(convertType(value, convertedTypes))) + return failure(); + + // If this argument is being dropped, there is nothing left to do. + if (convertedTypes.empty()) + return success(); + + // Otherwise, add the new inputs. + result.addInputs(inputNo, convertedTypes); + return success(); +} +LogicalResult +TypeConverter::convertSignatureArgs(ValueRange values, + SignatureConversion &result, + unsigned origInputOffset) const { + for (unsigned i = 0, e = values.size(); i != e; ++i) + if (failed(convertSignatureArg(origInputOffset + i, values[i], result))) + return failure(); + return success(); +} Value TypeConverter::materializeSourceConversion(OpBuilder &builder, Location loc, Type resultType, @@ -3263,7 +3587,7 @@ SmallVector<Value> TypeConverter::materializeTargetConversion( std::optional<TypeConverter::SignatureConversion> TypeConverter::convertBlockSignature(Block *block) const { SignatureConversion conversion(block->getNumArguments()); - if (failed(convertSignatureArgs(block->getArgumentTypes(), conversion))) + if (failed(convertSignatureArgs(block->getArguments(), conversion))) return std::nullopt; return conversion; } @@ -3388,7 +3712,7 @@ mlir::convertOpResultTypes(Operation *op, ValueRange operands, newOp.addOperands(operands); SmallVector<Type> newResultTypes; - if (failed(converter.convertTypes(op->getResultTypes(), newResultTypes))) + if (failed(converter.convertTypes(op->getResults(), newResultTypes))) return rewriter.notifyMatchFailure(loc, "couldn't convert return types"); newOp.addTypes(newResultTypes); @@ -3661,7 +3985,8 @@ static LogicalResult applyConversion(ArrayRef<Operation *> ops, SmallVector<IRUnit> irUnits(ops.begin(), ops.end()); ctx->executeAction<ApplyConversionAction>( [&] { - OperationConverter opConverter(target, patterns, config, mode); + OperationConverter opConverter(ops.front()->getContext(), target, + patterns, config, mode); status = opConverter.convertOperations(ops); }, irUnits); diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 607b86c..0324588 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -15,6 +15,8 @@ #include "mlir/Config/mlir-config.h" #include "mlir/IR/Action.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" @@ -23,7 +25,7 @@ #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" @@ -178,9 +180,8 @@ static Operation *getDumpRootOp(Operation *op) { return op; } static void logSuccessfulFolding(Operation *op) { - llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n"; - op->dump(); - llvm::dbgs() << "\n\n"; + LDBG() << "// *** IR Dump After Successful Folding ***\n" + << OpWithFlags(op, OpPrintingFlags().elideLargeElementsAttrs()); } #endif // NDEBUG @@ -394,8 +395,12 @@ private: function_ref<void(Diagnostic &)> reasonCallback) override; #ifndef NDEBUG + /// A raw output stream used to prefix the debug log. + + llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(), + llvm::dbgs()}; /// A logger used to emit information during the application process. - llvm::ScopedPrinter logger{llvm::dbgs()}; + llvm::ScopedPrinter logger{os}; #endif /// The low-level pattern applicator. @@ -871,7 +876,18 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && { ctx->executeAction<GreedyPatternRewriteIteration>( [&] { - continueRewrites = processWorklist(); + continueRewrites = false; + + // Erase unreachable blocks + // Operations like: + // %add = arith.addi %add, %add : i64 + // are legal in unreachable code. Unfortunately many patterns would be + // unsafe to apply on such IR and can lead to crashes or infinite + // loops. + continueRewrites |= + succeeded(eraseUnreachableBlocks(rewriter, region)); + + continueRewrites |= processWorklist(); // After applying patterns, make sure that the CFG of each of the // regions is kept up to date. @@ -917,10 +933,9 @@ mlir::applyPatternsGreedily(Region ®ion, RegionPatternRewriteDriver driver(region.getContext(), patterns, config, region); LogicalResult converged = std::move(driver).simplify(changed); - LLVM_DEBUG(if (failed(converged)) { - llvm::dbgs() << "The pattern rewrite did not converge after scanning " - << config.getMaxIterations() << " times\n"; - }); + if (failed(converged)) + LDBG() << "The pattern rewrite did not converge after scanning " + << config.getMaxIterations() << " times"; return converged; } @@ -1052,9 +1067,8 @@ LogicalResult mlir::applyOpPatternsGreedily( LogicalResult converged = std::move(driver).simplify(ops, changed); if (allErased) *allErased = surviving.empty(); - LLVM_DEBUG(if (failed(converged)) { - llvm::dbgs() << "The pattern rewrite did not converge after " - << config.getMaxNumRewrites() << " rewrites"; - }); + if (failed(converged)) + LDBG() << "The pattern rewrite did not converge after " + << config.getMaxNumRewrites() << " rewrites"; return converged; } diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp index eeb4052..73107cf 100644 --- a/mlir/lib/Transforms/Utils/InliningUtils.cpp +++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp @@ -13,10 +13,12 @@ #include "mlir/Transforms/InliningUtils.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/Operation.h" #include "mlir/Interfaces/CallInterfaces.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <optional> @@ -182,13 +184,16 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src, IRMapping &valueMapping) { for (auto &block : *src) { for (auto &op : block) { + // UnrealizedConversionCastOp is inlineable but cannot implement the + // inliner interface due to layering constraints. + if (isa<UnrealizedConversionCastOp>(op)) + continue; + // Check this operation. if (!interface.isLegalToInline(&op, insertRegion, shouldCloneInlinedRegion, valueMapping)) { - LLVM_DEBUG({ - llvm::dbgs() << "* Illegal to inline because of op: "; - op.dump(); - }); + LDBG() << "* Illegal to inline because of op: " + << OpWithFlags(&op, OpPrintingFlags().skipRegions()); return false; } // Check any nested regions. diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp index cb3f2c5..111f58e 100644 --- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp +++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp @@ -13,11 +13,13 @@ #include "mlir/Transforms/LoopInvariantCodeMotionUtils.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SubsetOpInterface.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include <queue> #define DEBUG_TYPE "licm" @@ -64,8 +66,7 @@ size_t mlir::moveLoopInvariantCode( size_t numMoved = 0; for (Region *region : regions) { - LLVM_DEBUG(llvm::dbgs() << "Original loop:\n" - << *region->getParentOp() << "\n"); + LDBG() << "Original loop:\n" << *region->getParentOp(); std::queue<Operation *> worklist; // Add top-level operations in the loop body to the worklist. @@ -83,12 +84,13 @@ size_t mlir::moveLoopInvariantCode( if (op->getParentRegion() != region) continue; - LLVM_DEBUG(llvm::dbgs() << "Checking op: " << *op << "\n"); + LDBG() << "Checking op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); if (!shouldMoveOutOfRegion(op, region) || !canBeHoisted(op, definedOutside)) continue; - LLVM_DEBUG(llvm::dbgs() << "Moving loop-invariant op: " << *op << "\n"); + LDBG() << "Moving loop-invariant op: " << *op; moveOutOfRegion(op, region); ++numMoved; @@ -322,7 +324,7 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter, LoopLikeOpInterface loopLike, BlockArgument iterArg) { assert(iterArg.getOwner()->getParentOp() == loopLike && "invalid iter_arg"); - auto it = llvm::find(loopLike.getRegionIterArgs(), iterArg); + BlockArgument *it = llvm::find(loopLike.getRegionIterArgs(), iterArg); int64_t iterArgIdx = std::distance(loopLike.getRegionIterArgs().begin(), it); MatchingSubsets subsets; if (failed(subsets.populateSubsetOpsAtIterArg(loopLike, iterArg))) diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp index a1d975d..31ae1d1 100644 --- a/mlir/lib/Transforms/Utils/RegionUtils.cpp +++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp @@ -23,12 +23,15 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include <deque> #include <iterator> using namespace mlir; +#define DEBUG_TYPE "region-utils" + void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement, Region ®ion) { for (auto &use : llvm::make_early_inc_range(orig.getUses())) { @@ -182,19 +185,34 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove( // TODO: We could likely merge this with the DCE algorithm below. LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, MutableArrayRef<Region> regions) { + LDBG() << "Starting eraseUnreachableBlocks with " << regions.size() + << " regions"; + // Set of blocks found to be reachable within a given region. llvm::df_iterator_default_set<Block *, 16> reachable; // If any blocks were found to be dead. - bool erasedDeadBlocks = false; + int erasedDeadBlocks = 0; SmallVector<Region *, 1> worklist; worklist.reserve(regions.size()); for (Region ®ion : regions) worklist.push_back(®ion); + + LDBG(2) << "Initial worklist size: " << worklist.size(); + while (!worklist.empty()) { Region *region = worklist.pop_back_val(); - if (region->empty()) + if (region->empty()) { + LDBG(2) << "Skipping empty region"; continue; + } + + LDBG(2) << "Processing region with " << region->getBlocks().size() + << " blocks"; + if (region->getParentOp()) + LDBG(2) << " -> for operation: " + << OpWithFlags(region->getParentOp(), + OpPrintingFlags().skipRegions()); // If this is a single block region, just collect the nested regions. if (region->hasOneBlock()) { @@ -209,13 +227,17 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, for (Block *block : depth_first_ext(®ion->front(), reachable)) (void)block /* Mark all reachable blocks */; + LDBG(2) << "Found " << reachable.size() << " reachable blocks out of " + << region->getBlocks().size() << " total blocks"; + // Collect all of the dead blocks and push the live regions onto the // worklist. for (Block &block : llvm::make_early_inc_range(*region)) { if (!reachable.count(&block)) { + LDBG() << "Erasing unreachable block: " << █ block.dropAllDefinedValueUses(); rewriter.eraseBlock(&block); - erasedDeadBlocks = true; + ++erasedDeadBlocks; continue; } @@ -226,7 +248,10 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter, } } - return success(erasedDeadBlocks); + LDBG() << "Finished eraseUnreachableBlocks, erased " << erasedDeadBlocks + << " dead blocks"; + + return success(erasedDeadBlocks > 0); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp index ee5c642..1382550 100644 --- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp @@ -13,18 +13,40 @@ #include "mlir/Transforms/WalkPatternRewriteDriver.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Verifier.h" #include "mlir/IR/Visitors.h" #include "mlir/Rewrite/PatternApplicator.h" -#include "llvm/Support/Debug.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/ErrorHandling.h" #define DEBUG_TYPE "walk-rewriter" namespace mlir { +// Find all reachable blocks in the region and add them to the visitedBlocks +// set. +static void findReachableBlocks(Region ®ion, + DenseSet<Block *> &reachableBlocks) { + Block *entryBlock = ®ion.front(); + reachableBlocks.insert(entryBlock); + // Traverse the CFG and add all reachable blocks to the blockList. + SmallVector<Block *> worklist({entryBlock}); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + Operation *terminator = &block->back(); + for (Block *successor : terminator->getSuccessors()) { + if (reachableBlocks.contains(successor)) + continue; + worklist.push_back(successor); + reachableBlocks.insert(successor); + } + } +} + namespace { struct WalkAndApplyPatternsAction final : tracing::ActionImpl<WalkAndApplyPatternsAction> { @@ -88,20 +110,104 @@ void walkAndApplyPatterns(Operation *op, PatternApplicator applicator(patterns); applicator.applyDefaultCostModel(); + // Iterator on all reachable operations in the region. + // Also keep track if we visited the nested regions of the current op + // already to drive the post-order traversal. + struct RegionReachableOpIterator { + RegionReachableOpIterator(Region *region) : region(region) { + regionIt = region->begin(); + if (regionIt != region->end()) + blockIt = regionIt->begin(); + if (!llvm::hasSingleElement(*region)) + findReachableBlocks(*region, reachableBlocks); + } + // Advance the iterator to the next reachable operation. + void advance() { + assert(regionIt != region->end()); + hasVisitedRegions = false; + if (blockIt == regionIt->end()) { + ++regionIt; + while (regionIt != region->end() && + !reachableBlocks.contains(&*regionIt)) + ++regionIt; + if (regionIt != region->end()) + blockIt = regionIt->begin(); + return; + } + ++blockIt; + if (blockIt != regionIt->end()) { + LDBG() << "Incrementing block iterator, next op: " + << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions()); + } + } + // The region we're iterating over. + Region *region; + // The Block currently being iterated over. + Region::iterator regionIt; + // The Operation currently being iterated over. + Block::iterator blockIt; + // The set of blocks that are reachable in the current region. + DenseSet<Block *> reachableBlocks; + // Whether we've visited the nested regions of the current op already. + bool hasVisitedRegions = false; + }; + + // Worklist of regions to visit to drive the post-order traversal. + SmallVector<RegionReachableOpIterator> worklist; + + LDBG() << "Starting walk-based pattern rewrite driver"; ctx->executeAction<WalkAndApplyPatternsAction>( [&] { + // Perform a post-order traversal of the regions, visiting each + // reachable operation. for (Region ®ion : op->getRegions()) { - region.walk([&](Operation *visitedOp) { - LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print( - llvm::dbgs(), OpPrintingFlags().skipRegions()); - llvm::dbgs() << "\n";); + assert(worklist.empty()); + if (region.empty()) + continue; + + // Prime the worklist with the entry block of this region. + worklist.push_back({®ion}); + while (!worklist.empty()) { + RegionReachableOpIterator &it = worklist.back(); + if (it.regionIt == it.region->end()) { + // We're done with this region. + worklist.pop_back(); + continue; + } + if (it.blockIt == it.regionIt->end()) { + // We're done with this block. + it.advance(); + continue; + } + Operation *op = &*it.blockIt; + // If we haven't visited the nested regions of this op yet, + // enqueue them. + if (!it.hasVisitedRegions) { + it.hasVisitedRegions = true; + for (Region &nestedRegion : llvm::reverse(op->getRegions())) { + if (nestedRegion.empty()) + continue; + worklist.push_back({&nestedRegion}); + } + } + // If we're not at the back of the worklist, we've enqueued some + // nested region for processing. We'll come back to this op later + // (post-order) + if (&it != &worklist.back()) + continue; + + // Preemptively increment the iterator, in case the current op + // would be erased. + it.advance(); + + LDBG() << "Visiting op: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - erasedListener.visitedOp = visitedOp; + erasedListener.visitedOp = op; #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS - if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) { - LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";); - } - }); + if (succeeded(applicator.matchAndRewrite(op, rewriter))) + LDBG() << "\tOp matched and rewritten"; + } } }, {op}); |