//===- ElementwiseToLinalg.cpp - conversion of elementwise to linalg ------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { #define GEN_PASS_DEF_CONVERTELEMENTWISETOLINALGPASS #include "mlir/Dialect/Linalg/Passes.h.inc" } // namespace mlir using namespace mlir; static inline bool isScalarLike(Type t) { return isa(t); } static bool isElementwiseMappableOpOnRankedTensors(Operation *op) { if (!OpTrait::hasElementwiseMappableTraits(op)) return false; auto types = op->getOperandTypes(); // We want at least one ranked tensor. bool anyRankedTensor = llvm::any_of(types, llvm::IsaPred); // No invalid operands (i.e., every operand is a ranked tensor or // scalar-like). bool noneInvalid = llvm::none_of(types, [](Type t) { return !(isa(t) || isScalarLike(t)); }); return anyRankedTensor && noneInvalid; } /// Given `op` assumed `isElementwiseMappableOpOnRankedTensors`, iterate over /// the result types and return a list of values such that, for each result type /// `t` and value `v` at the same index `idx`: /// 1. `v.getType() == t` /// 2. If an operand of `op` has type `t`, let `operand_first` be the first /// such operand. Then`v == operand_first`. /// 3. Otherwise, v is a newly created `tensor::EmptyOp` with: /// a. Static and dynamic dims extracted from the first operand of `op`. /// b. Elemental type equal to the elemental type of `t`. /// /// This is sufficient because ElementwiseMappable guarantees that "The static /// types of all vector (resp. tensor) operands and results must have the same /// shape". static SmallVector getOrCreateOperandsMatchingResultTypes(OpBuilder &b, Operation *op) { assert(isElementwiseMappableOpOnRankedTensors(op)); Location loc = op->getLoc(); ValueRange operands = op->getOperands(); TypeRange rankedTensorTypes = op->getResultTypes(); SmallVector res; res.reserve(rankedTensorTypes.size()); for (Type t : rankedTensorTypes) { // Try to find an operand with type matching the result tensor. bool found = false; for (Value v : operands) { if (v.getType() == t) { found = true; res.push_back(v); break; } } if (found) continue; // Extract static / dynamic shape mix from the first operand. res.push_back(tensor::EmptyOp::create( b, loc, tensor::getMixedSizes(b, loc, operands.front()), cast(t).getElementType())); } return res; } namespace { struct ConvertAnyElementwiseMappableOpOnRankedTensors : public RewritePattern { ConvertAnyElementwiseMappableOpOnRankedTensors(MLIRContext *context) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} LogicalResult matchAndRewrite(Operation *op, PatternRewriter &rewriter) const final { if (!isElementwiseMappableOpOnRankedTensors(op)) return rewriter.notifyMatchFailure( op, "requires elementwise op on ranked tensors"); auto resTy = cast(op->getResult(0).getType()); auto rank = resTy.getRank(); // Maps: identity for tensors (rank > 0), scalar map for scalars. AffineMap scalarMap = AffineMap::get(/*dimCount=*/rank, /*symbolCount=*/0, /*results=*/{}, rewriter.getContext()); AffineMap idMap = rewriter.getMultiDimIdentityMap(rank); // Match phase. SmallVector isScalarOperand; isScalarOperand.reserve(op->getNumOperands()); for (Type ty : op->getOperandTypes()) { if (isScalarLike(ty)) isScalarOperand.push_back(true); else if (auto rt = dyn_cast(ty)) isScalarOperand.push_back(false); else return rewriter.notifyMatchFailure( op, "unsupported operand type (expected scalar-like or ranked tensor)"); } // Create indexing maps. SmallVector indexingMaps; indexingMaps.reserve(op->getNumOperands() + op->getNumResults()); for (bool isScalar : isScalarOperand) indexingMaps.push_back(isScalar ? scalarMap : idMap); indexingMaps.append(op->getNumResults(), idMap); SmallVector iteratorTypes( rank, utils::IteratorType::parallel); SmallVector outputs = getOrCreateOperandsMatchingResultTypes(rewriter, op); rewriter.replaceOpWithNewOp( op, /*resultTensorTypes=*/op->getResultTypes(), /*inputs=*/op->getOperands(), /*outputs=*/outputs, /*indexingMaps=*/indexingMaps, /*iteratorTypes=*/iteratorTypes, /*bodyBuilder=*/ [&](OpBuilder &builder, Location loc, ValueRange regionArgs) { SmallVector resultEltTys = llvm::to_vector<6>( llvm::map_range(op->getResultTypes(), [](Type type) { return cast(type).getElementType(); })); Operation *scalarOp = builder.create(loc, op->getName().getIdentifier(), regionArgs.take_front(op->getNumOperands()), resultEltTys, op->getAttrs()); linalg::YieldOp::create(builder, loc, scalarOp->getResults()); }); return success(); } }; } // namespace void mlir::linalg::populateElementwiseToLinalgConversionPatterns( RewritePatternSet &patterns) { patterns.add( patterns.getContext()); } namespace { class ConvertElementwiseToLinalgPass : public impl::ConvertElementwiseToLinalgPassBase< ConvertElementwiseToLinalgPass> { using impl::ConvertElementwiseToLinalgPassBase< ConvertElementwiseToLinalgPass>::ConvertElementwiseToLinalgPassBase; void runOnOperation() final { auto *func = getOperation(); auto *context = &getContext(); ConversionTarget target(*context); RewritePatternSet patterns(context); mlir::linalg::populateElementwiseToLinalgConversionPatterns(patterns); target.markUnknownOpDynamicallyLegal([](Operation *op) { return !isElementwiseMappableOpOnRankedTensors(op); }); if (failed(applyPartialConversion(func, target, std::move(patterns)))) signalPassFailure(); } }; } // namespace