diff options
author | Uday Bondhugula <uday@polymagelabs.com> | 2020-06-11 18:10:53 +0530 |
---|---|---|
committer | Uday Bondhugula <uday@polymagelabs.com> | 2020-07-04 14:55:02 +0530 |
commit | 6d6d5db251e88b0c40f7a6951b51b9e4a1812c8c (patch) | |
tree | a6747e24be248d4a54ee8951b58834d8f5afff06 | |
parent | 7356b4243ad9dc373c63bb4b7ac0ba4f877c47ee (diff) | |
download | llvm-6d6d5db251e88b0c40f7a6951b51b9e4a1812c8c.zip llvm-6d6d5db251e88b0c40f7a6951b51b9e4a1812c8c.tar.gz llvm-6d6d5db251e88b0c40f7a6951b51b9e4a1812c8c.tar.bz2 |
[MLIR][Linalg] Generate the right type of load/store when lowering max/min pooling ops
While lowering min/max pooling ops to loops, generate the right kind of
load/stores (std or affine) instead of always generating std
load/stores.
Differential Revision: https://reviews.llvm.org/D83080
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Loops.cpp | 21 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/affine.mlir | 24 |
2 files changed, 37 insertions, 8 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp index ec57717..575115c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp @@ -333,23 +333,28 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs, ConvOp convOp) { template <typename IndexedValueType> void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMaxOp op) { - auto indices = getInputAndOutputIndices(allIvs, op); + InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); // Emit scalar form. - Value lhs = std_load(op.output(), indices.outputs); - Value rhs = std_load(op.input(), indices.inputs); + IndexedValueType output(op.output()); + IndexedValueType input(op.input()); + Value lhs = output(indices.outputs); + Value rhs = input(indices.inputs); using edsc::op::sgt; Value maxValue = std_select(sgt(lhs, rhs), lhs, rhs); - std_store(maxValue, op.output(), indices.outputs); + output(indices.outputs) = maxValue; } + template <typename IndexedValueType> void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingMinOp op) { - auto indices = getInputAndOutputIndices(allIvs, op); + InputAndOutputIndices indices = getInputAndOutputIndices(allIvs, op); // Emit scalar form. - Value lhs = std_load(op.output(), indices.outputs); - Value rhs = std_load(op.input(), indices.inputs); + IndexedValueType output(op.output()); + IndexedValueType input(op.input()); + Value lhs = output(indices.outputs); + Value rhs = input(indices.inputs); using edsc::op::slt; Value minValue = std_select(slt(lhs, rhs), lhs, rhs); - std_store(minValue, op.output(), indices.outputs); + output(indices.outputs) = minValue; } template <typename IndexedValueType> void emitScalarImplementation(ArrayRef<Value> allIvs, PoolingSumOp op) { diff --git a/mlir/test/Dialect/Linalg/affine.mlir b/mlir/test/Dialect/Linalg/affine.mlir index cb20646..13f37d8 100644 --- a/mlir/test/Dialect/Linalg/affine.mlir +++ b/mlir/test/Dialect/Linalg/affine.mlir @@ -123,3 +123,27 @@ func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memre // CHECK: %[[inc:.*]] = mulf %[[va]], %[[vb]] : f32 // CHECK: %[[res:.*]] = addf %[[vc]], %[[inc]] : f32 // CHECK: affine.store %[[res]], %[[mC]][%[[b]], %[[m]], %[[n]]] : memref<?x?x?xf32> + +// CHECK-LABEL: func @pooling_max_min +func @pooling_max_min(%arg0: memref<?x?xf32>, + %arg1: memref<?x?xi32>, + %arg2: memref<?x?xf32>) { + linalg.pooling_max(%arg0, %arg1, %arg2) { strides = [2, 1] }: + memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32> + linalg.pooling_min(%arg0, %arg1, %arg2) { strides = [2, 1] }: + memref<?x?xf32>, memref<?x?xi32>, memref<?x?xf32> + return +} +// This is a basic check to make sure the right load/stores are used. loops.mlir +// checks for the rest. +// CHECK: affine.load +// CHECK-NEXT: affine.load +// CHECK-NEXT: cmpf +// CHECK-NEXT: select +// CHECK-NEXT: affine.store +// The min pooling body. +// CHECK: affine.load +// CHECK-NEXT: affine.load +// CHECK-NEXT: cmpf +// CHECK-NEXT: select +// CHECK-NEXT: affine.store |