aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthias Springer <mspringer@nvidia.com>2025-01-11 13:39:06 +0100
committerMatthias Springer <mspringer@nvidia.com>2025-01-31 09:16:59 +0100
commit5a50536cc8e723fe6ed9fe026147045e0b0e10fa (patch)
tree28a5808603f61c1588d2d0887d0e658192997509
parent4435b7d8d3df31d59402b6b106d8d45fd2ba0f93 (diff)
downloadllvm-users/matthias-springer/winter_school_greedy_rewriter.zip
llvm-users/matthias-springer/winter_school_greedy_rewriter.tar.gz
llvm-users/matthias-springer/winter_school_greedy_rewriter.tar.bz2
-rw-r--r--mlir/include/mlir/Dialect/Arith/Transforms/Passes.h7
-rw-r--r--mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td19
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt15
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.cpp297
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.pdll29
-rw-r--r--mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidthConversion.cpp139
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp29
-rw-r--r--mlir/test/Dialect/Arith/WinterSchool/test_10_dialect_conversion.mlir36
-rw-r--r--mlir/test/Dialect/Arith/WinterSchool/test_1_arith_addf.mlir6
-rw-r--r--mlir/test/Dialect/Arith/WinterSchool/test_2_arith_addf.mlir6
-rw-r--r--mlir/test/Dialect/Arith/WinterSchool/test_3_arith_constant.mlir6
-rw-r--r--mlir/test/Dialect/Arith/WinterSchool/test_4_arith_constant.mlir6
-rw-r--r--mlir/test/Dialect/Arith/WinterSchool/test_5_arith_constant_addf.mlir8
-rw-r--r--mlir/test/Dialect/Arith/WinterSchool/test_6_func.mlir7
-rw-r--r--mlir/test/Dialect/Arith/WinterSchool/test_7_arith_truncf.mlir6
-rw-r--r--mlir/test/Dialect/Arith/WinterSchool/test_8_transform.mlir23
-rw-r--r--mlir/test/Dialect/Arith/WinterSchool/test_9_arith_truncf_pdl.mlir22
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp4
19 files changed, 666 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
index 58dce89..e580c49 100644
--- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h
@@ -13,6 +13,7 @@
namespace mlir {
class DataFlowSolver;
+class FloatType;
class ConversionTarget;
class TypeConverter;
@@ -82,6 +83,12 @@ void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
DataFlowSolver &solver,
ArrayRef<unsigned> bitwidthsSupported);
+/// Populate the specified patterns for reducing the bitwidth of FP
+/// computations.
+void populateTestReduceFloatBitwidthPatterns(
+ RewritePatternSet &patterns, ArrayRef<std::string> enabledPatterns,
+ FloatType sourceType, FloatType targetType);
+
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index 081bf9b..8cd01bf 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -27,6 +27,25 @@ def TransformAnyParamTypeOrAnyHandle : Type<
"transform any param type or any handle type">;
//===----------------------------------------------------------------------===//
+// Winter School
+//===----------------------------------------------------------------------===//
+
+def ApplyReduceFloatBitwidthPatternsOp : Op<Transform_Dialect,
+ "apply_patterns.arith.reduce_float_bitwidth",
+ [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+ let description = [{
+ Apply patterns to reduce the bidwidth of floating-point computations.
+ }];
+ let arguments = (ins ArrayAttr:$enabled_patterns,
+ TypeAttr:$sourceType,
+ TypeAttr:$targetType);
+ let assemblyFormat = [{
+ $enabled_patterns `from` $sourceType `to` $targetType attr-dict
+ }];
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
index 6149b35..f34089c49 100644
--- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt
@@ -1,3 +1,12 @@
+add_mlir_pdll_library(MLIRArithPDLLPatternsIncGen
+ TestReduceFloatBitwidth.pdll
+ TestReduceFloatBitwidthPatterns.h.inc
+
+ EXTRA_INCLUDES
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test
+ ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test
+ )
+
add_mlir_dialect_library(MLIRArithTransforms
BufferDeallocationOpInterfaceImpl.cpp
BufferizableOpInterfaceImpl.cpp
@@ -8,12 +17,15 @@ add_mlir_dialect_library(MLIRArithTransforms
ExpandOps.cpp
IntRangeOptimizations.cpp
ReifyValueBounds.cpp
+ TestReduceFloatBitwidth.cpp
+ TestReduceFloatBitwidthConversion.cpp
UnsignedWhenEquivalent.cpp
ADDITIONAL_HEADER_DIRS
{$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms
DEPENDS
+ MLIRArithPDLLPatternsIncGen
MLIRArithTransformsIncGen
LINK_LIBS PUBLIC
@@ -27,6 +39,9 @@ add_mlir_dialect_library(MLIRArithTransforms
MLIRIR
MLIRMemRefDialect
MLIRPass
+ MLIRPDLInterpDialect
+ MLIRPDLDialect
+ MLIRSupport
MLIRTensorDialect
MLIRTransforms
MLIRTransformUtils
diff --git a/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.cpp b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.cpp
new file mode 100644
index 0000000..1f2730c
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.cpp
@@ -0,0 +1,297 @@
+//===- TestReduceFloatBitwdith.cpp - Reduce Float Bitwidth -*- c++ -----*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A pass that reduces the bitwidth of Arith floating-point IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
+#include "mlir/Parser/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+
+#include "TestReduceFloatBitwidthPatterns.h.inc"
+
+namespace {
+
+/// Options for rewrite patterns.
+struct ReduceFloatOptions {
+ /// The source float type, who's bit width should be reduced.
+ FloatType sourceType;
+ /// The target float type.
+ FloatType targetType;
+};
+
+/// Pattern for arith.addf.
+class AddFOpPattern : public OpRewritePattern<AddFOp> {
+public:
+ AddFOpPattern(MLIRContext *context, const ReduceFloatOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(AddFOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getType() != options.sourceType)
+ return rewriter.notifyMatchFailure(op, "does not match source type");
+ Value lhsTrunc =
+ rewriter.create<TruncFOp>(op.getLoc(), options.targetType, op.getLhs());
+ Value rhsTrunc =
+ rewriter.create<TruncFOp>(op.getLoc(), options.targetType, op.getRhs());
+ Value newAdd = rewriter.create<AddFOp>(op.getLoc(), lhsTrunc, rhsTrunc);
+ rewriter.replaceOpWithNewOp<ExtFOp>(op, op.getType(), newAdd);
+ return success();
+ }
+
+private:
+ const ReduceFloatOptions options;
+};
+class AddFOpPatternV2 : public OpRewritePattern<AddFOp> {
+public:
+ AddFOpPatternV2(MLIRContext *context, const ReduceFloatOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(AddFOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getType() != options.sourceType)
+ return rewriter.notifyMatchFailure(op, "does not match source type");
+ Value lhsTrunc =
+ rewriter.create<TruncFOp>(op.getLoc(), options.targetType, op.getLhs());
+ Value rhsTrunc =
+ rewriter.create<TruncFOp>(op.getLoc(), options.targetType, op.getRhs());
+ Value newAdd = rewriter.create<AddFOp>(op.getLoc(), lhsTrunc, rhsTrunc);
+ Value replacementValue =
+ rewriter.create<ExtFOp>(op.getLoc(), op.getType(), newAdd);
+ rewriter.replaceAllUsesWith(op.getResult(), replacementValue);
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ const ReduceFloatOptions options;
+};
+
+/// Pattern for arith.constant.
+class ConstantOpPattern : public OpRewritePattern<ConstantOp> {
+public:
+ ConstantOpPattern(MLIRContext *context, const ReduceFloatOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(ConstantOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getType() != options.sourceType)
+ return rewriter.notifyMatchFailure(op, "does not match source type");
+ double val = cast<FloatAttr>(op.getValue()).getValueAsDouble();
+ auto newAttr = FloatAttr::get(options.targetType, val);
+ Value newConstant = rewriter.create<ConstantOp>(op.getLoc(), newAttr);
+ rewriter.replaceOpWithNewOp<ExtFOp>(op, op.getType(), newConstant);
+ return success();
+ }
+
+private:
+ const ReduceFloatOptions options;
+};
+
+/// Pattern for func.func.
+class FuncOpPattern : public OpRewritePattern<func::FuncOp> {
+public:
+ FuncOpPattern(MLIRContext *context, const ReduceFloatOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(func::FuncOp op,
+ PatternRewriter &rewriter) const override {
+ if (!llvm::hasSingleElement(op.getBody()))
+ return rewriter.notifyMatchFailure(op, "0 or >1 blocks not supported");
+ FunctionType type = op.getFunctionType();
+ SmallVector<Type> newInputs;
+ for (Type t : type.getInputs()) {
+ if (t == options.sourceType) {
+ newInputs.push_back(options.targetType);
+ } else {
+ newInputs.push_back(t);
+ }
+ }
+ SmallVector<Type> newResults;
+ for (Type t : type.getResults()) {
+ if (t == options.sourceType) {
+ newResults.push_back(options.targetType);
+ } else {
+ newResults.push_back(t);
+ }
+ }
+ if (llvm::equal(type.getInputs(), newInputs) &&
+ llvm::equal(type.getResults(), newResults))
+ return rewriter.notifyMatchFailure(op, "no types to convert");
+ auto newFuncOp = rewriter.create<func::FuncOp>(
+ op.getLoc(), op.getSymName(),
+ FunctionType::get(op.getContext(), newInputs, newResults));
+ SmallVector<Location> locs =
+ llvm::map_to_vector(op.getBody().getArguments(),
+ [](BlockArgument arg) { return arg.getLoc(); });
+ Block *newBlock = rewriter.createBlock(
+ &newFuncOp.getBody(), newFuncOp.getBody().begin(), newInputs, locs);
+ rewriter.setInsertionPointToStart(newBlock);
+ SmallVector<Value> argRepl;
+ for (auto [oldType, newType, newArg] : llvm::zip_equal(
+ type.getInputs(), newInputs, newBlock->getArguments())) {
+ if (oldType == newType) {
+ argRepl.push_back(newArg);
+ } else {
+ argRepl.push_back(
+ rewriter.create<ExtFOp>(newArg.getLoc(), oldType, newArg));
+ }
+ }
+ rewriter.inlineBlockBefore(&op.getBody().front(), newBlock, newBlock->end(),
+ argRepl);
+ rewriter.eraseOp(op);
+ return success();
+ }
+
+private:
+ const ReduceFloatOptions options;
+};
+
+/// Pattern for func.return.
+class ReturnOpPattern : public OpRewritePattern<func::ReturnOp> {
+public:
+ ReturnOpPattern(MLIRContext *context, const ReduceFloatOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit), options(options) {}
+
+ LogicalResult matchAndRewrite(func::ReturnOp op,
+ PatternRewriter &rewriter) const override {
+ bool changedIR = false;
+ SmallVector<Value> newOperands;
+ for (Value val : op.getOperands()) {
+ if (val.getType() != options.sourceType) {
+ newOperands.push_back(val);
+ } else {
+ changedIR = true;
+ newOperands.push_back(
+ rewriter.create<TruncFOp>(val.getLoc(), options.targetType, val));
+ }
+ }
+ if (!changedIR)
+ return rewriter.notifyMatchFailure(op, "no types to convert");
+ rewriter.modifyOpInPlace(
+ op, [&]() { op.getOperandsMutable().assign(newOperands); });
+ return success();
+ }
+
+private:
+ const ReduceFloatOptions options;
+};
+
+/// Pattern that folds arith.truncf(arith.extf(x)) => x.
+class ExtTruncFolding : public OpRewritePattern<TruncFOp> {
+public:
+ ExtTruncFolding(MLIRContext *context, const ReduceFloatOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern(context, benefit) {}
+
+ LogicalResult matchAndRewrite(TruncFOp op,
+ PatternRewriter &rewriter) const override {
+ auto extfOp = op.getIn().getDefiningOp<ExtFOp>();
+ if (!extfOp)
+ return rewriter.notifyMatchFailure(op,
+ "'in' is not defined by arith.extf");
+ if (extfOp.getIn().getType() != op.getType())
+ return rewriter.notifyMatchFailure(op, "types do not match");
+ rewriter.replaceOp(op, extfOp.getIn());
+ return success();
+ }
+};
+
+struct TestReduceFloatBitwidthPass
+ : public PassWrapper<TestReduceFloatBitwidthPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReduceFloatBitwidthPass)
+
+ TestReduceFloatBitwidthPass() = default;
+ TestReduceFloatBitwidthPass(const TestReduceFloatBitwidthPass &pass)
+ : PassWrapper(pass) {}
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<arith::ArithDialect, func::FuncDialect, pdl::PDLDialect,
+ pdl_interp::PDLInterpDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-arith-reduce-float-bitwidth";
+ }
+ StringRef getDescription() const final {
+ return "Pass that reduces the bitwidth of floating-point ops";
+ }
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+ RewritePatternSet patterns(ctx);
+ populateTestReduceFloatBitwidthPatterns(
+ patterns, optPatterns, FloatType::getF32(ctx), FloatType::getF16(ctx));
+
+ GreedyRewriteConfig config;
+ config.fold = optFold;
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns),
+ config))) {
+ getOperation()->emitError() << getArgument() << " failed";
+ signalPassFailure();
+ }
+ }
+
+ Option<bool> optFold{*this, "fold", llvm::cl::init(true),
+ llvm::cl::desc("fold ops")};
+ ListOption<std::string> optPatterns{*this, "patterns",
+ llvm::cl::desc("activated patterns")};
+};
+} // namespace
+
+static Attribute convertAttrF32ToF16(PatternRewriter &rewriter,
+ Attribute attr) {
+ auto floatAttr = dyn_cast<FloatAttr>(attr);
+ if (!attr)
+ return Attribute();
+ return rewriter.getFloatAttr(rewriter.getF16Type(),
+ floatAttr.getValueAsDouble());
+}
+
+void arith::populateTestReduceFloatBitwidthPatterns(
+ RewritePatternSet &patterns, ArrayRef<std::string> enabledPatterns,
+ FloatType sourceType, FloatType targetType) {
+ ReduceFloatOptions options{sourceType, targetType};
+ MLIRContext *ctx = patterns.getContext();
+ if (llvm::is_contained(enabledPatterns, "arith.addf"))
+ patterns.insert<AddFOpPattern>(ctx, options);
+ if (llvm::is_contained(enabledPatterns, "arith.addf_v2"))
+ patterns.insert<AddFOpPatternV2>(ctx, options);
+ if (llvm::is_contained(enabledPatterns, "arith.constant"))
+ patterns.insert<ConstantOpPattern>(ctx, options);
+ if (llvm::is_contained(enabledPatterns, "func.func"))
+ patterns.insert<FuncOpPattern>(ctx, options);
+ if (llvm::is_contained(enabledPatterns, "func.return"))
+ patterns.insert<ReturnOpPattern>(ctx, options);
+ if (llvm::is_contained(enabledPatterns, "arith.truncf"))
+ patterns.insert<ExtTruncFolding>(ctx, options);
+ if (llvm::is_contained(enabledPatterns, "pdl_patterns")) {
+ patterns.getPDLPatterns().registerRewriteFunction("ConvertAttrF32ToF16",
+ convertAttrF32ToF16);
+ populateGeneratedPDLLPatterns(patterns);
+ }
+}
+
+namespace mlir {
+void registerTestReduceFloatBitwidthPass() {
+ PassRegistration<TestReduceFloatBitwidthPass>();
+}
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.pdll b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.pdll
new file mode 100644
index 0000000..43e7b7f
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.pdll
@@ -0,0 +1,29 @@
+#include "mlir/Dialect/Arith/IR/ArithOps.td"
+
+Pattern TruncExtPattern {
+ let extf = op<arith.extf>(input: Value<type<"f16">>) -> (type<"f32">);
+ let truncf = op<arith.truncf>(extf) -> (type<"f16">);
+ replace truncf with input;
+}
+
+Pattern AddPattern {
+ let addf = op<arith.addf>(lhs: Value, rhs: Value) -> (type<"f32">);
+ rewrite addf with {
+ let lhs16 = op<arith.truncf>(lhs) -> (type<"f16">);
+ let rhs16 = op<arith.truncf>(rhs) -> (type<"f16">);
+ let addf16 = op<arith.addf>(lhs16, rhs16);
+ replace addf with op<arith.extf>(addf16) -> (type<"f32">);
+ };
+}
+
+Rewrite ConvertAttrF32ToF16(value: Attr) -> Attr;
+
+Pattern ConstantPattern {
+ let attr: Attr;
+ let constant = op<arith.constant> {value = attr} -> (type<"f32">);
+ rewrite constant with {
+ let attr16 = ConvertAttrF32ToF16(attr);
+ let const16 = op<arith.constant>() {value = attr16};
+ replace constant with op<arith.extf>(const16) -> (type<"f32">);
+ };
+}
diff --git a/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidthConversion.cpp b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidthConversion.cpp
new file mode 100644
index 0000000..09890dd
--- /dev/null
+++ b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidthConversion.cpp
@@ -0,0 +1,139 @@
+//===- TestReduceFloatBitwdithConversion.cpp ----------------*- c++ -----*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// A pass that reduces the bitwidth of Arith floating-point IR.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Func/Transforms/FuncConversions.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+using namespace mlir;
+using namespace mlir::arith;
+
+namespace {
+
+/// Pattern for arith.constant.
+class ConstantOpPattern : public OpConversionPattern<ConstantOp> {
+ using OpConversionPattern<ConstantOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ double val = cast<FloatAttr>(op.getValue()).getValueAsDouble();
+ auto newAttr = FloatAttr::get(Float16Type::get(op.getContext()), val);
+ rewriter.replaceOpWithNewOp<ConstantOp>(op, newAttr);
+ return success();
+ }
+};
+
+/// Pattern for arith.addf.
+class AddOpPattern : public OpConversionPattern<AddFOp> {
+ using OpConversionPattern<AddFOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(AddFOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ rewriter.replaceOpWithNewOp<AddFOp>(op, adaptor.getLhs(), adaptor.getRhs());
+ return success();
+ }
+};
+
+struct TestReduceFloatBitwidthConversionPass
+ : public PassWrapper<TestReduceFloatBitwidthConversionPass,
+ OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestReduceFloatBitwidthConversionPass)
+
+ TestReduceFloatBitwidthConversionPass() = default;
+ TestReduceFloatBitwidthConversionPass(
+ const TestReduceFloatBitwidthConversionPass &pass)
+ : PassWrapper(pass) {}
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<arith::ArithDialect, func::FuncDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-arith-reduce-float-bitwidth-conversion";
+ }
+ StringRef getDescription() const final {
+ return "Pass that reduces the bitwidth of floating-point ops (dialect "
+ "conversion)";
+ }
+
+ void runOnOperation() override {
+ MLIRContext *ctx = &getContext();
+
+ TypeConverter converter;
+ ConversionConfig config;
+ converter.addConversion([](Type type) { return type; });
+ converter.addConversion(
+ [&](Float32Type type) { return FloatType::getF16(ctx); });
+ if (optBuildMaterializations) {
+ converter.addSourceMaterialization(
+ [](OpBuilder &builder, FloatType resultType, ValueRange inputs,
+ Location loc) -> Value {
+ assert(inputs.size() == 1 && "expected single input");
+ return builder.create<ExtFOp>(loc, resultType, inputs[0]);
+ });
+ converter.addTargetMaterialization(
+ [](OpBuilder &builder, FloatType resultType, ValueRange inputs,
+ Location loc) -> Value {
+ assert(inputs.size() == 1 && "expected single input");
+ return builder.create<TruncFOp>(loc, resultType, inputs[0]);
+ });
+ config.buildMaterializations = true;
+ } else {
+ config.buildMaterializations = false;
+ }
+
+ RewritePatternSet patterns(ctx);
+ patterns.insert<ConstantOpPattern, AddOpPattern>(converter, ctx);
+ // Pattern for func.func.
+ populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
+ converter);
+ populateReturnOpTypeConversionPattern(patterns, converter);
+
+ ConversionTarget target(*ctx);
+ target.addDynamicallyLegalOp<ConstantOp, AddFOp, func::ReturnOp>(
+ [&](Operation *op) { return converter.isLegal(op); });
+ target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
+ return converter.isSignatureLegal(op.getFunctionType());
+ });
+
+ LogicalResult status = failure();
+ if (optFullConversion) {
+ status = applyFullConversion(getOperation(), target, std::move(patterns),
+ config);
+ } else {
+ status = applyPartialConversion(getOperation(), target,
+ std::move(patterns), config);
+ }
+ if (failed(status)) {
+ getOperation()->emitError() << getArgument() << " failed";
+ signalPassFailure();
+ }
+ }
+
+ Option<bool> optBuildMaterializations{
+ *this, "build-materializations", llvm::cl::init(false),
+ llvm::cl::desc("build materializations")};
+ Option<bool> optFullConversion{
+ *this, "full-conversion", llvm::cl::init(false),
+ llvm::cl::desc("full conversion (otherwise: partial)")};
+};
+} // namespace
+
+namespace mlir {
+void registerTestReduceFloatBitwidthConversionPass() {
+ PassRegistration<TestReduceFloatBitwidthConversionPass>();
+}
+} // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
index b15a470..b740309 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRLinalgTransformOps
LINK_LIBS PUBLIC
MLIRAffineDialect
MLIRArithDialect
+ MLIRArithTransforms
MLIRBufferizationDialect
MLIRBufferizationTransforms
MLIRFuncDialect
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index a1d619c..c98ce23 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
@@ -220,6 +221,34 @@ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults(
}
//===----------------------------------------------------------------------===//
+// Winter School
+//===----------------------------------------------------------------------===//
+
+void transform::ApplyReduceFloatBitwidthPatternsOp::populatePatterns(
+ RewritePatternSet &patterns) {
+ SmallVector<std::string> enabledPatternsStr;
+ for (Attribute attr : getEnabledPatterns()) {
+ enabledPatternsStr.push_back(cast<StringAttr>(attr).getValue().str());
+ }
+ FloatType sourceType = cast<FloatType>(getSourceType());
+ FloatType targetType = cast<FloatType>(getTargetType());
+ arith::populateTestReduceFloatBitwidthPatterns(patterns, enabledPatternsStr,
+ sourceType, targetType);
+}
+
+LogicalResult transform::ApplyReduceFloatBitwidthPatternsOp::verify() {
+ for (Attribute attr : getEnabledPatterns())
+ if (!isa<StringAttr>(attr))
+ return emitOpError(
+ "expected 'enabled_patterns' to be an array of string attributes");
+ if (!isa<FloatType>(getSourceType()))
+ return emitOpError("expected float source type");
+ if (!isa<FloatType>(getTargetType()))
+ return emitOpError("expected float target type");
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// Apply...PatternsOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_10_dialect_conversion.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_10_dialect_conversion.mlir
new file mode 100644
index 0000000..f032c88
--- /dev/null
+++ b/mlir/test/Dialect/Arith/WinterSchool/test_10_dialect_conversion.mlir
@@ -0,0 +1,36 @@
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth-conversion="build-materializations=0" -split-input-file
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth-conversion="build-materializations=1" -split-input-file
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth-conversion="build-materializations=1 full-conversion=1" -split-input-file
+
+func.func @test_constant() -> f32 {
+ %0 = arith.constant 2.0 : f32
+ return %0 : f32
+}
+
+// -----
+
+func.func @test_add(%arg0: f32, %arg1: f32) -> f32 {
+ %0 = arith.addf %arg0, %arg1 : f32
+ return %0 : f32
+}
+
+// -----
+
+func.func @test_add_constant(%arg0: f32) -> f32 {
+ %0 = arith.constant 2.0 : f32
+ %1 = arith.addf %arg0, %0 : f32
+ return %1 : f32
+}
+
+// -----
+
+func.func @test_func(%arg0: f32) -> f32 {
+ return %arg0 : f32
+}
+
+// -----
+
+func.func @test_boundary(%arg0: f32) -> f32 {
+ %0 = "test.consumer_producer"(%arg0) : (f32) -> (f32)
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_1_arith_addf.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_1_arith_addf.mlir
new file mode 100644
index 0000000..cd5b216
--- /dev/null
+++ b/mlir/test/Dialect/Arith/WinterSchool/test_1_arith_addf.mlir
@@ -0,0 +1,6 @@
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.addf"
+
+func.func @test_add(%arg0: f32, %arg1: f32) {
+ %0 = arith.addf %arg0, %arg1 : f32
+ return
+}
diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_2_arith_addf.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_2_arith_addf.mlir
new file mode 100644
index 0000000..85f9c57
--- /dev/null
+++ b/mlir/test/Dialect/Arith/WinterSchool/test_2_arith_addf.mlir
@@ -0,0 +1,6 @@
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.addf"
+
+func.func @test_add(%arg0: f32, %arg1: f32) -> f32 {
+ %0 = arith.addf %arg0, %arg1 : f32
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_3_arith_constant.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_3_arith_constant.mlir
new file mode 100644
index 0000000..8c8f933
--- /dev/null
+++ b/mlir/test/Dialect/Arith/WinterSchool/test_3_arith_constant.mlir
@@ -0,0 +1,6 @@
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.constant"
+
+func.func @test_constant() -> f32 {
+ %0 = arith.constant 2.0 : f32
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_4_arith_constant.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_4_arith_constant.mlir
new file mode 100644
index 0000000..5ef8434
--- /dev/null
+++ b/mlir/test/Dialect/Arith/WinterSchool/test_4_arith_constant.mlir
@@ -0,0 +1,6 @@
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.constant fold=false"
+
+func.func @test_constant() -> f32 {
+ %0 = arith.constant 2.0 : f32
+ return %0 : f32
+}
diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_5_arith_constant_addf.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_5_arith_constant_addf.mlir
new file mode 100644
index 0000000..b50a056
--- /dev/null
+++ b/mlir/test/Dialect/Arith/WinterSchool/test_5_arith_constant_addf.mlir
@@ -0,0 +1,8 @@
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.constant,arith.addf"
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.addf"
+
+func.func @test_add_constant(%arg0: f32) -> f32 {
+ %0 = arith.constant 2.0 : f32
+ %1 = arith.addf %arg0, %0 : f32
+ return %1 : f32
+}
diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_6_func.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_6_func.mlir
new file mode 100644
index 0000000..1b4ae9b
--- /dev/null
+++ b/mlir/test/Dialect/Arith/WinterSchool/test_6_func.mlir
@@ -0,0 +1,7 @@
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=func.return"
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=func.func"
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=func.func,func.return"
+
+func.func @test_func(%arg0: f32) -> f32 {
+ return %arg0 : f32
+}
diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_7_arith_truncf.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_7_arith_truncf.mlir
new file mode 100644
index 0000000..162f676
--- /dev/null
+++ b/mlir/test/Dialect/Arith/WinterSchool/test_7_arith_truncf.mlir
@@ -0,0 +1,6 @@
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=func.func,func.return,arith.truncf"
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=func.func,func.return,arith.truncf" -canonicalize
+
+func.func @test_func(%arg0: f32) -> f32 {
+ return %arg0 : f32
+}
diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_8_transform.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_8_transform.mlir
new file mode 100644
index 0000000..369ce48
--- /dev/null
+++ b/mlir/test/Dialect/Arith/WinterSchool/test_8_transform.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.addf"
+
+func.func @test_add(%arg0: f32, %arg1: f32) -> f32 {
+ %0 = arith.addf %arg0, %arg1 : f32
+ return %0 : f32
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func">
+ %add_op = transform.structured.match ops{["arith.addf"]} in %func_op : (!transform.op<"func.func">) -> !transform.op<"arith.addf">
+ transform.debug.emit_remark_at %add_op, "before pattern application" : !transform.op<"arith.addf">
+
+ transform.apply_patterns to %func_op {
+ transform.apply_patterns.arith.reduce_float_bitwidth ["arith.addf"] from f32 to f16
+ // transform.apply_patterns.arith.reduce_float_bitwidth ["func.func", "func.return", "arith.addf"] from f32 to f16
+ // transform.apply_patterns.arith.reduce_float_bitwidth ["arith.addf_v2"] from f32 to f16
+ } : !transform.op<"func.func">
+
+ transform.debug.emit_remark_at %add_op, "after pattern application" : !transform.op<"arith.addf">
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_9_arith_truncf_pdl.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_9_arith_truncf_pdl.mlir
new file mode 100644
index 0000000..68268a3
--- /dev/null
+++ b/mlir/test/Dialect/Arith/WinterSchool/test_9_arith_truncf_pdl.mlir
@@ -0,0 +1,22 @@
+// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=pdl_patterns" -split-input-file
+
+func.func @test_trunc_ext(%arg0: f16) -> f16 {
+ %0 = arith.extf %arg0 : f16 to f32
+ %1 = arith.truncf %0 : f32 to f16
+ return %1 : f16
+}
+
+// -----
+
+func.func @test_add(%arg0: f32, %arg1: f32) -> f32 {
+ %0 = arith.addf %arg0, %arg1 : f32
+ return %0 : f32
+}
+
+// -----
+
+func.func @test_add_constant(%arg0: f32) -> f32 {
+ %0 = arith.constant 2.0 : f32
+ %1 = arith.addf %arg0, %0 : f32
+ return %1 : f32
+}
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 960f703..551639b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -60,6 +60,8 @@ void registerTestPreserveUseListOrders();
void registerTestPrintDefUsePass();
void registerTestPrintInvalidPass();
void registerTestPrintNestingPass();
+void registerTestReduceFloatBitwidthPass();
+void registerTestReduceFloatBitwidthConversionPass();
void registerTestReducer();
void registerTestSpirvEntryPointABIPass();
void registerTestSpirvModuleCombinerPass();
@@ -200,6 +202,8 @@ void registerTestPasses() {
registerTestPrintDefUsePass();
registerTestPrintInvalidPass();
registerTestPrintNestingPass();
+ registerTestReduceFloatBitwidthPass();
+ registerTestReduceFloatBitwidthConversionPass();
registerTestReducer();
registerTestSpirvEntryPointABIPass();
registerTestSpirvModuleCombinerPass();