aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/AsmParser/DialectSymbolParser.cpp7
-rw-r--r--mlir/lib/AsmParser/Lexer.cpp27
-rw-r--r--mlir/lib/AsmParser/Lexer.h3
-rw-r--r--mlir/lib/CAPI/RegisterEverything/CMakeLists.txt11
-rw-r--r--mlir/lib/CMakeLists.txt34
-rw-r--r--mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp38
-rw-r--r--mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp41
-rw-r--r--mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp28
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp34
-rw-r--r--mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp9
-rw-r--r--mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp3
-rw-r--r--mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp7
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp78
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp36
-rw-r--r--mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp11
-rw-r--r--mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp52
-rw-r--r--mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp9
-rw-r--r--mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp37
-rw-r--r--mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp1
-rw-r--r--mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp1
-rw-r--r--mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp61
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp69
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp64
-rw-r--r--mlir/lib/Dialect/Affine/IR/AffineOps.cpp8
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp94
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp2
-rw-r--r--mlir/lib/Dialect/EmitC/IR/EmitC.cpp42
-rw-r--r--mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp17
-rw-r--r--mlir/lib/Dialect/GPU/IR/GPUDialect.cpp36
-rw-r--r--mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp22
-rw-r--r--mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp32
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp20
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp60
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp15
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp51
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp14
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp79
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp4
-rw-r--r--mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp62
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp8
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/CastOps.cpp42
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp59
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp28
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp118
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp12
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp54
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp15
-rw-r--r--mlir/lib/Dialect/Shard/IR/ShardOps.cpp1
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp3
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp48
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp122
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp7
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp16
-rw-r--r--mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp10
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp273
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp24
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp24
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp125
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp103
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp9
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp7
-rw-r--r--mlir/lib/IR/Diagnostics.cpp21
-rw-r--r--mlir/lib/Parser/Parser.cpp30
-rw-r--r--mlir/lib/Pass/Pass.cpp2
-rw-r--r--mlir/lib/RegisterAllDialects.cpp207
-rw-r--r--mlir/lib/RegisterAllExtensions.cpp115
-rw-r--r--mlir/lib/RegisterAllPasses.cpp99
-rw-r--r--mlir/lib/Rewrite/PatternApplicator.cpp15
-rw-r--r--mlir/lib/Support/ToolUtilities.cpp39
-rw-r--r--mlir/lib/Target/Cpp/TranslateToCpp.cpp48
-rw-r--r--mlir/lib/Target/LLVMIR/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp50
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp1
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp36
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp45
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt21
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp103
-rw-r--r--mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp10
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleImport.cpp66
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp61
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp43
-rw-r--r--mlir/lib/Target/SPIRV/Deserialization/Deserializer.h1
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp125
-rw-r--r--mlir/lib/Tools/mlir-opt/MlirOptMain.cpp58
-rw-r--r--mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp7
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp8
-rw-r--r--mlir/lib/Transforms/Utils/Inliner.cpp2
94 files changed, 2473 insertions, 1080 deletions
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 9f4a87a..8b14e71 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -89,6 +89,7 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
nestedPunctuation.pop_back();
return success();
};
+ const char *curBufferEnd = state.lex.getBufferEnd();
do {
// Handle code completions, which may appear in the middle of the symbol
// body.
@@ -98,6 +99,12 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
break;
}
+ if (curBufferEnd == curPtr) {
+ if (!nestedPunctuation.empty())
+ return emitPunctError();
+ return emitError("unexpected nul or EOF in pretty dialect name");
+ }
+
char c = *curPtr++;
switch (c) {
case '\0':
diff --git a/mlir/lib/AsmParser/Lexer.cpp b/mlir/lib/AsmParser/Lexer.cpp
index 751bd63..8f53529 100644
--- a/mlir/lib/AsmParser/Lexer.cpp
+++ b/mlir/lib/AsmParser/Lexer.cpp
@@ -37,6 +37,18 @@ Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context,
AsmParserCodeCompleteContext *codeCompleteContext)
: sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) {
auto bufferID = sourceMgr.getMainFileID();
+
+ // Check to see if the main buffer contains the last buffer, and if so the
+ // last buffer should be used as main file for parsing.
+ if (sourceMgr.getNumBuffers() > 1) {
+ unsigned lastFileID = sourceMgr.getNumBuffers();
+ const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID);
+ const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID);
+ if (main->getBufferStart() <= last->getBufferStart() &&
+ main->getBufferEnd() >= last->getBufferEnd()) {
+ bufferID = lastFileID;
+ }
+ }
curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
curPtr = curBuffer.begin();
@@ -71,6 +83,7 @@ Token Lexer::emitError(const char *loc, const Twine &message) {
}
Token Lexer::lexToken() {
+ const char *curBufferEnd = curBuffer.end();
while (true) {
const char *tokStart = curPtr;
@@ -78,6 +91,9 @@ Token Lexer::lexToken() {
if (tokStart == codeCompleteLoc)
return formToken(Token::code_complete, tokStart);
+ if (tokStart == curBufferEnd)
+ return formToken(Token::eof, tokStart);
+
// Lex the next token.
switch (*curPtr++) {
default:
@@ -102,7 +118,7 @@ Token Lexer::lexToken() {
case 0:
// This may either be a nul character in the source file or may be the EOF
// marker that llvm::MemoryBuffer guarantees will be there.
- if (curPtr - 1 == curBuffer.end())
+ if (curPtr - 1 == curBufferEnd)
return formToken(Token::eof, tokStart);
continue;
@@ -259,7 +275,11 @@ void Lexer::skipComment() {
assert(*curPtr == '/');
++curPtr;
+ const char *curBufferEnd = curBuffer.end();
while (true) {
+ if (curPtr == curBufferEnd)
+ return;
+
switch (*curPtr++) {
case '\n':
case '\r':
@@ -267,7 +287,7 @@ void Lexer::skipComment() {
return;
case 0:
// If this is the end of the buffer, end the comment.
- if (curPtr - 1 == curBuffer.end()) {
+ if (curPtr - 1 == curBufferEnd) {
--curPtr;
return;
}
@@ -405,6 +425,7 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
Token Lexer::lexString(const char *tokStart) {
assert(curPtr[-1] == '"');
+ const char *curBufferEnd = curBuffer.end();
while (true) {
// Check to see if there is a code completion location within the string. In
// these cases we generate a completion location and place the currently
@@ -419,7 +440,7 @@ Token Lexer::lexString(const char *tokStart) {
case 0:
// If this is a random nul character in the middle of a string, just
// include it. If it is the end of file, then it is an error.
- if (curPtr - 1 != curBuffer.end())
+ if (curPtr - 1 != curBufferEnd)
continue;
[[fallthrough]];
case '\n':
diff --git a/mlir/lib/AsmParser/Lexer.h b/mlir/lib/AsmParser/Lexer.h
index 4085a9b..670444e 100644
--- a/mlir/lib/AsmParser/Lexer.h
+++ b/mlir/lib/AsmParser/Lexer.h
@@ -40,6 +40,9 @@ public:
/// Returns the start of the buffer.
const char *getBufferBegin() { return curBuffer.data(); }
+ /// Returns the end of the buffer.
+ const char *getBufferEnd() { return curBuffer.end(); }
+
/// Return the code completion location of the lexer, or nullptr if there is
/// none.
const char *getCodeCompleteLoc() const { return codeCompleteLoc; }
diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
index 8b9a395..ccda668 100644
--- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
+++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt
@@ -1,19 +1,16 @@
# Dialect registration.
-get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS)
-get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
-get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything
RegisterEverything.cpp
LINK_LIBS PUBLIC
- ${dialect_libs}
${translation_libs}
- ${conversion_libs}
- ${extension_libs}
MLIRBuiltinToLLVMIRTranslation
MLIRCAPIIR
- MLIRLLVMToLLVMIRTranslation
MLIRCAPITransforms
+ MLIRLLVMToLLVMIRTranslation
+ MLIRRegisterAllDialects
+ MLIRRegisterAllExtensions
+ MLIRRegisterAllPasses
)
diff --git a/mlir/lib/CMakeLists.txt b/mlir/lib/CMakeLists.txt
index d25c84a..191b5ab6 100644
--- a/mlir/lib/CMakeLists.txt
+++ b/mlir/lib/CMakeLists.txt
@@ -20,3 +20,37 @@ add_subdirectory(Target)
add_subdirectory(Tools)
add_subdirectory(Transforms)
add_subdirectory(ExecutionEngine)
+
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
+get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS)
+
+add_mlir_library(MLIRRegisterAllDialects
+ RegisterAllDialects.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs}
+ )
+
+add_mlir_library(MLIRRegisterAllPasses
+ RegisterAllPasses.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs} # Some passes are part of the dialect libs
+ ${conversion_libs}
+ )
+
+add_mlir_library(MLIRRegisterAllExtensions
+ RegisterAllExtensions.cpp
+
+ PARTIAL_SOURCES_INTENDED
+
+ LINK_LIBS PUBLIC
+ ${dialect_libs}
+ ${conversion_libs}
+ ${extension_libs}
+ )
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index d43e681..265293b 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -99,6 +99,17 @@ static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
return builder.getF32FloatAttr(dstVal.convertToFloat());
}
+// Get in IntegerAttr from FloatAttr while preserving the bits.
+// Useful for converting float constants to integer constants while preserving
+// the bits.
+static IntegerAttr
+getIntegerAttrFromFloatAttr(FloatAttr floatAttr, Type dstType,
+ ConversionPatternRewriter &rewriter) {
+ APFloat floatVal = floatAttr.getValue();
+ APInt intVal = floatVal.bitcastToAPInt();
+ return rewriter.getIntegerAttr(dstType, intVal);
+}
+
/// Returns true if the given `type` is a boolean scalar or vector type.
static bool isBoolScalarOrVector(Type type) {
assert(type && "Not a valid type");
@@ -296,8 +307,18 @@ struct ConstantCompositeOpPattern final
SmallVector<Attribute, 8> elements;
if (isa<FloatType>(srcElemType)) {
for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
- FloatAttr dstAttr =
- convertFloatAttr(srcAttr, cast<FloatType>(dstElemType), rewriter);
+ Attribute dstAttr = nullptr;
+ // Handle 8-bit float conversion to 8-bit integer.
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcElemType.getIntOrFloatBitWidth() == 8 &&
+ isa<IntegerType>(dstElemType)) {
+ dstAttr =
+ getIntegerAttrFromFloatAttr(srcAttr, dstElemType, rewriter);
+ } else {
+ dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstElemType),
+ rewriter);
+ }
if (!dstAttr)
return failure();
elements.push_back(dstAttr);
@@ -361,11 +382,19 @@ struct ConstantScalarOpPattern final
// Floating-point types.
if (isa<FloatType>(srcType)) {
auto srcAttr = cast<FloatAttr>(cstAttr);
- auto dstAttr = srcAttr;
+ Attribute dstAttr = srcAttr;
// Floating-point types not supported in the target environment are all
// converted to float type.
- if (srcType != dstType) {
+ auto *typeConverter = getTypeConverter<SPIRVTypeConverter>();
+ if (typeConverter->getOptions().emulateUnsupportedFloatTypes &&
+ srcType.getIntOrFloatBitWidth() == 8 && isa<IntegerType>(dstType) &&
+ dstType.getIntOrFloatBitWidth() == 8) {
+ // If the source is an 8-bit float, convert it to a 8-bit integer.
+ dstAttr = getIntegerAttrFromFloatAttr(srcAttr, dstType, rewriter);
+ if (!dstAttr)
+ return failure();
+ } else if (srcType != dstType) {
dstAttr = convertFloatAttr(srcAttr, cast<FloatType>(dstType), rewriter);
if (!dstAttr)
return failure();
@@ -1352,6 +1381,7 @@ struct ConvertArithToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// Use UnrealizedConversionCast as the bridge so that we don't need to pull
diff --git a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
index 6f0fc29..35ad99c 100644
--- a/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
+++ b/mlir/lib/Conversion/ComplexToROCDLLibraryCalls/ComplexToROCDLLibraryCalls.cpp
@@ -64,10 +64,46 @@ void mlir::populateComplexToROCDLLibraryCallsConversionPatterns(
patterns.getContext(), "__ocml_cabs_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::AbsOp, Float64Type>>(
patterns.getContext(), "__ocml_cabs_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float32Type>>(
+ patterns.getContext(), "__ocml_carg_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::AngleOp, Float64Type>>(
+ patterns.getContext(), "__ocml_carg_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float32Type>>(
+ patterns.getContext(), "__ocml_conj_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::ConjOp, Float64Type>>(
+ patterns.getContext(), "__ocml_conj_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ccos_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::CosOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ccos_f64");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float32Type>>(
patterns.getContext(), "__ocml_cexp_f32");
patterns.add<ComplexOpToROCDLLibraryCalls<complex::ExpOp, Float64Type>>(
patterns.getContext(), "__ocml_cexp_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float32Type>>(
+ patterns.getContext(), "__ocml_clog_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::LogOp, Float64Type>>(
+ patterns.getContext(), "__ocml_clog_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float32Type>>(
+ patterns.getContext(), "__ocml_cpow_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::PowOp, Float64Type>>(
+ patterns.getContext(), "__ocml_cpow_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csin_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SinOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csin_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float32Type>>(
+ patterns.getContext(), "__ocml_csqrt_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::SqrtOp, Float64Type>>(
+ patterns.getContext(), "__ocml_csqrt_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctan_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctan_f64");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float32Type>>(
+ patterns.getContext(), "__ocml_ctanh_f32");
+ patterns.add<ComplexOpToROCDLLibraryCalls<complex::TanhOp, Float64Type>>(
+ patterns.getContext(), "__ocml_ctanh_f64");
}
namespace {
@@ -86,7 +122,10 @@ void ConvertComplexToROCDLLibraryCallsPass::runOnOperation() {
ConversionTarget target(getContext());
target.addLegalDialect<func::FuncDialect>();
- target.addIllegalOp<complex::AbsOp, complex::ExpOp>();
+ target.addIllegalOp<complex::AbsOp, complex::AngleOp, complex::ConjOp,
+ complex::CosOp, complex::ExpOp, complex::LogOp,
+ complex::PowOp, complex::SinOp, complex::SqrtOp,
+ complex::TanOp, complex::TanhOp>();
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
index 03f4bf4..56b6181 100644
--- a/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp
@@ -43,6 +43,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
// TODO: We should also take care of block argument type conversion.
diff --git a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
index 8ed9f65..c0439a4 100644
--- a/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp
@@ -42,6 +42,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 75e6563..3545acb 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -385,6 +385,14 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
if (auto attr = moduleOp->getAttrOfType<spirv::TargetEnvAttr>(
spirv::getTargetEnvAttrName()))
spvModule->setAttr(spirv::getTargetEnvAttrName(), attr);
+ if (ArrayAttr targets = moduleOp.getTargetsAttr()) {
+ for (Attribute targetAttr : targets)
+ if (auto spirvTargetEnvAttr =
+ dyn_cast<spirv::TargetEnvAttr>(targetAttr)) {
+ spvModule->setAttr(spirv::getTargetEnvAttrName(), spirvTargetEnvAttr);
+ break;
+ }
+ }
rewriter.eraseOp(moduleOp);
return success();
@@ -507,25 +515,27 @@ LogicalResult GPURotateConversion::matchAndRewrite(
getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
unsigned subgroupSize =
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
- IntegerAttr widthAttr;
- if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
- widthAttr.getValue().getZExtValue() > subgroupSize)
+ unsigned width = rotateOp.getWidth();
+ if (width > subgroupSize)
return rewriter.notifyMatchFailure(
- rotateOp,
- "rotate width is not a constant or larger than target subgroup size");
+ rotateOp, "rotate width is larger than target subgroup size");
Location loc = rotateOp.getLoc();
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
+ Value offsetVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getOffsetAttr());
+ Value widthVal =
+ arith::ConstantOp::create(rewriter, loc, adaptor.getWidthAttr());
Value rotateResult = spirv::GroupNonUniformRotateKHROp::create(
- rewriter, loc, scope, adaptor.getValue(), adaptor.getOffset(),
- adaptor.getWidth());
+ rewriter, loc, scope, adaptor.getValue(), offsetVal, widthVal);
Value validVal;
- if (widthAttr.getValue().getZExtValue() == subgroupSize) {
+ if (width == subgroupSize) {
validVal = spirv::ConstantOp::getOne(rewriter.getI1Type(), loc, rewriter);
} else {
+ IntegerAttr widthAttr = adaptor.getWidthAttr();
Value laneId = gpu::LaneIdOp::create(rewriter, loc, widthAttr);
validVal = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ult,
- laneId, adaptor.getWidth());
+ laneId, widthVal);
}
rewriter.replaceOp(rotateOp, {rotateResult, validVal});
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index a344f88..5eab057 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -48,9 +48,36 @@ struct GPUToSPIRVPass final : impl::ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
void runOnOperation() override;
private:
+ /// Queries the target environment from 'targets' attribute of the given
+ /// `moduleOp`.
+ spirv::TargetEnvAttr lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp);
+
+ /// Queries the target environment from 'targets' attribute of the given
+ /// `moduleOp` or returns target environment as returned by
+ /// `spirv::lookupTargetEnvOrDefault` if not provided by 'targets'.
+ spirv::TargetEnvAttr lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp);
bool mapMemorySpace;
};
+spirv::TargetEnvAttr
+GPUToSPIRVPass::lookupTargetEnvInTargets(gpu::GPUModuleOp moduleOp) {
+ if (ArrayAttr targets = moduleOp.getTargetsAttr()) {
+ for (Attribute targetAttr : targets)
+ if (auto spirvTargetEnvAttr = dyn_cast<spirv::TargetEnvAttr>(targetAttr))
+ return spirvTargetEnvAttr;
+ }
+
+ return {};
+}
+
+spirv::TargetEnvAttr
+GPUToSPIRVPass::lookupTargetEnvOrDefault(gpu::GPUModuleOp moduleOp) {
+ if (spirv::TargetEnvAttr targetEnvAttr = lookupTargetEnvInTargets(moduleOp))
+ return targetEnvAttr;
+
+ return spirv::lookupTargetEnvOrDefault(moduleOp);
+}
+
void GPUToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
@@ -58,9 +85,8 @@ void GPUToSPIRVPass::runOnOperation() {
SmallVector<Operation *, 1> gpuModules;
OpBuilder builder(context);
- auto targetEnvSupportsKernelCapability = [](gpu::GPUModuleOp moduleOp) {
- Operation *gpuModule = moduleOp.getOperation();
- auto targetAttr = spirv::lookupTargetEnvOrDefault(gpuModule);
+ auto targetEnvSupportsKernelCapability = [this](gpu::GPUModuleOp moduleOp) {
+ auto targetAttr = lookupTargetEnvOrDefault(moduleOp);
spirv::TargetEnv targetEnv(targetAttr);
return targetEnv.allows(spirv::Capability::Kernel);
};
@@ -86,7 +112,7 @@ void GPUToSPIRVPass::runOnOperation() {
// TargetEnv attributes.
for (Operation *gpuModule : gpuModules) {
spirv::TargetEnvAttr targetAttr =
- spirv::lookupTargetEnvOrDefault(gpuModule);
+ lookupTargetEnvOrDefault(cast<gpu::GPUModuleOp>(gpuModule));
// Map MemRef memory space to SPIR-V storage class first if requested.
if (mapMemorySpace) {
diff --git a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
index 855c582..cde2340 100644
--- a/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
+++ b/mlir/lib/Conversion/MathToFuncs/MathToFuncs.cpp
@@ -22,7 +22,7 @@
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOFUNCS
@@ -32,7 +32,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-funcs"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace {
// Pattern to convert vector operations to scalar operations.
@@ -653,10 +652,8 @@ FPowIOpLowering::matchAndRewrite(math::FPowIOp op,
/// }
static func::FuncOp createCtlzFunc(ModuleOp *module, Type elementType) {
if (!isa<IntegerType>(elementType)) {
- LLVM_DEBUG({
- DBGS() << "non-integer element type for CtlzFunc; type was: ";
- elementType.print(llvm::dbgs());
- });
+ LDBG() << "non-integer element type for CtlzFunc; type was: "
+ << elementType;
llvm_unreachable("non-integer element type");
}
int64_t bitWidth = elementType.getIntOrFloatBitWidth();
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
index 93d8b49..df219f3 100644
--- a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -21,7 +22,6 @@
#include "../GPUCommon/GPUOpsLowering.h"
#include "../GPUCommon/OpToFuncCallLowering.h"
-#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTMATHTOROCDL
@@ -31,7 +31,6 @@ namespace mlir {
using namespace mlir;
#define DEBUG_TYPE "math-to-rocdl"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
template <typename OpTy>
static void populateOpPatterns(const LLVMTypeConverter &converter,
diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
index a877ad2..1787e0a 100644
--- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
+++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp
@@ -488,7 +488,12 @@ namespace mlir {
void populateMathToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
// Core patterns
- patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
+ patterns
+ .add<CopySignPattern,
+ CheckedElementwiseOpPattern<math::IsInfOp, spirv::IsInfOp>,
+ CheckedElementwiseOpPattern<math::IsNaNOp, spirv::IsNanOp>,
+ CheckedElementwiseOpPattern<math::IsFiniteOp, spirv::IsFiniteOp>>(
+ typeConverter, patterns.getContext());
// GLSL patterns
patterns
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index e882845..6bd0e2d 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -19,10 +19,18 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <cstdint>
using namespace mlir;
+static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
+ return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
+ memRefType.getRank() != 0 &&
+ !llvm::is_contained(memRefType.getShape(), 0);
+}
+
namespace {
/// Implement the interface to convert MemRef to EmitC.
struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
@@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
+struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = allocOp.getLoc();
+ MemRefType memrefType = allocOp.getType();
+ if (!isMemRefTypeLegalForEmitC(memrefType)) {
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible memref type for EmitC conversion");
+ }
+
+ Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
+ Type elementType = memrefType.getElementType();
+ IndexType indexType = rewriter.getIndexType();
+ emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
+ loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
+ ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
+
+ int64_t numElements = 1;
+ for (int64_t dimSize : memrefType.getShape()) {
+ numElements *= dimSize;
+ }
+ Value numElementsValue = rewriter.create<emitc::ConstantOp>(
+ loc, indexType, rewriter.getIndexAttr(numElements));
+
+ Value totalSizeBytes = rewriter.create<emitc::MulOp>(
+ loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
+
+ emitc::CallOpaqueOp allocCall;
+ StringAttr allocFunctionName;
+ Value alignmentValue;
+ SmallVector<Value, 2> argsVec;
+ if (allocOp.getAlignment()) {
+ allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
+ alignmentValue = rewriter.create<emitc::ConstantOp>(
+ loc, sizeTType,
+ rewriter.getIntegerAttr(indexType,
+ allocOp.getAlignment().value_or(0)));
+ argsVec.push_back(alignmentValue);
+ } else {
+ allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
+ }
+
+ argsVec.push_back(totalSizeBytes);
+ ValueRange args(argsVec);
+
+ allocCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc,
+ emitc::PointerType::get(
+ emitc::OpaqueType::get(rewriter.getContext(), "void")),
+ allocFunctionName, args);
+
+ emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
+ emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
+ loc, targetPointerType, allocCall.getResult(0));
+
+ rewriter.replaceOp(allocOp, castOp);
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion(
[&](MemRefType memRefType) -> std::optional<Type> {
- if (!memRefType.hasStaticShape() ||
- !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
- llvm::is_contained(memRefType.getShape(), 0)) {
+ if (!isMemRefTypeLegalForEmitC(memRefType)) {
return {};
}
Type convertedElementType =
@@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
- ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
+ ConvertLoad, ConvertStore>(converter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index cf25c09..e78dd76 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -15,6 +15,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -28,9 +29,11 @@ using namespace mlir;
namespace {
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+ using Base::Base;
void runOnOperation() override {
TypeConverter converter;
-
+ ConvertMemRefToEmitCOptions options;
+ options.lowerToCpp = this->lowerToCpp;
// Fallback for other types.
converter.addConversion([](Type type) -> std::optional<Type> {
if (!emitc::isSupportedEmitCType(type))
@@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
+
+ mlir::ModuleOp module = getOperation();
+ module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
+ if (callOp.getCallee() != alignedAllocFunctionName &&
+ callOp.getCallee() != mallocFunctionName) {
+ return mlir::WalkResult::advance();
+ }
+
+ for (auto &op : *module.getBody()) {
+ emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
+ if (!includeOp) {
+ continue;
+ }
+ if (includeOp.getIsStandardInclude() &&
+ ((options.lowerToCpp &&
+ includeOp.getInclude() == cppStandardLibraryHeader) ||
+ (!options.lowerToCpp &&
+ includeOp.getInclude() == cStandardLibraryHeader))) {
+ return mlir::WalkResult::interrupt();
+ }
+ }
+
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ StringAttr includeAttr =
+ builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader);
+ builder.create<mlir::emitc::IncludeOp>(
+ module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+ return mlir::WalkResult::interrupt();
+ });
}
};
} // namespace
diff --git a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
index 6ba5bfe4..dc2035b 100644
--- a/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
+++ b/mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
@@ -24,11 +24,12 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Pass/Pass.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/MathExtras.h"
+
#include <optional>
#define DEBUG_TYPE "memref-to-llvm"
-#define DBGS() llvm::dbgs() << "[" DEBUG_TYPE "] "
namespace mlir {
#define GEN_PASS_DEF_FINALIZEMEMREFTOLLVMCONVERSIONPASS
@@ -1848,8 +1849,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::xchg;
case arith::AtomicRMWKind::maximumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw maximumf changed "
- "from fmax to fmaximum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw maximumf changed "
+ "from fmax to fmaximum, expect more NaNs";
return LLVM::AtomicBinOp::fmaximum;
case arith::AtomicRMWKind::maxnumf:
return LLVM::AtomicBinOp::fmax;
@@ -1859,8 +1860,8 @@ matchSimpleAtomicOp(memref::AtomicRMWOp atomicOp) {
return LLVM::AtomicBinOp::umax;
case arith::AtomicRMWKind::minimumf:
// TODO: remove this by end of 2025.
- LLVM_DEBUG(DBGS() << "the lowering of memref.atomicrmw minimum changed "
- "from fmin to fminimum, expect more NaNs");
+ LDBG() << "the lowering of memref.atomicrmw minimum changed "
+ "from fmin to fminimum, expect more NaNs";
return LLVM::AtomicBinOp::fminimum;
case arith::AtomicRMWKind::minnumf:
return LLVM::AtomicBinOp::fmin;
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index 5d13353..2549a9c 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -26,13 +26,12 @@
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/raw_ostream.h"
#include <optional>
#define DEBUG_TYPE "nvgpu-to-nvvm"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define DBGSE() (llvm::dbgs())
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVGPUTONVVMPASS
@@ -1105,13 +1104,13 @@ struct NVGPUGenerateWarpgroupDescriptorLowering
// // [0,14) start_address
dsc = insertBit(dsc, basePtr14bit, startBaseAddrBit);
- LLVM_DEBUG(DBGS() << "Generating warpgroup.descriptor: "
- << "leading_off:" << leadDimVal << "\t"
- << "stride_off :" << strideDimVal << "\t"
- << "base_offset:" << offsetVal << "\t"
- << "layout_type:" << swizzle << " ("
- << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
- << ")\n start_addr : " << baseAddr << "\n");
+ LDBG() << "Generating warpgroup.descriptor: "
+ << "leading_off:" << leadDimVal << "\t"
+ << "stride_off :" << strideDimVal << "\t"
+ << "base_offset:" << offsetVal << "\t"
+ << "layout_type:" << swizzle << " ("
+ << nvgpu::stringifyTensorMapSwizzleKind(swizzleKind)
+ << ")\n start_addr : " << baseAddr;
rewriter.replaceOp(op, dsc);
return success();
@@ -1281,8 +1280,8 @@ struct NVGPUWarpgroupMmaOpLowering
} else {
llvm_unreachable("msg: not supported K shape");
}
- LLVM_DEBUG(DBGS() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
- << ", n = " << wgmmaN << ", k = " << wgmmaK << "]\n");
+ LDBG() << "Generating WgmmaMmaAsyncOp shape[m = " << wgmmaM
+ << ", n = " << wgmmaN << ", k = " << wgmmaK << "]";
}
/// Generates WGMMATypesAttr from MLIR Type
@@ -1366,9 +1365,9 @@ struct NVGPUWarpgroupMmaOpLowering
int tileShapeA = matrixTypeA.getDimSize(1);
int incrementVal = ((wgmmaK * k) + (totalK * tileShapeA * i)) * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGS() << "\t\t[m: " << i << " n: " << j << " k: " << k
- << "] [wgmma descriptors] Descriptor A + "
- << incrementVal << " | \t ");
+ LDBG() << "\t\t[m: " << i << " n: " << j << " k: " << k
+ << "] [wgmma descriptors] Descriptor A + " << incrementVal
+ << " | \t ";
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1391,7 +1390,7 @@ struct NVGPUWarpgroupMmaOpLowering
int byte = elemB.getIntOrFloatBitWidth() / 8;
int incrementVal = matrixTypeB.getDimSize(0) * wgmmaK * k * byte;
incrementVal = incrementVal >> exclude4LSB;
- LLVM_DEBUG(DBGSE() << "Descriptor B + " << incrementVal << "\n");
+ LDBG() << "Descriptor B + " << incrementVal;
if (!incrementVal)
return desc;
return makeAdd(desc, makeI64Const(b, incrementVal));
@@ -1400,15 +1399,14 @@ struct NVGPUWarpgroupMmaOpLowering
/// This function generates a WgmmaMmaAsyncOp using provided GMMA matrix
/// descriptors and arranges them based on induction variables: i, j, and k.
Value generateWgmma(int i, int j, int k, Value matrixC) {
- LLVM_DEBUG(DBGS() << "\t wgmma."
- << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK
- << "(A[" << (iterationM * wgmmaM) << ":"
- << (iterationM * wgmmaM) + wgmmaM << "]["
- << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "] * "
- << " B[" << (iterationK * wgmmaK) << ":"
- << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":"
- << wgmmaN << "])\n");
+ LDBG() << "\t wgmma."
+ << "m" << wgmmaM << "n" << wgmmaN << "k" << wgmmaK << "(A["
+ << (iterationM * wgmmaM) << ":" << (iterationM * wgmmaM) + wgmmaM
+ << "][" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "] * "
+ << " B[" << (iterationK * wgmmaK) << ":"
+ << (iterationK * wgmmaK + wgmmaK) << "][" << 0 << ":" << wgmmaN
+ << "])";
Value descriptorA = iterateDescriptorA(adaptor.getDescriptorA(), i, j, k);
Value descriptorB = iterateDescriptorB(adaptor.getDescriptorB(), i, j, k);
@@ -1467,9 +1465,9 @@ struct NVGPUWarpgroupMmaOpLowering
totalM = op.getDescriptorA().getType().getTensor().getDimSize(0);
totalN = op.getDescriptorB().getType().getTensor().getDimSize(1);
totalK = op.getDescriptorA().getType().getTensor().getDimSize(1);
- LLVM_DEBUG(DBGS() << "===--- GEMM D[" << totalM << "][" << totalN
- << "] += A[" << totalM << "][" << totalK << "] * B["
- << totalK << "][" << totalN << "] ---===\n");
+ LDBG() << "===--- GEMM D[" << totalM << "][" << totalN << "] += A["
+ << totalM << "][" << totalK << "] * B[" << totalK << "][" << totalN
+ << "] ---===";
// Find the shape for one wgmma instruction
findWgmmaShape(
diff --git a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
index 662ee9e..91788f9 100644
--- a/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
+++ b/mlir/lib/Conversion/NVVMToLLVM/NVVMToLLVM.cpp
@@ -25,11 +25,10 @@
#include "mlir/IR/Value.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "nvvm-to-llvm"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTNVVMTOLLVMPASS
@@ -52,17 +51,17 @@ struct PtxLowering
LogicalResult matchAndRewrite(BasicPtxBuilderInterface op,
PatternRewriter &rewriter) const override {
if (op.hasIntrinsic()) {
- LLVM_DEBUG(DBGS() << "Ptx Builder does not lower \n\t" << op << "\n");
+ LDBG() << "Ptx Builder does not lower \n\t" << op;
return failure();
}
SmallVector<std::pair<Value, PTXRegisterMod>> asmValues;
- LLVM_DEBUG(DBGS() << op.getPtx() << "\n");
+ LDBG() << op.getPtx();
PtxBuilder generator(op, rewriter);
op.getAsmValues(rewriter, asmValues);
for (auto &[asmValue, modifier] : asmValues) {
- LLVM_DEBUG(DBGSNL() << asmValue << "\t Modifier : " << &modifier);
+ LDBG() << asmValue << "\t Modifier : " << &modifier;
generator.insertValue(asmValue, modifier);
}
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 807be7e..ba448e4 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -312,6 +312,19 @@ struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
} // namespace
+static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
+ // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
+ // llvm.loop_annotation attribute.
+ // LLVM requires the loop metadata to be attached on the "latch" block. Which
+ // is the back-edge to the header block (conditionBlock)
+ SmallVector<NamedAttribute> llvmAttrs;
+ llvm::copy_if(scfOp->getAttrs(), std::back_inserter(llvmAttrs),
+ [](auto attr) {
+ return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
+ });
+ brOp->setDiscardableAttrs(llvmAttrs);
+}
+
LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
PatternRewriter &rewriter) const {
Location loc = forOp.getLoc();
@@ -350,17 +363,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
auto branchOp =
cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried);
- // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
- // llvm.loop_annotation attribute.
- // LLVM requires the loop metadata to be attached on the "latch" block. Which
- // is the back-edge to the header block (conditionBlock)
- SmallVector<NamedAttribute> llvmAttrs;
- llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs),
- [](auto attr) {
- return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
- });
- branchOp->setDiscardableAttrs(llvmAttrs);
-
+ propagateLoopAttrs(forOp, branchOp);
rewriter.eraseOp(terminator);
// Compute loop bounds before branching to the condition.
@@ -589,9 +592,10 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
rewriter.setInsertionPointToEnd(after);
auto yieldOp = cast<scf::YieldOp>(after->getTerminator());
- rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
- yieldOp.getResults());
+ auto latch = rewriter.replaceOpWithNewOp<cf::BranchOp>(yieldOp, before,
+ yieldOp.getResults());
+ propagateLoopAttrs(whileOp, latch);
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
rewriter.replaceOp(whileOp, args);
@@ -631,10 +635,11 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
// Loop around the "before" region based on condition.
rewriter.setInsertionPointToEnd(before);
auto condOp = cast<ConditionOp>(before->getTerminator());
- cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(),
- before, condOp.getArgs(), continuation,
- ValueRange());
+ auto latch = cf::CondBranchOp::create(
+ rewriter, condOp.getLoc(), condOp.getCondition(), before,
+ condOp.getArgs(), continuation, ValueRange());
+ propagateLoopAttrs(whileOp, latch);
// Replace the op with values "yielded" from the "before" region, which are
// visible by dominance.
rewriter.replaceOp(whileOp, condOp.getArgs());
diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
index fd40e7c..fa9e544 100644
--- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
+++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp
@@ -36,7 +36,6 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#define DEBUG_TYPE "shard-to-mpi"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
namespace mlir {
#define GEN_PASS_DEF_CONVERTSHARDTOMPIPASS
diff --git a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
index f07386e..8cd650e 100644
--- a/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp
@@ -41,6 +41,7 @@ class ConvertTensorToSPIRVPass
SPIRVConversionOptions options;
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
+ options.emulateUnsupportedFloatTypes = this->emulateUnsupportedFloatTypes;
SPIRVTypeConverter typeConverter(targetAttr, options);
RewritePatternSet patterns(context);
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index a425eff..1d1904f 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -31,10 +31,9 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "vector-to-gpu"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
-#define DBGSNL() (llvm::dbgs() << "\n")
namespace mlir {
#define GEN_PASS_DEF_CONVERTVECTORTOGPU
@@ -366,7 +365,7 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op,
// by all operations.
if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
if (!supportsMMaMatrixType(op, useNvGpu)) {
- LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
+ LDBG() << "cannot convert op: " << *op;
return true;
}
return false;
@@ -548,7 +547,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
@@ -583,7 +582,7 @@ convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
isTranspose ? rewriter.getUnitAttr() : UnitAttr());
valueMapping[mappingResult] = load;
- LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
+ LDBG() << "transfer read to: " << load;
return success();
}
@@ -597,13 +596,13 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
std::optional<int64_t> stride =
getStaticallyKnownRowStride(op.getShapedType());
if (!stride.has_value()) {
- LLVM_DEBUG(DBGS() << "no stride\n");
+ LDBG() << "no stride";
return rewriter.notifyMatchFailure(op, "no stride");
}
auto it = valueMapping.find(op.getVector());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no mapping\n");
+ LDBG() << "no mapping";
return rewriter.notifyMatchFailure(op, "no mapping");
}
@@ -613,9 +612,9 @@ convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
(void)store;
- LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
+ LDBG() << "transfer write to: " << store;
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -641,21 +640,21 @@ convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
if (!dense) {
- LLVM_DEBUG(DBGS() << "not a splat\n");
+ LDBG() << "not a splat";
return rewriter.notifyMatchFailure(op, "not a splat");
}
@@ -677,8 +676,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
mlir::AffineMap map = op.getPermutationMap();
if (map.getNumResults() != 2) {
- LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
- "is not a 2d operand\n");
+ LDBG() << "Failed because the result of `vector.transfer_read` "
+ "is not a 2d operand";
return failure();
}
@@ -691,8 +690,8 @@ static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
auto exprN = dyn_cast<AffineDimExpr>(dN);
if (!exprM || !exprN) {
- LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
- "expressions, then transpose cannot be determined.\n");
+ LDBG() << "Failed because expressions are not affine dim "
+ "expressions, then transpose cannot be determined.";
return failure();
}
@@ -709,20 +708,20 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
nvgpu::getWarpMatrixInfo(op);
if (failed(warpMatrixInfo)) {
- LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
+ LDBG() << "no warpMatrixInfo";
return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
}
FailureOr<nvgpu::FragmentElementInfo> regInfo =
nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
if (failed(regInfo)) {
- LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
+ LDBG() << "not mma sync reg info";
return rewriter.notifyMatchFailure(op, "not mma sync reg info");
}
FailureOr<bool> transpose = isTransposed(op);
if (failed(transpose)) {
- LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
+ LDBG() << "failed to determine the transpose";
return rewriter.notifyMatchFailure(
op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
}
@@ -731,10 +730,8 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
if (failed(params)) {
- LLVM_DEBUG(
- DBGS()
- << "failed to convert vector.transfer_read to ldmatrix. "
- << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
+ LDBG() << "failed to convert vector.transfer_read to ldmatrix. "
+ << "Op should likely not be converted to a nvgpu.ldmatrix call.";
return rewriter.notifyMatchFailure(
op, "failed to convert vector.transfer_read to ldmatrix; this op "
"likely should not be converted to a nvgpu.ldmatrix call.");
@@ -745,7 +742,7 @@ creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
FailureOr<AffineMap> offsets =
nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
if (failed(offsets)) {
- LLVM_DEBUG(DBGS() << "no offsets\n");
+ LDBG() << "no offsets";
return rewriter.notifyMatchFailure(op, "no offsets");
}
@@ -934,7 +931,7 @@ convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
}
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1132,9 +1129,9 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
loop.getNumResults())))
rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
- LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
- LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
- LLVM_DEBUG(DBGS() << "erase: " << loop);
+ LDBG() << "newLoop now: " << newLoop;
+ LDBG() << "stripped scf.for: " << loop;
+ LDBG() << "erase: " << loop;
rewriter.eraseOp(loop);
return newLoop;
@@ -1150,7 +1147,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
auto it = valueMapping.find(operand.value());
if (it == valueMapping.end()) {
- LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
+ LDBG() << "no value mapping for: " << operand.value();
continue;
}
argMapping.push_back(std::make_pair(
@@ -1168,7 +1165,7 @@ static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
}
- LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
+ LDBG() << "scf.for to: " << newForOp;
return success();
}
@@ -1191,7 +1188,7 @@ convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
}
scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
- LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
+ LDBG() << "erase: " << op;
rewriter.eraseOp(op);
return success();
}
@@ -1244,7 +1241,7 @@ LogicalResult mlir::convertVectorToMMAOps(RewriterBase &rewriter,
auto globalRes = LogicalResult::success();
for (Operation *op : ops) {
- LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
+ LDBG() << "Process op: " << *op;
// Apparently callers do not want to early exit on failure here.
auto res = LogicalResult::success();
if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 4307bc6..17a79e3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1070,39 +1070,6 @@ public:
}
};
-class VectorExtractElementOpConversion
- : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
-public:
- using ConvertOpToLLVMPattern<
- vector::ExtractElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = extractEltOp.getSourceVectorType();
- auto llvmType = typeConverter->convertType(vectorType.getElementType());
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = extractEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
@@ -1206,39 +1173,6 @@ public:
}
};
-class VectorInsertElementOpConversion
- : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
-public:
- using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = insertEltOp.getDestVectorType();
- auto llvmType = typeConverter->convertType(vectorType);
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = insertEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
@@ -2244,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion, VectorScatterOpConversion>(
converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
- VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index b1af5f0..508f4e2 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -690,7 +690,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
-/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
@@ -1643,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
-/// %t = vector.extractelement %vec[i] : vector<9xf32>
+/// %t = vector.extract %vec[i] : f32 from vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 986eae3..a4be7d4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -335,63 +335,6 @@ struct VectorInsertOpConvert final
}
};
-struct VectorExtractElementOpConvert final
- : public OpConversionPattern<vector::ExtractElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type resultType = getTypeConverter()->convertType(extractOp.getType());
- if (!resultType)
- return failure();
-
- if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, resultType, adaptor.getVector(),
- rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
- else
- rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
-struct VectorInsertElementOpConvert final
- : public OpConversionPattern<vector::InsertElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type vectorType = getTypeConverter()->convertType(insertOp.getType());
- if (!vectorType)
- return failure();
-
- if (isa<spirv::ScalarType>(vectorType)) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(),
- cstPos.getSExtValue());
- else
- rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
struct VectorInsertStridedSliceOpConvert final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
@@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final
void mlir::populateVectorToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<
- VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
- VectorToElementOpConvert, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorToElementOpConvert, VectorInsertOpConvert,
+ VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 8d7053c..22608a1 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -26,7 +26,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/TypeSwitch.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <numeric>
@@ -40,7 +40,6 @@ using llvm::divideFloorSigned;
using llvm::mod;
#define DEBUG_TYPE "affine-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
#include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
@@ -1062,12 +1061,9 @@ static LogicalResult replaceAffineMinBoundingBoxExpression(AffineMinOp minOp,
AffineMap *map,
ValueRange dims,
ValueRange syms) {
+ LDBG() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`";
AffineMap affineMinMap = minOp.getAffineMap();
- LLVM_DEBUG({
- DBGS() << "replaceAffineMinBoundingBoxExpression: `" << minOp << "`\n";
- });
-
// Check the value is positive.
for (unsigned i = 0, e = affineMinMap.getNumResults(); i < e; ++i) {
// Compare each expression in the minimum against 0.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d1d1062..aa53f94 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -1,4 +1,5 @@
-//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
+//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
+//----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,12 +9,13 @@
//
// Module Bufferization is an extension of One-Shot Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
-// implementations for FuncOp, CallOp and ReturnOp.
+// implementations for FuncOp, CallOp and ReturnOp. Although it is named
+// Module Bufferization, it may operate on any SymbolTable.
//
-// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
-// This function analyzes the given module and determines the order of analysis
-// and bufferization: Functions that are called are processed before their
-// respective callers.
+// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
+// ...)`. This function analyzes the given op and determines the order of
+// analysis and bufferization: Functions that are called are processed before
+// their respective callers.
//
// After analyzing a FuncOp, additional information about its bbArgs is
// gathered and stored in `FuncAnalysisState`.
@@ -309,7 +311,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Return `failure()` if we are unable to retrieve the called FuncOp from
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
- ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
SymbolTableCollection &symbolTables) {
// For each FuncOp, the set of functions called by it (i.e. the union of
@@ -317,26 +319,29 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
-
- for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
- // Collect function calls and populate the caller map.
- numberCallOpsContainedInFuncOp[funcOp] = 0;
- WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
- assert(calledFunction && "could not retrieved called func::FuncOp");
- // If the called function does not have any tensors in its signature, then
- // it is not necessary to bufferize the callee before the caller.
- if (!hasTensorSignature(calledFunction))
- return WalkResult::skip();
-
- callerMap[calledFunction].insert(callOp);
- if (calledBy[calledFunction].insert(funcOp).second) {
- numberCallOpsContainedInFuncOp[funcOp]++;
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ // Collect function calls and populate the caller map.
+ numberCallOpsContainedInFuncOp[funcOp] = 0;
+ WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
+ func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
+ assert(calledFunction && "could not retrieved called func::FuncOp");
+ // If the called function does not have any tensors in its signature,
+ // then it is not necessary to bufferize the callee before the caller.
+ if (!hasTensorSignature(calledFunction))
+ return WalkResult::skip();
+
+ callerMap[calledFunction].insert(callOp);
+ if (calledBy[calledFunction].insert(funcOp).second) {
+ numberCallOpsContainedInFuncOp[funcOp]++;
+ }
+ return WalkResult::advance();
+ });
+ if (res.wasInterrupted())
+ return failure();
}
- return WalkResult::advance();
- });
- if (res.wasInterrupted())
- return failure();
+ }
}
// Iteratively remove function operations that do not call any of the
@@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}
LogicalResult
-mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
+mlir::bufferization::analyzeModuleOp(Operation *moduleOp,
OneShotAnalysisState &state,
BufferizationStatistics *statistics) {
assert(state.getOptions().bufferizeFunctionBoundaries &&
@@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
}
void mlir::bufferization::removeBufferizationAttributesInModule(
- ModuleOp moduleOp) {
- for (auto op : moduleOp.getOps<func::FuncOp>()) {
- for (BlockArgument bbArg : op.getArguments())
- removeBufferizationAttributes(bbArg);
+ Operation *moduleOp) {
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ for (BlockArgument bbArg : funcOp.getArguments())
+ removeBufferizationAttributes(bbArg);
+ }
+ }
}
}
LogicalResult mlir::bufferization::bufferizeModuleOp(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
- IRRewriter rewriter(moduleOp.getContext());
+ IRRewriter rewriter(moduleOp->getContext());
// A list of non-circular functions in the order in which they are analyzed
// and bufferized.
@@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
// Bufferize all other ops.
- for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
- // Functions were already bufferized.
- if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
- continue;
- if (failed(bufferizeOp(&op, options, state, statistics)))
- return failure();
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (mlir::Operation &op :
+ llvm::make_early_inc_range(block.getOperations())) {
+ // Functions were already bufferized.
+ if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
+ continue;
+ if (failed(bufferizeOp(&op, options, state, statistics)))
+ return failure();
+ }
+ }
}
// Post-pass cleanup of function argument attributes.
@@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index f999c93..a6159ee 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -33,7 +33,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
// analysis depending on whether function boundary bufferization is enabled or
// not.
if (options.bufferizeFunctionBoundaries) {
- if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
+ if (failed(analyzeModuleOp(op, analysisState, statistics)))
return failure();
} else {
if (failed(analyzeOp(op, analysisState, statistics)))
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 4c09022..e6a3154 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -1398,6 +1398,45 @@ void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
//===----------------------------------------------------------------------===//
// FieldOp
//===----------------------------------------------------------------------===//
+static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op,
+ TypeAttr type,
+ Attribute initialValue) {
+ p << type;
+ if (initialValue) {
+ p << " = ";
+ p.printAttributeWithoutType(initialValue);
+ }
+}
+
+static Type getInitializerTypeForField(Type type) {
+ if (auto array = llvm::dyn_cast<ArrayType>(type))
+ return RankedTensorType::get(array.getShape(), array.getElementType());
+ return type;
+}
+
+static ParseResult
+parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
+ Attribute &initialValue) {
+ Type type;
+ if (parser.parseType(type))
+ return failure();
+
+ typeAttr = TypeAttr::get(type);
+
+ if (parser.parseOptionalEqual())
+ return success();
+
+ if (parser.parseAttribute(initialValue, getInitializerTypeForField(type)))
+ return failure();
+
+ if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
+ initialValue))
+ return parser.emitError(parser.getNameLoc())
+ << "initial value should be a integer, float, elements or opaque "
+ "attribute";
+ return success();
+}
+
LogicalResult FieldOp::verify() {
if (!isSupportedEmitCType(getType()))
return emitOpError("expected valid emitc type");
@@ -1410,9 +1449,6 @@ LogicalResult FieldOp::verify() {
if (!symName || symName.getValue().empty())
return emitOpError("field must have a non-empty symbol name");
- if (!getAttrs())
- return success();
-
return success();
}
diff --git a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
index fa05ad8..c55e26e 100644
--- a/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
+++ b/mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp
@@ -58,17 +58,18 @@ public:
auto argAttrs = funcOp.getArgAttrs();
for (auto [idx, val] : llvm::enumerate(funcOp.getArguments())) {
- StringAttr fieldName;
- Attribute argAttr = nullptr;
-
- fieldName = rewriter.getStringAttr("fieldName" + std::to_string(idx));
- if (argAttrs && idx < argAttrs->size())
- argAttr = (*argAttrs)[idx];
+ StringAttr fieldName =
+ rewriter.getStringAttr("fieldName" + std::to_string(idx));
TypeAttr typeAttr = TypeAttr::get(val.getType());
fields.push_back({fieldName, typeAttr});
- emitc::FieldOp::create(rewriter, funcOp.getLoc(), fieldName, typeAttr,
- argAttr);
+
+ FieldOp fieldop = rewriter.create<emitc::FieldOp>(
+ funcOp->getLoc(), fieldName, typeAttr, nullptr);
+
+ if (argAttrs && idx < argAttrs->size()) {
+ fieldop->setDiscardableAttrs(funcOp.getArgAttrDict(idx));
+ }
}
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index d186a48..5a72ef1 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -1395,40 +1395,12 @@ void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value value,
// RotateOp
//===----------------------------------------------------------------------===//
-void RotateOp::build(OpBuilder &builder, OperationState &result, Value value,
- int32_t offset, int32_t width) {
- build(builder, result, value,
- arith::ConstantOp::create(builder, result.location,
- builder.getI32IntegerAttr(offset)),
- arith::ConstantOp::create(builder, result.location,
- builder.getI32IntegerAttr(width)));
-}
-
LogicalResult RotateOp::verify() {
- auto offsetConstOp = getOffset().getDefiningOp<arith::ConstantOp>();
- if (!offsetConstOp)
- return emitOpError() << "offset is not a constant value";
-
- auto offsetIntAttr =
- llvm::dyn_cast<mlir::IntegerAttr>(offsetConstOp.getValue());
-
- auto widthConstOp = getWidth().getDefiningOp<arith::ConstantOp>();
- if (!widthConstOp)
- return emitOpError() << "width is not a constant value";
-
- auto widthIntAttr =
- llvm::dyn_cast<mlir::IntegerAttr>(widthConstOp.getValue());
-
- llvm::APInt offsetValue = offsetIntAttr.getValue();
- llvm::APInt widthValue = widthIntAttr.getValue();
-
- if (!widthValue.isPowerOf2())
- return emitOpError() << "width must be a power of two";
+ uint32_t offset = getOffset();
+ uint32_t width = getWidth();
- if (offsetValue.sge(widthValue) || offsetValue.slt(0)) {
- int64_t widthValueInt = widthValue.getSExtValue();
- return emitOpError() << "offset must be in the range [0, " << widthValueInt
- << ")";
+ if (offset >= width) {
+ return emitOpError() << "offset must be in the range [0, " << width << ")";
}
return success();
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index cffe310..e0977f5 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -30,6 +30,7 @@
#include "mlir/IR/Types.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/IRBuilder.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
@@ -812,15 +813,26 @@ LogicalResult NVVM::LdMatrixOp::verify() {
}
LogicalResult NVVM::StMatrixOp::verify() {
- unsigned addressSpace =
- llvm::cast<LLVM::LLVMPointerType>(getPtr().getType()).getAddressSpace();
- if (addressSpace != NVVM::kSharedMemorySpace)
- return emitOpError("expected source pointer in memory space 3");
-
int numMatrix = getSources().size();
if (numMatrix != 1 && numMatrix != 2 && numMatrix != 4)
return emitOpError("expected num attribute to be 1, 2 or 4");
+ int m = getShape().getM(), n = getShape().getN();
+ if (m == 8 && n == 8) {
+ if (getEltType() != NVVM::LdStMatrixEltType::B16) {
+ return emitOpError("expected element type to be B16 for 8x8 matrix");
+ }
+ } else if (m == 16 && n == 8) {
+ if (getEltType() != NVVM::LdStMatrixEltType::B8) {
+ return emitOpError("expected element type to be B8 for 16x8 matrix");
+ }
+ if (getLayout() != NVVM::MMALayout::col) {
+ return emitOpError("expected layout to be col for 16x8 matrix");
+ }
+ } else {
+ return emitOpError("expected shape to be 8x8 or 16x8");
+ }
+
return success();
}
diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
index 935aa3c..b951df8 100644
--- a/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp
@@ -22,6 +22,8 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
+
#define DEBUG_TYPE "llvm-inliner"
using namespace mlir;
@@ -670,44 +672,42 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
bool wouldBeCloned) const final {
auto callOp = dyn_cast<LLVM::CallOp>(call);
if (!callOp) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is not an '"
- << LLVM::CallOp::getOperationName() << "' op\n");
+ LDBG() << "Cannot inline: call is not an '"
+ << LLVM::CallOp::getOperationName() << "' op";
return false;
}
if (callOp.getNoInline()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: call is marked no_inline\n");
+ LDBG() << "Cannot inline: call is marked no_inline";
return false;
}
auto funcOp = dyn_cast<LLVM::LLVMFuncOp>(callable);
if (!funcOp) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline: callable is not an '"
- << LLVM::LLVMFuncOp::getOperationName() << "' op\n");
+ LDBG() << "Cannot inline: callable is not an '"
+ << LLVM::LLVMFuncOp::getOperationName() << "' op";
return false;
}
if (funcOp.isNoInline()) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline: function is marked no_inline\n");
+ LDBG() << "Cannot inline: function is marked no_inline";
return false;
}
if (funcOp.isVarArg()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline: callable is variadic\n");
+ LDBG() << "Cannot inline: callable is variadic";
return false;
}
// TODO: Generate aliasing metadata from noalias result attributes.
if (auto attrs = funcOp.getArgAttrs()) {
for (DictionaryAttr attrDict : attrs->getAsRange<DictionaryAttr>()) {
if (attrDict.contains(LLVM::LLVMDialect::getInAllocaAttrName())) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
- << ": inalloca arguments not supported\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": inalloca arguments not supported";
return false;
}
}
}
// TODO: Handle exceptions.
if (funcOp.getPersonality()) {
- LLVM_DEBUG(llvm::dbgs() << "Cannot inline " << funcOp.getSymName()
- << ": unhandled function personality\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": unhandled function personality";
return false;
}
if (funcOp.getPassthrough()) {
@@ -717,10 +717,8 @@ struct LLVMInlinerInterface : public DialectInlinerInterface {
if (!stringAttr)
return false;
if (disallowedFunctionAttrs.contains(stringAttr)) {
- LLVM_DEBUG(llvm::dbgs()
- << "Cannot inline " << funcOp.getSymName()
- << ": found disallowed function attribute "
- << stringAttr << "\n");
+ LDBG() << "Cannot inline " << funcOp.getSymName()
+ << ": found disallowed function attribute " << stringAttr;
return true;
}
return false;
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index f49d9a1..73ae029 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -476,10 +476,10 @@ inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
SmallVector<unsigned, 2>(ac.begin(), ac.end()),
SmallVector<unsigned, 2>(bc.begin(), bc.end()),
SmallVector<unsigned, 2>(ra.begin(), ra.end())};
- llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
- llvm::sort(dimensions.m.begin(), dimensions.m.end());
- llvm::sort(dimensions.n.begin(), dimensions.n.end());
- llvm::sort(dimensions.k.begin(), dimensions.k.end());
+ llvm::sort(dimensions.batch);
+ llvm::sort(dimensions.m);
+ llvm::sort(dimensions.n);
+ llvm::sort(dimensions.k);
return dimensions;
}
@@ -797,12 +797,12 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
SmallVector<unsigned, 2>(depth.begin(), depth.end()),
/*strides=*/SmallVector<int64_t, 2>{},
/*dilations=*/SmallVector<int64_t, 2>{}};
- llvm::sort(dimensions.batch.begin(), dimensions.batch.end());
- llvm::sort(dimensions.outputImage.begin(), dimensions.outputImage.end());
- llvm::sort(dimensions.outputChannel.begin(), dimensions.outputChannel.end());
- llvm::sort(dimensions.filterLoop.begin(), dimensions.filterLoop.end());
- llvm::sort(dimensions.inputChannel.begin(), dimensions.inputChannel.end());
- llvm::sort(dimensions.depth.begin(), dimensions.depth.end());
+ llvm::sort(dimensions.batch);
+ llvm::sort(dimensions.outputImage);
+ llvm::sort(dimensions.outputChannel);
+ llvm::sort(dimensions.filterLoop);
+ llvm::sort(dimensions.inputChannel);
+ llvm::sort(dimensions.depth);
// Use the op carried strides/dilations attribute if present.
auto nativeStrides = linalgOp->getAttrOfType<DenseIntElementsAttr>("strides");
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27b6617..34c63d3 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -32,6 +32,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -2292,9 +2293,39 @@ Speculation::Speculatability BroadcastOp::getSpeculatability() {
return getGenericSpeculatabilityImpl(cast<LinalgOp>(getOperation()));
}
+/// Fold back-to-back broadcasts together.
+struct FoldBroadcasts : OpRewritePattern<linalg::BroadcastOp> {
+ using OpRewritePattern<linalg::BroadcastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(linalg::BroadcastOp broadcastOp,
+ PatternRewriter &rewriter) const override {
+ auto defBroadcastOp = broadcastOp.getInput().getDefiningOp<BroadcastOp>();
+ if (!defBroadcastOp)
+ return failure();
+ ArrayRef<int64_t> defDimensions = defBroadcastOp.getDimensions();
+ ArrayRef<int64_t> dimensions = broadcastOp.getDimensions();
+ SmallVector<int64_t> foldedDims(dimensions);
+ Value init = broadcastOp.getInit();
+ int64_t initRank = cast<ShapedType>(init.getType()).getRank();
+ // Mapping from input dims to init dims.
+ SmallVector<int64_t> dimMap;
+ for (auto dim : llvm::seq<int64_t>(0, initRank)) {
+ if (!llvm::is_contained(dimensions, dim))
+ dimMap.push_back(dim);
+ }
+ for (auto dim : defDimensions)
+ foldedDims.push_back(dimMap[dim]);
+
+ llvm::sort(foldedDims);
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ broadcastOp, defBroadcastOp.getInput(), init, foldedDims);
+ return success();
+ }
+};
+
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
+ results.add<EraseIdentityLinalgOp<BroadcastOp>, FoldBroadcasts>(context);
}
//===----------------------------------------------------------------------===//
@@ -4622,22 +4653,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
});
}
-/// Returns true if the dimension of `sourceShape` is smaller than the dimension
-/// of the `limitShape`.
-static bool areAllInBound(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> limitShape) {
- assert(
- sourceShape.size() == limitShape.size() &&
- "expected source shape rank, and limit of the shape to have same rank");
- return llvm::all_of(
- llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
- int64_t sourceExtent = std::get<0>(it);
- int64_t limit = std::get<1>(it);
- return ShapedType::isDynamic(sourceExtent) ||
- ShapedType::isDynamic(limit) || sourceExtent <= limit;
- });
-}
-
template <typename OpTy>
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4696,11 +4711,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// represents full tiles.
RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
- if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
- return op->emitError("the shape of output is not large enough to hold the "
- "packed data. Expected at least ")
- << expectedPackedType << ", got " << packedType;
- }
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
@@ -4717,6 +4727,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
+ if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
+ packedType.getShape()))) {
+ return op->emitError("expected ")
+ << expectedPackedType << " for the packed domain value, got "
+ << packedType;
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 7f9ba1b..bf66ed0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -637,6 +637,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
}
ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
+ ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
int64_t padRank = sourceShape.size();
auto isStaticZero = [](OpFoldResult f) {
@@ -647,16 +648,18 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
allowedUnitDims.end());
llvm::SmallDenseSet<unsigned> unitDims;
SmallVector<int64_t> newShape;
+ SmallVector<int64_t> newResultShape;
SmallVector<OpFoldResult> newLowPad;
SmallVector<OpFoldResult> newHighPad;
- for (const auto [dim, size, low, high] :
- zip_equal(llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
- padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
+ for (const auto [dim, size, outSize, low, high] : zip_equal(
+ llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
+ resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
isStaticZero(high)) {
unitDims.insert(dim);
} else {
newShape.push_back(size);
+ newResultShape.push_back(outSize);
newLowPad.push_back(low);
newHighPad.push_back(high);
}
@@ -686,8 +689,10 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
reassociationMap, options.rankReductionStrategy);
- auto newPadOp = tensor::PadOp::create(
- rewriter, padOp.getLoc(), /*result=*/Type(), collapsedSource, newLowPad,
+ auto newResultType = RankedTensorType::get(
+ newResultShape, padOp.getResultType().getElementType());
+ auto newPadOp = rewriter.create<tensor::PadOp>(
+ padOp.getLoc(), /*result=*/newResultType, collapsedSource, newLowPad,
newHighPad, paddingVal, padOp.getNofold());
Value dest = padOp.getResult();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 277e50b..9d7f4e0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 2c62cb6..2e62523 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -55,6 +55,28 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
return paddingSizes;
}
+/// Extracts the constant multiplier from an affine expression of the form
+/// `d * c` or `c * d`, where `d` is an AffineDimExpr and `c` is an
+/// AffineConstantExpr. Returns 1 if the expression is not a simple
+/// multiplication of a dimension and a constant.
+static int64_t extractConstantMultiplier(AffineExpr expr) {
+ if (auto binOp = dyn_cast<AffineBinaryOpExpr>(expr)) {
+ if (binOp.getKind() == AffineExprKind::Mul) {
+ auto lhsD = dyn_cast<AffineDimExpr>(binOp.getLHS());
+ auto rhsC = dyn_cast<AffineConstantExpr>(binOp.getRHS());
+ if (lhsD && rhsC) {
+ return rhsC.getValue();
+ }
+ auto lhsC = dyn_cast<AffineConstantExpr>(binOp.getLHS());
+ auto rhsD = dyn_cast<AffineDimExpr>(binOp.getRHS());
+ if (lhsC && rhsD) {
+ return lhsC.getValue();
+ }
+ }
+ }
+ return 1;
+}
+
/// Compute the padded shape of the given value `v` of `RankedTensorType` given
/// - `indexingSizes` a list of OpFoldResult.
/// - an `indexingMap` that encodes how the shape of varies with increases
@@ -63,6 +85,13 @@ getFullRankPaddingSizes(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
/// The implementaiton below iteratively combines increases from contributing
/// dimensions using affine.apply operations.
+/// The padded shape is computed by evaluating the maximum accessed index per
+/// dimension, which may involve multiplying by constant factors derived from
+/// the affine indexing expressions. Currently, only a limited set of projected
+/// permutation indexing maps are supported, such as
+/// - affine_map<(d0, d1, d2) -> (d0, d1)>
+/// - affine_map<(d0, d1, d2) -> (d0, d1 + d2)>
+/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult> linalg::computePaddedShape(
@@ -114,24 +143,33 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
/*compressDims=*/true);
// If we are padding to the next multiple of, compose with ceil(sz) * sz.
+ OpFoldResult paddingDimOfr;
if (options.padToMultipleOf) {
AffineExpr d0, s0;
bindDims(rewriter.getContext(), d0);
bindSymbols(rewriter.getContext(), s0);
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
AffineMap composedMap = projectedMap.compose(ceilMap);
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, composedMap,
{indexingSizes[paddingDim], paddingSize},
/*composeAffineMin=*/true);
- terms.push_back(paddingDimOfr);
} else {
// Otherwise just set to paddingSize.
- OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
+ paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, projectedMap, paddingSize);
- terms.push_back(paddingDimOfr);
}
+ // Adjust for the maximum accessed index, which is (paddingSize - 1) *
+ // multiplier.
+ AffineExpr d0;
+ bindDims(rewriter.getContext(), d0);
+ int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0));
+ AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier);
+ OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, subtractMap, {paddingDimOfr});
+ terms.push_back(maxAccessIdx);
+
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
}
@@ -148,8 +186,9 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
AffineExpr sumExpr = dims.front();
for (unsigned i = 1; i < dims.size(); ++i)
sumExpr = sumExpr + dims[i];
- OpFoldResult paddedDimOfr =
- affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
+ // Add 1 to the maximum accessed index and get the final padded size.
+ OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, sumExpr + 1, terms);
paddedShape[resultIndex] = paddedDimOfr;
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index dad3526..57b610b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -932,20 +932,6 @@ struct PackOpTiling
continue;
}
- // If the dimension needs padding, it is not supported because there are
- // iterations that only write padding values to the whole tile. The
- // consumer fusion is driven by the source, so it is not possible to map
- // an empty slice to the tile.
- bool needExtraPadding =
- ShapedType::isDynamic(destDimSize) || !cstInnerSize ||
- destDimSize * cstInnerSize.value() != srcDimSize;
- // Prioritize the case that the op already says that it does not need
- // padding.
- if (!packOp.getPaddingValue())
- needExtraPadding = false;
- if (needExtraPadding)
- return failure();
-
// Currently fusing `packOp` as consumer only expects perfect tiling
// scenario because even if without padding semantic, the `packOp` may
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index 0170837..0860cea 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -1831,6 +1831,53 @@ vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
return success();
}
+/// Given the re-associations, "collapses" the input Vector type
+///
+/// This is similar to CollapseShapeOp::inferCollapsedType with two notable
+/// differences:
+/// * We can safely assume that there are no dynamic sizes.
+/// * Scalable flags are updated alongside regular dims.
+///
+/// When collapsing scalable flags, conservatively avoids cases with two
+/// scalable dims. We could re-visit this in the future.
+///
+/// EXAMPLE:
+/// type = vector<4x16x[8]x16xf32>
+/// reassociation = [(d0, d1, d2, d3) -> (d0, d1),
+/// (d0, d1, d2, d3) -> (d2, d3)]
+/// Result:
+/// vector<64x[128]xf32>
+static VectorType getCollapsedVecType(VectorType type,
+ ArrayRef<AffineMap> reassociation) {
+ assert(type.getNumScalableDims() < 2 &&
+ "Collapsing more than 1 scalable dim is not supported ATM");
+
+ // Use the fact that reassociation is valid to simplify the logic: only use
+ // each map's rank.
+ assert(isReassociationValid(reassociation) && "invalid reassociation");
+
+ auto shape = type.getShape();
+ auto scalableFlags = type.getScalableDims();
+ SmallVector<int64_t> newShape;
+ SmallVector<bool> newScalableFlags;
+
+ unsigned currentDim = 0;
+ for (AffineMap m : reassociation) {
+ unsigned dim = m.getNumResults();
+ int64_t size = 1;
+ bool flag = false;
+ for (unsigned d = 0; d < dim; ++d) {
+ size *= shape[currentDim + d];
+ flag |= scalableFlags[currentDim + d];
+ }
+ newShape.push_back(size);
+ newScalableFlags.push_back(flag);
+ currentDim += dim;
+ }
+
+ return VectorType::get(newShape, type.getElementType(), newScalableFlags);
+}
+
/// Vectorize a `linalg::UnPackOp` to these 4 Ops:
/// Vector::TransferReadOp - Reads a vector from the source tensor
/// vector::TransposeOp - Transpose the Source tensor
@@ -1913,14 +1960,6 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
readVectorSizes.append(sourceShape.begin() + vectorSizes.size(),
sourceShape.end());
- ReifiedRankedShapedTypeDims reifiedRetShapes;
- LogicalResult status =
- cast<ReifyRankedShapedTypeOpInterface>(unpackOp.getOperation())
- .reifyResultShapes(rewriter, reifiedRetShapes);
- if (status.failed()) {
- LDBG() << "Unable to reify result shapes of " << unpackOp;
- return failure();
- }
Location loc = unpackOp->getLoc();
auto padValue = arith::ConstantOp::create(
@@ -1936,30 +1975,18 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
PackingMetadata packMetadata;
SmallVector<int64_t> lastDimToInsertPosPerm =
getUnPackInverseSrcPerm(unpackOp, packMetadata);
- ShapedType maskedOpShapedType = cast<ShapedType>(readResult.getType());
- SmallVector<int64_t> stripMineShape(maskedOpShapedType.getShape());
- mlir::Type stripMineElemType = maskedOpShapedType.getElementType();
- applyPermutationToVector(stripMineShape, lastDimToInsertPosPerm);
- RankedTensorType stripMineTensorType =
- RankedTensorType::get(stripMineShape, stripMineElemType);
// Transpose the appropriate rows to match output.
vector::TransposeOp transposeOp = vector::TransposeOp::create(
rewriter, loc, readResult, lastDimToInsertPosPerm);
// Collapse the vector to the size required by result.
- RankedTensorType collapsedType = tensor::CollapseShapeOp::inferCollapsedType(
- stripMineTensorType, packMetadata.reassociations);
- mlir::VectorType vecCollapsedType =
- VectorType::get(collapsedType.getShape(), collapsedType.getElementType());
+ VectorType collapsedVecType = getCollapsedVecType(
+ transposeOp.getType(),
+ getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
+ rewriter.getContext(), packMetadata.reassociations)));
vector::ShapeCastOp shapeCastOp = vector::ShapeCastOp::create(
- rewriter, loc, vecCollapsedType, transposeOp->getResult(0));
-
- // writeVectorSizes had to match the shapecast shape for dynamic sizes,
- // otherwise the validator complains that the mask size is invalid.
- SmallVector<int64_t> writeVectorSizes(
- unpackOp.getDestType().hasStaticShape()
- ? vectorSizes
- : shapeCastOp.getResultVectorType().getShape());
+ rewriter, loc, collapsedVecType, transposeOp->getResult(0));
+
Operation *write = createWriteOrMaskedWrite(
rewriter, loc, shapeCastOp.getResult(), unpackOp.getDest(),
/*writeIndices=*/{}, useInBoundsInsteadOfMasking);
diff --git a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
index 106c3b4..cce80db 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/ComposeSubView.cpp
@@ -80,10 +80,6 @@ struct ComposeSubViewOpPattern : public OpRewritePattern<memref::SubViewOp> {
for (auto &&[opOffset, sourceOffset, sourceStride, opSize] :
llvm::zip(op.getMixedOffsets(), sourceOp.getMixedOffsets(),
sourceOp.getMixedStrides(), op.getMixedSizes())) {
- // We only support static sizes.
- if (isa<Value>(opSize)) {
- return failure();
- }
sizes.push_back(opSize);
Attribute opOffsetAttr = llvm::dyn_cast_if_present<Attribute>(opOffset),
sourceOffsetAttr =
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index e73bdd3..485bb73 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -1375,6 +1375,21 @@ void acc::ParallelOp::addWaitOperands(
setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
}
+void acc::ParallelOp::addPrivatization(MLIRContext *context,
+ mlir::acc::PrivateOp op,
+ mlir::acc::PrivateRecipeOp recipe) {
+ getPrivateOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getPrivatizationRecipesAttr())
+ llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
static ParseResult parseNumGangs(
mlir::OpAsmParser &parser,
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
@@ -2011,6 +2026,21 @@ void acc::SerialOp::addWaitOperands(
setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
}
+void acc::SerialOp::addPrivatization(MLIRContext *context,
+ mlir::acc::PrivateOp op,
+ mlir::acc::PrivateRecipeOp recipe) {
+ getPrivateOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getPrivatizationRecipesAttr())
+ llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
//===----------------------------------------------------------------------===//
// KernelsOp
//===----------------------------------------------------------------------===//
@@ -2957,6 +2987,23 @@ bool acc::LoopOp::hasDefaultGangWorkerVector() {
getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
}
+acc::LoopParMode
+acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
+ if (hasSeq(deviceType))
+ return LoopParMode::loop_seq;
+ if (hasAuto(deviceType))
+ return LoopParMode::loop_auto;
+ if (hasIndependent(deviceType))
+ return LoopParMode::loop_independent;
+ if (hasSeq())
+ return LoopParMode::loop_seq;
+ if (hasAuto())
+ return LoopParMode::loop_auto;
+ assert(hasIndependent() &&
+ "loop must have default auto, seq, or independent");
+ return LoopParMode::loop_independent;
+}
+
void acc::LoopOp::addGangOperands(
MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
llvm::ArrayRef<GangArgType> argTypes, mlir::ValueRange values) {
@@ -2997,6 +3044,21 @@ void acc::LoopOp::addGangOperands(
}
}
+void acc::LoopOp::addPrivatization(MLIRContext *context,
+ mlir::acc::PrivateOp op,
+ mlir::acc::PrivateRecipeOp recipe) {
+ getPrivateOperandsMutable().append(op.getResult());
+
+ llvm::SmallVector<mlir::Attribute> recipes;
+
+ if (getPrivatizationRecipesAttr())
+ llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
+
+ recipes.push_back(
+ mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
+ setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
+}
+
//===----------------------------------------------------------------------===//
// DataOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 759e58b..0262a1b 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -137,6 +137,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
if (parser.parseOptionalArrowTypeList(result.types))
return failure();
+ if (succeeded(parser.parseOptionalKeyword("no_inline")))
+ result.addAttribute("no_inline", parser.getBuilder().getUnitAttr());
+
// Introduce the body region and parse it.
Region *body = result.addRegion();
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
@@ -148,8 +151,9 @@ ParseResult ExecuteRegionOp::parse(OpAsmParser &parser,
void ExecuteRegionOp::print(OpAsmPrinter &p) {
p.printOptionalArrowTypeList(getResultTypes());
-
p << ' ';
+ if (getNoInline())
+ p << "no_inline ";
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
@@ -184,7 +188,7 @@ struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
LogicalResult matchAndRewrite(ExecuteRegionOp op,
PatternRewriter &rewriter) const override {
- if (!op.getRegion().hasOneBlock())
+ if (!op.getRegion().hasOneBlock() || op.getNoInline())
return failure();
replaceOpWithRegion(rewriter, op, op.getRegion());
return success();
diff --git a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
index e27dc27..fcf4eb6 100644
--- a/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/CastOps.cpp
@@ -270,48 +270,6 @@ LogicalResult ConvertUToFOp::verify() {
}
//===----------------------------------------------------------------------===//
-// spirv.INTELConvertBF16ToFOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult INTELConvertBF16ToFOp::verify() {
- auto operandType = getOperand().getType();
- auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type have the same
- // shape.
- if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
- unsigned operandNumElements = vectorType.getNumElements();
- unsigned resultNumElements =
- llvm::cast<VectorType>(resultType).getNumElements();
- if (operandNumElements != resultNumElements) {
- return emitOpError(
- "operand and result must have same number of elements");
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
-// spirv.INTELConvertFToBF16Op
-//===----------------------------------------------------------------------===//
-
-LogicalResult INTELConvertFToBF16Op::verify() {
- auto operandType = getOperand().getType();
- auto resultType = getResult().getType();
- // ODS checks that vector result type and vector operand type have the same
- // shape.
- if (auto vectorType = llvm::dyn_cast<VectorType>(operandType)) {
- unsigned operandNumElements = vectorType.getNumElements();
- unsigned resultNumElements =
- llvm::cast<VectorType>(resultType).getNumElements();
- if (operandNumElements != resultNumElements) {
- return emitOpError(
- "operand and result must have same number of elements");
- }
- }
- return success();
-}
-
-//===----------------------------------------------------------------------===//
// spirv.FConvertOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index 9bee200..fcf1526 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -693,7 +693,9 @@ static ParseResult parseStructMemberDecorations(
// `!spirv.struct<` (id `,`)?
// `(`
// (spirv-type (`[` struct-member-decoration `]`)?)*
-// `)>`
+// `)`
+// (`,` struct-decoration)?
+// `>`
static Type parseStructType(SPIRVDialect const &dialect,
DialectAsmParser &parser) {
// TODO: This function is quite lengthy. Break it down into smaller chunks.
@@ -767,17 +769,48 @@ static Type parseStructType(SPIRVDialect const &dialect,
return Type();
}
- if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
+ if (failed(parser.parseRParen()))
+ return Type();
+
+ SmallVector<StructType::StructDecorationInfo, 1> structDecorationInfo;
+
+ auto parseStructDecoration = [&]() {
+ std::optional<spirv::Decoration> decoration =
+ parseAndVerify<spirv::Decoration>(dialect, parser);
+ if (!decoration)
+ return failure();
+
+ // Parse decoration value if it exists.
+ if (succeeded(parser.parseOptionalEqual())) {
+ Attribute decorationValue;
+ if (failed(parser.parseAttribute(decorationValue)))
+ return failure();
+
+ structDecorationInfo.emplace_back(decoration.value(), decorationValue);
+ } else {
+ structDecorationInfo.emplace_back(decoration.value(),
+ UnitAttr::get(dialect.getContext()));
+ }
+ return success();
+ };
+
+ while (succeeded(parser.parseOptionalComma()))
+ if (failed(parseStructDecoration()))
+ return Type();
+
+ if (failed(parser.parseGreater()))
return Type();
if (!identifier.empty()) {
if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationInfo)))
+ memberDecorationInfo,
+ structDecorationInfo)))
return Type();
return idStructTy;
}
- return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
+ return StructType::get(memberTypes, offsetInfo, memberDecorationInfo,
+ structDecorationInfo);
}
// spirv-type ::= array-type
@@ -893,7 +926,23 @@ static void print(StructType type, DialectAsmPrinter &os) {
};
llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
printMember);
- os << ")>";
+ os << ")";
+
+ SmallVector<spirv::StructType::StructDecorationInfo, 1> decorations;
+ type.getStructDecorations(decorations);
+ if (!decorations.empty()) {
+ os << ", ";
+ auto eachFn = [&os](spirv::StructType::StructDecorationInfo decoration) {
+ os << stringifyDecoration(decoration.decoration);
+ if (decoration.hasValue()) {
+ os << "=";
+ os.printAttributeWithoutType(decoration.decorationValue);
+ }
+ };
+ llvm::interleaveComma(decorations, os, eachFn);
+ }
+
+ os << ">";
}
static void print(CooperativeMatrixType type, DialectAsmPrinter &os) {
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 52c672a..f993398 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -767,19 +767,25 @@ void mlir::spirv::AddressOfOp::getAsmResultNames(
// spirv.EXTConstantCompositeReplicate
//===----------------------------------------------------------------------===//
+// Returns type of attribute. In case of a TypedAttr this will simply return
+// the type. But for an ArrayAttr which is untyped and can be multidimensional
+// it creates the ArrayType recursively.
+static Type getValueType(Attribute attr) {
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+ return typedAttr.getType();
+ }
+
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+ return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
+ }
+
+ return nullptr;
+}
+
LogicalResult spirv::EXTConstantCompositeReplicateOp::verify() {
- Type valueType;
- if (auto typedAttr = dyn_cast<TypedAttr>(getValue())) {
- valueType = typedAttr.getType();
- } else if (auto arrayAttr = dyn_cast<ArrayAttr>(getValue())) {
- auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
- if (!typedElemAttr)
- return emitError("value attribute is not typed");
- valueType =
- spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
- } else {
+ Type valueType = getValueType(getValue());
+ if (!valueType)
return emitError("unknown value attribute type");
- }
auto compositeType = dyn_cast<spirv::CompositeType>(getType());
if (!compositeType)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 46739bc..ddb3426 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -835,12 +835,14 @@ void SampledImageType::getCapabilities(
/// - for literal structs:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
///
/// Identified structures only have a mutable component consisting of:
/// - a list of member types;
/// - a list of member offset info;
-/// - a list of member decoration info.
+/// - a list of member decoration info;
+/// - a list of struct decoration info.
struct spirv::detail::StructTypeStorage : public TypeStorage {
/// Construct a storage object for an identified struct type. A struct type
/// associated with such storage must call StructType::trySetBody(...) later
@@ -848,6 +850,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(StringRef identifier)
: memberTypesAndIsBodySet(nullptr, false), offsetInfo(nullptr),
numMembers(0), numMemberDecorations(0), memberDecorationsInfo(nullptr),
+ numStructDecorations(0), structDecorationsInfo(nullptr),
identifier(identifier) {}
/// Construct a storage object for a literal struct type. A struct type
@@ -855,10 +858,14 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
StructTypeStorage(
unsigned numMembers, Type const *memberTypes,
StructType::OffsetInfo const *layoutInfo, unsigned numMemberDecorations,
- StructType::MemberDecorationInfo const *memberDecorationsInfo)
+ StructType::MemberDecorationInfo const *memberDecorationsInfo,
+ unsigned numStructDecorations,
+ StructType::StructDecorationInfo const *structDecorationsInfo)
: memberTypesAndIsBodySet(memberTypes, false), offsetInfo(layoutInfo),
numMembers(numMembers), numMemberDecorations(numMemberDecorations),
- memberDecorationsInfo(memberDecorationsInfo) {}
+ memberDecorationsInfo(memberDecorationsInfo),
+ numStructDecorations(numStructDecorations),
+ structDecorationsInfo(structDecorationsInfo) {}
/// A storage key is divided into 2 parts:
/// - for identified structs:
@@ -867,16 +874,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - an ArrayRef<Type> for member types;
/// - an ArrayRef<StructType::OffsetInfo> for member offset info;
/// - an ArrayRef<StructType::MemberDecorationInfo> for member decoration
+ /// info;
+ /// - an ArrayRef<StructType::StructDecorationInfo> for struct decoration
/// info.
///
/// An identified struct type is uniqued only by the first part (field 0)
/// of the key.
///
- /// A literal struct type is uniqued only by the second part (fields 1, 2, and
- /// 3) of the key. The identifier field (field 0) must be empty.
+ /// A literal struct type is uniqued only by the second part (fields 1, 2, 3
+ /// and 4) of the key. The identifier field (field 0) must be empty.
using KeyTy =
std::tuple<StringRef, ArrayRef<Type>, ArrayRef<StructType::OffsetInfo>,
- ArrayRef<StructType::MemberDecorationInfo>>;
+ ArrayRef<StructType::MemberDecorationInfo>,
+ ArrayRef<StructType::StructDecorationInfo>>;
/// For identified structs, return true if the given key contains the same
/// identifier.
@@ -890,7 +900,7 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
}
return key == KeyTy(StringRef(), getMemberTypes(), getOffsetInfo(),
- getMemberDecorationsInfo());
+ getMemberDecorationsInfo(), getStructDecorationsInfo());
}
/// If the given key contains a non-empty identifier, this method constructs
@@ -937,9 +947,17 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
memberDecorationList = allocator.copyInto(keyMemberDecorations).data();
}
- return new (allocator.allocate<StructTypeStorage>())
- StructTypeStorage(keyTypes.size(), typesList, offsetInfoList,
- numMemberDecorations, memberDecorationList);
+ const StructType::StructDecorationInfo *structDecorationList = nullptr;
+ unsigned numStructDecorations = 0;
+ if (!std::get<4>(key).empty()) {
+ auto keyStructDecorations = std::get<4>(key);
+ numStructDecorations = keyStructDecorations.size();
+ structDecorationList = allocator.copyInto(keyStructDecorations).data();
+ }
+
+ return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(
+ keyTypes.size(), typesList, offsetInfoList, numMemberDecorations,
+ memberDecorationList, numStructDecorations, structDecorationList);
}
ArrayRef<Type> getMemberTypes() const {
@@ -961,6 +979,13 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
return {};
}
+ ArrayRef<StructType::StructDecorationInfo> getStructDecorationsInfo() const {
+ if (structDecorationsInfo)
+ return ArrayRef<StructType::StructDecorationInfo>(structDecorationsInfo,
+ numStructDecorations);
+ return {};
+ }
+
StringRef getIdentifier() const { return identifier; }
bool isIdentified() const { return !identifier.empty(); }
@@ -973,17 +998,19 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
/// - If called for an identified struct whose body was set before (through a
/// call to this method) but with different contents from the passed
/// arguments.
- LogicalResult mutate(
- TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
- ArrayRef<StructType::OffsetInfo> structOffsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo) {
+ LogicalResult
+ mutate(TypeStorageAllocator &allocator, ArrayRef<Type> structMemberTypes,
+ ArrayRef<StructType::OffsetInfo> structOffsetInfo,
+ ArrayRef<StructType::MemberDecorationInfo> structMemberDecorationInfo,
+ ArrayRef<StructType::StructDecorationInfo> structDecorationInfo) {
if (!isIdentified())
return failure();
if (memberTypesAndIsBodySet.getInt() &&
(getMemberTypes() != structMemberTypes ||
getOffsetInfo() != structOffsetInfo ||
- getMemberDecorationsInfo() != structMemberDecorationInfo))
+ getMemberDecorationsInfo() != structMemberDecorationInfo ||
+ getStructDecorationsInfo() != structDecorationInfo))
return failure();
memberTypesAndIsBodySet.setInt(true);
@@ -1007,6 +1034,11 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
allocator.copyInto(structMemberDecorationInfo).data();
}
+ if (!structDecorationInfo.empty()) {
+ numStructDecorations = structDecorationInfo.size();
+ structDecorationsInfo = allocator.copyInto(structDecorationInfo).data();
+ }
+
return success();
}
@@ -1015,21 +1047,30 @@ struct spirv::detail::StructTypeStorage : public TypeStorage {
unsigned numMembers;
unsigned numMemberDecorations;
StructType::MemberDecorationInfo const *memberDecorationsInfo;
+ unsigned numStructDecorations;
+ StructType::StructDecorationInfo const *structDecorationsInfo;
StringRef identifier;
};
StructType
StructType::get(ArrayRef<Type> memberTypes,
ArrayRef<StructType::OffsetInfo> offsetInfo,
- ArrayRef<StructType::MemberDecorationInfo> memberDecorations) {
+ ArrayRef<StructType::MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructType::StructDecorationInfo> structDecorations) {
assert(!memberTypes.empty() && "Struct needs at least one member type");
// Sort the decorations.
- SmallVector<StructType::MemberDecorationInfo, 4> sortedDecorations(
+ SmallVector<StructType::MemberDecorationInfo, 4> sortedMemberDecorations(
memberDecorations);
- llvm::array_pod_sort(sortedDecorations.begin(), sortedDecorations.end());
+ llvm::array_pod_sort(sortedMemberDecorations.begin(),
+ sortedMemberDecorations.end());
+ SmallVector<StructType::StructDecorationInfo, 1> sortedStructDecorations(
+ structDecorations);
+ llvm::array_pod_sort(sortedStructDecorations.begin(),
+ sortedStructDecorations.end());
+
return Base::get(memberTypes.vec().front().getContext(),
/*identifier=*/StringRef(), memberTypes, offsetInfo,
- sortedDecorations);
+ sortedMemberDecorations, sortedStructDecorations);
}
StructType StructType::getIdentified(MLIRContext *context,
@@ -1039,18 +1080,21 @@ StructType StructType::getIdentified(MLIRContext *context,
return Base::get(context, identifier, ArrayRef<Type>(),
ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
}
StructType StructType::getEmpty(MLIRContext *context, StringRef identifier) {
StructType newStructType = Base::get(
context, identifier, ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>());
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>());
// Set an empty body in case this is a identified struct.
if (newStructType.isIdentified() &&
failed(newStructType.trySetBody(
ArrayRef<Type>(), ArrayRef<StructType::OffsetInfo>(),
- ArrayRef<StructType::MemberDecorationInfo>())))
+ ArrayRef<StructType::MemberDecorationInfo>(),
+ ArrayRef<StructType::StructDecorationInfo>())))
return StructType();
return newStructType;
@@ -1074,6 +1118,15 @@ TypeRange StructType::getElementTypes() const {
bool StructType::hasOffset() const { return getImpl()->offsetInfo; }
+bool StructType::hasDecoration(spirv::Decoration decoration) const {
+ for (StructType::StructDecorationInfo info :
+ getImpl()->getStructDecorationsInfo())
+ if (info.decoration == decoration)
+ return true;
+
+ return false;
+}
+
uint64_t StructType::getMemberOffset(unsigned index) const {
assert(getNumElements() > index && "member index out of range");
return getImpl()->offsetInfo[index];
@@ -1105,11 +1158,21 @@ void StructType::getMemberDecorations(
}
}
+void StructType::getStructDecorations(
+ SmallVectorImpl<StructType::StructDecorationInfo> &structDecorations)
+ const {
+ structDecorations.clear();
+ auto implDecorations = getImpl()->getStructDecorationsInfo();
+ structDecorations.append(implDecorations.begin(), implDecorations.end());
+}
+
LogicalResult
StructType::trySetBody(ArrayRef<Type> memberTypes,
ArrayRef<OffsetInfo> offsetInfo,
- ArrayRef<MemberDecorationInfo> memberDecorations) {
- return Base::mutate(memberTypes, offsetInfo, memberDecorations);
+ ArrayRef<MemberDecorationInfo> memberDecorations,
+ ArrayRef<StructDecorationInfo> structDecorations) {
+ return Base::mutate(memberTypes, offsetInfo, memberDecorations,
+ structDecorations);
}
void StructType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
@@ -1131,6 +1194,11 @@ llvm::hash_code spirv::hash_value(
memberDecorationInfo.decoration);
}
+llvm::hash_code spirv::hash_value(
+ const StructType::StructDecorationInfo &structDecorationInfo) {
+ return llvm::hash_value(structDecorationInfo.decoration);
+}
+
//===----------------------------------------------------------------------===//
// MatrixType
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index 81365b4..3911ec0 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -58,7 +58,17 @@ createGlobalVarForEntryPointArgument(OpBuilder &builder, spirv::FuncOp funcOp,
spirv::PointerType::get(spirv::StructType::get(varType), *storageClass);
}
auto varPtrType = cast<spirv::PointerType>(varType);
- auto varPointeeType = cast<spirv::StructType>(varPtrType.getPointeeType());
+ Type pointeeType = varPtrType.getPointeeType();
+
+ // Images are an opaque type and so we can just return a pointer to an image.
+ // Note that currently only sampled images are supported in the SPIR-V
+ // lowering.
+ if (isa<spirv::SampledImageType>(pointeeType))
+ return spirv::GlobalVariableOp::create(builder, funcOp.getLoc(), varType,
+ varName, abiInfo.getDescriptorSet(),
+ abiInfo.getBinding());
+
+ auto varPointeeType = cast<spirv::StructType>(pointeeType);
// Set the offset information.
varPointeeType =
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 35ec019..8f4c4cc 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -182,6 +182,14 @@ getTypeNumBytes(const SPIRVConversionOptions &options, Type type) {
return bitWidth / 8;
}
+ // Handle 8-bit floats.
+ if (options.emulateUnsupportedFloatTypes && isa<FloatType>(type)) {
+ auto bitWidth = type.getIntOrFloatBitWidth();
+ if (bitWidth == 8)
+ return bitWidth / 8;
+ return std::nullopt;
+ }
+
if (auto complexType = dyn_cast<ComplexType>(type)) {
auto elementSize = getTypeNumBytes(options, complexType.getElementType());
if (!elementSize)
@@ -318,6 +326,44 @@ static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
type.getSignedness());
}
+/// Converts 8-bit float types to integer types with the same bit width.
+/// Returns a nullptr for unsupported 8-bit float types.
+static Type convert8BitFloatType(const SPIRVConversionOptions &options,
+ FloatType type) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(type))
+ return IntegerType::get(type.getContext(), type.getWidth());
+ LLVM_DEBUG(llvm::dbgs() << "unsupported 8-bit float type: " << type << "\n");
+ return nullptr;
+}
+
+/// Returns a type with the same shape but with any 8-bit float element type
+/// converted to the same bit width integer type. This is a noop when the
+/// element type is not the 8-bit float type or emulation flag is set to false.
+static ShapedType
+convertShaped8BitFloatType(ShapedType type,
+ const SPIRVConversionOptions &options) {
+ if (!options.emulateUnsupportedFloatTypes)
+ return type;
+ Type srcElementType = type.getElementType();
+ Type convertedElementType = nullptr;
+ // F8 types are converted to integer types with the same bit width.
+ if (isa<Float8E5M2Type, Float8E4M3Type, Float8E4M3FNType, Float8E5M2FNUZType,
+ Float8E4M3FNUZType, Float8E4M3B11FNUZType, Float8E3M4Type,
+ Float8E8M0FNUType>(srcElementType))
+ convertedElementType = IntegerType::get(
+ type.getContext(), srcElementType.getIntOrFloatBitWidth());
+
+ if (!convertedElementType)
+ return type;
+
+ return type.clone(convertedElementType);
+}
+
/// Returns a type with the same shape but with any index element type converted
/// to the matching integer type. This is a noop when the element type is not
/// the index type.
@@ -337,6 +383,7 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
const SPIRVConversionOptions &options, VectorType type,
std::optional<spirv::StorageClass> storageClass = {}) {
type = cast<VectorType>(convertIndexElementType(type, options));
+ type = cast<VectorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
// If this is not a spec allowed scalar type, try to handle sub-byte integer
@@ -433,6 +480,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
}
type = cast<TensorType>(convertIndexElementType(type, options));
+ type = cast<TensorType>(convertShaped8BitFloatType(type, options));
auto scalarType = dyn_cast_or_null<spirv::ScalarType>(type.getElementType());
if (!scalarType) {
LLVM_DEBUG(llvm::dbgs()
@@ -596,6 +644,10 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
} else if (auto indexType = dyn_cast<IndexType>(elementType)) {
type = cast<MemRefType>(convertIndexElementType(type, options));
arrayElemType = type.getElementType();
+ } else if (auto floatType = dyn_cast<FloatType>(elementType)) {
+ // Hnadle 8 bit float types.
+ type = cast<MemRefType>(convertShaped8BitFloatType(type, options));
+ arrayElemType = type.getElementType();
} else {
LLVM_DEBUG(
llvm::dbgs()
@@ -1444,6 +1496,8 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
addConversion([this](FloatType floatType) -> std::optional<Type> {
if (auto scalarType = dyn_cast<spirv::ScalarType>(floatType))
return convertScalarType(this->targetEnv, this->options, scalarType);
+ if (floatType.getWidth() == 8)
+ return convert8BitFloatType(this->options, floatType);
return Type();
});
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
index 6a9b951..a53d0a7 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/UpdateVCEPass.cpp
@@ -174,6 +174,21 @@ void UpdateVCEPass::runOnOperation() {
if (walkResult.wasInterrupted())
return signalPassFailure();
+ // Update min version requirement for capabilities after deducing them.
+ for (spirv::Capability cap : deducedCapabilities) {
+ if (std::optional<spirv::Version> minVersion = spirv::getMinVersion(cap)) {
+ deducedVersion = std::max(deducedVersion, *minVersion);
+ if (deducedVersion > allowedVersion) {
+ module.emitError("Capability '")
+ << spirv::stringifyCapability(cap) << "' requires min version "
+ << spirv::stringifyVersion(deducedVersion)
+ << " but target environment allows up to "
+ << spirv::stringifyVersion(allowedVersion);
+ return signalPassFailure();
+ }
+ }
+ }
+
// TODO: verify that the deduced version is consistent with
// SPIR-V ops' maximal version requirements.
diff --git a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
index e5a3b5d..08fccfa 100644
--- a/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
+++ b/mlir/lib/Dialect/Shard/IR/ShardOps.cpp
@@ -38,7 +38,6 @@
#include <utility>
#define DEBUG_TYPE "shard-ops"
-#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
using namespace mlir;
using namespace mlir::shard;
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 0e96b59..869d27a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -115,8 +115,7 @@ public:
bufferization::BufferizationState bufferizationState;
- if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
- updatedOptions,
+ if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions,
bufferizationState)))
return failure();
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 6d2cbb5..e3cba388 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -452,18 +452,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
auto inputElementType = inputType.getElementType();
- if (!inputType.hasStaticShape()) {
- return failure();
- }
-
if (isa<FloatType>(inputElementType)) {
// Unlike integer types, floating point types can represent infinity.
- auto minClamp =
+ const auto minClamp =
llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
- auto maxClamp =
+ const auto maxClamp =
llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
- bool isMin = minClamp.isNegInfinity();
- bool isMax = maxClamp.isInfinity();
+ const bool isMin = minClamp.isNegInfinity();
+ const bool isMax = maxClamp.isInfinity();
if (isMin && isMax) {
rewriter.replaceOp(op, input);
@@ -472,18 +468,19 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
return failure();
}
- if (inputElementType.isUnsignedInteger()) {
- int64_t minClamp =
- llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
- int64_t maxClamp =
- llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();
+ // i1 types are boolean in TOSA
+ const bool isBoolean = inputElementType.isInteger(1);
+ if (inputElementType.isUnsignedInteger() || isBoolean) {
+ const int64_t minClamp = llvm::cast<mlir::IntegerAttr>(op.getMinValAttr())
+ .getValue()
+ .getZExtValue();
+ const int64_t maxClamp = llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr())
+ .getValue()
+ .getZExtValue();
- int64_t intMin =
- APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
- .getZExtValue();
- int64_t intMax =
- APInt::getMaxValue(inputElementType.getIntOrFloatBitWidth())
- .getZExtValue();
+ const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
+ const int64_t intMin = APInt::getMinValue(bitWidth).getZExtValue();
+ const int64_t intMax = APInt::getMaxValue(bitWidth).getZExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
@@ -493,17 +490,14 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
}
if (llvm::isa<IntegerType>(inputElementType)) {
- int64_t minClamp =
+ const int64_t minClamp =
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
- int64_t maxClamp =
+ const int64_t maxClamp =
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();
- int64_t intMin =
- APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
- .getSExtValue();
- int64_t intMax =
- APInt::getSignedMaxValue(inputElementType.getIntOrFloatBitWidth())
- .getSExtValue();
+ const unsigned bitWidth = inputElementType.getIntOrFloatBitWidth();
+ const int64_t intMin = APInt::getSignedMinValue(bitWidth).getSExtValue();
+ const int64_t intMax = APInt::getSignedMaxValue(bitWidth).getSExtValue();
if (minClamp <= intMin && maxClamp >= intMax) {
rewriter.replaceOp(op, input);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index ecd93ff..3cafb19 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -3647,6 +3647,22 @@ std::optional<SmallVector<int64_t, 4>> ApplyScaleOp::getShapeForUnroll() {
return std::nullopt;
}
+static void printInitializationList(OpAsmPrinter &parser,
+ Block::BlockArgListType blocksArgs,
+ ValueRange initializers,
+ StringRef prefix = "") {
+ assert(blocksArgs.size() == initializers.size() &&
+ "expected same length of arguments and initializers");
+ if (initializers.empty())
+ return;
+
+ parser << prefix << '(';
+ llvm::interleaveComma(
+ llvm::zip(blocksArgs, initializers), parser,
+ [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
+ parser << ")";
+}
+
// parse and print of IfOp refer to the implementation of SCF dialect.
ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
// Create the regions for 'then'.
@@ -3654,16 +3670,64 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
Region *thenRegion = result.addRegion();
Region *elseRegion = result.addRegion();
- auto &builder = parser.getBuilder();
OpAsmParser::UnresolvedOperand cond;
- // Create a i1 tensor type for the boolean condition.
- Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1));
- if (parser.parseOperand(cond) ||
- parser.resolveOperand(cond, i1Type, result.operands))
+
+ if (parser.parseOperand(cond))
return failure();
- // Parse optional results type list.
- if (parser.parseOptionalArrowTypeList(result.types))
+
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
+
+ // Parse the optional block arguments
+ OptionalParseResult listResult =
+ parser.parseOptionalAssignmentList(regionArgs, operands);
+ if (listResult.has_value() && failed(listResult.value()))
return failure();
+
+ // Parse a colon.
+ if (failed(parser.parseColon()))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Parse the type of the condition operand
+ Type condType;
+ if (failed(parser.parseType(condType)))
+ return parser.emitError(parser.getCurrentLocation(),
+ "expected type for condition operand");
+
+ // Resolve operand with provided type
+ if (failed(parser.resolveOperand(cond, condType, result.operands)))
+ return failure();
+
+ // Parse optional block arg types
+ if (listResult.has_value()) {
+ FunctionType functionType;
+
+ if (failed(parser.parseType(functionType)))
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected list of types for block arguments "
+ << "followed by arrow type and list of return types";
+
+ result.addTypes(functionType.getResults());
+
+ if (functionType.getNumInputs() != operands.size()) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected as many input types as operands "
+ << "(expected " << operands.size() << " got "
+ << functionType.getNumInputs() << ")";
+ }
+
+ // Resolve input operands.
+ if (failed(parser.resolveOperands(operands, functionType.getInputs(),
+ parser.getCurrentLocation(),
+ result.operands)))
+ return failure();
+ } else {
+ // Parse optional results type list.
+ if (parser.parseOptionalArrowTypeList(result.types))
+ return failure();
+ }
+
// Parse the 'then' region.
if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
@@ -3681,26 +3745,28 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}
void IfOp::print(OpAsmPrinter &p) {
- bool printBlockTerminators = false;
-
p << " " << getCondition();
- if (!getResults().empty()) {
- p << " -> (" << getResultTypes() << ")";
- // Print yield explicitly if the op defines values.
- printBlockTerminators = true;
+
+ printInitializationList(p, getThenGraph().front().getArguments(),
+ getInputList(), " ");
+ p << " : ";
+ p << getCondition().getType();
+
+ if (!getInputList().empty()) {
+ p << " (";
+ llvm::interleaveComma(getInputList().getTypes(), p);
+ p << ")";
}
- p << ' ';
- p.printRegion(getThenGraph(),
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printArrowTypeList(getResultTypes());
+ p << " ";
+
+ p.printRegion(getThenGraph());
// Print the 'else' regions if it exists and has a block.
auto &elseRegion = getElseGraph();
if (!elseRegion.empty()) {
p << " else ";
- p.printRegion(elseRegion,
- /*printEntryBlockArgs=*/false,
- /*printBlockTerminators=*/printBlockTerminators);
+ p.printRegion(elseRegion);
}
p.printOptionalAttrDict((*this)->getAttrs());
@@ -3909,22 +3975,6 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
parser.parseOptionalAttrDictWithKeyword(result.attributes));
}
-static void printInitializationList(OpAsmPrinter &parser,
- Block::BlockArgListType blocksArgs,
- ValueRange initializers,
- StringRef prefix = "") {
- assert(blocksArgs.size() == initializers.size() &&
- "expected same length of arguments and initializers");
- if (initializers.empty())
- return;
-
- parser << prefix << '(';
- llvm::interleaveComma(
- llvm::zip(blocksArgs, initializers), parser,
- [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); });
- parser << ")";
-}
-
void WhileOp::print(OpAsmPrinter &parser) {
printInitializationList(parser, getCondGraph().front().getArguments(),
getInputList(), " ");
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
index 88b0f36..9543fa1 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
@@ -464,9 +464,12 @@ LogicalResult TosaProfileCompliance::checkInvalid(Operation *op) {
CheckCondition condition = CheckCondition::invalid;
const auto maybeProfDef = getOperatorDefinition<Profile>(op, condition);
const auto maybeExtDef = getOperatorDefinition<Extension>(op, condition);
+ if (failed(maybeProfDef) && failed(maybeExtDef))
+ return success();
- if (!failed(maybeProfDef) && !failed(maybeExtDef) &&
- !maybeProfDef.value().size() && !maybeExtDef.value().size()) {
+ const bool hasEntry = (succeeded(maybeProfDef) && !maybeProfDef->empty()) ||
+ (succeeded(maybeExtDef) && !maybeExtDef->empty());
+ if (!hasEntry) {
std::string message;
llvm::raw_string_ostream os(message);
os << "illegal: operation operand/result data types did not align with any "
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
index 32b5fb6..c7b9534 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp
@@ -1248,16 +1248,14 @@ bool checkErrorIfCondIf(Operation *op) {
// })
//
// Simplified:
- // %0 = tosa.cond_if %arg2 {
- // tosa.yield %arg0
+ // %0 = tosa.cond_if %arg2 (%arg3 = %arg0, %arg4 = %arg1) {
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg3
// } else {
- // tosa.yield %arg1
+ // ^bb0(%arg3, %arg4):
+ // tosa.yield %arg4
// }
- //
- // Unfortunately, the simplified syntax does not encapsulate values
- // used in then/else regions (see 'simplified' example above), so it
- // must be rewritten to use the generic syntax in order to be conformant
- // to the specification.
+
return failed(checkIsolatedRegion(op, ifOp.getThenGraph(), "then")) ||
failed(checkIsolatedRegion(op, ifOp.getElseGraph(), "else"));
}
@@ -1383,7 +1381,7 @@ void TosaValidation::runOnOperation() {
// Some uses of TOSA rely on the constant operands of particular
// operations.
- if (strictOpSpecAlignment && failed(applyConstantOperandCheck(op)))
+ if (failed(applyConstantOperandCheck(op)))
signalPassFailure();
// do level checks
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index e297f7c..14a4fdf 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -21,16 +21,8 @@
#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "transform-dialect"
-#define DEBUG_TYPE_FULL "transform-dialect-full"
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
-#ifndef NDEBUG
-#define FULL_LDBG(X) \
- DEBUGLOG_WITH_STREAM_AND_TYPE(llvm::dbgs(), DEBUG_TYPE_FULL)
-#else
-#define FULL_LDBG(X) \
- for (bool _c = false; _c; _c = false) \
- ::llvm::nulls()
-#endif
+#define FULL_LDBG() LDBG(4)
using namespace mlir;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bce358d..a450056 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1258,63 +1258,6 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
CanonicalizeContractAdd<arith::AddFOp>>(context);
}
-//===----------------------------------------------------------------------===//
-// ExtractElementOp
-//===----------------------------------------------------------------------===//
-
-void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges.front());
-}
-
-void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
- Value source) {
- result.addOperands({source});
- result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
-}
-
-LogicalResult vector::ExtractElementOp::verify() {
- VectorType vectorType = getSourceVectorType();
- if (vectorType.getRank() == 0) {
- if (getPosition())
- return emitOpError("expected position to be empty with 0-D vector");
- return success();
- }
- if (vectorType.getRank() != 1)
- return emitOpError("unexpected >1 vector rank");
- if (!getPosition())
- return emitOpError("expected position for 1-D vector");
- return success();
-}
-
-OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
- // Skip the 0-D vector here now.
- if (!adaptor.getPosition())
- return {};
-
- // Fold extractelement (splat X) -> X.
- if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
- return splat.getInput();
-
- // Fold extractelement(broadcast(X)) -> X.
- if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
- if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
- return broadcast.getSource();
-
- auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
- auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
- if (!pos || !src)
- return {};
-
- auto srcElements = src.getValues<Attribute>();
-
- uint64_t posIdx = pos.getInt();
- if (posIdx >= srcElements.size())
- return {};
-
- return srcElements[posIdx];
-}
-
// Returns `true` if `index` is either within [0, maxIndex) or equal to
// `poisonValue`.
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
@@ -2533,17 +2476,19 @@ OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
return {};
}
-/// Rewrite a vector.from_elements into a vector.splat if all elements are the
-/// same SSA value. E.g.:
-///
-/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
-/// ==> rewrite to vector.splat %a : vector<3xf32>
-static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp,
- PatternRewriter &rewriter) {
+/// Rewrite vector.from_elements as vector.broadcast if the elements are the
+/// same. Example:
+/// %0 = vector.from_elements %a, %a, %a : vector<3xf32>
+/// =>
+/// %0 = vector.broadcast %a : f32 to vector<3xf32>
+static LogicalResult
+rewriteFromElementsAsBroadcast(FromElementsOp fromElementsOp,
+ PatternRewriter &rewriter) {
if (!llvm::all_equal(fromElementsOp.getElements()))
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(fromElementsOp, fromElementsOp.getType(),
- fromElementsOp.getElements().front());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(
+ fromElementsOp, fromElementsOp.getType(),
+ fromElementsOp.getElements().front());
return success();
}
@@ -2574,7 +2519,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
LogicalResult matchAndRewrite(FromElementsOp fromElements,
PatternRewriter &rewriter) const override {
- // Handled by `rewriteFromElementsAsSplat`
+ // Handled by `rewriteFromElementsAsBroadcast`.
if (fromElements.getType().getNumElements() == 1)
return failure();
@@ -2667,7 +2612,7 @@ class FromElementsToShapeCast : public OpRewritePattern<FromElementsOp> {
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add(rewriteFromElementsAsSplat);
+ results.add(rewriteFromElementsAsBroadcast);
results.add<FromElementsToShapeCast>(context);
}
@@ -3115,23 +3060,50 @@ struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
}
};
-/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
+/// Consider the defining operation `defOp` of `value`. If `defOp` is a
+/// vector.splat or a vector.broadcast with a scalar operand, return the scalar
+/// value that is splatted. Otherwise return null.
+///
+/// Examples:
+///
+/// scalar_source --> vector.splat --> value - return scalar_source
+/// scalar_source --> vector.broadcast --> value - return scalar_source
+static Value getScalarSplatSource(Value value) {
+ // Block argument:
+ Operation *defOp = value.getDefiningOp();
+ if (!defOp)
+ return {};
+
+ // Splat:
+ if (auto splat = dyn_cast<vector::SplatOp>(defOp))
+ return splat.getInput();
+
+ auto broadcast = dyn_cast<vector::BroadcastOp>(defOp);
+
+ // Not broadcast (and not splat):
+ if (!broadcast)
+ return {};
+
+ // Broadcast of a vector:
+ if (isa<VectorType>(broadcast.getSourceType()))
+ return {};
+
+ // Broadcast of a scalar:
+ return broadcast.getSource();
+}
+
+/// Pattern to rewrite shuffle(splat-like(v), splat-like(v)) as broadcast(v).
class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ShuffleOp op,
PatternRewriter &rewriter) const override {
- auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
- auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
-
- if (!v1Splat || !v2Splat)
- return failure();
-
- if (v1Splat.getInput() != v2Splat.getInput())
+ Value splat = getScalarSplatSource(op.getV1());
+ if (!splat || getScalarSplatSource(op.getV2()) != splat)
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
@@ -3184,60 +3156,6 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
-// InsertElementOp
-//===----------------------------------------------------------------------===//
-
-void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
-}
-
-void InsertElementOp::build(OpBuilder &builder, OperationState &result,
- Value source, Value dest) {
- build(builder, result, source, dest, {});
-}
-
-LogicalResult InsertElementOp::verify() {
- auto dstVectorType = getDestVectorType();
- if (dstVectorType.getRank() == 0) {
- if (getPosition())
- return emitOpError("expected position to be empty with 0-D vector");
- return success();
- }
- if (dstVectorType.getRank() != 1)
- return emitOpError("unexpected >1 vector rank");
- if (!getPosition())
- return emitOpError("expected position for 1-D vector");
- return success();
-}
-
-OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
- // Skip the 0-D vector here.
- if (!adaptor.getPosition())
- return {};
-
- auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
- auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
- auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
- if (!src || !dst || !pos)
- return {};
-
- if (src.getType() != getDestVectorType().getElementType())
- return {};
-
- auto dstElements = dst.getValues<Attribute>();
-
- SmallVector<Attribute> results(dstElements);
-
- uint64_t posIdx = pos.getInt();
- if (posIdx >= results.size())
- return {};
- results[posIdx] = src;
-
- return DenseElementsAttr::get(getDestVectorType(), results);
-}
-
-//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
@@ -3341,23 +3259,19 @@ public:
}
};
-/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
+/// Pattern to rewrite a insert(splat-like(v), splat-like(v)) as broadcast(v).
class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOp op,
PatternRewriter &rewriter) const override {
- auto srcSplat = op.getValueToStore().getDefiningOp<SplatOp>();
- auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
-
- if (!srcSplat || !dstSplat)
- return failure();
- if (srcSplat.getInput() != dstSplat.getInput())
+ Value splat = getScalarSplatSource(op.getValueToStore());
+ if (!splat || getScalarSplatSource(op.getDest()) != splat)
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
@@ -3625,8 +3539,7 @@ LogicalResult InsertStridedSliceOp::verify() {
}
namespace {
-/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
-/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
+/// Rewrite insert_strided_slice(splat-like(v), splat-like(v)) as v.
class FoldInsertStridedSliceSplat final
: public OpRewritePattern<InsertStridedSliceOp> {
public:
@@ -3634,18 +3547,13 @@ public:
LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
PatternRewriter &rewriter) const override {
- auto srcSplatOp =
- insertStridedSliceOp.getValueToStore().getDefiningOp<vector::SplatOp>();
- auto destSplatOp =
- insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
- if (!srcSplatOp || !destSplatOp)
+ auto dst = insertStridedSliceOp.getDest();
+ auto splat = getScalarSplatSource(insertStridedSliceOp.getValueToStore());
+ if (!splat || getScalarSplatSource(dst) != splat)
return failure();
- if (srcSplatOp.getInput() != destSplatOp.getInput())
- return failure();
-
- rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
+ rewriter.replaceOp(insertStridedSliceOp, dst);
return success();
}
};
@@ -4300,17 +4208,18 @@ public:
}
};
-/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
+/// Rewrite extract_strided_slice(splat-like(v)) with broadcast(v).
class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
PatternRewriter &rewriter) const override {
- auto splat = op.getVector().getDefiningOp<SplatOp>();
+
+ Value splat = getScalarSplatSource(op.getVector());
if (!splat)
return failure();
- rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
+ rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), splat);
return success();
}
};
@@ -6027,14 +5936,13 @@ OpFoldResult ShapeCastOp::fold(FoldAdaptor adaptor) {
}
// shape_cast(constant) -> constant
- if (auto splatAttr =
- llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()))
- return splatAttr.reshape(getType());
+ if (auto denseAttr =
+ dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()))
+ return denseAttr.reshape(getType());
// shape_cast(poison) -> poison
- if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource())) {
+ if (llvm::dyn_cast_if_present<ub::PoisonAttr>(adaptor.getSource()))
return ub::PoisonAttr::get(getContext());
- }
return {};
}
@@ -6427,6 +6335,11 @@ std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
return llvm::to_vector<4>(getResultVectorType().getShape());
}
+void TransposeOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ setResultRanges(getResult(), argRanges.front());
+}
+
namespace {
// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
@@ -6461,19 +6374,19 @@ public:
}
};
-// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
+/// Replace transpose(splat-like(v)) with broadcast(v)
class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(TransposeOp transposeOp,
PatternRewriter &rewriter) const override {
- auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
- if (!splatOp)
+ Value splat = getScalarSplatSource(transposeOp.getVector());
+ if (!splat)
return failure();
- rewriter.replaceOpWithNewOp<vector::SplatOp>(
- transposeOp, transposeOp.getResultVectorType(), splatOp.getInput());
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
+ transposeOp, transposeOp.getResultVectorType(), splat);
return success();
}
};
@@ -7224,6 +7137,23 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
return SplatElementsAttr::get(getType(), {constOperand});
}
+// Canonicalizer for vector.splat. It always gets canonicalized to a
+// vector.broadcast.
+class SplatToBroadcastPattern final : public OpRewritePattern<SplatOp> {
+public:
+ using OpRewritePattern<SplatOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(SplatOp splatOp,
+ PatternRewriter &rewriter) const override {
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(splatOp, splatOp.getType(),
+ splatOp.getOperand());
+ return success();
+ }
+};
+void SplatOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SplatToBroadcastPattern>(context);
+}
+
void SplatOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
SetIntRangeFn setResultRanges) {
setResultRanges(getResult(), argRanges.front());
@@ -7309,6 +7239,23 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
}
//===----------------------------------------------------------------------===//
+// StepOp
+//===----------------------------------------------------------------------===//
+
+void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
+ SetIntRangeFn setResultRanges) {
+ auto resultType = cast<VectorType>(getType());
+ if (resultType.isScalable()) {
+ return;
+ }
+ unsigned bitwidth = ConstantIntRanges::getStorageBitwidth(resultType);
+ APInt zero(bitwidth, 0);
+ APInt high(bitwidth, resultType.getDimSize(0) - 1);
+ ConstantIntRanges result = {zero, high, zero, high};
+ setResultRanges(getResult(), result);
+}
+
+//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
index cb8e566..dedc3b3 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorBroadcast.cpp
@@ -28,7 +28,10 @@ using namespace mlir;
using namespace mlir::vector;
namespace {
-/// Progressive lowering of BroadcastOp.
+
+/// Convert a vector.broadcast with a vector operand to a lower rank
+/// vector.broadcast. vector.broadcast with a scalar operand is expected to be
+/// convertible to the lower level target dialect (LLVM, SPIR-V, etc.) directly.
class BroadcastOpLowering : public OpRewritePattern<vector::BroadcastOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -40,20 +43,23 @@ public:
VectorType srcType = dyn_cast<VectorType>(op.getSourceType());
Type eltType = dstType.getElementType();
- // Scalar to any vector can use splat.
- if (!srcType) {
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, op.getSource());
- return success();
- }
+ // A broadcast from a scalar is considered to be in the lowered form.
+ if (!srcType)
+ return rewriter.notifyMatchFailure(
+ op, "broadcast from scalar already in lowered form");
// Determine rank of source and destination.
int64_t srcRank = srcType.getRank();
int64_t dstRank = dstType.getRank();
- // Stretching scalar inside vector (e.g. vector<1xf32>) can use splat.
+ // Here we are broadcasting to a rank-1 vector. Ensure that the source is a
+ // scalar.
if (srcRank <= 1 && dstRank == 1) {
- Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource());
- rewriter.replaceOpWithNewOp<vector::SplatOp>(op, dstType, ext);
+ SmallVector<int64_t> fullRankPosition(srcRank, 0);
+ Value ext = vector::ExtractOp::create(rewriter, loc, op.getSource(),
+ fullRankPosition);
+ assert(!isa<VectorType>(ext.getType()) && "expected scalar");
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, dstType, ext);
return success();
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
index 4baeb11..2cf8f0b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/LowerVectorTransfer.cpp
@@ -468,7 +468,7 @@ struct TransferReadToVectorLoadLowering
read, "vector type is not rank 1, can't create masked load, needs "
"VectorToSCF");
- Value fill = vector::SplatOp::create(
+ Value fill = vector::BroadcastOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getPadding());
res = vector::MaskedLoadOp::create(
rewriter, read.getLoc(), unbroadcastedVectorType, read.getBase(),
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
index 72352d7..cbb9d4b 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorInsertExtractStridedSliceRewritePatterns.cpp
@@ -303,7 +303,7 @@ public:
// Extract/insert on a lower ranked extract strided slice op.
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
- Value res = SplatOp::create(rewriter, loc, dstType, zero);
+ Value res = BroadcastOp::create(rewriter, loc, dstType, zero);
for (int64_t off = offset, e = offset + size * stride, idx = 0; off < e;
off += stride, ++idx) {
Value one = ExtractOp::create(rewriter, loc, op.getVector(), off);
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
index 48d680c..c707f38 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp
@@ -25,12 +25,10 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringRef.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#define DEBUG_TYPE "vector-transfer-opt"
-#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-
using namespace mlir;
/// Return the ancestor op in the region or nullptr if the region is not
@@ -88,8 +86,7 @@ bool TransferOptimization::isReachable(Operation *start, Operation *dest) {
/// transfer_write is dead if all reads that can be reached from the potentially
/// dead transfer_write are dominated by the overwriting transfer_write.
void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
- LLVM_DEBUG(DBGS() << "Candidate for dead store: " << *write.getOperation()
- << "\n");
+ LDBG() << "Candidate for dead store: " << *write.getOperation();
llvm::SmallVector<Operation *, 8> blockingAccesses;
Operation *firstOverwriteCandidate = nullptr;
Value source = memref::skipViewLikeOps(cast<MemrefValue>(write.getBase()));
@@ -150,13 +147,12 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
!isReachable(writeAncestor, accessAncestor))
continue;
if (!dominators.dominates(firstOverwriteCandidate, accessAncestor)) {
- LLVM_DEBUG(DBGS() << "Store may not be dead due to op: "
- << *accessAncestor << "\n");
+ LDBG() << "Store may not be dead due to op: " << *accessAncestor;
return;
}
}
- LLVM_DEBUG(DBGS() << "Found dead store: " << *write.getOperation()
- << " overwritten by: " << *firstOverwriteCandidate << "\n");
+ LDBG() << "Found dead store: " << *write.getOperation()
+ << " overwritten by: " << *firstOverwriteCandidate;
opToErase.push_back(write.getOperation());
}
@@ -174,8 +170,7 @@ void TransferOptimization::deadStoreOp(vector::TransferWriteOp write) {
void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
if (read.hasOutOfBoundsDim())
return;
- LLVM_DEBUG(DBGS() << "Candidate for Forwarding: " << *read.getOperation()
- << "\n");
+ LDBG() << "Candidate for Forwarding: " << *read.getOperation();
SmallVector<Operation *, 8> blockingWrites;
vector::TransferWriteOp lastwrite = nullptr;
Value source = memref::skipViewLikeOps(cast<MemrefValue>(read.getBase()));
@@ -230,14 +225,13 @@ void TransferOptimization::storeToLoadForwarding(vector::TransferReadOp read) {
if (writeAncestor == nullptr || !isReachable(writeAncestor, readAncestor))
continue;
if (!postDominators.postDominates(lastwrite, write)) {
- LLVM_DEBUG(DBGS() << "Fail to do write to read forwarding due to op: "
- << *write << "\n");
+ LDBG() << "Fail to do write to read forwarding due to op: " << *write;
return;
}
}
- LLVM_DEBUG(DBGS() << "Forward value from " << *lastwrite.getOperation()
- << " to: " << *read.getOperation() << "\n");
+ LDBG() << "Forward value from " << *lastwrite.getOperation()
+ << " to: " << *read.getOperation();
read.replaceAllUsesWith(lastwrite.getVector());
opToErase.push_back(read.getOperation());
}
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 8de87fe..2269a40 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -939,7 +939,7 @@ public:
Value zero = arith::ConstantOp::create(rewriter, loc, elemType,
rewriter.getZeroAttr(elemType));
- Value res = SplatOp::create(rewriter, loc, castDstType, zero);
+ Value res = BroadcastOp::create(rewriter, loc, castDstType, zero);
SmallVector<int64_t> sliceShape = {castDstLastDim};
SmallVector<int64_t> strides = {1};
@@ -965,6 +965,45 @@ private:
std::function<bool(BitCastOp)> controlFn;
};
+static bool haveSameShapeAndScaling(Type t, Type u) {
+ auto tVec = dyn_cast<VectorType>(t);
+ auto uVec = dyn_cast<VectorType>(u);
+ if (!tVec) {
+ return !uVec;
+ }
+ if (!uVec) {
+ return false;
+ }
+ return tVec.getShape() == uVec.getShape() &&
+ tVec.getScalableDims() == uVec.getScalableDims();
+}
+
+/// If `type` is shaped, clone it with `newElementType`. Otherwise,
+/// return `newElementType`.
+static Type cloneOrReplace(Type type, Type newElementType) {
+ if (auto shapedType = dyn_cast<ShapedType>(type)) {
+ return shapedType.clone(newElementType);
+ }
+ return newElementType;
+}
+
+/// If `value` is the result of a splat or broadcast operation, return the input
+/// of the splat/broadcast operation.
+static Value getBroadcastLikeSource(Value value) {
+
+ Operation *op = value.getDefiningOp();
+ if (!op)
+ return {};
+
+ if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
+ return broadcast.getSource();
+
+ if (auto splat = dyn_cast<vector::SplatOp>(op))
+ return splat.getInput();
+
+ return {};
+}
+
/// Reorders elementwise(broadcast/splat) to broadcast(elementwise). Ex:
///
/// Example:
@@ -988,16 +1027,14 @@ struct ReorderElementwiseOpsOnBroadcast final
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return failure();
- if (!llvm::isa<ShapedType>(op->getResults()[0].getType()))
+ auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
+ if (!resultType)
return failure();
if (!OpTrait::hasElementwiseMappableTraits(op))
return rewriter.notifyMatchFailure(
op, "Op doesn't have ElementwiseMappableTraits");
if (op->getNumOperands() == 0)
return failure();
- if (op->getResults()[0].getType() != op->getOperand(0).getType())
- return rewriter.notifyMatchFailure(op,
- "result and operand type mismatch");
if (isa<vector::FMAOp>(op)) {
return rewriter.notifyMatchFailure(
op,
@@ -1005,45 +1042,71 @@ struct ReorderElementwiseOpsOnBroadcast final
"might be a scalar");
}
- // Get the type of the lhs operand
- auto *lhsBcastOrSplat = op->getOperand(0).getDefiningOp();
- if (!lhsBcastOrSplat ||
- !isa<vector::BroadcastOp, vector::SplatOp>(*lhsBcastOrSplat))
+ Type resultElemType = resultType.getElementType();
+
+ // Get the type of the first non-constant operand
+ Value splatSource;
+ for (Value operand : op->getOperands()) {
+ Operation *definingOp = operand.getDefiningOp();
+ if (!definingOp)
+ return failure();
+ if (definingOp->hasTrait<OpTrait::ConstantLike>())
+ continue;
+ splatSource = getBroadcastLikeSource(operand);
+ break;
+ }
+ if (!splatSource)
return failure();
- auto lhsBcastOrSplatType = lhsBcastOrSplat->getOperand(0).getType();
+ Type unbroadcastResultType =
+ cloneOrReplace(splatSource.getType(), resultElemType);
- // Make sure that all operands are broadcast from identical types:
+ // Make sure that all operands are broadcast from identically-shaped types:
// * scalar (`vector.broadcast` + `vector.splat`), or
// * vector (`vector.broadcast`).
// Otherwise the re-ordering wouldn't be safe.
- if (!llvm::all_of(op->getOperands(), [&lhsBcastOrSplatType](Value val) {
- auto bcast = val.getDefiningOp<vector::BroadcastOp>();
- if (bcast)
- return (bcast.getOperand().getType() == lhsBcastOrSplatType);
- auto splat = val.getDefiningOp<vector::SplatOp>();
- if (splat)
- return (splat.getOperand().getType() == lhsBcastOrSplatType);
- return false;
+ if (!llvm::all_of(op->getOperands(), [splatSource](Value val) {
+ if (auto source = getBroadcastLikeSource(val))
+ return haveSameShapeAndScaling(source.getType(),
+ splatSource.getType());
+ SplatElementsAttr splatConst;
+ return matchPattern(val, m_Constant(&splatConst));
})) {
- return failure();
+ return rewriter.notifyMatchFailure(
+ op,
+ "not all operands are constants or broadcasts from the same type");
}
// Collect the source values before broadcasting
SmallVector<Value> srcValues;
srcValues.reserve(op->getNumOperands());
for (Value operand : op->getOperands()) {
- srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ SplatElementsAttr splatConst;
+ if (matchPattern(operand, m_Constant(&splatConst))) {
+ Attribute newConst;
+ Type elementType = getElementTypeOrSelf(operand.getType());
+ Type newType = cloneOrReplace(unbroadcastResultType, elementType);
+ if (auto newTypeShaped = dyn_cast<ShapedType>(newType)) {
+ newConst = splatConst.resizeSplat(newTypeShaped);
+ } else {
+ newConst = splatConst.getSplatValue<Attribute>();
+ }
+ Operation *newConstOp =
+ operand.getDefiningOp()->getDialect()->materializeConstant(
+ rewriter, newConst, newType, operand.getLoc());
+ srcValues.push_back(newConstOp->getResult(0));
+ } else {
+ srcValues.push_back(operand.getDefiningOp()->getOperand(0));
+ }
}
// Create the "elementwise" Op
Operation *elementwiseOp =
rewriter.create(op->getLoc(), op->getName().getIdentifier(), srcValues,
- lhsBcastOrSplatType, op->getAttrs());
+ unbroadcastResultType, op->getAttrs());
// Replace the original Op with the elementwise Op
- auto vectorType = op->getResultTypes()[0];
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- op, vectorType, elementwiseOp->getResults());
+ op, resultType, elementwiseOp->getResults());
return success();
}
@@ -1239,15 +1302,17 @@ public:
return rewriter.notifyMatchFailure(
op, "only 1-element vectors are supported");
- Operation *splat = op.getValueToStore().getDefiningOp();
- if (!isa_and_present<vector::BroadcastOp, vector::SplatOp>(splat))
- return rewriter.notifyMatchFailure(op, "neither a splat nor a broadcast");
+ Value toStore = op.getValueToStore();
+ Value source = getBroadcastLikeSource(toStore);
+ if (!source)
+ return rewriter.notifyMatchFailure(
+ op, "value to store is not from a broadcast");
// Checking for single use so we can remove splat.
+ Operation *splat = toStore.getDefiningOp();
if (!splat->hasOneUse())
return rewriter.notifyMatchFailure(op, "expected single op use");
- Value source = splat->getOperand(0);
Value base = op.getBase();
ValueRange indices = op.getIndices();
@@ -1297,13 +1362,13 @@ static Value buildVectorComparison(PatternRewriter &rewriter, Operation *op,
// Add in an offset if requested.
if (off) {
Value o = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, *off);
- Value ov = vector::SplatOp::create(rewriter, loc, indices.getType(), o);
+ Value ov = vector::BroadcastOp::create(rewriter, loc, indices.getType(), o);
indices = arith::AddIOp::create(rewriter, loc, ov, indices);
}
// Construct the vector comparison.
Value bound = getValueOrCreateCastToIndexLike(rewriter, loc, idxType, b);
Value bounds =
- vector::SplatOp::create(rewriter, loc, indices.getType(), bound);
+ vector::BroadcastOp::create(rewriter, loc, indices.getType(), bound);
return arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::slt,
indices, bounds);
}
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
index 704deea..33450f3 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
@@ -110,6 +110,34 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return success();
}
+static LogicalResult
+isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
+ int64_t chunkSize,
+ function_ref<InFlightDiagnostic()> emitError) {
+
+ if (!valueTy)
+ return emitError() << "Expecting a vector type result.";
+
+ auto maskShape = getShapeOf(maskTy);
+ auto valueShape = getShapeOf(valueTy);
+
+ // a valid shape for SIMT case
+ if (valueTy.getRank() == 1) {
+ if (valueTy.getNumElements() != chunkSize)
+ return emitError() << "value elements must match chunk size " << chunkSize
+ << " for SIMT code.";
+ return success();
+ }
+
+ llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
+ if (chunkSize > 1)
+ expectedMaskShape.pop_back();
+ if (expectedMaskShape != maskShape)
+ return emitError() << "Mask should match value except the chunk size dim.";
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
@@ -644,9 +672,14 @@ LogicalResult CreateDescOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
- if (!tdescTy.isScattered())
+
+ if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
+ if (!tdescTy && getRankOf(getSource()) > 1)
+ return emitOpError(
+ "Expecting the source is a 1D memref or pointer (uint64_t).");
+
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -659,6 +692,13 @@ LogicalResult PrefetchOp::verify() {
return success();
}
+void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
+}
+
//===----------------------------------------------------------------------===//
// XeGPU_LoadGatherOp
//===----------------------------------------------------------------------===//
@@ -667,6 +707,13 @@ LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.");
+
+ if (!tdescTy && getRankOf(getSource()) > 1)
+ return emitOpError(
+ "Expecting the source is a 1D memref or pointer (uint64_t).");
+
if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -676,8 +723,27 @@ LogicalResult LoadGatherOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
- return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
- [&]() { return emitOpError(); });
+ if (tdescTy)
+ return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+ [&]() { return emitOpError(); });
+ auto srcTy = getSourceType();
+ uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
+ auto memTy = dyn_cast<MemRefType>(srcTy);
+
+ if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ return emitError() << "Value should have the same element type as MemRef.";
+
+ return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ [&]() { return emitOpError(); });
+}
+
+void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
+ Type valueType, Value source, Value mask,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
+ l1_hint, l2_hint, l3_hint);
}
//===----------------------------------------------------------------------===//
@@ -688,6 +754,13 @@ LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();
+ if (tdescTy && !tdescTy.isScattered())
+ return emitOpError("Expects a scattered TensorDesc.\n");
+
+ if (!tdescTy && getRankOf(getDest()) > 1)
+ return emitOpError(
+ "Expecting the dest is a 1D memref or pointer (uint64_t).");
+
if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
@@ -697,8 +770,28 @@ LogicalResult StoreScatterOp::verify() {
if (!isWriteHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();
- return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
- [&]() { return emitOpError(); });
+ if (tdescTy)
+ return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
+ [&]() { return emitOpError(); });
+
+ auto destTy = getDestType();
+ uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
+ auto memTy = dyn_cast<MemRefType>(destTy);
+
+ if (memTy && (valueTy.getElementType() != memTy.getElementType()))
+ return emitError() << "Value should have the same element type as MemRef.";
+
+ return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
+ [&]() { return emitOpError(); });
+}
+
+void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
+ Value value, Value dest, Value mask,
+ xegpu::CachePolicyAttr l1_hint,
+ xegpu::CachePolicyAttr l2_hint,
+ xegpu::CachePolicyAttr l3_hint) {
+ build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
+ l2_hint, l3_hint);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
index ec8fad4..c793b71 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
@@ -481,7 +481,8 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ // TODO: handle the unstructure source case (!tdesTy)
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -543,7 +544,8 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ // TODO: handle the unstructure source case (!tdesTy)
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
@@ -572,7 +574,8 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();
- if (!tdescTy.isScattered())
+ // TODO: handle the unstructure source case (!tdesTy)
+ if (!tdescTy || op.getOffsets())
return failure();
std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index f95ad29..de52fbd 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -40,7 +40,7 @@
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/Endian.h"
#include "llvm/Support/ManagedStatic.h"
#include "llvm/Support/Regex.h"
@@ -2070,9 +2070,8 @@ static OpPrintingFlags verifyOpAndAdjustFlags(Operation *op,
return failure();
});
if (failed(verify(op))) {
- LLVM_DEBUG(llvm::dbgs()
- << DEBUG_TYPE << ": '" << op->getName()
- << "' failed to verify and will be printed in generic form\n");
+ LDBG() << op->getName()
+ << "' failed to verify and will be printed in generic form";
printerFlags.printGenericOpForm();
}
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index 3e33795..776b5c6 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -821,15 +821,7 @@ SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
for (unsigned i = 0, e = mgr.getNumBuffers(); i != e; ++i)
(void)impl->computeExpectedDiags(out, mgr, mgr.getMemoryBuffer(i + 1));
- // Register a handler to verify the diagnostics.
- setHandler([&](Diagnostic &diag) {
- // Process the main diagnostics.
- process(diag);
-
- // Process each of the notes.
- for (auto &note : diag.getNotes())
- process(note);
- });
+ registerInContext(ctx);
}
SourceMgrDiagnosticVerifierHandler::SourceMgrDiagnosticVerifierHandler(
@@ -862,6 +854,17 @@ LogicalResult SourceMgrDiagnosticVerifierHandler::verify() {
return impl->status;
}
+void SourceMgrDiagnosticVerifierHandler::registerInContext(MLIRContext *ctx) {
+ ctx->getDiagEngine().registerHandler([&](Diagnostic &diag) {
+ // Process the main diagnostics.
+ process(diag);
+
+ // Process each of the notes.
+ for (auto &note : diag.getNotes())
+ process(note);
+ });
+}
+
/// Process a single diagnostic.
void SourceMgrDiagnosticVerifierHandler::process(Diagnostic &diag) {
return process(diag.getLocation(), diag.str(), diag.getSeverity());
diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp
index e9b5e92..310680b 100644
--- a/mlir/lib/Parser/Parser.cpp
+++ b/mlir/lib/Parser/Parser.cpp
@@ -17,14 +17,32 @@
using namespace mlir;
+static std::pair<int64_t, int64_t>
+getLineAndColStart(const llvm::SourceMgr &sourceMgr) {
+ unsigned lastFileID = sourceMgr.getNumBuffers();
+ if (lastFileID == 1)
+ return {0, 0};
+
+ auto bufferID = sourceMgr.getMainFileID();
+ const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID);
+ const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID);
+ // Exclude same start.
+ if (main->getBufferStart() < last->getBufferStart() &&
+ main->getBufferEnd() >= last->getBufferEnd()) {
+ return sourceMgr.getLineAndColumn(
+ llvm::SMLoc::getFromPointer(last->getBufferStart()), bufferID);
+ }
+ return {0, 0};
+}
+
LogicalResult mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
Block *block, const ParserConfig &config,
LocationAttr *sourceFileLoc) {
const auto *sourceBuf = sourceMgr.getMemoryBuffer(sourceMgr.getMainFileID());
if (sourceFileLoc) {
- *sourceFileLoc = FileLineColLoc::get(config.getContext(),
- sourceBuf->getBufferIdentifier(),
- /*line=*/0, /*column=*/0);
+ auto [line, column] = getLineAndColStart(sourceMgr);
+ *sourceFileLoc = FileLineColLoc::get(
+ config.getContext(), sourceBuf->getBufferIdentifier(), line, column);
}
if (isBytecode(*sourceBuf))
return readBytecodeFile(*sourceBuf, block, config);
@@ -37,9 +55,9 @@ mlir::parseSourceFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr,
const auto *sourceBuf =
sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID());
if (sourceFileLoc) {
- *sourceFileLoc = FileLineColLoc::get(config.getContext(),
- sourceBuf->getBufferIdentifier(),
- /*line=*/0, /*column=*/0);
+ auto [line, column] = getLineAndColStart(*sourceMgr);
+ *sourceFileLoc = FileLineColLoc::get(
+ config.getContext(), sourceBuf->getBufferIdentifier(), line, column);
}
if (isBytecode(*sourceBuf))
return readBytecodeFile(sourceMgr, block, config);
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 0db9808..7094c8e 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -901,7 +901,7 @@ LogicalResult PassManager::run(Operation *op) {
if (failed(initialize(context, impl->initializationGeneration + 1)))
return failure();
initializationKey = newInitKey;
- pipelineKey = pipelineInitializationKey;
+ pipelineInitializationKey = pipelineKey;
}
// Construct a top level analysis manager for the pipeline.
diff --git a/mlir/lib/RegisterAllDialects.cpp b/mlir/lib/RegisterAllDialects.cpp
new file mode 100644
index 0000000..7a345ed
--- /dev/null
+++ b/mlir/lib/RegisterAllDialects.cpp
@@ -0,0 +1,207 @@
+//===- RegisterAllDialects.cpp - MLIR Dialects Registration -----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all dialects and
+// passes to the system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllDialects.h"
+
+#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferViewFlowOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Arith/Transforms/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/Async/IR/Async.h"
+#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
+#include "mlir/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Complex/IR/Complex.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/ControlFlow/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/DLTI/DLTI.h"
+#include "mlir/Dialect/EmitC/IR/EmitC.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/GPU/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/GPU/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/IRDL/IR/IRDL.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h"
+#include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
+#include "mlir/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/MPI/IR/MPI.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/MemRef/IR/MemRefMemorySlot.h"
+#include "mlir/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/BufferViewFlowOpInterfaceImpl.h"
+#include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Dialect/Ptr/IR/PtrDialect.h"
+#include "mlir/Dialect/Quant/IR/Quant.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.h"
+#include "mlir/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/SMT/IR/SMTDialect.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
+#include "mlir/Dialect/Shape/IR/Shape.h"
+#include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Shard/IR/ShardDialect.h"
+#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
+#include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Tensor/Transforms/RuntimeOpVerification.h"
+#include "mlir/Dialect/Tensor/Transforms/SubsetInsertionOpInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/ShardingInterfaceImpl.h"
+#include "mlir/Dialect/Tosa/IR/TosaOps.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
+#include "mlir/Dialect/Vector/IR/ValueBoundsOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h"
+#include "mlir/Dialect/Vector/Transforms/SubsetOpInterfaceImpl.h"
+#include "mlir/Dialect/X86Vector/X86VectorDialect.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/IR/Dialect.h"
+#include "mlir/Interfaces/CastInterfaces.h"
+#include "mlir/Target/LLVM/NVVM/Target.h"
+#include "mlir/Target/LLVM/ROCDL/Target.h"
+#include "mlir/Target/SPIRV/Target.h"
+
+/// Add all the MLIR dialects to the provided registry.
+void mlir::registerAllDialects(DialectRegistry &registry) {
+ // clang-format off
+ registry.insert<acc::OpenACCDialect,
+ affine::AffineDialect,
+ amdgpu::AMDGPUDialect,
+ amx::AMXDialect,
+ arith::ArithDialect,
+ arm_neon::ArmNeonDialect,
+ arm_sme::ArmSMEDialect,
+ arm_sve::ArmSVEDialect,
+ async::AsyncDialect,
+ bufferization::BufferizationDialect,
+ cf::ControlFlowDialect,
+ complex::ComplexDialect,
+ DLTIDialect,
+ emitc::EmitCDialect,
+ func::FuncDialect,
+ gpu::GPUDialect,
+ index::IndexDialect,
+ irdl::IRDLDialect,
+ linalg::LinalgDialect,
+ LLVM::LLVMDialect,
+ math::MathDialect,
+ memref::MemRefDialect,
+ shard::ShardDialect,
+ ml_program::MLProgramDialect,
+ mpi::MPIDialect,
+ nvgpu::NVGPUDialect,
+ NVVM::NVVMDialect,
+ omp::OpenMPDialect,
+ pdl::PDLDialect,
+ pdl_interp::PDLInterpDialect,
+ ptr::PtrDialect,
+ quant::QuantDialect,
+ ROCDL::ROCDLDialect,
+ scf::SCFDialect,
+ shape::ShapeDialect,
+ smt::SMTDialect,
+ sparse_tensor::SparseTensorDialect,
+ spirv::SPIRVDialect,
+ tensor::TensorDialect,
+ tosa::TosaDialect,
+ transform::TransformDialect,
+ ub::UBDialect,
+ vector::VectorDialect,
+ x86vector::X86VectorDialect,
+ xegpu::XeGPUDialect,
+ xevm::XeVMDialect>();
+ // clang-format on
+
+ // Register all external models.
+ affine::registerValueBoundsOpInterfaceExternalModels(registry);
+ arith::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ arith::registerBufferizableOpInterfaceExternalModels(registry);
+ arith::registerBufferViewFlowOpInterfaceExternalModels(registry);
+ arith::registerShardingInterfaceExternalModels(registry);
+ arith::registerValueBoundsOpInterfaceExternalModels(registry);
+ bufferization::func_ext::registerBufferizableOpInterfaceExternalModels(
+ registry);
+ builtin::registerCastOpInterfaceExternalModels(registry);
+ cf::registerBufferizableOpInterfaceExternalModels(registry);
+ cf::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ gpu::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ gpu::registerValueBoundsOpInterfaceExternalModels(registry);
+ LLVM::registerInlinerInterface(registry);
+ NVVM::registerInlinerInterface(registry);
+ linalg::registerAllDialectInterfaceImplementations(registry);
+ linalg::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ memref::registerAllocationOpInterfaceExternalModels(registry);
+ memref::registerBufferViewFlowOpInterfaceExternalModels(registry);
+ memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ memref::registerValueBoundsOpInterfaceExternalModels(registry);
+ memref::registerMemorySlotExternalModels(registry);
+ ml_program::registerBufferizableOpInterfaceExternalModels(registry);
+ scf::registerBufferDeallocationOpInterfaceExternalModels(registry);
+ scf::registerBufferizableOpInterfaceExternalModels(registry);
+ scf::registerValueBoundsOpInterfaceExternalModels(registry);
+ shape::registerBufferizableOpInterfaceExternalModels(registry);
+ sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry);
+ tensor::registerBufferizableOpInterfaceExternalModels(registry);
+ tensor::registerFindPayloadReplacementOpInterfaceExternalModels(registry);
+ tensor::registerInferTypeOpInterfaceExternalModels(registry);
+ tensor::registerRuntimeVerifiableOpInterfaceExternalModels(registry);
+ tensor::registerSubsetOpInterfaceExternalModels(registry);
+ tensor::registerTilingInterfaceExternalModels(registry);
+ tensor::registerValueBoundsOpInterfaceExternalModels(registry);
+ tosa::registerShardingInterfaceExternalModels(registry);
+ vector::registerBufferizableOpInterfaceExternalModels(registry);
+ vector::registerSubsetOpInterfaceExternalModels(registry);
+ vector::registerValueBoundsOpInterfaceExternalModels(registry);
+ NVVM::registerNVVMTargetInterfaceExternalModels(registry);
+ ROCDL::registerROCDLTargetInterfaceExternalModels(registry);
+ spirv::registerSPIRVTargetInterfaceExternalModels(registry);
+}
+
+/// Append all the MLIR dialects to the registry contained in the given context.
+void mlir::registerAllDialects(MLIRContext &context) {
+ DialectRegistry registry;
+ registerAllDialects(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/RegisterAllExtensions.cpp b/mlir/lib/RegisterAllExtensions.cpp
new file mode 100644
index 0000000..8f7c67c
--- /dev/null
+++ b/mlir/lib/RegisterAllExtensions.cpp
@@ -0,0 +1,115 @@
+//===- RegisterAllExtensions.cpp - MLIR Extension Registration --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all dialect
+// extensions to the system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllExtensions.h"
+
+#include "mlir/Conversion/ArithToEmitC/ArithToEmitC.h"
+#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
+#include "mlir/Conversion/ComplexToLLVM/ComplexToLLVM.h"
+#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/FuncToEmitC/FuncToEmitC.h"
+#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
+#include "mlir/Conversion/GPUCommon/GPUToLLVM.h"
+#include "mlir/Conversion/GPUToNVVM/GPUToNVVM.h"
+#include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h"
+#include "mlir/Conversion/MPIToLLVM/MPIToLLVM.h"
+#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
+#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
+#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
+#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
+#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
+#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
+#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
+#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
+#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
+#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
+#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
+#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
+#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
+#include "mlir/Dialect/Func/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Func/TransformOps/FuncTransformOps.h"
+#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
+#include "mlir/Dialect/Linalg/TransformOps/DialectExtension.h"
+#include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h"
+#include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.h"
+#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
+#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h"
+#include "mlir/Dialect/Tensor/Extensions/AllExtensions.h"
+#include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h"
+#include "mlir/Dialect/Transform/DebugExtension/DebugExtension.h"
+#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
+#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
+#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
+#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
+#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
+#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
+#include "mlir/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.h"
+
+/// This function may be called to register all MLIR dialect extensions with the
+/// provided registry.
+/// If you're building a compiler, you generally shouldn't use this: you would
+/// individually register the specific extensions that are useful for the
+/// pipelines and transformations you are using.
+void mlir::registerAllExtensions(DialectRegistry &registry) {
+ // Register all conversions to LLVM extensions.
+ registerConvertArithToEmitCInterface(registry);
+ arith::registerConvertArithToLLVMInterface(registry);
+ registerConvertComplexToLLVMInterface(registry);
+ cf::registerConvertControlFlowToLLVMInterface(registry);
+ func::registerAllExtensions(registry);
+ tensor::registerAllExtensions(registry);
+ registerConvertFuncToEmitCInterface(registry);
+ registerConvertFuncToLLVMInterface(registry);
+ index::registerConvertIndexToLLVMInterface(registry);
+ registerConvertMathToLLVMInterface(registry);
+ mpi::registerConvertMPIToLLVMInterface(registry);
+ registerConvertMemRefToEmitCInterface(registry);
+ registerConvertMemRefToLLVMInterface(registry);
+ registerConvertNVVMToLLVMInterface(registry);
+ registerConvertOpenMPToLLVMInterface(registry);
+ registerConvertSCFToEmitCInterface(registry);
+ ub::registerConvertUBToLLVMInterface(registry);
+ registerConvertAMXToLLVMInterface(registry);
+ gpu::registerConvertGpuToLLVMInterface(registry);
+ NVVM::registerConvertGpuToNVVMInterface(registry);
+ vector::registerConvertVectorToLLVMInterface(registry);
+ registerConvertXeVMToLLVMInterface(registry);
+
+ // Register all transform dialect extensions.
+ affine::registerTransformDialectExtension(registry);
+ bufferization::registerTransformDialectExtension(registry);
+ dlti::registerTransformDialectExtension(registry);
+ func::registerTransformDialectExtension(registry);
+ gpu::registerTransformDialectExtension(registry);
+ linalg::registerTransformDialectExtension(registry);
+ memref::registerTransformDialectExtension(registry);
+ nvgpu::registerTransformDialectExtension(registry);
+ scf::registerTransformDialectExtension(registry);
+ sparse_tensor::registerTransformDialectExtension(registry);
+ tensor::registerTransformDialectExtension(registry);
+ transform::registerDebugExtension(registry);
+ transform::registerIRDLExtension(registry);
+ transform::registerLoopExtension(registry);
+ transform::registerPDLExtension(registry);
+ transform::registerTuneExtension(registry);
+ vector::registerTransformDialectExtension(registry);
+ arm_neon::registerTransformDialectExtension(registry);
+ arm_sve::registerTransformDialectExtension(registry);
+
+ // Translation extensions need to be registered by calling
+ // `registerAllToLLVMIRTranslations` (see All.h).
+}
diff --git a/mlir/lib/RegisterAllPasses.cpp b/mlir/lib/RegisterAllPasses.cpp
new file mode 100644
index 0000000..1ed3a37
--- /dev/null
+++ b/mlir/lib/RegisterAllPasses.cpp
@@ -0,0 +1,99 @@
+//===- RegisterAllPasses.cpp - MLIR Registration ----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines a helper to trigger the registration of all passes to the
+// system.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/InitAllPasses.h"
+
+#include "mlir/Conversion/Passes.h"
+#include "mlir/Dialect/AMDGPU/Transforms/Passes.h"
+#include "mlir/Dialect/Affine/Passes.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Passes.h"
+#include "mlir/Dialect/Async/Passes.h"
+#include "mlir/Dialect/Bufferization/Pipelines/Passes.h"
+#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
+#include "mlir/Dialect/EmitC/Transforms/Passes.h"
+#include "mlir/Dialect/Func/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/Pipelines/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/Transforms/Passes.h"
+#include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/MLProgram/Transforms/Passes.h"
+#include "mlir/Dialect/Math/Transforms/Passes.h"
+#include "mlir/Dialect/MemRef/Transforms/Passes.h"
+#include "mlir/Dialect/NVGPU/Transforms/Passes.h"
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+#include "mlir/Dialect/Quant/Transforms/Passes.h"
+#include "mlir/Dialect/SCF/Transforms/Passes.h"
+#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
+#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/Shard/Transforms/Passes.h"
+#include "mlir/Dialect/SparseTensor/Pipelines/Passes.h"
+#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tensor/Transforms/Passes.h"
+#include "mlir/Dialect/Tosa/Transforms/Passes.h"
+#include "mlir/Dialect/Transform/Transforms/Passes.h"
+#include "mlir/Dialect/Vector/Transforms/Passes.h"
+#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
+#include "mlir/Transforms/Passes.h"
+
+// This function may be called to register the MLIR passes with the
+// global registry.
+// If you're building a compiler, you likely don't need this: you would build a
+// pipeline programmatically without the need to register with the global
+// registry, since it would already be calling the creation routine of the
+// individual passes.
+// The global registry is interesting to interact with the command-line tools.
+void mlir::registerAllPasses() {
+ // General passes
+ registerTransformsPasses();
+
+ // Conversion passes
+ registerConversionPasses();
+
+ // Dialect passes
+ acc::registerOpenACCPasses();
+ affine::registerAffinePasses();
+ amdgpu::registerAMDGPUPasses();
+ registerAsyncPasses();
+ arith::registerArithPasses();
+ bufferization::registerBufferizationPasses();
+ func::registerFuncPasses();
+ registerGPUPasses();
+ registerLinalgPasses();
+ registerNVGPUPasses();
+ registerSparseTensorPasses();
+ LLVM::registerLLVMPasses();
+ math::registerMathPasses();
+ memref::registerMemRefPasses();
+ shard::registerShardPasses();
+ ml_program::registerMLProgramPasses();
+ quant::registerQuantPasses();
+ registerSCFPasses();
+ registerShapePasses();
+ spirv::registerSPIRVPasses();
+ tensor::registerTensorPasses();
+ tosa::registerTosaOptPasses();
+ transform::registerTransformPasses();
+ vector::registerVectorPasses();
+ arm_sme::registerArmSMEPasses();
+ arm_sve::registerArmSVEPasses();
+ emitc::registerEmitCPasses();
+ xegpu::registerXeGPUPasses();
+
+ // Dialect pipelines
+ bufferization::registerBufferizationPipelines();
+ sparse_tensor::registerSparseTensorPipelines();
+ tosa::registerTosaToLinalgPipelines();
+ gpu::registerGPUToNVVMPipeline();
+}
diff --git a/mlir/lib/Rewrite/PatternApplicator.cpp b/mlir/lib/Rewrite/PatternApplicator.cpp
index b2b372b..e13bcff 100644
--- a/mlir/lib/Rewrite/PatternApplicator.cpp
+++ b/mlir/lib/Rewrite/PatternApplicator.cpp
@@ -13,7 +13,7 @@
#include "mlir/Rewrite/PatternApplicator.h"
#include "ByteCode.h"
-#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#ifndef NDEBUG
#include "llvm/ADT/ScopeExit.h"
@@ -51,9 +51,7 @@ static Operation *getDumpRootOp(Operation *op) {
return op;
}
static void logSucessfulPatternApplication(Operation *op) {
- llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n";
- op->dump();
- llvm::dbgs() << "\n\n";
+ LDBG(2) << "// *** IR Dump After Pattern Application ***\n" << *op << "\n";
}
#endif
@@ -208,8 +206,8 @@ LogicalResult PatternApplicator::matchAndRewrite(
result =
bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
} else {
- LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
- << bestPattern->getDebugName() << "\"\n");
+ LDBG() << "Trying to match \"" << bestPattern->getDebugName()
+ << "\"";
const auto *pattern =
static_cast<const RewritePattern *>(bestPattern);
@@ -223,9 +221,8 @@ LogicalResult PatternApplicator::matchAndRewrite(
[&] { rewriter.setListener(oldListener); });
#endif
result = pattern->matchAndRewrite(op, rewriter);
- LLVM_DEBUG(llvm::dbgs()
- << "\"" << bestPattern->getDebugName() << "\" result "
- << succeeded(result) << "\n");
+ LDBG() << " -> matchAndRewrite "
+ << (succeeded(result) ? "successful" : "failed");
}
// Process the result of the pattern application.
diff --git a/mlir/lib/Support/ToolUtilities.cpp b/mlir/lib/Support/ToolUtilities.cpp
index 748f928..2cf33eb 100644
--- a/mlir/lib/Support/ToolUtilities.cpp
+++ b/mlir/lib/Support/ToolUtilities.cpp
@@ -14,6 +14,8 @@
#include "mlir/Support/LLVM.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
+#include <string>
+#include <utility>
using namespace mlir;
@@ -22,18 +24,18 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
ChunkBufferHandler processChunkBuffer,
raw_ostream &os, llvm::StringRef inputSplitMarker,
llvm::StringRef outputSplitMarker) {
+ llvm::MemoryBufferRef originalBufferRef = originalBuffer->getMemBufferRef();
// If splitting is disabled, we process the full input buffer.
if (inputSplitMarker.empty())
- return processChunkBuffer(std::move(originalBuffer), os);
+ return processChunkBuffer(std::move(originalBuffer), originalBufferRef, os);
const int inputSplitMarkerLen = inputSplitMarker.size();
- auto *origMemBuffer = originalBuffer.get();
SmallVector<StringRef, 8> rawSourceBuffers;
const int checkLen = 2;
// Split dropping the last checkLen chars to enable flagging near misses.
- origMemBuffer->getBuffer().split(rawSourceBuffers,
- inputSplitMarker.drop_back(checkLen));
+ originalBufferRef.getBuffer().split(rawSourceBuffers,
+ inputSplitMarker.drop_back(checkLen));
if (rawSourceBuffers.empty())
return success();
@@ -79,11 +81,17 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
auto interleaveFn = [&](StringRef subBuffer) {
auto splitLoc = SMLoc::getFromPointer(subBuffer.data());
unsigned splitLine = fileSourceMgr.getLineAndColumn(splitLoc).first;
- auto subMemBuffer = llvm::MemoryBuffer::getMemBufferCopy(
- subBuffer, Twine("within split at ") +
- origMemBuffer->getBufferIdentifier() + ":" +
- Twine(splitLine) + " offset ");
- if (failed(processChunkBuffer(std::move(subMemBuffer), os)))
+ std::string name((Twine("within split at ") +
+ originalBufferRef.getBufferIdentifier() + ":" +
+ Twine(splitLine) + " offset ")
+ .str());
+ // Use MemoryBufferRef to avoid copying the buffer & keep at same location
+ // relative to the original buffer.
+ auto subMemBuffer =
+ llvm::MemoryBuffer::getMemBuffer(llvm::MemoryBufferRef(subBuffer, name),
+ /*RequiresNullTerminator=*/false);
+ if (failed(
+ processChunkBuffer(std::move(subMemBuffer), originalBufferRef, os)))
hadFailure = true;
};
llvm::interleave(sourceBuffers, os, interleaveFn,
@@ -92,3 +100,16 @@ mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
// If any fails, then return a failure of the tool.
return failure(hadFailure);
}
+
+LogicalResult
+mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer> originalBuffer,
+ NoSourceChunkBufferHandler processChunkBuffer,
+ raw_ostream &os, llvm::StringRef inputSplitMarker,
+ llvm::StringRef outputSplitMarker) {
+ auto process = [&](std::unique_ptr<llvm::MemoryBuffer> chunkBuffer,
+ const llvm::MemoryBufferRef &, raw_ostream &os) {
+ return processChunkBuffer(std::move(chunkBuffer), os);
+ };
+ return splitAndProcessBuffer(std::move(originalBuffer), process, os,
+ inputSplitMarker, outputSplitMarker);
+}
diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
index dcd2e11..8e83e45 100644
--- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp
+++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp
@@ -333,7 +333,8 @@ private:
/// Determine whether expression \p op should be emitted in a deferred way.
static bool hasDeferredEmission(Operation *op) {
return isa_and_nonnull<emitc::GetGlobalOp, emitc::LiteralOp, emitc::MemberOp,
- emitc::MemberOfPtrOp, emitc::SubscriptOp>(op);
+ emitc::MemberOfPtrOp, emitc::SubscriptOp,
+ emitc::GetFieldOp>(op);
}
/// Determine whether expression \p expressionOp should be emitted inline, i.e.
@@ -1049,25 +1050,17 @@ static LogicalResult printOperation(CppEmitter &emitter, ClassOp classOp) {
static LogicalResult printOperation(CppEmitter &emitter, FieldOp fieldOp) {
raw_ostream &os = emitter.ostream();
- if (failed(emitter.emitType(fieldOp->getLoc(), fieldOp.getType())))
+ if (failed(emitter.emitVariableDeclaration(
+ fieldOp->getLoc(), fieldOp.getType(), fieldOp.getSymName())))
return failure();
- os << " " << fieldOp.getSymName() << ";";
- return success();
-}
-
-static LogicalResult printOperation(CppEmitter &emitter,
- GetFieldOp getFieldOp) {
- raw_indented_ostream &os = emitter.ostream();
-
- Value result = getFieldOp.getResult();
- if (failed(emitter.emitType(getFieldOp->getLoc(), result.getType())))
- return failure();
- os << " ";
- if (failed(emitter.emitOperand(result)))
- return failure();
- os << " = ";
+ std::optional<Attribute> initialValue = fieldOp.getInitialValue();
+ if (initialValue) {
+ os << " = ";
+ if (failed(emitter.emitAttribute(fieldOp->getLoc(), *initialValue)))
+ return failure();
+ }
- os << getFieldOp.getFieldName().str();
+ os << ";";
return success();
}
@@ -1204,7 +1197,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
os << ") {\n";
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
return failure();
- os << "}\n";
+ os << "}";
return success();
}
@@ -1245,7 +1238,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
os << ") {\n";
if (failed(printFunctionBody(emitter, operation, functionOp.getBlocks())))
return failure();
- os << "}\n";
+ os << "}";
return success();
}
@@ -1700,12 +1693,11 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
emitc::CmpOp, emitc::ConditionalOp, emitc::ConstantOp,
emitc::DeclareFuncOp, emitc::DivOp, emitc::ExpressionOp,
emitc::FieldOp, emitc::FileOp, emitc::ForOp, emitc::FuncOp,
- emitc::GetFieldOp, emitc::GlobalOp, emitc::IfOp,
- emitc::IncludeOp, emitc::LoadOp, emitc::LogicalAndOp,
- emitc::LogicalNotOp, emitc::LogicalOrOp, emitc::MulOp,
- emitc::RemOp, emitc::ReturnOp, emitc::SubOp, emitc::SwitchOp,
- emitc::UnaryMinusOp, emitc::UnaryPlusOp, emitc::VariableOp,
- emitc::VerbatimOp>(
+ emitc::GlobalOp, emitc::IfOp, emitc::IncludeOp, emitc::LoadOp,
+ emitc::LogicalAndOp, emitc::LogicalNotOp, emitc::LogicalOrOp,
+ emitc::MulOp, emitc::RemOp, emitc::ReturnOp, emitc::SubOp,
+ emitc::SwitchOp, emitc::UnaryMinusOp, emitc::UnaryPlusOp,
+ emitc::VariableOp, emitc::VerbatimOp>(
[&](auto op) { return printOperation(*this, op); })
// Func ops.
@@ -1715,6 +1707,10 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
cacheDeferredOpResult(op.getResult(), op.getName());
return success();
})
+ .Case<emitc::GetFieldOp>([&](auto op) {
+ cacheDeferredOpResult(op.getResult(), op.getFieldName());
+ return success();
+ })
.Case<emitc::LiteralOp>([&](auto op) {
cacheDeferredOpResult(op.getResult(), op.getValue());
return success();
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
index af22a7f..9ea5c683 100644
--- a/mlir/lib/Target/LLVMIR/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/CMakeLists.txt
@@ -60,6 +60,7 @@ add_mlir_translation_library(MLIRToLLVMIRTranslationRegistration
MLIRROCDLToLLVMIRTranslation
MLIRSPIRVToLLVMIRTranslation
MLIRVCIXToLLVMIRTranslation
+ MLIRXeVMToLLVMIRTranslation
)
add_mlir_translation_library(MLIRTargetLLVMIRImport
diff --git a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
index f030fa7..86c731a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
+++ b/mlir/lib/Target/LLVMIR/Dialect/CMakeLists.txt
@@ -10,3 +10,4 @@ add_subdirectory(OpenMP)
add_subdirectory(ROCDL)
add_subdirectory(SPIRV)
add_subdirectory(VCIX)
+add_subdirectory(XeVM)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
index ff34a08..0f675a0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp
@@ -13,6 +13,7 @@
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Operation.h"
+#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
@@ -136,46 +137,6 @@ convertOperandBundles(OperandRangeRange bundleOperands,
return convertOperandBundles(bundleOperands, *bundleTags, moduleTranslation);
}
-static LogicalResult
-convertParameterAndResultAttrs(mlir::Location loc, ArrayAttr argAttrsArray,
- ArrayAttr resAttrsArray, llvm::CallBase *call,
- LLVM::ModuleTranslation &moduleTranslation) {
- if (argAttrsArray) {
- for (auto [argIdx, argAttrsAttr] : llvm::enumerate(argAttrsArray)) {
- if (auto argAttrs = cast<DictionaryAttr>(argAttrsAttr);
- !argAttrs.empty()) {
- FailureOr<llvm::AttrBuilder> attrBuilder =
- moduleTranslation.convertParameterAttrs(loc, argAttrs);
- if (failed(attrBuilder))
- return failure();
- call->addParamAttrs(argIdx, *attrBuilder);
- }
- }
- }
-
- if (resAttrsArray && resAttrsArray.size() > 0) {
- if (resAttrsArray.size() != 1)
- return mlir::emitError(loc, "llvm.func cannot have multiple results");
- if (auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
- !resAttrs.empty()) {
- FailureOr<llvm::AttrBuilder> attrBuilder =
- moduleTranslation.convertParameterAttrs(loc, resAttrs);
- if (failed(attrBuilder))
- return failure();
- call->addRetAttrs(*attrBuilder);
- }
- }
- return success();
-}
-
-static LogicalResult
-convertParameterAndResultAttrs(CallOpInterface callOp, llvm::CallBase *call,
- LLVM::ModuleTranslation &moduleTranslation) {
- return convertParameterAndResultAttrs(
- callOp.getLoc(), callOp.getArgAttrsAttr(), callOp.getResAttrsAttr(), call,
- moduleTranslation);
-}
-
/// Builder for LLVM_CallIntrinsicOp
static LogicalResult
convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
@@ -243,9 +204,7 @@ convertCallLLVMIntrinsicOp(CallIntrinsicOp op, llvm::IRBuilderBase &builder,
convertOperandBundles(op.getOpBundleOperands(), op.getOpBundleTags(),
moduleTranslation));
- if (failed(convertParameterAndResultAttrs(op.getLoc(), op.getArgAttrsAttr(),
- op.getResAttrsAttr(), inst,
- moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(op, inst)))
return failure();
if (op.getNumResults() == 1)
@@ -455,7 +414,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
if (callOp.getInlineHintAttr())
call->addFnAttr(llvm::Attribute::InlineHint);
- if (failed(convertParameterAndResultAttrs(callOp, call, moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(callOp, call)))
return failure();
if (MemoryEffectsAttr memAttr = callOp.getMemoryEffectsAttr()) {
@@ -569,8 +528,7 @@ convertOperationImpl(Operation &opInst, llvm::IRBuilderBase &builder,
operandsRef.drop_front(), opBundles);
}
result->setCallingConv(convertCConvToLLVM(invOp.getCConv()));
- if (failed(
- convertParameterAndResultAttrs(invOp, result, moduleTranslation)))
+ if (failed(moduleTranslation.convertArgAndResultAttrs(invOp, result)))
return failure();
moduleTranslation.mapBranch(invOp, result);
// InvokeOp can only have 0 or 1 result
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
index 1c9e226..55e73e8 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp
@@ -13,6 +13,7 @@
#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Target/LLVMIR/ModuleImport.h"
+#include "llvm/IR/ConstantRange.h"
using namespace mlir;
using namespace mlir::NVVM;
diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
index b3577c6..90462d1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.cpp
@@ -164,6 +164,42 @@ static llvm::Intrinsic::ID getLdMatrixIntrinsicId(NVVM::MMALayout layout,
}
}
+/// Return the intrinsic ID associated with stmatrix for the given paramters.
+static llvm::Intrinsic::ID
+getStMatrixIntrinsicId(NVVM::MMALayout layout, int32_t num,
+ NVVM::LdStMatrixShapeAttr shape,
+ NVVM::LdStMatrixEltType eltType) {
+ if (shape.getM() == 8 && shape.getN() == 8) {
+ switch (num) {
+ case 1:
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x1_b16
+ : llvm::Intrinsic::
+ nvvm_stmatrix_sync_aligned_m8n8_x1_trans_b16;
+ case 2:
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x2_b16
+ : llvm::Intrinsic::
+ nvvm_stmatrix_sync_aligned_m8n8_x2_trans_b16;
+ case 4:
+ return (layout == NVVM::MMALayout::row)
+ ? llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m8n8_x4_b16
+ : llvm::Intrinsic::
+ nvvm_stmatrix_sync_aligned_m8n8_x4_trans_b16;
+ }
+ } else if (shape.getM() == 16 && shape.getN() == 8) {
+ switch (num) {
+ case 1:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x1_trans_b8;
+ case 2:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x2_trans_b8;
+ case 4:
+ return llvm::Intrinsic::nvvm_stmatrix_sync_aligned_m16n8_x4_trans_b8;
+ }
+ }
+ llvm_unreachable("unknown stmatrix kind");
+}
+
/// Return the intrinsic ID associated with st.bulk for the given address type.
static llvm::Intrinsic::ID
getStBulkIntrinsicId(LLVM::LLVMPointerType addrType) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9f18199..49e1e55 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3877,29 +3877,28 @@ static omp::MapInfoOp getFirstOrLastMappedMemberPtr(omp::MapInfoOp mapInfo,
llvm::SmallVector<size_t> indices(indexAttr.size());
std::iota(indices.begin(), indices.end(), 0);
- llvm::sort(indices.begin(), indices.end(),
- [&](const size_t a, const size_t b) {
- auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
- auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
- for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
- int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
- int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
-
- if (aIndex == bIndex)
- continue;
-
- if (aIndex < bIndex)
- return first;
-
- if (aIndex > bIndex)
- return !first;
- }
-
- // Iterated the up until the end of the smallest member and
- // they were found to be equal up to that point, so select
- // the member with the lowest index count, so the "parent"
- return memberIndicesA.size() < memberIndicesB.size();
- });
+ llvm::sort(indices, [&](const size_t a, const size_t b) {
+ auto memberIndicesA = cast<ArrayAttr>(indexAttr[a]);
+ auto memberIndicesB = cast<ArrayAttr>(indexAttr[b]);
+ for (const auto it : llvm::zip(memberIndicesA, memberIndicesB)) {
+ int64_t aIndex = cast<IntegerAttr>(std::get<0>(it)).getInt();
+ int64_t bIndex = cast<IntegerAttr>(std::get<1>(it)).getInt();
+
+ if (aIndex == bIndex)
+ continue;
+
+ if (aIndex < bIndex)
+ return first;
+
+ if (aIndex > bIndex)
+ return !first;
+ }
+
+ // Iterated the up until the end of the smallest member and
+ // they were found to be equal up to that point, so select
+ // the member with the lowest index count, so the "parent"
+ return memberIndicesA.size() < memberIndicesB.size();
+ });
return llvm::cast<omp::MapInfoOp>(
mapInfo.getMembers()[indices.front()].getDefiningOp());
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
new file mode 100644
index 0000000..6308d7e
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/CMakeLists.txt
@@ -0,0 +1,21 @@
+set(LLVM_OPTIONAL_SOURCES
+ XeVMToLLVMIRTranslation.cpp
+)
+
+add_mlir_translation_library(MLIRXeVMToLLVMIRTranslation
+ XeVMToLLVMIRTranslation.cpp
+
+ DEPENDS
+ MLIRXeVMConversionsIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRIR
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRSupport
+ MLIRTargetLLVMIRExport
+)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
new file mode 100644
index 0000000..73b166d
--- /dev/null
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -0,0 +1,103 @@
+//===-- XeVMToLLVMIRTranslation.cpp - Translate XeVM to LLVM IR -*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a translation between the MLIR XeVM dialect and
+// LLVM IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Metadata.h"
+
+#include "llvm/IR/ConstantRange.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+using namespace mlir::LLVM;
+
+namespace {
+/// Implementation of the dialect interface that converts operations belonging
+/// to the XeVM dialect to LLVM IR.
+class XeVMDialectLLVMIRTranslationInterface
+ : public LLVMTranslationDialectInterface {
+public:
+ using LLVMTranslationDialectInterface::LLVMTranslationDialectInterface;
+
+ /// Attaches module-level metadata for functions marked as kernels.
+ LogicalResult
+ amendOperation(Operation *op, ArrayRef<llvm::Instruction *> instructions,
+ NamedAttribute attribute,
+ LLVM::ModuleTranslation &moduleTranslation) const final {
+ StringRef attrName = attribute.getName().getValue();
+ if (attrName == mlir::xevm::XeVMDialect::getCacheControlsAttrName()) {
+ auto cacheControlsArray = dyn_cast<ArrayAttr>(attribute.getValue());
+ if (cacheControlsArray.size() != 2) {
+ return op->emitOpError(
+ "Expected both L1 and L3 cache control attributes!");
+ }
+ if (instructions.size() != 1) {
+ return op->emitOpError("Expecting a single instruction");
+ }
+ return handleDecorationCacheControl(instructions.front(),
+ cacheControlsArray.getValue());
+ }
+ auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
+ if (!func)
+ return failure();
+
+ return success();
+ }
+
+private:
+ static LogicalResult handleDecorationCacheControl(llvm::Instruction *inst,
+ ArrayRef<Attribute> attrs) {
+ SmallVector<llvm::Metadata *> decorations;
+ llvm::LLVMContext &ctx = inst->getContext();
+ llvm::Type *i32Ty = llvm::IntegerType::getInt32Ty(ctx);
+ llvm::transform(
+ attrs, std::back_inserter(decorations),
+ [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * {
+ auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue();
+ std::array<llvm::Metadata *, 4> metadata;
+ llvm::transform(
+ valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) {
+ return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
+ i32Ty, cast<IntegerAttr>(valueAttr).getValue()));
+ });
+ return llvm::MDNode::get(ctx, metadata);
+ });
+ constexpr llvm::StringLiteral decorationCacheControlMDName =
+ "spirv.DecorationCacheControlINTEL";
+ inst->setMetadata(decorationCacheControlMDName,
+ llvm::MDNode::get(ctx, decorations));
+ return success();
+ }
+};
+} // namespace
+
+void mlir::registerXeVMDialectTranslation(::mlir::DialectRegistry &registry) {
+ registry.insert<xevm::XeVMDialect>();
+ registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
+ dialect->addInterfaces<XeVMDialectLLVMIRTranslationInterface>();
+ });
+}
+
+void mlir::registerXeVMDialectTranslation(::mlir::MLIRContext &context) {
+ DialectRegistry registry;
+ registerXeVMDialectTranslation(registry);
+ context.appendDialectRegistry(registry);
+}
diff --git a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
index 580afdd..cb1f234 100644
--- a/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
+++ b/mlir/lib/Target/LLVMIR/LLVMImportInterface.cpp
@@ -33,7 +33,9 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(
SmallVector<Value> mlirOperands;
SmallVector<NamedAttribute> mlirAttrs;
if (failed(moduleImport.convertIntrinsicArguments(
- llvmOperands, llvmOpBundles, false, {}, {}, mlirOperands, mlirAttrs)))
+ llvmOperands, llvmOpBundles, /*requiresOpBundles=*/false,
+ /*immArgPositions=*/{}, /*immArgAttrNames=*/{}, mlirOperands,
+ mlirAttrs)))
return failure();
Type resultType = moduleImport.convertType(inst->getType());
@@ -44,11 +46,7 @@ LogicalResult mlir::LLVMImportInterface::convertUnregisteredIntrinsic(
ValueRange{mlirOperands}, FastmathFlagsAttr{});
moduleImport.setFastmathFlagsAttr(inst, op);
-
- ArrayAttr argsAttr, resAttr;
- moduleImport.convertParameterAttributes(inst, argsAttr, resAttr, builder);
- op.setArgAttrsAttr(argsAttr);
- op.setResAttrsAttr(resAttr);
+ moduleImport.convertArgAndResultAttrs(inst, op);
// Update importer tracking of results.
unsigned numRes = op.getNumResults();
diff --git a/mlir/lib/Target/LLVMIR/ModuleImport.cpp b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
index 58e3c44..6325480 100644
--- a/mlir/lib/Target/LLVMIR/ModuleImport.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleImport.cpp
@@ -30,6 +30,7 @@
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/PostOrderIterator.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/IR/Comdat.h"
#include "llvm/IR/Constants.h"
@@ -1063,6 +1064,18 @@ void ModuleImport::convertTargetTriple() {
builder.getStringAttr(llvmModule->getTargetTriple().str()));
}
+void ModuleImport::convertModuleLevelAsm() {
+ llvm::StringRef asmStr = llvmModule->getModuleInlineAsm();
+ llvm::SmallVector<mlir::Attribute> asmArrayAttr;
+
+ for (llvm::StringRef line : llvm::split(asmStr, '\n'))
+ if (!line.empty())
+ asmArrayAttr.push_back(builder.getStringAttr(line));
+
+ mlirModule->setAttr(LLVM::LLVMDialect::getModuleLevelAsmAttrName(),
+ builder.getArrayAttr(asmArrayAttr));
+}
+
LogicalResult ModuleImport::convertFunctions() {
for (llvm::Function &func : llvmModule->functions())
if (failed(processFunction(&func)))
@@ -2267,7 +2280,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
// Handle parameter and result attributes unless it's an incompatible
// call.
if (!isIncompatibleCall)
- convertParameterAttributes(callInst, callOp, builder);
+ convertArgAndResultAttrs(callInst, callOp);
return callOp.getOperation();
}();
@@ -2364,7 +2377,7 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
// Handle parameter and result attributes unless it's an incompatible
// invoke.
if (!isIncompatibleInvoke)
- convertParameterAttributes(invokeInst, invokeOp, builder);
+ convertArgAndResultAttrs(invokeInst, invokeOp);
if (!invokeInst->getType()->isVoidTy())
mapValue(inst, invokeOp.getResults().front());
@@ -2730,11 +2743,10 @@ void ModuleImport::processFunctionAttributes(llvm::Function *func,
}
DictionaryAttr
-ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
- OpBuilder &builder) {
+ModuleImport::convertArgOrResultAttrSet(llvm::AttributeSet llvmAttrSet) {
SmallVector<NamedAttribute> paramAttrs;
for (auto [llvmKind, mlirName] : getAttrKindToNameMapping()) {
- auto llvmAttr = llvmParamAttrs.getAttribute(llvmKind);
+ auto llvmAttr = llvmAttrSet.getAttribute(llvmKind);
// Skip attributes that are not attached.
if (!llvmAttr.isValid())
continue;
@@ -2769,13 +2781,12 @@ ModuleImport::convertParameterAttribute(llvm::AttributeSet llvmParamAttrs,
return builder.getDictionaryAttr(paramAttrs);
}
-void ModuleImport::convertParameterAttributes(llvm::Function *func,
- LLVMFuncOp funcOp,
- OpBuilder &builder) {
+void ModuleImport::convertArgAndResultAttrs(llvm::Function *func,
+ LLVMFuncOp funcOp) {
auto llvmAttrs = func->getAttributes();
for (size_t i = 0, e = funcOp.getNumArguments(); i < e; ++i) {
llvm::AttributeSet llvmArgAttrs = llvmAttrs.getParamAttrs(i);
- funcOp.setArgAttrs(i, convertParameterAttribute(llvmArgAttrs, builder));
+ funcOp.setArgAttrs(i, convertArgOrResultAttrSet(llvmArgAttrs));
}
// Convert the result attributes and attach them wrapped in an ArrayAttribute
// to the funcOp.
@@ -2783,17 +2794,23 @@ void ModuleImport::convertParameterAttributes(llvm::Function *func,
if (!llvmResAttr.hasAttributes())
return;
funcOp.setResAttrsAttr(
- builder.getArrayAttr(convertParameterAttribute(llvmResAttr, builder)));
+ builder.getArrayAttr({convertArgOrResultAttrSet(llvmResAttr)}));
}
-void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
- ArrayAttr &argsAttr,
- ArrayAttr &resAttr,
- OpBuilder &builder) {
+void ModuleImport::convertArgAndResultAttrs(
+ llvm::CallBase *call, ArgAndResultAttrsOpInterface attrsOp,
+ ArrayRef<unsigned> immArgPositions) {
+ // Compute the set of immediate argument positions.
+ llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(),
+ immArgPositions.end());
+ // Convert the argument attributes and filter out immediate arguments.
llvm::AttributeList llvmAttrs = call->getAttributes();
SmallVector<llvm::AttributeSet> llvmArgAttrsSet;
bool anyArgAttrs = false;
for (size_t i = 0, e = call->arg_size(); i < e; ++i) {
+ // Skip immediate arguments.
+ if (immArgPositionsSet.contains(i))
+ continue;
llvmArgAttrsSet.emplace_back(llvmAttrs.getParamAttrs(i));
if (llvmArgAttrsSet.back().hasAttributes())
anyArgAttrs = true;
@@ -2807,24 +2824,16 @@ void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
if (anyArgAttrs) {
SmallVector<DictionaryAttr> argAttrs;
for (auto &llvmArgAttrs : llvmArgAttrsSet)
- argAttrs.emplace_back(convertParameterAttribute(llvmArgAttrs, builder));
- argsAttr = getArrayAttr(argAttrs);
+ argAttrs.emplace_back(convertArgOrResultAttrSet(llvmArgAttrs));
+ attrsOp.setArgAttrsAttr(getArrayAttr(argAttrs));
}
+ // Convert the result attributes.
llvm::AttributeSet llvmResAttr = llvmAttrs.getRetAttrs();
if (!llvmResAttr.hasAttributes())
return;
- DictionaryAttr resAttrs = convertParameterAttribute(llvmResAttr, builder);
- resAttr = getArrayAttr({resAttrs});
-}
-
-void ModuleImport::convertParameterAttributes(llvm::CallBase *call,
- CallOpInterface callOp,
- OpBuilder &builder) {
- ArrayAttr argsAttr, resAttr;
- convertParameterAttributes(call, argsAttr, resAttr, builder);
- callOp.setArgAttrsAttr(argsAttr);
- callOp.setResAttrsAttr(resAttr);
+ DictionaryAttr resAttrs = convertArgOrResultAttrSet(llvmResAttr);
+ attrsOp.setResAttrsAttr(getArrayAttr({resAttrs}));
}
template <typename Op>
@@ -2892,7 +2901,7 @@ LogicalResult ModuleImport::processFunction(llvm::Function *func) {
builder, loc, func->getName(), functionType,
convertLinkageFromLLVM(func->getLinkage()), dsoLocal, cconv);
- convertParameterAttributes(func, funcOp, builder);
+ convertArgAndResultAttrs(func, funcOp);
if (FlatSymbolRefAttr personality = getPersonalityAsAttr(func))
funcOp.setPersonalityAttr(personality);
@@ -3199,5 +3208,6 @@ OwningOpRef<ModuleOp> mlir::translateLLVMIRToModule(
if (failed(moduleImport.convertIFuncs()))
return {};
moduleImport.convertTargetTriple();
+ moduleImport.convertModuleLevelAsm();
return module;
}
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index b997e55..b3a06e2 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -1758,6 +1758,48 @@ ModuleTranslation::convertParameterAttrs(LLVMFuncOp func, int argIdx,
return attrBuilder;
}
+LogicalResult ModuleTranslation::convertArgAndResultAttrs(
+ ArgAndResultAttrsOpInterface attrsOp, llvm::CallBase *call,
+ ArrayRef<unsigned> immArgPositions) {
+ // Convert the argument attributes.
+ if (ArrayAttr argAttrsArray = attrsOp.getArgAttrsAttr()) {
+ unsigned argAttrIdx = 0;
+ llvm::SmallDenseSet<unsigned> immArgPositionsSet(immArgPositions.begin(),
+ immArgPositions.end());
+ for (unsigned argIdx : llvm::seq<unsigned>(call->arg_size())) {
+ if (argAttrIdx >= argAttrsArray.size())
+ break;
+ // Skip immediate arguments (they have no entries in argAttrsArray).
+ if (immArgPositionsSet.contains(argIdx))
+ continue;
+ // Skip empty argument attributes.
+ auto argAttrs = cast<DictionaryAttr>(argAttrsArray[argAttrIdx++]);
+ if (argAttrs.empty())
+ continue;
+ // Convert and add attributes to the call instruction.
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(attrsOp->getLoc(), argAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addParamAttrs(argIdx, *attrBuilder);
+ }
+ }
+
+ // Convert the result attributes.
+ if (ArrayAttr resAttrsArray = attrsOp.getResAttrsAttr()) {
+ if (!resAttrsArray.empty()) {
+ auto resAttrs = cast<DictionaryAttr>(resAttrsArray[0]);
+ FailureOr<llvm::AttrBuilder> attrBuilder =
+ convertParameterAttrs(attrsOp->getLoc(), resAttrs);
+ if (failed(attrBuilder))
+ return failure();
+ call->addRetAttrs(*attrBuilder);
+ }
+ }
+
+ return success();
+}
+
FailureOr<llvm::AttrBuilder>
ModuleTranslation::convertParameterAttrs(Location loc,
DictionaryAttr paramAttrs) {
@@ -2276,6 +2318,25 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
llvmModule->setTargetTriple(
llvm::Triple(cast<StringAttr>(targetTripleAttr).getValue()));
+ if (auto asmAttr = m->getDiscardableAttr(
+ LLVM::LLVMDialect::getModuleLevelAsmAttrName())) {
+ auto asmArrayAttr = dyn_cast<ArrayAttr>(asmAttr);
+ if (!asmArrayAttr) {
+ m->emitError("expected an array attribute for a module level asm");
+ return nullptr;
+ }
+
+ for (Attribute elt : asmArrayAttr) {
+ auto asmStrAttr = dyn_cast<StringAttr>(elt);
+ if (!asmStrAttr) {
+ m->emitError(
+ "expected a string attribute for each entry of a module level asm");
+ return nullptr;
+ }
+ llvmModule->appendModuleInlineAsm(asmStrAttr.getValue());
+ }
+ }
+
return llvmModule;
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index e5934bb..d0ae513 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -347,10 +347,6 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
}
- // Block decoration does not affect spirv.struct type, but is still stored
- // for verification.
- // TODO: Update StructType to contain this information since
- // it is needed for many validation rules.
decorations[words[0]].set(symbol, opBuilder.getUnitAttr());
break;
case spirv::Decoration::Location:
@@ -993,7 +989,8 @@ spirv::Deserializer::processOpTypePointer(ArrayRef<uint32_t> operands) {
if (failed(structType.trySetBody(
deferredStructIt->memberTypes, deferredStructIt->offsetInfo,
- deferredStructIt->memberDecorationsInfo)))
+ deferredStructIt->memberDecorationsInfo,
+ deferredStructIt->structDecorationsInfo)))
return failure();
deferredStructIt = deferredStructTypesInfos.erase(deferredStructIt);
@@ -1203,24 +1200,37 @@ spirv::Deserializer::processStructType(ArrayRef<uint32_t> operands) {
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
+ if (decorations.count(operands[0])) {
+ NamedAttrList &allDecorations = decorations[operands[0]];
+ for (NamedAttribute &decorationAttr : allDecorations) {
+ std::optional<spirv::Decoration> decoration = spirv::symbolizeDecoration(
+ llvm::convertToCamelFromSnakeCase(decorationAttr.getName(), true));
+ assert(decoration.has_value());
+ structDecorationsInfo.emplace_back(decoration.value(),
+ decorationAttr.getValue());
+ }
+ }
+
uint32_t structID = operands[0];
std::string structIdentifier = nameMap.lookup(structID).str();
if (structIdentifier.empty()) {
assert(unresolvedMemberTypes.empty() &&
"didn't expect unresolved member types");
- typeMap[structID] =
- spirv::StructType::get(memberTypes, offsetInfo, memberDecorationsInfo);
+ typeMap[structID] = spirv::StructType::get(
+ memberTypes, offsetInfo, memberDecorationsInfo, structDecorationsInfo);
} else {
auto structTy = spirv::StructType::getIdentified(context, structIdentifier);
typeMap[structID] = structTy;
if (!unresolvedMemberTypes.empty())
- deferredStructTypesInfos.push_back({structTy, unresolvedMemberTypes,
- memberTypes, offsetInfo,
- memberDecorationsInfo});
+ deferredStructTypesInfos.push_back(
+ {structTy, unresolvedMemberTypes, memberTypes, offsetInfo,
+ memberDecorationsInfo, structDecorationsInfo});
else if (failed(structTy.trySetBody(memberTypes, offsetInfo,
- memberDecorationsInfo)))
+ memberDecorationsInfo,
+ structDecorationsInfo)))
return failure();
}
@@ -1769,7 +1779,7 @@ LogicalResult
spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
return emitError(unknownLoc,
- "OpConstantNull must have type <id> and result <id>");
+ "OpConstantNull must only have type <id> and result <id>");
}
Type resultType = getType(operands[0]);
@@ -1779,8 +1789,15 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
}
auto resultID = operands[1];
+ Attribute attr;
if (resultType.isIntOrFloat() || isa<VectorType>(resultType)) {
- auto attr = opBuilder.getZeroAttr(resultType);
+ attr = opBuilder.getZeroAttr(resultType);
+ } else if (auto tensorType = dyn_cast<TensorArmType>(resultType)) {
+ if (auto element = opBuilder.getZeroAttr(tensorType.getElementType()))
+ attr = DenseElementsAttr::get(tensorType, element);
+ }
+
+ if (attr) {
// For normal constants, we just record the attribute (and its type) for
// later materialization at use sites.
constantMap.try_emplace(resultID, attr, resultType);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index 20482bd..db1cc3f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -95,6 +95,7 @@ struct DeferredStructTypeInfo {
SmallVector<Type, 4> memberTypes;
SmallVector<spirv::StructType::OffsetInfo, 0> offsetInfo;
SmallVector<spirv::StructType::MemberDecorationInfo, 0> memberDecorationsInfo;
+ SmallVector<spirv::StructType::StructDecorationInfo, 0> structDecorationsInfo;
};
/// A struct that collects the info needed to materialize/emit a
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 58e5353..3053663 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -69,6 +69,25 @@ static Block *getPhiIncomingBlock(Block *block) {
return block;
}
+static bool isZeroValue(Attribute attr) {
+ if (auto floatAttr = dyn_cast<FloatAttr>(attr)) {
+ return floatAttr.getValue().isZero();
+ }
+ if (auto boolAttr = dyn_cast<BoolAttr>(attr)) {
+ return !boolAttr.getValue();
+ }
+ if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
+ return intAttr.getValue().isZero();
+ }
+ if (auto splatElemAttr = dyn_cast<SplatElementsAttr>(attr)) {
+ return isZeroValue(splatElemAttr.getSplatValue<Attribute>());
+ }
+ if (auto denseElemAttr = dyn_cast<DenseElementsAttr>(attr)) {
+ return all_of(denseElemAttr.getValues<Attribute>(), isZeroValue);
+ }
+ return false;
+}
+
namespace mlir {
namespace spirv {
@@ -318,6 +337,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
case spirv::Decoration::Constant:
+ case spirv::Decoration::Block:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
@@ -446,6 +466,19 @@ LogicalResult Serializer::processType(Location loc, Type type,
LogicalResult
Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
SetVector<StringRef> &serializationCtx) {
+
+ // Map unsigned integer types to singless integer types.
+ // This is needed otherwise the generated spirv assembly will contain
+ // twice a type declaration (like OpTypeInt 32 0) which is no permitted and
+ // such module fails validation. Indeed at MLIR level the two types are
+ // different and lookup in the cache below misses.
+ // Note: This conversion needs to happen here before the type is looked up in
+ // the cache.
+ if (type.isUnsignedInteger()) {
+ type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(),
+ IntegerType::SignednessSemantics::Signless);
+ }
+
typeID = getTypeID(type);
if (typeID)
return success();
@@ -617,11 +650,16 @@ LogicalResult Serializer::prepareBasicType(
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
operands.push_back(pointeeTypeID);
+ // TODO: Now struct decorations are supported this code may not be
+ // necessary. However, it is left to support backwards compatibility.
+ // Ideally, Block decorations should be inserted when converting to SPIR-V.
if (isInterfaceStructPtrType(ptrType)) {
- if (failed(emitDecoration(getTypeID(pointeeStruct),
- spirv::Decoration::Block)))
- return emitError(loc, "cannot decorate ")
- << pointeeStruct << " with Block decoration";
+ auto structType = cast<spirv::StructType>(ptrType.getPointeeType());
+ if (!structType.hasDecoration(spirv::Decoration::Block))
+ if (failed(emitDecoration(getTypeID(pointeeStruct),
+ spirv::Decoration::Block)))
+ return emitError(loc, "cannot decorate ")
+ << pointeeStruct << " with Block decoration";
}
return success();
@@ -691,6 +729,20 @@ LogicalResult Serializer::prepareBasicType(
}
}
+ SmallVector<spirv::StructType::StructDecorationInfo, 1> structDecorations;
+ structType.getStructDecorations(structDecorations);
+
+ for (spirv::StructType::StructDecorationInfo &structDecoration :
+ structDecorations) {
+ if (failed(processDecorationAttr(loc, resultID,
+ structDecoration.decoration,
+ structDecoration.decorationValue))) {
+ return emitError(loc, "cannot decorate struct ")
+ << structType << " with "
+ << stringifyDecoration(structDecoration.decoration);
+ }
+ }
+
typeEnum = spirv::Opcode::OpTypeStruct;
if (structType.isIdentified())
@@ -925,6 +977,30 @@ Serializer::prepareDenseElementsConstant(Location loc, Type constType,
} else {
return 0;
}
+ } else if (isa<spirv::TensorArmType>(constType)) {
+ if (isZeroValue(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ return resultID;
+ }
+ numberOfConstituents = shapedType.getNumElements();
+ operands.reserve(numberOfConstituents + 2);
+ for (int i = 0; i < numberOfConstituents; ++i) {
+ uint32_t elementID = 0;
+ if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
+ elementID =
+ elementType.isInteger(1)
+ ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[i])
+ : prepareConstantInt(loc, attr.getValues<IntegerAttr>()[i]);
+ }
+ if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
+ elementID = prepareConstantFp(loc, attr.getValues<FloatAttr>()[i]);
+ }
+ if (!elementID) {
+ return 0;
+ }
+ operands.push_back(elementID);
+ }
} else {
operands.reserve(numberOfConstituents + 2);
for (int i = 0; i < numberOfConstituents; ++i) {
@@ -1111,6 +1187,21 @@ uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
return resultID;
}
+// Returns type of attribute. In case of a TypedAttr this will simply return
+// the type. But for an ArrayAttr which is untyped and can be multidimensional
+// it creates the ArrayType recursively.
+static Type getValueType(Attribute attr) {
+ if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
+ return typedAttr.getType();
+ }
+
+ if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
+ return spirv::ArrayType::get(getValueType(arrayAttr[0]), arrayAttr.size());
+ }
+
+ return nullptr;
+}
+
uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
Type resultType,
Attribute valueAttr) {
@@ -1124,18 +1215,9 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
return 0;
}
- Type valueType;
- if (auto typedAttr = dyn_cast<TypedAttr>(valueAttr)) {
- valueType = typedAttr.getType();
- } else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
- auto typedElemAttr = dyn_cast<TypedAttr>(arrayAttr[0]);
- if (!typedElemAttr)
- return 0;
- valueType =
- spirv::ArrayType::get(typedElemAttr.getType(), arrayAttr.size());
- } else {
+ Type valueType = getValueType(valueAttr);
+ if (!valueAttr)
return 0;
- }
auto compositeType = dyn_cast<CompositeType>(resultType);
if (!compositeType)
@@ -1150,11 +1232,14 @@ uint32_t Serializer::prepareConstantCompositeReplicate(Location loc,
}
uint32_t resultID = getNextID();
- uint32_t operands[] = {typeID, resultID, constandID};
-
- encodeInstructionInto(typesGlobalValues,
- spirv::Opcode::OpConstantCompositeReplicateEXT,
- operands);
+ if (dyn_cast<spirv::TensorArmType>(resultType) && isZeroValue(valueAttr)) {
+ encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpConstantNull,
+ {typeID, resultID});
+ } else {
+ encodeInstructionInto(typesGlobalValues,
+ spirv::Opcode::OpConstantCompositeReplicateEXT,
+ {typeID, resultID, constandID});
+ }
constCompositeReplicateIDMap[valueTypePair] = resultID;
return resultID;
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index 8f78590..de714d8b 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -501,20 +501,26 @@ performActions(raw_ostream &os,
<< "bytecode version while not emitting bytecode";
AsmState asmState(op.get(), OpPrintingFlags(), /*locationMap=*/nullptr,
&fallbackResourceMap);
- op.get()->print(os, asmState);
- os << '\n';
+ os << OpWithState(op.get(), asmState) << '\n';
return success();
}
/// Parses the memory buffer. If successfully, run a series of passes against
/// it and print the result.
-static LogicalResult processBuffer(raw_ostream &os,
- std::unique_ptr<MemoryBuffer> ownedBuffer,
- const MlirOptMainConfig &config,
- DialectRegistry &registry,
- llvm::ThreadPoolInterface *threadPool) {
+static LogicalResult
+processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
+ llvm::MemoryBufferRef sourceBuffer,
+ const MlirOptMainConfig &config, DialectRegistry &registry,
+ SourceMgrDiagnosticVerifierHandler *verifyHandler,
+ llvm::ThreadPoolInterface *threadPool) {
// Tell sourceMgr about this buffer, which is what the parser will pick up.
auto sourceMgr = std::make_shared<SourceMgr>();
+ // Add the original buffer to the source manager to use for determining
+ // locations.
+ sourceMgr->AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBuffer(sourceBuffer,
+ /*RequiresNullTerminator=*/false),
+ SMLoc());
sourceMgr->AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
// Create a context just for the current buffer. Disable threading on creation
@@ -522,6 +528,8 @@ static LogicalResult processBuffer(raw_ostream &os,
MLIRContext context(registry, MLIRContext::Threading::DISABLED);
if (threadPool)
context.setThreadPool(*threadPool);
+ if (verifyHandler)
+ verifyHandler->registerInContext(&context);
StringRef irdlFile = config.getIrdlFile();
if (!irdlFile.empty() && failed(loadIRDLDialects(irdlFile, context)))
@@ -545,17 +553,12 @@ static LogicalResult processBuffer(raw_ostream &os,
return performActions(os, sourceMgr, &context, config);
}
- SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
- *sourceMgr, &context, config.verifyDiagnosticsLevel());
-
// Do any processing requested by command line flags. We don't care whether
// these actions succeed or fail, we only care what diagnostics they produce
// and whether they match our expectations.
(void)performActions(os, sourceMgr, &context, config);
- // Verify the diagnostic handler to make sure that each of the diagnostics
- // matched.
- return sourceMgrHandler.verify();
+ return success();
}
std::pair<std::string, std::string>
@@ -624,14 +627,31 @@ LogicalResult mlir::MlirOptMain(llvm::raw_ostream &outputStream,
if (threadPoolCtx.isMultithreadingEnabled())
threadPool = &threadPoolCtx.getThreadPool();
+ SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(
+ llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(),
+ /*RequiresNullTerminator=*/false),
+ SMLoc());
+ // Note: this creates a verifier handler independent of the the flag set, as
+ // internally if the flag is not set, a new scoped diagnostic handler is
+ // created which would intercept the diagnostics and verify them.
+ SourceMgrDiagnosticVerifierHandler sourceMgrHandler(
+ sourceMgr, &threadPoolCtx, config.verifyDiagnosticsLevel());
auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
- raw_ostream &os) {
- return processBuffer(os, std::move(chunkBuffer), config, registry,
- threadPool);
+ llvm::MemoryBufferRef sourceBuffer, raw_ostream &os) {
+ return processBuffer(
+ os, std::move(chunkBuffer), sourceBuffer, config, registry,
+ config.shouldVerifyDiagnostics() ? &sourceMgrHandler : nullptr,
+ threadPool);
};
- return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
- config.inputSplitMarker(),
- config.outputSplitMarker());
+ LogicalResult status = splitAndProcessBuffer(
+ llvm::MemoryBuffer::getMemBuffer(buffer->getMemBufferRef(),
+ /*RequiresNullTerminator=*/false),
+ chunkFn, outputStream, config.inputSplitMarker(),
+ config.outputSplitMarker());
+ if (config.shouldVerifyDiagnostics() && failed(sourceMgrHandler.verify()))
+ status = failure();
+ return status;
}
LogicalResult mlir::MlirOptMain(int argc, char **argv,
diff --git a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
index c11cb8d..e1c8afb 100644
--- a/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
+++ b/mlir/lib/Tools/mlir-translate/MlirTranslateMain.cpp
@@ -135,6 +135,13 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
// Processes the memory buffer with a new MLIRContext.
auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
raw_ostream &os) {
+ // Many of the translations expect a null-terminated buffer while splitting
+ // the buffer does not guarantee null-termination. Make a copy of the buffer
+ // to ensure null-termination.
+ if (!ownedBuffer->getBuffer().ends_with('\0')) {
+ ownedBuffer = llvm::MemoryBuffer::getMemBufferCopy(
+ ownedBuffer->getBuffer(), ownedBuffer->getBufferIdentifier());
+ }
// Temporary buffers for chained translation processing.
std::string dataIn;
std::string dataOut;
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 08803e0..f23c619 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -20,6 +20,7 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/SaveAndRestore.h"
#include "llvm/Support/ScopedPrinter.h"
@@ -1129,8 +1130,13 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// verification.
SmallPtrSet<Operation *, 1> pendingRootUpdates;
+ /// A raw output stream used to prefix the debug log.
+ llvm::impl::raw_ldbg_ostream os{(Twine("[") + DEBUG_TYPE + "] ").str(),
+ llvm::dbgs(), /*HasPendingNewline=*/false};
+
/// A logger used to emit diagnostics during the conversion process.
- llvm::ScopedPrinter logger{llvm::dbgs()};
+ llvm::ScopedPrinter logger{os};
+ std::string logPrefix;
#endif
};
} // namespace detail
diff --git a/mlir/lib/Transforms/Utils/Inliner.cpp b/mlir/lib/Transforms/Utils/Inliner.cpp
index dc78065..26c965c 100644
--- a/mlir/lib/Transforms/Utils/Inliner.cpp
+++ b/mlir/lib/Transforms/Utils/Inliner.cpp
@@ -348,13 +348,11 @@ static void collectCallOps(iterator_range<Region::iterator> blocks,
// InlinerInterfaceImpl
//===----------------------------------------------------------------------===//
-#ifndef NDEBUG
static std::string getNodeName(CallOpInterface op) {
if (llvm::dyn_cast_if_present<SymbolRefAttr>(op.getCallableForCallee()))
return debugString(op);
return "_unnamed_callee_";
}
-#endif
/// Return true if the specified `inlineHistoryID` indicates an inline history
/// that already includes `node`.