//===- TosaOps.cpp - MLIR Dialect for TOSA --------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // \file // This file implements the TOSA Specification: // https://www.mlplatform.org/tosa/tosa_spec.html // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Shard/Interfaces/ShardingInterface.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/TypeSwitch.h" #include using namespace mlir; using namespace mlir::tosa; #include "mlir/Dialect/Tosa/IR/TosaOpsDialect.cpp.inc" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" //===----------------------------------------------------------------------===// // Tosa dialect interface includes. //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaAvailability.cpp.inc" #include "mlir/Dialect/Tosa/IR/TosaEnums.cpp.inc" #include "mlir/Dialect/Tosa/IR/TosaInterfaces.cpp.inc" #include "mlir/Dialect/Tosa/IR/TosaOpAvailabilityImpl.inc" namespace { #include "mlir/Dialect/Tosa/IR/TosaDialectBytecode.cpp.inc" //===----------------------------------------------------------------------===// // Dialect Function Inliner Interface. //===----------------------------------------------------------------------===// struct TosaInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks. //===--------------------------------------------------------------------===// /// All operations can be inlined by default. bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, IRMapping &map) const final { return true; } /// All regions with If and While parent operators can be inlined. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, IRMapping &map) const final { return (isa(dest->getParentOp()) || isa(dest->getParentOp())); } }; /// This class implements the bytecode interface for the Tosa dialect. struct TosaDialectBytecodeInterface : public BytecodeDialectInterface { TosaDialectBytecodeInterface(Dialect *dialect) : BytecodeDialectInterface(dialect) {} //===--------------------------------------------------------------------===// // Attributes Attribute readAttribute(DialectBytecodeReader &reader) const override { return ::readAttribute(getContext(), reader); } LogicalResult writeAttribute(Attribute attr, DialectBytecodeWriter &writer) const override { return ::writeAttribute(attr, writer); } //===--------------------------------------------------------------------===// // Types Type readType(DialectBytecodeReader &reader) const override { return ::readType(getContext(), reader); } LogicalResult writeType(Type type, DialectBytecodeWriter &writer) const override { return ::writeType(type, writer); } void writeVersion(DialectBytecodeWriter &writer) const final { // TODO: Populate. } std::unique_ptr readVersion(DialectBytecodeReader &reader) const final { // TODO: Populate reader.emitError("Dialect does not support versioning"); return nullptr; } LogicalResult upgradeFromVersion(Operation *topLevelOp, const DialectVersion &version) const final { return success(); } }; } // namespace //===----------------------------------------------------------------------===// // TOSA control flow support. //===----------------------------------------------------------------------===// /// Returns the while loop body. SmallVector tosa::WhileOp::getLoopRegions() { return {&getBodyGraph()}; } //===----------------------------------------------------------------------===// // TOSA variable operator support. //===----------------------------------------------------------------------===// static SmallVector convertToMlirShape(ArrayRef shape) { return to_vector(llvm::map_range(shape, [](int64_t dim) { return dim == -1 ? ShapedType::kDynamic : dim; })); } // returns type of variable op RankedTensorType mlir::tosa::getVariableType(tosa::VariableOp variableOp) { Type elementType = variableOp.getType(); DenseIntElementsAttr varShapeAttr = variableOp.getVarShape(); auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues())); return RankedTensorType::get(shape, elementType); } //===----------------------------------------------------------------------===// // Tosa dialect initialization. //===----------------------------------------------------------------------===// void TosaDialect::initialize() { addTypes< #define GET_TYPEDEF_LIST #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" >(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc" >(); addAttributes< #define GET_ATTRDEF_LIST #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" >(); addInterfaces(); declarePromisedInterfaces< shard::ShardingInterface, ClampOp, SigmoidOp, TanhOp, AddOp, ArithmeticRightShiftOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, IntDivOp, LogicalAndOp, LogicalLeftShiftOp, LogicalRightShiftOp, LogicalOrOp, LogicalXorOp, MaximumOp, MinimumOp, MulOp, PowOp, SubOp, AbsOp, BitwiseNotOp, CeilOp, ClzOp, ExpOp, FloorOp, LogOp, LogicalNotOp, NegateOp, ReciprocalOp, RsqrtOp, SelectOp, EqualOp, GreaterOp, GreaterEqualOp, MatMulOp>(); } Operation *TosaDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { // Tosa dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. if (llvm::isa(type) && llvm::isa(value)) { return tosa::ConstShapeOp::create(builder, loc, type, llvm::cast(value)); } if (llvm::isa(value)) return tosa::ConstOp::create(builder, loc, type, llvm::cast(value)); return nullptr; } //===----------------------------------------------------------------------===// // Parsers and printers //===----------------------------------------------------------------------===// namespace { ParseResult getShapeAndElementType(OpAsmParser &parser, Type parsedType, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr) { if (auto shapedType = dyn_cast(parsedType)) { if (!shapedType.hasRank()) return parser.emitError(parser.getCurrentLocation()) << "expected ranked type"; auto elementType = shapedType.getElementType(); typeAttr = TypeAttr::get(elementType); ArrayRef shape = shapedType.getShape(); Builder builder(parser.getContext()); varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape)); return success(); } return parser.emitError(parser.getCurrentLocation()) << "expected shaped type"; } } // namespace // parses the optional initial value or type for a tosa variable // with initial value: // tosa.variable @name = dense<0.0> : tensor<1x8xf32> // // without initial value: // tosa.variable @name : tensor<1x8xf32> ParseResult mlir::tosa::parseVariableOpTypeOrInitialValue( OpAsmParser &parser, DenseElementsAttr &varShapeAttr, TypeAttr &typeAttr, Attribute &initialValueAttr) { if (succeeded(parser.parseOptionalEqual())) { if (failed(parser.parseAttribute(initialValueAttr))) { return parser.emitError(parser.getCurrentLocation()) << "expected attribute"; } if (auto typedAttr = dyn_cast(initialValueAttr)) { return getShapeAndElementType(parser, typedAttr.getType(), varShapeAttr, typeAttr); } return parser.emitError(parser.getCurrentLocation()) << "expected Typed attr"; } initialValueAttr = nullptr; Type parsedType; if (failed(parser.parseColonType(parsedType))) { return parser.emitError(parser.getCurrentLocation()) << "expected type after colon"; } return getShapeAndElementType(parser, parsedType, varShapeAttr, typeAttr); } void mlir::tosa::printVariableOpTypeOrInitialValue( OpAsmPrinter &p, Operation *op, DenseElementsAttr varShapeAttr, TypeAttr typeAttr, Attribute initialValueAttr) { bool needsSpace = false; if (!dyn_cast_or_null(initialValueAttr)) { auto shape = convertToMlirShape(to_vector(varShapeAttr.getValues())); Type elementType = typeAttr.getValue(); RankedTensorType tensorType = RankedTensorType::get(ArrayRef(shape), elementType); auto tensorTypeAttr = TypeAttr::get(tensorType); p << ": "; p.printAttribute(tensorTypeAttr); needsSpace = true; // subsequent attr value needs a space separator } if (initialValueAttr) { if (needsSpace) p << ' '; p << "= "; p.printAttribute(initialValueAttr); } } namespace { // parse attributes with special handling for tosa enum attributes template ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser, NamedAttrList &outAttrs) { llvm::StringRef name; if (parser.parseOptionalKeyword(&name) || parser.parseEqual()) return failure(); // special handling: rounding_mode accepts a *bare* RoundingMode enum // keyword. llvm::StringRef kw; if constexpr (std::is_same_v) { if (name == "rounding_mode" && succeeded(parser.parseOptionalKeyword(&kw))) { auto sym = symbolizeRoundingMode(kw); if (!sym) return parser.emitError(parser.getCurrentLocation()) << "invalid rounding_mode value: " << kw; auto attr = RoundingModeAttr::get(parser.getContext(), sym.value()); outAttrs.push_back(NamedAttribute(name, attr)); return success(); } } // special handling: mode accepts a *bare* ResizeMode enum keyword. if constexpr (std::is_same_v) { if (name == "mode" && succeeded(parser.parseOptionalKeyword(&kw))) { auto sym = symbolizeResizeMode(kw); if (!sym) return parser.emitError(parser.getCurrentLocation()) << "invalid resize mode value: " << kw; auto attr = ResizeModeAttr::get(parser.getContext(), sym.value()); outAttrs.push_back(NamedAttribute(name, attr)); return success(); } } // special handling: nan_mode accepts a *bare* NanPropagationMode enum // keyword. if constexpr (std::is_same_v) { if (name == "nan_mode" && succeeded(parser.parseOptionalKeyword(&kw))) { auto sym = symbolizeNanPropagationMode(kw); if (!sym) return parser.emitError(parser.getCurrentLocation()) << "invalid nan_mode value: " << kw; auto attr = NanPropagationModeAttr::get(parser.getContext(), sym.value()); outAttrs.push_back(NamedAttribute(name, attr)); return success(); } } // Default path: parse any normal attribute literal, including fully qualified // enum keyword Attribute attr; return parser.parseAttribute(attr, name, outAttrs); } template ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) { // parse operands SmallVector operands; if (parser.parseCommaSeparatedList( [&]() { return parser.parseOperand(operands.emplace_back()); })) return failure(); // Parse { attr-dict } with special handling for enum bare token NamedAttrList attrs; if (succeeded(parser.parseOptionalLBrace()) && failed(parser.parseOptionalRBrace())) { do { if (parseAttrEntryWithEnumHandling(parser, attrs)) return failure(); } while (succeeded(parser.parseOptionalComma())); if (parser.parseRBrace()) return failure(); } FunctionType fnTy; if (parser.parseColonType(fnTy)) return failure(); // Resolve operands and types if (failed(parser.resolveOperands(operands, fnTy.getInputs(), parser.getCurrentLocation(), result.operands))) return failure(); result.addTypes(fnTy.getResult(0)); result.addAttributes(attrs); return success(); } void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) { parser << namedAttr.getName().strref() << " = "; auto attr = namedAttr.getValue(); if (auto roundingModeAttr = dyn_cast(attr)) { parser << roundingModeAttr.getValue(); } else if (auto resizeModeAttr = dyn_cast(attr)) { parser << resizeModeAttr.getValue(); } else if (auto nanPropagationModeAttr = dyn_cast(attr)) { parser << nanPropagationModeAttr.getValue(); } else { parser.printAttribute(attr); } } // print with special handling for default valued NanPropagationMode attribute void printWithNanPropagationHandling(OpAsmPrinter &parser, Operation *op) { parser << " "; parser.printOperands(op->getOperands()); NamedAttrList toPrint(op->getAttrs()); // remove default NanPropagate attribute const auto kDefaultNanValue = NanPropagationMode::PROPAGATE; for (auto attr : op->getAttrs()) { if (auto nanAttr = dyn_cast(attr.getValue())) { if (nanAttr.getValue() == kDefaultNanValue) { // elide from toPrint toPrint.erase(attr.getName()); break; } } } if (!toPrint.empty()) { parser << " {"; llvm::interleaveComma(toPrint, parser, [&](const NamedAttribute namedAttr) { printNamedAttr(parser, namedAttr); }); parser << "}"; } parser << " : "; parser.printFunctionalType(op); } // print with special handling for enums: RoundingMode, ResizeMode void printWithEnumHandling(OpAsmPrinter &parser, Operation *op) { parser << " "; parser.printOperands(op->getOperands()); if (!op->getAttrs().empty()) { parser << " {"; llvm::interleaveComma(op->getAttrs(), parser, [&](const NamedAttribute namedAttr) { printNamedAttr(parser, namedAttr); }); parser << "}"; } parser << " : "; parser.printFunctionalType(op); } } // namespace ParseResult RescaleOp::parse(OpAsmParser &parser, OperationState &result) { return parseWithEnumHandling(parser, result); } void RescaleOp::print(OpAsmPrinter &parser) { printWithEnumHandling(parser, *this); } ParseResult ApplyScaleOp::parse(OpAsmParser &parser, OperationState &result) { return parseWithEnumHandling(parser, result); } void ApplyScaleOp::print(OpAsmPrinter &parser) { printWithEnumHandling(parser, *this); } ParseResult ResizeOp::parse(OpAsmParser &parser, OperationState &result) { return parseWithEnumHandling(parser, result); } void ResizeOp::print(OpAsmPrinter &parser) { printWithEnumHandling(parser, *this); } ParseResult ArgMaxOp::parse(OpAsmParser &parser, OperationState &result) { return parseWithEnumHandling(parser, result); } void ArgMaxOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } ParseResult MaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) { return parseWithEnumHandling(parser, result); } void MaxPool2dOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } ParseResult ClampOp::parse(OpAsmParser &parser, OperationState &result) { return parseWithEnumHandling(parser, result); } void ClampOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } ParseResult MaximumOp::parse(OpAsmParser &parser, OperationState &result) { return parseWithEnumHandling(parser, result); } void MaximumOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } ParseResult MinimumOp::parse(OpAsmParser &parser, OperationState &result) { return parseWithEnumHandling(parser, result); } void MinimumOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } ParseResult ReduceMaxOp::parse(OpAsmParser &parser, OperationState &result) { return parseWithEnumHandling(parser, result); } void ReduceMaxOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } ParseResult ReduceMinOp::parse(OpAsmParser &parser, OperationState &result) { return parseWithEnumHandling(parser, result); } void ReduceMinOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// static std::optional idivCheck(const int64_t lhs, const int64_t rhs) { if (lhs % rhs != 0) return std::nullopt; return lhs / rhs; } static Type getStorageElementTypeOrSelf(Type type) { auto srcType = getElementTypeOrSelf(type); if (auto quantType = llvm::dyn_cast(srcType)) srcType = quantType.getStorageType(); return srcType; } static Type getStorageElementTypeOrSelf(Value value) { return getStorageElementTypeOrSelf(value.getType()); } static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val, Value valZp, StringRef name) { Type eType = getStorageElementTypeOrSelf(val.getType()); Type eZpType = getStorageElementTypeOrSelf(valZp.getType()); bool bothInts = mlir::isa(eType) && mlir::isa(eZpType); bool sameBitWidth = (eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth()); if (!bothInts || !sameBitWidth) { return op->emitOpError() << "expected " << name << " and " << name << "_zp to both be integer of the same bitwidth, but got " << eType << " vs. " << eZpType; } return success(); } // Create a pad-const const tensor with value of `val` of required data-type Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc, Value src, int32_t val) { const auto srcType = getElementTypeOrSelf(src); const auto srcElemType = getStorageElementTypeOrSelf(src); const auto padConstType = mlir::RankedTensorType::get({1}, srcType); const auto padConstEType = mlir::RankedTensorType::get({1}, srcElemType); const auto padConstAttr{ llvm::isa(srcElemType) ? DenseElementsAttr::get(padConstEType, builder.getFloatAttr(srcElemType, val)) : DenseElementsAttr::get(padConstEType, builder.getIntegerAttr(srcElemType, val))}; return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr); } //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// template static LogicalResult verifyConvOp(T op) { const auto inputType = llvm::dyn_cast(op.getInput().getType()); const auto weightType = llvm::dyn_cast(op.getWeight().getType()); auto inputEType = inputType.getElementType(); auto weightEType = weightType.getElementType(); auto biasEType = llvm::cast(op.getBias().getType()).getElementType(); auto resultEType = llvm::cast(op.getResult().getType()).getElementType(); bool biasIsFloat = llvm::isa(biasEType); bool resultIsFloat = llvm::isa(resultEType); if (auto quantType = llvm::dyn_cast(inputEType)) inputEType = quantType.getStorageType(); if (auto quantType = llvm::dyn_cast(weightEType)) weightEType = quantType.getStorageType(); if (auto quantType = llvm::dyn_cast(biasEType)) biasEType = quantType.getStorageType(); if (auto quantType = llvm::dyn_cast(resultEType)) resultEType = quantType.getStorageType(); if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) { // for now, only enforce bias element type == result element type for // float types. op.emitOpError( "expect both bias and result to have same element type, got ") << biasEType << " and " << resultEType; return failure(); } if (isa(inputEType) || isa(inputEType) || isa(weightEType) || isa(weightEType)) { if (inputEType != weightEType) { op.emitOpError( "expect both input and weight to have same element type, got ") << inputEType << " and " << weightEType; return failure(); } } bool inputIsFloat = llvm::isa(inputEType); bool weightIsFloat = llvm::isa(weightEType); // Either both must be float or both non-float. if (inputIsFloat != weightIsFloat) { op.emitOpError( "expect both input and weight to be float or not together, got ") << inputEType << " and " << weightEType; return failure(); } auto inputZpEType = getStorageElementTypeOrSelf(op.getInputZp().getType()); if (inputEType != inputZpEType) { return op.emitOpError("expect both input and its zero point are the same " "element type, got ") << inputEType << " and " << inputZpEType; } auto weightZpEType = getStorageElementTypeOrSelf(op.getWeightZp().getType()); if (weightEType != weightZpEType) { return op.emitOpError("expect both weight and its zero point are the same " "element type, got ") << weightEType << " and " << weightZpEType; } FailureOr maybeIZp = op.getInputZeroPoint(); if (succeeded(maybeIZp) && op.verifyInputZeroPoint(*maybeIZp).failed()) return failure(); FailureOr maybeWZp = op.getWeightZeroPoint(); if (succeeded(maybeWZp) && op.verifyWeightZeroPoint(*maybeWZp).failed()) return failure(); return success(); } LogicalResult tosa::ConstOp::verify() { auto attrType = llvm::dyn_cast(getValuesAttr().getType()); auto outputType = llvm::dyn_cast(getOutput().getType()); if (!attrType || !outputType) { emitOpError("expected tensors for attr/result type"); return failure(); } if (auto result = llvm::dyn_cast( outputType.getElementType())) { if (result.getStorageType() == attrType.getElementType()) return success(); } if (attrType.getElementType() != outputType.getElementType()) { emitOpError("expected same attr/result element types"); return failure(); } return success(); } template static LogicalResult verifyConvOpModes(T op) { auto inputEType = llvm::cast(op.getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(inputEType)) inputEType = quantType.getStorageType(); auto accType = op.getAccType(); if (inputEType.isInteger(8) && !accType.isInteger(32)) return op.emitOpError("accumulator type for i8 tensor is not i32"); if (inputEType.isInteger(16) && !accType.isInteger(48)) return op.emitOpError("accumulator type for i16 tensor is not i48"); if (isa(inputEType) && !accType.isF16()) return op.emitOpError("accumulator type for f8 tensor is not f16"); if (inputEType.isF16() && !(accType.isF16() || accType.isF32())) return op.emitOpError("accumulator type for f16 tensor is not f16/f32"); if (inputEType.isBF16() && !accType.isF32()) return op.emitOpError("accumulator type for bf16 tensor is not f32"); if (inputEType.isF32() && !accType.isF32()) return op.emitOpError("accumulator type for f32 tensor is not f32"); auto resultEType = llvm::cast(op.getResult().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(resultEType)) resultEType = quantType.getStorageType(); return success(); } //===----------------------------------------------------------------------===// // ERROR_IF functions. // ERROR_IF is a predicate that must set an error if the condition holds. //===----------------------------------------------------------------------===// template static LogicalResult verifyConvOpErrorIf(T op) { llvm::ArrayRef padding = op.getPad(); if (llvm::any_of(padding, [](int64_t p) { return p < 0; })) return op.emitOpError("expect all padding values to be >= 0, got ") << padding; llvm::ArrayRef strides = op.getStride(); if (llvm::any_of(strides, [](int64_t s) { return s < 1; })) return op.emitOpError("expect all stride values to be >= 1, got ") << strides; llvm::ArrayRef dilations = op.getDilation(); if (llvm::any_of(dilations, [](int64_t d) { return d < 1; })) return op.emitOpError("expect all dilation values to be >= 1, got ") << dilations; const RankedTensorType outputType = llvm::dyn_cast(op.getOutput().getType()); if (!outputType) // Skip following checks if output is not ranked return success(); const RankedTensorType inputType = llvm::dyn_cast(op.getInput().getType()); const RankedTensorType weightType = llvm::dyn_cast(op.getWeight().getType()); if (inputType && weightType) { const auto verifyOutputSize = [&op](const int64_t inputSize, const int64_t kernelSize, const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName) -> LogicalResult { if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic) return success(); // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1 const std::optional calculatedOutSizeMinusOne = idivCheck( inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation, stride); if (!calculatedOutSizeMinusOne.has_value()) return op.emitOpError("expected input_") << dimName << " - 1 + pad_" << padBeforeName << " + pad_" << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_" << dimAxis << " to be wholly divisible by stride_" << dimAxis << ", got (" << inputSize << " - 1 + " << padBefore << " + " << padAfter << " - (" << kernelSize << " - 1) * " << dilation << ") / " << stride; const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1; if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize) return op.emitOpError("calculated output ") << dimName << " did not match expected: " << "calculated=" << calculatedOutSize << ", expected=" << outputSize; return success(); }; // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_] if constexpr (std::is_same::value) { if (failed(verifyOutputSize( inputType.getDimSize(1), weightType.getDimSize(1), outputType.getDimSize(1), padding[0], padding[1], strides[0], dilations[0], "height", "y", "top", "bottom"))) return failure(); if (failed(verifyOutputSize( inputType.getDimSize(2), weightType.getDimSize(2), outputType.getDimSize(2), padding[2], padding[3], strides[1], dilations[1], "width", "x", "left", "right"))) return failure(); } // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_] if constexpr (std::is_same::value) { if (failed(verifyOutputSize( inputType.getDimSize(1), weightType.getDimSize(0), outputType.getDimSize(1), padding[0], padding[1], strides[0], dilations[0], "height", "y", "top", "bottom"))) return failure(); if (failed(verifyOutputSize( inputType.getDimSize(2), weightType.getDimSize(1), outputType.getDimSize(2), padding[2], padding[3], strides[1], dilations[1], "width", "x", "left", "right"))) return failure(); } // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_] if constexpr (std::is_same::value) { if (failed(verifyOutputSize( inputType.getDimSize(1), weightType.getDimSize(1), outputType.getDimSize(1), padding[0], padding[1], strides[0], dilations[0], "depth", "d", "front", "back"))) return failure(); if (failed(verifyOutputSize( inputType.getDimSize(2), weightType.getDimSize(2), outputType.getDimSize(2), padding[2], padding[3], strides[1], dilations[1], "height", "y", "top", "bottom"))) return failure(); if (failed(verifyOutputSize( inputType.getDimSize(3), weightType.getDimSize(3), outputType.getDimSize(3), padding[4], padding[5], strides[2], dilations[2], "width", "x", "left", "right"))) return failure(); } } const RankedTensorType biasType = llvm::dyn_cast(op.getBias().getType()); if (!biasType) // Skip following checks if bias is not ranked return success(); const int64_t biasChannels = biasType.getDimSize(0); const int64_t outputChannels = outputType.getDimSize(outputType.getRank() - 1); if (biasChannels == ShapedType::kDynamic || outputChannels == ShapedType::kDynamic) // Skip following checks if biasChannels or outputChannels is dynamic dim return success(); if (biasChannels != outputChannels && biasChannels != 1) return op.emitOpError( "bias channels expected to be equal to output channels (") << outputChannels << ") or 1, got " << biasChannels; return success(); } // Verify whether same type and shape of the given two types. static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, Type type1, StringRef name1, Type type2, StringRef name2) { auto shapeType1 = dyn_cast(type1); auto shapeType2 = dyn_cast(type2); if (!shapeType1 || !shapeType2) return failure(); auto elemType1 = shapeType1.getElementType(); auto elemType2 = shapeType2.getElementType(); if (elemType1 != elemType2) return op->emitOpError() << "require same element type for " << name1 << " (" << elemType1 << ") and " << name2 << " (" << elemType2 << ")"; if (failed(verifyCompatibleShape(type1, type2))) return op->emitOpError() << "require same shapes for " << name1 << " (" << type1 << ") and " << name2 << " (" << type2 << ")"; return success(); } // Verify whether same length, type, and shape of the given two tensor lists. static LogicalResult errorIfTypeOrShapeMismatch(Operation *op, ValueRange list1, StringRef name1, ValueRange list2, StringRef name2) { if (list1.size() != list2.size()) return op->emitOpError() << "require same number of values in " << name1 << " (" << list1.size() << ") and " << name2 << " (" << list2.size() << ")"; for (auto [type1, type2] : llvm::zip_equal(list1.getTypes(), list2.getTypes())) { if (errorIfTypeOrShapeMismatch(op, type1, name1, type2, name2).failed()) return failure(); } return success(); } static inline LogicalResult errorIfShapeNotSizeOne(Operation *op, Type type) { ShapeAdaptor shapeAdaptor(type); if (!shapeAdaptor.hasRank() || !shapeAdaptor.hasStaticShape()) return success(); return shapeAdaptor.getNumElements() == 1 ? success() : failure(); } template static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { Operation *symTableOp = op->template getParentWithTrait(); if (!symTableOp) // If the operation is not the scope of a symbol table, we cannot // verify it against it's declaration. return success(); SymbolTable symTable(symTableOp); const auto varOp = symTable.lookup(op.getName()); // Verify prior declaration if (!varOp) return op->emitOpError("'") << op.getName() << "' has not been declared by 'tosa.variable'"; // Verify type and shape auto variableType = getVariableType(varOp); if (errorIfTypeOrShapeMismatch(op, type, name, variableType, "the input tensor") .failed()) return failure(); return success(); } // verify that inType and outType have same element types template static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { auto inputType = llvm::dyn_cast(inType); auto outputType = llvm::dyn_cast(outType); if (!inputType) { op.emitOpError("expect shaped tensor for input, got ") << inType; return failure(); } if (!outputType) { op.emitOpError("expect shaped tensor for output, got ") << outType; return failure(); } auto inputElementType = inputType.getElementType(); auto outputElementType = outputType.getElementType(); auto inputQuantType = llvm::dyn_cast(inputElementType); auto outputQuantType = llvm::dyn_cast(outputElementType); if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) && (outputElementType.isIntOrIndexOrFloat() || outputQuantType) && inputElementType != outputElementType) { // only check if both element types are int/index/float/UniformQuantized // eg, not sure how to check quant::QuantizedType // this happens in test_conv2d_q_grouped_convolution in // tfl-to-tosa-pipeline.mlir op.emitOpError("expect input and output to have same element type, got ") << inputElementType << " and " << outputElementType; return failure(); } return success(); } LogicalResult tosa::ArgMaxOp::verify() { const ShapedType resultType = llvm::cast(getType()); // Ensure output is of 32-bit integer if (const auto resultETy = resultType.getElementType(); !resultETy.isIntOrIndex()) return emitOpError("result tensor is not of integer type"); const auto inputType = llvm::cast(getInput().getType()); if (!inputType.hasRank()) return success(); // Ensure axis is within the tensor rank const int64_t axis = getAxisAttr().getInt(); if (((axis < 0) || axis >= inputType.getRank())) return emitOpError("specified axis is outside the rank of the tensor"); if (!resultType.hasRank()) return success(); const ArrayRef inputShape = inputType.getShape(); const ArrayRef outputShape = resultType.getShape(); llvm::SmallVector expectedOutputShape(inputShape); expectedOutputShape.erase(expectedOutputShape.begin() + axis); if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) return emitOpError("expected output shape '") << expectedOutputShape << "', got '" << outputShape << "'"; return success(); } template static LogicalResult verifyPoolingOp(T op) { const llvm::ArrayRef kernel = op.getKernel(); if (llvm::any_of(kernel, [](int64_t s) { return s < 1; })) return op.emitOpError("expect all kernel values to be >= 1, got ") << kernel; const llvm::ArrayRef strides = op.getStride(); if (llvm::any_of(strides, [](int64_t s) { return s < 1; })) return op.emitOpError("expect all stride values to be >= 1, got ") << strides; const llvm::ArrayRef padding = op.getPad(); if (llvm::any_of(padding, [](int64_t p) { return p < 0; })) return op.emitOpError("expect all padding values to be >= 0, got ") << padding; // Padding must be less than kernel size to avoid a divide-by-zero const int64_t kernelX = kernel[1]; const int64_t padLeft = padding[2]; const int64_t padRight = padding[3]; if (padRight >= kernelX || padLeft >= kernelX) return op.emitOpError("expected left/right padding to be less than the " "width of the kernel, got pad_left=") << padLeft << ", pad_right=" << padRight << ", kernel_x=" << kernelX; const int64_t kernelY = kernel[0]; const int64_t padTop = padding[0]; const int64_t padBottom = padding[1]; if (padTop >= kernelY || padBottom >= kernelY) return op.emitOpError("expected top/bottom padding to be less than the " "height of the kernel, got pad_top=") << padTop << ", pad_bottom=" << padBottom << ", kernel_y=" << kernelY; const auto inputType = llvm::dyn_cast(op.getInput().getType()); const auto outputType = llvm::dyn_cast(op.getResult().getType()); if (!inputType || !outputType) return success(); const auto verifyOutputSize = [&op](const int64_t inputSize, const int64_t outputSize, const int64_t kernelSize, const int64_t strideSize, const int64_t padBefore, const int64_t padAfter, const llvm::StringRef dimName, const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, const llvm::StringRef padAfterName) -> LogicalResult { if (ShapedType::isDynamic(inputSize)) return success(); const std::optional calculatedOutSizeMinusOne = idivCheck(inputSize + padBefore + padAfter - kernelSize, strideSize); if (!calculatedOutSizeMinusOne.has_value()) return op.emitOpError("expected input_") << dimName << " + pad_" << padBeforeName << " + pad_" << padAfterName << " - kernel_" << dimAxis << " to be wholly divisible by stride_" << dimAxis << ", got (" << inputSize << " + " << padBefore << " + " << padAfter << " - " << kernelSize << ") / " << strideSize; const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1; if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize) return op.emitOpError("calculated output ") << dimName << " did not match expected: " << "calculated=" << calculatedOutSize << ", expected=" << outputSize; return success(); }; if (failed(verifyOutputSize(inputType.getDimSize(1), outputType.getDimSize(1), kernel[0], strides[0], padding[0], padding[1], "height", "y", "top", "bottom"))) return failure(); if (failed(verifyOutputSize(inputType.getDimSize(2), outputType.getDimSize(2), kernel[1], strides[1], padding[2], padding[3], "width", "x", "left", "right"))) return failure(); return success(); } LogicalResult tosa::AvgPool2dOp::verify() { if (failed(verifyPoolingOp(*this))) return failure(); const Type inputETy = getStorageElementTypeOrSelf(getInput().getType()); const Type resultETy = getStorageElementTypeOrSelf(getOutput().getType()); const Type inputZpETy = getStorageElementTypeOrSelf(getInputZp().getType()); const Type outputZpETy = getStorageElementTypeOrSelf(getOutputZp().getType()); auto accType = getAccType(); if (llvm::isa(inputETy) && !accType.isInteger(32)) return emitOpError("accumulator type for integer tensor is not i32"); if (inputETy.isF16() && !(accType.isF16() || accType.isF32())) return emitOpError("accumulator type for f16 tensor is not f16/f32"); if (inputETy.isBF16() && !accType.isF32()) return emitOpError("accumulator type for bf16 tensor is not f32"); if (inputETy.isF32() && !accType.isF32()) return emitOpError("accumulator type for f32 tensor is not f32"); if (inputETy != inputZpETy) return emitOpError("expect both input and its zero point are the same " "element type, got ") << inputETy << " and " << inputZpETy; if (resultETy != outputZpETy) return emitOpError("expect both output and its zero point are the same " "element type, got ") << resultETy << " and " << outputZpETy; FailureOr maybeIZp = getInputZeroPoint(); if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed()) return failure(); FailureOr maybeOZp = getOutputZeroPoint(); if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed()) return failure(); return success(); } LogicalResult tosa::ClampOp::verify() { mlir::Type inputETy = llvm::cast(getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(inputETy)) { inputETy = quantType.getStorageType(); } mlir::Type outputETy = llvm::cast(getOutput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(outputETy)) { outputETy = quantType.getStorageType(); } if (inputETy != outputETy) return emitOpError("input/output element types are incompatible."); auto maxValAttr = getMaxValAttr(); auto minValAttr = getMinValAttr(); unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth(); if (inputETy.isInteger(dataTypeBitWidth)) { // if input datatype is integer, check that the min_val/max_val attributes // are integer attributes, and that their type is the same as the input's // datatype auto intMaxValAttr = mlir::dyn_cast(maxValAttr); auto intMinValAttr = mlir::dyn_cast(minValAttr); if (!intMaxValAttr || !intMinValAttr || (intMaxValAttr.getType() != intMinValAttr.getType()) || (intMaxValAttr.getType() != inputETy)) return emitOpError("min/max attributes types are incompatible with " "input/output element types."); const bool isUnsigned = inputETy.isUnsignedInteger(); const bool isBoolean = inputETy.isInteger(1); const APInt minVal = intMinValAttr.getValue(); const APInt maxVal = intMaxValAttr.getValue(); if ((isUnsigned || isBoolean) ? maxVal.ult(minVal) : maxVal.slt(minVal)) return emitOpError("expected min_val <= max_val, got min_val=") << minValAttr << ", max_val=" << maxValAttr; } else { // otherwise, input datatype is float, check that the min_val/max_val // attributes share the same type and that their type is the same as the // input's datatype auto floatMaxValAttr = mlir::dyn_cast(maxValAttr); auto floatMinValAttr = mlir::dyn_cast(minValAttr); if (!floatMaxValAttr || !floatMinValAttr || (floatMaxValAttr.getType() != floatMinValAttr.getType()) || (floatMaxValAttr.getType() != inputETy)) return emitOpError("min/max attributes types are incompatible with " "input/output element types."); const APFloat minVal = floatMinValAttr.getValue(); const APFloat maxVal = floatMaxValAttr.getValue(); if (minVal.isNaN() || maxVal.isNaN()) return emitOpError("min/max attributes should not be 'NaN', got min_val=") << minValAttr << ", max_val=" << maxValAttr; if (maxVal < minVal) return emitOpError("expected min_val <= max_val, got min_val=") << minValAttr << ", max_val=" << maxValAttr; } return success(); } //===----------------------------------------------------------------------===// // TOSA Operator Quantization Builders. //===----------------------------------------------------------------------===// /// This builder is called on all convolution operators except TransposeConv, /// which has specialized output shape semantics. The builder also defines the /// bitwidth of the output given the bit width of the input & weight content. static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr pad, DenseI64ArrayAttr stride, DenseI64ArrayAttr dilation, TypeAttr accType) { auto zps = createZPsAsConst(builder, input, weight); result.addOperands({input, weight, bias, zps.first, zps.second}); result.addAttribute("pad", pad); result.addAttribute("stride", stride); result.addAttribute("dilation", dilation); result.addAttribute("acc_type", accType); Type finalOutputType = outputType; auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { finalOutputType = buildConvOpResultTypeInfo(builder, outputType, input, weight); } result.addTypes(finalOutputType); } /// Handles tosa.transpose_conv2d which has outpad and output shape /// attributes. static void buildTransConvOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value weight, Value bias, DenseI64ArrayAttr outpad, DenseI64ArrayAttr stride, TypeAttr accType) { auto zps = createZPsAsConst(builder, input, weight); result.addOperands({input, weight, bias, zps.first, zps.second}); result.addAttribute("out_pad", outpad); result.addAttribute("stride", stride); result.addAttribute("acc_type", accType); Type finalOutputType = outputType; auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { finalOutputType = buildConvOpResultTypeInfo(builder, outputType, input, weight); } result.addTypes(finalOutputType); } /// The tosa.matmul op is also intended to be generated where a fully_connected /// op must be constructed where the weight is not a constant. In this case, /// the fully_connected op must be expressed using matmul. /// TODO: Add link to the leglization document explaining this. static void buildMatMulOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value a, Value b) { auto zps = createZPsAsConst(builder, a, b); result.addOperands({a, b, zps.first, zps.second}); Type finalOutputType{outputType}; if (auto quantAttr = buildMatMulOpQuantizationAttr(builder, a, b)) { auto eType = getStorageElementTypeOrSelf(a.getType()); auto inputBits = eType.getIntOrFloatBitWidth(); auto outputShapedType = llvm::dyn_cast(outputType); assert(outputShapedType && "Output must be a shaped type"); IntegerType accElementType; if (inputBits == 16) accElementType = builder.getIntegerType(48); else accElementType = builder.getI32Type(); finalOutputType = outputShapedType.clone(accElementType); } result.addTypes(finalOutputType); } /// Both the tosa.avg_pool2d and unary ops use the same /// UnaryOpQuantizationAttr but avg_pool operator has its own builder as it /// has additional parameters not part of the unary ops. static void buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, DenseArrayAttr kernel, DenseArrayAttr stride, DenseArrayAttr pad, TypeAttr accType) { const Location loc{result.location}; int64_t inputZp{0}; int64_t outputZp{0}; if (auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType)) { inputZp = quantAttr.getInputZp(); outputZp = quantAttr.getOutputZp(); } const std::optional inputZpOp = createZeroPointTensor(builder, loc, input.getType(), inputZp); if (!inputZpOp) { (void)emitError( loc, "Failed to create input zero point tensor for quantized AVG_POOL2D op"); } const std::optional outputZpOp = createZeroPointTensor(builder, loc, outputType, outputZp); if (!outputZpOp) { (void)emitError(loc, "Failed to create output zero point tensor for " "quantized AVG_POOL2D op"); } if (inputZpOp && outputZpOp) { result.addOperands({input, inputZpOp.value(), outputZpOp.value()}); } else { // failed to create one or more zero points above: just add input as // operands this will trigger error in building the op because of missing // zero points result.addOperands({input}); } result.addAttribute("kernel", kernel); result.addAttribute("stride", stride); result.addAttribute("pad", pad); result.addAttribute("acc_type", accType); result.types.push_back(outputType); } /// This builder is called on single-parameter negate operator /// to construct input and output zero points based on their /// types. static void buildNegateOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input) { const Location loc{result.location}; int64_t input1Zp{0}; int64_t outputZp{0}; auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); if (quantAttr) { input1Zp = quantAttr.getInputZp(); outputZp = quantAttr.getOutputZp(); } const std::optional input1ZpOp = createZeroPointTensor(builder, loc, input.getType(), input1Zp); if (!input1ZpOp) { (void)emitError( loc, "Failed to create input1 zero point for quantized NEGATE op"); } const std::optional outputZpOp = createZeroPointTensor(builder, loc, input.getType(), outputZp); if (!outputZpOp) { (void)emitError( loc, "Failed to create output zero point for quantized NEGATE op"); } if (input1ZpOp && outputZpOp) { result.addOperands({input, input1ZpOp.value(), outputZpOp.value()}); } else { // failed to create one or more zero points above: just add input as // operands. This will trigger error in building the op because of // missing zero points result.addOperands({input}); } result.types.push_back(outputType); } /// This builder is called on TOSA pad operator that needs to create its own /// OptionalAttr quantization_attr parameter to scale the padding values /// correctly. No pad_const is interpreted as zero-padding. static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Type outputType, Value input, Value paddings) { const Location loc{result.location}; int32_t zp{0}; const auto quantAttr = buildPadOpQuantizationAttr(builder, input); if (quantAttr) { zp = static_cast(quantAttr.getInputZp()); } const auto padConstOp{createPadConstTensor(builder, loc, input, zp)}; result.addOperands({input, paddings, padConstOp}); result.types.push_back(outputType); } static void buildVariableOp(OpBuilder &builder, OperationState &result, StringRef name, Type variableType, Attribute initialValue) { const Location loc{result.location}; auto nameAttr = builder.getStringAttr(name); auto shapedType = dyn_cast(variableType); if (!shapedType) { (void)emitError(loc, "variable type must be a shaped type"); return; } if (!shapedType.hasRank()) { (void)emitError(loc, "variable type must be a ranked type"); return; } auto elementType = shapedType.getElementType(); auto elementTypeAttr = TypeAttr::get(elementType); ArrayRef shape = shapedType.getShape(); auto varShapeAttr = builder.getIndexTensorAttr(convertFromMlirShape(shape)); result.addAttribute("sym_name", nameAttr); result.addAttribute("var_shape", varShapeAttr); result.addAttribute("type", elementTypeAttr); result.addAttribute("initial_value", initialValue); } //===----------------------------------------------------------------------===// // TOSA Operator Return Type Inference. //===----------------------------------------------------------------------===// static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector &outShape) { int64_t outRank = 0; for (int i = 0, e = operands.size(); i != e; ++i) { auto shape = operands.getShape(i); if (!shape.hasRank()) { // TODO(jennik): Update function to have better case handling for // invalid operands and for ranked tensors. return failure(); } outRank = std::max(outRank, shape.getRank()); } outShape.resize(outRank, 1); for (int i = 0, e = operands.size(); i != e; ++i) { auto shape = operands.getShape(i); auto rankDiff = outShape.size() - shape.getRank(); for (size_t i = 0, e = shape.getRank(); i < e; ++i) { auto dim1 = outShape[i + rankDiff]; auto dim2 = shape.getDimSize(i); auto resolvedDim = dim1; if (dim1 == 1) { resolvedDim = dim2; } else if (dim2 == 1) { resolvedDim = dim1; } else if (dim1 != dim2) { return failure(); } outShape[i + rankDiff] = resolvedDim; } } return success(); } LogicalResult tosa::ArgMaxOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ArgMaxOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); IntegerAttr axis = adaptor.getProperties().axis; int32_t axisVal = axis.getValue().getSExtValue(); if (!inputShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } SmallVector outShape; outShape.reserve(inputShape.getRank() - 1); for (int i = 0, s = inputShape.getRank(); i < s; i++) { if (i == axisVal) continue; outShape.push_back(inputShape.getDimSize(i)); } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); return success(); } LogicalResult tosa::RFFT2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, RFFT2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInputReal().getType()); if (!inputShape.hasRank()) return failure(); llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamic); outputShape[0] = inputShape.getDimSize(0); outputShape[1] = inputShape.getDimSize(1); int64_t inWidth = inputShape.getDimSize(2); // Note that we can support this calculation symbolically // in the future e.g. [x, y, z] -> [x, y, z / 2 + 1] if (inWidth != ShapedType::kDynamic) outputShape[2] = inWidth / 2 + 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } static LogicalResult verifyDimIsPowerOfTwo(Operation *op, const int64_t dimSize, const llvm::StringRef dimName) { const bool isPowerOfTwo = (dimSize & (dimSize - 1)) == 0 && dimSize > 0; if (!isPowerOfTwo) return op->emitOpError("expected ") << dimName << " to be a power of two, got " << dimSize; return success(); } LogicalResult tosa::RFFT2dOp::verify() { const auto outputTypes = getResultTypes(); if (failed(verifyCompatibleShapes(outputTypes))) return emitOpError("expected output shapes to match, got ") << outputTypes; const auto inputType = llvm::dyn_cast(getInputReal().getType()); if (!inputType) return success(); const int64_t height = inputType.getDimSize(1); if (ShapedType::isStatic(height) && failed(verifyDimIsPowerOfTwo(*this, height, "height"))) return failure(); const int64_t width = inputType.getDimSize(2); if (ShapedType::isStatic(width) && failed(verifyDimIsPowerOfTwo(*this, width, "width"))) return failure(); const auto outputType = llvm::dyn_cast(outputTypes[0]); if (!outputType) return success(); // Batch and height input/output dimensions should match if (failed(verifyCompatibleShape(inputType.getShape().drop_back(), outputType.getShape().drop_back()))) return emitOpError("expected batch and height dimensions of input/output " "to match, got input=") << inputType << " output=" << outputType; // Output width dimension expected to be input_width / 2 + 1 const int64_t outputWidth = outputType.getDimSize(2); if (ShapedType::isStatic(width) && ShapedType::isStatic(outputWidth) && (outputWidth != (width / 2) + 1)) return emitOpError( "expected output width to be equal to input_width / 2 + 1, got ") << outputWidth; return success(); } LogicalResult tosa::FFT2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, FFT2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { inferredReturnShapes.push_back( ShapedTypeComponents(ShapeAdaptor(adaptor.getInputReal().getType()))); inferredReturnShapes.push_back( ShapedTypeComponents(ShapeAdaptor(adaptor.getInputImag().getType()))); return success(); } LogicalResult tosa::FFT2dOp::verify() { const auto inputRealType = llvm::dyn_cast(getInputReal().getType()); const auto inputImagType = llvm::dyn_cast(getInputImag().getType()); if (!inputRealType || !inputImagType) return success(); const auto trySelectStaticDim = [](const int64_t a, const int64_t b) { return ShapedType::isDynamic(a) ? a : b; }; const int64_t height = trySelectStaticDim(inputRealType.getDimSize(1), inputImagType.getDimSize(1)); if (ShapedType::isStatic(height) && failed(verifyDimIsPowerOfTwo(*this, height, "height"))) return failure(); const int64_t width = trySelectStaticDim(inputRealType.getDimSize(2), inputImagType.getDimSize(2)); if (ShapedType::isStatic(width) && failed(verifyDimIsPowerOfTwo(*this, width, "width"))) return failure(); return success(); } LogicalResult tosa::ConcatOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ConcatOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { // Infer all dimension sizes by reducing based on inputs. const Properties &prop = adaptor.getProperties(); int32_t axis = prop.axis.getValue().getSExtValue(); llvm::SmallVector outputShape; bool hasRankedInput = false; for (auto operand : adaptor.getOperands()) { ShapeAdaptor operandShape(operand.getType()); if (!operandShape.hasRank()) continue; // Copy the Operand's rank. if (!hasRankedInput) outputShape.resize(operandShape.getRank(), ShapedType::kDynamic); // Copy shapes until the dim is non-dynamic. for (int i = 0, s = operandShape.getRank(); i < s; i++) { if (i == axis || operandShape.isDynamicDim(i)) continue; if (outputShape[i] == ShapedType::kDynamic) outputShape[i] = operandShape.getDimSize(i); if (outputShape[i] != operandShape.getDimSize(i)) return emitOptionalError(location, "Cannot concat tensors with different sizes" " on the non-axis dimension ", i); } hasRankedInput = true; } if (adaptor.getInput1().empty()) return failure(); Type inputType = llvm::cast(adaptor.getInput1().getType()[0]).getElementType(); if (!hasRankedInput) { inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); } // Determine the dimension size along the concatenation axis. int64_t concatDimSize = 0; for (auto operand : adaptor.getOperands()) { ShapeAdaptor operandShape(operand.getType()); // We need to know the length of the concatenation axis of all inputs to // determine the dimension size of the output shape. if (!operandShape.hasRank() || operandShape.isDynamicDim(axis)) { concatDimSize = ShapedType::kDynamic; break; } concatDimSize += operandShape.getDimSize(axis); } outputShape[axis] = concatDimSize; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } LogicalResult tosa::ConcatOp::verify() { // check that each input has same element type as output auto outType = getOutput().getType(); const Operation::operand_range inputList = getInput1(); // Check there is at least one input if (inputList.empty()) return emitOpError("expect at least one input"); if (!llvm::all_of(inputList, [&](auto input) { return succeeded(verifySameElementTypes( *this, /* inType = */ input.getType(), outType)); })) { return failure(); } const int32_t axis = getAxis(); ShapeAdaptor firstRankedInputShape = nullptr; for (const auto &input : inputList) { const Type inputType = input.getType(); ShapeAdaptor currShape(inputType); if (currShape.hasRank()) { firstRankedInputShape = currShape; // Check axis is in expected range if (axis < 0 || axis >= firstRankedInputShape.getRank()) return emitOpError("expect axis to be within range 0 < axis < " "rank(input1[firstRankedTensorIdx]), got ") << axis; break; } } const auto allOperandsHasRank = [](const Value input) { return ShapeAdaptor(input.getType()).hasRank(); }; if (llvm::all_of(inputList, allOperandsHasRank)) { const int64_t firstInputRank = firstRankedInputShape.getRank(); for (const auto &[index, input] : llvm::enumerate(inputList.drop_front())) { const ShapeAdaptor inputShape(input.getType()); const int64_t inputRank = inputShape.getRank(); const size_t operandNum = index + 1; // Check that each operand has the same rank if (inputRank != firstInputRank) return emitOpError( "expect all operands to have the same rank, but got ") << firstInputRank << " vs " << inputRank << " on operands 0 and " << operandNum; // Check non-axis dims match for (int i = 0; i < inputRank; i++) { const int64_t inputDim = inputShape.getDimSize(i); const int64_t firstInputDim = firstRankedInputShape.getDimSize(i); if (i == axis || firstRankedInputShape.isDynamicDim(i) || inputShape.isDynamicDim(i)) continue; if (inputDim != firstInputDim) return emitOpError("expect all operand shapes to have the same sizes " "on non-axis dimensions, but got ") << inputDim << " vs " << firstInputDim << " at index " << i << " on operands 0 and " << operandNum; } } // ERROR_IF(axis_sum != shape[axis]); int64_t axisSum = 0; for (const auto &input : inputList) { const ShapeAdaptor inputShape(input.getType()); if (inputShape.isDynamicDim(axis)) { // make axisSum negative to indicate invalid value axisSum = -1; break; } axisSum += inputShape.getDimSize(axis); } const ShapeAdaptor outputShape(outType); if (axisSum >= 0 && outputShape.hasRank() && !outputShape.isDynamicDim(axis) && axisSum != outputShape.getDimSize(axis)) return emitOpError("requires sum of axis dimensions of input1 " "equal to output axis dimension, got ") << axisSum << " and " << outputShape.getDimSize(axis); } return success(); } LogicalResult tosa::EqualOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { auto elementType = IntegerType::get(context, /*width=*/1); llvm::SmallVector outShape; if (resolveBroadcastShape(operands, outShape).failed()) { inferredReturnShapes.push_back(ShapedTypeComponents(elementType)); return success(); } inferredReturnShapes.push_back(ShapedTypeComponents(outShape, elementType)); return success(); } bool tosa::EqualOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != r.size() || l.size() != 1) return false; return succeeded(verifyCompatibleShape(l[0], r[0])); } LogicalResult tosa::MatMulOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, MatMulOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor lhsShape(adaptor.getA().getType()); ShapeAdaptor rhsShape(adaptor.getB().getType()); // All shapes are dynamic. SmallVector outShape; outShape.resize(3, ShapedType::kDynamic); if (lhsShape.hasRank()) { outShape[0] = lhsShape.getDimSize(0); outShape[1] = lhsShape.getDimSize(1); } if (rhsShape.hasRank()) { outShape[0] = outShape[0] == ShapedType::kDynamic ? rhsShape.getDimSize(0) : outShape[0]; outShape[2] = rhsShape.getDimSize(2); } inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); return success(); } LogicalResult MatMulOp::verify() { auto aType = llvm::dyn_cast(getA().getType()); auto bType = llvm::dyn_cast(getB().getType()); // Must be shaped tensor types if (!aType) return emitOpError("expect a shaped tensor for input a, got ") << getA().getType(); if (!bType) return emitOpError("expect a shaped tensor for input b, got ") << getB().getType(); auto aElementType = aType.getElementType(); auto bElementType = bType.getElementType(); auto aQuantizedEType = llvm::dyn_cast(aElementType); auto bQuantizedEType = llvm::dyn_cast(bElementType); if (aQuantizedEType || bQuantizedEType) { if (!aQuantizedEType || !bQuantizedEType) { return emitOpError("expect operands to be both quantized or both not " "quantized, got ") << aElementType << " and " << bElementType; } // both a and b have quantized element types auto aQuantWidth = aQuantizedEType.getStorageTypeIntegralWidth(); auto bQuantWidth = bQuantizedEType.getStorageTypeIntegralWidth(); if (aQuantWidth != bQuantWidth) { return emitOpError("expect quantized operands to have same widths, got ") << aQuantWidth << " and " << bQuantWidth; } } // check a_zp and b_zp auto aEType = getStorageElementTypeOrSelf(aType); auto aZpEType = getStorageElementTypeOrSelf(getAZp().getType()); if (aEType != aZpEType) { return emitOpError("expect input a and a_zp have the same " "element type, got ") << aEType << " and " << aZpEType; } auto bEType = getStorageElementTypeOrSelf(bType); auto bZpEType = getStorageElementTypeOrSelf(getBZp().getType()); if (bEType != bZpEType) { return emitOpError("expect input b and b_zp have the same " "element type, got ") << bEType << " and " << bZpEType; } FailureOr maybeAZp = getAZeroPoint(); if (succeeded(maybeAZp) && verifyAZeroPoint(*maybeAZp).failed()) return failure(); FailureOr maybeBZp = getBZeroPoint(); if (succeeded(maybeBZp) && verifyBZeroPoint(*maybeBZp).failed()) return failure(); return success(); } LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, PadOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); auto paddingRank = cast(adaptor.getPadding().getType()).getRank(); SmallVector outputShape; // If the input rank is unknown, we can infer the output rank using the // padding shape's rank divided by 2. if (!inputShape.hasRank()) { outputShape.resize(paddingRank / 2, ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } SmallVector paddingValues; // If the paddings value is not a constant, all dimensions must be dynamic. if (!tosa::getConstShapeValues(adaptor.getPadding().getDefiningOp(), paddingValues)) { outputShape.resize(inputShape.getRank(), ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } outputShape.reserve(inputShape.getRank()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { if (inputShape.isDynamicDim(i)) { outputShape.push_back(ShapedType::kDynamic); continue; } auto padFront = paddingValues[i * 2]; auto padBack = paddingValues[i * 2 + 1]; if (padFront < 0 || padBack < 0) { // if either padding for dim i is -1, output dim is unknown outputShape.push_back(ShapedType::kDynamic); continue; } outputShape.push_back(inputShape.getDimSize(i) + padFront + padBack); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::PadOp::verify() { if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(), /* outType = */ getOutput().getType()) .failed()) { return failure(); } if (auto padConst = getPadConst()) { if (verifySameElementTypes(*this, /* inType = */ padConst.getType(), /* outType = */ getOutput().getType()) .failed()) { return failure(); } } RankedTensorType inputType = llvm::dyn_cast(getInput1().getType()); RankedTensorType outputType = llvm::dyn_cast(getOutput().getType()); if (!inputType || !outputType) return success(); auto inputRank = inputType.getRank(); auto outputRank = outputType.getRank(); if (inputRank != outputRank) return emitOpError() << "expect same input and output tensor rank, but got " << "inputRank: " << inputRank << ", outputRank: " << outputRank; DenseIntElementsAttr paddingAttr; if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) { return failure(); } auto paddingValues = paddingAttr.getValues(); if (paddingValues.size() != static_cast(inputRank * 2)) return emitOpError() << "padding tensor must have " << inputRank << " * 2 = " << inputRank * 2 << " elements, but got " << paddingValues.size(); auto inputShape = inputType.getShape(); auto outputShape = outputType.getShape(); for (int64_t i = 0; i < inputRank; ++i) { int64_t padStart = paddingValues[i * 2].getSExtValue(); int64_t padEnd = paddingValues[i * 2 + 1].getSExtValue(); if ((padStart < 0 && padStart != -1) || (padEnd < 0 && padEnd != -1)) { return emitOpError() << "invalid padding values at dimension " << i << ": values must be non-negative or -1 for dynamic padding, got [" << padStart << ", " << padEnd << "]"; } // Skip shape verification for dynamic input/output if (inputShape[i] == ShapedType::kDynamic || outputShape[i] == ShapedType::kDynamic) continue; if (outputShape[i] != inputShape[i] + padStart + padEnd) { return emitOpError() << "mismatch in output shape at dimension " << i << ": expected " << inputShape[i] << " + " << padStart << " + " << padEnd << " = " << (inputShape[i] + padStart + padEnd) << ", but got " << outputShape[i]; } } return success(); } LogicalResult tosa::SliceOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, SliceOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); SmallVector start; SmallVector size; if (!tosa::getConstShapeValues(adaptor.getStart().getDefiningOp(), start) || !tosa::getConstShapeValues(adaptor.getSize().getDefiningOp(), size)) { auto rank = cast(adaptor.getSize().getType()).getRank(); SmallVector fallback(rank, ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType)); return success(); } // if size[i] is -1, all remaining elements in dimension i are included // in the slice, similar to TF. ShapeAdaptor inputShape(adaptor.getInput1().getType()); // initialize outputShape to all unknown SmallVector outputShape(size.size(), ShapedType::kDynamic); if (inputShape.hasRank()) { for (size_t i = 0; i < size.size(); i++) { if (size[i] != 0 && size[i] >= -1 && start[i] >= 0 && (ShapedType::isDynamic(inputShape.getDimSize(i)) || start[i] < inputShape.getDimSize(i))) { // size[i] is not 0 and not < -1, and start[i] is in valid range if (ShapedType::isDynamic(inputShape.getDimSize(i))) { // input shape has unknown dim[i] - only valid if size[i] > 0 if (size[i] > 0) { outputShape[i] = size[i]; } } else { // input shape has known dim[i] if (size[i] == -1) { outputShape[i] = inputShape.getDimSize(i) - start[i]; } else if (start[i] + size[i] <= inputShape.getDimSize(i)) { // start[i] + size[i] is within bound of input shape's dim[i] outputShape[i] = size[i]; } } } } } else { outputShape = convertToMlirShape(size); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::SliceOp::verify() { if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(), /* outType = */ getOutput().getType()) .failed()) return failure(); const ShapeAdaptor inputShape(getInput1().getType()); if (inputShape.hasRank()) { const auto inputRank = inputShape.getRank(); const ShapeAdaptor outputShape(getOutput().getType()); if (outputShape.hasRank() && inputRank != outputShape.getRank()) return emitOpError( "expect input1 and output to have the same ranks, got ") << inputRank << " and " << outputShape.getRank(); const auto startShapeRank = llvm::cast(getStart().getType()).getRank(); if (inputRank != startShapeRank) return emitOpError("length of start is not equal to rank of input shape"); const auto sizeShapeRank = llvm::cast(getSize().getType()).getRank(); if (inputRank != sizeShapeRank) return emitOpError("length of size is not equal to rank of input shape"); } return success(); } LogicalResult tosa::MulOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { // mul op's output shape only depend on input1 and input2, not on shift ValueShapeRange twoInputs = operands.drop_back(); llvm::SmallVector outShape; if (resolveBroadcastShape(twoInputs, outShape).failed()) { inferredReturnShapes.push_back(ShapedTypeComponents()); } else { inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); } return success(); } LogicalResult tosa::MulOp::verify() { const Value output = getOutput(); auto resElemType = getElementTypeOrSelf(output); // Verify if the element type among operands and result match tosa // specification. if (auto resIntType = dyn_cast(resElemType)) { IntegerType lhsIntType = dyn_cast(getElementTypeOrSelf(getInput1())); IntegerType rhsIntType = dyn_cast(getElementTypeOrSelf(getInput2())); if (!lhsIntType || !rhsIntType || lhsIntType != rhsIntType) return emitOpError("requires the same element type for all operands"); // Though the spec requires the element type of result to be i32, a more // relaxed way is provided at dialect level for easier cooperating with // other dialects. if (lhsIntType.getWidth() > resIntType.getWidth()) return emitOpError("invalid data type size for operands or result"); } else { // For other supported type, the spec requires requires the same element // type for all operands (excludes `shift` operand) and results. for (int i = 0; i < 2; ++i) { if (getElementTypeOrSelf(getOperand(i)) != resElemType) return emitOpError( "requires the same element type for all operands and results"); } // verify shift has value 0 for non-integer types ElementsAttr shift_elem; if (matchPattern(getShift(), m_Constant(&shift_elem))) { int32_t shift = shift_elem.getValues()[0].getInt(); if (shift != 0) { return emitOpError() << "require shift to be 0 for float type"; } } } // Verify the op has same ranks for all main operands (excludes extra operands // such as shift of mul op, so this is the only difference with the built-in // `SameOperandsAndResultRank` trait) and results types, if known. TypeRange operandTypes = getOperandTypes(); ShapedType aType = cast(operandTypes[0]); ShapedType bType = cast(operandTypes[1]); const bool aHasRank = aType.hasRank(); const bool bHasRank = bType.hasRank(); if (aHasRank && bHasRank) { const int64_t aRank = aType.getRank(); const int64_t bRank = bType.getRank(); if (aRank != bRank) return emitOpError("a and b operands don't have matching ranks, got ") << aRank << " and " << bRank; // check for broadcast compatible shapes SmallVector resultShape; if (!mlir::OpTrait::util::getBroadcastedShape( aType.getShape(), bType.getShape(), resultShape)) return emitOpError("a and b operands don't have broadcast-compatible " "shapes, got ") << aType << " and " << bType; } ShapedType resultType = cast(output.getType()); if (!resultType.hasRank()) return success(); const int64_t resultRank = resultType.getRank(); if (aHasRank && resultRank != aType.getRank()) return emitOpError("result type has different rank than a, got ") << resultRank << " vs " << aType.getRank(); if (bHasRank && resultRank != bType.getRank()) return emitOpError("result type has different rank than b, got ") << resultRank << " vs " << bType.getRank(); return success(); } LogicalResult tosa::TableOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TableOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); if (!inputShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } inferredReturnShapes.resize(1); inputShape.getDims(inferredReturnShapes[0]); return success(); } LogicalResult tosa::TableOp::verify() { const TensorType inputType = getInput1().getType(); const TensorType outputType = getOutput().getType(); if (!inputType.hasRank() || !outputType.hasRank()) return success(); if (inputType.getRank() != outputType.getRank()) return emitOpError() << "expected input tensor rank to equal result tensor rank"; auto inputDims = inputType.getShape(); auto outputDims = outputType.getShape(); for (auto it : llvm::enumerate(llvm::zip(inputDims, outputDims))) { int64_t dim = it.index(); auto [inputDim, outputDim] = it.value(); if (ShapedType::isStatic(outputDim) && outputDim != inputDim) { return emitOpError() << "dim(result, " << dim << ") = " << outputDim << " doesn't match dim(input, " << dim << ") = " << inputDim; } } return success(); } LogicalResult tosa::TileOp::getConstantMultiples(SmallVector &multiples) { // Multiples must be constants. DenseIntElementsAttr multiplesAttr; if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr))) return failure(); multiples = llvm::to_vector( llvm::map_range(multiplesAttr.getValues(), [](const APInt &val) { return val.getSExtValue(); })); return success(); } LogicalResult tosa::TileOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); SmallVector multiples; if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(), multiples)) { auto rank = cast(adaptor.getMultiples().getType()).getRank(); SmallVector fallback(rank, ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType)); return success(); } else { multiples = convertToMlirShape(multiples); } ShapeAdaptor inputShape(adaptor.getInput1().getType()); SmallVector outputShape; if (!inputShape.hasRank()) { outputShape.resize(multiples.size(), ShapedType::kDynamic); inferredReturnShapes.push_back( ShapedTypeComponents(outputShape, inputType)); return success(); } else if (static_cast(inputShape.getRank()) != multiples.size()) return failure(); // Any non dynamic dimension can be multiplied to a known size. outputShape.reserve(multiples.size()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { if (multiples[i] == ShapedType::kDynamic) { outputShape.push_back(ShapedType::kDynamic); } else { int64_t dim = inputShape.getDimSize(i); if (dim != ShapedType::kDynamic) dim *= multiples[i]; outputShape.push_back(dim); } } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } LogicalResult tosa::TileOp::verify() { if (verifySameElementTypes(*this, /* intype = */ getInput1().getType(), /* outType = */ getOutput().getType()) .failed()) { return failure(); } ShapedType inputType = llvm::cast(getInput1().getType()); ShapedType outputType = llvm::cast(getType()); shapeType multiplesType = llvm::cast(getMultiples().getType()); auto multiplesRank = multiplesType.getRank(); if (inputType.hasRank()) { if (inputType.getRank() != multiplesRank) return emitOpError("expect 'multiples' to have rank ") << inputType.getRank() << " but got " << multiplesRank << "."; if (outputType.hasRank() && inputType.getRank() != outputType.getRank()) return emitOpError("expect same input and output tensor rank."); } else if (outputType.hasRank() && outputType.getRank() != multiplesRank) return emitOpError("expect 'multiples' array to have length ") << outputType.getRank() << " but got " << multiplesRank << "."; SmallVector multiples; if (getConstantMultiples(multiples).succeeded() && llvm::any_of(multiples, [](int64_t v) { return v <= 0 && v != -1; })) return emitOpError( "expect element of 'multiples' to be positive integer or -1."); return success(); } bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { if (l.size() != r.size() || l.size() != 1) return false; return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]); } LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ReshapeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); llvm::SmallVector newShapeValue; if (!tosa::getConstShapeValues(adaptor.getShape().getDefiningOp(), newShapeValue)) { auto rank = cast(adaptor.getShape().getType()).getRank(); SmallVector fallback(rank, ShapedType::kDynamic); inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType)); return success(); } else { newShapeValue = convertToMlirShape(newShapeValue); } // We cannot infer from the total number of elements so we must take the // shape attribute as exact. if (!inputShape.hasRank() || !inputShape.hasStaticShape()) { inferredReturnShapes.push_back( ShapedTypeComponents(newShapeValue, inputType)); return success(); } // Determine the number of elements covered by the slice of all static // dimensions. This allows us to infer the length of the remaining dynamic // dimension. int64_t numElements = inputShape.getNumElements(); int64_t staticMul = 1; for (auto val : newShapeValue) { if (ShapedType::isStatic(val)) { staticMul *= val; } } // Determine the length of the dynamic dimension. for (auto &val : newShapeValue) { if (ShapedType::isDynamic(val)) val = numElements / staticMul; } inferredReturnShapes.push_back( ShapedTypeComponents(newShapeValue, inputType)); return success(); } llvm::LogicalResult tosa::ReshapeOp::verify() { if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(), /* outType = */ getOutput().getType()) .failed()) { return failure(); } TensorType inputType = getInput1().getType(); SmallVector shapeValues; if (!tosa::getConstShapeValues(getShape().getDefiningOp(), shapeValues)) { // skip following checks if shape is not constant return mlir::success(); } int missingDims = llvm::count(shapeValues, -1); if (missingDims > 1) return emitOpError() << "expected at most one target dimension to be -1"; const auto outputType = dyn_cast(getType()); if (!outputType) return success(); if ((int64_t)shapeValues.size() != outputType.getRank()) return emitOpError() << "new shape does not match result rank"; for (auto [newShapeDim, outputShapeDim] : zip(shapeValues, outputType.getShape())) { if (newShapeDim != -1 && newShapeDim != ShapedType::kDynamic && outputShapeDim != ShapedType::kDynamic && newShapeDim != outputShapeDim) return emitOpError() << "new shape is inconsistent with result shape"; if (newShapeDim != ShapedType::kDynamic && newShapeDim < -1) return emitOpError() << "new shape has invalid tensor dimension size " << newShapeDim; } if (inputType.hasStaticShape()) { int64_t inputElementsNum = inputType.getNumElements(); if (outputType.hasStaticShape()) { int64_t outputElementsNum = outputType.getNumElements(); if (inputElementsNum != outputElementsNum) { return emitOpError() << "cannot reshape " << inputElementsNum << " elements into " << outputElementsNum; } } int64_t newShapeElementsNum = std::accumulate( shapeValues.begin(), shapeValues.end(), 1LL, [](int64_t acc, int64_t dim) { return (dim > 0) ? acc * dim : acc; }); bool isStaticNewShape = llvm::all_of(shapeValues, [](int64_t s) { return s > 0; }); if ((isStaticNewShape && inputElementsNum != newShapeElementsNum) || (!isStaticNewShape && newShapeElementsNum > inputElementsNum)) { return emitOpError() << "cannot reshape " << inputElementsNum << " elements into " << newShapeElementsNum; } } return mlir::success(); } // return failure if val is not a constant // set zp to -1 if val is non-zero float or val is not integer nor float // otherwise set zp to val's constant value static FailureOr getZeroPoint(Value val, bool signExtend) { ElementsAttr zpAttr; if (!matchPattern(val, m_Constant(&zpAttr))) { return failure(); } Type zpElemType = zpAttr.getElementType(); if (llvm::isa(zpElemType)) { if (zpAttr.getValues()[0].isZero()) { return 0; } // return non-zero value to trigger error check return -1; } if (llvm::isa(zpElemType)) { if (signExtend) return zpAttr.getValues()[0].getSExtValue(); else return zpAttr.getValues()[0].getZExtValue(); } // return non-zero value to trigger error check return -1; } template static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, const std::string &operand) { Type zpElemType = getElementTypeOrSelf(val); if (!zpElemType.isInteger(8) && zp != 0) { // convert operand to lower case for error message std::string lower = operand; std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower); return op.emitOpError() << lower << " zero point must be zero for non-int8 integer types"; } return success(); } static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal, const int64_t &zp, const std::string &operand) { bool isInputZp = (operand == "Input"); bool tensorUnsigned = isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned(); StringRef tensorName = isInputZp ? "input" : "output"; Type zpElemType = getElementTypeOrSelf(zpVal); if (zp != 0) { if (!zpElemType.isInteger(8) && !(zpElemType.isInteger(16) && tensorUnsigned)) { return op.emitOpError() << "expect " << tensorName << "_zp of 0, got " << zp; } if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) { return op.emitOpError() << "expect " << tensorName << "_zp of 0 or 32768 for unsigned int16 " << tensorName << ", got " << zp; } } return success(); } #define ZERO_POINT_HELPER(OP, OPERAND_NAME, SIGN_EXTEND) \ FailureOr tosa::OP::get##OPERAND_NAME##ZeroPoint() { \ return getZeroPoint(get##OPERAND_NAME##Zp(), SIGN_EXTEND); \ } \ LogicalResult tosa::OP::verify##OPERAND_NAME##ZeroPoint(int64_t zp) { \ return verifyZeroPoint(*this, get##OPERAND_NAME##Zp(), zp, #OPERAND_NAME); \ } ZERO_POINT_HELPER(Conv2DOp, Input, true) ZERO_POINT_HELPER(Conv2DOp, Weight, true) ZERO_POINT_HELPER(Conv3DOp, Input, true) ZERO_POINT_HELPER(Conv3DOp, Weight, true) ZERO_POINT_HELPER(DepthwiseConv2DOp, Input, true) ZERO_POINT_HELPER(DepthwiseConv2DOp, Weight, true) ZERO_POINT_HELPER(TransposeConv2DOp, Input, true) ZERO_POINT_HELPER(TransposeConv2DOp, Weight, true) ZERO_POINT_HELPER(AvgPool2dOp, Input, true) ZERO_POINT_HELPER(AvgPool2dOp, Output, true) ZERO_POINT_HELPER(MatMulOp, A, true) ZERO_POINT_HELPER(MatMulOp, B, true) ZERO_POINT_HELPER(NegateOp, Input1, true) ZERO_POINT_HELPER(NegateOp, Output, true) ZERO_POINT_HELPER(RescaleOp, Input, !getInputUnsigned()) ZERO_POINT_HELPER(RescaleOp, Output, !getOutputUnsigned()) #undef ZERO_POINT_HELPER LogicalResult tosa::TransposeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TransposeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); // If input rank and permutation length is unknown, the output rank is // unknown. if (!inputShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents()); return success(); } const auto inputRank = inputShape.getRank(); // This would imply the number of permutations does not match the rank of // the input which is illegal. if (adaptor.getPerms().size() != static_cast(inputRank)) { return failure(); } SmallVector outputShape; // Rank-0 means no permutations matter. if (inputRank == 0) { inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } // Check whether the input dimensions are all the same. bool allTheSame = true; for (int i = 1, s = inputRank; i < s; i++) { if (inputShape.getDimSize(0) != inputShape.getDimSize(i)) { allTheSame = false; break; } } // If all of the input dimensions are the same we don't care about the // permutation. if (allTheSame) { outputShape.resize(inputRank, inputShape.getDimSize(0)); inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } outputShape.resize(inputRank, ShapedType::kDynamic); // Constant permutation values must be within the input rank. if (llvm::any_of(adaptor.getPerms(), [inputRank](const auto i) { return i >= inputRank; })) return failure(); outputShape.reserve(inputRank); for (int i = 0, s = inputRank; i < s; i++) { outputShape[i] = inputShape.getDimSize(adaptor.getPerms()[i]); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::TransposeOp::verify() { if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(), /* outType = */ getOutput().getType()) .failed()) { return failure(); } const ShapeAdaptor inputShape(getInput1().getType()); const ShapeAdaptor outputShape(getOutput().getType()); const llvm::ArrayRef constantPerms = getPerms(); if (inputShape.hasRank() && constantPerms.size() != static_cast(inputShape.getRank())) return emitOpError() << "expected perms attribute to have size " << inputShape.getRank() << " (input rank) but got size " << constantPerms.size(); if (inputShape.hasRank() && outputShape.hasRank() && inputShape.getRank() != outputShape.getRank()) return emitOpError() << "expected input tensor rank to equal result tensor rank"; if (outputShape.hasRank() && constantPerms.size() != static_cast(outputShape.getRank())) return emitOpError() << "expected perms attribute to have size " << outputShape.getRank() << " (output rank) but got size " << constantPerms.size(); if (!llvm::all_of(constantPerms, [&constantPerms](int32_t s) { return s >= 0 && static_cast(s) < constantPerms.size(); }) || !isPermutationVector(llvm::to_vector(llvm::map_range( constantPerms, [](int32_t v) -> int64_t { return v; })))) return emitOpError() << "expected valid permutation indices"; // ERROR_IF(tensor_size(shape1) != tensor_size(shape)) if (inputShape.hasStaticShape() && outputShape.hasStaticShape() && inputShape.getNumElements() != outputShape.getNumElements()) return emitOpError() << "expected input1 and output to have same numbers " "of elements, got " << inputShape.getNumElements() << " and " << outputShape.getNumElements(); // Verify that the types of the input and output tensors are properly // permuted. if (inputShape.hasRank() && outputShape.hasRank()) { for (auto i = 0; i < outputShape.getRank(); i++) { if (inputShape.isDynamicDim(constantPerms[i]) || outputShape.isDynamicDim(i)) continue; if (inputShape.getDimSize(constantPerms[i]) != outputShape.getDimSize(i)) return emitOpError() << "expected output tensor dim " << i << " to match " << "input dim " << constantPerms[i] << " with value of " << inputShape.getDimSize(constantPerms[i]); } } return success(); } LogicalResult TransposeOp::reifyResultShapes( OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) { const llvm::ArrayRef transposePerms = getPerms(); Value input = getInput1(); auto inputType = cast(input.getType()); SmallVector returnedDims(inputType.getRank()); for (auto dim : transposePerms) { int32_t dimInInput = transposePerms[dim]; if (inputType.isDynamicDim(dimInInput)) returnedDims[dim] = tensor::DimOp::create(builder, getLoc(), input, dimInInput) .getResult(); else returnedDims[dim] = builder.getIndexAttr(inputType.getDimSize(dimInInput)); } reifiedReturnShapes.emplace_back(std::move(returnedDims)); return success(); } LogicalResult tosa::GatherOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, GatherOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamic); ShapeAdaptor valuesShape(adaptor.getValues().getType()); if (valuesShape.hasRank()) { outputShape[0] = valuesShape.getDimSize(0); outputShape[2] = valuesShape.getDimSize(2); } ShapeAdaptor indicesShape(adaptor.getIndices().getType()); if (indicesShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); if (outputShape[1] == ShapedType::kDynamic) outputShape[1] = indicesShape.getDimSize(1); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::GatherOp::verify() { if (verifySameElementTypes(*this, /* inType = */ getValues().getType(), /* outType = */ getOutput().getType()) .failed()) { return failure(); } const ShapeAdaptor valuesShape(getValues().getType()); const ShapeAdaptor indicesShape(getIndices().getType()); const ShapeAdaptor outputShape(getOutput().getType()); int64_t N = ShapedType::kDynamic; int64_t W = ShapedType::kDynamic; int64_t C = ShapedType::kDynamic; if (valuesShape.hasRank()) { N = valuesShape.getDimSize(0); C = valuesShape.getDimSize(2); } if (indicesShape.hasRank()) { const int64_t indicesN = indicesShape.getDimSize(0); W = indicesShape.getDimSize(1); if (N == ShapedType::kDynamic) N = indicesN; else if (indicesN != ShapedType::kDynamic && N != indicesN) return emitOpError() << "requires indices dimension 0 to have size " << N << ", got " << indicesN; } if (outputShape.hasRank()) { const int64_t outputN = outputShape.getDimSize(0); const int64_t outputW = outputShape.getDimSize(1); const int64_t outputC = outputShape.getDimSize(2); if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic && N != outputN) return emitOpError() << "requires output dimension 0 to have size " << N << ", got " << outputN; if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic && W != outputW) return emitOpError() << "requires output dimension 1 to have size " << W << ", got " << outputW; if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic && C != outputC) return emitOpError() << "requires output dimension 2 to have size " << C << ", got " << outputC; } return success(); } LogicalResult tosa::ResizeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ResizeOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(4, ShapedType::kDynamic); ShapeAdaptor inputShape(adaptor.getInput().getType()); if (!inputShape.hasRank()) return failure(); outputShape[0] = inputShape.getDimSize(0); outputShape[3] = inputShape.getDimSize(3); int64_t inputHeight = inputShape.getDimSize(1); int64_t inputWidth = inputShape.getDimSize(2); if ((inputHeight == ShapedType::kDynamic) || (inputWidth == ShapedType::kDynamic)) return failure(); SmallVector scaleInt, offsetInt, borderInt; if (!tosa::getConstShapeValues(adaptor.getScale().getDefiningOp(), scaleInt) || !tosa::getConstShapeValues(adaptor.getOffset().getDefiningOp(), offsetInt) || !tosa::getConstShapeValues(adaptor.getBorder().getDefiningOp(), borderInt)) { return failure(); } // Compute the output shape based on attributes: scale, offset, and border. const int64_t outputHeight = (((inputHeight - 1) * scaleInt[0] - offsetInt[0] + borderInt[0]) / scaleInt[1]) + 1; const int64_t outputWidth = (((inputWidth - 1) * scaleInt[2] - offsetInt[1] + borderInt[1]) / scaleInt[3]) + 1; if (outputHeight < 0 || outputWidth < 0) { return emitOptionalError( location, "calculated output height and width must be non-negative, " "got height = ", outputHeight, ", width = ", outputWidth); } outputShape[1] = outputHeight; outputShape[2] = outputWidth; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::ResizeOp::verify() { const Value input = getInput(); const Value output = getOutput(); const RankedTensorType inputType = llvm::dyn_cast(input.getType()); const RankedTensorType outputType = llvm::dyn_cast(output.getType()); SmallVector scaleValues; SmallVector offsetValues; SmallVector borderValues; if (!tosa::getConstShapeValues(getScale().getDefiningOp(), scaleValues) || !tosa::getConstShapeValues(getOffset().getDefiningOp(), offsetValues) || !tosa::getConstShapeValues(getBorder().getDefiningOp(), borderValues)) { // Skip following checks if shape is not constant return success(); } if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; })) return emitOpError("expect all scale values to be > 0, got ") << scaleValues; const int64_t scaleYN = scaleValues[0]; const int64_t scaleYD = scaleValues[1]; const int64_t scaleXN = scaleValues[2]; const int64_t scaleXD = scaleValues[3]; const int64_t offsetY = offsetValues[0]; const int64_t offsetX = offsetValues[1]; const int64_t borderY = borderValues[0]; const int64_t borderX = borderValues[1]; if (!inputType) return success(); if (!outputType) return success(); const int64_t oh = outputType.getDimSize(1); const int64_t ow = outputType.getDimSize(2); const int64_t ih = inputType.getDimSize(1); const int64_t iw = inputType.getDimSize(2); // Don't check with input height that could be broadcast (ih != 1) // since Linalg, a consumer of TOSA, expects broadcasting support // in resize to be available. Taking the cautious approach for now, // we can consider removing support for broadcasting later. if (ih != ShapedType::kDynamic && ih != 1) { const std::optional calculatedOutHeightMinusOne = idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD); if (!calculatedOutHeightMinusOne.has_value()) return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + " "border_y ") << "to be wholly divisible by scale_y_d, got ((" << ih << " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY << ") / " << scaleYD; const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1; if (oh != ShapedType::kDynamic && calculatedOutHeight != oh) return emitOpError("calculated output height did not match expected: ") << "calculated=" << calculatedOutHeight << ", expected=" << oh; } // Don't check with input width that could be broadcast (iw != 1) // since Linalg, a consumer of TOSA, expects broadcasting support // in resize to be available. Taking the cautious approach for now, // we can consider removing support for broadcasting later. if (iw != ShapedType::kDynamic && iw != 1) { const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX; const std::optional calculatedOutWidthMinusOne = idivCheck(scaledInWidth, scaleXD); if (!calculatedOutWidthMinusOne.has_value()) return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + " "border_x ") << "to be wholly divisible by scale_x_d, got ((" << iw << " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX << ") / " << scaleXD; const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1; if (ow != ShapedType::kDynamic && calculatedOutWidth != ow) return emitOpError("calculated output width did not match expected: ") << "calculated=" << calculatedOutWidth << ", expected=" << ow; } return success(); } LogicalResult tosa::ScatterOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ScatterOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(3, ShapedType::kDynamic); ShapeAdaptor valuesInShape(adaptor.getValuesIn().getType()); if (valuesInShape.hasRank()) { outputShape[0] = valuesInShape.getDimSize(0); outputShape[1] = valuesInShape.getDimSize(1); outputShape[2] = valuesInShape.getDimSize(2); } ShapeAdaptor indicesShape(adaptor.getIndices().getType()); if (indicesShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = indicesShape.getDimSize(0); } ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { if (outputShape[0] == ShapedType::kDynamic) outputShape[0] = inputShape.getDimSize(0); if (outputShape[2] == ShapedType::kDynamic) outputShape[2] = inputShape.getDimSize(2); } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult tosa::ScatterOp::verify() { if (verifySameElementTypes(*this, /* inType = */ getValuesIn().getType(), /* outType = */ getValuesOut().getType()) .failed() || verifySameElementTypes(*this, /* inType = */ getInput().getType(), /* outType = */ getValuesOut().getType()) .failed()) { return failure(); } const ShapeAdaptor valuesInShape(getValuesIn().getType()); const ShapeAdaptor indicesShape(getIndices().getType()); const ShapeAdaptor inputShape(getInput().getType()); const ShapeAdaptor outputShape(getValuesOut().getType()); int64_t N = ShapedType::kDynamic; int64_t K = ShapedType::kDynamic; int64_t W = ShapedType::kDynamic; int64_t C = ShapedType::kDynamic; if (valuesInShape.hasRank()) { N = valuesInShape.getDimSize(0); K = valuesInShape.getDimSize(1); C = valuesInShape.getDimSize(2); } if (indicesShape.hasRank()) { const int64_t indicesN = indicesShape.getDimSize(0); W = indicesShape.getDimSize(1); if (N == ShapedType::kDynamic) N = indicesN; else if (indicesN != ShapedType::kDynamic && N != indicesN) return emitOpError() << "requires indices dimension 0 to have size " << N << ", got " << indicesN; } if (inputShape.hasRank()) { const int64_t inputN = inputShape.getDimSize(0); const int64_t inputW = inputShape.getDimSize(1); const int64_t inputC = inputShape.getDimSize(2); if (N == ShapedType::kDynamic) N = inputN; else if (inputN != ShapedType::kDynamic && N != inputN) return emitOpError() << "requires input dimension 0 to have size " << N << ", got " << inputN; if (W == ShapedType::kDynamic) W = inputW; else if (inputW != ShapedType::kDynamic && W != inputW) return emitOpError() << "requires input dimension 1 to have size " << W << ", got " << inputW; if (C == ShapedType::kDynamic) C = inputC; else if (inputC != ShapedType::kDynamic && C != inputC) return emitOpError() << "requires input dimension 2 to have size " << C << ", got " << inputC; } if (outputShape.hasRank()) { const int64_t outputN = outputShape.getDimSize(0); const int64_t outputK = outputShape.getDimSize(1); const int64_t outputC = outputShape.getDimSize(2); if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic && N != outputN) return emitOpError() << "requires values_out dimension 0 to have size " << N << ", got " << outputN; if (K == ShapedType::kDynamic) K = outputK; else if (outputK != ShapedType::kDynamic && K != outputK) return emitOpError() << "requires values_out dimension 1 to have size " << K << ", got " << outputK; if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic && C != outputC) return emitOpError() << "requires values_out dimension 2 to have size " << C << ", got " << outputC; } if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W)) return emitOpError() << "requires dimensions K >= W, got K=" << K << " and W=" << W; return success(); } static LogicalResult ReduceInferReturnTypes( ShapeAdaptor operandShape, Type inputType, IntegerAttr axis, SmallVectorImpl &inferredReturnShapes) { int64_t axisVal = axis.getValue().getSExtValue(); if (!operandShape.hasRank() || operandShape.getRank() <= axisVal) { inferredReturnShapes.push_back(ShapedTypeComponents(inputType)); return success(); } SmallVector outputShape; operandShape.getDims(outputShape); outputShape[axisVal] = 1; inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } #define COMPATIBLE_RETURN_TYPES(OP) \ bool OP::isCompatibleReturnTypes(TypeRange l, TypeRange r) { \ if (l.size() != r.size() || l.size() != 1) \ return false; \ if (getElementTypeOrSelf(l[0]) != getElementTypeOrSelf(r[0])) \ return false; \ return succeeded(verifyCompatibleShape(l[0], r[0])); \ } #define REDUCE_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::std::optional location, \ OP::Adaptor adaptor, \ SmallVectorImpl &inferredReturnShapes) { \ Type inputType = \ llvm::cast(adaptor.getInput().getType()).getElementType(); \ ShapeAdaptor inputShape(adaptor.getInput().getType()); \ const Properties &prop = adaptor.getProperties(); \ return ReduceInferReturnTypes(inputShape, inputType, prop.axis, \ inferredReturnShapes); \ } \ COMPATIBLE_RETURN_TYPES(OP) REDUCE_SHAPE_INFER(tosa::ReduceAllOp) REDUCE_SHAPE_INFER(tosa::ReduceAnyOp) REDUCE_SHAPE_INFER(tosa::ReduceMaxOp) REDUCE_SHAPE_INFER(tosa::ReduceMinOp) REDUCE_SHAPE_INFER(tosa::ReduceProductOp) REDUCE_SHAPE_INFER(tosa::ReduceSumOp) #undef REDUCE_SHAPE_INFER COMPATIBLE_RETURN_TYPES(tosa::ConcatOp) #undef COMPATIBLE_RETURN_TYPES template static LogicalResult verifyReduceOp(T op) { // All TOSA reduce Ops have input, output and axis. TensorType inputType = op.getInput().getType(); TensorType outputType = op.getOutput().getType(); int32_t reduceAxis = op.getAxis(); if (reduceAxis < 0) { op.emitOpError("reduce axis must not be negative"); return failure(); } if (inputType.hasRank()) { int64_t inputRank = inputType.getRank(); // We allow for a special case where the input/output shape has rank 0 and // axis is also 0. if (reduceAxis >= inputRank && (reduceAxis != 0 || inputRank != 0)) { op.emitOpError("expect input tensor rank (") << inputRank << ") to be larger than reduce axis (" << reduceAxis << ")"; return failure(); } } if (outputType.hasRank()) { int64_t outputRank = outputType.getRank(); if (inputType.hasRank() && outputRank != inputType.getRank()) { op.emitOpError( "expect output tensor rank to be equal to input tensor rank"); return failure(); } if (reduceAxis >= outputRank && (reduceAxis != 0 || outputRank != 0)) { op.emitOpError("expect output tensor rank (") << outputRank << ") to be larger than reduce axis (" << reduceAxis << ")"; return failure(); } // We can only verify the reduced dimension size to be 1 if this is not // the special case of output rank == 0. if (outputRank != 0) { auto outputShape = outputType.getShape(); if (!outputType.isDynamicDim(reduceAxis) && outputShape[reduceAxis] != 1) { op.emitOpError("expect reduced dimension size to be 1, got ") << outputShape[reduceAxis]; return failure(); } } } return success(); } LogicalResult tosa::ReduceAllOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceAnyOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceMaxOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceMinOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceProductOp::verify() { return verifyReduceOp(*this); } LogicalResult tosa::ReduceSumOp::verify() { return verifyReduceOp(*this); } static LogicalResult NAryInferReturnTypes( const ValueShapeRange &operands, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outShape; if (resolveBroadcastShape(operands, outShape).failed()) { inferredReturnShapes.push_back(ShapedTypeComponents()); } else { inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); } return success(); } #define NARY_SHAPE_INFER(OP) \ LogicalResult OP::inferReturnTypeComponents( \ MLIRContext *context, ::std::optional location, \ ValueShapeRange operands, DictionaryAttr attributes, \ OpaqueProperties properties, RegionRange regions, \ SmallVectorImpl &inferredReturnShapes) { \ return NAryInferReturnTypes(operands, inferredReturnShapes); \ } NARY_SHAPE_INFER(tosa::AbsOp) NARY_SHAPE_INFER(tosa::AddOp) NARY_SHAPE_INFER(tosa::ArithmeticRightShiftOp) NARY_SHAPE_INFER(tosa::BitwiseAndOp) NARY_SHAPE_INFER(tosa::BitwiseOrOp) NARY_SHAPE_INFER(tosa::BitwiseXorOp) NARY_SHAPE_INFER(tosa::BitwiseNotOp) NARY_SHAPE_INFER(tosa::CastOp) NARY_SHAPE_INFER(tosa::CeilOp) NARY_SHAPE_INFER(tosa::ClampOp) NARY_SHAPE_INFER(tosa::ClzOp) NARY_SHAPE_INFER(tosa::CosOp) NARY_SHAPE_INFER(tosa::ExpOp) NARY_SHAPE_INFER(tosa::FloorOp) NARY_SHAPE_INFER(tosa::GreaterEqualOp) NARY_SHAPE_INFER(tosa::GreaterOp) NARY_SHAPE_INFER(tosa::IdentityOp) NARY_SHAPE_INFER(tosa::IntDivOp) NARY_SHAPE_INFER(tosa::LogOp) NARY_SHAPE_INFER(tosa::LogicalAndOp) NARY_SHAPE_INFER(tosa::LogicalLeftShiftOp) NARY_SHAPE_INFER(tosa::LogicalNotOp) NARY_SHAPE_INFER(tosa::LogicalOrOp) NARY_SHAPE_INFER(tosa::LogicalRightShiftOp) NARY_SHAPE_INFER(tosa::LogicalXorOp) NARY_SHAPE_INFER(tosa::MaximumOp) NARY_SHAPE_INFER(tosa::MinimumOp) NARY_SHAPE_INFER(tosa::PowOp) NARY_SHAPE_INFER(tosa::ReciprocalOp) NARY_SHAPE_INFER(tosa::ReverseOp) NARY_SHAPE_INFER(tosa::RsqrtOp) NARY_SHAPE_INFER(tosa::SinOp) NARY_SHAPE_INFER(tosa::SelectOp) NARY_SHAPE_INFER(tosa::SubOp) NARY_SHAPE_INFER(tosa::TanhOp) NARY_SHAPE_INFER(tosa::ErfOp) NARY_SHAPE_INFER(tosa::SigmoidOp) #undef PRED_SHAPE_INFER LogicalResult tosa::NegateOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, NegateOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput1().getType()); inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); return success(); } LogicalResult tosa::NegateOp::verify() { // Verify same element type const Type input1Type = getInput1().getType(); const Type outputType = getOutput().getType(); if (verifySameElementTypes(*this, input1Type, outputType).failed()) return failure(); // Verify same shape const SmallVector types = {input1Type, outputType}; if (failed(verifyCompatibleShapes(types))) return emitOpError() << "requires the same shape for input1 and output"; const Type input1EType = getStorageElementTypeOrSelf(getInput1().getType()); const Type input1ZpEType = getStorageElementTypeOrSelf(getInput1Zp().getType()); if (input1EType != input1ZpEType) { return emitOpError("expect both input1 and its zero point are the same " "element type, got ") << input1EType << " and " << input1ZpEType; } const Type outputEType = getStorageElementTypeOrSelf(getOutput().getType()); const Type outputZpEType = getStorageElementTypeOrSelf(getOutputZp().getType()); if (outputEType != outputZpEType) { return emitOpError("expect both output and its zero point are the same " "element type, got ") << outputEType << " and " << outputZpEType; } FailureOr maybeIZp = getInput1ZeroPoint(); if (succeeded(maybeIZp) && verifyInput1ZeroPoint(*maybeIZp).failed()) return failure(); FailureOr maybeOZp = getOutputZeroPoint(); if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed()) return failure(); return success(); } static LogicalResult poolingInferReturnTypes( ShapeAdaptor inputShape, ArrayRef kernel, ArrayRef stride, ArrayRef pad, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape; outputShape.resize(4, ShapedType::kDynamic); // We only know the rank if the input type is unranked. if (!inputShape) { inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } // Batch and number of channels are identical for pooling layer. outputShape[0] = inputShape.getDimSize(0); outputShape[3] = inputShape.getDimSize(3); int64_t height = inputShape.getDimSize(1); int64_t width = inputShape.getDimSize(2); if (ShapedType::isStatic(height)) { int64_t padded = height + pad[0] + pad[1] - kernel[0]; outputShape[1] = padded / stride[0] + 1; } if (ShapedType::isStatic(width)) { int64_t padded = width + pad[2] + pad[3] - kernel[1]; outputShape[2] = padded / stride[1] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult Conv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, Conv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamic); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputHeight = inputShape.getDimSize(1); inputWidth = inputShape.getDimSize(2); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { outputShape[3] = weightShape.getDimSize(0); weightHeight = weightShape.getDimSize(1); weightWidth = weightShape.getDimSize(2); } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getBias().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } llvm::ArrayRef dilation = adaptor.getDilation(); llvm::ArrayRef stride = adaptor.getStride(); llvm::ArrayRef padding = adaptor.getPad(); if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) { int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) { int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult Conv2DOp::verify() { if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() || verifyConvOpErrorIf(*this).failed()) return failure(); return success(); } LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, Conv3DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(5, ShapedType::kDynamic); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t inputDepth = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; int64_t weightDepth = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputDepth = inputShape.getDimSize(1); inputHeight = inputShape.getDimSize(2); inputWidth = inputShape.getDimSize(3); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { outputShape[4] = weightShape.getDimSize(0); weightDepth = weightShape.getDimSize(1); weightHeight = weightShape.getDimSize(2); weightWidth = weightShape.getDimSize(3); } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getBias().getType()); if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[4])) { outputShape[4] = biasShape.getDimSize(0); } llvm::ArrayRef dilation = adaptor.getDilation(); llvm::ArrayRef stride = adaptor.getStride(); llvm::ArrayRef pad = adaptor.getPad(); if (ShapedType::isStatic(inputDepth) && ShapedType::isStatic(weightDepth)) { int32_t inputSize = inputDepth + pad[0] + pad[1]; int32_t filterSize = (weightDepth - 1) * dilation[0] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) { int32_t inputSize = inputHeight + pad[2] + pad[3]; int32_t filterSize = (weightHeight - 1) * dilation[1] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) { int32_t inputSize = inputWidth + pad[4] + pad[5]; int32_t filterSize = (weightWidth - 1) * dilation[2] + 1; int32_t unstridedResult = inputSize - filterSize + 1; outputShape[3] = (unstridedResult - 1) / stride[2] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult Conv3DOp::verify() { if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() || verifyConvOpErrorIf(*this).failed()) return failure(); return success(); } LogicalResult AvgPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, AvgPool2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); const Properties &prop = adaptor.getProperties(); return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, inferredReturnShapes); } LogicalResult MaxPool2dOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, MaxPool2dOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); const Properties &prop = adaptor.getProperties(); return poolingInferReturnTypes(inputShape, prop.kernel, prop.stride, prop.pad, inferredReturnShapes); } LogicalResult MaxPool2dOp::verify() { if (failed(verifySameElementTypes(*this, /* intype = */ getInput().getType(), /* outType = */ getOutput().getType()))) return failure(); if (failed(verifyPoolingOp(*this))) return failure(); return success(); } LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, DepthwiseConv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamic); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t inputChannels = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; int64_t depthChannels = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = inputShape.getDimSize(0); inputHeight = inputShape.getDimSize(1); inputWidth = inputShape.getDimSize(2); inputChannels = inputShape.getDimSize(3); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { weightHeight = weightShape.getDimSize(0); weightWidth = weightShape.getDimSize(1); inputChannels = ShapedType::isDynamic(inputChannels) ? weightShape.getDimSize(2) : inputChannels; depthChannels = weightShape.getDimSize(3); } // If both inputChannels and depthChannels are available we can determine // the output channels. if (ShapedType::isStatic(inputChannels) && ShapedType::isStatic(depthChannels)) { outputShape[3] = inputChannels * depthChannels; } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getBias().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } llvm::ArrayRef dilation = adaptor.getDilation(); llvm::ArrayRef padding = adaptor.getPad(); llvm::ArrayRef stride = adaptor.getStride(); if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) { int64_t inputSize = inputHeight + padding[0] + padding[1]; int64_t filterSize = (weightHeight - 1) * dilation[0] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[1] = (unstridedResult - 1) / stride[0] + 1; } if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) { int64_t inputSize = inputWidth + padding[2] + padding[3]; int64_t filterSize = (weightWidth - 1) * dilation[1] + 1; int64_t unstridedResult = inputSize - filterSize + 1; outputShape[2] = (unstridedResult - 1) / stride[1] + 1; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult DepthwiseConv2DOp::verify() { if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed() || verifyConvOpErrorIf(*this).failed()) return failure(); return success(); } LogicalResult TransposeConv2DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TransposeConv2DOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector outputShape(4, ShapedType::kDynamic); int64_t inputWidth = ShapedType::kDynamic; int64_t inputHeight = ShapedType::kDynamic; int64_t weightWidth = ShapedType::kDynamic; int64_t weightHeight = ShapedType::kDynamic; // Input shape describes input width/height and batch. ShapeAdaptor inputShape(adaptor.getInput().getType()); if (inputShape.hasRank()) { outputShape[0] = ShapedType::isDynamic(outputShape[0]) ? inputShape.getDimSize(0) : outputShape[0]; inputHeight = inputShape.getDimSize(1); inputWidth = inputShape.getDimSize(2); } // Weight shapes describes the filter width/height and the output channels. ShapeAdaptor weightShape(adaptor.getWeight().getType()); if (weightShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? weightShape.getDimSize(0) : outputShape[3]; weightHeight = weightShape.getDimSize(1); weightWidth = weightShape.getDimSize(2); } // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getInput().getType()); if (biasShape.hasRank()) { outputShape[3] = ShapedType::isDynamic(outputShape[3]) ? biasShape.getDimSize(0) : outputShape[3]; } llvm::ArrayRef padding = adaptor.getOutPad(); llvm::ArrayRef stride = adaptor.getStride(); if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) { int64_t calculateSize = (inputHeight - 1) * stride[0] + padding[0] + padding[1] + weightHeight; outputShape[1] = ShapedType::isDynamic(outputShape[1]) ? calculateSize : outputShape[1]; } if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) { int64_t calculateSize = (inputWidth - 1) * stride[1] + padding[2] + padding[3] + weightWidth; outputShape[2] = ShapedType::isDynamic(outputShape[2]) ? calculateSize : outputShape[2]; } inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } LogicalResult TransposeConv2DOp::verify() { if (verifyConvOp(*this).failed() || verifyConvOpModes(*this).failed()) return failure(); const llvm::ArrayRef strides = getStride(); const int64_t strideY = strides[0]; const int64_t strideX = strides[1]; if (strideY < 1 || strideX < 1) return emitOpError("expect all stride values to be >= 1, got [") << strides << "]"; const auto checkPadAgainstKernelDim = [this](int64_t pad_value, int64_t kernel_dim_size, llvm::StringRef pad_name, llvm::StringRef kernel_dim_name) -> LogicalResult { if (pad_value <= -kernel_dim_size) return emitOpError("expected ") << pad_name << " > -" << kernel_dim_name << ", but got: " << pad_name << "=" << pad_value << " and " << kernel_dim_name << "=" << kernel_dim_size; return success(); }; const llvm::ArrayRef padding = getOutPad(); const int64_t outPadTop = padding[0]; const int64_t outPadBottom = padding[1]; const int64_t outPadLeft = padding[2]; const int64_t outPadRight = padding[3]; const auto weightType = llvm::dyn_cast(getWeight().getType()); if (weightType) { const int64_t kernelHeight = weightType.getDimSize(1); if (ShapedType::isStatic(kernelHeight)) { if (failed(checkPadAgainstKernelDim(outPadTop, kernelHeight, "out_pad_top", "KH"))) return failure(); if (failed(checkPadAgainstKernelDim(outPadBottom, kernelHeight, "out_pad_bottom", "KH"))) return failure(); } const int64_t kernelWidth = weightType.getDimSize(2); if (ShapedType::isStatic(kernelWidth)) { if (failed(checkPadAgainstKernelDim(outPadLeft, kernelWidth, "out_pad_left", "KW"))) return failure(); if (failed(checkPadAgainstKernelDim(outPadRight, kernelWidth, "out_pad_right", "KW"))) return failure(); } } // Rest of the checks depend on the output type being a RankedTensorType const auto outputType = llvm::dyn_cast(getOutput().getType()); if (!outputType) return success(); const auto inputType = llvm::dyn_cast(getInput().getType()); if (inputType && weightType) { const int64_t inputHeight = inputType.getDimSize(1); const int64_t kernelHeight = weightType.getDimSize(1); const int64_t outputHeight = outputType.getDimSize(1); if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(outputHeight)) { if (outputHeight != (inputHeight - 1) * strideY + outPadTop + outPadBottom + kernelHeight) return emitOpError( "dimension mismatch: expected OH == (IH - 1) * stride_y " "+ out_pad_top + out_pad_bottom + KH, but got ") << outputHeight << " != (" << inputHeight << " - 1) * " << strideY << " + " << outPadTop << " + " << outPadBottom << " + " << kernelHeight; } const int64_t inputWidth = inputType.getDimSize(2); const int64_t kernelWidth = weightType.getDimSize(2); const int64_t outputWidth = outputType.getDimSize(2); if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(outputWidth)) { if (outputWidth != (inputWidth - 1) * strideX + outPadLeft + outPadRight + kernelWidth) return emitOpError( "dimension mismatch: expected OW == (IW - 1) * stride_x " "+ out_pad_left + out_pad_right + KW, but got ") << outputWidth << " != (" << inputWidth << " - 1) * " << strideX << " + " << outPadLeft << " + " << outPadRight << " + " << kernelWidth; } } const auto biasType = llvm::dyn_cast(getBias().getType()); if (!biasType) return success(); const int64_t biasChannels = biasType.getDimSize(0); // Skip further checks if bias is dynamic if (biasChannels == ShapedType::kDynamic) return success(); const int64_t outputChannels = outputType.getDimSize(3); if (!ShapedType::isDynamic(outputChannels) && biasChannels != outputChannels && biasChannels != 1) return emitOpError( "bias channels expected to be equal to output channels (") << outputChannels << ") or 1, got " << biasChannels; return success(); } LogicalResult RescaleOp::verify() { auto inputType = llvm::dyn_cast(getInput().getType()); if (!inputType) { emitOpError("expect shaped tensor for input, got ") << getInput().getType(); return failure(); } auto inputElementType = getStorageElementTypeOrSelf(inputType.getElementType()); if (!mlir::isa(inputElementType)) { emitOpError("expect input to have integer element type, got ") << inputElementType; return failure(); } auto outputType = llvm::dyn_cast(getOutput().getType()); if (!outputType) { emitOpError("expect shaped tensor for output, got ") << getOutput().getType(); return failure(); } auto outputElementType = getStorageElementTypeOrSelf(outputType.getElementType()); if (!mlir::isa(outputElementType)) { emitOpError("expect output to have integer element type, got ") << outputElementType; return failure(); } if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input") .failed()) return failure(); if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output") .failed()) return failure(); FailureOr maybeIZp = getInputZeroPoint(); if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed()) return failure(); FailureOr maybeOZp = getOutputZeroPoint(); if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed()) return failure(); auto multiplierType = llvm::dyn_cast(getMultiplier().getType()); if (!multiplierType) { emitOpError("expect shaped tensor for multiplier, got ") << getMultiplier().getType(); return failure(); } auto shiftType = llvm::dyn_cast(getShift().getType()); if (!shiftType) { emitOpError("expect shaped tensor for shift, got ") << getShift().getType(); return failure(); } // multiplier element type must be i32 for scale32 = true if (getScale32() && !multiplierType.getElementType().isInteger(32)) { emitOpError("expect i32 element type for multiplier for scale32=true, got ") << multiplierType.getElementType(); return failure(); } // multiplier element type must be i16 for scale32 = false if (!getScale32() && !multiplierType.getElementType().isInteger(16)) { emitOpError( "expect i16 element type for multiplier for scale32=false, got ") << multiplierType.getElementType(); return failure(); } if (!inputType.hasRank()) return success(); // multiplier/shift must have shape = {numChannels}, // where numChannel is 1 if per_channel = false // otherwise numChannel is dimension in input shape's last axis int64_t numChannels = 1; if (getPerChannel()) { if (inputType.getRank() < 1) { emitOpError("requires input to be at least rank 1 when per_channel is " "true, but got rank ") << inputType.getRank(); return failure(); } numChannels = inputType.getDimSize(inputType.getRank() - 1); } if (!multiplierType.hasRank()) return success(); ArrayRef multiplierShape = multiplierType.getShape(); // multiplier input has rank 1 by dialect definition if (multiplierShape[0] != ShapedType::kDynamic && multiplierShape[0] != numChannels) { emitOpError("expect shape of { ") << numChannels << " } for multiplier input, got { " << multiplierShape[0] << " }"; return failure(); } if (!shiftType.hasRank()) return success(); ArrayRef shiftShape = shiftType.getShape(); // shift input has rank 1 by dialect definition if (shiftShape[0] != ShapedType::kDynamic && shiftShape[0] != numChannels) { emitOpError("expect shape of { ") << numChannels << " } for shift input, got { " << shiftShape[0] << " }"; return failure(); } return success(); } LogicalResult RescaleOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, RescaleOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape(adaptor.getInput().getType()); inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); return success(); } LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, IfOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector yieldOps; for (Region *region : adaptor.getRegions()) { for (auto &block : *region) if (auto returnOp = dyn_cast(block.getTerminator())) yieldOps.push_back(returnOp); } if (yieldOps.empty()) return failure(); // Get the initial type information for the yield op. llvm::SmallVector resultKnowledge; resultKnowledge.reserve(yieldOps.front().getNumOperands()); for (auto operand : yieldOps.front().getOperands()) { resultKnowledge.push_back( ValueKnowledge::getKnowledgeFromType(operand.getType())); } for (auto yieldOp : yieldOps) { if (resultKnowledge.size() != yieldOp.getNumOperands()) return failure(); for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { int32_t index = it.index(); auto meet = ValueKnowledge::meet( resultKnowledge[index], ValueKnowledge::getKnowledgeFromType(it.value().getType())); if (!meet) continue; resultKnowledge[index] = meet; } } for (const ValueKnowledge &result : resultKnowledge) { inferredReturnShapes.push_back(result.getShapedTypeComponents()); } return success(); } LogicalResult WhileOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, WhileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { llvm::SmallVector yieldOps; for (auto &block : adaptor.getBodyGraph()) if (auto returnOp = dyn_cast(block.getTerminator())) yieldOps.push_back(returnOp); // TOSA's while must have a tosa.yield as its terminator. If not found this // tosa.while is invalid. if (yieldOps.empty()) return failure(); // Get the initial type information from the operand types. llvm::SmallVector resultKnowledge; resultKnowledge.reserve(yieldOps.front().getNumOperands()); for (auto operand : yieldOps.front().getOperands()) { resultKnowledge.push_back( ValueKnowledge::getKnowledgeFromType(operand.getType())); } for (auto yieldOp : yieldOps) { if (resultKnowledge.size() != yieldOp.getNumOperands()) return failure(); for (const auto &it : llvm::enumerate(yieldOp.getOperands())) { int32_t index = it.index(); if (auto meet = ValueKnowledge::meet( resultKnowledge[index], ValueKnowledge::getKnowledgeFromType(it.value().getType()))) { resultKnowledge[index] = meet; } } } for (const ValueKnowledge &result : resultKnowledge) { inferredReturnShapes.push_back(result.getShapedTypeComponents()); } return success(); } std::optional> ApplyScaleOp::getShapeForUnroll() { if (auto vt = llvm::dyn_cast(getType())) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } static void printInitializationList(OpAsmPrinter &parser, Block::BlockArgListType blocksArgs, ValueRange initializers, StringRef prefix = "") { assert(blocksArgs.size() == initializers.size() && "expected same length of arguments and initializers"); if (initializers.empty()) return; parser << prefix << '('; llvm::interleaveComma( llvm::zip(blocksArgs, initializers), parser, [&](auto it) { parser << std::get<0>(it) << " = " << std::get<1>(it); }); parser << ")"; } // parse and print of IfOp refer to the implementation of SCF dialect. ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { // Create the regions for 'then'. result.regions.reserve(2); Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); OpAsmParser::UnresolvedOperand cond; if (parser.parseOperand(cond)) return failure(); SmallVector regionArgs; SmallVector operands; // Parse the optional block arguments OptionalParseResult listResult = parser.parseOptionalAssignmentList(regionArgs, operands); if (listResult.has_value() && failed(listResult.value())) return failure(); // Parse a colon. if (failed(parser.parseColon())) return parser.emitError(parser.getCurrentLocation(), "expected type for condition operand"); // Parse the type of the condition operand Type condType; if (failed(parser.parseType(condType))) return parser.emitError(parser.getCurrentLocation(), "expected type for condition operand"); // Resolve operand with provided type if (failed(parser.resolveOperand(cond, condType, result.operands))) return failure(); // Parse optional block arg types if (listResult.has_value()) { FunctionType functionType; if (failed(parser.parseType(functionType))) return parser.emitError(parser.getCurrentLocation()) << "expected list of types for block arguments " << "followed by arrow type and list of return types"; result.addTypes(functionType.getResults()); if (functionType.getNumInputs() != operands.size()) { return parser.emitError(parser.getCurrentLocation()) << "expected as many input types as operands " << "(expected " << operands.size() << " got " << functionType.getNumInputs() << ")"; } // Resolve input operands. if (failed(parser.resolveOperands(operands, functionType.getInputs(), parser.getCurrentLocation(), result.operands))) return failure(); } else { // Parse optional results type list. if (parser.parseOptionalArrowTypeList(result.types)) return failure(); } // Parse the 'then' region. if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); // If we find an 'else' keyword then parse the 'else' region. if (!parser.parseOptionalKeyword("else")) { if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{})) return failure(); } // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } void IfOp::print(OpAsmPrinter &p) { p << " " << getCondition(); printInitializationList(p, getThenGraph().front().getArguments(), getInputList(), " "); p << " : "; p << getCondition().getType(); if (!getInputList().empty()) { p << " ("; llvm::interleaveComma(getInputList().getTypes(), p); p << ")"; } p.printArrowTypeList(getResultTypes()); p << " "; p.printRegion(getThenGraph()); // Print the 'else' regions if it exists and has a block. auto &elseRegion = getElseGraph(); if (!elseRegion.empty()) { p << " else "; p.printRegion(elseRegion); } p.printOptionalAttrDict((*this)->getAttrs()); } LogicalResult IfOp::verify() { if (errorIfTypeOrShapeMismatch(*this, getThenGraph().front().getArguments(), "'then_graph' arguments", getInputList(), "'input_list'") .failed()) return failure(); if (errorIfTypeOrShapeMismatch(*this, getElseGraph().front().getArguments(), "'else_graph' arguments", getInputList(), "'input_list'") .failed()) return failure(); // MLIR will verify the absence of the terminator for us if otherwise. if (getThenGraph().front().mightHaveTerminator()) { auto thenYield = dyn_cast(getThenGraph().front().getTerminator()); if (thenYield && errorIfTypeOrShapeMismatch( *this, thenYield.getInputs(), "'then_graph' results", getOutputList(), "'output_list'") .failed()) return failure(); } // MLIR will verify the absence of the terminator for us if otherwise. if (getElseGraph().front().mightHaveTerminator()) { auto elseYield = dyn_cast(getElseGraph().front().getTerminator()); if (elseYield && errorIfTypeOrShapeMismatch( *this, elseYield.getInputs(), "'else_graph' results", getOutputList(), "'output_list'") .failed()) return failure(); } auto condType = getCondition().getType(); if (errorIfShapeNotSizeOne(*this, condType).failed()) return emitOpError() << "'condition' must be a size 1 tensor, got " << condType; return success(); } LogicalResult WhileOp::verify() { if (errorIfTypeOrShapeMismatch(*this, getInputList(), "'input_list'", getOutputList(), "'output_list'") .failed()) return failure(); if (errorIfTypeOrShapeMismatch(*this, getCondGraph().front().getArguments(), "'cond_graph' arguments", getInputList(), "'input_list'") .failed()) return failure(); if (errorIfTypeOrShapeMismatch(*this, getBodyGraph().front().getArguments(), "'body_graph' arguments", getInputList(), "'input_list'") .failed()) return failure(); if (getBodyGraph().front().mightHaveTerminator()) { auto bodyYield = dyn_cast(getBodyGraph().front().getTerminator()); if (bodyYield && errorIfTypeOrShapeMismatch(*this, bodyYield.getInputs(), "'body_graph' results", getInputList(), "'input_list'") .failed()) return failure(); } // Condition block output must be a single element tensor with a single bool // value. if (!getCondGraph().front().mightHaveTerminator()) return success(); auto condYield = dyn_cast(getCondGraph().front().getTerminator()); if (!condYield) return success(); if (condYield.getInputs().size() != 1) return emitOpError() << "require 'cond_graph' only have one result"; auto condOutType = condYield.getInputs()[0].getType(); if (errorIfShapeNotSizeOne(*this, condOutType).failed()) return emitOpError() << "'cond_graph' result must be a size 1 tensor, got " << condOutType; if (!getElementTypeOrSelf(condOutType).isInteger(1)) return emitOpError() << "'cond_graph' result must be a boolean tensor, got " << condOutType; return success(); } LogicalResult ReverseOp::verify() { if (verifySameElementTypes(*this, /* inType = */ getInput1().getType(), /* outType = */ getOutput().getType()) .failed()) return failure(); TensorType inputType = getInput1().getType(); TensorType outputType = getOutput().getType(); int32_t reverseAxis = getAxis(); if (reverseAxis < 0) return emitOpError("expected non-negative reverse axis"); if (inputType.hasRank()) { int64_t inputRank = inputType.getRank(); // We allow for a special case where the input/output shape has rank 0 and // axis is also 0. if (reverseAxis >= inputRank && (reverseAxis != 0 || inputRank != 0)) return emitOpError("expect input tensor rank (") << inputRank << ") to be larger than reverse axis (" << reverseAxis << ")"; } if (outputType.hasRank()) { int64_t outputRank = outputType.getRank(); if (inputType.hasRank() && outputRank != inputType.getRank()) return emitOpError( "expect output tensor rank to be equal to input tensor rank"); if (reverseAxis >= outputRank && (reverseAxis != 0 || outputRank != 0)) return emitOpError("expect output tensor rank (") << outputRank << ") to be larger than reverse axis (" << reverseAxis << ")"; } return success(); } LogicalResult tosa::SelectOp::verify() { // verify input2 and input3 have same element type as output if (verifySameElementTypes(*this, /* inType = */ getOnTrue().getType(), /* outType = */ getOutput().getType()) .failed() || verifySameElementTypes(*this, /* inType = */ getOnFalse().getType(), /* outType = */ getOutput().getType()) .failed()) { return failure(); } // verify input1 has element type of bool auto predicateType = llvm::dyn_cast(getPred().getType()); if (!predicateType) { return emitOpError("expect shaped tensor for input1, got ") << getInput1().getType(); } auto predicateElementType = predicateType.getElementType(); if (!predicateElementType.isInteger(1)) { return emitOpError("expect element type of bool for input1, got ") << predicateElementType; } return success(); } LogicalResult tosa::VariableReadOp::verify() { if (verifyVariableOpErrorIf(*this, getOutput1().getType(), "'output1'") .failed()) return failure(); return success(); } LogicalResult tosa::VariableWriteOp::verify() { if (verifyVariableOpErrorIf(*this, getInput1().getType(), "'input1'") .failed()) return failure(); return success(); } // parse and print of WhileOp refer to the implementation of SCF dialect. ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { SmallVector regionArgs; SmallVector operands; Region *cond = result.addRegion(); Region *body = result.addRegion(); OptionalParseResult listResult = parser.parseOptionalAssignmentList(regionArgs, operands); if (listResult.has_value() && failed(listResult.value())) return failure(); FunctionType functionType; SMLoc typeLoc = parser.getCurrentLocation(); if (failed(parser.parseColonType(functionType))) return failure(); result.addTypes(functionType.getResults()); if (functionType.getNumInputs() != operands.size()) { return parser.emitError(typeLoc) << "expected as many input types as operands " << "(expected " << operands.size() << " got " << functionType.getNumInputs() << ")"; } // Resolve input operands. if (failed(parser.resolveOperands(operands, functionType.getInputs(), parser.getCurrentLocation(), result.operands))) return failure(); // Propagate the types into the region arguments. for (size_t i = 0, e = regionArgs.size(); i != e; ++i) regionArgs[i].type = functionType.getInput(i); return failure(parser.parseRegion(*cond, regionArgs) || parser.parseKeyword("do") || parser.parseRegion(*body) || parser.parseOptionalAttrDictWithKeyword(result.attributes)); } void WhileOp::print(OpAsmPrinter &parser) { printInitializationList(parser, getCondGraph().front().getArguments(), getInputList(), " "); parser << " : "; parser.printFunctionalType(getInputList().getTypes(), getResults().getTypes()); parser << ' '; parser.printRegion(getCondGraph(), /*printEntryBlockArgs=*/false); parser << " do "; parser.printRegion(getBodyGraph()); parser.printOptionalAttrDictWithKeyword((*this)->getAttrs()); } // Create a rank-1 const tensor for zero point of the source tensor. std::optional mlir::tosa::createZeroPointTensor(OpBuilder &builder, Location loc, Type srcElemType, int64_t zp) { srcElemType = getStorageElementTypeOrSelf(srcElemType); auto zpType = mlir::RankedTensorType::get({1}, srcElemType); if (llvm::isa(srcElemType)) { auto zpAttr = DenseElementsAttr::get( zpType, builder.getFloatAttr(srcElemType, static_cast(zp))); return tosa::ConstOp::create(builder, loc, zpType, zpAttr); } if (llvm::isa(srcElemType)) { auto zpAttr = DenseElementsAttr::get(zpType, builder.getIntegerAttr(srcElemType, zp)); return tosa::ConstOp::create(builder, loc, zpType, zpAttr); } llvm::errs() << "zero point is not allowed for unsupported data types\n"; return std::nullopt; } //===----------------------------------------------------------------------===// // TOSA Shape and Shape Operators Helper functions. //===----------------------------------------------------------------------===// bool mlir::tosa::isa_tosa_shape_type(mlir::Type t) { return mlir::isa(t); } LogicalResult mlir::tosa::shapeType::verify(function_ref emitError, int rank) { if (rank < 0) return emitError() << "invalid rank (must be >= 0): " << rank; return success(); } LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) { for (auto v : op->getOperands()) { if (mlir::isa<::mlir::tosa::shapeType>(v.getType())) { Operation *definingOp = v.getDefiningOp(); if (!definingOp || !definingOp->hasTrait()) { return op->emitOpError("shape operand is not compile time resolvable"); } } } return success(); } LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) { for (auto type : op->getOperandTypes()) { if (!mlir::isa(type)) { return op->emitOpError("must have operands with tosa shape type"); } } for (auto type : op->getResultTypes()) { if (!mlir::isa(type)) { return op->emitOpError("must have result with tosa shape type"); } } return success(); } LogicalResult OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) { if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) || failed(verifyTosaShapeOperator(op))) return failure(); // delegate function that returns rank of shape type auto getRank = [](const Type type) { return mlir::cast(type).getRank(); }; auto operandTypes = op->getOperandTypes(); auto resultTypes = op->getResultTypes(); auto rank = getRank(*op->getOperandTypes().begin()); for (auto type : operandTypes) { if (getRank(type) != rank) { return op->emitOpError("operands don't have matching ranks"); } } for (auto type : resultTypes) { if (getRank(type) != rank) { return op->emitOpError("result shape has different rank than operands"); } } return success(); } //===----------------------------------------------------------------------===// // TOSA Shape Operators verify functions. //===----------------------------------------------------------------------===// LogicalResult tosa::ConstShapeOp::verify() { // check one dimensional rank auto valuesRank = getValues().getType().getRank(); if (valuesRank != 1) return emitOpError("expect elements in attribute values with rank 1"); // check that number of elements in values attr equal to rank of result shape auto count = getValues().getNumElements(); auto rank = (cast(getResult().getType())).getRank(); if (count != rank && (count != 1 || rank != 0)) { return emitOpError("expect number of elements in attribute values (") << count << ") to be equal to the rank (" << rank << ") for the result shape type"; } return success(); } //===----------------------------------------------------------------------===// // TOSA Attribute Definitions. //===----------------------------------------------------------------------===// #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaAttributes.cpp.inc" //===----------------------------------------------------------------------===// // TOSA Type Definitions. //===----------------------------------------------------------------------===// #define GET_TYPEDEF_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOpsTypesBase.cpp.inc" //===----------------------------------------------------------------------===// // TOSA Operator Definitions. //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Tosa/IR/TosaOps.cpp.inc"