diff options
Diffstat (limited to 'mlir/lib')
18 files changed, 227 insertions, 339 deletions
diff --git a/mlir/lib/AsmParser/DialectSymbolParser.cpp b/mlir/lib/AsmParser/DialectSymbolParser.cpp index 9f4a87a..8b14e71 100644 --- a/mlir/lib/AsmParser/DialectSymbolParser.cpp +++ b/mlir/lib/AsmParser/DialectSymbolParser.cpp @@ -89,6 +89,7 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, nestedPunctuation.pop_back(); return success(); }; + const char *curBufferEnd = state.lex.getBufferEnd(); do { // Handle code completions, which may appear in the middle of the symbol // body. @@ -98,6 +99,12 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body, break; } + if (curBufferEnd == curPtr) { + if (!nestedPunctuation.empty()) + return emitPunctError(); + return emitError("unexpected nul or EOF in pretty dialect name"); + } + char c = *curPtr++; switch (c) { case '\0': diff --git a/mlir/lib/AsmParser/Lexer.cpp b/mlir/lib/AsmParser/Lexer.cpp index 751bd63..8f53529 100644 --- a/mlir/lib/AsmParser/Lexer.cpp +++ b/mlir/lib/AsmParser/Lexer.cpp @@ -37,6 +37,18 @@ Lexer::Lexer(const llvm::SourceMgr &sourceMgr, MLIRContext *context, AsmParserCodeCompleteContext *codeCompleteContext) : sourceMgr(sourceMgr), context(context), codeCompleteLoc(nullptr) { auto bufferID = sourceMgr.getMainFileID(); + + // Check to see if the main buffer contains the last buffer, and if so the + // last buffer should be used as main file for parsing. + if (sourceMgr.getNumBuffers() > 1) { + unsigned lastFileID = sourceMgr.getNumBuffers(); + const llvm::MemoryBuffer *main = sourceMgr.getMemoryBuffer(bufferID); + const llvm::MemoryBuffer *last = sourceMgr.getMemoryBuffer(lastFileID); + if (main->getBufferStart() <= last->getBufferStart() && + main->getBufferEnd() >= last->getBufferEnd()) { + bufferID = lastFileID; + } + } curBuffer = sourceMgr.getMemoryBuffer(bufferID)->getBuffer(); curPtr = curBuffer.begin(); @@ -71,6 +83,7 @@ Token Lexer::emitError(const char *loc, const Twine &message) { } Token Lexer::lexToken() { + const char *curBufferEnd = curBuffer.end(); while (true) { const char *tokStart = curPtr; @@ -78,6 +91,9 @@ Token Lexer::lexToken() { if (tokStart == codeCompleteLoc) return formToken(Token::code_complete, tokStart); + if (tokStart == curBufferEnd) + return formToken(Token::eof, tokStart); + // Lex the next token. switch (*curPtr++) { default: @@ -102,7 +118,7 @@ Token Lexer::lexToken() { case 0: // This may either be a nul character in the source file or may be the EOF // marker that llvm::MemoryBuffer guarantees will be there. - if (curPtr - 1 == curBuffer.end()) + if (curPtr - 1 == curBufferEnd) return formToken(Token::eof, tokStart); continue; @@ -259,7 +275,11 @@ void Lexer::skipComment() { assert(*curPtr == '/'); ++curPtr; + const char *curBufferEnd = curBuffer.end(); while (true) { + if (curPtr == curBufferEnd) + return; + switch (*curPtr++) { case '\n': case '\r': @@ -267,7 +287,7 @@ void Lexer::skipComment() { return; case 0: // If this is the end of the buffer, end the comment. - if (curPtr - 1 == curBuffer.end()) { + if (curPtr - 1 == curBufferEnd) { --curPtr; return; } @@ -405,6 +425,7 @@ Token Lexer::lexPrefixedIdentifier(const char *tokStart) { Token Lexer::lexString(const char *tokStart) { assert(curPtr[-1] == '"'); + const char *curBufferEnd = curBuffer.end(); while (true) { // Check to see if there is a code completion location within the string. In // these cases we generate a completion location and place the currently @@ -419,7 +440,7 @@ Token Lexer::lexString(const char *tokStart) { case 0: // If this is a random nul character in the middle of a string, just // include it. If it is the end of file, then it is an error. - if (curPtr - 1 != curBuffer.end()) + if (curPtr - 1 != curBufferEnd) continue; [[fallthrough]]; case '\n': diff --git a/mlir/lib/AsmParser/Lexer.h b/mlir/lib/AsmParser/Lexer.h index 4085a9b..670444e 100644 --- a/mlir/lib/AsmParser/Lexer.h +++ b/mlir/lib/AsmParser/Lexer.h @@ -40,6 +40,9 @@ public: /// Returns the start of the buffer. const char *getBufferBegin() { return curBuffer.data(); } + /// Returns the end of the buffer. + const char *getBufferEnd() { return curBuffer.end(); } + /// Return the code completion location of the lexer, or nullptr if there is /// none. const char *getCodeCompleteLoc() const { return codeCompleteLoc; } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp index e882845..6bd0e2d 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitC.cpp @@ -19,10 +19,18 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeRange.h" +#include "mlir/IR/Value.h" #include "mlir/Transforms/DialectConversion.h" +#include <cstdint> using namespace mlir; +static bool isMemRefTypeLegalForEmitC(MemRefType memRefType) { + return memRefType.hasStaticShape() && memRefType.getLayout().isIdentity() && + memRefType.getRank() != 0 && + !llvm::is_contained(memRefType.getShape(), 0); +} + namespace { /// Implement the interface to convert MemRef to EmitC. struct MemRefToEmitCDialectInterface : public ConvertToEmitCPatternInterface { @@ -89,6 +97,68 @@ Type convertMemRefType(MemRefType opTy, const TypeConverter *typeConverter) { return resultTy; } +struct ConvertAlloc final : public OpConversionPattern<memref::AllocOp> { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(memref::AllocOp allocOp, OpAdaptor operands, + ConversionPatternRewriter &rewriter) const override { + Location loc = allocOp.getLoc(); + MemRefType memrefType = allocOp.getType(); + if (!isMemRefTypeLegalForEmitC(memrefType)) { + return rewriter.notifyMatchFailure( + loc, "incompatible memref type for EmitC conversion"); + } + + Type sizeTType = emitc::SizeTType::get(rewriter.getContext()); + Type elementType = memrefType.getElementType(); + IndexType indexType = rewriter.getIndexType(); + emitc::CallOpaqueOp sizeofElementOp = rewriter.create<emitc::CallOpaqueOp>( + loc, sizeTType, rewriter.getStringAttr("sizeof"), ValueRange{}, + ArrayAttr::get(rewriter.getContext(), {TypeAttr::get(elementType)})); + + int64_t numElements = 1; + for (int64_t dimSize : memrefType.getShape()) { + numElements *= dimSize; + } + Value numElementsValue = rewriter.create<emitc::ConstantOp>( + loc, indexType, rewriter.getIndexAttr(numElements)); + + Value totalSizeBytes = rewriter.create<emitc::MulOp>( + loc, sizeTType, sizeofElementOp.getResult(0), numElementsValue); + + emitc::CallOpaqueOp allocCall; + StringAttr allocFunctionName; + Value alignmentValue; + SmallVector<Value, 2> argsVec; + if (allocOp.getAlignment()) { + allocFunctionName = rewriter.getStringAttr(alignedAllocFunctionName); + alignmentValue = rewriter.create<emitc::ConstantOp>( + loc, sizeTType, + rewriter.getIntegerAttr(indexType, + allocOp.getAlignment().value_or(0))); + argsVec.push_back(alignmentValue); + } else { + allocFunctionName = rewriter.getStringAttr(mallocFunctionName); + } + + argsVec.push_back(totalSizeBytes); + ValueRange args(argsVec); + + allocCall = rewriter.create<emitc::CallOpaqueOp>( + loc, + emitc::PointerType::get( + emitc::OpaqueType::get(rewriter.getContext(), "void")), + allocFunctionName, args); + + emitc::PointerType targetPointerType = emitc::PointerType::get(elementType); + emitc::CastOp castOp = rewriter.create<emitc::CastOp>( + loc, targetPointerType, allocCall.getResult(0)); + + rewriter.replaceOp(allocOp, castOp); + return success(); + } +}; + struct ConvertGlobal final : public OpConversionPattern<memref::GlobalOp> { using OpConversionPattern::OpConversionPattern; @@ -223,9 +293,7 @@ struct ConvertStore final : public OpConversionPattern<memref::StoreOp> { void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { typeConverter.addConversion( [&](MemRefType memRefType) -> std::optional<Type> { - if (!memRefType.hasStaticShape() || - !memRefType.getLayout().isIdentity() || memRefType.getRank() == 0 || - llvm::is_contained(memRefType.getShape(), 0)) { + if (!isMemRefTypeLegalForEmitC(memRefType)) { return {}; } Type convertedElementType = @@ -252,6 +320,6 @@ void mlir::populateMemRefToEmitCTypeConversion(TypeConverter &typeConverter) { void mlir::populateMemRefToEmitCConversionPatterns( RewritePatternSet &patterns, const TypeConverter &converter) { - patterns.add<ConvertAlloca, ConvertGlobal, ConvertGetGlobal, ConvertLoad, - ConvertStore>(converter, patterns.getContext()); + patterns.add<ConvertAlloca, ConvertAlloc, ConvertGlobal, ConvertGetGlobal, + ConvertLoad, ConvertStore>(converter, patterns.getContext()); } diff --git a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp index cf25c09..e78dd76 100644 --- a/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp +++ b/mlir/lib/Conversion/MemRefToEmitC/MemRefToEmitCPass.cpp @@ -15,6 +15,7 @@ #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitC.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/IR/Attributes.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -28,9 +29,11 @@ using namespace mlir; namespace { struct ConvertMemRefToEmitCPass : public impl::ConvertMemRefToEmitCBase<ConvertMemRefToEmitCPass> { + using Base::Base; void runOnOperation() override { TypeConverter converter; - + ConvertMemRefToEmitCOptions options; + options.lowerToCpp = this->lowerToCpp; // Fallback for other types. converter.addConversion([](Type type) -> std::optional<Type> { if (!emitc::isSupportedEmitCType(type)) @@ -50,6 +53,37 @@ struct ConvertMemRefToEmitCPass if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); + + mlir::ModuleOp module = getOperation(); + module.walk([&](mlir::emitc::CallOpaqueOp callOp) { + if (callOp.getCallee() != alignedAllocFunctionName && + callOp.getCallee() != mallocFunctionName) { + return mlir::WalkResult::advance(); + } + + for (auto &op : *module.getBody()) { + emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op); + if (!includeOp) { + continue; + } + if (includeOp.getIsStandardInclude() && + ((options.lowerToCpp && + includeOp.getInclude() == cppStandardLibraryHeader) || + (!options.lowerToCpp && + includeOp.getInclude() == cStandardLibraryHeader))) { + return mlir::WalkResult::interrupt(); + } + } + + mlir::OpBuilder builder(module.getBody(), module.getBody()->begin()); + StringAttr includeAttr = + builder.getStringAttr(options.lowerToCpp ? cppStandardLibraryHeader + : cStandardLibraryHeader); + builder.create<mlir::emitc::IncludeOp>( + module.getLoc(), includeAttr, + /*is_standard_include=*/builder.getUnitAttr()); + return mlir::WalkResult::interrupt(); + }); } }; } // namespace diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 4307bc6..17a79e3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1070,39 +1070,6 @@ public: } }; -class VectorExtractElementOpConversion - : public ConvertOpToLLVMPattern<vector::ExtractElementOp> { -public: - using ConvertOpToLLVMPattern< - vector::ExtractElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = extractEltOp.getSourceVectorType(); - auto llvmType = typeConverter->convertType(vectorType.getElementType()); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = extractEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::ExtractElementOp>( - extractEltOp, llvmType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - class VectorExtractOpConversion : public ConvertOpToLLVMPattern<vector::ExtractOp> { public: @@ -1206,39 +1173,6 @@ public: } }; -class VectorInsertElementOpConversion - : public ConvertOpToLLVMPattern<vector::InsertElementOp> { -public: - using ConvertOpToLLVMPattern<vector::InsertElementOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertEltOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto vectorType = insertEltOp.getDestVectorType(); - auto llvmType = typeConverter->convertType(vectorType); - - // Bail if result type cannot be lowered. - if (!llvmType) - return failure(); - - if (vectorType.getRank() == 0) { - Location loc = insertEltOp.getLoc(); - auto idxType = rewriter.getIndexType(); - auto zero = LLVM::ConstantOp::create(rewriter, loc, - typeConverter->convertType(idxType), - rewriter.getIntegerAttr(idxType, 0)); - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), zero); - return success(); - } - - rewriter.replaceOpWithNewOp<LLVM::InsertElementOp>( - insertEltOp, llvmType, adaptor.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - class VectorInsertOpConversion : public ConvertOpToLLVMPattern<vector::InsertOp> { public: @@ -2244,8 +2178,7 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorGatherOpConversion, VectorScatterOpConversion>( converter, useVectorAlignment); patterns.add<VectorBitCastOpConversion, VectorShuffleOpConversion, - VectorExtractElementOpConversion, VectorExtractOpConversion, - VectorFMAOp1DConversion, VectorInsertElementOpConversion, + VectorExtractOpConversion, VectorFMAOp1DConversion, VectorInsertOpConversion, VectorPrintOpConversion, VectorTypeCastOpConversion, VectorScaleOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp index b1af5f0..508f4e2 100644 --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -690,7 +690,7 @@ struct PrepareTransferWriteConversion /// %lastIndex = arith.subi %length, %c1 : index /// vector.print punctuation <open> /// scf.for %i = %c0 to %length step %c1 { -/// %el = vector.extractelement %v[%i : index] : vector<[4]xi32> +/// %el = vector.extract %v[%i] : i32 from vector<[4]xi32> /// vector.print %el : i32 punctuation <no_punctuation> /// %notLastIndex = arith.cmpi ult, %i, %lastIndex : index /// scf.if %notLastIndex { @@ -1643,7 +1643,7 @@ struct Strategy1d<TransferWriteOp> { /// Is rewritten to approximately the following pseudo-IR: /// ``` /// for i = 0 to 9 { -/// %t = vector.extractelement %vec[i] : vector<9xf32> +/// %t = vector.extract %vec[i] : f32 from vector<9xf32> /// memref.store %t, %arg0[%a + i, %b] : memref<?x?xf32> /// } /// ``` diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 986eae3..a4be7d4 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -335,63 +335,6 @@ struct VectorInsertOpConvert final } }; -struct VectorExtractElementOpConvert final - : public OpConversionPattern<vector::ExtractElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::ExtractElementOp extractOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type resultType = getTypeConverter()->convertType(extractOp.getType()); - if (!resultType) - return failure(); - - if (isa<spirv::ScalarType>(adaptor.getVector().getType())) { - rewriter.replaceOp(extractOp, adaptor.getVector()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>( - extractOp, resultType, adaptor.getVector(), - rewriter.getI32ArrayAttr({static_cast<int>(cstPos.getSExtValue())})); - else - rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>( - extractOp, resultType, adaptor.getVector(), adaptor.getPosition()); - return success(); - } -}; - -struct VectorInsertElementOpConvert final - : public OpConversionPattern<vector::InsertElementOp> { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(vector::InsertElementOp insertOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type vectorType = getTypeConverter()->convertType(insertOp.getType()); - if (!vectorType) - return failure(); - - if (isa<spirv::ScalarType>(vectorType)) { - rewriter.replaceOp(insertOp, adaptor.getSource()); - return success(); - } - - APInt cstPos; - if (matchPattern(adaptor.getPosition(), m_ConstantInt(&cstPos))) - rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>( - insertOp, adaptor.getSource(), adaptor.getDest(), - cstPos.getSExtValue()); - else - rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>( - insertOp, vectorType, insertOp.getDest(), adaptor.getSource(), - adaptor.getPosition()); - return success(); - } -}; - struct VectorInsertStridedSliceOpConvert final : public OpConversionPattern<vector::InsertStridedSliceOp> { using OpConversionPattern::OpConversionPattern; @@ -1107,12 +1050,11 @@ struct VectorToElementOpConvert final void mlir::populateVectorToSPIRVPatterns( const SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { patterns.add< - VectorBitcastConvert, VectorBroadcastConvert, - VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorBitcastConvert, VectorBroadcastConvert, VectorExtractOpConvert, VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>, VectorFmaOpConvert<spirv::CLFmaOp>, VectorFromElementsOpConvert, - VectorToElementOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>, + VectorToElementOpConvert, VectorInsertOpConvert, + VectorReductionPattern<GL_INT_MAX_MIN_OPS>, VectorReductionPattern<CL_INT_MAX_MIN_OPS>, VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>, VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index d1d1062..aa53f94 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -1,4 +1,5 @@ -//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// +//===- OneShotModuleBufferize.cpp - Bufferization across Func. Boundaries +//----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,12 +9,13 @@ // // Module Bufferization is an extension of One-Shot Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` -// implementations for FuncOp, CallOp and ReturnOp. +// implementations for FuncOp, CallOp and ReturnOp. Although it is named +// Module Bufferization, it may operate on any SymbolTable. // -// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. -// This function analyzes the given module and determines the order of analysis -// and bufferization: Functions that are called are processed before their -// respective callers. +// Module Bufferization is run via `runOneShotModuleBufferize(SymbolTableOp, +// ...)`. This function analyzes the given op and determines the order of +// analysis and bufferization: Functions that are called are processed before +// their respective callers. // // After analyzing a FuncOp, additional information about its bbArgs is // gathered and stored in `FuncAnalysisState`. @@ -309,7 +311,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) { /// Return `failure()` if we are unable to retrieve the called FuncOp from /// any func::CallOp. static LogicalResult getFuncOpsOrderedByCalls( - ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, + Operation *moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps, SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables) { // For each FuncOp, the set of functions called by it (i.e. the union of @@ -317,26 +319,29 @@ static LogicalResult getFuncOpsOrderedByCalls( DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy; // For each FuncOp, the number of func::CallOp it contains. DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp; - - for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) { - // Collect function calls and populate the caller map. - numberCallOpsContainedInFuncOp[funcOp] = 0; - WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { - func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); - assert(calledFunction && "could not retrieved called func::FuncOp"); - // If the called function does not have any tensors in its signature, then - // it is not necessary to bufferize the callee before the caller. - if (!hasTensorSignature(calledFunction)) - return WalkResult::skip(); - - callerMap[calledFunction].insert(callOp); - if (calledBy[calledFunction].insert(funcOp).second) { - numberCallOpsContainedInFuncOp[funcOp]++; + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) { + // Collect function calls and populate the caller map. + numberCallOpsContainedInFuncOp[funcOp] = 0; + WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { + func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); + assert(calledFunction && "could not retrieved called func::FuncOp"); + // If the called function does not have any tensors in its signature, + // then it is not necessary to bufferize the callee before the caller. + if (!hasTensorSignature(calledFunction)) + return WalkResult::skip(); + + callerMap[calledFunction].insert(callOp); + if (calledBy[calledFunction].insert(funcOp).second) { + numberCallOpsContainedInFuncOp[funcOp]++; + } + return WalkResult::advance(); + }); + if (res.wasInterrupted()) + return failure(); } - return WalkResult::advance(); - }); - if (res.wasInterrupted()) - return failure(); + } } // Iteratively remove function operations that do not call any of the @@ -447,7 +452,7 @@ static void foldMemRefCasts(func::FuncOp funcOp) { } LogicalResult -mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, +mlir::bufferization::analyzeModuleOp(Operation *moduleOp, OneShotAnalysisState &state, BufferizationStatistics *statistics) { assert(state.getOptions().bufferizeFunctionBoundaries && @@ -512,19 +517,23 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, } void mlir::bufferization::removeBufferizationAttributesInModule( - ModuleOp moduleOp) { - for (auto op : moduleOp.getOps<func::FuncOp>()) { - for (BlockArgument bbArg : op.getArguments()) - removeBufferizationAttributes(bbArg); + Operation *moduleOp) { + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) { + for (BlockArgument bbArg : funcOp.getArguments()) + removeBufferizationAttributes(bbArg); + } + } } } LogicalResult mlir::bufferization::bufferizeModuleOp( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, + Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); - IRRewriter rewriter(moduleOp.getContext()); + IRRewriter rewriter(moduleOp->getContext()); // A list of non-circular functions in the order in which they are analyzed // and bufferized. @@ -571,12 +580,17 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } // Bufferize all other ops. - for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { - // Functions were already bufferized. - if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) - continue; - if (failed(bufferizeOp(&op, options, state, statistics))) - return failure(); + for (mlir::Region ®ion : moduleOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (mlir::Operation &op : + llvm::make_early_inc_range(block.getOperations())) { + // Functions were already bufferized. + if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>()) + continue; + if (failed(bufferizeOp(&op, options, state, statistics))) + return failure(); + } + } } // Post-pass cleanup of function argument attributes. @@ -586,7 +600,7 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } LogicalResult mlir::bufferization::runOneShotModuleBufferize( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, + Operation *moduleOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index f999c93..a6159ee 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -33,7 +33,7 @@ LogicalResult mlir::bufferization::insertTensorCopies( // analysis depending on whether function boundary bufferization is enabled or // not. if (options.bufferizeFunctionBoundaries) { - if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics))) + if (failed(analyzeModuleOp(op, analysisState, statistics))) return failure(); } else { if (failed(analyzeOp(op, analysisState, statistics))) diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 27b6617..b56a212 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -32,6 +32,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" @@ -4622,22 +4623,6 @@ static bool isInvalidPackingPosSpecification(ArrayRef<int64_t> dimsPos, }); } -/// Returns true if the dimension of `sourceShape` is smaller than the dimension -/// of the `limitShape`. -static bool areAllInBound(ArrayRef<int64_t> sourceShape, - ArrayRef<int64_t> limitShape) { - assert( - sourceShape.size() == limitShape.size() && - "expected source shape rank, and limit of the shape to have same rank"); - return llvm::all_of( - llvm::zip(sourceShape, limitShape), [](std::tuple<int64_t, int64_t> it) { - int64_t sourceExtent = std::get<0>(it); - int64_t limit = std::get<1>(it); - return ShapedType::isDynamic(sourceExtent) || - ShapedType::isDynamic(limit) || sourceExtent <= limit; - }); -} - template <typename OpTy> static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { static_assert(llvm::is_one_of<OpTy, PackOp, UnPackOp>::value, @@ -4696,11 +4681,6 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // represents full tiles. RankedTensorType expectedPackedType = PackOp::inferPackedType( unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); - if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) { - return op->emitError("the shape of output is not large enough to hold the " - "packed data. Expected at least ") - << expectedPackedType << ", got " << packedType; - } if (!llvm::all_of( llvm::zip(packedType.getShape().take_back(mixedTiles.size()), mixedTiles), @@ -4717,6 +4697,12 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { return op->emitError("mismatch in inner tile sizes specified and shaped of " "tiled dimension in the packed type"); } + if (failed(verifyCompatibleShape(expectedPackedType.getShape(), + packedType.getShape()))) { + return op->emitError("expected ") + << expectedPackedType << " for the packed domain value, got " + << packedType; + } return success(); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp index 277e50b..9d7f4e0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/PatternMatch.h" namespace mlir { diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp index dad3526..57b610b 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -932,20 +932,6 @@ struct PackOpTiling continue; } - // If the dimension needs padding, it is not supported because there are - // iterations that only write padding values to the whole tile. The - // consumer fusion is driven by the source, so it is not possible to map - // an empty slice to the tile. - bool needExtraPadding = - ShapedType::isDynamic(destDimSize) || !cstInnerSize || - destDimSize * cstInnerSize.value() != srcDimSize; - // Prioritize the case that the op already says that it does not need - // padding. - if (!packOp.getPaddingValue()) - needExtraPadding = false; - if (needExtraPadding) - return failure(); - // Currently fusing `packOp` as consumer only expects perfect tiling // scenario because even if without padding semantic, the `packOp` may // also yield incomplete tiles. E.g. tensor<30xf32> -> tensor<5x6xf32>, diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 0e96b59..869d27a 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -115,8 +115,7 @@ public: bufferization::BufferizationState bufferizationState; - if (failed(bufferization::bufferizeModuleOp(cast<ModuleOp>(getOperation()), - updatedOptions, + if (failed(bufferization::bufferizeModuleOp(getOperation(), updatedOptions, bufferizationState))) return failure(); diff --git a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp index e297f7c..14a4fdf 100644 --- a/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/Interfaces/TransformInterfaces.cpp @@ -21,16 +21,8 @@ #include "llvm/Support/InterleavedRange.h" #define DEBUG_TYPE "transform-dialect" -#define DEBUG_TYPE_FULL "transform-dialect-full" #define DEBUG_PRINT_AFTER_ALL "transform-dialect-print-top-level-after-all" -#ifndef NDEBUG -#define FULL_LDBG(X) \ - DEBUGLOG_WITH_STREAM_AND_TYPE(llvm::dbgs(), DEBUG_TYPE_FULL) -#else -#define FULL_LDBG(X) \ - for (bool _c = false; _c; _c = false) \ - ::llvm::nulls() -#endif +#define FULL_LDBG() LDBG(4) using namespace mlir; diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index bce358d..8789f55 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1258,63 +1258,6 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, CanonicalizeContractAdd<arith::AddFOp>>(context); } -//===----------------------------------------------------------------------===// -// ExtractElementOp -//===----------------------------------------------------------------------===// - -void ExtractElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges.front()); -} - -void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, - Value source) { - result.addOperands({source}); - result.addTypes(llvm::cast<VectorType>(source.getType()).getElementType()); -} - -LogicalResult vector::ExtractElementOp::verify() { - VectorType vectorType = getSourceVectorType(); - if (vectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (vectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here now. - if (!adaptor.getPosition()) - return {}; - - // Fold extractelement (splat X) -> X. - if (auto splat = getVector().getDefiningOp<vector::SplatOp>()) - return splat.getInput(); - - // Fold extractelement(broadcast(X)) -> X. - if (auto broadcast = getVector().getDefiningOp<vector::BroadcastOp>()) - if (!llvm::isa<VectorType>(broadcast.getSource().getType())) - return broadcast.getSource(); - - auto src = dyn_cast_or_null<DenseElementsAttr>(adaptor.getVector()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!pos || !src) - return {}; - - auto srcElements = src.getValues<Attribute>(); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= srcElements.size()) - return {}; - - return srcElements[posIdx]; -} - // Returns `true` if `index` is either within [0, maxIndex) or equal to // `poisonValue`. static bool isValidPositiveIndexOrPoison(int64_t index, int64_t poisonValue, @@ -3184,60 +3127,6 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, } //===----------------------------------------------------------------------===// -// InsertElementOp -//===----------------------------------------------------------------------===// - -void InsertElementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges, - SetIntRangeFn setResultRanges) { - setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1])); -} - -void InsertElementOp::build(OpBuilder &builder, OperationState &result, - Value source, Value dest) { - build(builder, result, source, dest, {}); -} - -LogicalResult InsertElementOp::verify() { - auto dstVectorType = getDestVectorType(); - if (dstVectorType.getRank() == 0) { - if (getPosition()) - return emitOpError("expected position to be empty with 0-D vector"); - return success(); - } - if (dstVectorType.getRank() != 1) - return emitOpError("unexpected >1 vector rank"); - if (!getPosition()) - return emitOpError("expected position for 1-D vector"); - return success(); -} - -OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) { - // Skip the 0-D vector here. - if (!adaptor.getPosition()) - return {}; - - auto src = dyn_cast_or_null<TypedAttr>(adaptor.getSource()); - auto dst = dyn_cast_or_null<DenseElementsAttr>(adaptor.getDest()); - auto pos = dyn_cast_or_null<IntegerAttr>(adaptor.getPosition()); - if (!src || !dst || !pos) - return {}; - - if (src.getType() != getDestVectorType().getElementType()) - return {}; - - auto dstElements = dst.getValues<Attribute>(); - - SmallVector<Attribute> results(dstElements); - - uint64_t posIdx = pos.getInt(); - if (posIdx >= results.size()) - return {}; - results[posIdx] = src; - - return DenseElementsAttr::get(getDestVectorType(), results); -} - -//===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp index 0db9808..7094c8e 100644 --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -901,7 +901,7 @@ LogicalResult PassManager::run(Operation *op) { if (failed(initialize(context, impl->initializationGeneration + 1))) return failure(); initializationKey = newInitKey; - pipelineKey = pipelineInitializationKey; + pipelineInitializationKey = pipelineKey; } // Construct a top level analysis manager for the pipeline. diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 58e5353..a8a2b2e 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -446,6 +446,19 @@ LogicalResult Serializer::processType(Location loc, Type type, LogicalResult Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, SetVector<StringRef> &serializationCtx) { + + // Map unsigned integer types to singless integer types. + // This is needed otherwise the generated spirv assembly will contain + // twice a type declaration (like OpTypeInt 32 0) which is no permitted and + // such module fails validation. Indeed at MLIR level the two types are + // different and lookup in the cache below misses. + // Note: This conversion needs to happen here before the type is looked up in + // the cache. + if (type.isUnsignedInteger()) { + type = IntegerType::get(loc->getContext(), type.getIntOrFloatBitWidth(), + IntegerType::SignednessSemantics::Signless); + } + typeID = getTypeID(type); if (typeID) return success(); |