diff options
author | Matthias Springer <mspringer@nvidia.com> | 2025-01-11 13:39:06 +0100 |
---|---|---|
committer | Matthias Springer <mspringer@nvidia.com> | 2025-01-31 09:16:59 +0100 |
commit | 5a50536cc8e723fe6ed9fe026147045e0b0e10fa (patch) | |
tree | 28a5808603f61c1588d2d0887d0e658192997509 | |
parent | 4435b7d8d3df31d59402b6b106d8d45fd2ba0f93 (diff) | |
download | llvm-users/matthias-springer/winter_school_greedy_rewriter.zip llvm-users/matthias-springer/winter_school_greedy_rewriter.tar.gz llvm-users/matthias-springer/winter_school_greedy_rewriter.tar.bz2 |
Greedy pattern rewriter exampleusers/matthias-springer/winter_school_greedy_rewriter
19 files changed, 666 insertions, 0 deletions
diff --git a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h index 58dce89..e580c49 100644 --- a/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Arith/Transforms/Passes.h @@ -13,6 +13,7 @@ namespace mlir { class DataFlowSolver; +class FloatType; class ConversionTarget; class TypeConverter; @@ -82,6 +83,12 @@ void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns, DataFlowSolver &solver, ArrayRef<unsigned> bitwidthsSupported); +/// Populate the specified patterns for reducing the bitwidth of FP +/// computations. +void populateTestReduceFloatBitwidthPatterns( + RewritePatternSet &patterns, ArrayRef<std::string> enabledPatterns, + FloatType sourceType, FloatType targetType); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 081bf9b..8cd01bf 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -27,6 +27,25 @@ def TransformAnyParamTypeOrAnyHandle : Type< "transform any param type or any handle type">; //===----------------------------------------------------------------------===// +// Winter School +//===----------------------------------------------------------------------===// + +def ApplyReduceFloatBitwidthPatternsOp : Op<Transform_Dialect, + "apply_patterns.arith.reduce_float_bitwidth", + [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> { + let description = [{ + Apply patterns to reduce the bidwidth of floating-point computations. + }]; + let arguments = (ins ArrayAttr:$enabled_patterns, + TypeAttr:$sourceType, + TypeAttr:$targetType); + let assemblyFormat = [{ + $enabled_patterns `from` $sourceType `to` $targetType attr-dict + }]; + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt index 6149b35..f34089c49 100644 --- a/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Arith/Transforms/CMakeLists.txt @@ -1,3 +1,12 @@ +add_mlir_pdll_library(MLIRArithPDLLPatternsIncGen + TestReduceFloatBitwidth.pdll + TestReduceFloatBitwidthPatterns.h.inc + + EXTRA_INCLUDES + ${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test + ${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test + ) + add_mlir_dialect_library(MLIRArithTransforms BufferDeallocationOpInterfaceImpl.cpp BufferizableOpInterfaceImpl.cpp @@ -8,12 +17,15 @@ add_mlir_dialect_library(MLIRArithTransforms ExpandOps.cpp IntRangeOptimizations.cpp ReifyValueBounds.cpp + TestReduceFloatBitwidth.cpp + TestReduceFloatBitwidthConversion.cpp UnsignedWhenEquivalent.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith/Transforms DEPENDS + MLIRArithPDLLPatternsIncGen MLIRArithTransformsIncGen LINK_LIBS PUBLIC @@ -27,6 +39,9 @@ add_mlir_dialect_library(MLIRArithTransforms MLIRIR MLIRMemRefDialect MLIRPass + MLIRPDLInterpDialect + MLIRPDLDialect + MLIRSupport MLIRTensorDialect MLIRTransforms MLIRTransformUtils diff --git a/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.cpp b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.cpp new file mode 100644 index 0000000..1f2730c --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.cpp @@ -0,0 +1,297 @@ +//===- TestReduceFloatBitwdith.cpp - Reduce Float Bitwidth -*- c++ -----*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A pass that reduces the bitwidth of Arith floating-point IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Parser/Parser.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::arith; + +#include "TestReduceFloatBitwidthPatterns.h.inc" + +namespace { + +/// Options for rewrite patterns. +struct ReduceFloatOptions { + /// The source float type, who's bit width should be reduced. + FloatType sourceType; + /// The target float type. + FloatType targetType; +}; + +/// Pattern for arith.addf. +class AddFOpPattern : public OpRewritePattern<AddFOp> { +public: + AddFOpPattern(MLIRContext *context, const ReduceFloatOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + + LogicalResult matchAndRewrite(AddFOp op, + PatternRewriter &rewriter) const override { + if (op.getType() != options.sourceType) + return rewriter.notifyMatchFailure(op, "does not match source type"); + Value lhsTrunc = + rewriter.create<TruncFOp>(op.getLoc(), options.targetType, op.getLhs()); + Value rhsTrunc = + rewriter.create<TruncFOp>(op.getLoc(), options.targetType, op.getRhs()); + Value newAdd = rewriter.create<AddFOp>(op.getLoc(), lhsTrunc, rhsTrunc); + rewriter.replaceOpWithNewOp<ExtFOp>(op, op.getType(), newAdd); + return success(); + } + +private: + const ReduceFloatOptions options; +}; +class AddFOpPatternV2 : public OpRewritePattern<AddFOp> { +public: + AddFOpPatternV2(MLIRContext *context, const ReduceFloatOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + + LogicalResult matchAndRewrite(AddFOp op, + PatternRewriter &rewriter) const override { + if (op.getType() != options.sourceType) + return rewriter.notifyMatchFailure(op, "does not match source type"); + Value lhsTrunc = + rewriter.create<TruncFOp>(op.getLoc(), options.targetType, op.getLhs()); + Value rhsTrunc = + rewriter.create<TruncFOp>(op.getLoc(), options.targetType, op.getRhs()); + Value newAdd = rewriter.create<AddFOp>(op.getLoc(), lhsTrunc, rhsTrunc); + Value replacementValue = + rewriter.create<ExtFOp>(op.getLoc(), op.getType(), newAdd); + rewriter.replaceAllUsesWith(op.getResult(), replacementValue); + rewriter.eraseOp(op); + return success(); + } + +private: + const ReduceFloatOptions options; +}; + +/// Pattern for arith.constant. +class ConstantOpPattern : public OpRewritePattern<ConstantOp> { +public: + ConstantOpPattern(MLIRContext *context, const ReduceFloatOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + + LogicalResult matchAndRewrite(ConstantOp op, + PatternRewriter &rewriter) const override { + if (op.getType() != options.sourceType) + return rewriter.notifyMatchFailure(op, "does not match source type"); + double val = cast<FloatAttr>(op.getValue()).getValueAsDouble(); + auto newAttr = FloatAttr::get(options.targetType, val); + Value newConstant = rewriter.create<ConstantOp>(op.getLoc(), newAttr); + rewriter.replaceOpWithNewOp<ExtFOp>(op, op.getType(), newConstant); + return success(); + } + +private: + const ReduceFloatOptions options; +}; + +/// Pattern for func.func. +class FuncOpPattern : public OpRewritePattern<func::FuncOp> { +public: + FuncOpPattern(MLIRContext *context, const ReduceFloatOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + + LogicalResult matchAndRewrite(func::FuncOp op, + PatternRewriter &rewriter) const override { + if (!llvm::hasSingleElement(op.getBody())) + return rewriter.notifyMatchFailure(op, "0 or >1 blocks not supported"); + FunctionType type = op.getFunctionType(); + SmallVector<Type> newInputs; + for (Type t : type.getInputs()) { + if (t == options.sourceType) { + newInputs.push_back(options.targetType); + } else { + newInputs.push_back(t); + } + } + SmallVector<Type> newResults; + for (Type t : type.getResults()) { + if (t == options.sourceType) { + newResults.push_back(options.targetType); + } else { + newResults.push_back(t); + } + } + if (llvm::equal(type.getInputs(), newInputs) && + llvm::equal(type.getResults(), newResults)) + return rewriter.notifyMatchFailure(op, "no types to convert"); + auto newFuncOp = rewriter.create<func::FuncOp>( + op.getLoc(), op.getSymName(), + FunctionType::get(op.getContext(), newInputs, newResults)); + SmallVector<Location> locs = + llvm::map_to_vector(op.getBody().getArguments(), + [](BlockArgument arg) { return arg.getLoc(); }); + Block *newBlock = rewriter.createBlock( + &newFuncOp.getBody(), newFuncOp.getBody().begin(), newInputs, locs); + rewriter.setInsertionPointToStart(newBlock); + SmallVector<Value> argRepl; + for (auto [oldType, newType, newArg] : llvm::zip_equal( + type.getInputs(), newInputs, newBlock->getArguments())) { + if (oldType == newType) { + argRepl.push_back(newArg); + } else { + argRepl.push_back( + rewriter.create<ExtFOp>(newArg.getLoc(), oldType, newArg)); + } + } + rewriter.inlineBlockBefore(&op.getBody().front(), newBlock, newBlock->end(), + argRepl); + rewriter.eraseOp(op); + return success(); + } + +private: + const ReduceFloatOptions options; +}; + +/// Pattern for func.return. +class ReturnOpPattern : public OpRewritePattern<func::ReturnOp> { +public: + ReturnOpPattern(MLIRContext *context, const ReduceFloatOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), options(options) {} + + LogicalResult matchAndRewrite(func::ReturnOp op, + PatternRewriter &rewriter) const override { + bool changedIR = false; + SmallVector<Value> newOperands; + for (Value val : op.getOperands()) { + if (val.getType() != options.sourceType) { + newOperands.push_back(val); + } else { + changedIR = true; + newOperands.push_back( + rewriter.create<TruncFOp>(val.getLoc(), options.targetType, val)); + } + } + if (!changedIR) + return rewriter.notifyMatchFailure(op, "no types to convert"); + rewriter.modifyOpInPlace( + op, [&]() { op.getOperandsMutable().assign(newOperands); }); + return success(); + } + +private: + const ReduceFloatOptions options; +}; + +/// Pattern that folds arith.truncf(arith.extf(x)) => x. +class ExtTruncFolding : public OpRewritePattern<TruncFOp> { +public: + ExtTruncFolding(MLIRContext *context, const ReduceFloatOptions &options, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + LogicalResult matchAndRewrite(TruncFOp op, + PatternRewriter &rewriter) const override { + auto extfOp = op.getIn().getDefiningOp<ExtFOp>(); + if (!extfOp) + return rewriter.notifyMatchFailure(op, + "'in' is not defined by arith.extf"); + if (extfOp.getIn().getType() != op.getType()) + return rewriter.notifyMatchFailure(op, "types do not match"); + rewriter.replaceOp(op, extfOp.getIn()); + return success(); + } +}; + +struct TestReduceFloatBitwidthPass + : public PassWrapper<TestReduceFloatBitwidthPass, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestReduceFloatBitwidthPass) + + TestReduceFloatBitwidthPass() = default; + TestReduceFloatBitwidthPass(const TestReduceFloatBitwidthPass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<arith::ArithDialect, func::FuncDialect, pdl::PDLDialect, + pdl_interp::PDLInterpDialect>(); + } + StringRef getArgument() const final { + return "test-arith-reduce-float-bitwidth"; + } + StringRef getDescription() const final { + return "Pass that reduces the bitwidth of floating-point ops"; + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateTestReduceFloatBitwidthPatterns( + patterns, optPatterns, FloatType::getF32(ctx), FloatType::getF16(ctx)); + + GreedyRewriteConfig config; + config.fold = optFold; + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { + getOperation()->emitError() << getArgument() << " failed"; + signalPassFailure(); + } + } + + Option<bool> optFold{*this, "fold", llvm::cl::init(true), + llvm::cl::desc("fold ops")}; + ListOption<std::string> optPatterns{*this, "patterns", + llvm::cl::desc("activated patterns")}; +}; +} // namespace + +static Attribute convertAttrF32ToF16(PatternRewriter &rewriter, + Attribute attr) { + auto floatAttr = dyn_cast<FloatAttr>(attr); + if (!attr) + return Attribute(); + return rewriter.getFloatAttr(rewriter.getF16Type(), + floatAttr.getValueAsDouble()); +} + +void arith::populateTestReduceFloatBitwidthPatterns( + RewritePatternSet &patterns, ArrayRef<std::string> enabledPatterns, + FloatType sourceType, FloatType targetType) { + ReduceFloatOptions options{sourceType, targetType}; + MLIRContext *ctx = patterns.getContext(); + if (llvm::is_contained(enabledPatterns, "arith.addf")) + patterns.insert<AddFOpPattern>(ctx, options); + if (llvm::is_contained(enabledPatterns, "arith.addf_v2")) + patterns.insert<AddFOpPatternV2>(ctx, options); + if (llvm::is_contained(enabledPatterns, "arith.constant")) + patterns.insert<ConstantOpPattern>(ctx, options); + if (llvm::is_contained(enabledPatterns, "func.func")) + patterns.insert<FuncOpPattern>(ctx, options); + if (llvm::is_contained(enabledPatterns, "func.return")) + patterns.insert<ReturnOpPattern>(ctx, options); + if (llvm::is_contained(enabledPatterns, "arith.truncf")) + patterns.insert<ExtTruncFolding>(ctx, options); + if (llvm::is_contained(enabledPatterns, "pdl_patterns")) { + patterns.getPDLPatterns().registerRewriteFunction("ConvertAttrF32ToF16", + convertAttrF32ToF16); + populateGeneratedPDLLPatterns(patterns); + } +} + +namespace mlir { +void registerTestReduceFloatBitwidthPass() { + PassRegistration<TestReduceFloatBitwidthPass>(); +} +} // namespace mlir diff --git a/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.pdll b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.pdll new file mode 100644 index 0000000..43e7b7f --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidth.pdll @@ -0,0 +1,29 @@ +#include "mlir/Dialect/Arith/IR/ArithOps.td" + +Pattern TruncExtPattern { + let extf = op<arith.extf>(input: Value<type<"f16">>) -> (type<"f32">); + let truncf = op<arith.truncf>(extf) -> (type<"f16">); + replace truncf with input; +} + +Pattern AddPattern { + let addf = op<arith.addf>(lhs: Value, rhs: Value) -> (type<"f32">); + rewrite addf with { + let lhs16 = op<arith.truncf>(lhs) -> (type<"f16">); + let rhs16 = op<arith.truncf>(rhs) -> (type<"f16">); + let addf16 = op<arith.addf>(lhs16, rhs16); + replace addf with op<arith.extf>(addf16) -> (type<"f32">); + }; +} + +Rewrite ConvertAttrF32ToF16(value: Attr) -> Attr; + +Pattern ConstantPattern { + let attr: Attr; + let constant = op<arith.constant> {value = attr} -> (type<"f32">); + rewrite constant with { + let attr16 = ConvertAttrF32ToF16(attr); + let const16 = op<arith.constant>() {value = attr16}; + replace constant with op<arith.extf>(const16) -> (type<"f32">); + }; +} diff --git a/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidthConversion.cpp b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidthConversion.cpp new file mode 100644 index 0000000..09890dd --- /dev/null +++ b/mlir/lib/Dialect/Arith/Transforms/TestReduceFloatBitwidthConversion.cpp @@ -0,0 +1,139 @@ +//===- TestReduceFloatBitwdithConversion.cpp ----------------*- c++ -----*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// A pass that reduces the bitwidth of Arith floating-point IR. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::arith; + +namespace { + +/// Pattern for arith.constant. +class ConstantOpPattern : public OpConversionPattern<ConstantOp> { + using OpConversionPattern<ConstantOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + double val = cast<FloatAttr>(op.getValue()).getValueAsDouble(); + auto newAttr = FloatAttr::get(Float16Type::get(op.getContext()), val); + rewriter.replaceOpWithNewOp<ConstantOp>(op, newAttr); + return success(); + } +}; + +/// Pattern for arith.addf. +class AddOpPattern : public OpConversionPattern<AddFOp> { + using OpConversionPattern<AddFOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(AddFOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + rewriter.replaceOpWithNewOp<AddFOp>(op, adaptor.getLhs(), adaptor.getRhs()); + return success(); + } +}; + +struct TestReduceFloatBitwidthConversionPass + : public PassWrapper<TestReduceFloatBitwidthConversionPass, + OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestReduceFloatBitwidthConversionPass) + + TestReduceFloatBitwidthConversionPass() = default; + TestReduceFloatBitwidthConversionPass( + const TestReduceFloatBitwidthConversionPass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<arith::ArithDialect, func::FuncDialect>(); + } + StringRef getArgument() const final { + return "test-arith-reduce-float-bitwidth-conversion"; + } + StringRef getDescription() const final { + return "Pass that reduces the bitwidth of floating-point ops (dialect " + "conversion)"; + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + + TypeConverter converter; + ConversionConfig config; + converter.addConversion([](Type type) { return type; }); + converter.addConversion( + [&](Float32Type type) { return FloatType::getF16(ctx); }); + if (optBuildMaterializations) { + converter.addSourceMaterialization( + [](OpBuilder &builder, FloatType resultType, ValueRange inputs, + Location loc) -> Value { + assert(inputs.size() == 1 && "expected single input"); + return builder.create<ExtFOp>(loc, resultType, inputs[0]); + }); + converter.addTargetMaterialization( + [](OpBuilder &builder, FloatType resultType, ValueRange inputs, + Location loc) -> Value { + assert(inputs.size() == 1 && "expected single input"); + return builder.create<TruncFOp>(loc, resultType, inputs[0]); + }); + config.buildMaterializations = true; + } else { + config.buildMaterializations = false; + } + + RewritePatternSet patterns(ctx); + patterns.insert<ConstantOpPattern, AddOpPattern>(converter, ctx); + // Pattern for func.func. + populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, + converter); + populateReturnOpTypeConversionPattern(patterns, converter); + + ConversionTarget target(*ctx); + target.addDynamicallyLegalOp<ConstantOp, AddFOp, func::ReturnOp>( + [&](Operation *op) { return converter.isLegal(op); }); + target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { + return converter.isSignatureLegal(op.getFunctionType()); + }); + + LogicalResult status = failure(); + if (optFullConversion) { + status = applyFullConversion(getOperation(), target, std::move(patterns), + config); + } else { + status = applyPartialConversion(getOperation(), target, + std::move(patterns), config); + } + if (failed(status)) { + getOperation()->emitError() << getArgument() << " failed"; + signalPassFailure(); + } + } + + Option<bool> optBuildMaterializations{ + *this, "build-materializations", llvm::cl::init(false), + llvm::cl::desc("build materializations")}; + Option<bool> optFullConversion{ + *this, "full-conversion", llvm::cl::init(false), + llvm::cl::desc("full conversion (otherwise: partial)")}; +}; +} // namespace + +namespace mlir { +void registerTestReduceFloatBitwidthConversionPass() { + PassRegistration<TestReduceFloatBitwidthConversionPass>(); +} +} // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt index b15a470..b740309 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/TransformOps/CMakeLists.txt @@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRLinalgTransformOps LINK_LIBS PUBLIC MLIRAffineDialect MLIRArithDialect + MLIRArithTransforms MLIRBufferizationDialect MLIRBufferizationTransforms MLIRFuncDialect diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index a1d619c..c98ce23 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" @@ -220,6 +221,34 @@ static DiagnosedSilenceableFailure reifyMixedParamAndHandleResults( } //===----------------------------------------------------------------------===// +// Winter School +//===----------------------------------------------------------------------===// + +void transform::ApplyReduceFloatBitwidthPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + SmallVector<std::string> enabledPatternsStr; + for (Attribute attr : getEnabledPatterns()) { + enabledPatternsStr.push_back(cast<StringAttr>(attr).getValue().str()); + } + FloatType sourceType = cast<FloatType>(getSourceType()); + FloatType targetType = cast<FloatType>(getTargetType()); + arith::populateTestReduceFloatBitwidthPatterns(patterns, enabledPatternsStr, + sourceType, targetType); +} + +LogicalResult transform::ApplyReduceFloatBitwidthPatternsOp::verify() { + for (Attribute attr : getEnabledPatterns()) + if (!isa<StringAttr>(attr)) + return emitOpError( + "expected 'enabled_patterns' to be an array of string attributes"); + if (!isa<FloatType>(getSourceType())) + return emitOpError("expected float source type"); + if (!isa<FloatType>(getTargetType())) + return emitOpError("expected float target type"); + return success(); +} + +//===----------------------------------------------------------------------===// // Apply...PatternsOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_10_dialect_conversion.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_10_dialect_conversion.mlir new file mode 100644 index 0000000..f032c88 --- /dev/null +++ b/mlir/test/Dialect/Arith/WinterSchool/test_10_dialect_conversion.mlir @@ -0,0 +1,36 @@ +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth-conversion="build-materializations=0" -split-input-file +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth-conversion="build-materializations=1" -split-input-file +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth-conversion="build-materializations=1 full-conversion=1" -split-input-file + +func.func @test_constant() -> f32 { + %0 = arith.constant 2.0 : f32 + return %0 : f32 +} + +// ----- + +func.func @test_add(%arg0: f32, %arg1: f32) -> f32 { + %0 = arith.addf %arg0, %arg1 : f32 + return %0 : f32 +} + +// ----- + +func.func @test_add_constant(%arg0: f32) -> f32 { + %0 = arith.constant 2.0 : f32 + %1 = arith.addf %arg0, %0 : f32 + return %1 : f32 +} + +// ----- + +func.func @test_func(%arg0: f32) -> f32 { + return %arg0 : f32 +} + +// ----- + +func.func @test_boundary(%arg0: f32) -> f32 { + %0 = "test.consumer_producer"(%arg0) : (f32) -> (f32) + return %0 : f32 +} diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_1_arith_addf.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_1_arith_addf.mlir new file mode 100644 index 0000000..cd5b216 --- /dev/null +++ b/mlir/test/Dialect/Arith/WinterSchool/test_1_arith_addf.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.addf" + +func.func @test_add(%arg0: f32, %arg1: f32) { + %0 = arith.addf %arg0, %arg1 : f32 + return +} diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_2_arith_addf.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_2_arith_addf.mlir new file mode 100644 index 0000000..85f9c57 --- /dev/null +++ b/mlir/test/Dialect/Arith/WinterSchool/test_2_arith_addf.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.addf" + +func.func @test_add(%arg0: f32, %arg1: f32) -> f32 { + %0 = arith.addf %arg0, %arg1 : f32 + return %0 : f32 +} diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_3_arith_constant.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_3_arith_constant.mlir new file mode 100644 index 0000000..8c8f933 --- /dev/null +++ b/mlir/test/Dialect/Arith/WinterSchool/test_3_arith_constant.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.constant" + +func.func @test_constant() -> f32 { + %0 = arith.constant 2.0 : f32 + return %0 : f32 +} diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_4_arith_constant.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_4_arith_constant.mlir new file mode 100644 index 0000000..5ef8434 --- /dev/null +++ b/mlir/test/Dialect/Arith/WinterSchool/test_4_arith_constant.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.constant fold=false" + +func.func @test_constant() -> f32 { + %0 = arith.constant 2.0 : f32 + return %0 : f32 +} diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_5_arith_constant_addf.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_5_arith_constant_addf.mlir new file mode 100644 index 0000000..b50a056 --- /dev/null +++ b/mlir/test/Dialect/Arith/WinterSchool/test_5_arith_constant_addf.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.constant,arith.addf" +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.addf" + +func.func @test_add_constant(%arg0: f32) -> f32 { + %0 = arith.constant 2.0 : f32 + %1 = arith.addf %arg0, %0 : f32 + return %1 : f32 +} diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_6_func.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_6_func.mlir new file mode 100644 index 0000000..1b4ae9b --- /dev/null +++ b/mlir/test/Dialect/Arith/WinterSchool/test_6_func.mlir @@ -0,0 +1,7 @@ +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=func.return" +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=func.func" +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=func.func,func.return" + +func.func @test_func(%arg0: f32) -> f32 { + return %arg0 : f32 +} diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_7_arith_truncf.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_7_arith_truncf.mlir new file mode 100644 index 0000000..162f676 --- /dev/null +++ b/mlir/test/Dialect/Arith/WinterSchool/test_7_arith_truncf.mlir @@ -0,0 +1,6 @@ +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=func.func,func.return,arith.truncf" +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=func.func,func.return,arith.truncf" -canonicalize + +func.func @test_func(%arg0: f32) -> f32 { + return %arg0 : f32 +} diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_8_transform.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_8_transform.mlir new file mode 100644 index 0000000..369ce48 --- /dev/null +++ b/mlir/test/Dialect/Arith/WinterSchool/test_8_transform.mlir @@ -0,0 +1,23 @@ +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=arith.addf" + +func.func @test_add(%arg0: f32, %arg1: f32) -> f32 { + %0 = arith.addf %arg0, %arg1 : f32 + return %0 : f32 +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %func_op = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.op<"func.func"> + %add_op = transform.structured.match ops{["arith.addf"]} in %func_op : (!transform.op<"func.func">) -> !transform.op<"arith.addf"> + transform.debug.emit_remark_at %add_op, "before pattern application" : !transform.op<"arith.addf"> + + transform.apply_patterns to %func_op { + transform.apply_patterns.arith.reduce_float_bitwidth ["arith.addf"] from f32 to f16 + // transform.apply_patterns.arith.reduce_float_bitwidth ["func.func", "func.return", "arith.addf"] from f32 to f16 + // transform.apply_patterns.arith.reduce_float_bitwidth ["arith.addf_v2"] from f32 to f16 + } : !transform.op<"func.func"> + + transform.debug.emit_remark_at %add_op, "after pattern application" : !transform.op<"arith.addf"> + transform.yield + } +} diff --git a/mlir/test/Dialect/Arith/WinterSchool/test_9_arith_truncf_pdl.mlir b/mlir/test/Dialect/Arith/WinterSchool/test_9_arith_truncf_pdl.mlir new file mode 100644 index 0000000..68268a3 --- /dev/null +++ b/mlir/test/Dialect/Arith/WinterSchool/test_9_arith_truncf_pdl.mlir @@ -0,0 +1,22 @@ +// RUN: mlir-opt %s -test-arith-reduce-float-bitwidth="patterns=pdl_patterns" -split-input-file + +func.func @test_trunc_ext(%arg0: f16) -> f16 { + %0 = arith.extf %arg0 : f16 to f32 + %1 = arith.truncf %0 : f32 to f16 + return %1 : f16 +} + +// ----- + +func.func @test_add(%arg0: f32, %arg1: f32) -> f32 { + %0 = arith.addf %arg0, %arg1 : f32 + return %0 : f32 +} + +// ----- + +func.func @test_add_constant(%arg0: f32) -> f32 { + %0 = arith.constant 2.0 : f32 + %1 = arith.addf %arg0, %0 : f32 + return %1 : f32 +} diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 960f703..551639b 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -60,6 +60,8 @@ void registerTestPreserveUseListOrders(); void registerTestPrintDefUsePass(); void registerTestPrintInvalidPass(); void registerTestPrintNestingPass(); +void registerTestReduceFloatBitwidthPass(); +void registerTestReduceFloatBitwidthConversionPass(); void registerTestReducer(); void registerTestSpirvEntryPointABIPass(); void registerTestSpirvModuleCombinerPass(); @@ -200,6 +202,8 @@ void registerTestPasses() { registerTestPrintDefUsePass(); registerTestPrintInvalidPass(); registerTestPrintNestingPass(); + registerTestReduceFloatBitwidthPass(); + registerTestReduceFloatBitwidthConversionPass(); registerTestReducer(); registerTestSpirvEntryPointABIPass(); registerTestSpirvModuleCombinerPass(); |