diff options
14 files changed, 161 insertions, 128 deletions
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 8ede271..98bcbca 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -78,7 +78,8 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> { Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr4:$pad, TypeAttrOf<Tosa_AccType>:$acc_type, - OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info + OptionalAttr<I32Attr>:$input_zp, + OptionalAttr<I32Attr>:$output_zp ); let results = (outs @@ -237,7 +238,8 @@ def Tosa_FullyConnectedOp : Tosa_InferShapedTypeOp<"fully_connected"> { Tosa_Tensor2D:$input, TosaTensorRankOf<[Tosa_Weight], [2]>:$weight, Tosa_Tensor1D:$bias, - OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info + OptionalAttr<I32Attr>:$input_zp, + OptionalAttr<I32Attr>:$weight_zp ); let results = (outs @@ -263,7 +265,8 @@ def Tosa_MatMulOp : Tosa_InferShapedTypeOp<"matmul"> { let arguments = (ins Tosa_Tensor3D:$a, Tosa_Tensor3D:$b, - OptionalAttr<Tosa_MatMulOpQuantizationAttr>:$quantization_info + OptionalAttr<I32Attr>:$a_zp, + OptionalAttr<I32Attr>:$b_zp ); let results = (outs @@ -1114,7 +1117,8 @@ def Tosa_NegateOp : Tosa_ElementwiseUnaryOp<"negate"> { let arguments = (ins Tosa_Tensor:$input1, - OptionalAttr<Tosa_UnaryOpQuantizationAttr>:$quantization_info + OptionalAttr<I32Attr>:$input1_zp, + OptionalAttr<I32Attr>:$output_zp ); let results = (outs @@ -1589,7 +1593,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> { Tosa_RankedTensor:$input1, Tosa_Shape:$padding, Optional<Tosa_ScalarTensor>:$pad_const, - OptionalAttr<Tosa_PadOpQuantizationAttr>:$quantization_info + OptionalAttr<I32Attr>:$input_zp ); let results = (outs diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index b0eb2d6..67218ce 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -141,63 +141,65 @@ static Value createLinalgBodyCalculationForElementwiseOp( } // tosa::NegateOp - if (isa<tosa::NegateOp>(op) && isa<FloatType>(elementTy)) - return rewriter.create<arith::NegFOp>(loc, resultTypes, args); + if (isa<tosa::NegateOp>(op)) { + if (isa<FloatType>(elementTy)) + return rewriter.create<arith::NegFOp>(loc, resultTypes, args); - if (isa<tosa::NegateOp>(op) && isa<IntegerType>(elementTy)) { - int64_t inZp = 0, outZp = 0; + if (isa<IntegerType>(elementTy)) { + auto inputZpAttr = cast<tosa::NegateOp>(op).getInput1Zp(); + auto outputZpAttr = cast<tosa::NegateOp>(op).getOutputZp(); - if (cast<tosa::NegateOp>(op).getQuantizationInfo()) { - auto quantizationInfo = cast<tosa::NegateOp>(op).getQuantizationInfo(); - inZp = quantizationInfo.value().getInputZp(); - outZp = quantizationInfo.value().getOutputZp(); - } + const int64_t inZp = inputZpAttr ? *inputZpAttr : 0; + const int64_t outZp = outputZpAttr ? *outputZpAttr : 0; - int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); - if (!inZp && !outZp) { - auto constant = rewriter.create<arith::ConstantOp>( - loc, IntegerAttr::get(elementTy, 0)); - return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, - args[0]); - } + if (!inZp && !outZp) { + auto constant = rewriter.create<arith::ConstantOp>( + loc, IntegerAttr::get(elementTy, 0)); + return rewriter.create<arith::SubIOp>(loc, resultTypes, constant, + args[0]); + } - // Compute the maximum value that can occur in the intermediate buffer. - int64_t zpAdd = inZp + outZp; - int64_t maxValue = APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + - std::abs(zpAdd) + 1; - - // Convert that maximum value into the maximum bitwidth needed to represent - // it. We assume 48-bit numbers may be supported further in the pipeline. - int intermediateBitWidth = 64; - if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { - intermediateBitWidth = 16; - } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { - intermediateBitWidth = 32; - } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { - intermediateBitWidth = 48; - } + // Compute the maximum value that can occur in the intermediate buffer. + const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth(); + const int64_t zpAdd = inZp + outZp; + const int64_t maxValue = + APInt::getSignedMaxValue(inputBitWidth).getSExtValue() + + std::abs(zpAdd) + 1; + + // Convert that maximum value into the maximum bitwidth needed to + // represent it. We assume 48-bit numbers may be supported further in + // the pipeline. + int intermediateBitWidth = 64; + if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) { + intermediateBitWidth = 16; + } else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) { + intermediateBitWidth = 32; + } else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) { + intermediateBitWidth = 48; + } - Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); - Value zpAddValue = rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); - - // The negation can be applied by doing: - // outputValue = inZp + outZp - inputValue - auto ext = rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]); - auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext); - - // Clamp to the negation range. - Value min = rewriter.create<arith::ConstantIntOp>( - loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(), - intermediateType); - Value max = rewriter.create<arith::ConstantIntOp>( - loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(), - intermediateType); - auto clamp = - clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false); - - // Truncate to the final value. - return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp); + Type intermediateType = rewriter.getIntegerType(intermediateBitWidth); + Value zpAddValue = rewriter.create<arith::ConstantOp>( + loc, rewriter.getIntegerAttr(intermediateType, zpAdd)); + + // The negation can be applied by doing: + // outputValue = inZp + outZp - inputValue + auto ext = + rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[0]); + auto sub = rewriter.create<arith::SubIOp>(loc, zpAddValue, ext); + + // Clamp to the negation range. + Value min = rewriter.create<arith::ConstantIntOp>( + loc, APInt::getSignedMinValue(inputBitWidth).getSExtValue(), + intermediateType); + Value max = rewriter.create<arith::ConstantIntOp>( + loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(), + intermediateType); + auto clamp = clampIntHelper(loc, sub, min, max, rewriter, false); + + // Truncate to the final value. + return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp); + } } // tosa::BitwiseAndOp diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index cf9852e..6321cb6 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -590,18 +590,15 @@ public: .create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{emptyTensor}) .result(); - if (!op.getQuantizationInfo()) { + if (!op.getAZp() && !op.getBZp()) { rewriter.replaceOpWithNewOp<linalg::BatchMatmulOp>( op, TypeRange{op.getType()}, ValueRange{adaptor.getA(), adaptor.getB()}, ValueRange{zeroTensor}); return success(); } - auto quantizationInfo = *op.getQuantizationInfo(); - auto aZp = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getAZp())); - auto bZp = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getBZp())); + auto aZp = rewriter.create<arith::ConstantOp>(loc, op.getAZpAttr()); + auto bZp = rewriter.create<arith::ConstantOp>(loc, op.getBZpAttr()); rewriter.replaceOpWithNewOp<linalg::QuantizedBatchMatmulOp>( op, TypeRange{op.getType()}, ValueRange{adaptor.getA(), adaptor.getB(), aZp, bZp}, zeroTensor); @@ -661,7 +658,7 @@ public: Value broadcastBias = linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor); - if (!op.getQuantizationInfo()) { + if (!op.getInputZp() && !op.getWeightZp()) { Value matmul = rewriter .create<linalg::MatmulOp>( loc, TypeRange{op.getType()}, @@ -672,11 +669,9 @@ public: return success(); } - auto quantizationInfo = *op.getQuantizationInfo(); - auto inputZp = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getInputZp())); - auto outputZp = rewriter.create<arith::ConstantOp>( - loc, rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp())); + auto inputZp = rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr()); + auto outputZp = + rewriter.create<arith::ConstantOp>(loc, op.getWeightZpAttr()); Value matmul = rewriter .create<linalg::QuantizedMatmulOp>( @@ -958,10 +953,9 @@ public: // If we have quantization information we need to apply an offset // for the input zp value. - if (op.getQuantizationInfo()) { - auto quantizationInfo = *op.getQuantizationInfo(); - auto inputZp = rewriter.create<arith::ConstantOp>( - loc, b.getIntegerAttr(accETy, quantizationInfo.getInputZp())); + if (op.getInputZp()) { + auto inputZp = + rewriter.create<arith::ConstantOp>(loc, op.getInputZpAttr()); Value offset = rewriter.create<arith::MulIOp>(loc, accETy, count, inputZp); poolVal = @@ -1013,11 +1007,9 @@ public: // If we have quantization information we need to apply output // zeropoint. - if (op.getQuantizationInfo()) { - auto quantizationInfo = *op.getQuantizationInfo(); - auto outputZp = rewriter.create<arith::ConstantOp>( - loc, b.getIntegerAttr(scaled.getType(), - quantizationInfo.getOutputZp())); + if (op.getOutputZp()) { + auto outputZp = + rewriter.create<arith::ConstantOp>(loc, op.getOutputZpAttr()); scaled = rewriter.create<arith::AddIOp>(loc, scaled, outputZp) .getResult(); } diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index c4b787d..2a9b4d1 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -358,10 +358,10 @@ public: TypedAttr constantAttr; if (isa<FloatType>(elementTy)) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - } else if (isa<IntegerType>(elementTy) && !padOp.getQuantizationInfo()) { + } else if (isa<IntegerType>(elementTy) && !padOp.getInputZpAttr()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); - } else if (isa<IntegerType>(elementTy) && padOp.getQuantizationInfo()) { - int64_t value = padOp.getQuantizationInfo()->getInputZp(); + } else if (isa<IntegerType>(elementTy) && padOp.getInputZpAttr()) { + int64_t value = padOp.getInputZpAttr().getInt(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } if (constantAttr) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 9d36947..8e22c87 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -207,10 +207,10 @@ struct MaterializePadValue : public OpRewritePattern<tosa::PadOp> { Attribute constantAttr; if (llvm::isa<FloatType>(elementTy)) { constantAttr = rewriter.getFloatAttr(elementTy, 0.0); - } else if (llvm::isa<IntegerType>(elementTy) && !op.getQuantizationInfo()) { + } else if (llvm::isa<IntegerType>(elementTy) && !op.getInputZpAttr()) { constantAttr = rewriter.getIntegerAttr(elementTy, 0); - } else if (llvm::isa<IntegerType>(elementTy) && op.getQuantizationInfo()) { - auto value = op.getQuantizationInfo()->getInputZp(); + } else if (llvm::isa<IntegerType>(elementTy) && op.getInputZpAttr()) { + int64_t value = op.getInputZpAttr().getInt(); constantAttr = rewriter.getIntegerAttr(elementTy, value); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index e8b2890..031c279 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -271,11 +271,11 @@ static LogicalResult verifyConvOp(T op) { } } - bool inputIsQuant = !llvm::isa<FloatType>(inputEType); - bool weightIsQuant = !llvm::isa<FloatType>(weightEType); + bool inputIsFloat = llvm::isa<FloatType>(inputEType); + bool weightIsFloat = llvm::isa<FloatType>(weightEType); - // Either both must be quantized or both unquantized. - if (inputIsQuant != weightIsQuant) { + // Either both must be float or both non-float. + if (inputIsFloat != weightIsFloat) { op.emitOpError( "expect both input and weight to be float or not together, got ") << inputEType << " and " << weightEType; @@ -527,7 +527,12 @@ static void buildTransConvOpWithQuantInfo( auto quantAttr = ::buildConvOpQuantizationAttr(builder, input, weight); if (quantAttr) { - result.addAttribute("quantization_info", quantAttr); + result.addAttribute("input_zp", + builder.getI32IntegerAttr( + static_cast<int32_t>(quantAttr.getInputZp()))); + result.addAttribute("weight_zp", + builder.getI32IntegerAttr( + static_cast<int32_t>(quantAttr.getWeightZp()))); result.addTypes( buildConvOpResultTypeInfo(builder, outputType, input, weight)); } else { @@ -563,7 +568,10 @@ static void buildMatMulOpWithQuantInfo(OpBuilder &builder, auto quantAttr = ::buildMatMulOpQuantizationAttr(builder, a, b); if (quantAttr) { - result.addAttribute("quantization_info", quantAttr); + result.addAttribute("a_zp", builder.getI32IntegerAttr( + static_cast<int32_t>(quantAttr.getAZp()))); + result.addAttribute("b_zp", builder.getI32IntegerAttr( + static_cast<int32_t>(quantAttr.getBZp()))); auto inputType = llvm::dyn_cast<ShapedType>(a.getType()); assert(inputType && "Input must be a shaped tensor type!"); @@ -603,8 +611,14 @@ buildAvgPool2dOpWithQuantInfo(OpBuilder &builder, OperationState &result, result.addAttribute("pad", pad); result.addAttribute("acc_type", accType); auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); - if (quantAttr) - result.addAttribute("quantization_info", quantAttr); + if (quantAttr) { + result.addAttribute("input_zp", + builder.getI32IntegerAttr( + static_cast<int32_t>(quantAttr.getInputZp()))); + result.addAttribute("output_zp", + builder.getI32IntegerAttr( + static_cast<int32_t>(quantAttr.getOutputZp()))); + } result.types.push_back(outputType); } @@ -616,8 +630,15 @@ static void buildUnaryOpWithQuantInfo(OpBuilder &builder, Value input) { result.addOperands(input); auto quantAttr = buildUnaryOpQuantizationAttr(builder, input, outputType); - if (quantAttr) - result.addAttribute("quantization_info", quantAttr); + if (quantAttr) { + // note: negateOp has attributes input1_zp and output_zp + result.addAttribute("input1_zp", + builder.getI32IntegerAttr( + static_cast<int32_t>(quantAttr.getInputZp()))); + result.addAttribute("output_zp", + builder.getI32IntegerAttr( + static_cast<int32_t>(quantAttr.getOutputZp()))); + } result.types.push_back(outputType); } @@ -629,8 +650,11 @@ static void buildPadOpWithQuantInfo(OpBuilder &builder, OperationState &result, Value paddings) { result.addOperands({input, paddings}); auto quantAttr = buildPadOpQuantizationAttr(builder, input); - if (quantAttr) - result.addAttribute("quantization_info", quantAttr); + if (quantAttr) { + result.addAttribute("input_zp", + builder.getI32IntegerAttr( + static_cast<int32_t>(quantAttr.getInputZp()))); + } result.types.push_back(outputType); } @@ -643,8 +667,11 @@ static void buildExplicitValuePadOpWithQuantInfo(OpBuilder &builder, Value padConst) { result.addOperands({input, paddings, padConst}); auto quantAttr = buildPadOpQuantizationAttr(builder, input); - if (quantAttr) - result.addAttribute("quantization_info", quantAttr); + if (quantAttr) { + result.addAttribute("input_zp", + builder.getI32IntegerAttr( + static_cast<int32_t>(quantAttr.getInputZp()))); + } result.types.push_back(outputType); } @@ -898,9 +925,8 @@ LogicalResult FullyConnectedOp::verify() { // Quantized type must have constructed the quantizationattr, and unquantized // types should not have a quantizationattr. - if ((inputIsQuant && !getQuantizationInfo()) || - (!inputIsQuant && getQuantizationInfo())) { - emitOpError("quantizationattr is required for quantized type, and not " + if ((inputIsQuant && !getInputZp()) || (!inputIsQuant && getInputZp())) { + emitOpError("input zero point is required for quantized type, and not " "allowed for float type"); return failure(); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp index 7d3deae..4eba89b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp @@ -130,13 +130,13 @@ struct Conv2DIsFullyConnected : public OpRewritePattern<tosa::Conv2DOp> { auto maybeZps = failureOrMaybeZps.value(); Value fullyConnectedValue; if (maybeZps) { - auto zeroPointAttr = rewriter.getAttr<tosa::ConvOpQuantizationAttr>( - maybeZps->inputZp, maybeZps->weightZp); fullyConnectedValue = rewriter .create<tosa::FullyConnectedOp>( op.getLoc(), fullyConnectedShapeType, reshapedInput, - reshapedWeight, op.getBias(), zeroPointAttr) + reshapedWeight, op.getBias(), + rewriter.getI32IntegerAttr(maybeZps->inputZp), + rewriter.getI32IntegerAttr(maybeZps->weightZp)) .getResult(); } else { fullyConnectedValue = rewriter diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index ae22467..b5b3e9d 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -143,8 +143,7 @@ public: weight = CreateOpAndInferShape<tosa::PadOp>( rewriter, loc, UnrankedTensorType::get(weightETy), weight, weightPaddingVal, nullptr, - rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->weightZp)); - + rewriter.getI32IntegerAttr(maybeZps->weightZp)); } else { weight = CreateOpAndInferShape<tosa::PadOp>( rewriter, loc, UnrankedTensorType::get(weightETy), weight, @@ -203,7 +202,7 @@ public: input = CreateOpAndInferShape<tosa::PadOp>( rewriter, loc, UnrankedTensorType::get(inputETy), input, inputPaddingVal, nullptr, - rewriter.getAttr<PadOpQuantizationAttr>(maybeZps->inputZp)); + rewriter.getI32IntegerAttr(maybeZps->inputZp)); } else { input = CreateOpAndInferShape<tosa::PadOp>( rewriter, loc, UnrankedTensorType::get(inputETy), input, diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index 116cd04..87c388b 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -23,7 +23,7 @@ func.func @matmul_quantized(%arg0: tensor<1x5x3xi8>, %arg1: tensor<1x3x6xi8>) -> // CHECK: [[ONE:%.+]] = arith.constant 1 // CHECK: [[TWO:%.+]] = arith.constant 2 // CHECK: linalg.quantized_batch_matmul ins(%arg0, %arg1, [[ONE]], [[TWO]] : tensor<1x5x3xi8>, tensor<1x3x6xi8>, i32, i32) outs([[FILLED]] : tensor<1x5x6xi32>) -> tensor<1x5x6xi32> - %0 = tosa.matmul %arg0, %arg1 {quantization_info = #tosa.matmul_quant<a_zp = 1, b_zp = 2>} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32> + %0 = tosa.matmul %arg0, %arg1 {a_zp = 1 : i32, b_zp = 2 : i32} : (tensor<1x5x3xi8>, tensor<1x3x6xi8>) -> tensor<1x5x6xi32> return %0 : tensor<1x5x6xi32> } @@ -124,7 +124,7 @@ func.func @quantized_fully_connected(%arg0: tensor<5x3xi8>, %arg1: tensor<6x3xi8 // CHECK: %[[C2:.+]] = arith.constant 2 : i32 // CHECK: linalg.quantized_matmul ins(%arg0, %[[TRANSPOSE]], %[[C1]], %[[C2]] : tensor<5x3xi8>, tensor<3x6xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<5x6xi32>) -> tensor<5x6xi32> - %0 = tosa.fully_connected %arg0, %arg1, %arg2 {quantization_info = #tosa.conv_quant<input_zp = 1, weight_zp = 2>} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32> + %0 = tosa.fully_connected %arg0, %arg1, %arg2 {input_zp = 1 : i32, weight_zp = 2 : i32} : (tensor<5x3xi8>, tensor<6x3xi8>, tensor<6xi32>) -> tensor<5x6xi32> return %0 : tensor<5x6xi32> } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index f9bdcef..6e8501a 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -880,26 +880,36 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () { func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () { // CHECK: linalg.generic // CHECK: ^bb0(%[[BBARG0:.+]]: i8, - // CHECK: [[CNST:%.+]] = arith.constant 7 + // CHECK: [[ZERO:%.+]] = arith.constant 0 + // CHECK: [[SUB:%.+]] = arith.subi [[ZERO]], %[[BBARG0]] + // CHECK: linalg.yield [[SUB]] + %0 = tosa.negate %arg0 {input_zp1 = 0 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8> + + // CHECK: linalg.generic + // CHECK: ^bb0(%[[BBARG0:.+]]: i8, + // CHECK: [[C32639:%.+]] = arith.constant 32639 // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16 - // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]] + // CHECK: [[SUB:%.+]] = arith.subi [[C32639]], [[EXT]] // CHECK: [[MIN:%.+]] = arith.constant -128 // CHECK: [[MAX:%.+]] = arith.constant 127 // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]] // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]] // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]] // CHECK: linalg.yield [[TRUNC]] - %0 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 0, output_zp = 7>} : (tensor<1xi8>) -> tensor<1xi8> - - // CHECK: linalg.generic - // CHECK: ^bb0(%[[BBARG0:.+]]: i8, - // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16 - %1 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32639, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8> + %1 = tosa.negate %arg0 {input1_zp = 32639 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8> // CHECK: linalg.generic // CHECK: ^bb0(%[[BBARG0:.+]]: i8, + // CHECK: [[C32640:%.+]] = arith.constant 32640 // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i32 - %2 = tosa.negate %arg0 {quantization_info = #tosa.unary_quant<input_zp = 32640, output_zp = 0>} : (tensor<1xi8>) -> tensor<1xi8> + // CHECK: [[SUB:%.+]] = arith.subi [[C32640]], [[EXT]] + // CHECK: [[MIN:%.+]] = arith.constant -128 + // CHECK: [[MAX:%.+]] = arith.constant 127 + // CHECK: [[LBOUND:%.+]] = arith.maxsi [[MIN]], [[SUB]] + // CHECK: [[UBOUND:%.+]] = arith.minsi [[MAX]], [[LBOUND]] + // CHECK: [[TRUNC:%.+]] = arith.trunci [[UBOUND]] + // CHECK: linalg.yield [[TRUNC]] + %2 = tosa.negate %arg0 {input1_zp = 32640 : i32, output_zp = 0 : i32} : (tensor<1xi8>) -> tensor<1xi8> // CHECK: linalg.generic // CHECK: ^bb0(%[[BBARG0:.+]]: i8, diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir index f95de79..e83e898 100644 --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -492,7 +492,7 @@ func.func @pad_quant(%arg0 : tensor<1x2xi32>) -> (tensor<4x9xi32>) { // CHECK: [[CST:%.+]] = arith.constant 42 : i32 // CHECK: tensor.pad // CHECK: tensor.yield [[CST]] - %1 = "tosa.pad"(%arg0, %0) {quantization_info = #tosa.pad_quant<input_zp = 42>} : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>) + %1 = "tosa.pad"(%arg0, %0) {input_zp = 42 : i32} : (tensor<1x2xi32>, !tosa.shape<4>) -> (tensor<4x9xi32>) return %1 : tensor<4x9xi32> } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 71a7e28..e0e1de6 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -317,7 +317,7 @@ func.func @pad_determine_val_f32(%arg0: tensor<?x?xf32>, %arg1 : tensor<2x2xi32> // CHECK-LABEL: @pad_determine_val_quant func.func @pad_determine_val_quant(%arg0: tensor<?x?xi32>, %arg1 : tensor<2x2xi32>) -> tensor<?x?xi32> { - // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<i32>} + // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<42> : tensor<i32>} // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: tosa.pad %arg0, %[[PADDING]], %[[ZERO]] %0 = tosa.const_shape { value = dense<[1, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir index 685f799..e4a2897 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir @@ -28,7 +28,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t // CHECK: %[[VAR1:.*]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>} // CHECK-SAME: -> tensor<3x2xi8> // CHECK: %[[VAR2:.*]] = tosa.fully_connected %[[VAR0]], %[[VAR1]], %arg2 - // CHECK-SAME: quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24> + // CHECK-SAME: {input_zp = 42 : i32, weight_zp = 24 : i32} // CHECK-SAME: -> tensor<400x3xi32> // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]] {new_shape = array<i64: 4, 10, 10, 3>} // CHECK-SAME: -> tensor<4x10x10x3xi32> @@ -48,7 +48,7 @@ func.func @conv2d_as_fully_connected_quant(%arg0: tensor<4x10x10x2xi8>, %arg1: t func.func @conv_with_dynamic_dim(%arg0: tensor<?x14x14x64xi8>, %arg1: tensor<384x1x1x64xi8>, %arg2: tensor<384xi32>) -> tensor<?x14x14x384xi32> { // CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: -1, 64>} : (tensor<?x14x14x64xi8>) -> tensor<?x64xi8> // CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_1]] {new_shape = array<i64: 384, 64>} : (tensor<384x1x1x64xi8>) -> tensor<384x64xi8> -// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {quantization_info = #tosa.conv_quant<input_zp = -6, weight_zp = 11>} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32> +// CHECK: %[[VAL_5:.*]] = tosa.fully_connected %[[VAL_3]], %[[VAL_4]], %[[VAL_2]] {input_zp = -6 : i32, weight_zp = 11 : i32} : (tensor<?x64xi8>, tensor<384x64xi8>, tensor<384xi32>) -> tensor<?x384xi32> // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array<i64: -1, 14, 14, 384>} : (tensor<?x384xi32>) -> tensor<?x14x14x384xi32> // CHECK: return %[[VAL_6]] : tensor<?x14x14x384xi32> // CHECK: } @@ -67,7 +67,7 @@ func.func @conv2d_as_fully_connected_padded(%arg0: tensor<4x10x10x2xi8>, %arg1: // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PAD_SHAPE]], %[[PAD_VAL]] : (tensor<4x10x10x2xi8>, !tosa.shape<8>, tensor<i8>) -> tensor<4x12x12x2xi8> // CHECK-DAG: %[[RESHAPE_INPUT:.+]] = tosa.reshape %[[PAD]] {new_shape = array<i64: 576, 2>} // CHECK-DAG: %[[RESHAPE_FILTER:.+]] = tosa.reshape %arg1 {new_shape = array<i64: 3, 2>} - // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {quantization_info = #tosa.conv_quant<input_zp = 42, weight_zp = 24>} + // CHECK-DAG: %[[FULLY:.+]] = tosa.fully_connected %[[RESHAPE_INPUT]], %[[RESHAPE_FILTER]], %arg2 {input_zp = 42 : i32, weight_zp = 24 : i32} // CHECK: %[[RESHAPE:.+]] = tosa.reshape %[[FULLY]] {new_shape = array<i64: 4, 12, 12, 3>} %input_zp = "tosa.const"() {value = dense<42> : tensor<1xi8>} : () -> tensor<1xi8> %weight_zp = "tosa.const"() {value = dense<24> : tensor<1xi8>} : () -> tensor<1xi8> diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir index bb6de82..82838cc 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -91,7 +91,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // Manipulate the weight matrix to handle striding. // CHECK-DAG: %[[PADV:.+]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> // CHECK-DAG: %[[TRANSV:.+]] = "tosa.const"() <{value = dense<[2, 4, 0, 1, 3, 5]> : tensor<6xi32>} - // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {quantization_info = #tosa.pad_quant<input_zp = 42>} + // CHECK-DAG: %[[PADW:.+]] = tosa.pad %arg1, %[[PADV]] {input_zp = 42 : i32} // CHECK-DAG: %[[RESW1:.+]] = tosa.reshape %[[PADW]] {new_shape = array<i64: 5, 2, 2, 2, 3, 3>} // CHECK-DAG: %[[TRANS:.+]] = tosa.transpose %[[RESW1]], %[[TRANSV]] // CHECK-DAG: %[[RESW2:.+]] = tosa.reshape %[[TRANS]] {new_shape = array<i64: 30, 2, 2, 3>} @@ -101,7 +101,7 @@ func.func @transpose_conv2d_strided_quantized(%arg0: tensor<2x17x15x3xi8>, %arg1 // Pad out the input matrix to handle the transpose conv. // CHECK-DAG: %[[PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 1, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> // CHECK-DAG: %[[TRANS2:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} - // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {quantization_info = #tosa.pad_quant<input_zp = -22>} + // CHECK-DAG: %[[NEWINPUT:.+]] = tosa.pad %arg0, %[[PAD]] {input_zp = -22 : i32} // Manipulate the final shape. // CHECK-DAG: %[[BIAS:.+]] = "tosa.const"() <{value = dense<0> : tensor<30xi32>} @@ -132,14 +132,14 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : // CHECK-DAG: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0> : tensor<2xi32>} // CHECK-DAG: %[[RESULT_PERMS:.+]] = "tosa.const"() <{value = dense<[0, 1, 3, 2, 4, 5]> : tensor<6xi32>} // CHECK-DAG: %[[RESULT_PAD:.+]] = tosa.const_shape {value = dense<[0, 0, 2, 0, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> - // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = 93>} + // CHECK: %[[PAD_WEIGHT:.+]] = tosa.pad %arg1, %[[WEIGHT_PAD]] {input_zp = 93 : i32} // CHECK: %[[RESHAPE_WEIGHT_0:.+]] = tosa.reshape %[[PAD_WEIGHT]] {new_shape = array<i64: 1, 2, 1, 1, 2, 1>} // CHECK: %[[TRANSPOSE_WEIGHT:.+]] = tosa.transpose %[[RESHAPE_WEIGHT_0]], %[[WEIGHT_PERMS]] // CHECK: %[[RESHAPE_WEIGHT_1:.+]] = tosa.reshape %[[TRANSPOSE_WEIGHT]] {new_shape = array<i64: 2, 2, 1, 1>} // CHECK: %[[REVERSE:.+]] = tosa.reverse %[[RESHAPE_WEIGHT_1]] {axis = 1 : i32} - // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {quantization_info = #tosa.pad_quant<input_zp = -103>} + // CHECK: %[[PAD_INPUT:.+]] = tosa.pad %arg0, %[[INPUT_PAD]] {input_zp = -103 : i32} // CHECK: %[[CONV:.+]] = tosa.conv2d %[[PAD_INPUT]], %[[REVERSE]], %[[ZERO]] - // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], quantization_info = #tosa.conv_quant<input_zp = -103, weight_zp = 93>, stride = [1, 1]} + // CHECK-SAME{literal}: dilation = [1, 1], pad = [0, 0, 0, 0], input_zp = -103 : i32, weight_zp = 93 : i32, stride = [1, 1]} // CHECK: %[[RESHAPE_RESULT_0:.+]] = tosa.reshape %[[CONV]] {new_shape = array<i64: 1, 17, 1, 1, 2, 1>} // CHECK: %[[TRANSPOSE_RESULT:.+]] = tosa.transpose %[[RESHAPE_RESULT_0]], %[[RESULT_PERMS]] // CHECK: %[[RESHAPE_RESULT_1:.+]] = tosa.reshape %[[TRANSPOSE_RESULT]] {new_shape = array<i64: 1, 17, 2, 1>} |