aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Tosa/IR/TosaOps.cpp')
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaOps.cpp377
1 files changed, 358 insertions, 19 deletions
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 00f84bc..0aff67f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser,
}
}
+ // special handling: block_size accepts a *bare* BlockSizeMode enum
+ if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) {
+ if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) {
+ auto sym = symbolizeBlockSize(kw);
+ if (!sym)
+ return parser.emitError(parser.getCurrentLocation())
+ << "invalid block_size value: " << kw;
+ auto attr = BlockSizeAttr::get(parser.getContext(), sym.value());
+ outAttrs.push_back(NamedAttribute(name, attr));
+ return success();
+ }
+ }
+
// Default path: parse any normal attribute literal, including fully qualified
// enum keyword
Attribute attr;
@@ -357,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) {
result.operands)))
return failure();
- result.addTypes(fnTy.getResult(0));
+ result.addTypes(fnTy.getResults());
result.addAttributes(attrs);
return success();
@@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) {
} else if (auto nanPropagationModeAttr =
dyn_cast<tosa::NanPropagationModeAttr>(attr)) {
parser << nanPropagationModeAttr.getValue();
+ } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) {
+ parser << blockSizeAttr.getValue();
} else {
parser.printAttribute(attr);
}
@@ -508,6 +523,33 @@ void ReduceMinOp::print(OpAsmPrinter &parser) {
printWithNanPropagationHandling(parser, *this);
}
+ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastFromBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
+ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser,
+ OperationState &result) {
+ return parseWithEnumHandling<tosa::BlockSize>(parser, result);
+}
+
+void CastToBlockScaledOp::print(OpAsmPrinter &parser) {
+ printWithEnumHandling(parser, *this);
+}
+
//===----------------------------------------------------------------------===//
// Tosa utilities.
//===----------------------------------------------------------------------===//
@@ -933,32 +975,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) {
// verify that inType and outType have same element types
template <typename T>
-static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
- auto inputType = llvm::dyn_cast<TensorType>(inType);
- auto outputType = llvm::dyn_cast<TensorType>(outType);
- if (!inputType) {
- op.emitOpError("expect shaped tensor for input, got ") << inType;
+static LogicalResult verifySameElementTypes(T op, Type aType, Type bType,
+ StringRef aName = "input",
+ StringRef bName = "output") {
+ auto aTType = llvm::dyn_cast<TensorType>(aType);
+ auto bTType = llvm::dyn_cast<TensorType>(bType);
+ if (!aTType) {
+ op.emitOpError("expect shaped tensor for") << aName << ", got " << aType;
return failure();
}
- if (!outputType) {
- op.emitOpError("expect shaped tensor for output, got ") << outType;
+ if (!bTType) {
+ op.emitOpError("expect shaped tensor for") << bName << ", got" << bType;
return failure();
}
- auto inputElementType = inputType.getElementType();
- auto outputElementType = outputType.getElementType();
- auto inputQuantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
- auto outputQuantType =
- llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
- if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
- (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
- inputElementType != outputElementType) {
+ auto aElementType = aTType.getElementType();
+ auto bElementType = bTType.getElementType();
+ auto aQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType);
+ auto bQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType);
+ if ((aElementType.isIntOrIndexOrFloat() || aQuantType) &&
+ (bElementType.isIntOrIndexOrFloat() || bQuantType) &&
+ aElementType != bElementType) {
// only check if both element types are int/index/float/UniformQuantized
// eg, not sure how to check quant::QuantizedType
// this happens in test_conv2d_q_grouped_convolution in
// tfl-to-tosa-pipeline.mlir
- op.emitOpError("expect input and output to have same element type, got ")
- << inputElementType << " and " << outputElementType;
+ op.emitOpError("expect ")
+ << aName << " and " << bName << " to have same element type, got "
+ << aElementType << " and " << bElementType;
return failure();
}
return success();
@@ -1846,6 +1891,161 @@ LogicalResult MatMulOp::verify() {
return success();
}
+LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ MatmulTBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic);
+
+ const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType());
+ if (aDataShape.hasRank()) {
+ outShape[0] = aDataShape.getDimSize(0);
+ outShape[1] = aDataShape.getDimSize(1);
+ }
+
+ const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType());
+ if (aScaleShape.hasRank()) {
+ outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0)
+ : outShape[0];
+ outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1)
+ : outShape[1];
+ }
+
+ // If B batch size is 1, it is broadcast across A's batch size
+ const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType());
+ if (bDataShape.hasRank()) {
+ const int64_t bDataBatchSize = bDataShape.getDimSize(0);
+ if (bDataBatchSize != 1)
+ outShape[0] =
+ ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0];
+ outShape[2] = bDataShape.getDimSize(1);
+ }
+
+ const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType());
+ if (bScaleShape.hasRank()) {
+ const int64_t bScaleBatchSize = bScaleShape.getDimSize(0);
+ if (bScaleBatchSize != 1)
+ outShape[0] =
+ ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0];
+ outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1)
+ : outShape[2];
+ }
+
+ inferredReturnShapes.push_back(ShapedTypeComponents(outShape));
+ return success();
+}
+
+LogicalResult MatmulTBlockScaledOp::verify() {
+ // Verify same input data types
+ const Type aDataType = getAData().getType();
+ const Type bDataType = getBData().getType();
+ if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data",
+ "B_data")))
+ return failure();
+
+ auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim,
+ const StringRef operandName,
+ const StringRef dimName) -> LogicalResult {
+ if (ShapedType::isDynamic(currDim)) {
+ currDim = newDim;
+ return success();
+ } else if (ShapedType::isStatic(newDim) && currDim != newDim) {
+ return emitOpError("expected ")
+ << dimName << " of " << operandName << " to match size " << currDim
+ << ", got " << newDim;
+ }
+ return success();
+ };
+
+ // Verify input shape compatibility
+ int64_t N = ShapedType::kDynamic;
+ int64_t D = ShapedType::kDynamic;
+ int64_t H = ShapedType::kDynamic;
+ int64_t W = ShapedType::kDynamic;
+ int64_t C = ShapedType::kDynamic;
+ int64_t multiplesOfC = ShapedType::kDynamic;
+
+ const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType);
+ if (aDataShape.hasRank()) {
+ N = aDataShape.getDimSize(0);
+ H = aDataShape.getDimSize(1);
+ C = aDataShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType());
+ if (aScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale",
+ "height")))
+ return failure();
+ multiplesOfC = aScaleShape.getDimSize(2);
+ }
+
+ const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType);
+ if (bDataShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data",
+ "channels")))
+ return failure();
+ W = bDataShape.getDimSize(1);
+ }
+
+ const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType());
+ if (bScaleShape.hasRank()) {
+ if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale",
+ "batch")) ||
+ failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale",
+ "width")) ||
+ failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2),
+ "b_scale", "C/block_size")))
+ return failure();
+ }
+
+ // Verify batch size is broadcast compatible
+ if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1)
+ return emitOpError("expect B matrix batch size to be broadcast compatible "
+ "with A, got D=")
+ << D << " vs N=" << N;
+
+ // Verify C is a multiple of block size
+ const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ if (ShapedType::isStatic(C) && C % blockSize != 0)
+ return emitOpError("expect C to be a multiple of block size, got C=")
+ << C << ", block_size=" << blockSize;
+
+ // Verify multiplesOfC is C / block size
+ if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) &&
+ multiplesOfC != C / blockSize)
+ return emitOpError(
+ "expect scale operands dimension 2 to equal C/block_size (")
+ << C << "/" << blockSize << ")"
+ << ", got " << multiplesOfC;
+
+ // Verify output shape
+ N = ShapedType::isDynamic(N) ? D : N;
+ const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W};
+ const auto outputType = cast<ShapedType>(getResult().getType());
+ if (outputType.hasRank() &&
+ failed(
+ verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) {
+ InFlightDiagnostic opError = emitOpError("expected output shape ");
+ auto stringifyDim = [&](int64_t d) {
+ if (ShapedType::isDynamic(d))
+ opError << "?";
+ else
+ opError << d;
+ };
+ llvm::interleaveComma(outputType.getShape(), opError, stringifyDim);
+ opError << " to be compatible with expected output shape ";
+ llvm::interleaveComma(expectedOutputShape, opError, stringifyDim);
+ return opError;
+ }
+
+ return success();
+}
+
LogicalResult tosa::PadOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
PadOp::Adaptor adaptor,
@@ -3762,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents(
return success();
}
+LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastFromBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ return success();
+}
+
+LogicalResult CastFromBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult().getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+
+ if (inputDataShape.hasRank()) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+
+ const Type inputScaleType = getInputScale().getType();
+ const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType);
+
+ if (inputScaleShape.hasRank()) {
+ SmallVector<int64_t> inputDataDims, inputScaleDims;
+ inputDataShape.getDims(inputDataDims);
+ inputScaleShape.getDims(inputScaleDims);
+
+ if (inputDataDims.size() != inputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(inputDataDims).drop_back(1),
+ ArrayRef<int64_t>(inputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "input_scale (" << inputScaleType
+ << ") except for the last dimension";
+
+ const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize,
+ inputScaleDims.back()};
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of input_scale ("
+ << inputScaleDims.back()
+ << ") to be equal to last dimension of input_data / block_size ("
+ << inputDataDims.back() / blockSize << ")";
+ }
+ }
+
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::inferReturnTypeComponents(
+ MLIRContext *context, ::std::optional<Location> location,
+ CastToBlockScaledOp::Adaptor adaptor,
+ SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
+ const ShapeAdaptor inputShape(adaptor.getInputData().getType());
+ inferredReturnShapes.push_back(ShapedTypeComponents(inputShape));
+ if (!inputShape.hasRank())
+ return success();
+
+ // Calculate output_scale shape if ranked input provided
+ SmallVector<int64_t> outputScaleShape;
+ inputShape.getDims(outputScaleShape);
+ const int64_t lastDimLoc = inputShape.getRank() - 1;
+ const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc);
+ if (ShapedType::isStatic(lastDimSize)) {
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize());
+ outputScaleShape[lastDimLoc] = lastDimSize / blockSize;
+ }
+ inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape));
+ return success();
+}
+
+LogicalResult CastToBlockScaledOp::verify() {
+ const Type inputDataType = getInputData().getType();
+ const Type outputDataType = getResult(0).getType();
+ if (failed(verifyCompatibleShape(inputDataType, outputDataType)))
+ return emitOpError() << "require compatible shapes for input_data ("
+ << inputDataType << ") and "
+ << "output_data (" << outputDataType << ")";
+
+ const unsigned int blockSize =
+ BlockSizeAttr::getBlockSizeValue(getBlockSize());
+ const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType);
+ if (inputDataShape.hasRank()) {
+ const int64_t inputDataLastDim =
+ inputDataShape.getDimSize(inputDataShape.getRank() - 1);
+ if (ShapedType::isStatic(inputDataLastDim) &&
+ inputDataLastDim % blockSize != 0)
+ return emitOpError() << "expect last dimension of input_data ("
+ << inputDataLastDim
+ << ") to be divisible by block_size (" << blockSize
+ << ")";
+ }
+
+ const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType);
+ const Type outputScaleType = getResult(1).getType();
+ const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType);
+ if (outputDataShape.hasRank() && outputScaleShape.hasRank()) {
+ SmallVector<int64_t> outputDataDims, outputScaleDims;
+ outputDataShape.getDims(outputDataDims);
+ outputScaleShape.getDims(outputScaleDims);
+
+ if (outputDataDims.size() != outputScaleDims.size() ||
+ failed(verifyCompatibleShape(
+ ArrayRef<int64_t>(outputDataDims).drop_back(1),
+ ArrayRef<int64_t>(outputScaleDims).drop_back(1))))
+ return emitOpError() << "require compatible shapes for output_data ("
+ << outputDataType << ") and "
+ << "output_scale (" << outputScaleType
+ << ") except for the last dimension";
+
+ const int64_t outputDataLastDim = outputDataDims.back();
+ const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize,
+ outputScaleDims.back()};
+ if (ShapedType::isStatic(outputDataLastDim) &&
+ failed(verifyCompatibleDims(dimsToCheck)))
+ return emitOpError()
+ << "expect last dimension of output_scale ("
+ << outputScaleDims.back()
+ << ") to be equal to last dimension of output_data / block_size ("
+ << outputDataDims.back() / blockSize << ")";
+ }
+
+ return success();
+}
+
LogicalResult IfOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
IfOp::Adaptor adaptor,