aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib')
-rw-r--r--mlir/lib/CAPI/Transforms/Rewrite.cpp13
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp5
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp1
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp21
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp46
5 files changed, 58 insertions, 28 deletions
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 8ee6308..0d56259 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -259,22 +259,23 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) {
/// RewritePatternSet and FrozenRewritePatternSet API
//===----------------------------------------------------------------------===//
-inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
+static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return *(static_cast<mlir::RewritePatternSet *>(module.ptr));
}
-inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
+static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) {
return {module};
}
-inline mlir::FrozenRewritePatternSet *
+static inline mlir::FrozenRewritePatternSet *
unwrap(MlirFrozenRewritePatternSet module) {
assert(module.ptr && "unexpected null module");
return static_cast<mlir::FrozenRewritePatternSet *>(module.ptr);
}
-inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) {
+static inline MlirFrozenRewritePatternSet
+wrap(mlir::FrozenRewritePatternSet *module) {
return {module};
}
@@ -321,12 +322,12 @@ inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) {
//===----------------------------------------------------------------------===//
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
-inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
+static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) {
assert(module.ptr && "unexpected null module");
return static_cast<mlir::PDLPatternModule *>(module.ptr);
}
-inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
+static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) {
return {module};
}
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index 035f197..399ccf3 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -267,9 +267,8 @@ class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
copyInfo.push_back(info);
}
// Create a call to the kernel and copy the data back.
- Operation *callOp = rewriter.replaceOpWithNewOp<LLVM::CallOp>(
- op, kernelFunc, ArrayRef<Value>());
- rewriter.setInsertionPointAfter(callOp);
+ rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
+ ArrayRef<Value>());
for (CopyInfo info : copyInfo)
copy(loc, info.src, info.dst, info.size, rewriter);
return success();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 6f28849..0cb0bad 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -802,7 +802,6 @@ public:
ValueRange{paddedInput, fakeWindowDims}, filledEmptyTensor, strideAttr,
dilationAttr);
- rewriter.setInsertionPointAfter(op);
NanPropagationMode nanMode = op.getNanMode();
rewriter.replaceOp(op, resultOp);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index 3bd763e..05fc7cb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1622,12 +1622,12 @@ static void generateCollapsedIndexingRegion(
}
}
-void collapseOperandsAndResults(LinalgOp op,
- const CollapsingInfo &collapsingInfo,
- RewriterBase &rewriter,
- SmallVectorImpl<Value> &inputOperands,
- SmallVectorImpl<Value> &outputOperands,
- SmallVectorImpl<Type> &resultTypes) {
+static void collapseOperandsAndResults(LinalgOp op,
+ const CollapsingInfo &collapsingInfo,
+ RewriterBase &rewriter,
+ SmallVectorImpl<Value> &inputOperands,
+ SmallVectorImpl<Value> &outputOperands,
+ SmallVectorImpl<Type> &resultTypes) {
Location loc = op->getLoc();
inputOperands =
llvm::map_to_vector(op.getDpsInputOperands(), [&](OpOperand *opOperand) {
@@ -1651,8 +1651,8 @@ void collapseOperandsAndResults(LinalgOp op,
/// Clone a `LinalgOp` to a collapsed version of same name
template <typename OpTy>
-OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
- const CollapsingInfo &collapsingInfo) {
+static OpTy cloneToCollapsedOp(RewriterBase &rewriter, OpTy origOp,
+ const CollapsingInfo &collapsingInfo) {
return nullptr;
}
@@ -1699,8 +1699,9 @@ GenericOp cloneToCollapsedOp<GenericOp>(RewriterBase &rewriter,
return collapsedOp;
}
-LinalgOp createCollapsedOp(LinalgOp op, const CollapsingInfo &collapsingInfo,
- RewriterBase &rewriter) {
+static LinalgOp createCollapsedOp(LinalgOp op,
+ const CollapsingInfo &collapsingInfo,
+ RewriterBase &rewriter) {
if (GenericOp genericOp = dyn_cast<GenericOp>(op.getOperation())) {
return cloneToCollapsedOp(rewriter, genericOp, collapsingInfo);
} else {
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index bf0136b..3a23bbf 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1856,6 +1856,44 @@ void ConversionPatternRewriterImpl::replaceOp(
Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
assert(newValues.size() == op->getNumResults() &&
"incorrect number of replacement values");
+ LLVM_DEBUG({
+ logger.startLine() << "** Replace : '" << op->getName() << "'(" << op
+ << ")\n";
+ if (currentTypeConverter) {
+ // If the user-provided replacement types are different from the
+ // legalized types, as per the current type converter, print a note.
+ // In most cases, the replacement types are expected to match the types
+ // produced by the type converter, so this could indicate a bug in the
+ // user code.
+ for (auto [result, repls] :
+ llvm::zip_equal(op->getResults(), newValues)) {
+ Type resultType = result.getType();
+ auto logProlog = [&, repls = repls]() {
+ logger.startLine() << " Note: Replacing op result of type "
+ << resultType << " with value(s) of type (";
+ llvm::interleaveComma(repls, logger.getOStream(), [&](Value v) {
+ logger.getOStream() << v.getType();
+ });
+ logger.getOStream() << ")";
+ };
+ SmallVector<Type> convertedTypes;
+ if (failed(currentTypeConverter->convertTypes(resultType,
+ convertedTypes))) {
+ logProlog();
+ logger.getOStream() << ", but the type converter failed to legalize "
+ "the original type.\n";
+ continue;
+ }
+ if (TypeRange(convertedTypes) != TypeRange(ValueRange(repls))) {
+ logProlog();
+ logger.getOStream() << ", but the legalized type(s) is/are (";
+ llvm::interleaveComma(convertedTypes, logger.getOStream(),
+ [&](Type t) { logger.getOStream() << t; });
+ logger.getOStream() << ")\n";
+ }
+ }
+ }
+ });
if (!config.allowPatternRollback) {
// Pattern rollback is not allowed: materialize all IR changes immediately.
@@ -2072,10 +2110,6 @@ void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
- LLVM_DEBUG({
- impl->logger.startLine()
- << "** Replace : '" << op->getName() << "'(" << op << ")\n";
- });
// If the current insertion point is before the erased operation, we adjust
// the insertion point to be after the operation.
@@ -2093,10 +2127,6 @@ void ConversionPatternRewriter::replaceOpWithMultiple(
Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
assert(op->getNumResults() == newValues.size() &&
"incorrect # of replacement values");
- LLVM_DEBUG({
- impl->logger.startLine()
- << "** Replace : '" << op->getName() << "'(" << op << ")\n";
- });
// If the current insertion point is before the erased operation, we adjust
// the insertion point to be after the operation.