diff options
Diffstat (limited to 'mlir/lib')
5 files changed, 58 insertions, 28 deletions
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 8ee6308..0d56259 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -259,22 +259,23 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { +static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return *(static_cast<mlir::RewritePatternSet *>(module.ptr)); } -inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { +static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { return {module}; } -inline mlir::FrozenRewritePatternSet * +static inline mlir::FrozenRewritePatternSet * unwrap(MlirFrozenRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr); } -inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) { +static inline MlirFrozenRewritePatternSet +wrap(mlir::FrozenRewritePatternSet *module) { return {module}; } @@ -321,12 +322,12 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { //===----------------------------------------------------------------------===// #if MLIR_ENABLE_PDL_IN_PATTERNMATCH -inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { +static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { assert(module.ptr && "unexpected null module"); return static_cast<mlir::PDLPatternModule *>(module.ptr); } -inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { +static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { return {module}; } diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp index 035f197..399ccf3 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp @@ -267,9 +267,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> { copyInfo.push_back(info); } // Create a call to the kernel and copy the data back. - Operation *callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>( - op, kernelFunc, ArrayRef<Value>()); - rewriter.setInsertionPointAfter(callOp); + rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc, + ArrayRef<Value>()); for (CopyInfo info : copyInfo) copy(loc, info.src, info.dst, info.size, rewriter); return success(); diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 6f28849..0cb0bad 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -802,7 +802,6 @@ public: ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr, dilationAttr); - rewriter.setInsertionPointAfter(op); NanPropagationMode nanMode = op.getNanMode(); rewriter.replaceOp(op, resultOp); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp index 3bd763e..05fc7cb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1622,12 +1622,12 @@ static void generateCollapsedIndexingRegion( } } -void collapseOperandsAndResults(LinalgOp op, - const CollapsingInfo &collapsingInfo, - RewriterBase &rewriter, - SmallVectorImpl<Value> &inputOperands, - SmallVectorImpl<Value> &outputOperands, - SmallVectorImpl<Type> &resultTypes) { +static void collapseOperandsAndResults(LinalgOp op, + const CollapsingInfo &collapsingInfo, + RewriterBase &rewriter, + SmallVectorImpl<Value> &inputOperands, + SmallVectorImpl<Value> &outputOperands, + SmallVectorImpl<Type> &resultTypes) { Location loc = op->getLoc(); inputOperands = llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) { @@ -1651,8 +1651,8 @@ void collapseOperandsAndResults(LinalgOp op, /// Clone a `LinalgOp` to a collapsed version of same name template <typename OpTy> -OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, - const CollapsingInfo &collapsingInfo) { +static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp, + const CollapsingInfo &collapsingInfo) { return nullptr; } @@ -1699,8 +1699,9 @@ GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter, return collapsedOp; } -LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo, - RewriterBase &rewriter) { +static LinalgOp createCollapsedOp(LinalgOp op, + const CollapsingInfo &collapsingInfo, + RewriterBase &rewriter) { if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) { return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo); } else { diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index bf0136b..3a23bbf 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1856,6 +1856,44 @@ void ConversionPatternRewriterImpl::replaceOp( Operation *op, SmallVector<SmallVector<Value>> &&newValues) { assert(newValues.size() == op->getNumResults() && "incorrect number of replacement values"); + LLVM_DEBUG({ + logger.startLine() << "** Replace : '" << op->getName() << "'(" << op + << ")\n"; + if (currentTypeConverter) { + // If the user-provided replacement types are different from the + // legalized types, as per the current type converter, print a note. + // In most cases, the replacement types are expected to match the types + // produced by the type converter, so this could indicate a bug in the + // user code. + for (auto [result, repls] : + llvm::zip_equal(op->getResults(), newValues)) { + Type resultType = result.getType(); + auto logProlog = [&, repls = repls]() { + logger.startLine() << " Note: Replacing op result of type " + << resultType << " with value(s) of type ("; + llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) { + logger.getOStream() << v.getType(); + }); + logger.getOStream() << ")"; + }; + SmallVector<Type> convertedTypes; + if (failed(currentTypeConverter->convertTypes(resultType, + convertedTypes))) { + logProlog(); + logger.getOStream() << ", but the type converter failed to legalize " + "the original type.\n"; + continue; + } + if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) { + logProlog(); + logger.getOStream() << ", but the legalized type(s) is/are ("; + llvm::interleaveComma(convertedTypes, logger.getOStream(), + [&](Type t) { logger.getOStream() << t; }); + logger.getOStream() << ")\n"; + } + } + } + }); if (!config.allowPatternRollback) { // Pattern rollback is not allowed: materialize all IR changes immediately. @@ -2072,10 +2110,6 @@ void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) { assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); - LLVM_DEBUG({ - impl->logger.startLine() - << "** Replace : '" << op->getName() << "'(" << op << ")\n"; - }); // If the current insertion point is before the erased operation, we adjust // the insertion point to be after the operation. @@ -2093,10 +2127,6 @@ void ConversionPatternRewriter::replaceOpWithMultiple( Operation *op, SmallVector<SmallVector<Value>> &&newValues) { assert(op->getNumResults() == newValues.size() && "incorrect # of replacement values"); - LLVM_DEBUG({ - impl->logger.startLine() - << "** Replace : '" << op->getName() << "'(" << op << ")\n"; - }); // If the current insertion point is before the erased operation, we adjust // the insertion point to be after the operation. |