diff options
Diffstat (limited to 'mlir/examples/toy/Ch7')
-rw-r--r-- | mlir/examples/toy/Ch7/CMakeLists.txt | 11 | ||||
-rw-r--r-- | mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp | 122 | ||||
-rw-r--r-- | mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp | 12 | ||||
-rw-r--r-- | mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp | 4 | ||||
-rw-r--r-- | mlir/examples/toy/Ch7/toyc.cpp | 1 |
5 files changed, 65 insertions, 85 deletions
diff --git a/mlir/examples/toy/Ch7/CMakeLists.txt b/mlir/examples/toy/Ch7/CMakeLists.txt index 362ab51..a489ae5 100644 --- a/mlir/examples/toy/Ch7/CMakeLists.txt +++ b/mlir/examples/toy/Ch7/CMakeLists.txt @@ -36,14 +36,8 @@ add_toy_chapter(toyc-ch7 include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${CMAKE_CURRENT_BINARY_DIR}/include/) -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) target_link_libraries(toyc-ch7 PRIVATE - ${dialect_libs} - ${conversion_libs} - ${extension_libs} MLIRAnalysis MLIRBuiltinToLLVMIRTranslation MLIRCallInterfaces @@ -56,7 +50,10 @@ target_link_libraries(toyc-ch7 MLIRMemRefDialect MLIRParser MLIRPass + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses MLIRSideEffectInterfaces MLIRTargetLLVMIRExport MLIRTransforms - ) + ) diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp index d65c89c..cbe4236 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -44,7 +44,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns +// ToyToAffine Conversion Patterns //===----------------------------------------------------------------------===// /// Convert the given RankedTensorType into the corresponding MemRefType. @@ -69,15 +69,13 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc, } /// This defines the function type used to process an iteration of a lowered -/// loop. It takes as input an OpBuilder, an range of memRefOperands -/// corresponding to the operands of the input operation, and the range of loop -/// induction variables for the iteration. It returns a value to store at the -/// current index of the iteration. -using LoopIterationFn = function_ref<Value( - OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>; - -static void lowerOpToLoops(Operation *op, ValueRange operands, - PatternRewriter &rewriter, +/// loop. It takes as input an OpBuilder and the range of loop induction +/// variables for the iteration. It returns a value to store at the current +/// index of the iteration. +using LoopIterationFn = + function_ref<Value(OpBuilder &rewriter, ValueRange loopIvs)>; + +static void lowerOpToLoops(Operation *op, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = llvm::cast<RankedTensorType>((*op->result_type_begin())); auto loc = op->getLoc(); @@ -95,10 +93,10 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, affine::buildAffineLoopNest( rewriter, loc, lowerBounds, tensorType.getShape(), steps, [&](OpBuilder &nestedBuilder, Location loc, ValueRange ivs) { - // Call the processing function with the rewriter, the memref operands, + // Call the processing function with the rewriter // and the loop induction variables. This function will return the value // to store at the current index. - Value valueToStore = processIteration(nestedBuilder, operands, ivs); + Value valueToStore = processIteration(nestedBuilder, ivs); affine::AffineStoreOp::create(nestedBuilder, loc, valueToStore, alloc, ivs); }); @@ -109,38 +107,30 @@ static void lowerOpToLoops(Operation *op, ValueRange operands, namespace { //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Binary operations +// ToyToAffine Conversion Patterns: Binary operations //===----------------------------------------------------------------------===// template <typename BinaryOp, typename LoweredBinaryOp> -struct BinaryOpLowering : public ConversionPattern { - BinaryOpLowering(MLIRContext *ctx) - : ConversionPattern(BinaryOp::getOperationName(), 1, ctx) {} +struct BinaryOpLowering : public OpConversionPattern<BinaryOp> { + using OpConversionPattern<BinaryOp>::OpConversionPattern; + using OpAdaptor = typename OpConversionPattern<BinaryOp>::OpAdaptor; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(BinaryOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // BinaryOp. This allows for using the nice named accessors - // that are generated by the ODS. - typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); - - // Generate loads for the element of 'lhs' and 'rhs' at the - // inner loop. - auto loadedLhs = affine::AffineLoadOp::create( - builder, loc, binaryAdaptor.getLhs(), loopIvs); - auto loadedRhs = affine::AffineLoadOp::create( - builder, loc, binaryAdaptor.getRhs(), loopIvs); - - // Create the binary operation performed on the loaded - // values. - return LoweredBinaryOp::create(builder, loc, loadedLhs, - loadedRhs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + // Generate loads for the element of 'lhs' and 'rhs' at the + // inner loop. + auto loadedLhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getLhs(), loopIvs); + auto loadedRhs = + affine::AffineLoadOp::create(builder, loc, adaptor.getRhs(), loopIvs); + + // Create the binary operation performed on the loaded + // values. + return LoweredBinaryOp::create(builder, loc, loadedLhs, loadedRhs); + }); return success(); } }; @@ -148,14 +138,15 @@ using AddOpLowering = BinaryOpLowering<toy::AddOp, arith::AddFOp>; using MulOpLowering = BinaryOpLowering<toy::MulOp, arith::MulFOp>; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Constant operations +// ToyToAffine Conversion Patterns: Constant operations //===----------------------------------------------------------------------===// -struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { - using OpRewritePattern<toy::ConstantOp>::OpRewritePattern; +struct ConstantOpLowering : public OpConversionPattern<toy::ConstantOp> { + using OpConversionPattern<toy::ConstantOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ConstantOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ConstantOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { DenseElementsAttr constantValue = op.getValue(); Location loc = op.getLoc(); @@ -216,7 +207,7 @@ struct ConstantOpLowering : public OpRewritePattern<toy::ConstantOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Func operations +// ToyToAffine Conversion Patterns: Func operations //===----------------------------------------------------------------------===// struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { @@ -247,7 +238,7 @@ struct FuncOpLowering : public OpConversionPattern<toy::FuncOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Print operations +// ToyToAffine Conversion Patterns: Print operations //===----------------------------------------------------------------------===// struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { @@ -265,14 +256,15 @@ struct PrintOpLowering : public OpConversionPattern<toy::PrintOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Return operations +// ToyToAffine Conversion Patterns: Return operations //===----------------------------------------------------------------------===// -struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { - using OpRewritePattern<toy::ReturnOp>::OpRewritePattern; +struct ReturnOpLowering : public OpConversionPattern<toy::ReturnOp> { + using OpConversionPattern<toy::ReturnOp>::OpConversionPattern; - LogicalResult matchAndRewrite(toy::ReturnOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(toy::ReturnOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { // During this lowering, we expect that all function calls have been // inlined. if (op.hasOperand()) @@ -285,32 +277,24 @@ struct ReturnOpLowering : public OpRewritePattern<toy::ReturnOp> { }; //===----------------------------------------------------------------------===// -// ToyToAffine RewritePatterns: Transpose operations +// ToyToAffine Conversion Patterns: Transpose operations //===----------------------------------------------------------------------===// -struct TransposeOpLowering : public ConversionPattern { - TransposeOpLowering(MLIRContext *ctx) - : ConversionPattern(toy::TransposeOp::getOperationName(), 1, ctx) {} +struct TransposeOpLowering : public OpConversionPattern<toy::TransposeOp> { + using OpConversionPattern<toy::TransposeOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::TransposeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = op->getLoc(); - lowerOpToLoops(op, operands, rewriter, - [loc](OpBuilder &builder, ValueRange memRefOperands, - ValueRange loopIvs) { - // Generate an adaptor for the remapped operands of the - // TransposeOp. This allows for using the nice named - // accessors that are generated by the ODS. - toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); - Value input = transposeAdaptor.getInput(); - - // Transpose the elements by generating a load from the - // reverse indices. - SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); - return affine::AffineLoadOp::create(builder, loc, input, - reverseIvs); - }); + lowerOpToLoops(op, rewriter, [&](OpBuilder &builder, ValueRange loopIvs) { + Value input = adaptor.getInput(); + + // Transpose the elements by generating a load from the + // reverse indices. + SmallVector<Value, 2> reverseIvs(llvm::reverse(loopIvs)); + return affine::AffineLoadOp::create(builder, loc, input, reverseIvs); + }); return success(); } }; diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 43a84da..8b48a8f 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -55,19 +55,18 @@ using namespace mlir; //===----------------------------------------------------------------------===// -// ToyToLLVM RewritePatterns +// ToyToLLVM Conversion Patterns //===----------------------------------------------------------------------===// namespace { /// Lowers `toy.print` to a loop nest calling `printf` on each of the individual /// elements of the array. -class PrintOpLowering : public ConversionPattern { +class PrintOpLowering : public OpConversionPattern<toy::PrintOp> { public: - explicit PrintOpLowering(MLIRContext *context) - : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} + using OpConversionPattern<toy::PrintOp>::OpConversionPattern; LogicalResult - matchAndRewrite(Operation *op, ArrayRef<Value> operands, + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *context = rewriter.getContext(); auto memRefType = llvm::cast<MemRefType>((*op->operand_type_begin())); @@ -108,9 +107,8 @@ public: } // Generate a call to printf for the current element of the loop. - auto printOp = cast<toy::PrintOp>(op); auto elementLoad = - memref::LoadOp::create(rewriter, loc, printOp.getInput(), loopIvs); + memref::LoadOp::create(rewriter, loc, op.getInput(), loopIvs); LLVM::CallOp::create(rewriter, loc, getPrintfType(context), printfRef, ArrayRef<Value>({formatSpecifierCst, elementLoad})); diff --git a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp index 2522abe..a552e1f0 100644 --- a/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp +++ b/mlir/examples/toy/Ch7/mlir/ShapeInferencePass.cpp @@ -23,7 +23,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Casting.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugLog.h" #include "llvm/Support/raw_ostream.h" #include <memory> @@ -81,7 +81,7 @@ struct ShapeInferencePass opWorklist.erase(op); // Ask the operation to infer its output shapes. - LLVM_DEBUG(llvm::dbgs() << "Inferring shape for: " << *op << "\n"); + LDBG() << "Inferring shape for: " << *op; if (auto shapeOp = dyn_cast<ShapeInference>(op)) { shapeOp.inferShapes(); } else { diff --git a/mlir/examples/toy/Ch7/toyc.cpp b/mlir/examples/toy/Ch7/toyc.cpp index dd86265..32208ecca 100644 --- a/mlir/examples/toy/Ch7/toyc.cpp +++ b/mlir/examples/toy/Ch7/toyc.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Func/Extensions/AllExtensions.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" #include "toy/AST.h" |