//===- TosaToTensor.cpp - Lowering Tosa to Tensor Dialect -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // These rewriters lower from the Tosa to the Tensor dialect. // //===----------------------------------------------------------------------===// #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" #include using namespace mlir; using namespace tosa; namespace { // Infer the type to which the input of a 'tosa.reshape' op must be cast when // lowered. TensorType inferReshapeInputType(TypedValue input, ArrayRef newShape) { // No need to cast input for non-empty target shape if (!newShape.empty()) return input.getType(); // The input type must be cast into a tensor with the same rank and all static // dimensions set to 1. This prevents the generation of a // tensor.collapse_shape op that converts a dynamically shaped tensor into a // 0D tensor. While such construct is not incorrect on its own, bufferization // cannot properly handle it at the moment, so we avoid it. SmallVector shape(input.getType().getRank(), 1); return input.getType().clone(shape); } // Infer the result type of 'tensor.expand_shape' in the collapse-expand // pair emitted for a 'tosa.reshape' op. TensorType inferReshapeExpandedType(TensorType inputType, ArrayRef newShape) { // Special case for 0D output tensor. Note: Watch out when using Type::clone() // with just '{}', as it will invoke the incorrect overload. if (newShape.empty()) return inputType.clone(ArrayRef{}); // Check if the input is static, and if so, get its total size bool inputIsStatic = inputType.hasStaticShape(); int64_t totalSize = inputIsStatic ? inputType.getNumElements() : -1; // Compute result shape auto resultShape = llvm::map_to_vector(newShape, [&](int64_t size) -> int64_t { // If this is not a placeholder, do not change it. if (size >= 0) return size; // If we do not know the total size of the tensor, keep this dimension // dynamic in the result shape. if (!inputIsStatic) return ShapedType::kDynamic; // Calculate the product of all elements in 'newShape' except for the -1 // placeholder, which we discard by negating the result. int64_t totalSizeNoPlaceholder = -std::accumulate( newShape.begin(), newShape.end(), 1, std::multiplies()); // If there is a 0 component in 'newShape', resolve the placeholder as // 0. if (totalSizeNoPlaceholder == 0) return 0; // Resolve the placeholder as the quotient between the total tensor size // and the product of all other sizes. return totalSize / totalSizeNoPlaceholder; }); bool resultIsStatic = ShapedType::isStaticShape(resultShape); // A syntactic restriction in 'tensor.expand_shape' forbids a dynamically // shaped input from being reshaped into a statically shaped result. We may // simply turn the first result dimension dynamic to address this. if (!inputIsStatic && resultIsStatic) resultShape[0] = ShapedType::kDynamic; // The 'tensor.expand_shape' op also forbids a statically shaped input from // being reshaped into a dynamically shaped result, but the placeholder // inference algorithm above guarantees that this will never be the case. assert(!inputIsStatic || resultIsStatic); // Create result type return inputType.clone(resultShape); } // Infer the result type of 'tensor.collapse_shape' in the collapse-expand // pair emitted for a 'tosa.reshape' op. TensorType inferReshapeCollapsedType(TensorType lhsType, TensorType rhsType) { auto lhsShape = lhsType.getShape(); auto rhsShape = rhsType.getShape(); if (lhsShape.empty() || rhsShape.empty()) return lhsType.clone(ArrayRef{}); if (ShapedType::isDynamicShape(lhsShape) || ShapedType::isDynamicShape(rhsShape)) return lhsType.clone({ShapedType::kDynamic}); SmallVector intermediateShape; unsigned currLhsDim = 0, currRhsDim = 0; while (currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { int64_t rhsSize = rhsShape[currRhsDim]; int64_t lhsSize = lhsShape[currLhsDim]; while (lhsSize != rhsSize && currLhsDim < lhsShape.size() && currRhsDim < rhsShape.size()) { if (lhsSize < rhsSize) { currLhsDim++; if (currLhsDim < lhsShape.size()) { lhsSize *= lhsShape[currLhsDim]; } } else { currRhsDim++; if (currRhsDim < rhsShape.size()) { rhsSize *= rhsShape[currRhsDim]; } } } if (lhsSize == rhsSize) { intermediateShape.push_back(lhsSize); } currRhsDim++; currLhsDim++; } // Static shapes are guaranteed to be compatible by the op verifier, so all // leftover dimensions should be 1. for (; currLhsDim < lhsShape.size(); currLhsDim++) { assert(lhsShape[currLhsDim] == 1); } for (; currRhsDim < rhsShape.size(); currRhsDim++) { assert(rhsShape[currRhsDim] == 1); } return lhsType.clone(intermediateShape); } SmallVector createReassociationMapForCollapse(OpBuilder &builder, Type srcType, Type dstType) { auto srcShape = cast(srcType).getShape(); auto dstShape = cast(dstType).getShape(); if (srcShape.empty() || dstShape.empty()) return {}; if (ShapedType::isDynamicShape(srcShape) || ShapedType::isDynamicShape(dstShape)) { assert(dstShape.size() == 1); SmallVector exprs; for (auto i : llvm::seq(srcShape.size())) exprs.push_back(builder.getAffineDimExpr(i)); return {exprs}; } SmallVector reassociationMap(dstShape.size()); unsigned currSrcDim = 0, currDstDim = 0; while (currSrcDim < srcShape.size() && currDstDim < dstShape.size()) { int64_t dstSize = dstShape[currDstDim]; int64_t srcSize = srcShape[currSrcDim]; while (srcSize < dstSize && currSrcDim < srcShape.size()) { reassociationMap[currDstDim].push_back( builder.getAffineDimExpr(currSrcDim++)); srcSize *= srcShape[currSrcDim]; } if (srcSize == dstSize) { reassociationMap[currDstDim].push_back( builder.getAffineDimExpr(currSrcDim++)); // If the next dim in collapsedShape is not 1, treat subsequent dims in // expandedShape which are 1 to be collapsed. if (currDstDim == dstShape.size() - 1 || dstShape[currDstDim + 1] != 1) { while (currSrcDim < srcShape.size() && srcShape[currSrcDim] == 1) { reassociationMap[currDstDim].push_back( builder.getAffineDimExpr(currSrcDim++)); } } } currDstDim++; } // If the source and target shapes are compatible, both iterators must have // reached the end. This condition is guaranteed by the op verifier for // static shapes. assert(currSrcDim == srcShape.size() && currDstDim == dstShape.size()); return reassociationMap; } // Create a tensor.collapse_shape op that reshapes the input into the given // result type. Value createCollapse(OpBuilder &builder, Location loc, TensorType resultType, Value input) { auto reassociationMap = createReassociationMapForCollapse(builder, input.getType(), resultType); return builder.createOrFold(loc, resultType, input, reassociationMap); } // Create a tensor.expand_shape op that reshapes the input into the given result // type. Value createExpand(OpBuilder &builder, Location loc, TensorType resultType, Value input) { auto reassociationMap = createReassociationMapForCollapse(builder, resultType, input.getType()); return builder.createOrFold(loc, resultType, input, reassociationMap); } class ReshapeConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = reshape.getLoc(); auto resultType = cast_if_present( getTypeConverter()->convertType(reshape.getType())); if (!resultType) { return rewriter.notifyMatchFailure(reshape.getLoc(), "could not convert result type"); } auto input = dyn_cast>(adaptor.getInput1()); if (!input) { return rewriter.notifyMatchFailure(reshape.getLoc(), "expected input type to be tensor"); } llvm::SmallVector newShape; if (!tosa::getConstShapeValues(reshape.getShape().getDefiningOp(), newShape)) { return failure(); } // Infer all intermediate types auto inputType = inferReshapeInputType(input, newShape); auto expandedType = inferReshapeExpandedType(inputType, newShape); auto collapsedType = inferReshapeCollapsedType(inputType, expandedType); // Cast input if needed auto castInput = rewriter.createOrFold(loc, inputType, input); // Emit collaspe-expand pair auto collapsed = createCollapse(rewriter, loc, collapsedType, castInput); auto expanded = createExpand(rewriter, loc, expandedType, collapsed); // Cast to final result type if needed auto result = rewriter.createOrFold(loc, resultType, expanded); rewriter.replaceOp(reshape, result); return success(); } }; class SliceConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::SliceOp sliceOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = sliceOp.getLoc(); Value input = adaptor.getInput1(); ShapedType resultType = cast(sliceOp.getType()); if (llvm::isa(resultType)) return failure(); ElementsAttr startElems; ElementsAttr sizeElems; if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems))) return rewriter.notifyMatchFailure( sliceOp, "start of slice must be a static ranked shape"); if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) return rewriter.notifyMatchFailure( sliceOp, "size of slice must be a static ranked shape"); llvm::SmallVector sliceStarts = llvm::to_vector(startElems.getValues()); llvm::SmallVector sliceSizes = llvm::to_vector(sizeElems.getValues()); SmallVector strides, sizes; strides.resize(cast(sliceOp.getType()).getRank(), 1); SmallVector dynSizes; for (const auto &i : llvm::enumerate(sliceSizes)) { int64_t size = i.value(); size_t index = i.index(); sizes.push_back(size == -1 ? ShapedType::kDynamic : size); if (ShapedType::isStatic(sizes.back())) continue; auto dim = tensor::DimOp::create(rewriter, loc, input, index); auto offset = arith::ConstantOp::create( rewriter, loc, rewriter.getIndexAttr(sliceStarts[index])); dynSizes.push_back(arith::SubIOp::create(rewriter, loc, dim, offset)); } auto newSliceOp = tensor::ExtractSliceOp::create( rewriter, sliceOp.getLoc(), sliceOp.getType(), input, ValueRange({}), dynSizes, ValueRange({}), rewriter.getDenseI64ArrayAttr(sliceStarts), rewriter.getDenseI64ArrayAttr(sizes), rewriter.getDenseI64ArrayAttr(strides)); // Remove const_shape ops when it no longer has use point. Operation *startConstShape = sliceOp.getStart().getDefiningOp(); if (startConstShape->getResult(0).hasOneUse()) rewriter.eraseOp(startConstShape); Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); if (sizeConstShape->getResult(0).hasOneUse()) rewriter.eraseOp(sizeConstShape); rewriter.replaceOp(sliceOp, newSliceOp.getResult()); return success(); } }; class PadConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::PadOp padOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto loc = padOp.getLoc(); auto input = padOp.getInput1(); ElementsAttr paddingElems; if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) { return rewriter.notifyMatchFailure( padOp, "padding must be a static shape value"); } llvm::SmallVector paddingVals; for (auto idx : paddingElems.getValues()) { paddingVals.push_back(static_cast(idx.getInt())); } ShapedType inputTy = cast(input.getType()); int64_t rank = inputTy.getRank(); // Setup the default constantAttr. Value padConstant = rewriter.createOrFold( loc, padOp.getPadConst(), ValueRange({arith::ConstantIndexOp::create(rewriter, loc, 0)})); if (!padConstant) { return rewriter.notifyMatchFailure( padOp, "tosa.pad was unable to determine the pad constant value."); } SmallVector lowValues; SmallVector highValues; lowValues.reserve(rank); highValues.reserve(rank); for (int i = 0; i < rank; i++) { Value lowVal = arith::ConstantOp::create( rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i])); Value highVal = arith::ConstantOp::create( rewriter, loc, rewriter.getIndexAttr(paddingVals[2 * i + 1])); lowValues.push_back(lowVal); highValues.push_back(highVal); } auto newPadOp = tensor::PadOp::create(rewriter, loc, padOp.getType(), input, lowValues, highValues, padConstant); rewriter.replaceOp(padOp, newPadOp.getResult()); return success(); } }; struct ConcatConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto resultType = dyn_cast(op.getType()); Location loc = op.getLoc(); int axis = op.getAxis(); Value axisValue = arith::ConstantOp::create(rewriter, loc, rewriter.getIndexAttr(axis)); int64_t rank = resultType.getRank(); SmallVector strides(rank, rewriter.getIndexAttr(1)); SmallVector offsets(rank, rewriter.getIndexAttr(0)); SmallVector sizes = tensor::getMixedSizes(rewriter, op.getLoc(), adaptor.getOperands()[0]); // Pre-compute the offsets along the axis dimension. // The axisOffsets will be of size rank + 1, where the last value // will hold the total size of the tensor along the 'axis' dimension. SmallVector axisOffsets; axisOffsets.push_back(rewriter.getIndexAttr(0)); axisOffsets.push_back(sizes[axis]); for (auto arg : adaptor.getOperands().drop_front()) { auto size = rewriter.createOrFold(loc, arg, axisValue); auto currentOffset = getValueOrCreateConstantIndexOp(rewriter, loc, axisOffsets.back()); auto total = rewriter.createOrFold(loc, currentOffset, size); axisOffsets.push_back(getAsOpFoldResult(total)); } sizes[axis] = axisOffsets.back(); // Compute the dynamic sizes of the tensor.empty operation. // This is based off of the specified result type of the tosa.concat // operation, since we don't want to change the result type of the operation // during the conversion. SmallVector dynDims; for (int64_t i = 0; i < rank; ++i) { if (resultType.isDynamicDim(i)) { dynDims.push_back( getValueOrCreateConstantIndexOp(rewriter, loc, sizes[i])); } } Value result = tensor::EmptyOp::create(rewriter, loc, resultType.getShape(), resultType.getElementType(), dynDims); for (auto [arg, offset] : llvm::zip(adaptor.getOperands(), axisOffsets)) { auto sizes = tensor::getMixedSizes(rewriter, op.getLoc(), arg); offsets[axis] = offset; result = rewriter.createOrFold( loc, arg, result, offsets, sizes, strides); } rewriter.replaceOp(op, result); return success(); } }; } // namespace void mlir::tosa::populateTosaToTensorConversionPatterns( const TypeConverter &converter, RewritePatternSet *patterns) { patterns ->add( converter, patterns->getContext()); }