aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp4
-rw-r--r--mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp233
-rw-r--r--mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp177
-rw-r--r--mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp85
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp74
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp64
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp28
-rw-r--r--mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp15
-rw-r--r--mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp218
-rw-r--r--mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp1
-rw-r--r--mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp7
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp111
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp120
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp47
-rw-r--r--mlir/lib/Target/LLVMIR/DebugImporter.cpp3
-rw-r--r--mlir/lib/Transforms/GenerateRuntimeVerification.cpp6
16 files changed, 971 insertions, 222 deletions
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index de1ed39..377f7eb 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -522,14 +522,14 @@ void DeadCodeAnalysis::visitRegionBranchEdges(
// Mark the entry block as executable.
auto *state = getOrCreate<Executable>(point);
propagateIfChanged(state, state->setToLive());
- LDBG() << "Marked region successor live: " << point;
+ LDBG() << "Marked region successor live: " << *point;
// Add the parent op as a predecessor.
auto *predecessors = getOrCreate<PredecessorState>(point);
propagateIfChanged(
predecessors,
predecessors->join(predecessorOp, successor.getSuccessorInputs()));
- LDBG() << "Added region branch as predecessor for successor: " << point;
+ LDBG() << "Added region branch as predecessor for successor: " << *point;
}
}
diff --git a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
index b51465b..daa3db5 100644
--- a/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DenseAnalysis.cpp
@@ -17,44 +17,74 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>
using namespace mlir;
using namespace mlir::dataflow;
+#define DEBUG_TYPE "dense-analysis"
+
//===----------------------------------------------------------------------===//
// AbstractDenseForwardDataFlowAnalysis
//===----------------------------------------------------------------------===//
void AbstractDenseForwardDataFlowAnalysis::initializeEquivalentLatticeAnchor(
Operation *top) {
+ LDBG() << "initializeEquivalentLatticeAnchor: "
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
top->walk([&](Operation *op) {
- if (isa<RegionBranchOpInterface, CallOpInterface>(op))
+ if (isa<RegionBranchOpInterface, CallOpInterface>(op)) {
+ LDBG() << " Skipping "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " (region branch or call)";
return;
+ }
+ LDBG() << " Building equivalent lattice anchor for "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
buildOperationEquivalentLatticeAnchor(op);
});
}
LogicalResult AbstractDenseForwardDataFlowAnalysis::initialize(Operation *top) {
+ LDBG() << "initialize (forward): "
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
// Visit every operation and block.
- if (failed(processOperation(top)))
+ if (failed(processOperation(top))) {
+ LDBG() << " Failed to process top-level operation";
return failure();
+ }
for (Region &region : top->getRegions()) {
+ LDBG() << " Processing region with " << region.getBlocks().size()
+ << " blocks";
for (Block &block : region) {
+ LDBG() << " Processing block with " << block.getOperations().size()
+ << " operations";
visitBlock(&block);
- for (Operation &op : block)
- if (failed(initialize(&op)))
+ for (Operation &op : block) {
+ LDBG() << " Initializing operation: "
+ << OpWithFlags(&op, OpPrintingFlags().skipRegions());
+ if (failed(initialize(&op))) {
+ LDBG() << " Failed to initialize operation";
return failure();
+ }
+ }
}
}
+ LDBG() << " Forward initialization completed successfully";
return success();
}
LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
- if (!point->isBlockStart())
+ LDBG() << "visit (forward): " << *point;
+ if (!point->isBlockStart()) {
+ LDBG() << " Processing operation: "
+ << OpWithFlags(point->getPrevOp(), OpPrintingFlags().skipRegions());
return processOperation(point->getPrevOp());
+ }
+ LDBG() << " Visiting block: " << point->getBlock();
visitBlock(point->getBlock());
return success();
}
@@ -62,6 +92,11 @@ LogicalResult AbstractDenseForwardDataFlowAnalysis::visit(ProgramPoint *point) {
void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
CallOpInterface call, const AbstractDenseLattice &before,
AbstractDenseLattice *after) {
+ LDBG() << "visitCallOperation (forward): "
+ << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
+ LDBG() << " before state: " << before;
+ LDBG() << " after state: " << *after;
+
// Allow for customizing the behavior of calls to external symbols, including
// when the analysis is explicitly marked as non-interprocedural.
auto isExternalCallable = [&]() {
@@ -70,6 +105,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
return callable && !callable.getCallableRegion();
};
if (!getSolverConfig().isInterprocedural() || isExternalCallable()) {
+ LDBG() << " Handling as external callee (non-interprocedural or external)";
return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, before, after);
}
@@ -78,10 +114,16 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
getProgramPointAfter(call.getOperation()), getProgramPointAfter(call));
// Otherwise, if not all return sites are known, then conservatively assume we
// can't reason about the data-flow.
- if (!predecessors->allPredecessorsKnown())
+ if (!predecessors->allPredecessorsKnown()) {
+ LDBG() << " Not all predecessors known, setting to entry state";
return setToEntryState(after);
+ }
+ LDBG() << " Processing " << predecessors->getKnownPredecessors().size()
+ << " known predecessors";
for (Operation *predecessor : predecessors->getKnownPredecessors()) {
+ LDBG() << " Processing predecessor: "
+ << OpWithFlags(predecessor, OpPrintingFlags().skipRegions());
// Get the lattices at callee return:
//
// func.func @callee() {
@@ -99,6 +141,7 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
const AbstractDenseLattice *latticeAtCalleeReturn =
getLatticeFor(getProgramPointAfter(call.getOperation()),
getProgramPointAfter(predecessor));
+ LDBG() << " Lattice at callee return: " << *latticeAtCalleeReturn;
visitCallControlFlowTransfer(call, CallControlFlowAction::ExitCallee,
*latticeAtCalleeReturn, latticeAfterCall);
}
@@ -106,12 +149,16 @@ void AbstractDenseForwardDataFlowAnalysis::visitCallOperation(
LogicalResult
AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
+ LDBG() << "processOperation (forward): "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
ProgramPoint *point = getProgramPointAfter(op);
// If the containing block is not executable, bail out.
if (op->getBlock() != nullptr &&
!getOrCreateFor<Executable>(point, getProgramPointBefore(op->getBlock()))
- ->isLive())
+ ->isLive()) {
+ LDBG() << " Block not executable, skipping operation";
return success();
+ }
// Get the dense lattice to update.
AbstractDenseLattice *after = getLattice(point);
@@ -119,10 +166,13 @@ AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
// Get the dense state before the execution of the op.
const AbstractDenseLattice *before =
getLatticeFor(point, getProgramPointBefore(op));
+ LDBG() << " before state: " << *before;
+ LDBG() << " after state: " << *after;
// If this op implements region control-flow, then control-flow dictates its
// transfer function.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG() << " Processing as region branch operation";
visitRegionBranchOperation(point, branch, after);
return success();
}
@@ -130,41 +180,57 @@ AbstractDenseForwardDataFlowAnalysis::processOperation(Operation *op) {
// If this is a call operation, then join its lattices across known return
// sites.
if (auto call = dyn_cast<CallOpInterface>(op)) {
+ LDBG() << " Processing as call operation";
visitCallOperation(call, *before, after);
return success();
}
// Invoke the operation transfer function.
+ LDBG() << " Invoking operation transfer function";
return visitOperationImpl(op, *before, after);
}
void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
+ LDBG() << "visitBlock (forward): " << block;
// If the block is not executable, bail out.
ProgramPoint *point = getProgramPointBefore(block);
- if (!getOrCreateFor<Executable>(point, point)->isLive())
+ if (!getOrCreateFor<Executable>(point, point)->isLive()) {
+ LDBG() << " Block not executable, skipping";
return;
+ }
// Get the dense lattice to update.
AbstractDenseLattice *after = getLattice(point);
+ LDBG() << " Block lattice state: " << *after;
// The dense lattices of entry blocks are set by region control-flow or the
// callgraph.
if (block->isEntryBlock()) {
+ LDBG() << " Processing entry block";
// Check if this block is the entry block of a callable region.
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
if (callable && callable.getCallableRegion() == block->getParent()) {
+ LDBG() << " Entry block of callable region";
const auto *callsites = getOrCreateFor<PredecessorState>(
point, getProgramPointAfter(callable));
// If not all callsites are known, conservatively mark all lattices as
// having reached their pessimistic fixpoints. Do the same if
// interprocedural analysis is not enabled.
if (!callsites->allPredecessorsKnown() ||
- !getSolverConfig().isInterprocedural())
+ !getSolverConfig().isInterprocedural()) {
+ LDBG() << " Not all callsites known or non-interprocedural, setting "
+ "to entry state";
return setToEntryState(after);
+ }
+ LDBG() << " Processing " << callsites->getKnownPredecessors().size()
+ << " known callsites";
for (Operation *callsite : callsites->getKnownPredecessors()) {
+ LDBG() << " Processing callsite: "
+ << OpWithFlags(callsite, OpPrintingFlags().skipRegions());
// Get the dense lattice before the callsite.
const AbstractDenseLattice *before;
before = getLatticeFor(point, getProgramPointBefore(callsite));
+ LDBG() << " Lattice before callsite: " << *before;
visitCallControlFlowTransfer(cast<CallOpInterface>(callsite),
CallControlFlowAction::EnterCallee,
@@ -174,23 +240,32 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
}
// Check if we can reason about the control-flow.
- if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp()))
+ if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
+ LDBG() << " Entry block of region branch operation";
return visitRegionBranchOperation(point, branch, after);
+ }
// Otherwise, we can't reason about the data-flow.
+ LDBG() << " Cannot reason about data-flow, setting to entry state";
return setToEntryState(after);
}
// Join the state with the state after the block's predecessors.
+ LDBG() << " Joining state from "
+ << std::distance(block->pred_begin(), block->pred_end())
+ << " predecessors";
for (Block::pred_iterator it = block->pred_begin(), e = block->pred_end();
it != e; ++it) {
// Skip control edges that aren't executable.
Block *predecessor = *it;
if (!getOrCreateFor<Executable>(
point, getLatticeAnchor<CFGEdge>(predecessor, block))
- ->isLive())
+ ->isLive()) {
+ LDBG() << " Skipping non-executable edge from " << predecessor;
continue;
+ }
+ LDBG() << " Joining state from predecessor " << predecessor;
// Merge in the state from the predecessor's terminator.
join(after, *getLatticeFor(
point, getProgramPointAfter(predecessor->getTerminator())));
@@ -200,20 +275,34 @@ void AbstractDenseForwardDataFlowAnalysis::visitBlock(Block *block) {
void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
ProgramPoint *point, RegionBranchOpInterface branch,
AbstractDenseLattice *after) {
+ LDBG() << "visitRegionBranchOperation (forward): "
+ << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
+ LDBG() << " point: " << *point;
+ LDBG() << " after state: " << *after;
+
// Get the terminator predecessors.
const auto *predecessors = getOrCreateFor<PredecessorState>(point, point);
assert(predecessors->allPredecessorsKnown() &&
"unexpected unresolved region successors");
+ LDBG() << " Processing " << predecessors->getKnownPredecessors().size()
+ << " known predecessors";
for (Operation *op : predecessors->getKnownPredecessors()) {
+ LDBG() << " Processing predecessor: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
const AbstractDenseLattice *before;
// If the predecessor is the parent, get the state before the parent.
if (op == branch) {
+ LDBG() << " Predecessor is the branch itself, getting state before "
+ "parent";
before = getLatticeFor(point, getProgramPointBefore(op));
// Otherwise, get the state after the terminator.
} else {
+ LDBG()
+ << " Predecessor is terminator, getting state after terminator";
before = getLatticeFor(point, getProgramPointAfter(op));
}
+ LDBG() << " before state: " << *before;
// This function is called in two cases:
// 1. when visiting the block (point = block start);
@@ -231,19 +320,31 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
std::optional<unsigned> regionFrom =
op == branch ? std::optional<unsigned>()
: op->getBlock()->getParent()->getRegionNumber();
+ LDBG() << " regionFrom: "
+ << (regionFrom ? std::to_string(*regionFrom) : "parent");
+
if (point->isBlockStart()) {
unsigned regionTo = point->getBlock()->getParent()->getRegionNumber();
+ LDBG() << " Point is block start, regionTo: " << regionTo;
+ LDBG() << " Calling visitRegionBranchControlFlowTransfer with "
+ "regionFrom/regionTo";
visitRegionBranchControlFlowTransfer(branch, regionFrom, regionTo,
*before, after);
} else {
assert(point->getPrevOp() == branch &&
"expected to be visiting the branch itself");
+ LDBG() << " Point is not block start, checking if predecessor is "
+ "region or op itself";
// Only need to call the arc transfer when the predecessor is the region
// or the op itself, not the previous op.
if (op->getParentOp() == branch || op == branch) {
+ LDBG() << " Predecessor is region or op itself, calling "
+ "visitRegionBranchControlFlowTransfer";
visitRegionBranchControlFlowTransfer(
branch, regionFrom, /*regionTo=*/std::nullopt, *before, after);
} else {
+ LDBG()
+ << " Predecessor is not region or op itself, performing join";
join(after, *before);
}
}
@@ -256,35 +357,61 @@ void AbstractDenseForwardDataFlowAnalysis::visitRegionBranchOperation(
void AbstractDenseBackwardDataFlowAnalysis::initializeEquivalentLatticeAnchor(
Operation *top) {
+ LDBG() << "initializeEquivalentLatticeAnchor (backward): "
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
top->walk([&](Operation *op) {
- if (isa<RegionBranchOpInterface, CallOpInterface>(op))
+ if (isa<RegionBranchOpInterface, CallOpInterface>(op)) {
+ LDBG() << " Skipping "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " (region branch or call)";
return;
+ }
+ LDBG() << " Building equivalent lattice anchor for "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
buildOperationEquivalentLatticeAnchor(op);
});
}
LogicalResult
AbstractDenseBackwardDataFlowAnalysis::initialize(Operation *top) {
+ LDBG() << "initialize (backward): "
+ << OpWithFlags(top, OpPrintingFlags().skipRegions());
// Visit every operation and block.
- if (failed(processOperation(top)))
+ if (failed(processOperation(top))) {
+ LDBG() << " Failed to process top-level operation";
return failure();
+ }
for (Region &region : top->getRegions()) {
+ LDBG() << " Processing region with " << region.getBlocks().size()
+ << " blocks";
for (Block &block : region) {
+ LDBG() << " Processing block with " << block.getOperations().size()
+ << " operations";
visitBlock(&block);
for (Operation &op : llvm::reverse(block)) {
- if (failed(initialize(&op)))
+ LDBG() << " Initializing operation (backward): "
+ << OpWithFlags(&op, OpPrintingFlags().skipRegions());
+ if (failed(initialize(&op))) {
+ LDBG() << " Failed to initialize operation";
return failure();
+ }
}
}
}
+ LDBG() << " Backward initialization completed successfully";
return success();
}
LogicalResult
AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
- if (!point->isBlockEnd())
+ LDBG() << "visit (backward): " << *point;
+ if (!point->isBlockEnd()) {
+ LDBG() << " Processing operation: "
+ << OpWithFlags(point->getNextOp(), OpPrintingFlags().skipRegions());
return processOperation(point->getNextOp());
+ }
+ LDBG() << " Visiting block: " << point->getBlock();
visitBlock(point->getBlock());
return success();
}
@@ -292,28 +419,47 @@ AbstractDenseBackwardDataFlowAnalysis::visit(ProgramPoint *point) {
void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
CallOpInterface call, const AbstractDenseLattice &after,
AbstractDenseLattice *before) {
+ LDBG() << "visitCallOperation (backward): "
+ << OpWithFlags(call.getOperation(), OpPrintingFlags().skipRegions());
+ LDBG() << " after state: " << after;
+ LDBG() << " before state: " << *before;
+
// If the solver is not interprocedural, let the hook handle it as an external
// callee.
- if (!getSolverConfig().isInterprocedural())
+ if (!getSolverConfig().isInterprocedural()) {
+ LDBG() << " Non-interprocedural analysis, handling as external callee";
return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, after, before);
+ }
// Find the callee.
Operation *callee = call.resolveCallableInTable(&symbolTable);
+ if (callee) {
+ LDBG() << " Resolved callee: "
+ << OpWithFlags(callee, OpPrintingFlags().skipRegions());
+ } else {
+ LDBG() << " Resolved callee: null";
+ }
auto callable = dyn_cast_or_null<CallableOpInterface>(callee);
// No region means the callee is only declared in this module.
// If that is the case or if the solver is not interprocedural,
// let the hook handle it.
- if (callable &&
- (!callable.getCallableRegion() || callable.getCallableRegion()->empty()))
+ if (callable && (!callable.getCallableRegion() ||
+ callable.getCallableRegion()->empty())) {
+ LDBG() << " Callee has no region or empty region, handling as external "
+ "callee";
return visitCallControlFlowTransfer(
call, CallControlFlowAction::ExternalCallee, after, before);
+ }
- if (!callable)
+ if (!callable) {
+ LDBG() << " No callable found, setting to exit state";
return setToExitState(before);
+ }
Region *region = callable.getCallableRegion();
+ LDBG() << " Processing callable with region";
// Call-level control flow specifies the data flow here.
//
@@ -332,6 +478,7 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
ProgramPoint *calleeEntry = getProgramPointBefore(calleeEntryBlock);
const AbstractDenseLattice &latticeAtCalleeEntry =
*getLatticeFor(getProgramPointBefore(call.getOperation()), calleeEntry);
+ LDBG() << " Lattice at callee entry: " << latticeAtCalleeEntry;
AbstractDenseLattice *latticeBeforeCall = before;
visitCallControlFlowTransfer(call, CallControlFlowAction::EnterCallee,
latticeAtCalleeEntry, latticeBeforeCall);
@@ -339,12 +486,16 @@ void AbstractDenseBackwardDataFlowAnalysis::visitCallOperation(
LogicalResult
AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
+ LDBG() << "processOperation (backward): "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
ProgramPoint *point = getProgramPointBefore(op);
// If the containing block is not executable, bail out.
if (op->getBlock() != nullptr &&
!getOrCreateFor<Executable>(point, getProgramPointBefore(op->getBlock()))
- ->isLive())
+ ->isLive()) {
+ LDBG() << " Block not executable, skipping operation";
return success();
+ }
// Get the dense lattice to update.
AbstractDenseLattice *before = getLattice(point);
@@ -352,30 +503,39 @@ AbstractDenseBackwardDataFlowAnalysis::processOperation(Operation *op) {
// Get the dense state after execution of this op.
const AbstractDenseLattice *after =
getLatticeFor(point, getProgramPointAfter(op));
+ LDBG() << " before state: " << *before;
+ LDBG() << " after state: " << *after;
// Special cases where control flow may dictate data flow.
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG() << " Processing as region branch operation";
visitRegionBranchOperation(point, branch, RegionBranchPoint::parent(),
before);
return success();
}
if (auto call = dyn_cast<CallOpInterface>(op)) {
+ LDBG() << " Processing as call operation";
visitCallOperation(call, *after, before);
return success();
}
// Invoke the operation transfer function.
+ LDBG() << " Invoking operation transfer function";
return visitOperationImpl(op, *after, before);
}
void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
+ LDBG() << "visitBlock (backward): " << block;
ProgramPoint *point = getProgramPointAfter(block);
// If the block is not executable, bail out.
if (!getOrCreateFor<Executable>(point, getProgramPointBefore(block))
- ->isLive())
+ ->isLive()) {
+ LDBG() << " Block not executable, skipping";
return;
+ }
AbstractDenseLattice *before = getLattice(point);
+ LDBG() << " Block lattice state: " << *before;
// We need "exit" blocks, i.e. the blocks that may return control to the
// parent operation.
@@ -391,23 +551,32 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
b->getTerminator());
};
if (isExitBlock(block)) {
+ LDBG() << " Processing exit block";
// If this block is exiting from a callable, the successors of exiting from
// a callable are the successors of all call sites. And the call sites
// themselves are predecessors of the callable.
auto callable = dyn_cast<CallableOpInterface>(block->getParentOp());
if (callable && callable.getCallableRegion() == block->getParent()) {
+ LDBG() << " Exit block of callable region";
const auto *callsites = getOrCreateFor<PredecessorState>(
point, getProgramPointAfter(callable));
// If not all call sites are known, conservative mark all lattices as
// having reached their pessimistic fix points.
if (!callsites->allPredecessorsKnown() ||
!getSolverConfig().isInterprocedural()) {
+ LDBG() << " Not all callsites known or non-interprocedural, setting "
+ "to exit state";
return setToExitState(before);
}
+ LDBG() << " Processing " << callsites->getKnownPredecessors().size()
+ << " known callsites";
for (Operation *callsite : callsites->getKnownPredecessors()) {
+ LDBG() << " Processing callsite: "
+ << OpWithFlags(callsite, OpPrintingFlags().skipRegions());
const AbstractDenseLattice *after =
getLatticeFor(point, getProgramPointAfter(callsite));
+ LDBG() << " Lattice after callsite: " << *after;
visitCallControlFlowTransfer(cast<CallOpInterface>(callsite),
CallControlFlowAction::ExitCallee, *after,
before);
@@ -418,22 +587,29 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
// If this block is exiting from an operation with region-based control
// flow, propagate the lattice back along the control flow edge.
if (auto branch = dyn_cast<RegionBranchOpInterface>(block->getParentOp())) {
+ LDBG() << " Exit block of region branch operation";
visitRegionBranchOperation(point, branch, block->getParent(), before);
return;
}
// Cannot reason about successors of an exit block, set the pessimistic
// fixpoint.
+ LDBG() << " Cannot reason about successors, setting to exit state";
return setToExitState(before);
}
// Meet the state with the state before block's successors.
+ LDBG() << " Meeting state from " << block->getSuccessors().size()
+ << " successors";
for (Block *successor : block->getSuccessors()) {
if (!getOrCreateFor<Executable>(point,
getLatticeAnchor<CFGEdge>(block, successor))
- ->isLive())
+ ->isLive()) {
+ LDBG() << " Skipping non-executable edge to " << successor;
continue;
+ }
+ LDBG() << " Meeting state from successor " << successor;
// Merge in the state from the successor: either the first operation, or the
// block itself when empty.
meet(before, *getLatticeFor(point, getProgramPointBefore(successor)));
@@ -443,28 +619,39 @@ void AbstractDenseBackwardDataFlowAnalysis::visitBlock(Block *block) {
void AbstractDenseBackwardDataFlowAnalysis::visitRegionBranchOperation(
ProgramPoint *point, RegionBranchOpInterface branch,
RegionBranchPoint branchPoint, AbstractDenseLattice *before) {
+ LDBG() << "visitRegionBranchOperation (backward): "
+ << OpWithFlags(branch.getOperation(), OpPrintingFlags().skipRegions());
+ LDBG() << " branchPoint: " << (branchPoint.isParent() ? "parent" : "region");
+ LDBG() << " before state: " << *before;
// The successors of the operation may be either the first operation of the
// entry block of each possible successor region, or the next operation when
// the branch is a successor of itself.
SmallVector<RegionSuccessor> successors;
branch.getSuccessorRegions(branchPoint, successors);
+ LDBG() << " Processing " << successors.size() << " successor regions";
for (const RegionSuccessor &successor : successors) {
const AbstractDenseLattice *after;
if (successor.isParent() || successor.getSuccessor()->empty()) {
+ LDBG() << " Successor is parent or empty region";
after = getLatticeFor(point, getProgramPointAfter(branch));
} else {
Region *successorRegion = successor.getSuccessor();
assert(!successorRegion->empty() && "unexpected empty successor region");
Block *successorBlock = &successorRegion->front();
+ LDBG() << " Successor region with "
+ << successorRegion->getBlocks().size() << " blocks";
if (!getOrCreateFor<Executable>(point,
getProgramPointBefore(successorBlock))
- ->isLive())
+ ->isLive()) {
+ LDBG() << " Successor block not executable, skipping";
continue;
+ }
after = getLatticeFor(point, getProgramPointBefore(successorBlock));
}
+ LDBG() << " After state: " << *after;
visitRegionBranchControlFlowTransfer(branch, branchPoint, successor, *after,
before);
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 1f239aa..519d9c8 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -21,6 +21,7 @@
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/LogicalResult.h"
namespace mlir {
#define GEN_PASS_DEF_SCFTOEMITC
@@ -106,7 +107,7 @@ static void assignValues(ValueRange values, ValueRange variables,
emitc::AssignOp::create(rewriter, loc, var, value);
}
-SmallVector<Value> loadValues(const SmallVector<Value> &variables,
+SmallVector<Value> loadValues(ArrayRef<Value> variables,
PatternRewriter &rewriter, Location loc) {
return llvm::map_to_vector<>(variables, [&](Value var) {
Type type = cast<emitc::LValueType>(var.getType()).getValueType();
@@ -116,16 +117,15 @@ SmallVector<Value> loadValues(const SmallVector<Value> &variables,
static LogicalResult lowerYield(Operation *op, ValueRange resultVariables,
ConversionPatternRewriter &rewriter,
- scf::YieldOp yield) {
+ scf::YieldOp yield, bool createYield = true) {
Location loc = yield.getLoc();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(yield);
SmallVector<Value> yieldOperands;
- if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands))) {
+ if (failed(rewriter.getRemappedValues(yield.getOperands(), yieldOperands)))
return rewriter.notifyMatchFailure(op, "failed to lower yield operands");
- }
assignValues(yieldOperands, resultVariables, rewriter, loc);
@@ -336,11 +336,177 @@ LogicalResult IndexSwitchOpLowering::matchAndRewrite(
return success();
}
+// Lower scf::while to emitc::do using mutable variables to maintain loop state
+// across iterations. The do-while structure ensures the condition is evaluated
+// after each iteration, matching SCF while semantics.
+struct WhileLowering : public OpConversionPattern<WhileOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(WhileOp whileOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = whileOp.getLoc();
+ MLIRContext *context = loc.getContext();
+
+ // Create an emitc::variable op for each result. These variables will be
+ // assigned to by emitc::assign ops within the loop body.
+ SmallVector<Value> resultVariables;
+ if (failed(createVariablesForResults(whileOp, getTypeConverter(), rewriter,
+ resultVariables)))
+ return rewriter.notifyMatchFailure(whileOp,
+ "Failed to create result variables");
+
+ // Create variable storage for loop-carried values to enable imperative
+ // updates while maintaining SSA semantics at conversion boundaries.
+ SmallVector<Value> loopVariables;
+ if (failed(createVariablesForLoopCarriedValues(
+ whileOp, rewriter, loopVariables, loc, context)))
+ return failure();
+
+ if (failed(lowerDoWhile(whileOp, loopVariables, resultVariables, context,
+ rewriter, loc)))
+ return failure();
+
+ rewriter.setInsertionPointAfter(whileOp);
+
+ // Load the final result values from result variables.
+ SmallVector<Value> finalResults =
+ loadValues(resultVariables, rewriter, loc);
+ rewriter.replaceOp(whileOp, finalResults);
+
+ return success();
+ }
+
+private:
+ // Initialize variables for loop-carried values to enable state updates
+ // across iterations without SSA argument passing.
+ LogicalResult createVariablesForLoopCarriedValues(
+ WhileOp whileOp, ConversionPatternRewriter &rewriter,
+ SmallVectorImpl<Value> &loopVars, Location loc,
+ MLIRContext *context) const {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(whileOp);
+
+ emitc::OpaqueAttr noInit = emitc::OpaqueAttr::get(context, "");
+
+ for (Value init : whileOp.getInits()) {
+ Type convertedType = getTypeConverter()->convertType(init.getType());
+ if (!convertedType)
+ return rewriter.notifyMatchFailure(whileOp, "type conversion failed");
+
+ emitc::VariableOp var = rewriter.create<emitc::VariableOp>(
+ loc, emitc::LValueType::get(convertedType), noInit);
+ rewriter.create<emitc::AssignOp>(loc, var.getResult(), init);
+ loopVars.push_back(var);
+ }
+
+ return success();
+ }
+
+ // Lower scf.while to emitc.do.
+ LogicalResult lowerDoWhile(WhileOp whileOp, ArrayRef<Value> loopVars,
+ ArrayRef<Value> resultVars, MLIRContext *context,
+ ConversionPatternRewriter &rewriter,
+ Location loc) const {
+ // Create a global boolean variable to store the loop condition state.
+ Type i1Type = IntegerType::get(context, 1);
+ auto globalCondition =
+ rewriter.create<emitc::VariableOp>(loc, emitc::LValueType::get(i1Type),
+ emitc::OpaqueAttr::get(context, ""));
+ Value conditionVal = globalCondition.getResult();
+
+ auto loweredDo = rewriter.create<emitc::DoOp>(loc);
+
+ // Convert region types to match the target dialect type system.
+ if (failed(rewriter.convertRegionTypes(&whileOp.getBefore(),
+ *getTypeConverter(), nullptr)) ||
+ failed(rewriter.convertRegionTypes(&whileOp.getAfter(),
+ *getTypeConverter(), nullptr))) {
+ return rewriter.notifyMatchFailure(whileOp,
+ "region types conversion failed");
+ }
+
+ // Prepare the before region (condition evaluation) for merging.
+ Block *beforeBlock = &whileOp.getBefore().front();
+ Block *bodyBlock = rewriter.createBlock(&loweredDo.getBodyRegion());
+ rewriter.setInsertionPointToStart(bodyBlock);
+
+ // Load current variable values to use as initial arguments for the
+ // condition block.
+ SmallVector<Value> replacingValues = loadValues(loopVars, rewriter, loc);
+ rewriter.mergeBlocks(beforeBlock, bodyBlock, replacingValues);
+
+ Operation *condTerminator =
+ loweredDo.getBodyRegion().back().getTerminator();
+ scf::ConditionOp condOp = cast<scf::ConditionOp>(condTerminator);
+ rewriter.setInsertionPoint(condOp);
+
+ // Update result variables with values from scf::condition.
+ SmallVector<Value> conditionArgs;
+ for (Value arg : condOp.getArgs()) {
+ conditionArgs.push_back(rewriter.getRemappedValue(arg));
+ }
+ assignValues(conditionArgs, resultVars, rewriter, loc);
+
+ // Convert scf.condition to condition variable assignment.
+ Value condition = rewriter.getRemappedValue(condOp.getCondition());
+ rewriter.create<emitc::AssignOp>(loc, conditionVal, condition);
+
+ // Wrap body region in conditional to preserve scf semantics. Only create
+ // ifOp if after-region is non-empty.
+ if (whileOp.getAfterBody()->getOperations().size() > 1) {
+ auto ifOp = rewriter.create<emitc::IfOp>(loc, condition, false, false);
+
+ // Prepare the after region (loop body) for merging.
+ Block *afterBlock = &whileOp.getAfter().front();
+ Block *ifBodyBlock = rewriter.createBlock(&ifOp.getBodyRegion());
+
+ // Replacement values for after block using condition op arguments.
+ SmallVector<Value> afterReplacingValues;
+ for (Value arg : condOp.getArgs())
+ afterReplacingValues.push_back(rewriter.getRemappedValue(arg));
+
+ rewriter.mergeBlocks(afterBlock, ifBodyBlock, afterReplacingValues);
+
+ if (failed(lowerYield(whileOp, loopVars, rewriter,
+ cast<scf::YieldOp>(ifBodyBlock->getTerminator()))))
+ return failure();
+ }
+
+ rewriter.eraseOp(condOp);
+
+ // Create condition region that loads from the flag variable.
+ Region &condRegion = loweredDo.getConditionRegion();
+ Block *condBlock = rewriter.createBlock(&condRegion);
+ rewriter.setInsertionPointToStart(condBlock);
+
+ auto exprOp = rewriter.create<emitc::ExpressionOp>(
+ loc, i1Type, conditionVal, /*do_not_inline=*/false);
+ Block *exprBlock = rewriter.createBlock(&exprOp.getBodyRegion());
+
+ // Set up the expression block to load the condition variable.
+ exprBlock->addArgument(conditionVal.getType(), loc);
+ rewriter.setInsertionPointToStart(exprBlock);
+
+ // Load the condition value and yield it as the expression result.
+ Value cond =
+ rewriter.create<emitc::LoadOp>(loc, i1Type, exprBlock->getArgument(0));
+ rewriter.create<emitc::YieldOp>(loc, cond);
+
+ // Yield the expression as the condition region result.
+ rewriter.setInsertionPointToEnd(condBlock);
+ rewriter.create<emitc::YieldOp>(loc, exprOp);
+
+ return success();
+ }
+};
+
void mlir::populateSCFToEmitCConversionPatterns(RewritePatternSet &patterns,
TypeConverter &typeConverter) {
patterns.add<ForLowering>(typeConverter, patterns.getContext());
patterns.add<IfLowering>(typeConverter, patterns.getContext());
patterns.add<IndexSwitchOpLowering>(typeConverter, patterns.getContext());
+ patterns.add<WhileLowering>(typeConverter, patterns.getContext());
}
void SCFToEmitCPass::runOnOperation() {
@@ -357,7 +523,8 @@ void SCFToEmitCPass::runOnOperation() {
// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
- target.addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp>();
+ target
+ .addIllegalOp<scf::ForOp, scf::IfOp, scf::IndexSwitchOp, scf::WhileOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 57877b8..f449d90 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -214,6 +214,10 @@ static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
return op.getCacheControl();
}
+static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
+ return op.getCacheControl();
+}
+
static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
return op.getCacheControl();
}
@@ -222,6 +226,10 @@ static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
return op.getCacheControl();
}
+static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
+ return op.getCacheControl();
+}
+
static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
if (op->hasAttr("cache_control")) {
auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
@@ -263,6 +271,7 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
std::is_same_v<OpType, BlockPrefetch2dOp> ||
std::is_same_v<OpType, LLVM::LoadOp> ||
+ std::is_same_v<OpType, BlockLoadOp> ||
std::is_same_v<OpType, PrefetchOp>;
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
@@ -618,6 +627,77 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
return success();
}
};
+
+template <typename OpType>
+class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
+ // Get OpenCL function name
+ // https://registry.khronos.org/OpenCL/extensions/
+ // intel/cl_intel_subgroup_local_block_io.html
+ std::string funcName{"intel_sub_group_block_"};
+ // Value or Result type can be vector or scalar
+ Type valOrResTy;
+ if constexpr (isStore) {
+ funcName += "write_u";
+ valOrResTy = op.getVal().getType();
+ } else {
+ funcName += "read_u";
+ valOrResTy = op.getType();
+ }
+ // Get element type of the vector/scalar
+ VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
+ Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
+ funcName += getTypeMangling(elemType);
+ if (vecTy)
+ funcName += std::to_string(vecTy.getNumElements());
+ SmallVector<Type, 2> argTypes{};
+ // XeVM BlockLoad/StoreOp always use signless integer types
+ // but OpenCL builtins expect unsigned types
+ // use unsigned types for mangling
+ SmallVector<bool, 2> isUnsigned{};
+ // arg0: pointer to the src/dst address
+ // arg1 - only if store : vector to store
+ // Prepare arguments
+ SmallVector<Value, 2> args{};
+ args.push_back(op.getPtr());
+ argTypes.push_back(op.getPtr().getType());
+ isUnsigned.push_back(true);
+ Type retType;
+ if constexpr (isStore) {
+ args.push_back(op.getVal());
+ argTypes.push_back(op.getVal().getType());
+ isUnsigned.push_back(true);
+ retType = LLVM::LLVMVoidType::get(rewriter.getContext());
+ } else {
+ retType = valOrResTy;
+ }
+ funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
+ "PU3AS" +
+ std::to_string(op.getPtr().getType().getAddressSpace());
+ funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
+ if constexpr (isStore)
+ funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
+ LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
+
+ LLVM::CallOp call =
+ createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
+ {}, funcAttr, op.getOperation());
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op)) {
+ call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ }
+ if constexpr (isStore)
+ rewriter.eraseOp(op);
+ else
+ rewriter.replaceOp(op, call->getResult(0));
+ return success();
+ }
+};
+
template <typename OpType>
class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
@@ -693,7 +773,10 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
- LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext());
+ LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
+ BlockLoadStore1DToOCLPattern<BlockLoadOp>,
+ BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
+ patterns.getContext());
}
void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 5c8564b..4754f0b 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -974,10 +974,10 @@ LogicalResult emitc::YieldOp::verify() {
Value result = getResult();
Operation *containingOp = getOperation()->getParentOp();
- if (result && containingOp->getNumResults() != 1)
+ if (!isa<DoOp>(containingOp) && result && containingOp->getNumResults() != 1)
return emitOpError() << "yields a value not returned by parent";
- if (!result && containingOp->getNumResults() != 0)
+ if (!isa<DoOp>(containingOp) && !result && containingOp->getNumResults() != 0)
return emitOpError() << "does not yield a value to be returned by parent";
return success();
@@ -1562,6 +1562,76 @@ LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
}
//===----------------------------------------------------------------------===//
+// DoOp
+//===----------------------------------------------------------------------===//
+
+void DoOp::print(OpAsmPrinter &p) {
+ p << ' ';
+ p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
+ p << " while ";
+ p.printRegion(getConditionRegion());
+ p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs());
+}
+
+LogicalResult emitc::DoOp::verify() {
+ Block &condBlock = getConditionRegion().front();
+
+ if (condBlock.getOperations().size() != 2)
+ return emitOpError(
+ "condition region must contain exactly two operations: "
+ "'emitc.expression' followed by 'emitc.yield', but found ")
+ << condBlock.getOperations().size() << " operations";
+
+ Operation &first = condBlock.front();
+ auto exprOp = dyn_cast<emitc::ExpressionOp>(first);
+ if (!exprOp)
+ return emitOpError("expected first op in condition region to be "
+ "'emitc.expression', but got ")
+ << first.getName();
+
+ if (!exprOp.getResult().getType().isInteger(1))
+ return emitOpError("emitc.expression in condition region must return "
+ "'i1', but returns ")
+ << exprOp.getResult().getType();
+
+ Operation &last = condBlock.back();
+ auto condYield = dyn_cast<emitc::YieldOp>(last);
+ if (!condYield)
+ return emitOpError("expected last op in condition region to be "
+ "'emitc.yield', but got ")
+ << last.getName();
+
+ if (condYield.getNumOperands() != 1)
+ return emitOpError("expected condition region to return 1 value, but "
+ "it returns ")
+ << condYield.getNumOperands() << " values";
+
+ if (condYield.getOperand(0) != exprOp.getResult())
+ return emitError("'emitc.yield' must return result of "
+ "'emitc.expression' from this condition region");
+
+ Block &bodyBlock = getBodyRegion().front();
+ if (bodyBlock.mightHaveTerminator())
+ return emitOpError("body region must not contain terminator");
+
+ return success();
+}
+
+ParseResult DoOp::parse(OpAsmParser &parser, OperationState &result) {
+ Region *bodyRegion = result.addRegion();
+ Region *condRegion = result.addRegion();
+
+ if (parser.parseRegion(*bodyRegion) || parser.parseKeyword("while") ||
+ parser.parseRegion(*condRegion))
+ return failure();
+
+ if (bodyRegion->empty())
+ bodyRegion->emplaceBlock();
+
+ return parser.parseOptionalAttrDictWithKeyword(result.attributes);
+}
+
+//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 7863c21..0dac688 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1146,37 +1146,25 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
Location loc = packOp.getLoc();
- Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
- DenseMap<int64_t, OpFoldResult> dimAndTileMapping =
- packOp.getDimAndTileMapping();
int64_t srcRank = packOp.getSourceRank();
int64_t destRank = packOp.getDestRank();
- int64_t numTiles = destRank - srcRank;
+ ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
+ int64_t numberOfTiles = innerDimsPos.size();
- // 1. Extract the inner tile sizes.
- // Where possible, values are replaced with constant attributes (to match the
- // behaviour of `getPackOpSourceOrPaddedSource`).
- SmallVector<OpFoldResult> tileSizes;
- for (auto i : llvm::seq<unsigned>(0, srcRank)) {
- if (dimAndTileMapping.count(i)) {
- // Rather than taking the tile size as is, extact the actual constant
- // value Attribute where possible, e.g.:
- // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8]
- auto [_, tileSize] =
- getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter);
- tileSizes.push_back(tileSize);
- }
- }
+ // 1. Get the input that is going to be packed. If the input requires padding,
+ // add a padding operation and return that as the input.
+ Value input = getPackOpSourceOrPaddedSource(rewriter, packOp);
// 2. Transpose the input to match the inner tile order:
// %init = tensor.empty()
// %transposed_tile = linalg.transpose ins(%source_or_padded_source),
// outs(%init)
// Assumptions made:
- // 1. All outer dims are 1 - the corresponding transposition order doesn't
+ // - All outer dims are 1 - the corresponding transposition order doesn't
// matter, but requires all dim indices to be present.
+
+ // 2.1 Get the permutation for linalg.transpose
SmallVector<int64_t> srcPermForTranspose;
- ArrayRef<int64_t> innerDimPos(packOp.getInnerDimsPos());
for (int64_t i = 0; i < srcRank; i++) {
// We assume the `k` dimensions of the inner dim position, where `k` is the
// rank of the inner tiling, correspond to the last `k` indices of the
@@ -1185,27 +1173,34 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
// rank of the source tensor. For example if we have a source tensor with
// indices [0, 1, 2, 3] and inner dim position of [3, 0], the remaining
// indices are [1, 2]. and the transpose will be [1, 2, 3, 0].
- if (llvm::is_contained(innerDimPos, i))
+ if (llvm::is_contained(innerDimsPos, i))
continue;
srcPermForTranspose.push_back(i);
}
- srcPermForTranspose.append(innerDimPos.begin(), innerDimPos.end());
+ srcPermForTranspose.append(innerDimsPos.begin(), innerDimsPos.end());
+
+ // 2.2 Create the init tensor for linalg.transpose with the correct shape
+ SmallVector<OpFoldResult> shapeForEmptyOp(srcRank - numberOfTiles,
+ oneIdxAttr);
+ shapeForEmptyOp.append(packOp.getMixedTiles());
+
+ // getMixedTiles() may contain Values pointing to constant ops, not the
+ // constant attributes. Replace them with a true OpFoldResult.
+ llvm::transform(shapeForEmptyOp, shapeForEmptyOp.begin(),
+ [&](OpFoldResult ofr) {
+ if (auto val = llvm::dyn_cast<Value>(ofr))
+ return getAsOpFoldResult(val);
+ return ofr;
+ });
LDBG() << "Pack permutation: " << packOp;
LDBG() << "perm: " << llvm::interleaved(srcPermForTranspose);
+ LDBG() << "Shape of empty tensor: " << llvm::interleaved(shapeForEmptyOp);
- // 2.1 Create tensor.empty (init value for TransposeOp)
- SmallVector<OpFoldResult> transShapeForEmptyOp(srcRank - numTiles,
- oneIdxAttr);
- transShapeForEmptyOp.append(tileSizes);
-
- applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp,
- srcPermForTranspose);
- Value empty =
- tensor::EmptyOp::create(rewriter, loc, transShapeForEmptyOp,
- packOp.getSourceType().getElementType());
+ Value empty = tensor::EmptyOp::create(
+ rewriter, loc, shapeForEmptyOp, packOp.getSourceType().getElementType());
- // 2.2 Create linalg.transpose
+ // 2.3 Create linalg.transpose
auto transposedOp = linalg::TransposeOp::create(rewriter, loc, input, empty,
srcPermForTranspose);
@@ -1214,8 +1209,7 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr);
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr);
// Outer dims are all 1s!
- SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(),
- oneIdxAttr);
+ SmallVector<OpFoldResult> writeSizes(destRank - numberOfTiles, oneIdxAttr);
SmallVector<int64_t> writeShape;
for (auto tileSize : packOp.getMixedTiles()) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index ef172c1..37bdd8b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -186,11 +186,11 @@ constexpr float A_2x2_5x5[] = {
/// Structure to keep information of constant transform matrices.
struct TransformMatrix {
- TransformMatrix(const float *table, int64_t rows, int64_t cols,
+ TransformMatrix(ArrayRef<float> table, int64_t rows, int64_t cols,
int64_t scalarFactor = 1)
: table(table), rows(rows), cols(cols), scalarFactor(scalarFactor) {}
- const float *table;
+ ArrayRef<float> table;
int64_t rows;
int64_t cols;
int64_t scalarFactor;
@@ -199,14 +199,20 @@ struct TransformMatrix {
/// Utility function to convert constant array to arith.constant Value.
Value create2DTransformMatrix(OpBuilder &builder, Location loc,
TransformMatrix transform, Type type) {
- ArrayRef<float> constVec(transform.table, transform.rows * transform.cols);
-
+ assert(transform.table.size() ==
+ static_cast<size_t>(transform.rows * transform.cols));
+ assert(type.isFloat() && "Only floats are supported by Winograd");
+ ArrayRef<float> constVec(transform.table.data(),
+ transform.rows * transform.cols);
+ auto constAttrVec =
+ llvm::map_to_vector<>(constVec, [&](const float v) -> Attribute {
+ return builder.getFloatAttr(type, v);
+ });
+ SmallVector<int64_t, 2> shape{transform.rows, transform.cols};
return arith::ConstantOp::create(
builder, loc,
- DenseFPElementsAttr::get(
- RankedTensorType::get(
- SmallVector<int64_t>{transform.rows, transform.cols}, type),
- constVec));
+ DenseFPElementsAttr::get(RankedTensorType::get(shape, type),
+ constAttrVec));
}
/// Extract height x width data from 4D tensors.
@@ -551,8 +557,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
- Value BT =
- create2DTransformMatrix(builder, loc, BTMatrix, builder.getF32Type());
+ Value BT = create2DTransformMatrix(builder, loc, BTMatrix, elementType);
// Multiply BT x d.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{BT, matmulRetValue},
@@ -574,8 +579,7 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
.getResult();
auto init =
linalg::FillOp::create(builder, loc, zero, empty).getResult(0);
- Value B =
- create2DTransformMatrix(builder, loc, BMatrix, builder.getF32Type());
+ Value B = create2DTransformMatrix(builder, loc, BMatrix, elementType);
// Multiply v = (BT x d) x B.
auto matmulOp = linalg::MatmulOp::create(builder, loc, matmulType,
ValueRange{matmulRetValue, B},
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 697cb35..237aab4 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -27,7 +27,7 @@ using namespace mlir::nvgpu;
#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.cpp.inc"
-void nvgpu::NVGPUDialect::initialize() {
+void NVGPUDialect::initialize() {
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/NVGPU/IR/NVGPUTypeDefs.cpp.inc"
@@ -42,7 +42,7 @@ void nvgpu::NVGPUDialect::initialize() {
>();
}
-bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
+bool NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
if (!memorySpace)
return false;
if (auto intAttr = llvm::dyn_cast<IntegerAttr>(memorySpace))
@@ -52,7 +52,7 @@ bool nvgpu::NVGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) {
return false;
}
-bool nvgpu::NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
+bool NVGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) {
Attribute memorySpace = type.getMemorySpace();
return isSharedMemoryAddressSpace(memorySpace);
}
@@ -140,7 +140,6 @@ static LogicalResult verifyMmaSyncOp(Operation *op,
TypedValue<VectorType> matrixC,
const std::array<int64_t, 3> &mmaShape,
bool tf32Enabled, bool sparse = false) {
-
// The verification for mma.sync covering various shapes and data types is
// based on the fundamental tensor core shape.
@@ -292,7 +291,6 @@ LogicalResult MmaSparseSyncOp::verify() {
// NVGPU_LdMatrixOp
//===----------------------------------------------------------------------===//
LogicalResult LdMatrixOp::verify() {
-
// ldmatrix reads data from source in shared memory
auto srcMemref = llvm::cast<MemRefType>(getSrcMemref().getType());
@@ -345,7 +343,7 @@ LogicalResult LdMatrixOp::verify() {
// NVGPU_TmaAsyncLoadOp
//===----------------------------------------------------------------------===//
-unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
+static unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
switch (kind) {
case TensorMapSwizzleKind::SWIZZLE_32B:
return 32;
@@ -359,7 +357,7 @@ unsigned getSwizzleBytes(TensorMapSwizzleKind kind) {
}
std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
- Operation *op, nvgpu::TensorMapDescriptorType descType,
+ Operation *op, TensorMapDescriptorType descType,
std::optional<MemRefType> memrefType = std::nullopt) {
MemRefType descMemref = descType.getTensor();
// Limitation
@@ -655,8 +653,7 @@ LogicalResult WarpgroupMmaStoreOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult WarpgroupMmaInitAccumulatorOp::verify() {
-
- nvgpu::WarpgroupAccumulatorType accType = getMatrixC().getType();
+ WarpgroupAccumulatorType accType = getMatrixC().getType();
int64_t sizeM = accType.getFragmented().getDimSize(0);
int64_t sizeN = accType.getFragmented().getDimSize(1);
Type elemType = accType.getFragmented().getElementType();
diff --git a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
index 46e82bd..2a857ed 100644
--- a/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
+++ b/mlir/lib/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp
@@ -43,7 +43,7 @@ using namespace mlir::transform;
// Apply...ConversionPatternsOp
//===----------------------------------------------------------------------===//
-void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
+void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
TypeConverter &typeConverter, RewritePatternSet &patterns) {
auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
/// device-side async tokens cannot be materialized in nvvm. We just
@@ -62,62 +62,58 @@ void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
llvm_unreachable("unknown address space enum value");
return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
});
- llvmTypeConverter.addConversion(
- [&](nvgpu::DeviceAsyncTokenType type) -> Type {
- return llvmTypeConverter.convertType(
- IntegerType::get(type.getContext(), 32));
- });
- llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
+ llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) -> Type {
+ return llvmTypeConverter.convertType(
+ IntegerType::get(type.getContext(), 32));
+ });
+ llvmTypeConverter.addConversion([&](MBarrierTokenType type) -> Type {
return llvmTypeConverter.convertType(
IntegerType::get(type.getContext(), 64));
});
- llvmTypeConverter.addConversion(
- [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
- Type elemType = type.getFragmented().getElementType();
- int64_t sizeM = type.getFragmented().getDimSize(0);
- int64_t sizeN = type.getFragmented().getDimSize(1);
-
- unsigned numMembers;
- if (elemType.isF32() || elemType.isInteger(32))
- numMembers = sizeN / 2;
- else if (elemType.isF16())
- numMembers = sizeN / 4;
- else
- llvm_unreachable("unsupported type for warpgroup accumulator");
-
- SmallVector<Type> innerStructBody;
- for (unsigned i = 0; i < numMembers; i++)
- innerStructBody.push_back(elemType);
- auto innerStructType = LLVM::LLVMStructType::getLiteral(
- type.getContext(), innerStructBody);
-
- SmallVector<Type> structBody;
- for (int i = 0; i < sizeM; i += kWgmmaSizeM)
- structBody.push_back(innerStructType);
-
- auto convertedType =
- LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
- return llvmTypeConverter.convertType(convertedType);
- });
- llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
+ llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) -> Type {
+ Type elemType = type.getFragmented().getElementType();
+ int64_t sizeM = type.getFragmented().getDimSize(0);
+ int64_t sizeN = type.getFragmented().getDimSize(1);
+
+ unsigned numMembers;
+ if (elemType.isF32() || elemType.isInteger(32))
+ numMembers = sizeN / 2;
+ else if (elemType.isF16())
+ numMembers = sizeN / 4;
+ else
+ llvm_unreachable("unsupported type for warpgroup accumulator");
+
+ SmallVector<Type> innerStructBody;
+ for (unsigned i = 0; i < numMembers; i++)
+ innerStructBody.push_back(elemType);
+ auto innerStructType =
+ LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
+
+ SmallVector<Type> structBody;
+ for (int i = 0; i < sizeM; i += kWgmmaSizeM)
+ structBody.push_back(innerStructType);
+
+ auto convertedType =
+ LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
+ return llvmTypeConverter.convertType(convertedType);
+ });
+ llvmTypeConverter.addConversion([&](MBarrierGroupType type) -> Type {
return llvmTypeConverter.convertType(
getMBarrierMemrefType(type.getContext(), type));
});
llvmTypeConverter.addConversion(
- [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
+ [&](WarpgroupMatrixDescriptorType type) -> Type {
return llvmTypeConverter.convertType(
IntegerType::get(type.getContext(), 64));
});
- llvmTypeConverter.addConversion(
- [&](nvgpu::TensorMapDescriptorType type) -> Type {
- return LLVM::LLVMPointerType::get(type.getContext());
- });
+ llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) -> Type {
+ return LLVM::LLVMPointerType::get(type.getContext());
+ });
populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns);
}
-LogicalResult
-transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
- transform::TypeConverterBuilderOpInterface builder) {
+LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
+ TypeConverterBuilderOpInterface builder) {
if (builder.getTypeConverterType() != "LLVMTypeConverter")
return emitOpError("expected LLVMTypeConverter");
return success();
@@ -127,17 +123,18 @@ transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
// CreateAsyncGroupsOp
//===---------------------------------------------------------------------===//
-void transform::CreateAsyncGroupsOp::getEffects(
+void CreateAsyncGroupsOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::consumesHandle(getTargetMutable(), effects);
- transform::producesHandle(getOperation()->getOpResults(), effects);
- transform::modifiesPayload(effects);
+ consumesHandle(getTargetMutable(), effects);
+ producesHandle(getOperation()->getOpResults(), effects);
+ modifiesPayload(effects);
}
-DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
- TransformRewriter &rewriter, Operation *target,
- ApplyToEachResultList &results, TransformState &state) {
- nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
+DiagnosedSilenceableFailure
+CreateAsyncGroupsOp::applyToOne(TransformRewriter &rewriter, Operation *target,
+ ApplyToEachResultList &results,
+ TransformState &state) {
+ createAsyncGroups(rewriter, target, getBypassL1());
results.push_back(target);
return DiagnosedSilenceableFailure::success();
}
@@ -218,7 +215,7 @@ collectStage0PipeliningOps(scf::ForOp forOp,
continue;
}
- if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
+ if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) {
ops.insert(&op);
ops.insert(std::make_move_iterator(barriers.begin()),
std::make_move_iterator(barriers.end()));
@@ -246,7 +243,7 @@ setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op,
unsigned iteration, unsigned depth) {
// Based on the order of copies within the loop we need to set the number
// of copies in flight, unless it is already set.
- auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
+ auto waitOp = dyn_cast<DeviceAsyncWaitOp>(op);
if (!waitOp || waitOp.getNumGroups())
return;
@@ -312,13 +309,12 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter,
// original number of iterations, in particular side-effect free operations
// and barriers, even if they cannot be predicated.
if (isMemoryEffectFree(op) ||
- isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
- nvgpu::DeviceAsyncWaitOp>(op)) {
+ isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) {
return op;
}
// Otherwise, only async copies can currently be predicated.
- auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
+ auto asyncCopyOp = dyn_cast<DeviceAsyncCopyOp>(op);
if (!asyncCopyOp)
return nullptr;
@@ -335,8 +331,8 @@ static Operation *replaceOpWithPredicatedOp(RewriterBase &rewriter,
Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0);
auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
originalSrcElement, c0Index);
- auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create(
- rewriter, loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
+ auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create(
+ rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
UnitAttr());
@@ -805,17 +801,16 @@ FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
rhsIndexFn, rhsShape);
Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
resIndexFn, resShape);
- res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape,
- info.tf32Enabled);
+ res =
+ MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled);
buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
resShape);
return res.getDefiningOp();
}
-DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
- transform::TransformRewriter &rewriter, LinalgOp linalgOp,
- transform::ApplyToEachResultList &results,
- transform::TransformState &state) {
+DiagnosedSilenceableFailure RewriteMatmulAsMmaSyncOp::applyToOne(
+ TransformRewriter &rewriter, LinalgOp linalgOp,
+ ApplyToEachResultList &results, TransformState &state) {
bool fail = true;
// TODO: more robust detection of matmulOp, with transposes etc.
if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
@@ -854,43 +849,42 @@ struct HopperBuilder {
HopperBuilder(RewriterBase &rewriter, Location loc)
: rewriter(rewriter), loc(loc) {}
- TypedValue<nvgpu::MBarrierGroupType>
+ TypedValue<MBarrierGroupType>
buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
/// Create tma descriptor op to initiate transfer from global to shared
/// memory. This must be done before the launch op, on the host.
- TypedValue<nvgpu::TensorMapDescriptorType>
+ TypedValue<TensorMapDescriptorType>
buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
gpu::LaunchOp launchOp);
/// Build a tma load from global memory to shared memory using `barrier` to
/// synchronize. Return the number of bytes that will be transferred.
- OpFoldResult
- buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
- TypedValue<MemRefType> sharedMemref,
- TypedValue<nvgpu::MBarrierGroupType> barrier,
- SmallVectorImpl<Operation *> &loadOps);
- void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
+ OpFoldResult buildTmaAsyncLoad(TypedValue<TensorMapDescriptorType> globalDesc,
+ TypedValue<MemRefType> sharedMemref,
+ TypedValue<MBarrierGroupType> barrier,
+ SmallVectorImpl<Operation *> &loadOps);
+ void buildBarrierArriveTx(TypedValue<MBarrierGroupType> barrier,
ArrayRef<OpFoldResult> sizes);
/// If threadIdx.x == 0 does TMA request + wait, else just wait.
/// Return the operation that performs the transfer on thread0.
// TODO: In the future, don't hardcode to thread 0 but elect a leader.
SmallVector<Operation *> buildPredicateLoadsOnThread0(
- ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
+ ArrayRef<TypedValue<TensorMapDescriptorType>> globalDescriptors,
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
- TypedValue<nvgpu::MBarrierGroupType> barrier);
+ TypedValue<MBarrierGroupType> barrier);
- void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
+ void buildTryWaitParity(TypedValue<MBarrierGroupType> barrier);
RewriterBase &rewriter;
Location loc;
};
SmallVector<Operation *> HopperBuilder::buildPredicateLoadsOnThread0(
- ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
+ ArrayRef<TypedValue<TensorMapDescriptorType>> globalDescriptors,
ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
- TypedValue<nvgpu::MBarrierGroupType> barrier) {
+ TypedValue<MBarrierGroupType> barrier) {
SmallVector<Operation *> loadOps;
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
@@ -931,22 +925,22 @@ static Attribute getSharedAddressSpaceAttribute(OpBuilder &b) {
// return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
}
-TypedValue<nvgpu::MBarrierGroupType>
+TypedValue<MBarrierGroupType>
HopperBuilder::buildAndInitBarrierInSharedMemory(OpFoldResult numThreads) {
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
- Value barrier = nvgpu::MBarrierCreateOp::create(
+ Value barrier = MBarrierCreateOp::create(
rewriter, loc,
- nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
+ MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
nvgpu::MBarrierInitOp::create(
rewriter, loc, barrier,
getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero,
Value());
gpu::BarrierOp::create(rewriter, loc);
- return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
+ return cast<TypedValue<MBarrierGroupType>>(barrier);
}
-TypedValue<nvgpu::TensorMapDescriptorType>
+TypedValue<TensorMapDescriptorType>
HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
gpu::LaunchOp launchOp) {
OpBuilder::InsertionGuard guard(rewriter);
@@ -962,29 +956,29 @@ HopperBuilder::buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
- Value desc = nvgpu::TmaCreateDescriptorOp::create(
+ Value desc = TmaCreateDescriptorOp::create(
rewriter, loc,
- nvgpu::TensorMapDescriptorType::get(
- rewriter.getContext(),
- MemRefType::Builder(memref.getType())
- .setMemorySpace(sharedMemorySpace),
- TensorMapSwizzleKind::SWIZZLE_NONE,
- TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
- TensorMapInterleaveKind::INTERLEAVE_NONE),
+ TensorMapDescriptorType::get(rewriter.getContext(),
+ MemRefType::Builder(memref.getType())
+ .setMemorySpace(sharedMemorySpace),
+ TensorMapSwizzleKind::SWIZZLE_NONE,
+ TensorMapL2PromoKind::L2PROMO_NONE,
+ TensorMapOOBKind::OOB_ZERO,
+ TensorMapInterleaveKind::INTERLEAVE_NONE),
unrankedMemRef, sizes);
- return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
+ return cast<TypedValue<TensorMapDescriptorType>>(desc);
}
-OpFoldResult HopperBuilder::buildTmaAsyncLoad(
- TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
- TypedValue<MemRefType> sharedMemref,
- TypedValue<nvgpu::MBarrierGroupType> barrier,
- SmallVectorImpl<Operation *> &loadOps) {
+OpFoldResult
+HopperBuilder::buildTmaAsyncLoad(TypedValue<TensorMapDescriptorType> globalDesc,
+ TypedValue<MemRefType> sharedMemref,
+ TypedValue<MBarrierGroupType> barrier,
+ SmallVectorImpl<Operation *> &loadOps) {
MLIRContext *ctx = rewriter.getContext();
Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
- Operation *loadOp = nvgpu::TmaAsyncLoadOp::create(
- rewriter, loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero},
- zero, Value(), Value());
+ Operation *loadOp =
+ TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc,
+ ValueRange{zero, zero}, zero, Value(), Value());
loadOps.push_back(loadOp);
auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
SmallVector<AffineExpr> symbols(mixedSizes.size());
@@ -997,9 +991,8 @@ OpFoldResult HopperBuilder::buildTmaAsyncLoad(
return res;
}
-void HopperBuilder::buildBarrierArriveTx(
- TypedValue<nvgpu::MBarrierGroupType> barrier,
- ArrayRef<OpFoldResult> mixedSizes) {
+void HopperBuilder::buildBarrierArriveTx(TypedValue<MBarrierGroupType> barrier,
+ ArrayRef<OpFoldResult> mixedSizes) {
assert(!mixedSizes.empty() && "expecte non-empty sizes");
MLIRContext *ctx = rewriter.getContext();
SmallVector<AffineExpr> symbols(mixedSizes.size());
@@ -1013,8 +1006,7 @@ void HopperBuilder::buildBarrierArriveTx(
Value());
}
-void HopperBuilder::buildTryWaitParity(
- TypedValue<nvgpu::MBarrierGroupType> barrier) {
+void HopperBuilder::buildTryWaitParity(TypedValue<MBarrierGroupType> barrier) {
Type i1 = rewriter.getI1Type();
Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
// 10M is an arbitrary, not too small or too big number to specify the number
@@ -1058,11 +1050,11 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
launchOp.getBlockSizeZ()});
- TypedValue<nvgpu::MBarrierGroupType> barrier =
+ TypedValue<MBarrierGroupType> barrier =
buildAndInitBarrierInSharedMemory(numThreads);
SmallVector<TypedValue<MemRefType>> shmems;
- SmallVector<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescs;
+ SmallVector<TypedValue<TensorMapDescriptorType>> globalDescs;
for (Operation *op : copyOps) {
auto copyOp = cast<linalg::CopyOp>(op);
auto inMemRef =
@@ -1071,7 +1063,7 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
"expected in to be a 2D memref");
// 2. Build global memory descriptor.
- TypedValue<nvgpu::TensorMapDescriptorType> globalDesc =
+ TypedValue<TensorMapDescriptorType> globalDesc =
buildGlobalMemRefDescriptor(inMemRef, launchOp);
globalDescs.push_back(globalDesc);
@@ -1098,9 +1090,8 @@ SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
}
DiagnosedSilenceableFailure
-transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
- transform::TransformResults &results,
- transform::TransformState &state) {
+RewriteCopyAsTmaOp::apply(TransformRewriter &rewriter,
+ TransformResults &results, TransformState &state) {
auto payloadOps = state.getPayloadOps(getTarget());
gpu::LaunchOp commonLaunchOp;
Operation *firstOp, *failingOp;
@@ -1137,15 +1128,14 @@ transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
namespace {
class NVGPUTransformDialectExtension
- : public transform::TransformDialectExtension<
- NVGPUTransformDialectExtension> {
+ : public TransformDialectExtension<NVGPUTransformDialectExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
NVGPUTransformDialectExtension() {
declareGeneratedDialect<arith::ArithDialect>();
declareGeneratedDialect<affine::AffineDialect>();
- declareGeneratedDialect<nvgpu::NVGPUDialect>();
+ declareGeneratedDialect<NVGPUDialect>();
declareGeneratedDialect<NVVM::NVVMDialect>();
declareGeneratedDialect<vector::VectorDialect>();
registerTransformOps<
diff --git a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
index 5b89c87..7f626a6 100644
--- a/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
+++ b/mlir/lib/Dialect/NVGPU/Transforms/MmaSyncTF32Transform.cpp
@@ -64,6 +64,5 @@ private:
void mlir::nvgpu::populateMmaSyncF32ToTF32Patterns(
RewritePatternSet &patterns, nvgpu::MmaSyncF32Lowering precision) {
-
patterns.add<MmaSyncF32ToTF32Pattern>(patterns.getContext(), precision);
}
diff --git a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
index 809d634..9e5ea93 100644
--- a/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
+++ b/mlir/lib/Dialect/NVGPU/Utils/MMAUtils.cpp
@@ -168,8 +168,7 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
const WarpMatrixInfo &fragmentType) {
Type elementType = fragmentType.vectorType.getElementType();
ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape();
- FailureOr<nvgpu::FragmentElementInfo> regInfo =
- getMmaSyncRegisterType(fragmentType);
+ FailureOr<FragmentElementInfo> regInfo = getMmaSyncRegisterType(fragmentType);
if (failed(regInfo))
return failure();
@@ -199,8 +198,8 @@ nvgpu::getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc,
(logicalValueIdDim % elementsPerRegister)});
}
-FailureOr<nvgpu::LdMatrixParams>
-nvgpu::getLdMatrixParams(const WarpMatrixInfo &type, bool transpose) {
+FailureOr<LdMatrixParams> nvgpu::getLdMatrixParams(const WarpMatrixInfo &type,
+ bool transpose) {
LdMatrixParams params;
Type elType = type.vectorType.getElementType();
params.fragmentType = type.vectorType;
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 6598ac1..6564a4e 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -7,6 +7,7 @@
// =============================================================================
#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -44,6 +45,7 @@ struct MemRefPointerLikeModel
Type getElementType(Type pointer) const {
return cast<MemRefType>(pointer).getElementType();
}
+
mlir::acc::VariableTypeCategory
getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
Type varType) const {
@@ -70,6 +72,115 @@ struct MemRefPointerLikeModel
assert(memrefTy.getRank() > 0 && "rank expected to be positive");
return mlir::acc::VariableTypeCategory::array;
}
+
+ mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
+ StringRef varName, Type varType,
+ Value originalVar) const {
+ auto memrefTy = cast<MemRefType>(pointer);
+
+ // Check if this is a static memref (all dimensions are known) - if yes
+ // then we can generate an alloca operation.
+ if (memrefTy.hasStaticShape())
+ return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
+
+ // For dynamic memrefs, extract sizes from the original variable if
+ // provided. Otherwise they cannot be handled.
+ if (originalVar && originalVar.getType() == memrefTy &&
+ memrefTy.hasRank()) {
+ SmallVector<Value> dynamicSizes;
+ for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
+ if (memrefTy.isDynamicDim(i)) {
+ // Extract the size of dimension i from the original variable
+ auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
+ auto dimSize =
+ memref::DimOp::create(builder, loc, originalVar, indexValue);
+ dynamicSizes.push_back(dimSize);
+ }
+ // Note: We only add dynamic sizes to the dynamicSizes array
+ // Static dimensions are handled automatically by AllocOp
+ }
+ return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
+ .getResult();
+ }
+
+ // TODO: Unranked not yet supported.
+ return {};
+ }
+
+ bool genFree(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> varPtr, Type varType) const {
+ if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varPtr)) {
+ // Walk through casts to find the original allocation
+ Value currentValue = memrefValue;
+ Operation *originalAlloc = nullptr;
+
+ // Follow the chain of operations to find the original allocation
+ // even if a casted result is provided.
+ while (currentValue) {
+ if (auto *definingOp = currentValue.getDefiningOp()) {
+ // Check if this is an allocation operation
+ if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
+ originalAlloc = definingOp;
+ break;
+ }
+
+ // Check if this is a cast operation we can look through
+ if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
+ currentValue = castOp.getSource();
+ continue;
+ }
+
+ // Check for other cast-like operations
+ if (auto reinterpretCastOp =
+ dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
+ currentValue = reinterpretCastOp.getSource();
+ continue;
+ }
+
+ // If we can't look through this operation, stop
+ break;
+ }
+ // This is a block argument or similar - can't trace further.
+ break;
+ }
+
+ if (originalAlloc) {
+ if (isa<memref::AllocaOp>(originalAlloc)) {
+ // This is an alloca - no dealloc needed, but return true (success)
+ return true;
+ }
+ if (isa<memref::AllocOp>(originalAlloc)) {
+ // This is an alloc - generate dealloc
+ memref::DeallocOp::create(builder, loc, memrefValue);
+ return true;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ bool genCopy(Type pointer, OpBuilder &builder, Location loc,
+ TypedValue<PointerLikeType> destination,
+ TypedValue<PointerLikeType> source, Type varType) const {
+ // Generate a copy operation between two memrefs
+ auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
+ auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
+
+ // As per memref documentation, source and destination must have same
+ // element type and shape in order to be compatible. We do not want to fail
+ // with an IR verification error - thus check that before generating the
+ // copy operation.
+ if (destMemref && srcMemref &&
+ destMemref.getType().getElementType() ==
+ srcMemref.getType().getElementType() &&
+ destMemref.getType().getShape() == srcMemref.getType().getShape()) {
+ memref::CopyOp::create(builder, loc, srcMemref, destMemref);
+ return true;
+ }
+
+ return false;
+ }
};
struct LLVMPointerPointerLikeModel
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index b0132e8..14e235f 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -47,6 +47,7 @@
#include <cassert>
#include <cstdint>
+#include <numeric>
#include "mlir/Dialect/Vector/IR/VectorDialect.cpp.inc"
// Pull in all enum type and utility function definitions.
@@ -2412,9 +2413,38 @@ foldToElementsFromElements(ToElementsOp toElementsOp,
return success();
}
+/// Folds vector.to_elements(vector.broadcast(%x)) for the scalar case only.
+///
+/// Example:
+/// %b = vector.broadcast %x : i32 to vector<3xf32>
+/// %e:3 = vector.to_elements %b : vector<3xf32>
+/// user_op %e#0, %e#1, %e#2
+/// becomes:
+/// user_op %x, %x, %x
+///
+/// The vector source case is handled by a canonicalization pattern.
+static LogicalResult
+foldToElementsOfBroadcast(ToElementsOp toElementsOp,
+ SmallVectorImpl<OpFoldResult> &results) {
+ auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+ if (!bcastOp)
+ return failure();
+ // Vectors are handled in the ToElementsOfBroadcast RewritePattern.
+ if (isa<VectorType>(bcastOp.getSource().getType()))
+ return failure();
+
+ auto resultVecType = cast<VectorType>(toElementsOp.getSource().getType());
+
+ Value scalar = bcastOp.getSource();
+ results.assign(resultVecType.getNumElements(), scalar);
+ return success();
+}
+
LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
SmallVectorImpl<OpFoldResult> &results) {
- return foldToElementsFromElements(*this, results);
+ if (succeeded(foldToElementsFromElements(*this, results)))
+ return success();
+ return foldToElementsOfBroadcast(*this, results);
}
LogicalResult
@@ -2427,6 +2457,94 @@ ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
return success();
}
+/// Canonicalize `vector.to_elements(vector.broadcast(%v))` where `%v` is a
+/// vector.
+/// - Build `vector.to_elements %v` and remap each destination element to the
+/// corresponding source element using broadcast rules (match or 1 →
+/// replicate).
+///
+/// Example:
+/// %v = vector.broadcast %src : vector<2xf32> to vector<3x2xf32>
+/// %e:6 = vector.to_elements %v : vector<3x2xf32>
+/// becomes:
+/// %src_elems:2 = vector.to_elements %src : vector<2xf32>
+/// // uses: %src_elems#0, %src_elems#1, %src_elems#0,
+/// // %src_elems#1, %src_elems#0, %src_elems#1
+struct ToElementsOfBroadcast final : OpRewritePattern<ToElementsOp> {
+ using Base::Base;
+
+ LogicalResult matchAndRewrite(ToElementsOp toElementsOp,
+ PatternRewriter &rewriter) const override {
+ auto bcastOp = toElementsOp.getSource().getDefiningOp<BroadcastOp>();
+ if (!bcastOp)
+ return failure();
+
+ // Only handle broadcasts from a vector source here.
+ auto srcType = dyn_cast<VectorType>(bcastOp.getSource().getType());
+ if (!srcType)
+ return failure();
+
+ auto dstType = cast<VectorType>(toElementsOp.getSource().getType());
+
+ ArrayRef<int64_t> dstShape = dstType.getShape();
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+
+ int64_t dstRank = dstShape.size();
+ int64_t srcRank = srcShape.size();
+
+ // Create elements for the broadcast source vector.
+ auto srcElems = vector::ToElementsOp::create(
+ rewriter, toElementsOp.getLoc(), bcastOp.getSource());
+
+ int64_t dstCount = std::accumulate(dstShape.begin(), dstShape.end(), 1,
+ std::multiplies<int64_t>());
+
+ SmallVector<Value> replacements;
+ replacements.reserve(dstCount);
+
+ // For each element of the destination, determine which element of the
+ // source should be used. We walk all destination positions using a single
+ // counter, decode it into per-dimension indices, then build the matching
+ // source position: use the same index where sizes match, and use 0 where
+ // the source size is 1 (replication). This mapping is needed so we can
+ // replace each result of to_elements with the corresponding element from
+ // the broadcast source.
+ // Inner-dimension stretch example:
+ // %v = vector.broadcast %src : vector<2x1x2xf32> to vector<2x3x2xf32>
+ // %e:12 = vector.to_elements %v : vector<2x3x2xf32>
+ // becomes:
+ // %src_elems:4 = vector.to_elements %src : vector<2x1x2xf32>
+ // // uses: %src_elems#0, %src_elems#1, %src_elems#0,
+ // // %src_elems#1, %src_elems#0, %src_elems#1,
+ // // %src_elems#2, %src_elems#3, %src_elems#2,
+ // // %src_elems#3, %src_elems#2, %src_elems#3
+
+ // Row-major strides for the destination shape.
+ SmallVector<int64_t> dstStrides = computeStrides(dstShape);
+ // Row-major strides for the source shape.
+ SmallVector<int64_t> srcStrides = computeStrides(srcShape);
+ SmallVector<int64_t> dstIdx(dstRank);
+ SmallVector<int64_t> srcIdx(srcRank);
+ for (int64_t lin = 0; lin < dstCount; ++lin) {
+ // Convert linear destination index to per-dimension indices.
+ dstIdx = delinearize(lin, dstStrides);
+ for (int64_t k = 0; k < srcRank; ++k)
+ srcIdx[k] = (srcShape[k] == 1) ? 0 : dstIdx[dstRank - srcRank + k];
+ // Convert per-dimension source indices back to a linear index.
+ int64_t srcLin = linearize(srcIdx, srcStrides);
+ replacements.push_back(srcElems.getResult(srcLin));
+ }
+
+ rewriter.replaceOp(toElementsOp, replacements);
+ return success();
+ }
+};
+
+void ToElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<ToElementsOfBroadcast>(context);
+}
+
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index a5bd80e..5fe5f41 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -201,6 +201,8 @@ struct CppEmitter {
/// Return the existing or a new label of a Block.
StringRef getOrCreateName(Block &block);
+ LogicalResult emitInlinedExpression(Value value);
+
/// Whether to map an mlir integer to a unsigned integer in C++.
bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);
@@ -557,6 +559,30 @@ static LogicalResult printOperation(CppEmitter &emitter,
return success();
}
+static LogicalResult printOperation(CppEmitter &emitter, emitc::DoOp doOp) {
+ raw_indented_ostream &os = emitter.ostream();
+
+ os << "do {\n";
+ os.indent();
+
+ Block &bodyBlock = doOp.getBodyRegion().front();
+ for (Operation &op : bodyBlock) {
+ if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/true)))
+ return failure();
+ }
+
+ os.unindent() << "} while (";
+
+ Block &condBlock = doOp.getConditionRegion().front();
+ auto condYield = cast<emitc::YieldOp>(condBlock.back());
+ if (failed(emitter.emitExpression(
+ cast<emitc::ExpressionOp>(condYield.getOperand(0).getDefiningOp()))))
+ return failure();
+
+ os << ");";
+ return success();
+}
+
static LogicalResult printOperation(CppEmitter &emitter, emitc::CmpOp cmpOp) {
Operation *operation = cmpOp.getOperation();
@@ -1711,13 +1737,14 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::BitwiseRightShiftOp, emitc::BitwiseXorOp, emitc::CallOp,
emitc::CallOpaqueOp, emitc::CastOp, emitc::ClassOp,
emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp,
- emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp,
- emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp,
- emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp,
- emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
- emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
- emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
- emitc::VariableOp, emitc::VerbatimOp>(
+ emitc::DeclareFuncOp, emitc::DivOp, emitc::DoOp,
+ emitc::ExpressionOp, emitc::FieldOp, emitc::FileOp,
+ emitc::ForOp, emitc::FuncOp, emitc::GlobalOp, emitc::IfOp,
+ emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp,
+ emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
+ emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp,
+ emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
+ emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
@@ -1765,9 +1792,9 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
// Never emit a semicolon for some operations, especially if endening with
// `}`.
trailingSemicolon &=
- !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::FileOp, emitc::ForOp,
- emitc::IfOp, emitc::IncludeOp, emitc::SwitchOp, emitc::VerbatimOp>(
- op);
+ !isa<cf::CondBranchOp, emitc::DeclareFuncOp, emitc::DoOp, emitc::FileOp,
+ emitc::ForOp, emitc::IfOp, emitc::IncludeOp, emitc::SwitchOp,
+ emitc::VerbatimOp>(op);
os << (trailingSemicolon ? ";\n" : "\n");
diff --git a/mlir/lib/Target/LLVMIR/DebugImporter.cpp b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
index 8b03265..4bbcd8e 100644
--- a/mlir/lib/Target/LLVMIR/DebugImporter.cpp
+++ b/mlir/lib/Target/LLVMIR/DebugImporter.cpp
@@ -59,7 +59,8 @@ DICompileUnitAttr DebugImporter::translateImpl(llvm::DICompileUnit *node) {
std::underlying_type_t<llvm::DICompileUnit::DebugNameTableKind>>(
node->getNameTableKind()));
return DICompileUnitAttr::get(
- context, getOrCreateDistinctID(node), node->getSourceLanguage(),
+ context, getOrCreateDistinctID(node),
+ node->getSourceLanguage().getUnversionedName(),
translate(node->getFile()), getStringAttrOrNull(node->getRawProducer()),
node->isOptimized(), emissionKind.value(), nameTableKind.value(),
getStringAttrOrNull(node->getRawSplitDebugFilename()));
diff --git a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
index 63c71cd..1e226c0 100644
--- a/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
+++ b/mlir/lib/Transforms/GenerateRuntimeVerification.cpp
@@ -51,9 +51,11 @@ public:
stream << "ERROR: Runtime op verification failed\n";
if (vLevel == 1) {
op->print(stream, state);
- stream << "\n";
+ stream << "\n^ " << msg;
+ } else {
+ stream << "^ " << msg;
}
- stream << "^\nLocation: ";
+ stream << "\nLocation: ";
op->getLoc().print(stream);
return buffer;
}