diff options
author | Pablo Antonio Martinez <pablo.antonio.martinez@huawei.com> | 2024-03-22 11:53:29 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-03-22 12:53:29 +0100 |
commit | c41286af3f30e099556c6edbef0001466afaefcb (patch) | |
tree | e0c2058ba52a5f367b5850c5993b815f3ce474bc /mlir | |
parent | ceabaa7e7a2d02b20cbd2b31e8336dedb1d4d9f5 (diff) | |
download | llvm-c41286af3f30e099556c6edbef0001466afaefcb.zip llvm-c41286af3f30e099556c6edbef0001466afaefcb.tar.gz llvm-c41286af3f30e099556c6edbef0001466afaefcb.tar.bz2 |
[mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (#80813)
**Description**
The documentation of `transform.structured.tile_using_forall` says:
_"It is the user’s responsibility to ensure that num_threads/tile_sizes
is a valid tiling specification (i.e. that only tiles parallel
dimensions, e.g. in the Linalg case)."_
In other words, tiling a non-parallel dimension would generate code with
data races which is not safe to parallelize. For example, consider this
example (included in the tests in this PR):
```
func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
%0 = scf.forall (%arg2) in (8) shared_outs(%arg3 = %arg1) -> (tensor<300x8xf32>) {
%1 = affine.min #map(%arg2)
%2 = affine.max #map1(%1)
%3 = affine.apply #map2(%arg2)
%extracted_slice = tensor.extract_slice %arg0[%3, 0, 0] [%2, 300, 8] [1, 1, 1] : tensor<100x300x8xf32> to tensor<?x300x8xf32>
%4 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["reduction", "parallel", "parallel"]} ins(%extracted_slice : tensor<?x300x8xf32>) outs(%arg3 : tensor<300x8xf32>) {
^bb0(%in: f32, %out: f32):
%5 = arith.addf %in, %out : f32
linalg.yield %5 : f32
} -> tensor<300x8xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %4 into %arg3[0, 0] [300, 8] [1, 1] : tensor<300x8xf32> into tensor<300x8xf32>
}
}
return %0 : tensor<300x8xf32>
}
```
We can easily see that this is not safe to parallelize because all
threads would be writing to the same position in `%arg3` (in the
`scf.forall.in_parallel`.
This PR detects wether it's safe to `tile_using_forall` and emits a
warning in the case it is not.
**Brief explanation**
It first generates a vector of affine expressions representing the tile
values and stores it in `dimExprs`. These affine expressions are
compared with the affine expressions coming from the results of the
affine map of each output in the linalg op. So going back to the
previous example, the original transform is:
```
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> {
// expected-warning@+1 {{tiling is not thread safe at axis #0}}
%0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) {
^bb0(%in: f32, %out: f32):
%1 = arith.addf %in, %out : f32
linalg.yield %1 : f32
} -> tensor<300x8xf32>
return %0 : tensor<300x8xf32>
}
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8]
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
```
The `num_threads` attribute would be represented as `(d0)`. Because the
linalg op has only one output (`arg1`) it would only check against the
results of `#map1`, which are `(d1, d2)`. The idea is to check that all
affine expressions in `dimExprs` are present in the output affine map.
In this example, `d0` is not in `(d1, d2)`, so tiling that axis is
considered not thread safe.
Diffstat (limited to 'mlir')
-rw-r--r-- | mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td | 4 | ||||
-rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 38 | ||||
-rw-r--r-- | mlir/test/Dialect/Linalg/tile-to-forall.mlir | 141 |
3 files changed, 180 insertions, 3 deletions
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 4f34016..c260fe3 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1918,7 +1918,9 @@ def TileUsingForallOp : It is the user's responsibility to ensure that `num_threads/tile_sizes` is a valid tiling specification (i.e. that only tiles parallel dimensions, - e.g. in the Linalg case). + e.g. in the Linalg case). If the dimension is not parallelizable, a warning + is issued to notify the user that the generated code is not safe to + parallelize. If non-empty, the `mapping` is added as an attribute to the resulting `scf.forall`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 30aed85..462f6926 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -304,6 +304,28 @@ static void calculateTileOffsetsAndSizes( } } +/// Returns a vector of bools representing if, for each axis, `op` can be tiled +/// without incurring in a race condition and thus it is thread-safe to do the +/// tiling. This is checked by iterating over numThreads and ensuring that the +/// corresponding iterator type is "parallel". If it is not, then we know that +/// such dimension is unsafe to tile. +SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp, + ArrayRef<OpFoldResult> numThreads) { + auto iterators = linalgOp.getIteratorTypesArray(); + SmallVector<bool> safeToTile(numThreads.size(), true); + + for (unsigned i = 0, e = numThreads.size(); i != e; i++) { + if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) { + if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) { + safeToTile[i] = iterators[i] == utils::IteratorType::parallel; + } + } else { + safeToTile[i] = iterators[i] == utils::IteratorType::parallel; + } + } + return safeToTile; +} + /// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The /// tiling is specified by the number of tiles/threads `numThreads` and the /// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is @@ -314,8 +336,10 @@ static void calculateTileOffsetsAndSizes( /// size of data. /// It is the user's responsibility to ensure that `numThreads` is a valid /// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the -/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will -/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. +/// Linalg case). If the dimension is not parallelizable, a warning is issued to +/// notify the user that the generated code is not safe to parallelize. If +/// `omitTileOffsetBoundsCheck` is true, then the function will assume that +/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. static FailureOr<ForallTilingResult> tileToForallOpImpl( RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads, std::optional<ArrayRef<OpFoldResult>> nominalTileSizes, @@ -344,6 +368,16 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl( return getValueOrCreateConstantIndexOp(b, loc, ofr); })); + LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation()); + if (linalgOp) { + // Check if tiling is thread safe and print a warning if not. + SmallVector<bool> tilingSafety = + safeToTileToForall(b.getContext(), linalgOp, numThreads); + for (size_t i = 0; i < tilingSafety.size(); i++) + if (!tilingSafety[i]) + op.emitWarning() << "tiling is not thread safe at axis #" << i; + } + // 1. Create the ForallOp. We don't use the lambda body-builder // version because we require the use of RewriterBase in the body, so we // manually move the insertion point to the body below. diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir index abd807b..12e2dea 100644 --- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir @@ -586,3 +586,144 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @tile_thread_safety1(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100xf32> + return %0 : tensor<100xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> + +func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> { + // expected-warning@below {{tiling is not thread safe at axis #0}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<300x8xf32> + return %0 : tensor<300x8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> + +func.func @tile_thread_safety3(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>) -> tensor<100x8xf32> { + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<100x8xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100x8xf32> + return %0 : tensor<100x8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d2)> + +func.func @tile_thread_safety4(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>, %arg2 : tensor<8xf32>) -> (tensor<100x8xf32>, tensor<8xf32>) { + // expected-warning@+2 {{tiling is not thread safe at axis #0}} + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0:2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1, %arg2 : tensor<100x8xf32>, tensor<8xf32>) { + ^bb0(%in: f32, %out1: f32, %out2: f32): + %1 = arith.addf %in, %out1 : f32 + %2 = arith.addf %in, %out2 : f32 + linalg.yield %1, %2 : f32, f32 + } -> (tensor<100x8xf32>, tensor<8xf32>) + return %0#0, %0#1 : tensor<100x8xf32>, tensor<8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @tile_thread_safety5(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100xf32> + return %0 : tensor<100xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 tile_sizes [10, 1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @tile_thread_safety6(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> { + // expected-warning@below {{tiling is not thread safe at axis #2}} + %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>) + outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>) + return %0 : tensor<?x?xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 0, 8] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} |