aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChris Lattner <clattner@nondot.org>2021-03-20 16:29:41 -0700
committerChris Lattner <clattner@nondot.org>2021-03-21 10:06:31 -0700
commit3a506b31a341585a21b21c42253ea9fc54c55b37 (patch)
treeef62eda91c35b5d5a4a7b37366fa23da92830418
parent9f864d202558b4206adc26789aff8a204ebbe0b2 (diff)
downloadllvm-3a506b31a341585a21b21c42253ea9fc54c55b37.zip
llvm-3a506b31a341585a21b21c42253ea9fc54c55b37.tar.gz
llvm-3a506b31a341585a21b21c42253ea9fc54c55b37.tar.bz2
Change OwningRewritePatternList to carry an MLIRContext with it.
This updates the codebase to pass the context when creating an instance of OwningRewritePatternList, and starts removing extraneous MLIRContext parameters. There are many many more to be removed. Differential Revision: https://reviews.llvm.org/D99028
-rw-r--r--mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h6
-rw-r--r--mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h4
-rw-r--r--mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h3
-rw-r--r--mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h3
-rw-r--r--mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h2
-rw-r--r--mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h3
-rw-r--r--mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h4
-rw-r--r--mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h4
-rw-r--r--mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h9
-rw-r--r--mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h5
-rw-r--r--mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h8
-rw-r--r--mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h2
-rw-r--r--mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h3
-rw-r--r--mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h4
-rw-r--r--mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h2
-rw-r--r--mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h3
-rw-r--r--mlir/include/mlir/Dialect/GPU/Passes.h8
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Passes.h13
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h36
-rw-r--r--mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h31
-rw-r--r--mlir/include/mlir/Dialect/Math/Transforms/Passes.h8
-rw-r--r--mlir/include/mlir/Dialect/SCF/Transforms.h4
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h2
-rw-r--r--mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h3
-rw-r--r--mlir/include/mlir/Dialect/Shape/Transforms/Passes.h10
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h5
-rw-r--r--mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h6
-rw-r--r--mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h3
-rw-r--r--mlir/include/mlir/Dialect/Vector/VectorOps.h19
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h33
-rw-r--r--mlir/include/mlir/Transforms/Bufferize.h3
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h32
-rw-r--r--mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp14
-rw-r--r--mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp12
-rw-r--r--mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp2
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp6
-rw-r--r--mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp5
-rw-r--r--mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp5
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp4
-rw-r--r--mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp6
-rw-r--r--mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp2
-rw-r--r--mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp6
-rw-r--r--mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp6
-rw-r--r--mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp8
-rw-r--r--mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp2
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp6
-rw-r--r--mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp4
-rw-r--r--mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp2
-rw-r--r--mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp5
-rw-r--r--mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp8
-rw-r--r--mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp10
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp2
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp15
-rw-r--r--mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp10
-rw-r--r--mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp12
-rw-r--r--mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp10
-rw-r--r--mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp2
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp10
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp10
-rw-r--r--mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp8
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp5
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp5
-rw-r--r--mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp6
-rw-r--r--mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp5
-rw-r--r--mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp8
-rw-r--r--mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp9
-rw-r--r--mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp12
-rw-r--r--mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp2
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp11
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp6
-rw-r--r--mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp4
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp2
-rw-r--r--mlir/lib/Dialect/Affine/Utils/Utils.cpp2
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp2
-rw-r--r--mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp2
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp9
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp8
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp7
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp9
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp6
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp27
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp15
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Loops.cpp2
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp4
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp5
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp33
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp42
-rw-r--r--mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp5
-rw-r--r--mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp8
-rw-r--r--mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp2
-rw-r--r--mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp2
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp6
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp6
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp5
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp10
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp2
-rw-r--r--mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp5
-rw-r--r--mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp8
-rw-r--r--mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp9
-rw-r--r--mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp9
-rw-r--r--mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp6
-rw-r--r--mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp10
-rw-r--r--mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp9
-rw-r--r--mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp11
-rw-r--r--mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp16
-rw-r--r--mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp2
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp9
-rw-r--r--mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp2
-rw-r--r--mlir/lib/Dialect/Vector/VectorOps.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/VectorTransforms.cpp36
-rw-r--r--mlir/lib/Transforms/Bufferize.cpp12
-rw-r--r--mlir/lib/Transforms/Canonicalizer.cpp2
-rw-r--r--mlir/lib/Transforms/Utils/DialectConversion.cpp17
-rw-r--r--mlir/lib/Transforms/Utils/LoopUtils.cpp2
-rw-r--r--mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp2
-rw-r--r--mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp2
-rw-r--r--mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp4
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp23
-rw-r--r--mlir/test/lib/Dialect/Test/TestTraits.cpp2
-rw-r--r--mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp2
-rw-r--r--mlir/test/lib/Transforms/TestConvVectorization.cpp10
-rw-r--r--mlir/test/lib/Transforms/TestConvertCallOp.cpp10
-rw-r--r--mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp2
-rw-r--r--mlir/test/lib/Transforms/TestExpandTanh.cpp4
-rw-r--r--mlir/test/lib/Transforms/TestGpuRewrite.cpp4
-rw-r--r--mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp2
-rw-r--r--mlir/test/lib/Transforms/TestLinalgTransforms.cpp50
-rw-r--r--mlir/test/lib/Transforms/TestPolynomialApproximation.cpp4
-rw-r--r--mlir/test/lib/Transforms/TestSparsification.cpp10
-rw-r--r--mlir/test/lib/Transforms/TestVectorTransforms.cpp46
-rw-r--r--mlir/unittests/Rewrite/PatternBenefit.cpp2
134 files changed, 550 insertions, 574 deletions
diff --git a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
index 4647cac..8d3301c 100644
--- a/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
+++ b/mlir/include/mlir/Conversion/AffineToStandard/AffineToStandard.h
@@ -18,7 +18,6 @@ class AffineMap;
class AffineParallelOp;
class Location;
struct LogicalResult;
-class MLIRContext;
class OpBuilder;
class Pass;
class RewritePattern;
@@ -43,13 +42,12 @@ Optional<SmallVector<Value, 8>> expandAffineMap(OpBuilder &builder,
/// Collect a set of patterns to convert from the Affine dialect to the Standard
/// dialect, in particular convert structured affine control flow into CFG
/// branch-based control flow.
-void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
+void populateAffineToStdConversionPatterns(OwningRewritePatternList &patterns);
/// Collect a set of patterns to convert vector-related Affine ops to the Vector
/// dialect.
void populateAffineToVectorConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
/// Emit code that computes the lower bound of the given affine loop using
/// standard arithmetic operations.
diff --git a/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
index 938c5cb..670942a 100644
--- a/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
+++ b/mlir/include/mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h
@@ -33,8 +33,8 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertAsyncToLLVMPass();
/// the TypeConverter, but otherwise don't care what type conversions are
/// happening.
void populateAsyncStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target);
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target);
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
index ad5dac0..e679b86 100644
--- a/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
+++ b/mlir/include/mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h
@@ -21,8 +21,7 @@ class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating GPU Ops to
/// SPIR-V ops. For a gpu.func to be converted, it should have a
/// spv.entry_point_abi attribute.
-void populateGPUToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
index b2fc9e4..8f94597 100644
--- a/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
+++ b/mlir/include/mlir/Conversion/LinalgToSPIRV/LinalgToSPIRV.h
@@ -20,8 +20,7 @@ class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating Linalg ops to
/// SPIR-V ops.
-void populateLinalgToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
index 3a6c8bb..240bc1f 100644
--- a/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
+++ b/mlir/include/mlir/Conversion/LinalgToStandard/LinalgToStandard.h
@@ -70,7 +70,7 @@ public:
/// Populate the given list with patterns that convert from Linalg to Standard.
void populateLinalgToStandardConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
} // namespace linalg
diff --git a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
index d6316f6..14c1608 100644
--- a/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
+++ b/mlir/include/mlir/Conversion/SCFToGPU/SCFToGPU.h
@@ -42,8 +42,7 @@ LogicalResult convertAffineLoopNestToGPULaunch(AffineForOp forOp,
/// Adds the conversion pattern from `scf.parallel` to `gpu.launch` to the
/// provided pattern list.
-void populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
+void populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns);
/// Configures the rewrite target such that only `scf.parallel` operations that
/// are not rewritten by the provided patterns are legal.
diff --git a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
index e0bab27..5a14c9b 100644
--- a/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
+++ b/mlir/include/mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h
@@ -15,7 +15,6 @@
#include <memory>
namespace mlir {
-class MLIRContext;
class Pass;
// Owning list of rewriting patterns.
@@ -35,8 +34,7 @@ private:
/// Collects a set of patterns to lower from scf.for, scf.if, and
/// loop.terminator to CFG operations within the SPIR-V dialect.
-void populateSCFToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
OwningRewritePatternList &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h b/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
index fd85a3d..95667d8 100644
--- a/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
+++ b/mlir/include/mlir/Conversion/SCFToStandard/SCFToStandard.h
@@ -14,7 +14,6 @@
namespace mlir {
struct LogicalResult;
-class MLIRContext;
class Pass;
class RewritePattern;
@@ -24,8 +23,7 @@ class OwningRewritePatternList;
/// Collect a set of patterns to lower from scf.for, scf.if, and
/// loop.terminator to CFG operations within the Standard dialect, in particular
/// convert structured control flow into CFG branch-based control flow.
-void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
+void populateLoopToStdConversionPatterns(OwningRewritePatternList &patterns);
/// Creates a pass to convert scf.for, scf.if and loop.terminator ops to CFG.
std::unique_ptr<Pass> createLowerToCFGPass();
diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
index 3ba24ea..2f6b6d7 100644
--- a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
+++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h
@@ -40,20 +40,17 @@ void encodeBindAttribute(ModuleOp module);
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter);
/// Populates the given list with patterns that convert from SPIR-V to LLVM.
-void populateSPIRVToLLVMConversionPatterns(MLIRContext *context,
- LLVMTypeConverter &typeConverter,
+void populateSPIRVToLLVMConversionPatterns(LLVMTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
/// Populates the given list with patterns for function conversion from SPIR-V
/// to LLVM.
void populateSPIRVToLLVMFunctionConversionPatterns(
- MLIRContext *context, LLVMTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns);
/// Populates the given patterns for module conversion from SPIR-V to LLVM.
void populateSPIRVToLLVMModuleConversionPatterns(
- MLIRContext *context, LLVMTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
index 176f101..7c94470 100644
--- a/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
+++ b/mlir/include/mlir/Conversion/ShapeToStandard/ShapeToStandard.h
@@ -14,19 +14,18 @@
namespace mlir {
class FuncOp;
-class MLIRContext;
class ModuleOp;
template <typename T>
class OperationPass;
class OwningRewritePatternList;
void populateShapeToStandardConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
std::unique_ptr<OperationPass<ModuleOp>> createConvertShapeToStandardPass();
void populateConvertShapeConstraintsConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
std::unique_ptr<OperationPass<FuncOp>> createConvertShapeConstraintsPass();
diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
index 87946d3..18cf4f3 100644
--- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
+++ b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h
@@ -21,8 +21,7 @@ class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating standard ops
/// to SPIR-V ops. Also adds the patterns to legalize ops not directly
/// translated to SPIR-V dialect.
-void populateStandardToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
/// Appends to a pattern list additional patterns for translating tensor ops
@@ -37,15 +36,14 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
/// variables. SPIR-V consumers in GPU drivers may or may not optimize that
/// away. So this has implications over register pressure. Therefore, a
/// threshold is used to control when the patterns should kick in.
-void populateTensorToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
int64_t byteCountThreshold,
OwningRewritePatternList &patterns);
/// Appends to a pattern list patterns to legalize ops that are not directly
/// lowered to SPIR-V.
void populateStdLegalizationPatternsForSPIRVLowering(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
index 42493a5..7553839 100644
--- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
+++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h
@@ -28,7 +28,7 @@ void addTosaToLinalgOnTensorsPasses(OpPassManager &pm);
/// Populates conversion passes from TOSA dialect to Linalg dialect.
void populateTosaToLinalgOnTensorsConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns);
+ OwningRewritePatternList *patterns);
} // namespace tosa
} // namespace mlir
diff --git a/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h b/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
index 68ed0e0..08b2fe9 100644
--- a/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
+++ b/mlir/include/mlir/Conversion/TosaToSCF/TosaToSCF.h
@@ -20,8 +20,7 @@ namespace tosa {
std::unique_ptr<Pass> createTosaToSCF();
-void populateTosaToSCFConversionPatterns(MLIRContext *context,
- OwningRewritePatternList *patterns);
+void populateTosaToSCFConversionPatterns(OwningRewritePatternList *patterns);
/// Populates passes to convert from TOSA to SCF.
void addTosaToSCFPasses(OpPassManager &pm);
diff --git a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
index 5a63d78..f130471 100644
--- a/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
+++ b/mlir/include/mlir/Conversion/TosaToStandard/TosaToStandard.h
@@ -21,10 +21,10 @@ namespace tosa {
std::unique_ptr<Pass> createTosaToStandard();
void populateTosaToStandardConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns);
+ OwningRewritePatternList *patterns);
void populateTosaRescaleToStandardConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns);
+ OwningRewritePatternList *patterns);
/// Populates passes to convert from TOSA to Standard.
void addTosaToStandardPasses(OpPassManager &pm);
diff --git a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
index f34a576..e7478cf 100644
--- a/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
+++ b/mlir/include/mlir/Conversion/VectorToSCF/VectorToSCF.h
@@ -162,7 +162,7 @@ struct VectorTransferRewriter : public RewritePattern {
/// Collect a set of patterns to convert from the Vector dialect to SCF + std.
void populateVectorToSCFConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
+ OwningRewritePatternList &patterns,
const VectorTransferToSCFOptions &options = VectorTransferToSCFOptions());
/// Create a pass to convert a subset of vector ops to SCF.
diff --git a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
index 7908f6e..8fc606f 100644
--- a/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
+++ b/mlir/include/mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h
@@ -20,8 +20,7 @@ class SPIRVTypeConverter;
/// Appends to a pattern list additional patterns for translating Vector Ops to
/// SPIR-V ops.
-void populateVectorToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/GPU/Passes.h b/mlir/include/mlir/Dialect/GPU/Passes.h
index bfb5626..327f9d6 100644
--- a/mlir/include/mlir/Dialect/GPU/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Passes.h
@@ -31,13 +31,11 @@ std::unique_ptr<OperationPass<ModuleOp>> createGpuKernelOutliningPass();
std::unique_ptr<OperationPass<FuncOp>> createGpuAsyncRegionPass();
/// Collect a set of patterns to rewrite all-reduce ops within the GPU dialect.
-void populateGpuAllReducePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+void populateGpuAllReducePatterns(OwningRewritePatternList &patterns);
/// Collect all patterns to rewrite ops within the GPU dialect.
-inline void populateGpuRewritePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
- populateGpuAllReducePatterns(context, patterns);
+inline void populateGpuRewritePatterns(OwningRewritePatternList &patterns) {
+ populateGpuAllReducePatterns(patterns);
}
namespace gpu {
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h
index 34e2568..24f49b5 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.h
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.h
@@ -53,7 +53,7 @@ std::unique_ptr<OperationPass<FuncOp>> createLinalgBufferizePass();
/// Populate patterns that convert `ElementwiseMappable` ops to linalg
/// parallel loops.
void populateElementwiseToLinalgConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
/// Create a pass to conver named Linalg operations to Linalg generic
/// operations.
@@ -67,14 +67,14 @@ std::unique_ptr<Pass> createLinalgDetensorizePass();
/// producer (consumer) generic operation by expanding the dimensionality of the
/// loop in the generic op.
void populateFoldReshapeOpsByExpansionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
/// indexing map used to access the source (target) of the reshape operation in
/// the generic/indexed_generic operation.
void populateFoldReshapeOpsByLinearizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its
/// producer (consumer) generic/indexed_generic operation by linearizing the
@@ -83,16 +83,15 @@ void populateFoldReshapeOpsByLinearizationPatterns(
/// the tensor reshape involved is collapsing (introducing) unit-extent
/// dimensions.
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
/// Patterns for fusing linalg operation on tensors.
-void populateLinalgTensorOpsFusionPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+void populateLinalgTensorOpsFusionPatterns(OwningRewritePatternList &patterns);
/// Patterns to fold unit-extent dimensions in operands/results of linalg ops on
/// tensors.
void populateLinalgFoldUnitExtentDimsPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
//===----------------------------------------------------------------------===//
// Registration
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
index 872e763..421a544 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h
@@ -36,11 +36,11 @@ template <template <typename> class PatternType, typename ConcreteOpType,
typename = std::enable_if_t<std::is_member_function_pointer<
decltype(&ConcreteOpType::getOperationName)>::value>>
void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
- MLIRContext *context, StringRef opName,
- linalg::LinalgTransformationFilter m) {
+ StringRef opName, linalg::LinalgTransformationFilter m) {
assert(opName == ConcreteOpType::getOperationName() &&
"explicit name must match ConcreteOpType::getOperationName");
- patternList.insert<PatternType<ConcreteOpType>>(context, options, m);
+ patternList.insert<PatternType<ConcreteOpType>>(patternList.getContext(),
+ options, m);
}
/// SFINAE: Enqueue helper for OpType that do not have a `getOperationName`
@@ -48,25 +48,26 @@ void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
template <template <typename> class PatternType, typename OpType,
typename OptionsType>
void sfinae_enqueue(OwningRewritePatternList &patternList, OptionsType options,
- MLIRContext *context, StringRef opName,
- linalg::LinalgTransformationFilter m) {
+ StringRef opName, linalg::LinalgTransformationFilter m) {
assert(!opName.empty() && "opName must not be empty");
- patternList.insert<PatternType<OpType>>(opName, context, options, m);
+ patternList.insert<PatternType<OpType>>(opName, patternList.getContext(),
+ options, m);
}
template <typename PatternType, typename OpType, typename OptionsType>
void enqueue(OwningRewritePatternList &patternList, OptionsType options,
- MLIRContext *context, StringRef opName,
- linalg::LinalgTransformationFilter m) {
+ StringRef opName, linalg::LinalgTransformationFilter m) {
if (!opName.empty())
- patternList.insert<PatternType>(opName, context, options, m);
+ patternList.insert<PatternType>(opName, patternList.getContext(), options,
+ m);
else
patternList.insert<PatternType>(m.addOpFilter<OpType>(), options);
}
/// Promotion transformation enqueues a particular stage-1 pattern for
/// `Tile<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType> struct Tile : public Transformation {
+template <typename LinalgOpType>
+struct Tile : public Transformation {
explicit Tile(linalg::LinalgTilingOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
: Transformation(f), opName(LinalgOpType::getOperationName()),
@@ -79,9 +80,9 @@ template <typename LinalgOpType> struct Tile : public Transformation {
OwningRewritePatternList
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
- OwningRewritePatternList tilingPatterns;
+ OwningRewritePatternList tilingPatterns(context);
sfinae_enqueue<linalg::LinalgTilingPattern, LinalgOpType>(
- tilingPatterns, options, context, opName, m);
+ tilingPatterns, options, opName, m);
return tilingPatterns;
}
@@ -92,7 +93,8 @@ private:
/// Promotion transformation enqueues a particular stage-1 pattern for
/// `Promote<LinalgOpType>`with the appropriate `options`.
-template <typename LinalgOpType> struct Promote : public Transformation {
+template <typename LinalgOpType>
+struct Promote : public Transformation {
explicit Promote(
linalg::LinalgPromotionOptions options,
linalg::LinalgTransformationFilter::FilterFunction f = nullptr)
@@ -106,9 +108,9 @@ template <typename LinalgOpType> struct Promote : public Transformation {
OwningRewritePatternList
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
- OwningRewritePatternList promotionPatterns;
+ OwningRewritePatternList promotionPatterns(context);
sfinae_enqueue<linalg::LinalgPromotionPattern, LinalgOpType>(
- promotionPatterns, options, context, opName, m);
+ promotionPatterns, options, opName, m);
return promotionPatterns;
}
@@ -134,9 +136,9 @@ struct Vectorize : public Transformation {
OwningRewritePatternList
buildRewritePatterns(MLIRContext *context,
linalg::LinalgTransformationFilter m) override {
- OwningRewritePatternList vectorizationPatterns;
+ OwningRewritePatternList vectorizationPatterns(context);
enqueue<linalg::LinalgVectorizationPattern, LinalgOpType>(
- vectorizationPatterns, options, context, opName, m);
+ vectorizationPatterns, options, opName, m);
vectorizationPatterns.insert<linalg::LinalgCopyVTRForwardingPattern,
linalg::LinalgCopyVTWForwardingPattern>(
context, /*benefit=*/2);
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 6d42838..318db82 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -37,8 +37,7 @@ void populateConvVectorizationPatterns(
ArrayRef<int64_t> tileSizes);
/// Populates the given list with patterns to bufferize linalg ops.
-void populateLinalgBufferizePatterns(MLIRContext *context,
- BufferizeTypeConverter &converter,
+void populateLinalgBufferizePatterns(BufferizeTypeConverter &converter,
OwningRewritePatternList &patterns);
/// Performs standalone tiling of a single LinalgOp by `tileSizes`.
@@ -445,7 +444,7 @@ struct LinalgTilingOptions {
OwningRewritePatternList
getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx);
void populateLinalgTilingCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
/// Base pattern that applied the tiling transformation specified by `options`.
/// Abort and return failure in 2 cases:
@@ -692,11 +691,10 @@ template <
typename = std::enable_if_t<detect_has_get_operation_name<OpType>::value>,
typename = void>
void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
- MLIRContext *context,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f) {
patternList.insert<linalg::LinalgVectorizationPattern>(
- OpType::getOperationName(), context, options, f);
+ OpType::getOperationName(), patternList.getContext(), options, f);
}
/// SFINAE helper for single C++ class without a `getOperationName` method (e.g.
@@ -704,7 +702,6 @@ void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
template <typename OpType, typename = std::enable_if_t<
!detect_has_get_operation_name<OpType>::value>>
void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
- MLIRContext *context,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f) {
patternList.insert<linalg::LinalgVectorizationPattern>(
@@ -714,14 +711,14 @@ void insertVectorizationPatternImpl(OwningRewritePatternList &patternList,
/// Variadic helper function to insert vectorization patterns for C++ ops.
template <typename... OpTypes>
void insertVectorizationPatterns(OwningRewritePatternList &patternList,
- MLIRContext *context,
linalg::LinalgVectorizationOptions options,
linalg::LinalgTransformationFilter f =
linalg::LinalgTransformationFilter()) {
// FIXME: In c++17 this can be simplified by using 'fold expressions'.
- (void)std::initializer_list<int>{0, (insertVectorizationPatternImpl<OpTypes>(
- patternList, context, options, f),
- 0)...};
+ (void)std::initializer_list<int>{
+ 0, (insertVectorizationPatternImpl<OpTypes>(
+ patternList, patternList.getContext(), options, f),
+ 0)...};
}
///
@@ -793,13 +790,13 @@ private:
/// Populates `patterns` with patterns to convert spec-generated named ops to
/// linalg.generic ops.
void populateLinalgNamedOpsGeneralizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
+ OwningRewritePatternList &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter());
/// Populates `patterns` with patterns to convert linalg.conv ops to
/// linalg.generic ops.
void populateLinalgConvGeneralizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
+ OwningRewritePatternList &patterns,
LinalgTransformationFilter filter = LinalgTransformationFilter());
//===----------------------------------------------------------------------===//
@@ -893,7 +890,7 @@ struct AffineMinSCFCanonicalizationPattern
PatternRewriter &rewriter) const override;
};
- /// Helper struct to return the results of `substituteMin`.
+/// Helper struct to return the results of `substituteMin`.
struct AffineMapAndOperands {
AffineMap map;
SmallVector<Value> dims;
@@ -914,8 +911,8 @@ struct AffineMapAndOperands {
/// Return a new AffineMap, dims and symbols that have been canonicalized and
/// simplified.
AffineMapAndOperands substituteMin(
- AffineMinOp affineMinOp,
- llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
+ AffineMinOp affineMinOp,
+ llvm::function_ref<bool(Operation *)> substituteOperation = nullptr);
/// Converts Convolution op into vector contraction.
///
@@ -1060,12 +1057,12 @@ struct SparsificationOptions {
/// Sets up sparsification rewriting rules with the given options.
void populateSparsificationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
+ OwningRewritePatternList &patterns,
const SparsificationOptions &options = SparsificationOptions());
/// Sets up sparsification conversion rules with the given options.
void populateSparsificationConversionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns);
+ OwningRewritePatternList &patterns);
} // namespace linalg
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
index c965bab..3ce88a13 100644
--- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h
@@ -9,18 +9,14 @@
#ifndef MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
#define MLIR_DIALECT_MATH_TRANSFORMS_PASSES_H_
-#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/Bufferize.h"
-
namespace mlir {
class OwningRewritePatternList;
-void populateExpandTanhPattern(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
+void populateExpandTanhPattern(OwningRewritePatternList &patterns);
void populateMathPolynomialApproximationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx);
+ OwningRewritePatternList &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms.h b/mlir/include/mlir/Dialect/SCF/Transforms.h
index 456eb4e..914a1a0 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms.h
@@ -60,8 +60,8 @@ tileParallelLoop(ParallelOp op, llvm::ArrayRef<int64_t> tileSizes);
/// corresponding scf.yield ops need to update their types accordingly to the
/// TypeConverter, but otherwise don't care what type conversions are happening.
void populateSCFStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target);
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target);
} // namespace scf
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
index 1921dbb..098d4fd 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.h
@@ -24,7 +24,7 @@
namespace mlir {
namespace spirv {
void populateSPIRVGLSLCanonicalizationPatterns(
- mlir::OwningRewritePatternList &results, mlir::MLIRContext *context);
+ mlir::OwningRewritePatternList &results);
} // namespace spirv
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
index 1ac7db1..d7cd76b 100644
--- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
+++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h
@@ -67,8 +67,7 @@ private:
/// `func` op to the SPIR-V dialect. These patterns do not handle shader
/// interface/ABI; they convert function parameters to be of SPIR-V allowed
/// types.
-void populateBuiltinFuncToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
namespace spirv {
diff --git a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
index 6df1299..9e4b4af 100644
--- a/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Shape/Transforms/Passes.h
@@ -28,8 +28,7 @@ namespace mlir {
std::unique_ptr<Pass> createShapeToShapeLowering();
/// Collects a set of patterns to rewrite ops within the Shape dialect.
-void populateShapeRewritePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+void populateShapeRewritePatterns(OwningRewritePatternList &patterns);
// Collects a set of patterns to replace all constraints with passing witnesses.
// This is intended to then allow all ShapeConstraint related ops and data to
@@ -37,8 +36,7 @@ void populateShapeRewritePatterns(MLIRContext *context,
// canonicalization and dead code elimination.
//
// After this pass, no cstr_ operations exist.
-void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx);
+void populateRemoveShapeConstraintsPatterns(OwningRewritePatternList &patterns);
std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
/// Populates patterns for shape dialect structural type conversions and sets up
@@ -53,8 +51,8 @@ std::unique_ptr<FunctionPass> createRemoveShapeConstraintsPass();
/// do for a structural type conversion is to update both of their types
/// consistently to the new types prescribed by the TypeConverter.
void populateShapeStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target);
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target);
// Bufferizes shape dialect ops.
//
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
index 1a0308d..a7eb59a 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/FuncConversions.h
@@ -25,7 +25,6 @@ class TypeConverter;
/// Add a pattern to the given pattern list to convert the operand and result
/// types of a CallOp with the given type converter.
void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
- MLIRContext *ctx,
TypeConverter &converter);
/// Add a pattern to the given pattern list to rewrite branch operations to use
@@ -33,8 +32,7 @@ void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
/// be done if the branch operation implements the BranchOpInterface. Only
/// needed for partial conversions.
void populateBranchOpInterfaceTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &converter);
+ OwningRewritePatternList &patterns, TypeConverter &converter);
/// Return true if op is a BranchOpInterface op whose operands are all legal
/// according to converter.
@@ -44,7 +42,6 @@ bool isLegalForBranchOpInterfaceTypeConversionPattern(Operation *op,
/// Add a pattern to the given pattern list to rewrite `return` ops to use
/// operands that have been legalized by the conversion framework.
void populateReturnOpTypeConversionPattern(OwningRewritePatternList &patterns,
- MLIRContext *ctx,
TypeConverter &converter);
/// For ReturnLike ops (except `return`), return True. If op is a `return` &&
diff --git a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
index a6fdca8..1e04b22 100644
--- a/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/StandardOps/Transforms/Passes.h
@@ -21,8 +21,7 @@ namespace mlir {
class OwningRewritePatternList;
-void populateStdBufferizePatterns(MLIRContext *context,
- BufferizeTypeConverter &typeConverter,
+void populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
/// Creates an instance of std bufferization pass.
@@ -42,8 +41,7 @@ std::unique_ptr<Pass> createTensorConstantBufferizePass();
std::unique_ptr<Pass> createStdExpandOpsPass();
/// Collects a set of patterns to rewrite ops within the Std dialect.
-void populateStdExpandOpsPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns);
+void populateStdExpandOpsPatterns(OwningRewritePatternList &patterns);
//===----------------------------------------------------------------------===//
// Registration
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
index 436b3fc..72539c8 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Passes.h
@@ -16,8 +16,7 @@ namespace mlir {
class OwningRewritePatternList;
-void populateTensorBufferizePatterns(MLIRContext *context,
- BufferizeTypeConverter &typeConverter,
+void populateTensorBufferizePatterns(BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns);
/// Creates an instance of `tensor` dialect bufferization pass.
diff --git a/mlir/include/mlir/Dialect/Vector/VectorOps.h b/mlir/include/mlir/Dialect/Vector/VectorOps.h
index 9e486d0..7d20e64 100644
--- a/mlir/include/mlir/Dialect/Vector/VectorOps.h
+++ b/mlir/include/mlir/Dialect/Vector/VectorOps.h
@@ -39,11 +39,11 @@ struct BitmaskEnumStorage;
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context);
+ OwningRewritePatternList &patterns);
/// Collect a set of vector-to-vector transformation patterns.
void populateVectorToVectorTransformationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context);
+ OwningRewritePatternList &patterns);
/// Collect a set of patterns to split transfer read/write ops.
///
@@ -54,7 +54,7 @@ void populateVectorToVectorTransformationPatterns(
/// of being generic canonicalization patterns. Also one can let the
/// `ignoreFilter` to return true to fail matching for fine-grained control.
void populateSplitVectorTransferPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
+ OwningRewritePatternList &patterns,
std::function<bool(Operation *)> ignoreFilter = nullptr);
/// Collect a set of leading one dimension removal patterns.
@@ -64,15 +64,14 @@ void populateSplitVectorTransferPatterns(
/// With them, there are more chances that we can cancel out extract-insert
/// pairs or forward write-read pairs.
void populateCastAwayVectorLeadingOneDimPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context);
+ OwningRewritePatternList &patterns);
/// Collect a set of patterns that bubble up/down bitcast ops.
///
/// These patterns move vector.bitcast ops to be before insert ops or after
/// extract ops where suitable. With them, bitcast will happen on smaller
/// vectors and there are more chances to share extract/insert ops.
-void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
- MLIRContext *context);
+void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns);
/// Collect a set of vector slices transformation patterns:
/// ExtractSlicesOpLowering, InsertSlicesOpLowering
@@ -82,15 +81,13 @@ void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
/// use for "slices" ops), this lowering removes all tuple related
/// operations as well (through DCE and folding). If tuple values
/// "leak" coming in, however, some tuple related ops will remain.
-void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
- MLIRContext *context);
+void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns);
/// Collect a set of transfer read/write lowering patterns.
///
/// These patterns lower transfer ops to simpler ops like `vector.load`,
/// `vector.store` and `vector.broadcast`.
-void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns,
- MLIRContext *context);
+void populateVectorTransferLoweringPatterns(OwningRewritePatternList &patterns);
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
@@ -174,7 +171,7 @@ struct VectorTransformsOptions {
/// These transformation express higher level vector ops in terms of more
/// elementary extraction, insertion, reduction, product, and broadcast ops.
void populateVectorContractLoweringPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
+ OwningRewritePatternList &patterns,
VectorTransformsOptions vectorTransformOptions = VectorTransformsOptions());
/// Returns the integer type required for subscripts in the vector dialect.
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index c797f53..bc491037 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -255,7 +255,8 @@ public:
PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
/// Returns true if the type of the held value is `T`.
- template <typename T> bool isa() const {
+ template <typename T>
+ bool isa() const {
assert(value && "isa<> used on a null value");
return kind == getKindOf<T>();
}
@@ -271,7 +272,8 @@ public:
/// Cast this value to type `T`, asserts if this value is not an instance of
/// `T`.
- template <typename T> T cast() const {
+ template <typename T>
+ T cast() const {
assert(isa<T>() && "expected value to be of type `T`");
return castImpl<T>();
}
@@ -290,7 +292,8 @@ public:
private:
/// Find the index of a given type in a range of other types.
- template <typename...> struct index_of_t;
+ template <typename...>
+ struct index_of_t;
template <typename T, typename... R>
struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
template <typename T, typename F, typename... R>
@@ -298,7 +301,8 @@ private:
: std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
/// Return the kind used for the given T.
- template <typename T> static Kind getKindOf() {
+ template <typename T>
+ static Kind getKindOf() {
return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
TypeRange, Value, ValueRange>::value);
}
@@ -718,14 +722,19 @@ class OwningRewritePatternList {
using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
public:
- OwningRewritePatternList() = default;
+ OwningRewritePatternList(MLIRContext *context) : context(context) {}
/// Construct a OwningRewritePatternList populated with the given pattern.
- OwningRewritePatternList(std::unique_ptr<RewritePattern> pattern) {
+ OwningRewritePatternList(MLIRContext *context,
+ std::unique_ptr<RewritePattern> pattern)
+ : context(context) {
nativePatterns.emplace_back(std::move(pattern));
}
OwningRewritePatternList(PDLPatternModule &&pattern)
- : pdlPatterns(std::move(pattern)) {}
+ : context(pattern.getModule()->getContext()),
+ pdlPatterns(std::move(pattern)) {}
+
+ MLIRContext *getContext() const { return context; }
/// Return the native patterns held in this list.
NativePatternListT &getNativePatterns() { return nativePatterns; }
@@ -750,7 +759,7 @@ public:
typename... ConstructorArgs,
typename = std::enable_if_t<sizeof...(Ts) != 0>>
OwningRewritePatternList &insert(ConstructorArg &&arg,
- ConstructorArgs &&...args) {
+ ConstructorArgs &&... args) {
// The following expands a call to emplace_back for each of the pattern
// types 'Ts'. This magic is necessary due to a limitation in the places
// that a parameter pack can be expanded in c++11.
@@ -761,7 +770,8 @@ public:
/// Add an instance of each of the pattern types 'Ts'. Return a reference to
/// `this` for chaining insertions.
- template <typename... Ts> OwningRewritePatternList &insert() {
+ template <typename... Ts>
+ OwningRewritePatternList &insert() {
(void)std::initializer_list<int>{0, (insertImpl<Ts>(), 0)...};
return *this;
}
@@ -785,16 +795,17 @@ private:
/// chaining insertions.
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
- insertImpl(Args &&...args) {
+ insertImpl(Args &&... args) {
nativePatterns.emplace_back(
std::make_unique<T>(std::forward<Args>(args)...));
}
template <typename T, typename... Args>
std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
- insertImpl(Args &&...args) {
+ insertImpl(Args &&... args) {
pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
}
+ MLIRContext *const context;
NativePatternListT nativePatterns;
PDLPatternModule pdlPatterns;
};
diff --git a/mlir/include/mlir/Transforms/Bufferize.h b/mlir/include/mlir/Transforms/Bufferize.h
index 29e16c2..9f2c0e3 100644
--- a/mlir/include/mlir/Transforms/Bufferize.h
+++ b/mlir/include/mlir/Transforms/Bufferize.h
@@ -56,8 +56,7 @@ void populateBufferizeMaterializationLegality(ConversionTarget &target);
///
/// In particular, these are the tensor_load/buffer_cast ops.
void populateEliminateBufferizeMaterializationsPatterns(
- MLIRContext *context, BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns);
+ BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns);
} // end namespace mlir
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5cc5d8a..b93fffa 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -425,20 +425,18 @@ private:
/// FunctionLike ops which use FunctionType to represent their type.
void populateFunctionLikeTypeConversionPattern(
StringRef functionLikeOpName, OwningRewritePatternList &patterns,
- MLIRContext *ctx, TypeConverter &converter);
+ TypeConverter &converter);
template <typename FuncOpT>
void populateFunctionLikeTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &converter) {
+ OwningRewritePatternList &patterns, TypeConverter &converter) {
populateFunctionLikeTypeConversionPattern(FuncOpT::getOperationName(),
- patterns, ctx, converter);
+ patterns, converter);
}
/// Add a pattern to the given pattern list to convert the signature of a FuncOp
/// with the given type converter.
void populateFuncOpTypeConversionPattern(OwningRewritePatternList &patterns,
- MLIRContext *ctx,
TypeConverter &converter);
//===----------------------------------------------------------------------===//
@@ -604,22 +602,26 @@ public:
/// Register a legality action for the given operation.
void setOpAction(OperationName op, LegalizationAction action);
- template <typename OpT> void setOpAction(LegalizationAction action) {
+ template <typename OpT>
+ void setOpAction(LegalizationAction action) {
setOpAction(OperationName(OpT::getOperationName(), &ctx), action);
}
/// Register the given operations as legal.
- template <typename OpT> void addLegalOp() {
+ template <typename OpT>
+ void addLegalOp() {
setOpAction<OpT>(LegalizationAction::Legal);
}
- template <typename OpT, typename OpT2, typename... OpTs> void addLegalOp() {
+ template <typename OpT, typename OpT2, typename... OpTs>
+ void addLegalOp() {
addLegalOp<OpT>();
addLegalOp<OpT2, OpTs...>();
}
/// Register the given operation as dynamically legal, i.e. requiring custom
/// handling by the target via 'isDynamicallyLegal'.
- template <typename OpT> void addDynamicallyLegalOp() {
+ template <typename OpT>
+ void addDynamicallyLegalOp() {
setOpAction<OpT>(LegalizationAction::Dynamic);
}
template <typename OpT, typename OpT2, typename... OpTs>
@@ -651,10 +653,12 @@ public:
/// Register the given operation as illegal, i.e. this operation is known to
/// not be supported by this target.
- template <typename OpT> void addIllegalOp() {
+ template <typename OpT>
+ void addIllegalOp() {
setOpAction<OpT>(LegalizationAction::Illegal);
}
- template <typename OpT, typename OpT2, typename... OpTs> void addIllegalOp() {
+ template <typename OpT, typename OpT2, typename... OpTs>
+ void addIllegalOp() {
addIllegalOp<OpT>();
addIllegalOp<OpT2, OpTs...>();
}
@@ -692,7 +696,8 @@ public:
SmallVector<StringRef, 2> dialectNames({name, names...});
setDialectAction(dialectNames, LegalizationAction::Legal);
}
- template <typename... Args> void addLegalDialect() {
+ template <typename... Args>
+ void addLegalDialect() {
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
setDialectAction(dialectNames, LegalizationAction::Legal);
}
@@ -736,7 +741,8 @@ public:
SmallVector<StringRef, 2> dialectNames({name, names...});
setDialectAction(dialectNames, LegalizationAction::Illegal);
}
- template <typename... Args> void addIllegalDialect() {
+ template <typename... Args>
+ void addIllegalDialect() {
SmallVector<StringRef, 2> dialectNames({Args::getDialectNamespace()...});
setDialectAction(dialectNames, LegalizationAction::Illegal);
}
diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
index de2e059..4c741d4 100644
--- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
+++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
@@ -747,7 +747,7 @@ public:
} // end namespace
void mlir::populateAffineToStdConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
// clang-format off
patterns.insert<
AffineApplyLowering,
@@ -761,25 +761,25 @@ void mlir::populateAffineToStdConversionPatterns(
AffineStoreLowering,
AffineForLowering,
AffineIfLowering,
- AffineYieldOpLowering>(ctx);
+ AffineYieldOpLowering>(patterns.getContext());
// clang-format on
}
void mlir::populateAffineToVectorConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
// clang-format off
patterns.insert<
AffineVectorLoadLowering,
- AffineVectorStoreLowering>(ctx);
+ AffineVectorStoreLowering>(patterns.getContext());
// clang-format on
}
namespace {
class LowerAffinePass : public ConvertAffineToStandardBase<LowerAffinePass> {
void runOnOperation() override {
- OwningRewritePatternList patterns;
- populateAffineToStdConversionPatterns(patterns, &getContext());
- populateAffineToVectorConversionPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateAffineToStdConversionPatterns(patterns);
+ populateAffineToVectorConversionPatterns(patterns);
ConversionTarget target(getContext());
target.addLegalDialect<memref::MemRefDialect, scf::SCFDialect,
StandardOpsDialect, VectorDialect>();
diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
index 3fe1c7f..23a826a 100644
--- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
+++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp
@@ -875,7 +875,7 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
// Convert async dialect types and operations to LLVM dialect.
AsyncRuntimeTypeConverter converter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
// We use conversion to LLVM type to lower async.runtime load and store
// operations.
@@ -883,8 +883,8 @@ void ConvertAsyncToLLVMPass::runOnOperation() {
llvmConverter.addConversion(AsyncRuntimeTypeConverter::convertAsyncTypes);
// Convert async types in function signatures and function calls.
- populateFuncOpTypeConversionPattern(patterns, ctx, converter);
- populateCallOpTypeConversionPattern(patterns, ctx, converter);
+ populateFuncOpTypeConversionPattern(patterns, converter);
+ populateCallOpTypeConversionPattern(patterns, converter);
// Convert return operations inside async.execute regions.
patterns.insert<ReturnOpOpConversion>(converter, ctx);
@@ -985,8 +985,8 @@ std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertAsyncToLLVMPass() {
}
void mlir::populateAsyncStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target) {
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target) {
typeConverter.addConversion([&](TokenType type) { return type; });
typeConverter.addConversion([&](ValueType type) {
return ValueType::get(typeConverter.convertType(type.getValueType()));
@@ -994,7 +994,7 @@ void mlir::populateAsyncStructuralTypeConversionsAndLegality(
patterns
.insert<ConvertExecuteOpTypes, ConvertAwaitOpTypes, ConvertYieldOpTypes>(
- typeConverter, context);
+ typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<AwaitOp, ExecuteOp, async::YieldOp>(
[&](Operation *op) { return typeConverter.isLegal(op); });
diff --git a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
index 00ab637..71b2fc0 100644
--- a/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
+++ b/mlir/lib/Conversion/ComplexToLLVM/ComplexToLLVM.cpp
@@ -284,7 +284,7 @@ void ConvertComplexToLLVMPass::runOnOperation() {
auto module = getOperation();
// Convert to the LLVM IR dialect using the converter defined above.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateComplexToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index d490c52..dde968c 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -308,13 +308,13 @@ private:
void GpuToLLVMConversionPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
LLVMConversionTarget target(getContext());
populateVectorToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
- populateAsyncStructuralTypeConversionsAndLegality(&getContext(), converter,
- patterns, target);
+ populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
+ target);
converter.addConversion(
[context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
index 9e16712..3a6548b 100644
--- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
+++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
@@ -125,12 +125,13 @@ struct LowerGpuOpsToNVVMOpsPass
return converter.convertType(MemRefType::Builder(type).setMemorySpace(0));
});
- OwningRewritePatternList patterns, llvmPatterns;
+ OwningRewritePatternList patterns(m.getContext());
+ OwningRewritePatternList llvmPatterns(m.getContext());
// Apply in-dialect lowering first. In-dialect lowering will replace ops
// which need to be lowered further, which is not supported by a single
// conversion pass.
- populateGpuRewritePatterns(m.getContext(), patterns);
+ populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
populateStdToLLVMConversionPatterns(converter, llvmPatterns);
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index d61c047..21ae015 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -60,9 +60,10 @@ struct LowerGpuOpsToROCDLOpsPass
/*useAlignedAlloc =*/false};
LLVMTypeConverter converter(m.getContext(), options);
- OwningRewritePatternList patterns, llvmPatterns;
+ OwningRewritePatternList patterns(m.getContext());
+ OwningRewritePatternList llvmPatterns(m.getContext());
- populateGpuRewritePatterns(m.getContext(), patterns);
+ populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(m, std::move(patterns));
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index 1e0a766..2bb1543 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -329,9 +329,9 @@ namespace {
#include "GPUToSPIRV.cpp.inc"
}
-void mlir::populateGPUToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
+ auto *context = patterns.getContext();
populateWithGenerated(context, patterns);
patterns.insert<
GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
index 8edb42e..a8644c8 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRVPass.cpp
@@ -57,9 +57,9 @@ void GPUToSPIRVPass::runOnOperation() {
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns;
- populateGPUToSPIRVPatterns(context, typeConverter, patterns);
- populateStandardToSPIRVPatterns(context, typeConverter, patterns);
+ OwningRewritePatternList patterns(context);
+ populateGPUToSPIRVPatterns(typeConverter, patterns);
+ populateStandardToSPIRVPatterns(typeConverter, patterns);
if (failed(applyFullConversion(kernelModules, *target, std::move(patterns))))
return signalPassFailure();
diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
index 5c0eb5e..e49d6b8 100644
--- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
+++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
@@ -221,7 +221,7 @@ void ConvertLinalgToLLVMPass::runOnOperation() {
auto module = getOperation();
// Convert to the LLVM IR dialect using the converter defined above.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateLinalgToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
index 0db760b..052dea4 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
@@ -203,8 +203,8 @@ LogicalResult SingleWorkgroupReduction::matchAndRewrite(
// Pattern population
//===----------------------------------------------------------------------===//
-void mlir::populateLinalgToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void mlir::populateLinalgToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<SingleWorkgroupReduction>(typeConverter, context);
+ patterns.insert<SingleWorkgroupReduction>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
index ddcc97d..d9df551 100644
--- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRVPass.cpp
@@ -30,9 +30,9 @@ void LinalgToSPIRVPass::runOnOperation() {
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns;
- populateLinalgToSPIRVPatterns(context, typeConverter, patterns);
- populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+ OwningRewritePatternList patterns(context);
+ populateLinalgToSPIRVPatterns(typeConverter, patterns);
+ populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
// Allow builtin ops.
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
diff --git a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
index bf947a4..ce4fe8a 100644
--- a/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
+++ b/mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
@@ -192,14 +192,14 @@ mlir::linalg::IndexedGenericOpToLibraryCallRewrite::matchAndRewrite(
/// Populate the given list with patterns that convert from Linalg to Standard.
void mlir::linalg::populateLinalgToStandardConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
// TODO: ConvOp conversion needs to export a descriptor with relevant
// attribute values such as kernel striding and dilation.
// clang-format off
patterns.insert<
CopyOpToLibraryCallRewrite,
CopyTransposeRewrite,
- IndexedGenericOpToLibraryCallRewrite>(ctx);
+ IndexedGenericOpToLibraryCallRewrite>(patterns.getContext());
patterns.insert<LinalgOpToLibraryCallRewrite>();
// clang-format on
}
@@ -218,8 +218,8 @@ void ConvertLinalgToStandardPass::runOnOperation() {
StandardOpsDialect>();
target.addLegalOp<ModuleOp, FuncOp, ModuleTerminatorOp, ReturnOp>();
target.addLegalOp<linalg::ReshapeOp, linalg::RangeOp>();
- OwningRewritePatternList patterns;
- populateLinalgToStandardConversionPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateLinalgToStandardConversionPatterns(patterns);
if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 7bc5100..833d51f 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -58,7 +58,7 @@ void ConvertOpenMPToLLVMPass::runOnOperation() {
auto module = getOperation();
// Convert to OpenMP operations with LLVM IR dialect
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
LLVMTypeConverter converter(&getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateOpenMPToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
index 9f5e4ab..b9602dd 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPU.cpp
@@ -642,9 +642,9 @@ ParallelToGpuLaunchLowering::matchAndRewrite(ParallelOp parallelOp,
return success();
}
-void mlir::populateParallelLoopToGPUPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- patterns.insert<ParallelToGpuLaunchLowering>(ctx);
+void mlir::populateParallelLoopToGPUPatterns(
+ OwningRewritePatternList &patterns) {
+ patterns.insert<ParallelToGpuLaunchLowering>(patterns.getContext());
}
void mlir::configureParallelLoopToGPULegality(ConversionTarget &target) {
diff --git a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
index 15075b5..a6ab449 100644
--- a/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
+++ b/mlir/lib/Conversion/SCFToGPU/SCFToGPUPass.cpp
@@ -47,8 +47,8 @@ struct ForLoopMapper : public ConvertAffineForToGPUBase<ForLoopMapper> {
struct ParallelLoopToGpuPass
: public ConvertParallelLoopToGpuBase<ParallelLoopToGpuPass> {
void runOnOperation() override {
- OwningRewritePatternList patterns;
- populateParallelLoopToGPUPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateParallelLoopToGPUPatterns(patterns);
ConversionTarget target(getContext());
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
configureParallelLoopToGPULegality(target);
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 3adb02a..46e67e5 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -90,7 +90,7 @@ static LogicalResult applyPatterns(FuncOp func) {
[](scf::YieldOp op) { return !isa<scf::ParallelOp>(op->getParentOp()); });
target.addLegalDialect<omp::OpenMPDialect>();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(func.getContext());
patterns.insert<ParallelOpLowering>(func.getContext());
FrozenRewritePatternList frozen(std::move(patterns));
return applyPartialConversion(func, target, frozen);
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 19837fe..344af68 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -319,10 +319,9 @@ LogicalResult TerminatorOpConversion::matchAndRewrite(
// Hooks
//===----------------------------------------------------------------------===//
-void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void mlir::populateSCFToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
OwningRewritePatternList &patterns) {
patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
- context, typeConverter, scfToSPIRVContext.getImpl());
+ patterns.getContext(), typeConverter, scfToSPIRVContext.getImpl());
}
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
index b0d8799..024ff2c 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRVPass.cpp
@@ -37,10 +37,10 @@ void SCFToSPIRVPass::runOnOperation() {
SPIRVTypeConverter typeConverter(targetAttr);
ScfToSPIRVContext scfContext;
- OwningRewritePatternList patterns;
- populateSCFToSPIRVPatterns(context, typeConverter, scfContext, patterns);
- populateStandardToSPIRVPatterns(context, typeConverter, patterns);
- populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+ OwningRewritePatternList patterns(context);
+ populateSCFToSPIRVPatterns(typeConverter, scfContext, patterns);
+ populateStandardToSPIRVPatterns(typeConverter, patterns);
+ populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(module, *target, std::move(patterns))))
return signalPassFailure();
diff --git a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
index b8f3140..5250d53 100644
--- a/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
+++ b/mlir/lib/Conversion/SCFToStandard/SCFToStandard.cpp
@@ -569,15 +569,15 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp,
}
void mlir::populateLoopToStdConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
patterns.insert<ForLowering, IfLowering, ParallelLowering, WhileLowering>(
- ctx);
- patterns.insert<DoWhileLowering>(ctx, /*benefit=*/2);
+ patterns.getContext());
+ patterns.insert<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}
void SCFToStandardPass::runOnOperation() {
- OwningRewritePatternList patterns;
- populateLoopToStdConversionPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateLoopToStdConversionPatterns(patterns);
// Configure conversion to lower out scf.for, scf.if, scf.parallel and
// scf.while. Anything else is fine.
ConversionTarget target(getContext());
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
index d152a73..7f3752f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/ConvertLaunchFuncToLLVMCalls.cpp
@@ -278,7 +278,7 @@ public:
/*emitCWrappers=*/true,
/*indexBitwidth=*/kDeriveIndexBitwidthFromDataLayout};
auto *context = module.getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
LLVMTypeConverter typeConverter(context, options);
populateStdToLLVMConversionPatterns(typeConverter, patterns);
patterns.insert<GPULaunchLowering>(typeConverter);
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 3a139b4..6f6d56f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1385,8 +1385,7 @@ void mlir::populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter) {
}
void mlir::populateSPIRVToLLVMConversionPatterns(
- MLIRContext *context, LLVMTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
+ LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
patterns.insert<
// Arithmetic ops
DirectConversionPattern<spirv::IAddOp, LLVM::AddOp>,
@@ -1496,20 +1495,18 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
// Return ops
- ReturnPattern, ReturnValuePattern>(context, typeConverter);
+ ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
}
void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
- MLIRContext *context, LLVMTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
- patterns.insert<FuncConversionPattern>(context, typeConverter);
+ LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+ patterns.insert<FuncConversionPattern>(patterns.getContext(), typeConverter);
}
void mlir::populateSPIRVToLLVMModuleConversionPatterns(
- MLIRContext *context, LLVMTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
+ LLVMTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
- context, typeConverter);
+ patterns.getContext(), typeConverter);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
index 2a4113f..a807b31 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.cpp
@@ -36,15 +36,15 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
// Encode global variable's descriptor set and binding if they exist.
encodeBindAttribute(module);
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
populateSPIRVToLLVMTypeConversion(converter);
- populateSPIRVToLLVMModuleConversionPatterns(context, converter, patterns);
- populateSPIRVToLLVMConversionPatterns(context, converter, patterns);
- populateSPIRVToLLVMFunctionConversionPatterns(context, converter, patterns);
+ populateSPIRVToLLVMModuleConversionPatterns(converter, patterns);
+ populateSPIRVToLLVMConversionPatterns(converter, patterns);
+ populateSPIRVToLLVMFunctionConversionPatterns(converter, patterns);
- ConversionTarget target(getContext());
+ ConversionTarget target(*context);
target.addIllegalDialect<spirv::SPIRVDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
diff --git a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
index af976056..28697ba 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ConvertShapeConstraints.cpp
@@ -37,10 +37,10 @@ public:
} // namespace
void mlir::populateConvertShapeConstraintsConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<CstrBroadcastableToRequire>(ctx);
- patterns.insert<CstrEqToRequire>(ctx);
- patterns.insert<ConvertCstrRequireOp>(ctx);
+ OwningRewritePatternList &patterns) {
+ patterns.insert<CstrBroadcastableToRequire>(patterns.getContext());
+ patterns.insert<CstrEqToRequire>(patterns.getContext());
+ patterns.insert<ConvertCstrRequireOp>(patterns.getContext());
}
namespace {
@@ -54,8 +54,8 @@ class ConvertShapeConstraints
auto func = getOperation();
auto *context = &getContext();
- OwningRewritePatternList patterns;
- populateConvertShapeConstraintsConversionPatterns(patterns, context);
+ OwningRewritePatternList patterns(context);
+ populateConvertShapeConstraintsConversionPatterns(patterns);
if (failed(applyPatternsAndFoldGreedily(func, std::move(patterns))))
return signalPassFailure();
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 2c06702..048e352 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -678,8 +678,8 @@ void ConvertShapeToStandardPass::runOnOperation() {
target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
// Setup conversion patterns.
- OwningRewritePatternList patterns;
- populateShapeToStandardConversionPatterns(patterns, &ctx);
+ OwningRewritePatternList patterns(&ctx);
+ populateShapeToStandardConversionPatterns(patterns);
// Apply conversion.
auto module = getOperation();
@@ -688,9 +688,9 @@ void ConvertShapeToStandardPass::runOnOperation() {
}
void mlir::populateShapeToStandardConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
// clang-format off
- populateWithGenerated(ctx, patterns);
+ populateWithGenerated(patterns.getContext(), patterns);
patterns.insert<
AnyOpConversion,
BinaryOpConversion<AddOp, AddIOp>,
@@ -705,7 +705,7 @@ void mlir::populateShapeToStandardConversionPatterns(
ShapeEqOpConverter,
ShapeOfOpConversion,
SplitAtOpConversion,
- ToExtentTensorOpConversion>(ctx);
+ ToExtentTensorOpConversion>(patterns.getContext());
// clang-format on
}
diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
index 2490f35..63036c4 100644
--- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
+++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
@@ -4079,7 +4079,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase<LLVMLoweringPass> {
llvm::DataLayout(this->dataLayout)};
LLVMTypeConverter typeConverter(&getContext(), options);
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);
LLVMConversionTarget target(getContext());
diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
index 00bf6c0..57f1b17 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
@@ -193,11 +193,12 @@ StoreOpOfSubViewFolder<OpTy>::matchAndRewrite(OpTy storeOp,
//===----------------------------------------------------------------------===//
void mlir::populateStdLegalizationPatternsForSPIRVLowering(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
patterns.insert<LoadOpOfSubViewFolder<memref::LoadOp>,
LoadOpOfSubViewFolder<vector::TransferReadOp>,
StoreOpOfSubViewFolder<memref::StoreOp>,
- StoreOpOfSubViewFolder<vector::TransferWriteOp>>(context);
+ StoreOpOfSubViewFolder<vector::TransferWriteOp>>(
+ patterns.getContext());
}
//===----------------------------------------------------------------------===//
@@ -212,9 +213,8 @@ struct SPIRVLegalization final
} // namespace
void SPIRVLegalization::runOnOperation() {
- OwningRewritePatternList patterns;
- auto *context = &getContext();
- populateStdLegalizationPatternsForSPIRVLowering(context, patterns);
+ OwningRewritePatternList patterns(&getContext());
+ populateStdLegalizationPatternsForSPIRVLowering(patterns);
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
}
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
index 025029a..8552db4 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp
@@ -1224,9 +1224,10 @@ XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef<Value> operands,
//===----------------------------------------------------------------------===//
namespace mlir {
-void populateStandardToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
+ MLIRContext *context = patterns.getContext();
+
patterns.insert<
// Math dialect operations.
// TODO: Move to separate pass.
@@ -1293,11 +1294,10 @@ void populateStandardToSPIRVPatterns(MLIRContext *context,
/*benefit=*/2);
}
-void populateTensorToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
int64_t byteCountThreshold,
OwningRewritePatternList &patterns) {
- patterns.insert<TensorExtractPattern>(typeConverter, context,
+ patterns.insert<TensorExtractPattern>(typeConverter, patterns.getContext(),
byteCountThreshold);
}
diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
index ce8419b..a1c6f98 100644
--- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp
@@ -35,11 +35,11 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns;
- populateStandardToSPIRVPatterns(context, typeConverter, patterns);
- populateTensorToSPIRVPatterns(context, typeConverter,
+ OwningRewritePatternList patterns(context);
+ populateStandardToSPIRVPatterns(typeConverter, patterns);
+ populateTensorToSPIRVPatterns(typeConverter,
/*byteCountThreshold=*/64, patterns);
- populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns);
+ populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
if (failed(applyPartialConversion(module, *target, std::move(patterns))))
return signalPassFailure();
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index fc83116..698fb5a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -989,7 +989,7 @@ public:
} // namespace
void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns) {
+ OwningRewritePatternList *patterns) {
patterns->insert<
PointwiseConverter<tosa::AddOp>, PointwiseConverter<tosa::SubOp>,
PointwiseConverter<tosa::MulOp>, PointwiseConverter<tosa::NegateOp>,
@@ -1014,5 +1014,6 @@ void mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
IdentityNConverter<tosa::IdentityNOp>, ReduceConverter<tosa::ReduceMinOp>,
ReduceConverter<tosa::ReduceMaxOp>, ReduceConverter<tosa::ReduceSumOp>,
ReduceConverter<tosa::ReduceProdOp>, ConcatConverter, ReshapeConverter,
- RescaleConverter, ReverseConverter, TransposeConverter>(context);
+ RescaleConverter, ReverseConverter, TransposeConverter>(
+ patterns->getContext());
}
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
index e0f1369..7d6815e 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp
@@ -37,7 +37,7 @@ public:
}
void runOnFunction() override {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
StandardOpsDialect>();
@@ -52,8 +52,7 @@ public:
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
FuncOp func = getFunction();
- mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(
- func.getContext(), &patterns);
+ mlir::tosa::populateTosaToLinalgOnTensorsConversionPatterns(&patterns);
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
index 55ed64b..4fb06d1 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCF.cpp
@@ -103,7 +103,7 @@ public:
} // namespace
void mlir::tosa::populateTosaToSCFConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns) {
- patterns->insert<IfOpConverter>(context);
- patterns->insert<WhileOpConverter>(context);
+ OwningRewritePatternList *patterns) {
+ patterns->insert<IfOpConverter>(patterns->getContext());
+ patterns->insert<WhileOpConverter>(patterns->getContext());
}
diff --git a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
index f403a46..9b562fa 100644
--- a/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
+++ b/mlir/lib/Conversion/TosaToSCF/TosaToSCFPass.cpp
@@ -29,15 +29,14 @@ namespace {
struct TosaToSCF : public TosaToSCFBase<TosaToSCF> {
public:
void runOnOperation() override {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext());
target.addLegalDialect<tensor::TensorDialect, scf::SCFDialect>();
target.addIllegalOp<tosa::IfOp, tosa::WhileOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
auto *op = getOperation();
- mlir::tosa::populateTosaToSCFConversionPatterns(op->getContext(),
- &patterns);
+ mlir::tosa::populateTosaToSCFConversionPatterns(&patterns);
if (failed(applyPartialConversion(op, target, std::move(patterns))))
signalPassFailure();
}
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
index 95f5c51..8db7868 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandard.cpp
@@ -154,12 +154,12 @@ public:
} // namespace
void mlir::tosa::populateTosaToStandardConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns) {
+ OwningRewritePatternList *patterns) {
patterns->insert<ApplyScaleOpConverter, ConstOpConverter, SliceOpConverter>(
- context);
+ patterns->getContext());
}
void mlir::tosa::populateTosaRescaleToStandardConversionPatterns(
- MLIRContext *context, OwningRewritePatternList *patterns) {
- patterns->insert<ApplyScaleOpConverter>(context);
+ OwningRewritePatternList *patterns) {
+ patterns->insert<ApplyScaleOpConverter>(patterns->getContext());
}
diff --git a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
index 14c800e..de8768b 100644
--- a/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
+++ b/mlir/lib/Conversion/TosaToStandard/TosaToStandardPass.cpp
@@ -29,17 +29,16 @@ namespace {
struct TosaToStandard : public TosaToStandardBase<TosaToStandard> {
public:
void runOnOperation() override {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
ConversionTarget target(getContext());
target.addIllegalOp<tosa::ConstOp>();
target.addIllegalOp<tosa::SliceOp>();
target.addIllegalOp<tosa::ApplyScaleOp>();
target.addLegalDialect<StandardOpsDialect>();
- auto *op = getOperation();
- mlir::tosa::populateTosaToStandardConversionPatterns(op->getContext(),
- &patterns);
- if (failed(applyPartialConversion(op, target, std::move(patterns))))
+ mlir::tosa::populateTosaToStandardConversionPatterns(&patterns);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
signalPassFailure();
}
};
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 8565774..b8c43c8 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -61,16 +61,16 @@ void LowerVectorToLLVMPass::runOnOperation() {
// Perform progressive lowering of operations on slices and
// all contraction operations. Also applies folding and DCE.
{
- OwningRewritePatternList patterns;
- populateVectorToVectorCanonicalizationPatterns(patterns, &getContext());
- populateVectorSlicesLoweringPatterns(patterns, &getContext());
- populateVectorContractLoweringPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorSlicesLoweringPatterns(patterns);
+ populateVectorContractLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(
converter, patterns, reassociateFPReductions, enableIndexOptimizations);
@@ -98,7 +98,7 @@ void LowerVectorToLLVMPass::runOnOperation() {
return false;
};
// Remove any ArmSVE-specific types from function signatures and results.
- populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
+ populateFuncOpTypeConversionPattern(patterns, converter);
target.addDynamicallyLegalOp<FuncOp>([hasScalableVectorType](FuncOp op) {
return !hasScalableVectorType(op.getType().getInputs()) &&
!hasScalableVectorType(op.getType().getResults());
diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
index 42c0726..4b097c5 100644
--- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
+++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp
@@ -158,7 +158,7 @@ struct LowerVectorToROCDLPass
void LowerVectorToROCDLPass::runOnOperation() {
LLVMTypeConverter converter(&getContext());
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
populateVectorToROCDLConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index dce5b64..3c7c457 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -694,11 +694,11 @@ LogicalResult VectorTransferRewriter<TransferWriteOp>::matchAndRewrite(
}
void populateVectorToSCFConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
+ OwningRewritePatternList &patterns,
const VectorTransferToSCFOptions &options) {
patterns.insert<VectorTransferRewriter<vector::TransferReadOp>,
- VectorTransferRewriter<vector::TransferWriteOp>>(options,
- context);
+ VectorTransferRewriter<vector::TransferWriteOp>>(
+ options, patterns.getContext());
}
} // namespace mlir
@@ -713,10 +713,9 @@ struct ConvertVectorToSCFPass
}
void runOnFunction() override {
- OwningRewritePatternList patterns;
- auto *context = getFunction().getContext();
+ OwningRewritePatternList patterns(getFunction().getContext());
populateVectorToSCFConversionPatterns(
- patterns, context, VectorTransferToSCFOptions().setUnroll(fullUnroll));
+ patterns, VectorTransferToSCFOptions().setUnroll(fullUnroll));
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 8d4fcba..2d8ffc0 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -241,12 +241,12 @@ struct VectorInsertStridedSliceOpConvert final
} // namespace
-void mlir::populateVectorToSPIRVPatterns(MLIRContext *context,
- SPIRVTypeConverter &typeConverter,
+void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
patterns.insert<VectorBitcastConvert, VectorBroadcastConvert,
VectorExtractElementOpConvert, VectorExtractOpConvert,
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert,
VectorInsertElementOpConvert, VectorInsertOpConvert,
- VectorInsertStridedSliceOpConvert>(typeConverter, context);
+ VectorInsertStridedSliceOpConvert>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
index 9a4d09f..b3c63848 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp
@@ -37,8 +37,8 @@ void LowerVectorToSPIRVPass::runOnOperation() {
spirv::SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
- OwningRewritePatternList patterns;
- populateVectorToSPIRVPatterns(context, typeConverter, patterns);
+ OwningRewritePatternList patterns(context);
+ populateVectorToSPIRVPatterns(typeConverter, patterns);
target->addLegalOp<ModuleOp, ModuleTerminatorOp>();
target->addLegalOp<FuncOp>();
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index e3834ea..62cad1f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -227,7 +227,7 @@ void AffineDataCopyGeneration::runOnFunction() {
// Promoting single iteration loops could lead to simplification of
// contained load's/store's, and the latter could anyway also be
// canonicalized.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
AffineLoadOp::getCanonicalizationPatterns(patterns, &getContext());
AffineStoreOp::getCanonicalizationPatterns(patterns, &getContext());
FrozenRewritePatternList frozenPatterns(std::move(patterns));
diff --git a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
index 918fec4..512ecd6 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SimplifyAffineStructures.cpp
@@ -79,7 +79,7 @@ mlir::createSimplifyAffineStructuresPass() {
void SimplifyAffineStructures::runOnFunction() {
auto func = getFunction();
simplifiedAttributes.clear();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(func.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, func.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, func.getContext());
AffineApplyOp::getCanonicalizationPatterns(patterns, func.getContext());
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index acd854d..12d3a73e 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -188,7 +188,7 @@ LogicalResult mlir::hoistAffineIfOp(AffineIfOp ifOp, bool *folded) {
// effective (no unused operands). Since the pattern rewriter's folding is
// entangled with application of patterns, we may fold/end up erasing the op,
// in which case we return with `folded` being set.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ifOp.getContext());
AffineIfOp::getCanonicalizationPatterns(patterns, ifOp.getContext());
bool erased;
FrozenRewritePatternList frozenPatterns(std::move(patterns));
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index f4f6e0b..cb124e3 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -270,7 +270,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
void AsyncParallelForPass::runOnFunction() {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
patterns.insert<AsyncParallelForRewrite>(ctx, numConcurrentAsyncExecute);
if (failed(applyPatternsAndFoldGreedily(getFunction(), std::move(patterns))))
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
index a17da42..99cc0b0 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp
@@ -485,7 +485,7 @@ void AsyncToAsyncRuntimePass::runOnOperation() {
// Lower async operations to async.runtime operations.
MLIRContext *ctx = module->getContext();
- OwningRewritePatternList asyncPatterns;
+ OwningRewritePatternList asyncPatterns(ctx);
// Async lowering does not use type converter because it must preserve all
// types for async.runtime operations.
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 8e9ec0b..3e4189d 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -401,7 +401,6 @@ struct GpuAllReduceConversion : public RewritePattern {
};
} // namespace
-void mlir::populateGpuAllReducePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
- patterns.insert<GpuAllReduceConversion>(context);
+void mlir::populateGpuAllReducePatterns(OwningRewritePatternList &patterns) {
+ patterns.insert<GpuAllReduceConversion>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
index 419226b..df195af 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp
@@ -323,8 +323,8 @@ struct LinalgBufferizePass : public LinalgBufferizeBase<LinalgBufferizePass> {
target.addDynamicallyLegalDialect<linalg::LinalgDialect>(isLegalOperation);
target.addDynamicallyLegalOp<ConstantOp>(isLegalOperation);
- OwningRewritePatternList patterns;
- populateLinalgBufferizePatterns(&context, typeConverter, patterns);
+ OwningRewritePatternList patterns(&context);
+ populateLinalgBufferizePatterns(typeConverter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
signalPassFailure();
@@ -337,8 +337,7 @@ std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgBufferizePass() {
}
void mlir::linalg::populateLinalgBufferizePatterns(
- MLIRContext *context, BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
+ BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
patterns.insert<BufferizeAnyLinalgOp>(typeConverter);
// TODO: Drop this once tensor constants work in standard.
// clang-format off
@@ -347,6 +346,6 @@ void mlir::linalg::populateLinalgBufferizePatterns(
BufferizeInitTensorOp,
SubTensorOpConverter,
SubTensorInsertOpConverter
- >(typeConverter, context);
+ >(typeConverter, patterns.getContext());
// clang-format on
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
index cd7b481..a7e1332 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/CodegenStrategy.cpp
@@ -76,7 +76,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
// Programmatic splitting of slow/fast path vector transfers.
if (lateCodegenStrategyOptions.enableVectorTransferPartialRewrite) {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
patterns.insert<vector::VectorTransferFullPartialRewriter>(
context, vectorTransformsOptions);
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
@@ -84,7 +84,7 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
// Programmatic controlled lowering of vector.contract only.
if (lateCodegenStrategyOptions.enableVectorContractLowering) {
- OwningRewritePatternList vectorContractLoweringPatterns;
+ OwningRewritePatternList vectorContractLoweringPatterns(context);
vectorContractLoweringPatterns
.insert<ContractionOpToOuterProductOpLowering,
ContractionOpToMatmulOpLowering, ContractionOpLowering>(
@@ -95,8 +95,8 @@ void mlir::linalg::CodegenStrategy::transform(FuncOp func) const {
// Programmatic controlled lowering of vector.transfer only.
if (lateCodegenStrategyOptions.enableVectorToSCFConversion) {
- OwningRewritePatternList vectorToLoopsPatterns;
- populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
+ OwningRewritePatternList vectorToLoopsPatterns(context);
+ populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
vectorToSCFOptions);
(void)applyPatternsAndFoldGreedily(func, std::move(vectorToLoopsPatterns));
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 2d34468..cc95218 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -163,7 +163,7 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
void runOnFunction() override {
auto *context = &getContext();
DetensorizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
target.addDynamicallyLegalOp<GenericOp>([&](GenericOp op) {
@@ -199,13 +199,12 @@ struct LinalgDetensorize : public LinalgDetensorizeBase<LinalgDetensorize> {
context, typeConverter);
// Since non-entry block arguments get detensorized, we also need to update
// the control flow inside the function to reflect the correct types.
- populateBranchOpInterfaceTypeConversionPattern(patterns, context,
- typeConverter);
+ populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
if (failed(applyFullConversion(getFunction(), target, std::move(patterns))))
signalPassFailure();
- OwningRewritePatternList canonPatterns;
+ OwningRewritePatternList canonPatterns(context);
canonPatterns.insert<ExtractFromReshapeFromElements>(context);
if (failed(applyPatternsAndFoldGreedily(getFunction(),
std::move(canonPatterns))))
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index c7b7640..a8db840 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -490,14 +490,15 @@ struct FoldReshapeOpWithUnitExtent : OpRewritePattern<TensorReshapeOp> {
/// Patterns that are used to canonicalize the use of unit-extent dims for
/// broadcasting.
void mlir::populateLinalgFoldUnitExtentDimsPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
+ auto *context = patterns.getContext();
patterns
.insert<FoldUnitDimLoops<GenericOp>, FoldUnitDimLoops<IndexedGenericOp>,
ReplaceUnitExtentTensors<GenericOp>,
ReplaceUnitExtentTensors<IndexedGenericOp>>(context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
patterns.insert<FoldReshapeOpWithUnitExtent>(context);
- populateFoldUnitDimsReshapeOpsByLinearizationPatterns(context, patterns);
+ populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns);
}
namespace {
@@ -505,14 +506,14 @@ namespace {
struct LinalgFoldUnitExtentDimsPass
: public LinalgFoldUnitExtentDimsBase<LinalgFoldUnitExtentDimsPass> {
void runOnFunction() override {
- OwningRewritePatternList patterns;
FuncOp funcOp = getFunction();
MLIRContext *context = funcOp.getContext();
+ OwningRewritePatternList patterns(context);
if (foldOneTripLoopsOnly)
patterns.insert<FoldUnitDimLoops<GenericOp>,
FoldUnitDimLoops<IndexedGenericOp>>(context);
else
- populateLinalgFoldUnitExtentDimsPatterns(context, patterns);
+ populateLinalgFoldUnitExtentDimsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns));
}
};
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
index 1d50e06..48677df 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp
@@ -116,7 +116,7 @@ struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern {
} // namespace
void mlir::populateElementwiseToLinalgConversionPatterns(
- OwningRewritePatternList &patterns, MLIRContext *) {
+ OwningRewritePatternList &patterns) {
patterns.insert<ConvertAnyElementwiseMappableOpOnRankedTensors>();
}
@@ -128,9 +128,9 @@ class ConvertElementwiseToLinalgPass
auto func = getOperation();
auto *context = &getContext();
ConversionTarget target(*context);
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
- populateElementwiseToLinalgConversionPatterns(patterns, context);
+ populateElementwiseToLinalgConversionPatterns(patterns);
target.markUnknownOpDynamicallyLegal([](Operation *op) {
return !isElementwiseMappableOpOnRankedTensors(op);
});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index ad7ad11..a61102d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -1112,9 +1112,9 @@ struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
struct FusionOfTensorOpsPass
: public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
void runOnOperation() override {
- OwningRewritePatternList patterns;
Operation *op = getOperation();
- populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
+ OwningRewritePatternList patterns(op->getContext());
+ populateLinalgTensorOpsFusionPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
@@ -1125,9 +1125,9 @@ struct FoldReshapeOpsByLinearizationPass
: public LinalgFoldReshapeOpsByLinearizationBase<
FoldReshapeOpsByLinearizationPass> {
void runOnOperation() override {
- OwningRewritePatternList patterns;
Operation *op = getOperation();
- populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns);
+ OwningRewritePatternList patterns(op->getContext());
+ populateFoldReshapeOpsByLinearizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
}
};
@@ -1135,33 +1135,36 @@ struct FoldReshapeOpsByLinearizationPass
} // namespace
void mlir::populateFoldReshapeOpsByLinearizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, false>,
FoldProducerReshapeOpByLinearization<IndexedGenericOp, false>,
- FoldConsumerReshapeOpByLinearization<false>>(context);
+ FoldConsumerReshapeOpByLinearization<false>>(
+ patterns.getContext());
}
void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp, true>,
FoldProducerReshapeOpByLinearization<IndexedGenericOp, true>,
- FoldConsumerReshapeOpByLinearization<true>>(context);
+ FoldConsumerReshapeOpByLinearization<true>>(
+ patterns.getContext());
}
void mlir::populateFoldReshapeOpsByExpansionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
patterns.insert<FoldReshapeWithGenericOpByExpansion,
FoldWithProducerReshapeOpByExpansion<GenericOp>,
FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
- context);
+ patterns.getContext());
}
void mlir::populateLinalgTensorOpsFusionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
+ auto *context = patterns.getContext();
patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
FoldSplatConstants<GenericOp>,
FoldSplatConstants<IndexedGenericOp>>(context);
- populateFoldReshapeOpsByExpansionPatterns(context, patterns);
+ populateFoldReshapeOpsByExpansionPatterns(patterns);
GenericOp::getCanonicalizationPatterns(patterns, context);
IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
index 69de55c..3783ef5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
@@ -143,9 +143,9 @@ struct LinalgGeneralizationPass
void LinalgGeneralizationPass::runOnFunction() {
FuncOp func = getFunction();
- OwningRewritePatternList patterns;
- linalg::populateLinalgConvGeneralizationPatterns(&getContext(), patterns);
- linalg::populateLinalgNamedOpsGeneralizationPatterns(&getContext(), patterns);
+ OwningRewritePatternList patterns(&getContext());
+ linalg::populateLinalgConvGeneralizationPatterns(patterns);
+ linalg::populateLinalgNamedOpsGeneralizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(func.getBody(), std::move(patterns));
}
@@ -167,15 +167,16 @@ linalg::GenericOp GeneralizeConvOp::createGenericOp(linalg::ConvOp convOp,
}
void mlir::linalg::populateLinalgConvGeneralizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
+ OwningRewritePatternList &patterns,
linalg::LinalgTransformationFilter marker) {
- patterns.insert<GeneralizeConvOp>(context, marker);
+ patterns.insert<GeneralizeConvOp>(patterns.getContext(), marker);
}
void mlir::linalg::populateLinalgNamedOpsGeneralizationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
+ OwningRewritePatternList &patterns,
linalg::LinalgTransformationFilter marker) {
- patterns.insert<LinalgNamedOpGeneralizationPattern>(context, marker);
+ patterns.insert<LinalgNamedOpGeneralizationPattern>(patterns.getContext(),
+ marker);
}
std::unique_ptr<OperationPass<FuncOp>> mlir::createLinalgGeneralizationPass() {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index cc0cce7..635855f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -378,7 +378,7 @@ void mlir::linalg::hoistRedundantVectorTransfersOnTensor(FuncOp func) {
// Apply canonicalization so the newForOp + yield folds immediately, thus
// cleaning up the IR and potentially enabling more hoisting.
if (changed) {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(func->getContext());
scf::ForOp::getCanonicalizationPatterns(patterns, func->getContext());
(void)applyPatternsAndFoldGreedily(func, std::move(patterns));
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index d6423f4..10b4cac 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -545,7 +545,7 @@ template <typename LoopType>
static void lowerLinalgToLoopsImpl(FuncOp funcOp,
ArrayRef<unsigned> interchangeVector) {
MLIRContext *context = funcOp.getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
patterns.insert<LinalgRewritePattern<LoopType>>(interchangeVector);
memref::DimOp::getCanonicalizationPatterns(patterns, context);
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp b/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
index d9c2580..1fc82d5 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SparseLowering.cpp
@@ -137,8 +137,8 @@ public:
/// Populates the given patterns list with conversion rules required for
/// the sparsification of linear algebra operations.
void linalg::populateSparsificationConversionPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns) {
+ OwningRewritePatternList &patterns) {
patterns.insert<TensorFromPointerConverter, TensorToDimSizeConverter,
TensorToPointersConverter, TensorToIndicesConverter,
- TensorToValuesConverter>(context);
+ TensorToValuesConverter>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
index a940bd6..c740241 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp
@@ -1361,7 +1361,6 @@ private:
/// Populates the given patterns list with rewriting rules required for
/// the sparsification of linear algebra operations.
void linalg::populateSparsificationPatterns(
- MLIRContext *context, OwningRewritePatternList &patterns,
- const SparsificationOptions &options) {
- patterns.insert<GenericOpSparsifier>(context, options);
+ OwningRewritePatternList &patterns, const SparsificationOptions &options) {
+ patterns.insert<GenericOpSparsifier>(patterns.getContext(), options);
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index d638c60..3f4c698 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -511,15 +511,15 @@ class CanonicalizationPatternList;
template <>
class CanonicalizationPatternList<> {
public:
- static void insert(OwningRewritePatternList &patterns, MLIRContext *ctx) {}
+ static void insert(OwningRewritePatternList &patterns) {}
};
template <typename OpTy, typename... OpTypes>
class CanonicalizationPatternList<OpTy, OpTypes...> {
public:
- static void insert(OwningRewritePatternList &patterns, MLIRContext *ctx) {
- OpTy::getCanonicalizationPatterns(patterns, ctx);
- CanonicalizationPatternList<OpTypes...>::insert(patterns, ctx);
+ static void insert(OwningRewritePatternList &patterns) {
+ OpTy::getCanonicalizationPatterns(patterns, patterns.getContext());
+ CanonicalizationPatternList<OpTypes...>::insert(patterns);
}
};
@@ -531,32 +531,34 @@ template <>
class RewritePatternList<> {
public:
static void insert(OwningRewritePatternList &patterns,
- const LinalgTilingOptions &options, MLIRContext *ctx) {}
+ const LinalgTilingOptions &options) {}
};
template <typename OpTy, typename... OpTypes>
class RewritePatternList<OpTy, OpTypes...> {
public:
static void insert(OwningRewritePatternList &patterns,
- const LinalgTilingOptions &options, MLIRContext *ctx) {
+ const LinalgTilingOptions &options) {
+ auto *ctx = patterns.getContext();
patterns.insert<LinalgTilingPattern<OpTy>>(
ctx, options,
LinalgTransformationFilter(ArrayRef<Identifier>{},
Identifier::get("tiled", ctx)));
- RewritePatternList<OpTypes...>::insert(patterns, options, ctx);
+ RewritePatternList<OpTypes...>::insert(patterns, options);
}
};
} // namespace
OwningRewritePatternList
mlir::linalg::getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx) {
- OwningRewritePatternList patterns;
- populateLinalgTilingCanonicalizationPatterns(patterns, ctx);
+ OwningRewritePatternList patterns(ctx);
+ populateLinalgTilingCanonicalizationPatterns(patterns);
return patterns;
}
void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
+ auto *ctx = patterns.getContext();
AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
AffineForOp::getCanonicalizationPatterns(patterns, ctx);
AffineMinOp::getCanonicalizationPatterns(patterns, ctx);
@@ -571,17 +573,16 @@ void mlir::linalg::populateLinalgTilingCanonicalizationPatterns(
CanonicalizationPatternList<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
- >::insert(patterns, ctx);
+ >::insert(patterns);
}
/// Populate the given list with patterns that apply Linalg tiling.
static void insertTilingPatterns(OwningRewritePatternList &patterns,
- const LinalgTilingOptions &options,
- MLIRContext *ctx) {
+ const LinalgTilingOptions &options) {
RewritePatternList<GenericOp, IndexedGenericOp,
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
- >::insert(patterns, options, ctx);
+ >::insert(patterns, options);
}
static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType,
@@ -590,8 +591,8 @@ static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType,
auto options =
LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType);
MLIRContext *ctx = funcOp.getContext();
- OwningRewritePatternList patterns;
- insertTilingPatterns(patterns, options, ctx);
+ OwningRewritePatternList patterns(ctx);
+ insertTilingPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
(void)applyPatternsAndFoldGreedily(
funcOp, getLinalgTilingCanonicalizationPatterns(ctx));
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
index dab32d2..b56072c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
@@ -580,8 +580,8 @@ static void
populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
OwningRewritePatternList &promotionPatterns,
OwningRewritePatternList &vectorizationPatterns,
- ArrayRef<int64_t> tileSizes,
- MLIRContext *context) {
+ ArrayRef<int64_t> tileSizes) {
+ auto *context = tilingPatterns.getContext();
if (tileSizes.size() < N)
return;
@@ -608,45 +608,47 @@ populateVectorizationPatterns(OwningRewritePatternList &tilingPatterns,
void mlir::linalg::populateConvVectorizationPatterns(
MLIRContext *context, SmallVectorImpl<OwningRewritePatternList> &patterns,
ArrayRef<int64_t> tileSizes) {
- OwningRewritePatternList tiling, promotion, vectorization;
+ OwningRewritePatternList tiling(context);
+ OwningRewritePatternList promotion(context);
+ OwningRewritePatternList vectorization(context);
populateVectorizationPatterns<ConvWOp, 1>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvNWCOp, 3>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvInputNWCFilterWCFOp, 3>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
populateVectorizationPatterns<ConvNCWOp, 3>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvInputNCWFilterWCFOp, 3>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
populateVectorizationPatterns<ConvHWOp, 2>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvNHWCOp, 4>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvInputNHWCFilterHWCFOp, 4>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
populateVectorizationPatterns<ConvNCHWOp, 4>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
populateVectorizationPatterns<ConvInputNCHWFilterHWCFOp, 4>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
populateVectorizationPatterns<ConvDHWOp, 3>(tiling, promotion, vectorization,
- tileSizes, context);
+ tileSizes);
- populateVectorizationPatterns<ConvNDHWCOp, 5>(
- tiling, promotion, vectorization, tileSizes, context);
+ populateVectorizationPatterns<ConvNDHWCOp, 5>(tiling, promotion,
+ vectorization, tileSizes);
populateVectorizationPatterns<ConvInputNDHWCFilterDHWCFOp, 5>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
- populateVectorizationPatterns<ConvNCDHWOp, 5>(
- tiling, promotion, vectorization, tileSizes, context);
+ populateVectorizationPatterns<ConvNCDHWOp, 5>(tiling, promotion,
+ vectorization, tileSizes);
populateVectorizationPatterns<ConvInputNCDHWFilterDHWCFOp, 5>(
- tiling, promotion, vectorization, tileSizes, context);
+ tiling, promotion, vectorization, tileSizes);
patterns.push_back(std::move(tiling));
patterns.push_back(std::move(promotion));
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
index 06d5158..d61dc31 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandTanh.cpp
@@ -60,7 +60,6 @@ public:
};
} // namespace
-void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
- patterns.insert<TanhOpConverter>(ctx);
+void mlir::populateExpandTanhPattern(OwningRewritePatternList &patterns) {
+ patterns.insert<TanhOpConverter>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
index f13e48e..6c5d74f 100644
--- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
@@ -10,6 +10,7 @@
// that do not rely on any of the library functions.
//
//===----------------------------------------------------------------------===//
+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/Dialect/Math/IR/Math.h"
@@ -17,9 +18,10 @@
#include "mlir/Dialect/Vector/VectorOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/Transforms/Bufferize.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include <limits.h>
+#include <climits>
using namespace mlir;
using namespace mlir::vector;
@@ -530,7 +532,7 @@ ExpApproximation::matchAndRewrite(math::ExpOp op,
//----------------------------------------------------------------------------//
void mlir::populateMathPolynomialApproximationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
+ OwningRewritePatternList &patterns) {
patterns.insert<TanhApproximation, LogApproximation, Log2Approximation,
- ExpApproximation>(ctx);
+ ExpApproximation>(patterns.getContext());
}
diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
index f67020d..44d8be9 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertConst.cpp
@@ -91,7 +91,7 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
}
void ConvertConstPass::runOnFunction() {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
auto func = getFunction();
auto *context = &getContext();
patterns.insert<QuantizedConstRewrite>(context);
diff --git a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
index daa1cda..ac28ce6 100644
--- a/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/ConvertSimQuant.cpp
@@ -124,8 +124,8 @@ public:
void ConvertSimulatedQuantPass::runOnFunction() {
bool hadFailure = false;
- OwningRewritePatternList patterns;
auto func = getFunction();
+ OwningRewritePatternList patterns(func.getContext());
auto ctx = func.getContext();
patterns.insert<ConstFakeQuantRewrite, ConstFakeQuantPerAxisRewrite>(
ctx, &hadFailure);
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index aa25f47..15a5aba 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -25,12 +25,12 @@ struct SCFBufferizePass : public SCFBufferizeBase<SCFBufferizePass> {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
populateBufferizeMaterializationLegality(target);
- populateSCFStructuralTypeConversionsAndLegality(context, typeConverter,
- patterns, target);
+ populateSCFStructuralTypeConversionsAndLegality(typeConverter, patterns,
+ target);
if (failed(applyPartialConversion(func, target, std::move(patterns))))
return signalPassFailure();
};
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 9197375..0029c3b7 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -134,10 +134,10 @@ public:
} // namespace
void mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target) {
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target) {
patterns.insert<ConvertForOpTypes, ConvertIfOpTypes, ConvertYieldOpTypes>(
- typeConverter, context);
+ typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
return typeConverter.isLegal(op->getResultTypes());
});
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
index 0aa41394..c5eeb8a 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVGLSLCanonicalization.cpp
@@ -23,13 +23,14 @@ namespace {
namespace mlir {
namespace spirv {
void populateSPIRVGLSLCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
+ OwningRewritePatternList &results) {
results.insert<ConvertComparisonIntoClampSPV_FOrdLessThanOp,
ConvertComparisonIntoClampSPV_FOrdLessThanEqualOp,
ConvertComparisonIntoClampSPV_SLessThanOp,
ConvertComparisonIntoClampSPV_SLessThanEqualOp,
ConvertComparisonIntoClampSPV_ULessThanOp,
- ConvertComparisonIntoClampSPV_ULessThanEqualOp>(context);
+ ConvertComparisonIntoClampSPV_ULessThanEqualOp>(
+ results.getContext());
}
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
index c4954ca..afaadb0 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp
@@ -74,10 +74,10 @@ public:
};
} // namespace
-static void populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns,
- MLIRContext *ctx) {
+static void
+populateSPIRVLayoutInfoPatterns(OwningRewritePatternList &patterns) {
patterns.insert<SPIRVGlobalVariableOpLayoutInfoDecoration,
- SPIRVAddressOfOpLayoutInfoDecoration>(ctx);
+ SPIRVAddressOfOpLayoutInfoDecoration>(patterns.getContext());
}
namespace {
@@ -90,8 +90,8 @@ class DecorateSPIRVCompositeTypeLayoutPass
void DecorateSPIRVCompositeTypeLayoutPass::runOnOperation() {
auto module = getOperation();
- OwningRewritePatternList patterns;
- populateSPIRVLayoutInfoPatterns(patterns, module.getContext());
+ OwningRewritePatternList patterns(module.getContext());
+ populateSPIRVLayoutInfoPatterns(patterns);
ConversionTarget target(*(module.getContext()));
target.addLegalDialect<spirv::SPIRVDialect>();
target.addLegalOp<FuncOp>();
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
index d96892b..71ebf8c 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp
@@ -246,7 +246,7 @@ void LowerABIAttributesPass::runOnOperation() {
return builder.create<spirv::BitcastOp>(loc, type, inputs[0]).getResult();
});
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
patterns.insert<ProcessInterfaceVarABI>(typeConverter, context);
ConversionTarget target(*context);
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index c544512..4aa8bd4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -515,9 +515,8 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
}
void mlir::populateBuiltinFuncToSPIRVPatterns(
- MLIRContext *context, SPIRVTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
- patterns.insert<FuncOpConversion>(typeConverter, context);
+ SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+ patterns.insert<FuncOpConversion>(typeConverter, patterns.getContext());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
index 36b5eac..779993c 100644
--- a/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/Bufferize.cpp
@@ -19,13 +19,13 @@ struct ShapeBufferizePass : public ShapeBufferizeBase<ShapeBufferizePass> {
void runOnFunction() override {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&ctx);
BufferizeTypeConverter typeConverter;
- ConversionTarget target(getContext());
+ ConversionTarget target(ctx);
populateBufferizeMaterializationLegality(target);
- populateShapeStructuralTypeConversionsAndLegality(&ctx, typeConverter,
- patterns, target);
+ populateShapeStructuralTypeConversionsAndLegality(typeConverter, patterns,
+ target);
if (failed(
applyPartialConversion(getFunction(), target, std::move(patterns))))
diff --git a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
index 492abce..b712264 100644
--- a/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/RemoveShapeConstraints.cpp
@@ -46,8 +46,8 @@ class RemoveShapeConstraintsPass
void runOnFunction() override {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns;
- populateRemoveShapeConstraintsPatterns(patterns, &ctx);
+ OwningRewritePatternList patterns(&ctx);
+ populateRemoveShapeConstraintsPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
@@ -56,8 +56,9 @@ class RemoveShapeConstraintsPass
} // namespace
void mlir::populateRemoveShapeConstraintsPatterns(
- OwningRewritePatternList &patterns, MLIRContext *ctx) {
- patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(ctx);
+ OwningRewritePatternList &patterns) {
+ patterns.insert<RemoveCstrBroadcastableOp, RemoveCstrEqOp>(
+ patterns.getContext());
}
std::unique_ptr<FunctionPass> mlir::createRemoveShapeConstraintsPass() {
diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
index 6190ff3..479ce71 100644
--- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
@@ -61,8 +61,8 @@ struct ShapeToShapeLowering
void ShapeToShapeLowering::runOnFunction() {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns;
- populateShapeRewritePatterns(&ctx, patterns);
+ OwningRewritePatternList patterns(&ctx);
+ populateShapeRewritePatterns(patterns);
ConversionTarget target(getContext());
target.addLegalDialect<ShapeDialect, StandardOpsDialect>();
@@ -72,9 +72,8 @@ void ShapeToShapeLowering::runOnFunction() {
signalPassFailure();
}
-void mlir::populateShapeRewritePatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
- patterns.insert<NumElementsOpConverter>(context);
+void mlir::populateShapeRewritePatterns(OwningRewritePatternList &patterns) {
+ patterns.insert<NumElementsOpConverter>(patterns.getContext());
}
std::unique_ptr<Pass> mlir::createShapeToShapeLowering() {
diff --git a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
index 041b54b..6ebf9fc 100644
--- a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp
@@ -57,10 +57,10 @@ public:
} // namespace
void mlir::populateShapeStructuralTypeConversionsAndLegality(
- MLIRContext *context, TypeConverter &typeConverter,
- OwningRewritePatternList &patterns, ConversionTarget &target) {
+ TypeConverter &typeConverter, OwningRewritePatternList &patterns,
+ ConversionTarget &target) {
patterns.insert<ConvertAssumingOpTypes, ConvertAssumingYieldOpTypes>(
- typeConverter, context);
+ typeConverter, patterns.getContext());
target.addDynamicallyLegalOp<AssumingOp>([&](AssumingOp op) {
return typeConverter.isLegal(op.getResultTypes());
});
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
index c2b9c93..6eeb39e 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp
@@ -54,10 +54,10 @@ public:
};
} // namespace
-void mlir::populateStdBufferizePatterns(MLIRContext *context,
- BufferizeTypeConverter &typeConverter,
+void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeDimOp, BufferizeSelectOp>(typeConverter, context);
+ patterns.insert<BufferizeDimOp, BufferizeSelectOp>(typeConverter,
+ patterns.getContext());
}
namespace {
@@ -65,14 +65,14 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<memref::MemRefDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<scf::SCFDialect>();
- populateStdBufferizePatterns(context, typeConverter, patterns);
+ populateStdBufferizePatterns(typeConverter, patterns);
// We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to
// touch the data).
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
index 98b261c..3f2504e 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/ExpandOps.cpp
@@ -211,8 +211,8 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
void runOnFunction() override {
MLIRContext &ctx = getContext();
- OwningRewritePatternList patterns;
- populateStdExpandOpsPatterns(&ctx, patterns);
+ OwningRewritePatternList patterns(&ctx);
+ populateStdExpandOpsPatterns(patterns);
ConversionTarget target(getContext());
@@ -234,11 +234,10 @@ struct StdExpandOpsPass : public StdExpandOpsBase<StdExpandOpsPass> {
} // namespace
-void mlir::populateStdExpandOpsPatterns(MLIRContext *context,
- OwningRewritePatternList &patterns) {
+void mlir::populateStdExpandOpsPatterns(OwningRewritePatternList &patterns) {
patterns.insert<AtomicRMWOpConverter, MemRefReshapeOpConverter,
SignedCeilDivIOpConverter, SignedFloorDivIOpConverter>(
- context);
+ patterns.getContext());
}
std::unique_ptr<Pass> mlir::createStdExpandOpsPass() {
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
index d38a564..04424c7 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncBufferize.cpp
@@ -28,21 +28,20 @@ struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
- populateFuncOpTypeConversionPattern(patterns, context, typeConverter);
+ populateFuncOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp op) {
return typeConverter.isSignatureLegal(op.getType()) &&
typeConverter.isLegal(&op.getBody());
});
- populateCallOpTypeConversionPattern(patterns, context, typeConverter);
+ populateCallOpTypeConversionPattern(patterns, typeConverter);
target.addDynamicallyLegalOp<CallOp>(
[&](CallOp op) { return typeConverter.isLegal(op); });
- populateBranchOpInterfaceTypeConversionPattern(patterns, context,
- typeConverter);
- populateReturnOpTypeConversionPattern(patterns, context, typeConverter);
+ populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter);
+ populateReturnOpTypeConversionPattern(patterns, typeConverter);
target.addLegalOp<ModuleOp, ModuleTerminatorOp, memref::TensorLoadOp,
memref::BufferCastOp>();
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
index 4ba2069..4008676 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp
@@ -38,9 +38,8 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
} // end anonymous namespace
void mlir::populateCallOpTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &converter) {
- patterns.insert<CallOpSignatureConversion>(converter, ctx);
+ OwningRewritePatternList &patterns, TypeConverter &converter) {
+ patterns.insert<CallOpSignatureConversion>(converter, patterns.getContext());
}
namespace {
@@ -103,9 +102,9 @@ public:
} // end anonymous namespace
void mlir::populateBranchOpInterfaceTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &typeConverter) {
- patterns.insert<BranchOpInterfaceTypeConversion>(typeConverter, ctx);
+ OwningRewritePatternList &patterns, TypeConverter &typeConverter) {
+ patterns.insert<BranchOpInterfaceTypeConversion>(typeConverter,
+ patterns.getContext());
}
bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
@@ -125,9 +124,8 @@ bool mlir::isLegalForBranchOpInterfaceTypeConversionPattern(
}
void mlir::populateReturnOpTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &typeConverter) {
- patterns.insert<ReturnOpTypeConversion>(typeConverter, ctx);
+ OwningRewritePatternList &patterns, TypeConverter &typeConverter) {
+ patterns.insert<ReturnOpTypeConversion>(typeConverter, patterns.getContext());
}
bool mlir::isLegalForReturnOpTypeConversionPattern(Operation *op,
diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
index 55d3405..625bdc1 100644
--- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
+++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp
@@ -90,7 +90,7 @@ struct TensorConstantBufferizePass
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<memref::MemRefDialect>();
diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
index 1ef742e..4c1d0b7 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp
@@ -138,10 +138,9 @@ public:
} // namespace
void mlir::populateTensorBufferizePatterns(
- MLIRContext *context, BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
+ BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
patterns.insert<BufferizeCastOp, BufferizeExtractOp, BufferizeFromElementsOp,
- BufferizeGenerateOp>(typeConverter, context);
+ BufferizeGenerateOp>(typeConverter, patterns.getContext());
}
namespace {
@@ -149,12 +148,12 @@ struct TensorBufferizePass : public TensorBufferizeBase<TensorBufferizePass> {
void runOnFunction() override {
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
populateBufferizeMaterializationLegality(target);
- populateTensorBufferizePatterns(context, typeConverter, patterns);
+ populateTensorBufferizePatterns(typeConverter, patterns);
target.addIllegalOp<tensor::CastOp, tensor::ExtractOp,
tensor::FromElementsOp, tensor::GenerateOp>();
target.addLegalDialect<memref::MemRefDialect>();
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 540a790..2ab1a64 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -251,7 +251,7 @@ struct TosaMakeBroadcastable
public:
void runOnFunction() override {
auto func = getFunction();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(func.getContext());
MLIRContext *ctx = func.getContext();
// Add the generated patterns to the list.
patterns.insert<ConvertTosaOp<tosa::AddOp>>(ctx);
diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp
index 08bf762..23b194d 100644
--- a/mlir/lib/Dialect/Vector/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/VectorOps.cpp
@@ -3534,11 +3534,11 @@ void CreateMaskOp::getCanonicalizationPatterns(
}
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ OwningRewritePatternList &patterns) {
patterns.insert<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder,
GatherFolder, ScatterFolder, ExpandLoadFolder,
CompressStoreFolder, StridedSliceConstantMaskFolder,
- TransposeFolder>(context);
+ TransposeFolder>(patterns.getContext());
}
#define GET_OP_CLASSES
diff --git a/mlir/lib/Dialect/Vector/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
index 57602a5..16664b1 100644
--- a/mlir/lib/Dialect/Vector/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/VectorTransforms.cpp
@@ -2784,7 +2784,7 @@ struct TransferReadToVectorLoadLowering
// If broadcasting is required and the number of loaded elements is 1 then
// we can create `memref.load` instead of `vector.load`.
loadOp = rewriter.create<memref::LoadOp>(read.getLoc(), read.source(),
- read.indices());
+ read.indices());
} else {
// Otherwise create `vector.load`.
loadOp = rewriter.create<vector::LoadOp>(read.getLoc(),
@@ -3263,43 +3263,43 @@ struct BubbleUpBitCastForStridedSliceInsert
// TODO: Add pattern to rewrite ExtractSlices(ConstantMaskOp).
// TODO: Add this as DRR pattern.
void mlir::vector::populateVectorToVectorTransformationPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ OwningRewritePatternList &patterns) {
patterns.insert<ShapeCastOpDecomposer, ShapeCastOpFolder, TupleGetFolderOp,
TransferReadExtractPattern, TransferWriteInsertPattern>(
- context);
+ patterns.getContext());
}
void mlir::vector::populateSplitVectorTransferPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
+ OwningRewritePatternList &patterns,
std::function<bool(Operation *)> ignoreFilter) {
- patterns.insert<SplitTransferReadOp, SplitTransferWriteOp>(context,
- ignoreFilter);
+ patterns.insert<SplitTransferReadOp, SplitTransferWriteOp>(
+ patterns.getContext(), ignoreFilter);
}
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ OwningRewritePatternList &patterns) {
patterns.insert<CastAwayExtractStridedSliceLeadingOneDim,
CastAwayInsertStridedSliceLeadingOneDim,
CastAwayTransferReadLeadingOneDim,
CastAwayTransferWriteLeadingOneDim, ShapeCastOpFolder>(
- context);
+ patterns.getContext());
}
void mlir::vector::populateBubbleVectorBitCastOpPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ OwningRewritePatternList &patterns) {
patterns.insert<BubbleDownVectorBitCastForExtract,
BubbleDownBitCastForStridedSliceExtract,
- BubbleUpBitCastForStridedSliceInsert>(context);
+ BubbleUpBitCastForStridedSliceInsert>(patterns.getContext());
}
void mlir::vector::populateVectorSlicesLoweringPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
- patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
+ OwningRewritePatternList &patterns) {
+ patterns.insert<ExtractSlicesOpLowering, InsertSlicesOpLowering>(
+ patterns.getContext());
}
void mlir::vector::populateVectorContractLoweringPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context,
- VectorTransformsOptions parameters) {
+ OwningRewritePatternList &patterns, VectorTransformsOptions parameters) {
// clang-format off
patterns.insert<BroadcastOpLowering,
CreateMaskOpLowering,
@@ -3307,16 +3307,16 @@ void mlir::vector::populateVectorContractLoweringPatterns(
OuterProductOpLowering,
ShapeCastOp2DDownCastRewritePattern,
ShapeCastOp2DUpCastRewritePattern,
- ShapeCastOpRewritePattern>(context);
+ ShapeCastOpRewritePattern>(patterns.getContext());
patterns.insert<TransposeOpLowering,
ContractionOpLowering,
ContractionOpToMatmulOpLowering,
- ContractionOpToOuterProductOpLowering>(parameters, context);
+ ContractionOpToOuterProductOpLowering>(parameters, patterns.getContext());
// clang-format on
}
void mlir::vector::populateVectorTransferLoweringPatterns(
- OwningRewritePatternList &patterns, MLIRContext *context) {
+ OwningRewritePatternList &patterns) {
patterns.insert<TransferReadToVectorLoadLowering,
- TransferWriteToVectorStoreLowering>(context);
+ TransferWriteToVectorStoreLowering>(patterns.getContext());
}
diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp
index 74de861..ba1f566 100644
--- a/mlir/lib/Transforms/Bufferize.cpp
+++ b/mlir/lib/Transforms/Bufferize.cpp
@@ -84,10 +84,9 @@ public:
} // namespace
void mlir::populateEliminateBufferizeMaterializationsPatterns(
- MLIRContext *context, BufferizeTypeConverter &typeConverter,
- OwningRewritePatternList &patterns) {
- patterns.insert<BufferizeTensorLoadOp, BufferizeCastOp>(typeConverter,
- context);
+ BufferizeTypeConverter &typeConverter, OwningRewritePatternList &patterns) {
+ patterns.insert<BufferizeTensorLoadOp, BufferizeCastOp>(
+ typeConverter, patterns.getContext());
}
namespace {
@@ -101,11 +100,10 @@ struct FinalizingBufferizePass
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
ConversionTarget target(*context);
- populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
- patterns);
+ populateEliminateBufferizeMaterializationsPatterns(typeConverter, patterns);
// If all result types are legal, and all block arguments are legal (ensured
// by func conversion above), then all types in the program are legal.
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index cd99681..900d89c 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -25,7 +25,7 @@ struct Canonicalizer : public CanonicalizerBase<Canonicalizer> {
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
LogicalResult initialize(MLIRContext *context) override {
- OwningRewritePatternList owningPatterns;
+ OwningRewritePatternList owningPatterns(context);
for (auto *op : context->getRegisteredOperations())
op->getCanonicalizationPatterns(owningPatterns, context);
patterns = std::move(owningPatterns);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 5c99c58..113ba46 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -75,7 +75,8 @@ computeConversionSet(iterator_range<Region::iterator> region,
/// A utility function to log a successful result for the given reason.
template <typename... Args>
-static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
+static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt,
+ Args &&... args) {
LLVM_DEBUG({
os.unindent();
os.startLine() << "} -> SUCCESS";
@@ -88,7 +89,8 @@ static void logSuccess(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
/// A utility function to log a failure result for the given reason.
template <typename... Args>
-static void logFailure(llvm::ScopedPrinter &os, StringRef fmt, Args &&...args) {
+static void logFailure(llvm::ScopedPrinter &os, StringRef fmt,
+ Args &&... args) {
LLVM_DEBUG({
os.unindent();
os.startLine() << "} -> FAILURE : "
@@ -2611,15 +2613,14 @@ struct FunctionLikeSignatureConversion : public ConversionPattern {
void mlir::populateFunctionLikeTypeConversionPattern(
StringRef functionLikeOpName, OwningRewritePatternList &patterns,
- MLIRContext *ctx, TypeConverter &converter) {
- patterns.insert<FunctionLikeSignatureConversion>(functionLikeOpName, ctx,
- converter);
+ TypeConverter &converter) {
+ patterns.insert<FunctionLikeSignatureConversion>(
+ functionLikeOpName, patterns.getContext(), converter);
}
void mlir::populateFuncOpTypeConversionPattern(
- OwningRewritePatternList &patterns, MLIRContext *ctx,
- TypeConverter &converter) {
- populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, ctx, converter);
+ OwningRewritePatternList &patterns, TypeConverter &converter) {
+ populateFunctionLikeTypeConversionPattern<FuncOp>(patterns, converter);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/LoopUtils.cpp b/mlir/lib/Transforms/Utils/LoopUtils.cpp
index a9b5979..cd58ec9 100644
--- a/mlir/lib/Transforms/Utils/LoopUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopUtils.cpp
@@ -403,7 +403,7 @@ LogicalResult mlir::affineForOpBodySkew(AffineForOp forOp,
if (res) {
// Simplify/canonicalize the affine.for.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(res.getContext());
AffineForOp::getCanonicalizationPatterns(patterns, res.getContext());
bool erased;
(void)applyOpPatternsAndFold(res, std::move(patterns), &erased);
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index 4808557..b8aa7da 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -110,7 +110,7 @@ void TestAffineDataCopy::runOnFunction() {
// Promoting single iteration loops could lead to simplification of
// generated load's/store's, and the latter could anyway also be
// canonicalized.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
for (auto op : copyOps) {
patterns.clear();
if (isa<AffineLoadOp>(op)) {
diff --git a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
index 99a6022..f66ac8c 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
@@ -139,7 +139,7 @@ void ConvertToTargetEnv::runOnFunction() {
auto target = spirv::SPIRVConversionTarget::get(targetEnv);
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
ConvertToGroupNonUniformBallot, ConvertToModule,
ConvertToSubgroupBallot>(context);
diff --git a/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
index d80f912..75bc52a 100644
--- a/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
+++ b/mlir/test/lib/Dialect/SPIRV/TestGLSLCanonicalization.cpp
@@ -25,8 +25,8 @@ public:
} // namespace
void TestGLSLCanonicalizationPass::runOnOperation() {
- OwningRewritePatternList patterns;
- spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ spirv::populateSPIRVGLSLCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 53651de..8c09406 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -79,7 +79,7 @@ public:
struct TestPatternDriver : public PassWrapper<TestPatternDriver, FunctionPass> {
void runOnFunction() override {
- mlir::OwningRewritePatternList patterns;
+ mlir::OwningRewritePatternList patterns(&getContext());
populateWithGenerated(&getContext(), patterns);
// Verify named pattern is generated with expected name.
@@ -557,7 +557,7 @@ struct TestLegalizePatternDriver
void runOnOperation() override {
TestTypeConverter converter;
- mlir::OwningRewritePatternList patterns;
+ mlir::OwningRewritePatternList patterns(&getContext());
populateWithGenerated(&getContext(), patterns);
patterns.insert<
TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
@@ -568,10 +568,8 @@ struct TestLegalizePatternDriver
TestNonRootReplacement, TestBoundedRecursiveRewrite,
TestNestedOpCreationUndoRewrite>(&getContext());
patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
- mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
- converter);
- mlir::populateCallOpTypeConversionPattern(patterns, &getContext(),
- converter);
+ mlir::populateFuncOpTypeConversionPattern(patterns, converter);
+ mlir::populateCallOpTypeConversionPattern(patterns, converter);
// Define the conversion target used for the test.
ConversionTarget target(getContext());
@@ -700,7 +698,7 @@ struct OneVResOneVOperandOp1Converter
struct TestRemappedValue
: public mlir::PassWrapper<TestRemappedValue, FunctionPass> {
void runOnFunction() override {
- mlir::OwningRewritePatternList patterns;
+ mlir::OwningRewritePatternList patterns(&getContext());
patterns.insert<OneVResOneVOperandOp1Converter>(&getContext());
mlir::ConversionTarget target(getContext());
@@ -742,7 +740,7 @@ struct RemoveTestDialectOps : public RewritePattern {
struct TestUnknownRootOpDriver
: public mlir::PassWrapper<TestUnknownRootOpDriver, FunctionPass> {
void runOnFunction() override {
- mlir::OwningRewritePatternList patterns;
+ mlir::OwningRewritePatternList patterns(&getContext());
patterns.insert<RemoveTestDialectOps>();
mlir::ConversionTarget target(getContext());
@@ -878,12 +876,11 @@ struct TestTypeConversionDriver
});
// Initialize the set of rewrite patterns.
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
patterns.insert<TestTypeConsumerForward, TestTypeConversionProducer,
TestSignatureConversionUndo>(converter, &getContext());
patterns.insert<TestTypeConversionAnotherProducer>(&getContext());
- mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
- converter);
+ mlir::populateFuncOpTypeConversionPattern(patterns, converter);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
@@ -966,8 +963,8 @@ struct TestMergeBlocksPatternDriver
: public PassWrapper<TestMergeBlocksPatternDriver,
OperationPass<ModuleOp>> {
void runOnOperation() override {
- mlir::OwningRewritePatternList patterns;
MLIRContext *context = &getContext();
+ mlir::OwningRewritePatternList patterns(context);
patterns
.insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
context);
@@ -1035,8 +1032,8 @@ struct TestSelectiveReplacementPatternDriver
: public PassWrapper<TestSelectiveReplacementPatternDriver,
OperationPass<>> {
void runOnOperation() override {
- mlir::OwningRewritePatternList patterns;
MLIRContext *context = &getContext();
+ mlir::OwningRewritePatternList patterns(context);
patterns.insert<TestSelectiveOpReplacementPattern>(context);
(void)applyPatternsAndFoldGreedily(getOperation()->getRegions(),
std::move(patterns));
diff --git a/mlir/test/lib/Dialect/Test/TestTraits.cpp b/mlir/test/lib/Dialect/Test/TestTraits.cpp
index 87bd782..e1f151f 100644
--- a/mlir/test/lib/Dialect/Test/TestTraits.cpp
+++ b/mlir/test/lib/Dialect/Test/TestTraits.cpp
@@ -34,7 +34,7 @@ namespace {
struct TestTraitFolder : public PassWrapper<TestTraitFolder, FunctionPass> {
void runOnFunction() override {
(void)applyPatternsAndFoldGreedily(getFunction(),
- OwningRewritePatternList());
+ OwningRewritePatternList(&getContext()));
}
};
} // end anonymous namespace
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index 416bbca..06777ea0 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -183,8 +183,8 @@ struct TosaTestQuantUtilAPI
};
void TosaTestQuantUtilAPI::runOnFunction() {
- OwningRewritePatternList patterns;
auto *ctx = &getContext();
+ OwningRewritePatternList patterns(ctx);
auto func = getFunction();
patterns.insert<ConvertTosaNegateOp>(ctx);
diff --git a/mlir/test/lib/Transforms/TestConvVectorization.cpp b/mlir/test/lib/Transforms/TestConvVectorization.cpp
index cda3542..cd741d0 100644
--- a/mlir/test/lib/Transforms/TestConvVectorization.cpp
+++ b/mlir/test/lib/Transforms/TestConvVectorization.cpp
@@ -91,7 +91,7 @@ void TestConvVectorization::runOnOperation() {
VectorTransformsOptions vectorTransformsOptions{
VectorContractLowering::Dot, VectorTransposeLowering::EltWise};
- OwningRewritePatternList vectorTransferPatterns;
+ OwningRewritePatternList vectorTransferPatterns(context);
// Pattern is not applied because rank-reducing vector transfer is not yet
// supported as can be seen in splitFullAndPartialTransferPrecondition,
// VectorTransforms.cpp
@@ -106,15 +106,15 @@ void TestConvVectorization::runOnOperation() {
llvm_unreachable("Unexpected failure in linalg to loops pass.");
// Programmatic controlled lowering of vector.contract only.
- OwningRewritePatternList vectorContractLoweringPatterns;
+ OwningRewritePatternList vectorContractLoweringPatterns(context);
populateVectorContractLoweringPatterns(vectorContractLoweringPatterns,
- context, vectorTransformsOptions);
+ vectorTransformsOptions);
(void)applyPatternsAndFoldGreedily(module,
std::move(vectorContractLoweringPatterns));
// Programmatic controlled lowering of vector.transfer only.
- OwningRewritePatternList vectorToLoopsPatterns;
- populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context,
+ OwningRewritePatternList vectorToLoopsPatterns(context);
+ populateVectorToSCFConversionPatterns(vectorToLoopsPatterns,
VectorTransferToSCFOptions());
(void)applyPatternsAndFoldGreedily(module, std::move(vectorToLoopsPatterns));
diff --git a/mlir/test/lib/Transforms/TestConvertCallOp.cpp b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
index 2fe29b4..dbe1a31 100644
--- a/mlir/test/lib/Transforms/TestConvertCallOp.cpp
+++ b/mlir/test/lib/Transforms/TestConvertCallOp.cpp
@@ -43,15 +43,15 @@ public:
ModuleOp m = getOperation();
// Populate type conversions.
- LLVMTypeConverter type_converter(m.getContext());
- type_converter.addConversion([&](test::TestType type) {
+ LLVMTypeConverter typeConverter(m.getContext());
+ typeConverter.addConversion([&](test::TestType type) {
return LLVM::LLVMPointerType::get(IntegerType::get(m.getContext(), 8));
});
// Populate patterns.
- OwningRewritePatternList patterns;
- populateStdToLLVMConversionPatterns(type_converter, patterns);
- patterns.insert<TestTypeProducerOpConverter>(type_converter);
+ OwningRewritePatternList patterns(m.getContext());
+ populateStdToLLVMConversionPatterns(typeConverter, patterns);
+ patterns.insert<TestTypeProducerOpConverter>(typeConverter);
// Set target.
ConversionTarget target(getContext());
diff --git a/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
index 2dd2c34..13c01a1 100644
--- a/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
+++ b/mlir/test/lib/Transforms/TestDecomposeCallGraphTypes.cpp
@@ -33,7 +33,7 @@ struct TestDecomposeCallGraphTypes
TypeConverter typeConverter;
ConversionTarget target(*context);
ValueDecomposer decomposer;
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(context);
target.addLegalDialect<test::TestDialect>();
diff --git a/mlir/test/lib/Transforms/TestExpandTanh.cpp b/mlir/test/lib/Transforms/TestExpandTanh.cpp
index e67e89b..dc54a4b 100644
--- a/mlir/test/lib/Transforms/TestExpandTanh.cpp
+++ b/mlir/test/lib/Transforms/TestExpandTanh.cpp
@@ -24,8 +24,8 @@ struct TestExpandTanhPass
} // end anonymous namespace
void TestExpandTanhPass::runOnFunction() {
- OwningRewritePatternList patterns;
- populateExpandTanhPattern(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateExpandTanhPattern(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/lib/Transforms/TestGpuRewrite.cpp b/mlir/test/lib/Transforms/TestGpuRewrite.cpp
index 44ffd38..5f87a9f 100644
--- a/mlir/test/lib/Transforms/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Transforms/TestGpuRewrite.cpp
@@ -25,8 +25,8 @@ struct TestGpuRewritePass
registry.insert<StandardOpsDialect, memref::MemRefDialect>();
}
void runOnOperation() override {
- OwningRewritePatternList patterns;
- populateGpuRewritePatterns(&getContext(), patterns);
+ OwningRewritePatternList patterns(&getContext());
+ populateGpuRewritePatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
diff --git a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
index 1efc565..8cb7702 100644
--- a/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgFusionTransforms.cpp
@@ -109,7 +109,7 @@ struct TestLinalgFusionTransforms
void runOnFunction() override {
MLIRContext *context = &this->getContext();
FuncOp funcOp = this->getFunction();
- OwningRewritePatternList fusionPatterns;
+ OwningRewritePatternList fusionPatterns(context);
Aliases alias;
LinalgDependenceGraph dependenceGraph =
LinalgDependenceGraph::buildDependenceGraph(alias, funcOp);
diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
index 6cc390f..8e1cd2d 100644
--- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp
@@ -92,7 +92,7 @@ struct TestLinalgTransforms
static void applyPatterns(FuncOp funcOp) {
MLIRContext *ctx = funcOp.getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
//===--------------------------------------------------------------------===//
// Linalg tiling patterns.
@@ -237,21 +237,26 @@ static void fillL1TilingAndMatmulToVectorPatterns(
FuncOp funcOp, StringRef startMarker,
SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
MLIRContext *ctx = funcOp.getContext();
- patternsVector.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
- ctx,
- LinalgTilingOptions().setTileSizes({8, 12, 16}).setInterchange({1, 0, 2}),
- LinalgTransformationFilter(Identifier::get(startMarker, ctx),
- Identifier::get("L1", ctx))));
+ patternsVector.emplace_back(
+ ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
+ ctx,
+ LinalgTilingOptions()
+ .setTileSizes({8, 12, 16})
+ .setInterchange({1, 0, 2}),
+ LinalgTransformationFilter(Identifier::get(startMarker, ctx),
+ Identifier::get("L1", ctx))));
patternsVector.emplace_back(
+ ctx,
std::make_unique<LinalgPromotionPattern<MatmulOp>>(
ctx, LinalgPromotionOptions().setUseFullTileBuffersByDefault(true),
LinalgTransformationFilter(Identifier::get("L1", ctx),
Identifier::get("VEC", ctx))));
- patternsVector.emplace_back(std::make_unique<LinalgVectorizationPattern>(
- MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
- LinalgTransformationFilter(Identifier::get("VEC", ctx))));
+ patternsVector.emplace_back(
+ ctx, std::make_unique<LinalgVectorizationPattern>(
+ MatmulOp::getOperationName(), ctx, LinalgVectorizationOptions(),
+ LinalgTransformationFilter(Identifier::get("VEC", ctx))));
patternsVector.back().insert<LinalgVectorizationPattern>(
LinalgTransformationFilter().addFilter(
[](Operation *op) { return success(isa<FillOp, CopyOp>(op)); }));
@@ -462,13 +467,14 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("START", ctx),
stage1Patterns);
} else if (testMatmulToVectorPatterns2dTiling) {
- stage1Patterns.emplace_back(std::make_unique<LinalgTilingPattern<MatmulOp>>(
- ctx,
- LinalgTilingOptions()
- .setTileSizes({768, 264, 768})
- .setInterchange({1, 2, 0}),
- LinalgTransformationFilter(Identifier::get("START", ctx),
- Identifier::get("L2", ctx))));
+ stage1Patterns.emplace_back(
+ ctx, std::make_unique<LinalgTilingPattern<MatmulOp>>(
+ ctx,
+ LinalgTilingOptions()
+ .setTileSizes({768, 264, 768})
+ .setInterchange({1, 2, 0}),
+ LinalgTransformationFilter(Identifier::get("START", ctx),
+ Identifier::get("L2", ctx))));
fillL1TilingAndMatmulToVectorPatterns(funcOp, Identifier::get("L2", ctx),
stage1Patterns);
}
@@ -481,14 +487,14 @@ applyMatmulToVectorPatterns(FuncOp funcOp,
}
static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
- OwningRewritePatternList forwardPattern;
+ OwningRewritePatternList forwardPattern(funcOp.getContext());
forwardPattern.insert<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
forwardPattern.insert<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
(void)applyPatternsAndFoldGreedily(funcOp, std::move(forwardPattern));
}
static void applyLinalgToVectorPatterns(FuncOp funcOp) {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(funcOp.getContext());
patterns.insert<LinalgVectorizationPattern>(
LinalgTransformationFilter()
.addOpFilter<ContractionOpInterface, FillOp, CopyOp, GenericOp>());
@@ -497,7 +503,7 @@ static void applyLinalgToVectorPatterns(FuncOp funcOp) {
}
static void applyAffineMinSCFCanonicalizationPatterns(FuncOp funcOp) {
- OwningRewritePatternList foldPattern;
+ OwningRewritePatternList foldPattern(funcOp.getContext());
foldPattern.insert<AffineMinSCFCanonicalizationPattern>(funcOp.getContext());
FrozenRewritePatternList frozenPatterns(std::move(foldPattern));
@@ -517,7 +523,7 @@ static Value getNeutralOfLinalgOp(OpBuilder &b, OpOperand &op) {
static void applyTileAndPadPattern(FuncOp funcOp) {
MLIRContext *context = funcOp.getContext();
- OwningRewritePatternList tilingPattern;
+ OwningRewritePatternList tilingPattern(context);
auto linalgTilingOptions =
linalg::LinalgTilingOptions()
.setTileSizes({2, 3, 4})
@@ -539,13 +545,13 @@ void TestLinalgTransforms::runOnFunction() {
std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
if (testPromotionOptions) {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
fillPromotionCallBackPatterns(&getContext(), patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
return;
}
if (testTileAndDistributionOptions) {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
fillTileAndDistributePatterns(&getContext(), patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
return;
diff --git a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
index b4b8ac5..c702301 100644
--- a/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
+++ b/mlir/test/lib/Transforms/TestPolynomialApproximation.cpp
@@ -32,8 +32,8 @@ struct TestMathPolynomialApproximationPass
} // end anonymous namespace
void TestMathPolynomialApproximationPass::runOnFunction() {
- OwningRewritePatternList patterns;
- populateMathPolynomialApproximationPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateMathPolynomialApproximationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/test/lib/Transforms/TestSparsification.cpp b/mlir/test/lib/Transforms/TestSparsification.cpp
index a76b8664..8c58f6e 100644
--- a/mlir/test/lib/Transforms/TestSparsification.cpp
+++ b/mlir/test/lib/Transforms/TestSparsification.cpp
@@ -101,25 +101,25 @@ struct TestSparsification
/// Runs the test on a function.
void runOnOperation() override {
auto *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
// Translate strategy flags to strategy options.
linalg::SparsificationOptions options(parallelOption(), vectorOption(),
vectorLength, typeOption(ptrType),
typeOption(indType), fastOutput);
// Apply rewriting.
- linalg::populateSparsificationPatterns(ctx, patterns, options);
- vector::populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
+ linalg::populateSparsificationPatterns(patterns, options);
+ vector::populateVectorToVectorCanonicalizationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
// Lower sparse primitives to calls into runtime support library.
if (lower) {
- OwningRewritePatternList conversionPatterns;
+ OwningRewritePatternList conversionPatterns(ctx);
ConversionTarget target(*ctx);
target.addIllegalOp<linalg::SparseTensorFromPointerOp,
linalg::SparseTensorToPointersMemRefOp,
linalg::SparseTensorToIndicesMemRefOp,
linalg::SparseTensorToValuesMemRefOp>();
target.addLegalOp<CallOp>();
- linalg::populateSparsificationConversionPatterns(ctx, conversionPatterns);
+ linalg::populateSparsificationConversionPatterns(conversionPatterns);
if (failed(applyPartialConversion(getOperation(), target,
std::move(conversionPatterns))))
signalPassFailure();
diff --git a/mlir/test/lib/Transforms/TestVectorTransforms.cpp b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
index f11ee13..ac0b099f 100644
--- a/mlir/test/lib/Transforms/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Transforms/TestVectorTransforms.cpp
@@ -36,19 +36,19 @@ struct TestVectorToVectorConversion
llvm::cl::init(false)};
void runOnFunction() override {
- OwningRewritePatternList patterns;
auto *ctx = &getContext();
+ OwningRewritePatternList patterns(ctx);
if (unroll) {
patterns.insert<UnrollVectorPattern>(
ctx,
UnrollVectorOptions().setNativeShapeFn(getShape).setFilterConstraint(
filter));
}
- populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
- populateVectorToVectorTransformationPatterns(patterns, ctx);
- populateBubbleVectorBitCastOpPatterns(patterns, ctx);
- populateCastAwayVectorLeadingOneDimPatterns(patterns, ctx);
- populateSplitVectorTransferPatterns(patterns, ctx);
+ populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorToVectorTransformationPatterns(patterns);
+ populateBubbleVectorBitCastOpPatterns(patterns);
+ populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ populateSplitVectorTransferPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
@@ -70,8 +70,8 @@ private:
struct TestVectorSlicesConversion
: public PassWrapper<TestVectorSlicesConversion, FunctionPass> {
void runOnFunction() override {
- OwningRewritePatternList patterns;
- populateVectorSlicesLoweringPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateVectorSlicesLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
@@ -101,7 +101,7 @@ struct TestVectorContractionConversion
llvm::cl::init(false)};
void runOnFunction() override {
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&getContext());
// Test on one pattern in isolation.
if (lowerToOuterProduct) {
@@ -138,7 +138,7 @@ struct TestVectorContractionConversion
if (lowerToFlatTranspose)
transposeLowering = VectorTransposeLowering::Flat;
VectorTransformsOptions options{contractLowering, transposeLowering};
- populateVectorContractLoweringPatterns(patterns, &getContext(), options);
+ populateVectorContractLoweringPatterns(patterns, options);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
@@ -149,7 +149,7 @@ struct TestVectorUnrollingPatterns
TestVectorUnrollingPatterns(const TestVectorUnrollingPatterns &pass) {}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
patterns.insert<UnrollVectorPattern>(
ctx, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{2, 2})
@@ -185,8 +185,8 @@ struct TestVectorUnrollingPatterns
return success(isa<ContractionOp>(op));
}));
}
- populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
- populateVectorToVectorTransformationPatterns(patterns, ctx);
+ populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
@@ -210,7 +210,7 @@ struct TestVectorDistributePatterns
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
FuncOp func = getFunction();
func.walk([&](AddFOp op) {
OpBuilder builder(op);
@@ -241,7 +241,7 @@ struct TestVectorDistributePatterns
}
});
patterns.insert<PointwiseExtractPattern>(ctx);
- populateVectorToVectorTransformationPatterns(patterns, ctx);
+ populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
@@ -260,7 +260,7 @@ struct TestVectorToLoopPatterns
llvm::cl::init(32)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
FuncOp func = getFunction();
func.walk([&](AddFOp op) {
// Check that the operation type can be broken down into a loop.
@@ -301,7 +301,7 @@ struct TestVectorToLoopPatterns
return mlir::WalkResult::interrupt();
});
patterns.insert<PointwiseExtractPattern>(ctx);
- populateVectorToVectorTransformationPatterns(patterns, ctx);
+ populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
@@ -313,7 +313,7 @@ struct TestVectorTransferUnrollingPatterns
}
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
patterns.insert<UnrollVectorPattern>(
ctx,
UnrollVectorOptions()
@@ -322,8 +322,8 @@ struct TestVectorTransferUnrollingPatterns
return success(
isa<vector::TransferReadOp, vector::TransferWriteOp>(op));
}));
- populateVectorToVectorCanonicalizationPatterns(patterns, ctx);
- populateVectorToVectorTransformationPatterns(patterns, ctx);
+ populateVectorToVectorCanonicalizationPatterns(patterns);
+ populateVectorToVectorTransformationPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
@@ -347,7 +347,7 @@ struct TestVectorTransferFullPartialSplitPatterns
llvm::cl::init(false)};
void runOnFunction() override {
MLIRContext *ctx = &getContext();
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(ctx);
VectorTransformsOptions options;
if (useLinalgOps)
options.setVectorTransferSplit(VectorTransferSplit::LinalgCopy);
@@ -369,8 +369,8 @@ struct TestVectorTransferLoweringPatterns
registry.insert<memref::MemRefDialect>();
}
void runOnFunction() override {
- OwningRewritePatternList patterns;
- populateVectorTransferLoweringPatterns(patterns, &getContext());
+ OwningRewritePatternList patterns(&getContext());
+ populateVectorTransferLoweringPatterns(patterns);
(void)applyPatternsAndFoldGreedily(getFunction(), std::move(patterns));
}
};
diff --git a/mlir/unittests/Rewrite/PatternBenefit.cpp b/mlir/unittests/Rewrite/PatternBenefit.cpp
index 721ec5e..ee36c6a 100644
--- a/mlir/unittests/Rewrite/PatternBenefit.cpp
+++ b/mlir/unittests/Rewrite/PatternBenefit.cpp
@@ -52,7 +52,7 @@ TEST(PatternBenefitTest, BenefitOrder) {
bool *called;
};
- OwningRewritePatternList patterns;
+ OwningRewritePatternList patterns(&context);
bool called1 = false;
bool called2 = false;