diff options
Diffstat (limited to 'mlir/lib/Dialect/Tosa/IR')
| -rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp | 2 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 553 | ||||
| -rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 813 |
3 files changed, 1057 insertions, 311 deletions
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp index eb47e85..9f616b2 100644 --- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp @@ -43,6 +43,8 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) { return TosaSpecificationVersion(1, 0); case Extension::mxfp: case Extension::int64: + case Extension::mxfp_conv: + case Extension::shape: return TosaSpecificationVersion(1, 1); case Extension::none: return TosaSpecificationVersion(0, 0); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 293c6af..bb715a9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" @@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> { auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) { - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); } Attribute newMinValAttr, newMaxValAttr; @@ -884,38 +885,365 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results, SliceDynamicSizeCanonicalization>(context); } +struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> { + using OpRewritePattern<tosa::CastOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::CastOp castOp, + PatternRewriter &rewriter) const override { + const Value castInput = castOp.getInput(); + auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>(); + if (!innerCastOp) + return rewriter.notifyMatchFailure(castOp, + "input must be cast operation"); + + const Value innerCastInput = innerCastOp.getInput(); + + const auto innerInputType = + llvm::cast<ShapedType>(innerCastInput.getType()); + const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType()); + const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType()); + + const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType, + outerOutputType}; + if (llvm::any_of(types, [](const ShapedType type) { + return !type.getElementType().isInteger(); + })) + return rewriter.notifyMatchFailure(castOp, + "only integer types are supported"); + + // Check inner cast is non-narrowing + const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth(); + if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth()) + return rewriter.notifyMatchFailure(castOp, + "inner cast operation is narrowing"); + + // Check outer cast is non-narrowing from the inner cast input + if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth()) + return rewriter.notifyMatchFailure(castOp, + "outer cast operation is narrowing"); + + rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType, + innerCastInput); + + return success(); + } +}; + +void CastOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add<NonNarrowingCastsOptimization>(context); +} + //===----------------------------------------------------------------------===// // Operator Folders. //===----------------------------------------------------------------------===// -template <typename IntFolder, typename FloatFolder> -static DenseElementsAttr binaryFolder(DenseElementsAttr lhs, - DenseElementsAttr rhs, - RankedTensorType returnTy) { - if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) { - auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType(); - auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType(); - if (lETy != rETy) - return {}; +template <typename Folder> +static DenseElementsAttr +binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy, + bool foldDenseValues = false) { + if (!lhs || !rhs) + return {}; - if (llvm::isa<IntegerType>(lETy)) { - APInt l = lhs.getSplatValue<APInt>(); - APInt r = rhs.getSplatValue<APInt>(); - auto result = IntFolder()(l, r); - return DenseElementsAttr::get(returnTy, result); + const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType(); + const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType(); + if (lETy != rETy) + return {}; + + if (lhs.isSplat() && rhs.isSplat()) { + if (isa<FloatType>(lETy)) { + const APFloat l = lhs.getSplatValue<APFloat>(); + const APFloat r = rhs.getSplatValue<APFloat>(); + const auto maybeResult = Folder::fold(l, r); + if (failed(maybeResult)) + return {}; + return DenseElementsAttr::get(returnTy, maybeResult.value()); } - if (llvm::isa<FloatType>(lETy)) { - APFloat l = lhs.getSplatValue<APFloat>(); - APFloat r = rhs.getSplatValue<APFloat>(); - auto result = FloatFolder()(l, r); - return DenseElementsAttr::get(returnTy, result); + if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) { + const APInt l = lhs.getSplatValue<APInt>(); + const APInt r = rhs.getSplatValue<APInt>(); + const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned()); + if (failed(maybeResult)) + return {}; + return DenseElementsAttr::get(returnTy, maybeResult.value()); + } + } + + if (foldDenseValues) { + assert(lETy.isIntOrIndex() && + "Only integer types are currently supported."); + SmallVector<APInt> resultValues; + for (auto [l, r] : + llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) { + const auto maybeResult = Folder::fold(l, r, false); + if (failed(maybeResult)) + return {}; + resultValues.push_back(maybeResult.value()); } + return DenseElementsAttr::get(returnTy, resultValues); } return {}; } +template <typename Folder> +static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy, + bool foldDenseValues = false) { + if (!val) + return {}; + + const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType(); + + if (val.isSplat()) { + if (const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) { + const APInt v = val.getSplatValue<APInt>(); + const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned()); + if (failed(maybeResult)) + return {}; + return DenseElementsAttr::get(returnTy, maybeResult.value()); + } + } + + if (foldDenseValues) { + mlir::Type elemTy = val.getElementType(); + if (elemTy.isIntOrIndex()) { + SmallVector<APInt> resultValues; + for (auto const &v : val.getValues<APInt>()) { + const auto maybeResult = Folder::fold(v, false); + if (failed(maybeResult)) + return {}; + resultValues.push_back(maybeResult.value()); + } + return DenseElementsAttr::get(returnTy, resultValues); + } + } + + // Folding arbitrarily sized tensor operations is not supported + return {}; +} + +struct AddFoldAdaptor { + static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + bool overflow; + const APInt result = + isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow); + if (overflow) + return failure(); + return result; + } + + static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) { + return lhs + rhs; + } +}; + +struct SubFoldAdaptor { + static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + bool overflow; + const APInt result = + isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow); + if (overflow) + return failure(); + return result; + } + + static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) { + return lhs - rhs; + } +}; + +struct MulFoldAdaptor { + static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + + const unsigned originalWidth = lhs.getBitWidth(); + + // Check same type + if (lhs.getBitWidth() != rhs.getBitWidth()) { + return failure(); + } + + // If either is `0` + if (lhs == 0 || rhs == 0) + return APInt::getZero(originalWidth); + + bool overflow = false; + APInt const result = + isUnsigned ? lhs.umul_ov(rhs, overflow) : lhs.smul_ov(rhs, overflow); + + if (overflow) + return failure(); + + return result.trunc(originalWidth); + } + + static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) { + return lhs * rhs; + } +}; + +static bool signsDiffer(const APInt &a, const APInt &b) { + return a.isNegative() != b.isNegative(); +} + +template <bool Ceil> +struct DivFoldAdaptor { + static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs, + bool isUnsigned) { + if (lhs.getBitWidth() != rhs.getBitWidth()) + return failure(); + if (rhs.isZero()) + return failure(); + + if (isUnsigned) { + APInt q{}; + APInt r{}; + APInt::udivrem(lhs, rhs, q, r); + if (!r.isZero() && Ceil) { + return q + 1; + } + return q; + } + + // Signed: start from trunc-toward-zero, then adjust to ceil. + bool overflow{false}; + APInt const q = lhs.sdiv_ov(rhs, overflow); + if (overflow) + return failure(); + APInt const r = lhs.srem(rhs); + + if (Ceil && !r.isZero() && !signsDiffer(lhs, rhs)) { + // Same sign => exact quotient is positive; trunc is below ceil => + // increment q. + return q + 1; + } + return q; + } + + static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) { + return lhs / rhs; + } +}; + +struct ModFoldAdaptor { + static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs, + bool isUnsigned) { + if (lhs.getBitWidth() != rhs.getBitWidth()) + return failure(); + if (lhs.isNegative() || (!rhs.isStrictlyPositive())) + return failure(); + + if (isUnsigned) { + return lhs.urem(rhs); + } + + return lhs.srem(rhs); + } + + static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) { + auto t = lhs; + auto const r = t.mod(rhs); + if (llvm::APFloatBase::opStatus::opOK == r) { + return t; + } + return failure(); + } +}; + +struct MaxFoldAdaptor { + static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs, + bool isUnsigned) { + if (lhs.getBitWidth() != rhs.getBitWidth()) + return failure(); + return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs; + } + + static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) { + return lhs >= rhs ? lhs : rhs; + } +}; + +struct MinFoldAdaptor { + static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs, + bool isUnsigned) { + if (lhs.getBitWidth() != rhs.getBitWidth()) + return failure(); + return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs; + } + + static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) { + return lhs <= rhs ? lhs : rhs; + } +}; + +struct Exp2FoldAdaptor { + static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) { + auto const numBits = value.getBitWidth(); + if (isUnsigned) { + auto const zextv = value.getZExtValue(); + if (zextv >= numBits) + return failure(); + return APInt::getOneBitSet(numBits, zextv); + } + auto const sextv = value.getSExtValue(); + if (sextv < 0 || sextv >= numBits || (value.isNegative())) + return failure(); + return APInt::getOneBitSet(numBits, sextv); + } +}; + +struct Log2CeilFoldAdaptor { + static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) { + if (!value.isStrictlyPositive()) + return failure(); + return APInt(/*numBits=*/value.getBitWidth(), value.ceilLogBase2()); + } +}; + +struct Log2FloorFoldAdaptor { + static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) { + if (!value.isStrictlyPositive()) + return failure(); + return APInt(/*numBits=*/value.getBitWidth(), value.logBase2()); + } +}; + +struct GreaterFoldAdaptor { + static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs)); + } + + static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) { + return APInt(1, lhs > rhs); + } +}; + +struct GreaterEqualFoldAdaptor { + static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs)); + } + + static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) { + return APInt(1, lhs >= rhs); + } +}; + +struct EqualFoldAdaptor { + static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs, + const bool isUnsigned) { + return APInt(1, lhs == rhs); + } + + static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) { + return APInt(1, lhs == rhs); + } +}; + static bool isSplatZero(Type elemType, DenseElementsAttr val) { if (llvm::isa<FloatType>(elemType)) return val && val.isSplat() && val.getSplatValue<APFloat>().isZero(); @@ -962,8 +1290,7 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr, - resultTy); + return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy); } OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) { @@ -992,7 +1319,8 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) { if (lhsTy != rhsTy) return {}; - // IntDivOp inputs must be integer type, no need to check for quantized type + // IntDivOp inputs must be integer type, no need to check for quantized + // type auto resultETy = resultTy.getElementType(); auto lhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1()); @@ -1015,8 +1343,12 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) { APInt l = lhsAttr.getSplatValue<APInt>(); APInt r = rhsAttr.getSplatValue<APInt>(); if (!r.isZero()) { - APInt result = l.sdiv(r); - return DenseElementsAttr::get(resultTy, result); + auto intTy = dyn_cast<mlir::IntegerType>(resultETy); + auto const result = + DivFoldAdaptor</*Ceil*/ false>::fold(l, r, intTy.isUnsigned()); + if (failed(result)) + return {}; + return DenseElementsAttr::get(resultTy, result.value()); } } @@ -1028,13 +1360,18 @@ namespace { // return nullopt if result is not in range of int32_t when shift > 0 std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift, unsigned bitwidth) { - APInt result = lhs.sext(64) * rhs.sext(64); + bool overflow = false; + APInt result = lhs.sext(64).smul_ov(rhs.sext(64), overflow); + + if (overflow) + return std::nullopt; if (shift > 0) { auto round = APInt(64, 1) << (shift - 1); result += round; result.ashrInPlace(shift); - // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>()) + // REQUIRE(product >= minimum_s<i32_t>() && product <= + // maximum_s<i32_t>()) if (!(result.getSExtValue() >= INT32_MIN && result.getSExtValue() <= INT32_MAX)) { // REQUIRE failed @@ -1090,8 +1427,8 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) { auto rhsAttr = llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2()); - // Result right shift on i32_t data type only. For simplification, synthesize - // a zero shift for other data type. + // Result right shift on i32_t data type only. For simplification, + // synthesize a zero shift for other data type. int32_t shift = 0; if (resultETy.isInteger(32)) { ElementsAttr shift_elem; @@ -1144,38 +1481,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr, - resultTy); + return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy); } -namespace { -template <typename Cmp> -struct ComparisonFold { - ComparisonFold() = default; - APInt operator()(const APInt &l, const APInt &r) { - return APInt(1, Cmp()(l, r)); - } - - APInt operator()(const APFloat &l, const APFloat &r) { - return APInt(1, Cmp()(l, r)); - } -}; - -struct APIntFoldGreater { - APIntFoldGreater() = default; - APInt operator()(const APInt &l, const APInt &r) { - return APInt(1, l.sgt(r)); - } -}; - -struct APIntFoldGreaterEqual { - APIntFoldGreaterEqual() = default; - APInt operator()(const APInt &l, const APInt &r) { - return APInt(1, l.sge(r)); - } -}; -} // namespace - OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { auto resultTy = llvm::dyn_cast<RankedTensorType>(getType()); auto lhsAttr = @@ -1186,8 +1494,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>( - lhsAttr, rhsAttr, resultTy); + return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy); } OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { @@ -1200,9 +1507,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder<APIntFoldGreaterEqual, - ComparisonFold<std::greater_equal<APFloat>>>( - lhsAttr, rhsAttr, resultTy); + return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy); } OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { @@ -1215,8 +1520,8 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { Value rhs = getInput2(); auto lhsTy = llvm::cast<ShapedType>(lhs.getType()); - // If we are comparing an integer value to itself it is always true. We can - // not do this with float due to float values. + // If we are comparing an integer value to itself it is always true. We + // can not do this with float due to float values. if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy && resultTy.hasStaticShape() && lhs == rhs) { return DenseElementsAttr::get(resultTy, true); @@ -1225,9 +1530,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) { if (!lhsAttr || !rhsAttr) return {}; - return binaryFolder<ComparisonFold<std::equal_to<APInt>>, - ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr, - resultTy); + return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy); } OpFoldResult CastOp::fold(FoldAdaptor adaptor) { @@ -1329,9 +1632,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) { if (!inputTy || !outputTy) return {}; - // Fold when the input and output types are the same. This is only safe when - // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions, - // there may still be a productive reshape. + // Fold when the input and output types are the same. This is only safe + // when there is at most 1 dynamic dimension. For 2 or more dynamic + // dimensions, there may still be a productive reshape. if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2) return getInput1(); @@ -1485,7 +1788,24 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { return {}; } +static bool +mayRequireBroadcast(ValueTypeRange<mlir::OperandRange> operandTypes) { + const auto isDynamic = [](Type ty) { + const auto shapedTy = llvm::dyn_cast<ShapedType>(ty); + return !shapedTy || !shapedTy.hasStaticShape(); + }; + + return llvm::any_of(operandTypes, isDynamic) || + failed(verifyCompatibleShapes(operandTypes)); +} + OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { + // Select allows operand shapes to be broadcast to the output shape. For + // now, don't support folding when we cannot prove no broadcasting is + // involved. + if (mayRequireBroadcast(getOperandTypes())) + return {}; + if (getOnTrue() == getOnFalse()) return getOnTrue(); @@ -1632,3 +1952,92 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) { return {}; } + +template <typename Op, typename OpFoldAdaptor> +OpFoldResult unaryShapeFold(Op *op) { + auto input1ConstShape = + dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp()); + if (!input1ConstShape) + return {}; + + const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues()); + + return unaryFolder<OpFoldAdaptor>(input1Attr, input1Attr.getType(), + /*foldDenseValues=*/true); +} + +template <typename Op, typename OpFoldAdaptor> +OpFoldResult binaryFold(Op *op) { + auto input1ConstShape = + dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp()); + auto input2ConstShape = + dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp()); + if (!input1ConstShape || !input2ConstShape) + return {}; + + const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues()); + const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues()); + + return binaryFolder<OpFoldAdaptor>(input1Attr, input2Attr, + input1Attr.getType(), + /*foldDenseValues=*/true); +} + +OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) { + const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType()); + if (!inputTy || !inputTy.hasRank()) + return {}; + const int32_t axis = getAxis(); + const int64_t dimSize = inputTy.getDimSize(axis); + if (ShapedType::isDynamic(dimSize)) + return {}; + + OpBuilder builder(getContext()); + const auto resultAttrTy = + RankedTensorType::get(/*rank=*/1, builder.getIndexType()); + return DenseElementsAttr::get(resultAttrTy, dimSize); +} + +OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) { + return binaryFold<AddShapeOp, AddFoldAdaptor>(this); +} + +OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) { + return binaryFold<SubShapeOp, SubFoldAdaptor>(this); +} + +OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) { + return binaryFold<MulShapeOp, MulFoldAdaptor>(this); +} + +OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) { + return binaryFold<DivCeilShapeOp, DivFoldAdaptor</*Ceil*/ true>>(this); +} + +OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) { + return binaryFold<DivFloorShapeOp, DivFoldAdaptor</*Ceil*/ false>>(this); +} + +OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) { + return binaryFold<ModShapeOp, ModFoldAdaptor>(this); +} + +OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) { + return binaryFold<MaxShapeOp, MaxFoldAdaptor>(this); +} + +OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) { + return binaryFold<MinShapeOp, MinFoldAdaptor>(this); +} + +OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) { + return unaryShapeFold<Exp2ShapeOp, Exp2FoldAdaptor>(this); +} + +OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) { + return unaryShapeFold<Log2CeilShapeOp, Log2CeilFoldAdaptor>(this); +} + +OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) { + return unaryShapeFold<Log2FloorShapeOp, Log2FloorFoldAdaptor>(this); +} diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 0aff67f..798fc36 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/Dialect/Tosa/Utils/ShapeUtils.h" #include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/VerificationUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/Matchers.h" @@ -26,6 +27,7 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/SmallVectorExtras.h" #include "llvm/ADT/TypeSwitch.h" #include <numeric> @@ -134,9 +136,9 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() { //===----------------------------------------------------------------------===// static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) { - return to_vector(llvm::map_range(shape, [](int64_t dim) { + return map_to_vector(shape, [](int64_t dim) { return dim == -1 ? ShapedType::kDynamic : dim; - })); + }); } // returns type of variable op @@ -550,6 +552,15 @@ void CastToBlockScaledOp::print(OpAsmPrinter &parser) { printWithEnumHandling(parser, *this); } +ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// @@ -563,7 +574,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) { static Type getStorageElementTypeOrSelf(Type type) { auto srcType = getElementTypeOrSelf(type); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType)) - srcType = quantType.getStorageType(); + srcType = getStorageElementTypeFromQuantized(quantType); return srcType; } @@ -606,6 +617,61 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc, return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr); } +unsigned mlir::tosa::getBitWidth(Type type) { + if (dyn_cast<tosa::mxint8Type>(type)) + return 8; + return type.getIntOrFloatBitWidth(); +} + +// Update dim size if current dim is dynamic, otherwise raise an error if sizes +// do not match +LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim, + const int64_t newDim, + const StringRef operandName, + const StringRef dimName) { + if (ShapedType::isDynamic(currDim)) { + currDim = newDim; + return success(); + } else if (ShapedType::isStatic(newDim) && currDim != newDim) { + return op->emitOpError("expected ") + << dimName << " of " << operandName << " to match size " << currDim + << ", got " << newDim; + } + return success(); +} + +LogicalResult verifyConvOutputSize( + Operation *op, const int64_t inputSize, const int64_t kernelSize, + const int64_t outputSize, const int64_t padBefore, const int64_t padAfter, + const int64_t stride, const int64_t dilation, const llvm::StringRef dimName, + const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName, + const llvm::StringRef padAfterName) { + if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic) + return success(); + + // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1 + + const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck( + inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation, + stride); + if (!calculatedOutSizeMinusOne.has_value()) + return op->emitOpError("expected input_") + << dimName << " - 1 + pad_" << padBeforeName << " + pad_" + << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_" + << dimAxis << " to be wholly divisible by stride_" << dimAxis + << ", got (" << inputSize << " - 1 + " << padBefore << " + " + << padAfter << " - (" << kernelSize << " - 1) * " << dilation + << ") / " << stride; + + const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1; + if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize) + return op->emitOpError("calculated output ") + << dimName << " did not match expected: " + << "calculated=" << calculatedOutSize << ", expected=" << outputSize; + + return success(); +} + //===----------------------------------------------------------------------===// // TOSA Operator Verifiers. //===----------------------------------------------------------------------===// @@ -625,16 +691,16 @@ static LogicalResult verifyConvOp(T op) { bool resultIsFloat = llvm::isa<FloatType>(resultEType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType)) - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType)) - weightEType = quantType.getStorageType(); + weightEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType)) - biasEType = quantType.getStorageType(); + biasEType = getStorageElementTypeFromQuantized(quantType); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType)) - resultEType = quantType.getStorageType(); + resultEType = getStorageElementTypeFromQuantized(quantType); if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) { // for now, only enforce bias element type == result element type for @@ -703,7 +769,7 @@ LogicalResult tosa::ConstOp::verify() { if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>( outputType.getElementType())) { - if (result.getStorageType() == attrType.getElementType()) + if (getStorageElementTypeFromQuantized(result) == attrType.getElementType()) return success(); } @@ -721,32 +787,40 @@ static LogicalResult verifyConvOpModes(T op) { llvm::cast<ShapedType>(op.getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType)) - inputEType = quantType.getStorageType(); + inputEType = getStorageElementTypeFromQuantized(quantType); auto accType = op.getAccType(); if (inputEType.isInteger(8) && !accType.isInteger(32)) - return op.emitOpError("accumulator type for i8 tensor is not i32"); + return op.emitOpError("accumulator type for i8 tensor is not i32, got ") + << accType; if (inputEType.isInteger(16) && !accType.isInteger(48)) - return op.emitOpError("accumulator type for i16 tensor is not i48"); + return op.emitOpError("accumulator type for i16 tensor is not i48, got ") + << accType; - if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16()) - return op.emitOpError("accumulator type for f8 tensor is not f16"); + if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && + !(accType.isF16() || accType.isF32())) + return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ") + << accType; if (inputEType.isF16() && !(accType.isF16() || accType.isF32())) - return op.emitOpError("accumulator type for f16 tensor is not f16/f32"); + return op.emitOpError( + "accumulator type for f16 tensor is not f16/f32, got ") + << accType; if (inputEType.isBF16() && !accType.isF32()) - return op.emitOpError("accumulator type for bf16 tensor is not f32"); + return op.emitOpError("accumulator type for bf16 tensor is not f32, got ") + << accType; if (inputEType.isF32() && !accType.isF32()) - return op.emitOpError("accumulator type for f32 tensor is not f32"); + return op.emitOpError("accumulator type for f32 tensor is not f32, got ") + << accType; auto resultEType = llvm::cast<ShapedType>(op.getResult().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType)) - resultEType = quantType.getStorageType(); + resultEType = getStorageElementTypeFromQuantized(quantType); return success(); } @@ -785,53 +859,16 @@ static LogicalResult verifyConvOpErrorIf(T op) { llvm::dyn_cast<RankedTensorType>(op.getWeight().getType()); if (inputType && weightType) { - const auto verifyOutputSize = - [&op](const int64_t inputSize, const int64_t kernelSize, - const int64_t outputSize, const int64_t padBefore, - const int64_t padAfter, const int64_t stride, - const int64_t dilation, const llvm::StringRef dimName, - const llvm::StringRef dimAxis, - const llvm::StringRef padBeforeName, - const llvm::StringRef padAfterName) -> LogicalResult { - if (inputSize == ShapedType::kDynamic || - kernelSize == ShapedType::kDynamic) - return success(); - - // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1 - - const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck( - inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation, - stride); - if (!calculatedOutSizeMinusOne.has_value()) - return op.emitOpError("expected input_") - << dimName << " - 1 + pad_" << padBeforeName << " + pad_" - << padAfterName << " - (kernel_" << dimName - << " - 1) * dilation_" << dimAxis - << " to be wholly divisible by stride_" << dimAxis << ", got (" - << inputSize << " - 1 + " << padBefore << " + " << padAfter - << " - (" << kernelSize << " - 1) * " << dilation << ") / " - << stride; - - const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1; - if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize) - return op.emitOpError("calculated output ") - << dimName << " did not match expected: " - << "calculated=" << calculatedOutSize - << ", expected=" << outputSize; - - return success(); - }; - // input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_] if constexpr (std::is_same<T, tosa::Conv2DOp>::value) { - if (failed(verifyOutputSize( - inputType.getDimSize(1), weightType.getDimSize(1), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(1), weightType.getDimSize(1), outputType.getDimSize(1), padding[0], padding[1], strides[0], dilations[0], "height", "y", "top", "bottom"))) return failure(); - if (failed(verifyOutputSize( - inputType.getDimSize(2), weightType.getDimSize(2), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(2), weightType.getDimSize(2), outputType.getDimSize(2), padding[2], padding[3], strides[1], dilations[1], "width", "x", "left", "right"))) return failure(); @@ -839,14 +876,14 @@ static LogicalResult verifyConvOpErrorIf(T op) { // input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_] if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) { - if (failed(verifyOutputSize( - inputType.getDimSize(1), weightType.getDimSize(0), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(1), weightType.getDimSize(0), outputType.getDimSize(1), padding[0], padding[1], strides[0], dilations[0], "height", "y", "top", "bottom"))) return failure(); - if (failed(verifyOutputSize( - inputType.getDimSize(2), weightType.getDimSize(1), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(2), weightType.getDimSize(1), outputType.getDimSize(2), padding[2], padding[3], strides[1], dilations[1], "width", "x", "left", "right"))) return failure(); @@ -854,20 +891,20 @@ static LogicalResult verifyConvOpErrorIf(T op) { // input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_] if constexpr (std::is_same<T, tosa::Conv3DOp>::value) { - if (failed(verifyOutputSize( - inputType.getDimSize(1), weightType.getDimSize(1), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(1), weightType.getDimSize(1), outputType.getDimSize(1), padding[0], padding[1], strides[0], dilations[0], "depth", "d", "front", "back"))) return failure(); - if (failed(verifyOutputSize( - inputType.getDimSize(2), weightType.getDimSize(2), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(2), weightType.getDimSize(2), outputType.getDimSize(2), padding[2], padding[3], strides[1], dilations[1], "height", "y", "top", "bottom"))) return failure(); - if (failed(verifyOutputSize( - inputType.getDimSize(3), weightType.getDimSize(3), + if (failed(verifyConvOutputSize( + op, inputType.getDimSize(3), weightType.getDimSize(3), outputType.getDimSize(3), padding[4], padding[5], strides[2], dilations[2], "width", "x", "left", "right"))) return failure(); @@ -1105,9 +1142,8 @@ static LogicalResult verifyPoolingOp(T op) { const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1; if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize) return op.emitOpError("calculated output ") - << dimName << " did not match expected: " - << "calculated=" << calculatedOutSize - << ", expected=" << outputSize; + << dimName << " did not match expected: " << "calculated=" + << calculatedOutSize << ", expected=" << outputSize; return success(); }; @@ -1173,13 +1209,13 @@ LogicalResult tosa::ClampOp::verify() { llvm::cast<ShapedType>(getInput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) { - inputETy = quantType.getStorageType(); + inputETy = getStorageElementTypeFromQuantized(quantType); } mlir::Type outputETy = llvm::cast<ShapedType>(getOutput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) { - outputETy = quantType.getStorageType(); + outputETy = getStorageElementTypeFromQuantized(quantType); } if (inputETy != outputETy) return emitOpError("input/output element types are incompatible."); @@ -1445,6 +1481,19 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result, //===----------------------------------------------------------------------===// // TOSA Operator Return Type Inference. //===----------------------------------------------------------------------===// +static FailureOr<int64_t> resolveBroadcastDim(const int64_t dim1, + const int64_t dim2) { + if (dim1 == 1) + return dim2; + if (dim2 == 1) + return dim1; + + if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2) + return failure(); + + // Prefer static dimension over dynamic + return ShapedType::isDynamic(dim1) ? dim2 : dim1; +} static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, SmallVector<int64_t> &outShape) { @@ -1468,15 +1517,12 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands, for (size_t i = 0, e = shape.getRank(); i < e; ++i) { auto dim1 = outShape[i + rankDiff]; auto dim2 = shape.getDimSize(i); - auto resolvedDim = dim1; - if (dim1 == 1) { - resolvedDim = dim2; - } else if (dim2 == 1) { - resolvedDim = dim1; - } else if (dim1 != dim2) { + const FailureOr<int64_t> maybeResolvedDim = + resolveBroadcastDim(dim1, dim2); + if (failed(maybeResolvedDim)) return failure(); - } + const int64_t resolvedDim = *maybeResolvedDim; outShape[i + rankDiff] = resolvedDim; } } @@ -1755,6 +1801,11 @@ LogicalResult tosa::ConcatOp::verify() { } } + const ShapeAdaptor outputShape(outType); + if (outputShape.hasRank() && outputShape.getRank() != firstInputRank) + return emitOpError("expect output rank to match inputs rank, got ") + << outputShape.getRank() << " vs " << firstInputRank; + // ERROR_IF(axis_sum != shape[axis]); int64_t axisSum = 0; for (const auto &input : inputList) { @@ -1766,7 +1817,7 @@ LogicalResult tosa::ConcatOp::verify() { } axisSum += inputShape.getDimSize(axis); } - const ShapeAdaptor outputShape(outType); + if (axisSum >= 0 && outputShape.hasRank() && !outputShape.isDynamicDim(axis) && axisSum != outputShape.getDimSize(axis)) @@ -1943,20 +1994,6 @@ LogicalResult MatmulTBlockScaledOp::verify() { "B_data"))) return failure(); - auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim, - const StringRef operandName, - const StringRef dimName) -> LogicalResult { - if (ShapedType::isDynamic(currDim)) { - currDim = newDim; - return success(); - } else if (ShapedType::isStatic(newDim) && currDim != newDim) { - return emitOpError("expected ") - << dimName << " of " << operandName << " to match size " << currDim - << ", got " << newDim; - } - return success(); - }; - // Verify input shape compatibility int64_t N = ShapedType::kDynamic; int64_t D = ShapedType::kDynamic; @@ -1974,32 +2011,33 @@ LogicalResult MatmulTBlockScaledOp::verify() { const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType()); if (aScaleShape.hasRank()) { - if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale", - "batch")) || - failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale", - "height"))) + if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0), + "a_scale", "batch")) || + failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1), + "a_scale", "height"))) return failure(); multiplesOfC = aScaleShape.getDimSize(2); } const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType); if (bDataShape.hasRank()) { - if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data", - "batch")) || - failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data", - "channels"))) + if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0), + "b_data", "batch")) || + failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2), + "b_data", "channels"))) return failure(); W = bDataShape.getDimSize(1); } const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType()); if (bScaleShape.hasRank()) { - if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale", - "batch")) || - failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale", - "width")) || - failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2), - "b_scale", "C/block_size"))) + if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0), + "b_scale", "batch")) || + failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1), + "b_scale", "width")) || + failed(tryUpdateDimOrFailure(*this, multiplesOfC, + bScaleShape.getDimSize(2), "b_scale", + "C/block_size"))) return failure(); } @@ -2020,8 +2058,7 @@ LogicalResult MatmulTBlockScaledOp::verify() { multiplesOfC != C / blockSize) return emitOpError( "expect scale operands dimension 2 to equal C/block_size (") - << C << "/" << blockSize << ")" - << ", got " << multiplesOfC; + << C << "/" << blockSize << ")" << ", got " << multiplesOfC; // Verify output shape N = ShapedType::isDynamic(N) ? D : N; @@ -2115,17 +2152,14 @@ LogicalResult tosa::PadOp::verify() { if (!inputType || !outputType) return success(); - auto inputRank = inputType.getRank(); - auto outputRank = outputType.getRank(); - if (inputRank != outputRank) - return emitOpError() << "expect same input and output tensor rank, but got " - << "inputRank: " << inputRank - << ", outputRank: " << outputRank; + if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input", + "output"))) + return failure(); + auto inputRank = inputType.getRank(); DenseIntElementsAttr paddingAttr; - if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) { - return failure(); - } + if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) + return success(); auto paddingValues = paddingAttr.getValues<APInt>(); if (paddingValues.size() != static_cast<size_t>(inputRank * 2)) @@ -2290,9 +2324,9 @@ LogicalResult tosa::MulOp::verify() { } // verify shift has value 0 for non-integer types - ElementsAttr shift_elem; - if (matchPattern(getShift(), m_Constant(&shift_elem))) { - int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt(); + ElementsAttr shiftElem; + if (matchPattern(getShift(), m_Constant(&shiftElem))) { + int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt(); if (shift != 0) { return emitOpError() << "require shift to be 0 for float type"; } @@ -2362,9 +2396,9 @@ LogicalResult tosa::TableOp::verify() { if (!inputType.hasRank() || !outputType.hasRank()) return success(); - if (inputType.getRank() != outputType.getRank()) - return emitOpError() - << "expected input tensor rank to equal result tensor rank"; + if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input", + "result"))) + return failure(); auto inputDims = inputType.getShape(); auto outputDims = outputType.getShape(); @@ -2386,9 +2420,9 @@ tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) { DenseIntElementsAttr multiplesAttr; if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr))) return failure(); - multiples = llvm::to_vector( - llvm::map_range(multiplesAttr.getValues<APInt>(), - [](const APInt &val) { return val.getSExtValue(); })); + multiples = + llvm::map_to_vector(multiplesAttr.getValues<APInt>(), + [](const APInt &val) { return val.getSExtValue(); }); return success(); } @@ -2454,8 +2488,10 @@ LogicalResult tosa::TileOp::verify() { if (inputType.getRank() != multiplesRank) return emitOpError("expect 'multiples' to have rank ") << inputType.getRank() << " but got " << multiplesRank << "."; - if (outputType.hasRank() && inputType.getRank() != outputType.getRank()) - return emitOpError("expect same input and output tensor rank."); + if (outputType.hasRank() && + failed(verifyRanksMatch(getOperation(), inputType, outputType, "input", + "output"))) + return failure(); } else if (outputType.hasRank() && outputType.getRank() != multiplesRank) return emitOpError("expect 'multiples' array to have length ") << outputType.getRank() << " but got " << multiplesRank << "."; @@ -2622,7 +2658,7 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp, if (!zpElemType.isInteger(8) && zp != 0) { // convert operand to lower case for error message std::string lower = operand; - std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower); + llvm::transform(lower, lower.begin(), ::tolower); return op.emitOpError() << lower << " zero point must be zero for non-int8 integer types"; } @@ -2780,8 +2816,8 @@ LogicalResult tosa::TransposeOp::verify() { return s >= 0 && static_cast<size_t>(s) < constantPerms.size(); }) || - !isPermutationVector(llvm::to_vector(llvm::map_range( - constantPerms, [](int32_t v) -> int64_t { return v; })))) + !isPermutationVector(llvm::map_to_vector( + constantPerms, [](int32_t v) -> int64_t { return v; }))) return emitOpError() << "expected valid permutation indices"; // ERROR_IF(tensor_size(shape1) != tensor_size(shape)) @@ -2871,39 +2907,39 @@ LogicalResult tosa::GatherOp::verify() { const ShapeAdaptor indicesShape(getIndices().getType()); const ShapeAdaptor outputShape(getOutput().getType()); - int64_t N = ShapedType::kDynamic; - int64_t W = ShapedType::kDynamic; - int64_t C = ShapedType::kDynamic; + int64_t n = ShapedType::kDynamic; + int64_t w = ShapedType::kDynamic; + int64_t c = ShapedType::kDynamic; if (valuesShape.hasRank()) { - N = valuesShape.getDimSize(0); - C = valuesShape.getDimSize(2); + n = valuesShape.getDimSize(0); + c = valuesShape.getDimSize(2); } if (indicesShape.hasRank()) { const int64_t indicesN = indicesShape.getDimSize(0); - W = indicesShape.getDimSize(1); - if (N == ShapedType::kDynamic) - N = indicesN; - else if (indicesN != ShapedType::kDynamic && N != indicesN) - return emitOpError() << "requires indices dimension 0 to have size " << N + w = indicesShape.getDimSize(1); + if (n == ShapedType::kDynamic) + n = indicesN; + else if (indicesN != ShapedType::kDynamic && n != indicesN) + return emitOpError() << "requires indices dimension 0 to have size " << n << ", got " << indicesN; } if (outputShape.hasRank()) { const int64_t outputN = outputShape.getDimSize(0); const int64_t outputW = outputShape.getDimSize(1); const int64_t outputC = outputShape.getDimSize(2); - if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic && - N != outputN) - return emitOpError() << "requires output dimension 0 to have size " << N + if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic && + n != outputN) + return emitOpError() << "requires output dimension 0 to have size " << n << ", got " << outputN; - if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic && - W != outputW) - return emitOpError() << "requires output dimension 1 to have size " << W + if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic && + w != outputW) + return emitOpError() << "requires output dimension 1 to have size " << w << ", got " << outputW; - if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic && - C != outputC) - return emitOpError() << "requires output dimension 2 to have size " << C + if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic && + c != outputC) + return emitOpError() << "requires output dimension 2 to have size " << c << ", got " << outputC; } return success(); @@ -3096,66 +3132,66 @@ LogicalResult tosa::ScatterOp::verify() { const ShapeAdaptor inputShape(getInput().getType()); const ShapeAdaptor outputShape(getValuesOut().getType()); - int64_t N = ShapedType::kDynamic; - int64_t K = ShapedType::kDynamic; - int64_t W = ShapedType::kDynamic; - int64_t C = ShapedType::kDynamic; + int64_t n = ShapedType::kDynamic; + int64_t k = ShapedType::kDynamic; + int64_t w = ShapedType::kDynamic; + int64_t c = ShapedType::kDynamic; if (valuesInShape.hasRank()) { - N = valuesInShape.getDimSize(0); - K = valuesInShape.getDimSize(1); - C = valuesInShape.getDimSize(2); + n = valuesInShape.getDimSize(0); + k = valuesInShape.getDimSize(1); + c = valuesInShape.getDimSize(2); } if (indicesShape.hasRank()) { const int64_t indicesN = indicesShape.getDimSize(0); - W = indicesShape.getDimSize(1); - if (N == ShapedType::kDynamic) - N = indicesN; - else if (indicesN != ShapedType::kDynamic && N != indicesN) - return emitOpError() << "requires indices dimension 0 to have size " << N + w = indicesShape.getDimSize(1); + if (n == ShapedType::kDynamic) + n = indicesN; + else if (indicesN != ShapedType::kDynamic && n != indicesN) + return emitOpError() << "requires indices dimension 0 to have size " << n << ", got " << indicesN; } if (inputShape.hasRank()) { const int64_t inputN = inputShape.getDimSize(0); const int64_t inputW = inputShape.getDimSize(1); const int64_t inputC = inputShape.getDimSize(2); - if (N == ShapedType::kDynamic) - N = inputN; - else if (inputN != ShapedType::kDynamic && N != inputN) - return emitOpError() << "requires input dimension 0 to have size " << N + if (n == ShapedType::kDynamic) + n = inputN; + else if (inputN != ShapedType::kDynamic && n != inputN) + return emitOpError() << "requires input dimension 0 to have size " << n << ", got " << inputN; - if (W == ShapedType::kDynamic) - W = inputW; - else if (inputW != ShapedType::kDynamic && W != inputW) - return emitOpError() << "requires input dimension 1 to have size " << W + if (w == ShapedType::kDynamic) + w = inputW; + else if (inputW != ShapedType::kDynamic && w != inputW) + return emitOpError() << "requires input dimension 1 to have size " << w << ", got " << inputW; - if (C == ShapedType::kDynamic) - C = inputC; - else if (inputC != ShapedType::kDynamic && C != inputC) - return emitOpError() << "requires input dimension 2 to have size " << C + if (c == ShapedType::kDynamic) + c = inputC; + else if (inputC != ShapedType::kDynamic && c != inputC) + return emitOpError() << "requires input dimension 2 to have size " << c << ", got " << inputC; } if (outputShape.hasRank()) { const int64_t outputN = outputShape.getDimSize(0); const int64_t outputK = outputShape.getDimSize(1); const int64_t outputC = outputShape.getDimSize(2); - if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic && - N != outputN) + if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic && + n != outputN) return emitOpError() << "requires values_out dimension 0 to have size " - << N << ", got " << outputN; - if (K == ShapedType::kDynamic) - K = outputK; - else if (outputK != ShapedType::kDynamic && K != outputK) + << n << ", got " << outputN; + if (k == ShapedType::kDynamic) + k = outputK; + else if (outputK != ShapedType::kDynamic && k != outputK) return emitOpError() << "requires values_out dimension 1 to have size " - << K << ", got " << outputK; - if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic && - C != outputC) + << k << ", got " << outputK; + if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic && + c != outputC) return emitOpError() << "requires values_out dimension 2 to have size " - << C << ", got " << outputC; + << c << ", got " << outputC; } - if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W)) - return emitOpError() << "requires dimensions K >= W, got K=" << K - << " and W=" << W; + if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w)) + return emitOpError() << "requires dimensions K >= W, got K=" << k + << " and W=" << w; return success(); } @@ -3383,7 +3419,7 @@ static LogicalResult poolingInferReturnTypes( outputShape.resize(4, ShapedType::kDynamic); // We only know the rank if the input type is unranked. - if (!inputShape) { + if (!inputShape.hasRank()) { inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); return success(); } @@ -3474,6 +3510,234 @@ LogicalResult Conv2DOp::verify() { return success(); } +LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + Conv2DBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic); + + int64_t inputWidth = ShapedType::kDynamic; + int64_t inputHeight = ShapedType::kDynamic; + int64_t weightWidth = ShapedType::kDynamic; + int64_t weightHeight = ShapedType::kDynamic; + + // Input shape describes input width/height and batch. + const ShapeAdaptor inputDataShape(adaptor.getInputData().getType()); + if (inputDataShape.hasRank()) { + outShape[0] = inputDataShape.getDimSize(0); + inputHeight = inputDataShape.getDimSize(1); + inputWidth = inputDataShape.getDimSize(2); + } + const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType()); + if (inputScaleShape.hasRank()) { + outShape[0] = ShapedType::isDynamic(outShape[0]) + ? inputScaleShape.getDimSize(0) + : outShape[0]; + inputHeight = ShapedType::isDynamic(inputHeight) + ? inputScaleShape.getDimSize(1) + : inputHeight; + inputWidth = ShapedType::isDynamic(inputWidth) + ? inputScaleShape.getDimSize(2) + : inputWidth; + } + + // Weight shapes describes the filter width/height and the output channels. + const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType()); + if (weightDataShape.hasRank()) { + outShape[3] = weightDataShape.getDimSize(0); + weightHeight = weightDataShape.getDimSize(1); + weightWidth = weightDataShape.getDimSize(2); + } + const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType()); + if (weightScaleShape.hasRank()) { + outShape[3] = ShapedType::isDynamic(outShape[3]) + ? weightScaleShape.getDimSize(0) + : outShape[3]; + weightHeight = ShapedType::isDynamic(weightHeight) + ? weightScaleShape.getDimSize(1) + : weightHeight; + weightWidth = ShapedType::isDynamic(weightWidth) + ? weightScaleShape.getDimSize(2) + : weightWidth; + } + + // Bias shape can describe the output channels. + const ShapeAdaptor biasShape(adaptor.getBias().getType()); + if (biasShape.hasRank()) { + const int64_t biasSize = biasShape.getDimSize(0); + // Bias of size 1 may be broadcast + if (biasSize != 1) { + outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3]; + } + } + + SmallVector<int64_t> padValues; + SmallVector<int64_t> strideValues; + SmallVector<int64_t> dilationValues; + if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues) || + !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(), + strideValues) || + !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(), + dilationValues)) { + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); + } + + if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) { + const int64_t inputSize = inputHeight + padValues[0] + padValues[1]; + const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1; + const int64_t unstridedResult = inputSize - filterSize + 1; + outShape[1] = (unstridedResult - 1) / strideValues[0] + 1; + } + + if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) { + const int64_t inputSize = inputWidth + padValues[2] + padValues[3]; + const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1; + const int64_t unstridedResult = inputSize - filterSize + 1; + outShape[2] = (unstridedResult - 1) / strideValues[1] + 1; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult Conv2DBlockScaledOp::verify() { + if (failed(verifySameElementTypes(*this, getInputData().getType(), + getWeightData().getType(), "input_data", + "weight_data")) || + failed(verifySameElementTypes(*this, getInputScale().getType(), + getWeightScale().getType(), "input_scale", + "weight_scale")) || + failed(verifySameElementTypes(*this, getBias().getType(), + getOutput().getType(), "bias", "output"))) + return failure(); + + // Verify input shape compatibility + int64_t N = ShapedType::kDynamic; + int64_t IH = ShapedType::kDynamic; + int64_t IW = ShapedType::kDynamic; + int64_t IC = ShapedType::kDynamic; + int64_t multiplesOfIC = ShapedType::kDynamic; + int64_t OC = ShapedType::kDynamic; + int64_t KH = ShapedType::kDynamic; + int64_t KW = ShapedType::kDynamic; + + const ShapeAdaptor inputDataShape(getInputData().getType()); + if (inputDataShape.hasRank()) { + N = inputDataShape.getDimSize(0); + IH = inputDataShape.getDimSize(1); + IW = inputDataShape.getDimSize(2); + IC = inputDataShape.getDimSize(3); + } + + const ShapeAdaptor inputScaleShape(getInputScale().getType()); + if (inputScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(*this, N, inputScaleShape.getDimSize(0), + "input_scale", "batch size")) || + failed(tryUpdateDimOrFailure(*this, IH, inputScaleShape.getDimSize(1), + "input_scale", "input height")) || + failed(tryUpdateDimOrFailure(*this, IW, inputScaleShape.getDimSize(2), + "input_scale", "input width"))) + return failure(); + multiplesOfIC = inputScaleShape.getDimSize(3); + } + + const ShapeAdaptor weightDataShape(getWeightData().getType()); + if (weightDataShape.hasRank()) { + OC = weightDataShape.getDimSize(0); + KH = weightDataShape.getDimSize(1); + KW = weightDataShape.getDimSize(2); + if (failed(tryUpdateDimOrFailure(*this, IC, weightDataShape.getDimSize(3), + "weight_data", "input channels"))) + return failure(); + } + + const ShapeAdaptor weightScaleShape(getWeightScale().getType()); + if (weightScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(*this, OC, weightScaleShape.getDimSize(0), + "weight_scale", "output channels")) || + failed(tryUpdateDimOrFailure(*this, KH, weightScaleShape.getDimSize(1), + "weight_scale", "kernel height")) || + failed(tryUpdateDimOrFailure(*this, KW, weightScaleShape.getDimSize(2), + "weight_scale", "kernel width")) || + failed(tryUpdateDimOrFailure(*this, multiplesOfIC, + weightScaleShape.getDimSize(3), + "weight_scale", "input channel blocks"))) + return failure(); + } + + // Verify IC is a multiple of block size + const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize()); + if (ShapedType::isStatic(IC) && IC % blockSize != 0) + return emitOpError("expect IC to be a multiple of block size, got IC=") + << IC << ", block_size=" << blockSize; + + // Verify multiplesOfIC is IC / block size + if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) && + multiplesOfIC != IC / blockSize) + return emitOpError( + "expect scale operands dimension 2 to equal IC/block_size (") + << IC << "/" << blockSize << ")" + << ", got " << multiplesOfIC; + + // Verify pad/stride/dilation values + SmallVector<int64_t> padValues; + if (tosa::getConstShapeValues(getPad().getDefiningOp(), padValues)) { + if (llvm::any_of(padValues, [](int64_t p) { return p < 0; })) + return emitOpError("expect all padding values to be >= 0, got ") + << padValues; + } + + SmallVector<int64_t> strideValues; + if (tosa::getConstShapeValues(getStride().getDefiningOp(), strideValues)) { + if (llvm::any_of(strideValues, [](int64_t s) { return s < 1; })) + return emitOpError("expect all stride values to be >= 1, got ") + << strideValues; + } + + SmallVector<int64_t> dilationValues; + if (tosa::getConstShapeValues(getDilation().getDefiningOp(), + dilationValues)) { + if (llvm::any_of(dilationValues, [](int64_t d) { return d < 1; })) + return emitOpError("expect all dilation values to be >= 1, got ") + << dilationValues; + } + + // Verify output shape compatibility + const ShapeAdaptor outputShape(getOutput().getType()); + if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() && + outputShape.hasRank()) { + if (failed(verifyConvOutputSize(*this, IH, KH, outputShape.getDimSize(1), + padValues[0], padValues[1], strideValues[0], + dilationValues[0], "height", "y", "top", + "bottom")) || + failed(verifyConvOutputSize(*this, IW, KW, outputShape.getDimSize(2), + padValues[2], padValues[3], strideValues[1], + dilationValues[1], "width", "x", "left", + "right"))) + return failure(); + } + + // Verify bias + const ShapeAdaptor biasShape(getBias().getType()); + if (biasShape.hasRank() && outputShape.hasRank()) { + const int64_t biasChannels = biasShape.getDimSize(0); + const int64_t outputChannels = + outputShape.getDimSize(outputShape.getRank() - 1); + if (biasChannels == ShapedType::kDynamic || + outputChannels == ShapedType::kDynamic) + // Skip following checks if biasChannels or outputChannels is dynamic dim + return success(); + + if (biasChannels != outputChannels && biasChannels != 1) + return emitOpError( + "bias channels expected to be equal to output channels (") + << outputChannels << ") or 1, got " << biasChannels; + } + + return success(); +} + LogicalResult Conv3DOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, Conv3DOp::Adaptor adaptor, @@ -3622,10 +3886,10 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents( // Bias shape can describe the output channels. ShapeAdaptor biasShape(adaptor.getBias().getType()); - if (biasShape.hasRank()) { - outputShape[3] = ShapedType::isDynamic(outputShape[3]) - ? biasShape.getDimSize(0) - : outputShape[3]; + if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) { + int64_t bc = biasShape.getDimSize(0); + if (bc != ShapedType::kDynamic && bc != 1) + outputShape[3] = bc; } llvm::ArrayRef<int64_t> dilation = adaptor.getDilation(); @@ -3689,11 +3953,11 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents( } // Bias shape can describe the output channels. - ShapeAdaptor biasShape(adaptor.getInput().getType()); - if (biasShape.hasRank()) { - outputShape[3] = ShapedType::isDynamic(outputShape[3]) - ? biasShape.getDimSize(0) - : outputShape[3]; + ShapeAdaptor biasShape(adaptor.getBias().getType()); + if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) { + int64_t bc = biasShape.getDimSize(0); + if (bc != ShapedType::kDynamic && bc != 1) + outputShape[3] = bc; } llvm::ArrayRef<int64_t> padding = adaptor.getOutPad(); @@ -3730,14 +3994,13 @@ LogicalResult TransposeConv2DOp::verify() { << strides << "]"; const auto checkPadAgainstKernelDim = - [this](int64_t pad_value, int64_t kernel_dim_size, - llvm::StringRef pad_name, - llvm::StringRef kernel_dim_name) -> LogicalResult { - if (pad_value <= -kernel_dim_size) + [this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName, + llvm::StringRef kernelDimName) -> LogicalResult { + if (padValue <= -kernelDimSize) return emitOpError("expected ") - << pad_name << " > -" << kernel_dim_name - << ", but got: " << pad_name << "=" << pad_value << " and " - << kernel_dim_name << "=" << kernel_dim_size; + << padName << " > -" << kernelDimName << ", but got: " << padName + << "=" << padValue << " and " << kernelDimName << "=" + << kernelDimSize; return success(); }; @@ -3976,8 +4239,8 @@ LogicalResult CastFromBlockScaledOp::verify() { const Type outputDataType = getResult().getType(); if (failed(verifyCompatibleShape(inputDataType, outputDataType))) return emitOpError() << "require compatible shapes for input_data (" - << inputDataType << ") and " - << "output_data (" << outputDataType << ")"; + << inputDataType << ") and " << "output_data (" + << outputDataType << ")"; const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); @@ -4004,10 +4267,10 @@ LogicalResult CastFromBlockScaledOp::verify() { failed(verifyCompatibleShape( ArrayRef<int64_t>(inputDataDims).drop_back(1), ArrayRef<int64_t>(inputScaleDims).drop_back(1)))) - return emitOpError() << "require compatible shapes for input_data (" - << inputDataType << ") and " - << "input_scale (" << inputScaleType - << ") except for the last dimension"; + return emitOpError() + << "require compatible shapes for input_data (" << inputDataType + << ") and " << "input_scale (" << inputScaleType + << ") except for the last dimension"; const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize, inputScaleDims.back()}; @@ -4052,8 +4315,8 @@ LogicalResult CastToBlockScaledOp::verify() { const Type outputDataType = getResult(0).getType(); if (failed(verifyCompatibleShape(inputDataType, outputDataType))) return emitOpError() << "require compatible shapes for input_data (" - << inputDataType << ") and " - << "output_data (" << outputDataType << ")"; + << inputDataType << ") and " << "output_data (" + << outputDataType << ")"; const unsigned int blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize()); @@ -4082,8 +4345,8 @@ LogicalResult CastToBlockScaledOp::verify() { ArrayRef<int64_t>(outputDataDims).drop_back(1), ArrayRef<int64_t>(outputScaleDims).drop_back(1)))) return emitOpError() << "require compatible shapes for output_data (" - << outputDataType << ") and " - << "output_scale (" << outputScaleType + << outputDataType << ") and " << "output_scale (" + << outputScaleType << ") except for the last dimension"; const int64_t outputDataLastDim = outputDataDims.back(); @@ -4259,9 +4522,9 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) { if (functionType.getNumInputs() != operands.size()) { return parser.emitError(parser.getCurrentLocation()) - << "expected as many input types as operands " - << "(expected " << operands.size() << " got " - << functionType.getNumInputs() << ")"; + << "expected as many input types as operands " << "(expected " + << operands.size() << " got " << functionType.getNumInputs() + << ")"; } // Resolve input operands. @@ -4510,9 +4773,8 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) { if (functionType.getNumInputs() != operands.size()) { return parser.emitError(typeLoc) - << "expected as many input types as operands " - << "(expected " << operands.size() << " got " - << functionType.getNumInputs() << ")"; + << "expected as many input types as operands " << "(expected " + << operands.size() << " got " << functionType.getNumInputs() << ")"; } // Resolve input operands. @@ -4592,24 +4854,9 @@ LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) { return success(); } -LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) { - for (auto type : op->getOperandTypes()) { - if (!mlir::isa<mlir::tosa::shapeType>(type)) { - return op->emitOpError("must have operands with tosa shape type"); - } - } - for (auto type : op->getResultTypes()) { - if (!mlir::isa<mlir::tosa::shapeType>(type)) { - return op->emitOpError("must have result with tosa shape type"); - } - } - return success(); -} - LogicalResult OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) { - if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) || - failed(verifyTosaShapeOperator(op))) + if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1))) return failure(); // delegate function that returns rank of shape type @@ -4653,6 +4900,94 @@ LogicalResult tosa::ConstShapeOp::verify() { return success(); } +LogicalResult tosa::DimOp::verify() { + const tosa::shapeType outShapeType = + cast<tosa::shapeType>(getResult().getType()); + if (outShapeType.getRank() != 1) + return emitOpError("expect output shape type to contain one element, got ") + << outShapeType; + + const ShapeAdaptor inputType(getInput1().getType()); + if (inputType.hasRank()) { + const int64_t inputRank = inputType.getRank(); + const int64_t axis = getAxisAttr().getInt(); + if (axis < 0 || axis >= inputRank) + return emitOpError("expect axis to be in the range [0, ") + << inputRank << "), got " << axis; + } + return success(); +} + +LogicalResult tosa::ConcatShapeOp::verify() { + const tosa::shapeType outShapeType = + cast<tosa::shapeType>(getResult().getType()); + const int64_t outputRank = outShapeType.getRank(); + const Operation::operand_range inputList = getInput(); + + if (inputList.size() == 0) + return emitOpError("requires at least one input shape"); + + if (llvm::any_of(inputList, [](Value v) { + return cast<tosa::shapeType>(v.getType()).getRank() == 0; + })) + return emitOpError("requires all inputs shapes have a rank greater than 0"); + + const int64_t inputsRank = + llvm::accumulate(inputList, 0, [](int64_t acc, const Value &input) { + const tosa::shapeType inShapeType = + cast<tosa::shapeType>(input.getType()); + return acc + inShapeType.getRank(); + }); + if (outputRank != inputsRank) + return emitOpError("requires output shape rank to be equal to the sum of " + "the input shape ranks (") + << inputsRank << "), got " << outputRank; + + return success(); +} + +LogicalResult tosa::SliceShapeOp::verify() { + std::optional<int32_t> start; + DenseIntElementsAttr startAttr; + if (matchPattern(getStart(), m_Constant(&startAttr))) + start = startAttr.getValues<int32_t>()[0]; + if (start && start.value() < 0) + return emitOpError("expected non-negative start index, got ") + << start.value(); + + std::optional<int32_t> size; + DenseIntElementsAttr sizeAttr; + if (matchPattern(getSize(), m_Constant(&sizeAttr))) + size = sizeAttr.getValues<int32_t>()[0]; + if (size && size.value() <= 0) + return emitOpError("expected positive size, got ") << size.value(); + + if (!size) + return success(); + + const tosa::shapeType outShapeType = + cast<tosa::shapeType>(getResult().getType()); + const int64_t outputRank = outShapeType.getRank(); + if (outputRank != size) + return emitOpError( + "expected output type size to be equal to size attribute, got ") + << outputRank << " vs " << size.value(); + + if (!start) + return success(); + + const tosa::shapeType inShapeType = + cast<tosa::shapeType>(getInput().getType()); + const int64_t inputRank = inShapeType.getRank(); + const int64_t sliceSize = start.value() + size.value(); + if (sliceSize > inputRank) + return emitOpError("expected start + size to be less than or equal to " + "input shape rank (") + << inputRank << "), got " << sliceSize; + + return success(); +} + //===----------------------------------------------------------------------===// // TOSA Attribute Definitions. //===----------------------------------------------------------------------===// |
