diff options
author | Matthias Springer <mspringer@nvidia.com> | 2025-04-18 16:14:03 +0200 |
---|---|---|
committer | Matthias Springer <mspringer@nvidia.com> | 2025-04-18 16:18:41 +0200 |
commit | c5cf03a74d5a27d6f03da7d3c21532e98906c638 (patch) | |
tree | 7294e6c650c4dba061d3cf9dbbb6127cdf85bdee | |
parent | 622765f976f0a76313aa38d0f549eea9affeebfd (diff) | |
download | llvm-users/matthias-springer/tosa_no_rollback.zip llvm-users/matthias-springer/tosa_no_rollback.tar.gz llvm-users/matthias-springer/tosa_no_rollback.tar.bz2 |
-rw-r--r-- | mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp | 217 |
1 files changed, 119 insertions, 98 deletions
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 9ca93ab..bc4ef58 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -91,6 +91,50 @@ createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType, op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); } +/// Return "failure" if the given elementwise operation cannot be converted. +static LogicalResult +isSupportedElementwiseOperation(ConversionPatternRewriter &rewriter, + Operation *op, RankedTensorType resultType) { + auto elementTy = + cast<ShapedType>(op->getOperand(0).getType()).getElementType(); + + // tosa::MulOp + if (isa<tosa::MulOp>(op)) { + auto shiftVal = cast<tosa::MulOp>(op).getShift(); + DenseElementsAttr shiftElem; + if (!matchPattern(shiftVal, m_Constant(&shiftElem))) + return rewriter.notifyMatchFailure(op, "shift value of mul not found"); + + int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt(); + if (isa<FloatType>(elementTy) && shift != 0) + return rewriter.notifyMatchFailure(op, + "Cannot have shift value for float"); + return success(); + } + + // tosa::NegateOp + if (isa<tosa::NegateOp>(op)) { + auto negate = cast<tosa::NegateOp>(op); + if (failed(negate.getInput1ZeroPoint())) + return rewriter.notifyMatchFailure( + op, "input1 zero point cannot be statically determined"); + if (failed(negate.getOutputZeroPoint())) + return rewriter.notifyMatchFailure( + op, "output zero point cannot be statically determined"); + return success(); + } + + // tosa::CastOp + if (isa<tosa::CastOp>(op)) { + if (!elementTy.isIntOrFloat() || + !resultType.getElementType().isIntOrFloat()) + return rewriter.notifyMatchFailure(op, "unsupported type"); + return success(); + } + + return success(); +} + static Value createLinalgBodyCalculationForElementwiseOp( Operation *op, ValueRange args, ArrayRef<Type> resultTypes, ConversionPatternRewriter &rewriter) { @@ -139,17 +183,14 @@ static Value createLinalgBodyCalculationForElementwiseOp( auto shiftVal = cast<tosa::MulOp>(op).getShift(); DenseElementsAttr shiftElem; if (!matchPattern(shiftVal, m_Constant(&shiftElem))) { - (void)rewriter.notifyMatchFailure(op, "shift value of mul not found"); - return nullptr; + llvm_unreachable("shift value of mul not found"); } int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt(); if (isa<FloatType>(elementTy)) { if (shift != 0) { - (void)rewriter.notifyMatchFailure(op, - "Cannot have shift value for float"); - return nullptr; + llvm_unreachable("Cannot have shift value for float"); } return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]); } @@ -196,16 +237,12 @@ static Value createLinalgBodyCalculationForElementwiseOp( FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint(); if (failed(maybeInZp)) { - (void)rewriter.notifyMatchFailure( - op, "input1 zero point cannot be statically determined"); - return nullptr; + llvm_unreachable("input1 zero point cannot be statically determined"); } FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint(); if (failed(maybeOutZp)) { - (void)rewriter.notifyMatchFailure( - op, "output zero point cannot be statically determined"); - return nullptr; + llvm_unreachable("output zero point cannot be statically determined"); } int64_t inZp = *maybeInZp; @@ -548,10 +585,8 @@ static Value createLinalgBodyCalculationForElementwiseOp( if (isa<tosa::CastOp>(op)) { Type srcTy = elementTy; Type dstTy = resultTypes.front(); - if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) { - (void)rewriter.notifyMatchFailure(op, "unsupported type"); - return nullptr; - } + if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) + llvm_unreachable("unsupported type"); bool bitExtend = srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); @@ -706,8 +741,8 @@ static Value createLinalgBodyCalculationForElementwiseOp( } } - (void)rewriter.notifyMatchFailure( - op, "unhandled op for linalg body calculation for elementwise op"); + llvm_unreachable( + "unhandled op for linalg body calculation for elementwise op"); return nullptr; } @@ -930,17 +965,11 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, }); } -static LogicalResult -emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, - Operation *operation, ValueRange operands, - ArrayRef<OpFoldResult> targetShape, - const TypeConverter &converter) { +static LogicalResult emitElementwiseComputation( + ConversionPatternRewriter &rewriter, Location loc, Operation *operation, + ValueRange operands, ArrayRef<OpFoldResult> targetShape, + const TypeConverter &converter, RankedTensorType resultType) { // Generate output tensor - auto resultType = cast_or_null<RankedTensorType>( - converter.convertType(operation->getResultTypes().front())); - if (!resultType) { - return rewriter.notifyMatchFailure(operation, "failed to convert type"); - } Value outputTensor = rewriter.create<tensor::EmptyOp>( loc, targetShape, resultType.getElementType()); @@ -967,7 +996,6 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Emit 'linalg.generic' op - bool encounteredError = false; auto linalgOp = rewriter.create<linalg::GenericOp>( loc, outputTensor.getType(), operands, outputTensor, affineMaps, getNParallelLoopsAttrs(rank), @@ -975,15 +1003,10 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Value opResult = createLinalgBodyCalculationForElementwiseOp( operation, blockArgs.take_front(operation->getNumOperands()), {resultType.getElementType()}, rewriter); - if (!opResult) { - encounteredError = true; - return; - } + assert(opResult && + "unable to create linalg.generic body for elementwise op"); opBuilder.create<linalg::YieldOp>(loc, opResult); }); - if (encounteredError) - return rewriter.notifyMatchFailure( - operation, "unable to create linalg.generic body for elementwise op"); // Cast 'linalg.generic' result into original result type if needed auto castResult = rewriter.createOrFold<tensor::CastOp>( @@ -1008,13 +1031,20 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter) { - // Collect op properties + // Check if operation is supported. assert(operation->getNumResults() == 1 && "elementwise op expects 1 result"); assert(operation->getNumOperands() >= 1 && "elementwise op expects at least 1 operand"); if (!operandsAndResultsRanked(operation)) return rewriter.notifyMatchFailure(operation, "Unranked tensors not supported"); + auto resultType = cast_or_null<RankedTensorType>( + converter.convertType(operation->getResultTypes().front())); + if (!resultType) { + return rewriter.notifyMatchFailure(operation, "failed to convert type"); + } + if (failed(isSupportedElementwiseOperation(rewriter, operation, resultType))) + return failure(); // Lower operation IndexPool indexPool; @@ -1026,7 +1056,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast, targetShape, masterOperands); return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands, - targetShape, converter); + targetShape, converter, resultType); } // Returns the constant initial value for a given reduction operation. The @@ -1126,7 +1156,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op, if (isa<tosa::ReduceAnyOp>(op) && elementTy.isInteger(1)) return rewriter.create<arith::OrIOp>(loc, args); - return {}; + llvm_unreachable("unhandled reduction op"); } // Performs the match and rewrite for reduction operations. This includes @@ -1142,6 +1172,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); auto elementTy = resultTy.getElementType(); + auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); + if (!fillValueAttr) + return rewriter.notifyMatchFailure( + op, "No initial value found for reduction operation"); Value input = op->getOperand(0); SmallVector<int64_t> reduceShape; @@ -1164,11 +1198,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, dynDims) .getResult(); - auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); - if (!fillValueAttr) - return rewriter.notifyMatchFailure( - op, "No initial value found for reduction operation"); - auto fillValue = rewriter.create<arith::ConstantOp>(loc, fillValueAttr); auto filledTensor = rewriter .create<linalg::FillOp>(loc, ValueRange{fillValue}, @@ -1212,7 +1241,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, } } - bool didEncounterError = false; linalg::LinalgOp linalgOp = rewriter.create<linalg::ReduceOp>( loc, inputs, outputs, axis, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { @@ -1220,8 +1248,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]}; auto result = createLinalgBodyCalculationForReduceOp( op, binaryArgs, elementTy, rewriter); - if (result) - didEncounterError = true; + assert(result && "could not create reduction body"); SmallVector<Value> resultsToYield; if (isNanIgnoreMode) { @@ -1247,10 +1274,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, nestedBuilder.create<linalg::YieldOp>(loc, resultsToYield); }); - if (!didEncounterError) - return rewriter.notifyMatchFailure( - op, "unable to create linalg.generic body for reduce op"); - if (isNanIgnoreMode) { // Materialize a check to see whether we encountered any non-NaN values, if // we didn't we need to select a tensor of NaNs since the result will just @@ -1358,13 +1381,6 @@ public: if (!isa<IntegerType>(inputTy.getElementType())) return rewriter.notifyMatchFailure(op, "only support integer type"); - SmallVector<Value> dynDims; - for (int i = 0; i < outputTy.getRank(); i++) { - if (outputTy.isDynamicDim(i)) { - dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); - } - } - // The shift and multiplier values. DenseElementsAttr shiftElems; if (!matchPattern(op.getShift(), m_Constant(&shiftElems))) @@ -1376,6 +1392,21 @@ public: return rewriter.notifyMatchFailure( op, "tosa.rescale requires constant multiplier input values"); + if (failed(op.getInputZeroPoint())) + return rewriter.notifyMatchFailure( + op, "input zero point cannot be statically determined"); + + if (failed(op.getOutputZeroPoint())) + return rewriter.notifyMatchFailure( + op, "output zero point cannot be statically determined"); + + SmallVector<Value> dynDims; + for (int i = 0; i < outputTy.getRank(); i++) { + if (outputTy.isDynamicDim(i)) { + dynDims.push_back(rewriter.create<tensor::DimOp>(loc, input, i)); + } + } + llvm::SmallVector<int8_t> shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>()); // explicit cast is required here @@ -1473,23 +1504,10 @@ public: int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32; FailureOr<int64_t> maybeIZp = op.getInputZeroPoint(); - if (failed(maybeIZp)) { - (void)rewriter.notifyMatchFailure( - op, "input zero point cannot be statically determined"); - return; - } - auto inputZp = createConstOpFromZpVal<int32_t>( op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth), nestedBuilder); - FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint(); - if (failed(maybeOZp)) { - (void)rewriter.notifyMatchFailure( - op, "output zero point cannot be statically determined"); - return; - }; - auto outputZp = createConstOpFromZpVal<int32_t>( op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder); @@ -1783,6 +1801,15 @@ public: return rewriter.notifyMatchFailure( op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR"); + SmallVector<int64_t> scale, offset, border; + if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) || + !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) || + !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) { + return rewriter.notifyMatchFailure( + op, "tosa.resize scale/offset/border should have compile time " + "constant values."); + } + SmallVector<AffineMap, 2> affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto emptyTensor = b.create<tensor::EmptyOp>(resultTy.getShape(), resultETy, @@ -1810,15 +1837,6 @@ public: Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y); Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x); - SmallVector<int64_t> scale, offset, border; - if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) || - !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) || - !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) { - return rewriter.notifyMatchFailure( - op, "tosa.resize scale/offset/border should have compile time " - "constant values."); - } - Value yScaleN, yScaleD, xScaleN, xScaleD; yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0])); yScaleD = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[1])); @@ -2204,6 +2222,9 @@ public: auto inputTy = cast<ShapedType>(input.getType()); auto resultTy = cast<ShapedType>(argmaxOp.getOutput().getType()); auto inElementTy = inputTy.getElementType(); + if (!isa<IntegerType, FloatType>(inElementTy)) + return rewriter.notifyMatchFailure( + argmaxOp, "unsupported tosa.argmax element type"); auto outElementTy = resultTy.getElementType(); int axis = argmaxOp.getAxis(); auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); @@ -2213,6 +2234,12 @@ public: argmaxOp, "tosa.arg_max to linalg.* requires integer-like result type"); + auto fillValueMaxAttr = + createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); + if (!fillValueMaxAttr) + return rewriter.notifyMatchFailure( + argmaxOp, "unsupported tosa.argmax element type"); + SmallVector<Value> dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) && i != axis) { @@ -2238,12 +2265,6 @@ public: .create<tensor::EmptyOp>(loc, resultTy.getShape(), inElementTy, dynDims) .getResult(); - auto fillValueMaxAttr = - createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); - - if (!fillValueMaxAttr) - return rewriter.notifyMatchFailure( - argmaxOp, "unsupported tosa.argmax element type"); auto fillValueMax = rewriter.create<arith::ConstantOp>(loc, fillValueMaxAttr); @@ -2267,7 +2288,6 @@ public: dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); } - bool didEncounterError = false; auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs}, rewriter.getContext()); auto linalgOp = rewriter.create<linalg::GenericOp>( @@ -2305,8 +2325,7 @@ public: predicate = rewriter.create<arith::CmpIOp>( nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); } else { - didEncounterError = true; - return; + llvm_unreachable("unsupported tosa.argmax element type"); } auto resultMax = rewriter.create<arith::SelectOp>( @@ -2317,10 +2336,6 @@ public: nestedLoc, ValueRange({resultIndex, resultMax})); }); - if (didEncounterError) - return rewriter.notifyMatchFailure( - argmaxOp, "unsupported tosa.argmax element type"); - rewriter.replaceOp(argmaxOp, linalgOp.getResult(0)); return success(); } @@ -2416,6 +2431,15 @@ public: auto tableElementTy = tableTy.getElementType(); auto resultElementTy = resultTy.getElementType(); + bool isI8_8_8 = inputElementTy.isInteger(8) && + tableElementTy.isInteger(8) && resultElementTy.isInteger(8); + bool isI16_16_32 = inputElementTy.isInteger(16) && + tableElementTy.isInteger(16) && + resultElementTy.isInteger(32); + if (!isI8_8_8 && !isI16_16_32) + return rewriter.notifyMatchFailure( + op, "unable to create body for tosa.table op"); + SmallVector<Value> dynDims; for (int i = 0; i < resultTy.getRank(); ++i) { if (inputTy.isDynamicDim(i)) { @@ -2446,8 +2470,7 @@ public: auto inputValue = block->getArgument(0); rewriter.setInsertionPointToStart(block); - if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && - resultElementTy.isInteger(8)) { + if (isI8_8_8) { Value index = rewriter.create<arith::IndexCastOp>( loc, rewriter.getIndexType(), inputValue); Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 128); @@ -2459,8 +2482,7 @@ public: return success(); } - if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && - resultElementTy.isInteger(32)) { + if (isI16_16_32) { Value extend = rewriter.create<arith::ExtSIOp>( loc, rewriter.getI32Type(), inputValue); @@ -2516,8 +2538,7 @@ public: } } - return rewriter.notifyMatchFailure( - op, "unable to create body for tosa.table op"); + llvm_unreachable("unable to create body for tosa.table op"); } }; |