aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tosa/IR
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Tosa/IR')
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp2
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp553
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp813
3 files changed, 1057 insertions, 311 deletions
diff --git a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
index eb47e85..9f616b2 100644
--- a/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TargetEnv.cpp
@@ -43,6 +43,8 @@ TosaSpecificationVersion getMinVersion(const Extension &extension) {
return TosaSpecificationVersion(1, 0);
case Extension::mxfp:
case Extension::int64:
+ case Extension::mxfp_conv:
+ case Extension::shape:
return TosaSpecificationVersion(1, 1);
case Extension::none:
return TosaSpecificationVersion(0, 0);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 293c6af..bb715a9 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
+#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
@@ -539,7 +540,7 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
}
Attribute newMinValAttr, newMaxValAttr;
@@ -884,38 +885,365 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
SliceDynamicSizeCanonicalization>(context);
}
+struct NonNarrowingCastsOptimization : public OpRewritePattern<tosa::CastOp> {
+ using OpRewritePattern<tosa::CastOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::CastOp castOp,
+ PatternRewriter &rewriter) const override {
+ const Value castInput = castOp.getInput();
+ auto innerCastOp = castInput.getDefiningOp<tosa::CastOp>();
+ if (!innerCastOp)
+ return rewriter.notifyMatchFailure(castOp,
+ "input must be cast operation");
+
+ const Value innerCastInput = innerCastOp.getInput();
+
+ const auto innerInputType =
+ llvm::cast<ShapedType>(innerCastInput.getType());
+ const auto innerOutputType = llvm::cast<ShapedType>(innerCastOp.getType());
+ const auto outerOutputType = llvm::cast<ShapedType>(castOp.getType());
+
+ const SmallVector<ShapedType, 3> types = {innerInputType, innerOutputType,
+ outerOutputType};
+ if (llvm::any_of(types, [](const ShapedType type) {
+ return !type.getElementType().isInteger();
+ }))
+ return rewriter.notifyMatchFailure(castOp,
+ "only integer types are supported");
+
+ // Check inner cast is non-narrowing
+ const unsigned innerInputBitWidth = innerInputType.getElementTypeBitWidth();
+ if (innerInputBitWidth > innerOutputType.getElementTypeBitWidth())
+ return rewriter.notifyMatchFailure(castOp,
+ "inner cast operation is narrowing");
+
+ // Check outer cast is non-narrowing from the inner cast input
+ if (innerInputBitWidth > outerOutputType.getElementTypeBitWidth())
+ return rewriter.notifyMatchFailure(castOp,
+ "outer cast operation is narrowing");
+
+ rewriter.replaceOpWithNewOp<tosa::CastOp>(castOp, outerOutputType,
+ innerCastInput);
+
+ return success();
+ }
+};
+
+void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<NonNarrowingCastsOptimization>(context);
+}
+
//===----------------------------------------------------------------------===//
// Operator Folders.
//===----------------------------------------------------------------------===//
-template <typename IntFolder, typename FloatFolder>
-static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
- DenseElementsAttr rhs,
- RankedTensorType returnTy) {
- if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
- auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
- auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
- if (lETy != rETy)
- return {};
+template <typename Folder>
+static DenseElementsAttr
+binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs, ShapedType returnTy,
+ bool foldDenseValues = false) {
+ if (!lhs || !rhs)
+ return {};
- if (llvm::isa<IntegerType>(lETy)) {
- APInt l = lhs.getSplatValue<APInt>();
- APInt r = rhs.getSplatValue<APInt>();
- auto result = IntFolder()(l, r);
- return DenseElementsAttr::get(returnTy, result);
+ const auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
+ const auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
+ if (lETy != rETy)
+ return {};
+
+ if (lhs.isSplat() && rhs.isSplat()) {
+ if (isa<FloatType>(lETy)) {
+ const APFloat l = lhs.getSplatValue<APFloat>();
+ const APFloat r = rhs.getSplatValue<APFloat>();
+ const auto maybeResult = Folder::fold(l, r);
+ if (failed(maybeResult))
+ return {};
+ return DenseElementsAttr::get(returnTy, maybeResult.value());
}
- if (llvm::isa<FloatType>(lETy)) {
- APFloat l = lhs.getSplatValue<APFloat>();
- APFloat r = rhs.getSplatValue<APFloat>();
- auto result = FloatFolder()(l, r);
- return DenseElementsAttr::get(returnTy, result);
+ if (const auto lIntTy = llvm::dyn_cast<IntegerType>(lETy)) {
+ const APInt l = lhs.getSplatValue<APInt>();
+ const APInt r = rhs.getSplatValue<APInt>();
+ const auto maybeResult = Folder::fold(l, r, lIntTy.isUnsigned());
+ if (failed(maybeResult))
+ return {};
+ return DenseElementsAttr::get(returnTy, maybeResult.value());
+ }
+ }
+
+ if (foldDenseValues) {
+ assert(lETy.isIntOrIndex() &&
+ "Only integer types are currently supported.");
+ SmallVector<APInt> resultValues;
+ for (auto [l, r] :
+ llvm::zip(lhs.getValues<APInt>(), rhs.getValues<APInt>())) {
+ const auto maybeResult = Folder::fold(l, r, false);
+ if (failed(maybeResult))
+ return {};
+ resultValues.push_back(maybeResult.value());
}
+ return DenseElementsAttr::get(returnTy, resultValues);
}
return {};
}
+template <typename Folder>
+static DenseElementsAttr unaryFolder(DenseElementsAttr val, ShapedType returnTy,
+ bool foldDenseValues = false) {
+ if (!val)
+ return {};
+
+ const auto vETy = llvm::cast<ShapedType>(val.getType()).getElementType();
+
+ if (val.isSplat()) {
+ if (const auto vIntTy = llvm::dyn_cast<IntegerType>(vETy)) {
+ const APInt v = val.getSplatValue<APInt>();
+ const auto maybeResult = Folder::fold(v, vIntTy.isUnsigned());
+ if (failed(maybeResult))
+ return {};
+ return DenseElementsAttr::get(returnTy, maybeResult.value());
+ }
+ }
+
+ if (foldDenseValues) {
+ mlir::Type elemTy = val.getElementType();
+ if (elemTy.isIntOrIndex()) {
+ SmallVector<APInt> resultValues;
+ for (auto const &v : val.getValues<APInt>()) {
+ const auto maybeResult = Folder::fold(v, false);
+ if (failed(maybeResult))
+ return {};
+ resultValues.push_back(maybeResult.value());
+ }
+ return DenseElementsAttr::get(returnTy, resultValues);
+ }
+ }
+
+ // Folding arbitrarily sized tensor operations is not supported
+ return {};
+}
+
+struct AddFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ bool overflow;
+ const APInt result =
+ isUnsigned ? lhs.uadd_ov(rhs, overflow) : lhs.sadd_ov(rhs, overflow);
+ if (overflow)
+ return failure();
+ return result;
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs + rhs;
+ }
+};
+
+struct SubFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ bool overflow;
+ const APInt result =
+ isUnsigned ? lhs.usub_ov(rhs, overflow) : lhs.ssub_ov(rhs, overflow);
+ if (overflow)
+ return failure();
+ return result;
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs - rhs;
+ }
+};
+
+struct MulFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+
+ const unsigned originalWidth = lhs.getBitWidth();
+
+ // Check same type
+ if (lhs.getBitWidth() != rhs.getBitWidth()) {
+ return failure();
+ }
+
+ // If either is `0`
+ if (lhs == 0 || rhs == 0)
+ return APInt::getZero(originalWidth);
+
+ bool overflow = false;
+ APInt const result =
+ isUnsigned ? lhs.umul_ov(rhs, overflow) : lhs.smul_ov(rhs, overflow);
+
+ if (overflow)
+ return failure();
+
+ return result.trunc(originalWidth);
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs * rhs;
+ }
+};
+
+static bool signsDiffer(const APInt &a, const APInt &b) {
+ return a.isNegative() != b.isNegative();
+}
+
+template <bool Ceil>
+struct DivFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ bool isUnsigned) {
+ if (lhs.getBitWidth() != rhs.getBitWidth())
+ return failure();
+ if (rhs.isZero())
+ return failure();
+
+ if (isUnsigned) {
+ APInt q{};
+ APInt r{};
+ APInt::udivrem(lhs, rhs, q, r);
+ if (!r.isZero() && Ceil) {
+ return q + 1;
+ }
+ return q;
+ }
+
+ // Signed: start from trunc-toward-zero, then adjust to ceil.
+ bool overflow{false};
+ APInt const q = lhs.sdiv_ov(rhs, overflow);
+ if (overflow)
+ return failure();
+ APInt const r = lhs.srem(rhs);
+
+ if (Ceil && !r.isZero() && !signsDiffer(lhs, rhs)) {
+ // Same sign => exact quotient is positive; trunc is below ceil =>
+ // increment q.
+ return q + 1;
+ }
+ return q;
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs / rhs;
+ }
+};
+
+struct ModFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ bool isUnsigned) {
+ if (lhs.getBitWidth() != rhs.getBitWidth())
+ return failure();
+ if (lhs.isNegative() || (!rhs.isStrictlyPositive()))
+ return failure();
+
+ if (isUnsigned) {
+ return lhs.urem(rhs);
+ }
+
+ return lhs.srem(rhs);
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ auto t = lhs;
+ auto const r = t.mod(rhs);
+ if (llvm::APFloatBase::opStatus::opOK == r) {
+ return t;
+ }
+ return failure();
+ }
+};
+
+struct MaxFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ bool isUnsigned) {
+ if (lhs.getBitWidth() != rhs.getBitWidth())
+ return failure();
+ return lhs.getSExtValue() >= rhs.getSExtValue() ? lhs : rhs;
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs >= rhs ? lhs : rhs;
+ }
+};
+
+struct MinFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ bool isUnsigned) {
+ if (lhs.getBitWidth() != rhs.getBitWidth())
+ return failure();
+ return lhs.getSExtValue() <= rhs.getSExtValue() ? lhs : rhs;
+ }
+
+ static FailureOr<APFloat> fold(const APFloat &lhs, const APFloat &rhs) {
+ return lhs <= rhs ? lhs : rhs;
+ }
+};
+
+struct Exp2FoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
+ auto const numBits = value.getBitWidth();
+ if (isUnsigned) {
+ auto const zextv = value.getZExtValue();
+ if (zextv >= numBits)
+ return failure();
+ return APInt::getOneBitSet(numBits, zextv);
+ }
+ auto const sextv = value.getSExtValue();
+ if (sextv < 0 || sextv >= numBits || (value.isNegative()))
+ return failure();
+ return APInt::getOneBitSet(numBits, sextv);
+ }
+};
+
+struct Log2CeilFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
+ if (!value.isStrictlyPositive())
+ return failure();
+ return APInt(/*numBits=*/value.getBitWidth(), value.ceilLogBase2());
+ }
+};
+
+struct Log2FloorFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &value, bool isUnsigned) {
+ if (!value.isStrictlyPositive())
+ return failure();
+ return APInt(/*numBits=*/value.getBitWidth(), value.logBase2());
+ }
+};
+
+struct GreaterFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ return isUnsigned ? APInt(1, lhs.ugt(rhs)) : APInt(1, lhs.sgt(rhs));
+ }
+
+ static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+ return APInt(1, lhs > rhs);
+ }
+};
+
+struct GreaterEqualFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ return isUnsigned ? APInt(1, lhs.uge(rhs)) : APInt(1, lhs.sge(rhs));
+ }
+
+ static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+ return APInt(1, lhs >= rhs);
+ }
+};
+
+struct EqualFoldAdaptor {
+ static FailureOr<APInt> fold(const APInt &lhs, const APInt &rhs,
+ const bool isUnsigned) {
+ return APInt(1, lhs == rhs);
+ }
+
+ static FailureOr<APInt> fold(const APFloat &lhs, const APFloat &rhs) {
+ return APInt(1, lhs == rhs);
+ }
+};
+
static bool isSplatZero(Type elemType, DenseElementsAttr val) {
if (llvm::isa<FloatType>(elemType))
return val && val.isSplat() && val.getSplatValue<APFloat>().isZero();
@@ -962,8 +1290,7 @@ OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::plus<APInt>, std::plus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return binaryFolder<AddFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult ArgMaxOp::fold(FoldAdaptor adaptor) {
@@ -992,7 +1319,8 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
if (lhsTy != rhsTy)
return {};
- // IntDivOp inputs must be integer type, no need to check for quantized type
+ // IntDivOp inputs must be integer type, no need to check for quantized
+ // type
auto resultETy = resultTy.getElementType();
auto lhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput1());
@@ -1015,8 +1343,12 @@ OpFoldResult IntDivOp::fold(FoldAdaptor adaptor) {
APInt l = lhsAttr.getSplatValue<APInt>();
APInt r = rhsAttr.getSplatValue<APInt>();
if (!r.isZero()) {
- APInt result = l.sdiv(r);
- return DenseElementsAttr::get(resultTy, result);
+ auto intTy = dyn_cast<mlir::IntegerType>(resultETy);
+ auto const result =
+ DivFoldAdaptor</*Ceil*/ false>::fold(l, r, intTy.isUnsigned());
+ if (failed(result))
+ return {};
+ return DenseElementsAttr::get(resultTy, result.value());
}
}
@@ -1028,13 +1360,18 @@ namespace {
// return nullopt if result is not in range of int32_t when shift > 0
std::optional<APInt> mulInt(APInt lhs, APInt rhs, int32_t shift,
unsigned bitwidth) {
- APInt result = lhs.sext(64) * rhs.sext(64);
+ bool overflow = false;
+ APInt result = lhs.sext(64).smul_ov(rhs.sext(64), overflow);
+
+ if (overflow)
+ return std::nullopt;
if (shift > 0) {
auto round = APInt(64, 1) << (shift - 1);
result += round;
result.ashrInPlace(shift);
- // REQUIRE(product >= minimum_s<i32_t>() && product <= maximum_s<i32_t>())
+ // REQUIRE(product >= minimum_s<i32_t>() && product <=
+ // maximum_s<i32_t>())
if (!(result.getSExtValue() >= INT32_MIN &&
result.getSExtValue() <= INT32_MAX)) {
// REQUIRE failed
@@ -1090,8 +1427,8 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
auto rhsAttr =
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
- // Result right shift on i32_t data type only. For simplification, synthesize
- // a zero shift for other data type.
+ // Result right shift on i32_t data type only. For simplification,
+ // synthesize a zero shift for other data type.
int32_t shift = 0;
if (resultETy.isInteger(32)) {
ElementsAttr shift_elem;
@@ -1144,38 +1481,9 @@ OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<std::minus<APInt>, std::minus<APFloat>>(lhsAttr, rhsAttr,
- resultTy);
+ return binaryFolder<SubFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
-namespace {
-template <typename Cmp>
-struct ComparisonFold {
- ComparisonFold() = default;
- APInt operator()(const APInt &l, const APInt &r) {
- return APInt(1, Cmp()(l, r));
- }
-
- APInt operator()(const APFloat &l, const APFloat &r) {
- return APInt(1, Cmp()(l, r));
- }
-};
-
-struct APIntFoldGreater {
- APIntFoldGreater() = default;
- APInt operator()(const APInt &l, const APInt &r) {
- return APInt(1, l.sgt(r));
- }
-};
-
-struct APIntFoldGreaterEqual {
- APIntFoldGreaterEqual() = default;
- APInt operator()(const APInt &l, const APInt &r) {
- return APInt(1, l.sge(r));
- }
-};
-} // namespace
-
OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
auto lhsAttr =
@@ -1186,8 +1494,7 @@ OpFoldResult GreaterOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreater, ComparisonFold<std::greater<APFloat>>>(
- lhsAttr, rhsAttr, resultTy);
+ return binaryFolder<GreaterFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
@@ -1200,9 +1507,7 @@ OpFoldResult GreaterEqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<APIntFoldGreaterEqual,
- ComparisonFold<std::greater_equal<APFloat>>>(
- lhsAttr, rhsAttr, resultTy);
+ return binaryFolder<GreaterEqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
@@ -1215,8 +1520,8 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
Value rhs = getInput2();
auto lhsTy = llvm::cast<ShapedType>(lhs.getType());
- // If we are comparing an integer value to itself it is always true. We can
- // not do this with float due to float values.
+ // If we are comparing an integer value to itself it is always true. We
+ // can not do this with float due to float values.
if (llvm::isa<IntegerType>(lhsTy.getElementType()) && resultTy &&
resultTy.hasStaticShape() && lhs == rhs) {
return DenseElementsAttr::get(resultTy, true);
@@ -1225,9 +1530,7 @@ OpFoldResult EqualOp::fold(FoldAdaptor adaptor) {
if (!lhsAttr || !rhsAttr)
return {};
- return binaryFolder<ComparisonFold<std::equal_to<APInt>>,
- ComparisonFold<std::equal_to<APFloat>>>(lhsAttr, rhsAttr,
- resultTy);
+ return binaryFolder<EqualFoldAdaptor>(lhsAttr, rhsAttr, resultTy);
}
OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
@@ -1329,9 +1632,9 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (!inputTy || !outputTy)
return {};
- // Fold when the input and output types are the same. This is only safe when
- // there is at most 1 dynamic dimension. For 2 or more dynamic dimensions,
- // there may still be a productive reshape.
+ // Fold when the input and output types are the same. This is only safe
+ // when there is at most 1 dynamic dimension. For 2 or more dynamic
+ // dimensions, there may still be a productive reshape.
if (inputTy == outputTy && inputTy.getNumDynamicDims() < 2)
return getInput1();
@@ -1485,7 +1788,24 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
return {};
}
+static bool
+mayRequireBroadcast(ValueTypeRange<mlir::OperandRange> operandTypes) {
+ const auto isDynamic = [](Type ty) {
+ const auto shapedTy = llvm::dyn_cast<ShapedType>(ty);
+ return !shapedTy || !shapedTy.hasStaticShape();
+ };
+
+ return llvm::any_of(operandTypes, isDynamic) ||
+ failed(verifyCompatibleShapes(operandTypes));
+}
+
OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) {
+ // Select allows operand shapes to be broadcast to the output shape. For
+ // now, don't support folding when we cannot prove no broadcasting is
+ // involved.
+ if (mayRequireBroadcast(getOperandTypes()))
+ return {};
+
if (getOnTrue() == getOnFalse())
return getOnTrue();
@@ -1632,3 +1952,92 @@ OpFoldResult tosa::ReciprocalOp::fold(FoldAdaptor adaptor) {
return {};
}
+
+template <typename Op, typename OpFoldAdaptor>
+OpFoldResult unaryShapeFold(Op *op) {
+ auto input1ConstShape =
+ dyn_cast<tosa::ConstShapeOp>(op->getInput().getDefiningOp());
+ if (!input1ConstShape)
+ return {};
+
+ const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
+
+ return unaryFolder<OpFoldAdaptor>(input1Attr, input1Attr.getType(),
+ /*foldDenseValues=*/true);
+}
+
+template <typename Op, typename OpFoldAdaptor>
+OpFoldResult binaryFold(Op *op) {
+ auto input1ConstShape =
+ dyn_cast<tosa::ConstShapeOp>(op->getInput1().getDefiningOp());
+ auto input2ConstShape =
+ dyn_cast<tosa::ConstShapeOp>(op->getInput2().getDefiningOp());
+ if (!input1ConstShape || !input2ConstShape)
+ return {};
+
+ const auto input1Attr = cast<DenseElementsAttr>(input1ConstShape.getValues());
+ const auto input2Attr = cast<DenseElementsAttr>(input2ConstShape.getValues());
+
+ return binaryFolder<OpFoldAdaptor>(input1Attr, input2Attr,
+ input1Attr.getType(),
+ /*foldDenseValues=*/true);
+}
+
+OpFoldResult tosa::DimOp::fold(FoldAdaptor adaptor) {
+ const auto inputTy = llvm::dyn_cast<ShapedType>(getInput1().getType());
+ if (!inputTy || !inputTy.hasRank())
+ return {};
+ const int32_t axis = getAxis();
+ const int64_t dimSize = inputTy.getDimSize(axis);
+ if (ShapedType::isDynamic(dimSize))
+ return {};
+
+ OpBuilder builder(getContext());
+ const auto resultAttrTy =
+ RankedTensorType::get(/*rank=*/1, builder.getIndexType());
+ return DenseElementsAttr::get(resultAttrTy, dimSize);
+}
+
+OpFoldResult tosa::AddShapeOp::fold(FoldAdaptor adaptor) {
+ return binaryFold<AddShapeOp, AddFoldAdaptor>(this);
+}
+
+OpFoldResult tosa::SubShapeOp::fold(FoldAdaptor adaptor) {
+ return binaryFold<SubShapeOp, SubFoldAdaptor>(this);
+}
+
+OpFoldResult tosa::MulShapeOp::fold(FoldAdaptor adaptor) {
+ return binaryFold<MulShapeOp, MulFoldAdaptor>(this);
+}
+
+OpFoldResult tosa::DivCeilShapeOp::fold(FoldAdaptor adaptor) {
+ return binaryFold<DivCeilShapeOp, DivFoldAdaptor</*Ceil*/ true>>(this);
+}
+
+OpFoldResult tosa::DivFloorShapeOp::fold(FoldAdaptor adaptor) {
+ return binaryFold<DivFloorShapeOp, DivFoldAdaptor</*Ceil*/ false>>(this);
+}
+
+OpFoldResult tosa::ModShapeOp::fold(FoldAdaptor adaptor) {
+ return binaryFold<ModShapeOp, ModFoldAdaptor>(this);
+}
+
+OpFoldResult tosa::MaxShapeOp::fold(FoldAdaptor adaptor) {
+ return binaryFold<MaxShapeOp, MaxFoldAdaptor>(this);
+}
+
+OpFoldResult tosa::MinShapeOp::fold(FoldAdaptor adaptor) {
+ return binaryFold<MinShapeOp, MinFoldAdaptor>(this);
+}
+
+OpFoldResult tosa::Exp2ShapeOp::fold(FoldAdaptor adaptor) {
+ return unaryShapeFold<Exp2ShapeOp, Exp2FoldAdaptor>(this);
+}
+
+OpFoldResult tosa::Log2CeilShapeOp::fold(FoldAdaptor adaptor) {
+ return unaryShapeFold<Log2CeilShapeOp, Log2CeilFoldAdaptor>(this);
+}
+
+OpFoldResult tosa::Log2FloorShapeOp::fold(FoldAdaptor adaptor) {
+ return unaryShapeFold<Log2FloorShapeOp, Log2FloorFoldAdaptor>(this);
+}
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0aff67f..798fc36 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -19,6 +19,7 @@
#include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Utils/VerificationUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Matchers.h"
@@ -26,6 +27,7 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/APFloat.h"
+#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include <numeric>
@@ -134,9 +136,9 @@ SmallVector<Region *> tosa::WhileOp::getLoopRegions() {
//===----------------------------------------------------------------------===//
static SmallVector<int64_t> convertToMlirShape(ArrayRef<int64_t> shape) {
- return to_vector(llvm::map_range(shape, [](int64_t dim) {
+ return map_to_vector(shape, [](int64_t dim) {
return dim == -1 ? ShapedType::kDynamic : dim;
- }));
+ });
}
// returns type of variable op
@@ -550,6 +552,15 @@ void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
printWithEnumHandling(parser, *this);
}
+ParseResult Conv2DBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void Conv2DBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -563,7 +574,7 @@ static std::optional<int64_t> idivCheck(const int64_t lhs, const int64_t rhs) {
static Type getStorageElementTypeOrSelf(Type type) {
auto srcType = getElementTypeOrSelf(type);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(srcType))
- srcType = quantType.getStorageType();
+ srcType = getStorageElementTypeFromQuantized(quantType);
return srcType;
}
@@ -606,6 +617,61 @@ Value mlir::tosa::createPadConstTensor(OpBuilder &builder, Location loc,
return tosa::ConstOp::create(builder, loc, padConstType, padConstAttr);
}
+unsigned mlir::tosa::getBitWidth(Type type) {
+ if (dyn_cast<tosa::mxint8Type>(type))
+ return 8;
+ return type.getIntOrFloatBitWidth();
+}
+
+// Update dim size if current dim is dynamic, otherwise raise an error if sizes
+// do not match
+LogicalResult tryUpdateDimOrFailure(Operation *op, int64_t &currDim,
+ const int64_t newDim,
+ const StringRef operandName,
+ const StringRef dimName) {
+ if (ShapedType::isDynamic(currDim)) {
+ currDim = newDim;
+ return success();
+ } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+ return op->emitOpError("expected ")
+ << dimName << " of " << operandName << " to match size " << currDim
+ << ", got " << newDim;
+ }
+ return success();
+}
+
+LogicalResult verifyConvOutputSize(
+ Operation *op, const int64_t inputSize, const int64_t kernelSize,
+ const int64_t outputSize, const int64_t padBefore, const int64_t padAfter,
+ const int64_t stride, const int64_t dilation, const llvm::StringRef dimName,
+ const llvm::StringRef dimAxis, const llvm::StringRef padBeforeName,
+ const llvm::StringRef padAfterName) {
+ if (inputSize == ShapedType::kDynamic || kernelSize == ShapedType::kDynamic)
+ return success();
+
+ // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
+
+ const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
+ inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
+ stride);
+ if (!calculatedOutSizeMinusOne.has_value())
+ return op->emitOpError("expected input_")
+ << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
+ << padAfterName << " - (kernel_" << dimName << " - 1) * dilation_"
+ << dimAxis << " to be wholly divisible by stride_" << dimAxis
+ << ", got (" << inputSize << " - 1 + " << padBefore << " + "
+ << padAfter << " - (" << kernelSize << " - 1) * " << dilation
+ << ") / " << stride;
+
+ const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
+ if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
+ return op->emitOpError("calculated output ")
+ << dimName << " did not match expected: "
+ << "calculated=" << calculatedOutSize << ", expected=" << outputSize;
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
@@ -625,16 +691,16 @@ static LogicalResult verifyConvOp(T op) {
bool resultIsFloat = llvm::isa<FloatType>(resultEType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(weightEType))
- weightEType = quantType.getStorageType();
+ weightEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(biasEType))
- biasEType = quantType.getStorageType();
+ biasEType = getStorageElementTypeFromQuantized(quantType);
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
- resultEType = quantType.getStorageType();
+ resultEType = getStorageElementTypeFromQuantized(quantType);
if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
// for now, only enforce bias element type == result element type for
@@ -703,7 +769,7 @@ LogicalResult tosa::ConstOp::verify() {
if (auto result = llvm::dyn_cast<mlir::quant::QuantizedType>(
outputType.getElementType())) {
- if (result.getStorageType() == attrType.getElementType())
+ if (getStorageElementTypeFromQuantized(result) == attrType.getElementType())
return success();
}
@@ -721,32 +787,40 @@ static LogicalResult verifyConvOpModes(T op) {
llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(inputEType))
- inputEType = quantType.getStorageType();
+ inputEType = getStorageElementTypeFromQuantized(quantType);
auto accType = op.getAccType();
if (inputEType.isInteger(8) && !accType.isInteger(32))
- return op.emitOpError("accumulator type for i8 tensor is not i32");
+ return op.emitOpError("accumulator type for i8 tensor is not i32, got ")
+ << accType;
if (inputEType.isInteger(16) && !accType.isInteger(48))
- return op.emitOpError("accumulator type for i16 tensor is not i48");
+ return op.emitOpError("accumulator type for i16 tensor is not i48, got ")
+ << accType;
- if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) && !accType.isF16())
- return op.emitOpError("accumulator type for f8 tensor is not f16");
+ if (isa<Float8E5M2Type, Float8E4M3Type>(inputEType) &&
+ !(accType.isF16() || accType.isF32()))
+ return op.emitOpError("accumulator type for f8 tensor is not f16/f32, got ")
+ << accType;
if (inputEType.isF16() && !(accType.isF16() || accType.isF32()))
- return op.emitOpError("accumulator type for f16 tensor is not f16/f32");
+ return op.emitOpError(
+ "accumulator type for f16 tensor is not f16/f32, got ")
+ << accType;
if (inputEType.isBF16() && !accType.isF32())
- return op.emitOpError("accumulator type for bf16 tensor is not f32");
+ return op.emitOpError("accumulator type for bf16 tensor is not f32, got ")
+ << accType;
if (inputEType.isF32() && !accType.isF32())
- return op.emitOpError("accumulator type for f32 tensor is not f32");
+ return op.emitOpError("accumulator type for f32 tensor is not f32, got ")
+ << accType;
auto resultEType =
llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
if (auto quantType = llvm::dyn_cast<mlir::quant::QuantizedType>(resultEType))
- resultEType = quantType.getStorageType();
+ resultEType = getStorageElementTypeFromQuantized(quantType);
return success();
}
@@ -785,53 +859,16 @@ static LogicalResult verifyConvOpErrorIf(T op) {
llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
if (inputType && weightType) {
- const auto verifyOutputSize =
- [&op](const int64_t inputSize, const int64_t kernelSize,
- const int64_t outputSize, const int64_t padBefore,
- const int64_t padAfter, const int64_t stride,
- const int64_t dilation, const llvm::StringRef dimName,
- const llvm::StringRef dimAxis,
- const llvm::StringRef padBeforeName,
- const llvm::StringRef padAfterName) -> LogicalResult {
- if (inputSize == ShapedType::kDynamic ||
- kernelSize == ShapedType::kDynamic)
- return success();
-
- // ERROR_IF: O != idiv_check(I - 1 + pa + pb - (K - 1) * d, s) + 1
-
- const std::optional<int64_t> calculatedOutSizeMinusOne = idivCheck(
- inputSize - 1 + padBefore + padAfter - (kernelSize - 1) * dilation,
- stride);
- if (!calculatedOutSizeMinusOne.has_value())
- return op.emitOpError("expected input_")
- << dimName << " - 1 + pad_" << padBeforeName << " + pad_"
- << padAfterName << " - (kernel_" << dimName
- << " - 1) * dilation_" << dimAxis
- << " to be wholly divisible by stride_" << dimAxis << ", got ("
- << inputSize << " - 1 + " << padBefore << " + " << padAfter
- << " - (" << kernelSize << " - 1) * " << dilation << ") / "
- << stride;
-
- const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
- if (outputSize != ShapedType::kDynamic && calculatedOutSize != outputSize)
- return op.emitOpError("calculated output ")
- << dimName << " did not match expected: "
- << "calculated=" << calculatedOutSize
- << ", expected=" << outputSize;
-
- return success();
- };
-
// input = [_,IH,IW,_], weight = [_,KH,KW,_], output = [_,OH,OW,_]
if constexpr (std::is_same<T, tosa::Conv2DOp>::value) {
- if (failed(verifyOutputSize(
- inputType.getDimSize(1), weightType.getDimSize(1),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(1), weightType.getDimSize(1),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "height", "y", "top", "bottom")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(2), weightType.getDimSize(2),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(2), weightType.getDimSize(2),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "width", "x", "left", "right")))
return failure();
@@ -839,14 +876,14 @@ static LogicalResult verifyConvOpErrorIf(T op) {
// input = [_,IH,IW,_], weight = [KH,KW,_,_], output = [_,OH,OW,_]
if constexpr (std::is_same<T, tosa::DepthwiseConv2DOp>::value) {
- if (failed(verifyOutputSize(
- inputType.getDimSize(1), weightType.getDimSize(0),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(1), weightType.getDimSize(0),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "height", "y", "top", "bottom")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(2), weightType.getDimSize(1),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(2), weightType.getDimSize(1),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "width", "x", "left", "right")))
return failure();
@@ -854,20 +891,20 @@ static LogicalResult verifyConvOpErrorIf(T op) {
// input = [_,ID,IH,IW,_], weight = [_,KD,KH,KW,_], output = [_,OD,OH,OW,_]
if constexpr (std::is_same<T, tosa::Conv3DOp>::value) {
- if (failed(verifyOutputSize(
- inputType.getDimSize(1), weightType.getDimSize(1),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(1), weightType.getDimSize(1),
outputType.getDimSize(1), padding[0], padding[1], strides[0],
dilations[0], "depth", "d", "front", "back")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(2), weightType.getDimSize(2),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(2), weightType.getDimSize(2),
outputType.getDimSize(2), padding[2], padding[3], strides[1],
dilations[1], "height", "y", "top", "bottom")))
return failure();
- if (failed(verifyOutputSize(
- inputType.getDimSize(3), weightType.getDimSize(3),
+ if (failed(verifyConvOutputSize(
+ op, inputType.getDimSize(3), weightType.getDimSize(3),
outputType.getDimSize(3), padding[4], padding[5], strides[2],
dilations[2], "width", "x", "left", "right")))
return failure();
@@ -1105,9 +1142,8 @@ static LogicalResult verifyPoolingOp(T op) {
const int64_t calculatedOutSize = calculatedOutSizeMinusOne.value() + 1;
if (ShapedType::isStatic(outputSize) && calculatedOutSize != outputSize)
return op.emitOpError("calculated output ")
- << dimName << " did not match expected: "
- << "calculated=" << calculatedOutSize
- << ", expected=" << outputSize;
+ << dimName << " did not match expected: " << "calculated="
+ << calculatedOutSize << ", expected=" << outputSize;
return success();
};
@@ -1173,13 +1209,13 @@ LogicalResult tosa::ClampOp::verify() {
llvm::cast<ShapedType>(getInput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy)) {
- inputETy = quantType.getStorageType();
+ inputETy = getStorageElementTypeFromQuantized(quantType);
}
mlir::Type outputETy =
llvm::cast<ShapedType>(getOutput().getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputETy)) {
- outputETy = quantType.getStorageType();
+ outputETy = getStorageElementTypeFromQuantized(quantType);
}
if (inputETy != outputETy)
return emitOpError("input/output element types are incompatible.");
@@ -1445,6 +1481,19 @@ static void buildVariableOp(OpBuilder &builder, OperationState &result,
//===----------------------------------------------------------------------===//
// TOSA Operator Return Type Inference.
//===----------------------------------------------------------------------===//
+static FailureOr<int64_t> resolveBroadcastDim(const int64_t dim1,
+ const int64_t dim2) {
+ if (dim1 == 1)
+ return dim2;
+ if (dim2 == 1)
+ return dim1;
+
+ if (ShapedType::isStatic(dim1) && ShapedType::isStatic(dim2) && dim1 != dim2)
+ return failure();
+
+ // Prefer static dimension over dynamic
+ return ShapedType::isDynamic(dim1) ? dim2 : dim1;
+}
static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
SmallVector<int64_t> &outShape) {
@@ -1468,15 +1517,12 @@ static LogicalResult resolveBroadcastShape(const ValueShapeRange &operands,
for (size_t i = 0, e = shape.getRank(); i < e; ++i) {
auto dim1 = outShape[i + rankDiff];
auto dim2 = shape.getDimSize(i);
- auto resolvedDim = dim1;
- if (dim1 == 1) {
- resolvedDim = dim2;
- } else if (dim2 == 1) {
- resolvedDim = dim1;
- } else if (dim1 != dim2) {
+ const FailureOr<int64_t> maybeResolvedDim =
+ resolveBroadcastDim(dim1, dim2);
+ if (failed(maybeResolvedDim))
return failure();
- }
+ const int64_t resolvedDim = *maybeResolvedDim;
outShape[i + rankDiff] = resolvedDim;
}
}
@@ -1755,6 +1801,11 @@ LogicalResult tosa::ConcatOp::verify() {
}
}
+ const ShapeAdaptor outputShape(outType);
+ if (outputShape.hasRank() && outputShape.getRank() != firstInputRank)
+ return emitOpError("expect output rank to match inputs rank, got ")
+ << outputShape.getRank() << " vs " << firstInputRank;
+
// ERROR_IF(axis_sum != shape[axis]);
int64_t axisSum = 0;
for (const auto &input : inputList) {
@@ -1766,7 +1817,7 @@ LogicalResult tosa::ConcatOp::verify() {
}
axisSum += inputShape.getDimSize(axis);
}
- const ShapeAdaptor outputShape(outType);
+
if (axisSum >= 0 && outputShape.hasRank() &&
!outputShape.isDynamicDim(axis) &&
axisSum != outputShape.getDimSize(axis))
@@ -1943,20 +1994,6 @@ LogicalResult MatmulTBlockScaledOp::verify() {
"B_data")))
return failure();
- auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
- const StringRef operandName,
- const StringRef dimName) -> LogicalResult {
- if (ShapedType::isDynamic(currDim)) {
- currDim = newDim;
- return success();
- } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
- return emitOpError("expected ")
- << dimName << " of " << operandName << " to match size " << currDim
- << ", got " << newDim;
- }
- return success();
- };
-
// Verify input shape compatibility
int64_t N = ShapedType::kDynamic;
int64_t D = ShapedType::kDynamic;
@@ -1974,32 +2011,33 @@ LogicalResult MatmulTBlockScaledOp::verify() {
const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
if (aScaleShape.hasRank()) {
- if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
- "batch")) ||
- failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
- "height")))
+ if (failed(tryUpdateDimOrFailure(*this, N, aScaleShape.getDimSize(0),
+ "a_scale", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, H, aScaleShape.getDimSize(1),
+ "a_scale", "height")))
return failure();
multiplesOfC = aScaleShape.getDimSize(2);
}
const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
if (bDataShape.hasRank()) {
- if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
- "batch")) ||
- failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
- "channels")))
+ if (failed(tryUpdateDimOrFailure(*this, D, bDataShape.getDimSize(0),
+ "b_data", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, C, bDataShape.getDimSize(2),
+ "b_data", "channels")))
return failure();
W = bDataShape.getDimSize(1);
}
const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
if (bScaleShape.hasRank()) {
- if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
- "batch")) ||
- failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
- "width")) ||
- failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
- "b_scale", "C/block_size")))
+ if (failed(tryUpdateDimOrFailure(*this, D, bScaleShape.getDimSize(0),
+ "b_scale", "batch")) ||
+ failed(tryUpdateDimOrFailure(*this, W, bScaleShape.getDimSize(1),
+ "b_scale", "width")) ||
+ failed(tryUpdateDimOrFailure(*this, multiplesOfC,
+ bScaleShape.getDimSize(2), "b_scale",
+ "C/block_size")))
return failure();
}
@@ -2020,8 +2058,7 @@ LogicalResult MatmulTBlockScaledOp::verify() {
multiplesOfC != C / blockSize)
return emitOpError(
"expect scale operands dimension 2 to equal C/block_size (")
- << C << "/" << blockSize << ")"
- << ", got " << multiplesOfC;
+ << C << "/" << blockSize << ")" << ", got " << multiplesOfC;
// Verify output shape
N = ShapedType::isDynamic(N) ? D : N;
@@ -2115,17 +2152,14 @@ LogicalResult tosa::PadOp::verify() {
if (!inputType || !outputType)
return success();
- auto inputRank = inputType.getRank();
- auto outputRank = outputType.getRank();
- if (inputRank != outputRank)
- return emitOpError() << "expect same input and output tensor rank, but got "
- << "inputRank: " << inputRank
- << ", outputRank: " << outputRank;
+ if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
+ "output")))
+ return failure();
+ auto inputRank = inputType.getRank();
DenseIntElementsAttr paddingAttr;
- if (!matchPattern(getPadding(), m_Constant(&paddingAttr))) {
- return failure();
- }
+ if (!matchPattern(getPadding(), m_Constant(&paddingAttr)))
+ return success();
auto paddingValues = paddingAttr.getValues<APInt>();
if (paddingValues.size() != static_cast<size_t>(inputRank * 2))
@@ -2290,9 +2324,9 @@ LogicalResult tosa::MulOp::verify() {
}
// verify shift has value 0 for non-integer types
- ElementsAttr shift_elem;
- if (matchPattern(getShift(), m_Constant(&shift_elem))) {
- int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ ElementsAttr shiftElem;
+ if (matchPattern(getShift(), m_Constant(&shiftElem))) {
+ int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
if (shift != 0) {
return emitOpError() << "require shift to be 0 for float type";
}
@@ -2362,9 +2396,9 @@ LogicalResult tosa::TableOp::verify() {
if (!inputType.hasRank() || !outputType.hasRank())
return success();
- if (inputType.getRank() != outputType.getRank())
- return emitOpError()
- << "expected input tensor rank to equal result tensor rank";
+ if (failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
+ "result")))
+ return failure();
auto inputDims = inputType.getShape();
auto outputDims = outputType.getShape();
@@ -2386,9 +2420,9 @@ tosa::TileOp::getConstantMultiples(SmallVector<int64_t> &multiples) {
DenseIntElementsAttr multiplesAttr;
if (!matchPattern(getMultiples(), m_Constant(&multiplesAttr)))
return failure();
- multiples = llvm::to_vector(
- llvm::map_range(multiplesAttr.getValues<APInt>(),
- [](const APInt &val) { return val.getSExtValue(); }));
+ multiples =
+ llvm::map_to_vector(multiplesAttr.getValues<APInt>(),
+ [](const APInt &val) { return val.getSExtValue(); });
return success();
}
@@ -2454,8 +2488,10 @@ LogicalResult tosa::TileOp::verify() {
if (inputType.getRank() != multiplesRank)
return emitOpError("expect 'multiples' to have rank ")
<< inputType.getRank() << " but got " << multiplesRank << ".";
- if (outputType.hasRank() && inputType.getRank() != outputType.getRank())
- return emitOpError("expect same input and output tensor rank.");
+ if (outputType.hasRank() &&
+ failed(verifyRanksMatch(getOperation(), inputType, outputType, "input",
+ "output")))
+ return failure();
} else if (outputType.hasRank() && outputType.getRank() != multiplesRank)
return emitOpError("expect 'multiples' array to have length ")
<< outputType.getRank() << " but got " << multiplesRank << ".";
@@ -2622,7 +2658,7 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
if (!zpElemType.isInteger(8) && zp != 0) {
// convert operand to lower case for error message
std::string lower = operand;
- std::transform(lower.begin(), lower.end(), lower.begin(), ::tolower);
+ llvm::transform(lower, lower.begin(), ::tolower);
return op.emitOpError()
<< lower << " zero point must be zero for non-int8 integer types";
}
@@ -2780,8 +2816,8 @@ LogicalResult tosa::TransposeOp::verify() {
return s >= 0 &&
static_cast<size_t>(s) < constantPerms.size();
}) ||
- !isPermutationVector(llvm::to_vector(llvm::map_range(
- constantPerms, [](int32_t v) -> int64_t { return v; }))))
+ !isPermutationVector(llvm::map_to_vector(
+ constantPerms, [](int32_t v) -> int64_t { return v; })))
return emitOpError() << "expected valid permutation indices";
// ERROR_IF(tensor_size(shape1) != tensor_size(shape))
@@ -2871,39 +2907,39 @@ LogicalResult tosa::GatherOp::verify() {
const ShapeAdaptor indicesShape(getIndices().getType());
const ShapeAdaptor outputShape(getOutput().getType());
- int64_t N = ShapedType::kDynamic;
- int64_t W = ShapedType::kDynamic;
- int64_t C = ShapedType::kDynamic;
+ int64_t n = ShapedType::kDynamic;
+ int64_t w = ShapedType::kDynamic;
+ int64_t c = ShapedType::kDynamic;
if (valuesShape.hasRank()) {
- N = valuesShape.getDimSize(0);
- C = valuesShape.getDimSize(2);
+ n = valuesShape.getDimSize(0);
+ c = valuesShape.getDimSize(2);
}
if (indicesShape.hasRank()) {
const int64_t indicesN = indicesShape.getDimSize(0);
- W = indicesShape.getDimSize(1);
- if (N == ShapedType::kDynamic)
- N = indicesN;
- else if (indicesN != ShapedType::kDynamic && N != indicesN)
- return emitOpError() << "requires indices dimension 0 to have size " << N
+ w = indicesShape.getDimSize(1);
+ if (n == ShapedType::kDynamic)
+ n = indicesN;
+ else if (indicesN != ShapedType::kDynamic && n != indicesN)
+ return emitOpError() << "requires indices dimension 0 to have size " << n
<< ", got " << indicesN;
}
if (outputShape.hasRank()) {
const int64_t outputN = outputShape.getDimSize(0);
const int64_t outputW = outputShape.getDimSize(1);
const int64_t outputC = outputShape.getDimSize(2);
- if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
- N != outputN)
- return emitOpError() << "requires output dimension 0 to have size " << N
+ if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+ n != outputN)
+ return emitOpError() << "requires output dimension 0 to have size " << n
<< ", got " << outputN;
- if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
- W != outputW)
- return emitOpError() << "requires output dimension 1 to have size " << W
+ if (w != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
+ w != outputW)
+ return emitOpError() << "requires output dimension 1 to have size " << w
<< ", got " << outputW;
- if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
- C != outputC)
- return emitOpError() << "requires output dimension 2 to have size " << C
+ if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+ c != outputC)
+ return emitOpError() << "requires output dimension 2 to have size " << c
<< ", got " << outputC;
}
return success();
@@ -3096,66 +3132,66 @@ LogicalResult tosa::ScatterOp::verify() {
const ShapeAdaptor inputShape(getInput().getType());
const ShapeAdaptor outputShape(getValuesOut().getType());
- int64_t N = ShapedType::kDynamic;
- int64_t K = ShapedType::kDynamic;
- int64_t W = ShapedType::kDynamic;
- int64_t C = ShapedType::kDynamic;
+ int64_t n = ShapedType::kDynamic;
+ int64_t k = ShapedType::kDynamic;
+ int64_t w = ShapedType::kDynamic;
+ int64_t c = ShapedType::kDynamic;
if (valuesInShape.hasRank()) {
- N = valuesInShape.getDimSize(0);
- K = valuesInShape.getDimSize(1);
- C = valuesInShape.getDimSize(2);
+ n = valuesInShape.getDimSize(0);
+ k = valuesInShape.getDimSize(1);
+ c = valuesInShape.getDimSize(2);
}
if (indicesShape.hasRank()) {
const int64_t indicesN = indicesShape.getDimSize(0);
- W = indicesShape.getDimSize(1);
- if (N == ShapedType::kDynamic)
- N = indicesN;
- else if (indicesN != ShapedType::kDynamic && N != indicesN)
- return emitOpError() << "requires indices dimension 0 to have size " << N
+ w = indicesShape.getDimSize(1);
+ if (n == ShapedType::kDynamic)
+ n = indicesN;
+ else if (indicesN != ShapedType::kDynamic && n != indicesN)
+ return emitOpError() << "requires indices dimension 0 to have size " << n
<< ", got " << indicesN;
}
if (inputShape.hasRank()) {
const int64_t inputN = inputShape.getDimSize(0);
const int64_t inputW = inputShape.getDimSize(1);
const int64_t inputC = inputShape.getDimSize(2);
- if (N == ShapedType::kDynamic)
- N = inputN;
- else if (inputN != ShapedType::kDynamic && N != inputN)
- return emitOpError() << "requires input dimension 0 to have size " << N
+ if (n == ShapedType::kDynamic)
+ n = inputN;
+ else if (inputN != ShapedType::kDynamic && n != inputN)
+ return emitOpError() << "requires input dimension 0 to have size " << n
<< ", got " << inputN;
- if (W == ShapedType::kDynamic)
- W = inputW;
- else if (inputW != ShapedType::kDynamic && W != inputW)
- return emitOpError() << "requires input dimension 1 to have size " << W
+ if (w == ShapedType::kDynamic)
+ w = inputW;
+ else if (inputW != ShapedType::kDynamic && w != inputW)
+ return emitOpError() << "requires input dimension 1 to have size " << w
<< ", got " << inputW;
- if (C == ShapedType::kDynamic)
- C = inputC;
- else if (inputC != ShapedType::kDynamic && C != inputC)
- return emitOpError() << "requires input dimension 2 to have size " << C
+ if (c == ShapedType::kDynamic)
+ c = inputC;
+ else if (inputC != ShapedType::kDynamic && c != inputC)
+ return emitOpError() << "requires input dimension 2 to have size " << c
<< ", got " << inputC;
}
if (outputShape.hasRank()) {
const int64_t outputN = outputShape.getDimSize(0);
const int64_t outputK = outputShape.getDimSize(1);
const int64_t outputC = outputShape.getDimSize(2);
- if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
- N != outputN)
+ if (n != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+ n != outputN)
return emitOpError() << "requires values_out dimension 0 to have size "
- << N << ", got " << outputN;
- if (K == ShapedType::kDynamic)
- K = outputK;
- else if (outputK != ShapedType::kDynamic && K != outputK)
+ << n << ", got " << outputN;
+ if (k == ShapedType::kDynamic)
+ k = outputK;
+ else if (outputK != ShapedType::kDynamic && k != outputK)
return emitOpError() << "requires values_out dimension 1 to have size "
- << K << ", got " << outputK;
- if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
- C != outputC)
+ << k << ", got " << outputK;
+ if (c != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+ c != outputC)
return emitOpError() << "requires values_out dimension 2 to have size "
- << C << ", got " << outputC;
+ << c << ", got " << outputC;
}
- if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
- return emitOpError() << "requires dimensions K >= W, got K=" << K
- << " and W=" << W;
+ if (k != ShapedType::kDynamic && w != ShapedType::kDynamic && !(k >= w))
+ return emitOpError() << "requires dimensions K >= W, got K=" << k
+ << " and W=" << w;
return success();
}
@@ -3383,7 +3419,7 @@ static LogicalResult poolingInferReturnTypes(
outputShape.resize(4, ShapedType::kDynamic);
// We only know the rank if the input type is unranked.
- if (!inputShape) {
+ if (!inputShape.hasRank()) {
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
return success();
}
@@ -3474,6 +3510,234 @@ LogicalResult Conv2DOp::verify() {
return success();
}
+LogicalResult Conv2DBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ Conv2DBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ SmallVector<int64_t, 4> outShape(4, ShapedType::kDynamic);
+
+ int64_t inputWidth = ShapedType::kDynamic;
+ int64_t inputHeight = ShapedType::kDynamic;
+ int64_t weightWidth = ShapedType::kDynamic;
+ int64_t weightHeight = ShapedType::kDynamic;
+
+ // Input shape describes input width/height and batch.
+ const ShapeAdaptor inputDataShape(adaptor.getInputData().getType());
+ if (inputDataShape.hasRank()) {
+ outShape[0] = inputDataShape.getDimSize(0);
+ inputHeight = inputDataShape.getDimSize(1);
+ inputWidth = inputDataShape.getDimSize(2);
+ }
+ const ShapeAdaptor inputScaleShape(adaptor.getInputScale().getType());
+ if (inputScaleShape.hasRank()) {
+ outShape[0] = ShapedType::isDynamic(outShape[0])
+ ? inputScaleShape.getDimSize(0)
+ : outShape[0];
+ inputHeight = ShapedType::isDynamic(inputHeight)
+ ? inputScaleShape.getDimSize(1)
+ : inputHeight;
+ inputWidth = ShapedType::isDynamic(inputWidth)
+ ? inputScaleShape.getDimSize(2)
+ : inputWidth;
+ }
+
+ // Weight shapes describes the filter width/height and the output channels.
+ const ShapeAdaptor weightDataShape(adaptor.getWeightData().getType());
+ if (weightDataShape.hasRank()) {
+ outShape[3] = weightDataShape.getDimSize(0);
+ weightHeight = weightDataShape.getDimSize(1);
+ weightWidth = weightDataShape.getDimSize(2);
+ }
+ const ShapeAdaptor weightScaleShape(adaptor.getWeightScale().getType());
+ if (weightScaleShape.hasRank()) {
+ outShape[3] = ShapedType::isDynamic(outShape[3])
+ ? weightScaleShape.getDimSize(0)
+ : outShape[3];
+ weightHeight = ShapedType::isDynamic(weightHeight)
+ ? weightScaleShape.getDimSize(1)
+ : weightHeight;
+ weightWidth = ShapedType::isDynamic(weightWidth)
+ ? weightScaleShape.getDimSize(2)
+ : weightWidth;
+ }
+
+ // Bias shape can describe the output channels.
+ const ShapeAdaptor biasShape(adaptor.getBias().getType());
+ if (biasShape.hasRank()) {
+ const int64_t biasSize = biasShape.getDimSize(0);
+ // Bias of size 1 may be broadcast
+ if (biasSize != 1) {
+ outShape[3] = ShapedType::isDynamic(outShape[3]) ? biasSize : outShape[3];
+ }
+ }
+
+ SmallVector<int64_t> padValues;
+ SmallVector<int64_t> strideValues;
+ SmallVector<int64_t> dilationValues;
+ if (!tosa::getConstShapeValues(adaptor.getPad().getDefiningOp(), padValues) ||
+ !tosa::getConstShapeValues(adaptor.getStride().getDefiningOp(),
+ strideValues) ||
+ !tosa::getConstShapeValues(adaptor.getDilation().getDefiningOp(),
+ dilationValues)) {
+ inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+ return success();
+ }
+
+ if (ShapedType::isStatic(inputHeight) && ShapedType::isStatic(weightHeight)) {
+ const int64_t inputSize = inputHeight + padValues[0] + padValues[1];
+ const int64_t filterSize = (weightHeight - 1) * dilationValues[0] + 1;
+ const int64_t unstridedResult = inputSize - filterSize + 1;
+ outShape[1] = (unstridedResult - 1) / strideValues[0] + 1;
+ }
+
+ if (ShapedType::isStatic(inputWidth) && ShapedType::isStatic(weightWidth)) {
+ const int64_t inputSize = inputWidth + padValues[2] + padValues[3];
+ const int64_t filterSize = (weightWidth - 1) * dilationValues[1] + 1;
+ const int64_t unstridedResult = inputSize - filterSize + 1;
+ outShape[2] = (unstridedResult - 1) / strideValues[1] + 1;
+ }
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+ return success();
+}
+
+LogicalResult Conv2DBlockScaledOp::verify() {
+ if (failed(verifySameElementTypes(*this, getInputData().getType(),
+ getWeightData().getType(), "input_data",
+ "weight_data")) ||
+ failed(verifySameElementTypes(*this, getInputScale().getType(),
+ getWeightScale().getType(), "input_scale",
+ "weight_scale")) ||
+ failed(verifySameElementTypes(*this, getBias().getType(),
+ getOutput().getType(), "bias", "output")))
+ return failure();
+
+ // Verify input shape compatibility
+ int64_t N = ShapedType::kDynamic;
+ int64_t IH = ShapedType::kDynamic;
+ int64_t IW = ShapedType::kDynamic;
+ int64_t IC = ShapedType::kDynamic;
+ int64_t multiplesOfIC = ShapedType::kDynamic;
+ int64_t OC = ShapedType::kDynamic;
+ int64_t KH = ShapedType::kDynamic;
+ int64_t KW = ShapedType::kDynamic;
+
+ const ShapeAdaptor inputDataShape(getInputData().getType());
+ if (inputDataShape.hasRank()) {
+ N = inputDataShape.getDimSize(0);
+ IH = inputDataShape.getDimSize(1);
+ IW = inputDataShape.getDimSize(2);
+ IC = inputDataShape.getDimSize(3);
+ }
+
+ const ShapeAdaptor inputScaleShape(getInputScale().getType());
+ if (inputScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, N, inputScaleShape.getDimSize(0),
+ "input_scale", "batch size")) ||
+ failed(tryUpdateDimOrFailure(*this, IH, inputScaleShape.getDimSize(1),
+ "input_scale", "input height")) ||
+ failed(tryUpdateDimOrFailure(*this, IW, inputScaleShape.getDimSize(2),
+ "input_scale", "input width")))
+ return failure();
+ multiplesOfIC = inputScaleShape.getDimSize(3);
+ }
+
+ const ShapeAdaptor weightDataShape(getWeightData().getType());
+ if (weightDataShape.hasRank()) {
+ OC = weightDataShape.getDimSize(0);
+ KH = weightDataShape.getDimSize(1);
+ KW = weightDataShape.getDimSize(2);
+ if (failed(tryUpdateDimOrFailure(*this, IC, weightDataShape.getDimSize(3),
+ "weight_data", "input channels")))
+ return failure();
+ }
+
+ const ShapeAdaptor weightScaleShape(getWeightScale().getType());
+ if (weightScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(*this, OC, weightScaleShape.getDimSize(0),
+ "weight_scale", "output channels")) ||
+ failed(tryUpdateDimOrFailure(*this, KH, weightScaleShape.getDimSize(1),
+ "weight_scale", "kernel height")) ||
+ failed(tryUpdateDimOrFailure(*this, KW, weightScaleShape.getDimSize(2),
+ "weight_scale", "kernel width")) ||
+ failed(tryUpdateDimOrFailure(*this, multiplesOfIC,
+ weightScaleShape.getDimSize(3),
+ "weight_scale", "input channel blocks")))
+ return failure();
+ }
+
+ // Verify IC is a multiple of block size
+ const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (ShapedType::isStatic(IC) && IC % blockSize != 0)
+ return emitOpError("expect IC to be a multiple of block size, got IC=")
+ << IC << ", block_size=" << blockSize;
+
+ // Verify multiplesOfIC is IC / block size
+ if (ShapedType::isStatic(IC) && ShapedType::isStatic(multiplesOfIC) &&
+ multiplesOfIC != IC / blockSize)
+ return emitOpError(
+ "expect scale operands dimension 2 to equal IC/block_size (")
+ << IC << "/" << blockSize << ")"
+ << ", got " << multiplesOfIC;
+
+ // Verify pad/stride/dilation values
+ SmallVector<int64_t> padValues;
+ if (tosa::getConstShapeValues(getPad().getDefiningOp(), padValues)) {
+ if (llvm::any_of(padValues, [](int64_t p) { return p < 0; }))
+ return emitOpError("expect all padding values to be >= 0, got ")
+ << padValues;
+ }
+
+ SmallVector<int64_t> strideValues;
+ if (tosa::getConstShapeValues(getStride().getDefiningOp(), strideValues)) {
+ if (llvm::any_of(strideValues, [](int64_t s) { return s < 1; }))
+ return emitOpError("expect all stride values to be >= 1, got ")
+ << strideValues;
+ }
+
+ SmallVector<int64_t> dilationValues;
+ if (tosa::getConstShapeValues(getDilation().getDefiningOp(),
+ dilationValues)) {
+ if (llvm::any_of(dilationValues, [](int64_t d) { return d < 1; }))
+ return emitOpError("expect all dilation values to be >= 1, got ")
+ << dilationValues;
+ }
+
+ // Verify output shape compatibility
+ const ShapeAdaptor outputShape(getOutput().getType());
+ if (!padValues.empty() && !strideValues.empty() && !dilationValues.empty() &&
+ outputShape.hasRank()) {
+ if (failed(verifyConvOutputSize(*this, IH, KH, outputShape.getDimSize(1),
+ padValues[0], padValues[1], strideValues[0],
+ dilationValues[0], "height", "y", "top",
+ "bottom")) ||
+ failed(verifyConvOutputSize(*this, IW, KW, outputShape.getDimSize(2),
+ padValues[2], padValues[3], strideValues[1],
+ dilationValues[1], "width", "x", "left",
+ "right")))
+ return failure();
+ }
+
+ // Verify bias
+ const ShapeAdaptor biasShape(getBias().getType());
+ if (biasShape.hasRank() && outputShape.hasRank()) {
+ const int64_t biasChannels = biasShape.getDimSize(0);
+ const int64_t outputChannels =
+ outputShape.getDimSize(outputShape.getRank() - 1);
+ if (biasChannels == ShapedType::kDynamic ||
+ outputChannels == ShapedType::kDynamic)
+ // Skip following checks if biasChannels or outputChannels is dynamic dim
+ return success();
+
+ if (biasChannels != outputChannels && biasChannels != 1)
+ return emitOpError(
+ "bias channels expected to be equal to output channels (")
+ << outputChannels << ") or 1, got " << biasChannels;
+ }
+
+ return success();
+}
+
LogicalResult Conv3DOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
Conv3DOp::Adaptor adaptor,
@@ -3622,10 +3886,10 @@ LogicalResult DepthwiseConv2DOp::inferReturnTypeComponents(
// Bias shape can describe the output channels.
ShapeAdaptor biasShape(adaptor.getBias().getType());
- if (biasShape.hasRank()) {
- outputShape[3] = ShapedType::isDynamic(outputShape[3])
- ? biasShape.getDimSize(0)
- : outputShape[3];
+ if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
+ int64_t bc = biasShape.getDimSize(0);
+ if (bc != ShapedType::kDynamic && bc != 1)
+ outputShape[3] = bc;
}
llvm::ArrayRef<int64_t> dilation = adaptor.getDilation();
@@ -3689,11 +3953,11 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
}
// Bias shape can describe the output channels.
- ShapeAdaptor biasShape(adaptor.getInput().getType());
- if (biasShape.hasRank()) {
- outputShape[3] = ShapedType::isDynamic(outputShape[3])
- ? biasShape.getDimSize(0)
- : outputShape[3];
+ ShapeAdaptor biasShape(adaptor.getBias().getType());
+ if (biasShape.hasRank() && ShapedType::isDynamic(outputShape[3])) {
+ int64_t bc = biasShape.getDimSize(0);
+ if (bc != ShapedType::kDynamic && bc != 1)
+ outputShape[3] = bc;
}
llvm::ArrayRef<int64_t> padding = adaptor.getOutPad();
@@ -3730,14 +3994,13 @@ LogicalResult TransposeConv2DOp::verify() {
<< strides << "]";
const auto checkPadAgainstKernelDim =
- [this](int64_t pad_value, int64_t kernel_dim_size,
- llvm::StringRef pad_name,
- llvm::StringRef kernel_dim_name) -> LogicalResult {
- if (pad_value <= -kernel_dim_size)
+ [this](int64_t padValue, int64_t kernelDimSize, llvm::StringRef padName,
+ llvm::StringRef kernelDimName) -> LogicalResult {
+ if (padValue <= -kernelDimSize)
return emitOpError("expected ")
- << pad_name << " > -" << kernel_dim_name
- << ", but got: " << pad_name << "=" << pad_value << " and "
- << kernel_dim_name << "=" << kernel_dim_size;
+ << padName << " > -" << kernelDimName << ", but got: " << padName
+ << "=" << padValue << " and " << kernelDimName << "="
+ << kernelDimSize;
return success();
};
@@ -3976,8 +4239,8 @@ LogicalResult CastFromBlockScaledOp::verify() {
const Type outputDataType = getResult().getType();
if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
return emitOpError() << "require compatible shapes for input_data ("
- << inputDataType << ") and "
- << "output_data (" << outputDataType << ")";
+ << inputDataType << ") and " << "output_data ("
+ << outputDataType << ")";
const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
@@ -4004,10 +4267,10 @@ LogicalResult CastFromBlockScaledOp::verify() {
failed(verifyCompatibleShape(
ArrayRef<int64_t>(inputDataDims).drop_back(1),
ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
- return emitOpError() << "require compatible shapes for input_data ("
- << inputDataType << ") and "
- << "input_scale (" << inputScaleType
- << ") except for the last dimension";
+ return emitOpError()
+ << "require compatible shapes for input_data (" << inputDataType
+ << ") and " << "input_scale (" << inputScaleType
+ << ") except for the last dimension";
const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
inputScaleDims.back()};
@@ -4052,8 +4315,8 @@ LogicalResult CastToBlockScaledOp::verify() {
const Type outputDataType = getResult(0).getType();
if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
return emitOpError() << "require compatible shapes for input_data ("
- << inputDataType << ") and "
- << "output_data (" << outputDataType << ")";
+ << inputDataType << ") and " << "output_data ("
+ << outputDataType << ")";
const unsigned int blockSize =
BlockSizeAttr::getBlockSizeValue(getBlockSize());
@@ -4082,8 +4345,8 @@ LogicalResult CastToBlockScaledOp::verify() {
ArrayRef<int64_t>(outputDataDims).drop_back(1),
ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
return emitOpError() << "require compatible shapes for output_data ("
- << outputDataType << ") and "
- << "output_scale (" << outputScaleType
+ << outputDataType << ") and " << "output_scale ("
+ << outputScaleType
<< ") except for the last dimension";
const int64_t outputDataLastDim = outputDataDims.back();
@@ -4259,9 +4522,9 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(parser.getCurrentLocation())
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs()
+ << ")";
}
// Resolve input operands.
@@ -4510,9 +4773,8 @@ ParseResult WhileOp::parse(OpAsmParser &parser, OperationState &result) {
if (functionType.getNumInputs() != operands.size()) {
return parser.emitError(typeLoc)
- << "expected as many input types as operands "
- << "(expected " << operands.size() << " got "
- << functionType.getNumInputs() << ")";
+ << "expected as many input types as operands " << "(expected "
+ << operands.size() << " got " << functionType.getNumInputs() << ")";
}
// Resolve input operands.
@@ -4592,24 +4854,9 @@ LogicalResult OpTrait::tosa::verifyTosaResolvableShapeOperands(Operation *op) {
return success();
}
-LogicalResult OpTrait::tosa::verifyTosaShapeOperator(Operation *op) {
- for (auto type : op->getOperandTypes()) {
- if (!mlir::isa<mlir::tosa::shapeType>(type)) {
- return op->emitOpError("must have operands with tosa shape type");
- }
- }
- for (auto type : op->getResultTypes()) {
- if (!mlir::isa<mlir::tosa::shapeType>(type)) {
- return op->emitOpError("must have result with tosa shape type");
- }
- }
- return success();
-}
-
LogicalResult
OpTrait::tosa::verifyTosaShapeOperatorWithSameRanks(Operation *op) {
- if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)) ||
- failed(verifyTosaShapeOperator(op)))
+ if (failed(OpTrait::impl::verifyAtLeastNOperands(op, 1)))
return failure();
// delegate function that returns rank of shape type
@@ -4653,6 +4900,94 @@ LogicalResult tosa::ConstShapeOp::verify() {
return success();
}
+LogicalResult tosa::DimOp::verify() {
+ const tosa::shapeType outShapeType =
+ cast<tosa::shapeType>(getResult().getType());
+ if (outShapeType.getRank() != 1)
+ return emitOpError("expect output shape type to contain one element, got ")
+ << outShapeType;
+
+ const ShapeAdaptor inputType(getInput1().getType());
+ if (inputType.hasRank()) {
+ const int64_t inputRank = inputType.getRank();
+ const int64_t axis = getAxisAttr().getInt();
+ if (axis < 0 || axis >= inputRank)
+ return emitOpError("expect axis to be in the range [0, ")
+ << inputRank << "), got " << axis;
+ }
+ return success();
+}
+
+LogicalResult tosa::ConcatShapeOp::verify() {
+ const tosa::shapeType outShapeType =
+ cast<tosa::shapeType>(getResult().getType());
+ const int64_t outputRank = outShapeType.getRank();
+ const Operation::operand_range inputList = getInput();
+
+ if (inputList.size() == 0)
+ return emitOpError("requires at least one input shape");
+
+ if (llvm::any_of(inputList, [](Value v) {
+ return cast<tosa::shapeType>(v.getType()).getRank() == 0;
+ }))
+ return emitOpError("requires all inputs shapes have a rank greater than 0");
+
+ const int64_t inputsRank =
+ llvm::accumulate(inputList, 0, [](int64_t acc, const Value &input) {
+ const tosa::shapeType inShapeType =
+ cast<tosa::shapeType>(input.getType());
+ return acc + inShapeType.getRank();
+ });
+ if (outputRank != inputsRank)
+ return emitOpError("requires output shape rank to be equal to the sum of "
+ "the input shape ranks (")
+ << inputsRank << "), got " << outputRank;
+
+ return success();
+}
+
+LogicalResult tosa::SliceShapeOp::verify() {
+ std::optional<int32_t> start;
+ DenseIntElementsAttr startAttr;
+ if (matchPattern(getStart(), m_Constant(&startAttr)))
+ start = startAttr.getValues<int32_t>()[0];
+ if (start && start.value() < 0)
+ return emitOpError("expected non-negative start index, got ")
+ << start.value();
+
+ std::optional<int32_t> size;
+ DenseIntElementsAttr sizeAttr;
+ if (matchPattern(getSize(), m_Constant(&sizeAttr)))
+ size = sizeAttr.getValues<int32_t>()[0];
+ if (size && size.value() <= 0)
+ return emitOpError("expected positive size, got ") << size.value();
+
+ if (!size)
+ return success();
+
+ const tosa::shapeType outShapeType =
+ cast<tosa::shapeType>(getResult().getType());
+ const int64_t outputRank = outShapeType.getRank();
+ if (outputRank != size)
+ return emitOpError(
+ "expected output type size to be equal to size attribute, got ")
+ << outputRank << " vs " << size.value();
+
+ if (!start)
+ return success();
+
+ const tosa::shapeType inShapeType =
+ cast<tosa::shapeType>(getInput().getType());
+ const int64_t inputRank = inShapeType.getRank();
+ const int64_t sliceSize = start.value() + size.value();
+ if (sliceSize > inputRank)
+ return emitOpError("expected start + size to be less than or equal to "
+ "input shape rank (")
+ << inputRank << "), got " << sliceSize;
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// TOSA Attribute Definitions.
//===----------------------------------------------------------------------===//