aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp29
-rw-r--r--mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp46
-rw-r--r--mlir/lib/Analysis/FlatLinearValueConstraints.cpp2
-rw-r--r--mlir/lib/Analysis/Presburger/IntegerRelation.cpp29
-rw-r--r--mlir/lib/Analysis/Presburger/Simplex.cpp2
-rw-r--r--mlir/lib/Bindings/Python/ExecutionEngineModule.cpp13
-rw-r--r--mlir/lib/Bindings/Python/Globals.h39
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp196
-rw-r--r--mlir/lib/Bindings/Python/IRModule.cpp70
-rw-r--r--mlir/lib/Bindings/Python/IRModule.h15
-rw-r--r--mlir/lib/Bindings/Python/MainModule.cpp23
-rw-r--r--mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp9
-rw-r--r--mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp37
-rw-r--r--mlir/lib/Conversion/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp17
-rw-r--r--mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp41
-rw-r--r--mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp26
-rw-r--r--mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp139
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp48
-rw-r--r--mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp62
-rw-r--r--mlir/lib/Conversion/LLVMCommon/Pattern.cpp121
-rw-r--r--mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp115
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp116
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp56
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp15
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp28
-rw-r--r--mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp2
-rw-r--r--mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp7
-rw-r--r--mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp4
-rw-r--r--mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp10
-rw-r--r--mlir/lib/Conversion/VectorToAMX/CMakeLists.txt19
-rw-r--r--mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp283
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp14
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp1
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt1
-rw-r--r--mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp313
-rw-r--r--mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp4
-rw-r--r--mlir/lib/Dialect/Arith/IR/ArithOps.cpp5
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp10
-rw-r--r--mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp4
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp3
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp1
-rw-r--r--mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp32
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp56
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp58
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp76
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp504
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp24
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp22
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp3
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp18
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp11
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp9
-rw-r--r--mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp1
-rw-r--r--mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp6
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp29
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp101
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp49
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp3
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp9
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp5
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp3
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp6
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp8
-rw-r--r--mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp11
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp2
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp10
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp5
-rw-r--r--mlir/lib/Dialect/Utils/StaticValueUtils.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp48
-rw-r--r--mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp5
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp65
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp33
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp3
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp26
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt5
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp322
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp252
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp409
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt2
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp20
-rw-r--r--mlir/lib/ExecutionEngine/ExecutionEngine.cpp20
-rw-r--r--mlir/lib/ExecutionEngine/JitRunner.cpp2
-rw-r--r--mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp3
-rw-r--r--mlir/lib/RegisterAllDialects.cpp2
-rw-r--r--mlir/lib/RegisterAllExtensions.cpp1
-rw-r--r--mlir/lib/Target/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp4
-rw-r--r--mlir/lib/Target/LLVM/CMakeLists.txt24
-rw-r--r--mlir/lib/Target/LLVM/XeVM/Target.cpp418
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp90
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp91
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp14
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp53
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.h2
-rw-r--r--mlir/lib/Target/SPIRV/TranslateRegistration.cpp59
-rw-r--r--mlir/lib/Target/Wasm/CMakeLists.txt13
-rw-r--r--mlir/lib/Target/Wasm/TranslateFromWasm.cpp1245
-rw-r--r--mlir/lib/Target/Wasm/TranslateRegistration.cpp28
-rw-r--r--mlir/lib/Transforms/RemoveDeadValues.cpp46
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp401
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp42
-rw-r--r--mlir/lib/Transforms/Utils/InliningUtils.cpp6
-rw-r--r--mlir/lib/Transforms/Utils/RegionUtils.cpp33
-rw-r--r--mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp126
111 files changed, 6087 insertions, 892 deletions
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 509f520..65df355 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -294,7 +294,34 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
solver.load<LivenessAnalysis>(symbolTable);
LDBG() << "Initializing and running solver";
(void)solver.initializeAndRun(op);
- LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName();
+ LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName()
+ << " check on unreachable code now:";
+ // The framework doesn't visit operations in dead blocks, so we need to
+ // explicitly mark them as dead.
+ op->walk([&](Operation *op) {
+ if (op->getNumResults() == 0)
+ return;
+ for (auto result : llvm::enumerate(op->getResults())) {
+ if (getLiveness(result.value()))
+ continue;
+ LDBG() << "Result: " << result.index() << " of "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " has no liveness info (unreachable), mark dead";
+ solver.getOrCreateState<Liveness>(result.value());
+ }
+ for (auto &region : op->getRegions()) {
+ for (auto &block : region) {
+ for (auto blockArg : llvm::enumerate(block.getArguments())) {
+ if (getLiveness(blockArg.value()))
+ continue;
+ LDBG() << "Block argument: " << blockArg.index() << " of "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions())
+ << " has no liveness info, mark dead";
+ solver.getOrCreateState<Liveness>(blockArg.value());
+ }
+ }
+ }
+ });
}
const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index e625f62..13a3e14 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -19,12 +19,15 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#include <optional>
using namespace mlir;
using namespace mlir::dataflow;
+#define DEBUG_TYPE "dataflow"
+
//===----------------------------------------------------------------------===//
// AbstractSparseLattice
//===----------------------------------------------------------------------===//
@@ -64,22 +67,36 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
LogicalResult
AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
+ LDBG() << "Initializing recursively for operation: " << op->getName();
+
// Initialize the analysis by visiting every owner of an SSA value (all
// operations and blocks).
- if (failed(visitOperation(op)))
+ if (failed(visitOperation(op))) {
+ LDBG() << "Failed to visit operation: " << op->getName();
return failure();
+ }
for (Region &region : op->getRegions()) {
+ LDBG() << "Processing region with " << region.getBlocks().size()
+ << " blocks";
for (Block &block : region) {
+ LDBG() << "Processing block with " << block.getNumArguments()
+ << " arguments";
getOrCreate<Executable>(getProgramPointBefore(&block))
->blockContentSubscribe(this);
visitBlock(&block);
- for (Operation &op : block)
- if (failed(initializeRecursively(&op)))
+ for (Operation &op : block) {
+ LDBG() << "Recursively initializing nested operation: " << op.getName();
+ if (failed(initializeRecursively(&op))) {
+ LDBG() << "Failed to initialize nested operation: " << op.getName();
return failure();
+ }
+ }
}
}
+ LDBG() << "Successfully completed recursive initialization for operation: "
+ << op->getName();
return success();
}
@@ -409,11 +426,20 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
LogicalResult
AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
+ LDBG() << "Visiting operation: " << op->getName() << " with "
+ << op->getNumOperands() << " operands and " << op->getNumResults()
+ << " results";
+
// If we're in a dead block, bail out.
if (op->getBlock() != nullptr &&
- !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
+ !getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
+ ->isLive()) {
+ LDBG() << "Operation is in dead block, bailing out";
return success();
+ }
+ LDBG() << "Creating lattice elements for " << op->getNumOperands()
+ << " operands and " << op->getNumResults() << " results";
SmallVector<AbstractSparseLattice *> operandLattices =
getLatticeElements(op->getOperands());
SmallVector<const AbstractSparseLattice *> resultLattices =
@@ -422,11 +448,15 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// Block arguments of region branch operations flow back into the operands
// of the parent op
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
+ LDBG() << "Processing RegionBranchOpInterface operation";
visitRegionSuccessors(branch, operandLattices);
return success();
}
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
+ LDBG() << "Processing BranchOpInterface operation with "
+ << op->getNumSuccessors() << " successors";
+
// Block arguments of successor blocks flow back into our operands.
// We remember all operands not forwarded to any block in a BitVector.
@@ -463,6 +493,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// For function calls, connect the arguments of the entry blocks to the
// operands of the call op that are forwarded to these arguments.
if (auto call = dyn_cast<CallOpInterface>(op)) {
+ LDBG() << "Processing CallOpInterface operation";
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
// Not all operands of a call op forward to arguments. Such operands are
@@ -513,6 +544,7 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
// of this op itself and the operands of the terminators of the regions of
// this op.
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
+ LDBG() << "Processing RegionBranchTerminatorOpInterface operation";
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
visitRegionSuccessorsFromTerminator(terminator, branch);
return success();
@@ -520,12 +552,16 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
}
if (op->hasTrait<OpTrait::ReturnLike>()) {
+ LDBG() << "Processing ReturnLike operation";
// Going backwards, the operands of the return are derived from the
// results of all CallOps calling this CallableOp.
- if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
+ if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
+ LDBG() << "Callable parent found, visiting callable operation";
return visitCallableOperation(op, callable, operandLattices);
+ }
}
+ LDBG() << "Using default visitOperationImpl for operation: " << op->getName();
return visitOperationImpl(op, operandLattices, resultLattices);
}
diff --git a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
index f4b02b4..30ce1fb 100644
--- a/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
+++ b/mlir/lib/Analysis/FlatLinearValueConstraints.cpp
@@ -60,7 +60,7 @@ private:
AffineExpr localExpr) override {
SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
// Update localVarCst.
- localVarCst.addLocalFloorDiv(dividend, divisor);
+ (void)localVarCst.addLocalFloorDiv(dividend, divisor);
}
LogicalResult addLocalIdSemiAffine(ArrayRef<int64_t> lhs,
diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
index 5c4d4d1..0dcdd5b 100644
--- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
+++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp
@@ -1500,12 +1500,13 @@ void IntegerRelation::addBound(BoundType type, ArrayRef<DynamicAPInt> expr,
/// respect to a positive constant 'divisor'. Two constraints are added to the
/// system to capture equivalence with the floordiv.
/// q = expr floordiv c <=> c*q <= expr <= c*q + c - 1.
-void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
- const DynamicAPInt &divisor) {
+/// Returns the column position of the new local variable.
+unsigned IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
+ const DynamicAPInt &divisor) {
assert(dividend.size() == getNumCols() && "incorrect dividend size");
assert(divisor > 0 && "positive divisor expected");
- appendVar(VarKind::Local);
+ unsigned newVar = appendVar(VarKind::Local);
SmallVector<DynamicAPInt, 8> dividendCopy(dividend);
dividendCopy.insert(dividendCopy.end() - 1, DynamicAPInt(0));
@@ -1513,6 +1514,28 @@ void IntegerRelation::addLocalFloorDiv(ArrayRef<DynamicAPInt> dividend,
getDivLowerBound(dividendCopy, divisor, dividendCopy.size() - 2));
addInequality(
getDivUpperBound(dividendCopy, divisor, dividendCopy.size() - 2));
+ return newVar;
+}
+
+unsigned IntegerRelation::addLocalModulo(ArrayRef<DynamicAPInt> exprs,
+ const DynamicAPInt &modulus) {
+ assert(exprs.size() == getNumCols() && "incorrect exprs size");
+ assert(modulus > 0 && "positive modulus expected");
+
+ /// Add a local variable for q = expr floordiv modulus
+ addLocalFloorDiv(exprs, modulus);
+
+ /// Add a local var to represent the result
+ auto resultIndex = appendVar(VarKind::Local);
+
+ SmallVector<DynamicAPInt, 8> exprsCopy(exprs);
+ /// Insert the two new locals before the constant
+ /// Add locals that correspond to `q` and `result` to compute
+ /// 0 = (expr - modulus * q) - result
+ exprsCopy.insert(exprsCopy.end() - 1,
+ {DynamicAPInt(-modulus), DynamicAPInt(-1)});
+ addEquality(exprsCopy);
+ return resultIndex;
}
int IntegerRelation::findEqualityToConstant(unsigned pos, bool symbolic) const {
diff --git a/mlir/lib/Analysis/Presburger/Simplex.cpp b/mlir/lib/Analysis/Presburger/Simplex.cpp
index 08290db..51e2007 100644
--- a/mlir/lib/Analysis/Presburger/Simplex.cpp
+++ b/mlir/lib/Analysis/Presburger/Simplex.cpp
@@ -433,7 +433,7 @@ LogicalResult SymbolicLexSimplex::addSymbolicCut(unsigned row) {
normalizeDiv(divCoeffs, divDenom);
domainSimplex.addDivisionVariable(divCoeffs, divDenom);
- domainPoly.addLocalFloorDiv(divCoeffs, divDenom);
+ (void)domainPoly.addLocalFloorDiv(divCoeffs, divDenom);
// Update `this` to account for the additional symbol we just added.
appendSymbol();
diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
index 81dada3..4885d62c 100644
--- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
+++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp
@@ -7,8 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir-c/ExecutionEngine.h"
-#include "mlir/Bindings/Python/NanobindAdaptors.h"
#include "mlir/Bindings/Python/Nanobind.h"
+#include "mlir/Bindings/Python/NanobindAdaptors.h"
namespace nb = nanobind;
using namespace mlir;
@@ -125,6 +125,17 @@ NB_MODULE(_mlirExecutionEngine, m) {
nb::arg("name"), nb::arg("callback"),
"Register `callback` as the runtime symbol `name`.")
.def(
+ "initialize",
+ [](PyExecutionEngine &executionEngine) {
+ mlirExecutionEngineInitialize(executionEngine.get());
+ },
+ "Initialize the ExecutionEngine. Global constructors specified by "
+ "`llvm.mlir.global_ctors` will be run. One common scenario is that "
+ "kernel binary compiled from `gpu.module` gets loaded during "
+ "initialization. Make sure all symbols are resolvable before "
+ "initialization by calling `register_runtime` or including "
+ "shared libraries.")
+ .def(
"dump_to_object_file",
[](PyExecutionEngine &executionEngine, const std::string &fileName) {
mlirExecutionEngineDumpToObjectFile(
diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 826a34a..71a051c 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -10,15 +10,19 @@
#define MLIR_BINDINGS_PYTHON_GLOBALS_H
#include <optional>
+#include <regex>
#include <string>
+#include <unordered_set>
#include <vector>
#include "NanobindUtils.h"
#include "mlir-c/IR.h"
#include "mlir/CAPI/Support.h"
#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Regex.h"
namespace mlir {
namespace python {
@@ -114,6 +118,39 @@ public:
std::optional<nanobind::object>
lookupOperationClass(llvm::StringRef operationName);
+ class TracebackLoc {
+ public:
+ bool locTracebacksEnabled();
+
+ void setLocTracebacksEnabled(bool value);
+
+ size_t locTracebackFramesLimit();
+
+ void setLocTracebackFramesLimit(size_t value);
+
+ void registerTracebackFileInclusion(const std::string &file);
+
+ void registerTracebackFileExclusion(const std::string &file);
+
+ bool isUserTracebackFilename(llvm::StringRef file);
+
+ static constexpr size_t kMaxFrames = 512;
+
+ private:
+ nanobind::ft_mutex mutex;
+ bool locTracebackEnabled_ = false;
+ size_t locTracebackFramesLimit_ = 10;
+ std::unordered_set<std::string> userTracebackIncludeFiles;
+ std::unordered_set<std::string> userTracebackExcludeFiles;
+ std::regex userTracebackIncludeRegex;
+ bool rebuildUserTracebackIncludeRegex = false;
+ std::regex userTracebackExcludeRegex;
+ bool rebuildUserTracebackExcludeRegex = false;
+ llvm::StringMap<bool> isUserTracebackFilenameCache;
+ };
+
+ TracebackLoc &getTracebackLoc() { return tracebackLoc; }
+
private:
static PyGlobals *instance;
@@ -134,6 +171,8 @@ private:
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
+
+ TracebackLoc tracebackLoc;
};
} // namespace python
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5feed95..4b3a06c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -20,11 +20,8 @@
#include "nanobind/nanobind.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
-#include "llvm/Support/raw_ostream.h"
#include <optional>
-#include <system_error>
-#include <utility>
namespace nb = nanobind;
using namespace nb::literals;
@@ -1523,7 +1520,7 @@ nb::object PyOperation::create(std::string_view name,
llvm::ArrayRef<MlirValue> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- int regions, DefaultingPyLocation location,
+ int regions, PyLocation &location,
const nb::object &maybeIp, bool inferType) {
llvm::SmallVector<MlirType, 4> mlirResults;
llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1627,7 +1624,7 @@ nb::object PyOperation::create(std::string_view name,
if (!operation.ptr)
throw nb::value_error("Operation creation failed");
PyOperationRef created =
- PyOperation::createDetached(location->getContext(), operation);
+ PyOperation::createDetached(location.getContext(), operation);
maybeInsertOperation(created, maybeIp);
return created.getObject();
@@ -1937,9 +1934,9 @@ nb::object PyOpView::buildGeneric(
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions, PyLocation &location,
const nb::object &maybeIp) {
- PyMlirContextRef context = location->getContext();
+ PyMlirContextRef context = location.getContext();
// Class level operation construction metadata.
// Operand and result segment specs are either none, which does no
@@ -2789,6 +2786,156 @@ private:
PyOperationRef operation;
};
+// see
+// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h
+
+#ifndef _Py_CAST
+#define _Py_CAST(type, expr) ((type)(expr))
+#endif
+
+// Static inline functions should use _Py_NULL rather than using directly NULL
+// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer,
+// _Py_NULL is defined as nullptr.
+#ifndef _Py_NULL
+#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \
+ (defined(__cplusplus) && __cplusplus >= 201103)
+#define _Py_NULL nullptr
+#else
+#define _Py_NULL NULL
+#endif
+#endif
+
+// Python 3.10.0a3
+#if PY_VERSION_HEX < 0x030A00A3
+
+// bpo-42262 added Py_XNewRef()
+#if !defined(Py_XNewRef)
+[[maybe_unused]] PyObject *_Py_XNewRef(PyObject *obj) {
+ Py_XINCREF(obj);
+ return obj;
+}
+#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj))
+#endif
+
+// bpo-42262 added Py_NewRef()
+#if !defined(Py_NewRef)
+[[maybe_unused]] PyObject *_Py_NewRef(PyObject *obj) {
+ Py_INCREF(obj);
+ return obj;
+}
+#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj))
+#endif
+
+#endif // Python 3.10.0a3
+
+// Python 3.9.0b1
+#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION)
+
+// bpo-40429 added PyThreadState_GetFrame()
+PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) {
+ assert(tstate != _Py_NULL && "expected tstate != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame));
+}
+
+// bpo-40421 added PyFrame_GetBack()
+PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back));
+}
+
+// bpo-40421 added PyFrame_GetCode()
+PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) {
+ assert(frame != _Py_NULL && "expected frame != _Py_NULL");
+ assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL");
+ return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code));
+}
+
+#endif // Python 3.9.0b1
+
+MlirLocation tracebackToLocation(MlirContext ctx) {
+ size_t framesLimit =
+ PyGlobals::get().getTracebackLoc().locTracebackFramesLimit();
+ // Use a thread_local here to avoid requiring a large amount of space.
+ thread_local std::array<MlirLocation, PyGlobals::TracebackLoc::kMaxFrames>
+ frames;
+ size_t count = 0;
+
+ nb::gil_scoped_acquire acquire;
+ PyThreadState *tstate = PyThreadState_GET();
+ PyFrameObject *next;
+ PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate);
+ // In the increment expression:
+ // 1. get the next prev frame;
+ // 2. decrement the ref count on the current frame (in order that it can get
+ // gc'd, along with any objects in its closure and etc);
+ // 3. set current = next.
+ for (; pyFrame != nullptr && count < framesLimit;
+ next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) {
+ PyCodeObject *code = PyFrame_GetCode(pyFrame);
+ auto fileNameStr =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_filename));
+ llvm::StringRef fileName(fileNameStr);
+ if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName))
+ continue;
+
+ // co_qualname and PyCode_Addr2Location added in py3.11
+#if PY_VERSION_HEX < 0x030B00F0
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_name));
+ llvm::StringRef funcName(name);
+ int startLine = PyFrame_GetLineNumber(pyFrame);
+ MlirLocation loc =
+ mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0);
+#else
+ std::string name =
+ nb::cast<std::string>(nb::borrow<nb::str>(code->co_qualname));
+ llvm::StringRef funcName(name);
+ int startLine, startCol, endLine, endCol;
+ int lasti = PyFrame_GetLasti(pyFrame);
+ if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine,
+ &endCol)) {
+ throw nb::python_error();
+ }
+ MlirLocation loc = mlirLocationFileLineColRangeGet(
+ ctx, wrap(fileName), startLine, startCol, endLine, endCol);
+#endif
+
+ frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc);
+ ++count;
+ }
+ // When the loop breaks (after the last iter), current frame (if non-null)
+ // is leaked without this.
+ Py_XDECREF(pyFrame);
+
+ if (count == 0)
+ return mlirLocationUnknownGet(ctx);
+
+ MlirLocation callee = frames[0];
+ assert(!mlirLocationIsNull(callee) && "expected non-null callee location");
+ if (count == 1)
+ return callee;
+
+ MlirLocation caller = frames[count - 1];
+ assert(!mlirLocationIsNull(caller) && "expected non-null caller location");
+ for (int i = count - 2; i >= 1; i--)
+ caller = mlirLocationCallSiteGet(frames[i], caller);
+
+ return mlirLocationCallSiteGet(callee, caller);
+}
+
+PyLocation
+maybeGetTracebackLocation(const std::optional<PyLocation> &location) {
+ if (location.has_value())
+ return location.value();
+ if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled())
+ return DefaultingPyLocation::resolve();
+
+ PyMlirContext &ctx = DefaultingPyMlirContext::resolve();
+ MlirLocation mlirLoc = tracebackToLocation(ctx.get());
+ PyMlirContextRef ref = PyMlirContext::forContext(ctx.get());
+ return {ref, mlirLoc};
+}
+
} // namespace
//------------------------------------------------------------------------------
@@ -3052,10 +3199,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def("__eq__", [](PyLocation &self, nb::object other) { return false; })
.def_prop_ro_static(
"current",
- [](nb::object & /*class*/) {
+ [](nb::object & /*class*/) -> std::optional<PyLocation *> {
auto *loc = PyThreadContextEntry::getDefaultLocation();
if (!loc)
- throw nb::value_error("No current Location");
+ return std::nullopt;
return loc;
},
"Gets the Location bound to the current thread or raises ValueError")
@@ -3240,8 +3387,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
kModuleParseDocstring)
.def_static(
"create",
- [](DefaultingPyLocation loc) {
- MlirModule module = mlirModuleCreateEmpty(loc);
+ [](const std::optional<PyLocation> &loc) {
+ PyLocation pyLoc = maybeGetTracebackLocation(loc);
+ MlirModule module = mlirModuleCreateEmpty(pyLoc.get());
return PyModule::forModule(module).releaseObject();
},
nb::arg("loc").none() = nb::none(), "Creates an empty module")
@@ -3442,6 +3590,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
return operation.createOpView();
},
"Detaches the operation from its parent block.")
+ .def_prop_ro(
+ "attached",
+ [](PyOperationBase &self) {
+ PyOperation &operation = self.getOperation();
+ operation.checkValid();
+ return operation.isAttached();
+ },
+ "Reports if the operation is attached to its parent block.")
.def("erase", [](PyOperationBase &self) { self.getOperation().erase(); })
.def("walk", &PyOperationBase::walk, nb::arg("callback"),
nb::arg("walk_order") = MlirWalkPostOrder);
@@ -3454,8 +3610,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<std::vector<PyValue *>> operands,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
- DefaultingPyLocation location, const nb::object &maybeIp,
- bool inferType) {
+ const std::optional<PyLocation> &location,
+ const nb::object &maybeIp, bool inferType) {
// Unpack/validate operands.
llvm::SmallVector<MlirValue, 4> mlirOperands;
if (operands) {
@@ -3467,8 +3623,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
}
}
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
return PyOperation::create(name, results, mlirOperands, attributes,
- successors, regions, location, maybeIp,
+ successors, regions, pyLoc, maybeIp,
inferType);
},
nb::arg("name"), nb::arg("results").none() = nb::none(),
@@ -3512,12 +3669,14 @@ void mlir::python::populateIRCore(nb::module_ &m) {
std::optional<nb::list> resultTypeList, nb::list operandList,
std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions,
+ const std::optional<PyLocation> &location,
const nb::object &maybeIp) {
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
new (self) PyOpView(PyOpView::buildGeneric(
name, opRegionSpec, operandSegmentSpecObj,
resultSegmentSpecObj, resultTypeList, operandList,
- attributes, successors, regions, location, maybeIp));
+ attributes, successors, regions, pyLoc, maybeIp));
},
nb::arg("name"), nb::arg("opRegionSpec"),
nb::arg("operandSegmentSpecObj").none() = nb::none(),
@@ -3551,17 +3710,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](nb::handle cls, std::optional<nb::list> resultTypeList,
nb::list operandList, std::optional<nb::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions, std::optional<PyLocation> location,
const nb::object &maybeIp) {
std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
std::tuple<int, bool> opRegionSpec =
nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
+ PyLocation pyLoc = maybeGetTracebackLocation(location);
return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
resultSegmentSpec, resultTypeList,
operandList, attributes, successors,
- regions, location, maybeIp);
+ regions, pyLoc, maybeIp);
},
nb::arg("cls"), nb::arg("results").none() = nb::none(),
nb::arg("operands").none() = nb::none(),
diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp
index e600f1b..0de2f17 100644
--- a/mlir/lib/Bindings/Python/IRModule.cpp
+++ b/mlir/lib/Bindings/Python/IRModule.cpp
@@ -13,9 +13,9 @@
#include "Globals.h"
#include "NanobindUtils.h"
+#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/Support.h"
#include "mlir/Bindings/Python/Nanobind.h"
-#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
namespace nb = nanobind;
using namespace mlir;
@@ -197,3 +197,71 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) {
// Not found and loading did not yield a registration.
return std::nullopt;
}
+
+bool PyGlobals::TracebackLoc::locTracebacksEnabled() {
+ nanobind::ft_lock_guard lock(mutex);
+ return locTracebackEnabled_;
+}
+
+void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) {
+ nanobind::ft_lock_guard lock(mutex);
+ locTracebackEnabled_ = value;
+}
+
+size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() {
+ nanobind::ft_lock_guard lock(mutex);
+ return locTracebackFramesLimit_;
+}
+
+void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) {
+ nanobind::ft_lock_guard lock(mutex);
+ locTracebackFramesLimit_ = std::min(value, kMaxFrames);
+}
+
+void PyGlobals::TracebackLoc::registerTracebackFileInclusion(
+ const std::string &file) {
+ nanobind::ft_lock_guard lock(mutex);
+ auto reg = "^" + llvm::Regex::escape(file);
+ if (userTracebackIncludeFiles.insert(reg).second)
+ rebuildUserTracebackIncludeRegex = true;
+ if (userTracebackExcludeFiles.count(reg)) {
+ if (userTracebackExcludeFiles.erase(reg))
+ rebuildUserTracebackExcludeRegex = true;
+ }
+}
+
+void PyGlobals::TracebackLoc::registerTracebackFileExclusion(
+ const std::string &file) {
+ nanobind::ft_lock_guard lock(mutex);
+ auto reg = "^" + llvm::Regex::escape(file);
+ if (userTracebackExcludeFiles.insert(reg).second)
+ rebuildUserTracebackExcludeRegex = true;
+ if (userTracebackIncludeFiles.count(reg)) {
+ if (userTracebackIncludeFiles.erase(reg))
+ rebuildUserTracebackIncludeRegex = true;
+ }
+}
+
+bool PyGlobals::TracebackLoc::isUserTracebackFilename(
+ const llvm::StringRef file) {
+ nanobind::ft_lock_guard lock(mutex);
+ if (rebuildUserTracebackIncludeRegex) {
+ userTracebackIncludeRegex.assign(
+ llvm::join(userTracebackIncludeFiles, "|"));
+ rebuildUserTracebackIncludeRegex = false;
+ isUserTracebackFilenameCache.clear();
+ }
+ if (rebuildUserTracebackExcludeRegex) {
+ userTracebackExcludeRegex.assign(
+ llvm::join(userTracebackExcludeFiles, "|"));
+ rebuildUserTracebackExcludeRegex = false;
+ isUserTracebackFilenameCache.clear();
+ }
+ if (!isUserTracebackFilenameCache.contains(file)) {
+ std::string fileStr = file.str();
+ bool include = std::regex_search(fileStr, userTracebackIncludeRegex);
+ bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex);
+ isUserTracebackFilenameCache[file] = include || !exclude;
+ }
+ return isUserTracebackFilenameCache[file];
+}
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9c22dea..fa16ae3 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -192,16 +192,6 @@ public:
PyMlirContext(const PyMlirContext &) = delete;
PyMlirContext(PyMlirContext &&) = delete;
- /// For the case of a python __init__ (nanobind::init) method, pybind11 is
- /// quite strict about needing to return a pointer that is not yet associated
- /// to an nanobind::object. Since the forContext() method acts like a pool,
- /// possibly returning a recycled context, it does not satisfy this need. The
- /// usual way in python to accomplish such a thing is to override __new__, but
- /// that is also not supported by pybind11. Instead, we use this entry
- /// point which always constructs a fresh context (which cannot alias an
- /// existing one because it is fresh).
- static PyMlirContext *createNewContextForInit();
-
/// Returns a context reference for the singleton PyMlirContext wrapper for
/// the given context.
static PyMlirContextRef forContext(MlirContext context);
@@ -722,8 +712,7 @@ public:
llvm::ArrayRef<MlirValue> operands,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors, int regions,
- DefaultingPyLocation location, const nanobind::object &ip,
- bool inferType);
+ PyLocation &location, const nanobind::object &ip, bool inferType);
/// Creates an OpView suitable for this operation.
nanobind::object createOpView();
@@ -781,7 +770,7 @@ public:
nanobind::list operandList,
std::optional<nanobind::dict> attributes,
std::optional<std::vector<PyBlock *>> successors,
- std::optional<int> regions, DefaultingPyLocation location,
+ std::optional<int> regions, PyLocation &location,
const nanobind::object &maybeIp);
/// Construct an instance of a class deriving from OpView, bypassing its
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 6f49431..278847e 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -6,7 +6,6 @@
//
//===----------------------------------------------------------------------===//
-
#include "Globals.h"
#include "IRModule.h"
#include "NanobindUtils.h"
@@ -44,7 +43,27 @@ NB_MODULE(_mlir, m) {
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"operation_name"_a, "operation_class"_a, nb::kw_only(),
"replace"_a = false,
- "Testing hook for directly registering an operation");
+ "Testing hook for directly registering an operation")
+ .def("loc_tracebacks_enabled",
+ [](PyGlobals &self) {
+ return self.getTracebackLoc().locTracebacksEnabled();
+ })
+ .def("set_loc_tracebacks_enabled",
+ [](PyGlobals &self, bool enabled) {
+ self.getTracebackLoc().setLocTracebacksEnabled(enabled);
+ })
+ .def("set_loc_tracebacks_frame_limit",
+ [](PyGlobals &self, int n) {
+ self.getTracebackLoc().setLocTracebackFramesLimit(n);
+ })
+ .def("register_traceback_file_inclusion",
+ [](PyGlobals &self, const std::string &filename) {
+ self.getTracebackLoc().registerTracebackFileInclusion(filename);
+ })
+ .def("register_traceback_file_exclusion",
+ [](PyGlobals &self, const std::string &filename) {
+ self.getTracebackLoc().registerTracebackFileExclusion(filename);
+ });
// Aside from making the globals accessible to python, having python manage
// it is necessary to make sure it is destroyed (and releases its python
diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
index 306cebd..2dbb993 100644
--- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
@@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
return wrap(jitOrError->release());
}
+extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) {
+ unwrap(jit)->initialize();
+}
+
extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
delete (unwrap(jit));
}
@@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
void *sym) {
unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
llvm::orc::SymbolMap symbolMap;
- symbolMap[interner(unwrap(name))] =
- { llvm::orc::ExecutorAddr::fromPtr(sym),
- llvm::JITSymbolFlags::Exported };
+ symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym),
+ llvm::JITSymbolFlags::Exported};
return symbolMap;
});
}
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 18e857c..cb0c829 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -238,6 +238,16 @@ struct CmpFOpLowering : public ConvertOpToLLVMPattern<arith::CmpFOp> {
ConversionPatternRewriter &rewriter) const override;
};
+struct SelectOpOneToNLowering : public ConvertOpToLLVMPattern<arith::SelectOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<arith::SelectOp>::OneToNOpAdaptor;
+
+ LogicalResult
+ matchAndRewrite(arith::SelectOp op, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -480,6 +490,32 @@ CmpFOpLowering::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
}
//===----------------------------------------------------------------------===//
+// SelectOpOneToNLowering
+//===----------------------------------------------------------------------===//
+
+/// Pattern for arith.select where the true/false values lower to multiple
+/// SSA values (1:N conversion). This pattern generates multiple arith.select
+/// than can be lowered by the 1:1 arith.select pattern.
+LogicalResult SelectOpOneToNLowering::matchAndRewrite(
+ arith::SelectOp op, Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ // In case of a 1:1 conversion, the 1:1 pattern will match.
+ if (llvm::hasSingleElement(adaptor.getTrueValue()))
+ return rewriter.notifyMatchFailure(
+ op, "not a 1:N conversion, 1:1 pattern will match");
+ if (!op.getCondition().getType().isInteger(1))
+ return rewriter.notifyMatchFailure(op,
+ "non-i1 conditions are not supported");
+ SmallVector<Value> results;
+ for (auto [trueValue, falseValue] :
+ llvm::zip_equal(adaptor.getTrueValue(), adaptor.getFalseValue()))
+ results.push_back(arith::SelectOp::create(
+ rewriter, op.getLoc(), op.getCondition(), trueValue, falseValue));
+ rewriter.replaceOpWithMultiple(op, {results});
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -587,6 +623,7 @@ void mlir::arith::populateArithToLLVMConversionPatterns(
RemSIOpLowering,
RemUIOpLowering,
SelectOpLowering,
+ SelectOpOneToNLowering,
ShLIOpLowering,
ShRSIOpLowering,
ShRUIOpLowering,
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 785cb82..171f716 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -68,6 +68,7 @@ add_subdirectory(TosaToSCF)
add_subdirectory(TosaToTensor)
add_subdirectory(UBToLLVM)
add_subdirectory(UBToSPIRV)
+add_subdirectory(VectorToAMX)
add_subdirectory(VectorToArmSME)
add_subdirectory(VectorToGPU)
add_subdirectory(VectorToLLVM)
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 35ad99c..b3d6d59 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,14 +64,6 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
patterns.getContext(), "__ocml_cabs_f64");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
- patterns.getContext(), "__ocml_carg_f32");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
- patterns.getContext(), "__ocml_carg_f64");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
- patterns.getContext(), "__ocml_conj_f32");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
- patterns.getContext(), "__ocml_conj_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
patterns.getContext(), "__ocml_ccos_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
@@ -84,10 +76,6 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_clog_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
patterns.getContext(), "__ocml_clog_f64");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
- patterns.getContext(), "__ocml_cpow_f32");
- patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
- patterns.getContext(), "__ocml_cpow_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
patterns.getContext(), "__ocml_csin_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
@@ -122,9 +110,8 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
- complex::CosOp, complex::ExpOp, complex::LogOp,
- complex::PowOp, complex::SinOp, complex::SqrtOp,
+ target.addIllegalOp<complex::AbsOp, complex::CosOp, complex::ExpOp,
+ complex::LogOp, complex::SinOp, complex::SqrtOp,
complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
diff --git a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
index ff6d369..798d8b0 100644
--- a/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
+++ b/mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
@@ -125,22 +125,33 @@ static FailureOr<Block *> getConvertedBlock(ConversionPatternRewriter &rewriter,
return rewriter.applySignatureConversion(block, *conversion, converter);
}
+/// Flatten the given value ranges into a single vector of values.
+static SmallVector<Value> flattenValues(ArrayRef<ValueRange> values) {
+ SmallVector<Value> result;
+ for (const ValueRange &vals : values)
+ llvm::append_range(result, vals);
+ return result;
+}
+
/// Convert the destination block signature (if necessary) and lower the branch
/// op to llvm.br.
struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
using ConvertOpToLLVMPattern<cf::BranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::BranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::BranchOp op, typename cf::BranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::BranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptor = flattenValues(adaptor.getOperands());
FailureOr<Block *> convertedBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getSuccessor(),
- TypeRange(adaptor.getOperands()));
+ TypeRange(ValueRange(flattenedAdaptor)));
if (failed(convertedBlock))
return failure();
DictionaryAttr attrs = op->getAttrDictionary();
Operation *newOp = rewriter.replaceOpWithNewOp<LLVM::BrOp>(
- op, adaptor.getOperands(), *convertedBlock);
+ op, flattenedAdaptor, *convertedBlock);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
newOp->setAttrs(attrs);
@@ -152,29 +163,37 @@ struct BranchOpLowering : public ConvertOpToLLVMPattern<cf::BranchOp> {
/// branch op to llvm.cond_br.
struct CondBranchOpLowering : public ConvertOpToLLVMPattern<cf::CondBranchOp> {
using ConvertOpToLLVMPattern<cf::CondBranchOp>::ConvertOpToLLVMPattern;
+ using Adaptor =
+ typename ConvertOpToLLVMPattern<cf::CondBranchOp>::OneToNOpAdaptor;
LogicalResult
- matchAndRewrite(cf::CondBranchOp op,
- typename cf::CondBranchOp::Adaptor adaptor,
+ matchAndRewrite(cf::CondBranchOp op, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+ SmallVector<Value> flattenedAdaptorTrue =
+ flattenValues(adaptor.getTrueDestOperands());
+ SmallVector<Value> flattenedAdaptorFalse =
+ flattenValues(adaptor.getFalseDestOperands());
+ if (!llvm::hasSingleElement(adaptor.getCondition()))
+ return rewriter.notifyMatchFailure(op,
+ "expected single element condition");
FailureOr<Block *> convertedTrueBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getTrueDest(),
- TypeRange(adaptor.getTrueDestOperands()));
+ TypeRange(ValueRange(flattenedAdaptorTrue)));
if (failed(convertedTrueBlock))
return failure();
FailureOr<Block *> convertedFalseBlock =
getConvertedBlock(rewriter, getTypeConverter(), op, op.getFalseDest(),
- TypeRange(adaptor.getFalseDestOperands()));
+ TypeRange(ValueRange(flattenedAdaptorFalse)));
if (failed(convertedFalseBlock))
return failure();
- DictionaryAttr attrs = op->getAttrDictionary();
+ DictionaryAttr attrs = op->getDiscardableAttrDictionary();
auto newOp = rewriter.replaceOpWithNewOp<LLVM::CondBrOp>(
- op, adaptor.getCondition(), adaptor.getTrueDestOperands(),
- adaptor.getFalseDestOperands(), op.getBranchWeightsAttr(),
+ op, llvm::getSingleElement(adaptor.getCondition()),
+ flattenedAdaptorTrue, flattenedAdaptorFalse, op.getBranchWeightsAttr(),
*convertedTrueBlock, *convertedFalseBlock);
// TODO: We should not just forward all attributes like that. But there are
// existing Flang tests that depend on this behavior.
- newOp->setAttrs(attrs);
+ newOp->setDiscardableAttrs(attrs);
return success();
}
};
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index ed5d6d4..cdb7150 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -31,7 +31,8 @@ namespace {
class ConvertToLLVMPassInterface {
public:
ConvertToLLVMPassInterface(MLIRContext *context,
- ArrayRef<std::string> filterDialects);
+ ArrayRef<std::string> filterDialects,
+ bool allowPatternRollback = true);
virtual ~ConvertToLLVMPassInterface() = default;
/// Get the dependent dialects used by `convert-to-llvm`.
@@ -60,6 +61,9 @@ protected:
MLIRContext *context;
/// List of dialects names to use as filters.
ArrayRef<std::string> filterDialects;
+ /// An experimental flag to disallow pattern rollback. This is more efficient
+ /// but not supported by all lowering patterns.
+ bool allowPatternRollback;
};
/// This DialectExtension can be attached to the context, which will invoke the
@@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
/// Apply the conversion driver.
LogicalResult transform(Operation *op, AnalysisManager manager) const final {
- if (failed(applyPartialConversion(op, *target, *patterns)))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(op, *target, *patterns, config)))
return failure();
return success();
}
@@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
patterns);
// Apply the conversion.
- if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
+ if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
return failure();
return success();
}
@@ -206,9 +214,11 @@ public:
std::shared_ptr<ConvertToLLVMPassInterface> impl;
// Choose the pass implementation.
if (useDynamic)
- impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
+ impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
+ allowPatternRollback);
else
- impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
+ impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
+ allowPatternRollback);
if (failed(impl->initialize()))
return failure();
this->impl = impl;
@@ -228,8 +238,10 @@ public:
//===----------------------------------------------------------------------===//
ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
- MLIRContext *context, ArrayRef<std::string> filterDialects)
- : context(context), filterDialects(filterDialects) {}
+ MLIRContext *context, ArrayRef<std::string> filterDialects,
+ bool allowPatternRollback)
+ : context(context), filterDialects(filterDialects),
+ allowPatternRollback(allowPatternRollback) {}
void ConvertToLLVMPassInterface::getDependentDialects(
DialectRegistry &registry) {
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 67bb1c1..42c76ed 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -527,19 +527,21 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
using ConvertOpToLLVMPattern<CallOpType>::ConvertOpToLLVMPattern;
using Super = CallOpInterfaceLowering<CallOpType>;
using Base = ConvertOpToLLVMPattern<CallOpType>;
+ using Adaptor = typename ConvertOpToLLVMPattern<CallOpType>::OneToNOpAdaptor;
- LogicalResult matchAndRewriteImpl(CallOpType callOp,
- typename CallOpType::Adaptor adaptor,
+ LogicalResult matchAndRewriteImpl(CallOpType callOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter,
bool useBarePtrCallConv = false) const {
// Pack the result types into a struct.
Type packedResult = nullptr;
+ SmallVector<SmallVector<Type>> groupedResultTypes;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
-
+ int64_t numConvertedTypes = 0;
if (numResults != 0) {
if (!(packedResult = this->getTypeConverter()->packFunctionResults(
- resultTypes, useBarePtrCallConv)))
+ resultTypes, useBarePtrCallConv, &groupedResultTypes,
+ &numConvertedTypes)))
return failure();
}
@@ -565,34 +567,64 @@ struct CallOpInterfaceLowering : public ConvertOpToLLVMPattern<CallOpType> {
static_cast<int32_t>(promoted.size()), 0};
newOp.getProperties().op_bundle_sizes = rewriter.getDenseI32ArrayAttr({});
- SmallVector<Value, 4> results;
- if (numResults < 2) {
- // If < 2 results, packing did not do anything and we can just return.
- results.append(newOp.result_begin(), newOp.result_end());
- } else {
- // Otherwise, it had been converted to an operation producing a structure.
- // Extract individual results from the structure and return them as list.
- results.reserve(numResults);
- for (unsigned i = 0; i < numResults; ++i) {
- results.push_back(LLVM::ExtractValueOp::create(
- rewriter, callOp.getLoc(), newOp->getResult(0), i));
+ // Helper function that extracts an individual result from the return value
+ // of the new call op. llvm.call ops support only 0 or 1 result. In case of
+ // 2 or more results, the results are packed into a structure.
+ //
+ // The new call op may have more than 2 results because:
+ // a. The original call op has more than 2 results.
+ // b. An original op result type-converted to more than 1 result.
+ auto getUnpackedResult = [&](unsigned i) -> Value {
+ assert(numConvertedTypes > 0 && "convert op has no results");
+ if (numConvertedTypes == 1) {
+ assert(i == 0 && "out of bounds: converted op has only one result");
+ return newOp->getResult(0);
}
+ // Results have been converted to a structure. Extract individual results
+ // from the structure.
+ return LLVM::ExtractValueOp::create(rewriter, callOp.getLoc(),
+ newOp->getResult(0), i);
+ };
+
+ // Group the results into a vector of vectors, such that it is clear which
+ // original op result is replaced with which range of values. (In case of a
+ // 1:N conversion, there can be multiple replacements for a single result.)
+ SmallVector<SmallVector<Value>> results;
+ results.reserve(numResults);
+ unsigned counter = 0;
+ for (unsigned i = 0; i < numResults; ++i) {
+ SmallVector<Value> &group = results.emplace_back();
+ for (unsigned j = 0, e = groupedResultTypes[i].size(); j < e; ++j)
+ group.push_back(getUnpackedResult(counter++));
}
- if (useBarePtrCallConv) {
- // For the bare-ptr calling convention, promote memref results to
- // descriptors.
- assert(results.size() == resultTypes.size() &&
- "The number of arguments and types doesn't match");
- this->getTypeConverter()->promoteBarePtrsToDescriptors(
- rewriter, callOp.getLoc(), resultTypes, results);
- } else if (failed(this->copyUnrankedDescriptors(rewriter, callOp.getLoc(),
- resultTypes, results,
- /*toDynamic=*/false))) {
- return failure();
+ // Special handling for MemRef types.
+ for (unsigned i = 0; i < numResults; ++i) {
+ Type origType = resultTypes[i];
+ auto memrefType = dyn_cast<MemRefType>(origType);
+ auto unrankedMemrefType = dyn_cast<UnrankedMemRefType>(origType);
+ if (useBarePtrCallConv && memrefType) {
+ // For the bare-ptr calling convention, promote memref results to
+ // descriptors.
+ assert(results[i].size() == 1 && "expected one converted result");
+ results[i].front() = MemRefDescriptor::fromStaticShape(
+ rewriter, callOp.getLoc(), *this->getTypeConverter(), memrefType,
+ results[i].front());
+ }
+ if (unrankedMemrefType) {
+ assert(!useBarePtrCallConv && "unranked memref is not supported in the "
+ "bare-ptr calling convention");
+ assert(results[i].size() == 1 && "expected one converted result");
+ Value desc = this->copyUnrankedDescriptor(
+ rewriter, callOp.getLoc(), unrankedMemrefType, results[i].front(),
+ /*toDynamic=*/false);
+ if (!desc)
+ return failure();
+ results[i].front() = desc;
+ }
}
- rewriter.replaceOp(callOp, results);
+ rewriter.replaceOpWithMultiple(callOp, results);
return success();
}
};
@@ -606,7 +638,7 @@ public:
symbolTables(symbolTables) {}
LogicalResult
- matchAndRewrite(func::CallOp callOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallOp callOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool useBarePtrCallConv = false;
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
@@ -636,7 +668,7 @@ struct CallIndirectOpLowering
using Super::Super;
LogicalResult
- matchAndRewrite(func::CallIndirectOp callIndirectOp, OpAdaptor adaptor,
+ matchAndRewrite(func::CallIndirectOp callIndirectOp, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
return matchAndRewriteImpl(callIndirectOp, adaptor, rewriter);
}
@@ -679,41 +711,50 @@ struct ReturnOpLowering : public ConvertOpToLLVMPattern<func::ReturnOp> {
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
LogicalResult
- matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
+ matchAndRewrite(func::ReturnOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
- unsigned numArguments = op.getNumOperands();
SmallVector<Value, 4> updatedOperands;
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
bool useBarePtrCallConv =
shouldUseBarePtrCallConv(funcOp, this->getTypeConverter());
- if (useBarePtrCallConv) {
- // For the bare-ptr calling convention, extract the aligned pointer to
- // be returned from the memref descriptor.
- for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) {
- Type oldTy = std::get<0>(it).getType();
- Value newOperand = std::get<1>(it);
- if (isa<MemRefType>(oldTy) && getTypeConverter()->canConvertToBarePtr(
- cast<BaseMemRefType>(oldTy))) {
- MemRefDescriptor memrefDesc(newOperand);
- newOperand = memrefDesc.allocatedPtr(rewriter, loc);
- } else if (isa<UnrankedMemRefType>(oldTy)) {
+
+ for (auto [oldOperand, newOperands] :
+ llvm::zip_equal(op->getOperands(), adaptor.getOperands())) {
+ Type oldTy = oldOperand.getType();
+ if (auto memRefType = dyn_cast<MemRefType>(oldTy)) {
+ assert(newOperands.size() == 1 && "expected one converted result");
+ if (useBarePtrCallConv &&
+ getTypeConverter()->canConvertToBarePtr(memRefType)) {
+ // For the bare-ptr calling convention, extract the aligned pointer to
+ // be returned from the memref descriptor.
+ MemRefDescriptor memrefDesc(newOperands.front());
+ updatedOperands.push_back(memrefDesc.allocatedPtr(rewriter, loc));
+ continue;
+ }
+ } else if (auto unrankedMemRefType =
+ dyn_cast<UnrankedMemRefType>(oldTy)) {
+ assert(newOperands.size() == 1 && "expected one converted result");
+ if (useBarePtrCallConv) {
// Unranked memref is not supported in the bare pointer calling
// convention.
return failure();
}
- updatedOperands.push_back(newOperand);
+ Value updatedDesc =
+ copyUnrankedDescriptor(rewriter, loc, unrankedMemRefType,
+ newOperands.front(), /*toDynamic=*/true);
+ if (!updatedDesc)
+ return failure();
+ updatedOperands.push_back(updatedDesc);
+ continue;
}
- } else {
- updatedOperands = llvm::to_vector<4>(adaptor.getOperands());
- (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(),
- updatedOperands,
- /*toDynamic=*/true);
+
+ llvm::append_range(updatedOperands, newOperands);
}
// If ReturnOp has 0 or 1 operand, create it and return immediately.
- if (numArguments <= 1) {
+ if (updatedOperands.size() <= 1) {
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(
op, TypeRange(), updatedOperands, op->getAttrs());
return success();
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index d22364e..e6fbcf9 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -79,17 +79,30 @@ static bool canBeCalledWithBarePointers(gpu::GPUFuncOp func) {
return canBeBare;
}
-static Value getLaneId(ConversionPatternRewriter &rewriter, Location loc,
- const unsigned indexBitwidth) {
+static Value getLaneId(RewriterBase &rewriter, Location loc) {
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
- Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, int32Type,
- ValueRange{minus1, zero});
- Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, int32Type,
- ValueRange{minus1, mbcntLo});
+ NamedAttribute noundef = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getNoUndefAttrName(), rewriter.getUnitAttr());
+ NamedAttribute lowRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 32)));
+ NamedAttribute highRange = rewriter.getNamedAttr(
+ LLVM::LLVMDialect::getRangeAttrName(),
+ LLVM::ConstantRangeAttr::get(rewriter.getContext(), APInt::getZero(32),
+ APInt(32, 64)));
+ Value mbcntLo = ROCDL::MbcntLoOp::create(
+ rewriter, loc, int32Type, minus1, zero, /*arg_attrs=*/{},
+ /*res_attrs=*/
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, lowRange})));
+ Value laneId = ROCDL::MbcntHiOp::create(
+ rewriter, loc, int32Type, minus1, mbcntLo, /*arg_attrs=*/{},
+ rewriter.getArrayAttr(rewriter.getDictionaryAttr({noundef, highRange})));
return laneId;
}
+
static constexpr StringLiteral amdgcnDataLayout =
"e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32"
"-p7:160:256:256:32-p8:128:128:128:48-p9:192:256:256:32-i64:64-v16:16-v24:"
@@ -104,18 +117,16 @@ struct GPULaneIdOpToROCDL : ConvertOpToLLVMPattern<gpu::LaneIdOp> {
LogicalResult
matchAndRewrite(gpu::LaneIdOp op, gpu::LaneIdOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- auto loc = op->getLoc();
+ Location loc = op.getLoc();
MLIRContext *context = rewriter.getContext();
- // convert to: %mlo = call @llvm.amdgcn.mbcnt.lo(-1, 0)
- // followed by: %lid = call @llvm.amdgcn.mbcnt.hi(-1, %mlo)
-
- Type intTy = IntegerType::get(context, 32);
- Value zero = arith::ConstantIntOp::create(rewriter, loc, 0, 32);
- Value minus1 = arith::ConstantIntOp::create(rewriter, loc, -1, 32);
- Value mbcntLo = ROCDL::MbcntLoOp::create(rewriter, loc, intTy,
- ValueRange{minus1, zero});
- Value laneId = ROCDL::MbcntHiOp::create(rewriter, loc, intTy,
- ValueRange{minus1, mbcntLo});
+ // convert to:
+ // %mlo = call noundef range(i32 0, 32)
+ // @llvm.amdgcn.mbcnt.lo(-1, 0)
+ // followed by:
+ // %lid = call noundef range(i32 0, 64)
+ // @llvm.amdgcn.mbcnt.hi(-1, %mlo)
+
+ Value laneId = getLaneId(rewriter, loc);
// Truncate or extend the result depending on the index bitwidth specified
// by the LLVMTypeConverter options.
const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
@@ -185,8 +196,7 @@ struct GPUShuffleOpLowering : public ConvertOpToLLVMPattern<gpu::ShuffleOp> {
Location loc = op->getLoc();
Value initShflValue = adaptor.getValue();
- const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth();
- Value srcLaneId = getLaneId(rewriter, loc, indexBitwidth);
+ Value srcLaneId = getLaneId(rewriter, loc);
auto int32Type = IntegerType::get(rewriter.getContext(), 32);
Value width = adaptor.getWidth();
diff --git a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
index fce7a3f..522e914 100644
--- a/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/MemRefBuilder.cpp
@@ -353,14 +353,9 @@ void UnrankedMemRefDescriptor::unpack(OpBuilder &builder, Location loc,
results.push_back(d.memRefDescPtr(builder, loc));
}
-void UnrankedMemRefDescriptor::computeSizes(
+Value UnrankedMemRefDescriptor::computeSize(
OpBuilder &builder, Location loc, const LLVMTypeConverter &typeConverter,
- ArrayRef<UnrankedMemRefDescriptor> values, ArrayRef<unsigned> addressSpaces,
- SmallVectorImpl<Value> &sizes) {
- if (values.empty())
- return;
- assert(values.size() == addressSpaces.size() &&
- "must provide address space for each descriptor");
+ UnrankedMemRefDescriptor desc, unsigned addressSpace) {
// Cache the index type.
Type indexType = typeConverter.getIndexType();
@@ -371,34 +366,31 @@ void UnrankedMemRefDescriptor::computeSizes(
builder, loc, indexType,
llvm::divideCeil(typeConverter.getIndexTypeBitwidth(), 8));
- sizes.reserve(sizes.size() + values.size());
- for (auto [desc, addressSpace] : llvm::zip(values, addressSpaces)) {
- // Emit IR computing the memory necessary to store the descriptor. This
- // assumes the descriptor to be
- // { type*, type*, index, index[rank], index[rank] }
- // and densely packed, so the total size is
- // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
- // TODO: consider including the actual size (including eventual padding due
- // to data layout) into the unranked descriptor.
- Value pointerSize = createIndexAttrConstant(
- builder, loc, indexType,
- llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
- Value doublePointerSize =
- LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
-
- // (1 + 2 * rank) * sizeof(index)
- Value rank = desc.rank(builder, loc);
- Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
- Value doubleRankIncremented =
- LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
- Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
- doubleRankIncremented, indexSize);
-
- // Total allocation size.
- Value allocationSize = LLVM::AddOp::create(
- builder, loc, indexType, doublePointerSize, rankIndexSize);
- sizes.push_back(allocationSize);
- }
+ // Emit IR computing the memory necessary to store the descriptor. This
+ // assumes the descriptor to be
+ // { type*, type*, index, index[rank], index[rank] }
+ // and densely packed, so the total size is
+ // 2 * sizeof(pointer) + (1 + 2 * rank) * sizeof(index).
+ // TODO: consider including the actual size (including eventual padding due
+ // to data layout) into the unranked descriptor.
+ Value pointerSize = createIndexAttrConstant(
+ builder, loc, indexType,
+ llvm::divideCeil(typeConverter.getPointerBitwidth(addressSpace), 8));
+ Value doublePointerSize =
+ LLVM::MulOp::create(builder, loc, indexType, two, pointerSize);
+
+ // (1 + 2 * rank) * sizeof(index)
+ Value rank = desc.rank(builder, loc);
+ Value doubleRank = LLVM::MulOp::create(builder, loc, indexType, two, rank);
+ Value doubleRankIncremented =
+ LLVM::AddOp::create(builder, loc, indexType, doubleRank, one);
+ Value rankIndexSize = LLVM::MulOp::create(builder, loc, indexType,
+ doubleRankIncremented, indexSize);
+
+ // Total allocation size.
+ Value allocationSize = LLVM::AddOp::create(builder, loc, indexType,
+ doublePointerSize, rankIndexSize);
+ return allocationSize;
}
Value UnrankedMemRefDescriptor::allocatedPtr(
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 2568044..48a0319 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -216,34 +216,14 @@ MemRefDescriptor ConvertToLLVMPattern::createMemRefDescriptor(
return memRefDescriptor;
}
-LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
- OpBuilder &builder, Location loc, TypeRange origTypes,
- SmallVectorImpl<Value> &operands, bool toDynamic) const {
- assert(origTypes.size() == operands.size() &&
- "expected as may original types as operands");
-
- // Find operands of unranked memref type and store them.
- SmallVector<UnrankedMemRefDescriptor> unrankedMemrefs;
- SmallVector<unsigned> unrankedAddressSpaces;
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
- unrankedMemrefs.emplace_back(operands[i]);
- FailureOr<unsigned> addressSpace =
- getTypeConverter()->getMemRefAddressSpace(memRefType);
- if (failed(addressSpace))
- return failure();
- unrankedAddressSpaces.emplace_back(*addressSpace);
- }
- }
-
- if (unrankedMemrefs.empty())
- return success();
-
- // Compute allocation sizes.
- SmallVector<Value> sizes;
- UnrankedMemRefDescriptor::computeSizes(builder, loc, *getTypeConverter(),
- unrankedMemrefs, unrankedAddressSpaces,
- sizes);
+Value ConvertToLLVMPattern::copyUnrankedDescriptor(
+ OpBuilder &builder, Location loc, UnrankedMemRefType memRefType,
+ Value operand, bool toDynamic) const {
+ // Convert memory space.
+ FailureOr<unsigned> addressSpace =
+ getTypeConverter()->getMemRefAddressSpace(memRefType);
+ if (failed(addressSpace))
+ return {};
// Get frequently used types.
Type indexType = getTypeConverter()->getIndexType();
@@ -254,52 +234,61 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
if (toDynamic) {
mallocFunc = LLVM::lookupOrCreateMallocFn(builder, module, indexType);
if (failed(mallocFunc))
- return failure();
+ return {};
}
if (!toDynamic) {
freeFunc = LLVM::lookupOrCreateFreeFn(builder, module);
if (failed(freeFunc))
- return failure();
+ return {};
}
- unsigned unrankedMemrefPos = 0;
- for (unsigned i = 0, e = operands.size(); i < e; ++i) {
- Type type = origTypes[i];
- if (!isa<UnrankedMemRefType>(type))
- continue;
- Value allocationSize = sizes[unrankedMemrefPos++];
- UnrankedMemRefDescriptor desc(operands[i]);
-
- // Allocate memory, copy, and free the source if necessary.
- Value memory =
- toDynamic ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
- allocationSize)
- .getResult()
- : LLVM::AllocaOp::create(builder, loc, getPtrType(),
- IntegerType::get(getContext(), 8),
- allocationSize,
- /*alignment=*/0);
- Value source = desc.memRefDescPtr(builder, loc);
- LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
- if (!toDynamic)
- LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
-
- // Create a new descriptor. The same descriptor can be returned multiple
- // times, attempting to modify its pointer can lead to memory leaks
- // (allocated twice and overwritten) or double frees (the caller does not
- // know if the descriptor points to the same memory).
- Type descriptorType = getTypeConverter()->convertType(type);
- if (!descriptorType)
- return failure();
- auto updatedDesc =
- UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
- Value rank = desc.rank(builder, loc);
- updatedDesc.setRank(builder, loc, rank);
- updatedDesc.setMemRefDescPtr(builder, loc, memory);
+ UnrankedMemRefDescriptor desc(operand);
+ Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+ builder, loc, *getTypeConverter(), desc, *addressSpace);
+
+ // Allocate memory, copy, and free the source if necessary.
+ Value memory = toDynamic
+ ? LLVM::CallOp::create(builder, loc, mallocFunc.value(),
+ allocationSize)
+ .getResult()
+ : LLVM::AllocaOp::create(builder, loc, getPtrType(),
+ IntegerType::get(getContext(), 8),
+ allocationSize,
+ /*alignment=*/0);
+ Value source = desc.memRefDescPtr(builder, loc);
+ LLVM::MemcpyOp::create(builder, loc, memory, source, allocationSize, false);
+ if (!toDynamic)
+ LLVM::CallOp::create(builder, loc, freeFunc.value(), source);
+
+ // Create a new descriptor. The same descriptor can be returned multiple
+ // times, attempting to modify its pointer can lead to memory leaks
+ // (allocated twice and overwritten) or double frees (the caller does not
+ // know if the descriptor points to the same memory).
+ Type descriptorType = getTypeConverter()->convertType(memRefType);
+ if (!descriptorType)
+ return {};
+ auto updatedDesc =
+ UnrankedMemRefDescriptor::poison(builder, loc, descriptorType);
+ Value rank = desc.rank(builder, loc);
+ updatedDesc.setRank(builder, loc, rank);
+ updatedDesc.setMemRefDescPtr(builder, loc, memory);
+ return updatedDesc;
+}
- operands[i] = updatedDesc;
+LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
+ OpBuilder &builder, Location loc, TypeRange origTypes,
+ SmallVectorImpl<Value> &operands, bool toDynamic) const {
+ assert(origTypes.size() == operands.size() &&
+ "expected as may original types as operands");
+ for (unsigned i = 0, e = operands.size(); i < e; ++i) {
+ if (auto memRefType = dyn_cast<UnrankedMemRefType>(origTypes[i])) {
+ Value updatedDesc = copyUnrankedDescriptor(builder, loc, memRefType,
+ operands[i], toDynamic);
+ if (!updatedDesc)
+ return failure();
+ operands[i] = updatedDesc;
+ }
}
-
return success();
}
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 1a9bf56..cb9dea1 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -365,6 +365,7 @@ Type LLVMTypeConverter::convertFunctionSignatureImpl(
useBarePtrCallConv = useBarePtrCallConv || options.useBarePtrCallConv;
auto funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter
: structFuncArgTypeConverter;
+
// Convert argument types one by one and check for errors.
for (auto [idx, type] : llvm::enumerate(funcTy.getInputs())) {
SmallVector<Type, 8> converted;
@@ -658,27 +659,19 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
/// UnrankedMemRefType, are converted following the specific rules for the
/// calling convention. Calling convention independent types are converted
/// following the default LLVM type conversions.
-Type LLVMTypeConverter::convertCallingConventionType(
- Type type, bool useBarePtrCallConv) const {
- if (useBarePtrCallConv)
- if (auto memrefTy = dyn_cast<BaseMemRefType>(type))
- return convertMemRefToBarePtr(memrefTy);
-
- return convertType(type);
-}
+LogicalResult LLVMTypeConverter::convertCallingConventionType(
+ Type type, SmallVectorImpl<Type> &result, bool useBarePtrCallConv) const {
+ if (useBarePtrCallConv) {
+ if (auto memrefTy = dyn_cast<BaseMemRefType>(type)) {
+ Type converted = convertMemRefToBarePtr(memrefTy);
+ if (!converted)
+ return failure();
+ result.push_back(converted);
+ return success();
+ }
+ }
-/// Promote the bare pointers in 'values' that resulted from memrefs to
-/// descriptors. 'stdTypes' holds they types of 'values' before the conversion
-/// to the LLVM-IR dialect (i.e., MemRefType, or any other builtin type).
-void LLVMTypeConverter::promoteBarePtrsToDescriptors(
- ConversionPatternRewriter &rewriter, Location loc, ArrayRef<Type> stdTypes,
- SmallVectorImpl<Value> &values) const {
- assert(stdTypes.size() == values.size() &&
- "The number of types and values doesn't match");
- for (unsigned i = 0, end = values.size(); i < end; ++i)
- if (auto memrefTy = dyn_cast<MemRefType>(stdTypes[i]))
- values[i] = MemRefDescriptor::fromStaticShape(rewriter, loc, *this,
- memrefTy, values[i]);
+ return convertType(type, result);
}
/// Convert a non-empty list of types of values produced by an operation into an
@@ -706,23 +699,35 @@ Type LLVMTypeConverter::packOperationResults(TypeRange types) const {
/// LLVM-compatible type. In particular, if more than one value is returned,
/// create an LLVM dialect structure type with elements that correspond to each
/// of the types converted with `convertCallingConventionType`.
-Type LLVMTypeConverter::packFunctionResults(TypeRange types,
- bool useBarePtrCallConv) const {
+Type LLVMTypeConverter::packFunctionResults(
+ TypeRange types, bool useBarePtrCallConv,
+ SmallVector<SmallVector<Type>> *groupedTypes,
+ int64_t *numConvertedTypes) const {
assert(!types.empty() && "expected non-empty list of type");
+ assert((!groupedTypes || groupedTypes->empty()) &&
+ "expected groupedTypes to be empty");
useBarePtrCallConv |= options.useBarePtrCallConv;
- if (types.size() == 1)
- return convertCallingConventionType(types.front(), useBarePtrCallConv);
-
SmallVector<Type> resultTypes;
resultTypes.reserve(types.size());
+ size_t sizeBefore = 0;
for (auto t : types) {
- auto converted = convertCallingConventionType(t, useBarePtrCallConv);
- if (!converted || !LLVM::isCompatibleType(converted))
+ if (failed(
+ convertCallingConventionType(t, resultTypes, useBarePtrCallConv)))
return {};
- resultTypes.push_back(converted);
+ if (groupedTypes) {
+ SmallVector<Type> &group = groupedTypes->emplace_back();
+ llvm::append_range(group, ArrayRef(resultTypes).drop_front(sizeBefore));
+ }
+ sizeBefore = resultTypes.size();
}
+ if (numConvertedTypes)
+ *numConvertedTypes = resultTypes.size();
+ if (resultTypes.size() == 1)
+ return resultTypes.front();
+ if (resultTypes.empty())
+ return {};
return LLVM::LLVMStructType::getLiteral(&getContext(), resultTypes);
}
@@ -740,40 +745,50 @@ Value LLVMTypeConverter::promoteOneMemRefDescriptor(Location loc, Value operand,
return allocated;
}
-SmallVector<Value, 4>
-LLVMTypeConverter::promoteOperands(Location loc, ValueRange opOperands,
- ValueRange operands, OpBuilder &builder,
- bool useBarePtrCallConv) const {
+SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(
+ Location loc, ValueRange opOperands, ValueRange adaptorOperands,
+ OpBuilder &builder, bool useBarePtrCallConv) const {
+ SmallVector<ValueRange> ranges;
+ for (size_t i = 0, e = adaptorOperands.size(); i < e; i++)
+ ranges.push_back(adaptorOperands.slice(i, 1));
+ return promoteOperands(loc, opOperands, ranges, builder, useBarePtrCallConv);
+}
+
+SmallVector<Value, 4> LLVMTypeConverter::promoteOperands(
+ Location loc, ValueRange opOperands, ArrayRef<ValueRange> adaptorOperands,
+ OpBuilder &builder, bool useBarePtrCallConv) const {
SmallVector<Value, 4> promotedOperands;
- promotedOperands.reserve(operands.size());
+ promotedOperands.reserve(adaptorOperands.size());
useBarePtrCallConv |= options.useBarePtrCallConv;
- for (auto it : llvm::zip(opOperands, operands)) {
- auto operand = std::get<0>(it);
- auto llvmOperand = std::get<1>(it);
-
+ for (auto [operand, llvmOperand] :
+ llvm::zip_equal(opOperands, adaptorOperands)) {
if (useBarePtrCallConv) {
// For the bare-ptr calling convention, we only have to extract the
// aligned pointer of a memref.
if (isa<MemRefType>(operand.getType())) {
- MemRefDescriptor desc(llvmOperand);
- llvmOperand = desc.alignedPtr(builder, loc);
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ MemRefDescriptor desc(llvmOperand.front());
+ promotedOperands.push_back(desc.alignedPtr(builder, loc));
+ continue;
} else if (isa<UnrankedMemRefType>(operand.getType())) {
llvm_unreachable("Unranked memrefs are not supported");
}
} else {
if (isa<UnrankedMemRefType>(operand.getType())) {
- UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand,
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ UnrankedMemRefDescriptor::unpack(builder, loc, llvmOperand.front(),
promotedOperands);
continue;
}
if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
- MemRefDescriptor::unpack(builder, loc, llvmOperand, memrefType,
+ assert(llvmOperand.size() == 1 && "Expected a single operand");
+ MemRefDescriptor::unpack(builder, loc, llvmOperand.front(), memrefType,
promotedOperands);
continue;
}
}
- promotedOperands.push_back(llvmOperand);
+ llvm::append_range(promotedOperands, llvmOperand);
}
return promotedOperands;
}
@@ -802,11 +817,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
result.append(converted.begin(), converted.end());
return success();
}
- auto converted = converter.convertType(type);
- if (!converted)
- return failure();
- result.push_back(converted);
- return success();
+ return converter.convertType(type, result);
}
/// Callback to convert function argument types. It converts MemRef function
@@ -814,11 +825,7 @@ mlir::structFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
LogicalResult
mlir::barePtrFuncArgTypeConverter(const LLVMTypeConverter &converter, Type type,
SmallVectorImpl<Type> &result) {
- auto llvmTy = converter.convertCallingConventionType(
- type, /*useBarePointerCallConv=*/true);
- if (!llvmTy)
- return failure();
-
- result.push_back(llvmTy);
- return success();
+ return converter.convertCallingConventionType(
+ type, result,
+ /*useBarePointerCallConv=*/true);
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index 6bd0e2d..2b7bdc9 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -17,11 +17,13 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
#include <cstdint>
+#include <numeric>
using namespace mlir;
@@ -97,6 +99,48 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
+static Value calculateMemrefTotalSizeBytes(Location loc, MemRefType memrefType,
+ OpBuilder &builder) {
+ assert(isMemRefTypeLegalForEmitC(memrefType) &&
+ "incompatible memref type for EmitC conversion");
+ emitc::CallOpaqueOp elementSize = emitc::CallOpaqueOp::create(
+ builder, loc, emitc::SizeTType::get(builder.getContext()),
+ builder.getStringAttr("sizeof"), ValueRange{},
+ ArrayAttr::get(builder.getContext(),
+ {TypeAttr::get(memrefType.getElementType())}));
+
+ IndexType indexType = builder.getIndexType();
+ int64_t numElements = std::accumulate(memrefType.getShape().begin(),
+ memrefType.getShape().end(), int64_t{1},
+ std::multiplies<int64_t>());
+ emitc::ConstantOp numElementsValue = emitc::ConstantOp::create(
+ builder, loc, indexType, builder.getIndexAttr(numElements));
+
+ Type sizeTType = emitc::SizeTType::get(builder.getContext());
+ emitc::MulOp totalSizeBytes = emitc::MulOp::create(
+ builder, loc, sizeTType, elementSize.getResult(0), numElementsValue);
+
+ return totalSizeBytes.getResult();
+}
+
+static emitc::ApplyOp
+createPointerFromEmitcArray(Location loc, OpBuilder &builder,
+ TypedValue<emitc::ArrayType> arrayValue) {
+
+ emitc::ConstantOp zeroIndex = emitc::ConstantOp::create(
+ builder, loc, builder.getIndexType(), builder.getIndexAttr(0));
+
+ emitc::ArrayType arrayType = arrayValue.getType();
+ llvm::SmallVector<mlir::Value> indices(arrayType.getRank(), zeroIndex);
+ emitc::SubscriptOp subPtr =
+ emitc::SubscriptOp::create(builder, loc, arrayValue, ValueRange(indices));
+ emitc::ApplyOp ptr = emitc::ApplyOp::create(
+ builder, loc, emitc::PointerType::get(arrayType.getElementType()),
+ builder.getStringAttr("&"), subPtr);
+
+ return ptr;
+}
+
struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
@@ -112,19 +156,21 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
Type elementType = memrefType.getElementType();
IndexType indexType = rewriter.getIndexType();
- emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
- loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
+ emitc::CallOpaqueOp sizeofElementOp = emitc::CallOpaqueOp::create(
+ rewriter, loc, sizeTType, rewriter.getStringAttr("sizeof"),
+ ValueRange{},
ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
int64_t numElements = 1;
for (int64_t dimSize : memrefType.getShape()) {
numElements *= dimSize;
}
- Value numElementsValue = rewriter.create<emitc::ConstantOp>(
- loc, indexType, rewriter.getIndexAttr(numElements));
+ Value numElementsValue = emitc::ConstantOp::create(
+ rewriter, loc, indexType, rewriter.getIndexAttr(numElements));
- Value totalSizeBytes = rewriter.create<emitc::MulOp>(
- loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
+ Value totalSizeBytes =
+ emitc::MulOp::create(rewriter, loc, sizeTType,
+ sizeofElementOp.getResult(0), numElementsValue);
emitc::CallOpaqueOp allocCall;
StringAttr allocFunctionName;
@@ -132,8 +178,8 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
SmallVector<Value, 2> argsVec;
if (allocOp.getAlignment()) {
allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
- alignmentValue = rewriter.create<emitc::ConstantOp>(
- loc, sizeTType,
+ alignmentValue = emitc::ConstantOp::create(
+ rewriter, loc, sizeTType,
rewriter.getIntegerAttr(indexType,
allocOp.getAlignment().value_or(0)));
argsVec.push_back(alignmentValue);
@@ -144,21 +190,62 @@ struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
argsVec.push_back(totalSizeBytes);
ValueRange args(argsVec);
- allocCall = rewriter.create<emitc::CallOpaqueOp>(
- loc,
+ allocCall = emitc::CallOpaqueOp::create(
+ rewriter, loc,
emitc::PointerType::get(
emitc::OpaqueType::get(rewriter.getContext(), "void")),
allocFunctionName, args);
emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
- emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
- loc, targetPointerType, allocCall.getResult(0));
+ emitc::CastOp castOp = emitc::CastOp::create(
+ rewriter, loc, targetPointerType, allocCall.getResult(0));
rewriter.replaceOp(allocOp, castOp);
return success();
}
};
+struct ConvertCopy final : public OpConversionPattern<memref::CopyOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(memref::CopyOp copyOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = copyOp.getLoc();
+ MemRefType srcMemrefType = cast<MemRefType>(copyOp.getSource().getType());
+ MemRefType targetMemrefType =
+ cast<MemRefType>(copyOp.getTarget().getType());
+
+ if (!isMemRefTypeLegalForEmitC(srcMemrefType))
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible source memref type for EmitC conversion");
+
+ if (!isMemRefTypeLegalForEmitC(targetMemrefType))
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible target memref type for EmitC conversion");
+
+ auto srcArrayValue =
+ cast<TypedValue<emitc::ArrayType>>(operands.getSource());
+ emitc::ApplyOp srcPtr =
+ createPointerFromEmitcArray(loc, rewriter, srcArrayValue);
+
+ auto targetArrayValue =
+ cast<TypedValue<emitc::ArrayType>>(operands.getTarget());
+ emitc::ApplyOp targetPtr =
+ createPointerFromEmitcArray(loc, rewriter, targetArrayValue);
+
+ emitc::CallOpaqueOp memCpyCall = emitc::CallOpaqueOp::create(
+ rewriter, loc, TypeRange{}, "memcpy",
+ ValueRange{
+ targetPtr.getResult(), srcPtr.getResult(),
+ calculateMemrefTotalSizeBytes(loc, srcMemrefType, rewriter)});
+
+ rewriter.replaceOp(copyOp, memCpyCall.getResults());
+
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -320,6 +407,7 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
- ConvertLoad, ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertCopy, ConvertGlobal,
+ ConvertGetGlobal, ConvertLoad, ConvertStore>(
+ converter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index e78dd76..a073a9a 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -18,6 +18,8 @@
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/StringRef.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMEMREFTOEMITC
@@ -27,6 +29,15 @@ namespace mlir {
using namespace mlir;
namespace {
+
+emitc::IncludeOp addStandardHeader(OpBuilder &builder, ModuleOp module,
+ StringRef headerName) {
+ StringAttr includeAttr = builder.getStringAttr(headerName);
+ return emitc::IncludeOp::create(
+ builder, module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+}
+
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
using Base::Base;
@@ -55,34 +66,29 @@ struct ConvertMemRefToEmitCPass
return signalPassFailure();
mlir::ModuleOp module = getOperation();
+ llvm::SmallSet<StringRef, 4> existingHeaders;
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ module.walk([&](mlir::emitc::IncludeOp includeOp) {
+ if (includeOp.getIsStandardInclude())
+ existingHeaders.insert(includeOp.getInclude());
+ });
+
module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
- if (callOp.getCallee() != alignedAllocFunctionName &&
- callOp.getCallee() != mallocFunctionName) {
+ StringRef expectedHeader;
+ if (callOp.getCallee() == alignedAllocFunctionName ||
+ callOp.getCallee() == mallocFunctionName)
+ expectedHeader = options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader;
+ else if (callOp.getCallee() == memcpyFunctionName)
+ expectedHeader =
+ options.lowerToCpp ? cppStringLibraryHeader : cStringLibraryHeader;
+ else
return mlir::WalkResult::advance();
+ if (!existingHeaders.contains(expectedHeader)) {
+ addStandardHeader(builder, module, expectedHeader);
+ existingHeaders.insert(expectedHeader);
}
-
- for (auto &op : *module.getBody()) {
- emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
- if (!includeOp) {
- continue;
- }
- if (includeOp.getIsStandardInclude() &&
- ((options.lowerToCpp &&
- includeOp.getInclude() == cppStandardLibraryHeader) ||
- (!options.lowerToCpp &&
- includeOp.getInclude() == cStandardLibraryHeader))) {
- return mlir::WalkResult::interrupt();
- }
- }
-
- mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
- StringAttr includeAttr =
- builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
- : cStandardLibraryHeader);
- builder.create<mlir::emitc::IncludeOp>(
- module.getLoc(), includeAttr,
- /*is_standard_include=*/builder.getUnitAttr());
- return mlir::WalkResult::interrupt();
+ return mlir::WalkResult::advance();
});
}
};
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index d6bdd34..262e0e7 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -1246,10 +1246,8 @@ struct MemorySpaceCastOpLowering
auto result = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(resultTypeU));
result.setRank(rewriter, loc, rank);
- SmallVector<Value, 1> sizes;
- UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
- result, resultAddrSpace, sizes);
- Value resultUnderlyingSize = sizes.front();
+ Value resultUnderlyingSize = UnrankedMemRefDescriptor::computeSize(
+ rewriter, loc, *getTypeConverter(), result, resultAddrSpace);
Value resultUnderlyingDesc =
LLVM::AllocaOp::create(rewriter, loc, getPtrType(),
rewriter.getI8Type(), resultUnderlyingSize);
@@ -1530,12 +1528,11 @@ private:
auto targetDesc = UnrankedMemRefDescriptor::poison(
rewriter, loc, typeConverter->convertType(targetType));
targetDesc.setRank(rewriter, loc, resultRank);
- SmallVector<Value, 4> sizes;
- UnrankedMemRefDescriptor::computeSizes(rewriter, loc, *getTypeConverter(),
- targetDesc, addressSpace, sizes);
+ Value allocationSize = UnrankedMemRefDescriptor::computeSize(
+ rewriter, loc, *getTypeConverter(), targetDesc, addressSpace);
Value underlyingDescPtr = LLVM::AllocaOp::create(
rewriter, loc, getPtrType(), IntegerType::get(getContext(), 8),
- sizes.front());
+ allocationSize);
targetDesc.setMemRefDescPtr(rewriter, loc, underlyingDescPtr);
// Extract pointers and offset from the source memref.
@@ -1872,6 +1869,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umin;
case arith::AtomicRMWKind::ori:
return LLVM::AtomicBinOp::_or;
+ case arith::AtomicRMWKind::xori:
+ return LLVM::AtomicBinOp::_xor;
case arith::AtomicRMWKind::andi:
return LLVM::AtomicBinOp::_and;
default:
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 2549a9c..c6c5ab3 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -283,11 +283,13 @@ struct MmaLdMatrixOpToNVVM : public ConvertOpToLLVMPattern<nvgpu::LdMatrixOp> {
Value srcPtr =
getStridedElementPtr(rewriter, b.getLoc(), srcMemrefType,
adaptor.getSrcMemref(), adaptor.getIndices());
+ auto shape = NVVM::LdStMatrixShapeAttr::get(rewriter.getContext(), 8, 8);
Value ldMatrixResult = NVVM::LdMatrixOp::create(
b, ldMatrixResultType, srcPtr,
/*num=*/op.getNumTiles(),
/*layout=*/op.getTranspose() ? NVVM::MMALayout::col
- : NVVM::MMALayout::row);
+ : NVVM::MMALayout::row,
+ /*shape=*/shape, /*eltType=*/NVVM::LdStMatrixEltType::B16);
// The ldmatrix operation returns either a single i32 value or a struct of
// i32 values. Here we unpack those values and cast them back to their
@@ -1104,12 +1106,10 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LDBG() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ LDBG() << "Generating warpgroup.descriptor: " << "leading_off:"
+ << leadDimVal << "\t" << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t" << "layout_type:" << swizzle
+ << " (" << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
<< ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
@@ -1399,14 +1399,12 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LDBG() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
- << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
- << "][" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
- << "])";
+ LDBG() << "\t wgmma." << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
+ << "(A[" << (iterationM * wgmmaM) << ":"
+ << (iterationM * wgmmaM) + wgmmaM << "][" << (iterationK * wgmmaK)
+ << ":" << (iterationK * wgmmaK + wgmmaK) << "] * " << " B["
+ << (iterationK * wgmmaK) << ":" << (iterationK * wgmmaK + wgmmaK)
+ << "][" << 0 << ":" << wgmmaN << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 91788f9..e0144bf 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -61,7 +61,7 @@ struct PtxLowering
op.getAsmValues(rewriter, asmValues);
for (auto &[asmValue, modifier] : asmValues) {
- LDBG() << asmValue << "\t Modifier : " << &modifier;
+ LDBG() << asmValue << "\t Modifier : " << modifier;
generator.insertValue(asmValue, modifier);
}
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index ba448e4..37cfc9f 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -382,8 +382,11 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
// With the body block done, we can fill in the condition block.
rewriter.setInsertionPointToEnd(conditionBlock);
- auto comparison = arith::CmpIOp::create(
- rewriter, loc, arith::CmpIPredicate::slt, iv, upperBound);
+ arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
+ ? arith::CmpIPredicate::ult
+ : arith::CmpIPredicate::slt;
+ auto comparison =
+ arith::CmpIOp::create(rewriter, loc, predicate, iv, upperBound);
cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock,
ArrayRef<Value>(), endBlock, ArrayRef<Value>());
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 84cbd86..1f239aa 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -154,6 +154,10 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
+ if (forOp.getUnsignedCmp())
+ return rewriter.notifyMatchFailure(forOp,
+ "unsigned loops are not supported");
+
// Create an emitc::variable op for each result. These variables will be
// assigned to by emitc::assign ops within the loop body.
SmallVector<Value> resultVariables;
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index dc92367f..55ed31e 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -178,8 +178,14 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
// Generate the rest of the loop header.
rewriter.setInsertionPointToEnd(header);
auto *mergeBlock = loopOp.getMergeBlock();
- auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
- newIndVar, adaptor.getUpperBound());
+ Value cmpOp;
+ if (forOp.getUnsignedCmp()) {
+ cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.getI1Type(),
+ newIndVar, adaptor.getUpperBound());
+ } else {
+ cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
+ newIndVar, adaptor.getUpperBound());
+ }
spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
ArrayRef<Value>(), mergeBlock,
diff --git a/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
new file mode 100644
index 0000000..2d4b2b6
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/CMakeLists.txt
@@ -0,0 +1,19 @@
+add_mlir_conversion_library(MLIRVectorToAMX
+ VectorToAMX.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/VectorToAMX
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_LIBS PUBLIC
+ MLIRAMXDialect
+ MLIRAffineUtils
+ MLIRArithDialect
+ MLIRLinalgUtils
+ MLIRMemRefDialect
+ MLIRSCFDialect
+ MLIRTransforms
+ MLIRVectorDialect
+ )
diff --git a/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
new file mode 100644
index 0000000..a11e9b2
--- /dev/null
+++ b/mlir/lib/Conversion/VectorToAMX/VectorToAMX.cpp
@@ -0,0 +1,283 @@
+//===- VectorToAMX.cpp - Convert vector to AMX dialect ----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/VectorToAMX/VectorToAMX.h"
+
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/ViewLikeInterfaceUtils.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include <numeric>
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTVECTORTOAMX
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+/// Return true if vector shape is compatible with AMX tiles.
+/// The validation accounts for VNNI packing.
+static bool verifyAmxShape(VectorType vec) {
+ // Check overall shape:
+ // - 2D for plain layout input or output
+ // - 3D for VNNI packed input
+ if (vec.getRank() != 2 && vec.getRank() != 3)
+ return false;
+
+ ArrayRef<int64_t> shape = vec.getShape();
+ int64_t rows = shape[0];
+ int64_t cols = shape[1];
+ unsigned elemBitWidth = vec.getElementType().getIntOrFloatBitWidth();
+
+ // 3D shape indicates VNNI packed layout.
+ if (vec.getRank() == 3) {
+ int64_t vnniFactor = 32 / elemBitWidth;
+ if (shape.back() != vnniFactor)
+ return false;
+ cols *= vnniFactor;
+ }
+
+ // AMX tile supports up to 16 rows of 64 bytes each.
+ constexpr unsigned maxRows = 16;
+ constexpr unsigned maxBitsPerRow = 64 * 8;
+ return rows <= maxRows && (cols * elemBitWidth) <= maxBitsPerRow;
+}
+
+/// Checks if contraction operands are in AMX-compatible packed VNNI layout.
+static LogicalResult isAmxVnniLayout(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+ if (!accType || accType.getRank() != 2)
+ return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
+
+ // Expect 3D inputs for VNNI packed data.
+ VectorType lhsType = contractOp.getLhs().getType();
+ VectorType rhsType = contractOp.getRhs().getType();
+ if (lhsType.getRank() != 3 || rhsType.getRank() != 3)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects lhs and rhs 3D vectors");
+
+ // Check if shapes are compatible with AMX tile.
+ if (!verifyAmxShape(lhsType) || !verifyAmxShape(rhsType) ||
+ !verifyAmxShape(accType))
+ return rewriter.notifyMatchFailure(contractOp, "Invalid operand shape");
+
+ // Validate affine maps.
+ //
+ // Iterators can be ordered arbitrarily. Indexing map positions are based on
+ // operands' target shapes.
+ // The matrix layouts must match the following:
+ // - matrix A - [M]x[K/vnniFactor]x[vnniFactor]
+ // - matrix B - [K/vnniFactor]x[N]x[vnniFactor]
+ // - matrix C - [M]x[N]
+ SmallVector<AffineMap, 4> indexingMaps = contractOp.getIndexingMapsArray();
+ AffineMap mapA = indexingMaps[0];
+ AffineMap mapB = indexingMaps[1];
+ if (mapA.getNumInputs() != 4 || mapA.getNumResults() != 3 ||
+ mapB.getNumResults() != 3)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid input indexing maps");
+ FailureOr<linalg::ContractionDimensions> dims =
+ linalg::inferContractionDims(indexingMaps);
+ if (failed(dims))
+ return rewriter.notifyMatchFailure(contractOp,
+ "Failed to infer contraction dims");
+ // Two reduction dimensions are expected:
+ // - one for the K dimension
+ // - one for the VNNI factor
+ if (dims->k.size() != 2)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expected two reduction dims");
+ assert(dims->m.size() == 1 && dims->n.size() == 1 &&
+ "Invalid parallel contraction dims");
+
+ SmallVector<vector::IteratorType> iteratorTypes =
+ contractOp.getIteratorTypesArray();
+ // Check VNNI dim maps - the innermost dim for A and B inputs.
+ auto vnniDimA = dyn_cast<AffineDimExpr>(mapA.getResult(2));
+ auto vnniDimB = dyn_cast<AffineDimExpr>(mapB.getResult(2));
+ if (!vnniDimA || !vnniDimB || vnniDimA != vnniDimB ||
+ iteratorTypes[vnniDimA.getPosition()] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid VNNI dim map");
+ // Check K dim maps - non-transposed row-major layout.
+ auto redDimA = dyn_cast<AffineDimExpr>(mapA.getResult(1));
+ auto redDimB = dyn_cast<AffineDimExpr>(mapB.getResult(0));
+ if (!redDimA || !redDimB || redDimA != redDimB ||
+ iteratorTypes[redDimA.getPosition()] != vector::IteratorType::reduction)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid K dim map");
+ // Check M and N dim maps - map to non-transposed output.
+ AffineMap mapC = indexingMaps[2];
+ auto mDimC = dyn_cast<AffineDimExpr>(mapC.getResult(0));
+ auto nDimC = dyn_cast<AffineDimExpr>(mapC.getResult(1));
+ if (!mDimC || !nDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid acc maps");
+ auto parallelDimA = dyn_cast<AffineDimExpr>(mapA.getResult(0));
+ if (!parallelDimA ||
+ iteratorTypes[parallelDimA.getPosition()] !=
+ vector::IteratorType::parallel ||
+ parallelDimA != mDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid M dim map");
+ auto parallelDimB = dyn_cast<AffineDimExpr>(mapB.getResult(1));
+ if (!parallelDimB ||
+ iteratorTypes[parallelDimB.getPosition()] !=
+ vector::IteratorType::parallel ||
+ parallelDimB != nDimC)
+ return rewriter.notifyMatchFailure(contractOp, "Invalid N dim map");
+
+ return success();
+}
+
+/// Validate contraction operands for AMX lowering.
+static LogicalResult validateOperands(PatternRewriter &rewriter,
+ vector::ContractionOp contractOp) {
+ VectorType accType = dyn_cast<VectorType>(contractOp.getAcc().getType());
+ if (!accType)
+ return rewriter.notifyMatchFailure(contractOp, "Expects vector acc");
+
+ // Check if operand types are compatible with AMX compute ops.
+ bool validElemTypes = false;
+ Type lhsElemType = contractOp.getLhs().getType().getElementType();
+ Type rhsElemType = contractOp.getRhs().getType().getElementType();
+ Type accElemType = accType.getElementType();
+ if (accElemType.isInteger(32)) {
+ validElemTypes = lhsElemType.isInteger(8) && rhsElemType.isInteger(8);
+ } else if (accElemType.isF32()) {
+ validElemTypes = (lhsElemType.isF16() && rhsElemType.isF16()) ||
+ (lhsElemType.isBF16() && rhsElemType.isBF16());
+ }
+ if (!validElemTypes)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Invalid combination of operand types");
+
+ if (failed(isAmxVnniLayout(rewriter, contractOp)))
+ return failure();
+
+ return success();
+}
+
+/// Collapses the two innermost dimensions together.
+static Value collapseLastDim(PatternRewriter &rewriter,
+ TypedValue<MemRefType> memref) {
+ int64_t rank = memref.getType().getRank();
+ SmallVector<ReassociationIndices> reassocIndices;
+ for (auto i : llvm::seq<int64_t>(0, rank - 2))
+ reassocIndices.push_back({i});
+ reassocIndices.push_back({rank - 2, rank - 1});
+ return memref::CollapseShapeOp::create(rewriter, memref.getLoc(), memref,
+ reassocIndices);
+}
+
+/// Loads vector values to an AMX tile.
+static TypedValue<amx::TileType> loadTile(PatternRewriter &rewriter,
+ TypedValue<VectorType> vec) {
+ Location loc = vec.getLoc();
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+ // Transfer the vector to a tile through an intermediate buffer.
+ VectorType vecTy = vec.getType();
+ Value buf = memref::AllocaOp::create(
+ rewriter, loc, MemRefType::get(vecTy.getShape(), vecTy.getElementType()));
+ SmallVector<Value> indices(vecTy.getRank(), zeroIndex);
+ vector::TransferWriteOp::create(rewriter, loc, vec, buf, indices);
+
+ // Collapse the VNNI dimension in case of packing.
+ bool isPacked = vecTy.getRank() == 3;
+ if (isPacked)
+ buf = collapseLastDim(rewriter, cast<TypedValue<MemRefType>>(buf));
+
+ ArrayRef<int64_t> shape = vecTy.getShape();
+ int64_t rows = shape[0];
+ int64_t cols = std::accumulate(shape.begin() + 1, shape.end(), 1,
+ std::multiplies<int64_t>());
+ auto tileType = amx::TileType::get({rows, cols}, vecTy.getElementType());
+
+ return amx::TileLoadOp::create(rewriter, loc, tileType, buf,
+ {zeroIndex, zeroIndex});
+}
+
+/// Stores an AMX tile in a vector.
+static TypedValue<VectorType> storeTile(PatternRewriter &rewriter,
+ TypedValue<amx::TileType> tile) {
+ Location loc = tile.getLoc();
+ Value zeroIndex = rewriter.createOrFold<arith::ConstantIndexOp>(loc, 0);
+
+ // Transfer the tile to a vector through an intermediate buffer.
+ amx::TileType tileTy = tile.getType();
+ Value buf = memref::AllocaOp::create(
+ rewriter, loc,
+ MemRefType::get(tileTy.getShape(), tileTy.getElementType()));
+ SmallVector<Value> indices(2, zeroIndex);
+ amx::TileStoreOp::create(rewriter, loc, buf, indices, tile);
+
+ auto vecTy = VectorType::get(tileTy.getShape(), tileTy.getElementType());
+ return vector::TransferReadOp::create(rewriter, loc, vecTy, buf, indices, {});
+}
+
+struct ContractionToAMX : public OpRewritePattern<vector::ContractionOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
+ PatternRewriter &rewriter) const override {
+ Location loc = contractOp.getLoc();
+
+ if (contractOp.getKind() != vector::CombiningKind::ADD)
+ return rewriter.notifyMatchFailure(contractOp,
+ "Expects add combining kind");
+ if (failed(validateOperands(rewriter, contractOp)))
+ return failure();
+
+ TypedValue<amx::TileType> lhsTile = loadTile(rewriter, contractOp.getLhs());
+ TypedValue<amx::TileType> rhsTile = loadTile(rewriter, contractOp.getRhs());
+ auto acc = dyn_cast<TypedValue<VectorType>>(contractOp.getAcc());
+ assert(acc && "Invalid accumulator type");
+ TypedValue<amx::TileType> accTile = loadTile(rewriter, acc);
+
+ TypedValue<amx::TileType> tileMul;
+ if (acc.getType().getElementType().isFloat()) {
+ tileMul = amx::TileMulFOp::create(rewriter, loc, accTile.getType(),
+ lhsTile, rhsTile, accTile);
+ } else {
+ tileMul = amx::TileMulIOp::create(rewriter, loc, accTile.getType(),
+ lhsTile, rhsTile, accTile);
+ }
+
+ Value res = storeTile(rewriter, tileMul);
+ rewriter.replaceOp(contractOp, res);
+
+ return success();
+ }
+};
+
+struct ConvertVectorToAMXPass
+ : public impl::ConvertVectorToAMXBase<ConvertVectorToAMXPass> {
+ void runOnOperation() override {
+ MLIRContext &ctx = getContext();
+ RewritePatternSet patterns(&ctx);
+ populateVectorToAMXConversionPatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
+
+void mlir::populateVectorToAMXConversionPatterns(RewritePatternSet &patterns) {
+ patterns.add<ContractionToAMX>(patterns.getContext());
+}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index f9e2a01..afc3d1b 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1891,15 +1891,21 @@ struct VectorFromElementsLowering
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
- // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
- // Such ops should be handled in the same way as vector.insert.
+ // Only support 1-D vectors. Multi-dimensional vectors should have been
+ // transformed to 1-D vectors by the vector-to-vector transformations before
+ // this.
if (vectorType.getRank() > 1)
return rewriter.notifyMatchFailure(fromElementsOp,
"rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
+ Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
- for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
- result = vector::InsertOp::create(rewriter, loc, val, result, idx);
+ for (auto [idx, val] : llvm::enumerate(adaptor.getElements())) {
+ auto constIdx =
+ LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, idx);
+ result = LLVM::InsertElementOp::create(rewriter, loc, llvmType, result,
+ val, constIdx);
+ }
rewriter.replaceOp(fromElementsOp, result);
return success();
}
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index cf10869..9852df6 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -94,6 +94,7 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
+ populateVectorFromElementsLoweringPatterns(patterns);
if (armI8MM) {
if (armNeon)
arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
diff --git a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
index 567083d..e9ad67c5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToXeGPU/CMakeLists.txt
@@ -13,4 +13,5 @@ add_mlir_conversion_library(MLIRVectorToXeGPU
MLIRTransforms
MLIRVectorDialect
MLIRXeGPUDialect
+ MLIRXeGPUUtils
)
diff --git a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
index 8010755..819c2e5 100644
--- a/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
+++ b/mlir/lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
@@ -14,9 +14,11 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"
@@ -68,11 +70,6 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
if (!srcTy)
return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
- // Perform common data transfer checks.
- VectorType vecTy = xferOp.getVectorType();
- if (failed(storeLoadPreconditions(rewriter, xferOp, vecTy)))
- return failure();
-
// Validate further transfer op semantics.
SmallVector<int64_t> strides;
int64_t offset;
@@ -80,6 +77,7 @@ static LogicalResult transferPreconditions(PatternRewriter &rewriter,
return rewriter.notifyMatchFailure(
xferOp, "Buffer must be contiguous in the innermost dimension");
+ VectorType vecTy = xferOp.getVectorType();
unsigned vecRank = vecTy.getRank();
if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
return rewriter.notifyMatchFailure(
@@ -155,6 +153,277 @@ createNdDescriptor(PatternRewriter &rewriter, Location loc,
return ndDesc;
}
+// Adjusts the strides of a memref according to a given permutation map for
+// vector operations.
+//
+// This function updates the innermost strides in the `strides` array to
+// reflect the permutation specified by `permMap`. The permutation is computed
+// using the inverse and broadcasting-aware version of the permutation map,
+// and is applied to the relevant strides. This ensures that memory accesses
+// are consistent with the logical permutation of vector elements.
+//
+// Example:
+// Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
+// If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
+// 0]), then after calling this function, the last two strides will be
+// swapped:
+// Original strides: [s0, s1, s2, s3]
+// After permutation: [s0, s1, s3, s2]
+//
+static void adjustStridesForPermutation(AffineMap permMap,
+ SmallVectorImpl<Value> &strides) {
+
+ AffineMap invMap = inverseAndBroadcastProjectedPermutation(permMap);
+ SmallVector<unsigned> perms;
+ invMap.isPermutationOfMinorIdentityWithBroadcasting(perms);
+ SmallVector<int64_t> perms64(perms.begin(), perms.end());
+ strides = applyPermutation(strides, perms64);
+}
+
+// Computes memory strides for vector transfer operations, handling both
+// static and dynamic memrefs while applying permutation transformations
+// for XeGPU lowering.
+static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
+ SmallVector<Value> strides;
+ Value baseMemref = xferOp.getBase();
+ AffineMap permMap = xferOp.getPermutationMap();
+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
+
+ Location loc = xferOp.getLoc();
+ if (memrefType.hasStaticShape()) {
+ int64_t offset;
+ SmallVector<int64_t> intStrides;
+ if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
+ return {};
+ // Wrap static strides as MLIR values
+ for (int64_t s : intStrides)
+ strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
+ } else {
+ // For dynamic shape memref, use memref.extract_strided_metadata to get
+ // stride values
+ unsigned rank = memrefType.getRank();
+ Type indexType = rewriter.getIndexType();
+
+ // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
+ // size0, size1, ..., sizeN-1]
+ SmallVector<Type> resultTypes;
+ resultTypes.push_back(MemRefType::get(
+ {}, memrefType.getElementType())); // base memref (unranked)
+ resultTypes.push_back(indexType); // offset
+
+ for (unsigned i = 0; i < rank; ++i)
+ resultTypes.push_back(indexType); // strides
+
+ for (unsigned i = 0; i < rank; ++i)
+ resultTypes.push_back(indexType); // sizes
+
+ auto meta = memref::ExtractStridedMetadataOp::create(
+ rewriter, loc, resultTypes, baseMemref);
+ strides.append(meta.getStrides().begin(), meta.getStrides().end());
+ }
+ // Adjust strides according to the permutation map (e.g., for transpose)
+ adjustStridesForPermutation(permMap, strides);
+ return strides;
+}
+
+// This function compute the vectors of localOffsets for scattered load/stores.
+// It is used in the lowering of vector.transfer_read/write to
+// load_gather/store_scatter Example:
+// %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
+// %cst {in_bounds = [true, true, true, true]}>} :
+// memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
+//
+// %6 = vector.step: vector<4xindex>
+// %7 = vector.step: vector<2xindex>
+// %8 = vector.step: vector<6xindex>
+// %9 = vector.step: vector<32xindex>
+// %10 = arith.mul %6, 384
+// %11 = arith.mul %7, 192
+// %12 = arith.mul %8, 32
+// %13 = arith.mul %9, 1
+// %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
+// %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
+// %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
+// %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
+// %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
+// %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
+// %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
+// %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
+// %22 = arith.add %18, %19
+// %23 = arith.add %20, %21
+// %local_offsets = arith.add %22, %23
+// %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
+// %offsets = orig_offset + local_offsets
+static Value computeOffsets(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter,
+ ArrayRef<Value> strides) {
+ Location loc = xferOp.getLoc();
+ VectorType vectorType = xferOp.getVectorType();
+ SmallVector<Value> indices(xferOp.getIndices().begin(),
+ xferOp.getIndices().end());
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+
+ // Create vector.step operations for each dimension
+ SmallVector<Value> stepVectors;
+ llvm::map_to_vector(vectorShape, [&](int64_t dim) {
+ auto stepType = VectorType::get({dim}, rewriter.getIndexType());
+ auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
+ stepVectors.push_back(stepOp);
+ return stepOp;
+ });
+
+ // Multiply step vectors by corresponding strides
+ size_t memrefRank = strides.size();
+ size_t vectorRank = vectorShape.size();
+ SmallVector<Value> strideMultiplied;
+ for (size_t i = 0; i < vectorRank; ++i) {
+ size_t memrefDim = memrefRank - vectorRank + i;
+ Value strideValue = strides[memrefDim];
+ auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
+ auto bcastOp =
+ vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
+ auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
+ strideMultiplied.push_back(mulOp);
+ }
+
+ // Shape cast each multiplied vector to add singleton dimensions
+ SmallVector<Value> shapeCasted;
+ for (size_t i = 0; i < vectorRank; ++i) {
+ SmallVector<int64_t> newShape(vectorRank, 1);
+ newShape[i] = vectorShape[i];
+ auto newType = VectorType::get(newShape, rewriter.getIndexType());
+ auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
+ strideMultiplied[i]);
+ shapeCasted.push_back(castOp);
+ }
+
+ // Broadcast each shape-casted vector to full vector shape
+ SmallVector<Value> broadcasted;
+ auto fullIndexVectorType =
+ VectorType::get(vectorShape, rewriter.getIndexType());
+ for (Value shapeCastVal : shapeCasted) {
+ auto broadcastOp = vector::BroadcastOp::create(
+ rewriter, loc, fullIndexVectorType, shapeCastVal);
+ broadcasted.push_back(broadcastOp);
+ }
+
+ // Add all broadcasted vectors together to compute local offsets
+ Value localOffsets = broadcasted[0];
+ for (size_t i = 1; i < broadcasted.size(); ++i)
+ localOffsets =
+ arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
+
+ // Compute base offset from transfer read indices
+ Value baseOffset = nullptr;
+ if (!indices.empty()) {
+ baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
+ for (size_t i = 0; i < indices.size(); ++i) {
+ Value strideVal = strides[i];
+ Value offsetContrib =
+ arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
+ baseOffset =
+ arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
+ }
+ // Broadcast base offset to match vector shape
+ Value bcastBase = vector::BroadcastOp::create(
+ rewriter, loc, fullIndexVectorType, baseOffset);
+ localOffsets =
+ arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
+ }
+ return localOffsets;
+}
+
+// Collapse memref shape to 1D
+static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
+ PatternRewriter &rewriter) {
+ Location loc = xferOp.getLoc();
+
+ Value baseMemref = xferOp.getBase();
+ MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
+ Type elementType = memrefType.getElementType();
+
+ // Compute the total number of elements in the memref
+ MemRefType flatMemrefType;
+ if (memrefType.hasStaticShape()) {
+ auto totalElements = memrefType.getNumElements();
+ flatMemrefType = MemRefType::get({totalElements}, elementType);
+ } else {
+ flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
+ }
+
+ SmallVector<ReassociationIndices> reassociation;
+ ReassociationIndices allDims =
+ llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
+ reassociation.push_back(allDims);
+
+ auto collapseOp = memref::CollapseShapeOp::create(
+ rewriter, loc, flatMemrefType, baseMemref, reassociation);
+ return collapseOp;
+}
+
+static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
+ PatternRewriter &rewriter) {
+
+ Location loc = readOp.getLoc();
+ VectorType vectorType = readOp.getVectorType();
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+ auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
+ if (!memrefType)
+ return rewriter.notifyMatchFailure(readOp, "Expected memref source");
+
+ SmallVector<Value> strides = computeStrides(readOp, rewriter);
+ if (strides.empty())
+ return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
+
+ Value localOffsets = computeOffsets(readOp, rewriter, strides);
+
+ Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
+
+ Value mask = vector::ConstantMaskOp::create(
+ rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
+ vectorShape);
+ auto gatherOp = xegpu::LoadGatherOp::create(
+ rewriter, loc, vectorType, flatMemref, localOffsets, mask,
+ /*chunk_size=*/IntegerAttr{},
+ /*l1_hint=*/xegpu::CachePolicyAttr{},
+ /*l2_hint=*/xegpu::CachePolicyAttr{},
+ /*l3_hint=*/xegpu::CachePolicyAttr{});
+
+ rewriter.replaceOp(readOp, gatherOp.getResult());
+ return success();
+}
+
+static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
+ PatternRewriter &rewriter) {
+
+ Location loc = writeOp.getLoc();
+ VectorType vectorType = writeOp.getVectorType();
+ ArrayRef<int64_t> vectorShape = vectorType.getShape();
+
+ auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
+ if (!memrefType)
+ return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
+
+ SmallVector<Value> strides = computeStrides(writeOp, rewriter);
+
+ Value localOffsets = computeOffsets(writeOp, rewriter, strides);
+
+ Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
+
+ Value mask = vector::ConstantMaskOp::create(
+ rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
+ vectorShape);
+ xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
+ localOffsets, mask,
+ /*chunk_size=*/IntegerAttr{},
+ /*l1_hint=*/xegpu::CachePolicyAttr{},
+ /*l2_hint=*/xegpu::CachePolicyAttr{},
+ /*l3_hint=*/xegpu::CachePolicyAttr{});
+ rewriter.eraseOp(writeOp);
+ return success();
+}
+
struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
@@ -165,6 +434,22 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
if (failed(transferPreconditions(rewriter, readOp)))
return failure();
+ // TODO:This check needs to be replaced with proper uArch capability check
+ auto chip = xegpu::getChipStr(readOp);
+ if (chip != "pvc" && chip != "bmg") {
+ // lower to scattered load Op if the target HW doesn't have 2d block load
+ // support
+ // TODO: add support for OutOfBound access
+ if (readOp.hasOutOfBoundsDim())
+ return failure();
+ return lowerToScatteredLoadOp(readOp, rewriter);
+ }
+
+ // Perform common data transfer checks.
+ VectorType vecTy = readOp.getVectorType();
+ if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
+ return failure();
+
bool isOutOfBounds = readOp.hasOutOfBoundsDim();
if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
return rewriter.notifyMatchFailure(
@@ -173,7 +458,6 @@ struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
AffineMap readMap = readOp.getPermutationMap();
bool isTransposeLoad = !readMap.isMinorIdentity();
- VectorType vecTy = readOp.getVectorType();
Type elementType = vecTy.getElementType();
unsigned minTransposeBitWidth = 32;
if (isTransposeLoad &&
@@ -221,11 +505,26 @@ struct TransferWriteLowering
if (failed(transferPreconditions(rewriter, writeOp)))
return failure();
+ // TODO:This check needs to be replaced with proper uArch capability check
+ auto chip = xegpu::getChipStr(writeOp);
+ if (chip != "pvc" && chip != "bmg") {
+ // lower to scattered store Op if the target HW doesn't have 2d block
+ // store support
+ // TODO: add support for OutOfBound access
+ if (writeOp.hasOutOfBoundsDim())
+ return failure();
+ return lowerToScatteredStoreOp(writeOp, rewriter);
+ }
+
+ // Perform common data transfer checks.
+ VectorType vecTy = writeOp.getVectorType();
+ if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
+ return failure();
+
AffineMap map = writeOp.getPermutationMap();
if (!map.isMinorIdentity())
return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
- VectorType vecTy = writeOp.getVectorType();
auto descType = xegpu::TensorDescType::get(
vecTy.getShape(), vecTy.getElementType(),
/*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
diff --git a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
index 86edc2b..b405ec2 100644
--- a/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/AffineStructures.cpp
@@ -93,13 +93,13 @@ FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
int64_t lb = forOp.getConstantLowerBound();
dividend[pos] = 1;
dividend.back() -= lb;
- addLocalFloorDiv(dividend, step);
+ unsigned qPos = addLocalFloorDiv(dividend, step);
// Second constraint: (iv - lb) - step * q = 0.
SmallVector<int64_t, 8> eq(getNumCols(), 0);
eq[pos] = 1;
eq.back() -= lb;
// For the local var just added above.
- eq[getNumCols() - 2] = -step;
+ eq[qPos] = -step;
addEquality(eq);
}
}
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 488c3c3..7d4d818 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -2678,6 +2678,7 @@ TypedAttr mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
case AtomicRMWKind::addi:
case AtomicRMWKind::maxu:
case AtomicRMWKind::ori:
+ case AtomicRMWKind::xori:
return builder.getZeroAttr(resultType);
case AtomicRMWKind::andi:
return builder.getIntegerAttr(
@@ -2736,7 +2737,7 @@ std::optional<TypedAttr> mlir::arith::getNeutralElement(Operation *op) {
// Integer operations.
.Case([](arith::AddIOp op) { return AtomicRMWKind::addi; })
.Case([](arith::OrIOp op) { return AtomicRMWKind::ori; })
- .Case([](arith::XOrIOp op) { return AtomicRMWKind::ori; })
+ .Case([](arith::XOrIOp op) { return AtomicRMWKind::xori; })
.Case([](arith::AndIOp op) { return AtomicRMWKind::andi; })
.Case([](arith::MaxUIOp op) { return AtomicRMWKind::maxu; })
.Case([](arith::MinUIOp op) { return AtomicRMWKind::minu; })
@@ -2806,6 +2807,8 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
return arith::OrIOp::create(builder, loc, lhs, rhs);
case AtomicRMWKind::andi:
return arith::AndIOp::create(builder, loc, lhs, rhs);
+ case AtomicRMWKind::xori:
+ return arith::XOrIOp::create(builder, loc, lhs, rhs);
// TODO: Add remaining reduction operations.
default:
(void)emitOptionalError(loc, "Reduction operation type not supported");
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index e79da92..5359826 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1395,6 +1395,7 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
//===----------------------------------------------------------------------===//
// FieldOp
//===----------------------------------------------------------------------===//
+
static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op,
TypeAttr type,
Attribute initialValue) {
@@ -1452,6 +1453,15 @@ LogicalResult FieldOp::verify() {
//===----------------------------------------------------------------------===//
// GetFieldOp
//===----------------------------------------------------------------------===//
+
+LogicalResult GetFieldOp::verify() {
+ auto parentClassOp = getOperation()->getParentOfType<emitc::ClassOp>();
+ if (!parentClassOp.getOperation())
+ return emitOpError(" must be nested within an emitc.class operation");
+
+ return success();
+}
+
LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
FieldOp fieldOp =
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index c55e26e..06d7e07 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -64,8 +64,8 @@ public:
TypeAttr typeAttr = TypeAttr::get(val.getType());
fields.push_back({fieldName, typeAttr});
- FieldOp fieldop = rewriter.create<emitc::FieldOp>(
- funcOp->getLoc(), fieldName, typeAttr, nullptr);
+ FieldOp fieldop = emitc::FieldOp::create(rewriter, funcOp->getLoc(),
+ fieldName, typeAttr, nullptr);
if (argAttrs && idx < argAttrs->size()) {
fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
diff --git a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
index d4978ca..97adad6 100644
--- a/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/KernelOutlining.cpp
@@ -431,8 +431,7 @@ private:
if (std::optional<SymbolTable::UseRange> symbolUses =
SymbolTable::getSymbolUses(symbolDefWorklist.pop_back_val())) {
for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
- StringRef symbolName =
- cast<FlatSymbolRefAttr>(symbolUse.getSymbolRef()).getValue();
+ StringAttr symbolName = symbolUse.getSymbolRef().getLeafReference();
if (symbolTable.lookup(symbolName))
continue;
diff --git a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
index e9cf493..6da76e9 100644
--- a/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/XeVMAttachTarget.cpp
@@ -17,6 +17,7 @@
#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
+#include "mlir/Target/LLVM/XeVM/Target.h"
#include "llvm/Support/Regex.h"
namespace mlir {
diff --git a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
index 384d1a0..be71bd0 100644
--- a/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
+++ b/mlir/lib/Dialect/GPU/Utils/DistributionUtils.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Value.h"
+#include "llvm/ADT/DenseMap.h"
#include <numeric>
@@ -57,26 +58,29 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
warpOp.getResultTypes().end());
auto yield = cast<gpu::YieldOp>(
warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
- llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(),
- yield.getOperands().end());
+ SmallVector<Value> yieldValues(yield.getOperands().begin(),
+ yield.getOperands().end());
+ llvm::SmallDenseMap<Value, unsigned> indexLookup;
+ // Record the value -> first index mapping for faster lookup.
+ for (auto [i, v] : llvm::enumerate(yieldValues)) {
+ if (!indexLookup.count(v))
+ indexLookup[v] = i;
+ }
+
for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) {
- if (yieldValues.insert(value)) {
+ // If the value already exists in the yield, don't create a new output.
+ if (indexLookup.count(value)) {
+ indices.push_back(indexLookup[value]);
+ } else {
+ // If the value is new, add it to the yield and to the types.
+ yieldValues.push_back(value);
types.push_back(type);
indices.push_back(yieldValues.size() - 1);
- } else {
- // If the value already exit the region don't create a new output.
- for (auto [idx, yieldOperand] :
- llvm::enumerate(yieldValues.getArrayRef())) {
- if (yieldOperand == value) {
- indices.push_back(idx);
- break;
- }
- }
}
}
- yieldValues.insert_range(newYieldedValues);
+
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
- rewriter, warpOp, yieldValues.getArrayRef(), types);
+ rewriter, warpOp, yieldValues, types);
rewriter.replaceOp(warpOp,
newWarpOp.getResults().take_front(warpOp.getNumResults()));
return newWarpOp;
diff --git a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
index 894de44..e004d5f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/BasicPtxBuilderInterface.cpp
@@ -107,11 +107,32 @@ void PtxBuilder::insertValue(Value v, PTXRegisterMod itype) {
ss << getModifier() << getRegisterType(v) << ",";
}
+/// Check if the operation needs to pack and unpack results.
+static bool needsPackUnpack(BasicPtxBuilderInterface interfaceOp) {
+ return interfaceOp->getNumResults() > 1;
+}
+
+/// Pack the result types of the interface operation.
+/// If the operation has multiple results, it packs them into a struct
+/// type. Otherwise, it returns the original result types.
+static SmallVector<Type> packResultTypes(MLIRContext *ctx,
+ BasicPtxBuilderInterface interfaceOp) {
+ TypeRange results = interfaceOp->getResultTypes();
+
+ if (!needsPackUnpack(interfaceOp))
+ return llvm::to_vector<1>(results);
+
+ SmallVector<mlir::Type> elems(results.begin(), results.end());
+ auto sTy = LLVM::LLVMStructType::getLiteral(ctx, elems, /*isPacked=*/false);
+ return {sTy};
+}
+
LLVM::InlineAsmOp PtxBuilder::build() {
+ MLIRContext *ctx = interfaceOp->getContext();
auto asmDialectAttr = LLVM::AsmDialectAttr::get(interfaceOp->getContext(),
LLVM::AsmDialect::AD_ATT);
- auto resultTypes = interfaceOp->getResultTypes();
+ SmallVector<Type> resultTypes = packResultTypes(ctx, interfaceOp);
// Remove the last comma from the constraints string.
if (!registerConstraints.empty() &&
@@ -136,7 +157,7 @@ LLVM::InlineAsmOp PtxBuilder::build() {
rewriter, interfaceOp->getLoc(),
/*result types=*/resultTypes,
/*operands=*/ptxOperands,
- /*asm_string=*/llvm::StringRef(ptxInstruction),
+ /*asm_string=*/ptxInstruction,
/*constraints=*/registerConstraints.data(),
/*has_side_effects=*/interfaceOp.hasSideEffect(),
/*is_align_stack=*/false, LLVM::TailCallKind::None,
@@ -147,9 +168,34 @@ LLVM::InlineAsmOp PtxBuilder::build() {
void PtxBuilder::buildAndReplaceOp() {
LLVM::InlineAsmOp inlineAsmOp = build();
LLVM_DEBUG(DBGS() << "\n Generated PTX \n\t" << inlineAsmOp << "\n");
- if (inlineAsmOp->getNumResults() == interfaceOp->getNumResults()) {
- rewriter.replaceOp(interfaceOp, inlineAsmOp);
- } else {
+
+ // Case 1: no result
+ if (inlineAsmOp->getNumResults() == 0) {
rewriter.eraseOp(interfaceOp);
+ return;
+ }
+
+ // Case 2: single result, forward it directly
+ if (!needsPackUnpack(interfaceOp)) {
+ rewriter.replaceOp(interfaceOp, inlineAsmOp->getResults());
+ return;
}
+
+ // Case 3: multiple results were packed; unpack the struct.
+ assert(mlir::LLVM::LLVMStructType::classof(
+ inlineAsmOp.getResultTypes().front()) &&
+ "Expected result type to be LLVMStructType when unpacking multiple "
+ "results");
+ auto structTy = llvm::cast<mlir::LLVM::LLVMStructType>(
+ inlineAsmOp.getResultTypes().front());
+
+ SmallVector<mlir::Value> unpacked;
+ Value structVal = inlineAsmOp.getResult(0);
+ for (auto [idx, elemTy] : llvm::enumerate(structTy.getBody())) {
+ Value unpackedValue = LLVM::ExtractValueOp::create(
+ rewriter, interfaceOp->getLoc(), structVal, idx);
+ unpacked.push_back(unpackedValue);
+ }
+
+ rewriter.replaceOp(interfaceOp, unpacked);
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 422039f..a6e89f6 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -141,6 +141,38 @@ static ParseResult parseLLVMLinkage(OpAsmParser &p, LinkageAttr &val) {
return success();
}
+static ArrayAttr getLLVMAlignParamForCompressExpand(OpBuilder &builder,
+ bool isExpandLoad,
+ uint64_t alignment = 1) {
+ // From
+ // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
+ // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
+ //
+ // The pointer alignment defaults to 1.
+ if (alignment == 1) {
+ return nullptr;
+ }
+
+ auto emptyDictAttr = builder.getDictionaryAttr({});
+ auto alignmentAttr = builder.getI64IntegerAttr(alignment);
+ auto namedAttr =
+ builder.getNamedAttr(LLVMDialect::getAlignAttrName(), alignmentAttr);
+ SmallVector<mlir::NamedAttribute> attrs = {namedAttr};
+ auto alignDictAttr = builder.getDictionaryAttr(attrs);
+ // From
+ // https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics
+ // https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics
+ //
+ // The align parameter attribute can be provided for [expandload]'s first
+ // argument. The align parameter attribute can be provided for
+ // [compressstore]'s second argument.
+ int pos = isExpandLoad ? 0 : 1;
+ return pos == 0 ? builder.getArrayAttr(
+ {alignDictAttr, emptyDictAttr, emptyDictAttr})
+ : builder.getArrayAttr(
+ {emptyDictAttr, alignDictAttr, emptyDictAttr});
+}
+
//===----------------------------------------------------------------------===//
// Operand bundle helpers.
//===----------------------------------------------------------------------===//
@@ -4117,6 +4149,32 @@ LogicalResult LLVM::masked_scatter::verify() {
}
//===----------------------------------------------------------------------===//
+// masked_expandload (intrinsic)
+//===----------------------------------------------------------------------===//
+
+void LLVM::masked_expandload::build(OpBuilder &builder, OperationState &state,
+ mlir::TypeRange resTys, Value ptr,
+ Value mask, Value passthru,
+ uint64_t align) {
+ ArrayAttr argAttrs = getLLVMAlignParamForCompressExpand(builder, true, align);
+ build(builder, state, resTys, ptr, mask, passthru, /*arg_attrs=*/argAttrs,
+ /*res_attrs=*/nullptr);
+}
+
+//===----------------------------------------------------------------------===//
+// masked_compressstore (intrinsic)
+//===----------------------------------------------------------------------===//
+
+void LLVM::masked_compressstore::build(OpBuilder &builder,
+ OperationState &state, Value value,
+ Value ptr, Value mask, uint64_t align) {
+ ArrayAttr argAttrs =
+ getLLVMAlignParamForCompressExpand(builder, false, align);
+ build(builder, state, value, ptr, mask, /*arg_attrs=*/argAttrs,
+ /*res_attrs=*/nullptr);
+}
+
+//===----------------------------------------------------------------------===//
// InlineAsmOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index e0977f5..dbcc738 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -189,6 +189,26 @@ LogicalResult BulkStoreOp::verify() {
return success();
}
+LogicalResult PMEventOp::verify() {
+ auto eventId = getEventId();
+ auto maskedEventId = getMaskedEventId();
+ if (!maskedEventId && !eventId) {
+ return emitOpError() << "either `id` or `mask` must be set";
+ }
+
+ if (maskedEventId && eventId) {
+ return emitOpError() << "`id` and `mask` cannot be set at the same time";
+ }
+
+ if (eventId) {
+ if (eventId < 0 || eventId > 15) {
+ return emitOpError() << "`id` must be between 0 and 15";
+ }
+ }
+
+ return llvm::success();
+}
+
// Given the element type of an operand and whether or not it is an accumulator,
// this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
// operand's element type.
@@ -791,24 +811,58 @@ LogicalResult NVVM::WMMAMmaOp::verify() {
}
LogicalResult NVVM::LdMatrixOp::verify() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
- if (addressSpace != NVVM::kSharedMemorySpace)
- return emitOpError("expected source pointer in memory space 3");
-
- if (getNum() != 1 && getNum() != 2 && getNum() != 4)
- return emitOpError("expected num attribute to be 1, 2 or 4");
+ uint32_t num = getNum(), m = getShape().getM(), n = getShape().getN();
+ if (m == 8 && n == 8) {
+ if (num != 1 && num != 2 && num != 4) {
+ return emitOpError("expected num attribute to be 1, 2 or 4 for 8x8 "
+ "matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B16) {
+ return emitOpError("expected element type to be b16 for 8x8 matrix");
+ }
+ } else if (m == 8 && n == 16) {
+ if (num != 1 && num != 2 && num != 4) {
+ return emitOpError("expected num attribute to be 1, 2 or 4 for 8x16 "
+ "matrix");
+ }
+ if (getLayout() != MMALayout::row) {
+ return emitOpError("expected layout to be row for 8x16 matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
+ getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
+ return emitOpError("expected element type to be b8x16.b4x16_p64 or "
+ "b8x16.b6x16_p32 for 8x16 matrix");
+ }
+ } else if (m == 16 && n == 16) {
+ if (num != 1 && num != 2) {
+ return emitOpError("expected num attribute to be 1 or 2 for 16x16 "
+ "matrix");
+ }
+ if (getLayout() != MMALayout::col) {
+ return emitOpError("expected layout to be col for 16x16 matrix");
+ }
+ if (getEltType() != LdStMatrixEltType::B8 &&
+ getEltType() != LdStMatrixEltType::B8X16_B4X16_P64 &&
+ getEltType() != LdStMatrixEltType::B8X16_B6X16_P32) {
+ return emitOpError("expected element type to be b8, b8x16.b4x16_p64 or "
+ "b8x16.b6x16_p32 for 16x16 matrix");
+ }
+ } else {
+ return emitOpError("expected shape to be 8x8, 8x16 or 16x16");
+ }
Type i32 = IntegerType::get(getContext(), 32);
- if (getNum() == 1 && getType() != i32)
+ uint32_t numElements = (m == 16 && n == 16 ? num * 2 : num);
+ if (numElements == 1 && getType() != i32)
return emitOpError("expected destination type is i32");
- if (getNum() == 2 || getNum() == 4) {
+ if (numElements == 2 || numElements == 4) {
Type dstType = LLVM::LLVMStructType::getLiteral(
- getContext(), SmallVector<Type>(getNum(), i32));
+ getContext(), SmallVector<Type>(numElements, i32));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
- << getNum() << " elements of type i32";
+ << numElements << " elements of type i32";
}
+
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 34c63d3..e0e3716 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -194,9 +194,10 @@ static void buildMatmulOp(OpBuilder &b, OperationState &state,
ArrayRef<AffineMap> indexingMaps) {
// Initialize indexingMaps attribute, for MatmulOp.
SmallVector<Attribute, 3> indexingMapsAttrVal;
- indexingMapsAttrVal = llvm::map_to_vector(
- MatmulOp::getDefaultIndexingMaps(b.getContext()),
- [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); });
+ indexingMapsAttrVal =
+ llvm::map_to_vector(indexingMaps, [](AffineMap map) -> Attribute {
+ return AffineMapAttr::get(map);
+ });
state.addAttribute("indexing_maps", b.getArrayAttr(indexingMapsAttrVal));
return buildStructuredOp(b, state, resultTensorTypes, inputs, outputs,
attributes, regionBuilder);
@@ -1569,40 +1570,50 @@ ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
return success();
}
-// Retrieve the operation from the body, if it is the only one (except
-// yield) and if it gets the same amount of arguments as the body does.
-// If initFirst flag is enabled, we check that init takes the first position in
-// operands of payload.
-static Operation *findPayloadOp(Block *body, bool initFirst = false) {
+static bool canUseShortForm(Block *body, bool initFirst = false) {
+ // Check if the body can be printed in short form. The following 4 conditions
+ // must be satisfied:
+
+ // 1) The body must contain exactly 2 operations: the payload op and a yield.
if (body->getOperations().size() != 2)
- return nullptr;
+ return false;
Operation &payload = body->getOperations().front();
- assert(isa<YieldOp>(body->getOperations().back()));
+ // 2) The payload op must have the same number of operands as the number of
+ // block arguments.
if (payload.getNumOperands() == 0 ||
payload.getNumOperands() != body->getNumArguments())
- return nullptr;
+ return false;
+
+ // 3) If `initFirst` is true (e.g., for reduction ops), the init block
+ // must be the first operand of the payload op, otherwise, the operands
+ // must match the block arguments in order.
if (initFirst) {
// check init
if (payload.getOperands().back() != body->getArgument(0))
- return nullptr;
+ return false;
// check rest
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments().drop_front())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
} else {
for (const auto &[operand, bbArg] :
llvm::zip(payload.getOperands(), body->getArguments())) {
if (bbArg != operand)
- return nullptr;
+ return false;
}
}
- return &payload;
+
+ // 4) The `yield` operand must be the result of the payload op.
+ auto yieldOp = cast<YieldOp>(body->getTerminator());
+ return yieldOp.getNumOperands() == 1 &&
+ yieldOp.getOperand(0).getDefiningOp() &&
+ yieldOp.getOperand(0).getDefiningOp() == &payload;
}
-void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
+static void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
SmallVector<StringRef> elidedAttrs;
std::string attrToElide;
p << " { " << payloadOp->getName().getStringRef();
@@ -1621,15 +1632,15 @@ void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
void MapOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
p.printOptionalAttrDict((*this)->getAttrs());
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
@@ -1828,15 +1839,15 @@ static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
void ReduceOp::print(OpAsmPrinter &p) {
Block *mapper = getBody();
- Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
- if (payloadOp) {
- printShortForm(p, payloadOp);
+ bool useShortForm = canUseShortForm(mapper, /*initFirst=*/true);
+ if (useShortForm) {
+ printShortForm(p, &mapper->getOperations().front());
}
printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
- if (!payloadOp) {
+ if (!useShortForm) {
// Print region if the payload op was not detected.
p.increaseIndent();
p.printNewline();
@@ -3749,6 +3760,25 @@ std::pair<int64_t, int64_t> getFmrFromWinogradConv2DFmr(WinogradConv2DFmr fmr) {
// MatMulOp
//===----------------------------------------------------------------------===//
+static FailureOr<SmallVector<SmallVector<int64_t>>>
+getAffineResultPositions(ArrayAttr maps) {
+ SmallVector<SmallVector<int64_t>> positions;
+ for (auto map : maps) {
+ AffineMapAttr attr = dyn_cast<AffineMapAttr>(map);
+ if (!attr)
+ return failure();
+ SmallVector<int64_t> pos;
+ for (auto result : attr.getAffineMap().getResults()) {
+ auto dim = dyn_cast<AffineDimExpr>(result);
+ if (!dim)
+ return failure();
+ pos.push_back(dim.getPosition());
+ }
+ positions.push_back(pos);
+ }
+ return positions;
+}
+
/// Returns a list of AffineMap with the typical matmul indexing charactristic.
SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
AffineExpr d0, d1, d2;
@@ -3760,6 +3790,20 @@ SmallVector<AffineMap> MatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool MatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
+ (*positions)[1] == SmallVector<int64_t>{2, 1} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
SmallVector<utils::IteratorType> MatmulOp::getIteratorTypesArray() {
return SmallVector<utils::IteratorType>{utils::IteratorType::parallel,
utils::IteratorType::parallel,
@@ -3912,6 +3956,380 @@ Speculation::Speculatability MatmulOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+SmallVector<AffineMap>
+MatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2);
+ AffineMap mapLHS = AffineMap::get(3, 0, {d2, d0}, context);
+ AffineMap mapRHS = AffineMap::get(3, 0, {d2, d1}, context);
+ AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool MatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{2, 0} &&
+ (*positions)[1] == SmallVector<int64_t>{2, 1} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeAOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeAOp
+MatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<MatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool MatmulTransposeAOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::MatmulOp>(op) &&
+ MatmulTransposeAOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+MatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2);
+ AffineMap mapLHS = AffineMap::get(3, 0, {d0, d2}, context);
+ AffineMap mapRHS = AffineMap::get(3, 0, {d1, d2}, context);
+ AffineMap mapOut = AffineMap::get(3, 0, {d0, d1}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool MatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 2} &&
+ (*positions)[1] == SmallVector<int64_t>{1, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1};
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::MatmulTransposeBOp::build(OpBuilder &builder,
+ OperationState &result,
+ TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ MatmulOp::getRegionBuilder(), getDefaultIndexingMaps(builder));
+}
+
+MatmulTransposeBOp
+MatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<MatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool MatmulTransposeBOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::MatmulOp>(op) &&
+ MatmulTransposeBOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeAOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2, d3;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2, d3);
+ AffineMap mapLHS = AffineMap::get(4, 0, {d0, d3, d1}, context);
+ AffineMap mapRHS = AffineMap::get(4, 0, {d0, d3, d2}, context);
+ AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool BatchMatmulTransposeAOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 3, 1} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeAOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeAOp
+BatchMatmulTransposeAOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeAOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool BatchMatmulTransposeAOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
+ BatchMatmulTransposeAOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
+SmallVector<AffineMap>
+BatchMatmulTransposeBOp::getDefaultIndexingMaps(OpBuilder &builder) {
+ AffineExpr d0, d1, d2, d3;
+ MLIRContext *context = builder.getContext();
+ bindDims(context, d0, d1, d2, d3);
+ AffineMap mapLHS = AffineMap::get(4, 0, {d0, d1, d3}, context);
+ AffineMap mapRHS = AffineMap::get(4, 0, {d0, d2, d3}, context);
+ AffineMap mapOut = AffineMap::get(4, 0, {d0, d1, d2}, context);
+ return {mapLHS, mapRHS, mapOut};
+}
+
+bool BatchMatmulTransposeBOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 2, 3} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, ValueRange inputs,
+ ValueRange outputs, ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, std::nullopt, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+void linalg::BatchMatmulTransposeBOp::build(
+ OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
+ ValueRange inputs, ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ result.addAttribute("cast", cast);
+ buildMatmulOp(builder, result, resultTensorTypes, inputs, outputs, attributes,
+ BatchMatmulOp::getRegionBuilder(),
+ getDefaultIndexingMaps(builder));
+}
+
+BatchMatmulTransposeBOp
+BatchMatmulTransposeBOp::create(OpBuilder &builder, Location location,
+ TypeRange resultTensorTypes, ValueRange inputs,
+ ValueRange outputs, Attribute cast,
+ ArrayRef<NamedAttribute> attributes) {
+ OperationState state(location, getOperationName());
+ build(builder, state, resultTensorTypes, inputs, outputs, cast, attributes);
+ auto res = dyn_cast<BatchMatmulTransposeBOp>(builder.create(state));
+ assert(res && "builder didn't return the right type");
+ return res;
+}
+
+bool BatchMatmulTransposeBOp::classof(Operation *op) {
+ return dyn_cast_or_null<linalg::BatchMatmulOp>(op) &&
+ BatchMatmulTransposeBOp::isDefaultIndexingMaps(
+ op->getAttr("indexing_maps"));
+}
+
//===----------------------------------------------------------------------===//
// ContractOp
//===----------------------------------------------------------------------===//
@@ -4120,6 +4538,20 @@ BatchMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool BatchMatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{0, 1, 2};
+}
+
SmallVector<utils::IteratorType> BatchMatmulOp::getIteratorTypesArray() {
return SmallVector<utils::IteratorType>{
utils::IteratorType::parallel, utils::IteratorType::parallel,
@@ -5345,11 +5777,18 @@ ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
auto innerDimsPos = getInnerDimsPos();
- auto packedShape = getSourceType().getShape();
+ SmallVector<int64_t> outerDims(getAllOuterDims());
SmallVector<int64_t> res;
+ // Recover the original order of the outer dims.
+ SmallVector<int64_t> outerDimPermInv(getOuterDimsPerm());
+ invertPermutationVector(outerDimPermInv);
+ if (!outerDimPermInv.empty())
+ applyPermutationToVector(outerDims, outerDimPermInv);
+
+ // Collect the outer dims corresponding to the tilled inner dims.
for (auto index : innerDimsPos)
- res.push_back(packedShape[index]);
+ res.push_back(outerDims[index]);
return res;
}
@@ -5646,6 +6085,19 @@ BatchReduceMatmulOp::getDefaultIndexingMaps(MLIRContext *context) {
return indexingMaps;
}
+bool BatchReduceMatmulOp::isDefaultIndexingMaps(Attribute attr) {
+ ArrayAttr maps = dyn_cast<ArrayAttr>(attr);
+ if (!maps)
+ return false;
+ if (maps.size() != 3)
+ return false;
+ auto positions = getAffineResultPositions(maps);
+ if (failed(positions))
+ return false;
+ return (*positions)[0] == SmallVector<int64_t>{0, 1, 3} &&
+ (*positions)[1] == SmallVector<int64_t>{0, 3, 2} &&
+ (*positions)[2] == SmallVector<int64_t>{1, 2};
+}
unsigned BatchReduceMatmulOp::getNumRegionArgs() { return 3; }
std::string BatchReduceMatmulOp::getLibraryCallName() {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8754743..639e0fe 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
+#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/TransformOps/GPUHeuristics.h"
@@ -27,6 +28,7 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Transform/Utils/Utils.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
@@ -1985,14 +1987,19 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
// Convert the padding values to attributes.
SmallVector<Attribute> paddingValues;
- for (auto const &it :
+ for (auto const &[untypedAttr, elementOrTensorType] :
llvm::zip(getPaddingValues(), linalgTarget->getOperandTypes())) {
- auto attr = dyn_cast<TypedAttr>(std::get<0>(it));
+
+ if (isa<ub::PoisonAttr>(untypedAttr)) {
+ paddingValues.push_back(untypedAttr);
+ continue;
+ }
+ auto attr = dyn_cast<TypedAttr>(untypedAttr);
if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ emitOpError("expects padding values to be typed attributes or poison");
return DiagnosedSilenceableFailure::definiteFailure();
}
- Type elementType = getElementTypeOrSelf(std::get<1>(it));
+ Type elementType = getElementTypeOrSelf(elementOrTensorType);
// Try to parse string attributes to obtain an attribute of element type.
if (auto stringAttr = dyn_cast<StringAttr>(attr)) {
auto parsedAttr = dyn_cast_if_present<TypedAttr>(parseAttribute(
@@ -2000,7 +2007,7 @@ transform::PadOp::apply(transform::TransformRewriter &rewriter,
/*numRead=*/nullptr, /*isKnownNullTerminated=*/true));
if (!parsedAttr || parsedAttr.getType() != elementType) {
auto diag = this->emitOpError("expects a padding that parses to ")
- << elementType << ", got " << std::get<0>(it);
+ << elementType << ", got " << untypedAttr;
diag.attachNote(linalgTarget.getLoc()) << "when applied to this op";
return DiagnosedSilenceableFailure::definiteFailure();
}
@@ -2235,8 +2242,13 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
llvm::zip(getPaddingValues(), targetOp->getOperandTypes())) {
auto attr = dyn_cast<TypedAttr>(untypedAttr);
Type elementType = getElementTypeOrSelf(elementOrTensorType);
+
+ if (isa<ub::PoisonAttr>(untypedAttr)) {
+ paddingValues.push_back(untypedAttr);
+ continue;
+ }
if (!attr) {
- emitOpError("expects padding values to be typed attributes");
+ emitOpError("expects padding values to be typed attributes or poison");
return DiagnosedSilenceableFailure::definiteFailure();
}
// Try to parse string attributes to obtain an attribute of element type.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
index 3908d73..b4507a9 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp
@@ -320,10 +320,6 @@ void linalg::populateBlockPackMatmulPatterns(
RewritePatternSet &patterns, const ControlBlockPackMatmulFn &controlFn) {
patterns.add<BlockPackMatmul<linalg::GenericOp>,
BlockPackMatmul<linalg::MatmulOp>,
- BlockPackMatmul<linalg::BatchMatmulOp>,
- BlockPackMatmul<linalg::MatmulTransposeAOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeAOp>,
- BlockPackMatmul<linalg::MatmulTransposeBOp>,
- BlockPackMatmul<linalg::BatchMatmulTransposeBOp>>(
- patterns.getContext(), controlFn);
+ BlockPackMatmul<linalg::BatchMatmulOp>>(patterns.getContext(),
+ controlFn);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index bf66ed0..22690da 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -691,9 +691,9 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
auto newResultType = RankedTensorType::get(
newResultShape, padOp.getResultType().getElementType());
- auto newPadOp = rewriter.create<tensor::PadOp>(
- padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad,
- newHighPad, paddingVal, padOp.getNofold());
+ auto newPadOp = tensor::PadOp::create(
+ rewriter, padOp.getLoc(), /*result=*/newResultType, collapsedSource,
+ newLowPad, newHighPad, paddingVal, padOp.getNofold());
Value dest = padOp.getResult();
if (options.rankReductionStrategy ==
@@ -1052,12 +1052,8 @@ struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
static bool constexpr reduceLeft =
(std::is_same_v<FromOpTy, BatchMatmulOp> &&
std::is_same_v<ToOpTy, BatchVecmatOp>) ||
- (std::is_same_v<FromOpTy, BatchMatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, BatchVecmatOp>) ||
(std::is_same_v<FromOpTy, MatmulOp> &&
std::is_same_v<ToOpTy, VecmatOp>) ||
- (std::is_same_v<FromOpTy, MatmulTransposeAOp> &&
- std::is_same_v<ToOpTy, VecmatOp>) ||
(std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
/// Look for non-batch spatial dims to collapse.
@@ -1113,27 +1109,15 @@ void mlir::linalg::populateContractionOpRankReducingPatterns(
MLIRContext *context = patterns.getContext();
// Unbatching patterns for unit batch size
patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeAOp, MatmulTransposeAOp>>(
- context);
- patterns
- .add<RankReduceToUnBatched<BatchMatmulTransposeBOp, MatmulTransposeBOp>>(
- context);
patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
// Non-batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeAOp, VecmatOp>>(context);
- patterns.add<RankReduceMatmul<MatmulTransposeBOp, MatvecOp>>(context);
// Batch rank 1 reducing patterns
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeAOp, BatchVecmatOp>>(
- context);
- patterns.add<RankReduceMatmul<BatchMatmulTransposeBOp, BatchMatvecOp>>(
- context);
// Non-batch rank 0 reducing patterns
patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index fd530f2..9436f1c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -594,7 +594,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
auto clonedForOp = scf::ForOp::create(
rewriter, loc, bvm.lookupOrDefault(forOp.getLowerBound()),
bvm.lookupOrDefault(forOp.getUpperBound()),
- bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
+ bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
// Map the induction var, region args and results to the `clonedForOp`.
bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 58986a6..922b7d6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -55,7 +55,8 @@ static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
scf::ForOp newLoop = scf::ForOp::create(
rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
- loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {});
+ loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
+ loop.getUnsignedCmp());
// Generate the new yield with the replaced operand.
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2e62523..3d12bc3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -230,13 +231,18 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
Value paddingValue;
if (auto complexTy =
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
- auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
- paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
- complexTy, complexAttr);
- } else {
- paddingValue = arith::ConstantOp::create(rewriter, opToPad.getLoc(),
- cast<TypedAttr>(paddingValueAttr));
+ if (auto complexAttr = dyn_cast<ArrayAttr>(paddingValueAttr)) {
+ paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(),
+ complexTy, complexAttr);
+ }
+ } else if (isa<ub::PoisonAttr>(paddingValueAttr)) {
+ paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(),
+ getElementTypeOrSelf(v.getType()));
+ } else if (auto typedAttr = dyn_cast<TypedAttr>(paddingValueAttr)) {
+ paddingValue =
+ arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr);
}
+ assert(paddingValue && "failed to create value from padding attribute");
// Pad the operand to the bounding box defined by `paddedShape`.
SmallVector<int64_t> tensorShape;
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
index 455e1a6..35ba4f15 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
@@ -234,19 +234,8 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
/// Codegen the different matmul variants.
if (numOfBatchDims) {
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeAOp>(rewriter,
- genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<BatchMatmulTransposeBOp>(rewriter,
- genericOp);
return replaceWithMatmulVariant<BatchMatmulOp>(rewriter, genericOp);
}
-
- if (a == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeAOp>(rewriter, genericOp);
- if (b == IndexMatchResult::Transposed)
- return replaceWithMatmulVariant<MatmulTransposeBOp>(rewriter, genericOp);
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
index a2a4335..2650488 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TransposeMatmul.cpp
@@ -59,12 +59,12 @@ FailureOr<Operation *> mlir::linalg::transposeMatmul(RewriterBase &rewriter,
ArrayRef<int64_t>{1, 0});
Operation *newMatmulOp;
if (transposeLHS) {
- newMatmulOp = linalg::MatmulTransposeAOp::create(
+ newMatmulOp = MatmulTransposeAOp::create(
rewriter, loc, matmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), matmulOp.getInputs()[1]},
matmulOp.getOutputs());
} else {
- newMatmulOp = linalg::MatmulTransposeBOp::create(
+ newMatmulOp = MatmulTransposeBOp::create(
rewriter, loc, matmulOp.getResultTypes(),
ValueRange{matmulOp.getInputs()[0], transposeOp->getResult(0)},
matmulOp.getOutputs());
@@ -116,12 +116,12 @@ mlir::linalg::transposeBatchMatmul(RewriterBase &rewriter,
ArrayRef<int64_t>{0, 2, 1});
Operation *newMatmulOp;
if (transposeLHS) {
- newMatmulOp = linalg::BatchMatmulTransposeAOp::create(
+ newMatmulOp = BatchMatmulTransposeAOp::create(
rewriter, loc, batchMatmulOp.getResultTypes(),
ValueRange{transposeOp->getResult(0), batchMatmulOp.getInputs()[1]},
batchMatmulOp.getOutputs());
} else {
- newMatmulOp = linalg::BatchMatmulTransposeBOp::create(
+ newMatmulOp = BatchMatmulTransposeBOp::create(
rewriter, loc, batchMatmulOp.getResultTypes(),
ValueRange{batchMatmulOp.getInputs()[0], transposeOp->getResult(0)},
batchMatmulOp.getOutputs());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index cf65e67..406f05c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -2563,7 +2563,7 @@ vectorizeScalableVectorPrecondition(Operation *op,
"vectorization";
return failure();
}
- if (isa<linalg::MatmulOp>(op) || isa<linalg::MatmulTransposeAOp>(op)) {
+ if (isa<linalg::MatmulOp>(op)) {
LDBG()
<< "Scalable vectorization of the reduction dim in Matmul-like ops "
"is not supported";
@@ -2604,17 +2604,12 @@ vectorizeScalableVectorPrecondition(Operation *op,
return failure();
}
- // Check to not let go the matmul with extended semantic, through this
- // transform.
- if (linalgOp.hasUserDefinedMaps())
- return failure();
-
// Cond 4: Only the following ops are supported in the
// presence of scalable vectors
return success(isElementwise(linalgOp) || isa<linalg::MatmulOp>(op) ||
- isa<linalg::MatmulTransposeAOp>(op) ||
isa<linalg::DepthwiseConv1DNwcWcOp>(op) ||
isa<linalg::MatvecOp>(op) || isa<linalg::Mmt4DOp>(op) ||
+ isa<linalg::BatchMmt4DOp>(op) ||
hasReductionIterator(linalgOp));
}
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index 74b968c..b59d73d 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3558,6 +3558,7 @@ LogicalResult AtomicRMWOp::verify() {
case arith::AtomicRMWKind::minu:
case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori:
+ case arith::AtomicRMWKind::xori:
case arith::AtomicRMWKind::andi:
if (!llvm::isa<IntegerType>(getValue().getType()))
return emitOpError() << "with kind '"
diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
index 34c95e3..8474244 100644
--- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
+++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp
@@ -422,6 +422,12 @@ std::optional<InFlightDiagnostic> verifyTmaDescriptorWithMemref(
<< descMemref << " != " << dstMemref;
}
+ int lastDimBytes =
+ descMemref.getShape().back() * descMemref.getElementTypeBitWidth() / 8;
+ if (lastDimBytes % 16 != 0) {
+ return op->emitError() << "the bytes in the last dimension of the tensor "
+ "map must be a multiple of 16";
+ }
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index 485bb73..d7c8916 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1390,6 +1390,20 @@ void acc::ParallelOp::addPrivatization(MLIRContext *context,
setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
+void acc::ParallelOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ getFirstprivateOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getFirstprivatizationRecipesAttr())
+ llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
static ParseResult parseNumGangs(
mlir::OpAsmParser &parser,
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
@@ -2041,6 +2055,21 @@ void acc::SerialOp::addPrivatization(MLIRContext *context,
setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
}
+void acc::SerialOp::addFirstPrivatization(
+ MLIRContext *context, mlir::acc::FirstprivateOp op,
+ mlir::acc::FirstprivateRecipeOp recipe) {
+ getFirstprivateOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getFirstprivatizationRecipesAttr())
+ llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
//===----------------------------------------------------------------------===//
// KernelsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c1c1767..fa94219 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -3874,6 +3874,107 @@ LogicalResult AllocateDirOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// TargetAllocMemOp
+//===----------------------------------------------------------------------===//
+
+mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
+ return getInTypeAttr().getValue();
+}
+
+/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
+/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
+/// attr-dict-without-keyword
+static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ auto &builder = parser.getBuilder();
+ bool hasOperands = false;
+ std::int32_t typeparamsSize = 0;
+
+ // Parse device number as a new operand
+ mlir::OpAsmParser::UnresolvedOperand deviceOperand;
+ mlir::Type deviceType;
+ if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
+ return mlir::failure();
+ if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
+ return mlir::failure();
+ if (parser.parseComma())
+ return mlir::failure();
+
+ mlir::Type intype;
+ if (parser.parseType(intype))
+ return mlir::failure();
+ result.addAttribute("in_type", mlir::TypeAttr::get(intype));
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> operands;
+ llvm::SmallVector<mlir::Type> typeVec;
+ if (!parser.parseOptionalLParen()) {
+ // parse the LEN params of the derived type. (<params> : <types>)
+ if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(typeVec) || parser.parseRParen())
+ return mlir::failure();
+ typeparamsSize = operands.size();
+ hasOperands = true;
+ }
+ std::int32_t shapeSize = 0;
+ if (!parser.parseOptionalComma()) {
+ // parse size to scale by, vector of n dimensions of type index
+ if (parser.parseOperandList(operands, mlir::OpAsmParser::Delimiter::None))
+ return mlir::failure();
+ shapeSize = operands.size() - typeparamsSize;
+ auto idxTy = builder.getIndexType();
+ for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
+ typeVec.push_back(idxTy);
+ hasOperands = true;
+ }
+ if (hasOperands &&
+ parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
+ result.operands))
+ return mlir::failure();
+
+ mlir::Type restype = builder.getIntegerType(64);
+ if (!restype) {
+ parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
+ return mlir::failure();
+ }
+ llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
+ result.addAttribute("operandSegmentSizes",
+ builder.getDenseI32ArrayAttr(segmentSizes));
+ if (parser.parseOptionalAttrDict(result.attributes) ||
+ parser.addTypeToList(restype, result.types))
+ return mlir::failure();
+ return mlir::success();
+}
+
+mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
+ mlir::OperationState &result) {
+ return parseTargetAllocMemOp(parser, result);
+}
+
+void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
+ p << " ";
+ p.printOperand(getDevice());
+ p << " : ";
+ p << getDevice().getType();
+ p << ", ";
+ p << getInType();
+ if (!getTypeparams().empty()) {
+ p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
+ }
+ for (auto sh : getShape()) {
+ p << ", ";
+ p.printOperand(sh);
+ }
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ {"in_type", "operandSegmentSizes"});
+}
+
+llvm::LogicalResult omp::TargetAllocMemOp::verify() {
+ mlir::Type outType = getType();
+ if (!mlir::dyn_cast<IntegerType>(outType))
+ return emitOpError("must be a integer type");
+ return mlir::success();
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 0262a1b..0dbc041 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -318,9 +318,12 @@ void ConditionOp::getSuccessorRegions(
void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
Value ub, Value step, ValueRange initArgs,
- BodyBuilderFn bodyBuilder) {
+ BodyBuilderFn bodyBuilder, bool unsignedCmp) {
OpBuilder::InsertionGuard guard(builder);
+ if (unsignedCmp)
+ result.addAttribute(getUnsignedCmpAttrName(result.name),
+ builder.getUnitAttr());
result.addOperands({lb, ub, step});
result.addOperands(initArgs);
for (Value v : initArgs)
@@ -450,6 +453,9 @@ static void printInitializationList(OpAsmPrinter &p,
}
void ForOp::print(OpAsmPrinter &p) {
+ if (getUnsignedCmp())
+ p << " unsigned";
+
p << " " << getInductionVar() << " = " << getLowerBound() << " to "
<< getUpperBound() << " step " << getStep();
@@ -462,7 +468,8 @@ void ForOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/!getInitArgs().empty());
- p.printOptionalAttrDict((*this)->getAttrs());
+ p.printOptionalAttrDict((*this)->getAttrs(),
+ /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
}
ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -472,6 +479,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
OpAsmParser::Argument inductionVariable;
OpAsmParser::UnresolvedOperand lb, ub, step;
+ if (succeeded(parser.parseOptionalKeyword("unsigned")))
+ result.addAttribute(getUnsignedCmpAttrName(result.name),
+ builder.getUnitAttr());
+
// Parse the induction variable followed by '='.
if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
// Parse loop bounds.
@@ -562,7 +573,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
inits.append(newInitOperands.begin(), newInitOperands.end());
scf::ForOp newLoop = scf::ForOp::create(
rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
- [](OpBuilder &, Location, Value, ValueRange) {});
+ [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
// Generate the new yield values and append them to the scf.yield operation.
@@ -806,7 +817,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
// 2. Create the new forOp shell.
scf::ForOp newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newIterOperands);
+ forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
@@ -931,7 +943,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
scf::ForOp newForOp =
scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
- forOp.getUpperBound(), forOp.getStep(), newIterArgs);
+ forOp.getUpperBound(), forOp.getStep(), newIterArgs,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block &newBlock = newForOp.getRegion().front();
@@ -989,12 +1002,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
/// Util function that tries to compute a constant diff between u and l.
/// Returns std::nullopt when the difference between two AffineValueMap is
/// dynamic.
-static std::optional<int64_t> computeConstDiff(Value l, Value u) {
+static std::optional<APInt> computeConstDiff(Value l, Value u) {
IntegerAttr clb, cub;
if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
llvm::APInt lbValue = clb.getValue();
llvm::APInt ubValue = cub.getValue();
- return (ubValue - lbValue).getSExtValue();
+ return ubValue - lbValue;
}
// Else a simple pattern match for x + c or c + x
@@ -1003,7 +1016,7 @@ static std::optional<int64_t> computeConstDiff(Value l, Value u) {
u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
matchPattern(
u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
- return diff.getSExtValue();
+ return diff;
return std::nullopt;
}
@@ -1022,13 +1035,15 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
return success();
}
- std::optional<int64_t> diff =
+ std::optional<APInt> diff =
computeConstDiff(op.getLowerBound(), op.getUpperBound());
if (!diff)
return failure();
// If the loop is known to have 0 iterations, remove it.
- if (*diff <= 0) {
+ bool zeroOrLessIterations =
+ diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
+ if (zeroOrLessIterations) {
rewriter.replaceOp(op, op.getInitArgs());
return success();
}
@@ -3384,9 +3399,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs() << ")";
}
// Resolve input operands.
@@ -4222,14 +4236,15 @@ LogicalResult scf::IndexSwitchOp::verify() {
<< "see yield operation here";
}
for (auto [idx, result, operand] :
- llvm::zip(llvm::seq<unsigned>(0, getNumResults()), getResultTypes(),
- yield.getOperandTypes())) {
- if (result == operand)
+ llvm::enumerate(getResultTypes(), yield.getOperands())) {
+ if (!operand)
+ return yield.emitOpError() << "operand " << idx << " is null\n";
+ if (result == operand.getType())
continue;
return (emitOpError("expected result #")
<< idx << " of each region to be " << result)
.attachNote(yield.getLoc())
- << name << " returns " << operand << " here";
+ << name << " returns " << operand.getType() << " here";
}
return success();
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index f8799c5..fb179e6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -769,7 +769,8 @@ struct ForOpInterface
// Construct a new scf.for op with memref instead of tensor values.
auto newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), castedInitArgs);
+ forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
newForOp->setAttrs(forOp->getAttrs());
Block *loopBody = newForOp.getBody();
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index bee7780..ae52af5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -58,9 +58,12 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
auto *beforeBlock = rewriter.createBlock(
&whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
- auto cmpOp = arith::CmpIOp::create(
- rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt,
- beforeBlock->getArgument(0), forOp.getUpperBound());
+ arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
+ ? arith::CmpIPredicate::ult
+ : arith::CmpIPredicate::slt;
+ auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), predicate,
+ beforeBlock->getArgument(0),
+ forOp.getUpperBound());
scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(),
beforeBlock->getArguments());
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 1130538..7e7fba4 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -791,6 +791,11 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
bool *modifiedIR) {
if (modifiedIR)
*modifiedIR = false;
+
+ // TODO: Add support for unsigned loops.
+ if (forOp.getUnsignedCmp())
+ return failure();
+
LoopPipelinerInternal pipeliner;
if (!pipeliner.initializeLoopInfo(forOp, options))
return failure();
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
index 4752c08..f1203b2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopSpecialization.cpp
@@ -256,6 +256,10 @@ struct ForLoopPeelingPattern : public OpRewritePattern<ForOp> {
LogicalResult matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const override {
+ if (forOp.getUnsignedCmp())
+ return rewriter.notifyMatchFailure(forOp,
+ "unsigned loops are not supported");
+
// Do not peel already peeled loops.
if (forOp->hasAttr(kPeeledLoopLabel))
return failure();
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 1b07b77..3b75970 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -116,7 +116,8 @@ public:
llvm::getSingleElement(adaptor.getLowerBound()),
llvm::getSingleElement(adaptor.getUpperBound()),
llvm::getSingleElement(adaptor.getStep()),
- flattenValues(adaptor.getInitArgs()));
+ flattenValues(adaptor.getInitArgs()),
+ /*bodyBuilder=*/nullptr, op.getUnsignedCmp());
// Reserve whatever attributes in the original op.
newOp->setAttrs(op->getAttrs());
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index c0e47ee..250c413 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -797,7 +797,8 @@ FailureOr<LoopLikeOpInterface> yieldTiledValuesAndReplaceLoop<scf::ForOp>(
inits.append(newInitOperands.begin(), newInitOperands.end());
auto newLoop = scf::ForOp::create(
rewriter, loc, loopOp.getLowerBound(), loopOp.getUpperBound(),
- loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {});
+ loopOp.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
+ loopOp.getUnsignedCmp());
// Move the loop body to the new op.
Block *loopBody = loopOp.getBody();
@@ -935,7 +936,8 @@ static LogicalResult addInitOperandsToLoopNest(
auto newLoop = scf::ForOp::create(
rewriter, forLoop.getLoc(), forLoop.getLowerBound(),
forLoop.getUpperBound(), forLoop.getStep(), newInits,
- [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {},
+ forLoop.getUnsignedCmp());
// Merge the body of the new loop with the body of the old loops.
SmallVector<Value> sourceBlockArgs;
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 5731795..4910258 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1233,6 +1233,7 @@ static void getPerfectlyNestedLoopsImpl(
static Loops stripmineSink(scf::ForOp forOp, Value factor,
ArrayRef<scf::ForOp> targets) {
+ assert(!forOp.getUnsignedCmp() && "unsigned loops are not supported");
auto originalStep = forOp.getStep();
auto iv = forOp.getInductionVar();
@@ -1241,6 +1242,8 @@ static Loops stripmineSink(scf::ForOp forOp, Value factor,
Loops innerLoops;
for (auto t : targets) {
+ assert(!t.getUnsignedCmp() && "unsigned loops are not supported");
+
// Save information for splicing ops out of t when done
auto begin = t.getBody()->begin();
auto nOps = t.getBody()->getOperations().size();
@@ -1415,6 +1418,8 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
scf::ForOp source,
RewriterBase &rewriter) {
+ assert(source.getUnsignedCmp() == target.getUnsignedCmp() &&
+ "incompatible signedness");
unsigned numTargetOuts = target.getNumResults();
unsigned numSourceOuts = source.getNumResults();
@@ -1428,7 +1433,8 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
rewriter.setInsertionPointAfter(source);
scf::ForOp fusedLoop = scf::ForOp::create(
rewriter, source.getLoc(), source.getLowerBound(), source.getUpperBound(),
- source.getStep(), fusedInitArgs);
+ source.getStep(), fusedInitArgs, /*bodyBuilder=*/nullptr,
+ source.getUnsignedCmp());
// Map original induction variables and operands to those of the fused loop.
IRMapping mapping;
diff --git a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
index 3b97786..dabbea1 100644
--- a/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp
@@ -71,7 +71,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createLowerAffinePass());
pm.addPass(
createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createFinalizeMemRefToLLVMConversionPass());
pm.addNestedPass<func::FuncOp>(createConvertComplexToStandardPass());
pm.addNestedPass<func::FuncOp>(arith::createArithExpandOpsPass());
pm.addNestedPass<func::FuncOp>(createConvertMathToLLVMPass());
@@ -79,12 +78,6 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createConvertComplexToLibm());
pm.addPass(
createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createConvertComplexToLLVMPass());
- pm.addPass(
- createConvertVectorToLLVMPass(options.convertVectorToLLVMOptions()));
- pm.addPass(createConvertFuncToLLVMPass());
- pm.addPass(createArithToLLVMConversionPass());
- pm.addPass(createConvertControlFlowToLLVMPass());
// Finalize GPU code generation.
if (gpuCodegen) {
@@ -99,8 +92,8 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
pm.addPass(createGpuModuleToBinaryPass(gpuModuleToBinaryPassOptions));
}
- // Convert poison values.
- pm.addPass(createUBToLLVMConversionPass());
+ // Convert to LLVM.
+ pm.addPass(createConvertToLLVMPass());
// Ensure all casts are realized.
pm.addPass(createReconcileUnrealizedCastsPass());
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 134aef3..0e88d31d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -730,9 +730,9 @@ public:
{tensor, lvlCoords, values, filled, added, count},
EmitCInterface::On);
Operation *parent = getTop(op);
+ rewriter.setInsertionPointAfter(parent);
rewriter.replaceOp(op, adaptor.getTensor());
// Deallocate the buffers on exit of the loop nest.
- rewriter.setInsertionPointAfter(parent);
memref::DeallocOp::create(rewriter, loc, values);
memref::DeallocOp::create(rewriter, loc, filled);
memref::DeallocOp::create(rewriter, loc, added);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
index 4464450..febec6d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp
@@ -533,8 +533,10 @@ static bool vectorizeStmt(PatternRewriter &rewriter, scf::ForOp forOp, VL vl,
VectorType vtp = vectorType(vl, init.getType());
Value vinit = genVectorReducInit(rewriter, loc, yield->getOperand(0),
forOp.getRegionIterArg(0), init, vtp);
- forOpNew = scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
- forOp.getUpperBound(), step, vinit);
+ forOpNew =
+ scf::ForOp::create(rewriter, loc, forOp.getLowerBound(),
+ forOp.getUpperBound(), step, vinit,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
forOpNew->setAttr(
LoopEmitter::getLoopEmitterLoopAttrName(),
forOp->getAttr(LoopEmitter::getLoopEmitterLoopAttrName()));
@@ -605,8 +607,8 @@ public:
ForOpRewriter(MLIRContext *context, unsigned vectorLength,
bool enableVLAVectorization, bool enableSIMDIndex32)
- : OpRewritePattern(context), vl{vectorLength, enableVLAVectorization,
- enableSIMDIndex32} {}
+ : OpRewritePattern(context),
+ vl{vectorLength, enableVLAVectorization, enableSIMDIndex32} {}
LogicalResult matchAndRewrite(scf::ForOp op,
PatternRewriter &rewriter) const override {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index e3cba388..fce61f2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -1120,13 +1120,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
}
if (rhsTy == resultTy) {
- if (isSplatZero(resultETy, lhsAttr))
+ if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
+ // constant values can only be resized if resulting type is static
return lhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, lhsAttr, shift))
return rhs;
}
if (lhsTy == resultTy) {
- if (isSplatZero(resultETy, rhsAttr))
+ if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
return rhsAttr.resizeSplat(resultTy);
if (isSplatOne(resultETy, rhsAttr, shift))
return lhs;
diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
index e6ef028..34385d7 100644
--- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
+++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp
@@ -276,7 +276,7 @@ std::optional<int64_t> constantTripCount(OpFoldResult lb, OpFoldResult ub,
if (!ubConstant)
return std::nullopt;
std::optional<int64_t> stepConstant = getConstantIntValue(step);
- if (!stepConstant)
+ if (!stepConstant || *stepConstant == 0)
return std::nullopt;
return llvm::divideCeilSigned(*ubConstant - *lbConstant, *stepConstant);
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a450056..74e48b5 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -2402,6 +2402,16 @@ LogicalResult ToElementsOp::fold(FoldAdaptor adaptor,
return foldToElementsFromElements(*this, results);
}
+LogicalResult
+ToElementsOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
+ ToElementsOp::Adaptor adaptor,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ auto vecType = cast<VectorType>(adaptor.getSource().getType());
+ Type elType = vecType.getElementType();
+ inferredReturnTypes.append(vecType.getNumElements(), elType);
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
@@ -2841,9 +2851,47 @@ LogicalResult BroadcastOp::verify() {
llvm_unreachable("unexpected vector.broadcast op error");
}
+// Fold broadcast(shape_cast(x)) into broadcast(x) if x's type is compatible
+// with broadcast's result type and shape_cast only adds or removes ones in the
+// leading dimensions.
+static LogicalResult foldBroadcastOfShapeCast(BroadcastOp broadcastOp) {
+ auto srcShapeCast = broadcastOp.getSource().getDefiningOp<ShapeCastOp>();
+ if (!srcShapeCast)
+ return failure();
+
+ VectorType srcType = srcShapeCast.getSourceVectorType();
+ VectorType destType = broadcastOp.getResultVectorType();
+ // Check type compatibility.
+ if (vector::isBroadcastableTo(srcType, destType) !=
+ BroadcastableToResult::Success)
+ return failure();
+
+ ArrayRef<int64_t> srcShape = srcType.getShape();
+ ArrayRef<int64_t> shapecastShape =
+ srcShapeCast.getResultVectorType().getShape();
+ // Trailing dimensions should be the same if shape_cast only alters the
+ // leading dimensions.
+ unsigned numTrailingDims = std::min(srcShape.size(), shapecastShape.size());
+ if (!llvm::equal(srcShape.take_back(numTrailingDims),
+ shapecastShape.take_back(numTrailingDims)))
+ return failure();
+
+ assert(all_of(srcShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ all_of(shapecastShape.drop_back(numTrailingDims),
+ [](int64_t E) { return E == 1; }) &&
+ "ill-formed shape_cast");
+
+ broadcastOp.getSourceMutable().assign(srcShapeCast.getSource());
+ return success();
+}
+
OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
if (getSourceType() == getResultVectorType())
return getSource();
+ if (succeeded(foldBroadcastOfShapeCast(*this)))
+ return getResult();
+
if (!adaptor.getSource())
return {};
auto vectorType = getResultVectorType();
diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
index 2d5cc07..fe066dc 100644
--- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
+++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp
@@ -139,6 +139,11 @@ void transform::ApplyLowerGatherPatternsOp::populatePatterns(
vector::populateVectorGatherLoweringPatterns(patterns);
}
+void transform::ApplyUnrollFromElementsPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ vector::populateVectorFromElementsLoweringPatterns(patterns);
+}
+
void transform::ApplyLowerScanPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
vector::populateVectorScanLoweringPatterns(patterns);
diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
index 9e287fc..acbf2b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
@@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
LowerVectorBitCast.cpp
LowerVectorBroadcast.cpp
LowerVectorContract.cpp
+ LowerVectorFromElements.cpp
LowerVectorGather.cpp
LowerVectorInterleave.cpp
LowerVectorMask.cpp
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
new file mode 100644
index 0000000..c22fd54
--- /dev/null
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorFromElements.cpp
@@ -0,0 +1,65 @@
+//===- LowerVectorFromElements.cpp - Lower 'vector.from_elements' op -----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements target-independent rewrites and utilities to lower the
+// 'vector.from_elements' operation.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
+
+#define DEBUG_TYPE "lower-vector-from-elements"
+
+using namespace mlir;
+
+namespace {
+
+/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
+/// outermost dimension. For example:
+/// ```
+/// %v = vector.from_elements %e0, %e1, %e2, %e3, %e4, %e5 : vector<2x3xf32>
+///
+/// ==>
+///
+/// %0 = ub.poison : vector<2x3xf32>
+/// %v0 = vector.from_elements %e0, %e1, %e2 : vector<3xf32>
+/// %1 = vector.insert %v0, %0 [0] : vector<3xf32> into vector<2x3xf32>
+/// %v1 = vector.from_elements %e3, %e4, %e5 : vector<3xf32>
+/// %v = vector.insert %v1, %1 [1] : vector<3xf32> into vector<2x3xf32>
+/// ```
+///
+/// When applied exhaustively, this will produce a sequence of 1-d from_elements
+/// ops.
+struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::FromElementsOp op,
+ PatternRewriter &rewriter) const override {
+ ValueRange allElements = op.getElements();
+
+ auto unrollFromElementsFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ size_t subTyNumElements = subTy.getNumElements();
+ assert((index + 1) * subTyNumElements <= allElements.size() &&
+ "out of bounds");
+ ValueRange subElements =
+ allElements.slice(index * subTyNumElements, subTyNumElements);
+ return vector::FromElementsOp::create(rewriter, loc, subTy, subElements);
+ };
+
+ return unrollVectorOp(op, rewriter, unrollFromElementsFn);
+ }
+};
+
+} // namespace
+
+void mlir::vector::populateVectorFromElementsLoweringPatterns(
+ RewritePatternSet &patterns, PatternBenefit benefit) {
+ patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
+}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
index e062f55..90f21c5 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp
@@ -54,27 +54,13 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
LogicalResult matchAndRewrite(vector::GatherOp op,
PatternRewriter &rewriter) const override {
- VectorType resultTy = op.getType();
- if (resultTy.getRank() < 2)
- return rewriter.notifyMatchFailure(op, "already 1-D");
-
- // Unrolling doesn't take vscale into account. Pattern is disabled for
- // vectors with leading scalable dim(s).
- if (resultTy.getScalableDims().front())
- return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
-
- Location loc = op.getLoc();
Value indexVec = op.getIndexVec();
Value maskVec = op.getMask();
Value passThruVec = op.getPassThru();
- Value result = arith::ConstantOp::create(rewriter, loc, resultTy,
- rewriter.getZeroAttr(resultTy));
-
- VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
-
- for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
- int64_t thisIdx[1] = {i};
+ auto unrollGatherFn = [&](PatternRewriter &rewriter, Location loc,
+ VectorType subTy, int64_t index) {
+ int64_t thisIdx[1] = {index};
Value indexSubVec =
vector::ExtractOp::create(rewriter, loc, indexVec, thisIdx);
@@ -82,15 +68,12 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
vector::ExtractOp::create(rewriter, loc, maskVec, thisIdx);
Value passThruSubVec =
vector::ExtractOp::create(rewriter, loc, passThruVec, thisIdx);
- Value subGather = vector::GatherOp::create(
- rewriter, loc, subTy, op.getBase(), op.getIndices(), indexSubVec,
- maskSubVec, passThruSubVec);
- result =
- vector::InsertOp::create(rewriter, loc, subGather, result, thisIdx);
- }
+ return vector::GatherOp::create(rewriter, loc, subTy, op.getBase(),
+ op.getIndices(), indexSubVec, maskSubVec,
+ passThruSubVec);
+ };
- rewriter.replaceOp(op, result);
- return success();
+ return unrollVectorOp(op, rewriter, unrollGatherFn);
}
};
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index bb0f339..be0d28a 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1826,7 +1826,8 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
rewriter.setInsertionPointAfter(newWarpOp);
auto newForOp = scf::ForOp::create(
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
- forOp.getStep(), newForOpOperands);
+ forOp.getStep(), newForOpOperands, /*bodyBuilder=*/nullptr,
+ forOp.getUnsignedCmp());
// Next, we insert a new `WarpOp` (called inner `WarpOp`) inside the
// newly created `ForOp`. This `WarpOp` will contain all ops that were
// contained within the original `ForOp` body.
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 2269a40..023c4da 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2274,7 +2274,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
LogicalResult matchAndRewrite(MulOpType mulOp,
PatternRewriter &rewriter) const override {
- auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
+ auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
if (!resType)
return failure();
if (resType.getRank() != 2)
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 6e2fa35..841e138 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -392,3 +392,29 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
}
return success();
}
+
+LogicalResult vector::unrollVectorOp(Operation *op, PatternRewriter &rewriter,
+ vector::UnrollVectorOpFn unrollFn) {
+ assert(op->getNumResults() == 1 && "expected single result");
+ assert(isa<VectorType>(op->getResult(0).getType()) && "expected vector type");
+ VectorType resultTy = cast<VectorType>(op->getResult(0).getType());
+ if (resultTy.getRank() < 2)
+ return rewriter.notifyMatchFailure(op, "already 1-D");
+
+ // Unrolling doesn't take vscale into account. Pattern is disabled for
+ // vectors with leading scalable dim(s).
+ if (resultTy.getScalableDims().front())
+ return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
+
+ Location loc = op->getLoc();
+ Value result = ub::PoisonOp::create(rewriter, loc, resultTy);
+ VectorType subTy = VectorType::Builder(resultTy).dropDim(0);
+
+ for (int64_t i = 0, e = resultTy.getShape().front(); i < e; ++i) {
+ Value subVector = unrollFn(rewriter, loc, subTy, i);
+ result = vector::InsertOp::create(rewriter, loc, subVector, result, i);
+ }
+
+ rewriter.replaceOp(op, result);
+ return success();
+}
diff --git a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
index 242a97c..7869a28 100644
--- a/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
@@ -7,13 +7,18 @@ add_mlir_dialect_library(MLIRXeGPUDialect
DEPENDS
MLIRXeGPUIncGen
+ MLIRXeGPUAttrInterfaceIncGen
MLIRXeGPUAttrsIncGen
MLIRXeGPUEnumsIncGen
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRIndexDialect
+ MLIRAffineUtils
MLIRArithUtils
MLIRDialectUtils
+ MLIRGPUDialect
+ MLIRXeVMDialect
MLIRIR
MLIRViewLikeInterface
MLIRVectorDialect
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 3c0ca114..8ea8cb1 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -6,12 +6,16 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/Utils.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/Debug.h"
using std::optional;
@@ -33,6 +37,57 @@ void XeGPUDialect::initialize() {
>();
}
+/// Generates instructions to compute offsets for a subgroup identified by
+/// its multidimensional indices (sgId), using the specified subgroup layout
+/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
+/// dimensions (sizePerWg).
+static SmallVector<SmallVector<Value>>
+genOffsetsComputingInsts(OpBuilder &builder, Location loc,
+ SmallVector<Value> sgId, ArrayRef<int64_t> sgLayout,
+ ArrayRef<int64_t> sizePerSg,
+ ArrayRef<int64_t> sizePerWg) {
+
+ SmallVector<SmallVector<Value>> offsets;
+
+ // nd local offset, localOffset[i] = sgId[i] * sizePerSg[i]
+ SmallVector<Value> localOffsets = llvm::map_to_vector(
+ llvm::zip(sgId, sizePerSg), [&](const auto &t) -> Value {
+ return builder.createOrFold<index::MulOp>(
+ loc, std::get<0>(t),
+ builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
+ });
+
+ // distUnit[i] is the minimum value between sizePerWg[i] and
+ // sgLayout[i] * sizePerSg[i]
+ SmallVector<int64_t> distUnit = llvm::map_to_vector(
+ llvm::zip_equal(sizePerWg, computeElementwiseMul(sgLayout, sizePerSg)),
+ [](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });
+
+ for (SmallVector<int64_t> unitOffs :
+ StaticTileOffsetRange(sizePerWg, distUnit)) {
+ SmallVector<Value> base =
+ llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
+ return arith::ConstantIndexOp::create(builder, loc, d);
+ });
+
+ SmallVector<Value> adds = llvm::map_to_vector(
+ llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
+ return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
+ std::get<1>(t));
+ });
+
+ SmallVector<Value> mods = llvm::map_to_vector(
+ llvm::zip_equal(adds, sizePerWg), [&](const auto &t) -> Value {
+ return builder.createOrFold<index::RemUOp>(
+ loc, std::get<0>(t),
+ arith::ConstantIndexOp::create(builder, loc, std::get<1>(t)));
+ });
+
+ offsets.push_back(mods);
+ }
+ return offsets;
+}
+
// Checks if the given shape can be evenly distributed based on the layout
// and data factors provided by the LayoutAttr.
bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
@@ -211,6 +266,148 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
return success();
}
+FailureOr<SmallVector<Value>>
+LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+ Value linearId) {
+ // delinearizeSubgroupId is only available for
+ // workgroup-level layout attribute
+ if (!isWgLayout())
+ return failure();
+
+ // TODO: handle order attribute
+ auto hasDefaultOrder = [&]() {
+ DenseI32ArrayAttr order = getOrder();
+ return !order || isIdentityPermutation(llvm::to_vector_of<int64_t>(
+ llvm::reverse(order.asArrayRef())));
+ };
+ if (!hasDefaultOrder())
+ return mlir::emitError(loc, "order attribute is currently not supported.");
+
+ auto dims = llvm::map_to_vector(*getSgLayoutAsInt(), [&](int64_t d) -> Value {
+ return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
+ });
+
+ return affine::delinearizeIndex(builder, loc, linearId, dims);
+}
+
+/// Implements LayoutTrait::getOffsets to generate instructions for
+/// computing multi-dimensional offsets when distributed by LayoutAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+ ArrayRef<int64_t> shape) {
+ if (!isWgLayout())
+ return failure();
+
+ SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
+ SmallVector<int64_t> sgShape;
+ if (auto maybeSgShape = getSgDataAsInt())
+ sgShape = maybeSgShape.value();
+ else if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+ sgShape = derivedShape.value();
+ else
+ return failure();
+
+ // delinearize Ids
+ auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ if (failed(maybeIds))
+ return failure();
+ SmallVector<Value> sgIds = *maybeIds;
+
+ return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
+ shape);
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_SliceAttr
+//===----------------------------------------------------------------------===//
+LogicalResult
+SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+ xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
+ if (!parent || !dims)
+ return emitError() << "expected parent layout and dims attribute";
+
+ int64_t rank = parent.getRank();
+
+ // check every element in dims is unique and smaller than rank
+ llvm::SmallDenseSet<int64_t> seen;
+ for (int64_t dim : dims.asArrayRef()) {
+ if (dim < 0 || dim >= rank)
+ return emitError() << "invalid dim (" << dim << ") in slice attribute.";
+ if (!seen.insert(dim).second)
+ return emitError() << "repeated dim (" << dim << ") in slice attribute.";
+ }
+ return success();
+}
+
+SliceAttr SliceAttr::flatten() const {
+ xegpu::LayoutTrait parent = getParent();
+ SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
+
+ while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
+ parent = sliceAttr.getParent();
+ slicedDims.push_back(sliceAttr.getDims());
+ }
+
+ auto layoutAttr = dyn_cast<xegpu::LayoutAttr>(parent);
+ SmallVector<int64_t> indices =
+ llvm::to_vector(llvm::seq<int64_t>(0, layoutAttr.getRank()));
+
+ // get remaining dims (flattend) by applying slice ops with all slicedDims
+ SmallVector<int64_t> remainingDims(indices);
+ for (auto dim : llvm::reverse(slicedDims))
+ remainingDims = XeGPUDialect::slice(llvm::ArrayRef<int64_t>(remainingDims),
+ dim.asArrayRef());
+
+ // get flattend sliced dims by applying slice ops with the remaining dims
+ SmallVector<int64_t> flattendDims = XeGPUDialect::slice(
+ llvm::ArrayRef<int64_t>(indices), llvm::ArrayRef<int64_t>(remainingDims));
+
+ return xegpu::SliceAttr::get(
+ getContext(), layoutAttr,
+ DenseI64ArrayAttr::get(getContext(), flattendDims));
+}
+
+FailureOr<SmallVector<Value>>
+SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
+ Value linearId) {
+ SliceAttr attr = flatten();
+ auto parent = dyn_cast<LayoutAttr>(attr.getParent());
+ return parent.delinearizeSubgroupId(builder, loc, linearId);
+}
+
+/// Implements LayoutTrait::getOffsets to generate instructions for
+/// computing multi-dimensional offsets when distributed by SliceAttr.
+FailureOr<SmallVector<SmallVector<Value>>>
+SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
+ ArrayRef<int64_t> shape) {
+ assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
+ if (!isWgLayout())
+ return failure();
+
+ SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
+ SmallVector<int64_t> sgShape;
+ if (auto maybeSgShape = getSgDataAsInt())
+ sgShape = maybeSgShape.value();
+ else if (auto derivedShape = computeShapeRatio(shape, sgLayout))
+ sgShape = derivedShape.value();
+ else
+ return failure();
+
+ // delinearize Ids
+ auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
+ if (failed(maybeIds))
+ return failure();
+
+ // The effective sgIds for offsets computing correspond
+ // to the dims that are not sliced.
+ ArrayRef<int64_t> dims = flatten().getDims().asArrayRef();
+ SmallVector<Value> sgIds =
+ XeGPUDialect::slice(ArrayRef<Value>(*maybeIds), dims);
+
+ return genOffsetsComputingInsts(builder, loc, sgIds, sgLayout, sgShape,
+ shape);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
@@ -230,7 +427,7 @@ RangeAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
// XeGPU_TensorDescType
//===----------------------------------------------------------------------===//
-mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
+mlir::Type TensorDescType::parse(AsmParser &parser) {
llvm::SmallVector<int64_t> shape;
mlir::Type elementType;
mlir::FailureOr<mlir::Attribute> encoding;
@@ -280,7 +477,7 @@ mlir::Type TensorDescType::parse(::mlir::AsmParser &parser) {
layout.value_or(mlir::Attribute()));
}
-void TensorDescType::print(::mlir::AsmPrinter &printer) const {
+void TensorDescType::print(AsmPrinter &printer) const {
printer << "<";
auto shape = getShape();
@@ -325,10 +522,10 @@ TensorDescType TensorDescType::get(llvm::ArrayRef<int64_t> shape,
return Base::get(context, shape, elementType, attr, layout);
}
-LogicalResult TensorDescType::verify(
- llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
- llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
- mlir::Attribute encoding, mlir::Attribute layout) {
+LogicalResult
+TensorDescType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
+ llvm::ArrayRef<int64_t> shape, mlir::Type elementType,
+ mlir::Attribute encoding, mlir::Attribute layout) {
size_t rank = shape.size();
if (rank == 0)
@@ -394,6 +591,119 @@ LogicalResult TensorDescType::verify(
return success();
}
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+mlir::Type MemDescType::parse(AsmParser &parser) {
+ llvm::SmallVector<int64_t> shape;
+ mlir::Type elementType;
+ mlir::FailureOr<MemLayoutAttr> layout;
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ auto shapeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseDimensionList(shape, false, true))) {
+ parser.emitError(shapeLoc, "failed to parse parameter 'shape'");
+ return {};
+ }
+
+ auto elemTypeLoc = parser.getCurrentLocation();
+ if (mlir::failed(parser.parseType(elementType))) {
+ parser.emitError(elemTypeLoc, "failed to parse parameter 'elementType'");
+ return {};
+ }
+
+ // parse optional attributes
+ if (mlir::succeeded(parser.parseOptionalComma())) {
+ MemLayoutAttr attr;
+ ParseResult res = parser.parseAttribute(attr);
+ if (mlir::failed(res))
+ return {};
+ layout = attr;
+ }
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ MLIRContext *ctxt = parser.getContext();
+ return MemDescType::getChecked(
+ [&]() { return parser.emitError(parser.getNameLoc()); }, ctxt, shape,
+ elementType, layout.value_or(MemLayoutAttr()));
+}
+
+void MemDescType::print(AsmPrinter &printer) const {
+ printer << "<";
+
+ printer.printDimensionList(getShape());
+ printer << 'x';
+ printer << getElementType();
+
+ if (auto layout = getMemLayout())
+ printer << ", " << layout;
+
+ printer << ">";
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescType
+//===----------------------------------------------------------------------===//
+
+Attribute MemLayoutAttr::parse(AsmParser &parser, Type type) {
+
+ auto context = parser.getContext();
+ llvm::SMLoc loc = parser.getCurrentLocation();
+
+ llvm::SmallDenseSet<StringRef> seenKeys;
+ SmallVector<NamedAttribute> attributes;
+
+ auto parseElt = [&]() -> ParseResult {
+ StringRef nameId;
+ if (failed(parser.parseKeyword(&nameId)))
+ return parser.emitError(loc, "expected valid attribute name");
+
+ if (!seenKeys.insert(nameId).second)
+ return parser.emitError(loc, "duplicate key '")
+ << nameId << " in mem layout attribute";
+
+ if (failed(parser.parseEqual()))
+ return failure();
+
+ Attribute attr;
+ if (failed(parser.parseAttribute(attr)))
+ return failure();
+ attributes.emplace_back(nameId, attr);
+ return success();
+ };
+
+ // Parse literal '<'
+ if (parser.parseLess())
+ return {};
+
+ if (failed(parser.parseCommaSeparatedList(parseElt)))
+ return {};
+
+ // Parse literal '>'
+ if (parser.parseGreater())
+ return {};
+
+ return parser.getChecked<MemLayoutAttr>(
+ loc, context, DictionaryAttr::get(context, attributes));
+}
+
+void MemLayoutAttr::print(AsmPrinter &printer) const {
+ printer << "<";
+ ArrayRef<NamedAttribute> attrs = getAttrs().getValue();
+ for (size_t i = 0; i < attrs.size(); i++) {
+ printer << attrs[i].getName().str() << " = " << attrs[i].getValue();
+ if (i < attrs.size() - 1)
+ printer << ", ";
+ }
+ printer << ">";
+}
+
} // namespace xegpu
} // namespace mlir
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 33450f3..906c71d 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -7,6 +7,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -21,6 +23,17 @@
namespace mlir {
namespace xegpu {
+bool isSharedMemory(const MemRefType &memrefTy) {
+ Attribute attr = memrefTy.getMemorySpace();
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
+ return intAttr.getInt() == 3;
+ if (auto memrefSpace = llvm::dyn_cast<MemorySpaceAttr>(attr))
+ return memrefSpace.getValue() == MemorySpace::SLM;
+ if (auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
+ return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
+ return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
+}
+
template <typename T>
static std::string makeString(T array, bool breakline = false) {
std::string buf;
@@ -121,12 +134,20 @@ isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
- // a valid shape for SIMT case
- if (valueTy.getRank() == 1) {
- if (valueTy.getNumElements() != chunkSize)
- return emitError() << "value elements must match chunk size " << chunkSize
- << " for SIMT code.";
- return success();
+ auto maskVecTy = dyn_cast<VectorType>(maskTy);
+ if (!maskVecTy)
+ return emitError() << "Expecting a vector type mask.";
+ int64_t maskSize = maskVecTy.getNumElements();
+
+ auto valueSize = valueTy.getNumElements();
+ if (chunkSize > 1) {
+ if ((valueTy.getRank() == 1) && (valueSize != chunkSize))
+ return emitError() << "value elements must match chunk size "
+ << chunkSize;
+ } else {
+ if (valueSize != maskSize)
+ return emitError()
+ << "Mask should match value except the chunk size dim.";
}
llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
@@ -156,41 +177,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<MemRefType> source,
+ Type tdesc, Value source,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
- assert(shape.size() && strides.size() && shape.size() == strides.size() &&
- "Shape and strides must be present and of equal size for ui64 "
- "initialization.");
+ Type srcTy = source.getType();
+ assert((isa<IntegerType, MemRefType>(srcTy)) &&
+ "Source has to be either int or memref.");
- llvm::SmallVector<int64_t> staticShape;
- llvm::SmallVector<int64_t> staticStrides;
llvm::SmallVector<Value> dynamicShape;
llvm::SmallVector<Value> dynamicStrides;
- dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
- dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
-
- auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
- auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
-
- build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
- dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
- staticStridesAttr);
-}
-
-void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
- Type tdesc, TypedValue<IntegerType> source,
- llvm::ArrayRef<OpFoldResult> shape,
- llvm::ArrayRef<OpFoldResult> strides) {
- assert(shape.size() && strides.size() && shape.size() == strides.size() &&
- "Shape and strides must be present and of equal size for ui64 "
- "initialization.");
-
llvm::SmallVector<int64_t> staticShape;
llvm::SmallVector<int64_t> staticStrides;
- llvm::SmallVector<Value> dynamicShape;
- llvm::SmallVector<Value> dynamicStrides;
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
@@ -198,6 +196,18 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
+ if (auto memrefTy = dyn_cast<MemRefType>(srcTy)) {
+ auto memrefShape = memrefTy.getShape();
+ auto [memrefStrides, _] = memrefTy.getStridesAndOffset();
+
+ // if shape and strides are from Memref, we don't need attributes for them
+ // to keep the IR print clean.
+ if (staticShape == memrefShape && staticStrides == memrefStrides) {
+ staticShapeAttr = DenseI64ArrayAttr();
+ staticStridesAttr = DenseI64ArrayAttr();
+ }
+ }
+
build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
staticStridesAttr);
@@ -265,8 +275,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}
LogicalResult CreateNdDescOp::verify() {
- auto rank = (int64_t)getMixedOffsets().size();
- bool invalidRank = false;
+ size_t rank = getMixedSizes().size();
+ bool invalidRank = rank != getMixedStrides().size();
bool invalidElemTy = false;
// Memory space of created TensorDesc should match with the source.
@@ -280,31 +290,28 @@ LogicalResult CreateNdDescOp::verify() {
<< " Source: " << srcMemorySpace
<< ", TensorDesc: " << tdescMemorySpace;
+ if (size_t offsetRank = getMixedOffsets().size())
+ invalidRank |= (offsetRank != rank);
+
// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
- auto memrefTy = dyn_cast<MemRefType>(getSourceType());
- if (memrefTy) {
- invalidRank |= (memrefTy.getRank() != rank);
+ if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
invalidElemTy |= memrefTy.getElementType() != getElementType();
- }
if (llvm::isa<IntegerType>(getSourceType())) {
// strides and shape must present for integer source.
if (getMixedStrides().empty() || getMixedSizes().empty())
- return emitOpError("Expecting strides and shape to be present for "
+ return emitOpError("expecting strides and shape to be present for "
"integer source.");
}
- // mismatches among shape, strides, and offsets are
- // already handeled by OffsetSizeAndStrideOpInterface.
- // So they are not check here.
if (invalidRank)
return emitOpError(
"Expecting the rank of shape, strides, offsets, and source (if source "
"is a memref) should match with each other.");
// check result TensorDesc rank
- if (getType().getRank() > rank)
+ if (getType().getRank() > (int64_t)rank)
return emitOpError(
"Expecting the TensorDesc rank is not greater than the "
"ranks of shape, strides, offsets or the memref source.");
@@ -360,13 +367,10 @@ ParseResult parseOptionalDynamicIndexList(
void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
DenseI64ArrayAttr integers) {
-
- if (!integers)
+ if (!integers || integers.empty())
return;
-
- return printDynamicIndexList(printer, op, values, integers,
- /*scalableFlags=*/{}, {},
- AsmParser::Delimiter::Square);
+ printDynamicIndexList(printer, op, values, integers,
+ /*scalableFlags=*/{}, {}, AsmParser::Delimiter::Square);
}
//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
@@ -381,6 +385,21 @@ void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
l1_hint, l2_hint, l3_hint);
}
+void PrefetchNdOp::build(OpBuilder &builder, OperationState &state,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, tensorDesc, dynamicOffsets, staticOffsetsAttr, l1_hint,
+ l2_hint, l3_hint);
+}
+
LogicalResult PrefetchNdOp::verify() {
auto tdescTy = getTensorDescType();
if (tdescTy.isScattered())
@@ -423,6 +442,22 @@ void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
l3_hint);
}
+void LoadNdOp::build(OpBuilder &builder, OperationState &state, Type retType,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ UnitAttr packed, DenseI64ArrayAttr transpose,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, retType, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+ packed, transpose, l1_hint, l2_hint, l3_hint);
+}
+
LogicalResult LoadNdOp::verify() {
auto tdescTy = getTensorDescType();
auto valueTy = getType();
@@ -529,6 +564,21 @@ void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
DenseI64ArrayAttr(), l1_hint, l2_hint, l3_hint);
}
+void StoreNdOp::build(OpBuilder &builder, OperationState &state, Value value,
+ Value tensorDesc, ArrayRef<OpFoldResult> offsets,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ SmallVector<Value> dynamicOffsets;
+ SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+
+ build(builder, state, value, tensorDesc, dynamicOffsets, staticOffsetsAttr,
+ l1_hint, l2_hint, l3_hint);
+}
+
LogicalResult StoreNdOp::verify() {
auto dstTy = getTensorDescType(); // Tile
auto valTy = getValueType(); // Vector
@@ -674,7 +724,7 @@ LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
if (tdescTy && !tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
+ return emitOpError("Expects a scattered TensorDesc.");
if (!tdescTy && getRankOf(getSource()) > 1)
return emitOpError(
@@ -755,7 +805,7 @@ LogicalResult StoreScatterOp::verify() {
auto valueTy = getValueType();
if (tdescTy && !tdescTy.isScattered())
- return emitOpError("Expects a scattered TensorDesc.\n");
+ return emitOpError("Expects a scattered TensorDesc.");
if (!tdescTy && getRankOf(getDest()) > 1)
return emitOpError(
@@ -928,9 +978,107 @@ void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
patterns.add<FoldConvertLayoutOp>(context);
}
+//===----------------------------------------------------------------------===//
+// XeGPU_LoadMatrixOp
+//===----------------------------------------------------------------------===//
+void LoadMatrixOp::build(OpBuilder &builder, OperationState &state, Type res,
+ TypedValue<MemDescType> memDesc,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ LayoutTrait layout) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, res, memDesc, dynamicOffsets, staticOffsetsAttr,
+ layout);
+}
+
+LogicalResult LoadMatrixOp::verify() {
+ VectorType resTy = getRes().getType();
+ MemDescType mdescTy = getMemDesc().getType();
+
+ if (mdescTy.getRank() != 2)
+ return emitOpError("mem_desc must be 2D.");
+
+ ArrayRef<int64_t> valueShape = resTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ if (llvm::any_of(llvm::zip_equal(valueShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("result shape must not exceed mem_desc shape.");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_StoreMatrixOp
+//===----------------------------------------------------------------------===//
+void StoreMatrixOp::build(OpBuilder &builder, OperationState &state, Value data,
+ TypedValue<MemDescType> memDesc,
+ llvm::ArrayRef<OpFoldResult> offsets,
+ LayoutTrait layout) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, data, memDesc, dynamicOffsets, staticOffsetsAttr,
+ layout);
+}
+
+LogicalResult StoreMatrixOp::verify() {
+ VectorType dataTy = getData().getType();
+ MemDescType mdescTy = getMemDesc().getType();
+
+ if (mdescTy.getRank() != 2)
+ return emitOpError("mem_desc must be 2D.");
+
+ ArrayRef<int64_t> dataShape = dataTy.getShape();
+ ArrayRef<int64_t> mdescShape = mdescTy.getShape();
+ if (llvm::any_of(llvm::zip_equal(dataShape, mdescShape),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("data shape must not exceed mem_desc shape.");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// XeGPU_MemDescSubviewOp
+//===----------------------------------------------------------------------===//
+
+void MemDescSubviewOp::build(OpBuilder &builder, OperationState &state,
+ Type resTy, Value src,
+ llvm::ArrayRef<OpFoldResult> offsets) {
+ llvm::SmallVector<Value> dynamicOffsets;
+ llvm::SmallVector<int64_t> staticOffsets;
+ dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
+ auto staticOffsetsAttr = builder.getDenseI64ArrayAttr(staticOffsets);
+ build(builder, state, resTy, src, dynamicOffsets, staticOffsetsAttr);
+}
+
+LogicalResult MemDescSubviewOp::verify() {
+ MemDescType srcTy = getSrc().getType();
+ MemDescType resTy = getRes().getType();
+ ArrayRef<int64_t> srcShape = srcTy.getShape();
+ ArrayRef<int64_t> resShape = resTy.getShape();
+
+ if (srcTy.getRank() < resTy.getRank())
+ return emitOpError("result rank must not exceed source rank.");
+
+ if (llvm::any_of(
+ llvm::zip_equal(resShape, srcShape.take_back(resShape.size())),
+ [](auto p) { return std::get<0>(p) > std::get<1>(p); }))
+ return emitOpError("result shape must not exceed source shape.");
+
+ if (srcTy.getStrides() != resTy.getStrides())
+ return emitOpError("result must inherit the source strides.");
+
+ return success();
+}
+
} // namespace xegpu
} // namespace mlir
+namespace mlir {
+#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.cpp.inc>
+} // namespace mlir
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.cpp.inc>
#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.cpp.inc>
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index 850f70c..8f1208e 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -125,42 +125,15 @@ getSgShapeAndCount(ArrayRef<int64_t> shape, xegpu::LayoutAttr layout) {
struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
- // Calculate offset for each subgroup
- static SmallVector<OpFoldResult>
- calculateGlobalOffsets(ConversionPatternRewriter &rewriter, Location loc,
- const SmallVector<OpFoldResult> &originalOffsets,
- const SmallVector<Value> &localOffset,
- const SmallVector<int64_t> &distUnitBaseAddr,
- const SmallVector<int64_t> &distUnitShape) {
- assert(localOffset.size() == distUnitBaseAddr.size() &&
- "localOffset and distUnitBaseAddr must have the same rank");
-
- SmallVector<OpFoldResult> globalOffsets(originalOffsets.begin(),
- originalOffsets.end());
- size_t rank = localOffset.size();
- for (size_t i = 0; i < rank; ++i) {
- size_t dimIdx = originalOffsets.size() - rank + i;
- Value constOffset =
- arith::ConstantIndexOp::create(rewriter, loc, distUnitBaseAddr[i]);
- Value offset =
- rewriter.createOrFold<index::AddOp>(loc, localOffset[i], constOffset);
- Value modValue =
- arith::ConstantIndexOp::create(rewriter, loc, distUnitShape[i]);
- Value offsetMod =
- rewriter.createOrFold<index::RemUOp>(loc, offset, modValue);
- Value origOffset = getValueOrCreateConstantIndexOp(
- rewriter, loc, originalOffsets[dimIdx]);
- Value globalOffset =
- rewriter.createOrFold<index::AddOp>(loc, origOffset, offsetMod);
- globalOffsets[dimIdx] = globalOffset;
- }
-
- return globalOffsets;
- }
-
LogicalResult
matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
+
+ // Ensure that the op has explicit offsets specified (either dynamic or
+ // constant).
+ if (op.getMixedOffsets().empty())
+ return failure();
+
Location loc = op.getLoc();
MLIRContext *ctx = op.getContext();
xegpu::TensorDescType tdescTy = op.getType();
@@ -177,73 +150,98 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
return rewriter.notifyMatchFailure(
op, "sgLayout attribute is required in layout");
- SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
-
- // TODO : Handle order attribute
// Get the subgroup ID
- auto linearSgId =
+ Value linearSgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
- // Create constants for layout dimensions
- SmallVector<Value> sgLayoutDim(sgLayout.size());
- SmallVector<Value> sgDataDim(sgShape.size());
-
- for (size_t i = 0; i < sgLayout.size(); i++) {
- sgLayoutDim[i] =
- arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]);
- sgDataDim[i] = arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]);
- }
-
int64_t startOfRange = -1, endOfRange = -1;
bool sgIdRangeSpecified =
isSgIdRangeSpecified(op, startOfRange, endOfRange);
- Value adjustedSgId = linearSgId;
if (sgIdRangeSpecified) {
int64_t sgCount = endOfRange - startOfRange;
if (computeProduct(sgLayout) != sgCount)
return rewriter.notifyMatchFailure(
op, "sg_layout size must match the sg_id_range");
- // Subtract startOfRange from the original subgroup id to get the adjusted
- // sg id
+ // Subtract startOfRange from the original subgroup id to get
+ // the adjusted sg id
Value startOfRangeVal =
arith::ConstantIndexOp::create(rewriter, loc, startOfRange);
- adjustedSgId =
+ linearSgId =
rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
}
- auto deLinearizeSgId =
- affine::delinearizeIndex(rewriter, loc, adjustedSgId, sgLayoutDim);
- if (failed(deLinearizeSgId))
+ auto maybeTdescOffsets =
+ layout.getOffsets(rewriter, loc, linearSgId, wgShape);
+ if (failed(maybeTdescOffsets))
return failure();
- SmallVector<Value> sgIds = *deLinearizeSgId;
-
- // Calculate distribution unit shape and local offsets for subgroup
- SmallVector<int64_t> distUnitShape(sgLayout.size());
- SmallVector<Value> localOffset(sgLayout.size());
- for (size_t i = 0; i < sgLayout.size(); i++) {
- distUnitShape[i] = std::min(sgLayout[i] * sgShape[i], wgShape[i]);
- localOffset[i] =
- rewriter.createOrFold<index::MulOp>(loc, sgIds[i], sgDataDim[i]);
- }
-
- SmallVector<OpFoldResult> originalOffsets = op.getMixedOffsets();
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
xegpu::TensorDescType newTdescTy =
xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
layout.dropSgLayoutAndData());
+
SmallVector<Value> newCreateNdOps;
- for (SmallVector<int64_t> distUnitBaseAddr :
- StaticTileOffsetRange(wgShape, distUnitShape)) {
- SmallVector<OpFoldResult> globalOffsets =
- calculateGlobalOffsets(rewriter, loc, originalOffsets, localOffset,
- distUnitBaseAddr, distUnitShape);
-
- auto newCreateNdOp = xegpu::CreateNdDescOp::create(
- rewriter, loc, newTdescTy, op.getSource(), globalOffsets,
+ SmallVector<OpFoldResult> origOffsets = op.getMixedOffsets();
+
+ for (auto tdescOffsets : *maybeTdescOffsets) {
+ SmallVector<OpFoldResult> sgOffsets;
+ size_t rank = tdescOffsets.size();
+ for (size_t i = 0; i < rank; i++) {
+ size_t idx = origOffsets.size() - rank + i;
+ Value add = rewriter.createOrFold<index::AddOp>(
+ loc, tdescOffsets[i],
+ getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
+ sgOffsets.push_back(add);
+ }
+
+ auto newOp = xegpu::CreateNdDescOp::create(
+ rewriter, loc, newTdescTy, op.getSource(), sgOffsets,
op.getMixedSizes(), op.getMixedStrides());
- newCreateNdOps.push_back(newCreateNdOp);
+ newCreateNdOps.push_back(newOp);
}
+ rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
+ return success();
+ }
+};
+
+// This pattern transforms the CreateNdDescOp without offsets to create a
+// subgroup descriptor from a workgroup descriptor
+struct WgToSgCreateNdOpNoOffset
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ // Check no offsets are specified.
+ if (!op.getMixedOffsets().empty())
+ return failure();
+
+ Location loc = op.getLoc();
+ MLIRContext *ctx = op.getContext();
+ xegpu::TensorDescType tdescTy = op.getType();
+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+ if (!layout || !layout.isWgLayout())
+ return failure();
+
+ Type elemTy = tdescTy.getElementType();
+ ArrayRef<int64_t> wgShape = tdescTy.getShape();
+
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+ xegpu::TensorDescType newTdescTy =
+ xegpu::TensorDescType::get(ctx, sgShape, elemTy, tdescTy.getEncoding(),
+ layout.dropSgLayoutAndData());
+
+ SmallVector<Value> newCreateNdOps(count);
+ std::generate(newCreateNdOps.begin(), newCreateNdOps.end(), [&]() {
+ return xegpu::CreateNdDescOp::create(rewriter, loc, newTdescTy,
+ op.getSource(), op.getMixedSizes(),
+ op.getMixedStrides());
+ });
rewriter.replaceOpWithMultiple(op, {newCreateNdOps});
return success();
@@ -298,6 +296,205 @@ struct WgToSgStoreNdOp : public OpConversionPattern<xegpu::StoreNdOp> {
}
};
+// Utility function to compute global offsets for subgroup operations.
+// Returns a vector of new offsets for each subgroup, given the original op's
+// offsets and subgroup relative offsets.
+static SmallVector<SmallVector<OpFoldResult>>
+computeOffsets(Operation *op, ArrayRef<SmallVector<Value>> sgOffsetsList,
+ ArrayRef<OpFoldResult> origOffsets,
+ ConversionPatternRewriter &rewriter) {
+ SmallVector<SmallVector<OpFoldResult>> finalOffsets;
+ Location loc = op->getLoc();
+ for (const auto &sgOffsets : sgOffsetsList) {
+ SmallVector<OpFoldResult> newOffsets;
+ size_t rank = sgOffsets.size();
+ for (size_t i = 0; i < rank; i++) {
+ size_t idx = origOffsets.size() - rank + i;
+ Value add = rewriter.createOrFold<index::AddOp>(
+ loc, sgOffsets[i],
+ getValueOrCreateConstantIndexOp(rewriter, loc, origOffsets[idx]));
+ newOffsets.push_back(add);
+ }
+ finalOffsets.push_back(std::move(newOffsets));
+ }
+ return finalOffsets;
+}
+
+// Utility function to get sgShape, sgOffsetList for a given
+// op.
+template <typename OpTy, typename AdaptorTy>
+LogicalResult getSgOffsets(OpTy op, AdaptorTy adaptor,
+ ConversionPatternRewriter &rewriter,
+ SmallVector<int64_t> &sgShape,
+ SmallVector<SmallVector<Value>> &sgOffsetList) {
+ int64_t offsetSize = static_cast<int64_t>(op.getOffsets().size());
+ if (offsetSize == 0 && (!op.getConstOffsetsAttr()))
+ return failure();
+
+ Location loc = op.getLoc();
+ Value tdesc = op.getTensorDesc();
+ auto tdescTy = dyn_cast<xegpu::TensorDescType>(tdesc.getType());
+ if (!tdescTy)
+ return failure();
+ auto layout = dyn_cast<xegpu::LayoutAttr>(tdescTy.getLayout());
+ if (!layout)
+ return failure();
+
+ SmallVector<int64_t> sgLayout;
+ auto sgLayoutAttr = layout.getSgLayout();
+ if (!sgLayoutAttr)
+ return rewriter.notifyMatchFailure(
+ op, "sgLayout attribute is required in layout");
+ sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
+
+ ArrayRef<int64_t> wgShape = tdescTy.getShape();
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+ // Get the subgroup ID
+ Value linearSgId =
+ gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
+
+ int64_t startOfRange = -1, endOfRange = -1;
+ bool sgIdRangeSpecified = isSgIdRangeSpecified(op, startOfRange, endOfRange);
+
+ if (sgIdRangeSpecified) {
+ int64_t sgCount = endOfRange - startOfRange;
+ if (computeProduct(sgLayout) != sgCount)
+ return rewriter.notifyMatchFailure(
+ op, "sg_layout size must match the sg_id_range");
+ Value startOfRangeVal =
+ rewriter.create<arith::ConstantIndexOp>(loc, startOfRange);
+ linearSgId =
+ rewriter.createOrFold<index::SubOp>(loc, linearSgId, startOfRangeVal);
+ }
+
+ auto sgOffsets = layout.getOffsets(rewriter, loc, linearSgId, wgShape);
+ if (failed(sgOffsets))
+ return failure();
+
+ sgOffsetList = *sgOffsets;
+ return success();
+}
+
+template <typename OpTy>
+SmallVector<OpFoldResult> getOffsets(OpTy op,
+ ConversionPatternRewriter &rewriter) {
+ SmallVector<OpFoldResult> origOffsets;
+ if (auto constOffsets = op.getConstOffsetsAttr()) {
+ for (auto attr : constOffsets.asArrayRef())
+ origOffsets.push_back(rewriter.getIndexAttr(attr));
+ }
+ for (auto v : op.getOffsets())
+ origOffsets.push_back(v);
+ return origOffsets;
+}
+
+// This pattern transforms the LoadNdOp with explicit offsets to load
+// subgroup data.
+struct WgToSgLoadNdOpWithOffset : public OpConversionPattern<xegpu::LoadNdOp> {
+ using OpConversionPattern<xegpu::LoadNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::LoadNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<int64_t> sgShape;
+ SmallVector<SmallVector<Value>> sgOffsetList;
+
+ // Do the distribution from workgroup to subgroup and get subgroup offsets
+ if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
+ return failure();
+
+ // Get the original workgroup offsets
+ SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
+
+ // Calculate the final offsets for each subgroup
+ auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
+
+ SmallVector<Value> newLoadOps;
+ for (auto [offsets, tdesc] :
+ llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
+ VectorType newResTy = VectorType::get(
+ sgShape,
+ dyn_cast<xegpu::TensorDescType>(tdesc.getType()).getElementType());
+ auto newLoadOp = rewriter.create<xegpu::LoadNdOp>(
+ op.getLoc(), newResTy, tdesc, offsets,
+ /*packed=*/nullptr,
+ /*transpose=*/nullptr, op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ newLoadOps.push_back(newLoadOp);
+ }
+ rewriter.replaceOpWithMultiple(op, {newLoadOps});
+ return success();
+ }
+};
+
+// This pattern transforms the StoreNdOp with explicit offsets to store
+// subgroup data.
+struct WgToSgStoreNdOpWithOffset
+ : public OpConversionPattern<xegpu::StoreNdOp> {
+ using OpConversionPattern<xegpu::StoreNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::StoreNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<int64_t> sgShape;
+ SmallVector<SmallVector<Value>> sgOffsetList;
+
+ // Do the distribution from workgroup to subgroup and get subgroup offsets
+ if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
+ return failure();
+
+ // Get the original workgroup offsets
+ SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
+
+ // Calculate the final offsets for each subgroup
+ auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
+
+ for (auto [offsets, tdesc, value] :
+ llvm::zip(finalOffsets, adaptor.getTensorDesc(), adaptor.getValue())) {
+ rewriter.create<xegpu::StoreNdOp>(op.getLoc(), value, tdesc, offsets,
+ op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
+// This pattern transforms the PrefetchNdOp with explicit offsets to prefetch
+// subgroup data.
+struct WgToSgPrefetchNdOpWithOffset
+ : public OpConversionPattern<xegpu::PrefetchNdOp> {
+ using OpConversionPattern<xegpu::PrefetchNdOp>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::PrefetchNdOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+
+ SmallVector<int64_t> sgShape;
+ SmallVector<SmallVector<Value>> sgOffsetList;
+
+ // Do the distribution from workgroup to subgroup and get subgroup offsets
+ if (failed(getSgOffsets(op, adaptor, rewriter, sgShape, sgOffsetList)))
+ return failure();
+
+ // Get the original workgroup offsets
+ SmallVector<OpFoldResult> origOffsets = getOffsets(op, rewriter);
+
+ // Calculate the final offsets for each subgroup
+ auto finalOffsets = computeOffsets(op, sgOffsetList, origOffsets, rewriter);
+
+ for (auto [offsets, tdesc] :
+ llvm::zip(finalOffsets, adaptor.getTensorDesc())) {
+ rewriter.create<xegpu::PrefetchNdOp>(
+ op.getLoc(), tdesc, offsets, op.getL1HintAttr(), op.getL2HintAttr(),
+ op.getL3HintAttr());
+ }
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
/// This pattern transforms the UpdateNdOffsetOp to update the offsets of a
/// subgroup descriptor. It creates an UpdateNdOffsetOp op to update the
/// offsets of the new subgroup src tensor descriptors.
@@ -526,8 +723,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
// is lowered to:
// #a = #xegpu.layout<inst_data = [16, 16]>
// #b = #xegpu.layout<inst_data = [8, 16]>
-// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, matrix_desc<32x64xf32>
-// %d = load_matrix %slm <{layout_result_0 = #a}> : matrix_desc<32x64xf32> -> vector<16x32xf32>
+// store_matrix %1, %slm <{layout_input_0 = #a}> : vector<32x16>, mem_desc<32x64xf32>
+// %d = load_matrix %slm <{layout_result_0 = #a}> : mem_desc<32x64xf32> -> vector<16x32xf32>
// xegpu.convert_layout %d <{input_layout = #a, target_layout = #b}> : vector<16x32xf32>
// clang-format on
struct WgToSgConvertLayoutOp
@@ -649,16 +846,56 @@ struct UnrealizedConversionCastOpPattern
}
};
+// This pattern distributes arith.constant op into subgroup-level constants
+struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecAttr || !vecAttr.isSplat() || !vecType)
+ return failure();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ ArrayRef<int64_t> wgShape = vecType.getShape();
+ SmallVector<int64_t> sgShape;
+ int count;
+ std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
+
+ // Current limitation: constant of vector with single value.
+ // TODO: support more complex cases, e.g., vector with multiple values.
+ Attribute singleVal = vecAttr.getSplatValue<Attribute>();
+
+ auto newType = VectorType::get(sgShape, vecType.getElementType());
+ auto sgAttr = DenseElementsAttr::get(newType, singleVal);
+ auto cstOp =
+ arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
+ if (auto newLayout = layout.dropSgLayoutAndData())
+ xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
+ SmallVector<Value> newConsts(count, cstOp);
+
+ rewriter.replaceOpWithMultiple(op, {newConsts});
+ return success();
+ }
+};
+
} // namespace
namespace mlir {
namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
- patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
- WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
- UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
- WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
- patterns.getContext());
+ patterns
+ .add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
+ WgToSgLoadNdOpWithOffset, WgToSgStoreNdOp, WgToSgStoreNdOpWithOffset,
+ WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
+ WgToSgPrefetchNdOpWithOffset, UnrealizedConversionCastOpPattern,
+ WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
+ WgToSgArithConstantOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -770,6 +1007,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
});
+ target.addDynamicallyLegalOp<arith::ConstantOp>(
+ [=](arith::ConstantOp op) -> bool {
+ auto vecType = dyn_cast<VectorType>(op.getType());
+ if (!vecType)
+ return true;
+ return isLegal(xegpu::getLayoutAttr(op.getResult()));
+ });
+
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
diff --git a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
index 98e84a4..d9bf4a1 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
+++ b/mlir/lib/Dialect/XeGPU/Utils/CMakeLists.txt
@@ -7,5 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUUtils
LINK_LIBS PUBLIC
MLIRIR
MLIRSCFTransforms
+ MLIRGPUDialect
+ MLIRXeVMDialect
MLIRXeGPUDialect
)
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index 2cf21fb..19eedba 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -11,6 +11,8 @@
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
@@ -404,3 +406,21 @@ void xegpu::doSCFStructuralTypeConversionWithTensorType(
(void)mlir::applyPartialConversion(op, target, std::move(patterns));
}
}
+
+std::optional<std::string> xegpu::getChipStr(Operation *op) {
+ auto gpuModuleOp = op->getParentOfType<gpu::GPUModuleOp>();
+
+ if (!gpuModuleOp)
+ return std::nullopt;
+
+ auto targetAttrs = gpuModuleOp.getTargets();
+ if (targetAttrs) {
+ for (auto &attr : *targetAttrs) {
+ auto xevmAttr = llvm::dyn_cast<xevm::XeVMTargetAttr>(attr);
+ if (xevmAttr)
+ return xevmAttr.getChip().str();
+ }
+ }
+
+ return std::nullopt;
+}
diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index f704fbf..52162a4 100644
--- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -106,7 +106,7 @@ void ExecutionEngine::dumpToObjectFile(StringRef filename) {
}
// Compilation is lazy and it doesn't populate object cache unless requested.
// In case object dump is requested before cache is populated, we need to
- // force compilation manually.
+ // force compilation manually.
if (cache->isEmpty()) {
for (std::string &functionName : functionNames) {
auto result = lookupPacked(functionName);
@@ -400,13 +400,6 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
return symbolMap;
};
engine->registerSymbols(runtimeSymbolMap);
-
- // Execute the global constructors from the module being processed.
- // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
- // crash for AArch64 see related issue #71963.
- if (!engine->jit->getTargetTriple().isAArch64())
- cantFail(engine->jit->initialize(engine->jit->getMainJITDylib()));
-
return std::move(engine);
}
@@ -442,6 +435,7 @@ Expected<void *> ExecutionEngine::lookup(StringRef name) const {
Error ExecutionEngine::invokePacked(StringRef name,
MutableArrayRef<void *> args) {
+ initialize();
auto expectedFPtr = lookupPacked(name);
if (!expectedFPtr)
return expectedFPtr.takeError();
@@ -451,3 +445,13 @@ Error ExecutionEngine::invokePacked(StringRef name,
return Error::success();
}
+
+void ExecutionEngine::initialize() {
+ if (isInitialized)
+ return;
+ // TODO: Allow JIT initialize for AArch64. Currently there's a bug causing a
+ // crash for AArch64 see related issue #71963.
+ if (!jit->getTargetTriple().isAArch64())
+ cantFail(jit->initialize(jit->getMainJITDylib()));
+ isInitialized = true;
+}
diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp
index 2107df3..0ada4cc 100644
--- a/mlir/lib/ExecutionEngine/JitRunner.cpp
+++ b/mlir/lib/ExecutionEngine/JitRunner.cpp
@@ -202,6 +202,8 @@ compileAndExecute(Options &options, Operation *module, StringRef entryPoint,
auto engine = std::move(*expectedEngine);
+ engine->initialize();
+
auto expectedFPtr = engine->lookupPacked(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
diff --git a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
index 2f47939..af4ea5a 100644
--- a/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
+++ b/mlir/lib/Interfaces/Utils/InferIntRangeCommon.cpp
@@ -290,8 +290,7 @@ static ConstantIntRanges inferDivURange(const ConstantIntRanges &lhs,
DivisionFixupFn fixup) {
const APInt &lhsMin = lhs.umin(), &lhsMax = lhs.umax(), &rhsMin = rhs.umin(),
&rhsMax = rhs.umax();
-
- if (!rhsMin.isZero()) {
+ if (!rhsMin.isZero() && !rhsMax.isZero()) {
auto udiv = [&fixup](const APInt &a,
const APInt &b) -> std::optional<APInt> {
return fixup(a, b, a.udiv(b));
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
index 950b85e2..258fed1 100644
--- a/mlir/lib/RegisterAllDialects.cpp
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -102,6 +102,7 @@
#include "mlir/Interfaces/CastInterfaces.h"
#include "mlir/Target/LLVM/NVVM/Target.h"
#include "mlir/Target/LLVM/ROCDL/Target.h"
+#include "mlir/Target/LLVM/XeVM/Target.h"
#include "mlir/Target/SPIRV/Target.h"
/// Add all the MLIR dialects to the provided registry.
@@ -199,6 +200,7 @@ void mlir::registerAllDialects(DialectRegistry &registry) {
NVVM::registerNVVMTargetInterfaceExternalModels(registry);
ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
spirv::registerSPIRVTargetInterfaceExternalModels(registry);
+ xevm::registerXeVMTargetInterfaceExternalModels(registry);
}
/// Append all the MLIR dialects to the registry contained in the given context.
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
index 8f7c67c..232ddaf 100644
--- a/mlir/lib/RegisterAllExtensions.cpp
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -58,6 +58,7 @@
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
/// This function may be called to register all MLIR dialect extensions with the
/// provided registry.
diff --git a/mlir/lib/Target/CMakeLists.txt b/mlir/lib/Target/CMakeLists.txt
index 6eb0abc..f0c3ac4 100644
--- a/mlir/lib/Target/CMakeLists.txt
+++ b/mlir/lib/Target/CMakeLists.txt
@@ -4,3 +4,4 @@ add_subdirectory(SPIRV)
add_subdirectory(LLVMIR)
add_subdirectory(LLVM)
add_subdirectory(SMTLIB)
+add_subdirectory(Wasm)
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index 8e83e45..a5ee64c 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -1447,7 +1447,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
}
if (auto dense = dyn_cast<DenseIntElementsAttr>(attr)) {
if (auto iType = dyn_cast<IntegerType>(
- cast<TensorType>(dense.getType()).getElementType())) {
+ cast<ShapedType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os, [&](const APInt &val) {
printInt(val, shouldMapToUnsigned(iType.getSignedness()));
@@ -1456,7 +1456,7 @@ LogicalResult CppEmitter::emitAttribute(Location loc, Attribute attr) {
return success();
}
if (auto iType = dyn_cast<IndexType>(
- cast<TensorType>(dense.getType()).getElementType())) {
+ cast<ShapedType>(dense.getType()).getElementType())) {
os << '{';
interleaveComma(dense, os,
[&](const APInt &val) { printInt(val, false); });
diff --git a/mlir/lib/Target/LLVM/CMakeLists.txt b/mlir/lib/Target/LLVM/CMakeLists.txt
index f6e44c6..9a0e4d4 100644
--- a/mlir/lib/Target/LLVM/CMakeLists.txt
+++ b/mlir/lib/Target/LLVM/CMakeLists.txt
@@ -210,3 +210,27 @@ if(MLIR_ENABLE_ROCM_CONVERSIONS)
)
endif()
+if ("SPIRV" IN_LIST LLVM_TARGETS_TO_BUILD)
+ set(SPIRV_LIBS
+ SPIRVCodeGen
+ SPIRVDesc
+ SPIRVInfo
+ )
+endif()
+
+add_mlir_dialect_library(MLIRXeVMTarget
+ XeVM/Target.cpp
+
+ OBJECT
+
+ LINK_COMPONENTS
+ ${SPIRV_LIBS}
+
+ LINK_LIBS PUBLIC
+ MLIRIR
+ MLIRExecutionEngineUtils
+ MLIRSupport
+ MLIRGPUDialect
+ MLIRTargetLLVM
+ MLIRXeVMToLLVMIRTranslation
+)
diff --git a/mlir/lib/Target/LLVM/XeVM/Target.cpp b/mlir/lib/Target/LLVM/XeVM/Target.cpp
new file mode 100644
index 0000000..1e6784a2
--- /dev/null
+++ b/mlir/lib/Target/LLVM/XeVM/Target.cpp
@@ -0,0 +1,418 @@
+//===- Target.cpp - MLIR LLVM XeVM target compilation -----------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This files defines XeVM target related functions including registration
+// calls for the `#xevm.target` compilation attribute.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVM/XeVM/Target.h"
+
+#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/DialectResourceBlobManager.h"
+#include "mlir/Target/LLVM/XeVM/Utils.h"
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Export.h"
+#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/Target/TargetMachine.h"
+
+#include "llvm/Bitcode/BitcodeWriter.h"
+#include "llvm/Config/Targets.h"
+#include "llvm/Support/FileSystem.h"
+#include "llvm/Support/FileUtilities.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
+#include "llvm/Support/Process.h"
+#include "llvm/Support/Program.h"
+#include "llvm/Support/TargetSelect.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include <cstdint>
+#include <cstdlib>
+
+using namespace mlir;
+using namespace mlir::xevm;
+
+namespace {
+// XeVM implementation of the gpu:TargetAttrInterface.
+class XeVMTargetAttrImpl
+ : public gpu::TargetAttrInterface::FallbackModel<XeVMTargetAttrImpl> {
+public:
+ std::optional<SmallVector<char, 0>>
+ serializeToObject(Attribute attribute, Operation *module,
+ const gpu::TargetOptions &options) const;
+
+ Attribute createObject(Attribute attribute, Operation *module,
+ const SmallVector<char, 0> &object,
+ const gpu::TargetOptions &options) const;
+};
+} // namespace
+
+void mlir::xevm::registerXeVMTargetInterfaceExternalModels(
+ DialectRegistry &registry) {
+ registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) {
+ XeVMTargetAttr::attachInterface<XeVMTargetAttrImpl>(*ctx);
+ });
+}
+
+void mlir::xevm::registerXeVMTargetInterfaceExternalModels(
+ MLIRContext &context) {
+ DialectRegistry registry;
+ registerXeVMTargetInterfaceExternalModels(registry);
+ context.appendDialectRegistry(registry);
+}
+
+SerializeGPUModuleBase::SerializeGPUModuleBase(
+ Operation &module, XeVMTargetAttr xeTarget,
+ const gpu::TargetOptions &targetOptions)
+ : ModuleToObject(module, xeTarget.getTriple(), "", {}, xeTarget.getO()),
+ xeTarget(xeTarget), librariesToLink(targetOptions.getLibrariesToLink()),
+ targetOptions(targetOptions) {
+ if (xeTarget.getLinkFiles())
+ librariesToLink.append(xeTarget.getLinkFiles().begin(),
+ xeTarget.getLinkFiles().end());
+}
+
+XeVMTargetAttr SerializeGPUModuleBase::getTarget() const { return xeTarget; }
+
+std::optional<SmallVector<std::unique_ptr<llvm::Module>>>
+SerializeGPUModuleBase::loadBitcodeFiles(llvm::Module &module) {
+ if (librariesToLink.empty())
+ return SmallVector<std::unique_ptr<llvm::Module>>();
+ SmallVector<std::unique_ptr<llvm::Module>> bcFiles;
+ if (failed(loadBitcodeFilesFromList(module.getContext(), librariesToLink,
+ bcFiles)))
+ return std::nullopt;
+ return std::move(bcFiles);
+}
+
+gpu::GPUModuleOp SerializeGPUModuleBase::getGPUModuleOp() {
+ return dyn_cast<gpu::GPUModuleOp>(&SerializeGPUModuleBase::getOperation());
+}
+
+// There is 1 way to finalize IL to native code: IGC
+// There are 2 ways to access IGC: AOT (ocloc) and JIT (L0 runtime).
+// - L0 runtime consumes IL and is external to MLIR codebase (rt wrappers).
+// - `ocloc` tool can be "queried" from within MLIR.
+std::optional<SmallVector<char, 0>>
+SerializeGPUModuleBase::compileToBinary(const std::string &asmStr,
+ StringRef inputFormat) {
+ using TmpFile = std::pair<llvm::SmallString<128>, llvm::FileRemover>;
+ // Find the `ocloc` tool.
+ std::optional<std::string> oclocCompiler = findTool("ocloc");
+ if (!oclocCompiler)
+ return std::nullopt;
+ Location loc = getGPUModuleOp().getLoc();
+ std::string basename = llvm::formatv(
+ "mlir-{0}-{1}-{2}", getGPUModuleOp().getNameAttr().getValue(),
+ getTarget().getTriple(), getTarget().getChip());
+
+ auto createTemp = [&](StringRef name,
+ StringRef suffix) -> std::optional<TmpFile> {
+ llvm::SmallString<128> filePath;
+ if (auto ec = llvm::sys::fs::createTemporaryFile(name, suffix, filePath)) {
+ getGPUModuleOp().emitError()
+ << "Couldn't create the temp file: `" << filePath
+ << "`, error message: " << ec.message();
+ return std::nullopt;
+ }
+ return TmpFile(filePath, llvm::FileRemover(filePath.c_str()));
+ };
+ // Create temp file
+ std::optional<TmpFile> asmFile = createTemp(basename, "asm");
+ std::optional<TmpFile> binFile = createTemp(basename, "");
+ std::optional<TmpFile> logFile = createTemp(basename, "log");
+ if (!logFile || !asmFile || !binFile)
+ return std::nullopt;
+ // Dump the assembly to a temp file
+ std::error_code ec;
+ {
+ llvm::raw_fd_ostream asmStream(asmFile->first, ec);
+ if (ec) {
+ emitError(loc) << "Couldn't open the file: `" << asmFile->first
+ << "`, error message: " << ec.message();
+ return std::nullopt;
+ }
+ asmStream << asmStr;
+ if (asmStream.has_error()) {
+ emitError(loc) << "An error occurred while writing the assembly to: `"
+ << asmFile->first << "`.";
+ return std::nullopt;
+ }
+ asmStream.flush();
+ }
+ // Set cmd options
+ std::pair<llvm::BumpPtrAllocator, SmallVector<const char *>> cmdOpts =
+ targetOptions.tokenizeCmdOptions();
+ // Example: --gpu-module-to-binary="opts='opt1 opt2'"
+ const std::string cmdOptsStr = "\"" + llvm::join(cmdOpts.second, " ") + "\"";
+ SmallVector<StringRef, 12> oclocArgs(
+ {"ocloc", "compile", "-file", asmFile->first, inputFormat, "-device",
+ getTarget().getChip(), "-output", binFile->first, "-output_no_suffix",
+ "-options", cmdOptsStr});
+
+// Dump tool invocation commands.
+#define DEBUG_TYPE "serialize-to-binary"
+ LLVM_DEBUG({
+ llvm::dbgs() << "Tool invocation for module: "
+ << getGPUModuleOp().getNameAttr() << "\n";
+ llvm::interleave(oclocArgs, llvm::dbgs(), " ");
+ llvm::dbgs() << "\n";
+ });
+#undef DEBUG_TYPE
+ // Helper function for printing tool error logs.
+ std::string message;
+ auto emitLogError =
+ [&](StringRef toolName) -> std::optional<SmallVector<char, 0>> {
+ if (message.empty()) {
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> toolStderr =
+ llvm::MemoryBuffer::getFile(logFile->first);
+ if (toolStderr)
+ emitError(loc) << toolName << " invocation failed. Log:\n"
+ << toolStderr->get()->getBuffer();
+ else
+ emitError(loc) << toolName << " invocation failed.";
+ return std::nullopt;
+ }
+ emitError(loc) << toolName
+ << " invocation failed, error message: " << message;
+ return std::nullopt;
+ };
+ std::optional<StringRef> redirects[] = {
+ std::nullopt,
+ logFile->first,
+ logFile->first,
+ };
+ // Invoke ocloc.
+ if (llvm::sys::ExecuteAndWait(oclocCompiler.value(), oclocArgs, std::nullopt,
+ redirects, 0, 0, &message))
+ return emitLogError("`ocloc`");
+ binFile->first.append(".bin");
+ llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> binaryBuffer =
+ llvm::MemoryBuffer::getFile(binFile->first);
+ if (!binaryBuffer) {
+ emitError(loc) << "Couldn't open the file: `" << binFile->first
+ << "`, error message: " << binaryBuffer.getError().message();
+ return std::nullopt;
+ }
+ StringRef bin = (*binaryBuffer)->getBuffer();
+ return SmallVector<char, 0>(bin.begin(), bin.end());
+}
+
+std::optional<std::string> SerializeGPUModuleBase::findTool(StringRef tool) {
+ // 1. Check the toolkit path given in the command line.
+ StringRef pathRef = targetOptions.getToolkitPath();
+ SmallVector<char, 256> path;
+ if (!pathRef.empty()) {
+ path.insert(path.begin(), pathRef.begin(), pathRef.end());
+ llvm::sys::path::append(path, "bin", tool);
+ if (llvm::sys::fs::can_execute(path))
+ return StringRef(path.data(), path.size()).str();
+ }
+ // 2. Check PATH.
+ if (std::optional<std::string> toolPath =
+ llvm::sys::Process::FindInEnvPath("PATH", tool))
+ return *toolPath;
+
+ getGPUModuleOp().emitError()
+ << "Couldn't find the `" << tool
+ << "` binary. Please specify the toolkit "
+ "path via GpuModuleToBinaryPass or add the compiler to $PATH`.";
+ return std::nullopt;
+}
+
+namespace {
+class SPIRVSerializer : public SerializeGPUModuleBase {
+public:
+ SPIRVSerializer(Operation &module, XeVMTargetAttr xeTarget,
+ const gpu::TargetOptions &targetOptions)
+ : SerializeGPUModuleBase(module, xeTarget, targetOptions) {}
+
+ static void init();
+
+ /// Serializes the LLVM module to an object format, depending on the
+ /// compilation target selected in target options.
+ std::optional<SmallVector<char, 0>>
+ moduleToObject(llvm::Module &llvmModule) override;
+
+private:
+ /// Translates the LLVM module to SPIR-V binary using LLVM's
+ /// SPIR-V target.
+ std::optional<std::string>
+ translateToSPIRVBinary(llvm::Module &llvmModule,
+ llvm::TargetMachine &targetMachine);
+};
+} // namespace
+
+void SPIRVSerializer::init() {
+ static llvm::once_flag initializeBackendOnce;
+ llvm::call_once(initializeBackendOnce, []() {
+#if LLVM_HAS_SPIRV_TARGET
+ LLVMInitializeSPIRVTarget();
+ LLVMInitializeSPIRVTargetInfo();
+ LLVMInitializeSPIRVTargetMC();
+ LLVMInitializeSPIRVAsmPrinter();
+#endif
+ });
+}
+
+std::optional<SmallVector<char, 0>>
+SPIRVSerializer::moduleToObject(llvm::Module &llvmModule) {
+#define DEBUG_TYPE "serialize-to-llvm"
+ LLVM_DEBUG({
+ llvm::dbgs() << "LLVM IR for module: " << getGPUModuleOp().getNameAttr()
+ << "\n";
+ llvm::dbgs() << llvmModule << "\n";
+ llvm::dbgs().flush();
+ });
+#undef DEBUG_TYPE
+
+ // Return LLVM IR if the compilation target is `offload`.
+ if (targetOptions.getCompilationTarget() == gpu::CompilationTarget::Offload)
+ return SerializeGPUModuleBase::moduleToObject(llvmModule);
+
+#if !LLVM_HAS_SPIRV_TARGET
+ getGPUModuleOp()->emitError("The `SPIRV` target was not built. Please enable "
+ "it when building LLVM.");
+ return std::nullopt;
+#endif // LLVM_HAS_SPIRV_TARGET
+
+ std::optional<llvm::TargetMachine *> targetMachine =
+ getOrCreateTargetMachine();
+ if (!targetMachine) {
+ getGPUModuleOp().emitError() << "Target Machine unavailable for triple "
+ << triple << ", can't optimize with LLVM\n";
+ return std::nullopt;
+ }
+
+ // Return SPIRV if the compilation target is `assembly`.
+ if (targetOptions.getCompilationTarget() ==
+ gpu::CompilationTarget::Assembly) {
+ std::optional<std::string> serializedISA =
+ translateToISA(llvmModule, **targetMachine);
+ if (!serializedISA) {
+ getGPUModuleOp().emitError() << "Failed translating the module to ISA."
+ << triple << ", can't compile with LLVM\n";
+ return std::nullopt;
+ }
+
+#define DEBUG_TYPE "serialize-to-isa"
+ LLVM_DEBUG({
+ llvm::dbgs() << "SPIR-V for module: " << getGPUModuleOp().getNameAttr()
+ << "\n";
+ llvm::dbgs() << *serializedISA << "\n";
+ llvm::dbgs().flush();
+ });
+#undef DEBUG_TYPE
+
+ // Make sure to include the null terminator.
+ StringRef bin(serializedISA->c_str(), serializedISA->size() + 1);
+ return SmallVector<char, 0>(bin.begin(), bin.end());
+ }
+
+ // Level zero runtime is set up to accept SPIR-V binary
+ // translateToSPIRVBinary translates the LLVM module to SPIR-V binary
+ // using LLVM's SPIRV target.
+ // compileToBinary can be used in the future if level zero runtime
+ // implementation switches to native XeVM binary format.
+ std::optional<std::string> serializedSPIRVBinary =
+ translateToSPIRVBinary(llvmModule, **targetMachine);
+ if (!serializedSPIRVBinary) {
+ getGPUModuleOp().emitError() << "Failed translating the module to Binary.";
+ return std::nullopt;
+ }
+ if (serializedSPIRVBinary->size() % 4) {
+ getGPUModuleOp().emitError() << "SPIRV code size must be a multiple of 4.";
+ return std::nullopt;
+ }
+ StringRef bin(serializedSPIRVBinary->c_str(), serializedSPIRVBinary->size());
+ return SmallVector<char, 0>(bin.begin(), bin.end());
+}
+
+std::optional<std::string>
+SPIRVSerializer::translateToSPIRVBinary(llvm::Module &llvmModule,
+ llvm::TargetMachine &targetMachine) {
+ std::string targetISA;
+ llvm::raw_string_ostream stream(targetISA);
+
+ { // Drop pstream after this to prevent the ISA from being stuck buffering
+ llvm::buffer_ostream pstream(stream);
+ llvm::legacy::PassManager codegenPasses;
+ if (targetMachine.addPassesToEmitFile(codegenPasses, pstream, nullptr,
+ llvm::CodeGenFileType::ObjectFile))
+ return std::nullopt;
+
+ codegenPasses.run(llvmModule);
+ }
+ return targetISA;
+}
+
+std::optional<SmallVector<char, 0>>
+XeVMTargetAttrImpl::serializeToObject(Attribute attribute, Operation *module,
+ const gpu::TargetOptions &options) const {
+ if (!module)
+ return std::nullopt;
+ auto gpuMod = dyn_cast<gpu::GPUModuleOp>(module);
+ if (!gpuMod) {
+ module->emitError("expected to be a gpu.module op");
+ return std::nullopt;
+ }
+ auto xeTarget = cast<XeVMTargetAttr>(attribute);
+ if (xeTarget.getTriple().starts_with("spirv")) {
+ gpuMod.walk([&](LLVM::LLVMFuncOp funcOp) {
+ if (funcOp->hasAttr(gpu::GPUDialect::getKernelFuncAttrName())) {
+ funcOp.setIntelReqdSubGroupSize(16);
+ return WalkResult::interrupt();
+ }
+ return WalkResult::advance();
+ });
+
+ SPIRVSerializer serializer(*module, cast<XeVMTargetAttr>(attribute),
+ options);
+ serializer.init();
+
+#if !LLVM_HAS_SPIRV_TARGET
+ module->emitError("Cannot run `TargetRegistry::lookupTarget()` for SPIRV "
+ "without having the target built.");
+#endif
+
+ return serializer.run();
+ }
+ module->emitError("Unsupported XeVM target triple: ") << xeTarget.getTriple();
+ return std::nullopt;
+}
+
+Attribute
+XeVMTargetAttrImpl::createObject(Attribute attribute, Operation *module,
+ const SmallVector<char, 0> &object,
+ const gpu::TargetOptions &options) const {
+ Builder builder(attribute.getContext());
+ gpu::CompilationTarget format = options.getCompilationTarget();
+ auto xeTarget = cast<XeVMTargetAttr>(attribute);
+ SmallVector<NamedAttribute, 2> properties;
+ if (format == gpu::CompilationTarget::Assembly)
+ properties.push_back(
+ builder.getNamedAttr("O", builder.getI32IntegerAttr(xeTarget.getO())));
+
+ DictionaryAttr objectProps;
+ if (!properties.empty())
+ objectProps = builder.getDictionaryAttr(properties);
+
+ return builder.getAttr<gpu::ObjectAttr>(
+ attribute, format,
+ builder.getStringAttr(StringRef(object.data(), object.size())),
+ objectProps, /*kernels=*/nullptr);
+}
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index 90462d1..e67cfed 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -135,33 +135,83 @@ static llvm::Intrinsic::ID getVoteSyncIntrinsicId(NVVM::VoteSyncKind kind) {
llvm_unreachable("unsupported vote kind");
}
-/// Return the intrinsic ID associated with ldmatrix for the given paramters.
-static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
- int32_t num) {
- if (layout == NVVM::MMALayout::row) {
+static llvm::Intrinsic::ID
+getLdMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
+ NVVM::LdStMatrixShapeAttr shape,
+ NVVM::LdStMatrixEltType eltType) {
+ if (shape.getM() == 8 && shape.getN() == 8) {
switch (num) {
case 1:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16;
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16
+ : llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
case 2:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16;
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16
+ : llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
case 4:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16;
- default:
- llvm_unreachable("unsupported number of matrix");
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16
+ : llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
}
-
- } else {
- switch (num) {
- case 1:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16;
- case 2:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16;
- case 4:
- return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16;
- default:
- llvm_unreachable("unsupported number of matrix");
+ } else if (shape.getM() == 8 && shape.getN() == 16) {
+ if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32;
+ case 4:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32;
+ }
+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64;
+ case 4:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64;
+ }
+ }
+ } else if (shape.getM() == 16 && shape.getN() == 16) {
+ if (eltType == NVVM::LdStMatrixEltType::B8) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8;
+ case 2:
+ return llvm::Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8;
+ }
+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B6X16_P32) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32;
+ }
+ } else if (eltType == NVVM::LdStMatrixEltType::B8X16_B4X16_P64) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64;
+ case 2:
+ return llvm::Intrinsic::
+ nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64;
+ }
}
}
+ llvm_unreachable("unknown ldmatrix kind");
}
/// Return the intrinsic ID associated with stmatrix for the given paramters.
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2cdd502..6694de8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4356,9 +4356,11 @@ createAlteredByCaptureMap(MapInfoData &mapData,
if (!isPtrTy) {
auto curInsert = builder.saveIP();
+ llvm::DebugLoc DbgLoc = builder.getCurrentDebugLocation();
builder.restoreIP(findAllocaInsertPoint(builder, moduleTranslation));
auto *memTempAlloc =
builder.CreateAlloca(builder.getPtrTy(), nullptr, ".casted");
+ builder.SetCurrentDebugLocation(DbgLoc);
builder.restoreIP(curInsert);
builder.CreateStore(newV, memTempAlloc);
@@ -5865,6 +5867,10 @@ static bool isTargetDeviceOp(Operation *op) {
if (mlir::isa<omp::ThreadprivateOp>(op))
return true;
+ if (mlir::isa<omp::TargetAllocMemOp>(op) ||
+ mlir::isa<omp::TargetFreeMemOp>(op))
+ return true;
+
if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
if (auto declareTargetIface =
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(
@@ -5877,6 +5883,85 @@ static bool isTargetDeviceOp(Operation *op) {
return false;
}
+static llvm::Function *getOmpTargetAlloc(llvm::IRBuilderBase &builder,
+ llvm::Module *llvmModule) {
+ llvm::Type *i64Ty = builder.getInt64Ty();
+ llvm::Type *i32Ty = builder.getInt32Ty();
+ llvm::Type *returnType = builder.getPtrTy(0);
+ llvm::FunctionType *fnType =
+ llvm::FunctionType::get(returnType, {i64Ty, i32Ty}, false);
+ llvm::Function *func = cast<llvm::Function>(
+ llvmModule->getOrInsertFunction("omp_target_alloc", fnType).getCallee());
+ return func;
+}
+
+static LogicalResult
+convertTargetAllocMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ auto allocMemOp = cast<omp::TargetAllocMemOp>(opInst);
+ if (!allocMemOp)
+ return failure();
+
+ // Get "omp_target_alloc" function
+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+ llvm::Function *ompTargetAllocFunc = getOmpTargetAlloc(builder, llvmModule);
+ // Get the corresponding device value in llvm
+ mlir::Value deviceNum = allocMemOp.getDevice();
+ llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
+ // Get the allocation size.
+ llvm::DataLayout dataLayout = llvmModule->getDataLayout();
+ mlir::Type heapTy = allocMemOp.getAllocatedType();
+ llvm::Type *llvmHeapTy = moduleTranslation.convertType(heapTy);
+ llvm::TypeSize typeSize = dataLayout.getTypeStoreSize(llvmHeapTy);
+ llvm::Value *allocSize = builder.getInt64(typeSize.getFixedValue());
+ for (auto typeParam : allocMemOp.getTypeparams())
+ allocSize =
+ builder.CreateMul(allocSize, moduleTranslation.lookupValue(typeParam));
+ // Create call to "omp_target_alloc" with the args as translated llvm values.
+ llvm::CallInst *call =
+ builder.CreateCall(ompTargetAllocFunc, {allocSize, llvmDeviceNum});
+ llvm::Value *resultI64 = builder.CreatePtrToInt(call, builder.getInt64Ty());
+
+ // Map the result
+ moduleTranslation.mapValue(allocMemOp.getResult(), resultI64);
+ return success();
+}
+
+static llvm::Function *getOmpTargetFree(llvm::IRBuilderBase &builder,
+ llvm::Module *llvmModule) {
+ llvm::Type *ptrTy = builder.getPtrTy(0);
+ llvm::Type *i32Ty = builder.getInt32Ty();
+ llvm::Type *voidTy = builder.getVoidTy();
+ llvm::FunctionType *fnType =
+ llvm::FunctionType::get(voidTy, {ptrTy, i32Ty}, false);
+ llvm::Function *func = dyn_cast<llvm::Function>(
+ llvmModule->getOrInsertFunction("omp_target_free", fnType).getCallee());
+ return func;
+}
+
+static LogicalResult
+convertTargetFreeMemOp(Operation &opInst, llvm::IRBuilderBase &builder,
+ LLVM::ModuleTranslation &moduleTranslation) {
+ auto freeMemOp = cast<omp::TargetFreeMemOp>(opInst);
+ if (!freeMemOp)
+ return failure();
+
+ // Get "omp_target_free" function
+ llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
+ llvm::Function *ompTragetFreeFunc = getOmpTargetFree(builder, llvmModule);
+ // Get the corresponding device value in llvm
+ mlir::Value deviceNum = freeMemOp.getDevice();
+ llvm::Value *llvmDeviceNum = moduleTranslation.lookupValue(deviceNum);
+ // Get the corresponding heapref value in llvm
+ mlir::Value heapref = freeMemOp.getHeapref();
+ llvm::Value *llvmHeapref = moduleTranslation.lookupValue(heapref);
+ // Convert heapref int to ptr and call "omp_target_free"
+ llvm::Value *intToPtr =
+ builder.CreateIntToPtr(llvmHeapref, builder.getPtrTy(0));
+ builder.CreateCall(ompTragetFreeFunc, {intToPtr, llvmDeviceNum});
+ return success();
+}
+
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR (including
/// OpenMP runtime calls).
static LogicalResult
@@ -6051,6 +6136,12 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
// the omp.canonical_loop.
return applyUnrollHeuristic(op, builder, moduleTranslation);
})
+ .Case([&](omp::TargetAllocMemOp) {
+ return convertTargetAllocMemOp(*op, builder, moduleTranslation);
+ })
+ .Case([&](omp::TargetFreeMemOp) {
+ return convertTargetFreeMemOp(*op, builder, moduleTranslation);
+ })
.Default([&](Operation *inst) {
return inst->emitError()
<< "not yet implemented: " << inst->getName();
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index c967e86..d8c54ec 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -1560,7 +1560,19 @@ spirv::Deserializer::processConstantComposite(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
- if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
+ if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
+ SmallVector<Attribute> flattenedElems;
+ for (Attribute element : elements) {
+ if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(element)) {
+ for (auto value : denseElemAttr.getValues<Attribute>())
+ flattenedElems.push_back(value);
+ } else {
+ flattenedElems.push_back(element);
+ }
+ }
+ auto attr = DenseElementsAttr::get(tensorType, flattenedElems);
+ constantMap.try_emplace(resultID, attr, tensorType);
+ } else if (auto shapedType = dyn_cast<ShapedType>(resultType)) {
auto attr = DenseElementsAttr::get(shapedType, elements);
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index c049574..7fc7795 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
+#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
@@ -112,7 +113,9 @@ LogicalResult Serializer::serialize() {
// TODO: handle the other sections
processCapability();
- processExtension();
+ if (failed(processExtension())) {
+ return failure();
+ }
processMemoryModel();
processDebugInfo();
@@ -204,13 +207,24 @@ void Serializer::processDebugInfo() {
// TODO: Encode more debug instructions.
}
-void Serializer::processExtension() {
+LogicalResult Serializer::processExtension() {
llvm::SmallVector<uint32_t, 16> extName;
- for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
+ llvm::SmallSet<Extension, 4> deducedExts(
+ llvm::from_range, module.getVceTriple()->getExtensions());
+ auto nonSemanticInfoExt = spirv::Extension::SPV_KHR_non_semantic_info;
+ if (options.emitDebugInfo && !deducedExts.contains(nonSemanticInfoExt)) {
+ TargetEnvAttr targetEnvAttr = lookupTargetEnvOrDefault(module);
+ if (!is_contained(targetEnvAttr.getExtensions(), nonSemanticInfoExt))
+ return module.emitError(
+ "SPV_KHR_non_semantic_info extension not available");
+ deducedExts.insert(nonSemanticInfoExt);
+ }
+ for (spirv::Extension ext : deducedExts) {
extName.clear();
spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
}
+ return success();
}
void Serializer::processMemoryModel() {
@@ -956,6 +970,11 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
+ if (auto tensorArmType = dyn_cast<spirv::TensorArmType>(constType)) {
+ ArrayRef<int64_t> innerShape = tensorArmType.getShape().drop_front();
+ if (!innerShape.empty())
+ elementType = spirv::TensorArmType::get(innerShape, elementType);
+ }
// "If the Result Type is a cooperative matrix type, then there must be only
// one Constituent, with scalar type matching the cooperative matrix Component
@@ -979,30 +998,10 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
} else {
return 0;
}
- } else if (isa<spirv::TensorArmType>(constType)) {
- if (isZeroValue(valueAttr)) {
- encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
- {typeID, resultID});
- return resultID;
- }
- numberOfConstituents = shapedType.getNumElements();
- operands.reserve(numberOfConstituents + 2);
- for (int i = 0; i < numberOfConstituents; ++i) {
- uint32_t elementID = 0;
- if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
- elementID =
- elementType.isInteger(1)
- ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
- : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
- }
- if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
- elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
- }
- if (!elementID) {
- return 0;
- }
- operands.push_back(elementID);
- }
+ } else if (isa<spirv::TensorArmType>(constType) && isZeroValue(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ return resultID;
} else {
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 7047869..fb2cecd 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -102,7 +102,7 @@ private:
void processDebugInfo();
- void processExtension();
+ LogicalResult processExtension();
void processMemoryModel();
diff --git a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
index ac338d55..796354e 100644
--- a/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
+++ b/mlir/lib/Target/SPIRV/TranslateRegistration.cpp
@@ -21,8 +21,11 @@
#include "mlir/Target/SPIRV/Serialization.h"
#include "mlir/Tools/mlir-translate/Translation.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MemoryBuffer.h"
+#include "llvm/Support/Path.h"
#include "llvm/Support/SourceMgr.h"
+#include "llvm/Support/ToolOutputFile.h"
using namespace mlir;
@@ -76,24 +79,66 @@ void registerFromSPIRVTranslation() {
// Serialization registration
//===----------------------------------------------------------------------===//
-static LogicalResult serializeModule(spirv::ModuleOp module,
- raw_ostream &output) {
+static LogicalResult
+serializeModule(spirv::ModuleOp moduleOp, raw_ostream &output,
+ const spirv::SerializationOptions &options) {
SmallVector<uint32_t, 0> binary;
- if (failed(spirv::serialize(module, binary)))
+ if (failed(spirv::serialize(moduleOp, binary)))
return failure();
- output.write(reinterpret_cast<char *>(binary.data()),
- binary.size() * sizeof(uint32_t));
+ size_t sizeInBytes = binary.size() * sizeof(uint32_t);
+
+ output.write(reinterpret_cast<char *>(binary.data()), sizeInBytes);
+
+ if (options.saveModuleForValidation) {
+ size_t dirSeparator =
+ options.validationFilePrefix.find(llvm::sys::path::get_separator());
+ // If file prefix includes directory check if that directory exists.
+ if (dirSeparator != std::string::npos) {
+ llvm::StringRef parentDir =
+ llvm::sys::path::parent_path(options.validationFilePrefix);
+ if (!llvm::sys::fs::is_directory(parentDir))
+ return moduleOp.emitError(
+ "validation prefix directory does not exist\n");
+ }
+
+ SmallString<128> filename;
+ int fd = 0;
+
+ std::error_code errorCode = llvm::sys::fs::createUniqueFile(
+ options.validationFilePrefix + "%%%%%%.spv", fd, filename);
+ if (errorCode)
+ return moduleOp.emitError("error creating validation output file: ")
+ << errorCode.message() << "\n";
+
+ llvm::raw_fd_ostream validationOutput(fd, /*shouldClose=*/true);
+ validationOutput.write(reinterpret_cast<char *>(binary.data()),
+ sizeInBytes);
+ validationOutput.flush();
+ }
return mlir::success();
}
namespace mlir {
void registerToSPIRVTranslation() {
+ static llvm::cl::opt<std::string> validationFilesPrefix(
+ "spirv-save-validation-files-with-prefix",
+ llvm::cl::desc(
+ "When non-empty string is passed each serialized SPIR-V module is "
+ "saved to an additional file that starts with the given prefix. This "
+ "is used to generate separate binaries for validation, where "
+ "`--split-input-file` normally combines all outputs into one. The "
+ "one combined output (`-o`) is still written. Created files need to "
+ "be removed manually once processed."),
+ llvm::cl::init(""));
+
TranslateFromMLIRRegistration toBinary(
"serialize-spirv", "serialize SPIR-V dialect",
- [](spirv::ModuleOp module, raw_ostream &output) {
- return serializeModule(module, output);
+ [](spirv::ModuleOp moduleOp, raw_ostream &output) {
+ return serializeModule(moduleOp, output,
+ {true, false, !validationFilesPrefix.empty(),
+ validationFilesPrefix});
},
[](DialectRegistry &registry) {
registry.insert<spirv::SPIRVDialect>();
diff --git a/mlir/lib/Target/Wasm/CMakeLists.txt b/mlir/lib/Target/Wasm/CMakeLists.txt
new file mode 100644
index 0000000..890fc0ec
--- /dev/null
+++ b/mlir/lib/Target/Wasm/CMakeLists.txt
@@ -0,0 +1,13 @@
+add_mlir_translation_library(MLIRTargetWasmImport
+ TranslateRegistration.cpp
+ TranslateFromWasm.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/Target/Wasm
+
+ LINK_LIBS PUBLIC
+ MLIRWasmSSADialect
+ MLIRIR
+ MLIRSupport
+ MLIRTranslateLib
+)
diff --git a/mlir/lib/Target/Wasm/TranslateFromWasm.cpp b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
new file mode 100644
index 0000000..8d45052
--- /dev/null
+++ b/mlir/lib/Target/Wasm/TranslateFromWasm.cpp
@@ -0,0 +1,1245 @@
+//===- TranslateFromWasm.cpp - Translating to WasmSSA dialect -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the WebAssembly importer.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Location.h"
+#include "mlir/Target/Wasm/WasmBinaryEncoding.h"
+#include "mlir/Target/Wasm/WasmImporter.h"
+#include "llvm/ADT/Statistic.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/FormatVariadic.h"
+#include "llvm/Support/LEB128.h"
+
+#include <climits>
+#include <cstdint>
+#include <variant>
+
+#define DEBUG_TYPE "wasm-translate"
+
+static_assert(CHAR_BIT == 8,
+ "This code expects std::byte to be exactly 8 bits");
+
+using namespace mlir;
+using namespace mlir::wasm;
+using namespace mlir::wasmssa;
+
+namespace {
+using section_id_t = uint8_t;
+enum struct WasmSectionType : section_id_t {
+ CUSTOM = 0,
+ TYPE = 1,
+ IMPORT = 2,
+ FUNCTION = 3,
+ TABLE = 4,
+ MEMORY = 5,
+ GLOBAL = 6,
+ EXPORT = 7,
+ START = 8,
+ ELEMENT = 9,
+ CODE = 10,
+ DATA = 11,
+ DATACOUNT = 12
+};
+
+constexpr section_id_t highestWasmSectionID{
+ static_cast<section_id_t>(WasmSectionType::DATACOUNT)};
+
+#define APPLY_WASM_SEC_TRANSFORM \
+ WASM_SEC_TRANSFORM(CUSTOM) \
+ WASM_SEC_TRANSFORM(TYPE) \
+ WASM_SEC_TRANSFORM(IMPORT) \
+ WASM_SEC_TRANSFORM(FUNCTION) \
+ WASM_SEC_TRANSFORM(TABLE) \
+ WASM_SEC_TRANSFORM(MEMORY) \
+ WASM_SEC_TRANSFORM(GLOBAL) \
+ WASM_SEC_TRANSFORM(EXPORT) \
+ WASM_SEC_TRANSFORM(START) \
+ WASM_SEC_TRANSFORM(ELEMENT) \
+ WASM_SEC_TRANSFORM(CODE) \
+ WASM_SEC_TRANSFORM(DATA) \
+ WASM_SEC_TRANSFORM(DATACOUNT)
+
+template <WasmSectionType>
+constexpr const char *wasmSectionName = "";
+
+#define WASM_SEC_TRANSFORM(section) \
+ template <> \
+ [[maybe_unused]] constexpr const char \
+ *wasmSectionName<WasmSectionType::section> = #section;
+APPLY_WASM_SEC_TRANSFORM
+#undef WASM_SEC_TRANSFORM
+
+constexpr bool sectionShouldBeUnique(WasmSectionType secType) {
+ return secType != WasmSectionType::CUSTOM;
+}
+
+template <std::byte... Bytes>
+struct ByteSequence {};
+
+/// Template class for representing a byte sequence of only one byte
+template <std::byte Byte>
+struct UniqueByte : ByteSequence<Byte> {};
+
+[[maybe_unused]] constexpr ByteSequence<
+ WasmBinaryEncoding::Type::i32, WasmBinaryEncoding::Type::i64,
+ WasmBinaryEncoding::Type::f32, WasmBinaryEncoding::Type::f64,
+ WasmBinaryEncoding::Type::v128> valueTypesEncodings{};
+
+template <std::byte... allowedFlags>
+constexpr bool isValueOneOf(std::byte value,
+ ByteSequence<allowedFlags...> = {}) {
+ return ((value == allowedFlags) | ... | false);
+}
+
+template <std::byte... flags>
+constexpr bool isNotIn(std::byte value, ByteSequence<flags...> = {}) {
+ return !isValueOneOf<flags...>(value);
+}
+
+struct GlobalTypeRecord {
+ Type type;
+ bool isMutable;
+};
+
+struct TypeIdxRecord {
+ size_t id;
+};
+
+struct SymbolRefContainer {
+ FlatSymbolRefAttr symbol;
+};
+
+struct GlobalSymbolRefContainer : SymbolRefContainer {
+ Type globalType;
+};
+
+struct FunctionSymbolRefContainer : SymbolRefContainer {
+ FunctionType functionType;
+};
+
+using ImportDesc =
+ std::variant<TypeIdxRecord, TableType, LimitType, GlobalTypeRecord>;
+
+using parsed_inst_t = FailureOr<SmallVector<Value>>;
+
+struct WasmModuleSymbolTables {
+ SmallVector<FunctionSymbolRefContainer> funcSymbols;
+ SmallVector<GlobalSymbolRefContainer> globalSymbols;
+ SmallVector<SymbolRefContainer> memSymbols;
+ SmallVector<SymbolRefContainer> tableSymbols;
+ SmallVector<FunctionType> moduleFuncTypes;
+
+ std::string getNewSymbolName(StringRef prefix, size_t id) const {
+ return (prefix + Twine{id}).str();
+ }
+
+ std::string getNewFuncSymbolName() const {
+ auto id = funcSymbols.size();
+ return getNewSymbolName("func_", id);
+ }
+
+ std::string getNewGlobalSymbolName() const {
+ auto id = globalSymbols.size();
+ return getNewSymbolName("global_", id);
+ }
+
+ std::string getNewMemorySymbolName() const {
+ auto id = memSymbols.size();
+ return getNewSymbolName("mem_", id);
+ }
+
+ std::string getNewTableSymbolName() const {
+ auto id = tableSymbols.size();
+ return getNewSymbolName("table_", id);
+ }
+};
+
+class ParserHead;
+
+/// Wrapper around SmallVector to only allow access as push and pop on the
+/// stack. Makes sure that there are no "free accesses" on the stack to preserve
+/// its state.
+class ValueStack {
+private:
+ struct LabelLevel {
+ size_t stackIdx;
+ LabelLevelOpInterface levelOp;
+ };
+
+public:
+ bool empty() const { return values.empty(); }
+
+ size_t size() const { return values.size(); }
+
+ /// Pops values from the stack because they are being used in an operation.
+ /// @param operandTypes The list of expected types of the operation, used
+ /// to know how many values to pop and check if the types match the
+ /// expectation.
+ /// @param opLoc Location of the caller, used to report accurately the
+ /// location
+ /// if an error occurs.
+ /// @return Failure or the vector of popped values.
+ FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes,
+ Location *opLoc);
+
+ /// Push the results of an operation to the stack so they can be used in a
+ /// following operation.
+ /// @param results The list of results of the operation
+ /// @param opLoc Location of the caller, used to report accurately the
+ /// location
+ /// if an error occurs.
+ LogicalResult pushResults(ValueRange results, Location *opLoc);
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+ /// A simple dump function for debugging.
+ /// Writes output to llvm::dbgs().
+ LLVM_DUMP_METHOD void dump() const;
+#endif
+
+private:
+ SmallVector<Value> values;
+};
+
+using local_val_t = TypedValue<wasmssa::LocalRefType>;
+
+class ExpressionParser {
+public:
+ using locals_t = SmallVector<local_val_t>;
+ ExpressionParser(ParserHead &parser, WasmModuleSymbolTables const &symbols,
+ ArrayRef<local_val_t> initLocal)
+ : parser{parser}, symbols{symbols}, locals{initLocal} {}
+
+private:
+ template <std::byte opCode>
+ inline parsed_inst_t parseSpecificInstruction(OpBuilder &builder);
+
+ template <typename valueT>
+ parsed_inst_t
+ parseConstInst(OpBuilder &builder,
+ std::enable_if_t<std::is_arithmetic_v<valueT>> * = nullptr);
+
+ /// This function generates a dispatch tree to associate an opcode with a
+ /// parser. Parsers are registered by specialising the
+ /// `parseSpecificInstruction` function for the op code to handle.
+ ///
+ /// The dispatcher is generated by recursively creating all possible patterns
+ /// for an opcode and calling the relevant parser on the leaf.
+ ///
+ /// @tparam patternBitSize is the first bit for which the pattern is not fixed
+ ///
+ /// @tparam highBitPattern is the fixed pattern that this instance handles for
+ /// the 8-patternBitSize bits
+ template <size_t patternBitSize = 0, std::byte highBitPattern = std::byte{0}>
+ inline parsed_inst_t dispatchToInstParser(std::byte opCode,
+ OpBuilder &builder) {
+ static_assert(patternBitSize <= 8,
+ "PatternBitSize is outside of range of opcode space! "
+ "(expected at most 8 bits)");
+ if constexpr (patternBitSize < 8) {
+ constexpr std::byte bitSelect{1 << (7 - patternBitSize)};
+ constexpr std::byte nextHighBitPatternStem = highBitPattern << 1;
+ constexpr size_t nextPatternBitSize = patternBitSize + 1;
+ if ((opCode & bitSelect) != std::byte{0})
+ return dispatchToInstParser<nextPatternBitSize,
+ nextHighBitPatternStem | std::byte{1}>(
+ opCode, builder);
+ return dispatchToInstParser<nextPatternBitSize, nextHighBitPatternStem>(
+ opCode, builder);
+ } else {
+ return parseSpecificInstruction<highBitPattern>(builder);
+ }
+ }
+
+ struct ParseResultWithInfo {
+ SmallVector<Value> opResults;
+ std::byte endingByte;
+ };
+
+public:
+ template <std::byte ParseEndByte = WasmBinaryEncoding::endByte>
+ parsed_inst_t parse(OpBuilder &builder, UniqueByte<ParseEndByte> = {});
+
+ template <std::byte... ExpressionParseEnd>
+ FailureOr<ParseResultWithInfo>
+ parse(OpBuilder &builder,
+ ByteSequence<ExpressionParseEnd...> parsingEndFilters);
+
+ FailureOr<SmallVector<Value>> popOperands(TypeRange operandTypes) {
+ return valueStack.popOperands(operandTypes, &currentOpLoc.value());
+ }
+
+ LogicalResult pushResults(ValueRange results) {
+ return valueStack.pushResults(results, &currentOpLoc.value());
+ }
+
+private:
+ std::optional<Location> currentOpLoc;
+ ParserHead &parser;
+ [[maybe_unused]] WasmModuleSymbolTables const &symbols;
+ locals_t locals;
+ ValueStack valueStack;
+};
+
+class ParserHead {
+public:
+ ParserHead(StringRef src, StringAttr name) : head{src}, locName{name} {}
+ ParserHead(ParserHead &&) = default;
+
+private:
+ ParserHead(ParserHead const &other) = default;
+
+public:
+ auto getLocation() const {
+ return FileLineColLoc::get(locName, 0, anchorOffset + offset);
+ }
+
+ FailureOr<StringRef> consumeNBytes(size_t nBytes) {
+ LDBG() << "Consume " << nBytes << " bytes";
+ LDBG() << " Bytes remaining: " << size();
+ LDBG() << " Current offset: " << offset;
+ if (nBytes > size())
+ return emitError(getLocation(), "trying to extract ")
+ << nBytes << "bytes when only " << size() << "are available";
+
+ StringRef res = head.slice(offset, offset + nBytes);
+ offset += nBytes;
+ LDBG() << " Updated offset (+" << nBytes << "): " << offset;
+ return res;
+ }
+
+ FailureOr<std::byte> consumeByte() {
+ auto res = consumeNBytes(1);
+ if (failed(res))
+ return failure();
+ return std::byte{*res->bytes_begin()};
+ }
+
+ template <typename T>
+ FailureOr<T> parseLiteral();
+
+ FailureOr<uint32_t> parseVectorSize();
+
+private:
+ // TODO: This is equivalent to parseLiteral<uint32_t> and could be removed
+ // if parseLiteral specialization were moved here, but default GCC on Ubuntu
+ // 22.04 has bug with template specialization in class declaration
+ inline FailureOr<uint32_t> parseUI32();
+ inline FailureOr<int64_t> parseI64();
+
+public:
+ FailureOr<StringRef> parseName() {
+ FailureOr<uint32_t> size = parseVectorSize();
+ if (failed(size))
+ return failure();
+
+ return consumeNBytes(*size);
+ }
+
+ FailureOr<WasmSectionType> parseWasmSectionType() {
+ FailureOr<std::byte> id = consumeByte();
+ if (failed(id))
+ return failure();
+ if (std::to_integer<unsigned>(*id) > highestWasmSectionID)
+ return emitError(getLocation(), "invalid section ID: ")
+ << static_cast<int>(*id);
+ return static_cast<WasmSectionType>(*id);
+ }
+
+ FailureOr<LimitType> parseLimit(MLIRContext *ctx) {
+ using WasmLimits = WasmBinaryEncoding::LimitHeader;
+ FileLineColLoc limitLocation = getLocation();
+ FailureOr<std::byte> limitHeader = consumeByte();
+ if (failed(limitHeader))
+ return failure();
+
+ if (isNotIn<WasmLimits::bothLimits, WasmLimits::lowLimitOnly>(*limitHeader))
+ return emitError(limitLocation, "invalid limit header: ")
+ << static_cast<int>(*limitHeader);
+ FailureOr<uint32_t> minParse = parseUI32();
+ if (failed(minParse))
+ return failure();
+ std::optional<uint32_t> max{std::nullopt};
+ if (*limitHeader == WasmLimits::bothLimits) {
+ FailureOr<uint32_t> maxParse = parseUI32();
+ if (failed(maxParse))
+ return failure();
+ max = *maxParse;
+ }
+ return LimitType::get(ctx, *minParse, max);
+ }
+
+ FailureOr<Type> parseValueType(MLIRContext *ctx) {
+ FileLineColLoc typeLoc = getLocation();
+ FailureOr<std::byte> typeEncoding = consumeByte();
+ if (failed(typeEncoding))
+ return failure();
+ switch (*typeEncoding) {
+ case WasmBinaryEncoding::Type::i32:
+ return IntegerType::get(ctx, 32);
+ case WasmBinaryEncoding::Type::i64:
+ return IntegerType::get(ctx, 64);
+ case WasmBinaryEncoding::Type::f32:
+ return Float32Type::get(ctx);
+ case WasmBinaryEncoding::Type::f64:
+ return Float64Type::get(ctx);
+ case WasmBinaryEncoding::Type::v128:
+ return IntegerType::get(ctx, 128);
+ case WasmBinaryEncoding::Type::funcRef:
+ return wasmssa::FuncRefType::get(ctx);
+ case WasmBinaryEncoding::Type::externRef:
+ return wasmssa::ExternRefType::get(ctx);
+ default:
+ return emitError(typeLoc, "invalid value type encoding: ")
+ << static_cast<int>(*typeEncoding);
+ }
+ }
+
+ FailureOr<GlobalTypeRecord> parseGlobalType(MLIRContext *ctx) {
+ using WasmGlobalMut = WasmBinaryEncoding::GlobalMutability;
+ FailureOr<Type> typeParsed = parseValueType(ctx);
+ if (failed(typeParsed))
+ return failure();
+ FileLineColLoc mutLoc = getLocation();
+ FailureOr<std::byte> mutSpec = consumeByte();
+ if (failed(mutSpec))
+ return failure();
+ if (isNotIn<WasmGlobalMut::isConst, WasmGlobalMut::isMutable>(*mutSpec))
+ return emitError(mutLoc, "invalid global mutability specifier: ")
+ << static_cast<int>(*mutSpec);
+ return GlobalTypeRecord{*typeParsed, *mutSpec == WasmGlobalMut::isMutable};
+ }
+
+ FailureOr<TupleType> parseResultType(MLIRContext *ctx) {
+ FailureOr<uint32_t> nParamsParsed = parseVectorSize();
+ if (failed(nParamsParsed))
+ return failure();
+ uint32_t nParams = *nParamsParsed;
+ SmallVector<Type> res{};
+ res.reserve(nParams);
+ for (size_t i = 0; i < nParams; ++i) {
+ FailureOr<Type> parsedType = parseValueType(ctx);
+ if (failed(parsedType))
+ return failure();
+ res.push_back(*parsedType);
+ }
+ return TupleType::get(ctx, res);
+ }
+
+ FailureOr<FunctionType> parseFunctionType(MLIRContext *ctx) {
+ FileLineColLoc typeLoc = getLocation();
+ FailureOr<std::byte> funcTypeHeader = consumeByte();
+ if (failed(funcTypeHeader))
+ return failure();
+ if (*funcTypeHeader != WasmBinaryEncoding::Type::funcType)
+ return emitError(typeLoc, "invalid function type header byte. Expecting ")
+ << std::to_integer<unsigned>(WasmBinaryEncoding::Type::funcType)
+ << " got " << std::to_integer<unsigned>(*funcTypeHeader);
+ FailureOr<TupleType> inputTypes = parseResultType(ctx);
+ if (failed(inputTypes))
+ return failure();
+
+ FailureOr<TupleType> resTypes = parseResultType(ctx);
+ if (failed(resTypes))
+ return failure();
+
+ return FunctionType::get(ctx, inputTypes->getTypes(), resTypes->getTypes());
+ }
+
+ FailureOr<TypeIdxRecord> parseTypeIndex() {
+ FailureOr<uint32_t> res = parseUI32();
+ if (failed(res))
+ return failure();
+ return TypeIdxRecord{*res};
+ }
+
+ FailureOr<TableType> parseTableType(MLIRContext *ctx) {
+ FailureOr<Type> elmTypeParse = parseValueType(ctx);
+ if (failed(elmTypeParse))
+ return failure();
+ if (!isWasmRefType(*elmTypeParse))
+ return emitError(getLocation(), "invalid element type for table");
+ FailureOr<LimitType> limitParse = parseLimit(ctx);
+ if (failed(limitParse))
+ return failure();
+ return TableType::get(ctx, *elmTypeParse, *limitParse);
+ }
+
+ FailureOr<ImportDesc> parseImportDesc(MLIRContext *ctx) {
+ FileLineColLoc importLoc = getLocation();
+ FailureOr<std::byte> importType = consumeByte();
+ auto packager = [](auto parseResult) -> FailureOr<ImportDesc> {
+ if (llvm::failed(parseResult))
+ return failure();
+ return {*parseResult};
+ };
+ if (failed(importType))
+ return failure();
+ switch (*importType) {
+ case WasmBinaryEncoding::Import::typeID:
+ return packager(parseTypeIndex());
+ case WasmBinaryEncoding::Import::tableType:
+ return packager(parseTableType(ctx));
+ case WasmBinaryEncoding::Import::memType:
+ return packager(parseLimit(ctx));
+ case WasmBinaryEncoding::Import::globalType:
+ return packager(parseGlobalType(ctx));
+ default:
+ return emitError(importLoc, "invalid import type descriptor: ")
+ << static_cast<int>(*importType);
+ }
+ }
+
+ parsed_inst_t parseExpression(OpBuilder &builder,
+ WasmModuleSymbolTables const &symbols,
+ ArrayRef<local_val_t> locals = {}) {
+ auto eParser = ExpressionParser{*this, symbols, locals};
+ return eParser.parse(builder);
+ }
+
+ bool end() const { return curHead().empty(); }
+
+ ParserHead copy() const { return *this; }
+
+private:
+ StringRef curHead() const { return head.drop_front(offset); }
+
+ FailureOr<std::byte> peek() const {
+ if (end())
+ return emitError(
+ getLocation(),
+ "trying to peek at next byte, but input stream is empty");
+ return static_cast<std::byte>(curHead().front());
+ }
+
+ size_t size() const { return head.size() - offset; }
+
+ StringRef head;
+ StringAttr locName;
+ unsigned anchorOffset{0};
+ unsigned offset{0};
+};
+
+template <>
+FailureOr<float> ParserHead::parseLiteral<float>() {
+ auto bytes = consumeNBytes(4);
+ if (failed(bytes))
+ return failure();
+ float result;
+ std::memcpy(&result, bytes->bytes_begin(), 4);
+ return result;
+}
+
+template <>
+FailureOr<double> ParserHead::parseLiteral<double>() {
+ auto bytes = consumeNBytes(8);
+ if (failed(bytes))
+ return failure();
+ double result;
+ std::memcpy(&result, bytes->bytes_begin(), 8);
+ return result;
+}
+
+template <>
+FailureOr<uint32_t> ParserHead::parseLiteral<uint32_t>() {
+ char const *error = nullptr;
+ uint32_t res{0};
+ unsigned encodingSize{0};
+ StringRef src = curHead();
+ uint64_t decoded = llvm::decodeULEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+
+ if (std::isgreater(decoded, std::numeric_limits<uint32_t>::max()))
+ return emitError(getLocation()) << "literal does not fit on 32 bits";
+
+ res = static_cast<uint32_t>(decoded);
+ offset += encodingSize;
+ return res;
+}
+
+template <>
+FailureOr<int32_t> ParserHead::parseLiteral<int32_t>() {
+ char const *error = nullptr;
+ int32_t res{0};
+ unsigned encodingSize{0};
+ StringRef src = curHead();
+ int64_t decoded = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+ if (std::isgreater(decoded, std::numeric_limits<int32_t>::max()) ||
+ std::isgreater(std::numeric_limits<int32_t>::min(), decoded))
+ return emitError(getLocation()) << "literal does not fit on 32 bits";
+
+ res = static_cast<int32_t>(decoded);
+ offset += encodingSize;
+ return res;
+}
+
+template <>
+FailureOr<int64_t> ParserHead::parseLiteral<int64_t>() {
+ char const *error = nullptr;
+ unsigned encodingSize{0};
+ StringRef src = curHead();
+ int64_t res = llvm::decodeSLEB128(src.bytes_begin(), &encodingSize,
+ src.bytes_end(), &error);
+ if (error)
+ return emitError(getLocation(), error);
+
+ offset += encodingSize;
+ return res;
+}
+
+FailureOr<uint32_t> ParserHead::parseVectorSize() {
+ return parseLiteral<uint32_t>();
+}
+
+inline FailureOr<uint32_t> ParserHead::parseUI32() {
+ return parseLiteral<uint32_t>();
+}
+
+inline FailureOr<int64_t> ParserHead::parseI64() {
+ return parseLiteral<int64_t>();
+}
+
+template <std::byte opCode>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction(OpBuilder &) {
+ return emitError(*currentOpLoc, "unknown instruction opcode: ")
+ << static_cast<int>(opCode);
+}
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+void ValueStack::dump() const {
+ llvm::dbgs() << "================= Wasm ValueStack =======================\n";
+ llvm::dbgs() << "size: " << size() << "\n";
+ llvm::dbgs() << "<Top>"
+ << "\n";
+ // Stack is pushed to via push_back. Therefore the top of the stack is the
+ // end of the vector. Iterate in reverse so that the first thing we print
+ // is the top of the stack.
+ size_t stackSize = size();
+ for (size_t idx = 0; idx < stackSize; idx++) {
+ size_t actualIdx = stackSize - 1 - idx;
+ llvm::dbgs() << " ";
+ values[actualIdx].dump();
+ }
+ llvm::dbgs() << "<Bottom>"
+ << "\n";
+ llvm::dbgs() << "=========================================================\n";
+}
+#endif
+
+parsed_inst_t ValueStack::popOperands(TypeRange operandTypes, Location *opLoc) {
+ LDBG() << "Popping from ValueStack\n"
+ << " Elements(s) to pop: " << operandTypes.size() << "\n"
+ << " Current stack size: " << values.size();
+ if (operandTypes.size() > values.size())
+ return emitError(*opLoc,
+ "stack doesn't contain enough values. Trying to get ")
+ << operandTypes.size() << " operands on a stack containing only "
+ << values.size() << " values.";
+ size_t stackIdxOffset = values.size() - operandTypes.size();
+ SmallVector<Value> res{};
+ res.reserve(operandTypes.size());
+ for (size_t i{0}; i < operandTypes.size(); ++i) {
+ Value operand = values[i + stackIdxOffset];
+ Type stackType = operand.getType();
+ if (stackType != operandTypes[i])
+ return emitError(*opLoc, "invalid operand type on stack. Expecting ")
+ << operandTypes[i] << ", value on stack is of type " << stackType
+ << ".";
+ LDBG() << " POP: " << operand;
+ res.push_back(operand);
+ }
+ values.resize(values.size() - operandTypes.size());
+ LDBG() << " Updated stack size: " << values.size();
+ return res;
+}
+
+LogicalResult ValueStack::pushResults(ValueRange results, Location *opLoc) {
+ LDBG() << "Pushing to ValueStack\n"
+ << " Elements(s) to push: " << results.size() << "\n"
+ << " Current stack size: " << values.size();
+ for (Value val : results) {
+ if (!isWasmValueType(val.getType()))
+ return emitError(*opLoc, "invalid value type on stack: ")
+ << val.getType();
+ LDBG() << " PUSH: " << val;
+ values.push_back(val);
+ }
+
+ LDBG() << " Updated stack size: " << values.size();
+ return success();
+}
+
+template <std::byte EndParseByte>
+parsed_inst_t ExpressionParser::parse(OpBuilder &builder,
+ UniqueByte<EndParseByte> endByte) {
+ auto res = parse(builder, ByteSequence<EndParseByte>{});
+ if (failed(res))
+ return failure();
+ return res->opResults;
+}
+
+template <std::byte... ExpressionParseEnd>
+FailureOr<ExpressionParser::ParseResultWithInfo>
+ExpressionParser::parse(OpBuilder &builder,
+ ByteSequence<ExpressionParseEnd...> parsingEndFilters) {
+ SmallVector<Value> res;
+ for (;;) {
+ currentOpLoc = parser.getLocation();
+ FailureOr<std::byte> opCode = parser.consumeByte();
+ if (failed(opCode))
+ return failure();
+ if (isValueOneOf(*opCode, parsingEndFilters))
+ return {{res, *opCode}};
+ parsed_inst_t resParsed;
+ resParsed = dispatchToInstParser(*opCode, builder);
+ if (failed(resParsed))
+ return failure();
+ std::swap(res, *resParsed);
+ if (failed(pushResults(res)))
+ return failure();
+ }
+}
+
+template <typename T>
+inline Type buildLiteralType(OpBuilder &);
+
+template <>
+inline Type buildLiteralType<int32_t>(OpBuilder &builder) {
+ return builder.getI32Type();
+}
+
+template <>
+inline Type buildLiteralType<int64_t>(OpBuilder &builder) {
+ return builder.getI64Type();
+}
+
+template <>
+[[maybe_unused]] inline Type buildLiteralType<uint32_t>(OpBuilder &builder) {
+ return builder.getI32Type();
+}
+
+template <>
+[[maybe_unused]] inline Type buildLiteralType<uint64_t>(OpBuilder &builder) {
+ return builder.getI64Type();
+}
+
+template <>
+inline Type buildLiteralType<float>(OpBuilder &builder) {
+ return builder.getF32Type();
+}
+
+template <>
+inline Type buildLiteralType<double>(OpBuilder &builder) {
+ return builder.getF64Type();
+}
+
+template <typename ValT,
+ typename E = std::enable_if_t<std::is_arithmetic_v<ValT>>>
+struct AttrHolder;
+
+template <typename ValT>
+struct AttrHolder<ValT, std::enable_if_t<std::is_integral_v<ValT>>> {
+ using type = IntegerAttr;
+};
+
+template <typename ValT>
+struct AttrHolder<ValT, std::enable_if_t<std::is_floating_point_v<ValT>>> {
+ using type = FloatAttr;
+};
+
+template <typename ValT>
+using attr_holder_t = typename AttrHolder<ValT>::type;
+
+template <typename ValT,
+ typename EnableT = std::enable_if_t<std::is_arithmetic_v<ValT>>>
+attr_holder_t<ValT> buildLiteralAttr(OpBuilder &builder, ValT val) {
+ return attr_holder_t<ValT>::get(buildLiteralType<ValT>(builder), val);
+}
+
+template <typename valueT>
+parsed_inst_t ExpressionParser::parseConstInst(
+ OpBuilder &builder, std::enable_if_t<std::is_arithmetic_v<valueT>> *) {
+ auto parsedConstant = parser.parseLiteral<valueT>();
+ if (failed(parsedConstant))
+ return failure();
+ auto constOp =
+ ConstOp::create(builder, *currentOpLoc,
+ buildLiteralAttr<valueT>(builder, *parsedConstant));
+ return {{constOp.getResult()}};
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constI32>(OpBuilder &builder) {
+ return parseConstInst<int32_t>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constI64>(OpBuilder &builder) {
+ return parseConstInst<int64_t>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constFP32>(OpBuilder &builder) {
+ return parseConstInst<float>(builder);
+}
+
+template <>
+inline parsed_inst_t ExpressionParser::parseSpecificInstruction<
+ WasmBinaryEncoding::OpCode::constFP64>(OpBuilder &builder) {
+ return parseConstInst<double>(builder);
+}
+
+class WasmBinaryParser {
+private:
+ struct SectionRegistry {
+ using section_location_t = StringRef;
+
+ std::array<SmallVector<section_location_t>, highestWasmSectionID + 1>
+ registry;
+
+ template <WasmSectionType SecType>
+ std::conditional_t<sectionShouldBeUnique(SecType),
+ std::optional<section_location_t>,
+ ArrayRef<section_location_t>>
+ getContentForSection() const {
+ constexpr auto idx = static_cast<size_t>(SecType);
+ if constexpr (sectionShouldBeUnique(SecType)) {
+ return registry[idx].empty() ? std::nullopt
+ : std::make_optional(registry[idx][0]);
+ } else {
+ return registry[idx];
+ }
+ }
+
+ bool hasSection(WasmSectionType secType) const {
+ return !registry[static_cast<size_t>(secType)].empty();
+ }
+
+ ///
+ /// @returns success if registration valid, failure in case registration
+ /// can't be done (if another section of same type already exist and this
+ /// section type should only be present once)
+ ///
+ LogicalResult registerSection(WasmSectionType secType,
+ section_location_t location, Location loc) {
+ if (sectionShouldBeUnique(secType) && hasSection(secType))
+ return emitError(loc,
+ "trying to add a second instance of unique section");
+
+ registry[static_cast<size_t>(secType)].push_back(location);
+ emitRemark(loc, "Adding section with section ID ")
+ << static_cast<uint8_t>(secType);
+ return success();
+ }
+
+ LogicalResult populateFromBody(ParserHead ph) {
+ while (!ph.end()) {
+ FileLineColLoc sectionLoc = ph.getLocation();
+ FailureOr<WasmSectionType> secType = ph.parseWasmSectionType();
+ if (failed(secType))
+ return failure();
+
+ FailureOr<uint32_t> secSizeParsed = ph.parseLiteral<uint32_t>();
+ if (failed(secSizeParsed))
+ return failure();
+
+ uint32_t secSize = *secSizeParsed;
+ FailureOr<StringRef> sectionContent = ph.consumeNBytes(secSize);
+ if (failed(sectionContent))
+ return failure();
+
+ LogicalResult registration =
+ registerSection(*secType, *sectionContent, sectionLoc);
+
+ if (failed(registration))
+ return failure();
+ }
+ return success();
+ }
+ };
+
+ auto getLocation(int offset = 0) const {
+ return FileLineColLoc::get(srcName, 0, offset);
+ }
+
+ template <WasmSectionType>
+ LogicalResult parseSectionItem(ParserHead &, size_t);
+
+ template <WasmSectionType section>
+ LogicalResult parseSection() {
+ auto secName = std::string{wasmSectionName<section>};
+ auto sectionNameAttr =
+ StringAttr::get(ctx, srcName.strref() + ":" + secName + "-SECTION");
+ unsigned offset = 0;
+ auto getLocation = [sectionNameAttr, &offset]() {
+ return FileLineColLoc::get(sectionNameAttr, 0, offset);
+ };
+ auto secContent = registry.getContentForSection<section>();
+ if (!secContent) {
+ LDBG() << secName << " section is not present in file.";
+ return success();
+ }
+
+ auto secSrc = secContent.value();
+ ParserHead ph{secSrc, sectionNameAttr};
+ FailureOr<uint32_t> nElemsParsed = ph.parseVectorSize();
+ if (failed(nElemsParsed))
+ return failure();
+ uint32_t nElems = *nElemsParsed;
+ LDBG() << "Starting to parse " << nElems << " items for section "
+ << secName;
+ for (size_t i = 0; i < nElems; ++i) {
+ if (failed(parseSectionItem<section>(ph, i)))
+ return failure();
+ }
+
+ if (!ph.end())
+ return emitError(getLocation(), "unparsed garbage at end of section ")
+ << secName;
+ return success();
+ }
+
+ /// Handles the registration of a function import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, TypeIdxRecord tid) {
+ using llvm::Twine;
+ if (tid.id >= symbols.moduleFuncTypes.size())
+ return emitError(loc, "invalid type id: ")
+ << tid.id << ". Only " << symbols.moduleFuncTypes.size()
+ << " type registration.";
+ FunctionType type = symbols.moduleFuncTypes[tid.id];
+ std::string symbol = symbols.getNewFuncSymbolName();
+ auto funcOp = FuncImportOp::create(builder, loc, symbol, moduleName,
+ importName, type);
+ symbols.funcSymbols.push_back({{FlatSymbolRefAttr::get(funcOp)}, type});
+ return funcOp.verify();
+ }
+
+ /// Handles the registration of a memory import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, LimitType limitType) {
+ std::string symbol = symbols.getNewMemorySymbolName();
+ auto memOp = MemImportOp::create(builder, loc, symbol, moduleName,
+ importName, limitType);
+ symbols.memSymbols.push_back({FlatSymbolRefAttr::get(memOp)});
+ return memOp.verify();
+ }
+
+ /// Handles the registration of a table import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, TableType tableType) {
+ std::string symbol = symbols.getNewTableSymbolName();
+ auto tableOp = TableImportOp::create(builder, loc, symbol, moduleName,
+ importName, tableType);
+ symbols.tableSymbols.push_back({FlatSymbolRefAttr::get(tableOp)});
+ return tableOp.verify();
+ }
+
+ /// Handles the registration of a global variable import
+ LogicalResult visitImport(Location loc, StringRef moduleName,
+ StringRef importName, GlobalTypeRecord globalType) {
+ std::string symbol = symbols.getNewGlobalSymbolName();
+ auto giOp =
+ GlobalImportOp::create(builder, loc, symbol, moduleName, importName,
+ globalType.type, globalType.isMutable);
+ symbols.globalSymbols.push_back(
+ {{FlatSymbolRefAttr::get(giOp)}, giOp.getType()});
+ return giOp.verify();
+ }
+
+ // Detect occurence of errors
+ LogicalResult peekDiag(Diagnostic &diag) {
+ if (diag.getSeverity() == DiagnosticSeverity::Error)
+ isValid = false;
+ return failure();
+ }
+
+public:
+ WasmBinaryParser(llvm::SourceMgr &sourceMgr, MLIRContext *ctx)
+ : builder{ctx}, ctx{ctx} {
+ ctx->getDiagEngine().registerHandler(
+ [this](Diagnostic &diag) { return peekDiag(diag); });
+ ctx->loadAllAvailableDialects();
+ if (sourceMgr.getNumBuffers() != 1) {
+ emitError(UnknownLoc::get(ctx), "one source file should be provided");
+ return;
+ }
+ uint32_t sourceBufId = sourceMgr.getMainFileID();
+ StringRef source = sourceMgr.getMemoryBuffer(sourceBufId)->getBuffer();
+ srcName = StringAttr::get(
+ ctx, sourceMgr.getMemoryBuffer(sourceBufId)->getBufferIdentifier());
+
+ auto parser = ParserHead{source, srcName};
+ auto const wasmHeader = StringRef{"\0asm", 4};
+ FileLineColLoc magicLoc = parser.getLocation();
+ FailureOr<StringRef> magic = parser.consumeNBytes(wasmHeader.size());
+ if (failed(magic) || magic->compare(wasmHeader)) {
+ emitError(magicLoc, "source file does not contain valid Wasm header.");
+ return;
+ }
+ auto const expectedVersionString = StringRef{"\1\0\0\0", 4};
+ FileLineColLoc versionLoc = parser.getLocation();
+ FailureOr<StringRef> version =
+ parser.consumeNBytes(expectedVersionString.size());
+ if (failed(version))
+ return;
+ if (version->compare(expectedVersionString)) {
+ emitError(versionLoc,
+ "unsupported Wasm version. Only version 1 is supported.");
+ return;
+ }
+ LogicalResult fillRegistry = registry.populateFromBody(parser.copy());
+ if (failed(fillRegistry))
+ return;
+
+ mOp = ModuleOp::create(builder, getLocation());
+ builder.setInsertionPointToStart(&mOp.getBodyRegion().front());
+ LogicalResult parsingTypes = parseSection<WasmSectionType::TYPE>();
+ if (failed(parsingTypes))
+ return;
+
+ LogicalResult parsingImports = parseSection<WasmSectionType::IMPORT>();
+ if (failed(parsingImports))
+ return;
+
+ firstInternalFuncID = symbols.funcSymbols.size();
+
+ LogicalResult parsingFunctions = parseSection<WasmSectionType::FUNCTION>();
+ if (failed(parsingFunctions))
+ return;
+
+ LogicalResult parsingTables = parseSection<WasmSectionType::TABLE>();
+ if (failed(parsingTables))
+ return;
+
+ LogicalResult parsingMems = parseSection<WasmSectionType::MEMORY>();
+ if (failed(parsingMems))
+ return;
+
+ LogicalResult parsingExports = parseSection<WasmSectionType::EXPORT>();
+ if (failed(parsingExports))
+ return;
+
+ // Copy over sizes of containers into statistics.
+ LDBG() << "WASM Imports:"
+ << "\n"
+ << " - Num functions: " << symbols.funcSymbols.size() << "\n"
+ << " - Num globals: " << symbols.globalSymbols.size() << "\n"
+ << " - Num memories: " << symbols.memSymbols.size() << "\n"
+ << " - Num tables: " << symbols.tableSymbols.size();
+ }
+
+ ModuleOp getModule() {
+ if (isValid)
+ return mOp;
+ if (mOp)
+ mOp.erase();
+ return ModuleOp{};
+ }
+
+private:
+ mlir::StringAttr srcName;
+ OpBuilder builder;
+ WasmModuleSymbolTables symbols;
+ MLIRContext *ctx;
+ ModuleOp mOp;
+ SectionRegistry registry;
+ size_t firstInternalFuncID{0};
+ bool isValid{true};
+};
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::IMPORT>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc importLoc = ph.getLocation();
+ auto moduleName = ph.parseName();
+ if (failed(moduleName))
+ return failure();
+
+ auto importName = ph.parseName();
+ if (failed(importName))
+ return failure();
+
+ FailureOr<ImportDesc> import = ph.parseImportDesc(ctx);
+ if (failed(import))
+ return failure();
+
+ return std::visit(
+ [this, importLoc, &moduleName, &importName](auto import) {
+ return visitImport(importLoc, *moduleName, *importName, import);
+ },
+ *import);
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::EXPORT>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc exportLoc = ph.getLocation();
+
+ auto exportName = ph.parseName();
+ if (failed(exportName))
+ return failure();
+
+ FailureOr<std::byte> opcode = ph.consumeByte();
+ if (failed(opcode))
+ return failure();
+
+ FailureOr<uint32_t> idx = ph.parseLiteral<uint32_t>();
+ if (failed(idx))
+ return failure();
+
+ using SymbolRefDesc = std::variant<SmallVector<SymbolRefContainer>,
+ SmallVector<GlobalSymbolRefContainer>,
+ SmallVector<FunctionSymbolRefContainer>>;
+
+ SymbolRefDesc currentSymbolList;
+ std::string symbolType = "";
+ switch (*opcode) {
+ case WasmBinaryEncoding::Export::function:
+ symbolType = "function";
+ currentSymbolList = symbols.funcSymbols;
+ break;
+ case WasmBinaryEncoding::Export::table:
+ symbolType = "table";
+ currentSymbolList = symbols.tableSymbols;
+ break;
+ case WasmBinaryEncoding::Export::memory:
+ symbolType = "memory";
+ currentSymbolList = symbols.memSymbols;
+ break;
+ case WasmBinaryEncoding::Export::global:
+ symbolType = "global";
+ currentSymbolList = symbols.globalSymbols;
+ break;
+ default:
+ return emitError(exportLoc, "invalid value for export type: ")
+ << std::to_integer<unsigned>(*opcode);
+ }
+
+ auto currentSymbol = std::visit(
+ [&](const auto &list) -> FailureOr<FlatSymbolRefAttr> {
+ if (*idx > list.size()) {
+ emitError(
+ exportLoc,
+ llvm::formatv(
+ "trying to export {0} {1} which is undefined in this scope",
+ symbolType, *idx));
+ return failure();
+ }
+ return list[*idx].symbol;
+ },
+ currentSymbolList);
+
+ if (failed(currentSymbol))
+ return failure();
+
+ Operation *op = SymbolTable::lookupSymbolIn(mOp, *currentSymbol);
+ SymbolTable::setSymbolVisibility(op, SymbolTable::Visibility::Public);
+ StringAttr symName = SymbolTable::getSymbolName(op);
+ return SymbolTable{mOp}.rename(symName, *exportName);
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::TABLE>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc opLocation = ph.getLocation();
+ FailureOr<TableType> tableType = ph.parseTableType(ctx);
+ if (failed(tableType))
+ return failure();
+ LDBG() << " Parsed table description: " << *tableType;
+ StringAttr symbol = builder.getStringAttr(symbols.getNewTableSymbolName());
+ auto tableOp =
+ TableOp::create(builder, opLocation, symbol.strref(), *tableType);
+ symbols.tableSymbols.push_back({SymbolRefAttr::get(tableOp)});
+ return success();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::FUNCTION>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc opLoc = ph.getLocation();
+ auto typeIdxParsed = ph.parseLiteral<uint32_t>();
+ if (failed(typeIdxParsed))
+ return failure();
+ uint32_t typeIdx = *typeIdxParsed;
+ if (typeIdx >= symbols.moduleFuncTypes.size())
+ return emitError(getLocation(), "invalid type index: ") << typeIdx;
+ std::string symbol = symbols.getNewFuncSymbolName();
+ auto funcOp =
+ FuncOp::create(builder, opLoc, symbol, symbols.moduleFuncTypes[typeIdx]);
+ Block *block = funcOp.addEntryBlock();
+ auto ip = builder.saveInsertionPoint();
+ builder.setInsertionPointToEnd(block);
+ ReturnOp::create(builder, opLoc);
+ builder.restoreInsertionPoint(ip);
+ symbols.funcSymbols.push_back(
+ {{FlatSymbolRefAttr::get(funcOp.getSymNameAttr())},
+ symbols.moduleFuncTypes[typeIdx]});
+ return funcOp.verify();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::TYPE>(ParserHead &ph,
+ size_t) {
+ FailureOr<FunctionType> funcType = ph.parseFunctionType(ctx);
+ if (failed(funcType))
+ return failure();
+ LDBG() << "Parsed function type " << *funcType;
+ symbols.moduleFuncTypes.push_back(*funcType);
+ return success();
+}
+
+template <>
+LogicalResult
+WasmBinaryParser::parseSectionItem<WasmSectionType::MEMORY>(ParserHead &ph,
+ size_t) {
+ FileLineColLoc opLocation = ph.getLocation();
+ FailureOr<LimitType> memory = ph.parseLimit(ctx);
+ if (failed(memory))
+ return failure();
+
+ LDBG() << " Registering memory " << *memory;
+ std::string symbol = symbols.getNewMemorySymbolName();
+ auto memOp = MemOp::create(builder, opLocation, symbol, *memory);
+ symbols.memSymbols.push_back({SymbolRefAttr::get(memOp)});
+ return success();
+}
+} // namespace
+
+namespace mlir::wasm {
+OwningOpRef<ModuleOp> importWebAssemblyToModule(llvm::SourceMgr &source,
+ MLIRContext *context) {
+ WasmBinaryParser wBN{source, context};
+ ModuleOp mOp = wBN.getModule();
+ if (mOp)
+ return {mOp};
+
+ return {nullptr};
+}
+} // namespace mlir::wasm
diff --git a/mlir/lib/Target/Wasm/TranslateRegistration.cpp b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
new file mode 100644
index 0000000..03b9784
--- /dev/null
+++ b/mlir/lib/Target/Wasm/TranslateRegistration.cpp
@@ -0,0 +1,28 @@
+//===- TranslateRegistration.cpp - Register translation -------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/WasmSSA/IR/WasmSSA.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/IR/OwningOpRef.h"
+#include "mlir/Target/Wasm/WasmImporter.h"
+#include "mlir/Tools/mlir-translate/Translation.h"
+
+using namespace mlir;
+
+namespace mlir {
+void registerFromWasmTranslation() {
+ TranslateToMLIRRegistration registration{
+ "import-wasm", "Translate WASM to MLIR",
+ [](llvm::SourceMgr &sourceMgr,
+ MLIRContext *context) -> OwningOpRef<Operation *> {
+ return wasm::importWebAssemblyToModule(sourceMgr, context);
+ },
+ [](DialectRegistry &registry) {
+ registry.insert<wasmssa::WasmSSADialect>();
+ }};
+}
+} // namespace mlir
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 4ccb83f..02dad69 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -258,18 +258,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
RDVFinalCleanupList &cl) {
- LDBG() << "Processing simple op: " << *op;
if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
- LDBG()
- << "Simple op is not memory effect free or has live results, skipping: "
- << *op;
+ LDBG() << "Simple op is not memory effect free or has live results, "
+ "preserving it: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
return;
}
LDBG()
<< "Simple op has all dead results and is memory effect free, scheduling "
"for removal: "
- << *op;
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
cl.operations.push_back(op);
collectNonLiveValues(nonLiveSet, op->getResults(),
BitVector(op->getNumResults(), true));
@@ -728,19 +727,31 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
/// Removes dead values collected in RDVFinalCleanupList.
/// To be run once when all dead values have been collected.
static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+ LDBG() << "Starting cleanup of dead values...";
+
// 1. Operations
+ LDBG() << "Cleaning up " << list.operations.size() << " operations";
for (auto &op : list.operations) {
+ LDBG() << "Erasing operation: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
op->dropAllUses();
op->erase();
}
// 2. Values
+ LDBG() << "Cleaning up " << list.values.size() << " values";
for (auto &v : list.values) {
+ LDBG() << "Dropping all uses of value: " << v;
v.dropAllUses();
}
// 3. Functions
+ LDBG() << "Cleaning up " << list.functions.size() << " functions";
for (auto &f : list.functions) {
+ LDBG() << "Cleaning up function: " << f.funcOp.getOperation()->getName();
+ LDBG() << " Erasing " << f.nonLiveArgs.count() << " non-live arguments";
+ LDBG() << " Erasing " << f.nonLiveRets.count()
+ << " non-live return values";
// Some functions may not allow erasing arguments or results. These calls
// return failure in such cases without modifying the function, so it's okay
// to proceed.
@@ -749,44 +760,67 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
}
// 4. Operands
+ LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
for (OperationToCleanup &o : list.operands) {
- if (o.op->getNumOperands() > 0)
+ if (o.op->getNumOperands() > 0) {
+ LDBG() << "Erasing " << o.nonLive.count()
+ << " non-live operands from operation: "
+ << OpWithFlags(o.op, OpPrintingFlags().skipRegions());
o.op->eraseOperands(o.nonLive);
+ }
}
// 5. Results
+ LDBG() << "Cleaning up " << list.results.size() << " result lists";
for (auto &r : list.results) {
+ LDBG() << "Erasing " << r.nonLive.count()
+ << " non-live results from operation: "
+ << OpWithFlags(r.op, OpPrintingFlags().skipRegions());
dropUsesAndEraseResults(r.op, r.nonLive);
}
// 6. Blocks
+ LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
for (auto &b : list.blocks) {
// blocks that are accessed via multiple codepaths processed once
if (b.b->getNumArguments() != b.nonLiveArgs.size())
continue;
+ LDBG() << "Erasing " << b.nonLiveArgs.count()
+ << " non-live arguments from block: " << b.b;
// it iterates backwards because erase invalidates all successor indexes
for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
if (!b.nonLiveArgs[i])
continue;
+ LDBG() << " Erasing block argument " << i << ": " << b.b->getArgument(i);
b.b->getArgument(i).dropAllUses();
b.b->eraseArgument(i);
}
}
// 7. Successor Operands
+ LDBG() << "Cleaning up " << list.successorOperands.size()
+ << " successor operand lists";
for (auto &op : list.successorOperands) {
SuccessorOperands successorOperands =
op.branch.getSuccessorOperands(op.successorIndex);
// blocks that are accessed via multiple codepaths processed once
if (successorOperands.size() != op.nonLiveOperands.size())
continue;
+ LDBG() << "Erasing " << op.nonLiveOperands.count()
+ << " non-live successor operands from successor "
+ << op.successorIndex << " of branch: "
+ << OpWithFlags(op.branch, OpPrintingFlags().skipRegions());
// it iterates backwards because erase invalidates all successor indexes
for (int i = successorOperands.size() - 1; i >= 0; --i) {
if (!op.nonLiveOperands[i])
continue;
+ LDBG() << " Erasing successor operand " << i << ": "
+ << successorOperands[i];
successorOperands.erase(i);
}
}
+
+ LDBG() << "Finished cleanup of dead values";
}
struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 0c26b4e..7494ca9 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -182,15 +182,24 @@ private:
/// conversions.)
static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
+/// Return the operation that defines all values in the vector. Return nullptr
+/// if the values are not defined by the same operation.
+static Operation *getCommonDefiningOp(const ValueVector &values) {
+ assert(!values.empty() && "expected non-empty value vector");
+ Operation *op = values.front().getDefiningOp();
+ for (Value v : llvm::drop_begin(values)) {
+ if (v.getDefiningOp() != op)
+ return nullptr;
+ }
+ return op;
+}
+
/// A vector of values is a pure type conversion if all values are defined by
/// the same operation and the operation has the `kPureTypeConversionMarker`
/// attribute.
static bool isPureTypeConversion(const ValueVector &values) {
assert(!values.empty() && "expected non-empty value vector");
- Operation *op = values.front().getDefiningOp();
- for (Value v : llvm::drop_begin(values))
- if (v.getDefiningOp() != op)
- return false;
+ Operation *op = getCommonDefiningOp(values);
return op && op->hasAttr(kPureTypeConversionMarker);
}
@@ -841,7 +850,7 @@ namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
- : context(ctx), config(config) {}
+ : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -863,6 +872,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// failure.
template <typename RewriteTy, typename... Args>
void appendRewrite(Args &&...args) {
+ assert(config.allowPatternRollback && "appending rewrites is not allowed");
rewrites.push_back(
std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
}
@@ -889,15 +899,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasOpReplaced(Operation *op) const;
/// Lookup the most recently mapped values with the desired types in the
- /// mapping.
- ///
- /// Special cases:
- /// - If the desired type range is empty, simply return the most recently
- /// mapped values.
- /// - If there is no mapping to the desired types, also return the most
- /// recently mapped values.
- /// - If there is no mapping for the given values at all, return the given
- /// value.
+ /// mapping, taking into account only replacements. Perform a best-effort
+ /// search for existing materializations with the desired types.
///
/// If `skipPureTypeConversions` is "true", materializations that are pure
/// type conversions are not considered.
@@ -1066,6 +1069,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
ConversionValueMapping mapping;
/// Ordered list of block operations (creations, splits, motions).
+ /// This vector is maintained only if `allowPatternRollback` is set to
+ /// "true". Otherwise, all IR rewrites are materialized immediately and no
+ /// bookkeeping is needed.
SmallVector<std::unique_ptr<IRRewrite>> rewrites;
/// A set of operations that should no longer be considered for legalization.
@@ -1089,6 +1095,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// by the current pattern.
SetVector<Block *> patternInsertedBlocks;
+ /// A list of unresolved materializations that were created by the current
+ /// pattern.
+ DenseSet<UnrealizedConversionCastOp> patternMaterializations;
+
/// A mapping for looking up metadata of unresolved materializations.
DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
unresolvedMaterializations;
@@ -1104,6 +1114,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Dialect conversion configuration.
const ConversionConfig &config;
+ /// A set of erased operations. This set is utilized only if
+ /// `allowPatternRollback` is set to "false". Conceptually, this set is
+ /// similar to `replacedOps` (which is maintained when the flag is set to
+ /// "true"). However, erasing from a DenseSet is more efficient than erasing
+ /// from a SetVector.
+ DenseSet<Operation *> erasedOps;
+
+ /// A set of erased blocks. This set is utilized only if
+ /// `allowPatternRollback` is set to "false".
+ DenseSet<Block *> erasedBlocks;
+
+ /// A rewriter that notifies the listener (if any) about all IR
+ /// modifications. This rewriter is utilized only if `allowPatternRollback`
+ /// is set to "false". If the flag is set to "true", the listener is notified
+ /// with a separate mechanism (e.g., in `IRRewrite::commit`).
+ IRRewriter notifyingRewriter;
+
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
/// strictly necessary, and is thus only active during debug builds for extra
@@ -1111,8 +1138,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
SmallPtrSet<Operation *, 1> pendingRootUpdates;
/// A raw output stream used to prefix the debug log.
- llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + "] ").str(),
- llvm::dbgs(), /*HasPendingNewline=*/false};
+ llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
+ llvm::dbgs()};
/// A logger used to emit diagnostics during the conversion process.
llvm::ScopedPrinter logger{os};
@@ -1140,11 +1167,8 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
- Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
- if (!repl)
- return;
-
+static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
+ Value repl) {
if (isa<BlockArgument>(repl)) {
rewriter.replaceAllUsesWith(arg, repl);
return;
@@ -1161,6 +1185,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
});
}
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
+ Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+ if (!repl)
+ return;
+ performReplaceBlockArg(rewriter, arg, repl);
+}
+
void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
@@ -1246,6 +1277,30 @@ void ConversionPatternRewriterImpl::applyRewrites() {
ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
+ // Helper function that looks up a single value.
+ auto lookup = [&](const ValueVector &values) -> ValueVector {
+ assert(!values.empty() && "expected non-empty value vector");
+
+ // If the pattern rollback is enabled, use the mapping to look up the
+ // values.
+ if (config.allowPatternRollback)
+ return mapping.lookup(values);
+
+ // Otherwise, look up values by examining the IR. All replacements have
+ // already been materialized in IR.
+ Operation *op = getCommonDefiningOp(values);
+ if (!op)
+ return {};
+ auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
+ if (!castOp)
+ return {};
+ if (!this->unresolvedMaterializations.contains(castOp))
+ return {};
+ if (castOp.getOutputs() != values)
+ return {};
+ return castOp.getInputs();
+ };
+
// Helper function that looks up each value in `values` individually and then
// composes the results. If that fails, it tries to look up the entire vector
// at once.
@@ -1253,7 +1308,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
// If possible, replace each value with (one or multiple) mapped values.
ValueVector next;
for (Value v : values) {
- ValueVector r = mapping.lookup({v});
+ ValueVector r = lookup({v});
if (!r.empty()) {
llvm::append_range(next, r);
} else {
@@ -1273,7 +1328,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
// be stored (and looked up) in the mapping. But for performance reasons,
// we choose to reuse existing IR (when possible) instead of creating it
// multiple times.
- ValueVector r = mapping.lookup(values);
+ ValueVector r = lookup(values);
if (r.empty()) {
// No mapping found: The lookup stops here.
return {};
@@ -1347,15 +1402,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state,
void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
StringRef patternName) {
for (auto &rewrite :
- llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
- if (!config.allowPatternRollback &&
- !isa<UnresolvedMaterializationRewrite>(rewrite)) {
- // Unresolved materializations can always be rolled back (erased).
- llvm::report_fatal_error("pattern '" + patternName +
- "' rollback of IR modifications requested");
- }
+ llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
rewrite->rollback();
- }
rewrites.resize(numRewritesToKeep);
}
@@ -1419,12 +1467,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
// Check to see if this operation is ignored or was replaced.
- return replacedOps.count(op) || ignoredOps.count(op);
+ return wasOpReplaced(op) || ignoredOps.count(op);
}
bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
// Check to see if this operation was replaced.
- return replacedOps.count(op);
+ return replacedOps.count(op) || erasedOps.count(op);
}
//===----------------------------------------------------------------------===//
@@ -1508,7 +1556,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
// a bit more efficient, so we try to do that when possible.
bool fastPath = !config.listener;
if (fastPath) {
- appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+ if (config.allowPatternRollback)
+ appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
newBlock->getOperations().splice(newBlock->end(), block->getOperations());
} else {
while (!block->empty())
@@ -1556,7 +1605,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
replaceUsesOfBlockArgument(origArg, replArgs, converter);
}
- appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
+ if (config.allowPatternRollback)
+ appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
// Erase the old block. (It is just unlinked for now and will be erased during
// cleanup.)
@@ -1585,23 +1635,37 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
// tracking the materialization like we do for other operations.
OpBuilder builder(outputTypes.front().getContext());
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
- auto convertOp =
+ UnrealizedConversionCastOp convertOp =
UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
+ if (config.attachDebugMaterializationKind) {
+ StringRef kindStr =
+ kind == MaterializationKind::Source ? "source" : "target";
+ convertOp->setAttr("__kind__", builder.getStringAttr(kindStr));
+ }
if (isPureTypeConversion)
convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
- if (!valuesToMap.empty())
- mapping.map(valuesToMap, convertOp.getResults());
+
+ // Register the materialization.
if (castOp)
*castOp = convertOp;
unresolvedMaterializations[convertOp] =
UnresolvedMaterializationInfo(converter, kind, originalType);
- appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
- std::move(valuesToMap));
+ if (config.allowPatternRollback) {
+ if (!valuesToMap.empty())
+ mapping.map(valuesToMap, convertOp.getResults());
+ appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+ std::move(valuesToMap));
+ } else {
+ patternMaterializations.insert(convertOp);
+ }
return convertOp.getResults();
}
Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
Value value, const TypeConverter *converter) {
+ assert(config.allowPatternRollback &&
+ "this code path is valid only in rollback mode");
+
// Try to find a replacement value with the same type in the conversion value
// mapping. This includes cached materializations. We try to reuse those
// instead of generating duplicate IR.
@@ -1663,26 +1727,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
logger.getOStream() << " (was detached)";
logger.getOStream() << "\n";
});
- assert(!wasOpReplaced(op->getParentOp()) &&
+
+ // In rollback mode, it is easier to misuse the API, so perform extra error
+ // checking.
+ assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
"attempting to insert into a block within a replaced/erased op");
+ // In "no rollback" mode, the listener is always notified immediately.
+ if (!config.allowPatternRollback && config.listener)
+ config.listener->notifyOperationInserted(op, previous);
+
if (wasDetached) {
- // If the op was detached, it is most likely a newly created op.
- // TODO: If the same op is inserted multiple times from a detached state,
- // the rollback mechanism may erase the same op multiple times. This is a
- // bug in the rollback-based dialect conversion driver.
- appendRewrite<CreateOperationRewrite>(op);
+ // If the op was detached, it is most likely a newly created op. Add it the
+ // set of newly created ops, so that it will be legalized. If this op is
+ // not a newly created op, it will be legalized a second time, which is
+ // inefficient but harmless.
patternNewOps.insert(op);
+
+ if (config.allowPatternRollback) {
+ // TODO: If the same op is inserted multiple times from a detached
+ // state, the rollback mechanism may erase the same op multiple times.
+ // This is a bug in the rollback-based dialect conversion driver.
+ appendRewrite<CreateOperationRewrite>(op);
+ } else {
+ // In "no rollback" mode, there is an extra data structure for tracking
+ // erased operations that must be kept up to date.
+ erasedOps.erase(op);
+ }
return;
}
// The op was moved from one place to another.
- appendRewrite<MoveOperationRewrite>(op, previous);
+ if (config.allowPatternRollback)
+ appendRewrite<MoveOperationRewrite>(op, previous);
+}
+
+/// Given that `fromRange` is about to be replaced with `toRange`, compute
+/// replacement values with the types of `fromRange`.
+static SmallVector<Value>
+getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
+ const SmallVector<SmallVector<Value>> &toRange,
+ const TypeConverter *converter) {
+ assert(!impl.config.allowPatternRollback &&
+ "this code path is valid only in 'no rollback' mode");
+ SmallVector<Value> repls;
+ for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
+ if (from.use_empty()) {
+ // The replaced value is dead. No replacement value is needed.
+ repls.push_back(Value());
+ continue;
+ }
+
+ if (to.empty()) {
+ // The replaced value is dropped. Materialize a replacement value "out of
+ // thin air".
+ Value srcMat = impl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
+ /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
+ /*outputTypes=*/from.getType(), /*originalType=*/Type(),
+ converter)[0];
+ repls.push_back(srcMat);
+ continue;
+ }
+
+ if (TypeRange(ValueRange(to)) == TypeRange(from.getType())) {
+ // The replacement value already has the correct type. Use it directly.
+ repls.push_back(to[0]);
+ continue;
+ }
+
+ // The replacement value has the wrong type. Build a source materialization
+ // to the original type.
+ // TODO: This is a bit inefficient. We should try to reuse existing
+ // materializations if possible. This would require an extension of the
+ // `lookupOrDefault` API.
+ Value srcMat = impl.buildUnresolvedMaterialization(
+ MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
+ /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
+ /*originalType=*/Type(), converter)[0];
+ repls.push_back(srcMat);
+ }
+
+ return repls;
}
void ConversionPatternRewriterImpl::replaceOp(
Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
- assert(newValues.size() == op->getNumResults());
+ assert(newValues.size() == op->getNumResults() &&
+ "incorrect number of replacement values");
+
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed: materialize all IR changes immediately.
+ SmallVector<Value> repls = getReplacementValues(
+ *this, op->getResults(), newValues, currentTypeConverter);
+ // Update internal data structures, so that there are no dangling pointers
+ // to erased IR.
+ op->walk([&](Operation *op) {
+ erasedOps.insert(op);
+ ignoredOps.remove(op);
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ unresolvedMaterializations.erase(castOp);
+ patternMaterializations.erase(castOp);
+ }
+ // The original op will be erased, so remove it from the set of
+ // unlegalized ops.
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->erase(op);
+ });
+ op->walk([&](Block *block) { erasedBlocks.insert(block); });
+ // Replace the op with the replacement values and notify the listener.
+ notifyingRewriter.replaceOp(op, repls);
+ return;
+ }
+
assert(!ignoredOps.contains(op) && "operation was already replaced");
// Check if replaced op is an unresolved materialization, i.e., an
@@ -1722,11 +1879,46 @@ void ConversionPatternRewriterImpl::replaceOp(
void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
BlockArgument from, ValueRange to, const TypeConverter *converter) {
+ if (!config.allowPatternRollback) {
+ SmallVector<Value> toConv = llvm::to_vector(to);
+ SmallVector<Value> repls =
+ getReplacementValues(*this, from, {toConv}, converter);
+ IRRewriter r(from.getContext());
+ Value repl = repls.front();
+ if (!repl)
+ return;
+
+ performReplaceBlockArg(r, from, repl);
+ return;
+ }
+
appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
mapping.map(from, to);
}
void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+ if (!config.allowPatternRollback) {
+ // Pattern rollback is not allowed: materialize all IR changes immediately.
+ // Update internal data structures, so that there are no dangling pointers
+ // to erased IR.
+ block->walk([&](Operation *op) {
+ erasedOps.insert(op);
+ ignoredOps.remove(op);
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+ unresolvedMaterializations.erase(castOp);
+ patternMaterializations.erase(castOp);
+ }
+ // The original op will be erased, so remove it from the set of
+ // unlegalized ops.
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->erase(op);
+ });
+ block->walk([&](Block *block) { erasedBlocks.insert(block); });
+ // Erase the block and notify the listener.
+ notifyingRewriter.eraseBlock(block);
+ return;
+ }
+
assert(!wasOpReplaced(block->getParentOp()) &&
"attempting to erase a block within a replaced/erased op");
appendRewrite<EraseBlockRewrite>(block);
@@ -1760,23 +1952,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
logger.getOStream() << " (was detached)";
logger.getOStream() << "\n";
});
- assert(!wasOpReplaced(newParentOp) &&
+
+ // In rollback mode, it is easier to misuse the API, so perform extra error
+ // checking.
+ assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
"attempting to insert into a region within a replaced/erased op");
(void)newParentOp;
+ // In "no rollback" mode, the listener is always notified immediately.
+ if (!config.allowPatternRollback && config.listener)
+ config.listener->notifyBlockInserted(block, previous, previousIt);
+
patternInsertedBlocks.insert(block);
if (wasDetached) {
// If the block was detached, it is most likely a newly created block.
- // TODO: If the same block is inserted multiple times from a detached state,
- // the rollback mechanism may erase the same block multiple times. This is a
- // bug in the rollback-based dialect conversion driver.
- appendRewrite<CreateBlockRewrite>(block);
+ if (config.allowPatternRollback) {
+ // TODO: If the same block is inserted multiple times from a detached
+ // state, the rollback mechanism may erase the same block multiple times.
+ // This is a bug in the rollback-based dialect conversion driver.
+ appendRewrite<CreateBlockRewrite>(block);
+ } else {
+ // In "no rollback" mode, there is an extra data structure for tracking
+ // erased blocks that must be kept up to date.
+ erasedBlocks.erase(block);
+ }
return;
}
// The block was moved from one place to another.
- appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
+ if (config.allowPatternRollback)
+ appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
}
void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
@@ -1956,7 +2162,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
// a bit more efficient, so we try to do that when possible.
bool fastPath = !getConfig().listener;
- if (fastPath)
+ if (fastPath && impl->config.allowPatternRollback)
impl->inlineBlockBefore(source, dest, before);
// Replace all uses of block arguments.
@@ -1982,6 +2188,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
}
void ConversionPatternRewriter::startOpModification(Operation *op) {
+ if (!impl->config.allowPatternRollback) {
+ // Pattern rollback is not allowed: no extra bookkeeping is needed.
+ PatternRewriter::startOpModification(op);
+ return;
+ }
assert(!impl->wasOpReplaced(op) &&
"attempting to modify a replaced/erased op");
#ifndef NDEBUG
@@ -1991,20 +2202,29 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
}
void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
- assert(!impl->wasOpReplaced(op) &&
- "attempting to modify a replaced/erased op");
- PatternRewriter::finalizeOpModification(op);
impl->patternModifiedOps.insert(op);
+ if (!impl->config.allowPatternRollback) {
+ PatternRewriter::finalizeOpModification(op);
+ if (getConfig().listener)
+ getConfig().listener->notifyOperationModified(op);
+ return;
+ }
// There is nothing to do here, we only need to track the operation at the
// start of the update.
#ifndef NDEBUG
+ assert(!impl->wasOpReplaced(op) &&
+ "attempting to modify a replaced/erased op");
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
#endif
}
void ConversionPatternRewriter::cancelOpModification(Operation *op) {
+ if (!impl->config.allowPatternRollback) {
+ PatternRewriter::cancelOpModification(op);
+ return;
+ }
#ifndef NDEBUG
assert(impl->pendingRootUpdates.erase(op) &&
"operation did not have a pending in-place update");
@@ -2029,17 +2249,17 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
// ConversionPattern
//===----------------------------------------------------------------------===//
-SmallVector<Value> ConversionPattern::getOneToOneAdaptorOperands(
+FailureOr<SmallVector<Value>> ConversionPattern::getOneToOneAdaptorOperands(
ArrayRef<ValueRange> operands) const {
SmallVector<Value> oneToOneOperands;
oneToOneOperands.reserve(operands.size());
for (ValueRange operand : operands) {
if (operand.size() != 1)
- llvm::report_fatal_error("pattern '" + getDebugName() +
- "' does not support 1:N conversion");
+ return failure();
+
oneToOneOperands.push_back(operand.front());
}
- return oneToOneOperands;
+ return std::move(oneToOneOperands);
}
LogicalResult
@@ -2257,15 +2477,17 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
- // If the operation isn't legal, try to fold it in-place.
- // TODO: Should we always try to do this, even if the op is
- // already legal?
- if (succeeded(legalizeWithFold(op, rewriter))) {
- LLVM_DEBUG({
- logSuccess(logger, "operation was folded");
- logger.startLine() << logLineComment;
- });
- return success();
+ // If the operation is not legal, try to fold it in-place if the folding mode
+ // is 'BeforePatterns'. 'Never' will skip this.
+ const ConversionConfig &config = rewriter.getConfig();
+ if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
+ if (succeeded(legalizeWithFold(op, rewriter))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
}
// Otherwise, we need to apply a legalization pattern to this operation.
@@ -2277,6 +2499,18 @@ OperationLegalizer::legalize(Operation *op,
return success();
}
+ // If the operation can't be legalized via patterns, try to fold it in-place
+ // if the folding mode is 'AfterPatterns'.
+ if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
+ if (succeeded(legalizeWithFold(op, rewriter))) {
+ LLVM_DEBUG({
+ logSuccess(logger, "operation was folded");
+ logger.startLine() << logLineComment;
+ });
+ return success();
+ }
+ }
+
LLVM_DEBUG({
logFailure(logger, "no matched legalization pattern");
logger.startLine() << logLineComment;
@@ -2425,17 +2659,23 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
RewriterState curState = rewriterImpl.getCurrentState();
auto onFailure = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
-#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (!rewriterImpl.config.allowPatternRollback) {
- // Returning "failure" after modifying IR is not allowed.
+ // Erase all unresolved materializations.
+ for (auto op : rewriterImpl.patternMaterializations) {
+ rewriterImpl.unresolvedMaterializations.erase(op);
+ op.erase();
+ }
+ rewriterImpl.patternMaterializations.clear();
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ // Expensive pattern check that can detect API violations.
if (checkOp) {
OperationFingerPrint fingerPrintAfterPattern(checkOp);
if (fingerPrintAfterPattern != *topLevelFingerPrint)
llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
"' returned failure but IR did change");
}
- }
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+ }
rewriterImpl.patternNewOps.clear();
rewriterImpl.patternModifiedOps.clear();
rewriterImpl.patternInsertedBlocks.clear();
@@ -2459,6 +2699,16 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// successfully applied.
auto onSuccess = [&](const Pattern &pattern) {
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+ if (!rewriterImpl.config.allowPatternRollback) {
+ // Eagerly erase unused materializations.
+ for (auto op : rewriterImpl.patternMaterializations) {
+ if (op->use_empty()) {
+ rewriterImpl.unresolvedMaterializations.erase(op);
+ op.erase();
+ }
+ }
+ rewriterImpl.patternMaterializations.clear();
+ }
SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
SetVector<Operation *> modifiedOps =
moveAndReset(rewriterImpl.patternModifiedOps);
@@ -2549,6 +2799,9 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
// If the pattern moved or created any blocks, make sure the types of block
// arguments get legalized.
for (Block *block : insertedBlocks) {
+ if (impl.erasedBlocks.contains(block))
+ continue;
+
// Only check blocks outside of the current operation.
Operation *parentOp = block->getParentOp();
if (!parentOp || parentOp == op || block->getNumArguments() == 0)
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 607b86c..0324588 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -15,6 +15,8 @@
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
@@ -23,7 +25,7 @@
#include "llvm/ADT/BitVector.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/ScopeExit.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_ostream.h"
@@ -178,9 +180,8 @@ static Operation *getDumpRootOp(Operation *op) {
return op;
}
static void logSuccessfulFolding(Operation *op) {
- llvm::dbgs() << "// *** IR Dump After Successful Folding ***\n";
- op->dump();
- llvm::dbgs() << "\n\n";
+ LDBG() << "// *** IR Dump After Successful Folding ***\n"
+ << OpWithFlags(op, OpPrintingFlags().elideLargeElementsAttrs());
}
#endif // NDEBUG
@@ -394,8 +395,12 @@ private:
function_ref<void(Diagnostic &)> reasonCallback) override;
#ifndef NDEBUG
+ /// A raw output stream used to prefix the debug log.
+
+ llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + ":1] ").str(),
+ llvm::dbgs()};
/// A logger used to emit information during the application process.
- llvm::ScopedPrinter logger{llvm::dbgs()};
+ llvm::ScopedPrinter logger{os};
#endif
/// The low-level pattern applicator.
@@ -871,7 +876,18 @@ LogicalResult RegionPatternRewriteDriver::simplify(bool *changed) && {
ctx->executeAction<GreedyPatternRewriteIteration>(
[&] {
- continueRewrites = processWorklist();
+ continueRewrites = false;
+
+ // Erase unreachable blocks
+ // Operations like:
+ // %add = arith.addi %add, %add : i64
+ // are legal in unreachable code. Unfortunately many patterns would be
+ // unsafe to apply on such IR and can lead to crashes or infinite
+ // loops.
+ continueRewrites |=
+ succeeded(eraseUnreachableBlocks(rewriter, region));
+
+ continueRewrites |= processWorklist();
// After applying patterns, make sure that the CFG of each of the
// regions is kept up to date.
@@ -917,10 +933,9 @@ mlir::applyPatternsGreedily(Region &region,
RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
region);
LogicalResult converged = std::move(driver).simplify(changed);
- LLVM_DEBUG(if (failed(converged)) {
- llvm::dbgs() << "The pattern rewrite did not converge after scanning "
- << config.getMaxIterations() << " times\n";
- });
+ if (failed(converged))
+ LDBG() << "The pattern rewrite did not converge after scanning "
+ << config.getMaxIterations() << " times";
return converged;
}
@@ -1052,9 +1067,8 @@ LogicalResult mlir::applyOpPatternsGreedily(
LogicalResult converged = std::move(driver).simplify(ops, changed);
if (allErased)
*allErased = surviving.empty();
- LLVM_DEBUG(if (failed(converged)) {
- llvm::dbgs() << "The pattern rewrite did not converge after "
- << config.getMaxNumRewrites() << " rewrites";
- });
+ if (failed(converged))
+ LDBG() << "The pattern rewrite did not converge after "
+ << config.getMaxNumRewrites() << " rewrites";
return converged;
}
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index eeb4052..5ea3105 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -13,6 +13,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/CallInterfaces.h"
@@ -182,6 +183,11 @@ static bool isLegalToInline(InlinerInterface &interface, Region *src,
IRMapping &valueMapping) {
for (auto &block : *src) {
for (auto &op : block) {
+ // UnrealizedConversionCastOp is inlineable but cannot implement the
+ // inliner interface due to layering constraints.
+ if (isa<UnrealizedConversionCastOp>(op))
+ continue;
+
// Check this operation.
if (!interface.isLegalToInline(&op, insertRegion,
shouldCloneInlinedRegion, valueMapping)) {
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index a1d975d..31ae1d1 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -23,12 +23,15 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include <deque>
#include <iterator>
using namespace mlir;
+#define DEBUG_TYPE "region-utils"
+
void mlir::replaceAllUsesInRegionWith(Value orig, Value replacement,
Region &region) {
for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
@@ -182,19 +185,34 @@ SmallVector<Value> mlir::makeRegionIsolatedFromAbove(
// TODO: We could likely merge this with the DCE algorithm below.
LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
MutableArrayRef<Region> regions) {
+ LDBG() << "Starting eraseUnreachableBlocks with " << regions.size()
+ << " regions";
+
// Set of blocks found to be reachable within a given region.
llvm::df_iterator_default_set<Block *, 16> reachable;
// If any blocks were found to be dead.
- bool erasedDeadBlocks = false;
+ int erasedDeadBlocks = 0;
SmallVector<Region *, 1> worklist;
worklist.reserve(regions.size());
for (Region &region : regions)
worklist.push_back(&region);
+
+ LDBG(2) << "Initial worklist size: " << worklist.size();
+
while (!worklist.empty()) {
Region *region = worklist.pop_back_val();
- if (region->empty())
+ if (region->empty()) {
+ LDBG(2) << "Skipping empty region";
continue;
+ }
+
+ LDBG(2) << "Processing region with " << region->getBlocks().size()
+ << " blocks";
+ if (region->getParentOp())
+ LDBG(2) << " -> for operation: "
+ << OpWithFlags(region->getParentOp(),
+ OpPrintingFlags().skipRegions());
// If this is a single block region, just collect the nested regions.
if (region->hasOneBlock()) {
@@ -209,13 +227,17 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
for (Block *block : depth_first_ext(&region->front(), reachable))
(void)block /* Mark all reachable blocks */;
+ LDBG(2) << "Found " << reachable.size() << " reachable blocks out of "
+ << region->getBlocks().size() << " total blocks";
+
// Collect all of the dead blocks and push the live regions onto the
// worklist.
for (Block &block : llvm::make_early_inc_range(*region)) {
if (!reachable.count(&block)) {
+ LDBG() << "Erasing unreachable block: " << &block;
block.dropAllDefinedValueUses();
rewriter.eraseBlock(&block);
- erasedDeadBlocks = true;
+ ++erasedDeadBlocks;
continue;
}
@@ -226,7 +248,10 @@ LogicalResult mlir::eraseUnreachableBlocks(RewriterBase &rewriter,
}
}
- return success(erasedDeadBlocks);
+ LDBG() << "Finished eraseUnreachableBlocks, erased " << erasedDeadBlocks
+ << " dead blocks";
+
+ return success(erasedDeadBlocks > 0);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
index ee5c642..1382550 100644
--- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp
@@ -13,18 +13,40 @@
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Rewrite/PatternApplicator.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#define DEBUG_TYPE "walk-rewriter"
namespace mlir {
+// Find all reachable blocks in the region and add them to the visitedBlocks
+// set.
+static void findReachableBlocks(Region &region,
+ DenseSet<Block *> &reachableBlocks) {
+ Block *entryBlock = &region.front();
+ reachableBlocks.insert(entryBlock);
+ // Traverse the CFG and add all reachable blocks to the blockList.
+ SmallVector<Block *> worklist({entryBlock});
+ while (!worklist.empty()) {
+ Block *block = worklist.pop_back_val();
+ Operation *terminator = &block->back();
+ for (Block *successor : terminator->getSuccessors()) {
+ if (reachableBlocks.contains(successor))
+ continue;
+ worklist.push_back(successor);
+ reachableBlocks.insert(successor);
+ }
+ }
+}
+
namespace {
struct WalkAndApplyPatternsAction final
: tracing::ActionImpl<WalkAndApplyPatternsAction> {
@@ -88,20 +110,104 @@ void walkAndApplyPatterns(Operation *op,
PatternApplicator applicator(patterns);
applicator.applyDefaultCostModel();
+ // Iterator on all reachable operations in the region.
+ // Also keep track if we visited the nested regions of the current op
+ // already to drive the post-order traversal.
+ struct RegionReachableOpIterator {
+ RegionReachableOpIterator(Region *region) : region(region) {
+ regionIt = region->begin();
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ if (!llvm::hasSingleElement(*region))
+ findReachableBlocks(*region, reachableBlocks);
+ }
+ // Advance the iterator to the next reachable operation.
+ void advance() {
+ assert(regionIt != region->end());
+ hasVisitedRegions = false;
+ if (blockIt == regionIt->end()) {
+ ++regionIt;
+ while (regionIt != region->end() &&
+ !reachableBlocks.contains(&*regionIt))
+ ++regionIt;
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ return;
+ }
+ ++blockIt;
+ if (blockIt != regionIt->end()) {
+ LDBG() << "Incrementing block iterator, next op: "
+ << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+ }
+ }
+ // The region we're iterating over.
+ Region *region;
+ // The Block currently being iterated over.
+ Region::iterator regionIt;
+ // The Operation currently being iterated over.
+ Block::iterator blockIt;
+ // The set of blocks that are reachable in the current region.
+ DenseSet<Block *> reachableBlocks;
+ // Whether we've visited the nested regions of the current op already.
+ bool hasVisitedRegions = false;
+ };
+
+ // Worklist of regions to visit to drive the post-order traversal.
+ SmallVector<RegionReachableOpIterator> worklist;
+
+ LDBG() << "Starting walk-based pattern rewrite driver";
ctx->executeAction<WalkAndApplyPatternsAction>(
[&] {
+ // Perform a post-order traversal of the regions, visiting each
+ // reachable operation.
for (Region &region : op->getRegions()) {
- region.walk([&](Operation *visitedOp) {
- LLVM_DEBUG(llvm::dbgs() << "Visiting op: "; visitedOp->print(
- llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";);
+ assert(worklist.empty());
+ if (region.empty())
+ continue;
+
+ // Prime the worklist with the entry block of this region.
+ worklist.push_back({&region});
+ while (!worklist.empty()) {
+ RegionReachableOpIterator &it = worklist.back();
+ if (it.regionIt == it.region->end()) {
+ // We're done with this region.
+ worklist.pop_back();
+ continue;
+ }
+ if (it.blockIt == it.regionIt->end()) {
+ // We're done with this block.
+ it.advance();
+ continue;
+ }
+ Operation *op = &*it.blockIt;
+ // If we haven't visited the nested regions of this op yet,
+ // enqueue them.
+ if (!it.hasVisitedRegions) {
+ it.hasVisitedRegions = true;
+ for (Region &nestedRegion : llvm::reverse(op->getRegions())) {
+ if (nestedRegion.empty())
+ continue;
+ worklist.push_back({&nestedRegion});
+ }
+ }
+ // If we're not at the back of the worklist, we've enqueued some
+ // nested region for processing. We'll come back to this op later
+ // (post-order)
+ if (&it != &worklist.back())
+ continue;
+
+ // Preemptively increment the iterator, in case the current op
+ // would be erased.
+ it.advance();
+
+ LDBG() << "Visiting op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- erasedListener.visitedOp = visitedOp;
+ erasedListener.visitedOp = op;
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (succeeded(applicator.matchAndRewrite(visitedOp, rewriter))) {
- LLVM_DEBUG(llvm::dbgs() << "\tOp matched and rewritten\n";);
- }
- });
+ if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+ LDBG() << "\tOp matched and rewritten";
+ }
}
},
{op});