aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/docs/Dialects/Transform.md8
-rw-r--r--mlir/docs/Tutorials/transform/Ch1.md6
-rw-r--r--mlir/docs/Tutorials/transform/Ch2.md6
-rw-r--r--mlir/docs/Tutorials/transform/Ch3.md2
-rw-r--r--mlir/docs/Tutorials/transform/ChH.md6
-rw-r--r--mlir/docs/Tutorials/transform/_index.md10
-rw-r--r--mlir/include/mlir/Dialect/GPU/Transforms/Passes.h7
-rw-r--r--mlir/include/mlir/Dialect/GPU/Transforms/Utils.h5
-rw-r--r--mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h2
-rw-r--r--mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td11
-rw-r--r--mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h3
-rw-r--r--mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td48
-rw-r--r--mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md683
-rw-r--r--mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h35
-rw-r--r--mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h1
-rw-r--r--mlir/include/mlir/Dialect/Transform/IR/TransformOps.td22
-rw-r--r--mlir/include/mlir/IR/Operation.h30
-rw-r--r--mlir/include/mlir/Pass/PassManager.h36
-rw-r--r--mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h6
-rw-r--r--mlir/lib/CAPI/IR/IR.cpp6
-rw-r--r--mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp68
-rw-r--r--mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp48
-rw-r--r--mlir/lib/Dialect/GPU/CMakeLists.txt1
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp27
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp174
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/Utils.cpp44
-rw-r--r--mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp91
-rw-r--r--mlir/lib/Dialect/Mesh/IR/MeshOps.cpp213
-rw-r--r--mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt3
-rw-r--r--mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp639
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp28
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp180
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp44
-rw-r--r--mlir/lib/Dialect/Transform/IR/TransformOps.cpp37
-rw-r--r--mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp2
-rw-r--r--mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp2
-rw-r--r--mlir/lib/IR/AsmPrinter.cpp5
-rw-r--r--mlir/lib/Interfaces/InferTypeOpInterface.cpp5
-rw-r--r--mlir/lib/Pass/Pass.cpp18
-rw-r--r--mlir/lib/Pass/PassCrashRecovery.cpp87
-rw-r--r--mlir/lib/Pass/PassDetail.h5
-rw-r--r--mlir/lib/Tools/mlir-opt/MlirOptMain.cpp18
-rw-r--r--mlir/test/Conversion/GPUCommon/memref-arg-attrs.mlir25
-rw-r--r--mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir37
-rw-r--r--mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir190
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir15
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-match.mlir5
-rw-r--r--mlir/test/Dialect/Linalg/transform-op-pad.mlir3
-rw-r--r--mlir/test/Dialect/Mesh/invalid.mlir96
-rw-r--r--mlir/test/Dialect/Mesh/ops.mlir49
-rw-r--r--mlir/test/Dialect/Mesh/resharding-spmdization.mlir154
-rw-r--r--mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir (renamed from mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir)28
-rw-r--r--mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir72
-rw-r--r--mlir/test/Dialect/Transform/ops-invalid.mlir8
-rw-r--r--mlir/test/Dialect/Transform/test-interpreter.mlir96
-rw-r--r--mlir/test/Dialect/Transform/test-loop-transforms.mlir9
-rw-r--r--mlir/test/Dialect/Vector/vector-transfer-flatten.mlir15
-rw-r--r--mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-hand.mlir (renamed from mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir)139
-rw-r--r--mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir141
-rw-r--r--mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir41
-rw-r--r--mlir/test/Pass/crashless-reproducer.mlir10
-rw-r--r--mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp24
-rw-r--r--mlir/test/lib/Dialect/Mesh/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp122
-rw-r--r--mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp45
-rw-r--r--mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td27
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp2
-rw-r--r--mlir/unittests/IR/OpPropertiesTest.cpp46
68 files changed, 3411 insertions, 661 deletions
diff --git a/mlir/docs/Dialects/Transform.md b/mlir/docs/Dialects/Transform.md
index 8fa7038c..67b4ee5 100644
--- a/mlir/docs/Dialects/Transform.md
+++ b/mlir/docs/Dialects/Transform.md
@@ -19,7 +19,7 @@ operations, and then applying loop unrolling to the inner loops produced by the
previous transformations. As such, it is not intended as a replacement for the
pass infrastructure, nor for the pattern rewriting infrastructure. In the most
common case, the transform IR will be processed and applied to the payload IR by
-a pass. Transformations expressed by the transform dialect may be implemented
+a pass. Transformations expressed by the Transform dialect may be implemented
using the pattern infrastructure or any other relevant MLIR component.
The following IR gives a rough idea of what the operations in this dialect
@@ -271,7 +271,7 @@ operation lists.
## Handle Invalidation
-The execution model of the transform dialect allows a payload IR operation to be
+The execution model of the Transform dialect allows a payload IR operation to be
associated with _multiple_ handles as well as nested payload IR operations to be
associated with different handles. Similarly, a payload IR value may be
associated with multiple transform IR value handles. When a transform IR
@@ -373,13 +373,13 @@ to specify which transformations the pass should run. The transform dialect
provides a uniform, extensible mechanism for controlling transformations in
such cases.
-The transform dialect is supposed to be consumed by an "interpreter" pass
+The Transform dialect is supposed to be consumed by an "interpreter" pass
that drives the application of transformations. To ensure extensibility and
composability, this pass is not expected to actually perform the
transformations specified by the ops. Instead, the transformations are
implemented by the transform ops themselves via `TransformOpInterface`. The
pass serves as the entry point, handles the flow of transform operations and
-takes care of bookkeeping. As such, the transform dialect does not provide
+takes care of bookkeeping. As such, the Transform dialect does not provide
the interpreter pass. Instead, it provides a set of utilities that can be
used by clients to define their own interpreter passes or as part of a more
complex pass. For example, the mapping between values in the transform IR
diff --git a/mlir/docs/Tutorials/transform/Ch1.md b/mlir/docs/Tutorials/transform/Ch1.md
index 95b37eb..0df25a5 100644
--- a/mlir/docs/Tutorials/transform/Ch1.md
+++ b/mlir/docs/Tutorials/transform/Ch1.md
@@ -79,7 +79,7 @@ transform.sequence failures(propagate) {
## Transform Dialect Interpreter
-Since we don’t want to recompile the compiler every time we change a transformation, we can use a transform dialect interpreter pass to apply this transformation sequence to the payload IR. As we will see in the next chapter, it is possible to define custom passes or even integrate the transform interpreter into a larger pass. For now, we can use the existing test pass:
+Since we don’t want to recompile the compiler every time we change a transformation, we can use a Transform dialect interpreter pass to apply this transformation sequence to the payload IR. As we will see in the next chapter, it is possible to define custom passes or even integrate the transform interpreter into a larger pass. For now, we can use the existing test pass:
```sh
@@ -168,7 +168,7 @@ Besides producing new handles, the tiling transform operation _consumes_ the ope
## Handle Invalidation and Expensive Checks Mode
-Undefined behavior is difficult to grapple with when it does happen, so the transform dialect interpreter provides a set of additional expensive checks that detect most undefined behavior in the transform IR. For example, if we wanted to use the `%arg1` handle after it is consumed, it would cause undefined behavior that manifests as an assertion in the debug build, and likely as a segmentation fault in the release mode.
+Undefined behavior is difficult to grapple with when it does happen, so the Transform dialect interpreter provides a set of additional expensive checks that detect most undefined behavior in the transform IR. For example, if we wanted to use the `%arg1` handle after it is consumed, it would cause undefined behavior that manifests as an assertion in the debug build, and likely as a segmentation fault in the release mode.
```mlir
transform.sequence failures(propagate) {
@@ -379,7 +379,7 @@ Finally, we would like to replace the call to the outlined function with a call
## Tracking IR Modifications
-The transform dialect automatically tracks all IR changes that are made as part
+The Transform dialect automatically tracks all IR changes that are made as part
of transform ops. (Implementations must use the provided rewriter to modify IR.)
If a payload op is erased, it is automatically removed from all handles that it
is currently associated with. If a payload op is replaced, the transform dialect
diff --git a/mlir/docs/Tutorials/transform/Ch2.md b/mlir/docs/Tutorials/transform/Ch2.md
index 8d5076e..ac6d7d42 100644
--- a/mlir/docs/Tutorials/transform/Ch2.md
+++ b/mlir/docs/Tutorials/transform/Ch2.md
@@ -10,7 +10,7 @@ The Transform dialect uses the dialect extension mechanism to allow additional o
// In MyExtension.cpp.
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
-// Define a new transform dialect extension. This uses the CRTP idiom to identify
+// Define a new Transform dialect extension. This uses the CRTP idiom to identify
// extensions.
class MyExtension : public ::mlir::transform::TransformDialectExtension<MyExtension> {
public:
@@ -200,7 +200,7 @@ must be modified with the provided rewriter.
```c++
// In MyExtension.cpp
-// Implementation of our transform dialect operation.
+// Implementation of our Transform dialect operation.
// This operation returns a tri-state result that can be one of:
// - success when the transformation succeeded;
// - definite failure when the transformation failed in such a way that
@@ -277,7 +277,7 @@ void registerMyExtension(::mlir::DialectRegistry &registry) {
}
```
-After registering the extension, it becomes possible to use our new operation in the transform dialect interpreter. The upstream testing pass can be used as is.
+After registering the extension, it becomes possible to use our new operation in the Transform dialect interpreter. The upstream testing pass can be used as is.
```mlir
transform.sequence failures(propagate) {
diff --git a/mlir/docs/Tutorials/transform/Ch3.md b/mlir/docs/Tutorials/transform/Ch3.md
index 4e9d1e6..84251df 100644
--- a/mlir/docs/Tutorials/transform/Ch3.md
+++ b/mlir/docs/Tutorials/transform/Ch3.md
@@ -138,7 +138,7 @@ void MyExtension::init() {
}
```
-This type is now directly available in the transform dialect and can be used in operations.
+This type is now directly available in the Transform dialect and can be used in operations.
```mlir
diff --git a/mlir/docs/Tutorials/transform/ChH.md b/mlir/docs/Tutorials/transform/ChH.md
index 7c12728..f4dae5c 100644
--- a/mlir/docs/Tutorials/transform/ChH.md
+++ b/mlir/docs/Tutorials/transform/ChH.md
@@ -1,7 +1,7 @@
# Chapter H: Reproducing Halide Schedule
This chapter demonstrates how a schedule from the [Halide
-DSL](http://halide-lang.org) can be implemented using transform dialect for
+DSL](http://halide-lang.org) can be implemented using Transform dialect for
structured ops.
Note that the IR below is pseudo-code with types removed for brevity. It may
@@ -408,7 +408,7 @@ identical_ to the code with the full schedule. Therefore, we will only unroll
the corresponding loops corresponding to `xi` and `ci` dimensions that actually
get unrolled by Halide.
-As tiling in the transform dialect produces handles to the loops materialized by
+As tiling in the Transform dialect produces handles to the loops materialized by
tiling, unrolling those loops is just a matter of chaining the corresponding
transformation. Note that the inner loop must be unrolled first as unrolling the
outer loop will invalidate the handles to the inner loop.
@@ -499,7 +499,7 @@ bufferization is directly available as a transform operation.
One-shot bufferization itself does not produce buffer deallocations, which may
lead to leaks. So we have to run the buffer deallocation pass pipeline to avoid
-them. Note that the transform dialect seamlessly runs named passes and pass
+them. Note that the Transform dialect seamlessly runs named passes and pass
pipelines: if desired, one could replace complex `--pass-pipeline expressions`
with operations. Note that we apply the pipeline to functions rather than entire
module to avoid running it on the transform IR that is contained in the module.
diff --git a/mlir/docs/Tutorials/transform/_index.md b/mlir/docs/Tutorials/transform/_index.md
index 3afb9c5..b508a5d 100644
--- a/mlir/docs/Tutorials/transform/_index.md
+++ b/mlir/docs/Tutorials/transform/_index.md
@@ -8,15 +8,15 @@ scheduling languages). This tutorial presents the concepts of the MLIR transform
dialect and related infrastructure. It will be accompanied by a practical
demonstration of three use scenarios:
-- Composing transform dialect operations available in (upstream) MLIR to perform
+- Composing Transform dialect operations available in (upstream) MLIR to perform
a sequence of optimizing transformations that results in efficient code for an
MLIR linear algebra operation.
-- Defining new transform dialect operations and adapting existing transformation
- code to work with the transform dialect infrastructure.
-- Setting up and using the transform dialect infrastructure in a downstream
+- Defining new Transform dialect operations and adapting existing transformation
+ code to work with the Transform dialect infrastructure.
+- Setting up and using the Transform dialect infrastructure in a downstream
out-of-tree project with custom dialects, transformations and passes.
-After following the tutorial, one will be able to apply the transform dialect in
+After following the tutorial, one will be able to apply the Transform dialect in
their work and extend it when necessary. Basic familiarity with MLIR is a
prerequisite. See [Toy tutorial](../Toy) for introduction to MLIR.
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
index 6c5bf75..5885fac 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Passes.h
@@ -70,6 +70,13 @@ void populateGpuBreakDownSubgrupReducePatterns(RewritePatternSet &patterns,
unsigned maxShuffleBitwidth = 32,
PatternBenefit benefit = 1);
+/// Collect a set of patterns to lower `gpu.subgroup_reduce` into `gpu.shuffle`
+/// ops over `shuffleBitwidth` scalar types. Assumes that the subgroup has
+/// `subgroupSize` lanes. Uses the butterfly shuffle algorithm.
+void populateGpuLowerSubgroupReduceToShufflePattenrs(
+ RewritePatternSet &patterns, unsigned subgroupSize,
+ unsigned shuffleBitwidth = 32, PatternBenefit benefit = 1);
+
/// Collect all patterns to rewrite ops within the GPU dialect.
inline void populateGpuRewritePatterns(RewritePatternSet &patterns) {
populateGpuAllReducePatterns(patterns);
diff --git a/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
index a426bee..f25c506 100644
--- a/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/Transforms/Utils.h
@@ -13,6 +13,8 @@
#ifndef MLIR_DIALECT_GPU_TRANSFORMS_UTILS_H_
#define MLIR_DIALECT_GPU_TRANSFORMS_UTILS_H_
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Support/LLVM.h"
#include <string>
@@ -28,6 +30,9 @@ class LaunchOp;
/// Returns the default annotation name for GPU binary blobs.
std::string getDefaultGpuBinaryAnnotation();
+
+/// Returns the matching vector combining kind.
+vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode);
} // namespace gpu
/// Get a gpu.func created from outlining the region of a gpu.launch op with the
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
index 6c82402..f92843a 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
@@ -62,6 +62,8 @@ struct ContractionDimensions {
/// `k`, indices are returned in sorted order.
/// Returns a failure if any of `m`, `n` or `k` is empty.
FailureOr<ContractionDimensions> inferContractionDims(LinalgOp linalgOp);
+FailureOr<ContractionDimensions>
+inferContractionDims(ArrayRef<AffineMap> indexingMaps);
/// Checks whether `linalgOp` conforms to ContractionOpInterface.
// TODO: embed within `isa<ContractionOpInterface>` if possible / natural.
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
index 9d39b1b3..a9d30df 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshBase.td
@@ -33,6 +33,10 @@ def Mesh_Dialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
}
+
+def Mesh_MeshAxis : I<16>;
+def Mesh_MeshAxesAttr : DenseArrayAttrBase<"DenseI16ArrayAttr", "int16_t", "i16">;
+
//===----------------------------------------------------------------------===//
// Mesh Enums.
//===----------------------------------------------------------------------===//
@@ -125,7 +129,7 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_sum along mesh axis 1.
- tensor<4x8xf32, #mesh.shard<@mesh0, [[0], [], [1]]>
+ tensor<4x8xf32, #mesh.shard<@mesh0, [[0], []], partial = sum[1]>
// The tensor is sharded on the first dimension along axis 0 of @mesh0 and
// it is also a partial_max along mesh axis 1.
@@ -158,6 +162,11 @@ def MeshSharding : AttrDef<Mesh_Dialect, "MeshSharding"> {
}]>
];
+ let extraClassDeclaration = [{
+ bool operator==(::mlir::Attribute rhs) const;
+ bool operator==(::mlir::mesh::MeshShardingAttr rhs) const;
+ }];
+
let genVerifyDecl = 1;
}
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index 9077d2e..ce7d5d0 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -30,6 +30,9 @@
namespace mlir {
namespace mesh {
+using MeshAxis = int16_t;
+using MeshAxesAttr = DenseI16ArrayAttr;
+
bool isReductionLoop(IteratorType iType);
bool areReductionAndPartialMatch(IteratorType iType, Partial partial);
diff --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
index 784f3eb..1ed54b6 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.td
@@ -10,6 +10,7 @@
#define MLIR_DIALECT_MESH_IR_MESHOPS_TD
include "mlir/Dialect/Mesh/IR/MeshBase.td"
+include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinTypes.td"
@@ -95,6 +96,28 @@ def Mesh_ClusterOp : Mesh_Op<"cluster", [Symbol]> {
let hasVerifier = 1;
}
+def Mesh_ClusterShapeOp : Mesh_Op<"cluster_shape", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "Get the shape of the cluster.";
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+
+ let results = (outs
+ Variadic<Index>:$result
+ );
+
+ let assemblyFormat = [{
+ $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ ];
+}
+
def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
let summary = "Annotate on how a tensor is sharded across a mesh cluster.";
let description = [{
@@ -186,6 +209,29 @@ def Mesh_ShardOp : Mesh_Op<"shard", [Pure, SameOperandsAndResultType]> {
}];
}
+def Mesh_ProcessIndexOp : Mesh_Op<"process_index", [Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
+ let summary = "Get the index of current device along specified mesh axis.";
+ let description = [{
+ It is used in the SPMD format of IR.
+ The `axes` mush be non-negative and less than the total number of mesh axes.
+ }];
+ let arguments = (ins
+ FlatSymbolRefAttr:$mesh,
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$axes
+ );
+ let results = (outs
+ Variadic<Index>:$result
+ );
+ let assemblyFormat = [{
+ `on` $mesh (`axes` `=` $axes^)?
+ attr-dict `:` type($result)
+ }];
+ let builders = [
+ OpBuilder<(ins "::mlir::mesh::ClusterOp":$mesh)>,
+ OpBuilder<(ins "StringRef":$mesh, "ArrayRef<int16_t>":$axes)>
+ ];
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -197,7 +243,7 @@ class Mesh_CollectiveCommunicationOpBase<
[DeclareOpInterfaceMethods<SymbolUserOpInterface>])> {
dag commonArgs = (ins
FlatSymbolRefAttr:$mesh,
- DefaultValuedAttr<DenseI16ArrayAttr, "{}">:$mesh_axes
+ DefaultValuedAttr<Mesh_MeshAxesAttr, "{}">:$mesh_axes
);
}
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md
new file mode 100644
index 0000000..6368931
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/ReshardingSpmdizationDoc.md
@@ -0,0 +1,683 @@
+# Resharding Spmdization Examples
+
+Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[0, 1]]` on a `2x3` mesh.
+
+unsharded `2x3` tensor
+```
+11 12 13
+21 22 23
+```
+
+sharded on a `2x3` mesh
+
+sharding = `[[0, 1]]`
+
+mesh contents:
+
+```
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+```
+
+Transform into
+sharding = `[[1, 0]]`
+```
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 13 | 22 | |
++----+----+----+ |
+| 12 | 21 | 23 | |
++----+----+----+ ↓
+```
+Algorithm:
+Swap contents on devices that have the same linear index in the 2 shardings.
+
+--------------------------------------------------------------
+
+Reshard `2x3` tensor from sharding `[[0, 1]]` to sharding `[[1]]` on a `2x3` mesh.
+
+unsharded `2x3` tensor
+```
+11 12 13
+21 22 23
+```
+
+sharded on a `2x3` mesh
+
+sharding = `[[0, 1]]`
+
+mesh contents:
+```
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+```
+
+Transform into
+sharding = `[[1]]`
+```
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+```
+Algorithm:
+All-gather along mesh axis 0.
+
+--------------------------------------------------------------
+
+Reshard `4x6` tensor from sharding `[[], [0, 1]]` to sharding `[[], [0]]` on a `2x3` mesh.
+
+unsharded `4x6` tensor
+```
+11 12 13 14 15 16
+21 22 23 24 25 26
+```
+
+sharded on a `2x3` mesh
+
+sharding = `[[], [0, 1]]`
+
+mesh contents:
+```
+mesh axis 1
+----------->
++----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 | |
+| 21 | 22 | 23 | |
++----+----+----+ |
+| 14 | 15 | 16 | |
+| 24 | 25 | 26 | |
++----+----+----+ ↓
+```
+Transform into
+sharding = `[[], [0]]`
+```
+mesh axis 1
+----------->
++----------+----------+ mesh axis 0 |
+| 11 12 13 | 11 12 13 | |
+| 21 22 23 | 21 22 23 | |
++----------+----------+ |
+| 14 15 16 | 14 15 16 | |
+| 24 25 26 | 24 25 26 | |
++----------+----------+ ↓
+```
+Algorithm:
+All-gather along mesh axis 1.
+
+--------------------------------------------------------------
+
+Reshard `4x8` tensor from sharding `[[0], [1, 2]]` to sharding `[[0], [2]]` on a `2x2x2` mesh.
+
+unsharded `4x8` tensor
+```
+11 12 13 14 15 16 17 18
+21 22 23 24 25 26 27 28
+31 32 33 34 35 36 37 38
+41 42 43 44 45 46 47 48
+```
+sharded on a `2x2x2` mesh
+
+sharding = `[[0], [1, 2]]`
+
+mesh contents:
+```
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 14 | | |
+| 21 22 | 23 24 | | |
++-------+-------+ | |
+| 15 16 | 17 18 | | |
+| 25 26 | 27 28 | | |
++-------+-------+ ↓ |
++-------+-------+ |
+| 31 32 | 33 34 | |
+| 41 42 | 43 44 | |
++-------+-------+ |
+| 35 36 | 37 38 | |
+| 45 46 | 47 48 | |
++-------+-------+ ↓
+```
+Transform into
+sharding = `[[0], [2]]`
+```
+mesh axis 2
+----------->
++-------------+-------------+ mesh axis 1 | mesh axis 0 |
+| 11 12 13 14 | 15 16 17 18 | | |
+| 21 22 23 24 | 25 26 27 28 | | |
++-------------+-------------+ | |
+| 11 12 13 14 | 15 16 17 18 | | |
+| 21 22 23 24 | 25 26 27 28 | | |
++-------------+-------------+ ↓ |
++-------------+-------------+ |
+| 31 32 33 34 | 35 36 37 38 | |
+| 41 42 43 44 | 45 46 47 48 | |
++-------------+-------------+ |
+| 31 32 33 34 | 35 36 37 38 | |
+| 41 42 43 44 | 45 46 47 48 | |
++-------------+-------------+ ↓
+```
+Algorithm:
+
+Can't be done with just an all-gather along mesh axis 1.
+Can be handled by multiple resharding transformations
+`[[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]`
+
+--------------------------------------------------------------
+
+Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh.
+
+unsharded `6x6` tensor
+```
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+```
+sharded on a `2x3` mesh
+
+sharding = `[[0], [1]]`
+```
+mesh axis 1
+----------->
++-------+-------+-------+ mesh axis 0 |
+| 11 12 | 13 14 | 15 16 | |
+| 21 22 | 23 24 | 25 26 | |
+| 31 32 | 33 34 | 35 36 | |
++-------+-------+-------+ |
+| 41 42 | 43 44 | 45 46 | |
+| 51 52 | 53 54 | 55 56 | |
+| 61 62 | 63 64 | 65 66 | |
++-------+-------+-------+ ↓
+```
+transform to
+sharding = `[[1], [0]]`
+```
+mesh axis 1
+----------->
++----------+----------+----------+ mesh axis 0 |
+| 11 12 13 | 31 32 33 | 51 52 53 | |
+| 21 22 23 | 41 42 43 | 61 62 63 | |
++----------+----------+----------+ |
+| 14 15 16 | 34 35 36 | 54 55 56 | |
+| 24 25 26 | 44 45 46 | 64 65 66 | |
++----------+----------+----------+ ↓
+
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 | |
+| 21 22 23 | 24 25 26 | |
++----------+----------+ |
+| 31 32 33 | 34 35 36 | |
+| 41 42 43 | 44 45 46 | |
++----------+----------+ |
+| 51 52 53 | 54 55 56 | |
+| 61 62 63 | 64 65 66 | |
++----------+----------+ ↓
+```
+Algorithm: TODO
+
+--------------------------------------------------------------
+
+Reshard `6x6` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x6` mesh.
+
+unsharded 6x6 tensor
+```
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+```
+shard on `2x6` mesh
+
+sharding = `[[0], [1]]`
+```
+mesh axis 1
+----------->
++----+----+----+----+----+----+ mesh axis 0 |
+| 11 | 12 | 13 ‖ 14 | 15 | 16 | |
+| 21 | 22 | 23 ‖ 24 | 23 | 26 | |
+| 31 | 32 | 33 ‖ 34 | 35 | 36 | |
++----+----+----+----+----+----+ |
+| 41 | 42 | 43 ‖ 44 | 45 | 46 | |
+| 51 | 52 | 53 ‖ 54 | 55 | 56 | |
+| 61 | 62 | 63 ‖ 64 | 65 | 66 | |
++----+----+----+----+----+----+ ↓
+```
+transform to
+sharding = `[[1], [0]]`
+```
+mesh axis 0
+----------->
++----------+----------+ mesh axis 1 |
+| 11 12 13 | 14 15 16 | |
++----------+----------+ |
+| 21 22 23 | 24 25 26 | |
++----------+----------+ |
+| 31 32 33 | 34 35 36 | |
++==========+==========+ |
+| 41 42 43 | 44 45 46 | |
++----------+----------+ |
+| 51 52 53 | 54 55 56 | |
++----------+----------+ |
+| 61 62 63 | 64 65 66 | |
++----------+----------+ ↓
+```
+Algorithm: TODO
+
+--------------------------------------------------------------
+
+Reshard KxL tensor from `[[0], [1]]` to `[[1], [0]]` on `MxN` mesh.
+
+`M x N` mesh.
+`K x L` tensor `t`.
+`d(m, n)` the tensor on device `(m, n)`.
+
+sharding = `[[0], [1]]`
+Tensor shard s on each device has size `(K ceildiv M, L ceildiv N)`.
+```
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+```
+substitute
+```
+i <- m * (K ceildiv M) + k
+j <- n * (L ceildiv N) + l
+```
+```
+m -> i floordiv (K ceildiv M)
+n -> j floordiv (L ceildiv N)
+k -> i % (K ceildiv M)
+l -> j % (L ceildiv N)
+```
+For the inverse map we get
+```
+t[i, j] -> d(
+ i floordiv (K ceildiv M), j floordiv (L ceildiv N)
+)[
+ i % (K ceildiv M), j % (L ceildiv N)
+]
+```
+Check:
+```
+i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
+t[13, 17] = d(
+ 13 floordiv (16 ceildiv 3),
+ 17 floordiv (23 ceilvid 4)
+)[
+ 13 % (16 ceildiv 3),
+ 17 % (23 ceilvid 4)
+]
+= d(
+ 13 floordiv 6,
+ 17 floordiv 6
+)[
+ 13 % 6,
+ 17 % 6
+]
+= d(2, 2)[1, 5]
+= t[
+ 2 * (16 ceildiv 3) + 1,
+ 2 * (23 ceildiv 4) + 5
+]
+= t[
+ 2 * 6 + 1,
+ 2 * 6 + 5
+]
+= t[13, 17]
+```
+
+sharding = `[[1], [0]]`
+Tensor shard s on each device has size `(K ceildiv N, L ceildiv M)`.
+```
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+```
+substitute
+```
+i <- n * (K ceildiv N) + k
+j <- m * (L ceildiv M) + l
+```
+```
+m -> j floordiv (L ceildiv M)
+n -> i floordiv (K ceildiv N)
+k -> i % (K ceildiv N)
+l -> j % (L ceildiv M)
+```
+For the inverse map we get
+```
+t[i, j] -> d(
+ j floordiv (L ceildiv M), i floordiv (K ceildiv N)
+)[
+ i % (K ceildiv N), j % (L ceildiv M)
+]
+```
+Check:
+```
+i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
+t[9, 19] = d(
+ 19 floordiv (14 ceildiv 5),
+ 9 floordiv (27 ceildiv 2)
+)[
+ 9 % (27 ceildiv 2),
+ 19 % (14 ceildiv 5)
+]
+= d(
+ 19 floordiv 3,
+ 9 floordiv 14
+)[
+ 9 % 14
+ 19 % 3
+]
+= d(6, 0)[9, 1]
+= t[
+ 0 * (27 ceildiv 2) + 9,
+ 6 * (14 ceildiv 5) + 1
+]
+= t[
+ 0 * 14 + 9,
+ 6 * 3 + 1
+]
+= t[9, 19]
+```
+sharding = `[[0], [1]]`
+```
+d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
+t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
+```
+sharding = `[[1], [0]]`
+```
+d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
+t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
+```
+sharding `[[0], [1]] -> [[1], [0]]`
+`d1(m, n)` the tensor on device `(m, n)` for sharding sharding `[[0], [1]]`.
+`d2(m, n)` the tensor on device `(m, n)` for sharding sharding `[[1], [0]]`.
+```
+d1(m, n)[k, l] ->
+t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
+d2(
+ (m * (L ceildiv M) + l) floordiv (L ceildiv M),
+ (n * (K ceildiv N) + k) floordiv (K ceildiv N)
+)[
+ (n * (K ceildiv N) + k) % (K ceildiv N),
+ (m * (L ceildiv M) + l) % (L ceildiv M)
+]
+= d2(p, q)[u, v]
+```
+We want to copy the the data between devices in slices/tiles.
+What are the source/target tile coordinates?
+For a fixed `(m, n, p, q)` what is the range of `(k, l, u, v)`?
+TODO
+
+--------------------------------------------------------------
+
+Reshard `KxL` tensor from sharding `[[0], [1]]` to sharding `[[1], [0]]` on a `2x3` mesh.
+
+Device placement on a `2x3` mesh
+```
+11 12 13 <- devices
+21 22 23
+```
+sharding `[[0], [1]]`
+```
+tensor axis 1
+----------->
++----+----+----+ tensor axis 0 |
+| 11 | 12 | 13 | |
++----+----+----+ |
+| 21 | 22 | 23 | |
++----+----+----+ ↓
+```
+transform to
+sharding `[[1], [0]]`
+```
+tensor axis 1
+----------->
++----+----+ tensor axis 0 |
+| 11 | 21 | |
++----+----+ |
+| 12 | 22 | |
++----+----+ |
+| 13 | 23 | |
++----+----+ ↓
+```
+```
++-----------------+--------+--------+-----------------+
+| | | |
++ + + +
+| 11 | 12 | 13 |
++ + + +
+| | | |
++-----------------+--------+--------+-----------------+
+| | | |
++ + + +
+| 21 | 22 | 23 |
++ + + +
+| | | |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+| | |
++ 11 + 21 +
+| | |
++-----------------+--------+--------+-----------------+
+| | |
++ 12 + 22 +
+| | |
++-----------------+--------+--------+-----------------+
+| | |
++ 13 + 23 +
+| | |
++-----------------+--------+--------+-----------------+
+
++-----------------+--------+--------+-----------------+
+| | | | |
++ 11 11 + 12 11 + 12 21 + 13 21 +
+| | | | |
++-----------------+--------+--------+-----------------+
+| 11 12 | 12 12 | 12 22 | 13 22 |
++-----------------+--------+--------+-----------------+
+| 21 12 | 22 12 | 22 22 | 23 22 |
++-----------------+--------+--------+-----------------+
+| | | | |
++ 21 13 + 22 13 + 22 23 + 23 23 +
+| | | | |
++-----------------+--------+--------+-----------------+
+```
+If `S` and `T` are the source and target shard sizes along some tensor axis.
+Then we have a period of `(S*T)/gcd(S, T)`. Then the cut pattern repeats.
+TODO
+
+--------------------------------------------------------------
+
+Reshard `6x6` tensor from sharding `[[0], []]` to sharding `[[], [0]]` on a `3` mesh.
+
+unsharded `6x6` tensor
+```
+11 12 13 14 15 16
+21 22 23 24 25 26
+31 32 33 34 35 36
+41 42 43 44 45 46
+51 52 53 54 55 56
+61 62 63 64 65 66
+```
+sharded on a `3` mesh
+
+sharding = `[[0], []]`
+```
++-------------------+ mesh axis 0 |
+| 11 12 13 14 15 16 | |
+| 21 22 23 24 25 26 | |
++-------------------+ |
+| 31 32 33 34 35 36 | |
+| 41 42 43 44 45 46 | |
++-------------------+ |
+| 51 52 53 54 55 56 | |
+| 61 62 63 64 65 66 | |
++-------------------+ ↓
+```
+transform to
+sharding = `[[], [0]]`
+```
+mesh axis 0
+----------->
++-------+-------+-------+
+| 11 12 | 13 14 | 15 16 |
+| 21 22 | 23 24 | 25 26 |
+| 31 32 | 33 34 | 35 36 |
+| 41 42 | 43 44 | 45 46 |
+| 51 52 | 53 54 | 55 56 |
+| 61 62 | 63 64 | 65 66 |
++-------+-------+-------+
+```
+Algorithm:
+```mlir
+%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
+```
+--------------------------------------------------------------
+
+Reshard `4x4` tensor from sharding `[[0], [1, 2]]` to sharding `[[0, 1], [2]]` on a `2x2x2` mesh.
+
+unsharded `4x4` tensor
+```
+11 12 13 14
+21 22 23 24
+31 32 33 34
+41 42 43 44
+```
+sharded on a `2x2x2` mesh
+
+sharding = `[[0], [1, 2]]`
+```
+mesh axis 2
+----------->
++----+----+ mesh axis 1 | mesh axis 0 |
+| 11 | 12 | | |
+| 21 | 22 | | |
++----+----+ | |
+| 13 | 14 | | |
+| 23 | 24 | | |
++----+----+ ↓ |
++----+----+ |
+| 31 | 32 | |
+| 41 | 42 | |
++----+----+ |
+| 33 | 34 | |
+| 43 | 44 | |
++----+----+ ↓
+```
+transform to
+sharding = `[[0, 1], [2]]`
+```
+mesh axis 2
+----------->
++-------+-------+ mesh axis 1 | mesh axis 0 |
+| 11 12 | 13 41 | | |
++-------+-------+ | |
+| 21 22 | 23 24 | | |
++-------+-------+ ↓ |
++-------+-------+ |
+| 31 32 | 33 34 | |
++-------+-------+ |
+| 41 42 | 43 44 | |
++-------+-------+ ↓
+```
+Algorithm:
+```mlir
+%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
+```
+is not enough.
+
+Can be decomposed into
+```
+[[0], [1, 2]] -> [[0], [2, 1]] -> [[0, 1], [2]]
+```
+
+## Decomposition into basis of reshardings
+
+We can decompose each resharding into a sequence of basis reshardings.
+It is not communication efficient in terms of minimizing the data communicated
+between devices.
+An efficient approach would be more complicated to implement.
+Each device has to receive at most as much data as the size of its target
+sharding tensor.
+
+--------------------------------------------------------------
+
+Basis:
+
+* From replicate to split.
+ ```
+ [[]] -> [[1]]
+ ```
+ Extract slices without communication.
+
+* From split to replicate.
+ ```
+ [[0]] -> [[]]
+ [[0, 1]] -> [[1]]
+ ```
+ All-gather along mesh axis 0.
+
+* Swap mesh axes order when assigned to the same tensor axis.
+ ```
+ [[0, 1]] -> [[1, 0]]
+ ```
+ Swap contents on devices with the same linear index.
+
+* Move mesh axis to different tensor dimension.
+ ```
+ [[0], []] -> [[], [0]]
+ ```
+ All-to-all.
+
+--------------------------------------------------------------
+
+Example decomposition of
+```
+[[0], [1]] -> [[1], [0]]
+```
+into
+```
+[[0], [1]] -> all-gather along mesh axis 1 ->
+[[0], []] -> all-to-all along mesh axis 0 ->
+[[], [0]] -> extract slice along mesh axis 1 ->
+[[1], [0]]
+```
+
+--------------------------------------------------------------
+
+Example decomposition of
+```
+[[3, 2], [], [0, 1]] -> [[0], [1, 2], []]
+```
+into
+```
+[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 ->
+[[3, 2], [1], [0]] -> all-to-all along mesh axis 2 ->
+[[3], [1, 2], [0]] -> all-gather along mesh axis 3 ->
+[[], [1, 2], [0]] -> all-to-all along mesh axis 0 ->
+[[0], [1, 2], []]
+```
diff --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
new file mode 100644
index 0000000..f71bb9b
--- /dev/null
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Spmdization.h
@@ -0,0 +1,35 @@
+//===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+#define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
+
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/IR/DialectRegistry.h"
+#include "mlir/Support/LogicalResult.h"
+
+namespace mlir {
+namespace mesh {
+
+// Return the sharded shape `shape` acording ot sharding `sharding`.
+ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+ MeshShardingAttr sharding);
+
+// Insert resharding spmdization of the value `sourceShardValue`
+// from sharding `source` to sharding `target`.
+// `sourceShardValue` is the already sharded value according to `source`.
+TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
+ ShardOp source, ShardOp target,
+ TypedValue<ShapedType> sourceShardValue);
+
+void reshardingRegisterDependentDialects(DialectRegistry &registry);
+
+} // namespace mesh
+} // namespace mlir
+
+#endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 35b519e..e8a09c4 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -76,7 +76,6 @@ void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns);
/// Populates `patterns` with patterns that simplify `tensor.pack` and
/// `tensor.unpack` operations.
-/// TODO: Add a pattern to convert tensor.unpack op to tensor.collapse_shape op.
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns);
/// Populates `patterns` with patterns that fold operations like `tensor.pad`
diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
index 307257f..fcdb21d 100644
--- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
+++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
@@ -438,6 +438,28 @@ def CastOp : TransformDialectOp<"cast",
}];
}
+def NumAssociationsOp : TransformDialectOp<"num_associations",
+ [MemoryEffectsOpInterface, ParamProducerTransformOpTrait,
+ DeclareOpInterfaceMethods<TransformOpInterface>,
+ MatchOpInterface]> {
+ let summary =
+ "Returns the number of payload objects associated with the argument";
+ let description = [{
+ Given an argument, handle or parameter, returns a new parameter associated
+ with a single 64-bit number that corresponds to the number of payload
+ objects (operations or values for a handle, attributes for a parameter)
+ associated with the argument.
+
+ Always succeeds.
+ }];
+ let arguments = (ins Transform_AnyHandleOrParamType:$handle);
+ let results = (outs TransformParamTypeInterface:$num);
+ let assemblyFormat = [{
+ $handle attr-dict `:` functional-type(operands, results)
+ }];
+ let hasVerifier = 1;
+}
+
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index 3ba0dae..d2f52cf1 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -475,19 +475,33 @@ public:
return removeDiscardableAttr(StringAttr::get(getContext(), name));
}
- /// Return all of the discardable attributes on this operation.
- ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); }
+ /// Return a range of all of discardable attributes on this operation. Note
+ /// that for unregistered operations that are not storing inherent attributes
+ /// as properties, all attributes are considered discardable.
+ auto getDiscardableAttrs() {
+ std::optional<RegisteredOperationName> opName = getRegisteredInfo();
+ ArrayRef<StringAttr> attributeNames =
+ opName ? getRegisteredInfo()->getAttributeNames()
+ : ArrayRef<StringAttr>();
+ return llvm::make_filter_range(
+ attrs.getValue(),
+ [this, attributeNames](const NamedAttribute attribute) {
+ return getPropertiesStorage() ||
+ !llvm::is_contained(attributeNames, attribute.getName());
+ });
+ }
/// Return all of the discardable attributes on this operation as a
/// DictionaryAttr.
- DictionaryAttr getDiscardableAttrDictionary() { return attrs; }
+ DictionaryAttr getDiscardableAttrDictionary() {
+ if (getPropertiesStorage())
+ return attrs;
+ return DictionaryAttr::get(getContext(),
+ llvm::to_vector(getDiscardableAttrs()));
+ }
/// Return all of the attributes on this operation.
- ArrayRef<NamedAttribute> getAttrs() {
- if (!getPropertiesStorage())
- return getDiscardableAttrs();
- return getAttrDictionary().getValue();
- }
+ ArrayRef<NamedAttribute> getAttrs() { return getAttrDictionary().getValue(); }
/// Return all of the attributes on this operation as a DictionaryAttr.
DictionaryAttr getAttrDictionary();
diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h
index d5f1ea0..1b2e6a3 100644
--- a/mlir/include/mlir/Pass/PassManager.h
+++ b/mlir/include/mlir/Pass/PassManager.h
@@ -207,6 +207,27 @@ enum class PassDisplayMode {
Pipeline,
};
+/// Streams on which to output crash reproducer.
+struct ReproducerStream {
+ virtual ~ReproducerStream() = default;
+
+ /// Description of the reproducer stream.
+ virtual StringRef description() = 0;
+
+ /// Stream on which to output reproducer.
+ virtual raw_ostream &os() = 0;
+};
+
+/// Method type for constructing ReproducerStream.
+using ReproducerStreamFactory =
+ std::function<std::unique_ptr<ReproducerStream>(std::string &error)>;
+
+std::string
+makeReproducer(StringRef anchorName,
+ const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
+ Operation *op, StringRef outputFile, bool disableThreads = false,
+ bool verifyPasses = false);
+
/// The main pass manager and pipeline builder.
class PassManager : public OpPassManager {
public:
@@ -243,21 +264,6 @@ public:
void enableCrashReproducerGeneration(StringRef outputFile,
bool genLocalReproducer = false);
- /// Streams on which to output crash reproducer.
- struct ReproducerStream {
- virtual ~ReproducerStream() = default;
-
- /// Description of the reproducer stream.
- virtual StringRef description() = 0;
-
- /// Stream on which to output reproducer.
- virtual raw_ostream &os() = 0;
- };
-
- /// Method type for constructing ReproducerStream.
- using ReproducerStreamFactory =
- std::function<std::unique_ptr<ReproducerStream>(std::string &error)>;
-
/// Enable support for the pass manager to generate a reproducer on the event
/// of a crash or a pass failure. `factory` is used to construct the streams
/// to write the generated reproducer to. If `genLocalReproducer` is true, the
diff --git a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
index e255d9f..6e90fad 100644
--- a/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
+++ b/mlir/include/mlir/Tools/mlir-opt/MlirOptMain.h
@@ -173,6 +173,9 @@ public:
}
bool shouldVerifyRoundtrip() const { return verifyRoundtripFlag; }
+ /// Reproducer file generation (no crash required).
+ StringRef getReproducerFilename() const { return generateReproducerFileFlag; }
+
protected:
/// Allow operation with no registered dialects.
/// This option is for convenience during testing only and discouraged in
@@ -228,6 +231,9 @@ protected:
/// Verify that the input IR round-trips perfectly.
bool verifyRoundtripFlag = false;
+
+ /// The reproducer output filename (no crash required).
+ std::string generateReproducerFileFlag = "";
};
/// This defines the function type used to setup the pass manager. This can be
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index ac9889d..a97cfe5 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -613,12 +613,14 @@ void mlirOperationSetInherentAttributeByName(MlirOperation op,
}
intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) {
- return static_cast<intptr_t>(unwrap(op)->getDiscardableAttrs().size());
+ return static_cast<intptr_t>(
+ llvm::range_size(unwrap(op)->getDiscardableAttrs()));
}
MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op,
intptr_t pos) {
- NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos];
+ NamedAttribute attr =
+ *std::next(unwrap(op)->getDiscardableAttrs().begin(), pos);
return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())};
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index e79a02f..6a005e6 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -26,9 +26,8 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
- for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
- BlockArgument attribution = en.value();
-
+ for (const auto [idx, attribution] :
+ llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
auto type = dyn_cast<MemRefType>(attribution.getType());
assert(type && type.hasStaticShape() && "unexpected type in attribution");
@@ -37,12 +36,12 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
auto elementType =
cast<Type>(typeConverter->convertType(type.getElementType()));
auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
- std::string name = std::string(
- llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
+ std::string name =
+ std::string(llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), idx));
uint64_t alignment = 0;
if (auto alignAttr =
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getWorkgroupAttributionAttr(
- en.index(), LLVM::LLVMDialect::getAlignAttrName())))
+ idx, LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
auto globalOp = rewriter.create<LLVM::GlobalOp>(
gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
@@ -105,8 +104,7 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
rewriter.setInsertionPointToStart(&gpuFuncOp.front());
unsigned numProperArguments = gpuFuncOp.getNumArguments();
- for (const auto &en : llvm::enumerate(workgroupBuffers)) {
- LLVM::GlobalOp global = en.value();
+ for (const auto [idx, global] : llvm::enumerate(workgroupBuffers)) {
auto ptrType = LLVM::LLVMPointerType::get(rewriter.getContext(),
global.getAddrSpace());
Value address = rewriter.create<LLVM::AddressOfOp>(
@@ -119,18 +117,18 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// existing memref infrastructure. This may use more registers than
// otherwise necessary given that memref sizes are fixed, but we can try
// and canonicalize that away later.
- Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
+ Value attribution = gpuFuncOp.getWorkgroupAttributions()[idx];
auto type = cast<MemRefType>(attribution.getType());
auto descr = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), type, memory);
- signatureConversion.remapInput(numProperArguments + en.index(), descr);
+ signatureConversion.remapInput(numProperArguments + idx, descr);
}
// Rewrite private memory attributions to alloca'ed buffers.
unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
- for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
- Value attribution = en.value();
+ for (const auto [idx, attribution] :
+ llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
auto type = cast<MemRefType>(attribution.getType());
assert(type && type.hasStaticShape() && "unexpected type in attribution");
@@ -145,14 +143,14 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
uint64_t alignment = 0;
if (auto alignAttr =
dyn_cast_or_null<IntegerAttr>(gpuFuncOp.getPrivateAttributionAttr(
- en.index(), LLVM::LLVMDialect::getAlignAttrName())))
+ idx, LLVM::LLVMDialect::getAlignAttrName())))
alignment = alignAttr.getInt();
Value allocated = rewriter.create<LLVM::AllocaOp>(
gpuFuncOp.getLoc(), ptrType, elementType, numElements, alignment);
auto descr = MemRefDescriptor::fromStaticShape(
rewriter, loc, *getTypeConverter(), type, allocated);
signatureConversion.remapInput(
- numProperArguments + numWorkgroupAttributions + en.index(), descr);
+ numProperArguments + numWorkgroupAttributions + idx, descr);
}
}
@@ -169,15 +167,16 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
if (getTypeConverter()->getOptions().useBarePtrCallConv) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&llvmFuncOp.getBody().front());
- for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
- auto memrefTy = dyn_cast<MemRefType>(en.value());
+ for (const auto [idx, argTy] :
+ llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
+ auto memrefTy = dyn_cast<MemRefType>(argTy);
if (!memrefTy)
continue;
assert(memrefTy.hasStaticShape() &&
"Bare pointer convertion used with dynamically-shaped memrefs");
// Use a placeholder when replacing uses of the memref argument to prevent
// circular replacements.
- auto remapping = signatureConversion.getInputMapping(en.index());
+ auto remapping = signatureConversion.getInputMapping(idx);
assert(remapping && remapping->size == 1 &&
"Type converter should produce 1-to-1 mapping for bare memrefs");
BlockArgument newArg =
@@ -193,19 +192,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
// Get memref type from function arguments and set the noalias to
// pointer arguments.
- for (const auto &en : llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
- auto memrefTy = en.value().dyn_cast<MemRefType>();
- NamedAttrList argAttr = argAttrs
- ? argAttrs[en.index()].cast<DictionaryAttr>()
- : NamedAttrList();
-
+ for (const auto [idx, argTy] :
+ llvm::enumerate(gpuFuncOp.getArgumentTypes())) {
+ auto remapping = signatureConversion.getInputMapping(idx);
+ NamedAttrList argAttr =
+ argAttrs ? argAttrs[idx].cast<DictionaryAttr>() : NamedAttrList();
+ auto copyAttribute = [&](StringRef attrName) {
+ Attribute attr = argAttr.erase(attrName);
+ if (!attr)
+ return;
+ for (size_t i = 0, e = remapping->size; i < e; ++i)
+ llvmFuncOp.setArgAttr(remapping->inputNo + i, attrName, attr);
+ };
auto copyPointerAttribute = [&](StringRef attrName) {
Attribute attr = argAttr.erase(attrName);
- // This is a proxy for the bare pointer calling convention.
if (!attr)
return;
- auto remapping = signatureConversion.getInputMapping(en.index());
if (remapping->size > 1 &&
attrName == LLVM::LLVMDialect::getNoAliasAttrName()) {
emitWarning(llvmFuncOp.getLoc(),
@@ -224,10 +227,23 @@ GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
if (argAttr.empty())
continue;
- if (memrefTy) {
+ copyAttribute(LLVM::LLVMDialect::getReturnedAttrName());
+ copyAttribute(LLVM::LLVMDialect::getNoUndefAttrName());
+ copyAttribute(LLVM::LLVMDialect::getInRegAttrName());
+ bool lowersToPointer = false;
+ for (size_t i = 0, e = remapping->size; i < e; ++i) {
+ lowersToPointer |= isa<LLVM::LLVMPointerType>(
+ llvmFuncOp.getArgument(remapping->inputNo + i).getType());
+ }
+
+ if (lowersToPointer) {
copyPointerAttribute(LLVM::LLVMDialect::getNoAliasAttrName());
+ copyPointerAttribute(LLVM::LLVMDialect::getNoCaptureAttrName());
+ copyPointerAttribute(LLVM::LLVMDialect::getNoFreeAttrName());
+ copyPointerAttribute(LLVM::LLVMDialect::getAlignAttrName());
copyPointerAttribute(LLVM::LLVMDialect::getReadonlyAttrName());
copyPointerAttribute(LLVM::LLVMDialect::getWriteOnlyAttrName());
+ copyPointerAttribute(LLVM::LLVMDialect::getReadnoneAttrName());
copyPointerAttribute(LLVM::LLVMDialect::getNonNullAttrName());
copyPointerAttribute(LLVM::LLVMDialect::getDereferenceableAttrName());
copyPointerAttribute(
diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
index 2ee314e..a1aff1a 100644
--- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
+++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp
@@ -866,6 +866,31 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
this->setHasBoundedRewriteRecursion();
}
+ static void getMaskBufferLoadIndices(OpTy xferOp, Value castedMaskBuffer,
+ SmallVectorImpl<Value> &loadIndices,
+ Value iv) {
+ assert(xferOp.getMask() && "Expected transfer op to have mask");
+
+ // Add load indices from the previous iteration.
+ // The mask buffer depends on the permutation map, which makes determining
+ // the indices quite complex, so this is why we need to "look back" to the
+ // previous iteration to find the right indices.
+ Value maskBuffer = getMaskBuffer(xferOp);
+ for (Operation *user : maskBuffer.getUsers()) {
+ // If there is no previous load op, then the indices are empty.
+ if (auto loadOp = dyn_cast<memref::LoadOp>(user)) {
+ Operation::operand_range prevIndices = loadOp.getIndices();
+ loadIndices.append(prevIndices.begin(), prevIndices.end());
+ break;
+ }
+ }
+
+ // In case of broadcast: Use same indices to load from memref
+ // as before.
+ if (!xferOp.isBroadcastDim(0))
+ loadIndices.push_back(iv);
+ }
+
LogicalResult matchAndRewrite(OpTy xferOp,
PatternRewriter &rewriter) const override {
if (!xferOp->hasAttr(kPassLabel))
@@ -873,9 +898,9 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// Find and cast data buffer. How the buffer can be found depends on OpTy.
ImplicitLocOpBuilder locB(xferOp.getLoc(), rewriter);
- auto dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
+ Value dataBuffer = Strategy<OpTy>::getBuffer(xferOp);
auto dataBufferType = dyn_cast<MemRefType>(dataBuffer.getType());
- auto castedDataType = unpackOneDim(dataBufferType);
+ FailureOr<MemRefType> castedDataType = unpackOneDim(dataBufferType);
if (failed(castedDataType))
return failure();
@@ -885,8 +910,7 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If the xferOp has a mask: Find and cast mask buffer.
Value castedMaskBuffer;
if (xferOp.getMask()) {
- auto maskBuffer = getMaskBuffer(xferOp);
- auto maskBufferType = dyn_cast<MemRefType>(maskBuffer.getType());
+ Value maskBuffer = getMaskBuffer(xferOp);
if (xferOp.isBroadcastDim(0) || xferOp.getMaskType().getRank() == 1) {
// Do not unpack a dimension of the mask, if:
// * To-be-unpacked transfer op dimension is a broadcast.
@@ -897,7 +921,8 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
} else {
// It's safe to assume the mask buffer can be unpacked if the data
// buffer was unpacked.
- auto castedMaskType = *unpackOneDim(maskBufferType);
+ auto maskBufferType = cast<MemRefType>(maskBuffer.getType());
+ MemRefType castedMaskType = *unpackOneDim(maskBufferType);
castedMaskBuffer =
locB.create<vector::TypeCastOp>(castedMaskType, maskBuffer);
}
@@ -929,21 +954,16 @@ struct TransferOpConversion : public VectorToSCFPattern<OpTy> {
// If old transfer op has a mask: Set mask on new transfer op.
// Special case: If the mask of the old transfer op is 1D and
- // the
- // unpacked dim is not a broadcast, no mask is
- // needed on the new transfer op.
+ // the unpacked dim is not a broadcast, no mask is needed on
+ // the new transfer op.
if (xferOp.getMask() && (xferOp.isBroadcastDim(0) ||
xferOp.getMaskType().getRank() > 1)) {
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(newXfer); // Insert load before newXfer.
SmallVector<Value, 8> loadIndices;
- Strategy<OpTy>::getBufferIndices(xferOp, loadIndices);
- // In case of broadcast: Use same indices to load from memref
- // as before.
- if (!xferOp.isBroadcastDim(0))
- loadIndices.push_back(iv);
-
+ getMaskBufferLoadIndices(xferOp, castedMaskBuffer,
+ loadIndices, iv);
auto mask = b.create<memref::LoadOp>(loc, castedMaskBuffer,
loadIndices);
rewriter.updateRootInPlace(newXfer, [&]() {
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt
index 8383e06..8f289ce 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -64,6 +64,7 @@ add_mlir_dialect_library(MLIRGPUTransforms
Transforms/ShuffleRewriter.cpp
Transforms/SPIRVAttachTarget.cpp
Transforms/SubgroupReduceLowering.cpp
+ Transforms/Utils.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/GPU
diff --git a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
index 608d801..a75598a 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AllReduceLowering.cpp
@@ -27,33 +27,6 @@ using namespace mlir;
namespace {
-static vector::CombiningKind
-convertReductionKind(gpu::AllReduceOperation mode) {
- switch (mode) {
-#define MAP_CASE(X) \
- case gpu::AllReduceOperation::X: \
- return vector::CombiningKind::X
-
- MAP_CASE(ADD);
- MAP_CASE(MUL);
- MAP_CASE(MINUI);
- MAP_CASE(MINSI);
- MAP_CASE(MINNUMF);
- MAP_CASE(MAXSI);
- MAP_CASE(MAXUI);
- MAP_CASE(MAXNUMF);
- MAP_CASE(AND);
- MAP_CASE(OR);
- MAP_CASE(XOR);
- MAP_CASE(MINIMUMF);
- MAP_CASE(MAXIMUMF);
-
-#undef MAP_CASE
- }
-
- llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
-}
-
struct GpuAllReduceRewriter {
using AccumulatorFactory = std::function<Value(Value, Value)>;
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 61edce5..b00c65c 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -13,13 +13,17 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/Transforms/Utils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LogicalResult.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
+#include <cstdint>
using namespace mlir;
@@ -58,7 +62,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
unsigned elemBitwidth = elemTy.getIntOrFloatBitWidth();
if (elemBitwidth >= maxShuffleBitwidth)
return rewriter.notifyMatchFailure(
- op, llvm::formatv("element type too large {0}, cannot break down "
+ op, llvm::formatv("element type too large ({0}), cannot break down "
"into vectors of bitwidth {1} or less",
elemBitwidth, maxShuffleBitwidth));
@@ -139,6 +143,167 @@ struct ScalarizeSingleElementReduce final
}
};
+/// Emits a subgroup reduction using a sequence of shuffles. Uses the `packFn`
+/// and `unpackFn` to convert to the native shuffle type and to the reduction
+/// type, respectively. For example, with `input` of type `f16`, `packFn` could
+/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
+/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
+/// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
+static Value createSubgroupShuffleReduction(
+ OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
+ unsigned subgroupSize, function_ref<Value(Value)> packFn,
+ function_ref<Value(Value)> unpackFn) {
+ assert(llvm::isPowerOf2_32(subgroupSize));
+ // Lane value always stays in the original type. We use it to perform arith
+ // reductions.
+ Value laneVal = input;
+ // Parallel reduction using butterfly shuffles.
+ for (unsigned i = 1; i < subgroupSize; i <<= 1) {
+ Value shuffled = builder
+ .create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
+ /*width=*/subgroupSize,
+ /*mode=*/gpu::ShuffleMode::XOR)
+ .getShuffleResult();
+ laneVal = vector::makeArithReduction(builder, loc,
+ gpu::convertReductionKind(mode),
+ laneVal, unpackFn(shuffled));
+ assert(laneVal.getType() == input.getType());
+ }
+
+ return laneVal;
+}
+
+/// Lowers scalar gpu subgroup reductions to a series of shuffles.
+struct ScalarSubgroupReduceToShuffles final
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
+ ScalarSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
+ unsigned shuffleBitwidth,
+ PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
+ shuffleBitwidth(shuffleBitwidth) {}
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ Type valueTy = op.getType();
+ unsigned elemBitwidth =
+ getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
+ if (!valueTy.isIntOrFloat() || elemBitwidth > shuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op, "value type is not a compatible scalar");
+
+ Location loc = op.getLoc();
+ // Since this is already a native shuffle scalar, no packing is necessary.
+ if (elemBitwidth == shuffleBitwidth) {
+ auto identityFn = [](Value v) { return v; };
+ rewriter.replaceOp(op, createSubgroupShuffleReduction(
+ rewriter, loc, op.getValue(), op.getOp(),
+ subgroupSize, identityFn, identityFn));
+ return success();
+ }
+
+ auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
+ auto equivIntType = rewriter.getIntegerType(elemBitwidth);
+ auto packFn = [loc, &rewriter, equivIntType,
+ shuffleIntType](Value unpackedVal) -> Value {
+ auto asInt =
+ rewriter.create<arith::BitcastOp>(loc, equivIntType, unpackedVal);
+ return rewriter.create<arith::ExtUIOp>(loc, shuffleIntType, asInt);
+ };
+ auto unpackFn = [loc, &rewriter, equivIntType,
+ valueTy](Value packedVal) -> Value {
+ auto asInt =
+ rewriter.create<arith::TruncIOp>(loc, equivIntType, packedVal);
+ return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
+ };
+
+ rewriter.replaceOp(op, createSubgroupShuffleReduction(
+ rewriter, loc, op.getValue(), op.getOp(),
+ subgroupSize, packFn, unpackFn));
+ return success();
+ }
+
+private:
+ unsigned subgroupSize = 0;
+ unsigned shuffleBitwidth = 0;
+};
+
+/// Lowers vector gpu subgroup reductions to a series of shuffles.
+struct VectorSubgroupReduceToShuffles final
+ : OpRewritePattern<gpu::SubgroupReduceOp> {
+ VectorSubgroupReduceToShuffles(MLIRContext *ctx, unsigned subgroupSize,
+ unsigned shuffleBitwidth,
+ PatternBenefit benefit)
+ : OpRewritePattern(ctx, benefit), subgroupSize(subgroupSize),
+ shuffleBitwidth(shuffleBitwidth) {}
+
+ LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
+ PatternRewriter &rewriter) const override {
+ auto vecTy = dyn_cast<VectorType>(op.getType());
+ if (!vecTy)
+ return rewriter.notifyMatchFailure(op, "value type is not a vector");
+
+ unsigned vecBitwidth =
+ vecTy.getNumElements() * vecTy.getElementTypeBitWidth();
+ if (vecBitwidth > shuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op,
+ llvm::formatv("vector type bitwidth too large ({0}), cannot lower "
+ "to shuffles of size {1}",
+ vecBitwidth, shuffleBitwidth));
+
+ unsigned elementsPerShuffle =
+ shuffleBitwidth / vecTy.getElementTypeBitWidth();
+ if (elementsPerShuffle * vecTy.getElementTypeBitWidth() != shuffleBitwidth)
+ return rewriter.notifyMatchFailure(
+ op, "shuffle bitwidth is not a multiple of the element bitwidth");
+
+ Location loc = op.getLoc();
+
+ // If the reduced type is smaller than the native shuffle size, extend it,
+ // perform the shuffles, and extract at the end.
+ auto extendedVecTy = VectorType::get(
+ static_cast<int64_t>(elementsPerShuffle), vecTy.getElementType());
+ Value extendedInput = op.getValue();
+ if (vecBitwidth < shuffleBitwidth) {
+ auto zero = rewriter.create<arith::ConstantOp>(
+ loc, rewriter.getZeroAttr(extendedVecTy));
+ extendedInput = rewriter.create<vector::InsertStridedSliceOp>(
+ loc, extendedInput, zero, /*offsets=*/0, /*strides=*/1);
+ }
+
+ auto shuffleIntType = rewriter.getIntegerType(shuffleBitwidth);
+ auto shuffleVecType = VectorType::get(1, shuffleIntType);
+
+ auto packFn = [loc, &rewriter, shuffleVecType](Value unpackedVal) -> Value {
+ auto asIntVec =
+ rewriter.create<vector::BitCastOp>(loc, shuffleVecType, unpackedVal);
+ return rewriter.create<vector::ExtractOp>(loc, asIntVec, 0);
+ };
+ auto unpackFn = [loc, &rewriter, shuffleVecType,
+ extendedVecTy](Value packedVal) -> Value {
+ auto asIntVec =
+ rewriter.create<vector::BroadcastOp>(loc, shuffleVecType, packedVal);
+ return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
+ };
+
+ Value res =
+ createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
+ subgroupSize, packFn, unpackFn);
+
+ if (vecBitwidth < shuffleBitwidth) {
+ res = rewriter.create<vector::ExtractStridedSliceOp>(
+ loc, res, /*offsets=*/0, /*sizes=*/vecTy.getNumElements(),
+ /*strides=*/1);
+ }
+
+ rewriter.replaceOp(op, res);
+ return success();
+ }
+
+private:
+ unsigned subgroupSize = 0;
+ unsigned shuffleBitwidth = 0;
+};
} // namespace
void mlir::populateGpuBreakDownSubgrupReducePatterns(
@@ -148,3 +313,10 @@ void mlir::populateGpuBreakDownSubgrupReducePatterns(
maxShuffleBitwidth, benefit);
patterns.add<ScalarizeSingleElementReduce>(patterns.getContext(), benefit);
}
+
+void mlir::populateGpuLowerSubgroupReduceToShufflePattenrs(
+ RewritePatternSet &patterns, unsigned subgroupSize,
+ unsigned shuffleBitwidth, PatternBenefit benefit) {
+ patterns.add<ScalarSubgroupReduceToShuffles, VectorSubgroupReduceToShuffles>(
+ patterns.getContext(), subgroupSize, shuffleBitwidth, benefit);
+}
diff --git a/mlir/lib/Dialect/GPU/Transforms/Utils.cpp b/mlir/lib/Dialect/GPU/Transforms/Utils.cpp
new file mode 100644
index 0000000..e91aa18
--- /dev/null
+++ b/mlir/lib/Dialect/GPU/Transforms/Utils.cpp
@@ -0,0 +1,44 @@
+//===- Utils.cpp - GPU transforms utils -----------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Implements GPU dialect transforms utils.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/GPU/Transforms/Utils.h"
+#include "llvm/Support/ErrorHandling.h"
+
+namespace mlir::gpu {
+
+vector::CombiningKind convertReductionKind(gpu::AllReduceOperation mode) {
+ switch (mode) {
+#define MAP_CASE(X) \
+ case gpu::AllReduceOperation::X: \
+ return vector::CombiningKind::X
+
+ MAP_CASE(ADD);
+ MAP_CASE(MUL);
+ MAP_CASE(MINUI);
+ MAP_CASE(MINSI);
+ MAP_CASE(MINNUMF);
+ MAP_CASE(MAXSI);
+ MAP_CASE(MAXUI);
+ MAP_CASE(MAXNUMF);
+ MAP_CASE(AND);
+ MAP_CASE(OR);
+ MAP_CASE(XOR);
+ MAP_CASE(MINIMUMF);
+ MAP_CASE(MAXIMUMF);
+
+#undef MAP_CASE
+ }
+
+ llvm_unreachable("Vector and GPU reduction kinds should match 1:1");
+}
+
+} // namespace mlir::gpu
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ba419d3..2917840 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -176,22 +176,22 @@ static bool isContractionBody(Block &block) {
return linalg::detail::isContractionBody(block, &isPairTemplateImpl<Args...>);
}
-/// Given a `linalgOp` and one of its `opOperand`, returns the positions of the
-/// iterators of type `iter` that index the `opOperand` as a permutation.
-/// This is useful to infer various subcomputations on a given `linalgOp`.
-/// This is performed by looking up each result in the matching indexing map and
-/// determining whether:
+/// Given an `indexingMap` and its corresponding `iterators`, returns
+/// the positions of the iterators of type `iter` that are indexed by
+/// the `indexingMap` as a permutation. This is useful to infer various
+/// subcomputations on a `LinalgOp`. This is performed by looking up
+/// each result in the `indexingMap` and determining whether:
/// - It is a single AffineDimExpr.
/// - It is the only result involving this AffineDimExpr.
static llvm::SmallDenseSet<int64_t>
-findPermutationsIndexingOperand(LinalgOp linalgOp, OpOperand *opOperand,
+findPermutationsIndexingOperand(AffineMap indexingMap,
+ ArrayRef<utils::IteratorType> iterators,
utils::IteratorType iter) {
+ assert(iterators.size() == indexingMap.getNumDims());
llvm::SmallDenseSet<int64_t> res;
- assert(linalgOp == opOperand->getOwner() && "expected linalgOp owner");
- AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
for (AffineExpr e : indexingMap.getResults()) {
if (auto d = dyn_cast<AffineDimExpr>(e)) {
- if (linalgOp.getIteratorTypesArray()[d.getPosition()] == iter &&
+ if (iterators[d.getPosition()] == iter &&
llvm::count_if(indexingMap.getResults(), [d](AffineExpr e) {
return e.isFunctionOfDim(d.getPosition());
}) == 1)
@@ -206,6 +206,21 @@ auto par = utils::IteratorType::parallel;
auto red = utils::IteratorType::reduction;
} // namespace
+/// Infer the iterator types from the init affine map. This looks at which dims
+/// are present in the map results, and returns an iterator types array with
+/// parallel types for dims that are present, and reduction types for dims that
+/// are not present.
+static FailureOr<SmallVector<utils::IteratorType>>
+inferIteratorsFromOutMap(AffineMap map) {
+ if (!map.isProjectedPermutation())
+ return failure();
+ SmallVector<utils::IteratorType> iterators(map.getNumDims(), red);
+ for (auto expr : map.getResults())
+ if (auto dim = dyn_cast<AffineDimExpr>(expr))
+ iterators[dim.getPosition()] = par;
+ return iterators;
+}
+
/// Find 2 parallel (m and n) and 1 reduction (k) dimension candidates that form
/// a matmul subcomputation within `linalgOp`. These dimensions are such that:
/// 1. The m dimension is involved in an outer-product along LHS
@@ -217,17 +232,15 @@ auto red = utils::IteratorType::reduction;
/// 5. Optional batch dimensions that appear in all operands are captured.
/// This allows e.g. detecting that some contraction is embedded within
/// `linalgOp` with some orthogonal heuristic.
-FailureOr<ContractionDimensions>
-mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
- if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
- return failure();
-
- llvm::SmallDenseSet<int64_t> a = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), par);
- llvm::SmallDenseSet<int64_t> b = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), par);
- llvm::SmallDenseSet<int64_t> c = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInitOperand(0), par);
+static FailureOr<ContractionDimensions>
+inferContractionDimsImpl(ArrayRef<AffineMap> indexingMaps,
+ ArrayRef<utils::IteratorType> iterators) {
+ llvm::SmallDenseSet<int64_t> a =
+ findPermutationsIndexingOperand(indexingMaps[0], iterators, par);
+ llvm::SmallDenseSet<int64_t> b =
+ findPermutationsIndexingOperand(indexingMaps[1], iterators, par);
+ llvm::SmallDenseSet<int64_t> c =
+ findPermutationsIndexingOperand(indexingMaps[2], iterators, par);
// A & C - B are the iterators involved in an outer-product along A (the LHS).
llvm::SmallDenseSet<int64_t> ac = a;
@@ -243,10 +256,10 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
llvm::set_intersect(batches, c);
// A & B red are the reduction dimensions.
- llvm::SmallDenseSet<int64_t> ra = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(0), red);
- llvm::SmallDenseSet<int64_t> rb = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), red);
+ llvm::SmallDenseSet<int64_t> ra =
+ findPermutationsIndexingOperand(indexingMaps[0], iterators, red);
+ llvm::SmallDenseSet<int64_t> rb =
+ findPermutationsIndexingOperand(indexingMaps[1], iterators, red);
llvm::set_intersect(ra, rb);
// Return each set in sorted order.
@@ -262,6 +275,24 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
return dimensions;
}
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
+ if (linalgOp.getNumDpsInits() != 1 || linalgOp.getNumDpsInputs() != 2)
+ return failure();
+ return inferContractionDimsImpl(linalgOp.getIndexingMapsArray(),
+ linalgOp.getIteratorTypesArray());
+}
+
+FailureOr<ContractionDimensions>
+mlir::linalg::inferContractionDims(ArrayRef<AffineMap> indexingMaps) {
+ if (indexingMaps.size() != 3)
+ return failure();
+ auto iterators = inferIteratorsFromOutMap(indexingMaps[2]);
+ if (failed(iterators))
+ return failure();
+ return inferContractionDimsImpl(indexingMaps, iterators.value());
+}
+
namespace mlir::linalg::detail {
enum class MatchContractionResult {
Success = 0,
@@ -504,10 +535,14 @@ static FailureOr<ConvolutionDimensions>
inferConvolutionDimsImpl(LinalgOp linalgOp,
ConvAccessExprWalker &inputExprWalker,
bool allowEmptyConvolvedDims) {
+ auto filterMap =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(1));
+ auto outputMap =
+ linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0));
llvm::SmallDenseSet<int64_t> filterDims = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInputOperand(1), par);
+ filterMap, linalgOp.getIteratorTypesArray(), par);
llvm::SmallDenseSet<int64_t> outputDims = findPermutationsIndexingOperand(
- linalgOp, linalgOp.getDpsInitOperand(0), par);
+ outputMap, linalgOp.getIteratorTypesArray(), par);
// unConvolvedDims & outputDims - filterDims are the batch iterators.
llvm::SmallDenseSet<int64_t> batch = inputExprWalker.unConvolvedDims;
@@ -529,8 +564,8 @@ inferConvolutionDimsImpl(LinalgOp linalgOp,
llvm::set_intersect(depth, inputExprWalker.unConvolvedDims);
llvm::SmallDenseSet<int64_t> filterReducedDims =
- findPermutationsIndexingOperand(linalgOp, linalgOp.getDpsInputOperand(1),
- red);
+ findPermutationsIndexingOperand(filterMap,
+ linalgOp.getIteratorTypesArray(), red);
// convolvedDims & filterReducedDims are the filter loop iterators.
llvm::SmallDenseSet<int64_t> fl = inputExprWalker.convolvedDims;
diff --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index d276755..de4f58d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -9,8 +9,10 @@
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/Location.h"
@@ -59,8 +61,6 @@ static SmallVector<T> &canonicalizeSetAsVector(SmallVector<T> &vec) {
return vec;
}
-using MeshAxis = int16_t;
-
namespace {
struct DimensionSize {
@@ -114,6 +114,56 @@ Operation *MeshDialect::materializeConstant(OpBuilder &builder, Attribute value,
// Mesh utilities
//===----------------------------------------------------------------------===//
+static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
+ SymbolTableCollection &symbolTable) {
+ mesh::ClusterOp mesh =
+ symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
+ if (!mesh) {
+ return op->emitError() << "Undefined required mesh symbol \""
+ << meshSymbol.getValue() << "\".";
+ }
+
+ return mesh;
+}
+
+template <typename It>
+bool isUnique(It begin, It end) {
+ if (begin == end) {
+ return true;
+ }
+ It next = std::next(begin);
+ if (next == end) {
+ return true;
+ }
+ for (; next != end; ++next, ++begin) {
+ if (*begin == *next) {
+ return false;
+ }
+ }
+ return true;
+}
+
+static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
+ ClusterOp mesh) {
+ SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
+ llvm::sort(sorted);
+ if (!isUnique(sorted.begin(), sorted.end())) {
+ return emitError(loc) << "Mesh axes contains duplicate elements.";
+ }
+
+ MeshAxis rank = mesh.getRank();
+ for (auto axis : axes) {
+ if (axis >= rank || axis < 0) {
+ return emitError(loc)
+ << "0-based mesh axis index " << axis
+ << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
+ << "\" is of rank " << rank << ".";
+ }
+ }
+
+ return success();
+}
+
bool mesh::isReductionLoop(IteratorType iType) {
return iType != IteratorType::Parallel && iType != IteratorType::Invalid;
}
@@ -173,7 +223,45 @@ SmallVector<int64_t> ClusterOp::canonicalDimSizes() {
}
//===----------------------------------------------------------------------===//
-// mesh.shard op
+// mesh.cluster_shape op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ClusterShapeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ return failure();
+ }
+
+ size_t expectedResultsCount =
+ getAxes().empty() ? mesh->getRank() : getAxes().size();
+ if (getResult().size() != expectedResultsCount) {
+ return emitError() << "Unexpected number of results " << getResult().size()
+ << ". Expected " << expectedResultsCount << ".";
+ }
+
+ return success();
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ ClusterOp mesh) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr());
+}
+
+void ClusterShapeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef mesh, ArrayRef<MeshAxis> axes) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+ MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.shard attr
//===----------------------------------------------------------------------===//
LogicalResult
@@ -205,6 +293,75 @@ MeshShardingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
+bool MeshShardingAttr::operator==(Attribute rhs) const {
+ MeshShardingAttr rhsAsMeshShardingAttr = rhs.dyn_cast<MeshShardingAttr>();
+ return rhsAsMeshShardingAttr && *this == rhsAsMeshShardingAttr;
+}
+
+bool MeshShardingAttr::operator==(MeshShardingAttr rhs) const {
+ if (getCluster() != rhs.getCluster() ||
+ getPartialAxes() != rhs.getPartialAxes()) {
+ return false;
+ }
+
+ if (!getPartialAxes().empty() && getPartialType() != rhs.getPartialType()) {
+ return false;
+ }
+
+ auto minSize = std::min(getSplitAxes().size(), rhs.getSplitAxes().size());
+ if (!llvm::equal(llvm::make_range(getSplitAxes().begin(),
+ getSplitAxes().begin() + minSize),
+ llvm::make_range(rhs.getSplitAxes().begin(),
+ rhs.getSplitAxes().begin() + minSize))) {
+ return false;
+ }
+
+ return llvm::all_of(llvm::make_range(getSplitAxes().begin() + minSize,
+ getSplitAxes().end()),
+ std::mem_fn(&DenseI32ArrayAttr::empty)) &&
+ llvm::all_of(llvm::make_range(rhs.getSplitAxes().begin() + minSize,
+ rhs.getSplitAxes().end()),
+ std::mem_fn(&DenseI32ArrayAttr::empty));
+}
+
+//===----------------------------------------------------------------------===//
+// mesh.process_index op
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+ProcessIndexOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
+ auto mesh = ::getMesh(getOperation(), getMeshAttr(), symbolTable);
+ if (failed(mesh)) {
+ return failure();
+ }
+ if (failed(verifyMeshAxes(getLoc(), getAxes(), mesh.value()))) {
+ return failure();
+ }
+
+ size_t expectedResultsCount =
+ getAxes().empty() ? mesh->getRank() : getAxes().size();
+ if (getResult().size() != expectedResultsCount) {
+ return emitError() << "Unexpected number of results " << getResult().size()
+ << ". Expected " << expectedResultsCount << ".";
+ }
+
+ return success();
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ ClusterOp mesh) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(mesh.getRank(), odsBuilder.getIndexType()),
+ mesh.getSymName(), MeshAxesAttr());
+}
+
+void ProcessIndexOp::build(OpBuilder &odsBuilder, OperationState &odsState,
+ StringRef mesh, ArrayRef<MeshAxis> axes) {
+ build(odsBuilder, odsState,
+ SmallVector<Type>(axes.size(), odsBuilder.getIndexType()), mesh,
+ MeshAxesAttr::get(odsBuilder.getContext(), axes));
+}
+
//===----------------------------------------------------------------------===//
// collective communication ops
//===----------------------------------------------------------------------===//
@@ -258,56 +415,6 @@ static LogicalResult verifyInGroupDevice(Location loc, StringRef deviceName,
return success();
}
-static FailureOr<ClusterOp> getMesh(Operation *op, FlatSymbolRefAttr meshSymbol,
- SymbolTableCollection &symbolTable) {
- mesh::ClusterOp mesh =
- symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(op, meshSymbol);
- if (!mesh) {
- return op->emitError() << "Undefined required mesh symbol \""
- << meshSymbol.getValue() << "\".";
- }
-
- return mesh;
-}
-
-template <typename It>
-bool isUnique(It begin, It end) {
- if (begin == end) {
- return true;
- }
- It next = std::next(begin);
- if (next == end) {
- return true;
- }
- for (; next != end; ++next, ++begin) {
- if (*begin == *next) {
- return false;
- }
- }
- return true;
-}
-
-static LogicalResult verifyMeshAxes(Location loc, ArrayRef<MeshAxis> axes,
- ClusterOp mesh) {
- SmallVector<MeshAxis> sorted = llvm::to_vector(axes);
- llvm::sort(sorted);
- if (!isUnique(sorted.begin(), sorted.end())) {
- return emitError(loc) << "Mesh axes contains duplicate elements.";
- }
-
- MeshAxis rank = mesh.getRank();
- for (auto axis : axes) {
- if (axis >= rank || axis < 0) {
- return emitError(loc)
- << "0-based mesh axis index " << axis
- << " is out of bounds. The referenced mesh \"" << mesh.getSymName()
- << "\" is of rank " << rank << ".";
- }
- }
-
- return success();
-}
-
template <typename Op>
static FailureOr<ClusterOp>
getMeshAndVerifyAxes(Op op, SymbolTableCollection &symbolTable) {
diff --git a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
index 044b867..7a70c04 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Mesh/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRMeshTransforms
Simplifications.cpp
ShardingPropagation.cpp
+ Spmdization.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Mesh
@@ -11,11 +12,13 @@ add_mlir_dialect_library(MLIRMeshTransforms
LINK_LIBS PUBLIC
MLIRArithDialect
+ MLIRControlFlowDialect
MLIRFuncDialect
MLIRIR
MLIRMeshDialect
MLIRPass
MLIRShardingInterface
MLIRSupport
+ MLIRTensorDialect
MLIRTosaShardingInterfaceImpl
)
diff --git a/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
new file mode 100644
index 0000000..8d7e896
--- /dev/null
+++ b/mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp
@@ -0,0 +1,639 @@
+//===- Spmdization.cpp --------------------------------------------- C++ --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/Location.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/ADL.h"
+#include "llvm/ADT/APInt.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+#include <algorithm>
+#include <iterator>
+#include <numeric>
+#include <optional>
+#include <tuple>
+#include <type_traits>
+
+namespace mlir {
+namespace mesh {
+
+int64_t shardDimension(int64_t dim, int64_t shardCount) {
+ if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+ return ShapedType::kDynamic;
+
+ assert(dim % shardCount == 0);
+ return ceilDiv(dim, shardCount);
+}
+
+int64_t unshardDimension(int64_t dim, int64_t shardCount) {
+ if (ShapedType::isDynamic(dim) || ShapedType::isDynamic(shardCount))
+ return ShapedType::kDynamic;
+
+ return dim * shardCount;
+}
+
+template <typename MeshShape, typename SplitAxes>
+int64_t shardCount(const MeshShape &meshShape, const SplitAxes &splitAxes) {
+ int64_t res = 1;
+ for (auto splitAxis : splitAxes) {
+ int64_t meshDimSize = meshShape[splitAxis];
+ if (ShapedType::isDynamic(meshDimSize)) {
+ return ShapedType::kDynamic;
+ }
+ res *= meshDimSize;
+ }
+ return res;
+}
+
+// Compute the shape for the tensor on each device in the mesh.
+// Example:
+// On a 2x4x? mesh with split axes = [[0], [1], [2]] the shape ?x5x1
+// would result in a shape for each shard of ?x2x?.
+template <typename InShape, typename MeshShape, typename SplitAxes,
+ typename OutShape>
+static void shardShape(const InShape &inShape, const MeshShape &meshShape,
+ const SplitAxes &splitAxes, OutShape &outShape) {
+ std::copy(llvm::adl_begin(inShape), llvm::adl_end(inShape),
+ llvm::adl_begin(outShape));
+ for (auto [tensorAxis, innerSplitAxes] : llvm::enumerate(splitAxes)) {
+ outShape[tensorAxis] =
+ shardDimension(inShape[tensorAxis],
+ shardCount(meshShape, innerSplitAxes.asArrayRef()));
+ }
+}
+
+ShapedType shardShapedType(ShapedType shape, ClusterOp mesh,
+ MeshShardingAttr sharding) {
+ using Dim = std::decay_t<decltype(shape.getDimSize(0))>;
+ SmallVector<Dim> resShapeArr(shape.getShape().size());
+ shardShape(shape.getShape(), mesh.canonicalDimSizes(),
+ sharding.getSplitAxes(), resShapeArr);
+ return shape.clone(resShapeArr);
+}
+
+template <typename SourceAxes, typename TargetAxes>
+static bool arePartialAxesCompatible(const SourceAxes &sourceAxes,
+ const TargetAxes &targetAxes) {
+ return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) {
+ return sourceAxes.contains(targetAxis);
+ });
+}
+
+// Return the reduced value and its corresponding sharding.
+// Example:
+// sourceSharding = <@mesh_1d, [[0]], partial = sum[0]>
+// targetSharding = <@mesh_1d, [[]]>
+// Then will apply all-reduce on the source value
+// and return it with the sharding <@mesh_1d, [[0]]>.
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+handlePartialAxesDuringResharding(OpBuilder &builder,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceShard) {
+ if (sourceSharding.getPartialAxes().empty() &&
+ targetSharding.getPartialAxes().empty()) {
+ return {sourceShard, sourceSharding};
+ }
+ assert(targetSharding.getPartialAxes().empty() ||
+ (!sourceSharding.getPartialAxes().empty() &&
+ sourceSharding.getPartialType() == targetSharding.getPartialType()));
+ using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>;
+ using AxisSet = llvm::SmallDenseSet<Axis>;
+ AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(),
+ sourceSharding.getPartialAxes().end());
+ AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(),
+ targetSharding.getPartialAxes().end());
+ assert(arePartialAxesCompatible(sourceShardingPartialAxesSet,
+ targetShardingPartialAxesSet));
+ llvm::SmallVector<MeshAxis> allReduceMeshAxes;
+ llvm::copy_if(sourceShardingPartialAxesSet,
+ std::back_inserter(allReduceMeshAxes),
+ [&targetShardingPartialAxesSet](Axis a) {
+ return !targetShardingPartialAxesSet.contains(a);
+ });
+ if (allReduceMeshAxes.empty()) {
+ return {sourceShard, sourceSharding};
+ }
+
+ builder.setInsertionPointAfterValue(sourceShard);
+ TypedValue<ShapedType> resultValue =
+ builder
+ .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(),
+ sourceSharding.getCluster().getLeafReference(),
+ allReduceMeshAxes, sourceShard,
+ sourceSharding.getPartialType())
+ .getResult()
+ .cast<TypedValue<ShapedType>>();
+
+ llvm::SmallVector<int32_t> remainingPartialAxes;
+ llvm::copy_if(sourceShardingPartialAxesSet,
+ std::back_inserter(allReduceMeshAxes),
+ [&targetShardingPartialAxesSet](Axis a) {
+ return targetShardingPartialAxesSet.contains(a);
+ });
+ MeshShardingAttr resultSharding =
+ MeshShardingAttr::get(builder.getContext(), sourceSharding.getCluster(),
+ sourceSharding.getSplitAxes(), remainingPartialAxes,
+ sourceSharding.getPartialType());
+ return {resultValue, resultSharding};
+}
+
+static MeshShardingAttr
+targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+ int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+ SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ llvm::to_vector(sourceSharding.getSplitAxes());
+ while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
+ splitTensorAxis) {
+ targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+ }
+ auto targetSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
+ targetSplitAxes.push_back(splitMeshAxis);
+ targetShardingSplitAxes[splitTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ return MeshShardingAttr::get(
+ ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType targetShapeInSplitLastAxis(ShapedType sourceShape,
+ int64_t splitTensorAxis,
+ int64_t splitCount) {
+ SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+ targetShape[splitTensorAxis] =
+ shardDimension(targetShape[splitTensorAxis], splitCount);
+ return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+// Split a replicated tensor along a mesh axis.
+// e.g. [[0, 1]] -> [[0, 1, 2]].
+// Returns the spmdized target value with its sharding.
+//
+// The implementation is the extract the tensor slice corresponding
+// to the current device.
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+splitLastAxisInResharding(ImplicitLocOpBuilder &builder,
+ MeshShardingAttr sourceSharding,
+ TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+ int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+ MLIRContext *ctx = builder.getContext();
+ builder.setInsertionPointAfterValue(sourceShard);
+
+ Value zero = builder.create<arith::ConstantOp>(builder.getIndexAttr(0));
+
+ Value processIndexAlongAxis =
+ builder
+ .create<ProcessIndexOp>(mesh.getSymName(),
+ SmallVector<MeshAxis>({splitMeshAxis}))
+ .getResult()[0];
+
+ MeshShardingAttr targetSharding = targetShardingInSplitLastAxis(
+ ctx, sourceSharding, splitTensorAxis, splitMeshAxis);
+ ShapedType targetShape =
+ targetShapeInSplitLastAxis(sourceShard.getType(), splitTensorAxis,
+ mesh.canonicalDimSizes()[splitMeshAxis]);
+
+ Value meshAxisSize =
+ builder
+ .create<ClusterShapeOp>(mesh.getSymName(),
+ SmallVector<MeshAxis>({splitMeshAxis}))
+ .getResult()[0];
+
+ Value sourceAxisSize =
+ builder.create<tensor::DimOp>(sourceShard, splitTensorAxis);
+ Value sourceAxisSizeModMeshAxisSize =
+ builder.create<arith::RemUIOp>(sourceAxisSize, meshAxisSize);
+ Value isTargetShapeExactlyDivisible = builder.create<arith::CmpIOp>(
+ arith::CmpIPredicate::eq, sourceAxisSizeModMeshAxisSize, zero);
+ builder.create<cf::AssertOp>(
+ isTargetShapeExactlyDivisible,
+ "Sharding a tensor with axis size that is not exactly divisible by the "
+ "mesh axis size is not supported.");
+ Value targetAxisSize =
+ builder.create<arith::DivUIOp>(sourceAxisSize, meshAxisSize);
+ Value axisOffset =
+ builder.create<arith::MulIOp>(targetAxisSize, processIndexAlongAxis);
+ SmallVector<int64_t> staticOffsets(targetShape.getRank(), 0);
+ staticOffsets[splitTensorAxis] = ShapedType::kDynamic;
+ DenseI64ArrayAttr staticOffsetsAttr =
+ DenseI64ArrayAttr::get(ctx, staticOffsets);
+ SmallVector<Value> dynamicOffsets(1, axisOffset);
+
+ DenseI64ArrayAttr staticSizesAttr =
+ DenseI64ArrayAttr::get(ctx, targetShape.getShape());
+ SmallVector<Value> dynamicSizes;
+ for (int64_t i = 0; i < targetShape.getRank(); ++i) {
+ if (ShapedType::isDynamic(staticSizesAttr.asArrayRef()[i])) {
+ if (i == splitTensorAxis) {
+ dynamicSizes.push_back(targetAxisSize);
+ } else {
+ Value dimSize = builder.create<tensor::DimOp>(sourceShard, i);
+ dynamicSizes.push_back(dimSize);
+ }
+ }
+ }
+
+ DenseI64ArrayAttr staticStridesAttr = DenseI64ArrayAttr::get(
+ ctx, SmallVector<int64_t>(targetShape.getRank(), 1));
+ TypedValue<RankedTensorType> targetShard =
+ builder
+ .create<tensor::ExtractSliceOp>(
+ targetShape, sourceShard, dynamicOffsets, dynamicSizes,
+ SmallVector<Value>({}), staticOffsetsAttr, staticSizesAttr,
+ staticStridesAttr)
+ .getResult();
+ return {targetShard.cast<TypedValue<ShapedType>>(), targetSharding};
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1]] -> [[0, 1, 2]].
+// If detected, returns the corresponding tensor axis mesh axis pair.
+// Does not detect insertions like
+// [[0, 1]] -> [[0, 2, 1]].
+static std::optional<std::tuple<int64_t, MeshAxis>>
+detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding) {
+ for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size();
+ ++tensorAxis) {
+ if (sourceSharding.getSplitAxes().size() > tensorAxis) {
+ if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 !=
+ targetSharding.getSplitAxes()[tensorAxis].size()) {
+ continue;
+ }
+ if (!llvm::equal(
+ sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(),
+ llvm::make_range(
+ targetSharding.getSplitAxes()[tensorAxis]
+ .asArrayRef()
+ .begin(),
+ targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
+ 1))) {
+ continue;
+ }
+ } else {
+ if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) {
+ continue;
+ }
+ }
+ return std::make_tuple(
+ tensorAxis,
+ targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+ }
+ return std::nullopt;
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceShard) {
+ if (auto detectRes =
+ detectSplitLastAxisInResharding(sourceSharding, targetSharding)) {
+ auto [tensorAxis, meshAxis] = detectRes.value();
+ return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh,
+ tensorAxis, meshAxis);
+ }
+
+ return std::nullopt;
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1, 2]] -> [[0, 1]].
+// If detected, returns the corresponding tensor axis mesh axis pair.
+static std::optional<std::tuple<int64_t, MeshAxis>>
+detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding) {
+ for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size();
+ ++tensorAxis) {
+ if (targetSharding.getSplitAxes().size() > tensorAxis) {
+ if (sourceSharding.getSplitAxes()[tensorAxis].size() !=
+ targetSharding.getSplitAxes()[tensorAxis].size() + 1)
+ continue;
+ if (!llvm::equal(
+ llvm::make_range(
+ sourceSharding.getSplitAxes()[tensorAxis]
+ .asArrayRef()
+ .begin(),
+ sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() -
+ 1),
+ targetSharding.getSplitAxes()[tensorAxis].asArrayRef()))
+ continue;
+ } else {
+ if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1)
+ continue;
+ }
+ return std::make_tuple(
+ tensorAxis,
+ sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back());
+ }
+ return std::nullopt;
+}
+
+static MeshShardingAttr
+targetShardingInUnsplitLastAxis(MLIRContext *ctx,
+ MeshShardingAttr sourceSharding,
+ int64_t splitTensorAxis) {
+ SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ llvm::to_vector(sourceSharding.getSplitAxes());
+ assert(static_cast<int64_t>(targetShardingSplitAxes.size()) >
+ splitTensorAxis);
+ auto targetSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef());
+
+ targetSplitAxes.pop_back();
+ targetShardingSplitAxes[splitTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+ return MeshShardingAttr::get(
+ ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType allGatherResultShapeInUnsplitLastAxis(
+ ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) {
+ SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+ targetShape[splitTensorAxis] =
+ unshardDimension(targetShape[splitTensorAxis], splitCount);
+ return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder,
+ MeshShardingAttr sourceSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard, ClusterOp mesh,
+ int64_t splitTensorAxis, MeshAxis splitMeshAxis) {
+ MLIRContext *ctx = builder.getContext();
+ builder.setInsertionPointAfterValue(sourceShard);
+
+ MeshShardingAttr targetSharding =
+ targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis);
+ ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis(
+ sourceShard.getType(), mesh.canonicalDimSizes()[splitMeshAxis],
+ splitTensorAxis);
+ Value allGatherResult = builder.create<AllGatherOp>(
+ RankedTensorType::get(allGatherResultShape.getShape(),
+ allGatherResultShape.getElementType()),
+ mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard,
+ APInt(64, splitTensorAxis));
+ ShapedType targetShape =
+ shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+ TypedValue<ShapedType> targetShard =
+ builder.create<tensor::CastOp>(targetShape, allGatherResult)
+ .getResult()
+ .cast<TypedValue<ShapedType>>();
+ return {targetShard, targetSharding};
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard) {
+ if (auto detectRes =
+ detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) {
+ auto [tensorAxis, meshAxis] = detectRes.value();
+ return unsplitLastAxisInResharding(builder, sourceSharding,
+ sourceUnshardedShape, sourceShard, mesh,
+ tensorAxis, meshAxis);
+ }
+
+ return std::nullopt;
+}
+
+// Detect if the resharding is of type e.g.
+// [[0, 1], [2]] -> [[0], [1, 2]].
+// Only moving the last axis counts.
+// If detected, returns the corresponding (source_tensor_axis,
+// target_tensor_axis, mesh_axis) tuple.
+static std::optional<std::tuple<int64_t, int64_t, MeshAxis>>
+detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding) {
+ for (size_t sourceTensorAxis = 0;
+ sourceTensorAxis < sourceSharding.getSplitAxes().size();
+ ++sourceTensorAxis) {
+ for (size_t targetTensorAxis = 0;
+ targetTensorAxis < targetSharding.getSplitAxes().size();
+ ++targetTensorAxis) {
+ if (sourceTensorAxis == targetTensorAxis)
+ continue;
+ if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() ||
+ targetSharding.getSplitAxes()[targetTensorAxis].empty() ||
+ sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() !=
+ targetSharding.getSplitAxes()[targetTensorAxis]
+ .asArrayRef()
+ .back())
+ continue;
+ if (!llvm::equal(
+ llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis]
+ .asArrayRef()
+ .begin(),
+ sourceSharding.getSplitAxes()[sourceTensorAxis]
+ .asArrayRef()
+ .end() -
+ 1),
+ llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis]
+ .asArrayRef()
+ .begin(),
+ targetSharding.getSplitAxes()[targetTensorAxis]
+ .asArrayRef()
+ .end() -
+ 1)))
+ continue;
+ return std::make_tuple(
+ sourceTensorAxis, targetTensorAxis,
+ sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back());
+ }
+ }
+ return std::nullopt;
+}
+
+static MeshShardingAttr
+targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis) {
+ SmallVector<DenseI32ArrayAttr> targetShardingSplitAxes =
+ llvm::to_vector(sourceSharding.getSplitAxes());
+ while (static_cast<int64_t>(targetShardingSplitAxes.size()) <=
+ targetTensorAxis) {
+ targetShardingSplitAxes.push_back(DenseI32ArrayAttr::get(ctx, {}));
+ }
+
+ auto sourceSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef());
+ assert(!sourceSplitAxes.empty());
+ auto meshAxis = sourceSplitAxes.back();
+ sourceSplitAxes.pop_back();
+ targetShardingSplitAxes[sourceTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, sourceSplitAxes);
+
+ auto targetSplitAxes =
+ llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef());
+ targetSplitAxes.push_back(meshAxis);
+ targetShardingSplitAxes[targetTensorAxis] =
+ DenseI32ArrayAttr::get(ctx, targetSplitAxes);
+
+ return MeshShardingAttr::get(
+ ctx, sourceSharding.getCluster(), targetShardingSplitAxes,
+ sourceSharding.getPartialAxes(), sourceSharding.getPartialType());
+}
+
+static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape,
+ int64_t splitCount,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis) {
+ SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape());
+ targetShape[sourceTensorAxis] =
+ unshardDimension(targetShape[sourceTensorAxis], splitCount);
+ targetShape[targetTensorAxis] =
+ shardDimension(targetShape[targetTensorAxis], splitCount);
+ return sourceShape.cloneWith(targetShape, sourceShape.getElementType());
+}
+
+static std::tuple<TypedValue<ShapedType>, MeshShardingAttr>
+moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard,
+ int64_t sourceTensorAxis,
+ int64_t targetTensorAxis, MeshAxis meshAxis) {
+ MLIRContext *ctx = builder.getContext();
+ builder.setInsertionPointAfterValue(sourceShard);
+
+ MeshShardingAttr targetSharding = targetShardingInMoveLastAxis(
+ ctx, sourceSharding, sourceTensorAxis, targetTensorAxis);
+ ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis(
+ sourceShard.getType(), mesh.canonicalDimSizes()[meshAxis],
+ sourceTensorAxis, targetTensorAxis);
+ Value allToAllResult = builder.create<AllToAllOp>(
+ RankedTensorType::get(allToAllResultShape.getShape(),
+ allToAllResultShape.getElementType()),
+ mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard,
+ APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis));
+ ShapedType targetShape =
+ shardShapedType(sourceUnshardedShape, mesh, targetSharding);
+ TypedValue<ShapedType> targetShard =
+ builder.create<tensor::CastOp>(targetShape, allToAllResult)
+ .getResult()
+ .cast<TypedValue<ShapedType>>();
+ return {targetShard, targetSharding};
+}
+
+static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>>
+tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ ShapedType sourceUnshardedShape,
+ TypedValue<ShapedType> sourceShard) {
+ if (auto detectRes =
+ detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) {
+ auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value();
+ return moveLastSplitAxisInResharding(
+ builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard,
+ sourceTensorAxis, targetTensorAxis, meshAxis);
+ }
+
+ return std::nullopt;
+}
+
+// Handles only resharding on a 1D mesh.
+// Currently the sharded tensor axes must be exactly divisible by the single
+// mesh axis size.
+static TypedValue<ShapedType>
+reshardOn1DMesh(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceUnshardedValue,
+ TypedValue<ShapedType> sourceShard) {
+ assert(sourceShard.getType() ==
+ shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding));
+ [[maybe_unused]] ShapedType targetShardType =
+ shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding);
+ assert(sourceShard.getType().getRank() == targetShardType.getRank());
+ assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported.");
+
+ auto [reducedSourceShard, reducedSourceSharding] =
+ handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding,
+ sourceShard);
+
+ if (reducedSourceSharding == targetSharding) {
+ return reducedSourceShard;
+ }
+
+ TypedValue<ShapedType> targetShard;
+ MeshShardingAttr actualTargetSharding;
+ if (auto tryRes = tryMoveLastSplitAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ } else if (auto tryRes = trySplitLastAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ } else if (auto tryRes = tryUnsplitLastAxisInResharding(
+ builder, mesh, reducedSourceSharding, targetSharding,
+ sourceUnshardedValue.getType(), reducedSourceShard)) {
+ std::tie(targetShard, actualTargetSharding) = tryRes.value();
+ } else {
+ assert(false && "Did not find any pattern to apply.");
+ }
+
+ assert(actualTargetSharding == targetSharding);
+ assert(targetShard.getType() == targetShardType);
+ return targetShard;
+}
+
+TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, ClusterOp mesh,
+ MeshShardingAttr sourceSharding,
+ MeshShardingAttr targetSharding,
+ TypedValue<ShapedType> sourceUnshardedValue,
+ TypedValue<ShapedType> sourceShard) {
+ // Resort to handling only 1D meshes since the general case is complicated if
+ // it needs to be communication efficient in terms of minimizing the data
+ // transfered between devices.
+ return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding,
+ sourceUnshardedValue, sourceShard);
+}
+
+TypedValue<ShapedType> reshard(OpBuilder &builder, ClusterOp mesh,
+ ShardOp source, ShardOp target,
+ TypedValue<ShapedType> sourceShardValue) {
+ assert(!source.getAnnotateForUsers());
+ assert(target.getAnnotateForUsers());
+ assert(source.getResult() == target.getOperand());
+ ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder);
+ return reshard(
+ implicitLocOpBuilder, mesh, source.getShard(), target.getShard(),
+ source.getSrc().cast<TypedValue<ShapedType>>(), sourceShardValue);
+}
+
+void reshardingRegisterDependentDialects(DialectRegistry &registry) {
+ registry.insert<arith::ArithDialect, mesh::MeshDialect, tensor::TensorDialect,
+ cf::ControlFlowDialect>();
+}
+
+} // namespace mesh
+} // namespace mlir
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 8af3b69..87a37a7 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -448,6 +448,23 @@ static bool isAdmissibleBSR(SparseTensorType &aTp) {
return false;
}
+/// Test for 2:4 matrix with suitable metadata.
+static bool isAdmissible24(SparseTensorType &aTp) {
+ return aTp.getDimRank() == 2 && aTp.getLvlRank() == 3 && aTp.isDenseLvl(0) &&
+ aTp.isDenseLvl(1) && aTp.is2OutOf4Lvl(2) && isAdmissibleMetaData(aTp);
+}
+
+/// Test for conversion into 2:4 matrix.
+static bool isConversionInto24(Value v) {
+ if (auto cnv = v.getDefiningOp<ConvertOp>()) {
+ Value a = cnv.getResult();
+ Value d = cnv.getSource();
+ SparseTensorType aTp = getSparseTensorType(a);
+ return isDenseTensor(d) && isAdmissible24(aTp);
+ }
+ return false;
+}
+
/// Returns a suitable sparse format for the operation and given operand
/// types with cuSparse, or kNone if none is available.
static CuSparseFormat getCuSparseFormat(SparseTensorType aTp,
@@ -925,6 +942,15 @@ static LogicalResult rewrite2To4SpMM(PatternRewriter &rewriter,
Value C = op.getOperand(2); // we have C = AB
SmallVector<Value> tokens;
+ // The cuSparselt API currently only allows pruning and compression
+ // to occur on the device. So we recognize the pattern
+ // A' = convert A ; dense to 2:4
+ // C = A'B ; 2:4 matrix mult
+ // and then perform compression and matrix multiplication on device.
+ auto cnv = A.getDefiningOp<ConvertOp>();
+ assert(cnv);
+ A = cnv.getSource();
+
// All input should be dense tensors.
if (!isDenseTensor(A) || !isDenseTensor(B) || !isDenseTensor(C))
return failure();
@@ -1260,7 +1286,7 @@ struct LinalgOpRewriter : public OpRewritePattern<linalg::GenericOp> {
maps == infer({{i, k}, {k, j}, {i, j}}) && matchSumOfMultOfArgs(op)) {
if (!isDenseTensor(op.getOperand(0)) && !isDenseTensor(op.getOperand(1)))
return rewriteSpGEMM(rewriter, op, enableRT);
- if (op->getAttr("DENSE24"))
+ if (isConversionInto24(op.getOperand(0)))
return rewrite2To4SpMM(rewriter, op);
return rewriteSpMM(rewriter, op, enableRT);
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 934e1e5..35eb4b4 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -949,94 +949,9 @@ static void endIf(CodegenEnv &env, OpBuilder &builder, scf::IfOp ifOp,
// Sparsifier synthesis methods (loop sequence).
//===----------------------------------------------------------------------===//
-/// Starts a loop sequence at given level. Returns true if
-/// the universal loop index must be maintained at this level.
-static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
- LoopId curr, LatSetId lts) {
- assert(!env.getLoopVar(curr));
- // Emit invariants at this loop sequence level.
- genInvariants(env, builder, exp, curr, /*isStart=*/true);
- // Emit access pattern expansion for sparse tensor output.
- genExpand(env, builder, curr, /*isStart=*/true);
- // Emit further intitialization at this loop sequence level.
- const LatPointId l0 = env.set(lts)[0];
- bool needsUniv = false;
-
- SmallVector<TensorLevel> tidLvls;
- env.merger().foreachTensorLoopId(l0, [&](TensorLoopId b, TensorId tid,
- std::optional<Level> lvl,
- LevelType lt, bool isIdxReduc) {
- assert(env.merger().loop(b) == curr);
- if (isDenseLT(lt) || isUndefLT(lt)) {
- if (tid == env.merger().getSynTensorID()) {
- // Needs loop emitter to set up loop bounds for synthetic tensor too if
- // there is a loop condition imposed on the synthetic tensor.
- tidLvls.push_back(env.makeTensorLevel(tid, env.getCurrentDepth()));
- }
- needsUniv = true;
- }
- if (isCompressedLT(lt) || isSingletonLT(lt) || isLooseCompressedLT(lt) ||
- is2OutOf4LT(lt) || isIdxReduc) {
- // Only when this is a index reduction loop, can the lt be undefined.
- assert(!isUndefLT(lt) || isIdxReduc);
- // sparse/singleton levels, or a dense/sparse index reduction loop.
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
- }
- });
-
- env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
-
- // Maintain the universal index only if it is actually
- // consumed by a subsequent lattice point.
- if (needsUniv) {
- for (const LatPointId li : env.set(lts).drop_front())
- if (!env.merger().hasAnySparse(env.lat(li).simple))
- return true;
- }
- return false;
-}
-
-// Generates dense affine address for encoding.
-static void genConstantDenseAddressFromLevel(CodegenEnv &env,
- OpBuilder &builder, TensorId tid,
- Level startLvl) {
- // TODO: Handle affine expression on output tensor.
- linalg::GenericOp op = env.op();
- assert(tid < op.getNumDpsInputs());
- OpOperand *input = op.getDpsInputOperands()[tid];
- const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
- const auto enc = getSparseTensorEncoding(input->get().getType());
- if (enc) {
- const Location loc = op.getLoc();
- const TensorId tid = env.makeTensorId(input->getOperandNumber());
- const Level lvlRank = enc.getLvlRank();
- assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
- for (Level l = startLvl; l < lvlRank; l++) {
- AffineExpr lvlExpr = lvlExprs[l];
- if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
- env.emitter().genDenseAffineAddress(
- builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
- else
- return; // break on first non-dense non-constant level
- }
- }
-}
-
-// We can generate address for constant affine expression before any loops
-// starting from the first level as they do not depend on any thing.
-// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
-// levels can be determined before loops.
-static void genInitConstantDenseAddress(CodegenEnv &env,
- RewriterBase &rewriter) {
- for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
- genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
-}
-
-/// Return true if the lattices bit can be iterated by a for loop.
-static bool translateBitsToTidLvlPairs(
+static bool getAllTidLvlsInLatPoints(
CodegenEnv &env, LatPointId li, LoopId curr,
- SmallVectorImpl<TensorLevel> &tidLvls,
- SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+ llvm::function_ref<void(TensorLevel, AffineExpr)> callback) {
const BitVector &simple = env.lat(li).simple;
const TensorId outTid = env.merger().getOutTensorID();
const std::optional<Level> outLvl = env.merger().getLvl(outTid, curr);
@@ -1048,7 +963,7 @@ static bool translateBitsToTidLvlPairs(
LevelType lt, bool isIdxReduc) {
if (simple[b]) {
if (isIdxReduc) {
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
return;
}
@@ -1072,10 +987,10 @@ static bool translateBitsToTidLvlPairs(
}
}
hasNonUnique = !isUniqueLT(lt) || hasNonUnique;
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
numloopCond++;
} else if (isDenseLT(lt) || isIdxReduc) {
- tidLvls.push_back(env.makeTensorLevel(tid, *lvl));
+ callback(env.makeTensorLevel(tid, *lvl), nullptr);
} else {
assert(isUndefLT(lt));
linalg::GenericOp op = env.op();
@@ -1109,7 +1024,7 @@ static bool translateBitsToTidLvlPairs(
// level. We need to generate the address according to the
// affine expression. This is also the best place we can do it
// to avoid putting it inside inner loops.
- affineTidLvls.emplace_back(env.makeTensorLevel(tid, l), exp);
+ callback(env.makeTensorLevel(tid, l), exp);
}
}
}
@@ -1120,15 +1035,14 @@ static bool translateBitsToTidLvlPairs(
// Note that we generate dense indices of the output tensor
// unconditionally, since they may not appear in the lattice, but may be
// needed for linearized env.
- tidLvls.push_back(env.makeTensorLevel(outTid, *outLvl));
+ callback(env.makeTensorLevel(outTid, *outLvl), nullptr);
}
if (numloopCond == 0) {
// Corner cases where the loop bound is defined by a *unused* operand, in
// this case, we just generate a dense "fake" loop by iterating over the
// synthetic tensor.
- tidLvls.push_back(env.makeTensorLevel(env.merger().getSynTensorID(),
- env.getCurrentDepth()));
+ callback(env.makeTensorLevel(env.merger().getSynTensorID(), curr), nullptr);
numloopCond++;
}
// If we just need to one loop conditions and the conditions is not imposed on
@@ -1136,6 +1050,84 @@ static bool translateBitsToTidLvlPairs(
return numloopCond == 1 && !hasNonUnique;
}
+/// Starts a loop sequence at given level. Returns true if
+/// the universal loop index must be maintained at this level.
+static bool startLoopSeq(CodegenEnv &env, OpBuilder &builder, ExprId exp,
+ LoopId curr, LatSetId lts) {
+ assert(!env.getLoopVar(curr));
+ // Emit invariants at this loop sequence level.
+ genInvariants(env, builder, exp, curr, /*isStart=*/true);
+ // Emit access pattern expansion for sparse tensor output.
+ genExpand(env, builder, curr, /*isStart=*/true);
+ // Emit further initialization at this loop sequence level.
+ const LatPointId l0 = env.set(lts)[0];
+
+ SmallVector<TensorLevel> tidLvls;
+ getAllTidLvlsInLatPoints(env, l0, curr, [&](TensorLevel tl, AffineExpr) {
+ tidLvls.emplace_back(tl);
+ });
+
+ env.emitter().enterNewLoopSeq(builder, env.op().getLoc(), tidLvls);
+
+ // Maintain the universal index only if it is actually
+ // consumed by a subsequent lattice point.
+ for (const LatPointId li : env.set(lts).drop_front())
+ if (!env.merger().hasAnySparse(env.lat(li).simple))
+ return true;
+
+ return false;
+}
+
+// Generates dense affine address for encoding.
+static void genConstantDenseAddressFromLevel(CodegenEnv &env,
+ OpBuilder &builder, TensorId tid,
+ Level startLvl) {
+ // TODO: Handle affine expression on output tensor.
+ linalg::GenericOp op = env.op();
+ assert(tid < op.getNumDpsInputs());
+ OpOperand *input = op.getDpsInputOperands()[tid];
+ const auto lvlExprs = op.getMatchingIndexingMap(input).getResults();
+ const auto enc = getSparseTensorEncoding(input->get().getType());
+ if (enc) {
+ const Location loc = op.getLoc();
+ const TensorId tid = env.makeTensorId(input->getOperandNumber());
+ const Level lvlRank = enc.getLvlRank();
+ assert(lvlExprs.size() == static_cast<size_t>(lvlRank));
+ for (Level l = startLvl; l < lvlRank; l++) {
+ AffineExpr lvlExpr = lvlExprs[l];
+ if (enc.isDenseLvl(l) && isa<AffineConstantExpr>(lvlExpr))
+ env.emitter().genDenseAffineAddress(
+ builder, loc, env.makeTensorLevel(tid, l), lvlExpr);
+ else
+ return; // break on first non-dense non-constant level
+ }
+ }
+}
+
+// We can generate address for constant affine expression before any loops
+// starting from the first level as they do not depend on anything.
+// E.g., [Dense, Dense, Sparse] -> (1, 2, d0), the addresses for the first two
+// levels can be determined before loops.
+static void genInitConstantDenseAddress(CodegenEnv &env,
+ RewriterBase &rewriter) {
+ for (TensorId tid = 0, e = env.op().getNumDpsInputs(); tid < e; tid++)
+ genConstantDenseAddressFromLevel(env, rewriter, tid, 0);
+}
+
+/// Returns true if the lattice bit can be iterated by a for loop.
+static bool translateBitsToTidLvlPairs(
+ CodegenEnv &env, LatPointId li, LoopId curr,
+ SmallVectorImpl<TensorLevel> &tidLvls,
+ SmallVectorImpl<std::pair<TensorLevel, AffineExpr>> &affineTidLvls) {
+ return getAllTidLvlsInLatPoints(env, li, curr,
+ [&](TensorLevel tl, AffineExpr exp) {
+ if (exp)
+ affineTidLvls.emplace_back(tl, exp);
+ else
+ tidLvls.emplace_back(tl);
+ });
+}
+
/// Starts a single loop in current sequence.
static std::pair<Operation *, bool> startLoop(CodegenEnv &env,
OpBuilder &builder, LoopId curr,
diff --git a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
index e20450c..cfd838e 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/PackAndUnpackPatterns.cpp
@@ -61,6 +61,47 @@ struct SimplifyPackToExpandShape : public OpRewritePattern<PackOp> {
}
};
+struct SimplifyUnPackToCollapseShape : public OpRewritePattern<UnPackOp> {
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+ Value insertCollapse(RewriterBase &rewriter, Location loc, Value operand,
+ Type newOperandType, ArrayAttr reassociation) const {
+ if (operand.getType() == newOperandType)
+ return operand;
+ return rewriter.create<tensor::CollapseShapeOp>(loc, newOperandType,
+ operand, reassociation);
+ }
+
+ LogicalResult matchAndRewrite(UnPackOp unpackOp,
+ PatternRewriter &rewriter) const override {
+ if (!unpackOp.getOuterDimsPerm().empty()) {
+ return rewriter.notifyMatchFailure(unpackOp,
+ "expects no outer_dims_perm");
+ }
+
+ RankedTensorType sourceType = unpackOp.getSourceType();
+ RankedTensorType destType = unpackOp.getDestType();
+ if (!sourceType.hasStaticShape() || !destType.hasStaticShape())
+ return rewriter.notifyMatchFailure(unpackOp, "expects static shapes");
+
+ ArrayRef<int64_t> dimsPos = unpackOp.getInnerDimsPos();
+ if (dimsPos.size() != 1 || (dimsPos[0] + 1 != destType.getRank())) {
+ return rewriter.notifyMatchFailure(
+ unpackOp, "expects unpacking at the innermost dimension");
+ }
+
+ auto reassociation =
+ getReassociationIndicesForReshape(sourceType, destType);
+ if (!reassociation)
+ return failure();
+ Value collapsed = insertCollapse(
+ rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType,
+ getReassociationIndicesAttribute(rewriter, *reassociation));
+ rewriter.replaceOp(unpackOp, collapsed);
+ return success();
+ }
+};
+
/// Fold a `pad` -> `pack` into `pack` if they have the same padding values and
/// the pad op has zero low paddings, or if `pack` has no padding values.
struct FoldPadWithPackOp : public OpRewritePattern<PackOp> {
@@ -191,7 +232,8 @@ void populateFoldIntoPackAndUnpackPatterns(RewritePatternSet &patterns) {
}
void populateSimplifyPackAndUnpackPatterns(RewritePatternSet &patterns) {
- patterns.add<SimplifyPackToExpandShape>(patterns.getContext());
+ patterns.add<SimplifyPackToExpandShape, SimplifyUnPackToCollapseShape>(
+ patterns.getContext());
}
} // namespace tensor
diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
index 7136e42..aa4694c 100644
--- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
+++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp
@@ -32,6 +32,7 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -1975,6 +1976,42 @@ void transform::NamedSequenceOp::build(OpBuilder &builder,
}
//===----------------------------------------------------------------------===//
+// NumAssociationsOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure
+transform::NumAssociationsOp::apply(transform::TransformRewriter &rewriter,
+ transform::TransformResults &results,
+ transform::TransformState &state) {
+ size_t numAssociations =
+ llvm::TypeSwitch<Type, size_t>(getHandle().getType())
+ .Case([&](TransformHandleTypeInterface opHandle) {
+ return llvm::range_size(state.getPayloadOps(getHandle()));
+ })
+ .Case([&](TransformValueHandleTypeInterface valueHandle) {
+ return llvm::range_size(state.getPayloadValues(getHandle()));
+ })
+ .Case([&](TransformParamTypeInterface param) {
+ return llvm::range_size(state.getParams(getHandle()));
+ })
+ .Default([](Type) {
+ llvm_unreachable("unknown kind of transform dialect type");
+ return 0;
+ });
+ results.setParams(getNum().cast<OpResult>(),
+ rewriter.getI64IntegerAttr(numAssociations));
+ return DiagnosedSilenceableFailure::success();
+}
+
+LogicalResult transform::NumAssociationsOp::verify() {
+ // Verify that the result type accepts an i64 attribute as payload.
+ auto resultType = getNum().getType().cast<TransformParamTypeInterface>();
+ return resultType
+ .checkPayload(getLoc(), {Builder(getContext()).getI64IntegerAttr(0)})
+ .checkAndReport();
+}
+
+//===----------------------------------------------------------------------===//
// SelectOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 2ad992a..c1c0f54 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -271,7 +271,7 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
return false;
// Cond 1: A contiguous memref will always have a unit trailing stride.
- if (strides.back() != 1)
+ if (strides.empty() || strides.back() != 1)
return false;
// Cond 2: Strides of a contiguous memref have to match the flattened dims.
diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
index c45320a..b9a3429 100644
--- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
+++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
@@ -970,7 +970,7 @@ mgpuCuSparseLtSpMMBufferSize(void *bs, int32_t ma, int32_t mb, void *a, void *b,
// Note that this adds a synchronization on the stream.
// TODO: Do we want that?
if (prune_flag == 2) {
- int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream);
+ int *dvalid = (int *)mgpuMemAlloc(sizeof(int), stream, false);
CUSPARSE_REPORT_IF_ERROR(cusparseLtSpMMAPruneCheck(
&cusparseLt_env, &(matA->matmul), matA->values, dvalid, stream))
int valid = 0;
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 1f7cbf3..8fe8c78 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -3542,8 +3542,9 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
os << ')';
}
- auto attrs = op->getDiscardableAttrs();
- printOptionalAttrDict(attrs);
+ printOptionalAttrDict(op->getPropertiesStorage()
+ ? llvm::to_vector(op->getDiscardableAttrs())
+ : op->getAttrs());
// Print the type signature of the operation.
os << " : ";
diff --git a/mlir/lib/Interfaces/InferTypeOpInterface.cpp b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
index 3c50c4c..ee4c051 100644
--- a/mlir/lib/Interfaces/InferTypeOpInterface.cpp
+++ b/mlir/lib/Interfaces/InferTypeOpInterface.cpp
@@ -240,8 +240,9 @@ LogicalResult mlir::detail::verifyInferredResultTypes(Operation *op) {
auto retTypeFn = cast<InferTypeOpInterface>(op);
auto result = retTypeFn.refineReturnTypes(
op->getContext(), op->getLoc(), op->getOperands(),
- op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
- op->getRegions(), inferredReturnTypes);
+ op->getPropertiesStorage() ? op->getDiscardableAttrDictionary()
+ : op->getAttrDictionary(),
+ op->getPropertiesStorage(), op->getRegions(), inferredReturnTypes);
if (failed(result))
op->emitOpError() << "failed to infer returned types";
diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp
index 810d6a3..5ee0ae6 100644
--- a/mlir/lib/Pass/Pass.cpp
+++ b/mlir/lib/Pass/Pass.cpp
@@ -382,16 +382,22 @@ StringRef OpPassManager::getOpAnchorName() const {
/// Prints out the passes of the pass manager as the textual representation
/// of pipelines.
-void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
- os << getOpAnchorName() << "(";
+void printAsTextualPipeline(
+ raw_ostream &os, StringRef anchorName,
+ const llvm::iterator_range<OpPassManager::pass_iterator> &passes) {
+ os << anchorName << "(";
llvm::interleave(
- impl->passes,
- [&](const std::unique_ptr<Pass> &pass) {
- pass->printAsTextualPipeline(os);
- },
+ passes, [&](mlir::Pass &pass) { pass.printAsTextualPipeline(os); },
[&]() { os << ","; });
os << ")";
}
+void OpPassManager::printAsTextualPipeline(raw_ostream &os) const {
+ StringRef anchorName = getOpAnchorName();
+ ::printAsTextualPipeline(
+ os, anchorName,
+ {MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.begin(),
+ MutableArrayRef<std::unique_ptr<Pass>>{impl->passes}.end()});
+}
void OpPassManager::dump() {
llvm::errs() << "Pass Manager with " << impl->passes.size() << " passes:\n";
diff --git a/mlir/lib/Pass/PassCrashRecovery.cpp b/mlir/lib/Pass/PassCrashRecovery.cpp
index df1a076..3f3928e 100644
--- a/mlir/lib/Pass/PassCrashRecovery.cpp
+++ b/mlir/lib/Pass/PassCrashRecovery.cpp
@@ -38,7 +38,7 @@ namespace detail {
/// reproducers when a signal is raised, such as a segfault.
struct RecoveryReproducerContext {
RecoveryReproducerContext(std::string passPipelineStr, Operation *op,
- PassManager::ReproducerStreamFactory &streamFactory,
+ ReproducerStreamFactory &streamFactory,
bool verifyPasses);
~RecoveryReproducerContext();
@@ -67,7 +67,7 @@ private:
/// The factory for the reproducer output stream to use when generating the
/// reproducer.
- PassManager::ReproducerStreamFactory &streamFactory;
+ ReproducerStreamFactory &streamFactory;
/// Various pass manager and context flags.
bool disableThreads;
@@ -92,7 +92,7 @@ llvm::ManagedStatic<llvm::SmallSetVector<RecoveryReproducerContext *, 1>>
RecoveryReproducerContext::RecoveryReproducerContext(
std::string passPipelineStr, Operation *op,
- PassManager::ReproducerStreamFactory &streamFactory, bool verifyPasses)
+ ReproducerStreamFactory &streamFactory, bool verifyPasses)
: pipelineElements(std::move(passPipelineStr)),
preCrashOperation(op->clone()), streamFactory(streamFactory),
disableThreads(!op->getContext()->isMultithreadingEnabled()),
@@ -106,22 +106,24 @@ RecoveryReproducerContext::~RecoveryReproducerContext() {
disable();
}
-void RecoveryReproducerContext::generate(std::string &description) {
+static void appendReproducer(std::string &description, Operation *op,
+ const ReproducerStreamFactory &factory,
+ const std::string &pipelineElements,
+ bool disableThreads, bool verifyPasses) {
llvm::raw_string_ostream descOS(description);
// Try to create a new output stream for this crash reproducer.
std::string error;
- std::unique_ptr<PassManager::ReproducerStream> stream = streamFactory(error);
+ std::unique_ptr<ReproducerStream> stream = factory(error);
if (!stream) {
descOS << "failed to create output stream: " << error;
return;
}
descOS << "reproducer generated at `" << stream->description() << "`";
- std::string pipeline = (preCrashOperation->getName().getStringRef() + "(" +
- pipelineElements + ")")
- .str();
- AsmState state(preCrashOperation);
+ std::string pipeline =
+ (op->getName().getStringRef() + "(" + pipelineElements + ")").str();
+ AsmState state(op);
state.attachResourcePrinter(
"mlir_reproducer", [&](Operation *op, AsmResourceBuilder &builder) {
builder.buildString("pipeline", pipeline);
@@ -130,7 +132,12 @@ void RecoveryReproducerContext::generate(std::string &description) {
});
// Output the .mlir module.
- preCrashOperation->print(stream->os(), state);
+ op->print(stream->os(), state);
+}
+
+void RecoveryReproducerContext::generate(std::string &description) {
+ appendReproducer(description, preCrashOperation, streamFactory,
+ pipelineElements, disableThreads, verifyPasses);
}
void RecoveryReproducerContext::disable() {
@@ -175,12 +182,11 @@ void RecoveryReproducerContext::registerSignalHandler() {
//===----------------------------------------------------------------------===//
struct PassCrashReproducerGenerator::Impl {
- Impl(PassManager::ReproducerStreamFactory &streamFactory,
- bool localReproducer)
+ Impl(ReproducerStreamFactory &streamFactory, bool localReproducer)
: streamFactory(streamFactory), localReproducer(localReproducer) {}
/// The factory to use when generating a crash reproducer.
- PassManager::ReproducerStreamFactory streamFactory;
+ ReproducerStreamFactory streamFactory;
/// Flag indicating if reproducer generation should be localized to the
/// failing pass.
@@ -198,7 +204,7 @@ struct PassCrashReproducerGenerator::Impl {
};
PassCrashReproducerGenerator::PassCrashReproducerGenerator(
- PassManager::ReproducerStreamFactory &streamFactory, bool localReproducer)
+ ReproducerStreamFactory &streamFactory, bool localReproducer)
: impl(std::make_unique<Impl>(streamFactory, localReproducer)) {}
PassCrashReproducerGenerator::~PassCrashReproducerGenerator() = default;
@@ -382,9 +388,9 @@ private:
//===----------------------------------------------------------------------===//
namespace {
-/// This class represents a default instance of PassManager::ReproducerStream
+/// This class represents a default instance of mlir::ReproducerStream
/// that is backed by a file.
-struct FileReproducerStream : public PassManager::ReproducerStream {
+struct FileReproducerStream : public mlir::ReproducerStream {
FileReproducerStream(std::unique_ptr<llvm::ToolOutputFile> outputFile)
: outputFile(std::move(outputFile)) {}
~FileReproducerStream() override { outputFile->keep(); }
@@ -418,22 +424,45 @@ LogicalResult PassManager::runWithCrashRecovery(Operation *op,
return passManagerResult;
}
-void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
- bool genLocalReproducer) {
+static ReproducerStreamFactory
+makeReproducerStreamFactory(StringRef outputFile) {
// Capture the filename by value in case outputFile is out of scope when
// invoked.
std::string filename = outputFile.str();
- enableCrashReproducerGeneration(
- [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
- std::unique_ptr<llvm::ToolOutputFile> outputFile =
- mlir::openOutputFile(filename, &error);
- if (!outputFile) {
- error = "Failed to create reproducer stream: " + error;
- return nullptr;
- }
- return std::make_unique<FileReproducerStream>(std::move(outputFile));
- },
- genLocalReproducer);
+ return [filename](std::string &error) -> std::unique_ptr<ReproducerStream> {
+ std::unique_ptr<llvm::ToolOutputFile> outputFile =
+ mlir::openOutputFile(filename, &error);
+ if (!outputFile) {
+ error = "Failed to create reproducer stream: " + error;
+ return nullptr;
+ }
+ return std::make_unique<FileReproducerStream>(std::move(outputFile));
+ };
+}
+
+void printAsTextualPipeline(
+ raw_ostream &os, StringRef anchorName,
+ const llvm::iterator_range<OpPassManager::pass_iterator> &passes);
+
+std::string mlir::makeReproducer(
+ StringRef anchorName,
+ const llvm::iterator_range<OpPassManager::pass_iterator> &passes,
+ Operation *op, StringRef outputFile, bool disableThreads,
+ bool verifyPasses) {
+
+ std::string description;
+ std::string pipelineStr;
+ llvm::raw_string_ostream passOS(pipelineStr);
+ ::printAsTextualPipeline(passOS, anchorName, passes);
+ appendReproducer(description, op, makeReproducerStreamFactory(outputFile),
+ pipelineStr, disableThreads, verifyPasses);
+ return description;
+}
+
+void PassManager::enableCrashReproducerGeneration(StringRef outputFile,
+ bool genLocalReproducer) {
+ enableCrashReproducerGeneration(makeReproducerStreamFactory(outputFile),
+ genLocalReproducer);
}
void PassManager::enableCrashReproducerGeneration(
diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h
index 0e964b6..5cc7262 100644
--- a/mlir/lib/Pass/PassDetail.h
+++ b/mlir/lib/Pass/PassDetail.h
@@ -98,9 +98,8 @@ private:
class PassCrashReproducerGenerator {
public:
- PassCrashReproducerGenerator(
- PassManager::ReproducerStreamFactory &streamFactory,
- bool localReproducer);
+ PassCrashReproducerGenerator(ReproducerStreamFactory &streamFactory,
+ bool localReproducer);
~PassCrashReproducerGenerator();
/// Initialize the generator in preparation for reproducer generation. The
diff --git a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
index d7d4761..5395aa2 100644
--- a/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
+++ b/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp
@@ -151,6 +151,16 @@ struct MlirOptMainConfigCLOptions : public MlirOptMainConfig {
static cl::list<std::string> passPlugins(
"load-pass-plugin", cl::desc("Load passes from plugin library"));
+
+ static cl::opt<std::string, /*ExternalStorage=*/true>
+ generateReproducerFile(
+ "mlir-generate-reproducer",
+ llvm::cl::desc(
+ "Generate an mlir reproducer at the provided filename"
+ " (no crash required)"),
+ cl::location(generateReproducerFileFlag), cl::init(""),
+ cl::value_desc("filename"));
+
/// Set the callback to load a pass plugin.
passPlugins.setCallback([&](const std::string &pluginPath) {
auto plugin = PassPlugin::load(pluginPath);
@@ -384,6 +394,14 @@ performActions(raw_ostream &os,
if (failed(pm.run(*op)))
return failure();
+ // Generate reproducers if requested
+ if (!config.getReproducerFilename().empty()) {
+ StringRef anchorName = pm.getAnyOpAnchorName();
+ const auto &passes = pm.getPasses();
+ makeReproducer(anchorName, passes, op.get(),
+ config.getReproducerFilename());
+ }
+
// Print the output.
TimingScope outputTiming = timing.nest("Output");
if (config.shouldEmitBytecode()) {
diff --git a/mlir/test/Conversion/GPUCommon/memref-arg-attrs.mlir b/mlir/test/Conversion/GPUCommon/memref-arg-attrs.mlir
index 3337498..e7c7420 100644
--- a/mlir/test/Conversion/GPUCommon/memref-arg-attrs.mlir
+++ b/mlir/test/Conversion/GPUCommon/memref-arg-attrs.mlir
@@ -24,6 +24,17 @@ gpu.module @kernel {
// ROCDL-SAME: !llvm.ptr {llvm.writeonly}
// NVVM-SAME: !llvm.ptr {llvm.writeonly}
+// -----
+
+gpu.module @kernel {
+ gpu.func @test_func_readonly_ptr(%arg0 : !llvm.ptr {llvm.readonly} ) {
+ gpu.return
+ }
+}
+
+// CHECK-LABEL: llvm.func @test_func_readonly_ptr
+// ROCDL-SAME: !llvm.ptr {llvm.readonly}
+// NVVM-SAME: !llvm.ptr {llvm.readonly}
// -----
@@ -62,3 +73,17 @@ gpu.module @kernel {
// CHECK-LABEL: llvm.func @test_func_dereferenceable_or_null
// ROCDL-SAME: !llvm.ptr {llvm.dereferenceable_or_null = 4 : i64}
// NVVM-SAME: !llvm.ptr {llvm.dereferenceable_or_null = 4 : i64}
+
+// -----
+
+gpu.module @kernel {
+ gpu.func @test_func_noundef(%arg0 : memref<f32> {llvm.noundef} ) {
+ gpu.return
+ }
+}
+
+// CHECK-LABEL: llvm.func @test_func_noundef
+// ROCDL-SAME: !llvm.ptr {llvm.noundef}
+// ROCDL-SAME: i64 {llvm.noundef}
+// NVVM-SAME: !llvm.ptr {llvm.noundef}
+// NVVM-SAME: i64 {llvm.noundef}
diff --git a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
index ad78f0c..8316b40 100644
--- a/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
+++ b/mlir/test/Conversion/VectorToSCF/vector-to-scf.mlir
@@ -740,6 +740,43 @@ func.func @cannot_lower_transfer_read_with_leading_scalable(%arg0: memref<?x4xf3
// -----
+// Check that the `TransferOpConversion` generates valid indices for the LoadOp.
+
+#map1 = affine_map<(d0, d1, d2, d3) -> (d0, 0, 0, d3)>
+func.func @does_not_crash_on_unpack_one_dim(%subview: memref<1x1x1x1xi32>, %mask: vector<1x1xi1>) -> vector<1x1x1x1xi32> {
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ %3 = vector.transfer_read %subview[%c0, %c0, %c0, %c0], %c0_i32, %mask {permutation_map = #map1}
+ : memref<1x1x1x1xi32>, vector<1x1x1x1xi32>
+ return %3 : vector<1x1x1x1xi32>
+}
+// CHECK-LABEL: func.func @does_not_crash_on_unpack_one_dim
+// CHECK: %[[ALLOCA_0:.*]] = memref.alloca() : memref<vector<1x1xi1>>
+// CHECK: %[[MASK:.*]] = vector.type_cast %[[ALLOCA_0]] : memref<vector<1x1xi1>> to memref<1xvector<1xi1>>
+// CHECK: memref.load %[[MASK]][%{{.*}}] : memref<1xvector<1xi1>>
+
+// -----
+
+// Check that the `TransferOpConversion` generates valid indices for the StoreOp.
+// This test is pulled from an integration test for ArmSVE.
+
+func.func @add_arrays_of_scalable_vectors(%a: memref<1x2x?xf32>, %b: memref<1x2x?xf32>) -> vector<1x2x[4]xf32> {
+ %c0 = arith.constant 0 : index
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 2 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %dim_a = memref.dim %a, %c2 : memref<1x2x?xf32>
+ %mask_a = vector.create_mask %c2, %c3, %dim_a : vector<1x2x[4]xi1>
+ %vector_a = vector.transfer_read %a[%c0, %c0, %c0], %cst, %mask_a {in_bounds = [true, true, true]} : memref<1x2x?xf32>, vector<1x2x[4]xf32>
+ return %vector_a : vector<1x2x[4]xf32>
+}
+// CHECK-LABEL: func.func @add_arrays_of_scalable_vectors
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: memref.load
+
+// -----
+
// FULL-UNROLL-LABEL: @cannot_fully_unroll_transfer_write_of_nd_scalable_vector
func.func @cannot_fully_unroll_transfer_write_of_nd_scalable_vector(%vec: vector<[4]x[4]xf32>, %memref: memref<?x?xf32>) {
// FULL-UNROLL-NOT: vector.extract
diff --git a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
index b714607..f04a01f 100644
--- a/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
+++ b/mlir/test/Dialect/GPU/subgroup-redule-lowering.mlir
@@ -1,71 +1,191 @@
-// RUN: mlir-opt --allow-unregistered-dialect --test-gpu-subgroup-reduce-lowering %s | FileCheck %s
+// RUN: mlir-opt --allow-unregistered-dialect \
+// RUN: --test-gpu-subgroup-reduce-lowering %s \
+// RUN: | FileCheck %s --check-prefix=CHECK-SUB
-// CHECK: gpu.module @kernels {
+// RUN: mlir-opt --allow-unregistered-dialect \
+// RUN: --test-gpu-subgroup-reduce-lowering="expand-to-shuffles" %s \
+// RUN: | FileCheck %s --check-prefix=CHECK-SHFL
+
+// CHECK-SUB: gpu.module @kernels {
+// CHECK-SHFL: gpu.module @kernels {
gpu.module @kernels {
- // CHECK-LABEL: gpu.func @kernel0(
- // CHECK-SAME: %[[ARG0:.+]]: vector<5xf16>)
+ // CHECK-SUB-LABEL: gpu.func @kernel0(
+ // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<5xf16>)
+ //
+ // CHECK-SHFL-LABEL: gpu.func @kernel0(
gpu.func @kernel0(%arg0: vector<5xf16>) kernel {
- // CHECK: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
- // CHECK: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
- // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
- // CHECK: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
- // CHECK: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
- // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
- // CHECK: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
- // CHECK: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
- // CHECK: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
- // CHECK: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
- // CHECK: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
+ // CHECK-SUB: %[[VZ:.+]] = arith.constant dense<0.0{{.*}}> : vector<5xf16>
+ // CHECK-SUB: %[[E0:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+ // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (vector<2xf16>) -> vector<2xf16>
+ // CHECK-SUB: %[[V0:.+]] = vector.insert_strided_slice %[[R0]], %[[VZ]] {offsets = [0], strides = [1]} : vector<2xf16> into vector<5xf16>
+ // CHECK-SUB: %[[E1:.+]] = vector.extract_strided_slice %[[ARG0]] {offsets = [2], sizes = [2], strides = [1]} : vector<5xf16> to vector<2xf16>
+ // CHECK-SUB: %[[R1:.+]] = gpu.subgroup_reduce add %[[E1]] : (vector<2xf16>) -> vector<2xf16>
+ // CHECK-SUB: %[[V1:.+]] = vector.insert_strided_slice %[[R1]], %[[V0]] {offsets = [2], strides = [1]} : vector<2xf16> into vector<5xf16>
+ // CHECK-SUB: %[[E2:.+]] = vector.extract %[[ARG0]][4] : f16 from vector<5xf16>
+ // CHECK-SUB: %[[R2:.+]] = gpu.subgroup_reduce add %[[E2]] : (f16) -> f16
+ // CHECK-SUB: %[[V2:.+]] = vector.insert %[[R2]], %[[V1]] [4] : f16 into vector<5xf16>
+ // CHECK-SUB: "test.consume"(%[[V2]]) : (vector<5xf16>) -> ()
%sum0 = gpu.subgroup_reduce add %arg0 : (vector<5xf16>) -> (vector<5xf16>)
"test.consume"(%sum0) : (vector<5xf16>) -> ()
-
- // CHECK-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
- // CHECK: "test.consume"
+ // CHECK-SUB-COUNT-3: gpu.subgroup_reduce mul {{.+}} uniform
+ // CHECK-SUB: "test.consume"
%sum1 = gpu.subgroup_reduce mul %arg0 uniform : (vector<5xf16>) -> (vector<5xf16>)
"test.consume"(%sum1) : (vector<5xf16>) -> ()
- // CHECK: gpu.return
+ // CHECK-SUB: gpu.return
gpu.return
}
- // CHECK-LABEL: gpu.func @kernel1(
- // CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>)
+ // CHECK-SUB-LABEL: gpu.func @kernel1(
+ // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<1xf32>)
+ //
+ // CHECK-SHFL-LABEL: gpu.func @kernel1(
gpu.func @kernel1(%arg0: vector<1xf32>) kernel {
- // CHECK: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
- // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
- // CHECK: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
- // CHECK: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
+ // CHECK-SUB: %[[E0:.+]] = vector.extract %[[ARG0]][0] : f32 from vector<1xf32>
+ // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[E0]] : (f32) -> f32
+ // CHECK-SUB: %[[V0:.+]] = vector.broadcast %[[R0]] : f32 to vector<1xf32>
+ // CHECK-SUB: "test.consume"(%[[V0]]) : (vector<1xf32>) -> ()
%sum0 = gpu.subgroup_reduce add %arg0 : (vector<1xf32>) -> (vector<1xf32>)
"test.consume"(%sum0) : (vector<1xf32>) -> ()
- // CHECK: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
- // CHECK: "test.consume"
+ // CHECK-SUB: gpu.subgroup_reduce add {{.+}} uniform : (f32) -> f32
+ // CHECK-SUB: "test.consume"
%sum1 = gpu.subgroup_reduce add %arg0 uniform : (vector<1xf32>) -> (vector<1xf32>)
"test.consume"(%sum1) : (vector<1xf32>) -> ()
- // CHECK: gpu.return
+ // CHECK-SUB: gpu.return
gpu.return
}
// These vectors fit the native shuffle size and should not be broken down.
//
- // CHECK-LABEL: gpu.func @kernel2(
- // CHECK-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+ // CHECK-SUB-LABEL: gpu.func @kernel2(
+ // CHECK-SUB-SAME: %[[ARG0:.+]]: vector<3xi8>, %[[ARG1:.+]]: vector<4xi8>)
+ //
+ // CHECK-SHFL-LABEL: gpu.func @kernel2(
gpu.func @kernel2(%arg0: vector<3xi8>, %arg1: vector<4xi8>) kernel {
- // CHECK: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
- // CHECK: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
+ // CHECK-SUB: %[[R0:.+]] = gpu.subgroup_reduce add %[[ARG0]] : (vector<3xi8>) -> vector<3xi8>
+ // CHECK-SUB: "test.consume"(%[[R0]]) : (vector<3xi8>) -> ()
%sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
"test.consume"(%sum0) : (vector<3xi8>) -> ()
- // CHECK: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
- // CHECK: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
+ // CHECK-SUB: %[[R1:.+]] = gpu.subgroup_reduce add %[[ARG1]] : (vector<4xi8>) -> vector<4xi8>
+ // CHECK-SUB: "test.consume"(%[[R1]]) : (vector<4xi8>) -> ()
%sum1 = gpu.subgroup_reduce add %arg1 : (vector<4xi8>) -> (vector<4xi8>)
"test.consume"(%sum1) : (vector<4xi8>) -> ()
- // CHECK: gpu.return
+ // CHECK-SUB: gpu.return
+ gpu.return
+ }
+
+ // CHECK-SHFL-LABEL: gpu.func @kernel3(
+ // CHECK-SHFL-SAME: %[[ARG0:.+]]: i32)
+ gpu.func @kernel3(%arg0: i32) kernel {
+ // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
+ // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
+ // CHECK-SHFL-DAG: %[[C4:.+]] = arith.constant 4 : i32
+ // CHECK-SHFL-DAG: %[[C8:.+]] = arith.constant 8 : i32
+ // CHECK-SHFL-DAG: %[[C16:.+]] = arith.constant 16 : i32
+ // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
+
+ // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[ARG0]], %[[C1]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[S0]] : i32
+ // CHECK-SHFL: %[[S1:.+]], %{{.+}} = gpu.shuffle xor %[[A0]], %[[C2]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A1:.+]] = arith.addi %[[A0]], %[[S1]] : i32
+ // CHECK-SHFL: %[[S2:.+]], %{{.+}} = gpu.shuffle xor %[[A1]], %[[C4]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A2:.+]] = arith.addi %[[A1]], %[[S2]] : i32
+ // CHECK-SHFL: %[[S3:.+]], %{{.+}} = gpu.shuffle xor %[[A2]], %[[C8]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A3:.+]] = arith.addi %[[A2]], %[[S3]] : i32
+ // CHECK-SHFL: %[[S4:.+]], %{{.+}} = gpu.shuffle xor %[[A3]], %[[C16]], %[[C32]] : i32
+ // CHECK-SHFL: %[[A4:.+]] = arith.addi %[[A3]], %[[S4]] : i32
+ // CHECK-SHFL: "test.consume"(%[[A4]]) : (i32) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (i32) -> i32
+ "test.consume"(%sum0) : (i32) -> ()
+
+ // CHECK-SHFL: gpu.return
+ gpu.return
+ }
+
+ // CHECK-SHFL-LABEL: gpu.func @kernel4(
+ // CHECK-SHFL-SAME: %[[ARG0:.+]]: vector<2xf16>)
+ gpu.func @kernel4(%arg0: vector<2xf16>) kernel {
+ // CHECK-SHFL-DAG: %[[C1:.+]] = arith.constant 1 : i32
+ // CHECK-SHFL-DAG: %[[C2:.+]] = arith.constant 2 : i32
+ // CHECK-SHFL-DAG: %[[C4:.+]] = arith.constant 4 : i32
+ // CHECK-SHFL-DAG: %[[C8:.+]] = arith.constant 8 : i32
+ // CHECK-SHFL-DAG: %[[C16:.+]] = arith.constant 16 : i32
+ // CHECK-SHFL-DAG: %[[C32:.+]] = arith.constant 32 : i32
+
+ // CHECK-SHFL: %[[V0:.+]] = vector.bitcast %[[ARG0]] : vector<2xf16> to vector<1xi32>
+ // CHECK-SHFL: %[[I0:.+]] = vector.extract %[[V0]][0] : i32 from vector<1xi32>
+ // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[I0]], %[[C1]], %[[C32]] : i32
+ // CHECK-SHFL: %[[BR0:.+]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+ // CHECK-SHFL: %[[BC0:.+]] = vector.bitcast %[[BR0]] : vector<1xi32> to vector<2xf16>
+ // CHECK-SHFL: %[[ADD0:.+]] = arith.addf %[[ARG0]], %[[BC0]] : vector<2xf16>
+ // CHECK-SHFL: %[[BC1:.+]] = vector.bitcast %[[ADD0]] : vector<2xf16> to vector<1xi32>
+ // CHECK-SHFL: %[[I1:.+]] = vector.extract %[[BC1]][0] : i32 from vector<1xi32>
+ // CHECK-SHFL: gpu.shuffle xor %[[I1]], %[[C2]], %[[C32]] : i32
+ // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16>
+ // CHECK-SHFL: gpu.shuffle xor %{{.+}}, %[[C4]], %[[C32]] : i32
+ // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16>
+ // CHECK-SHFL: gpu.shuffle xor %{{.+}}, %[[C8]], %[[C32]] : i32
+ // CHECK-SHFL: arith.addf {{.+}} : vector<2xf16>
+ // CHECK-SHFL: %[[SL:.+]], %{{.+}} = gpu.shuffle xor %{{.+}}, %[[C16]], %[[C32]] : i32
+ // CHECK-SHFL: %[[BRL:.+]] = vector.broadcast %[[SL]] : i32 to vector<1xi32>
+ // CHECK-SHFL: %[[BCL:.+]] = vector.bitcast %[[BRL]] : vector<1xi32> to vector<2xf16>
+ // CHECK-SHFL: %[[ADDL:.+]] = arith.addf %{{.+}}, %[[BCL]] : vector<2xf16>
+ // CHECK-SHFL: "test.consume"(%[[ADDL]]) : (vector<2xf16>) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (vector<2xf16>) -> (vector<2xf16>)
+ "test.consume"(%sum0) : (vector<2xf16>) -> ()
+
+ // CHECK-SHFL: gpu.return
+ gpu.return
+ }
+
+ // CHECK-SHFL-LABEL: gpu.func @kernel5(
+ // CHECK-SHFL-SAME: %[[ARG0:.+]]: i16)
+ gpu.func @kernel5(%arg0: i16) kernel {
+ // CHECK-SHFL: %[[E0:.+]] = arith.extui %[[ARG0]] : i16 to i32
+ // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[E0]], {{.+}} : i32
+ // CHECK-SHFL: %[[T0:.+]] = arith.trunci %[[S0]] : i32 to i16
+ // CHECK-SHFL: %[[A0:.+]] = arith.addi %[[ARG0]], %[[T0]] : i16
+ // CHECK-SHFL: %[[E1:.+]] = arith.extui %[[A0]] : i16 to i32
+ // CHECK-SHFL: %{{.+}}, %{{.+}} = gpu.shuffle xor %[[E1]], {{.+}} : i32
+ // CHECK-SHFL-COUNT-3: gpu.shuffle xor
+ // CHECK-SHFL: arith.trunci {{.+}} : i32 to i16
+ // CHECK-SHFL: %[[AL:.+]] = arith.addi {{.+}} : i16
+ // CHECK-SHFL: "test.consume"(%[[AL]]) : (i16) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (i16) -> i16
+ "test.consume"(%sum0) : (i16) -> ()
+
+ // CHECK-SHFL: gpu.return
+ gpu.return
+ }
+
+ // CHECK-SHFL-LABEL: gpu.func @kernel6(
+ // CHECK-SHFL-SAME: %[[ARG0:.+]]: vector<3xi8>)
+ gpu.func @kernel6(%arg0: vector<3xi8>) kernel {
+ // CHECK-SHFL: %[[CZ:.+]] = arith.constant dense<0> : vector<4xi8>
+ // CHECK-SHFL: %[[V0:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CZ]] {offsets = [0], strides = [1]} : vector<3xi8> into vector<4xi8>
+ // CHECK-SHFL: %[[BC0:.+]] = vector.bitcast %[[V0]] : vector<4xi8> to vector<1xi32>
+ // CHECK-SHFL: %[[I0:.+]] = vector.extract %[[BC0]][0] : i32 from vector<1xi32>
+ // CHECK-SHFL: %[[S0:.+]], %{{.+}} = gpu.shuffle xor %[[I0]], {{.+}} : i32
+ // CHECK-SHFL: %[[BR0:.+]] = vector.broadcast %[[S0]] : i32 to vector<1xi32>
+ // CHECK-SHFL: %[[BC1:.+]] = vector.bitcast %[[BR0]] : vector<1xi32> to vector<4xi8>
+ // CHECK-SHFL: %[[ADD0:.+]] = arith.addi %[[V0]], %[[BC1]] : vector<4xi8>
+ // CHECK-SHFL: %[[BC2:.+]] = vector.bitcast %[[ADD0]] : vector<4xi8> to vector<1xi32>
+ // CHECK-SHFL: %[[I1:.+]] = vector.extract %[[BC2]][0] : i32 from vector<1xi32>
+ // CHECK-SHFL-COUNT-4: gpu.shuffle xor
+ // CHECK-SHFL: %[[ESS:.+]] = vector.extract_strided_slice %{{.+}} {offsets = [0], sizes = [3], strides = [1]} : vector<4xi8> to vector<3xi8>
+ // CHECK-SHFL: "test.consume"(%[[ESS]]) : (vector<3xi8>) -> ()
+ %sum0 = gpu.subgroup_reduce add %arg0 : (vector<3xi8>) -> (vector<3xi8>)
+ "test.consume"(%sum0) : (vector<3xi8>) -> ()
+
+ // CHECK-SHFL: gpu.return
gpu.return
}
}
+
diff --git a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
index 49a52ba..aa15ccf 100644
--- a/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-bufferize-to-allocation.mlir
@@ -36,13 +36,15 @@ module attributes {transform.with_named_sequence} {
// Ensure that one linalg.fill was generated.
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+ %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
// Ensure that one linalg.copy was generated.
%mat = transform.select "bufferization.materialize_in_destination" in %new : (!transform.any_op) -> !transform.any_op
+ %p2 = transform.num_associations %mat : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %mat : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
transform.yield
}
}
@@ -73,18 +75,21 @@ module attributes {transform.with_named_sequence} {
// Ensure that one linalg.fill was generated.
%fill_op = transform.select "linalg.fill" in %new : (!transform.any_op) -> !transform.any_op
+ %p = transform.num_associations %fill_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %fill_op : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
// Ensure that one linalg.copy was generated.
%linalg_copy = transform.select "linalg.copy" in %new : (!transform.any_op) -> !transform.any_op
+ %p2 = transform.num_associations %linalg_copy : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %linalg_copy : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
// Ensure that one memref.alloca was generated.
%alloca = transform.select "memref.alloca" in %new : (!transform.any_op) -> !transform.any_op
+ %p3 = transform.num_associations %alloca : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %alloca : !transform.any_op
+ transform.test_print_param %p3 : !transform.param<i64>
// Make sure that One-Shot Bufferize can bufferize the rest.
%4 = transform.bufferization.one_shot_bufferize %arg1 : (!transform.any_op) -> !transform.any_op
diff --git a/mlir/test/Dialect/Linalg/transform-op-match.mlir b/mlir/test/Dialect/Linalg/transform-op-match.mlir
index 15942db..db5b5f1 100644
--- a/mlir/test/Dialect/Linalg/transform-op-match.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-match.mlir
@@ -134,8 +134,9 @@ module attributes {transform.with_named_sequence} {
#linalg.iterator_type<parallel>,
#linalg.iterator_type<reduction>]}
in %arg1 : (!transform.any_op) -> !transform.any_op
- // expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %no_match : !transform.any_op
+ %p = transform.num_associations %no_match : (!transform.any_op) -> !transform.param<i64>
+ // expected-remark @below {{0}}
+ transform.test_print_param %p : !transform.param<i64>
transform.yield
}
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-pad.mlir b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
index 6bca6c1..1f9d81a 100644
--- a/mlir/test/Dialect/Linalg/transform-op-pad.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-pad.mlir
@@ -41,8 +41,9 @@ module attributes {transform.with_named_sequence} {
padding_dimensions=[0, 1, 2],
pack_paddings=[1, 1, 0]
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.op<"bufferization.materialize_in_destination">)
+ %p = transform.num_associations %copy_back : (!transform.op<"bufferization.materialize_in_destination">) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %copy_back : !transform.op<"bufferization.materialize_in_destination">
+ transform.test_print_param %p : !transform.param<i64>
transform.yield
}
}
diff --git a/mlir/test/Dialect/Mesh/invalid.mlir b/mlir/test/Dialect/Mesh/invalid.mlir
index 03994f8..3ee578a 100644
--- a/mlir/test/Dialect/Mesh/invalid.mlir
+++ b/mlir/test/Dialect/Mesh/invalid.mlir
@@ -70,6 +70,102 @@ func.func @mesh_axis_negtive_in_partial(
// -----
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @cluster_shape_mesh_axis_out_of_bounds() -> (index, index) {
+ // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+ %0:2 = mesh.cluster_shape @mesh0 axes = [0, 2] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @cluster_shape_duplicate_mesh_axis() -> (index, index, index) {
+ // expected-error@+1 {{Mesh axes contains duplicate elements.}}
+ %0:3 = mesh.cluster_shape @mesh0 axes = [0, 2, 0] : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @cluster_shape_wrong_number_of_results() -> (index, index) {
+ // expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
+ %0:2 = mesh.cluster_shape @mesh0 axes = [0] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @cluster_shape_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+ // expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
+ %0:2 = mesh.cluster_shape @mesh0 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+func.func @cluster_shape_invalid_mesh_name() -> (index) {
+ // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.cluster_shape @this_mesh_symbol_does_not_exist : index
+ return %0#0 : index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @process_index_mesh_axis_out_of_bounds() -> (index, index) {
+ // expected-error@+1 {{0-based mesh axis index 2 is out of bounds. The referenced mesh "mesh0" is of rank 2.}}
+ %0:2 = mesh.process_index on @mesh0 axes = [0, 2] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @process_index_duplicate_mesh_axis() -> (index, index, index) {
+ // expected-error@+1 {{Mesh axes contains duplicate elements.}}
+ %0:3 = mesh.process_index on @mesh0 axes = [0, 2, 0] : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 2, dim_sizes = 2x4)
+
+func.func @process_index_wrong_number_of_results() -> (index, index) {
+ // expected-error@+1 {{Unexpected number of results 2. Expected 1.}}
+ %0:2 = mesh.process_index on @mesh0 axes = [0] : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+mesh.cluster @mesh0(rank = 3, dim_sizes = 1x2x3)
+
+func.func @process_index_wrong_number_of_results_empty_mesh_axes() -> (index, index) {
+ // expected-error@+1 {{Unexpected number of results 2. Expected 3.}}
+ %0:2 = mesh.process_index on @mesh0 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// -----
+
+func.func @process_index_invalid_mesh_name() -> (index) {
+ // expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
+ %0 = mesh.process_index on @this_mesh_symbol_does_not_exist : index
+ return %0#0 : index
+}
+
+// -----
+
func.func @all_reduce_invalid_mesh_symbol(
%arg0 : tensor<4xf32>) -> tensor<4xf64> {
// expected-error@+1 {{Undefined required mesh symbol "this_mesh_symbol_does_not_exist".}}
diff --git a/mlir/test/Dialect/Mesh/ops.mlir b/mlir/test/Dialect/Mesh/ops.mlir
index 8f8e309..a7c3b3d 100644
--- a/mlir/test/Dialect/Mesh/ops.mlir
+++ b/mlir/test/Dialect/Mesh/ops.mlir
@@ -132,6 +132,55 @@ func.func @mesh_shard_op_two_users(%arg0 : tensor<4x8xf32>) ->
return %1, %2 : tensor<4x8xf32>, tensor<4x8xf32>
}
+// CHECK-LABEL: func @cluster_shape
+func.func @cluster_shape() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+ %0:2 = mesh.cluster_shape @mesh0 axes = [0, 1] : index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_default_axes
+func.func @cluster_shape_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+ %0:3 = mesh.cluster_shape @mesh0 : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @cluster_shape_empty_axes
+func.func @cluster_shape_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.cluster_shape @mesh0 : index, index, index
+ %0:3 = mesh.cluster_shape @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index
+func.func @process_index() -> (index, index) {
+ // CHECK: %[[RES:.*]]:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+ %0:2 = mesh.process_index on @mesh0 axes = [0, 1] : index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1 : index, index
+ return %0#0, %0#1 : index, index
+}
+
+// CHECK-LABEL: func @process_index_default_axes
+func.func @process_index_default_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+ %0:3 = mesh.process_index on @mesh0 : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+// CHECK-LABEL: func @process_index_empty_axes
+func.func @process_index_empty_axes() -> (index, index, index) {
+ // CHECK: %[[RES:.*]]:3 = mesh.process_index on @mesh0 : index, index, index
+ %0:3 = mesh.process_index on @mesh0 axes = [] : index, index, index
+ // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2 : index, index, index
+ return %0#0, %0#1, %0#2 : index, index, index
+}
+
+
// CHECK-LABEL: func @all_reduce
func.func @all_reduce(
// CHECK-SAME: %[[ARG:.*]]: tensor<3x4xf32>
diff --git a/mlir/test/Dialect/Mesh/resharding-spmdization.mlir b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
new file mode 100644
index 0000000..0ba0d76
--- /dev/null
+++ b/mlir/test/Dialect/Mesh/resharding-spmdization.mlir
@@ -0,0 +1,154 @@
+// RUN: mlir-opt -test-mesh-resharding-spmdization %s | FileCheck %s
+
+mesh.cluster @mesh_1d(rank = 1, dim_sizes = 2)
+mesh.cluster @mesh_1d_dynamic(rank = 1, dim_sizes = ?)
+
+// CHECK-LABEL: func @same_source_and_target_sharding
+func.func @same_source_and_target_sharding(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<2xf32>
+ %arg0: tensor<2xf32>
+) -> tensor<2xf32> {
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<2xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<2xf32>
+ // CHECK: return %[[ARG]]
+ return %1 : tensor<2xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis
+func.func @split_replicated_tensor_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<3x14xf32>
+ %arg0: tensor<3x14xf32>
+) -> tensor<3x14xf32> {
+ // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.constant 14 : index
+ // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d axes = [0] : index
+ // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d axes = [0] : index
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
+ // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
+ // CHECK: %[[RESULT_TENSOR_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
+ // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][0, %[[RESULT_TENSOR_AXIS_OFFSET]]] [3, 7] [1, 1] : tensor<3x14xf32> to tensor<3x7xf32>
+ // CHECK: %[[RESULT:.*]] = builtin.unrealized_conversion_cast %[[RESULT_TENSOR_SLICE]] : tensor<3x7xf32> to tensor<3x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]]> : tensor<3x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<3x14xf32>
+ // CHECK: return %[[RESULT]] : tensor<3x14xf32>
+ return %1 : tensor<3x14xf32>
+}
+
+// CHECK-LABEL: func @split_replicated_tensor_axis_dynamic
+func.func @split_replicated_tensor_axis_dynamic(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x3x?xf32>
+ %arg0: tensor<?x3x?xf32>
+) -> tensor<?x3x?xf32> {
+ // CHECK: %[[ZERO:.*]] = arith.constant 0 : index
+ // CHECK: %[[TWO:.*]] = arith.constant 2 : index
+ // CHECK: %[[PROCESS_INDEX:.*]] = mesh.process_index on @mesh_1d_dynamic axes = [0] : index
+ // CHECK: %[[MESH_AXIS_SIZE:.*]] = mesh.cluster_shape @mesh_1d_dynamic axes = [0] : index
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE:.*]] = tensor.dim %[[ARG]], %[[ZERO]] : tensor<?x3x?xf32>
+ // CHECK: %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE:.*]] = arith.remui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_AXIS_SIZE_CHECK:.*]] = arith.cmpi eq, %[[TENSOR_SPLIT_AXIS_SIZE_MOD_MESH_AXIS_SIZE]], %[[ZERO]] : index
+ // CHECK: cf.assert %[[RESULT_TENSOR_AXIS_SIZE_CHECK]]
+ // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_SIZE:.*]] = arith.divui %[[TENSOR_SPLIT_AXIS_SIZE]], %[[MESH_AXIS_SIZE]] : index
+ // CHECK: %[[RESULT_TENSOR_SPLIT_AXIS_OFFSET:.*]] = arith.muli %[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], %[[PROCESS_INDEX]] : index
+ // CHECK: %[[TENSOR_AXIS_2_SIZE:.*]] = tensor.dim %[[ARG]], %[[TWO]] : tensor<?x3x?xf32>
+ // CHECK: %[[RESULT_TENSOR_SLICE:.*]] = tensor.extract_slice %[[ARG]][%[[RESULT_TENSOR_SPLIT_AXIS_OFFSET]], 0, 0]
+ // CHECK-SAME: [%[[RESULT_TENSOR_SPLIT_AXIS_SIZE]], 3, %[[TENSOR_AXIS_2_SIZE]]] [1, 1, 1] : tensor<?x3x?xf32> to tensor<?x3x?xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[], [], []]> : tensor<?x3x?xf32>
+ %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[0]]> annotate_for_users : tensor<?x3x?xf32>
+ // CHECK: return %[[RESULT_TENSOR_SLICE]] : tensor<?x3x?xf32>
+ return %1 : tensor<?x3x?xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis
+func.func @move_split_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+ // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<5x14xf32> -> tensor<10x7xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x7xf32> to tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_axis_dynamic_mesh
+func.func @move_split_axis_dynamic_mesh(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+ // CHECK: %[[ALL_TO_ALL:.*]] = mesh.all_to_all %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x?xf32>
+ // CHECK: %[[TARGET_SHARD:.*]] = tensor.cast %[[ALL_TO_ALL]] : tensor<?x?xf32> to tensor<10x?xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<10x?xf32> to tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[], [0]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @move_split_dynamic_axis
+func.func @move_split_dynamic_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+ %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+ // CHECK: %[[TARGET_SHARD:.*]] = mesh.all_to_all %[[ARG]] on @mesh_1d mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<?x14xf32> -> tensor<?x7xf32>
+ // CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[TARGET_SHARD]] : tensor<?x7xf32> to tensor<?x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[], [0]]> annotate_for_users : tensor<?x14xf32>
+ // CHECK: return %[[RES]] : tensor<?x14xf32>
+ return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis
+func.func @unshard_static_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<5x14xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<5x14xf32> -> tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_dynamic_axis
+func.func @unshard_dynamic_axis(
+ // CHECK-SAME: %[[ARG:.*]]: tensor<?x14xf32>
+ %arg0: tensor<?x14xf32>
+) -> tensor<?x14xf32> {
+ // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[ARG]] on @mesh_1d mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[0]]> : tensor<?x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<?x14xf32>
+ // CHECK: return %[[ALL_GATHER]] : tensor<?x14xf32>
+ return %1 : tensor<?x14xf32>
+}
+
+// CHECK-LABEL: func @unshard_static_axis_on_dynamic_mesh_axis
+func.func @unshard_static_axis_on_dynamic_mesh_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[SOURCE_SHARD:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : tensor<10x14xf32> to tensor<?x14xf32>
+ // CHECK: %[[ALL_GATHER:.*]] = mesh.all_gather %[[SOURCE_SHARD]] on @mesh_1d_dynamic mesh_axes = [0] gather_axis = 0 : tensor<?x14xf32> -> tensor<?x14xf32>
+ // CHECK: %[[RES:.*]] = tensor.cast %[[ALL_GATHER]] : tensor<?x14xf32> to tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d_dynamic, [[0]]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d_dynamic, [[]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: return %[[RES]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
+
+// CHECK-LABEL: func @partial_axis
+func.func @partial_axis(
+// CHECK-SAME: %[[ARG:.*]]: tensor<10x14xf32>
+ %arg0: tensor<10x14xf32>
+) -> tensor<10x14xf32> {
+ // CHECK: %[[ALL_REDUCE:.*]] = mesh.all_reduce %[[ARG]] on @mesh_1d mesh_axes = [0] : tensor<10x14xf32> -> tensor<10x14xf32>
+ %0 = mesh.shard %arg0 to <@mesh_1d, [[]], partial = sum[0]> : tensor<10x14xf32>
+ %1 = mesh.shard %0 to <@mesh_1d, [[]]> annotate_for_users : tensor<10x14xf32>
+ // CHECK: %[[ALL_REDUCE]] : tensor<10x14xf32>
+ return %1 : tensor<10x14xf32>
+}
diff --git a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
index f584977e..6fe7ec9 100644
--- a/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul_lib_2to4.mlir
+++ b/mlir/test/Dialect/SparseTensor/GPU/gpu_matmul24_lib.mlir
@@ -1,5 +1,13 @@
// RUN: mlir-opt %s --linalg-generalize-named-ops --sparse-gpu-codegen="num-threads=0" | FileCheck %s
+#NV_24 = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i : dense,
+ j floordiv 4 : dense,
+ j mod 4 : block2_4
+ )
+}>
+
// CHECK-LABEL: func.func @matmul(
// CHECK-SAME: %[[VAL_0:.*0]]: tensor<?x?xf16>,
// CHECK-SAME: %[[VAL_1:.*1]]: tensor<?x?xf16>,
@@ -51,18 +59,14 @@
// CHECK: %[[VAL_55:.*]] = bufferization.to_tensor %[[VAL_19]] : memref<?x?xf16>
// CHECK: return %[[VAL_55]] : tensor<?x?xf16>
// CHECK: }
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
- func.func @matmul(%arg0: tensor<?x?xf16>, %arg1: tensor<?x?xf16>, %arg2: tensor<?x?xf16>) -> tensor<?x?xf16> {
- %0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>) outs(%arg2 : tensor<?x?xf16>) {
- ^bb0(%in: f16, %in_0: f16, %out: f16):
- %1 = arith.mulf %in, %in_0 : f16
- %2 = arith.addf %out, %1 : f16
- linalg.yield %2 : f16
- } -> tensor<?x?xf16>
- return %0 : tensor<?x?xf16>
+ func.func @matmul(%Ad: tensor<?x?xf16>,
+ %B: tensor<?x?xf16>,
+ %Cin: tensor<?x?xf16>) -> tensor<?x?xf16> {
+ %A = sparse_tensor.convert %Ad : tensor<?x?xf16> to tensor<?x?xf16, #NV_24>
+ %C = linalg.matmul
+ ins(%A, %B: tensor<?x?xf16, #NV_24>, tensor<?x?xf16>)
+ outs(%Cin: tensor<?x?xf16>) -> tensor<?x?xf16>
+ return %C : tensor<?x?xf16>
}
}
diff --git a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
index bdfe18a..b78ab9b 100644
--- a/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
+++ b/mlir/test/Dialect/Tensor/simplify-pack-unpack.mlir
@@ -56,3 +56,75 @@ func.func @single_first_inner_dim_packing(%arg0: tensor<256x5xf32>) -> tensor<8x
%0 = tensor.pack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<256x5xf32> -> tensor<8x5x32xf32>
return %0 : tensor<8x5x32xf32>
}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_1d_to_collapse
+// CHECK-SAME: %[[ARG0:.+]]: tensor<8x32xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1]] : tensor<8x32xf32> into tensor<256xf32>
+// CHECK: return %[[COLLAPSED]]
+func.func @unpack_1d_to_collapse(%arg0: tensor<8x32xf32>) -> tensor<256xf32> {
+ %empty = tensor.empty() : tensor<256xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<256xf32>
+ return %0 : tensor<256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_to_partial_slice
+// CHECK-NOT: tensor.collapse
+// CHECK: tensor.unpack
+func.func @unpack_to_partial_slice(%arg0: tensor<8x32xf32>) -> tensor<255xf32> {
+ %empty = tensor.empty() : tensor<255xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x32xf32> -> tensor<255xf32>
+ return %0 : tensor<255xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpack_dynamic
+// CHECK-NOT: tensor.collapse
+// CHECK: tensor.unpack
+func.func @unpack_dynamic(%arg0: tensor<?x32xf32>) -> tensor<?xf32> {
+ %c32 = arith.constant 32 : index
+ %c0 = arith.constant 0 : index
+ %d0 = tensor.dim %arg0, %c0 : tensor<?x32xf32>
+ %size = arith.muli %d0, %c32 : index
+ %empty = tensor.empty(%size) : tensor<?xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<?x32xf32> -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_last_inner_dim_unpacking(
+// CHECK-SAME: %[[ARG0:.+]]: tensor<5x8x32xf32>)
+// CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2]] : tensor<5x8x32xf32> into tensor<5x256xf32>
+// CHECK: return %[[COLLAPSED]] : tensor<5x256xf32>
+func.func @single_last_inner_dim_unpacking(%arg0: tensor<5x8x32xf32>) -> tensor<5x256xf32> {
+ %empty = tensor.empty() : tensor<5x256xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<5x8x32xf32> -> tensor<5x256xf32>
+ return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @unpacking_with_outer_dims_perm(
+// CHECK-NOT: tensor.collpase_shape
+// CHECK: tensor.unpack
+func.func @unpacking_with_outer_dims_perm(%arg0: tensor<8x5x32xf32>) -> tensor<5x256xf32> {
+ %empty = tensor.empty() : tensor<5x256xf32>
+ %0 = tensor.unpack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [1] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<5x256xf32>
+ return %0 : tensor<5x256xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @single_first_inner_dim_unpacking(
+// CHECK-NOT: tensor.collapse_shape
+// CHECK: tensor.unpack
+func.func @single_first_inner_dim_unpacking(%arg0: tensor<8x5x32xf32>) -> tensor<256x5xf32> {
+ %empty = tensor.empty() : tensor<256x5xf32>
+ %0 = tensor.unpack %arg0 inner_dims_pos = [0] inner_tiles = [32] into %empty : tensor<8x5x32xf32> -> tensor<256x5xf32>
+ return %0 : tensor<256x5xf32>
+}
diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir
index 0964161..5123958 100644
--- a/mlir/test/Dialect/Transform/ops-invalid.mlir
+++ b/mlir/test/Dialect/Transform/ops-invalid.mlir
@@ -696,3 +696,11 @@ transform.sequence failures(propagate) {
transform.named_sequence @foo()
} : !transform.any_op
}
+
+// -----
+
+transform.sequence failures(propagate) {
+^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
+ transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir
index d9a1199..a39e6f9 100644
--- a/mlir/test/Dialect/Transform/test-interpreter.mlir
+++ b/mlir/test/Dialect/Transform/test-interpreter.mlir
@@ -575,8 +575,9 @@ transform.with_pdl_patterns {
%0 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = pdl_match @addi in %arg1 : (!transform.any_op) -> !transform.any_op
%2 = merge_handles deduplicate %0, %1 : !transform.any_op
+ %3 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ test_print_param %3 : !transform.param<i64>
}
}
@@ -676,11 +677,13 @@ module {
^bb0(%arg1: !transform.any_op):
%0 = pdl_match @func in %arg1 : (!transform.any_op) -> !transform.any_op
%1 = replicate num(%0) %arg1 : !transform.any_op, !transform.any_op
+ %p = num_associations %1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
%2 = replicate num(%0) %1 : !transform.any_op, !transform.any_op
+ %p2 = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ test_print_param %p2 : !transform.param<i64>
}
}
}
@@ -708,8 +711,9 @@ transform.with_pdl_patterns {
%f = pdl_match @const in %arg1 : (!transform.any_op) -> !transform.any_op
transform.foreach %f : !transform.any_op {
^bb2(%arg2: !transform.any_op):
+ %p = transform.num_associations %arg2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %arg2 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
transform.test_print_remark_at_operand %arg2, "transform applied" : !transform.any_op
}
}
@@ -780,8 +784,9 @@ transform.with_pdl_patterns {
transform.yield %g : !transform.any_op
}
+ %p = transform.num_associations %results : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{3}}
- transform.test_print_number_of_associated_payload_ir_ops %results : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
transform.test_print_remark_at_operand %results, "transform applied" : !transform.any_op
}
}
@@ -877,8 +882,9 @@ transform.sequence failures(propagate) {
^bb1(%fun: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// expected-error @below {{expected to contain 3 payload ops but it contains 2 payload ops}}
%h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@@ -896,13 +902,15 @@ transform.sequence failures(suppress) {
^bb1(%fun: !transform.any_op):
%muli = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// Silenceable failure and all handles are now empty.
%h_2:3 = split_handle %muli_2 : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %p2 = transform.num_associations %h_2#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %h_2#0 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
}
// -----
@@ -918,12 +926,15 @@ transform.sequence failures(propagate) {
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
// No error, last result handle is empty.
%h:3 = split_handle %muli_2 {fail_on_payload_too_small = false} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
+ %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
+ %p3 = transform.num_associations %h#2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{0}}
- transform.test_print_number_of_associated_payload_ir_ops %h#2 : !transform.any_op
+ transform.test_print_param %p3 : !transform.param<i64>
}
// -----
@@ -940,10 +951,12 @@ transform.sequence failures(propagate) {
^bb1(%fun: !transform.any_op):
%muli_2 = transform.structured.match ops{["arith.muli"]} in %fun : (!transform.any_op) -> !transform.any_op
%h:2 = split_handle %muli_2 {overflow_result = 0} : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %p = transform.num_associations %h#0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{3}}
- transform.test_print_number_of_associated_payload_ir_ops %h#0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
+ %p2 = transform.num_associations %h#1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{1}}
- transform.test_print_number_of_associated_payload_ir_ops %h#1 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
}
// -----
@@ -1668,8 +1681,9 @@ transform.sequence failures(propagate) {
// expected-remark @below {{2 iterations}}
transform.test_tracked_rewrite %0 : (!transform.any_op) -> ()
// One replacement op (test.drop_mapping) is dropped from the mapping.
+ %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
}
// -----
@@ -1684,20 +1698,24 @@ module {
%2 = transform.param.constant 1 -> !transform.param<i64>
%3 = transform.param.constant 2 -> !transform.param<i64>
%4 = transform.merge_handles %1, %2 { deduplicate } : !transform.param<i64>
+ %p = num_associations %4 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_params %4 : !transform.param<i64>
+ test_print_param %p : !transform.param<i64>
%5 = transform.merge_handles %1, %1 { deduplicate } : !transform.param<i64>
+ %p2 = num_associations %5 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_params %5 : !transform.param<i64>
+ test_print_param %p2 : !transform.param<i64>
%6 = transform.merge_handles %1, %3 { deduplicate } : !transform.param<i64>
+ %p3 = num_associations %6 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_params %6 : !transform.param<i64>
+ test_print_param %p3 : !transform.param<i64>
%7 = transform.merge_handles %1, %1, %2, %3 : !transform.param<i64>
+ %p4 = num_associations %7 : (!transform.param<i64>) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_params %7 : !transform.param<i64>
+ test_print_param %p4 : !transform.param<i64>
}
}
@@ -1712,21 +1730,25 @@ transform.sequence failures(propagate) {
%3 = test_produce_value_handle_to_result %1, 1 : (!transform.any_op) -> !transform.any_value
%4 = transform.merge_handles %2, %2 { deduplicate } : !transform.any_value
+ %p = num_associations %4 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_values %4 : !transform.any_value
+ test_print_param %p : !transform.param<i64>
%5 = transform.merge_handles %2, %3 { deduplicate } : !transform.any_value
+ %p2 = num_associations %5 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{2}}
- test_print_number_of_associated_payload_ir_values %5 : !transform.any_value
+ test_print_param %p2 : !transform.param<i64>
%6 = test_produce_value_handle_to_result %1, 0 : (!transform.any_op) -> !transform.any_value
%7 = transform.merge_handles %2, %6 { deduplicate } : !transform.any_value
+ %p3 = num_associations %6 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{1}}
- test_print_number_of_associated_payload_ir_values %6 : !transform.any_value
+ test_print_param %p3 : !transform.param<i64>
%8 = transform.merge_handles %2, %2, %3, %4 : !transform.any_value
+ %p4 = num_associations %8 : (!transform.any_value) -> !transform.param<i64>
// expected-remark @below {{4}}
- test_print_number_of_associated_payload_ir_values %8 : !transform.any_value
+ test_print_param %p4 : !transform.param<i64>
}
// -----
@@ -1820,31 +1842,37 @@ transform.sequence failures(propagate) {
// There are 3 arith.constant ops.
%all = transform.structured.match ops{["arith.constant"]} in %0 : (!transform.any_op) -> !transform.any_op
+ %p = num_associations %all : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+ test_print_param %p : !transform.param<i64>
// "deduplicate" has no effect because these are 3 different ops.
%merged_before = transform.merge_handles deduplicate %all : !transform.any_op
+ %p2 = num_associations %merged_before : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %merged_before : !transform.any_op
+ test_print_param %p2 : !transform.param<i64>
// Apply CSE.
transform.apply_cse to %0 : !transform.any_op
// The handle is still mapped to 3 arith.constant ops.
+ %p3 = num_associations %all : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{3}}
- test_print_number_of_associated_payload_ir_ops %all : !transform.any_op
+ test_print_param %p3 : !transform.param<i64>
// But they are all the same op.
%merged_after = transform.merge_handles deduplicate %all : !transform.any_op
+ %p4 = num_associations %merged_after : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %merged_after : !transform.any_op
+ test_print_param %p4 : !transform.param<i64>
// The other handles were also updated.
test_print_remark_at_operand %elim_first, "eliminated 1" : !transform.any_op
+ %p5 = num_associations %elim_first : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %elim_first : !transform.any_op
+ test_print_param %p5 : !transform.param<i64>
test_print_remark_at_operand %elim_second, "eliminated 2" : !transform.any_op
+ %p6 = num_associations %elim_second : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %elim_second : !transform.any_op
+ test_print_param %p6 : !transform.param<i64>
}
// -----
@@ -1907,14 +1935,16 @@ transform.sequence failures(propagate) {
// Get immediate parent.
%2 = transform.get_parent_op %0 : (!transform.any_op) -> !transform.any_op
test_print_remark_at_operand %2, "direct parent" : !transform.any_op
+ %p = num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{2}}
- test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
// Deduplicate results.
%3 = transform.structured.match ops{["test.qux"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%4 = transform.get_parent_op %3 {deduplicate} : (!transform.any_op) -> !transform.any_op
+ %p2 = num_associations %4 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- test_print_number_of_associated_payload_ir_ops %4 : !transform.any_op
+ test_print_param %p2 : !transform.param<i64>
}
@@ -2029,8 +2059,9 @@ transform.sequence failures(propagate) {
// Match all ops inside the function (including the function itself).
%func_op = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
%0 = transform.structured.match in %func_op : (!transform.any_op) -> !transform.any_op
+ %p = num_associations %0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{5}}
- test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ test_print_param %p : !transform.param<i64>
// Select "test.foo".
%foo = transform.select "test.foo" in %0 : (!transform.any_op) -> !transform.any_op
@@ -2060,8 +2091,9 @@ transform.sequence failures(propagate) {
%empty_op = transform.structured.match ops{["tensor.empty"]} in %func_op : (!transform.any_op) -> !transform.any_op
transform.apply_dce to %func_op : !transform.any_op
+ %p = num_associations %empty_op : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{0}}
- test_print_number_of_associated_payload_ir_ops %empty_op : !transform.any_op
+ test_print_param %p : !transform.param<i64>
}
diff --git a/mlir/test/Dialect/Transform/test-loop-transforms.mlir b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
index 4259627..c34f4ba 100644
--- a/mlir/test/Dialect/Transform/test-loop-transforms.mlir
+++ b/mlir/test/Dialect/Transform/test-loop-transforms.mlir
@@ -37,13 +37,16 @@ module attributes {transform.with_named_sequence} {
// Make sure that the handles are still valid (and were updated in case of
// the loop).
+ %p = transform.num_associations %0 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %0 : !transform.any_op
+ transform.test_print_param %p : !transform.param<i64>
transform.test_print_remark_at_operand %0, "new loop op" : !transform.any_op
+ %p2 = transform.num_associations %1 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %1 : !transform.any_op
+ transform.test_print_param %p2 : !transform.param<i64>
+ %p3 = transform.num_associations %2 : (!transform.any_op) -> !transform.param<i64>
// expected-remark @below{{1}}
- transform.test_print_number_of_associated_payload_ir_ops %2 : !transform.any_op
+ transform.test_print_param %p3 : !transform.param<i64>
transform.yield
}
diff --git a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
index 3708d74..ae457ea 100644
--- a/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
+++ b/mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
@@ -356,3 +356,18 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32>
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32>
// CHECK: return %[[VAL_4]] : vector<8xi32>
+
+// -----
+
+// This test is to make sure there is no crash for empty stride.
+func.func @stride_empty_test(%1: memref<i16>) -> vector<32x256xi16> {
+ %c0_i16 = arith.constant 0 : i16
+ %3 = vector.transfer_read %1[], %c0_i16 {permutation_map = affine_map<() -> (0, 0)>} : memref<i16>, vector<32x256xi16>
+ return %3 : vector<32x256xi16>
+
+ // CHECK-LABEL: func.func @stride_empty_test
+ // CHECK: %[[VAL:.*]] = arith.constant 0 : i16
+ // CHECK: %[[RET:.*]] = vector.transfer_read {{.*}} vector<32x256xi16>
+ // CHECK: return %[[RET]]
+ // CHECK-NOT: empty()
+}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-hand.mlir
index d7e9ced..117832d 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib-from-linalg.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-hand.mlir
@@ -1,40 +1,58 @@
// NOTE: this test requires gpu-sm80 and cusparselt
//
-// DEFINE: %{compile} = mlir-opt %s \
-// DEFINE: --sparsifier="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
+// DEFINE: %{compile} = mlir-opt --convert-vector-to-scf --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \
+// DEFINE: --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \
+// DEFINE: %s
// DEFINE: %{run} = mlir-cpu-runner \
// DEFINE: --shared-libs=%mlir_cuda_runtime \
// DEFINE: --shared-libs=%mlir_c_runner_utils \
// DEFINE: --e main --entry-point-result=void \
// DEFINE: | FileCheck %s
//
-// with RT lib:
-//
-// RUN: %{compile} enable-runtime-library=true" | %{run}
-//
-// without RT lib:
-//
-// RUN: %{compile} enable-runtime-library=false" | %{run}
-
-#map = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+// RUN: %{compile} | %{run}
module {
llvm.func @mgpuCreateSparseLtEnv()
llvm.func @mgpuDestroySparseLtEnv()
- //
- // TODO: This uses our temporary ATTRIBUTE, replace with 2:4 type!
- //
- func.func @matmul_2to4(%arg0: tensor<16x32xf16>, %arg1: tensor<32x16xf16>, %arg2: tensor<16x16xf16>) -> tensor<16x16xf16> {
- %0 = linalg.generic { DENSE24, indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<16x32xf16>, tensor<32x16xf16>) outs(%arg2 : tensor<16x16xf16>) {
- ^bb0(%in: f16, %in_0: f16, %out: f16):
- %1 = arith.mulf %in, %in_0 : f16
- %2 = arith.addf %out, %1 : f16
- linalg.yield %2 : f16
- } -> tensor<16x16xf16>
- return %0 : tensor<16x16xf16>
+ // cuSparselt version for matmul coded by hand.
+ func.func @matmul24(%a : memref<16x32xf16>,
+ %b : memref<32x16xf16>,
+ %c : memref<16x16xf16>) {
+ %c0 = arith.constant 0.0 : f16
+ %c1 = arith.constant 1 : index
+ %c2 = arith.constant 2 : index
+ %c8 = arith.constant 8 : index
+ %c16 = arith.constant 16 : index
+ %c32 = arith.constant 32 : index
+ %c1048576 = arith.constant 1048576 : index
+ %token0 = gpu.wait async
+ %d_a, %token1 = gpu.alloc async [%token0] () : memref<16x32xf16>
+ %d_b, %token2 = gpu.alloc async [%token1] () : memref<32x16xf16>
+ %d_c, %token3 = gpu.alloc async [%token2] () : memref<16x16xf16>
+ %token4 = gpu.memcpy async [%token3] %d_a, %a : memref<16x32xf16>, memref<16x32xf16>
+ %token5 = gpu.memcpy async [%token4] %d_b, %b : memref<32x16xf16>, memref<32x16xf16>
+ %token6 = gpu.memcpy async [%token5] %d_c, %c : memref<16x16xf16>, memref<16x16xf16>
+ %spmat, %token8 = gpu.create_2to4_spmat async [%token6]{PRUNE_AND_CHECK} %c16, %c32, %d_a: memref<16x32xf16>
+ %dnmat, %token9 = gpu.create_dn_tensor async [%token8] %d_b, %c32, %c16: index, index into memref<32x16xf16>
+ %dnmat2, %token10 = gpu.create_dn_tensor async [%token9] %d_c, %c16, %c16: index, index into memref<16x16xf16>
+ %bufferSz0, %bufferSz1, %bufferSz2, %token11 = gpu.spmm_buffer_size async [%token10] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2 : index, index,index into f16
+ %mem1, %token12 = gpu.alloc async [%token11] (%bufferSz0) : memref<?xf16>
+ %mem2, %token13 = gpu.alloc async [%token12] (%bufferSz1) : memref<?xf16>
+ %mem3, %token14 = gpu.alloc async [%token13] (%bufferSz2) : memref<?xf16>
+ %token15 = gpu.spmm async [%token14] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2, %mem1, %mem2, %mem3 : memref<?xf16>, memref<?xf16>,memref<?xf16> into f16
+ %token16 = gpu.destroy_sp_mat async [%token15] %spmat
+ %token17 = gpu.destroy_dn_tensor async [%token16] %dnmat
+ %token18 = gpu.destroy_dn_tensor async [%token17] %dnmat2
+ %token19 = gpu.memcpy async [%token18] %c, %d_c : memref<16x16xf16>, memref<16x16xf16>
+ %token20 = gpu.dealloc async [%token19] %d_c : memref<16x16xf16>
+ %token21 = gpu.dealloc async [%token20] %d_b : memref<32x16xf16>
+ %token22 = gpu.dealloc async [%token21] %d_a : memref<16x32xf16>
+ %token23 = gpu.dealloc async [%token22] %mem3 : memref<?xf16>
+ %token24 = gpu.dealloc async [%token23] %mem2 : memref<?xf16>
+ %token25 = gpu.dealloc async [%token24] %mem1 : memref<?xf16>
+ gpu.wait [%token25]
+ return
}
//
@@ -54,50 +72,49 @@ module {
%c64 = arith.constant 64 : index
// Matrices A, B, C (16x32, 32x16, 16x16).
+ %a = memref.alloc() : memref<16x32xf16> // 16x32 with 2:4, row-major
+ %b = memref.alloc() : memref<32x16xf16> // regular dense column-major
+ %c = memref.alloc() : memref<16x16xf16> // accumulator row-major
//
// Setup matrix A.
//
- %DA = tensor.generate {
- ^bb0(%i: index, %j: index):
- // (i+ j/2 + 1) if j %2 == 0 else 0
- %cf0 = arith.constant 0.0 : f16
- %cf1 = arith.constant 1.0 : f16
- %j_2 = arith.floordivsi %j, %c2 : index
- %quotient = arith.remsi %j, %c2 : index
- %sum = arith.addi %i, %j_2 : index
- %sum_i = arith.index_cast %sum : index to i64
- %sum_f = arith.uitofp %sum_i : i64 to f16
- %sum_f_plus1 = arith.addf %sum_f, %cf1 : f16
- %is_zero = arith.cmpi "eq", %quotient, %c0 : index
- %s = arith.select %is_zero, %sum_f_plus1, %cf0 : f16
- tensor.yield %s : f16
- } : tensor<16x32xf16>
+ scf.for %ai = %c0 to %c16 step %c1 {
+ scf.for %aj = %c0 to %c16 step %c1 {
+ %cf0 = arith.constant 0.0: f16
+ %a0 = arith.addi %ai, %aj : index
+ %a1 = arith.addi %a0, %c1 : index
+ %a2 = arith.index_cast %a1 : index to i32
+ %a3 = arith.sitofp %a2 : i32 to f16
+ %ajj = arith.muli %aj, %c2 : index
+ %ajj2 = arith.addi %ajj, %c1 : index
+ memref.store %a3, %a[%ai, %ajj] : memref<16x32xf16>
+ memref.store %cf0, %a[%ai, %ajj2] : memref<16x32xf16>
+ }
+ }
//
// Setup matrix B.
//
- %DB = tensor.generate {
- ^bb0(%i: index, %j: index):
- // if j_i >=8, j_i - 8 else 0
- %is_ge8 = arith.cmpi "sge", %j, %c8 : index
- %j_minus8 = arith.subi %j, %c8 : index
- %j2 = arith.select %is_ge8, %j_minus8, %j : index
- %r_i = arith.subi %j2, %i : index
- %r_i64 = arith.index_cast %r_i : index to i64
- %r_f = arith.sitofp %r_i64 : i64 to f16
- tensor.yield %r_f : f16
- } : tensor<32x16xf16>
+ scf.for %bi = %c0 to %c8 step %c1 {
+ scf.for %bj = %c0 to %c32 step %c1 {
+ %b0 = arith.subi %bi, %bj : index
+ %b1 = arith.index_cast %b0 : index to i32
+ %b2 = arith.sitofp %b1 : i32 to f16
+ %bii = arith.addi %bi, %c8 : index
+ memref.store %b2, %b[%bj, %bi] : memref<32x16xf16>
+ memref.store %b2, %b[%bj, %bii] : memref<32x16xf16>
+ }
+ }
//
// Reset matrix C.
//
- %DC = tensor.generate {
- ^bb0(%i: index, %j: index):
- %cf0 = arith.constant 0.0 : f16
- tensor.yield %cf0 : f16
- } : tensor<16x16xf16>
-
+ scf.for %ci = %c0 to %c16 step %c1 {
+ scf.for %cj = %c0 to %c16 step %c1 {
+ memref.store %f0, %c[%ci, %cj] : memref<16x16xf16>
+ }
+ }
//
// Sanity check on 16x32 full 2:4 input matrix A.
@@ -121,7 +138,7 @@ module {
// CHECK-NEXT: ( 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30, 0, 31, 0 )
//
scf.for %pai = %c0 to %c16 step %c1 {
- %pa0 = vector.transfer_read %DA[%pai, %c0], %f0 : tensor<16x32xf16>, vector<32xf16>
+ %pa0 = vector.transfer_read %a[%pai, %c0], %f0 : memref<16x32xf16>, vector<32xf16>
vector.print %pa0 : vector<32xf16>
}
@@ -163,14 +180,12 @@ module {
//
//
scf.for %pbi = %c0 to %c32 step %c1 {
- %pb0 = vector.transfer_read %DB[%pbi, %c0], %f0 : tensor<32x16xf16>, vector<16xf16>
+ %pb0 = vector.transfer_read %b[%pbi, %c0], %f0 : memref<32x16xf16>, vector<16xf16>
vector.print %pb0 : vector<16xf16>
}
// Call the kernel.
- %t1 = arith.constant 1 : index
- %t32 = arith.constant 32 : index
- %c_out = call @matmul_2to4 (%DA, %DB, %DC): (tensor<16x32xf16>, tensor<32x16xf16>, tensor<16x16xf16>) -> tensor<16x16xf16>
+ call @matmul24(%a, %b, %c): (memref<16x32xf16>, memref<32x16xf16>, memref<16x16xf16>) -> ()
//
// Verify computed matrix C.
@@ -193,7 +208,7 @@ module {
// CHECK-NEXT: ( -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688, -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688 )
//
scf.for %pci = %c0 to %c16 step %c1 {
- %pc0 = vector.transfer_read %c_out[%pci, %c0], %f0 : tensor<16x16xf16>, vector<16xf16>
+ %pc0 = vector.transfer_read %c[%pci, %c0], %f0 : memref<16x16xf16>, vector<16xf16>
vector.print %pc0 : vector<16xf16>
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
index daf29d5..17b50b4 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-lib.mlir
@@ -1,57 +1,41 @@
// NOTE: this test requires gpu-sm80 and cusparselt
//
-// DEFINE: %{compile} = mlir-opt --convert-vector-to-scf --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \
-// DEFINE: --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \
-// DEFINE: %s
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE: --sparsifier="enable-gpu-libgen gpu-triple=nvptx64-nvidia-cuda gpu-chip=sm_80 gpu-features=+ptx71 gpu-format=%gpu_compilation_format
// DEFINE: %{run} = mlir-cpu-runner \
// DEFINE: --shared-libs=%mlir_cuda_runtime \
// DEFINE: --shared-libs=%mlir_c_runner_utils \
// DEFINE: --e main --entry-point-result=void \
// DEFINE: | FileCheck %s
//
-// RUN: %{compile} | %{run}
+// with RT lib:
+//
+// RUN: %{compile} enable-runtime-library=true" | %{run}
+//
+// without RT lib:
+//
+// RUN: %{compile} enable-runtime-library=false" | %{run}
+
+#NV_24 = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i : dense,
+ j floordiv 4 : dense,
+ j mod 4 : block2_4
+ )
+}>
module {
llvm.func @mgpuCreateSparseLtEnv()
llvm.func @mgpuDestroySparseLtEnv()
- func.func @sampled_matmul(%a : memref<16x32xf16>,
- %b : memref<32x16xf16>,
- %c : memref<16x16xf16>) {
- %c0 = arith.constant 0.0 : f16
- %c1 = arith.constant 1 : index
- %c2 = arith.constant 2 : index
- %c8 = arith.constant 8 : index
- %c16 = arith.constant 16 : index
- %c32 = arith.constant 32 : index
- %c1048576 = arith.constant 1048576 : index
- %token0 = gpu.wait async
- %d_a, %token1 = gpu.alloc async [%token0] () : memref<16x32xf16>
- %d_b, %token2 = gpu.alloc async [%token1] () : memref<32x16xf16>
- %d_c, %token3 = gpu.alloc async [%token2] () : memref<16x16xf16>
- %token4 = gpu.memcpy async [%token3] %d_a, %a : memref<16x32xf16>, memref<16x32xf16>
- %token5 = gpu.memcpy async [%token4] %d_b, %b : memref<32x16xf16>, memref<32x16xf16>
- %token6 = gpu.memcpy async [%token5] %d_c, %c : memref<16x16xf16>, memref<16x16xf16>
- %spmat, %token8 = gpu.create_2to4_spmat async [%token6]{PRUNE_AND_CHECK} %c16, %c32, %d_a: memref<16x32xf16>
- %dnmat, %token9 = gpu.create_dn_tensor async [%token8] %d_b, %c32, %c16: index, index into memref<32x16xf16>
- %dnmat2, %token10 = gpu.create_dn_tensor async [%token9] %d_c, %c16, %c16: index, index into memref<16x16xf16>
- %bufferSz0, %bufferSz1, %bufferSz2, %token11 = gpu.spmm_buffer_size async [%token10] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2 : index, index,index into f16
- %mem1, %token12 = gpu.alloc async [%token11] (%bufferSz0) : memref<?xf16>
- %mem2, %token13 = gpu.alloc async [%token12] (%bufferSz1) : memref<?xf16>
- %mem3, %token14 = gpu.alloc async [%token13] (%bufferSz2) : memref<?xf16>
- %token15 = gpu.spmm async [%token14] %spmat{NON_TRANSPOSE}, %dnmat{NON_TRANSPOSE}, %dnmat2, %mem1, %mem2, %mem3 : memref<?xf16>, memref<?xf16>,memref<?xf16> into f16
- %token16 = gpu.destroy_sp_mat async [%token15] %spmat
- %token17 = gpu.destroy_dn_tensor async [%token16] %dnmat
- %token18 = gpu.destroy_dn_tensor async [%token17] %dnmat2
- %token19 = gpu.memcpy async [%token18] %c, %d_c : memref<16x16xf16>, memref<16x16xf16>
- %token20 = gpu.dealloc async [%token19] %d_c : memref<16x16xf16>
- %token21 = gpu.dealloc async [%token20] %d_b : memref<32x16xf16>
- %token22 = gpu.dealloc async [%token21] %d_a : memref<16x32xf16>
- %token23 = gpu.dealloc async [%token22] %mem3 : memref<?xf16>
- %token24 = gpu.dealloc async [%token23] %mem2 : memref<?xf16>
- %token25 = gpu.dealloc async [%token24] %mem1 : memref<?xf16>
- gpu.wait [%token25]
- return
+ func.func @matmul24(%Ad: tensor<16x32xf16>,
+ %B: tensor<32x16xf16>,
+ %Cin: tensor<16x16xf16>) -> tensor<16x16xf16> {
+ %A = sparse_tensor.convert %Ad : tensor<16x32xf16> to tensor<16x32xf16, #NV_24>
+ %C = linalg.matmul
+ ins(%A, %B: tensor<16x32xf16, #NV_24>, tensor<32x16xf16>)
+ outs(%Cin: tensor<16x16xf16>) -> tensor<16x16xf16>
+ return %C : tensor<16x16xf16>
}
//
@@ -71,49 +55,50 @@ module {
%c64 = arith.constant 64 : index
// Matrices A, B, C (16x32, 32x16, 16x16).
- %a = memref.alloc() : memref<16x32xf16> // 16x32 with 2:4, row-major
- %b = memref.alloc() : memref<32x16xf16> // regular dense column-major
- %c = memref.alloc() : memref<16x16xf16> // accumulator row-major
//
// Setup matrix A.
//
- scf.for %ai = %c0 to %c16 step %c1 {
- scf.for %aj = %c0 to %c16 step %c1 {
- %cf0 = arith.constant 0.0: f16
- %a0 = arith.addi %ai, %aj : index
- %a1 = arith.addi %a0, %c1 : index
- %a2 = arith.index_cast %a1 : index to i32
- %a3 = arith.sitofp %a2 : i32 to f16
- %ajj = arith.muli %aj, %c2 : index
- %ajj2 = arith.addi %ajj, %c1 : index
- memref.store %a3, %a[%ai, %ajj] : memref<16x32xf16>
- memref.store %cf0, %a[%ai, %ajj2] : memref<16x32xf16>
- }
- }
+ %DA = tensor.generate {
+ ^bb0(%i: index, %j: index):
+ // (i+ j/2 + 1) if j %2 == 0 else 0
+ %cf0 = arith.constant 0.0 : f16
+ %cf1 = arith.constant 1.0 : f16
+ %j_2 = arith.floordivsi %j, %c2 : index
+ %quotient = arith.remsi %j, %c2 : index
+ %sum = arith.addi %i, %j_2 : index
+ %sum_i = arith.index_cast %sum : index to i64
+ %sum_f = arith.uitofp %sum_i : i64 to f16
+ %sum_f_plus1 = arith.addf %sum_f, %cf1 : f16
+ %is_zero = arith.cmpi "eq", %quotient, %c0 : index
+ %s = arith.select %is_zero, %sum_f_plus1, %cf0 : f16
+ tensor.yield %s : f16
+ } : tensor<16x32xf16>
//
// Setup matrix B.
//
- scf.for %bi = %c0 to %c8 step %c1 {
- scf.for %bj = %c0 to %c32 step %c1 {
- %b0 = arith.subi %bi, %bj : index
- %b1 = arith.index_cast %b0 : index to i32
- %b2 = arith.sitofp %b1 : i32 to f16
- %bii = arith.addi %bi, %c8 : index
- memref.store %b2, %b[%bj, %bi] : memref<32x16xf16>
- memref.store %b2, %b[%bj, %bii] : memref<32x16xf16>
- }
- }
+ %DB = tensor.generate {
+ ^bb0(%i: index, %j: index):
+ // if j_i >=8, j_i - 8 else 0
+ %is_ge8 = arith.cmpi "sge", %j, %c8 : index
+ %j_minus8 = arith.subi %j, %c8 : index
+ %j2 = arith.select %is_ge8, %j_minus8, %j : index
+ %r_i = arith.subi %j2, %i : index
+ %r_i64 = arith.index_cast %r_i : index to i64
+ %r_f = arith.sitofp %r_i64 : i64 to f16
+ tensor.yield %r_f : f16
+ } : tensor<32x16xf16>
//
// Reset matrix C.
//
- scf.for %ci = %c0 to %c16 step %c1 {
- scf.for %cj = %c0 to %c16 step %c1 {
- memref.store %f0, %c[%ci, %cj] : memref<16x16xf16>
- }
- }
+ %DC = tensor.generate {
+ ^bb0(%i: index, %j: index):
+ %cf0 = arith.constant 0.0 : f16
+ tensor.yield %cf0 : f16
+ } : tensor<16x16xf16>
+
//
// Sanity check on 16x32 full 2:4 input matrix A.
@@ -137,7 +122,7 @@ module {
// CHECK-NEXT: ( 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30, 0, 31, 0 )
//
scf.for %pai = %c0 to %c16 step %c1 {
- %pa0 = vector.transfer_read %a[%pai, %c0], %f0 : memref<16x32xf16>, vector<32xf16>
+ %pa0 = vector.transfer_read %DA[%pai, %c0], %f0 : tensor<16x32xf16>, vector<32xf16>
vector.print %pa0 : vector<32xf16>
}
@@ -179,12 +164,16 @@ module {
//
//
scf.for %pbi = %c0 to %c32 step %c1 {
- %pb0 = vector.transfer_read %b[%pbi, %c0], %f0 : memref<32x16xf16>, vector<16xf16>
+ %pb0 = vector.transfer_read %DB[%pbi, %c0], %f0 : tensor<32x16xf16>, vector<16xf16>
vector.print %pb0 : vector<16xf16>
}
// Call the kernel.
- call @sampled_matmul (%a, %b, %c): (memref<16x32xf16>, memref<32x16xf16>, memref<16x16xf16>) -> ()
+ %t1 = arith.constant 1 : index
+ %t32 = arith.constant 32 : index
+ %c_out = call @matmul24(%DA, %DB, %DC): (tensor<16x32xf16>,
+ tensor<32x16xf16>,
+ tensor<16x16xf16>) -> tensor<16x16xf16>
//
// Verify computed matrix C.
@@ -207,7 +196,7 @@ module {
// CHECK-NEXT: ( -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688, -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688 )
//
scf.for %pci = %c0 to %c16 step %c1 {
- %pc0 = vector.transfer_read %c[%pci, %c0], %f0 : memref<16x16xf16>, vector<16xf16>
+ %pc0 = vector.transfer_read %c_out[%pci, %c0], %f0 : tensor<16x16xf16>, vector<16xf16>
vector.print %pc0 : vector<16xf16>
}
diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
index e3072860..eb99a02 100644
--- a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
+++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sm80-lt/sparse-matmul-2-4-prune.mlir
@@ -16,34 +16,27 @@
//
// RUN: %{compile} enable-runtime-library=false" | %{run}
-#map0 = affine_map<(d0, d1, d2) -> (d0, d2)>
-#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
-#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
+#NV_24 = #sparse_tensor.encoding<{
+ map = ( i, j ) ->
+ ( i : dense,
+ j floordiv 4 : dense,
+ j mod 4 : block2_4
+ )
+}>
module {
llvm.func @mgpuCreateSparseLtEnv()
llvm.func @mgpuDestroySparseLtEnv()
- //
- // TODO: This uses our temporary ATTRIBUTE, replace with 2:4 type!
- //
- func.func @matmul(%arg0: tensor<16x16xf16>,
- %arg1: tensor<16x16xf16>,
- %arg2: tensor<16x16xf16>) -> tensor<16x16xf16> {
- %0 = linalg.generic {
- DENSE24,
- indexing_maps = [#map0, #map1, #map2],
- iterator_types = ["parallel", "parallel", "reduction"]
- }
- ins(%arg0, %arg1 : tensor<16x16xf16>, tensor<16x16xf16>)
- outs(%arg2 : tensor<16x16xf16>) {
- ^bb0(%in: f16, %in_0: f16, %out: f16):
- %1 = arith.mulf %in, %in_0 : f16
- %2 = arith.addf %out, %1 : f16
- linalg.yield %2 : f16
- } -> tensor<16x16xf16>
- return %0 : tensor<16x16xf16>
+ func.func @matmul24(%Ad: tensor<16x16xf16>,
+ %B: tensor<16x16xf16>,
+ %Cin: tensor<16x16xf16>) -> tensor<16x16xf16> {
+ %A = sparse_tensor.convert %Ad : tensor<16x16xf16> to tensor<16x16xf16, #NV_24>
+ %C = linalg.matmul
+ ins(%A, %B: tensor<16x16xf16, #NV_24>, tensor<16x16xf16>)
+ outs(%Cin: tensor<16x16xf16>) -> tensor<16x16xf16>
+ return %C : tensor<16x16xf16>
}
func.func @main() {
@@ -81,7 +74,9 @@ module {
// By effectively computing D = A B + C with id(B) and zero(C)
// the resulting matrix returns the pruned A back to the caller.
//
- %D = call @matmul(%A, %B, %C): (tensor<16x16xf16>, tensor<16x16xf16>, tensor<16x16xf16>) -> (tensor<16x16xf16>)
+ %D = call @matmul24(%A, %B, %C): (tensor<16x16xf16>,
+ tensor<16x16xf16>,
+ tensor<16x16xf16>) -> (tensor<16x16xf16>)
//
// This was the original matrix.
diff --git a/mlir/test/Pass/crashless-reproducer.mlir b/mlir/test/Pass/crashless-reproducer.mlir
new file mode 100644
index 0000000..d874d90
--- /dev/null
+++ b/mlir/test/Pass/crashless-reproducer.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(builtin.module(test-module-pass))' --mlir-generate-reproducer=%t -verify-diagnostics
+// RUN: cat %t | FileCheck -check-prefix=REPRO %s
+
+module @inner_mod1 {
+ module @foo {}
+}
+
+// REPRO: module @inner_mod1
+// REPRO: module @foo {
+// REPRO: pipeline: "builtin.module(any(builtin.module(test-module-pass)))"
diff --git a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
index 21cc89c..8c13c0e 100644
--- a/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
+++ b/mlir/test/lib/Dialect/GPU/TestGpuRewrite.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -47,19 +48,40 @@ struct TestGpuSubgroupReduceLoweringPass
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestGpuSubgroupReduceLoweringPass)
+ TestGpuSubgroupReduceLoweringPass() = default;
+ TestGpuSubgroupReduceLoweringPass(
+ const TestGpuSubgroupReduceLoweringPass &pass)
+ : PassWrapper(pass) {}
+
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<arith::ArithDialect, vector::VectorDialect>();
}
+
StringRef getArgument() const final {
return "test-gpu-subgroup-reduce-lowering";
}
+
StringRef getDescription() const final {
return "Applies gpu.subgroup_reduce lowering patterns.";
}
+
+ Option<bool> expandToShuffles{
+ *this, "expand-to-shuffles",
+ llvm::cl::desc("Expand subgroup_reduce ops to shuffle ops."),
+ llvm::cl::init(false)};
+
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
+
+ // Since both pattern sets match on the same ops, set higher benefit to
+ // perform fewer failing matches.
populateGpuBreakDownSubgrupReducePatterns(patterns,
- /*maxShuffleBitwidth=*/32);
+ /*maxShuffleBitwidth=*/32,
+ PatternBenefit(2));
+ if (expandToShuffles)
+ populateGpuLowerSubgroupReduceToShufflePattenrs(
+ patterns, /*subgroupSize=*/32, /*shuffleBitwidth=*/32);
+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
diff --git a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
index 16b50bb..f14d282 100644
--- a/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Mesh/CMakeLists.txt
@@ -1,5 +1,6 @@
# Exclude tests from libMLIR.so
add_mlir_library(MLIRMeshTestSimplifications
+ TestReshardingSpmdization.cpp
TestSimplifications.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
new file mode 100644
index 0000000..6fecbd4
--- /dev/null
+++ b/mlir/test/lib/Dialect/Mesh/TestReshardingSpmdization.cpp
@@ -0,0 +1,122 @@
+//===- TestSimplification.cpp - Test simplification -----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/Mesh/IR/MeshOps.h"
+#include "mlir/Dialect/Mesh/Transforms/Spmdization.h"
+#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/BuiltinTypeInterfaces.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/ImplicitLocOpBuilder.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+using namespace mlir;
+using namespace mlir::mesh;
+
+namespace {
+
+struct TestMeshReshardingRewritePattern : OpRewritePattern<ShardOp> {
+ using OpRewritePattern<ShardOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ShardOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getAnnotateForUsers()) {
+ return failure();
+ }
+
+ SymbolTableCollection symbolTable;
+ mesh::ClusterOp mesh = symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ op, op.getShard().getCluster());
+
+ bool foundUser = false;
+ for (auto user : op->getUsers()) {
+ if (auto targetShardOp = llvm::dyn_cast<ShardOp>(user)) {
+ if (targetShardOp.getAnnotateForUsers() &&
+ mesh == symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ targetShardOp, targetShardOp.getShard().getCluster())) {
+ foundUser = true;
+ break;
+ }
+ }
+ }
+
+ if (!foundUser) {
+ return failure();
+ }
+
+ for (auto user : op->getUsers()) {
+ auto targetShardOp = llvm::dyn_cast<ShardOp>(user);
+ if (!targetShardOp || !targetShardOp.getAnnotateForUsers() ||
+ symbolTable.lookupNearestSymbolFrom<mesh::ClusterOp>(
+ targetShardOp, targetShardOp.getShard().getCluster()) != mesh) {
+ continue;
+ }
+
+ ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
+ ShapedType sourceShardShape =
+ shardShapedType(op.getResult().getType(), mesh, op.getShard());
+ TypedValue<ShapedType> sourceShard =
+ builder
+ .create<UnrealizedConversionCastOp>(sourceShardShape,
+ op.getOperand())
+ ->getResult(0)
+ .cast<TypedValue<ShapedType>>();
+ TypedValue<ShapedType> targetShard =
+ reshard(builder, mesh, op, targetShardOp, sourceShard);
+ Value newTargetUnsharded =
+ builder
+ .create<UnrealizedConversionCastOp>(
+ targetShardOp.getResult().getType(), targetShard)
+ ->getResult(0);
+ rewriter.replaceAllUsesWith(targetShardOp.getResult(),
+ newTargetUnsharded);
+ }
+
+ return success();
+ }
+};
+
+struct TestMeshReshardingPass
+ : public PassWrapper<TestMeshReshardingPass, OperationPass<ModuleOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMeshReshardingPass)
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ patterns.insert<TestMeshReshardingRewritePattern>(&getContext());
+ if (failed(applyPatternsAndFoldGreedily(getOperation().getOperation(),
+ std::move(patterns)))) {
+ return signalPassFailure();
+ }
+ }
+ void getDependentDialects(DialectRegistry &registry) const override {
+ reshardingRegisterDependentDialects(registry);
+ registry.insert<BuiltinDialect>();
+ }
+ StringRef getArgument() const final {
+ return "test-mesh-resharding-spmdization";
+ }
+ StringRef getDescription() const final {
+ return "Test Mesh dialect resharding spmdization.";
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestMeshReshardingSpmdizationPass() {
+ PassRegistration<TestMeshReshardingPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
index e8c25ac..9c69164 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp
@@ -457,51 +457,6 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne(
}
DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply(
- transform::TransformRewriter &rewriter,
- transform::TransformResults &results, transform::TransformState &state) {
- if (!getHandle())
- emitRemark() << 0;
- emitRemark() << llvm::range_size(state.getPayloadOps(getHandle()));
- return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getHandle(), effects);
-}
-
-DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::apply(
- transform::TransformRewriter &rewriter,
- transform::TransformResults &results, transform::TransformState &state) {
- if (!getValueHandle())
- emitRemark() << 0;
- emitRemark() << llvm::range_size(state.getPayloadValues(getValueHandle()));
- return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIRValues::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getValueHandle(), effects);
-}
-
-DiagnosedSilenceableFailure
-mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::apply(
- transform::TransformRewriter &rewriter,
- transform::TransformResults &results, transform::TransformState &state) {
- if (!getParam())
- emitRemark() << 0;
- emitRemark() << llvm::range_size(state.getParams(getParam()));
- return DiagnosedSilenceableFailure::success();
-}
-
-void mlir::test::TestPrintNumberOfAssociatedPayloadIRParams::getEffects(
- SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
- transform::onlyReadsHandle(getParam(), effects);
-}
-
-DiagnosedSilenceableFailure
mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
index 41f318d..5cb4765 100644
--- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
+++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td
@@ -343,33 +343,6 @@ def TestMixedSuccessAndSilenceableOp
}];
}
-def TestPrintNumberOfAssociatedPayloadIROps
- : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_ops",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
- let arguments = (ins TransformHandleTypeInterface:$handle);
- let assemblyFormat = "$handle attr-dict `:` type($handle)";
- let cppNamespace = "::mlir::test";
-}
-
-def TestPrintNumberOfAssociatedPayloadIRValues
- : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_values",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
- let arguments = (ins TransformValueHandleTypeInterface:$value_handle);
- let assemblyFormat = "$value_handle attr-dict `:` type($value_handle)";
- let cppNamespace = "::mlir::test";
-}
-
-def TestPrintNumberOfAssociatedPayloadIRParams
- : Op<Transform_Dialect, "test_print_number_of_associated_payload_ir_params",
- [DeclareOpInterfaceMethods<TransformOpInterface>,
- DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
- let arguments = (ins TransformParamTypeInterface:$param);
- let assemblyFormat = "$param attr-dict `:` type($param)";
- let cppNamespace = "::mlir::test";
-}
-
def TestCopyPayloadOp
: Op<Transform_Dialect, "test_copy_payload",
[DeclareOpInterfaceMethods<TransformOpInterface>,
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index dc4121dc..f7a5b31 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -119,6 +119,7 @@ void registerTestMathPolynomialApproximationPass();
void registerTestMemRefDependenceCheck();
void registerTestMemRefStrideCalculation();
void registerTestMeshSimplificationsPass();
+void registerTestMeshReshardingSpmdizationPass();
void registerTestNextAccessPass();
void registerTestOneToNTypeConversionPass();
void registerTestOpaqueLoc();
@@ -237,6 +238,7 @@ void registerTestPasses() {
mlir::test::registerTestMemRefDependenceCheck();
mlir::test::registerTestMemRefStrideCalculation();
mlir::test::registerTestMeshSimplificationsPass();
+ mlir::test::registerTestMeshReshardingSpmdizationPass();
mlir::test::registerTestNextAccessPass();
mlir::test::registerTestOneToNTypeConversionPass();
mlir::test::registerTestOpaqueLoc();
diff --git a/mlir/unittests/IR/OpPropertiesTest.cpp b/mlir/unittests/IR/OpPropertiesTest.cpp
index 5c3e150..bb1b741 100644
--- a/mlir/unittests/IR/OpPropertiesTest.cpp
+++ b/mlir/unittests/IR/OpPropertiesTest.cpp
@@ -6,6 +6,7 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/IR/Attributes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Parser/Parser.h"
#include "gtest/gtest.h"
@@ -132,6 +133,23 @@ public:
}
};
+/// A custom operation for the purpose of showcasing how discardable attributes
+/// are handled in absence of properties.
+class OpWithoutProperties : public Op<OpWithoutProperties> {
+public:
+ // Begin boilerplate.
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpWithoutProperties)
+ using Op::Op;
+ static ArrayRef<StringRef> getAttributeNames() {
+ static StringRef attributeNames[] = {StringRef("inherent_attr")};
+ return ArrayRef(attributeNames);
+ };
+ static StringRef getOperationName() {
+ return "test_op_properties.op_without_properties";
+ }
+ // End boilerplate.
+};
+
// A trivial supporting dialect to register the above operation.
class TestOpPropertiesDialect : public Dialect {
public:
@@ -142,7 +160,7 @@ public:
explicit TestOpPropertiesDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context,
TypeID::get<TestOpPropertiesDialect>()) {
- addOperations<OpWithProperties>();
+ addOperations<OpWithProperties, OpWithoutProperties>();
}
};
@@ -359,4 +377,30 @@ TEST(OpPropertiesTest, getOrAddProperties) {
op->erase();
}
+constexpr StringLiteral withoutPropertiesAttrsSrc = R"mlir(
+ "test_op_properties.op_without_properties"()
+ {inherent_attr = 42, other_attr = 56} : () -> ()
+)mlir";
+
+TEST(OpPropertiesTest, withoutPropertiesDiscardableAttrs) {
+ MLIRContext context;
+ context.getOrLoadDialect<TestOpPropertiesDialect>();
+ ParserConfig config(&context);
+ OwningOpRef<Operation *> op =
+ parseSourceString(withoutPropertiesAttrsSrc, config);
+ ASSERT_EQ(llvm::range_size(op->getDiscardableAttrs()), 1u);
+ EXPECT_EQ(op->getDiscardableAttrs().begin()->getName().getValue(),
+ "other_attr");
+
+ EXPECT_EQ(op->getAttrs().size(), 2u);
+ EXPECT_TRUE(op->getInherentAttr("inherent_attr") != std::nullopt);
+ EXPECT_TRUE(op->getDiscardableAttr("other_attr") != Attribute());
+
+ std::string output;
+ llvm::raw_string_ostream os(output);
+ op->print(os);
+ EXPECT_TRUE(StringRef(os.str()).contains("inherent_attr = 42"));
+ EXPECT_TRUE(StringRef(os.str()).contains("other_attr = 56"));
+}
+
} // namespace