diff options
Diffstat (limited to 'mlir/lib')
-rw-r--r-- | mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp | 4 | ||||
-rw-r--r-- | mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp | 233 | ||||
-rw-r--r-- | mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 177 | ||||
-rw-r--r-- | mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp | 85 | ||||
-rw-r--r-- | mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 74 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp | 64 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp | 28 | ||||
-rw-r--r-- | mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp | 15 | ||||
-rw-r--r-- | mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp | 218 | ||||
-rw-r--r-- | mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp | 1 | ||||
-rw-r--r-- | mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp | 7 | ||||
-rw-r--r-- | mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp | 111 | ||||
-rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 120 | ||||
-rw-r--r-- | mlir/lib/Target/Cpp/TranslateToCpp.cpp | 47 | ||||
-rw-r--r-- | mlir/lib/Target/LLVMIR/DebugImporter.cpp | 3 | ||||
-rw-r--r-- | mlir/lib/Transforms/GenerateRuntimeVerification.cpp | 6 |
16 files changed, 971 insertions, 222 deletions
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp index de1ed39..377f7eb 100644 --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -522,14 +522,14 @@ void DeadCodeAnalysis::visitRegionBranchEdges( // Mark the entry block as executable. auto *state = getOrCreate<Executable>(point); propagateIfChanged(state, state->setToLive()); - LDBG() << "Marked region successor live: " << point; + LDBG() << "Marked region successor live: " << *point; // Add the parent op as a predecessor. auto *predecessors = getOrCreate<PredecessorState>(point); propagateIfChanged( predecessors, predecessors->join(predecessorOp, successor.getSuccessorInputs())); - LDBG() << "Added region branch as predecessor for successor: " << point; + LDBG() << "Added region branch as predecessor for successor: " << *point; } } diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp index b51465b..daa3db5 100644 --- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp @@ -17,44 +17,74 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Support/DebugLog.h" #include <cassert> #include <optional> using namespace mlir; using namespace mlir::dataflow; +#define DEBUG_TYPE "dense-analysis" + //===----------------------------------------------------------------------===// // AbstractDenseForwardDataFlowAnalysis //===----------------------------------------------------------------------===// void AbstractDenseForwardDataFlowAnalysis::initializeEquivalentLatticeAnchor( Operation *top) { + LDBG() << "initializeEquivalentLatticeAnchor: " + << OpWithFlags(top, OpPrintingFlags().skipRegions()); top->walk([&](Operation *op) { - if (isa<RegionBranchOpInterface, CallOpInterface>(op)) + if (isa<RegionBranchOpInterface, CallOpInterface>(op)) { + LDBG() << " Skipping " + << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " (region branch or call)"; return; + } + LDBG() << " Building equivalent lattice anchor for " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); buildOperationEquivalentLatticeAnchor(op); }); } LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) { + LDBG() << "initialize (forward): " + << OpWithFlags(top, OpPrintingFlags().skipRegions()); // Visit every operation and block. - if (failed(processOperation(top))) + if (failed(processOperation(top))) { + LDBG() << " Failed to process top-level operation"; return failure(); + } for (Region ®ion : top->getRegions()) { + LDBG() << " Processing region with " << region.getBlocks().size() + << " blocks"; for (Block &block : region) { + LDBG() << " Processing block with " << block.getOperations().size() + << " operations"; visitBlock(&block); - for (Operation &op : block) - if (failed(initialize(&op))) + for (Operation &op : block) { + LDBG() << " Initializing operation: " + << OpWithFlags(&op, OpPrintingFlags().skipRegions()); + if (failed(initialize(&op))) { + LDBG() << " Failed to initialize operation"; return failure(); + } + } } } + LDBG() << " Forward initialization completed successfully"; return success(); } LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint *point) { - if (!point->isBlockStart()) + LDBG() << "visit (forward): " << *point; + if (!point->isBlockStart()) { + LDBG() << " Processing operation: " + << OpWithFlags(point->getPrevOp(), OpPrintingFlags().skipRegions()); return processOperation(point->getPrevOp()); + } + LDBG() << " Visiting block: " << point->getBlock(); visitBlock(point->getBlock()); return success(); } @@ -62,6 +92,11 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint *point) { void AbstractDenseForwardDataFlowAnalysis::visitCallOperation( CallOpInterface call, const AbstractDenseLattice &before, AbstractDenseLattice *after) { + LDBG() << "visitCallOperation (forward): " + << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions()); + LDBG() << " before state: " << before; + LDBG() << " after state: " << *after; + // Allow for customizing the behavior of calls to external symbols, including // when the analysis is explicitly marked as non-interprocedural. auto isExternalCallable = [&]() { @@ -70,6 +105,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation( return callable && !callable.getCallableRegion(); }; if (!getSolverConfig().isInterprocedural() || isExternalCallable()) { + LDBG() << " Handling as external callee (non-interprocedural or external)"; return visitCallControlFlowTransfer( call, CallControlFlowAction::ExternalCallee, before, after); } @@ -78,10 +114,16 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation( getProgramPointAfter(call.getOperation()), getProgramPointAfter(call)); // Otherwise, if not all return sites are known, then conservatively assume we // can't reason about the data-flow. - if (!predecessors->allPredecessorsKnown()) + if (!predecessors->allPredecessorsKnown()) { + LDBG() << " Not all predecessors known, setting to entry state"; return setToEntryState(after); + } + LDBG() << " Processing " << predecessors->getKnownPredecessors().size() + << " known predecessors"; for (Operation *predecessor : predecessors->getKnownPredecessors()) { + LDBG() << " Processing predecessor: " + << OpWithFlags(predecessor, OpPrintingFlags().skipRegions()); // Get the lattices at callee return: // // func.func @callee() { @@ -99,6 +141,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation( const AbstractDenseLattice *latticeAtCalleeReturn = getLatticeFor(getProgramPointAfter(call.getOperation()), getProgramPointAfter(predecessor)); + LDBG() << " Lattice at callee return: " << *latticeAtCalleeReturn; visitCallControlFlowTransfer(call, CallControlFlowAction::ExitCallee, *latticeAtCalleeReturn, latticeAfterCall); } @@ -106,12 +149,16 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation( LogicalResult AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) { + LDBG() << "processOperation (forward): " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); ProgramPoint *point = getProgramPointAfter(op); // If the containing block is not executable, bail out. if (op->getBlock() != nullptr && !getOrCreateFor<Executable>(point, getProgramPointBefore(op->getBlock())) - ->isLive()) + ->isLive()) { + LDBG() << " Block not executable, skipping operation"; return success(); + } // Get the dense lattice to update. AbstractDenseLattice *after = getLattice(point); @@ -119,10 +166,13 @@ AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) { // Get the dense state before the execution of the op. const AbstractDenseLattice *before = getLatticeFor(point, getProgramPointBefore(op)); + LDBG() << " before state: " << *before; + LDBG() << " after state: " << *after; // If this op implements region control-flow, then control-flow dictates its // transfer function. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { + LDBG() << " Processing as region branch operation"; visitRegionBranchOperation(point, branch, after); return success(); } @@ -130,41 +180,57 @@ AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) { // If this is a call operation, then join its lattices across known return // sites. if (auto call = dyn_cast<CallOpInterface>(op)) { + LDBG() << " Processing as call operation"; visitCallOperation(call, *before, after); return success(); } // Invoke the operation transfer function. + LDBG() << " Invoking operation transfer function"; return visitOperationImpl(op, *before, after); } void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) { + LDBG() << "visitBlock (forward): " << block; // If the block is not executable, bail out. ProgramPoint *point = getProgramPointBefore(block); - if (!getOrCreateFor<Executable>(point, point)->isLive()) + if (!getOrCreateFor<Executable>(point, point)->isLive()) { + LDBG() << " Block not executable, skipping"; return; + } // Get the dense lattice to update. AbstractDenseLattice *after = getLattice(point); + LDBG() << " Block lattice state: " << *after; // The dense lattices of entry blocks are set by region control-flow or the // callgraph. if (block->isEntryBlock()) { + LDBG() << " Processing entry block"; // Check if this block is the entry block of a callable region. auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); if (callable && callable.getCallableRegion() == block->getParent()) { + LDBG() << " Entry block of callable region"; const auto *callsites = getOrCreateFor<PredecessorState>( point, getProgramPointAfter(callable)); // If not all callsites are known, conservatively mark all lattices as // having reached their pessimistic fixpoints. Do the same if // interprocedural analysis is not enabled. if (!callsites->allPredecessorsKnown() || - !getSolverConfig().isInterprocedural()) + !getSolverConfig().isInterprocedural()) { + LDBG() << " Not all callsites known or non-interprocedural, setting " + "to entry state"; return setToEntryState(after); + } + LDBG() << " Processing " << callsites->getKnownPredecessors().size() + << " known callsites"; for (Operation *callsite : callsites->getKnownPredecessors()) { + LDBG() << " Processing callsite: " + << OpWithFlags(callsite, OpPrintingFlags().skipRegions()); // Get the dense lattice before the callsite. const AbstractDenseLattice *before; before = getLatticeFor(point, getProgramPointBefore(callsite)); + LDBG() << " Lattice before callsite: " << *before; visitCallControlFlowTransfer(cast<CallOpInterface>(callsite), CallControlFlowAction::EnterCallee, @@ -174,23 +240,32 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) { } // Check if we can reason about the control-flow. - if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) + if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { + LDBG() << " Entry block of region branch operation"; return visitRegionBranchOperation(point, branch, after); + } // Otherwise, we can't reason about the data-flow. + LDBG() << " Cannot reason about data-flow, setting to entry state"; return setToEntryState(after); } // Join the state with the state after the block's predecessors. + LDBG() << " Joining state from " + << std::distance(block->pred_begin(), block->pred_end()) + << " predecessors"; for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end(); it != e; ++it) { // Skip control edges that aren't executable. Block *predecessor = *it; if (!getOrCreateFor<Executable>( point, getLatticeAnchor<CFGEdge>(predecessor, block)) - ->isLive()) + ->isLive()) { + LDBG() << " Skipping non-executable edge from " << predecessor; continue; + } + LDBG() << " Joining state from predecessor " << predecessor; // Merge in the state from the predecessor's terminator. join(after, *getLatticeFor( point, getProgramPointAfter(predecessor->getTerminator()))); @@ -200,20 +275,34 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) { void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation( ProgramPoint *point, RegionBranchOpInterface branch, AbstractDenseLattice *after) { + LDBG() << "visitRegionBranchOperation (forward): " + << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions()); + LDBG() << " point: " << *point; + LDBG() << " after state: " << *after; + // Get the terminator predecessors. const auto *predecessors = getOrCreateFor<PredecessorState>(point, point); assert(predecessors->allPredecessorsKnown() && "unexpected unresolved region successors"); + LDBG() << " Processing " << predecessors->getKnownPredecessors().size() + << " known predecessors"; for (Operation *op : predecessors->getKnownPredecessors()) { + LDBG() << " Processing predecessor: " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); const AbstractDenseLattice *before; // If the predecessor is the parent, get the state before the parent. if (op == branch) { + LDBG() << " Predecessor is the branch itself, getting state before " + "parent"; before = getLatticeFor(point, getProgramPointBefore(op)); // Otherwise, get the state after the terminator. } else { + LDBG() + << " Predecessor is terminator, getting state after terminator"; before = getLatticeFor(point, getProgramPointAfter(op)); } + LDBG() << " before state: " << *before; // This function is called in two cases: // 1. when visiting the block (point = block start); @@ -231,19 +320,31 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation( std::optional<unsigned> regionFrom = op == branch ? std::optional<unsigned>() : op->getBlock()->getParent()->getRegionNumber(); + LDBG() << " regionFrom: " + << (regionFrom ? std::to_string(*regionFrom) : "parent"); + if (point->isBlockStart()) { unsigned regionTo = point->getBlock()->getParent()->getRegionNumber(); + LDBG() << " Point is block start, regionTo: " << regionTo; + LDBG() << " Calling visitRegionBranchControlFlowTransfer with " + "regionFrom/regionTo"; visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo, *before, after); } else { assert(point->getPrevOp() == branch && "expected to be visiting the branch itself"); + LDBG() << " Point is not block start, checking if predecessor is " + "region or op itself"; // Only need to call the arc transfer when the predecessor is the region // or the op itself, not the previous op. if (op->getParentOp() == branch || op == branch) { + LDBG() << " Predecessor is region or op itself, calling " + "visitRegionBranchControlFlowTransfer"; visitRegionBranchControlFlowTransfer( branch, regionFrom, /*regionTo=*/std::nullopt, *before, after); } else { + LDBG() + << " Predecessor is not region or op itself, performing join"; join(after, *before); } } @@ -256,35 +357,61 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation( void AbstractDenseBackwardDataFlowAnalysis::initializeEquivalentLatticeAnchor( Operation *top) { + LDBG() << "initializeEquivalentLatticeAnchor (backward): " + << OpWithFlags(top, OpPrintingFlags().skipRegions()); top->walk([&](Operation *op) { - if (isa<RegionBranchOpInterface, CallOpInterface>(op)) + if (isa<RegionBranchOpInterface, CallOpInterface>(op)) { + LDBG() << " Skipping " + << OpWithFlags(op, OpPrintingFlags().skipRegions()) + << " (region branch or call)"; return; + } + LDBG() << " Building equivalent lattice anchor for " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); buildOperationEquivalentLatticeAnchor(op); }); } LogicalResult AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) { + LDBG() << "initialize (backward): " + << OpWithFlags(top, OpPrintingFlags().skipRegions()); // Visit every operation and block. - if (failed(processOperation(top))) + if (failed(processOperation(top))) { + LDBG() << " Failed to process top-level operation"; return failure(); + } for (Region ®ion : top->getRegions()) { + LDBG() << " Processing region with " << region.getBlocks().size() + << " blocks"; for (Block &block : region) { + LDBG() << " Processing block with " << block.getOperations().size() + << " operations"; visitBlock(&block); for (Operation &op : llvm::reverse(block)) { - if (failed(initialize(&op))) + LDBG() << " Initializing operation (backward): " + << OpWithFlags(&op, OpPrintingFlags().skipRegions()); + if (failed(initialize(&op))) { + LDBG() << " Failed to initialize operation"; return failure(); + } } } } + LDBG() << " Backward initialization completed successfully"; return success(); } LogicalResult AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) { - if (!point->isBlockEnd()) + LDBG() << "visit (backward): " << *point; + if (!point->isBlockEnd()) { + LDBG() << " Processing operation: " + << OpWithFlags(point->getNextOp(), OpPrintingFlags().skipRegions()); return processOperation(point->getNextOp()); + } + LDBG() << " Visiting block: " << point->getBlock(); visitBlock(point->getBlock()); return success(); } @@ -292,28 +419,47 @@ AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) { void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( CallOpInterface call, const AbstractDenseLattice &after, AbstractDenseLattice *before) { + LDBG() << "visitCallOperation (backward): " + << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions()); + LDBG() << " after state: " << after; + LDBG() << " before state: " << *before; + // If the solver is not interprocedural, let the hook handle it as an external // callee. - if (!getSolverConfig().isInterprocedural()) + if (!getSolverConfig().isInterprocedural()) { + LDBG() << " Non-interprocedural analysis, handling as external callee"; return visitCallControlFlowTransfer( call, CallControlFlowAction::ExternalCallee, after, before); + } // Find the callee. Operation *callee = call.resolveCallableInTable(&symbolTable); + if (callee) { + LDBG() << " Resolved callee: " + << OpWithFlags(callee, OpPrintingFlags().skipRegions()); + } else { + LDBG() << " Resolved callee: null"; + } auto callable = dyn_cast_or_null<CallableOpInterface>(callee); // No region means the callee is only declared in this module. // If that is the case or if the solver is not interprocedural, // let the hook handle it. - if (callable && - (!callable.getCallableRegion() || callable.getCallableRegion()->empty())) + if (callable && (!callable.getCallableRegion() || + callable.getCallableRegion()->empty())) { + LDBG() << " Callee has no region or empty region, handling as external " + "callee"; return visitCallControlFlowTransfer( call, CallControlFlowAction::ExternalCallee, after, before); + } - if (!callable) + if (!callable) { + LDBG() << " No callable found, setting to exit state"; return setToExitState(before); + } Region *region = callable.getCallableRegion(); + LDBG() << " Processing callable with region"; // Call-level control flow specifies the data flow here. // @@ -332,6 +478,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( ProgramPoint *calleeEntry = getProgramPointBefore(calleeEntryBlock); const AbstractDenseLattice &latticeAtCalleeEntry = *getLatticeFor(getProgramPointBefore(call.getOperation()), calleeEntry); + LDBG() << " Lattice at callee entry: " << latticeAtCalleeEntry; AbstractDenseLattice *latticeBeforeCall = before; visitCallControlFlowTransfer(call, CallControlFlowAction::EnterCallee, latticeAtCalleeEntry, latticeBeforeCall); @@ -339,12 +486,16 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation( LogicalResult AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) { + LDBG() << "processOperation (backward): " + << OpWithFlags(op, OpPrintingFlags().skipRegions()); ProgramPoint *point = getProgramPointBefore(op); // If the containing block is not executable, bail out. if (op->getBlock() != nullptr && !getOrCreateFor<Executable>(point, getProgramPointBefore(op->getBlock())) - ->isLive()) + ->isLive()) { + LDBG() << " Block not executable, skipping operation"; return success(); + } // Get the dense lattice to update. AbstractDenseLattice *before = getLattice(point); @@ -352,30 +503,39 @@ AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) { // Get the dense state after execution of this op. const AbstractDenseLattice *after = getLatticeFor(point, getProgramPointAfter(op)); + LDBG() << " before state: " << *before; + LDBG() << " after state: " << *after; // Special cases where control flow may dictate data flow. if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) { + LDBG() << " Processing as region branch operation"; visitRegionBranchOperation(point, branch, RegionBranchPoint::parent(), before); return success(); } if (auto call = dyn_cast<CallOpInterface>(op)) { + LDBG() << " Processing as call operation"; visitCallOperation(call, *after, before); return success(); } // Invoke the operation transfer function. + LDBG() << " Invoking operation transfer function"; return visitOperationImpl(op, *after, before); } void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { + LDBG() << "visitBlock (backward): " << block; ProgramPoint *point = getProgramPointAfter(block); // If the block is not executable, bail out. if (!getOrCreateFor<Executable>(point, getProgramPointBefore(block)) - ->isLive()) + ->isLive()) { + LDBG() << " Block not executable, skipping"; return; + } AbstractDenseLattice *before = getLattice(point); + LDBG() << " Block lattice state: " << *before; // We need "exit" blocks, i.e. the blocks that may return control to the // parent operation. @@ -391,23 +551,32 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { b->getTerminator()); }; if (isExitBlock(block)) { + LDBG() << " Processing exit block"; // If this block is exiting from a callable, the successors of exiting from // a callable are the successors of all call sites. And the call sites // themselves are predecessors of the callable. auto callable = dyn_cast<CallableOpInterface>(block->getParentOp()); if (callable && callable.getCallableRegion() == block->getParent()) { + LDBG() << " Exit block of callable region"; const auto *callsites = getOrCreateFor<PredecessorState>( point, getProgramPointAfter(callable)); // If not all call sites are known, conservative mark all lattices as // having reached their pessimistic fix points. if (!callsites->allPredecessorsKnown() || !getSolverConfig().isInterprocedural()) { + LDBG() << " Not all callsites known or non-interprocedural, setting " + "to exit state"; return setToExitState(before); } + LDBG() << " Processing " << callsites->getKnownPredecessors().size() + << " known callsites"; for (Operation *callsite : callsites->getKnownPredecessors()) { + LDBG() << " Processing callsite: " + << OpWithFlags(callsite, OpPrintingFlags().skipRegions()); const AbstractDenseLattice *after = getLatticeFor(point, getProgramPointAfter(callsite)); + LDBG() << " Lattice after callsite: " << *after; visitCallControlFlowTransfer(cast<CallOpInterface>(callsite), CallControlFlowAction::ExitCallee, *after, before); @@ -418,22 +587,29 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { // If this block is exiting from an operation with region-based control // flow, propagate the lattice back along the control flow edge. if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) { + LDBG() << " Exit block of region branch operation"; visitRegionBranchOperation(point, branch, block->getParent(), before); return; } // Cannot reason about successors of an exit block, set the pessimistic // fixpoint. + LDBG() << " Cannot reason about successors, setting to exit state"; return setToExitState(before); } // Meet the state with the state before block's successors. + LDBG() << " Meeting state from " << block->getSuccessors().size() + << " successors"; for (Block *successor : block->getSuccessors()) { if (!getOrCreateFor<Executable>(point, getLatticeAnchor<CFGEdge>(block, successor)) - ->isLive()) + ->isLive()) { + LDBG() << " Skipping non-executable edge to " << successor; continue; + } + LDBG() << " Meeting state from successor " << successor; // Merge in the state from the successor: either the first operation, or the // block itself when empty. meet(before, *getLatticeFor(point, getProgramPointBefore(successor))); @@ -443,28 +619,39 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) { void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation( ProgramPoint *point, RegionBranchOpInterface branch, RegionBranchPoint branchPoint, AbstractDenseLattice *before) { + LDBG() << "visitRegionBranchOperation (backward): " + << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions()); + LDBG() << " branchPoint: " << (branchPoint.isParent() ? "parent" : "region"); + LDBG() << " before state: " << *before; // The successors of the operation may be either the first operation of the // entry block of each possible successor region, or the next operation when // the branch is a successor of itself. SmallVector<RegionSuccessor> successors; branch.getSuccessorRegions(branchPoint, successors); + LDBG() << " Processing " << successors.size() << " successor regions"; for (const RegionSuccessor &successor : successors) { const AbstractDenseLattice *after; if (successor.isParent() || successor.getSuccessor()->empty()) { + LDBG() << " Successor is parent or empty region"; after = getLatticeFor(point, getProgramPointAfter(branch)); } else { Region *successorRegion = successor.getSuccessor(); assert(!successorRegion->empty() && "unexpected empty successor region"); Block *successorBlock = &successorRegion->front(); + LDBG() << " Successor region with " + << successorRegion->getBlocks().size() << " blocks"; if (!getOrCreateFor<Executable>(point, getProgramPointBefore(successorBlock)) - ->isLive()) + ->isLive()) { + LDBG() << " Successor block not executable, skipping"; continue; + } after = getLatticeFor(point, getProgramPointBefore(successorBlock)); } + LDBG() << " After state: " << *after; visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after, before); diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp index 1f239aa..519d9c8 100644 --- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp +++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/Passes.h" +#include "llvm/Support/LogicalResult.h" namespace mlir { #define GEN_PASS_DEF_SCFTOEMITC @@ -106,7 +107,7 @@ static void assignValues(ValueRange values, ValueRange variables, emitc::AssignOp::create(rewriter, loc, var, value); } -SmallVector<Value> loadValues(const SmallVector<Value> &variables, +SmallVector<Value> loadValues(ArrayRef<Value> variables, PatternRewriter &rewriter, Location loc) { return llvm::map_to_vector<>(variables, [&](Value var) { Type type = cast<emitc::LValueType>(var.getType()).getValueType(); @@ -116,16 +117,15 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables, static LogicalResult lowerYield(Operation *op, ValueRange resultVariables, ConversionPatternRewriter &rewriter, - scf::YieldOp yield) { + scf::YieldOp yield, bool createYield = true) { Location loc = yield.getLoc(); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(yield); SmallVector<Value> yieldOperands; - if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) { + if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) return rewriter.notifyMatchFailure(op, "failed to lower yield operands"); - } assignValues(yieldOperands, resultVariables, rewriter, loc); @@ -336,11 +336,177 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite( return success(); } +// Lower scf::while to emitc::do using mutable variables to maintain loop state +// across iterations. The do-while structure ensures the condition is evaluated +// after each iteration, matching SCF while semantics. +struct WhileLowering : public OpConversionPattern<WhileOp> { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = whileOp.getLoc(); + MLIRContext *context = loc.getContext(); + + // Create an emitc::variable op for each result. These variables will be + // assigned to by emitc::assign ops within the loop body. + SmallVector<Value> resultVariables; + if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter, + resultVariables))) + return rewriter.notifyMatchFailure(whileOp, + "Failed to create result variables"); + + // Create variable storage for loop-carried values to enable imperative + // updates while maintaining SSA semantics at conversion boundaries. + SmallVector<Value> loopVariables; + if (failed(createVariablesForLoopCarriedValues( + whileOp, rewriter, loopVariables, loc, context))) + return failure(); + + if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context, + rewriter, loc))) + return failure(); + + rewriter.setInsertionPointAfter(whileOp); + + // Load the final result values from result variables. + SmallVector<Value> finalResults = + loadValues(resultVariables, rewriter, loc); + rewriter.replaceOp(whileOp, finalResults); + + return success(); + } + +private: + // Initialize variables for loop-carried values to enable state updates + // across iterations without SSA argument passing. + LogicalResult createVariablesForLoopCarriedValues( + WhileOp whileOp, ConversionPatternRewriter &rewriter, + SmallVectorImpl<Value> &loopVars, Location loc, + MLIRContext *context) const { + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(whileOp); + + emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, ""); + + for (Value init : whileOp.getInits()) { + Type convertedType = getTypeConverter()->convertType(init.getType()); + if (!convertedType) + return rewriter.notifyMatchFailure(whileOp, "type conversion failed"); + + emitc::VariableOp var = rewriter.create<emitc::VariableOp>( + loc, emitc::LValueType::get(convertedType), noInit); + rewriter.create<emitc::AssignOp>(loc, var.getResult(), init); + loopVars.push_back(var); + } + + return success(); + } + + // Lower scf.while to emitc.do. + LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars, + ArrayRef<Value> resultVars, MLIRContext *context, + ConversionPatternRewriter &rewriter, + Location loc) const { + // Create a global boolean variable to store the loop condition state. + Type i1Type = IntegerType::get(context, 1); + auto globalCondition = + rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type), + emitc::OpaqueAttr::get(context, "")); + Value conditionVal = globalCondition.getResult(); + + auto loweredDo = rewriter.create<emitc::DoOp>(loc); + + // Convert region types to match the target dialect type system. + if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(), + *getTypeConverter(), nullptr)) || + failed(rewriter.convertRegionTypes(&whileOp.getAfter(), + *getTypeConverter(), nullptr))) { + return rewriter.notifyMatchFailure(whileOp, + "region types conversion failed"); + } + + // Prepare the before region (condition evaluation) for merging. + Block *beforeBlock = &whileOp.getBefore().front(); + Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion()); + rewriter.setInsertionPointToStart(bodyBlock); + + // Load current variable values to use as initial arguments for the + // condition block. + SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc); + rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues); + + Operation *condTerminator = + loweredDo.getBodyRegion().back().getTerminator(); + scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator); + rewriter.setInsertionPoint(condOp); + + // Update result variables with values from scf::condition. + SmallVector<Value> conditionArgs; + for (Value arg : condOp.getArgs()) { + conditionArgs.push_back(rewriter.getRemappedValue(arg)); + } + assignValues(conditionArgs, resultVars, rewriter, loc); + + // Convert scf.condition to condition variable assignment. + Value condition = rewriter.getRemappedValue(condOp.getCondition()); + rewriter.create<emitc::AssignOp>(loc, conditionVal, condition); + + // Wrap body region in conditional to preserve scf semantics. Only create + // ifOp if after-region is non-empty. + if (whileOp.getAfterBody()->getOperations().size() > 1) { + auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false); + + // Prepare the after region (loop body) for merging. + Block *afterBlock = &whileOp.getAfter().front(); + Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion()); + + // Replacement values for after block using condition op arguments. + SmallVector<Value> afterReplacingValues; + for (Value arg : condOp.getArgs()) + afterReplacingValues.push_back(rewriter.getRemappedValue(arg)); + + rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues); + + if (failed(lowerYield(whileOp, loopVars, rewriter, + cast<scf::YieldOp>(ifBodyBlock->getTerminator())))) + return failure(); + } + + rewriter.eraseOp(condOp); + + // Create condition region that loads from the flag variable. + Region &condRegion = loweredDo.getConditionRegion(); + Block *condBlock = rewriter.createBlock(&condRegion); + rewriter.setInsertionPointToStart(condBlock); + + auto exprOp = rewriter.create<emitc::ExpressionOp>( + loc, i1Type, conditionVal, /*do_not_inline=*/false); + Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion()); + + // Set up the expression block to load the condition variable. + exprBlock->addArgument(conditionVal.getType(), loc); + rewriter.setInsertionPointToStart(exprBlock); + + // Load the condition value and yield it as the expression result. + Value cond = + rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0)); + rewriter.create<emitc::YieldOp>(loc, cond); + + // Yield the expression as the condition region result. + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create<emitc::YieldOp>(loc, exprOp); + + return success(); + } +}; + void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns, TypeConverter &typeConverter) { patterns.add<ForLowering>(typeConverter, patterns.getContext()); patterns.add<IfLowering>(typeConverter, patterns.getContext()); patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext()); + patterns.add<WhileLowering>(typeConverter, patterns.getContext()); } void SCFToEmitCPass::runOnOperation() { @@ -357,7 +523,8 @@ void SCFToEmitCPass::runOnOperation() { // Configure conversion to lower out SCF operations. ConversionTarget target(getContext()); - target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>(); + target + .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 57877b8..f449d90 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -214,6 +214,10 @@ static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) { return op.getCacheControl(); } +static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) { + return op.getCacheControl(); +} + static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) { return op.getCacheControl(); } @@ -222,6 +226,10 @@ static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) { return op.getCacheControl(); } +static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) { + return op.getCacheControl(); +} + static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) { if (op->hasAttr("cache_control")) { auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control"); @@ -263,6 +271,7 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> || std::is_same_v<OpType, BlockPrefetch2dOp> || std::is_same_v<OpType, LLVM::LoadOp> || + std::is_same_v<OpType, BlockLoadOp> || std::is_same_v<OpType, PrefetchOp>; const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey}; SmallVector<int32_t, decorationCacheControlArity> decorationsL1{ @@ -618,6 +627,77 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> { return success(); } }; + +template <typename OpType> +class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> { + using OpConversionPattern<OpType>::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>; + // Get OpenCL function name + // https://registry.khronos.org/OpenCL/extensions/ + // intel/cl_intel_subgroup_local_block_io.html + std::string funcName{"intel_sub_group_block_"}; + // Value or Result type can be vector or scalar + Type valOrResTy; + if constexpr (isStore) { + funcName += "write_u"; + valOrResTy = op.getVal().getType(); + } else { + funcName += "read_u"; + valOrResTy = op.getType(); + } + // Get element type of the vector/scalar + VectorType vecTy = dyn_cast<VectorType>(valOrResTy); + Type elemType = vecTy ? vecTy.getElementType() : valOrResTy; + funcName += getTypeMangling(elemType); + if (vecTy) + funcName += std::to_string(vecTy.getNumElements()); + SmallVector<Type, 2> argTypes{}; + // XeVM BlockLoad/StoreOp always use signless integer types + // but OpenCL builtins expect unsigned types + // use unsigned types for mangling + SmallVector<bool, 2> isUnsigned{}; + // arg0: pointer to the src/dst address + // arg1 - only if store : vector to store + // Prepare arguments + SmallVector<Value, 2> args{}; + args.push_back(op.getPtr()); + argTypes.push_back(op.getPtr().getType()); + isUnsigned.push_back(true); + Type retType; + if constexpr (isStore) { + args.push_back(op.getVal()); + argTypes.push_back(op.getVal().getType()); + isUnsigned.push_back(true); + retType = LLVM::LLVMVoidType::get(rewriter.getContext()); + } else { + retType = valOrResTy; + } + funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName + + "PU3AS" + + std::to_string(op.getPtr().getType().getAddressSpace()); + funcName += getTypeMangling(elemType, /*isUnsigned=*/true); + if constexpr (isStore) + funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true); + LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs}; + + LLVM::CallOp call = + createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args, + {}, funcAttr, op.getOperation()); + if (std::optional<ArrayAttr> optCacheControls = + getCacheControlMetadata(rewriter, op)) { + call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); + } + if constexpr (isStore) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, call->getResult(0)); + return success(); + } +}; + template <typename OpType> class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> { using OpConversionPattern<OpType>::OpConversionPattern; @@ -693,7 +773,10 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target, LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>, MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern, LLVMLoadStoreToOCLPattern<LLVM::LoadOp>, - LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext()); + LLVMLoadStoreToOCLPattern<LLVM::StoreOp>, + BlockLoadStore1DToOCLPattern<BlockLoadOp>, + BlockLoadStore1DToOCLPattern<BlockStoreOp>>( + patterns.getContext()); } void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) { diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 5c8564b..4754f0b 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -974,10 +974,10 @@ LogicalResult emitc::YieldOp::verify() { Value result = getResult(); Operation *containingOp = getOperation()->getParentOp(); - if (result && containingOp->getNumResults() != 1) + if (!isa<DoOp>(containingOp) && result && containingOp->getNumResults() != 1) return emitOpError() << "yields a value not returned by parent"; - if (!result && containingOp->getNumResults() != 0) + if (!isa<DoOp>(containingOp) && !result && containingOp->getNumResults() != 0) return emitOpError() << "does not yield a value to be returned by parent"; return success(); @@ -1562,6 +1562,76 @@ LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) { } //===----------------------------------------------------------------------===// +// DoOp +//===----------------------------------------------------------------------===// + +void DoOp::print(OpAsmPrinter &p) { + p << ' '; + p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false); + p << " while "; + p.printRegion(getConditionRegion()); + p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs()); +} + +LogicalResult emitc::DoOp::verify() { + Block &condBlock = getConditionRegion().front(); + + if (condBlock.getOperations().size() != 2) + return emitOpError( + "condition region must contain exactly two operations: " + "'emitc.expression' followed by 'emitc.yield', but found ") + << condBlock.getOperations().size() << " operations"; + + Operation &first = condBlock.front(); + auto exprOp = dyn_cast<emitc::ExpressionOp>(first); + if (!exprOp) + return emitOpError("expected first op in condition region to be " + "'emitc.expression', but got ") + << first.getName(); + + if (!exprOp.getResult().getType().isInteger(1)) + return emitOpError("emitc.expression in condition region must return " + "'i1', but returns ") + << exprOp.getResult().getType(); + + Operation &last = condBlock.back(); + auto condYield = dyn_cast<emitc::YieldOp>(last); + if (!condYield) + return emitOpError("expected last op in condition region to be " + "'emitc.yield', but got ") + << last.getName(); + + if (condYield.getNumOperands() != 1) + return emitOpError("expected condition region to return 1 value, but " + "it returns ") + << condYield.getNumOperands() << " values"; + + if (condYield.getOperand(0) != exprOp.getResult()) + return emitError("'emitc.yield' must return result of " + "'emitc.expression' from this condition region"); + + Block &bodyBlock = getBodyRegion().front(); + if (bodyBlock.mightHaveTerminator()) + return emitOpError("body region must not contain terminator"); + + return success(); +} + +ParseResult DoOp::parse(OpAsmParser &parser, OperationState &result) { + Region *bodyRegion = result.addRegion(); + Region *condRegion = result.addRegion(); + + if (parser.parseRegion(*bodyRegion) || parser.parseKeyword("while") || + parser.parseRegion(*condRegion)) + return failure(); + + if (bodyRegion->empty()) + bodyRegion->emplaceBlock(); + + return parser.parseOptionalAttrDictWithKeyword(result.attributes); +} + +//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 7863c21..0dac688 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( Attribute oneIdxAttr = rewriter.getIndexAttr(1); Location loc = packOp.getLoc(); - Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); - DenseMap<int64_t, OpFoldResult> dimAndTileMapping = - packOp.getDimAndTileMapping(); int64_t srcRank = packOp.getSourceRank(); int64_t destRank = packOp.getDestRank(); - int64_t numTiles = destRank - srcRank; + ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos(); + int64_t numberOfTiles = innerDimsPos.size(); - // 1. Extract the inner tile sizes. - // Where possible, values are replaced with constant attributes (to match the - // behaviour of `getPackOpSourceOrPaddedSource`). - SmallVector<OpFoldResult> tileSizes; - for (auto i : llvm::seq<unsigned>(0, srcRank)) { - if (dimAndTileMapping.count(i)) { - // Rather than taking the tile size as is, extact the actual constant - // value Attribute where possible, e.g.: - // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8] - auto [_, tileSize] = - getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); - tileSizes.push_back(tileSize); - } - } + // 1. Get the input that is going to be packed. If the input requires padding, + // add a padding operation and return that as the input. + Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); // 2. Transpose the input to match the inner tile order: // %init = tensor.empty() // %transposed_tile = linalg.transpose ins(%source_or_padded_source), // outs(%init) // Assumptions made: - // 1. All outer dims are 1 - the corresponding transposition order doesn't + // - All outer dims are 1 - the corresponding transposition order doesn't // matter, but requires all dim indices to be present. + + // 2.1 Get the permutation for linalg.transpose SmallVector<int64_t> srcPermForTranspose; - ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos()); for (int64_t i = 0; i < srcRank; i++) { // We assume the `k` dimensions of the inner dim position, where `k` is the // rank of the inner tiling, correspond to the last `k` indices of the @@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( // rank of the source tensor. For example if we have a source tensor with // indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining // indices are [1, 2]. and the transpose will be [1, 2, 3, 0]. - if (llvm::is_contained(innerDimPos, i)) + if (llvm::is_contained(innerDimsPos, i)) continue; srcPermForTranspose.push_back(i); } - srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end()); + srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end()); + + // 2.2 Create the init tensor for linalg.transpose with the correct shape + SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles, + oneIdxAttr); + shapeForEmptyOp.append(packOp.getMixedTiles()); + + // getMixedTiles() may contain Values pointing to constant ops, not the + // constant attributes. Replace them with a true OpFoldResult. + llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(), + [&](OpFoldResult ofr) { + if (auto val = llvm::dyn_cast<Value>(ofr)) + return getAsOpFoldResult(val); + return ofr; + }); LDBG() << "Pack permutation: " << packOp; LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose); + LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp); - // 2.1 Create tensor.empty (init value for TransposeOp) - SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles, - oneIdxAttr); - transShapeForEmptyOp.append(tileSizes); - - applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, - srcPermForTranspose); - Value empty = - tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp, - packOp.getSourceType().getElementType()); + Value empty = tensor::EmptyOp::create( + rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType()); - // 2.2 Create linalg.transpose + // 2.3 Create linalg.transpose auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty, srcPermForTranspose); @@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); // Outer dims are all 1s! - SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(), - oneIdxAttr); + SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr); SmallVector<int64_t> writeShape; for (auto tileSize : packOp.getMixedTiles()) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index ef172c1..37bdd8b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -186,11 +186,11 @@ constexpr float A_2x2_5x5[] = { /// Structure to keep information of constant transform matrices. struct TransformMatrix { - TransformMatrix(const float *table, int64_t rows, int64_t cols, + TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols, int64_t scalarFactor = 1) : table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {} - const float *table; + ArrayRef<float> table; int64_t rows; int64_t cols; int64_t scalarFactor; @@ -199,14 +199,20 @@ struct TransformMatrix { /// Utility function to convert constant array to arith.constant Value. Value create2DTransformMatrix(OpBuilder &builder, Location loc, TransformMatrix transform, Type type) { - ArrayRef<float> constVec(transform.table, transform.rows * transform.cols); - + assert(transform.table.size() == + static_cast<size_t>(transform.rows * transform.cols)); + assert(type.isFloat() && "Only floats are supported by Winograd"); + ArrayRef<float> constVec(transform.table.data(), + transform.rows * transform.cols); + auto constAttrVec = + llvm::map_to_vector<>(constVec, [&](const float v) -> Attribute { + return builder.getFloatAttr(type, v); + }); + SmallVector<int64_t, 2> shape{transform.rows, transform.cols}; return arith::ConstantOp::create( builder, loc, - DenseFPElementsAttr::get( - RankedTensorType::get( - SmallVector<int64_t>{transform.rows, transform.cols}, type), - constVec)); + DenseFPElementsAttr::get(RankedTensorType::get(shape, type), + constAttrVec)); } /// Extract height x width data from 4D tensors. @@ -551,8 +557,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); - Value BT = - create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type()); + Value BT = create2DTransformMatrix(builder, loc, BTMatrix, elementType); // Multiply BT x d. auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, ValueRange{BT, matmulRetValue}, @@ -574,8 +579,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input, .getResult(); auto init = linalg::FillOp::create(builder, loc, zero, empty).getResult(0); - Value B = - create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type()); + Value B = create2DTransformMatrix(builder, loc, BMatrix, elementType); // Multiply v = (BT x d) x B. auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType, ValueRange{matmulRetValue, B}, diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp index 697cb35..237aab4 100644 --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -27,7 +27,7 @@ using namespace mlir::nvgpu; #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc" -void nvgpu::NVGPUDialect::initialize() { +void NVGPUDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc" @@ -42,7 +42,7 @@ void nvgpu::NVGPUDialect::initialize() { >(); } -bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { +bool NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { if (!memorySpace) return false; if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace)) @@ -52,7 +52,7 @@ bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { return false; } -bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { +bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { Attribute memorySpace = type.getMemorySpace(); return isSharedMemoryAddressSpace(memorySpace); } @@ -140,7 +140,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op, TypedValue<VectorType> matrixC, const std::array<int64_t, 3> &mmaShape, bool tf32Enabled, bool sparse = false) { - // The verification for mma.sync covering various shapes and data types is // based on the fundamental tensor core shape. @@ -292,7 +291,6 @@ LogicalResult MmaSparseSyncOp::verify() { // NVGPU_LdMatrixOp //===----------------------------------------------------------------------===// LogicalResult LdMatrixOp::verify() { - // ldmatrix reads data from source in shared memory auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType()); @@ -345,7 +343,7 @@ LogicalResult LdMatrixOp::verify() { // NVGPU_TmaAsyncLoadOp //===----------------------------------------------------------------------===// -unsigned getSwizzleBytes(TensorMapSwizzleKind kind) { +static unsigned getSwizzleBytes(TensorMapSwizzleKind kind) { switch (kind) { case TensorMapSwizzleKind::SWIZZLE_32B: return 32; @@ -359,7 +357,7 @@ unsigned getSwizzleBytes(TensorMapSwizzleKind kind) { } std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref( - Operation *op, nvgpu::TensorMapDescriptorType descType, + Operation *op, TensorMapDescriptorType descType, std::optional<MemRefType> memrefType = std::nullopt) { MemRefType descMemref = descType.getTensor(); // Limitation @@ -655,8 +653,7 @@ LogicalResult WarpgroupMmaStoreOp::verify() { //===----------------------------------------------------------------------===// LogicalResult WarpgroupMmaInitAccumulatorOp::verify() { - - nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType(); + WarpgroupAccumulatorType accType = getMatrixC().getType(); int64_t sizeM = accType.getFragmented().getDimSize(0); int64_t sizeN = accType.getFragmented().getDimSize(1); Type elemType = accType.getFragmented().getElementType(); diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp index 46e82bd..2a857ed 100644 --- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp +++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp @@ -43,7 +43,7 @@ using namespace mlir::transform; // Apply...ConversionPatternsOp //===----------------------------------------------------------------------===// -void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( +void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( TypeConverter &typeConverter, RewritePatternSet &patterns) { auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter); /// device-side async tokens cannot be materialized in nvvm. We just @@ -62,62 +62,58 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns( llvm_unreachable("unknown address space enum value"); return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic); }); - llvmTypeConverter.addConversion( - [&](nvgpu::DeviceAsyncTokenType type) -> Type { - return llvmTypeConverter.convertType( - IntegerType::get(type.getContext(), 32)); - }); - llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type { + llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) -> Type { + return llvmTypeConverter.convertType( + IntegerType::get(type.getContext(), 32)); + }); + llvmTypeConverter.addConversion([&](MBarrierTokenType type) -> Type { return llvmTypeConverter.convertType( IntegerType::get(type.getContext(), 64)); }); - llvmTypeConverter.addConversion( - [&](nvgpu::WarpgroupAccumulatorType type) -> Type { - Type elemType = type.getFragmented().getElementType(); - int64_t sizeM = type.getFragmented().getDimSize(0); - int64_t sizeN = type.getFragmented().getDimSize(1); - - unsigned numMembers; - if (elemType.isF32() || elemType.isInteger(32)) - numMembers = sizeN / 2; - else if (elemType.isF16()) - numMembers = sizeN / 4; - else - llvm_unreachable("unsupported type for warpgroup accumulator"); - - SmallVector<Type> innerStructBody; - for (unsigned i = 0; i < numMembers; i++) - innerStructBody.push_back(elemType); - auto innerStructType = LLVM::LLVMStructType::getLiteral( - type.getContext(), innerStructBody); - - SmallVector<Type> structBody; - for (int i = 0; i < sizeM; i += kWgmmaSizeM) - structBody.push_back(innerStructType); - - auto convertedType = - LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); - return llvmTypeConverter.convertType(convertedType); - }); - llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type { + llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) -> Type { + Type elemType = type.getFragmented().getElementType(); + int64_t sizeM = type.getFragmented().getDimSize(0); + int64_t sizeN = type.getFragmented().getDimSize(1); + + unsigned numMembers; + if (elemType.isF32() || elemType.isInteger(32)) + numMembers = sizeN / 2; + else if (elemType.isF16()) + numMembers = sizeN / 4; + else + llvm_unreachable("unsupported type for warpgroup accumulator"); + + SmallVector<Type> innerStructBody; + for (unsigned i = 0; i < numMembers; i++) + innerStructBody.push_back(elemType); + auto innerStructType = + LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody); + + SmallVector<Type> structBody; + for (int i = 0; i < sizeM; i += kWgmmaSizeM) + structBody.push_back(innerStructType); + + auto convertedType = + LLVM::LLVMStructType::getLiteral(type.getContext(), structBody); + return llvmTypeConverter.convertType(convertedType); + }); + llvmTypeConverter.addConversion([&](MBarrierGroupType type) -> Type { return llvmTypeConverter.convertType( getMBarrierMemrefType(type.getContext(), type)); }); llvmTypeConverter.addConversion( - [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type { + [&](WarpgroupMatrixDescriptorType type) -> Type { return llvmTypeConverter.convertType( IntegerType::get(type.getContext(), 64)); }); - llvmTypeConverter.addConversion( - [&](nvgpu::TensorMapDescriptorType type) -> Type { - return LLVM::LLVMPointerType::get(type.getContext()); - }); + llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) -> Type { + return LLVM::LLVMPointerType::get(type.getContext()); + }); populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns); } -LogicalResult -transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( - transform::TypeConverterBuilderOpInterface builder) { +LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( + TypeConverterBuilderOpInterface builder) { if (builder.getTypeConverterType() != "LLVMTypeConverter") return emitOpError("expected LLVMTypeConverter"); return success(); @@ -127,17 +123,18 @@ transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter( // CreateAsyncGroupsOp //===---------------------------------------------------------------------===// -void transform::CreateAsyncGroupsOp::getEffects( +void CreateAsyncGroupsOp::getEffects( SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { - transform::consumesHandle(getTargetMutable(), effects); - transform::producesHandle(getOperation()->getOpResults(), effects); - transform::modifiesPayload(effects); + consumesHandle(getTargetMutable(), effects); + producesHandle(getOperation()->getOpResults(), effects); + modifiesPayload(effects); } -DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne( - TransformRewriter &rewriter, Operation *target, - ApplyToEachResultList &results, TransformState &state) { - nvgpu::createAsyncGroups(rewriter, target, getBypassL1()); +DiagnosedSilenceableFailure +CreateAsyncGroupsOp::applyToOne(TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, + TransformState &state) { + createAsyncGroups(rewriter, target, getBypassL1()); results.push_back(target); return DiagnosedSilenceableFailure::success(); } @@ -218,7 +215,7 @@ collectStage0PipeliningOps(scf::ForOp forOp, continue; } - if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) { + if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) { ops.insert(&op); ops.insert(std::make_move_iterator(barriers.begin()), std::make_move_iterator(barriers.end())); @@ -246,7 +243,7 @@ setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, unsigned iteration, unsigned depth) { // Based on the order of copies within the loop we need to set the number // of copies in flight, unless it is already set. - auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op); + auto waitOp = dyn_cast<DeviceAsyncWaitOp>(op); if (!waitOp || waitOp.getNumGroups()) return; @@ -312,13 +309,12 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, // original number of iterations, in particular side-effect free operations // and barriers, even if they cannot be predicated. if (isMemoryEffectFree(op) || - isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp, - nvgpu::DeviceAsyncWaitOp>(op)) { + isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) { return op; } // Otherwise, only async copies can currently be predicated. - auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op); + auto asyncCopyOp = dyn_cast<DeviceAsyncCopyOp>(op); if (!asyncCopyOp) return nullptr; @@ -335,8 +331,8 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter, Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0); auto srcElements = arith::SelectOp::create(rewriter, loc, predicate, originalSrcElement, c0Index); - auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create( - rewriter, loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()), + auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create( + rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()), asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(), asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements, UnitAttr()); @@ -805,17 +801,16 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) { rhsIndexFn, rhsShape); Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef, resIndexFn, resShape); - res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, - info.tf32Enabled); + res = + MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled); buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn, resShape); return res.getDefiningOp(); } -DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne( - transform::TransformRewriter &rewriter, LinalgOp linalgOp, - transform::ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure RewriteMatmulAsMmaSyncOp::applyToOne( + TransformRewriter &rewriter, LinalgOp linalgOp, + ApplyToEachResultList &results, TransformState &state) { bool fail = true; // TODO: more robust detection of matmulOp, with transposes etc. if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) { @@ -854,43 +849,42 @@ struct HopperBuilder { HopperBuilder(RewriterBase &rewriter, Location loc) : rewriter(rewriter), loc(loc) {} - TypedValue<nvgpu::MBarrierGroupType> + TypedValue<MBarrierGroupType> buildAndInitBarrierInSharedMemory(OpFoldResult numThreads); /// Create tma descriptor op to initiate transfer from global to shared /// memory. This must be done before the launch op, on the host. - TypedValue<nvgpu::TensorMapDescriptorType> + TypedValue<TensorMapDescriptorType> buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, gpu::LaunchOp launchOp); /// Build a tma load from global memory to shared memory using `barrier` to /// synchronize. Return the number of bytes that will be transferred. - OpFoldResult - buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, - TypedValue<MemRefType> sharedMemref, - TypedValue<nvgpu::MBarrierGroupType> barrier, - SmallVectorImpl<Operation *> &loadOps); - void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier, + OpFoldResult buildTmaAsyncLoad(TypedValue<TensorMapDescriptorType> globalDesc, + TypedValue<MemRefType> sharedMemref, + TypedValue<MBarrierGroupType> barrier, + SmallVectorImpl<Operation *> &loadOps); + void buildBarrierArriveTx(TypedValue<MBarrierGroupType> barrier, ArrayRef<OpFoldResult> sizes); /// If threadIdx.x == 0 does TMA request + wait, else just wait. /// Return the operation that performs the transfer on thread0. // TODO: In the future, don't hardcode to thread 0 but elect a leader. SmallVector<Operation *> buildPredicateLoadsOnThread0( - ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, + ArrayRef<TypedValue<TensorMapDescriptorType>> globalDescriptors, ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, - TypedValue<nvgpu::MBarrierGroupType> barrier); + TypedValue<MBarrierGroupType> barrier); - void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier); + void buildTryWaitParity(TypedValue<MBarrierGroupType> barrier); RewriterBase &rewriter; Location loc; }; SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0( - ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors, + ArrayRef<TypedValue<TensorMapDescriptorType>> globalDescriptors, ArrayRef<TypedValue<MemRefType>> sharedMemBuffers, - TypedValue<nvgpu::MBarrierGroupType> barrier) { + TypedValue<MBarrierGroupType> barrier) { SmallVector<Operation *> loadOps; Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x); @@ -931,22 +925,22 @@ static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) { // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace)); } -TypedValue<nvgpu::MBarrierGroupType> +TypedValue<MBarrierGroupType> HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) { auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); - Value barrier = nvgpu::MBarrierCreateOp::create( + Value barrier = MBarrierCreateOp::create( rewriter, loc, - nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); + MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace)); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); nvgpu::MBarrierInitOp::create( rewriter, loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero, Value()); gpu::BarrierOp::create(rewriter, loc); - return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier); + return cast<TypedValue<MBarrierGroupType>>(barrier); } -TypedValue<nvgpu::TensorMapDescriptorType> +TypedValue<TensorMapDescriptorType> HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, gpu::LaunchOp launchOp) { OpBuilder::InsertionGuard guard(rewriter); @@ -962,29 +956,29 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref, getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes); auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter); - Value desc = nvgpu::TmaCreateDescriptorOp::create( + Value desc = TmaCreateDescriptorOp::create( rewriter, loc, - nvgpu::TensorMapDescriptorType::get( - rewriter.getContext(), - MemRefType::Builder(memref.getType()) - .setMemorySpace(sharedMemorySpace), - TensorMapSwizzleKind::SWIZZLE_NONE, - TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO, - TensorMapInterleaveKind::INTERLEAVE_NONE), + TensorMapDescriptorType::get(rewriter.getContext(), + MemRefType::Builder(memref.getType()) + .setMemorySpace(sharedMemorySpace), + TensorMapSwizzleKind::SWIZZLE_NONE, + TensorMapL2PromoKind::L2PROMO_NONE, + TensorMapOOBKind::OOB_ZERO, + TensorMapInterleaveKind::INTERLEAVE_NONE), unrankedMemRef, sizes); - return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc); + return cast<TypedValue<TensorMapDescriptorType>>(desc); } -OpFoldResult HopperBuilder::buildTmaAsyncLoad( - TypedValue<nvgpu::TensorMapDescriptorType> globalDesc, - TypedValue<MemRefType> sharedMemref, - TypedValue<nvgpu::MBarrierGroupType> barrier, - SmallVectorImpl<Operation *> &loadOps) { +OpFoldResult +HopperBuilder::buildTmaAsyncLoad(TypedValue<TensorMapDescriptorType> globalDesc, + TypedValue<MemRefType> sharedMemref, + TypedValue<MBarrierGroupType> barrier, + SmallVectorImpl<Operation *> &loadOps) { MLIRContext *ctx = rewriter.getContext(); Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); - Operation *loadOp = nvgpu::TmaAsyncLoadOp::create( - rewriter, loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, - zero, Value(), Value()); + Operation *loadOp = + TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc, + ValueRange{zero, zero}, zero, Value(), Value()); loadOps.push_back(loadOp); auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref); SmallVector<AffineExpr> symbols(mixedSizes.size()); @@ -997,9 +991,8 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad( return res; } -void HopperBuilder::buildBarrierArriveTx( - TypedValue<nvgpu::MBarrierGroupType> barrier, - ArrayRef<OpFoldResult> mixedSizes) { +void HopperBuilder::buildBarrierArriveTx(TypedValue<MBarrierGroupType> barrier, + ArrayRef<OpFoldResult> mixedSizes) { assert(!mixedSizes.empty() && "expecte non-empty sizes"); MLIRContext *ctx = rewriter.getContext(); SmallVector<AffineExpr> symbols(mixedSizes.size()); @@ -1013,8 +1006,7 @@ void HopperBuilder::buildBarrierArriveTx( Value()); } -void HopperBuilder::buildTryWaitParity( - TypedValue<nvgpu::MBarrierGroupType> barrier) { +void HopperBuilder::buildTryWaitParity(TypedValue<MBarrierGroupType> barrier) { Type i1 = rewriter.getI1Type(); Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0); // 10M is an arbitrary, not too small or too big number to specify the number @@ -1058,11 +1050,11 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(), launchOp.getBlockSizeZ()}); - TypedValue<nvgpu::MBarrierGroupType> barrier = + TypedValue<MBarrierGroupType> barrier = buildAndInitBarrierInSharedMemory(numThreads); SmallVector<TypedValue<MemRefType>> shmems; - SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs; + SmallVector<TypedValue<TensorMapDescriptorType>> globalDescs; for (Operation *op : copyOps) { auto copyOp = cast<linalg::CopyOp>(op); auto inMemRef = @@ -1071,7 +1063,7 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { "expected in to be a 2D memref"); // 2. Build global memory descriptor. - TypedValue<nvgpu::TensorMapDescriptorType> globalDesc = + TypedValue<TensorMapDescriptorType> globalDesc = buildGlobalMemRefDescriptor(inMemRef, launchOp); globalDescs.push_back(globalDesc); @@ -1098,9 +1090,8 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) { } DiagnosedSilenceableFailure -transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, - transform::TransformResults &results, - transform::TransformState &state) { +RewriteCopyAsTmaOp::apply(TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { auto payloadOps = state.getPayloadOps(getTarget()); gpu::LaunchOp commonLaunchOp; Operation *firstOp, *failingOp; @@ -1137,15 +1128,14 @@ transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter, namespace { class NVGPUTransformDialectExtension - : public transform::TransformDialectExtension< - NVGPUTransformDialectExtension> { + : public TransformDialectExtension<NVGPUTransformDialectExtension> { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension) NVGPUTransformDialectExtension() { declareGeneratedDialect<arith::ArithDialect>(); declareGeneratedDialect<affine::AffineDialect>(); - declareGeneratedDialect<nvgpu::NVGPUDialect>(); + declareGeneratedDialect<NVGPUDialect>(); declareGeneratedDialect<NVVM::NVVMDialect>(); declareGeneratedDialect<vector::VectorDialect>(); registerTransformOps< diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp index 5b89c87..7f626a6 100644 --- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp +++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp @@ -64,6 +64,5 @@ private: void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns( RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) { - patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision); } diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp index 809d634..9e5ea93 100644 --- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp +++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp @@ -168,8 +168,7 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType) { Type elementType = fragmentType.vectorType.getElementType(); ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape(); - FailureOr<nvgpu::FragmentElementInfo> regInfo = - getMmaSyncRegisterType(fragmentType); + FailureOr<FragmentElementInfo> regInfo = getMmaSyncRegisterType(fragmentType); if (failed(regInfo)) return failure(); @@ -199,8 +198,8 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, (logicalValueIdDim % elementsPerRegister)}); } -FailureOr<nvgpu::LdMatrixParams> -nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) { +FailureOr<LdMatrixParams> nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, + bool transpose) { LdMatrixParams params; Type elType = type.vectorType.getElementType(); params.fragmentType = type.vectorType; diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp index 6598ac1..6564a4e 100644 --- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp +++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp @@ -7,6 +7,7 @@ // ============================================================================= #include "mlir/Dialect/OpenACC/OpenACC.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -44,6 +45,7 @@ struct MemRefPointerLikeModel Type getElementType(Type pointer) const { return cast<MemRefType>(pointer).getElementType(); } + mlir::acc::VariableTypeCategory getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr, Type varType) const { @@ -70,6 +72,115 @@ struct MemRefPointerLikeModel assert(memrefTy.getRank() > 0 && "rank expected to be positive"); return mlir::acc::VariableTypeCategory::array; } + + mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc, + StringRef varName, Type varType, + Value originalVar) const { + auto memrefTy = cast<MemRefType>(pointer); + + // Check if this is a static memref (all dimensions are known) - if yes + // then we can generate an alloca operation. + if (memrefTy.hasStaticShape()) + return memref::AllocaOp::create(builder, loc, memrefTy).getResult(); + + // For dynamic memrefs, extract sizes from the original variable if + // provided. Otherwise they cannot be handled. + if (originalVar && originalVar.getType() == memrefTy && + memrefTy.hasRank()) { + SmallVector<Value> dynamicSizes; + for (int64_t i = 0; i < memrefTy.getRank(); ++i) { + if (memrefTy.isDynamicDim(i)) { + // Extract the size of dimension i from the original variable + auto indexValue = arith::ConstantIndexOp::create(builder, loc, i); + auto dimSize = + memref::DimOp::create(builder, loc, originalVar, indexValue); + dynamicSizes.push_back(dimSize); + } + // Note: We only add dynamic sizes to the dynamicSizes array + // Static dimensions are handled automatically by AllocOp + } + return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes) + .getResult(); + } + + // TODO: Unranked not yet supported. + return {}; + } + + bool genFree(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> varPtr, Type varType) const { + if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) { + // Walk through casts to find the original allocation + Value currentValue = memrefValue; + Operation *originalAlloc = nullptr; + + // Follow the chain of operations to find the original allocation + // even if a casted result is provided. + while (currentValue) { + if (auto *definingOp = currentValue.getDefiningOp()) { + // Check if this is an allocation operation + if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) { + originalAlloc = definingOp; + break; + } + + // Check if this is a cast operation we can look through + if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) { + currentValue = castOp.getSource(); + continue; + } + + // Check for other cast-like operations + if (auto reinterpretCastOp = + dyn_cast<memref::ReinterpretCastOp>(definingOp)) { + currentValue = reinterpretCastOp.getSource(); + continue; + } + + // If we can't look through this operation, stop + break; + } + // This is a block argument or similar - can't trace further. + break; + } + + if (originalAlloc) { + if (isa<memref::AllocaOp>(originalAlloc)) { + // This is an alloca - no dealloc needed, but return true (success) + return true; + } + if (isa<memref::AllocOp>(originalAlloc)) { + // This is an alloc - generate dealloc + memref::DeallocOp::create(builder, loc, memrefValue); + return true; + } + } + } + + return false; + } + + bool genCopy(Type pointer, OpBuilder &builder, Location loc, + TypedValue<PointerLikeType> destination, + TypedValue<PointerLikeType> source, Type varType) const { + // Generate a copy operation between two memrefs + auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination); + auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source); + + // As per memref documentation, source and destination must have same + // element type and shape in order to be compatible. We do not want to fail + // with an IR verification error - thus check that before generating the + // copy operation. + if (destMemref && srcMemref && + destMemref.getType().getElementType() == + srcMemref.getType().getElementType() && + destMemref.getType().getShape() == srcMemref.getType().getShape()) { + memref::CopyOp::create(builder, loc, srcMemref, destMemref); + return true; + } + + return false; + } }; struct LLVMPointerPointerLikeModel diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index b0132e8..14e235f 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -47,6 +47,7 @@ #include <cassert> #include <cstdint> +#include <numeric> #include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc" // Pull in all enum type and utility function definitions. @@ -2412,9 +2413,38 @@ foldToElementsFromElements(ToElementsOp toElementsOp, return success(); } +/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only. +/// +/// Example: +/// %b = vector.broadcast %x : i32 to vector<3xf32> +/// %e:3 = vector.to_elements %b : vector<3xf32> +/// user_op %e#0, %e#1, %e#2 +/// becomes: +/// user_op %x, %x, %x +/// +/// The vector source case is handled by a canonicalization pattern. +static LogicalResult +foldToElementsOfBroadcast(ToElementsOp toElementsOp, + SmallVectorImpl<OpFoldResult> &results) { + auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>(); + if (!bcastOp) + return failure(); + // Vectors are handled in the ToElementsOfBroadcast RewritePattern. + if (isa<VectorType>(bcastOp.getSource().getType())) + return failure(); + + auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType()); + + Value scalar = bcastOp.getSource(); + results.assign(resultVecType.getNumElements(), scalar); + return success(); +} + LogicalResult ToElementsOp::fold(FoldAdaptor adaptor, SmallVectorImpl<OpFoldResult> &results) { - return foldToElementsFromElements(*this, results); + if (succeeded(foldToElementsFromElements(*this, results))) + return success(); + return foldToElementsOfBroadcast(*this, results); } LogicalResult @@ -2427,6 +2457,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc, return success(); } +/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a +/// vector. +/// - Build `vector.to_elements %v` and remap each destination element to the +/// corresponding source element using broadcast rules (match or 1 → +/// replicate). +/// +/// Example: +/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32> +/// %e:6 = vector.to_elements %v : vector<3x2xf32> +/// becomes: +/// %src_elems:2 = vector.to_elements %src : vector<2xf32> +/// // uses: %src_elems#0, %src_elems#1, %src_elems#0, +/// // %src_elems#1, %src_elems#0, %src_elems#1 +struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> { + using Base::Base; + + LogicalResult matchAndRewrite(ToElementsOp toElementsOp, + PatternRewriter &rewriter) const override { + auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>(); + if (!bcastOp) + return failure(); + + // Only handle broadcasts from a vector source here. + auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType()); + if (!srcType) + return failure(); + + auto dstType = cast<VectorType>(toElementsOp.getSource().getType()); + + ArrayRef<int64_t> dstShape = dstType.getShape(); + ArrayRef<int64_t> srcShape = srcType.getShape(); + + int64_t dstRank = dstShape.size(); + int64_t srcRank = srcShape.size(); + + // Create elements for the broadcast source vector. + auto srcElems = vector::ToElementsOp::create( + rewriter, toElementsOp.getLoc(), bcastOp.getSource()); + + int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1, + std::multiplies<int64_t>()); + + SmallVector<Value> replacements; + replacements.reserve(dstCount); + + // For each element of the destination, determine which element of the + // source should be used. We walk all destination positions using a single + // counter, decode it into per-dimension indices, then build the matching + // source position: use the same index where sizes match, and use 0 where + // the source size is 1 (replication). This mapping is needed so we can + // replace each result of to_elements with the corresponding element from + // the broadcast source. + // Inner-dimension stretch example: + // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32> + // %e:12 = vector.to_elements %v : vector<2x3x2xf32> + // becomes: + // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32> + // // uses: %src_elems#0, %src_elems#1, %src_elems#0, + // // %src_elems#1, %src_elems#0, %src_elems#1, + // // %src_elems#2, %src_elems#3, %src_elems#2, + // // %src_elems#3, %src_elems#2, %src_elems#3 + + // Row-major strides for the destination shape. + SmallVector<int64_t> dstStrides = computeStrides(dstShape); + // Row-major strides for the source shape. + SmallVector<int64_t> srcStrides = computeStrides(srcShape); + SmallVector<int64_t> dstIdx(dstRank); + SmallVector<int64_t> srcIdx(srcRank); + for (int64_t lin = 0; lin < dstCount; ++lin) { + // Convert linear destination index to per-dimension indices. + dstIdx = delinearize(lin, dstStrides); + for (int64_t k = 0; k < srcRank; ++k) + srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k]; + // Convert per-dimension source indices back to a linear index. + int64_t srcLin = linearize(srcIdx, srcStrides); + replacements.push_back(srcElems.getResult(srcLin)); + } + + rewriter.replaceOp(toElementsOp, replacements); + return success(); + } +}; + +void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<ToElementsOfBroadcast>(context); +} + //===----------------------------------------------------------------------===// // FromElementsOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index a5bd80e..5fe5f41 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -201,6 +201,8 @@ struct CppEmitter { /// Return the existing or a new label of a Block. StringRef getOrCreateName(Block &block); + LogicalResult emitInlinedExpression(Value value); + /// Whether to map an mlir integer to a unsigned integer in C++. bool shouldMapToUnsigned(IntegerType::SignednessSemantics val); @@ -557,6 +559,30 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); } +static LogicalResult printOperation(CppEmitter &emitter, emitc::DoOp doOp) { + raw_indented_ostream &os = emitter.ostream(); + + os << "do {\n"; + os.indent(); + + Block &bodyBlock = doOp.getBodyRegion().front(); + for (Operation &op : bodyBlock) { + if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true))) + return failure(); + } + + os.unindent() << "} while ("; + + Block &condBlock = doOp.getConditionRegion().front(); + auto condYield = cast<emitc::YieldOp>(condBlock.back()); + if (failed(emitter.emitExpression( + cast<emitc::ExpressionOp>(condYield.getOperand(0).getDefiningOp())))) + return failure(); + + os << ");"; + return success(); +} + static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) { Operation *operation = cmpOp.getOperation(); @@ -1711,13 +1737,14 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp, emitc::CallOpaqueOp, emitc::CastOp, emitc::ClassOp, emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp, - emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp, - emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp, - emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp, - emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp, - emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp, - emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp, - emitc::VariableOp, emitc::VerbatimOp>( + emitc::DeclareFuncOp, emitc::DivOp, emitc::DoOp, + emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp, + emitc::ForOp, emitc::FuncOp, emitc::GlobalOp, emitc::IfOp, + emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp, + emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp, + emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp, + emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp, + emitc::VerbatimOp>( [&](auto op) { return printOperation(*this, op); }) // Func ops. @@ -1765,9 +1792,9 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { // Never emit a semicolon for some operations, especially if endening with // `}`. trailingSemicolon &= - !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::FileOp, emitc::ForOp, - emitc::IfOp, emitc::IncludeOp, emitc::SwitchOp, emitc::VerbatimOp>( - op); + !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::DoOp, emitc::FileOp, + emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::SwitchOp, + emitc::VerbatimOp>(op); os << (trailingSemicolon ? ";\n" : "\n"); diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp index 8b03265..4bbcd8e 100644 --- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp +++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp @@ -59,7 +59,8 @@ DICompileUnitAttr DebugImporter::translateImpl(llvm::DICompileUnit *node) { std::underlying_type_t<llvm::DICompileUnit::DebugNameTableKind>>( node->getNameTableKind())); return DICompileUnitAttr::get( - context, getOrCreateDistinctID(node), node->getSourceLanguage(), + context, getOrCreateDistinctID(node), + node->getSourceLanguage().getUnversionedName(), translate(node->getFile()), getStringAttrOrNull(node->getRawProducer()), node->isOptimized(), emissionKind.value(), nameTableKind.value(), getStringAttrOrNull(node->getRawSplitDebugFilename())); diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp index 63c71cd..1e226c0 100644 --- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp +++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp @@ -51,9 +51,11 @@ public: stream << "ERROR: Runtime op verification failed\n"; if (vLevel == 1) { op->print(stream, state); - stream << "\n"; + stream << "\n^ " << msg; + } else { + stream << "^ " << msg; } - stream << "^\nLocation: "; + stream << "\nLocation: "; op->getLoc().print(stream); return buffer; } |