diff options
Diffstat (limited to 'mlir/lib/Dialect/Tosa/IR/TosaOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 377 |
1 files changed, 358 insertions, 19 deletions
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 00f84bc..0aff67f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -321,6 +321,19 @@ ParseResult parseAttrEntryWithEnumHandling(OpAsmParser &parser, } } + // special handling: block_size accepts a *bare* BlockSizeMode enum + if constexpr (std::is_same_v<EnumType, tosa::BlockSize>) { + if (name == "block_size" && succeeded(parser.parseOptionalKeyword(&kw))) { + auto sym = symbolizeBlockSize(kw); + if (!sym) + return parser.emitError(parser.getCurrentLocation()) + << "invalid block_size value: " << kw; + auto attr = BlockSizeAttr::get(parser.getContext(), sym.value()); + outAttrs.push_back(NamedAttribute(name, attr)); + return success(); + } + } + // Default path: parse any normal attribute literal, including fully qualified // enum keyword Attribute attr; @@ -357,7 +370,7 @@ ParseResult parseWithEnumHandling(OpAsmParser &parser, OperationState &result) { result.operands))) return failure(); - result.addTypes(fnTy.getResult(0)); + result.addTypes(fnTy.getResults()); result.addAttributes(attrs); return success(); @@ -373,6 +386,8 @@ void printNamedAttr(OpAsmPrinter &parser, const NamedAttribute namedAttr) { } else if (auto nanPropagationModeAttr = dyn_cast<tosa::NanPropagationModeAttr>(attr)) { parser << nanPropagationModeAttr.getValue(); + } else if (auto blockSizeAttr = dyn_cast<tosa::BlockSizeAttr>(attr)) { + parser << blockSizeAttr.getValue(); } else { parser.printAttribute(attr); } @@ -508,6 +523,33 @@ void ReduceMinOp::print(OpAsmPrinter &parser) { printWithNanPropagationHandling(parser, *this); } +ParseResult MatmulTBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void MatmulTBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult CastFromBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void CastFromBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + +ParseResult CastToBlockScaledOp::parse(OpAsmParser &parser, + OperationState &result) { + return parseWithEnumHandling<tosa::BlockSize>(parser, result); +} + +void CastToBlockScaledOp::print(OpAsmPrinter &parser) { + printWithEnumHandling(parser, *this); +} + //===----------------------------------------------------------------------===// // Tosa utilities. //===----------------------------------------------------------------------===// @@ -933,32 +975,35 @@ static LogicalResult verifyVariableOpErrorIf(T op, Type type, StringRef name) { // verify that inType and outType have same element types template <typename T> -static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) { - auto inputType = llvm::dyn_cast<TensorType>(inType); - auto outputType = llvm::dyn_cast<TensorType>(outType); - if (!inputType) { - op.emitOpError("expect shaped tensor for input, got ") << inType; +static LogicalResult verifySameElementTypes(T op, Type aType, Type bType, + StringRef aName = "input", + StringRef bName = "output") { + auto aTType = llvm::dyn_cast<TensorType>(aType); + auto bTType = llvm::dyn_cast<TensorType>(bType); + if (!aTType) { + op.emitOpError("expect shaped tensor for") << aName << ", got " << aType; return failure(); } - if (!outputType) { - op.emitOpError("expect shaped tensor for output, got ") << outType; + if (!bTType) { + op.emitOpError("expect shaped tensor for") << bName << ", got" << bType; return failure(); } - auto inputElementType = inputType.getElementType(); - auto outputElementType = outputType.getElementType(); - auto inputQuantType = - llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType); - auto outputQuantType = - llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType); - if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) && - (outputElementType.isIntOrIndexOrFloat() || outputQuantType) && - inputElementType != outputElementType) { + auto aElementType = aTType.getElementType(); + auto bElementType = bTType.getElementType(); + auto aQuantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(aElementType); + auto bQuantType = + llvm::dyn_cast<mlir::quant::UniformQuantizedType>(bElementType); + if ((aElementType.isIntOrIndexOrFloat() || aQuantType) && + (bElementType.isIntOrIndexOrFloat() || bQuantType) && + aElementType != bElementType) { // only check if both element types are int/index/float/UniformQuantized // eg, not sure how to check quant::QuantizedType // this happens in test_conv2d_q_grouped_convolution in // tfl-to-tosa-pipeline.mlir - op.emitOpError("expect input and output to have same element type, got ") - << inputElementType << " and " << outputElementType; + op.emitOpError("expect ") + << aName << " and " << bName << " to have same element type, got " + << aElementType << " and " << bElementType; return failure(); } return success(); @@ -1846,6 +1891,161 @@ LogicalResult MatMulOp::verify() { return success(); } +LogicalResult tosa::MatmulTBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + MatmulTBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + SmallVector<int64_t, 3> outShape(3, ShapedType::kDynamic); + + const auto aDataShape = cast<ShapedType>(adaptor.getAData().getType()); + if (aDataShape.hasRank()) { + outShape[0] = aDataShape.getDimSize(0); + outShape[1] = aDataShape.getDimSize(1); + } + + const auto aScaleShape = cast<ShapedType>(adaptor.getAScale().getType()); + if (aScaleShape.hasRank()) { + outShape[0] = ShapedType::isDynamic(outShape[0]) ? aScaleShape.getDimSize(0) + : outShape[0]; + outShape[1] = ShapedType::isDynamic(outShape[1]) ? aScaleShape.getDimSize(1) + : outShape[1]; + } + + // If B batch size is 1, it is broadcast across A's batch size + const auto bDataShape = cast<ShapedType>(adaptor.getBData().getType()); + if (bDataShape.hasRank()) { + const int64_t bDataBatchSize = bDataShape.getDimSize(0); + if (bDataBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bDataBatchSize : outShape[0]; + outShape[2] = bDataShape.getDimSize(1); + } + + const auto bScaleShape = cast<ShapedType>(adaptor.getBScale().getType()); + if (bScaleShape.hasRank()) { + const int64_t bScaleBatchSize = bScaleShape.getDimSize(0); + if (bScaleBatchSize != 1) + outShape[0] = + ShapedType::isDynamic(outShape[0]) ? bScaleBatchSize : outShape[0]; + outShape[2] = ShapedType::isDynamic(outShape[2]) ? bScaleShape.getDimSize(1) + : outShape[2]; + } + + inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); + return success(); +} + +LogicalResult MatmulTBlockScaledOp::verify() { + // Verify same input data types + const Type aDataType = getAData().getType(); + const Type bDataType = getBData().getType(); + if (failed(verifySameElementTypes(*this, aDataType, bDataType, "A_data", + "B_data"))) + return failure(); + + auto tryUpdateDimOrFailure = [&](int64_t &currDim, const int64_t newDim, + const StringRef operandName, + const StringRef dimName) -> LogicalResult { + if (ShapedType::isDynamic(currDim)) { + currDim = newDim; + return success(); + } else if (ShapedType::isStatic(newDim) && currDim != newDim) { + return emitOpError("expected ") + << dimName << " of " << operandName << " to match size " << currDim + << ", got " << newDim; + } + return success(); + }; + + // Verify input shape compatibility + int64_t N = ShapedType::kDynamic; + int64_t D = ShapedType::kDynamic; + int64_t H = ShapedType::kDynamic; + int64_t W = ShapedType::kDynamic; + int64_t C = ShapedType::kDynamic; + int64_t multiplesOfC = ShapedType::kDynamic; + + const ShapeAdaptor aDataShape = ShapeAdaptor(aDataType); + if (aDataShape.hasRank()) { + N = aDataShape.getDimSize(0); + H = aDataShape.getDimSize(1); + C = aDataShape.getDimSize(2); + } + + const ShapeAdaptor aScaleShape = ShapeAdaptor(getAScale().getType()); + if (aScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(N, aScaleShape.getDimSize(0), "a_scale", + "batch")) || + failed(tryUpdateDimOrFailure(H, aScaleShape.getDimSize(1), "a_scale", + "height"))) + return failure(); + multiplesOfC = aScaleShape.getDimSize(2); + } + + const ShapeAdaptor bDataShape = ShapeAdaptor(bDataType); + if (bDataShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bDataShape.getDimSize(0), "b_data", + "batch")) || + failed(tryUpdateDimOrFailure(C, bDataShape.getDimSize(2), "b_data", + "channels"))) + return failure(); + W = bDataShape.getDimSize(1); + } + + const ShapeAdaptor bScaleShape = ShapeAdaptor(getBScale().getType()); + if (bScaleShape.hasRank()) { + if (failed(tryUpdateDimOrFailure(D, bScaleShape.getDimSize(0), "b_scale", + "batch")) || + failed(tryUpdateDimOrFailure(W, bScaleShape.getDimSize(1), "b_scale", + "width")) || + failed(tryUpdateDimOrFailure(multiplesOfC, bScaleShape.getDimSize(2), + "b_scale", "C/block_size"))) + return failure(); + } + + // Verify batch size is broadcast compatible + if (ShapedType::isStatic(N) && ShapedType::isStatic(D) && N != D && D != 1) + return emitOpError("expect B matrix batch size to be broadcast compatible " + "with A, got D=") + << D << " vs N=" << N; + + // Verify C is a multiple of block size + const uint32_t blockSize = BlockSizeAttr::getBlockSizeValue(getBlockSize()); + if (ShapedType::isStatic(C) && C % blockSize != 0) + return emitOpError("expect C to be a multiple of block size, got C=") + << C << ", block_size=" << blockSize; + + // Verify multiplesOfC is C / block size + if (ShapedType::isStatic(C) && ShapedType::isStatic(multiplesOfC) && + multiplesOfC != C / blockSize) + return emitOpError( + "expect scale operands dimension 2 to equal C/block_size (") + << C << "/" << blockSize << ")" + << ", got " << multiplesOfC; + + // Verify output shape + N = ShapedType::isDynamic(N) ? D : N; + const SmallVector<int64_t, 3> expectedOutputShape = {N, H, W}; + const auto outputType = cast<ShapedType>(getResult().getType()); + if (outputType.hasRank() && + failed( + verifyCompatibleShape(outputType.getShape(), expectedOutputShape))) { + InFlightDiagnostic opError = emitOpError("expected output shape "); + auto stringifyDim = [&](int64_t d) { + if (ShapedType::isDynamic(d)) + opError << "?"; + else + opError << d; + }; + llvm::interleaveComma(outputType.getShape(), opError, stringifyDim); + opError << " to be compatible with expected output shape "; + llvm::interleaveComma(expectedOutputShape, opError, stringifyDim); + return opError; + } + + return success(); +} + LogicalResult tosa::PadOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, PadOp::Adaptor adaptor, @@ -3762,6 +3962,145 @@ LogicalResult RescaleOp::inferReturnTypeComponents( return success(); } +LogicalResult CastFromBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + CastFromBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + return success(); +} + +LogicalResult CastFromBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult().getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + + if (inputDataShape.hasRank()) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + + const Type inputScaleType = getInputScale().getType(); + const ShapeAdaptor inputScaleShape = ShapeAdaptor(inputScaleType); + + if (inputScaleShape.hasRank()) { + SmallVector<int64_t> inputDataDims, inputScaleDims; + inputDataShape.getDims(inputDataDims); + inputScaleShape.getDims(inputScaleDims); + + if (inputDataDims.size() != inputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef<int64_t>(inputDataDims).drop_back(1), + ArrayRef<int64_t>(inputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "input_scale (" << inputScaleType + << ") except for the last dimension"; + + const SmallVector<int64_t, 2> dimsToCheck{inputDataLastDim / blockSize, + inputScaleDims.back()}; + if (ShapedType::isStatic(inputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of input_scale (" + << inputScaleDims.back() + << ") to be equal to last dimension of input_data / block_size (" + << inputDataDims.back() / blockSize << ")"; + } + } + + return success(); +} + +LogicalResult CastToBlockScaledOp::inferReturnTypeComponents( + MLIRContext *context, ::std::optional<Location> location, + CastToBlockScaledOp::Adaptor adaptor, + SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) { + const ShapeAdaptor inputShape(adaptor.getInputData().getType()); + inferredReturnShapes.push_back(ShapedTypeComponents(inputShape)); + if (!inputShape.hasRank()) + return success(); + + // Calculate output_scale shape if ranked input provided + SmallVector<int64_t> outputScaleShape; + inputShape.getDims(outputScaleShape); + const int64_t lastDimLoc = inputShape.getRank() - 1; + const int64_t lastDimSize = inputShape.getDimSize(lastDimLoc); + if (ShapedType::isStatic(lastDimSize)) { + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(adaptor.getBlockSize()); + outputScaleShape[lastDimLoc] = lastDimSize / blockSize; + } + inferredReturnShapes.push_back(ShapedTypeComponents(outputScaleShape)); + return success(); +} + +LogicalResult CastToBlockScaledOp::verify() { + const Type inputDataType = getInputData().getType(); + const Type outputDataType = getResult(0).getType(); + if (failed(verifyCompatibleShape(inputDataType, outputDataType))) + return emitOpError() << "require compatible shapes for input_data (" + << inputDataType << ") and " + << "output_data (" << outputDataType << ")"; + + const unsigned int blockSize = + BlockSizeAttr::getBlockSizeValue(getBlockSize()); + const ShapeAdaptor inputDataShape = ShapeAdaptor(inputDataType); + if (inputDataShape.hasRank()) { + const int64_t inputDataLastDim = + inputDataShape.getDimSize(inputDataShape.getRank() - 1); + if (ShapedType::isStatic(inputDataLastDim) && + inputDataLastDim % blockSize != 0) + return emitOpError() << "expect last dimension of input_data (" + << inputDataLastDim + << ") to be divisible by block_size (" << blockSize + << ")"; + } + + const ShapeAdaptor outputDataShape = ShapeAdaptor(outputDataType); + const Type outputScaleType = getResult(1).getType(); + const ShapeAdaptor outputScaleShape = ShapeAdaptor(outputScaleType); + if (outputDataShape.hasRank() && outputScaleShape.hasRank()) { + SmallVector<int64_t> outputDataDims, outputScaleDims; + outputDataShape.getDims(outputDataDims); + outputScaleShape.getDims(outputScaleDims); + + if (outputDataDims.size() != outputScaleDims.size() || + failed(verifyCompatibleShape( + ArrayRef<int64_t>(outputDataDims).drop_back(1), + ArrayRef<int64_t>(outputScaleDims).drop_back(1)))) + return emitOpError() << "require compatible shapes for output_data (" + << outputDataType << ") and " + << "output_scale (" << outputScaleType + << ") except for the last dimension"; + + const int64_t outputDataLastDim = outputDataDims.back(); + const SmallVector<int64_t, 2> dimsToCheck{outputDataLastDim / blockSize, + outputScaleDims.back()}; + if (ShapedType::isStatic(outputDataLastDim) && + failed(verifyCompatibleDims(dimsToCheck))) + return emitOpError() + << "expect last dimension of output_scale (" + << outputScaleDims.back() + << ") to be equal to last dimension of output_data / block_size (" + << outputDataDims.back() / blockSize << ")"; + } + + return success(); +} + LogicalResult IfOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional<Location> location, IfOp::Adaptor adaptor, |
