//===- FoldIntoElementwise.cpp - Fold Ops into elementwise if possible ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements folding ops such as transpose and broadcast into the // affine maps of the elementwise op. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/SmallVector.h" namespace mlir { #define GEN_PASS_DEF_LINALGFOLDINTOELEMENTWISEPASS #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir using namespace mlir; using namespace mlir::linalg; #define DEBUG_TYPE "linalg-fold-into-elementwise" namespace { struct FoldTransposePattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ElementwiseOp op, PatternRewriter &rewriter) const override { bool changed = false; SmallVector newIns; SmallVector newMaps; for (OpOperand *operand : op.getDpsInputOperands()) { AffineMap map = op.getMatchingIndexingMap(operand); auto transposeOp = operand->get().getDefiningOp(); if (!map.isIdentity() || !transposeOp) { // push in original operand and its map. newIns.push_back(operand->get()); newMaps.push_back(map); continue; } newIns.push_back(transposeOp.getInput()); // push in transposeOp's inverse permutation map. newMaps.push_back(transposeOp.getMatchingIndexingMap( transposeOp.getDpsInputOperand(0))); changed = true; } if (!changed) return failure(); newMaps.push_back(op.getIndexingMapsArray().back()); rewriter.replaceOpWithNewOp( op, newIns, op.getDpsInits()[0], op.getKindAttr(), rewriter.getAffineMapArrayAttr(newMaps)); return success(); } }; struct LinalgFoldIntoElementwisePass : public impl::LinalgFoldIntoElementwisePassBase< LinalgFoldIntoElementwisePass> { using impl::LinalgFoldIntoElementwisePassBase< LinalgFoldIntoElementwisePass>::LinalgFoldIntoElementwisePassBase; void runOnOperation() override { Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); populateLinalgFoldIntoElementwisePatterns(patterns); if (failed(applyPatternsGreedily(op, std::move(patterns)))) return signalPassFailure(); } }; } // namespace void mlir::linalg::populateLinalgFoldIntoElementwisePatterns( RewritePatternSet &patterns) { patterns.add(patterns.getContext()); }