From c41286af3f30e099556c6edbef0001466afaefcb Mon Sep 17 00:00:00 2001 From: Pablo Antonio Martinez Date: Fri, 22 Mar 2024 11:53:29 +0000 Subject: [mlir][linalg] Emit a warning when tile_using_forall generates non thread-safe code (#80813) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit **Description** The documentation of `transform.structured.tile_using_forall` says: _"It is the user’s responsibility to ensure that num_threads/tile_sizes is a valid tiling specification (i.e. that only tiles parallel dimensions, e.g. in the Linalg case)."_ In other words, tiling a non-parallel dimension would generate code with data races which is not safe to parallelize. For example, consider this example (included in the tests in this PR): ``` func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> { %0 = scf.forall (%arg2) in (8) shared_outs(%arg3 = %arg1) -> (tensor<300x8xf32>) { %1 = affine.min #map(%arg2) %2 = affine.max #map1(%1) %3 = affine.apply #map2(%arg2) %extracted_slice = tensor.extract_slice %arg0[%3, 0, 0] [%2, 300, 8] [1, 1, 1] : tensor<100x300x8xf32> to tensor %4 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["reduction", "parallel", "parallel"]} ins(%extracted_slice : tensor) outs(%arg3 : tensor<300x8xf32>) { ^bb0(%in: f32, %out: f32): %5 = arith.addf %in, %out : f32 linalg.yield %5 : f32 } -> tensor<300x8xf32> scf.forall.in_parallel { tensor.parallel_insert_slice %4 into %arg3[0, 0] [300, 8] [1, 1] : tensor<300x8xf32> into tensor<300x8xf32> } } return %0 : tensor<300x8xf32> } ``` We can easily see that this is not safe to parallelize because all threads would be writing to the same position in `%arg3` (in the `scf.forall.in_parallel`. This PR detects wether it's safe to `tile_using_forall` and emits a warning in the case it is not. **Brief explanation** It first generates a vector of affine expressions representing the tile values and stores it in `dimExprs`. These affine expressions are compared with the affine expressions coming from the results of the affine map of each output in the linalg op. So going back to the previous example, the original transform is: ``` #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2)> func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> { // expected-warning@+1 {{tiling is not thread safe at axis #0}} %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) { ^bb0(%in: f32, %out: f32): %1 = arith.addf %in, %out : f32 linalg.yield %1 : f32 } -> tensor<300x8xf32> return %0 : tensor<300x8xf32> } module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) transform.yield } } ``` The `num_threads` attribute would be represented as `(d0)`. Because the linalg op has only one output (`arg1`) it would only check against the results of `#map1`, which are `(d1, d2)`. The idea is to check that all affine expressions in `dimExprs` are present in the output affine map. In this example, `d0` is not in `(d1, d2)`, so tiling that axis is considered not thread safe. --- .../Linalg/TransformOps/LinalgTransformOps.td | 4 +- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 38 +++++- mlir/test/Dialect/Linalg/tile-to-forall.mlir | 141 +++++++++++++++++++++ 3 files changed, 180 insertions(+), 3 deletions(-) (limited to 'mlir') diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 4f34016..c260fe3 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1918,7 +1918,9 @@ def TileUsingForallOp : It is the user's responsibility to ensure that `num_threads/tile_sizes` is a valid tiling specification (i.e. that only tiles parallel dimensions, - e.g. in the Linalg case). + e.g. in the Linalg case). If the dimension is not parallelizable, a warning + is issued to notify the user that the generated code is not safe to + parallelize. If non-empty, the `mapping` is added as an attribute to the resulting `scf.forall`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 30aed85..462f6926 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -304,6 +304,28 @@ static void calculateTileOffsetsAndSizes( } } +/// Returns a vector of bools representing if, for each axis, `op` can be tiled +/// without incurring in a race condition and thus it is thread-safe to do the +/// tiling. This is checked by iterating over numThreads and ensuring that the +/// corresponding iterator type is "parallel". If it is not, then we know that +/// such dimension is unsafe to tile. +SmallVector safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp, + ArrayRef numThreads) { + auto iterators = linalgOp.getIteratorTypesArray(); + SmallVector safeToTile(numThreads.size(), true); + + for (unsigned i = 0, e = numThreads.size(); i != e; i++) { + if (auto attr = llvm::dyn_cast_if_present(numThreads[i])) { + if (cast(attr).getValue().getSExtValue() > 1) { + safeToTile[i] = iterators[i] == utils::IteratorType::parallel; + } + } else { + safeToTile[i] = iterators[i] == utils::IteratorType::parallel; + } + } + return safeToTile; +} + /// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The /// tiling is specified by the number of tiles/threads `numThreads` and the /// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is @@ -314,8 +336,10 @@ static void calculateTileOffsetsAndSizes( /// size of data. /// It is the user's responsibility to ensure that `numThreads` is a valid /// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the -/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will -/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. +/// Linalg case). If the dimension is not parallelizable, a warning is issued to +/// notify the user that the generated code is not safe to parallelize. If +/// `omitTileOffsetBoundsCheck` is true, then the function will assume that +/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. static FailureOr tileToForallOpImpl( RewriterBase &b, TilingInterface op, ArrayRef numThreads, std::optional> nominalTileSizes, @@ -344,6 +368,16 @@ static FailureOr tileToForallOpImpl( return getValueOrCreateConstantIndexOp(b, loc, ofr); })); + LinalgOp linalgOp = dyn_cast(op.getOperation()); + if (linalgOp) { + // Check if tiling is thread safe and print a warning if not. + SmallVector tilingSafety = + safeToTileToForall(b.getContext(), linalgOp, numThreads); + for (size_t i = 0; i < tilingSafety.size(); i++) + if (!tilingSafety[i]) + op.emitWarning() << "tiling is not thread safe at axis #" << i; + } + // 1. Create the ForallOp. We don't use the lambda body-builder // version because we require the use of RewriterBase in the body, so we // manually move the insertion point to the body below. diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir index abd807b..12e2dea 100644 --- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir +++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir @@ -586,3 +586,144 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @tile_thread_safety1(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100xf32> + return %0 : tensor<100xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> + +func.func @tile_thread_safety2(%arg0: tensor<100x300x8xf32>, %arg1: tensor<300x8xf32>) -> tensor<300x8xf32> { + // expected-warning@below {{tiling is not thread safe at axis #0}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["reduction", "parallel", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<300x8xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<300x8xf32> + return %0 : tensor<300x8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> + +func.func @tile_thread_safety3(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>) -> tensor<100x8xf32> { + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1 : tensor<100x8xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100x8xf32> + return %0 : tensor<100x8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d2)> + +func.func @tile_thread_safety4(%arg0: tensor<100x300x8xf32>, %arg1: tensor<100x8xf32>, %arg2 : tensor<8xf32>) -> (tensor<100x8xf32>, tensor<8xf32>) { + // expected-warning@+2 {{tiling is not thread safe at axis #0}} + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0:2 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel"]} ins(%arg0 : tensor<100x300x8xf32>) outs(%arg1, %arg2 : tensor<100x8xf32>, tensor<8xf32>) { + ^bb0(%in: f32, %out1: f32, %out2: f32): + %1 = arith.addf %in, %out1 : f32 + %2 = arith.addf %in, %out2 : f32 + linalg.yield %1, %2 : f32, f32 + } -> (tensor<100x8xf32>, tensor<8xf32>) + return %0#0, %0#1 : tensor<100x8xf32>, tensor<8xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [8, 4, 2] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +func.func @tile_thread_safety5(%arg0: tensor<100x300xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { + // expected-warning@below {{tiling is not thread safe at axis #1}} + %0 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "reduction"]} ins(%arg0 : tensor<100x300xf32>) outs(%arg1 : tensor<100xf32>) { + ^bb0(%in: f32, %out: f32): + %1 = arith.addf %in, %out : f32 + linalg.yield %1 : f32 + } -> tensor<100xf32> + return %0 : tensor<100xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 tile_sizes [10, 1] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + +func.func @tile_thread_safety6(%A: tensor, %B: tensor, %C: tensor) -> tensor { + // expected-warning@below {{tiling is not thread safe at axis #2}} + %0 = linalg.matmul ins(%A, %B : tensor, tensor) + outs(%C : tensor) -> (tensor) + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall, %tiled_generic = transform.structured.tile_using_forall %0 num_threads [2, 0, 8] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} -- cgit v1.1