diff options
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 85 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir | 90 |
2 files changed, 140 insertions, 35 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 6800a0f..c332307 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -810,27 +810,35 @@ static Value calculateGatherOffset(RewriterBase &rewriter, enum VectorMemoryAccessKind { ScalarBroadcast, Contiguous, Gather }; -/// Find the non-unit dim in a linalgOp. -/// When executing this hook, it is expected that only one dim will be non-unit. -/// Other cases (i.e. reading n-D vectors) should've been labelled as gather -/// loads before calling this method. This is used for finding contiguous loads -/// (represented as `tensor.extract`) within `linalg.generic` Ops. Note that -/// this condition is expected to hold for statically shaped Linalg Ops only. -static uint64_t getNonUnitLoopDim(LinalgOp linalgOp) { - uint64_t nonUnitDim = 0; - uint64_t countNonUnitDim = 0; - for (auto tripCount : llvm::enumerate(linalgOp.getStaticLoopRanges())) { - if (tripCount.value() != 1) { - nonUnitDim = tripCount.index(); - countNonUnitDim++; - } - } - +/// Find the index of the trailing non-unit dim in linalgOp. This hook is used +/// when checking whether `tensor.extract` Op (within a `linalg.generic` Op) +/// represents a contiguous load operation. +/// +/// Note that when calling this hook, it is assumed that the output vector is +/// effectively 1D. Other cases (i.e. reading n-D vectors) should've been +/// labelled as a gather load before entering this method. +/// +/// Following on from the above, it is assumed that: +/// * for statically shaped loops, when no masks are used, only one dim is != +/// 1 (that's what the shape of the output vector is based on). +/// * for dynamically shaped loops, there might be more non-unit dims +/// as the output vector type is user-specified. +/// +/// TODO: Statically shaped loops + vector masking +static uint64_t getTrailingNonUnitLoopDimIdx(LinalgOp linalgOp) { + SmallVector<int64_t> loopRanges = linalgOp.getStaticLoopRanges(); assert(linalgOp.hasDynamicShape() || - countNonUnitDim == 1 && "For statically shaped Linalg Ops, only one " - "non-unit loop dim is expected"); - (void)countNonUnitDim; - return nonUnitDim; + llvm::count_if(loopRanges, [](int64_t dim) { return dim != 1; }) == + 1 && + "For statically shaped Linalg Ops, only one " + "non-unit loop dim is expected"); + + size_t idx = loopRanges.size() - 1; + for (; idx >= 0; idx--) + if (loopRanges[idx] != 1) + break; + + return idx; } /// Checks whether `val` can be used for calculating a loop invariant index. @@ -854,11 +862,11 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, assert(defOp && "This is neither a block argument nor an operation result"); // IndexOp is loop invariant as long as its result remains constant across - // iterations. Given the assumptions on the loop ranges above, only the - // trailing loop dim ever changes. - auto trailingLoopDim = linalgOp.getStaticLoopRanges().size() - 1; - if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) - return (indexOp.getDim() != trailingLoopDim); + // iterations. Note that for dynamic shapes, the corresponding dim will also + // be conservatively treated as != 1. + if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) { + return linalgOp.getStaticLoopRanges()[indexOp.getDim()] == 1; + } auto *ancestor = block->findAncestorOpInBlock(*defOp); @@ -877,7 +885,7 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, return result; } -/// Check whether \p val could be used for calculating the trailing index for a +/// Check whether `val` could be used for calculating the trailing index for a /// contiguous load operation. /// /// There are currently 3 types of values that are allowed here: @@ -886,13 +894,14 @@ static bool isLoopInvariantIdx(LinalgOp &linalgOp, Value &val, /// 3. results of basic arithmetic operations (linear and continuous) /// involving 1., 2. and 3. /// This method returns True if indeed only such values are used in calculating -/// \p val. +/// `val.` /// /// Additionally, the trailing index for a contiguous load operation should /// increment by 1 with every loop iteration, i.e. be based on: /// * `linalg.index <dim>` , -/// where <dim> is the trailing dim of the iteration space. \p foundIndexOp is -/// updated to `true` when such an op is found. +/// where <dim> is the trailing non-unit dim of the iteration space (this way, +/// `linalg.index <dim>` increments by 1 with every loop iteration). +/// `foundIndexOp` is updated to `true` when such Op is found. static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, bool &foundIndexOp, VectorType resType) { @@ -912,12 +921,10 @@ static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val, Operation *defOp = val.getDefiningOp(); assert(defOp && "This is neither a block argument nor an operation result"); - // Given the assumption on the loop ranges above, we expect only 1 non-unit - // loop dim. - auto nonUnitLoopDim = getNonUnitLoopDim(linalgOp); - if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp)) { - foundIndexOp = (indexOp.getDim() == nonUnitLoopDim); + auto loopDimThatIncrementsByOne = getTrailingNonUnitLoopDimIdx(linalgOp); + + foundIndexOp = (indexOp.getDim() == loopDimThatIncrementsByOne); return true; } @@ -1012,7 +1019,10 @@ getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp, bool foundIndexOp = false; bool isContiguousLoad = isContiguousLoadIdx(linalgOp, extractOpTrailingIdx, foundIndexOp, resType); - isContiguousLoad &= foundIndexOp; + // TODO: Support generating contiguous loads for column vectors - that will + // require adding a permutation map to tranfer_read Ops. + bool isRowVector = resType.getShape().back() != 1; + isContiguousLoad &= (foundIndexOp && isRowVector); if (isContiguousLoad) { LDBG("Found contigous load: " << extractOp); @@ -1073,6 +1083,11 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state, // b. contiguous loads. // Both cases use vector.transfer_read. + assert(llvm::count_if(resultType.getShape(), + [](uint64_t dim) { return dim != 1; }) && + "Contiguous loads and scalar loads + broadcast only support 1-D " + "vectors ATM!"); + // Collect indices for `vector.transfer_read`. At this point, the indices will // either be scalars or would have been broadcast to vectors matching the // result type. For indices that are vectors, there are two options: diff --git a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir index ad3a8d9..2c56b71 100644 --- a/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir +++ b/mlir/test/Dialect/Linalg/vectorize-tensor-extract.mlir @@ -307,6 +307,96 @@ module attributes {transform.with_named_sequence} { // ----- +// Reading a 1D column vector (hence a candidate for a contiguous load), but given +// %1, it's a gather load. + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @index_from_output_column_vector_gather_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> { + %c0 = arith.constant 0 : index + %0 = tensor.empty() : tensor<8x1xf32> + %res = linalg.generic { + indexing_maps = [#map], + iterator_types = ["parallel", "parallel"] + } outs(%0 : tensor<8x1xf32>) { + ^bb0(%arg1: f32): + %1 = linalg.index 0 : index + %extracted = tensor.extract %src[%1, %c0] : tensor<8x128xf32> + linalg.yield %extracted : f32 + } -> tensor<8x1xf32> + return %res : tensor<8x1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func.func @index_from_output_column_vector_gather_load( +// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> { +// CHECK: %[[C128:.*]] = arith.constant dense<128> : vector<1x8xindex> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> +// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1> +// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> +// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32> +// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex> +// CHECK: %[[MUL:.*]] = arith.muli %[[B]], %[[C128]] : vector<1x8xindex> +// CHECK: %[[TR:.*]] = vector.transpose %[[MUL]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> +// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> +// CHECK: return %[[RES]] : tensor<8x1xf32> + +// ----- + +// Same as above, but the access indices have been swapped and hence this _is_ +// a contiguous load. Currently not supported and lowered as vector.gather +// instead. +// TODO: Make sure that this is lowered as a contiguous load. + +#map = affine_map<(d0, d1) -> (d0, d1)> +func.func @index_from_output_column_vector_contiguous_load(%src: tensor<8x128xf32>) -> tensor<8x1xf32> { + %c0 = arith.constant 0 : index + %0 = tensor.empty() : tensor<8x1xf32> + %res = linalg.generic { + indexing_maps = [#map], + iterator_types = ["parallel", "parallel"] + } outs(%0 : tensor<8x1xf32>) { + ^bb0(%arg1: f32): + %1 = linalg.index 0 : index + %extracted = tensor.extract %src[%c0, %1] : tensor<8x128xf32> + linalg.yield %extracted : f32 + } -> tensor<8x1xf32> + return %res : tensor<8x1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg2: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg2 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 {vectorize_nd_extract} : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func.func @index_from_output_column_vector_contiguous_load( +// CHECK-SAME: %[[SRC:.*]]: tensor<8x128xf32>) -> tensor<8x1xf32> { +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[PASS_THRU:.*]] = arith.constant dense<0.000000e+00> : vector<8x1xf32> +// CHECK: %[[MASK:.*]] = arith.constant dense<true> : vector<8x1xi1> +// CHECK: %[[IDX_VEC:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> +// CHECK: %[[OUT:.*]] = tensor.empty() : tensor<8x1xf32> +// CHECK: %[[B:.*]] = vector.broadcast %[[IDX_VEC]] : vector<8xindex> to vector<1x8xindex> +// CHECK: %[[TR:.*]] = vector.transpose %[[B]], [1, 0] : vector<1x8xindex> to vector<8x1xindex> +// CHECK: %[[GATHER:.*]] = vector.gather %[[SRC]]{{\[}}%[[C0]], %[[C0]]] {{\[}}%[[TR]]], %[[MASK]], %[[PASS_THRU]] : tensor<8x128xf32>, vector<8x1xindex>, vector<8x1xi1>, vector<8x1xf32> into vector<8x1xf32> +// CHECK: %[[RES:.*]] = vector.transfer_write %[[GATHER]], %[[OUT]]{{\[}}%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<8x1xf32>, tensor<8x1xf32> +// CHECK: return %[[RES]] : tensor<8x1xf32> + +// ----- + #map = affine_map<(d0) -> (d0)> func.func @vectorize_nd_tensor_extract_contiguous_and_gather(%arg0: tensor<6xf32>, %arg1: tensor<5xi32>) -> tensor<5xf32> { %c5 = arith.constant 5 : index |
