aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp8
-rw-r--r--mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp2
-rw-r--r--mlir/lib/Analysis/DataFlowFramework.cpp20
-rw-r--r--mlir/lib/AsmParser/DialectSymbolParser.cpp7
-rw-r--r--mlir/lib/AsmParser/Lexer.cpp27
-rw-r--r--mlir/lib/AsmParser/Lexer.h3
-rw-r--r--mlir/lib/CAPI/RegisterEverything/CMakeLists.txt11
-rw-r--r--mlir/lib/CMakeLists.txt34
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp41
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp20
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp78
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp36
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp69
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp64
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp16
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp55
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp94
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp2
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp36
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp20
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp60
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp51
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp8
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp4
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp17
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp12
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp3
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp122
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp14
-rw-r--r--mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp28
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp111
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp60
-rw-r--r--mlir/lib/IR/Diagnostics.cpp21
-rw-r--r--mlir/lib/IR/PatternMatch.cpp10
-rw-r--r--mlir/lib/Pass/Pass.cpp2
-rw-r--r--mlir/lib/RegisterAllDialects.cpp207
-rw-r--r--mlir/lib/RegisterAllExtensions.cpp115
-rw-r--r--mlir/lib/RegisterAllPasses.cpp99
-rw-r--r--mlir/lib/Support/ToolUtilities.cpp39
-rw-r--r--mlir/lib/Support/TypeID.cpp3
-rw-r--r--mlir/lib/Target/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp45
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt21
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp103
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp13
-rw-r--r--mlir/lib/Tools/mlir-opt/MlirOptMain.cpp55
-rw-r--r--mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp7
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp51
-rw-r--r--mlir/lib/Transforms/Utils/Inliner.cpp33
55 files changed, 1347 insertions, 643 deletions
diff --git a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
index 51fa773..fb5649e 100644
--- a/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/ConstantPropagationAnalysis.cpp
@@ -16,6 +16,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include <cassert>
#define DEBUG_TYPE "constant-propagation"
@@ -46,7 +47,7 @@ void ConstantValue::print(raw_ostream &os) const {
LogicalResult SparseConstantPropagation::visitOperation(
Operation *op, ArrayRef<const Lattice<ConstantValue> *> operands,
ArrayRef<Lattice<ConstantValue> *> results) {
- LLVM_DEBUG(llvm::dbgs() << "SCP: Visiting operation: " << *op << "\n");
+ LDBG() << "SCP: Visiting operation: " << *op;
// Don't try to simulate the results of a region operation as we can't
// guarantee that folding will be out-of-place. We don't allow in-place
@@ -98,12 +99,11 @@ LogicalResult SparseConstantPropagation::visitOperation(
// Merge in the result of the fold, either a constant or a value.
OpFoldResult foldResult = std::get<1>(it);
if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(foldResult)) {
- LLVM_DEBUG(llvm::dbgs() << "Folded to constant: " << attr << "\n");
+ LDBG() << "Folded to constant: " << attr;
propagateIfChanged(lattice,
lattice->join(ConstantValue(attr, op->getDialect())));
} else {
- LLVM_DEBUG(llvm::dbgs()
- << "Folded to value: " << cast<Value>(foldResult) << "\n");
+ LDBG() << "Folded to value: " << cast<Value>(foldResult);
AbstractSparseForwardDataFlowAnalysis::join(
lattice, *getLatticeElement(cast<Value>(foldResult)));
}
diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
index 197f97f..509f520 100644
--- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp
@@ -294,7 +294,7 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
solver.load<LivenessAnalysis>(symbolTable);
LDBG() << "Initializing and running solver";
(void)solver.initializeAndRun(op);
- LDBG() << "Dumping liveness state for op";
+ LDBG() << "RunLivenessAnalysis initialized for op: " << op->getName();
}
const Liveness *RunLivenessAnalysis::getLiveness(Value val) {
diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp
index 176d53e..16f7033 100644
--- a/mlir/lib/Analysis/DataFlowFramework.cpp
+++ b/mlir/lib/Analysis/DataFlowFramework.cpp
@@ -14,7 +14,7 @@
#include "llvm/ADT/iterator.h"
#include "llvm/Config/abi-breaking.h"
#include "llvm/Support/Casting.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "dataflow"
@@ -44,9 +44,8 @@ void AnalysisState::addDependency(ProgramPoint *dependent,
(void)inserted;
DATAFLOW_DEBUG({
if (inserted) {
- llvm::dbgs() << "Creating dependency between " << debugName << " of "
- << anchor << "\nand " << debugName << " on " << dependent
- << "\n";
+ LDBG() << "Creating dependency between " << debugName << " of " << anchor
+ << "\nand " << debugName << " on " << dependent;
}
});
}
@@ -116,8 +115,7 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
// Initialize the analyses.
for (DataFlowAnalysis &analysis : llvm::make_pointee_range(childAnalyses)) {
- DATAFLOW_DEBUG(llvm::dbgs()
- << "Priming analysis: " << analysis.debugName << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Priming analysis: " << analysis.debugName);
if (failed(analysis.initialize(top)))
return failure();
}
@@ -129,8 +127,8 @@ LogicalResult DataFlowSolver::initializeAndRun(Operation *top) {
auto [point, analysis] = worklist.front();
worklist.pop();
- DATAFLOW_DEBUG(llvm::dbgs() << "Invoking '" << analysis->debugName
- << "' on: " << point << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Invoking '" << analysis->debugName
+ << "' on: " << point);
if (failed(analysis->visit(point)))
return failure();
}
@@ -143,9 +141,9 @@ void DataFlowSolver::propagateIfChanged(AnalysisState *state,
assert(isRunning &&
"DataFlowSolver is not running, should not use propagateIfChanged");
if (changed == ChangeResult::Change) {
- DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName
- << " of " << state->anchor << "\n"
- << "Value: " << *state << "\n");
+ DATAFLOW_DEBUG(LDBG() << "Propagating update to " << state->debugName
+ << " of " << state->anchor << "\n"
+ << "Value: " << *state);
state->onUpdate(this);
}
}
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 9f4a87a..8b14e71 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -89,6 +89,7 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
nestedPunctuation.pop_back();
return success();
};
+ const char *curBufferEnd = state.lex.getBufferEnd();
do {
// Handle code completions, which may appear in the middle of the symbol
// body.
@@ -98,6 +99,12 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
break;
}
+ if (curBufferEnd == curPtr) {
+ if (!nestedPunctuation.empty())
+ return emitPunctError();
+ return emitError("unexpected nul or EOF in pretty dialect name");
+ }
+
char c = *curPtr++;
switch (c) {
case '\0':
diff --git a/mlir/lib/AsmParser/Lexer.cpp b/mlir/lib/AsmParser/Lexer.cpp
index 751bd63..8f53529 100644
--- a/mlir/lib/AsmParser/Lexer.cpp
+++ b/mlir/lib/AsmParser/Lexer.cpp
@@ -37,6 +37,18 @@ Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context,
AsmParserCodeCompleteContext *codeCompleteContext)
: sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) {
auto bufferID = sourceMgr.getMainFileID();
+
+ // Check to see if the main buffer contains the last buffer, and if so the
+ // last buffer should be used as main file for parsing.
+ if (sourceMgr.getNumBuffers() > 1) {
+ unsigned lastFileID = sourceMgr.getNumBuffers();
+ const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID);
+ const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID);
+ if (main->getBufferStart() <= last->getBufferStart() &&
+ main->getBufferEnd() >= last->getBufferEnd()) {
+ bufferID = lastFileID;
+ }
+ }
curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
curPtr = curBuffer.begin();
@@ -71,6 +83,7 @@ Token Lexer::emitError(const char *loc, const Twine &message) {
}
Token Lexer::lexToken() {
+ const char *curBufferEnd = curBuffer.end();
while (true) {
const char *tokStart = curPtr;
@@ -78,6 +91,9 @@ Token Lexer::lexToken() {
if (tokStart == codeCompleteLoc)
return formToken(Token::code_complete, tokStart);
+ if (tokStart == curBufferEnd)
+ return formToken(Token::eof, tokStart);
+
// Lex the next token.
switch (*curPtr++) {
default:
@@ -102,7 +118,7 @@ Token Lexer::lexToken() {
case 0:
// This may either be a nul character in the source file or may be the EOF
// marker that llvm::MemoryBuffer guarantees will be there.
- if (curPtr - 1 == curBuffer.end())
+ if (curPtr - 1 == curBufferEnd)
return formToken(Token::eof, tokStart);
continue;
@@ -259,7 +275,11 @@ void Lexer::skipComment() {
assert(*curPtr == '/');
++curPtr;
+ const char *curBufferEnd = curBuffer.end();
while (true) {
+ if (curPtr == curBufferEnd)
+ return;
+
switch (*curPtr++) {
case '\n':
case '\r':
@@ -267,7 +287,7 @@ void Lexer::skipComment() {
return;
case 0:
// If this is the end of the buffer, end the comment.
- if (curPtr - 1 == curBuffer.end()) {
+ if (curPtr - 1 == curBufferEnd) {
--curPtr;
return;
}
@@ -405,6 +425,7 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
Token Lexer::lexString(const char *tokStart) {
assert(curPtr[-1] == '"');
+ const char *curBufferEnd = curBuffer.end();
while (true) {
// Check to see if there is a code completion location within the string. In
// these cases we generate a completion location and place the currently
@@ -419,7 +440,7 @@ Token Lexer::lexString(const char *tokStart) {
case 0:
// If this is a random nul character in the middle of a string, just
// include it. If it is the end of file, then it is an error.
- if (curPtr - 1 != curBuffer.end())
+ if (curPtr - 1 != curBufferEnd)
continue;
[[fallthrough]];
case '\n':
diff --git a/mlir/lib/AsmParser/Lexer.h b/mlir/lib/AsmParser/Lexer.h
index 4085a9b..670444e 100644
--- a/mlir/lib/AsmParser/Lexer.h
+++ b/mlir/lib/AsmParser/Lexer.h
@@ -40,6 +40,9 @@ public:
/// Returns the start of the buffer.
const char *getBufferBegin() { return curBuffer.data(); }
+ /// Returns the end of the buffer.
+ const char *getBufferEnd() { return curBuffer.end(); }
+
/// Return the code completion location of the lexer, or nullptr if there is
/// none.
const char *getCodeCompleteLoc() const { return codeCompleteLoc; }
diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
index 8b9a395..ccda668 100644
--- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
+++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
@@ -1,19 +1,16 @@
# Dialect registration.
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
-get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
-get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything
RegisterEverything.cpp
LINK_LIBS PUBLIC
- ${dialect_libs}
${translation_libs}
- ${conversion_libs}
- ${extension_libs}
MLIRBuiltinToLLVMIRTranslation
MLIRCAPIIR
- MLIRLLVMToLLVMIRTranslation
MLIRCAPITransforms
+ MLIRLLVMToLLVMIRTranslation
+ MLIRRegisterAllDialects
+ MLIRRegisterAllExtensions
+ MLIRRegisterAllPasses
)
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index d25c84a..191b5ab6 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -20,3 +20,37 @@ add_subdirectory(Target)
add_subdirectory(Tools)
add_subdirectory(Transforms)
add_subdirectory(ExecutionEngine)
+
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
+get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
+
+add_mlir_library(MLIRRegisterAllDialects
+ RegisterAllDialects.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs}
+ )
+
+add_mlir_library(MLIRRegisterAllPasses
+ RegisterAllPasses.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs} # Some passes are part of the dialect libs
+ ${conversion_libs}
+ )
+
+add_mlir_library(MLIRRegisterAllExtensions
+ RegisterAllExtensions.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs}
+ ${conversion_libs}
+ ${extension_libs}
+ )
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 6f0fc29..35ad99c 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
patterns.getContext(), "__ocml_cabs_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
+ patterns.getContext(), "__ocml_carg_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
+ patterns.getContext(), "__ocml_carg_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
+ patterns.getContext(), "__ocml_conj_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
+ patterns.getContext(), "__ocml_conj_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ccos_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ccos_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
patterns.getContext(), "__ocml_cexp_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
patterns.getContext(), "__ocml_cexp_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
+ patterns.getContext(), "__ocml_clog_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
+ patterns.getContext(), "__ocml_clog_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
+ patterns.getContext(), "__ocml_cpow_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
+ patterns.getContext(), "__ocml_cpow_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csin_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csin_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csqrt_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csqrt_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctan_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctan_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctanh_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctanh_f64");
}
namespace {
@@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+ target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
+ complex::CosOp, complex::ExpOp, complex::LogOp,
+ complex::PowOp, complex::SinOp, complex::SqrtOp,
+ complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 75e6563..1817861 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -507,25 +507,27 @@ LogicalResult GPURotateConversion::matchAndRewrite(
getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
unsigned subgroupSize =
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
- IntegerAttr widthAttr;
- if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
- widthAttr.getValue().getZExtValue() > subgroupSize)
+ unsigned width = rotateOp.getWidth();
+ if (width > subgroupSize)
return rewriter.notifyMatchFailure(
- rotateOp,
- "rotate width is not a constant or larger than target subgroup size");
+ rotateOp, "rotate width is larger than target subgroup size");
Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+ Value offsetVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
+ Value widthVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
- rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(),
- adaptor.getWidth());
+ rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
Value validVal;
- if (widthAttr.getValue().getZExtValue() == subgroupSize) {
+ if (width == subgroupSize) {
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
} else {
+ IntegerAttr widthAttr = adaptor.getWidthAttr();
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
- laneId, adaptor.getWidth());
+ laneId, widthVal);
}
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index e882845..6bd0e2d 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -19,10 +19,18 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <cstdint>
using namespace mlir;
+static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
+ return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
+ memRefType.getRank() != 0 &&
+ !llvm::is_contained(memRefType.getShape(), 0);
+}
+
namespace {
/// Implement the interface to convert MemRef to EmitC.
struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
@@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
+struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = allocOp.getLoc();
+ MemRefType memrefType = allocOp.getType();
+ if (!isMemRefTypeLegalForEmitC(memrefType)) {
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible memref type for EmitC conversion");
+ }
+
+ Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
+ Type elementType = memrefType.getElementType();
+ IndexType indexType = rewriter.getIndexType();
+ emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
+ loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
+ ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
+
+ int64_t numElements = 1;
+ for (int64_t dimSize : memrefType.getShape()) {
+ numElements *= dimSize;
+ }
+ Value numElementsValue = rewriter.create<emitc::ConstantOp>(
+ loc, indexType, rewriter.getIndexAttr(numElements));
+
+ Value totalSizeBytes = rewriter.create<emitc::MulOp>(
+ loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
+
+ emitc::CallOpaqueOp allocCall;
+ StringAttr allocFunctionName;
+ Value alignmentValue;
+ SmallVector<Value, 2> argsVec;
+ if (allocOp.getAlignment()) {
+ allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
+ alignmentValue = rewriter.create<emitc::ConstantOp>(
+ loc, sizeTType,
+ rewriter.getIntegerAttr(indexType,
+ allocOp.getAlignment().value_or(0)));
+ argsVec.push_back(alignmentValue);
+ } else {
+ allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
+ }
+
+ argsVec.push_back(totalSizeBytes);
+ ValueRange args(argsVec);
+
+ allocCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc,
+ emitc::PointerType::get(
+ emitc::OpaqueType::get(rewriter.getContext(), "void")),
+ allocFunctionName, args);
+
+ emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
+ emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
+ loc, targetPointerType, allocCall.getResult(0));
+
+ rewriter.replaceOp(allocOp, castOp);
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion(
[&](MemRefType memRefType) -> std::optional<Type> {
- if (!memRefType.hasStaticShape() ||
- !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
- llvm::is_contained(memRefType.getShape(), 0)) {
+ if (!isMemRefTypeLegalForEmitC(memRefType)) {
return {};
}
Type convertedElementType =
@@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
- ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
+ ConvertLoad, ConvertStore>(converter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index cf25c09..e78dd76 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -15,6 +15,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -28,9 +29,11 @@ using namespace mlir;
namespace {
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+ using Base::Base;
void runOnOperation() override {
TypeConverter converter;
-
+ ConvertMemRefToEmitCOptions options;
+ options.lowerToCpp = this->lowerToCpp;
// Fallback for other types.
converter.addConversion([](Type type) -> std::optional<Type> {
if (!emitc::isSupportedEmitCType(type))
@@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
+
+ mlir::ModuleOp module = getOperation();
+ module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
+ if (callOp.getCallee() != alignedAllocFunctionName &&
+ callOp.getCallee() != mallocFunctionName) {
+ return mlir::WalkResult::advance();
+ }
+
+ for (auto &op : *module.getBody()) {
+ emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
+ if (!includeOp) {
+ continue;
+ }
+ if (includeOp.getIsStandardInclude() &&
+ ((options.lowerToCpp &&
+ includeOp.getInclude() == cppStandardLibraryHeader) ||
+ (!options.lowerToCpp &&
+ includeOp.getInclude() == cStandardLibraryHeader))) {
+ return mlir::WalkResult::interrupt();
+ }
+ }
+
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ StringAttr includeAttr =
+ builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader);
+ builder.create<mlir::emitc::IncludeOp>(
+ module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+ return mlir::WalkResult::interrupt();
+ });
}
};
} // namespace
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 4307bc6..17a79e3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1070,39 +1070,6 @@ public:
}
};
-class VectorExtractElementOpConversion
- : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
-public:
- using ConvertOpToLLVMPattern<
- vector::ExtractElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = extractEltOp.getSourceVectorType();
- auto llvmType = typeConverter->convertType(vectorType.getElementType());
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = extractEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
@@ -1206,39 +1173,6 @@ public:
}
};
-class VectorInsertElementOpConversion
- : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
-public:
- using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = insertEltOp.getDestVectorType();
- auto llvmType = typeConverter->convertType(vectorType);
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = insertEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
@@ -2244,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion, VectorScatterOpConversion>(
converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
- VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index b1af5f0..508f4e2 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -690,7 +690,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
-/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
@@ -1643,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
-/// %t = vector.extractelement %vec[i] : vector<9xf32>
+/// %t = vector.extract %vec[i] : f32 from vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 986eae3..a4be7d4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -335,63 +335,6 @@ struct VectorInsertOpConvert final
}
};
-struct VectorExtractElementOpConvert final
- : public OpConversionPattern<vector::ExtractElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type resultType = getTypeConverter()->convertType(extractOp.getType());
- if (!resultType)
- return failure();
-
- if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, resultType, adaptor.getVector(),
- rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
- else
- rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
-struct VectorInsertElementOpConvert final
- : public OpConversionPattern<vector::InsertElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type vectorType = getTypeConverter()->convertType(insertOp.getType());
- if (!vectorType)
- return failure();
-
- if (isa<spirv::ScalarType>(vectorType)) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(),
- cstPos.getSExtValue());
- else
- rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
struct VectorInsertStridedSliceOpConvert final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
@@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final
void mlir::populateVectorToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<
- VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
- VectorToElementOpConvert, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorToElementOpConvert, VectorInsertOpConvert,
+ VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
diff --git a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
index 3c00b32..6265f46 100644
--- a/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/DecomposeAffineOps.cpp
@@ -15,13 +15,13 @@
#include "mlir/Dialect/Affine/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/InterleavedRange.h"
using namespace mlir;
using namespace mlir::affine;
#define DEBUG_TYPE "decompose-affine-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
/// Count the number of loops surrounding `operand` such that operand could be
/// hoisted above.
@@ -115,7 +115,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
return rewriter.notifyMatchFailure(
op, "only add or mul binary expr can be reassociated");
- LLVM_DEBUG(DBGS() << "Start decomposeIntoFinerGrainedOps: " << op << "\n");
+ LDBG() << "Start decomposeIntoFinerGrainedOps: " << op;
// 2. Iteratively extract the RHS subexpressions while the top-level binary
// expr kind remains the same.
@@ -125,11 +125,11 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
auto currentBinExpr = dyn_cast<AffineBinaryOpExpr>(remainingExp);
if (!currentBinExpr || currentBinExpr.getKind() != binExpr.getKind()) {
subExpressions.push_back(remainingExp);
- LLVM_DEBUG(DBGS() << "--terminal: " << subExpressions.back() << "\n");
+ LDBG() << "--terminal: " << subExpressions.back();
break;
}
subExpressions.push_back(currentBinExpr.getRHS());
- LLVM_DEBUG(DBGS() << "--subExpr: " << subExpressions.back() << "\n");
+ LDBG() << "--subExpr: " << subExpressions.back();
remainingExp = currentBinExpr.getLHS();
}
@@ -146,9 +146,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
llvm::stable_sort(subExpressions, [&](AffineExpr e1, AffineExpr e2) {
return getMaxSymbol(e1) < getMaxSymbol(e2);
});
- LLVM_DEBUG(
- llvm::interleaveComma(subExpressions, DBGS() << "--sorted subexprs: ");
- llvm::dbgs() << "\n");
+ LDBG() << "--sorted subexprs: " << llvm::interleaved(subExpressions);
// 4. Merge sorted subExpressions iteratively, thus achieving reassociation.
auto s0 = getAffineSymbolExpr(0, ctx);
@@ -162,7 +160,7 @@ FailureOr<AffineApplyOp> mlir::affine::decompose(RewriterBase &rewriter,
Value tmp = createSubApply(rewriter, op, subExpressions[i]);
current = AffineApplyOp::create(rewriter, op.getLoc(), binMap,
ValueRange{current, tmp});
- LLVM_DEBUG(DBGS() << "--reassociate into: " << current << "\n");
+ LDBG() << "--reassociate into: " << current;
}
// 5. Replace original op.
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
index 8493b60..2521512 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineMinMax.cpp
@@ -19,11 +19,10 @@
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/IntEqClasses.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "affine-min-max"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
using namespace mlir::affine;
@@ -39,7 +38,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
ValueRange operands = affineOp.getOperands();
static constexpr bool isMin = std::is_same_v<AffineOp, AffineMinOp>;
- LLVM_DEBUG({ DBGS() << "analyzing value: `" << affineOp << "`\n"; });
+ LDBG() << "analyzing value: `" << affineOp;
// Create a `Variable` list with values corresponding to each of the results
// in the affine affineMap.
@@ -48,12 +47,9 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
[&](unsigned i) {
return Variable(affineMap.getSliceMap(i, 1), operands);
});
- LLVM_DEBUG({
- DBGS() << "- constructed variables are: "
- << llvm::interleaved_array(llvm::map_range(
- variables, [](const Variable &v) { return v.getMap(); }))
- << "`\n";
- });
+ LDBG() << "- constructed variables are: "
+ << llvm::interleaved_array(llvm::map_range(
+ variables, [](const Variable &v) { return v.getMap(); }));
// Get the comparison operation.
ComparisonOperator cmpOp =
@@ -72,10 +68,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Initialize the bound.
Variable *bound = &v;
- LLVM_DEBUG({
- DBGS() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
- << "`\n";
- });
+ LDBG() << "- inspecting variable: #" << i << ", with map: `" << v.getMap()
+ << "`\n";
// Check against the other variables.
for (size_t j = i + 1; j < variables.size(); ++j) {
@@ -87,10 +81,8 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Get the bound of the equivalence class or itself.
Variable *nv = bounds.lookup_or(jEqClass, &variables[j]);
- LLVM_DEBUG({
- DBGS() << "- comparing with variable: #" << jEqClass
- << ", with map: " << nv->getMap() << "\n";
- });
+ LDBG() << "- comparing with variable: #" << jEqClass
+ << ", with map: " << nv->getMap();
// Compare the variables.
FailureOr<bool> cmpResult =
@@ -98,18 +90,14 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// The variables cannot be compared.
if (failed(cmpResult)) {
- LLVM_DEBUG({
- DBGS() << "-- classes: #" << i << ", #" << jEqClass
- << " cannot be merged\n";
- });
+ LDBG() << "-- classes: #" << i << ", #" << jEqClass
+ << " cannot be merged";
continue;
}
// Join the equivalent classes and update the bound if necessary.
- LLVM_DEBUG({
- DBGS() << "-- merging classes: #" << i << ", #" << jEqClass
- << ", is cmp(lhs, rhs): " << *cmpResult << "`\n";
- });
+ LDBG() << "-- merging classes: #" << i << ", #" << jEqClass
+ << ", is cmp(lhs, rhs): " << *cmpResult << "`";
if (*cmpResult) {
boundedClasses.join(eqClass, jEqClass);
} else {
@@ -124,8 +112,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Return if there's no simplification.
if (bounds.size() >= affineMap.getNumResults()) {
- LLVM_DEBUG(
- { DBGS() << "- the affine operation couldn't get simplified\n"; });
+ LDBG() << "- the affine operation couldn't get simplified";
return false;
}
@@ -135,13 +122,11 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
for (auto [k, bound] : bounds)
results.push_back(bound->getMap().getResult(0));
- LLVM_DEBUG({
- DBGS() << "- starting from map: " << affineMap << "\n";
- DBGS() << "- creating new map with: \n";
- DBGS() << "--- dims: " << affineMap.getNumDims() << "\n";
- DBGS() << "--- syms: " << affineMap.getNumSymbols() << "\n";
- DBGS() << "--- res: " << llvm::interleaved_array(results) << "\n";
- });
+ LDBG() << "- starting from map: " << affineMap;
+ LDBG() << "- creating new map with:";
+ LDBG() << "--- dims: " << affineMap.getNumDims();
+ LDBG() << "--- syms: " << affineMap.getNumSymbols();
+ LDBG() << "--- res: " << llvm::interleaved_array(results);
affineMap =
AffineMap::get(0, affineMap.getNumSymbols() + affineMap.getNumDims(),
@@ -149,7 +134,7 @@ static bool simplifyAffineMinMaxOp(RewriterBase &rewriter, AffineOp affineOp) {
// Update the affine op.
rewriter.modifyOpInPlace(affineOp, [&]() { affineOp.setMap(affineMap); });
- LLVM_DEBUG({ DBGS() << "- simplified affine op: `" << affineOp << "`\n"; });
+ LDBG() << "- simplified affine op: `" << affineOp << "`";
return true;
}
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d1d1062..aa53f94 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -1,4 +1,5 @@
-//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
+//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
+//----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,12 +9,13 @@
//
// Module Bufferization is an extension of One-Shot Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
-// implementations for FuncOp, CallOp and ReturnOp.
+// implementations for FuncOp, CallOp and ReturnOp. Although it is named
+// Module Bufferization, it may operate on any SymbolTable.
//
-// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
-// This function analyzes the given module and determines the order of analysis
-// and bufferization: Functions that are called are processed before their
-// respective callers.
+// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
+// ...)`. This function analyzes the given op and determines the order of
+// analysis and bufferization: Functions that are called are processed before
+// their respective callers.
//
// After analyzing a FuncOp, additional information about its bbArgs is
// gathered and stored in `FuncAnalysisState`.
@@ -309,7 +311,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Return `failure()` if we are unable to retrieve the called FuncOp from
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
- ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
SymbolTableCollection &symbolTables) {
// For each FuncOp, the set of functions called by it (i.e. the union of
@@ -317,26 +319,29 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
-
- for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
- // Collect function calls and populate the caller map.
- numberCallOpsContainedInFuncOp[funcOp] = 0;
- WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
- assert(calledFunction && "could not retrieved called func::FuncOp");
- // If the called function does not have any tensors in its signature, then
- // it is not necessary to bufferize the callee before the caller.
- if (!hasTensorSignature(calledFunction))
- return WalkResult::skip();
-
- callerMap[calledFunction].insert(callOp);
- if (calledBy[calledFunction].insert(funcOp).second) {
- numberCallOpsContainedInFuncOp[funcOp]++;
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ // Collect function calls and populate the caller map.
+ numberCallOpsContainedInFuncOp[funcOp] = 0;
+ WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
+ func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
+ assert(calledFunction && "could not retrieved called func::FuncOp");
+ // If the called function does not have any tensors in its signature,
+ // then it is not necessary to bufferize the callee before the caller.
+ if (!hasTensorSignature(calledFunction))
+ return WalkResult::skip();
+
+ callerMap[calledFunction].insert(callOp);
+ if (calledBy[calledFunction].insert(funcOp).second) {
+ numberCallOpsContainedInFuncOp[funcOp]++;
+ }
+ return WalkResult::advance();
+ });
+ if (res.wasInterrupted())
+ return failure();
}
- return WalkResult::advance();
- });
- if (res.wasInterrupted())
- return failure();
+ }
}
// Iteratively remove function operations that do not call any of the
@@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}
LogicalResult
-mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
+mlir::bufferization::analyzeModuleOp(Operation *moduleOp,
OneShotAnalysisState &state,
BufferizationStatistics *statistics) {
assert(state.getOptions().bufferizeFunctionBoundaries &&
@@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
}
void mlir::bufferization::removeBufferizationAttributesInModule(
- ModuleOp moduleOp) {
- for (auto op : moduleOp.getOps<func::FuncOp>()) {
- for (BlockArgument bbArg : op.getArguments())
- removeBufferizationAttributes(bbArg);
+ Operation *moduleOp) {
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ for (BlockArgument bbArg : funcOp.getArguments())
+ removeBufferizationAttributes(bbArg);
+ }
+ }
}
}
LogicalResult mlir::bufferization::bufferizeModuleOp(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
- IRRewriter rewriter(moduleOp.getContext());
+ IRRewriter rewriter(moduleOp->getContext());
// A list of non-circular functions in the order in which they are analyzed
// and bufferized.
@@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
// Bufferize all other ops.
- for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
- // Functions were already bufferized.
- if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
- continue;
- if (failed(bufferizeOp(&op, options, state, statistics)))
- return failure();
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (mlir::Operation &op :
+ llvm::make_early_inc_range(block.getOperations())) {
+ // Functions were already bufferized.
+ if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
+ continue;
+ if (failed(bufferizeOp(&op, options, state, statistics)))
+ return failure();
+ }
+ }
}
// Post-pass cleanup of function argument attributes.
@@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index f999c93..a6159ee 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -33,7 +33,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
// analysis depending on whether function boundary bufferization is enabled or
// not.
if (options.bufferizeFunctionBoundaries) {
- if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
+ if (failed(analyzeModuleOp(op, analysisState, statistics)))
return failure();
} else {
if (failed(analyzeOp(op, analysisState, statistics)))
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d186a48..5a72ef1 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1395,40 +1395,12 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
// RotateOp
//===----------------------------------------------------------------------===//
-void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
- int32_t offset, int32_t width) {
- build(builder, result, value,
- arith::ConstantOp::create(builder, result.location,
- builder.getI32IntegerAttr(offset)),
- arith::ConstantOp::create(builder, result.location,
- builder.getI32IntegerAttr(width)));
-}
-
LogicalResult RotateOp::verify() {
- auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>();
- if (!offsetConstOp)
- return emitOpError() << "offset is not a constant value";
-
- auto offsetIntAttr =
- llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue());
-
- auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>();
- if (!widthConstOp)
- return emitOpError() << "width is not a constant value";
-
- auto widthIntAttr =
- llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue());
-
- llvm::APInt offsetValue = offsetIntAttr.getValue();
- llvm::APInt widthValue = widthIntAttr.getValue();
-
- if (!widthValue.isPowerOf2())
- return emitOpError() << "width must be a power of two";
+ uint32_t offset = getOffset();
+ uint32_t width = getWidth();
- if (offsetValue.sge(widthValue) || offsetValue.slt(0)) {
- int64_t widthValueInt = widthValue.getSExtValue();
- return emitOpError() << "offset must be in the range [0, " << widthValueInt
- << ")";
+ if (offset >= width) {
+ return emitOpError() << "offset must be in the range [0, " << width << ")";
}
return success();
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f49d9a1..73ae029 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -476,10 +476,10 @@ inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
SmallVector<unsigned, 2>(ac.begin(), ac.end()),
SmallVector<unsigned, 2>(bc.begin(), bc.end()),
SmallVector<unsigned, 2>(ra.begin(), ra.end())};
- llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
- llvm::sort(dimensions.m.begin(), dimensions.m.end());
- llvm::sort(dimensions.n.begin(), dimensions.n.end());
- llvm::sort(dimensions.k.begin(), dimensions.k.end());
+ llvm::sort(dimensions.batch);
+ llvm::sort(dimensions.m);
+ llvm::sort(dimensions.n);
+ llvm::sort(dimensions.k);
return dimensions;
}
@@ -797,12 +797,12 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
SmallVector<unsigned, 2>(depth.begin(), depth.end()),
/*strides=*/SmallVector<int64_t, 2>{},
/*dilations=*/SmallVector<int64_t, 2>{}};
- llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
- llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
- llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
- llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
- llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
- llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
+ llvm::sort(dimensions.batch);
+ llvm::sort(dimensions.outputImage);
+ llvm::sort(dimensions.outputChannel);
+ llvm::sort(dimensions.filterLoop);
+ llvm::sort(dimensions.inputChannel);
+ llvm::sort(dimensions.depth);
// Use the op carried strides/dilations attribute if present.
auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27b6617..34c63d3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -32,6 +32,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -2292,9 +2293,39 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+/// Fold back-to-back broadcasts together.
+struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
+ using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
+ if (!defBroadcastOp)
+ return failure();
+ ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
+ ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+ SmallVector<int64_t> foldedDims(dimensions);
+ Value init = broadcastOp.getInit();
+ int64_t initRank = cast<ShapedType>(init.getType()).getRank();
+ // Mapping from input dims to init dims.
+ SmallVector<int64_t> dimMap;
+ for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+ if (!llvm::is_contained(dimensions, dim))
+ dimMap.push_back(dim);
+ }
+ for (auto dim : defDimensions)
+ foldedDims.push_back(dimMap[dim]);
+
+ llvm::sort(foldedDims);
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
+ return success();
+ }
+};
+
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+ results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
}
//===----------------------------------------------------------------------===//
@@ -4622,22 +4653,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
});
}
-/// Returns true if the dimension of `sourceShape` is smaller than the dimension
-/// of the `limitShape`.
-static bool areAllInBound(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> limitShape) {
- assert(
- sourceShape.size() == limitShape.size() &&
- "expected source shape rank, and limit of the shape to have same rank");
- return llvm::all_of(
- llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
- int64_t sourceExtent = std::get<0>(it);
- int64_t limit = std::get<1>(it);
- return ShapedType::isDynamic(sourceExtent) ||
- ShapedType::isDynamic(limit) || sourceExtent <= limit;
- });
-}
-
template <typename OpTy>
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4696,11 +4711,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// represents full tiles.
RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
- if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
- return op->emitError("the shape of output is not large enough to hold the "
- "packed data. Expected at least ")
- << expectedPackedType << ", got " << packedType;
- }
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
@@ -4717,6 +4727,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
+ if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
+ packedType.getShape()))) {
+ return op->emitError("expected ")
+ << expectedPackedType << " for the packed domain value, got "
+ << packedType;
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
index c926dfb..5c8c2de 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/GPUHeuristics.cpp
@@ -11,6 +11,7 @@
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Debug.h"
#include "llvm/Support/DebugLog.h"
#include "llvm/Support/InterleavedRange.h"
#include "llvm/Support/MathExtras.h"
@@ -21,7 +22,6 @@
using namespace mlir;
#define DEBUG_TYPE "linalg-transforms"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
static Attribute linearId0(MLIRContext *ctx) {
return gpu::GPUThreadMappingAttr::get(ctx, gpu::MappingId::LinearDim0);
@@ -81,7 +81,7 @@ transform::gpu::CopyMappingInfo::CopyMappingInfo(MLIRContext *ctx,
this->threadMapping =
llvm::to_vector(ArrayRef(allThreadMappings)
.take_back(this->smallestBoundingTileSizes.size()));
- LLVM_DEBUG(this->print(DBGS()); llvm::dbgs() << "\n");
+ LDBG() << *this;
}
int64_t transform::gpu::CopyMappingInfo::maxContiguousElementsToTransfer(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
index 2fe72a3..d4a3e5f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
@@ -15,14 +15,13 @@
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InterleavedRange.h"
using namespace mlir;
#define DEBUG_TYPE "linalg-transforms"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
//===----------------------------------------------------------------------===//
// StructuredMatchOp
@@ -39,7 +38,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
return emitSilenceableError() << "expected a Linalg op";
}
// If errors are suppressed, succeed and set all results to empty lists.
- LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
+ LDBG() << "optional nested matcher expected a Linalg op";
results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
return DiagnosedSilenceableFailure::success();
}
@@ -75,8 +74,7 @@ DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
// When they are defined in this block, we additionally check if we have
// already applied the operation that defines them. If not, the
// corresponding results will be set to empty lists.
- LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
- << "\n");
+ LDBG() << "optional nested matcher failed: " << diag.getMessage();
(void)diag.silence();
SmallVector<OpOperand *> undefinedOperands;
for (OpOperand &terminatorOperand :
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 277e50b..9d7f4e0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2c62cb6..2e62523 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
return paddingSizes;
}
+/// Extracts the constant multiplier from an affine expression of the form
+/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an
+/// AffineConstantExpr. Returns 1 if the expression is not a simple
+/// multiplication of a dimension and a constant.
+static int64_t extractConstantMultiplier(AffineExpr expr) {
+ if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) {
+ if (binOp.getKind() == AffineExprKind::Mul) {
+ auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS());
+ auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS());
+ if (lhsD && rhsC) {
+ return rhsC.getValue();
+ }
+ auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS());
+ auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS());
+ if (lhsC && rhsD) {
+ return lhsC.getValue();
+ }
+ }
+ }
+ return 1;
+}
+
/// Compute the padded shape of the given value `v` of `RankedTensorType` given
/// - `indexingSizes` a list of OpFoldResult.
/// - an `indexingMap` that encodes how the shape of varies with increases
@@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
/// The implementaiton below iteratively combines increases from contributing
/// dimensions using affine.apply operations.
+/// The padded shape is computed by evaluating the maximum accessed index per
+/// dimension, which may involve multiplying by constant factors derived from
+/// the affine indexing expressions. Currently, only a limited set of projected
+/// permutation indexing maps are supported, such as
+/// - affine_map<(d0, d1, d2) -> (d0, d1)>
+/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult> linalg::computePaddedShape(
@@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
/*compressDims=*/true);
// If we are padding to the next multiple of, compose with ceil(sz) * sz.
+ OpFoldResult paddingDimOfr;
if (options.padToMultipleOf) {
AffineExpr d0, s0;
bindDims(rewriter.getContext(), d0);
bindSymbols(rewriter.getContext(), s0);
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
AffineMap composedMap = projectedMap.compose(ceilMap);
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, composedMap,
{indexingSizes[paddingDim], paddingSize},
/*composeAffineMin=*/true);
- terms.push_back(paddingDimOfr);
} else {
// Otherwise just set to paddingSize.
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, projectedMap, paddingSize);
- terms.push_back(paddingDimOfr);
}
+ // Adjust for the maximum accessed index, which is (paddingSize - 1) *
+ // multiplier.
+ AffineExpr d0;
+ bindDims(rewriter.getContext(), d0);
+ int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
+ AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
+ OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, subtractMap, {paddingDimOfr});
+ terms.push_back(maxAccessIdx);
+
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
}
@@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
AffineExpr sumExpr = dims.front();
for (unsigned i = 1; i < dims.size(); ++i)
sumExpr = sumExpr + dims[i];
- OpFoldResult paddedDimOfr =
- affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
+ // Add 1 to the maximum accessed index and get the final padded size.
+ OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, sumExpr + 1, terms);
paddedShape[resultIndex] = paddedDimOfr;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index dad3526..57b610b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -932,20 +932,6 @@ struct PackOpTiling
continue;
}
- // If the dimension needs padding, it is not supported because there are
- // iterations that only write padding values to the whole tile. The
- // consumer fusion is driven by the source, so it is not possible to map
- // an empty slice to the tile.
- bool needExtraPadding =
- ShapedType::isDynamic(destDimSize) || !cstInnerSize ||
- destDimSize * cstInnerSize.value() != srcDimSize;
- // Prioritize the case that the op already says that it does not need
- // padding.
- if (!packOp.getPaddingValue())
- needExtraPadding = false;
- if (needExtraPadding)
- return failure();
-
// Currently fusing `packOp` as consumer only expects perfect tiling
// scenario because even if without padding semantic, the `packOp` may
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0170837..793eec7 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1913,14 +1913,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
sourceShape.end());
- ReifiedRankedShapedTypeDims reifiedRetShapes;
- LogicalResult status =
- cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
- .reifyResultShapes(rewriter, reifiedRetShapes);
- if (status.failed()) {
- LDBG() << "Unable to reify result shapes of " << unpackOp;
- return failure();
- }
Location loc = unpackOp->getLoc();
auto padValue = arith::ConstantOp::create(
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 106c3b4..cce80db 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -80,10 +80,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
sourceOp.getMixedStrides(), op.getMixedSizes())) {
- // We only support static sizes.
- if (isa<Value>(opSize)) {
- return failure();
- }
sizes.push_back(opSize);
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
sourceOffsetAttr =
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index e73bdd3..9d5dfc1 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -2957,6 +2957,23 @@ bool acc::LoopOp::hasDefaultGangWorkerVector() {
getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
}
+acc::LoopParMode
+acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
+ if (hasSeq(deviceType))
+ return LoopParMode::loop_seq;
+ if (hasAuto(deviceType))
+ return LoopParMode::loop_auto;
+ if (hasIndependent(deviceType))
+ return LoopParMode::loop_independent;
+ if (hasSeq())
+ return LoopParMode::loop_seq;
+ if (hasAuto())
+ return LoopParMode::loop_auto;
+ assert(hasIndependent() &&
+ "loop must have default auto, seq, or independent");
+ return LoopParMode::loop_independent;
+}
+
void acc::LoopOp::addGangOperands(
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 81365b4..3911ec0 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -58,7 +58,17 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
}
auto varPtrType = cast<spirv::PointerType>(varType);
- auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
+ Type pointeeType = varPtrType.getPointeeType();
+
+ // Images are an opaque type and so we can just return a pointer to an image.
+ // Note that currently only sampled images are supported in the SPIR-V
+ // lowering.
+ if (isa<spirv::SampledImageType>(pointeeType))
+ return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
+ varName, abiInfo.getDescriptorSet(),
+ abiInfo.getBinding());
+
+ auto varPointeeType = cast<spirv::StructType>(pointeeType);
// Set the offset information.
varPointeeType =
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 0e96b59..869d27a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -115,8 +115,7 @@ public:
bufferization::BufferizationState bufferizationState;
- if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
- updatedOptions,
+ if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions,
bufferizationState)))
return failure();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index ecd93ff..3cafb19 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3647,6 +3647,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
return std::nullopt;
}
+static void printInitializationList(OpAsmPrinter &parser,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ parser << prefix << '(';
+ llvm::interleaveComma(
+ llvm::zip(blocksArgs, initializers), parser,
+ [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
+ parser << ")";
+}
+
// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions for 'then'.
@@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
- auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand cond;
- // Create a i1 tensor type for the boolean condition.
- Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
- if (parser.parseOperand(cond) ||
- parser.resolveOperand(cond, i1Type, result.operands))
+
+ if (parser.parseOperand(cond))
return failure();
- // Parse optional results type list.
- if (parser.parseOptionalArrowTypeList(result.types))
+
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+
+ // Parse the optional block arguments
+ OptionalParseResult listResult =
+ parser.parseOptionalAssignmentList(regionArgs, operands);
+ if (listResult.has_value() && failed(listResult.value()))
return failure();
+
+ // Parse a colon.
+ if (failed(parser.parseColon()))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Parse the type of the condition operand
+ Type condType;
+ if (failed(parser.parseType(condType)))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Resolve operand with provided type
+ if (failed(parser.resolveOperand(cond, condType, result.operands)))
+ return failure();
+
+ // Parse optional block arg types
+ if (listResult.has_value()) {
+ FunctionType functionType;
+
+ if (failed(parser.parseType(functionType)))
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected list of types for block arguments "
+ << "followed by arrow type and list of return types";
+
+ result.addTypes(functionType.getResults());
+
+ if (functionType.getNumInputs() != operands.size()) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected as many input types as operands "
+ << "(expected " << operands.size() << " got "
+ << functionType.getNumInputs() << ")";
+ }
+
+ // Resolve input operands.
+ if (failed(parser.resolveOperands(operands, functionType.getInputs(),
+ parser.getCurrentLocation(),
+ result.operands)))
+ return failure();
+ } else {
+ // Parse optional results type list.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+ }
+
// Parse the 'then' region.
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
@@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}
void IfOp::print(OpAsmPrinter &p) {
- bool printBlockTerminators = false;
-
p << " " << getCondition();
- if (!getResults().empty()) {
- p << " -> (" << getResultTypes() << ")";
- // Print yield explicitly if the op defines values.
- printBlockTerminators = true;
+
+ printInitializationList(p, getThenGraph().front().getArguments(),
+ getInputList(), " ");
+ p << " : ";
+ p << getCondition().getType();
+
+ if (!getInputList().empty()) {
+ p << " (";
+ llvm::interleaveComma(getInputList().getTypes(), p);
+ p << ")";
}
- p << ' ';
- p.printRegion(getThenGraph(),
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printArrowTypeList(getResultTypes());
+ p << " ";
+
+ p.printRegion(getThenGraph());
// Print the 'else' regions if it exists and has a block.
auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
- p.printRegion(elseRegion,
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printRegion(elseRegion);
}
p.printOptionalAttrDict((*this)->getAttrs());
@@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOptionalAttrDictWithKeyword(result.attributes));
}
-static void printInitializationList(OpAsmPrinter &parser,
- Block::BlockArgListType blocksArgs,
- ValueRange initializers,
- StringRef prefix = "") {
- assert(blocksArgs.size() == initializers.size() &&
- "expected same length of arguments and initializers");
- if (initializers.empty())
- return;
-
- parser << prefix << '(';
- llvm::interleaveComma(
- llvm::zip(blocksArgs, initializers), parser,
- [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
- parser << ")";
-}
-
void WhileOp::print(OpAsmPrinter &parser) {
printInitializationList(parser, getCondGraph().front().getArguments(),
getInputList(), " ");
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 32b5fb6..8ec7765 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) {
// })
//
// Simplified:
- // %0 = tosa.cond_if %arg2 {
- // tosa.yield %arg0
+ // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg3
// } else {
- // tosa.yield %arg1
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg4
// }
- //
- // Unfortunately, the simplified syntax does not encapsulate values
- // used in then/else regions (see 'simplified' example above), so it
- // must be rewritten to use the generic syntax in order to be conformant
- // to the specification.
+
return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
}
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index c0d20d4..14a4fdf 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -21,17 +21,8 @@
#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "transform-dialect"
-#define DEBUG_TYPE_FULL "transform-dialect-full"
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")
-#ifndef NDEBUG
-#define FULL_LDBG(X) \
- DEBUGLOG_WITH_STREAM_AND_TYPE(llvm::dbgs(), DEBUG_TYPE_FULL)
-#else
-#define FULL_LDBG(X) \
- for (bool _c = false; _c; _c = false) \
- ::llvm::nulls()
-#endif
+#define FULL_LDBG() LDBG(4)
using namespace mlir;
@@ -818,16 +809,14 @@ void transform::TransformState::compactOpHandles() {
DiagnosedSilenceableFailure
transform::TransformState::applyTransform(TransformOpInterface transform) {
- LLVM_DEBUG({
- DBGS() << "applying: ";
- transform->print(llvm::dbgs(), OpPrintingFlags().skipRegions());
- llvm::dbgs() << "\n";
- });
+ LDBG() << "applying: "
+ << OpWithFlags(transform, OpPrintingFlags().skipRegions());
FULL_LDBG() << "Top-level payload before application:\n" << *getTopLevel();
auto printOnFailureRAII = llvm::make_scope_exit([this] {
(void)this;
- LLVM_DEBUG(DBGS() << "Failing Top-level payload:\n"; getTopLevel()->print(
- llvm::dbgs(), mlir::OpPrintingFlags().printGenericOpForm()););
+ LDBG() << "Failing Top-level payload:\n"
+ << OpWithFlags(getTopLevel(),
+ OpPrintingFlags().printGenericOpForm());
});
// Set current transform op.
@@ -995,8 +984,7 @@ transform::TransformState::applyTransform(TransformOpInterface transform) {
printOnFailureRAII.release();
DEBUG_WITH_TYPE(DEBUG_PRINT_AFTER_ALL, {
- DBGS() << "Top-level payload:\n";
- getTopLevel()->print(llvm::dbgs());
+ LDBG() << "Top-level payload:\n" << *getTopLevel();
});
return result;
}
@@ -1273,7 +1261,7 @@ void transform::TrackingListener::notifyMatchFailure(
LLVM_DEBUG({
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
- DBGS() << "Match Failure : " << diag.str();
+ LDBG() << "Match Failure : " << diag.str();
});
}
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bce358d..8789f55 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1258,63 +1258,6 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
CanonicalizeContractAdd<arith::AddFOp>>(context);
}
-//===----------------------------------------------------------------------===//
-// ExtractElementOp
-//===----------------------------------------------------------------------===//
-
-void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges.front());
-}
-
-void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
- Value source) {
- result.addOperands({source});
- result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
-}
-
-LogicalResult vector::ExtractElementOp::verify() {
- VectorType vectorType = getSourceVectorType();
- if (vectorType.getRank() == 0) {
- if (getPosition())
- return emitOpError("expected position to be empty with 0-D vector");
- return success();
- }
- if (vectorType.getRank() != 1)
- return emitOpError("unexpected >1 vector rank");
- if (!getPosition())
- return emitOpError("expected position for 1-D vector");
- return success();
-}
-
-OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
- // Skip the 0-D vector here now.
- if (!adaptor.getPosition())
- return {};
-
- // Fold extractelement (splat X) -> X.
- if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
- return splat.getInput();
-
- // Fold extractelement(broadcast(X)) -> X.
- if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
- if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
- return broadcast.getSource();
-
- auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
- auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
- if (!pos || !src)
- return {};
-
- auto srcElements = src.getValues<Attribute>();
-
- uint64_t posIdx = pos.getInt();
- if (posIdx >= srcElements.size())
- return {};
-
- return srcElements[posIdx];
-}
-
// Returns `true` if `index` is either within [0, maxIndex) or equal to
// `poisonValue`.
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
@@ -3184,60 +3127,6 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
-// InsertElementOp
-//===----------------------------------------------------------------------===//
-
-void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
-}
-
-void InsertElementOp::build(OpBuilder &builder, OperationState &result,
- Value source, Value dest) {
- build(builder, result, source, dest, {});
-}
-
-LogicalResult InsertElementOp::verify() {
- auto dstVectorType = getDestVectorType();
- if (dstVectorType.getRank() == 0) {
- if (getPosition())
- return emitOpError("expected position to be empty with 0-D vector");
- return success();
- }
- if (dstVectorType.getRank() != 1)
- return emitOpError("unexpected >1 vector rank");
- if (!getPosition())
- return emitOpError("expected position for 1-D vector");
- return success();
-}
-
-OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
- // Skip the 0-D vector here.
- if (!adaptor.getPosition())
- return {};
-
- auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
- auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
- auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
- if (!src || !dst || !pos)
- return {};
-
- if (src.getType() != getDestVectorType().getElementType())
- return {};
-
- auto dstElements = dst.getValues<Attribute>();
-
- SmallVector<Attribute> results(dstElements);
-
- uint64_t posIdx = pos.getInt();
- if (posIdx >= results.size())
- return {};
- results[posIdx] = src;
-
- return DenseElementsAttr::get(getDestVectorType(), results);
-}
-
-//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8de87fe..c51c7b7 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -1005,26 +1005,39 @@ struct ReorderElementwiseOpsOnBroadcast final
"might be a scalar");
}
- // Get the type of the lhs operand
- auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
- if (!lhsBcastOrSplat ||
- !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
+ // Get the type of the first non-constant operand
+ Operation *firstBroadcastOrSplat = nullptr;
+ for (Value operand : op->getOperands()) {
+ Operation *definingOp = operand.getDefiningOp();
+ if (!definingOp)
+ return failure();
+ if (definingOp->hasTrait<OpTrait::ConstantLike>())
+ continue;
+ if (!isa<vector::BroadcastOp, vector::SplatOp>(*definingOp))
+ return failure();
+ firstBroadcastOrSplat = definingOp;
+ break;
+ }
+ if (!firstBroadcastOrSplat)
return failure();
- auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+ Type firstBroadcastOrSplatType =
+ firstBroadcastOrSplat->getOperand(0).getType();
// Make sure that all operands are broadcast from identical types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
- auto bcast = val.getDefiningOp<vector::BroadcastOp>();
- if (bcast)
- return (bcast.getOperand().getType() == lhsBcastOrSplatType);
- auto splat = val.getDefiningOp<vector::SplatOp>();
- if (splat)
- return (splat.getOperand().getType() == lhsBcastOrSplatType);
- return false;
- })) {
+ if (!llvm::all_of(
+ op->getOperands(), [&firstBroadcastOrSplatType](Value val) {
+ if (auto bcastOp = val.getDefiningOp<vector::BroadcastOp>())
+ return (bcastOp.getOperand().getType() ==
+ firstBroadcastOrSplatType);
+ if (auto splatOp = val.getDefiningOp<vector::SplatOp>())
+ return (splatOp.getOperand().getType() ==
+ firstBroadcastOrSplatType);
+ SplatElementsAttr splatConst;
+ return matchPattern(val, m_Constant(&splatConst));
+ })) {
return failure();
}
@@ -1032,13 +1045,28 @@ struct ReorderElementwiseOpsOnBroadcast final
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
- srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ SplatElementsAttr splatConst;
+ if (matchPattern(operand, m_Constant(&splatConst))) {
+ Attribute newConst;
+ if (auto shapedTy = dyn_cast<ShapedType>(firstBroadcastOrSplatType)) {
+ newConst = splatConst.resizeSplat(shapedTy);
+ } else {
+ newConst = splatConst.getSplatValue<Attribute>();
+ }
+ Operation *newConstOp =
+ operand.getDefiningOp()->getDialect()->materializeConstant(
+ rewriter, newConst, firstBroadcastOrSplatType,
+ operand.getLoc());
+ srcValues.push_back(newConstOp->getResult(0));
+ } else {
+ srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ }
}
// Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
- lhsBcastOrSplatType, op->getAttrs());
+ firstBroadcastOrSplatType, op->getAttrs());
// Replace the original Op with the elementwise Op
auto vectorType = op->getResultTypes()[0];
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 3e33795..776b5c6 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -821,15 +821,7 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
(void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1));
- // Register a handler to verify the diagnostics.
- setHandler([&](Diagnostic &diag) {
- // Process the main diagnostics.
- process(diag);
-
- // Process each of the notes.
- for (auto &note : diag.getNotes())
- process(note);
- });
+ registerInContext(ctx);
}
SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
@@ -862,6 +854,17 @@ LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
return impl->status;
}
+void SourceMgrDiagnosticVerifierHandler::registerInContext(MLIRContext *ctx) {
+ ctx->getDiagEngine().registerHandler([&](Diagnostic &diag) {
+ // Process the main diagnostics.
+ process(diag);
+
+ // Process each of the notes.
+ for (auto &note : diag.getNotes())
+ process(note);
+ });
+}
+
/// Process a single diagnostic.
void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
return process(diag.getLocation(), diag.str(), diag.getSeverity());
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 5c98417..9332f55 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -156,6 +156,11 @@ void RewriterBase::eraseOp(Operation *op) {
assert(op->use_empty() && "expected 'op' to have no uses");
auto *rewriteListener = dyn_cast_if_present<Listener>(listener);
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
// Fast path: If no listener is attached, the op can be dropped in one go.
if (!rewriteListener) {
op->erase();
@@ -320,6 +325,11 @@ void RewriterBase::inlineBlockBefore(Block *source, Block *dest,
moveOpBefore(&source->front(), dest, before);
}
+ // If the current insertion point is within the source block, adjust the
+ // insertion point to the destination block.
+ if (getInsertionBlock() == source)
+ setInsertionPoint(dest, getInsertionPoint());
+
// Erase the source block.
assert(source->empty() && "expected 'source' to be empty");
eraseBlock(source);
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 0db9808..7094c8e 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -901,7 +901,7 @@ LogicalResult PassManager::run(Operation *op) {
if (failed(initialize(context, impl->initializationGeneration + 1)))
return failure();
initializationKey = newInitKey;
- pipelineKey = pipelineInitializationKey;
+ pipelineInitializationKey = pipelineKey;
}
// Construct a top level analysis manager for the pipeline.
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
new file mode 100644
index 0000000..7a345ed
--- /dev/null
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -0,0 +1,207 @@
+//===- RegisterAllDialects.cpp - MLIR Dialects Registration -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all dialects and
+// passes to the system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllDialects.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
+#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Target/LLVM/NVVM/Target.h"
+#include "mlir/Target/LLVM/ROCDL/Target.h"
+#include "mlir/Target/SPIRV/Target.h"
+
+/// Add all the MLIR dialects to the provided registry.
+void mlir::registerAllDialects(DialectRegistry &registry) {
+ // clang-format off
+ registry.insert<acc::OpenACCDialect,
+ affine::AffineDialect,
+ amdgpu::AMDGPUDialect,
+ amx::AMXDialect,
+ arith::ArithDialect,
+ arm_neon::ArmNeonDialect,
+ arm_sme::ArmSMEDialect,
+ arm_sve::ArmSVEDialect,
+ async::AsyncDialect,
+ bufferization::BufferizationDialect,
+ cf::ControlFlowDialect,
+ complex::ComplexDialect,
+ DLTIDialect,
+ emitc::EmitCDialect,
+ func::FuncDialect,
+ gpu::GPUDialect,
+ index::IndexDialect,
+ irdl::IRDLDialect,
+ linalg::LinalgDialect,
+ LLVM::LLVMDialect,
+ math::MathDialect,
+ memref::MemRefDialect,
+ shard::ShardDialect,
+ ml_program::MLProgramDialect,
+ mpi::MPIDialect,
+ nvgpu::NVGPUDialect,
+ NVVM::NVVMDialect,
+ omp::OpenMPDialect,
+ pdl::PDLDialect,
+ pdl_interp::PDLInterpDialect,
+ ptr::PtrDialect,
+ quant::QuantDialect,
+ ROCDL::ROCDLDialect,
+ scf::SCFDialect,
+ shape::ShapeDialect,
+ smt::SMTDialect,
+ sparse_tensor::SparseTensorDialect,
+ spirv::SPIRVDialect,
+ tensor::TensorDialect,
+ tosa::TosaDialect,
+ transform::TransformDialect,
+ ub::UBDialect,
+ vector::VectorDialect,
+ x86vector::X86VectorDialect,
+ xegpu::XeGPUDialect,
+ xevm::XeVMDialect>();
+ // clang-format on
+
+ // Register all external models.
+ affine::registerValueBoundsOpInterfaceExternalModels(registry);
+ arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ arith::registerBufferizableOpInterfaceExternalModels(registry);
+ arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
+ arith::registerShardingInterfaceExternalModels(registry);
+ arith::registerValueBoundsOpInterfaceExternalModels(registry);
+ bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+ registry);
+ builtin::registerCastOpInterfaceExternalModels(registry);
+ cf::registerBufferizableOpInterfaceExternalModels(registry);
+ cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ gpu::registerValueBoundsOpInterfaceExternalModels(registry);
+ LLVM::registerInlinerInterface(registry);
+ NVVM::registerInlinerInterface(registry);
+ linalg::registerAllDialectInterfaceImplementations(registry);
+ linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ memref::registerAllocationOpInterfaceExternalModels(registry);
+ memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
+ memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ memref::registerValueBoundsOpInterfaceExternalModels(registry);
+ memref::registerMemorySlotExternalModels(registry);
+ ml_program::registerBufferizableOpInterfaceExternalModels(registry);
+ scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ scf::registerBufferizableOpInterfaceExternalModels(registry);
+ scf::registerValueBoundsOpInterfaceExternalModels(registry);
+ shape::registerBufferizableOpInterfaceExternalModels(registry);
+ sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
+ tensor::registerBufferizableOpInterfaceExternalModels(registry);
+ tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
+ tensor::registerInferTypeOpInterfaceExternalModels(registry);
+ tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ tensor::registerSubsetOpInterfaceExternalModels(registry);
+ tensor::registerTilingInterfaceExternalModels(registry);
+ tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+ tosa::registerShardingInterfaceExternalModels(registry);
+ vector::registerBufferizableOpInterfaceExternalModels(registry);
+ vector::registerSubsetOpInterfaceExternalModels(registry);
+ vector::registerValueBoundsOpInterfaceExternalModels(registry);
+ NVVM::registerNVVMTargetInterfaceExternalModels(registry);
+ ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
+ spirv::registerSPIRVTargetInterfaceExternalModels(registry);
+}
+
+/// Append all the MLIR dialects to the registry contained in the given context.
+void mlir::registerAllDialects(MLIRContext &context) {
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
new file mode 100644
index 0000000..8f7c67c
--- /dev/null
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -0,0 +1,115 @@
+//===- RegisterAllExtensions.cpp - MLIR Extension Registration --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all dialect
+// extensions to the system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllExtensions.h"
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
+#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
+#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
+#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
+#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
+#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
+#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
+#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
+#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
+#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
+
+/// This function may be called to register all MLIR dialect extensions with the
+/// provided registry.
+/// If you're building a compiler, you generally shouldn't use this: you would
+/// individually register the specific extensions that are useful for the
+/// pipelines and transformations you are using.
+void mlir::registerAllExtensions(DialectRegistry &registry) {
+ // Register all conversions to LLVM extensions.
+ registerConvertArithToEmitCInterface(registry);
+ arith::registerConvertArithToLLVMInterface(registry);
+ registerConvertComplexToLLVMInterface(registry);
+ cf::registerConvertControlFlowToLLVMInterface(registry);
+ func::registerAllExtensions(registry);
+ tensor::registerAllExtensions(registry);
+ registerConvertFuncToEmitCInterface(registry);
+ registerConvertFuncToLLVMInterface(registry);
+ index::registerConvertIndexToLLVMInterface(registry);
+ registerConvertMathToLLVMInterface(registry);
+ mpi::registerConvertMPIToLLVMInterface(registry);
+ registerConvertMemRefToEmitCInterface(registry);
+ registerConvertMemRefToLLVMInterface(registry);
+ registerConvertNVVMToLLVMInterface(registry);
+ registerConvertOpenMPToLLVMInterface(registry);
+ registerConvertSCFToEmitCInterface(registry);
+ ub::registerConvertUBToLLVMInterface(registry);
+ registerConvertAMXToLLVMInterface(registry);
+ gpu::registerConvertGpuToLLVMInterface(registry);
+ NVVM::registerConvertGpuToNVVMInterface(registry);
+ vector::registerConvertVectorToLLVMInterface(registry);
+ registerConvertXeVMToLLVMInterface(registry);
+
+ // Register all transform dialect extensions.
+ affine::registerTransformDialectExtension(registry);
+ bufferization::registerTransformDialectExtension(registry);
+ dlti::registerTransformDialectExtension(registry);
+ func::registerTransformDialectExtension(registry);
+ gpu::registerTransformDialectExtension(registry);
+ linalg::registerTransformDialectExtension(registry);
+ memref::registerTransformDialectExtension(registry);
+ nvgpu::registerTransformDialectExtension(registry);
+ scf::registerTransformDialectExtension(registry);
+ sparse_tensor::registerTransformDialectExtension(registry);
+ tensor::registerTransformDialectExtension(registry);
+ transform::registerDebugExtension(registry);
+ transform::registerIRDLExtension(registry);
+ transform::registerLoopExtension(registry);
+ transform::registerPDLExtension(registry);
+ transform::registerTuneExtension(registry);
+ vector::registerTransformDialectExtension(registry);
+ arm_neon::registerTransformDialectExtension(registry);
+ arm_sve::registerTransformDialectExtension(registry);
+
+ // Translation extensions need to be registered by calling
+ // `registerAllToLLVMIRTranslations` (see All.h).
+}
diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp
new file mode 100644
index 0000000..1ed3a37
--- /dev/null
+++ b/mlir/lib/RegisterAllPasses.cpp
@@ -0,0 +1,99 @@
+//===- RegisterAllPasses.cpp - MLIR Registration ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all passes to the
+// system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllPasses.h"
+
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/Func/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/Pipelines/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/Shard/Transforms/Passes.h"
+#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Transforms/Passes.h"
+
+// This function may be called to register the MLIR passes with the
+// global registry.
+// If you're building a compiler, you likely don't need this: you would build a
+// pipeline programmatically without the need to register with the global
+// registry, since it would already be calling the creation routine of the
+// individual passes.
+// The global registry is interesting to interact with the command-line tools.
+void mlir::registerAllPasses() {
+ // General passes
+ registerTransformsPasses();
+
+ // Conversion passes
+ registerConversionPasses();
+
+ // Dialect passes
+ acc::registerOpenACCPasses();
+ affine::registerAffinePasses();
+ amdgpu::registerAMDGPUPasses();
+ registerAsyncPasses();
+ arith::registerArithPasses();
+ bufferization::registerBufferizationPasses();
+ func::registerFuncPasses();
+ registerGPUPasses();
+ registerLinalgPasses();
+ registerNVGPUPasses();
+ registerSparseTensorPasses();
+ LLVM::registerLLVMPasses();
+ math::registerMathPasses();
+ memref::registerMemRefPasses();
+ shard::registerShardPasses();
+ ml_program::registerMLProgramPasses();
+ quant::registerQuantPasses();
+ registerSCFPasses();
+ registerShapePasses();
+ spirv::registerSPIRVPasses();
+ tensor::registerTensorPasses();
+ tosa::registerTosaOptPasses();
+ transform::registerTransformPasses();
+ vector::registerVectorPasses();
+ arm_sme::registerArmSMEPasses();
+ arm_sve::registerArmSVEPasses();
+ emitc::registerEmitCPasses();
+ xegpu::registerXeGPUPasses();
+
+ // Dialect pipelines
+ bufferization::registerBufferizationPipelines();
+ sparse_tensor::registerSparseTensorPipelines();
+ tosa::registerTosaToLinalgPipelines();
+ gpu::registerGPUToNVVMPipeline();
+}
diff --git a/mlir/lib/Support/ToolUtilities.cpp b/mlir/lib/Support/ToolUtilities.cpp
index 748f928..2cf33eb 100644
--- a/mlir/lib/Support/ToolUtilities.cpp
+++ b/mlir/lib/Support/ToolUtilities.cpp
@@ -14,6 +14,8 @@
#include "mlir/Support/LLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
+#include <string>
+#include <utility>
using namespace mlir;
@@ -22,18 +24,18 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
ChunkBufferHandler processChunkBuffer,
raw_ostream &os, llvm::StringRef inputSplitMarker,
llvm::StringRef outputSplitMarker) {
+ llvm::MemoryBufferRef originalBufferRef = originalBuffer->getMemBufferRef();
// If splitting is disabled, we process the full input buffer.
if (inputSplitMarker.empty())
- return processChunkBuffer(std::move(originalBuffer), os);
+ return processChunkBuffer(std::move(originalBuffer), originalBufferRef, os);
const int inputSplitMarkerLen = inputSplitMarker.size();
- auto *origMemBuffer = originalBuffer.get();
SmallVector<StringRef, 8> rawSourceBuffers;
const int checkLen = 2;
// Split dropping the last checkLen chars to enable flagging near misses.
- origMemBuffer->getBuffer().split(rawSourceBuffers,
- inputSplitMarker.drop_back(checkLen));
+ originalBufferRef.getBuffer().split(rawSourceBuffers,
+ inputSplitMarker.drop_back(checkLen));
if (rawSourceBuffers.empty())
return success();
@@ -79,11 +81,17 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
auto interleaveFn = [&](StringRef subBuffer) {
auto splitLoc = SMLoc::getFromPointer(subBuffer.data());
unsigned splitLine = fileSourceMgr.getLineAndColumn(splitLoc).first;
- auto subMemBuffer = llvm::MemoryBuffer::getMemBufferCopy(
- subBuffer, Twine("within split at ") +
- origMemBuffer->getBufferIdentifier() + ":" +
- Twine(splitLine) + " offset ");
- if (failed(processChunkBuffer(std::move(subMemBuffer), os)))
+ std::string name((Twine("within split at ") +
+ originalBufferRef.getBufferIdentifier() + ":" +
+ Twine(splitLine) + " offset ")
+ .str());
+ // Use MemoryBufferRef to avoid copying the buffer & keep at same location
+ // relative to the original buffer.
+ auto subMemBuffer =
+ llvm::MemoryBuffer::getMemBuffer(llvm::MemoryBufferRef(subBuffer, name),
+ /*RequiresNullTerminator=*/false);
+ if (failed(
+ processChunkBuffer(std::move(subMemBuffer), originalBufferRef, os)))
hadFailure = true;
};
llvm::interleave(sourceBuffers, os, interleaveFn,
@@ -92,3 +100,16 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
// If any fails, then return a failure of the tool.
return failure(hadFailure);
}
+
+LogicalResult
+mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
+ NoSourceChunkBufferHandler processChunkBuffer,
+ raw_ostream &os, llvm::StringRef inputSplitMarker,
+ llvm::StringRef outputSplitMarker) {
+ auto process = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
+ const llvm::MemoryBufferRef &, raw_ostream &os) {
+ return processChunkBuffer(std::move(chunkBuffer), os);
+ };
+ return splitAndProcessBuffer(std::move(originalBuffer), process, os,
+ inputSplitMarker, outputSplitMarker);
+}
diff --git a/mlir/lib/Support/TypeID.cpp b/mlir/lib/Support/TypeID.cpp
index 01ad910..304253c 100644
--- a/mlir/lib/Support/TypeID.cpp
+++ b/mlir/lib/Support/TypeID.cpp
@@ -27,9 +27,6 @@ namespace {
struct ImplicitTypeIDRegistry {
/// Lookup or insert a TypeID for the given type name.
TypeID lookupOrInsert(StringRef typeName) {
- LLVM_DEBUG(llvm::dbgs() << "ImplicitTypeIDRegistry::lookupOrInsert("
- << typeName << ")\n");
-
// Perform a heuristic check to see if this type is in an anonymous
// namespace. String equality is not valid for anonymous types, so we try to
// abort whenever we see them.
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index af22a7f..9ea5c683 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
MLIRROCDLToLLVMIRTranslation
MLIRSPIRVToLLVMIRTranslation
MLIRVCIXToLLVMIRTranslation
+ MLIRXeVMToLLVMIRTranslation
)
add_mlir_translation_library(MLIRTargetLLVMIRImport
diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index f030fa7..86c731a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -10,3 +10,4 @@ add_subdirectory(OpenMP)
add_subdirectory(ROCDL)
add_subdirectory(SPIRV)
add_subdirectory(VCIX)
+add_subdirectory(XeVM)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f18199..49e1e55 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3877,29 +3877,28 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
llvm::SmallVector<size_t> indices(indexAttr.size());
std::iota(indices.begin(), indices.end(), 0);
- llvm::sort(indices.begin(), indices.end(),
- [&](const size_t a, const size_t b) {
- auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
- auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
- for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
- int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
- int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
-
- if (aIndex == bIndex)
- continue;
-
- if (aIndex < bIndex)
- return first;
-
- if (aIndex > bIndex)
- return !first;
- }
-
- // Iterated the up until the end of the smallest member and
- // they were found to be equal up to that point, so select
- // the member with the lowest index count, so the "parent"
- return memberIndicesA.size() < memberIndicesB.size();
- });
+ llvm::sort(indices, [&](const size_t a, const size_t b) {
+ auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
+ auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
+ for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
+ int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
+ int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
+
+ if (aIndex == bIndex)
+ continue;
+
+ if (aIndex < bIndex)
+ return first;
+
+ if (aIndex > bIndex)
+ return !first;
+ }
+
+ // Iterated the up until the end of the smallest member and
+ // they were found to be equal up to that point, so select
+ // the member with the lowest index count, so the "parent"
+ return memberIndicesA.size() < memberIndicesB.size();
+ });
return llvm::cast<omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
new file mode 100644
index 0000000..6308d7e
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+set(LLVM_OPTIONAL_SOURCES
+ XeVMToLLVMIRTranslation.cpp
+)
+
+add_mlir_translation_library(MLIRXeVMToLLVMIRTranslation
+ XeVMToLLVMIRTranslation.cpp
+
+ DEPENDS
+ MLIRXeVMConversionsIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRSupport
+ MLIRTargetLLVMIRExport
+)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
new file mode 100644
index 0000000..73b166d
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -0,0 +1,103 @@
+//===-- XeVMToLLVMIRTranslation.cpp - Translate XeVM to LLVM IR -*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR XeVM dialect and
+// LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+
+#include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the XeVM dialect to LLVM IR.
+class XeVMDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ /// Attaches module-level metadata for functions marked as kernels.
+ LogicalResult
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ StringRef attrName = attribute.getName().getValue();
+ if (attrName == mlir::xevm::XeVMDialect::getCacheControlsAttrName()) {
+ auto cacheControlsArray = dyn_cast<ArrayAttr>(attribute.getValue());
+ if (cacheControlsArray.size() != 2) {
+ return op->emitOpError(
+ "Expected both L1 and L3 cache control attributes!");
+ }
+ if (instructions.size() != 1) {
+ return op->emitOpError("Expecting a single instruction");
+ }
+ return handleDecorationCacheControl(instructions.front(),
+ cacheControlsArray.getValue());
+ }
+ auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!func)
+ return failure();
+
+ return success();
+ }
+
+private:
+ static LogicalResult handleDecorationCacheControl(llvm::Instruction *inst,
+ ArrayRef<Attribute> attrs) {
+ SmallVector<llvm::Metadata *> decorations;
+ llvm::LLVMContext &ctx = inst->getContext();
+ llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx);
+ llvm::transform(
+ attrs, std::back_inserter(decorations),
+ [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * {
+ auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue();
+ std::array<llvm::Metadata *, 4> metadata;
+ llvm::transform(
+ valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) {
+ return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
+ i32Ty, cast<IntegerAttr>(valueAttr).getValue()));
+ });
+ return llvm::MDNode::get(ctx, metadata);
+ });
+ constexpr llvm::StringLiteral decorationCacheControlMDName =
+ "spirv.DecorationCacheControlINTEL";
+ inst->setMetadata(decorationCacheControlMDName,
+ llvm::MDNode::get(ctx, decorations));
+ return success();
+ }
+};
+} // namespace
+
+void mlir::registerXeVMDialectTranslation(::mlir::DialectRegistry &registry) {
+ registry.insert<xevm::XeVMDialect>();
+ registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
+ dialect->addInterfaces<XeVMDialectLLVMIRTranslationInterface>();
+ });
+}
+
+void mlir::registerXeVMDialectTranslation(::mlir::MLIRContext &context) {
+ DialectRegistry registry;
+ registerXeVMDialectTranslation(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 58e5353..a8a2b2e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -446,6 +446,19 @@ LogicalResult Serializer::processType(Location loc, Type type,
LogicalResult
Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
SetVector<StringRef> &serializationCtx) {
+
+ // Map unsigned integer types to singless integer types.
+ // This is needed otherwise the generated spirv assembly will contain
+ // twice a type declaration (like OpTypeInt 32 0) which is no permitted and
+ // such module fails validation. Indeed at MLIR level the two types are
+ // different and lookup in the cache below misses.
+ // Note: This conversion needs to happen here before the type is looked up in
+ // the cache.
+ if (type.isUnsignedInteger()) {
+ type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(),
+ IntegerType::SignednessSemantics::Signless);
+ }
+
typeID = getTypeID(type);
if (typeID)
return success();
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 8f78590..bdcdaa4 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -508,13 +508,20 @@ performActions(raw_ostream &os,
/// Parses the memory buffer. If successfully, run a series of passes against
/// it and print the result.
-static LogicalResult processBuffer(raw_ostream &os,
- std::unique_ptr<MemoryBuffer> ownedBuffer,
- const MlirOptMainConfig &config,
- DialectRegistry &registry,
- llvm::ThreadPoolInterface *threadPool) {
+static LogicalResult
+processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
+ llvm::MemoryBufferRef sourceBuffer,
+ const MlirOptMainConfig &config, DialectRegistry &registry,
+ SourceMgrDiagnosticVerifierHandler *verifyHandler,
+ llvm::ThreadPoolInterface *threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
auto sourceMgr = std::make_shared<SourceMgr>();
+ // Add the original buffer to the source manager to use for determining
+ // locations.
+ sourceMgr->AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBuffer(sourceBuffer,
+ /*RequiresNullTerminator=*/false),
+ SMLoc());
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
// Create a context just for the current buffer. Disable threading on creation
@@ -522,6 +529,8 @@ static LogicalResult processBuffer(raw_ostream &os,
MLIRContext context(registry, MLIRContext::Threading::DISABLED);
if (threadPool)
context.setThreadPool(*threadPool);
+ if (verifyHandler)
+ verifyHandler->registerInContext(&context);
StringRef irdlFile = config.getIrdlFile();
if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context)))
@@ -545,17 +554,12 @@ static LogicalResult processBuffer(raw_ostream &os,
return performActions(os, sourceMgr, &context, config);
}
- SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
- *sourceMgr, &context, config.verifyDiagnosticsLevel());
-
// Do any processing requested by command line flags. We don't care whether
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
(void)performActions(os, sourceMgr, &context, config);
- // Verify the diagnostic handler to make sure that each of the diagnostics
- // matched.
- return sourceMgrHandler.verify();
+ return success();
}
std::pair<std::string, std::string>
@@ -624,14 +628,31 @@ LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
if (threadPoolCtx.isMultithreadingEnabled())
threadPool = &threadPoolCtx.getThreadPool();
+ SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(),
+ /*RequiresNullTerminator=*/false),
+ SMLoc());
+ // Note: this creates a verifier handler independent of the the flag set, as
+ // internally if the flag is not set, a new scoped diagnostic handler is
+ // created which would intercept the diagnostics and verify them.
+ SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
+ sourceMgr, &threadPoolCtx, config.verifyDiagnosticsLevel());
auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
- raw_ostream &os) {
- return processBuffer(os, std::move(chunkBuffer), config, registry,
- threadPool);
+ llvm::MemoryBufferRef sourceBuffer, raw_ostream &os) {
+ return processBuffer(
+ os, std::move(chunkBuffer), sourceBuffer, config, registry,
+ config.shouldVerifyDiagnostics() ? &sourceMgrHandler : nullptr,
+ threadPool);
};
- return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
- config.inputSplitMarker(),
- config.outputSplitMarker());
+ LogicalResult status = splitAndProcessBuffer(
+ llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(),
+ /*RequiresNullTerminator=*/false),
+ chunkFn, outputStream, config.inputSplitMarker(),
+ config.outputSplitMarker());
+ if (config.shouldVerifyDiagnostics() && failed(sourceMgrHandler.verify()))
+ status = failure();
+ return status;
}
LogicalResult mlir::MlirOptMain(int argc, char **argv,
diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
index c11cb8d..e1c8afb 100644
--- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
+++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
@@ -135,6 +135,13 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) {
+ // Many of the translations expect a null-terminated buffer while splitting
+ // the buffer does not guarantee null-termination. Make a copy of the buffer
+ // to ensure null-termination.
+ if (!ownedBuffer->getBuffer().ends_with('\0')) {
+ ownedBuffer = llvm::MemoryBuffer::getMemBufferCopy(
+ ownedBuffer->getBuffer(), ownedBuffer->getBufferIdentifier());
+ }
// Temporary buffers for chained translation processing.
std::string dataIn;
std::string dataOut;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index df255cf..08803e0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -17,6 +17,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
@@ -1700,6 +1701,7 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
});
assert(!wasOpReplaced(newParentOp) &&
"attempting to insert into a region within a replaced/erased op");
+ (void)newParentOp;
patternInsertedBlocks.insert(block);
@@ -1758,6 +1760,12 @@ void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
SmallVector<SmallVector<Value>> newVals =
llvm::map_to_vector(newValues, [](Value v) -> SmallVector<Value> {
return v ? SmallVector<Value>{v} : SmallVector<Value>();
@@ -1773,6 +1781,12 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
impl->logger.startLine()
<< "** Replace : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
impl->replaceOp(op, std::move(newValues));
}
@@ -1781,6 +1795,12 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
impl->logger.startLine()
<< "** Erase : '" << op->getName() << "'(" << op << ")\n";
});
+
+ // If the current insertion point is before the erased operation, we adjust
+ // the insertion point to be after the operation.
+ if (getInsertionPoint() == op->getIterator())
+ setInsertionPointAfter(op);
+
SmallVector<SmallVector<Value>> nullRepls(op->getNumResults(), {});
impl->replaceOp(op, std::move(nullRepls));
}
@@ -1887,6 +1907,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
moveOpBefore(&source->front(), dest, before);
}
+ // If the current insertion point is within the source block, adjust the
+ // insertion point to the destination block.
+ if (getInsertionBlock() == source)
+ setInsertionPoint(dest, getInsertionPoint());
+
// Erase the source block.
eraseBlock(source);
}
@@ -2216,23 +2241,39 @@ OperationLegalizer::legalizeWithFold(Operation *op,
rewriterImpl.logger.startLine() << "* Fold {\n";
rewriterImpl.logger.indent();
});
- (void)rewriterImpl;
+
+ // Clear pattern state, so that the next pattern application starts with a
+ // clean slate. (The op/block sets are populated by listener notifications.)
+ auto cleanup = llvm::make_scope_exit([&]() {
+ rewriterImpl.patternNewOps.clear();
+ rewriterImpl.patternModifiedOps.clear();
+ rewriterImpl.patternInsertedBlocks.clear();
+ });
+
+ // Upon failure, undo all changes made by the folder.
+ RewriterState curState = rewriterImpl.getCurrentState();
// Try to fold the operation.
StringRef opName = op->getName().getStringRef();
SmallVector<Value, 2> replacementValues;
SmallVector<Operation *, 2> newOps;
rewriter.setInsertionPoint(op);
+ rewriter.startOpModification(op);
if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
+ rewriter.cancelOpModification(op);
return failure();
}
+ rewriter.finalizeOpModification(op);
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);
+ // Insert a replacement for 'op' with the folded replacement values.
+ rewriter.replaceOp(op, replacementValues);
+
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
if (failed(legalize(newOp, rewriter))) {
@@ -2245,16 +2286,12 @@ OperationLegalizer::legalizeWithFold(Operation *op,
"op '" + opName +
"' folder rollback of IR modifications requested");
}
- // Legalization failed: erase all materialized constants.
- for (Operation *op : newOps)
- rewriter.eraseOp(op);
+ rewriterImpl.resetState(
+ curState, std::string(op->getName().getStringRef()) + " folder");
return failure();
}
}
- // Insert a replacement for 'op' with the folded replacement values.
- rewriter.replaceOp(op, replacementValues);
-
LLVM_DEBUG(logSuccess(rewriterImpl.logger, ""));
return success();
}
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index b639e87f..26c965c 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -21,7 +21,7 @@
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/SCCIterator.h"
#include "llvm/ADT/STLExtras.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "inlining"
@@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
// InlinerInterfaceImpl
//===----------------------------------------------------------------------===//
-#ifndef NDEBUG
static std::string getNodeName(CallOpInterface op) {
if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
return debugString(op);
return "_unnamed_callee_";
}
-#endif
/// Return true if the specified `inlineHistoryID` indicates an inline history
/// that already includes `node`.
@@ -614,10 +612,10 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
std::vector<InlineHistoryT> callHistory(calls.size(), InlineHistoryT{});
LLVM_DEBUG({
- llvm::dbgs() << "* Inliner: Initial calls in SCC are: {\n";
+ LDBG() << "* Inliner: Initial calls in SCC are: {";
for (unsigned i = 0, e = calls.size(); i < e; ++i)
- llvm::dbgs() << " " << i << ". " << calls[i].call << ",\n";
- llvm::dbgs() << "}\n";
+ LDBG() << " " << i << ". " << calls[i].call << ",";
+ LDBG() << "}";
});
// Try to inline each of the call operations. Don't cache the end iterator
@@ -635,9 +633,9 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
CallOpInterface call = it.call;
LLVM_DEBUG({
if (doInline)
- llvm::dbgs() << "* Inlining call: " << i << ". " << call << "\n";
+ LDBG() << "* Inlining call: " << i << ". " << call;
else
- llvm::dbgs() << "* Not inlining call: " << i << ". " << call << "\n";
+ LDBG() << "* Not inlining call: " << i << ". " << call;
});
if (!doInline)
continue;
@@ -654,7 +652,7 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
cast<CallableOpInterface>(targetRegion->getParentOp()),
targetRegion, /*shouldCloneInlinedRegion=*/!inlineInPlace);
if (failed(inlineResult)) {
- LLVM_DEBUG(llvm::dbgs() << "** Failed to inline\n");
+ LDBG() << "** Failed to inline";
continue;
}
inlinedAnyCalls = true;
@@ -667,19 +665,16 @@ Inliner::Impl::inlineCallsInSCC(InlinerInterfaceImpl &inlinerIface,
auto historyToString = [](InlineHistoryT h) {
return h.has_value() ? std::to_string(*h) : "root";
};
- (void)historyToString;
- LLVM_DEBUG(llvm::dbgs()
- << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
- << getNodeName(call) << ", " << historyToString(inlineHistoryID)
- << "]\n");
+ LDBG() << "* new inlineHistory entry: " << newInlineHistoryID << ". ["
+ << getNodeName(call) << ", " << historyToString(inlineHistoryID)
+ << "]";
for (unsigned k = prevSize; k != calls.size(); ++k) {
callHistory.push_back(newInlineHistoryID);
- LLVM_DEBUG(llvm::dbgs() << "* new call " << k << " {" << calls[i].call
- << "}\n with historyID = " << newInlineHistoryID
- << ", added due to inlining of\n call {" << call
- << "}\n with historyID = "
- << historyToString(inlineHistoryID) << "\n");
+ LDBG() << "* new call " << k << " {" << calls[k].call
+ << "}\n with historyID = " << newInlineHistoryID
+ << ", added due to inlining of\n call {" << call
+ << "}\n with historyID = " << historyToString(inlineHistoryID);
}
// If the inlining was successful, Merge the new uses into the source node.