//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // \file // This file implements the TOSA Specification: // https://developer.mlplatform.org/w/tosa/ // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::tosa; #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc" //===----------------------------------------------------------------------===// // Tosa dialect interface includes. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" namespace { #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc" //===----------------------------------------------------------------------===// // Dialect Function Inliner Interface. //===----------------------------------------------------------------------===// struct TosaInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks. //===--------------------------------------------------------------------===// /// All operations can be inlined by default. bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, IRMapping &map) const final { return true; } /// All regions with If and While parent operators can be inlined. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, IRMapping &map) const final { return (isa(dest->getParentOp()) || isa(dest->getParentOp())); } }; /// This class implements the bytecode interface for the Tosa dialect. struct TosaDialectBytecodeInterface : public BytecodeDialectInterface { TosaDialectBytecodeInterface(Dialect *dialect) : BytecodeDialectInterface(dialect) {} //===--------------------------------------------------------------------===// // Attributes Attribute readAttribute(DialectBytecodeReader &reader) const override { return ::readAttribute(getContext(), reader); } LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const override { return ::writeAttribute(attr, writer); } //===--------------------------------------------------------------------===// // Types Type readType(DialectBytecodeReader &reader) const override { return ::readType(getContext(), reader); } LogicalResult writeType(Type type, DialectBytecodeWriter &writer) const override { return ::writeType(type, writer); } void writeVersion(DialectBytecodeWriter &writer) const final { // TODO: Populate. } std::unique_ptr readVersion(DialectBytecodeReader &reader) const final { // TODO: Populate reader.emitError("Dialect does not support versioning"); return nullptr; } LogicalResult upgradeFromVersion(Operation *topLevelOp, const DialectVersion &version) const final { return success(); } }; } // namespace //===----------------------------------------------------------------------===// // TOSA control flow support. //===----------------------------------------------------------------------===// /// Returns the while loop body. SmallVector tosa::WhileOp::getLoopRegions() { return {&getBody()}; } //===----------------------------------------------------------------------===// // Tosa dialect initialization. //===----------------------------------------------------------------------===// void TosaDialect::initialize() { addOperations< #define GET_OP_LIST #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" >(); addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" >(); addInterfaces(); } Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { // Tosa dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. if (llvm::isa(value)) return builder.create(loc, type, llvm::cast(value)); return nullptr; } //===----------------------------------------------------------------------===// // Parsers and printers //===----------------------------------------------------------------------===// ParseResult mlir::tosa::parseTypeOrAttr(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &attr) { if (succeeded(parser.parseOptionalEqual())) { if (failed(parser.parseAttribute(attr))) { return parser.emitError(parser.getCurrentLocation()) << "expected attribute"; } if (auto typedAttr = attr.dyn_cast()) { typeAttr = TypeAttr::get(typedAttr.getType()); } return success(); } Type type; if (failed(parser.parseColonType(type))) { return parser.emitError(parser.getCurrentLocation()) << "expected type"; } typeAttr = TypeAttr::get(type); return success(); } void mlir::tosa::printTypeOrAttr(OpAsmPrinter &p, Operation *op, TypeAttr type, Attribute attr) { bool needsSpace = false; auto typedAttr = attr.dyn_cast_or_null(); if (!typedAttr || typedAttr.getType() != type.getValue()) { p << ": "; p.printAttribute(type); needsSpace = true; // subsequent attr value needs a space separator } if (attr) { if (needsSpace) p << ' '; p << "= "; p.printAttribute(attr); } } //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// static bool hasZeroDimension(ShapedType shapedType) { if (!shapedType.hasRank()) return false; auto rank = shapedType.getRank(); for (int i = 0; i < rank; i++) { if (shapedType.isDynamicDim(i)) continue; if (shapedType.getDimSize(i) == 0) return true; } return false; } template static LogicalResult verifyConvOp(T op) { // All TOSA conv ops have an input() and weight(). auto inputType = llvm::dyn_cast(op.getInput().getType()); auto weightType = llvm::dyn_cast(op.getWeight().getType()); // Must be ranked tensor types if (!inputType) { op.emitOpError("expect a ranked tensor for input, got ") << op.getInput(); return failure(); } if (!weightType) { op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight(); return failure(); } if (hasZeroDimension(inputType)) return op.emitOpError() << "tensor has a dimension with size zero. Each " "dimension of a tensor must have size >= 1"; auto inputEType = inputType.getElementType(); auto weightEType = weightType.getElementType(); bool inputIsQuant = !llvm::isa(inputEType); bool weightIsQuant = !llvm::isa(weightEType); // Either both must be quantized or both unquantized. if (inputIsQuant != weightIsQuant) { op.emitOpError( "expect both input and weight to be float or not together, got ") << inputEType << " and " << weightEType; return failure(); } // Quantized type must have constructed the quantizationattr, and unquantized // types should not have a quantizationattr. if ((inputIsQuant && !op.getQuantizationInfo()) || (!inputIsQuant && op.getQuantizationInfo())) { op.emitOpError("quantizationattr is required for quantized type, and not " "allowed for float type"); return failure(); } return success(); } LogicalResult tosa::ArgMaxOp::verify() { // Ensure output is of 32-bit integer const auto resultETy = llvm::cast(getType()).getElementType(); if (!resultETy.isIntOrIndex()) return emitOpError("result tensor is not of integer type"); // Ensure axis is within the tensor rank const auto inputType = llvm::cast(getInput().getType()); const int64_t axis = getAxisAttr().getInt(); if (inputType.hasRank() && ((axis < 0) || axis >= inputType.getRank())) return emitOpError("specified axis is outside the rank of the tensor"); return success(); } LogicalResult tosa::AvgPool2dOp::verify() { auto inputType = llvm::cast(getInput().getType()); if (hasZeroDimension(inputType)) return emitOpError() << "tensor has a dimension with size zero. Each " "dimension of a tensor must have size >= 1"; auto inputETy = inputType.getElementType(); auto resultETy = llvm::cast(getType()).getElementType(); if (auto quantType = llvm::dyn_cast(inputETy)) inputETy = quantType.getStorageType(); if (auto quantType = llvm::dyn_cast(resultETy)) resultETy = quantType.getStorageType(); auto accType = getAccType(); if (llvm::isa(inputETy) && !accType.isInteger(32)) return emitOpError("accumulator type for integer tensor is not i32"); if (inputETy.isF16() && !(accType.isF16() || accType.isF32())) return emitOpError("accumulator type for f16 tensor is not f16/f32"); if (inputETy.isBF16() && !accType.isF32()) return emitOpError("accumulator type for bf16 tensor is not f32"); if (inputETy.isF32() && !accType.isF32()) return emitOpError("accumulator type for f32 tensor is not f32"); if ((inputETy.isF32() && resultETy.isF32()) || (inputETy.isF16() && resultETy.isF16()) || (inputETy.isBF16() && resultETy.isBF16()) || (inputETy.isInteger(8) && resultETy.isInteger(8)) || (inputETy.isInteger(16) && resultETy.isInteger(16))) return success(); return emitOpError("input/output element types are incompatible."); } LogicalResult tosa::ClampOp::verify() { mlir::Type inputETy = llvm::cast(getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(inputETy)) { inputETy = quantType.getStorageType(); } mlir::Type maxFpType = getMaxFpAttr().getType(); mlir::Type minFpType = getMinFpAttr().getType(); mlir::Type outputETy = llvm::cast(getOutput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(outputETy)) { outputETy = quantType.getStorageType(); } unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth(); if (inputETy != outputETy) return emitOpError("input/output element types are incompatible."); // if input datatype is float, check that the two min/max_fp attributes share // the same type and that their type is either the same of the input's // datatype, or a float type whose bitwidth > input datatype bitwidth if (!inputETy.isInteger(dataTypeBitWidth)) { if (((maxFpType != minFpType) || (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <= inputETy.getIntOrFloatBitWidth()))) return emitOpError("min/max attributes types are incompatible with " "input/output element types."); } return success(); } //===----------------------------------------------------------------------===// // TOSA Operator Quantization Builders. //===----------------------------------------------------------------------===// /// This builder is called on all convolution operators except TransposeConv, /// which has specialized output shape semantics. The builder also defines the /// bitwidth of the output given the bit width of the input & weight content. static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation) { result.addOperands({input, weight, bias}); result.addAttribute("pad", pad); result.addAttribute("stride", stride); result.addAttribute("dilation", dilation); auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); result.addTypes( buildConvOpResultTypeInfo(builder, outputType, input, weight)); } else { result.addTypes(outputType); } } /// Handles tosa.transpose_conv2d which has outpad and output shape attributes. static void buildTransConvOpWithQuantInfo( OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, DenseI64ArrayAttr outputShape) { result.addOperands({input, weight, bias}); result.addAttribute("out_pad", outpad); result.addAttribute("stride", stride); result.addAttribute("out_shape", outputShape); auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); result.addTypes( buildConvOpResultTypeInfo(builder, outputType, input, weight)); } else { result.addTypes(outputType); } } /// The tosa.fully_connected op has its own builder as it does not have /// strides/dilation/padding. static void buildFCOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias) { result.addOperands({input, weight, bias}); auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); result.addTypes( buildConvOpResultTypeInfo(builder, outputType, input, weight)); } else { result.addTypes(outputType); } } /// The tosa.matmul op is also intended to be generated where a fully_connected /// op must be constructed where the weight is not a constant. In this case, /// the fully_connected op must be expressed using matmul. /// TODO: Add link to the leglization document explaining this. static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b) { result.addOperands({a, b}); auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b); if (quantAttr) { result.addAttribute("quantization_info", quantAttr); auto inputType = llvm::dyn_cast(a.getType()); assert(inputType && "Input must be a shaped tensor type!"); auto inputQType = llvm::dyn_cast( inputType.getElementType()); assert(inputQType && "Tensor must have quantized datatype!"); unsigned inputBits = inputQType.getStorageTypeIntegralWidth(); auto outputShapedType = llvm::dyn_cast(outputType); assert(outputShapedType && "Output must be a shaped type"); IntegerType accElementType; if (inputBits == 16) accElementType = builder.getIntegerType(48); else accElementType = builder.getI32Type(); auto accType = outputShapedType.clone(accElementType); result.addTypes(accType); } else { result.addTypes(outputType); } } /// Both the tosa.avg_pool2d and unary ops use the same UnaruOpQuantizationAttr /// but avg_pool operator has its own builder as it has additional parameters /// not part of the unary ops. static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType) { result.addOperands(input); result.addAttribute("kernel", kernel); result.addAttribute("stride", stride); result.addAttribute("pad", pad); result.addAttribute("acc_type", accType); auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); if (quantAttr) result.addAttribute("quantization_info", quantAttr); result.types.push_back(outputType); } /// This builder is called on single-parameter unary operators that have scale /// relationship between their input and output, expressed by the /// UnaryOpQuantizationAttr. static void buildUnaryOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input) { result.addOperands(input); auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); if (quantAttr) result.addAttribute("quantization_info", quantAttr); result.types.push_back(outputType); } /// This builder is called on TOSA pad operator that needs to create its own /// OptionalAttr quantization_attr parameter to scale the padding values /// correctly. No pad_const is interpreted as zero-padding. static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings) { result.addOperands({input, paddings}); auto quantAttr = buildPadOpQuantizationAttr(builder, input); if (quantAttr) result.addAttribute("quantization_info", quantAttr); result.types.push_back(outputType); } /// This builder is called on TOSA pad operator when an explicit pad_const /// value is passed in. It also optionally constructs quantization_attr. static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings, Value padConst) { result.addOperands({input, paddings, padConst}); auto quantAttr = buildPadOpQuantizationAttr(builder, input); if (quantAttr) result.addAttribute("quantization_info", quantAttr); result.types.push_back(outputType); } //===----------------------------------------------------------------------===// // TOSA Operator Return Type Inference. //===----------------------------------------------------------------------===// static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector &outShape) { int64_t outRank = 0; for (int i = 0, e = operands.size(); i != e; ++i) { auto shape = operands.getShape(i); if (!shape.hasRank()) { // TODO(jennik): Update function to have better case handling for invalid // operands and for ranked tensors. return failure(); } outRank = std::max(outRank, shape.getRank()); } outShape.resize(outRank, 1); for (int i = 0, e = operands.size(); i != e; ++i) { auto shape = operands.getShape(i); auto rankDiff = outShape.size() - shape.getRank(); for (size_t i = 0, e = shape.getRank(); i < e; ++i) { auto dim1 = outShape[i + rankDiff]; auto dim2 = shape.getDimSize(i); auto resolvedDim = dim1; if (dim1 == 1) { resolvedDim = dim2; } else if (dim2 == 1) { resolvedDim = dim1; } else if (dim1 != dim2) { return failure(); } outShape[i + rankDiff] = resolvedDim; } } return success(); } LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ArgMaxOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); IntegerAttr axis = adaptor.getProperties().axis; int32_t axisVal = axis.getValue().getSExtValue(); if (!inputShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } SmallVector outShape; outShape.reserve(inputShape.getRank() - 1); for (int i = 0, s = inputShape.getRank(); i < s; i++) { if (i == axisVal) continue; outShape.push_back(inputShape.getDimSize(i)); } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); return success(); } LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, RFFT2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); if (!inputShape.hasRank()) return failure(); llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamic); outputShape[0] = inputShape.getDimSize(0); outputShape[1] = inputShape.getDimSize(1); int64_t inWidth = inputShape.getDimSize(2); // Note that we can support this calculation symbolically // in the future e.g. [x, y, z] -> [x, y, z / 2 - 1] if (inWidth != ShapedType::kDynamic) outputShape[2] = inWidth / 2 + 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::FFT2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, FFT2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { inferredReturnShapes.push_back( ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType()))); inferredReturnShapes.push_back( ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType()))); return success(); } LogicalResult tosa::ConcatOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ConcatOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { // Infer all dimension sizes by reducing based on inputs. const Properties &prop = adaptor.getProperties(); int32_t axis = prop.axis.getValue().getSExtValue(); llvm::SmallVector outputShape; bool hasRankedInput = false; for (auto operand : adaptor.getOperands()) { ShapeAdaptor operandShape(operand.getType()); if (!operandShape.hasRank()) continue; // Copy the Operand's rank. if (!hasRankedInput) outputShape.resize(operandShape.getRank(), ShapedType::kDynamic); // Copy shapes until the dim is non-dynamic. for (int i = 0, s = operandShape.getRank(); i < s; i++) { if (i == axis || operandShape.isDynamicDim(i)) continue; if (outputShape[i] == ShapedType::kDynamic) outputShape[i] = operandShape.getDimSize(i); if (outputShape[i] != operandShape.getDimSize(i)) return emitOptionalError(location, "Cannot concat tensors with different sizes" " on the non-axis dimension ", i); } hasRankedInput = true; } Type inputType = llvm::cast(adaptor.getInput1().getType()[0]).getElementType(); if (!hasRankedInput) { inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); } // Determine the dimension size along the concatenation axis. int64_t concatDimSize = 0; for (auto operand : adaptor.getOperands()) { ShapeAdaptor operandShape(operand.getType()); // We need to know the length of the concatenation axis of all inputs to // determine the dimension size of the output shape. if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) { concatDimSize = ShapedType::kDynamic; break; } concatDimSize += operandShape.getDimSize(axis); } outputShape[axis] = concatDimSize; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } LogicalResult tosa::EqualOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { auto elementType = IntegerType::get(context, /*width=*/1); llvm::SmallVector outShape; if (resolveBroadcastShape(operands, outShape).failed()) { inferredReturnShapes.push_back(ShapedTypeComponents(elementType)); return success(); } inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType)); return success(); } bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != r.size() || l.size() != 1) return false; return succeeded(verifyCompatibleShape(l[0], r[0])); } LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, FullyConnectedOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); ShapeAdaptor weightShape(adaptor.getWeight().getType()); ShapeAdaptor biasShape(adaptor.getBias().getType()); // All shapes are dynamic. SmallVector outShape; outShape.resize(2, ShapedType::kDynamic); if (inputShape.hasRank()) { outShape[0] = inputShape.getDimSize(0); } if (weightShape.hasRank()) { outShape[1] = weightShape.getDimSize(0); } if (biasShape.hasRank()) { outShape[1] = outShape[1] == ShapedType::kDynamic ? biasShape.getDimSize(0) : outShape[1]; } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); return success(); } LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); } LogicalResult tosa::MatMulOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, MatMulOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor lhsShape(adaptor.getA().getType()); ShapeAdaptor rhsShape(adaptor.getB().getType()); // All shapes are dynamic. SmallVector outShape; outShape.resize(3, ShapedType::kDynamic); if (lhsShape.hasRank()) { outShape[0] = lhsShape.getDimSize(0); outShape[1] = lhsShape.getDimSize(1); } if (rhsShape.hasRank()) { outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0) : outShape[0]; outShape[2] = rhsShape.getDimSize(2); } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); return success(); } LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, PadOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); ShapeAdaptor paddingShape(adaptor.getPadding().getType()); SmallVector outputShape; // If both inputs have unknown shape, we cannot determine the shape of the // output. if (!inputShape.hasRank() && !paddingShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } // If the input rank is unknown we can info the output rank using the padding // shape's first dim. if (!inputShape.hasRank()) { if (paddingShape.isDynamicDim(0)) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } outputShape.resize(paddingShape.getDimSize(0), ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } DenseIntElementsAttr paddings; // If the paddings value is not a constant, all dimensions must be dynamic. if (!matchPattern(adaptor.getPadding(), m_Constant(&paddings))) { outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } SmallVector paddingValues; for (auto val : paddings) { paddingValues.push_back(val.getSExtValue()); } outputShape.reserve(inputShape.getRank()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { if (inputShape.isDynamicDim(i)) { outputShape.push_back(ShapedType::kDynamic); continue; } outputShape.push_back(inputShape.getDimSize(i) + paddingValues[i * 2] + paddingValues[i * 2 + 1]); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } static SmallVector convertToMlirShape(ArrayRef shape) { return to_vector(llvm::map_range(shape, [](int64_t dim) { return dim == -1 ? ShapedType::kDynamic : dim; })); } LogicalResult tosa::SliceOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, SliceOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { inferredReturnShapes.push_back( ShapedTypeComponents(convertToMlirShape(adaptor.getSize()))); return success(); } LogicalResult tosa::SliceOp::verify() { auto inputType = llvm::dyn_cast(getInput().getType()); if (!inputType) return success(); if (static_cast(inputType.getRank()) != getStart().size()) return emitOpError( "length of start attribute is not equal rank of input shape"); if (static_cast(inputType.getRank()) != getSize().size()) return emitOpError( "length of size attribute is not equal rank of input shape"); return success(); } LogicalResult tosa::TableOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TableOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); if (!inputShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } inferredReturnShapes.resize(1); inputShape.getDims(inferredReturnShapes[0]); return success(); } LogicalResult tosa::TileOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ArrayRef multiples = adaptor.getMultiples(); ShapeAdaptor inputShape(adaptor.getInput1().getType()); SmallVector outputShape; if (!inputShape.hasRank()) { outputShape.resize(multiples.size(), ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } else if (static_cast(inputShape.getRank()) != multiples.size()) return failure(); // Any non dynamic dimension can be multiplied to a known size. outputShape.reserve(multiples.size()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { int64_t dim = inputShape.getDimSize(i); if (dim != ShapedType::kDynamic) dim *= multiples[i]; outputShape.push_back(dim); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::TileOp::verify() { ShapedType inputType = llvm::cast(getInput1().getType()); ShapedType outputType = llvm::cast(getType()); auto multiples = getMultiples(); if (inputType.hasRank()) { if (static_cast(inputType.getRank()) != multiples.size()) return emitOpError("expect 'multiples' array to have length ") << inputType.getRank() << " but got " << multiples.size() << "."; if (outputType.hasRank() && inputType.getRank() != outputType.getRank()) return emitOpError("expect same input and output tensor rank."); } else if (outputType.hasRank() && static_cast(outputType.getRank()) != multiples.size()) return emitOpError("expect 'multiples' array to have length ") << outputType.getRank() << " but got " << multiples.size() << "."; return success(); } bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != r.size() || l.size() != 1) return false; return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]); } LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ReshapeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); llvm::SmallVector newShapeValue = convertToMlirShape(adaptor.getNewShape()); // We cannot infer from the total number of elements so we must take the // shape attribute as exact. if (!inputShape.hasRank() || !inputShape.hasStaticShape()) { inferredReturnShapes.push_back( ShapedTypeComponents(newShapeValue, inputType)); return success(); } // Determine the number of elements covered by the slice of all static // dimensions. This allows us to infer the length of the remaining dynamic // dimension. int64_t numElements = inputShape.getNumElements(); int64_t staticMul = 1; for (auto val : newShapeValue) { if (!ShapedType::isDynamic(val)) { staticMul *= val; } } // Determine the length of the dynamic dimension. for (auto &val : newShapeValue) { if (ShapedType::isDynamic(val)) val = numElements / staticMul; } inferredReturnShapes.push_back( ShapedTypeComponents(newShapeValue, inputType)); return success(); } mlir::LogicalResult tosa::ReshapeOp::verify() { ShapedType inputType = llvm::cast(getInput1().getType()); ShapedType outputType = llvm::cast(getType()); if (hasZeroDimension(inputType) || hasZeroDimension(outputType)) return emitOpError() << "tensor has a dimension with size zero. Each " "dimension of a tensor must have size >= 1"; if (inputType.hasStaticShape() && outputType.hasStaticShape()) { int64_t inputElementsNum = inputType.getNumElements(); int64_t outputElementsNum = outputType.getNumElements(); if (inputElementsNum != outputElementsNum) { return emitOpError() << "Cannot reshape " << inputElementsNum << " elements into " << outputElementsNum; } } return mlir::success(); } LogicalResult tosa::TransposeOp::getConstantPerms(SmallVector &perms) { // Perms must be constants. DenseIntElementsAttr permsAttr; if (!matchPattern(getPerms(), m_Constant(&permsAttr))) return failure(); // Transpose is not the identity transpose. perms = llvm::to_vector( llvm::map_range(permsAttr.getValues(), [](const APInt &val) { return val.getSExtValue(); })); return success(); } LogicalResult tosa::TransposeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TransposeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); ShapeAdaptor permsShape(adaptor.getPerms().getType()); // We cannot infer anything from a rank-0 "permutation" tensor. if (permsShape.hasRank() && permsShape.getRank() == 0) return failure(); // If input rank and permutation length is unknown, the output rank is // unknown. if (!inputShape.hasRank() || !permsShape.hasRank() || permsShape.isDynamicDim(0)) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } // This would imply the number of permutations does not match the rank of the // input which is illegal. if (permsShape.getDimSize(0) != inputShape.getRank()) { return failure(); } SmallVector outputShape; // Rank-0 means no permutations matter. if (inputShape.getRank() == 0) { inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } // Check whether the input dimensions are all the same. bool allTheSame = true; for (int i = 1, s = inputShape.getRank(); i < s; i++) { if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) { allTheSame = false; break; } } // If all of the input dimensions are the same we don't care about the // permutation. if (allTheSame) { outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0)); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); // If the permuations are a constant we can directly determine the output // shape. DenseIntElementsAttr attr; if (matchPattern(adaptor.getPerms(), m_Constant(&attr)) && attr.getType().getRank() == 1) { ShapeAdaptor permShape = attr; // Constant permutation must be the same length as the input rank. if (inputShape.getRank() != permShape.getRank()) return emitOptionalError(location, "constant permutation must be the same length" " as the input rank"); // Constant permutation values must be within the input rank. for (int i = 0, e = inputShape.getRank(); i < e; i++) { if (inputShape.getRank() <= permShape.getDimSize(i)) return failure(); } outputShape.reserve(inputShape.getRank()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { outputShape[i] = inputShape.getDimSize(permShape.getDimSize(i)); } } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::TransposeOp::verify() { TensorType inputType = getInput1().getType(); TensorType permType = getPerms().getType(); TensorType outputType = getOutput().getType(); if (permType.hasRank() && permType.getRank() != 1) return emitOpError() << "expected permutation tensor to be rank 1 but got rank " << permType.getRank(); if (inputType.hasRank() && permType.hasRank()) if (!permType.isDynamicDim(0) && permType.getDimSize(0) != inputType.getRank()) return emitOpError() << "expected permutation tensor dim 0 to have size " << inputType.getRank() << " (input rank) but got size " << permType.getDimSize(0); if (inputType.hasRank() && outputType.hasRank() && inputType.getRank() != outputType.getRank()) return emitOpError() << "expected input tensor rank to equal result tensor rank"; if (outputType.hasRank() && permType.hasRank()) if (!permType.isDynamicDim(0) && permType.getDimSize(0) != outputType.getRank()) return emitOpError() << "expected permutation tensor dim 0 to have size " << outputType.getRank() << " (output rank) but got size " << permType.getDimSize(0); SmallVector constantPerms; if (succeeded(getConstantPerms(constantPerms))) { // Assert that the permutation tensor has a rank, which means that the rank // has been verified above. assert(permType.hasRank() && "Unexpectedly found permutation tensor without rank"); if (!isPermutationVector(constantPerms)) return emitOpError() << "expected valid permutation tensor"; } return success(); } LogicalResult tosa::GatherOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, GatherOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamic); ShapeAdaptor valuesShape(adaptor.getValues().getType()); if (valuesShape.hasRank()) { outputShape[0] = valuesShape.getDimSize(0); outputShape[2] = valuesShape.getDimSize(2); } ShapeAdaptor indicesShape(adaptor.getIndices().getType()); if (indicesShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); if (outputShape[1] == ShapedType::kDynamic) outputShape[1] = indicesShape.getDimSize(1); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::ResizeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ResizeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(4, ShapedType::kDynamic); ShapeAdaptor inputShape(adaptor.getInput().getType()); if (!inputShape.hasRank()) return failure(); outputShape[0] = inputShape.getDimSize(0); outputShape[3] = inputShape.getDimSize(3); int64_t inputHeight = inputShape.getDimSize(1); int64_t inputWidth = inputShape.getDimSize(2); if ((inputHeight == ShapedType::kDynamic) || (inputWidth == ShapedType::kDynamic)) return failure(); llvm::ArrayRef scaleInt = adaptor.getScale(); llvm::ArrayRef offsetInt = adaptor.getOffset(); llvm::ArrayRef borderInt = adaptor.getBorder(); // Compute the output shape based on attributes: scale, offset, and border. outputShape[1] = (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) / scaleInt[1]) + 1; outputShape[2] = (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) / scaleInt[3]) + 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::ScatterOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ScatterOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamic); ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType()); if (valuesInShape.hasRank()) { outputShape[0] = valuesInShape.getDimSize(0); outputShape[1] = valuesInShape.getDimSize(1); outputShape[2] = valuesInShape.getDimSize(2); } ShapeAdaptor indicesShape(adaptor.getIndices().getType()); if (indicesShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); } ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = inputShape.getDimSize(0); if (outputShape[2] == ShapedType::kDynamic) outputShape[2] = inputShape.getDimSize(2); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } static LogicalResult ReduceInferReturnTypes( ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl &inferredReturnShapes) { int64_t axisVal = axis.getValue().getSExtValue(); if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) { inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); } SmallVector outputShape; operandShape.getDims(outputShape); outputShape[axisVal] = 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } #define COMPATIBLE_RETURN_TYPES(OP) \ bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \ if (l.size() != r.size() || l.size() != 1) \ return false; \ if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \ return false; \ return succeeded(verifyCompatibleShape(l[0], r[0])); \ } #define REDUCE_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::std::optional location, \ OP::Adaptor adaptor, \ SmallVectorImpl &inferredReturnShapes) { \ Type inputType = \ llvm::cast(adaptor.getInput().getType()).getElementType(); \ ShapeAdaptor inputShape(adaptor.getInput().getType()); \ const Properties &prop = adaptor.getProperties(); \ return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \ inferredReturnShapes); \ } \ COMPATIBLE_RETURN_TYPES(OP) REDUCE_SHAPE_INFER(tosa::ReduceAllOp) REDUCE_SHAPE_INFER(tosa::ReduceAnyOp) REDUCE_SHAPE_INFER(tosa::ReduceMaxOp) REDUCE_SHAPE_INFER(tosa::ReduceMinOp) REDUCE_SHAPE_INFER(tosa::ReduceProdOp) REDUCE_SHAPE_INFER(tosa::ReduceSumOp) #undef REDUCE_SHAPE_INFER COMPATIBLE_RETURN_TYPES(tosa::ConcatOp) #undef COMPATIBLE_RETURN_TYPES template static LogicalResult verifyReduceOp(T op) { // All TOSA reduce Ops have input, output and axis. TensorType inputType = op.getInput().getType(); TensorType outputType = op.getOutput().getType(); int32_t reduceAxis = op.getAxis(); if (reduceAxis < 0) { op.emitOpError("reduce axis must not be negative"); return failure(); } if (inputType.hasRank()) { int64_t inputRank = inputType.getRank(); // We allow for a special case where the input/output shape has rank 0 and // axis is also 0. if (reduceAxis >= inputRank && !(reduceAxis == 0 && inputRank == 0)) { op.emitOpError("expect input tensor rank (") << inputRank << ") to be larger than reduce axis (" << reduceAxis << ")"; return failure(); } } if (outputType.hasRank()) { int64_t outputRank = outputType.getRank(); if (inputType.hasRank() && outputRank != inputType.getRank()) { op.emitOpError( "expect output tensor rank to be equal to input tensor rank"); return failure(); } if (reduceAxis >= outputRank && !(reduceAxis == 0 && outputRank == 0)) { op.emitOpError("expect output tensor rank (") << outputRank << ") to be larger than reduce axis (" << reduceAxis << ")"; return failure(); } // We can only verify the reduced dimension size to be 1 if this is not the // special case of output rank == 0. if (outputRank != 0) { auto outputShape = outputType.getShape(); if (!outputType.isDynamicDim(reduceAxis) && outputShape[reduceAxis] != 1) { op.emitOpError("expect reduced dimension size to be 1, got ") << outputShape[reduceAxis]; return failure(); } } } return success(); } LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceProdOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); } static LogicalResult NAryInferReturnTypes( const ValueShapeRange &operands, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outShape; if (resolveBroadcastShape(operands, outShape).failed()) { inferredReturnShapes.push_back(ShapedTypeComponents()); } else { inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); } return success(); } #define NARY_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::std::optional location, \ ValueShapeRange operands, DictionaryAttr attributes, \ OpaqueProperties properties, RegionRange regions, \ SmallVectorImpl &inferredReturnShapes) { \ return NAryInferReturnTypes(operands, inferredReturnShapes); \ } NARY_SHAPE_INFER(tosa::AbsOp) NARY_SHAPE_INFER(tosa::AddOp) NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp) NARY_SHAPE_INFER(tosa::BitwiseAndOp) NARY_SHAPE_INFER(tosa::BitwiseOrOp) NARY_SHAPE_INFER(tosa::BitwiseXorOp) NARY_SHAPE_INFER(tosa::BitwiseNotOp) NARY_SHAPE_INFER(tosa::CastOp) NARY_SHAPE_INFER(tosa::CeilOp) NARY_SHAPE_INFER(tosa::ClampOp) NARY_SHAPE_INFER(tosa::ClzOp) NARY_SHAPE_INFER(tosa::DivOp) NARY_SHAPE_INFER(tosa::ExpOp) NARY_SHAPE_INFER(tosa::FloorOp) NARY_SHAPE_INFER(tosa::GreaterEqualOp) NARY_SHAPE_INFER(tosa::GreaterOp) NARY_SHAPE_INFER(tosa::IdentityOp) NARY_SHAPE_INFER(tosa::LogOp) NARY_SHAPE_INFER(tosa::LogicalAndOp) NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp) NARY_SHAPE_INFER(tosa::LogicalNotOp) NARY_SHAPE_INFER(tosa::LogicalOrOp) NARY_SHAPE_INFER(tosa::LogicalRightShiftOp) NARY_SHAPE_INFER(tosa::LogicalXorOp) NARY_SHAPE_INFER(tosa::MaximumOp) NARY_SHAPE_INFER(tosa::MinimumOp) NARY_SHAPE_INFER(tosa::MulOp) NARY_SHAPE_INFER(tosa::NegateOp) NARY_SHAPE_INFER(tosa::PowOp) NARY_SHAPE_INFER(tosa::ReciprocalOp) NARY_SHAPE_INFER(tosa::RescaleOp) NARY_SHAPE_INFER(tosa::ReverseOp) NARY_SHAPE_INFER(tosa::RsqrtOp) NARY_SHAPE_INFER(tosa::SelectOp) NARY_SHAPE_INFER(tosa::SubOp) NARY_SHAPE_INFER(tosa::TanhOp) NARY_SHAPE_INFER(tosa::ErfOp) NARY_SHAPE_INFER(tosa::SigmoidOp) #undef PRED_SHAPE_INFER static LogicalResult poolingInferReturnTypes( ShapeAdaptor inputShape, ArrayRef kernel, ArrayRef stride, ArrayRef pad, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(4, ShapedType::kDynamic); // We only know the rank if the input type is unranked. if (!inputShape) { inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } // Batch and number of channels are identical for pooling layer. outputShape[0] = inputShape.getDimSize(0); outputShape[3] = inputShape.getDimSize(3); int64_t height = inputShape.getDimSize(1); int64_t width = inputShape.getDimSize(2); if (!ShapedType::isDynamic(height)) { int64_t padded = height + pad[0] + pad[1] - kernel[0]; outputShape[1] = padded / stride[0] + 1; } if (!ShapedType::isDynamic(width)) { int64_t padded = width + pad[2] + pad[3] - kernel[1]; outputShape[2] = padded / stride[1] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult Conv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, Conv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamic); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputHeight = inputShape.getDimSize(1); inputWidth = inputShape.getDimSize(2); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { outputShape[3] = weightShape.getDimSize(0); weightHeight = weightShape.getDimSize(1); weightWidth = weightShape.getDimSize(2); } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getBias().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } llvm::ArrayRef dilation = adaptor.getDilation(); llvm::ArrayRef stride = adaptor.getStride(); llvm::ArrayRef padding = adaptor.getPad(); if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } if (!ShapedType::isDynamic(inputWidth) && !ShapedType::isDynamic(weightWidth)) { int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult Conv2DOp::verify() { return verifyConvOp(*this); } LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, Conv3DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(5, ShapedType::kDynamic); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t inputDepth = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; int64_t weightDepth = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputDepth = inputShape.getDimSize(1); inputHeight = inputShape.getDimSize(2); inputWidth = inputShape.getDimSize(3); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { outputShape[4] = weightShape.getDimSize(0); weightDepth = weightShape.getDimSize(1); weightHeight = weightShape.getDimSize(2); weightWidth = weightShape.getDimSize(3); } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getBias().getType()); if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) { outputShape[4] = biasShape.getDimSize(0); } llvm::ArrayRef dilation = adaptor.getDilation(); llvm::ArrayRef stride = adaptor.getStride(); llvm::ArrayRef pad = adaptor.getPad(); if (!ShapedType::isDynamic(inputDepth) && !ShapedType::isDynamic(weightDepth)) { int32_t inputSize = inputDepth + pad[0] + pad[1]; int32_t filterSize = (weightDepth - 1) * dilation[0] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { int32_t inputSize = inputHeight + pad[2] + pad[3]; int32_t filterSize = (weightHeight - 1) * dilation[1] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } if (!ShapedType::isDynamic(inputWidth) && !ShapedType::isDynamic(weightWidth)) { int32_t inputSize = inputWidth + pad[4] + pad[5]; int32_t filterSize = (weightWidth - 1) * dilation[2] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[3] = (unstridedResult - 1) / stride[2] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult Conv3DOp::verify() { return verifyConvOp(*this); } LogicalResult AvgPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, AvgPool2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); const Properties &prop = adaptor.getProperties(); return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, inferredReturnShapes); } LogicalResult MaxPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, MaxPool2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); const Properties &prop = adaptor.getProperties(); return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, inferredReturnShapes); } LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, DepthwiseConv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamic); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t inputChannels = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; int64_t depthChannels = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputHeight = inputShape.getDimSize(1); inputWidth = inputShape.getDimSize(2); inputChannels = inputShape.getDimSize(3); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { weightHeight = weightShape.getDimSize(0); weightWidth = weightShape.getDimSize(1); inputChannels = ShapedType::isDynamic(inputChannels) ? weightShape.getDimSize(2) : inputChannels; depthChannels = weightShape.getDimSize(3); } // If both inputChannels and depthChannels are available we can determine // the output channels. if (!ShapedType::isDynamic(inputChannels) && !ShapedType::isDynamic(depthChannels)) { outputShape[3] = inputChannels * depthChannels; } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getBias().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } llvm::ArrayRef dilation = adaptor.getDilation(); llvm::ArrayRef padding = adaptor.getPad(); llvm::ArrayRef stride = adaptor.getStride(); if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } if (!ShapedType::isDynamic(inputWidth) && !ShapedType::isDynamic(weightWidth)) { int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult DepthwiseConv2DOp::verify() { return verifyConvOp(*this); } LogicalResult TransposeConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TransposeConv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { // outputShape is mutable. llvm::SmallVector outputShape = convertToMlirShape(adaptor.getOutShape()); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = ShapedType::isDynamic(outputShape[0]) ? inputShape.getDimSize(0) : outputShape[0]; inputHeight = inputShape.getDimSize(1); inputWidth = inputShape.getDimSize(2); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getFilter().getType()); if (weightShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? weightShape.getDimSize(0) : outputShape[3]; weightHeight = weightShape.getDimSize(1); weightWidth = weightShape.getDimSize(2); } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getInput().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } llvm::ArrayRef padding = adaptor.getOutPad(); llvm::ArrayRef stride = adaptor.getStride(); if (!ShapedType::isDynamic(inputHeight) && !ShapedType::isDynamic(weightHeight)) { int64_t calculateSize = (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight; outputShape[1] = ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1]; } if (!ShapedType::isDynamic(inputWidth) && !ShapedType::isDynamic(weightWidth)) { int64_t calculateSize = (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth; outputShape[2] = ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2]; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, IfOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector yieldOps; for (Region *region : adaptor.getRegions()) { for (auto &block : *region) if (auto returnOp = dyn_cast(block.getTerminator())) yieldOps.push_back(returnOp); } if (yieldOps.empty()) return failure(); // Get the initial type information for the yield op. llvm::SmallVector resultKnowledge; resultKnowledge.reserve(yieldOps.front().getNumOperands()); for (auto operand : yieldOps.front().getOperands()) { resultKnowledge.push_back( ValueKnowledge::getKnowledgeFromType(operand.getType())); } for (auto yieldOp : yieldOps) { if (resultKnowledge.size() != yieldOp.getNumOperands()) return failure(); for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { int32_t index = it.index(); auto meet = ValueKnowledge::meet( resultKnowledge[index], ValueKnowledge::getKnowledgeFromType(it.value().getType())); if (!meet) continue; resultKnowledge[index] = meet; } } for (const ValueKnowledge &result : resultKnowledge) { inferredReturnShapes.push_back(result.getShapedTypeComponents()); } return success(); } LogicalResult WhileOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, WhileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector yieldOps; for (auto &block : adaptor.getBody()) if (auto returnOp = dyn_cast(block.getTerminator())) yieldOps.push_back(returnOp); // TOSA's while must have a tosa.yield as its terminator. If not found this // tosa.while is invalid. if (yieldOps.empty()) return failure(); // Get the initial type information from the operand types. llvm::SmallVector resultKnowledge; resultKnowledge.reserve(yieldOps.front().getNumOperands()); for (auto operand : yieldOps.front().getOperands()) { resultKnowledge.push_back( ValueKnowledge::getKnowledgeFromType(operand.getType())); } for (auto yieldOp : yieldOps) { if (resultKnowledge.size() != yieldOp.getNumOperands()) return failure(); for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { int32_t index = it.index(); if (auto meet = ValueKnowledge::meet( resultKnowledge[index], ValueKnowledge::getKnowledgeFromType(it.value().getType()))) { resultKnowledge[index] = meet; } } } for (const ValueKnowledge &result : resultKnowledge) { inferredReturnShapes.push_back(result.getShapedTypeComponents()); } return success(); } std::optional> ApplyScaleOp::getShapeForUnroll() { if (auto vt = llvm::dyn_cast(getType())) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } // parse and print of IfOp refer to the implementation of SCF dialect. ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions for 'then'. result.regions.reserve(2); Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand cond; // Create a i1 tensor type for the boolean condition. Type i1Type = RankedTensorType::get({}, builder.getIntegerType(1)); if (parser.parseOperand(cond) || parser.resolveOperand(cond, i1Type, result.operands)) return failure(); // Parse optional results type list. if (parser.parseOptionalArrowTypeList(result.types)) return failure(); // Parse the 'then' region. if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); // If we find an 'else' keyword then parse the 'else' region. if (!parser.parseOptionalKeyword("else")) { if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); } // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } void IfOp::print(OpAsmPrinter &p) { bool printBlockTerminators = false; p << " " << getCond(); if (!getResults().empty()) { p << " -> (" << getResultTypes() << ")"; // Print yield explicitly if the op defines values. printBlockTerminators = true; } p << ' '; p.printRegion(getThenBranch(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/printBlockTerminators); // Print the 'else' regions if it exists and has a block. auto &elseRegion = getElseBranch(); if (!elseRegion.empty()) { p << " else "; p.printRegion(elseRegion, /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/printBlockTerminators); } p.printOptionalAttrDict((*this)->getAttrs()); } LogicalResult ReverseOp::verify() { TensorType inputType = getInput().getType(); TensorType outputType = getOutput().getType(); int32_t reverseAxis = getAxis(); if (reverseAxis < 0) return emitOpError("expected non-negative reverse axis"); if (inputType.hasRank()) { int64_t inputRank = inputType.getRank(); // We allow for a special case where the input/output shape has rank 0 and // axis is also 0. if (reverseAxis >= inputRank && !(reverseAxis == 0 && inputRank == 0)) return emitOpError("expect input tensor rank (") << inputRank << ") to be larger than reverse axis (" << reverseAxis << ")"; } if (outputType.hasRank()) { int64_t outputRank = outputType.getRank(); if (inputType.hasRank() && outputRank != inputType.getRank()) return emitOpError( "expect output tensor rank to be equal to input tensor rank"); if (reverseAxis >= outputRank && !(reverseAxis == 0 && outputRank == 0)) return emitOpError("expect output tensor rank (") << outputRank << ") to be larger than reverse axis (" << reverseAxis << ")"; } return success(); } // parse and print of WhileOp refer to the implementation of SCF dialect. ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector regionArgs; SmallVector operands; Region *cond = result.addRegion(); Region *body = result.addRegion(); OptionalParseResult listResult = parser.parseOptionalAssignmentList(regionArgs, operands); if (listResult.has_value() && failed(listResult.value())) return failure(); FunctionType functionType; SMLoc typeLoc = parser.getCurrentLocation(); if (failed(parser.parseColonType(functionType))) return failure(); result.addTypes(functionType.getResults()); if (functionType.getNumInputs() != operands.size()) { return parser.emitError(typeLoc) << "expected as many input types as operands " << "(expected " << operands.size() << " got " << functionType.getNumInputs() << ")"; } // Resolve input operands. if (failed(parser.resolveOperands(operands, functionType.getInputs(), parser.getCurrentLocation(), result.operands))) return failure(); // Propagate the types into the region arguments. for (size_t i = 0, e = regionArgs.size(); i != e; ++i) regionArgs[i].type = functionType.getInput(i); return failure(parser.parseRegion(*cond, regionArgs) || parser.parseKeyword("do") || parser.parseRegion(*body) || parser.parseOptionalAttrDictWithKeyword(result.attributes)); } static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix = "") { assert(blocksArgs.size() == initializers.size() && "expected same length of arguments and initializers"); if (initializers.empty()) return; parser << prefix << '('; llvm::interleaveComma( llvm::zip(blocksArgs, initializers), parser, [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); parser << ")"; } void WhileOp::print(OpAsmPrinter &parser) { printInitializationList(parser, getCond().front().getArguments(), getInputs(), " "); parser << " : "; parser.printFunctionalType(getInputs().getTypes(), getResults().getTypes()); parser << ' '; parser.printRegion(getCond(), /*printEntryBlockArgs=*/false); parser << " do "; parser.printRegion(getBody()); parser.printOptionalAttrDictWithKeyword((*this)->getAttrs()); } //===----------------------------------------------------------------------===// // TOSA Attribute Definitions. //===----------------------------------------------------------------------===// #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"