aboutsummaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp')
-rw-r--r--mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp158
1 files changed, 74 insertions, 84 deletions
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 3a20524..da1fb20 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -64,19 +64,20 @@ linalgIntBroadcastExtSIAdd(PatternRewriter &rewriter, Location loc, Value bias,
Value conv, Value result,
ArrayRef<AffineMap> indexingMaps) {
ShapedType resultTy = cast<ShapedType>(conv.getType());
- return rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({bias, conv}), result, indexingMaps,
- getNParallelLoopsAttrs(resultTy.getRank()),
- [](OpBuilder &builder, Location loc, ValueRange args) {
- Value biasVal = args[0];
- Type resType = args[1].getType();
- if (resType != biasVal.getType()) {
- biasVal = arith::ExtSIOp::create(builder, loc, resType, biasVal);
- }
- Value added = arith::AddIOp::create(builder, loc, biasVal, args[1]);
- linalg::YieldOp::create(builder, loc, added);
- })
+ return linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({bias, conv}), result,
+ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
+ [](OpBuilder &builder, Location loc, ValueRange args) {
+ Value biasVal = args[0];
+ Type resType = args[1].getType();
+ if (resType != biasVal.getType()) {
+ biasVal =
+ arith::ExtSIOp::create(builder, loc, resType, biasVal);
+ }
+ Value added =
+ arith::AddIOp::create(builder, loc, biasVal, args[1]);
+ linalg::YieldOp::create(builder, loc, added);
+ })
.getResult(0);
}
@@ -124,23 +125,23 @@ static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
// Build the broadcast-like operation as a linalg.generic.
- return rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({source}), result, indexingMaps,
- getNParallelLoopsAttrs(resultTy.getRank()),
- [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
- Value biasVal = args[0];
- Type resType = args[1].getType();
- if (resType != biasVal.getType()) {
- biasVal =
- resultTy.getElementType().isFloat()
- ? arith::ExtFOp::create(builder, loc, resType, biasVal)
- .getResult()
- : arith::ExtSIOp::create(builder, loc, resType, biasVal)
- .getResult();
- }
- linalg::YieldOp::create(builder, loc, biasVal);
- })
+ return linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({source}), result,
+ indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
+ [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
+ Value biasVal = args[0];
+ Type resType = args[1].getType();
+ if (resType != biasVal.getType()) {
+ biasVal =
+ resultTy.getElementType().isFloat()
+ ? arith::ExtFOp::create(builder, loc, resType, biasVal)
+ .getResult()
+ : arith::ExtSIOp::create(builder, loc, resType,
+ biasVal)
+ .getResult();
+ }
+ linalg::YieldOp::create(builder, loc, biasVal);
+ })
.getResult(0);
}
@@ -397,21 +398,19 @@ public:
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
auto kZpVal = arith::ConstantOp::create(rewriter, loc, kZp);
- Value conv =
- rewriter
- .create<LinalgConvQOp>(
- loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
- ValueRange{broadcastBias}, strideAttr, dilationAttr)
- ->getResult(0);
+ Value conv = LinalgConvQOp::create(
+ rewriter, loc, resultTy,
+ ValueRange{input, weight, iZpVal, kZpVal},
+ ValueRange{broadcastBias}, strideAttr, dilationAttr)
+ ->getResult(0);
rewriter.replaceOp(op, conv);
return success();
}
- Value conv = rewriter
- .create<LinalgConvOp>(
- loc, accTy, ValueRange{input, weight},
- ValueRange{broadcastBias}, strideAttr, dilationAttr)
+ Value conv = LinalgConvOp::create(
+ rewriter, loc, accTy, ValueRange{input, weight},
+ ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);
// We may need to truncate back to the result type if the accumulator was
@@ -529,9 +528,8 @@ public:
Value emptyTensor = tensor::EmptyOp::create(
rewriter, loc, linalgConvTy.getShape(), accETy, filteredDims);
Value zero = arith::ConstantOp::create(rewriter, loc, resultZeroAttr);
- Value zeroTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zero},
- ValueRange{emptyTensor})
+ Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
+ ValueRange{emptyTensor})
.result();
Value biasEmptyTensor = tensor::EmptyOp::create(
@@ -544,10 +542,9 @@ public:
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
if (hasNullZps) {
- Value conv = rewriter
- .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
- loc, linalgConvTy, ValueRange{input, weight},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
+ Value conv = linalg::DepthwiseConv2DNhwcHwcmOp::create(
+ rewriter, loc, linalgConvTy, ValueRange{input, weight},
+ ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
// We may need to truncate back to the result type if the accumulator was
@@ -565,22 +562,20 @@ public:
rewriter, loc, resultTy, conv, reassociationMap);
Value result =
- rewriter
- .create<linalg::GenericOp>(
- loc, resultTy, ValueRange({bias, convReshape}),
- biasEmptyTensor, indexingMaps,
- getNParallelLoopsAttrs(resultRank),
- [&](OpBuilder &nestedBuilder, Location nestedLoc,
- ValueRange args) {
- Value added;
- if (llvm::isa<FloatType>(inputETy))
- added = arith::AddFOp::create(nestedBuilder, loc, args[0],
- args[1]);
- else
- added = arith::AddIOp::create(nestedBuilder, loc, args[0],
- args[1]);
- linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
- })
+ linalg::GenericOp::create(
+ rewriter, loc, resultTy, ValueRange({bias, convReshape}),
+ biasEmptyTensor, indexingMaps, getNParallelLoopsAttrs(resultRank),
+ [&](OpBuilder &nestedBuilder, Location nestedLoc,
+ ValueRange args) {
+ Value added;
+ if (llvm::isa<FloatType>(inputETy))
+ added = arith::AddFOp::create(nestedBuilder, loc, args[0],
+ args[1]);
+ else
+ added = arith::AddIOp::create(nestedBuilder, loc, args[0],
+ args[1]);
+ linalg::YieldOp::create(nestedBuilder, nestedLoc, added);
+ })
.getResult(0);
rewriter.replaceOp(op, result);
} else {
@@ -588,12 +583,11 @@ public:
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
auto iZpVal = arith::ConstantOp::create(rewriter, loc, iZp);
auto kZpVal = arith::ConstantOp::create(rewriter, loc, wZp);
- Value conv =
- rewriter
- .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
- loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
- .getResult(0);
+ Value conv = linalg::DepthwiseConv2DNhwcHwcmQOp::create(
+ rewriter, loc, linalgConvTy,
+ ValueRange{input, weight, iZpVal, kZpVal},
+ ValueRange{zeroTensor}, strideAttr, dilationAttr)
+ .getResult(0);
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = tensor::CollapseShapeOp::create(
@@ -639,9 +633,8 @@ public:
auto emptyTensor =
tensor::EmptyOp::create(rewriter, loc, outputTy.getShape(),
outputTy.getElementType(), filteredDims);
- Value zeroTensor = rewriter
- .create<linalg::FillOp>(loc, ValueRange{zero},
- ValueRange{emptyTensor})
+ Value zeroTensor = linalg::FillOp::create(rewriter, loc, ValueRange{zero},
+ ValueRange{emptyTensor})
.result();
FailureOr<int64_t> maybeAZp = op.getAZeroPoint();
@@ -910,20 +903,18 @@ public:
rewriter, loc, accTy.getShape(), accETy, dynamicDims);
Value filledEmptyTensor =
- rewriter
- .create<linalg::FillOp>(loc, ValueRange{initialValue},
- ValueRange{poolEmptyTensor})
+ linalg::FillOp::create(rewriter, loc, ValueRange{initialValue},
+ ValueRange{poolEmptyTensor})
.result();
Value fakeWindowDims =
tensor::EmptyOp::create(rewriter, loc, kernel, accETy);
// Sum across the pooled region.
- Value poolingOp = rewriter
- .create<linalg::PoolingNhwcSumOp>(
- loc, ArrayRef<Type>{accTy},
- ValueRange{paddedInput, fakeWindowDims},
- filledEmptyTensor, strideAttr, dilationAttr)
+ Value poolingOp = linalg::PoolingNhwcSumOp::create(
+ rewriter, loc, ArrayRef<Type>{accTy},
+ ValueRange{paddedInput, fakeWindowDims},
+ filledEmptyTensor, strideAttr, dilationAttr)
.getResult(0);
// Normalize the summed value by the number of elements grouped in each
@@ -1050,10 +1041,9 @@ public:
Value shift = arith::AddIOp::create(rewriter, loc, k8, thirty8);
auto scaled =
- rewriter
- .create<tosa::ApplyScaleOp>(
- loc, rewriter.getI32Type(), poolVal, multiplier, shift,
- rewriter.getStringAttr("SINGLE_ROUND"))
+ tosa::ApplyScaleOp::create(
+ rewriter, loc, rewriter.getI32Type(), poolVal, multiplier,
+ shift, rewriter.getStringAttr("SINGLE_ROUND"))
.getResult();
// If we have quantization information we need to apply output