aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/AsmParser/DialectSymbolParser.cpp7
-rw-r--r--mlir/lib/AsmParser/Lexer.cpp27
-rw-r--r--mlir/lib/AsmParser/Lexer.h3
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp78
-rw-r--r--mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp36
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp69
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp4
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp64
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp94
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp28
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp14
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp3
-rw-r--r--mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp10
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp111
-rw-r--r--mlir/lib/Pass/Pass.cpp2
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp13
18 files changed, 227 insertions, 339 deletions
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp
index 9f4a87a..8b14e71 100644
--- a/mlir/lib/AsmParser/DialectSymbolParser.cpp
+++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp
@@ -89,6 +89,7 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
nestedPunctuation.pop_back();
return success();
};
+ const char *curBufferEnd = state.lex.getBufferEnd();
do {
// Handle code completions, which may appear in the middle of the symbol
// body.
@@ -98,6 +99,12 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body,
break;
}
+ if (curBufferEnd == curPtr) {
+ if (!nestedPunctuation.empty())
+ return emitPunctError();
+ return emitError("unexpected nul or EOF in pretty dialect name");
+ }
+
char c = *curPtr++;
switch (c) {
case '\0':
diff --git a/mlir/lib/AsmParser/Lexer.cpp b/mlir/lib/AsmParser/Lexer.cpp
index 751bd63..8f53529 100644
--- a/mlir/lib/AsmParser/Lexer.cpp
+++ b/mlir/lib/AsmParser/Lexer.cpp
@@ -37,6 +37,18 @@ Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context,
AsmParserCodeCompleteContext *codeCompleteContext)
: sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) {
auto bufferID = sourceMgr.getMainFileID();
+
+ // Check to see if the main buffer contains the last buffer, and if so the
+ // last buffer should be used as main file for parsing.
+ if (sourceMgr.getNumBuffers() > 1) {
+ unsigned lastFileID = sourceMgr.getNumBuffers();
+ const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID);
+ const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID);
+ if (main->getBufferStart() <= last->getBufferStart() &&
+ main->getBufferEnd() >= last->getBufferEnd()) {
+ bufferID = lastFileID;
+ }
+ }
curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer();
curPtr = curBuffer.begin();
@@ -71,6 +83,7 @@ Token Lexer::emitError(const char *loc, const Twine &message) {
}
Token Lexer::lexToken() {
+ const char *curBufferEnd = curBuffer.end();
while (true) {
const char *tokStart = curPtr;
@@ -78,6 +91,9 @@ Token Lexer::lexToken() {
if (tokStart == codeCompleteLoc)
return formToken(Token::code_complete, tokStart);
+ if (tokStart == curBufferEnd)
+ return formToken(Token::eof, tokStart);
+
// Lex the next token.
switch (*curPtr++) {
default:
@@ -102,7 +118,7 @@ Token Lexer::lexToken() {
case 0:
// This may either be a nul character in the source file or may be the EOF
// marker that llvm::MemoryBuffer guarantees will be there.
- if (curPtr - 1 == curBuffer.end())
+ if (curPtr - 1 == curBufferEnd)
return formToken(Token::eof, tokStart);
continue;
@@ -259,7 +275,11 @@ void Lexer::skipComment() {
assert(*curPtr == '/');
++curPtr;
+ const char *curBufferEnd = curBuffer.end();
while (true) {
+ if (curPtr == curBufferEnd)
+ return;
+
switch (*curPtr++) {
case '\n':
case '\r':
@@ -267,7 +287,7 @@ void Lexer::skipComment() {
return;
case 0:
// If this is the end of the buffer, end the comment.
- if (curPtr - 1 == curBuffer.end()) {
+ if (curPtr - 1 == curBufferEnd) {
--curPtr;
return;
}
@@ -405,6 +425,7 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) {
Token Lexer::lexString(const char *tokStart) {
assert(curPtr[-1] == '"');
+ const char *curBufferEnd = curBuffer.end();
while (true) {
// Check to see if there is a code completion location within the string. In
// these cases we generate a completion location and place the currently
@@ -419,7 +440,7 @@ Token Lexer::lexString(const char *tokStart) {
case 0:
// If this is a random nul character in the middle of a string, just
// include it. If it is the end of file, then it is an error.
- if (curPtr - 1 != curBuffer.end())
+ if (curPtr - 1 != curBufferEnd)
continue;
[[fallthrough]];
case '\n':
diff --git a/mlir/lib/AsmParser/Lexer.h b/mlir/lib/AsmParser/Lexer.h
index 4085a9b..670444e 100644
--- a/mlir/lib/AsmParser/Lexer.h
+++ b/mlir/lib/AsmParser/Lexer.h
@@ -40,6 +40,9 @@ public:
/// Returns the start of the buffer.
const char *getBufferBegin() { return curBuffer.data(); }
+ /// Returns the end of the buffer.
+ const char *getBufferEnd() { return curBuffer.end(); }
+
/// Return the code completion location of the lexer, or nullptr if there is
/// none.
const char *getCodeCompleteLoc() const { return codeCompleteLoc; }
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
index e882845..6bd0e2d 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp
@@ -19,10 +19,18 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeRange.h"
+#include "mlir/IR/Value.h"
#include "mlir/Transforms/DialectConversion.h"
+#include <cstdint>
using namespace mlir;
+static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) {
+ return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() &&
+ memRefType.getRank() != 0 &&
+ !llvm::is_contained(memRefType.getShape(), 0);
+}
+
namespace {
/// Implement the interface to convert MemRef to EmitC.
struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
@@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) {
return resultTy;
}
+struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = allocOp.getLoc();
+ MemRefType memrefType = allocOp.getType();
+ if (!isMemRefTypeLegalForEmitC(memrefType)) {
+ return rewriter.notifyMatchFailure(
+ loc, "incompatible memref type for EmitC conversion");
+ }
+
+ Type sizeTType = emitc::SizeTType::get(rewriter.getContext());
+ Type elementType = memrefType.getElementType();
+ IndexType indexType = rewriter.getIndexType();
+ emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>(
+ loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{},
+ ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)}));
+
+ int64_t numElements = 1;
+ for (int64_t dimSize : memrefType.getShape()) {
+ numElements *= dimSize;
+ }
+ Value numElementsValue = rewriter.create<emitc::ConstantOp>(
+ loc, indexType, rewriter.getIndexAttr(numElements));
+
+ Value totalSizeBytes = rewriter.create<emitc::MulOp>(
+ loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue);
+
+ emitc::CallOpaqueOp allocCall;
+ StringAttr allocFunctionName;
+ Value alignmentValue;
+ SmallVector<Value, 2> argsVec;
+ if (allocOp.getAlignment()) {
+ allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName);
+ alignmentValue = rewriter.create<emitc::ConstantOp>(
+ loc, sizeTType,
+ rewriter.getIntegerAttr(indexType,
+ allocOp.getAlignment().value_or(0)));
+ argsVec.push_back(alignmentValue);
+ } else {
+ allocFunctionName = rewriter.getStringAttr(mallocFunctionName);
+ }
+
+ argsVec.push_back(totalSizeBytes);
+ ValueRange args(argsVec);
+
+ allocCall = rewriter.create<emitc::CallOpaqueOp>(
+ loc,
+ emitc::PointerType::get(
+ emitc::OpaqueType::get(rewriter.getContext(), "void")),
+ allocFunctionName, args);
+
+ emitc::PointerType targetPointerType = emitc::PointerType::get(elementType);
+ emitc::CastOp castOp = rewriter.create<emitc::CastOp>(
+ loc, targetPointerType, allocCall.getResult(0));
+
+ rewriter.replaceOp(allocOp, castOp);
+ return success();
+ }
+};
+
struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> {
using OpConversionPattern::OpConversionPattern;
@@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> {
void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
typeConverter.addConversion(
[&](MemRefType memRefType) -> std::optional<Type> {
- if (!memRefType.hasStaticShape() ||
- !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 ||
- llvm::is_contained(memRefType.getShape(), 0)) {
+ if (!isMemRefTypeLegalForEmitC(memRefType)) {
return {};
}
Type convertedElementType =
@@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) {
void mlir::populateMemRefToEmitCConversionPatterns(
RewritePatternSet &patterns, const TypeConverter &converter) {
- patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad,
- ConvertStore>(converter, patterns.getContext());
+ patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal,
+ ConvertLoad, ConvertStore>(converter, patterns.getContext());
}
diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
index cf25c09..e78dd76 100644
--- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
+++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp
@@ -15,6 +15,7 @@
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h"
#include "mlir/Dialect/EmitC/IR/EmitC.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -28,9 +29,11 @@ using namespace mlir;
namespace {
struct ConvertMemRefToEmitCPass
: public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> {
+ using Base::Base;
void runOnOperation() override {
TypeConverter converter;
-
+ ConvertMemRefToEmitCOptions options;
+ options.lowerToCpp = this->lowerToCpp;
// Fallback for other types.
converter.addConversion([](Type type) -> std::optional<Type> {
if (!emitc::isSupportedEmitCType(type))
@@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();
+
+ mlir::ModuleOp module = getOperation();
+ module.walk([&](mlir::emitc::CallOpaqueOp callOp) {
+ if (callOp.getCallee() != alignedAllocFunctionName &&
+ callOp.getCallee() != mallocFunctionName) {
+ return mlir::WalkResult::advance();
+ }
+
+ for (auto &op : *module.getBody()) {
+ emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
+ if (!includeOp) {
+ continue;
+ }
+ if (includeOp.getIsStandardInclude() &&
+ ((options.lowerToCpp &&
+ includeOp.getInclude() == cppStandardLibraryHeader) ||
+ (!options.lowerToCpp &&
+ includeOp.getInclude() == cStandardLibraryHeader))) {
+ return mlir::WalkResult::interrupt();
+ }
+ }
+
+ mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
+ StringAttr includeAttr =
+ builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader
+ : cStandardLibraryHeader);
+ builder.create<mlir::emitc::IncludeOp>(
+ module.getLoc(), includeAttr,
+ /*is_standard_include=*/builder.getUnitAttr());
+ return mlir::WalkResult::interrupt();
+ });
}
};
} // namespace
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 4307bc6..17a79e3 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1070,39 +1070,6 @@ public:
}
};
-class VectorExtractElementOpConversion
- : public ConvertOpToLLVMPattern<vector::ExtractElementOp> {
-public:
- using ConvertOpToLLVMPattern<
- vector::ExtractElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = extractEltOp.getSourceVectorType();
- auto llvmType = typeConverter->convertType(vectorType.getElementType());
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = extractEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>(
- extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
class VectorExtractOpConversion
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
public:
@@ -1206,39 +1173,6 @@ public:
}
};
-class VectorInsertElementOpConversion
- : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
-public:
- using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- auto vectorType = insertEltOp.getDestVectorType();
- auto llvmType = typeConverter->convertType(vectorType);
-
- // Bail if result type cannot be lowered.
- if (!llvmType)
- return failure();
-
- if (vectorType.getRank() == 0) {
- Location loc = insertEltOp.getLoc();
- auto idxType = rewriter.getIndexType();
- auto zero = LLVM::ConstantOp::create(rewriter, loc,
- typeConverter->convertType(idxType),
- rewriter.getIntegerAttr(idxType, 0));
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero);
- return success();
- }
-
- rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>(
- insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
class VectorInsertOpConversion
: public ConvertOpToLLVMPattern<vector::InsertOp> {
public:
@@ -2244,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorGatherOpConversion, VectorScatterOpConversion>(
converter, useVectorAlignment);
patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
- VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorExtractOpConversion, VectorFMAOp1DConversion,
VectorInsertOpConversion, VectorPrintOpConversion,
VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorExpandLoadOpConversion, VectorCompressStoreOpConversion,
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index b1af5f0..508f4e2 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -690,7 +690,7 @@ struct PrepareTransferWriteConversion
/// %lastIndex = arith.subi %length, %c1 : index
/// vector.print punctuation <open>
/// scf.for %i = %c0 to %length step %c1 {
-/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32>
+/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32>
/// vector.print %el : i32 punctuation <no_punctuation>
/// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index
/// scf.if %notLastIndex {
@@ -1643,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> {
/// Is rewritten to approximately the following pseudo-IR:
/// ```
/// for i = 0 to 9 {
-/// %t = vector.extractelement %vec[i] : vector<9xf32>
+/// %t = vector.extract %vec[i] : f32 from vector<9xf32>
/// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32>
/// }
/// ```
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 986eae3..a4be7d4 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -335,63 +335,6 @@ struct VectorInsertOpConvert final
}
};
-struct VectorExtractElementOpConvert final
- : public OpConversionPattern<vector::ExtractElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type resultType = getTypeConverter()->convertType(extractOp.getType());
- if (!resultType)
- return failure();
-
- if (isa<spirv::ScalarType>(adaptor.getVector().getType())) {
- rewriter.replaceOp(extractOp, adaptor.getVector());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
- extractOp, resultType, adaptor.getVector(),
- rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())}));
- else
- rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
- extractOp, resultType, adaptor.getVector(), adaptor.getPosition());
- return success();
- }
-};
-
-struct VectorInsertElementOpConvert final
- : public OpConversionPattern<vector::InsertElementOp> {
- using OpConversionPattern::OpConversionPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- Type vectorType = getTypeConverter()->convertType(insertOp.getType());
- if (!vectorType)
- return failure();
-
- if (isa<spirv::ScalarType>(vectorType)) {
- rewriter.replaceOp(insertOp, adaptor.getSource());
- return success();
- }
-
- APInt cstPos;
- if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos)))
- rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
- insertOp, adaptor.getSource(), adaptor.getDest(),
- cstPos.getSExtValue());
- else
- rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
- insertOp, vectorType, insertOp.getDest(), adaptor.getSource(),
- adaptor.getPosition());
- return success();
- }
-};
-
struct VectorInsertStridedSliceOpConvert final
: public OpConversionPattern<vector::InsertStridedSliceOp> {
using OpConversionPattern::OpConversionPattern;
@@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final
void mlir::populateVectorToSPIRVPatterns(
const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<
- VectorBitcastConvert, VectorBroadcastConvert,
- VectorExtractElementOpConvert, VectorExtractOpConvert,
+ VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert,
- VectorToElementOpConvert, VectorInsertElementOpConvert,
- VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
+ VectorToElementOpConvert, VectorInsertOpConvert,
+ VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
index d1d1062..aa53f94 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp
@@ -1,4 +1,5 @@
-//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
+//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries
+//----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@@ -8,12 +9,13 @@
//
// Module Bufferization is an extension of One-Shot Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
-// implementations for FuncOp, CallOp and ReturnOp.
+// implementations for FuncOp, CallOp and ReturnOp. Although it is named
+// Module Bufferization, it may operate on any SymbolTable.
//
-// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
-// This function analyzes the given module and determines the order of analysis
-// and bufferization: Functions that are called are processed before their
-// respective callers.
+// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp,
+// ...)`. This function analyzes the given op and determines the order of
+// analysis and bufferization: Functions that are called are processed before
+// their respective callers.
//
// After analyzing a FuncOp, additional information about its bbArgs is
// gathered and stored in `FuncAnalysisState`.
@@ -309,7 +311,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Return `failure()` if we are unable to retrieve the called FuncOp from
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
- ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
+ Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
SymbolTableCollection &symbolTables) {
// For each FuncOp, the set of functions called by it (i.e. the union of
@@ -317,26 +319,29 @@ static LogicalResult getFuncOpsOrderedByCalls(
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
-
- for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
- // Collect function calls and populate the caller map.
- numberCallOpsContainedInFuncOp[funcOp] = 0;
- WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
- func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
- assert(calledFunction && "could not retrieved called func::FuncOp");
- // If the called function does not have any tensors in its signature, then
- // it is not necessary to bufferize the callee before the caller.
- if (!hasTensorSignature(calledFunction))
- return WalkResult::skip();
-
- callerMap[calledFunction].insert(callOp);
- if (calledBy[calledFunction].insert(funcOp).second) {
- numberCallOpsContainedInFuncOp[funcOp]++;
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ // Collect function calls and populate the caller map.
+ numberCallOpsContainedInFuncOp[funcOp] = 0;
+ WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
+ func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
+ assert(calledFunction && "could not retrieved called func::FuncOp");
+ // If the called function does not have any tensors in its signature,
+ // then it is not necessary to bufferize the callee before the caller.
+ if (!hasTensorSignature(calledFunction))
+ return WalkResult::skip();
+
+ callerMap[calledFunction].insert(callOp);
+ if (calledBy[calledFunction].insert(funcOp).second) {
+ numberCallOpsContainedInFuncOp[funcOp]++;
+ }
+ return WalkResult::advance();
+ });
+ if (res.wasInterrupted())
+ return failure();
}
- return WalkResult::advance();
- });
- if (res.wasInterrupted())
- return failure();
+ }
}
// Iteratively remove function operations that do not call any of the
@@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}
LogicalResult
-mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
+mlir::bufferization::analyzeModuleOp(Operation *moduleOp,
OneShotAnalysisState &state,
BufferizationStatistics *statistics) {
assert(state.getOptions().bufferizeFunctionBoundaries &&
@@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
}
void mlir::bufferization::removeBufferizationAttributesInModule(
- ModuleOp moduleOp) {
- for (auto op : moduleOp.getOps<func::FuncOp>()) {
- for (BlockArgument bbArg : op.getArguments())
- removeBufferizationAttributes(bbArg);
+ Operation *moduleOp) {
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
+ for (BlockArgument bbArg : funcOp.getArguments())
+ removeBufferizationAttributes(bbArg);
+ }
+ }
}
}
LogicalResult mlir::bufferization::bufferizeModuleOp(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
- IRRewriter rewriter(moduleOp.getContext());
+ IRRewriter rewriter(moduleOp->getContext());
// A list of non-circular functions in the order in which they are analyzed
// and bufferized.
@@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
// Bufferize all other ops.
- for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
- // Functions were already bufferized.
- if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
- continue;
- if (failed(bufferizeOp(&op, options, state, statistics)))
- return failure();
+ for (mlir::Region &region : moduleOp->getRegions()) {
+ for (mlir::Block &block : region.getBlocks()) {
+ for (mlir::Operation &op :
+ llvm::make_early_inc_range(block.getOperations())) {
+ // Functions were already bufferized.
+ if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
+ continue;
+ if (failed(bufferizeOp(&op, options, state, statistics)))
+ return failure();
+ }
+ }
}
// Post-pass cleanup of function argument attributes.
@@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
- ModuleOp moduleOp, const OneShotBufferizationOptions &options,
+ Operation *moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
index f999c93..a6159ee 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp
@@ -33,7 +33,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
// analysis depending on whether function boundary bufferization is enabled or
// not.
if (options.bufferizeFunctionBoundaries) {
- if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
+ if (failed(analyzeModuleOp(op, analysisState, statistics)))
return failure();
} else {
if (failed(analyzeOp(op, analysisState, statistics)))
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 27b6617..b56a212 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -32,6 +32,7 @@
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@@ -4622,22 +4623,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos,
});
}
-/// Returns true if the dimension of `sourceShape` is smaller than the dimension
-/// of the `limitShape`.
-static bool areAllInBound(ArrayRef<int64_t> sourceShape,
- ArrayRef<int64_t> limitShape) {
- assert(
- sourceShape.size() == limitShape.size() &&
- "expected source shape rank, and limit of the shape to have same rank");
- return llvm::all_of(
- llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) {
- int64_t sourceExtent = std::get<0>(it);
- int64_t limit = std::get<1>(it);
- return ShapedType::isDynamic(sourceExtent) ||
- ShapedType::isDynamic(limit) || sourceExtent <= limit;
- });
-}
-
template <typename OpTy>
static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value,
@@ -4696,11 +4681,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// represents full tiles.
RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
- if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
- return op->emitError("the shape of output is not large enough to hold the "
- "packed data. Expected at least ")
- << expectedPackedType << ", got " << packedType;
- }
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
mixedTiles),
@@ -4717,6 +4697,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
return op->emitError("mismatch in inner tile sizes specified and shaped of "
"tiled dimension in the packed type");
}
+ if (failed(verifyCompatibleShape(expectedPackedType.getShape(),
+ packedType.getShape()))) {
+ return op->emitError("expected ")
+ << expectedPackedType << " for the packed domain value, got "
+ << packedType;
+ }
return success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
index 277e50b..9d7f4e0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp
@@ -10,6 +10,7 @@
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/PatternMatch.h"
namespace mlir {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index dad3526..57b610b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -932,20 +932,6 @@ struct PackOpTiling
continue;
}
- // If the dimension needs padding, it is not supported because there are
- // iterations that only write padding values to the whole tile. The
- // consumer fusion is driven by the source, so it is not possible to map
- // an empty slice to the tile.
- bool needExtraPadding =
- ShapedType::isDynamic(destDimSize) || !cstInnerSize ||
- destDimSize * cstInnerSize.value() != srcDimSize;
- // Prioritize the case that the op already says that it does not need
- // padding.
- if (!packOp.getPaddingValue())
- needExtraPadding = false;
- if (needExtraPadding)
- return failure();
-
// Currently fusing `packOp` as consumer only expects perfect tiling
// scenario because even if without padding semantic, the `packOp` may
// also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>,
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
index 0e96b59..869d27a 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp
@@ -115,8 +115,7 @@ public:
bufferization::BufferizationState bufferizationState;
- if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()),
- updatedOptions,
+ if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions,
bufferizationState)))
return failure();
diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
index e297f7c..14a4fdf 100644
--- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
+++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp
@@ -21,16 +21,8 @@
#include "llvm/Support/InterleavedRange.h"
#define DEBUG_TYPE "transform-dialect"
-#define DEBUG_TYPE_FULL "transform-dialect-full"
#define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all"
-#ifndef NDEBUG
-#define FULL_LDBG(X) \
- DEBUGLOG_WITH_STREAM_AND_TYPE(llvm::dbgs(), DEBUG_TYPE_FULL)
-#else
-#define FULL_LDBG(X) \
- for (bool _c = false; _c; _c = false) \
- ::llvm::nulls()
-#endif
+#define FULL_LDBG() LDBG(4)
using namespace mlir;
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index bce358d..8789f55 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -1258,63 +1258,6 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
CanonicalizeContractAdd<arith::AddFOp>>(context);
}
-//===----------------------------------------------------------------------===//
-// ExtractElementOp
-//===----------------------------------------------------------------------===//
-
-void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges.front());
-}
-
-void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
- Value source) {
- result.addOperands({source});
- result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType());
-}
-
-LogicalResult vector::ExtractElementOp::verify() {
- VectorType vectorType = getSourceVectorType();
- if (vectorType.getRank() == 0) {
- if (getPosition())
- return emitOpError("expected position to be empty with 0-D vector");
- return success();
- }
- if (vectorType.getRank() != 1)
- return emitOpError("unexpected >1 vector rank");
- if (!getPosition())
- return emitOpError("expected position for 1-D vector");
- return success();
-}
-
-OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) {
- // Skip the 0-D vector here now.
- if (!adaptor.getPosition())
- return {};
-
- // Fold extractelement (splat X) -> X.
- if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
- return splat.getInput();
-
- // Fold extractelement(broadcast(X)) -> X.
- if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>())
- if (!llvm::isa<VectorType>(broadcast.getSource().getType()))
- return broadcast.getSource();
-
- auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector());
- auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
- if (!pos || !src)
- return {};
-
- auto srcElements = src.getValues<Attribute>();
-
- uint64_t posIdx = pos.getInt();
- if (posIdx >= srcElements.size())
- return {};
-
- return srcElements[posIdx];
-}
-
// Returns `true` if `index` is either within [0, maxIndex) or equal to
// `poisonValue`.
static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue,
@@ -3184,60 +3127,6 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
}
//===----------------------------------------------------------------------===//
-// InsertElementOp
-//===----------------------------------------------------------------------===//
-
-void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
- SetIntRangeFn setResultRanges) {
- setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1]));
-}
-
-void InsertElementOp::build(OpBuilder &builder, OperationState &result,
- Value source, Value dest) {
- build(builder, result, source, dest, {});
-}
-
-LogicalResult InsertElementOp::verify() {
- auto dstVectorType = getDestVectorType();
- if (dstVectorType.getRank() == 0) {
- if (getPosition())
- return emitOpError("expected position to be empty with 0-D vector");
- return success();
- }
- if (dstVectorType.getRank() != 1)
- return emitOpError("unexpected >1 vector rank");
- if (!getPosition())
- return emitOpError("expected position for 1-D vector");
- return success();
-}
-
-OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) {
- // Skip the 0-D vector here.
- if (!adaptor.getPosition())
- return {};
-
- auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource());
- auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest());
- auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition());
- if (!src || !dst || !pos)
- return {};
-
- if (src.getType() != getDestVectorType().getElementType())
- return {};
-
- auto dstElements = dst.getValues<Attribute>();
-
- SmallVector<Attribute> results(dstElements);
-
- uint64_t posIdx = pos.getInt();
- if (posIdx >= results.size())
- return {};
- results[posIdx] = src;
-
- return DenseElementsAttr::get(getDestVectorType(), results);
-}
-
-//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 0db9808..7094c8e 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -901,7 +901,7 @@ LogicalResult PassManager::run(Operation *op) {
if (failed(initialize(context, impl->initializationGeneration + 1)))
return failure();
initializationKey = newInitKey;
- pipelineKey = pipelineInitializationKey;
+ pipelineInitializationKey = pipelineKey;
}
// Construct a top level analysis manager for the pipeline.
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 58e5353..a8a2b2e 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -446,6 +446,19 @@ LogicalResult Serializer::processType(Location loc, Type type,
LogicalResult
Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
SetVector<StringRef> &serializationCtx) {
+
+ // Map unsigned integer types to singless integer types.
+ // This is needed otherwise the generated spirv assembly will contain
+ // twice a type declaration (like OpTypeInt 32 0) which is no permitted and
+ // such module fails validation. Indeed at MLIR level the two types are
+ // different and lookup in the cache below misses.
+ // Note: This conversion needs to happen here before the type is looked up in
+ // the cache.
+ if (type.isUnsignedInteger()) {
+ type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(),
+ IntegerType::SignednessSemantics::Signless);
+ }
+
typeID = getTypeID(type);
if (typeID)
return success();