aboutsummaryrefslogtreecommitdiff
path: root/mlir
diff options
context:
space:
mode:
Diffstat (limited to 'mlir')
-rw-r--r--mlir/docs/Bindings/Python.md84
-rw-r--r--mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td48
-rw-r--r--mlir/include/mlir/Dialect/SCF/IR/SCF.h4
-rw-r--r--mlir/include/mlir/Dialect/SCF/Utils/Utils.h33
-rw-r--r--mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h6
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td101
-rw-r--r--mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td101
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h314
-rw-r--r--mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h104
-rw-r--r--mlir/lib/Bindings/Python/IRCore.cpp2
-rw-r--r--mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp2
-rw-r--r--mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp6
-rw-r--r--mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp9
-rw-r--r--mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp6
-rw-r--r--mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp2
-rw-r--r--mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp29
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp6
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp145
-rw-r--r--mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp17
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp16
-rw-r--r--mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp30
-rw-r--r--mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp7
-rw-r--r--mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp2
-rw-r--r--mlir/lib/ExecutionEngine/ExecutionEngine.cpp2
-rw-r--r--mlir/lib/IR/MLIRContext.cpp4
-rw-r--r--mlir/lib/Target/LLVMIR/ModuleTranslation.cpp52
-rw-r--r--mlir/lib/Target/SPIRV/Serialization/Serializer.cpp18
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir80
-rw-r--r--mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir78
-rw-r--r--mlir/test/Dialect/AMDGPU/canonicalize.mlir30
-rw-r--r--mlir/test/Dialect/AMDGPU/invalid.mlir96
-rw-r--r--mlir/test/Dialect/AMDGPU/ops.mlir17
-rw-r--r--mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir8
-rw-r--r--mlir/test/Dialect/SCF/parallel-loop-unroll.mlir171
-rw-r--r--mlir/test/Dialect/SPIRV/IR/structure-ops.mlir2
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-blocking.mlir59
-rw-r--r--mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir7
-rw-r--r--mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir21
-rw-r--r--mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir13
-rw-r--r--mlir/test/Target/SPIRV/decorations-intel-cache-controls.mlir42
-rw-r--r--mlir/test/Target/SPIRV/decorations.mlir77
-rw-r--r--mlir/test/Target/SPIRV/function-decorations.mlir22
-rw-r--r--mlir/test/Target/SPIRV/global-variable.mlir22
-rw-r--r--mlir/test/lib/Dialect/SCF/CMakeLists.txt1
-rw-r--r--mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp85
-rw-r--r--mlir/test/python/CMakeLists.txt2
-rw-r--r--mlir/test/python/execution_engine.py27
-rw-r--r--mlir/test/python/ir/operation.py12
-rw-r--r--mlir/tools/mlir-opt/mlir-opt.cpp2
-rw-r--r--mlir/unittests/Analysis/Presburger/SimplexTest.cpp28
52 files changed, 1434 insertions, 632 deletions
diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md
index 6f778b0..877ae51 100644
--- a/mlir/docs/Bindings/Python.md
+++ b/mlir/docs/Bindings/Python.md
@@ -9,7 +9,7 @@
### Pre-requisites
* A relatively recent Python3 installation
-* Installation of python dependencies as specified in
+* Installation of Python dependencies as specified in
`mlir/python/requirements.txt`
### CMake variables
@@ -27,8 +27,8 @@
### Recommended development practices
-It is recommended to use a python virtual environment. Many ways exist for this,
-but the following is the simplest:
+It is recommended to use a Python virtual environment. Many ways exist for this,
+but one of the following is generally recommended:
```shell
# Make sure your 'python' is what you expect. Note that on multi-python
@@ -37,7 +37,22 @@ but the following is the simplest:
which python
python -m venv ~/.venv/mlirdev
source ~/.venv/mlirdev/bin/activate
+```
+
+Or, if you have uv installed on your system, you can also use the following commands
+to create the same environment (targeting a Python 3.12 toolchain in this example):
+
+```shell
+uv venv ~/.venv/mlirdev --seed -p 3.12
+source ~/.venv/mlirdev/bin/activate
+```
+
+You can change the Python version (`-p` flag) as needed - if you request any Python interpreter
+not present on your system, uv will attempt to download it, unless the `--no-python-downloads` option is given.
+For information on how to install uv, refer to the official documentation at
+https://docs.astral.sh/uv/getting-started/installation/
+```shell
# Note that many LTS distros will bundle a version of pip itself that is too
# old to download all of the latest binaries for certain platforms.
# The pip version can be obtained with `python -m pip --version`, and for
@@ -46,14 +61,16 @@ source ~/.venv/mlirdev/bin/activate
# It is recommended to upgrade pip:
python -m pip install --upgrade pip
-
# Now the `python` command will resolve to your virtual environment and
# packages will be installed there.
python -m pip install -r mlir/python/requirements.txt
+# In a uv-generated virtual environment, you can instead run:
+uv pip install -r mlir/python/requirements.txt
+
# Now run your build command with `cmake`, `ninja`, et al.
-# Run mlir tests. For example, to run python bindings tests only using ninja:
+# Run mlir tests. For example, to run Python bindings tests only using ninja:
ninja check-mlir-python
```
@@ -65,7 +82,7 @@ the `PYTHONPATH`. Typically:
export PYTHONPATH=$(cd build && pwd)/tools/mlir/python_packages/mlir_core
```
-Note that if you have installed (i.e. via `ninja install`, et al), then python
+Note that if you have installed (i.e. via `ninja install`, et al), then Python
packages for all enabled projects will be in your install tree under
`python_packages/` (i.e. `python_packages/mlir_core`). Official distributions
are built with a more specialized setup.
@@ -74,14 +91,14 @@ are built with a more specialized setup.
### Use cases
-There are likely two primary use cases for the MLIR python bindings:
+There are likely two primary use cases for the MLIR Python bindings:
1. Support users who expect that an installed version of LLVM/MLIR will yield
the ability to `import mlir` and use the API in a pure way out of the box.
1. Downstream integrations will likely want to include parts of the API in
their private namespace or specially built libraries, probably mixing it
- with other python native bits.
+ with other Python native bits.
### Composable modules
@@ -89,8 +106,8 @@ In order to support use case \#2, the Python bindings are organized into
composable modules that downstream integrators can include and re-export into
their own namespace if desired. This forces several design points:
-* Separate the construction/populating of a `py::module` from
- `PYBIND11_MODULE` global constructor.
+* Separate the construction/populating of a `nb::module` from
+ `NB_MODULE` global constructor.
* Introduce headers for C++-only wrapper classes as other related C++ modules
will need to interop with it.
@@ -130,7 +147,7 @@ registration, etc.
### Loader
-LLVM/MLIR is a non-trivial python-native project that is likely to co-exist with
+LLVM/MLIR is a non-trivial Python-native project that is likely to co-exist with
other non-trivial native extensions. As such, the native extension (i.e. the
`.so`/`.pyd`/`.dylib`) is exported as a notionally private top-level symbol
(`_mlir`), while a small set of Python code is provided in
@@ -160,7 +177,7 @@ are) with non-RTTI polymorphic C++ code (the default compilation mode of LLVM).
### Ownership in the Core IR
There are several top-level types in the core IR that are strongly owned by
-their python-side reference:
+their Python-side reference:
* `PyContext` (`mlir.ir.Context`)
* `PyModule` (`mlir.ir.Module`)
@@ -219,23 +236,24 @@ Due to the validity and parenting accounting needs, `PyOperation` is the owner
for regions and blocks. Operations are also the only entities which are allowed to be in
a detached state.
-**Note**: Multiple `PyOperation` objects (i.e., the Python objects themselves) can alias a single `mlir::Operation`.
-This means, for example, if you have `py_op1` and `py_op2` which wrap the same `mlir::Operation op`
+**Note**: Multiple `PyOperation` objects (i.e., the Python objects themselves) can alias a single `mlir::Operation`.
+This means, for example, if you have `py_op1` and `py_op2` which wrap the same `mlir::Operation op`
and you somehow transform `op` (e.g., you run a pass on `op`) then walking the MLIR AST via either/or `py_op1`, `py_op2`
-will reflect the same MLIR AST. This is perfectly safe and supported. What is not supported is invalidating any
-operation while there exist multiple Python objects wrapping that operation **and then manipulating those wrappers**.
-For example if `py_op1` and `py_op2` wrap the same operation under a root `py_op3` and then `py_op3` is
-transformed such that the operation referenced (by `py_op1`, `py_op2`) is erased. Then `py_op1`, `py_op2`
-become "undefined" in a sense; manipulating them in any way is "formally forbidden". Note, this also applies to
-`SymbolTable` mutation, which is considered a transformation of the root `SymbolTable`-supporting operation for the
-purposes of the discussion here. Metaphorically, one can think of this similarly to how STL container iterators are invalidated once the container itself is changed. The "best practices" recommendation is to structure your code such that
+will reflect the same MLIR AST. This is perfectly safe and supported. What is not supported is invalidating any
+operation while there exist multiple Python objects wrapping that operation **and then manipulating those wrappers**.
+For example if `py_op1` and `py_op2` wrap the same operation under a root `py_op3` and then `py_op3` is
+transformed such that the operation referenced (by `py_op1`, `py_op2`) is erased. Then `py_op1`, `py_op2`
+become "undefined" in a sense; manipulating them in any way is "formally forbidden". Note, this also applies to
+`SymbolTable` mutation, which is considered a transformation of the root `SymbolTable`-supporting operation for the
+purposes of the discussion here. Metaphorically, one can think of this similarly to how STL container iterators are invalidated
+once the container itself is changed. The "best practices" recommendation is to structure your code such that
1. First, query/manipulate various Python wrapper objects `py_op1`, `py_op2`, `py_op3`, etc.;
2. Second, transform the AST/erase operations/etc. via a single root object;
3. Invalidate all queried nodes (e.g., using `op._set_invalid()`).
-Ideally this should be done in a function body so that step (3) corresponds to the end of the function and there are no
-risks of Python wrapper objects leaking/living longer than necessary. In summary, you should scope your changes based on
+Ideally this should be done in a function body so that step (3) corresponds to the end of the function and there are no
+risks of Python wrapper objects leaking/living longer than necessary. In summary, you should scope your changes based on
nesting i.e., change leaf nodes first before going up in hierarchy, and only in very rare cases query nested ops post
modifying a parent op.
@@ -773,7 +791,7 @@ This allows to invoke op creation of an op with a `I32Attr` with
foo.Op(30)
```
-The registration is based on the ODS name but registry is via pure python
+The registration is based on the ODS name but registry is via pure Python
method. Only single custom builder is allowed to be registered per ODS attribute
type (e.g., I32Attr can have only one, which can correspond to multiple of the
underlying IntegerAttr type).
@@ -795,13 +813,13 @@ either for practicality or to give the resulting library an appropriately
Generally favor converting trivial methods like `getContext()`, `getName()`,
`isEntryBlock()`, etc to read-only Python properties (i.e. `context`). It is
-primarily a matter of calling `def_property_readonly` vs `def` in binding code,
+primarily a matter of calling `def_prop_ro` vs `def` in binding code,
and makes things feel much nicer to the Python side.
For example, prefer:
```c++
-m.def_property_readonly("context", ...)
+m.def_prop_ro("context", ...)
```
Over:
@@ -914,7 +932,7 @@ def create_my_op():
The MLIR Python bindings integrate with the tablegen-based ODS system for
providing user-friendly wrappers around MLIR dialects and operations. There are
multiple parts to this integration, outlined below. Most details have been
-elided: refer to the build rules and python sources under `mlir.dialects` for
+elided: refer to the build rules and Python sources under `mlir.dialects` for
the canonical way to use this facility.
Users are responsible for providing a `{DIALECT_NAMESPACE}.py` (or an equivalent
@@ -922,9 +940,9 @@ directory with `__init__.py` file) as the entrypoint.
### Generating `_{DIALECT_NAMESPACE}_ops_gen.py` wrapper modules
-Each dialect with a mapping to python requires that an appropriate
+Each dialect with a mapping to Python requires that an appropriate
`_{DIALECT_NAMESPACE}_ops_gen.py` wrapper module is created. This is done by
-invoking `mlir-tblgen` on a python-bindings specific tablegen wrapper that
+invoking `mlir-tblgen` on a Python-bindings specific tablegen wrapper that
includes the boilerplate and actual dialect specific `td` file. An example, for
the `Func` (which is assigned the namespace `func` as a special case):
@@ -954,7 +972,7 @@ from ._my_dialect_ops_gen import *
### Extending the search path for wrapper modules
-When the python bindings need to locate a wrapper module, they consult the
+When the Python bindings need to locate a wrapper module, they consult the
`dialect_search_path` and use it to find an appropriately named module. For the
main repository, this search path is hard-coded to include the `mlir.dialects`
module, which is where wrappers are emitted by the above build rule. Out of tree
@@ -1153,7 +1171,7 @@ subclasses can be defined using
[`include/mlir/Bindings/Python/PybindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/PybindAdaptors.h)
or
[`include/mlir/Bindings/Python/NanobindAdaptors.h`](https://github.com/llvm/llvm-project/blob/main/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h)
-utilities that mimic pybind11/nanobind API for defining functions and
+utilities that mimic pybind11/nanobind APIs for defining functions and
properties. These bindings are to be included in a separate module. The
utilities also provide automatic casting between C API handles `MlirAttribute`
and `MlirType` and their Python counterparts so that the C API handles can be
@@ -1176,11 +1194,11 @@ are available when the dialect is loaded from Python.
Dialect-specific passes can be made available to the pass manager in Python by
registering them with the context and relying on the API for pass pipeline
parsing from string descriptions. This can be achieved by creating a new
-pybind11 module, defined in `lib/Bindings/Python/<Dialect>Passes.cpp`, that
+nanobind module, defined in `lib/Bindings/Python/<Dialect>Passes.cpp`, that
calls the registration C API, which must be provided first. For passes defined
declaratively using Tablegen, `mlir-tblgen -gen-pass-capi-header` and
`-mlir-tblgen -gen-pass-capi-impl` automate the generation of C API. The
-pybind11 module must be compiled into a separate “Python extension” library,
+nanobind module must be compiled into a separate “Python extension” library,
which can be `import`ed from the main dialect file, i.e.
`python/mlir/dialects/<dialect-namespace>.py` or
`python/mlir/dialects/<dialect-namespace>/__init__.py`, or from a separate
diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
index d74abc2..37db096 100644
--- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
+++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td
@@ -923,10 +923,10 @@ def AMDGPU_MFMAOp :
AMDGPU_Op<"mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
Arguments<(ins
- I32Attr:$m,
- I32Attr:$n,
- I32Attr:$k,
- I32Attr:$blocks,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[4, 16, 32]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[1, 2, 4, 8, 16, 32, 64, 128]>]>:$k,
+ DefaultValuedAttr<ConfinedAttr<I32Attr, [IntIsOneOf<[1, 2, 4, 16]>]>, "1">:$blocks,
MFMAInTypes:$sourceA,
MFMAInTypes:$sourceB,
MFMAOutTypes:$destC,
@@ -969,14 +969,16 @@ def AMDGPU_MFMAOp :
Example:
```mlir
- %0 = amdgpu.mfma %matA * %matB + %matC
- { abid = 1 : i32, cbsz = 1 : i32,
- m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32 }
+ %0 = amdgpu.mfma 16x16x16 %matA * %matB + %matC
+ : vector<4xf16>, vector<4xf16>, vector<4xf32>
+
+ %1 = amdgpu.mfma 32x32x1 %matD * %matE + %matF
+ { abid = 1 : i32, cbsz = 1 : i32, blocks = 2 : i32 }
blgp = bcast_second_32 : f32, f32, vector<32xf32>
```
}];
let assemblyFormat = [{
- $sourceA `*` $sourceB `+` $destC
+ custom<MNKDimensionList>($m, $n, $k) $sourceA `*` $sourceB `+` $destC
attr-dict
`blgp` `=` $blgp
`:` type($sourceA) `,` type($sourceB) `,` type($destC)
@@ -1109,9 +1111,9 @@ def AMDGPU_ScaledMFMAOp :
AMDGPU_Op<"scaled_mfma", [AllTypesMatch<["destC", "destD"]>,
Pure]>,
Arguments<(ins
- I32Attr:$m,
- I32Attr:$n,
- I32Attr:$k,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$m,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[16, 32]>]>:$n,
+ ConfinedAttr<I32Attr, [IntIsOneOf<[64, 128]>]>:$k,
ScaledMFMAInTypes:$sourceA,
ScaledMFMAInTypes:$sourceB,
ScaledMFMAOutTypes:$destC,
@@ -1124,8 +1126,8 @@ def AMDGPU_ScaledMFMAOp :
let summary = "MLIR wrapper for CDNA scaled mfma instructions";
let description = [{
The `amdgpu.scaled_mfma` op is an MLIR wrapper around intrinsics
- for various scaled versions of `mfma` instructions in the CDNA architecture, which perform
- multiple outer products in order to allow fast matrix multiplication.
+ for various scaled versions of `mfma` instructions in the CDNA architecture, which
+ perform multiple outer products in order to allow fast matrix multiplication.
The wrapper will select an appropriate `mfma` instruction, if one is available,
based on the provided `m`, `k`, `n`, and `nBlks` attributes, along with the
@@ -1140,15 +1142,23 @@ def AMDGPU_ScaledMFMAOp :
This wrapper takes inspiration from `amdgpu.mfma`, but has some key differences:
- `amdgpu.scaled_mfma` operates on fp4 (f4E2M1FN), fp6 (f6E2M3FN and f6E3M2FN) and
- fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as their tile
- size.
+ fp8 (f8E4M3FN and f8E5M2) types using either M=N=16, K=128 or M=N=32, K=64 as
+ their tile size.
- `amdgpu.scaled_mfma` does not support broadcasting. So, `cbsz`, `abid`, and `blgp`
- are omitted from this wrapper.
- - The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported for
- double-precision operations on gfx94x and so are not included here.
+ are omitted from this wrapper.
+ - The `negateA`, `negateB`, and `negateC` flags in `amdgpu.mfma` are only supported
+ for double-precision operations on gfx94x and so are not included here.
+
+ Example:
+ ```mlir
+ %0 = amdgpu.scaled_mfma 32x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2
+ : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
+ ```
}];
let assemblyFormat = [{
- `(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*` `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
+ custom<MNKDimensionList>($m, $n, $k) ` `
+ `(` $scalesA `[` $scalesIdxA `]` `*` $sourceA `)` `*`
+ `(` $scalesB `[` $scalesIdxB `]` `*` $sourceB `)` `+` $destC
attr-dict
`:` type($scalesA) `,` type($sourceA) `,` type($scalesB) `,` type($sourceB) `,` type($destC)
}];
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index ba64818..e754a04 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -112,6 +112,10 @@ SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter,
Value replacement,
const ValueTypeCastFnTy &castFn);
+/// Helper function to compute the difference between two values. This is used
+/// by the loop implementations to compute the trip count.
+std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub, bool isSigned);
+
} // namespace scf
} // namespace mlir
#endif // MLIR_DIALECT_SCF_SCF_H
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index ecd829e..3475bb2 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -221,6 +221,39 @@ FailureOr<scf::ForallOp> normalizeForallOp(RewriterBase &rewriter,
/// 4. Each region iter arg and result has exactly one use
bool isPerfectlyNestedForLoops(MutableArrayRef<LoopLikeOpInterface> loops);
+/// Generate unrolled copies of an scf loop's 'loopBodyBlock', with 'iterArgs'
+/// and 'yieldedValues' as the block arguments and yielded values of the loop.
+/// The content of the loop body is replicated 'unrollFactor' times, calling
+/// 'ivRemapFn' to remap 'iv' for each unrolled body. If specified, annotates
+/// the Ops in each unrolled iteration using annotateFn. If provided,
+/// 'clonedToSrcOpsMap' is populated with the mappings from the cloned ops to
+/// the original op.
+void generateUnrolledLoop(
+ Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
+ function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
+ function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
+ ValueRange iterArgs, ValueRange yieldedValues,
+ IRMapping *clonedToSrcOpsMap = nullptr);
+
+/// Unroll this scf::Parallel loop by the specified unroll factors. Returns the
+/// unrolled loop if the unroll succeded; otherwise returns failure if the loop
+/// cannot be unrolled either due to restrictions or to invalid unroll factors.
+/// Requires positive loop bounds and step. If specified, annotates the Ops in
+/// each unrolled iteration by applying `annotateFn`.
+/// If provided, 'clonedToSrcOpsMap' is populated with the mappings from the
+/// cloned ops to the original op.
+FailureOr<scf::ParallelOp> parallelLoopUnrollByFactors(
+ scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors,
+ RewriterBase &rewriter,
+ function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr,
+ IRMapping *clonedToSrcOpsMap = nullptr);
+
+/// Get constant trip counts for each of the induction variables of the given
+/// loop operation. If any of the loop's trip counts is not constant, return an
+/// empty vector.
+llvm::SmallVector<int64_t>
+getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp);
+
} // namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
index d0a3f01..43e48a6 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensor.h
@@ -158,16 +158,14 @@ namespace sparse_tensor {
/// Convenience method to abbreviate casting `getType()`.
template <typename T>
inline RankedTensorType getRankedTensorType(T &&t) {
- assert(static_cast<bool>(std::forward<T>(t)) &&
- "getRankedTensorType got null argument");
+ assert(static_cast<bool>(t) && "getRankedTensorType got null argument");
return dyn_cast<RankedTensorType>(std::forward<T>(t).getType());
}
/// Convenience method to abbreviate casting `getType()`.
template <typename T>
inline MemRefType getMemRefType(T &&t) {
- assert(static_cast<bool>(std::forward<T>(t)) &&
- "getMemRefType got null argument");
+ assert(static_cast<bool>(t) && "getMemRefType got null argument");
return cast<MemRefType>(std::forward<T>(t).getType());
}
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6e17591..467dba3 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -19,6 +19,7 @@ include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
include "mlir/Dialect/Tosa/IR/TosaTypesBase.td"
@@ -2814,6 +2815,106 @@ def Tosa_WhileOp : Tosa_Op<"while_loop", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// Operator: variable
+//===----------------------------------------------------------------------===//
+def Tosa_VariableOp : Tosa_Op<"variable", [Symbol]> {
+ let summary = "Defines a variable";
+
+ let description = [{
+ Defines a new TOSA variable. This is a persistent mutable value across multiple
+ TOSA graph invocations. Modifications are expressed using read/write semantics.
+ }];
+
+ let arguments = (ins
+ // Note: "sym_name" is used as opposed to "name" in the specification,
+ // since a Symbol must be named "sym_name" for it to be recognised by
+ // the containing SymbolTable.
+ SymbolNameAttr:$sym_name,
+ IndexElementsAttr:$var_shape,
+ TypeAttr:$type,
+ OptionalAttr<AnyAttr>:$initial_value
+ );
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_VARIABLE]>,
+ ];
+
+ let hasCustomAssemblyFormat = 1;
+
+ let assemblyFormat = [{
+ $sym_name
+ attr-dict
+ custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
+ }];
+
+ let builders = [Tosa_VariableOpBuilder];
+
+ let extraClassDeclaration = [{
+ ::llvm::StringRef getName() {
+ return getSymName();
+ }
+ }];
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable_write
+//===----------------------------------------------------------------------===//
+def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
+ let summary = "write_buffer operator";
+
+ let description = [{
+ Assigns a value to the pseudo-buffer resource holding a persistent mutable tensor.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$name,
+ Tosa_Tensor:$input1
+ );
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_VARIABLE]>,
+ ];
+
+ let assemblyFormat = [{
+ $name attr-dict `,` $input1 `:` type($input1)
+ }];
+
+ let hasVerifier = 1;
+}
+
+//===----------------------------------------------------------------------===//
+// Operator: variable_read
+//===----------------------------------------------------------------------===//
+def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
+ let summary = "read_buffer operator";
+
+ let description = [{
+ Reads the value from a pseudo-buffer resource holding a persistent mutable tensor.
+ }];
+
+ let arguments = (ins
+ SymbolNameAttr:$name
+ );
+
+ let results = (outs
+ Tosa_Tensor:$output1
+ );
+
+ list<Availability> availability = [
+ Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
+ Extension<[Tosa_EXT_VARIABLE]>,
+ ];
+
+ let assemblyFormat = [{
+ $name attr-dict `:` type($output1)
+ }];
+
+ let hasVerifier = 1;
+}
+
include "mlir/Dialect/Tosa/IR/TosaUtilOps.td"
include "mlir/Dialect/Tosa/IR/TosaShapeOps.td"
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
index f1a618e..4c71089 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaUtilOps.td
@@ -18,7 +18,6 @@
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
-include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Dialect/Tosa/IR/TosaInterfaces.td"
@@ -80,104 +79,4 @@ def Tosa_YieldOp : Tosa_Op<"yield", [
let assemblyFormat = "$inputs attr-dict `:` type($inputs)";
}
-//===----------------------------------------------------------------------===//
-// Operator: variable
-//===----------------------------------------------------------------------===//
-def Tosa_VariableOp : Tosa_Op<"variable", [Symbol]> {
- let summary = "Defines a variable";
-
- let description = [{
- Defines a new TOSA variable. This is a persistent mutable value across multiple
- TOSA graph invocations. Modifications are expressed using read/write semantics.
- }];
-
- let arguments = (ins
- // Note: "sym_name" is used as opposed to "name" in the specification,
- // since a Symbol must be named "sym_name" for it to be recognised by
- // the containing SymbolTable.
- SymbolNameAttr:$sym_name,
- IndexElementsAttr:$var_shape,
- TypeAttr:$type,
- OptionalAttr<AnyAttr>:$initial_value
- );
-
- list<Availability> availability = [
- Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[Tosa_EXT_VARIABLE]>,
- ];
-
- let hasCustomAssemblyFormat = 1;
-
- let assemblyFormat = [{
- $sym_name
- attr-dict
- custom<VariableOpTypeOrInitialValue>($var_shape, $type, $initial_value)
- }];
-
- let builders = [Tosa_VariableOpBuilder];
-
- let extraClassDeclaration = [{
- ::llvm::StringRef getName() {
- return getSymName();
- }
- }];
-}
-
-//===----------------------------------------------------------------------===//
-// Operator: variable_write
-//===----------------------------------------------------------------------===//
-def Tosa_VariableWriteOp : Tosa_Op<"variable_write", []> {
- let summary = "write_buffer operator";
-
- let description = [{
- Assigns a value to the pseudo-buffer resource holding a persistent mutable tensor.
- }];
-
- let arguments = (ins
- SymbolNameAttr:$name,
- Tosa_Tensor:$input1
- );
-
- list<Availability> availability = [
- Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[Tosa_EXT_VARIABLE]>,
- ];
-
- let assemblyFormat = [{
- $name attr-dict `,` $input1 `:` type($input1)
- }];
-
- let hasVerifier = 1;
-}
-
-//===----------------------------------------------------------------------===//
-// Operator: variable_read
-//===----------------------------------------------------------------------===//
-def Tosa_VariableReadOp : Tosa_Op<"variable_read", []> {
- let summary = "read_buffer operator";
-
- let description = [{
- Reads the value from a pseudo-buffer resource holding a persistent mutable tensor.
- }];
-
- let arguments = (ins
- SymbolNameAttr:$name
- );
-
- let results = (outs
- Tosa_Tensor:$output1
- );
-
- list<Availability> availability = [
- Profile<[Tosa_PRO_INT, Tosa_PRO_FP]>,
- Extension<[Tosa_EXT_VARIABLE]>,
- ];
-
- let assemblyFormat = [{
- $name attr-dict `:` type($output1)
- }];
-
- let hasVerifier = 1;
-}
-
#endif // TOSA_UTIL_OPS
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
index 0519f7b..dcb2ad5 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/IntelGpuXe2.h
@@ -23,8 +23,6 @@
#include <map>
#include <string>
-#define DEBUG_TYPE "xegpu-uarch"
-
using namespace mlir;
using namespace mlir::xegpu::uArch;
@@ -33,21 +31,156 @@ namespace xegpu {
namespace uArch {
struct Xe2Plus : public uArch {
+ Xe2Plus(StringRef archName, StringRef archDescription,
+ llvm::ArrayRef<const Instruction *> instructionRegistry,
+ const XeCoreInfo &xeCore)
+ : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
+ int getSubgroupSize() const override { return 16; }
+ unsigned getGeneralPackedFormatBitSize() const override { return 32; }
+
+protected:
XeCoreInfo xeCore;
- Xe2Plus(const std::string &archName, const std::string &archDescription,
- const XeCoreInfo &xeCore,
- const std::map<RegisterFileType, RegisterFileInfo> &regInfo = {},
- const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
- const std::map<InstructionKind, std::shared_ptr<Instruction>>
- &instrs = {})
- : uArch(archName, archDescription, regInfo, cacheInfo, instrs),
- xeCore(xeCore) {}
};
-// struct to represent DPAS instruction
-struct DPASInstruction : public Instruction, public MMAInstructionInterface {
- DPASInstruction()
- : Instruction(InstructionKind::DPAS, InstructionScope::Subgroup) {}
+//===----------------------------------------------------------------------===//
+// uArch instructions
+//===----------------------------------------------------------------------===//
+struct Subgroup2DBlockStoreInstruction : public Instruction {
+ Subgroup2DBlockStoreInstruction()
+ : Instruction(InstructionKind::Subgroup2DBlockStore,
+ InstructionScope::Subgroup) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::Subgroup2DBlockStore;
+ }
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
+ std::optional<
+ std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
+ getBlockWidthHeightCount(Type elemTy) const {
+ const static int kHeight[] = {1, 2, 4, 8};
+ const static int kWidth16[] = {16};
+ const static int kWidth32[] = {16};
+ const static int kCount[] = {1};
+ const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
+ if (elemByteSize == 1)
+ return std::make_tuple(llvm::ArrayRef<int>(kWidth32),
+ llvm::ArrayRef<int>(kHeight),
+ llvm::ArrayRef<int>(kCount));
+ else if (elemByteSize == 2 || elemByteSize == 4)
+ return std::make_tuple(llvm::ArrayRef<int>(kWidth16),
+ llvm::ArrayRef<int>(kHeight),
+ llvm::ArrayRef<int>(kCount));
+ return std::nullopt;
+ }
+
+ int32_t getPackedFormatBitSize() const { return 16; }
+};
+
+struct Subgroup2DBlockLoadInstruction : public Instruction {
+ Subgroup2DBlockLoadInstruction()
+ : Instruction(InstructionKind::Subgroup2DBlockLoad,
+ InstructionScope::Subgroup) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::Subgroup2DBlockLoad;
+ }
+
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
+ std::optional<
+ std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
+ getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
+ bool upConv = false) const {
+ static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
+ static const int kHeightAtLeast8[] = {8, 16, 32};
+ static const int kHeightAtLeast16[] = {16, 32};
+ static const int kHeightAtLeast32[] = {32};
+
+ static const int kWidth32[] = {32};
+ static const int kWidth16[] = {16};
+ static const int kWidth8[] = {8};
+
+ static const int32_t kCount1[] = {1};
+ static const int32_t kCount2[] = {1, 2};
+ static const int32_t kCount4[] = {1, 2, 4};
+ static const int32_t kCount4Only[] = {4};
+ // (elemBytes, transform, transpose, upConvert)
+ using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
+ // (widths, heights, counts)
+ using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
+ llvm::ArrayRef<int32_t>>;
+ static const llvm::DenseMap<Key, Value> kMap = {
+ {{1, false, false, false}, {kWidth32, kHeightAtLeast1, kCount2}},
+ {{1, false, false, true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
+ {{2, false, false, false}, {kWidth16, kHeightAtLeast1, kCount2}},
+ {{4, false, false, false}, {kWidth16, kHeightAtLeast1, kCount1}},
+ // Block Loads with Transform:
+ {{1, true, false, false}, {kWidth16, kHeightAtLeast32, kCount4}},
+ {{2, true, false, false}, {kWidth16, kHeightAtLeast16, kCount2}},
+ // Block Loads with Transpose:
+ {{4, false, true, false}, {kWidth8, kHeightAtLeast16, kCount1}},
+ };
+ const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
+ auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
+ if (it != kMap.end())
+ return it->second;
+ return std::nullopt;
+ }
+
+ int32_t getPackedFormatBitSize() const { return 16; }
+};
+
+struct Subgroup2DBlockPrefetchInstruction : public Instruction {
+ Subgroup2DBlockPrefetchInstruction()
+ : Instruction(InstructionKind::Subgroup2DBlockPrefetch,
+ InstructionScope::Subgroup) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() == InstructionKind::Subgroup2DBlockPrefetch;
+ }
+ // Source :
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
+ std::optional<
+ std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
+ getBlockWidthHeightCount(Type elemTy) const {
+ static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
+
+ static const int kWidth32[] = {32};
+ static const int kWidth16[] = {16};
+
+ static const int32_t kCount1[] = {1};
+ static const int32_t kCount2[] = {1, 2};
+ // elemBytes
+ using Key = int;
+ // (widths, heights, counts)
+ using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
+ llvm::ArrayRef<int32_t>>;
+ static const llvm::DenseMap<Key, Value> kMap = {
+ {1, {kWidth32, kHeightAtLeast1, kCount2}},
+ {2, {kWidth16, kHeightAtLeast1, kCount2}},
+ {4, {kWidth16, kHeightAtLeast1, kCount1}},
+ };
+ const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
+ auto it = kMap.find(elemByteSize);
+ if (it != kMap.end())
+ return it->second;
+ return std::nullopt;
+ }
+ int32_t getPackedFormatBitSize() const { return 16; }
+};
+
+struct SubgroupMatrixMultiplyAcc : public Instruction,
+ public MMAInstructionInterface {
+ SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA,
+ unsigned packedFormatBitSizeB)
+ : Instruction(InstructionKind::SubgroupMatrixMultiplyAcc,
+ InstructionScope::Subgroup),
+ packedFormatBitSizeA(packedFormatBitSizeA),
+ packedFormatBitSizeB(packedFormatBitSizeB) {}
+ static bool classof(const Instruction *B) {
+ return B->getInstructionKind() ==
+ InstructionKind::SubgroupMatrixMultiplyAcc;
+ }
+ // Source:
+ // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
// Override all virtuals from MatrixOpInterface
virtual llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
@@ -67,84 +200,91 @@ struct DPASInstruction : public Instruction, public MMAInstructionInterface {
std::pair<uint32_t, uint32_t> CShape,
std::pair<uint32_t, uint32_t> DShape, Type AType,
Type BType, Type CType, Type DType) override;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) override;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) override;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) override;
+ virtual llvm::SmallVector<uint32_t, 8>
+ getSupportedM(Type type) const override;
+ virtual llvm::SmallVector<uint32_t, 8>
+ getSupportedK(Type type) const override;
+ virtual llvm::SmallVector<uint32_t, 8>
+ getSupportedN(Type type) const override;
+
+ unsigned getPackedFormatBitSizeA() const { return packedFormatBitSizeA; }
+ unsigned getPackedFormatBitSizeB() const { return packedFormatBitSizeB; }
+
+protected:
+ const unsigned packedFormatBitSizeA;
+ const unsigned packedFormatBitSizeB;
};
-struct PVCuArch : public Xe2Plus {
- // Maintaines ownership of the instructions owned by PVUarch
- llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+//===----------------------------------------------------------------------===//
+// uArch instances
+//===----------------------------------------------------------------------===//
+
+struct PVCuArch final : public Xe2Plus {
+ static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
+ static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
+ static const Subgroup2DBlockLoadInstruction loadNdInst;
+ static const Subgroup2DBlockStoreInstruction storeNdInst;
+ static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
+ static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
+ &prefetchNdInst};
+ return arr;
+ }
+
PVCuArch()
: Xe2Plus("pvc", // archName
"Ponte Vecchio Architecture", // archDescription
- XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8), // xeCore
- {/* registerFileInfo */}, // Optional: empty
- {/* cacheInfo */}, // Optional: empty
- {/* instructions */} // Optional: empty
- ) {
- // Intialize register file info
- // GRF
- this->registerFileInfo.emplace(
- RegisterFileType::GRF,
- RegisterFileInfo(
- 64 * 1024, // size in bits
- {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
- {128, 256} // registers per thread per mode
- ));
- // Initialize cache info
- // L1 cache, XeCore level
- this->cacheInfo.push_back(
- CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L1));
- // L2 cache, XeStack level
- this->cacheInfo.push_back(
- CacheInfo(512 * 1024, 64, CacheHierarchyLevel::L2));
-
- // Add the instructions-
- auto dpas = std::make_shared<DPASInstruction>();
- instructions.emplace(dpas->getInstructionKind(), dpas);
- owned_instructions.push_back(dpas);
+ getInstructionRegistryArr(),
+ XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
+ ) {}
+ static const uArch *getInstance() {
+ static const PVCuArch instance;
+ return reinterpret_cast<const uArch *>(&instance);
}
};
struct BMGuArch : public Xe2Plus {
- // Maintaines ownership of the instructions owned by PVUarch
- llvm::SmallVector<std::shared_ptr<Instruction>, 8> owned_instructions;
+ static llvm::ArrayRef<const Instruction *> getInstructionRegistryArr() {
+ static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
+ static const Subgroup2DBlockLoadInstruction loadNdInst;
+ static const Subgroup2DBlockStoreInstruction storeNdInst;
+ static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
+ static const Instruction *arr[] = {&dpasInst, &loadNdInst, &storeNdInst,
+ &prefetchNdInst};
+ return arr;
+ }
+
BMGuArch()
: Xe2Plus("bmg", // archName
"Battlemage Architecture", // archDescription
- XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8), // xeCore
- {/* registerFileInfo */}, // Optional: empty
- {/* cacheInfo */}, // Optional: empty
- {/* instructions */} // Optional: empty
- ) {
- // Intialize register file info
- // GRF
- this->registerFileInfo[RegisterFileType::GRF] = RegisterFileInfo(
- 64 * 1024, // size in bits
- {RegisterFileMode::Small, RegisterFileMode::Large}, // GRF modes
- {128, 256} // registers per thread per mode
- );
- // Initialize cache info
- // L1 cache, XeCore level
- this->cacheInfo.push_back(
- CacheInfo(256 * 1024, 64, CacheHierarchyLevel::L1));
- // L2 cache, XeStack level
- this->cacheInfo.push_back(
- CacheInfo(18 * 1024 * 1024, 256, CacheHierarchyLevel::L2));
-
- // Add the instructions
- auto dpas = std::make_shared<DPASInstruction>();
- instructions.emplace(dpas->getInstructionKind(), dpas);
- owned_instructions.push_back(dpas);
+ getInstructionRegistryArr(),
+ XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
+ ) {}
+ static const uArch *getInstance() {
+ static const BMGuArch instance;
+ return reinterpret_cast<const uArch *>(&instance);
}
};
+
+inline const uArch *getUArch(llvm::StringRef archName) {
+ if (archName.equals_insensitive("pvc"))
+ return PVCuArch::getInstance();
+ else if (archName.equals_insensitive("bmg"))
+ return BMGuArch::getInstance();
+
+ return nullptr;
+}
+
} // namespace uArch
} // namespace xegpu
} // namespace mlir
+//===----------------------------------------------------------------------===//
+// Instruction implementations
+//===----------------------------------------------------------------------===//
+
inline llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16>
-DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
+SubgroupMatrixMultiplyAcc::getSupportedShapes(Type dataType,
+ MMAOpndKind matrixType) {
auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
const llvm::SmallVector<uint32_t, 8> &b)
-> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
@@ -180,8 +320,8 @@ DPASInstruction::getSupportedShapes(Type dataType, MMAOpndKind matrixType) {
}
inline llvm::SmallVector<Type, 8>
-DPASInstruction::getSupportedTypes(MLIRContext &context,
- MMAOpndKind matrixType) {
+SubgroupMatrixMultiplyAcc::getSupportedTypes(MLIRContext &context,
+ MMAOpndKind matrixType) {
Type bf16Type = BFloat16Type::get(&context);
Type f16Type = Float16Type::get(&context);
Type tf32Type = FloatTF32Type::get(&context);
@@ -200,8 +340,10 @@ DPASInstruction::getSupportedTypes(MLIRContext &context,
return {};
}
-inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
- Type CType, Type DType) {
+inline bool SubgroupMatrixMultiplyAcc::checkSupportedTypes(Type AType,
+ Type BType,
+ Type CType,
+ Type DType) {
if (AType.isF16() || BType.isF16()) {
if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
(!DType.isF32() && !DType.isF16())) {
@@ -231,7 +373,7 @@ inline bool DPASInstruction::checkSupportedTypes(Type AType, Type BType,
return true;
}
-inline bool DPASInstruction::checkSupportedShapesAndTypes(
+inline bool SubgroupMatrixMultiplyAcc::checkSupportedShapesAndTypes(
std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
Type AType, Type BType, Type CType, Type DType) {
@@ -246,23 +388,21 @@ inline bool DPASInstruction::checkSupportedShapesAndTypes(
checkSupportedTypes(AType, BType, CType, DType);
}
-inline bool DPASInstruction::validate(std::pair<uint32_t, uint32_t> AShape,
- std::pair<uint32_t, uint32_t> BShape,
- std::pair<uint32_t, uint32_t> CShape,
- std::pair<uint32_t, uint32_t> DShape,
- Type AType, Type BType, Type CType,
- Type DType) {
+inline bool SubgroupMatrixMultiplyAcc::validate(
+ std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
+ std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
+ Type AType, Type BType, Type CType, Type DType) {
return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
BType, CType, DType);
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedM(Type type) {
+SubgroupMatrixMultiplyAcc::getSupportedM(Type type) const {
return {1, 2, 3, 4, 5, 6, 7, 8};
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedK(Type type) {
+SubgroupMatrixMultiplyAcc::getSupportedK(Type type) const {
// assert if data type is not int or float type
assert(type.isIntOrFloat() && "Matrix type must be int or float");
auto bitWidth = type.getIntOrFloatBitWidth();
@@ -290,7 +430,7 @@ DPASInstruction::getSupportedK(Type type) {
}
inline llvm::SmallVector<uint32_t, 8>
-DPASInstruction::getSupportedN(Type type) {
+SubgroupMatrixMultiplyAcc::getSupportedN(Type type) const {
return {16};
}
diff --git a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
index 955994e..ea33e88 100644
--- a/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
+++ b/mlir/include/mlir/Dialect/XeGPU/uArch/uArchBase.h
@@ -32,8 +32,11 @@ namespace uArch {
// An enum class to represent the scope of an instruction
enum class InstructionScope { Lane, Subgroup, Workgroup, Cluster };
enum class InstructionKind {
- DPAS, // Dot Product Accumulate Systolic (DPAS) is a matrix
- // multiply-add operation
+ SubgroupMatrixMultiplyAcc, // Dot Product Accumulate Systolic (DPAS) is a
+ // matrix multiply-add operation
+ Subgroup2DBlockStore, // Subgroup-level 2D block write instruction
+ Subgroup2DBlockLoad, // Subgroup-level 2D block load instruction
+ Subgroup2DBlockPrefetch // Subgroup-level 2D block prefetch instruction
// @TODO: Add more instructions as needed
};
@@ -46,14 +49,20 @@ struct Instruction {
Instruction(InstructionKind kind, InstructionScope scope)
: instKind(kind), scope(scope) {}
- virtual ~Instruction() = default;
+ ~Instruction() = default;
// Get methods
- InstructionKind getInstructionKind() { return instKind; }
- InstructionScope getScope() { return scope; }
+ InstructionKind getInstructionKind() const { return instKind; }
+ InstructionScope getScope() const { return scope; }
static llvm::StringRef toString(InstructionKind instKind) {
switch (instKind) {
- case InstructionKind::DPAS:
+ case InstructionKind::SubgroupMatrixMultiplyAcc:
return "dpas";
+ case InstructionKind::Subgroup2DBlockStore:
+ return "store_nd";
+ case InstructionKind::Subgroup2DBlockLoad:
+ return "load_nd";
+ case InstructionKind::Subgroup2DBlockPrefetch:
+ return "prefetch_nd";
}
llvm_unreachable("Unknown InstructionKind");
}
@@ -61,14 +70,14 @@ struct Instruction {
static std::optional<InstructionKind>
parseInstructionKind(llvm::StringRef str) {
if (str.equals_insensitive("dpas"))
- return InstructionKind::DPAS;
+ return InstructionKind::SubgroupMatrixMultiplyAcc;
return std::nullopt;
}
protected:
- InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
- InstructionScope scope; // scope of the instruction (e.g., lane, subgroup,
- // workgroup, cluster)
+ const InstructionKind instKind; // Specific InstructionKind (e.g., DPAS)
+ const InstructionScope scope; // scope of the instruction (e.g., lane,
+ // subgroup, workgroup, cluster)
// @TODO: Add more fields as needed
};
@@ -129,61 +138,36 @@ protected:
// latency, throughput, bandwidth)
};
-// A struct to represent the uArch
-// This struct is used to represent the microarchitecture of a target device.
struct uArch {
// Constructor
- uArch(
- const std::string &name, const std::string &description,
- const std::map<RegisterFileType, RegisterFileInfo> &registerFileInfo = {},
- const llvm::SmallVector<CacheInfo, 4> &cacheInfo = {},
- const std::map<InstructionKind, std::shared_ptr<Instruction>>
- &instructions = {})
- : name(name), description(description),
- registerFileInfo(registerFileInfo), cacheInfo(cacheInfo),
- instructions(instructions) {}
-
- // Get methods
- const std::string &getName() const { return name; }
-
- const std::string &getDescription() const { return description; }
-
- const std::map<RegisterFileType, RegisterFileInfo> &
- getRegisterFileInfo() const {
- return registerFileInfo;
- }
-
- const llvm::SmallVector<CacheInfo, 4> &getCacheInfo() const {
- return cacheInfo;
- }
-
- const std::map<InstructionKind, std::shared_ptr<Instruction>> &
- getInstructions() const {
- return instructions;
+ uArch(StringRef name, StringRef description,
+ llvm::ArrayRef<const Instruction *> instructionRegistry)
+ : name(name), description(description) {
+ for (const Instruction *instr : instructionRegistry)
+ this->instructionRegistry[instr->getInstructionKind()] = instr;
}
-
- // Get the name of the supported instruction names for that
- // architecture. It returns the names of the instructions added to the uArch.
- llvm::SmallVector<StringRef, 8> getSupportedInstructionNames() const {
- llvm::SmallVector<StringRef, 8> instructionNames;
- for (const auto &inst : instructions) {
- instructionNames.push_back(Instruction::toString(inst.first));
- }
- return instructionNames;
+ virtual ~uArch() = default;
+ StringRef getName() const { return name; }
+ StringRef getDescription() const { return description; }
+ virtual int getSubgroupSize() const = 0;
+ virtual unsigned getGeneralPackedFormatBitSize() const = 0;
+
+ const Instruction *getInstruction(InstructionKind instKind) const {
+ auto it = instructionRegistry.find(instKind);
+ assert(it != instructionRegistry.end() &&
+ "Instruction not found in registry");
+ return it->second;
}
- // Checks if an instruction is supported in this uArch
- bool checkSupportedInstruction(InstructionKind instr) const {
- return instructions.find(instr) != instructions.end();
+ bool isSupportedInstruction(InstructionKind instr) const {
+ return instructionRegistry.contains(instr);
}
protected:
- std::string name; // Name of the uArch, similar to target triple
- std::string description;
- std::map<RegisterFileType, RegisterFileInfo> registerFileInfo;
- llvm::SmallVector<CacheInfo, 4> cacheInfo;
- std::map<InstructionKind, std::shared_ptr<Instruction>>
- instructions; // set of instructions supported by the uArch
+ StringRef name;
+ StringRef description;
+ llvm::SmallDenseMap<InstructionKind, const Instruction *, 32>
+ instructionRegistry;
};
// A struct to represent shared memory information
@@ -251,9 +235,9 @@ struct MMAInstructionInterface {
std::pair<uint32_t, uint32_t> CShape,
std::pair<uint32_t, uint32_t> DShape, Type AType,
Type BType, Type CType, Type DType) = 0;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) = 0;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) = 0;
- virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedM(Type type) const = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedK(Type type) const = 0;
+ virtual llvm::SmallVector<uint32_t, 8> getSupportedN(Type type) const = 0;
virtual ~MMAInstructionInterface() = default;
};
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 06d0256..cda4fe1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -598,7 +598,7 @@ class PyOpOperand {
public:
PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {}
- PyOpView getOwner() {
+ nb::typed<nb::object, PyOpView> getOwner() {
MlirOperation owner = mlirOpOperandGetOwner(opOperand);
PyMlirContextRef context =
PyMlirContext::forContext(mlirOperationGetContext(owner));
diff --git a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
index b711e33..a4c66e1 100644
--- a/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
+++ b/mlir/lib/Conversion/PDLToPDLInterp/PDLToPDLInterp.cpp
@@ -692,7 +692,7 @@ SymbolRefAttr PatternLowering::generateRewriter(
llvm::map_range(rewriter.getExternalArgs(), mapRewriteValue);
args.append(mappedArgs.begin(), mappedArgs.end());
pdl_interp::ApplyRewriteOp::create(builder, rewriter.getLoc(),
- /*resultTypes=*/TypeRange(), rewriteName,
+ /*results=*/TypeRange(), rewriteName,
args);
} else {
// Otherwise this is a dag rewriter defined using PDL operations.
diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
index 4c4965e..585b6da 100644
--- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
+++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp
@@ -422,11 +422,11 @@ LogicalResult MFMAOp::verify() {
Type sourceElem = sourceType, destElem = destType;
uint32_t sourceLen = 1, destLen = 1;
- if (auto sourceVector = llvm::dyn_cast<VectorType>(sourceType)) {
+ if (auto sourceVector = dyn_cast<VectorType>(sourceType)) {
sourceLen = sourceVector.getNumElements();
sourceElem = sourceVector.getElementType();
}
- if (auto destVector = llvm::dyn_cast<VectorType>(destType)) {
+ if (auto destVector = dyn_cast<VectorType>(destType)) {
destLen = destVector.getNumElements();
destElem = destVector.getElementType();
}
@@ -451,7 +451,7 @@ LogicalResult MFMAOp::verify() {
return emitOpError("expected both non-small-float source operand types "
"to match exactly");
}
- // Normalize the wider integer types the compiler expects to i8
+ // Normalize the wider integer types the compiler expects to i8.
if (sourceElem.isInteger(32)) {
sourceLen *= 4;
sourceElem = b.getI8Type();
diff --git a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
index 50a0f3d..e08cc6f 100644
--- a/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp
@@ -978,12 +978,11 @@ static Operation *vectorizeAffineApplyOp(AffineApplyOp applyOp,
LLVM_DEBUG(
dbgs() << "\n[early-vect]+++++ affine.apply on vector operand\n");
return nullptr;
- } else {
- Value updatedOperand = state.valueScalarReplacement.lookupOrNull(operand);
- if (!updatedOperand)
- updatedOperand = operand;
- updatedOperands.push_back(updatedOperand);
}
+ Value updatedOperand = state.valueScalarReplacement.lookupOrNull(operand);
+ if (!updatedOperand)
+ updatedOperand = operand;
+ updatedOperands.push_back(updatedOperand);
}
auto newApplyOp = AffineApplyOp::create(
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
index d925c19..a651710 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/TileAllocation.cpp
@@ -216,8 +216,8 @@ void splitCondBranches(IRRewriter &rewriter, FunctionOpInterface function) {
for (auto condBranch : worklist) {
auto loc = condBranch.getLoc();
Block *block = condBranch->getBlock();
- auto newTrueBranch = rewriter.splitBlock(block, block->end());
- auto newFalseBranch = rewriter.splitBlock(block, block->end());
+ auto *newTrueBranch = rewriter.splitBlock(block, block->end());
+ auto *newFalseBranch = rewriter.splitBlock(block, block->end());
insertJump(loc, newTrueBranch, condBranch.getTrueDest(),
condBranch.getTrueDestOperands());
insertJump(loc, newFalseBranch, condBranch.getFalseDest(),
@@ -382,7 +382,7 @@ gatherTileLiveRanges(DenseMap<Operation *, unsigned> const &operationToIndexMap,
// Find or create a live range for `value`.
auto [it, _] = liveRanges.try_emplace(value, liveRangeAllocator);
LiveRange &valueLiveRange = it->second;
- auto lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
+ auto *lastUseInBlock = livenessInfo.getEndOperation(value, firstUseOrDef);
// Add the interval [firstUseOrDef, lastUseInBlock) to the live range.
unsigned startOpIdx =
operationToIndexMap.at(firstUseOrDef) + (liveAtBlockEntry ? -1 : 0);
diff --git a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
index a15bf89..6fa8ce4 100644
--- a/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -66,7 +66,7 @@ struct ExpandShapeOpInterface
ValueBoundsConstraintSet &cstr) const {
auto expandOp = cast<memref::ExpandShapeOp>(op);
assert(value == expandOp.getResult() && "invalid value");
- cstr.bound(value)[dim] == expandOp.getOutputShape()[dim];
+ cstr.bound(value)[dim] == expandOp.getMixedOutputShape()[dim];
}
};
diff --git a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
index 291da1f..14152c5 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/RuntimeOpVerification.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
using namespace mlir;
@@ -273,7 +274,9 @@ struct SubViewOpInterface
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
auto metadataOp =
ExtractStridedMetadataOp::create(builder, loc, subView.getSource());
- for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+ for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+ // Reset insertion point to before the operation for each dimension
+ builder.setInsertionPoint(subView);
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, subView.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(builder, loc,
@@ -290,6 +293,16 @@ struct SubViewOpInterface
std::to_string(i) +
" is out-of-bounds"));
+ // Only verify if size > 0
+ Value sizeIsNonZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+ auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
+ sizeIsNonZero, /*withElseRegion=*/true);
+
+ // Populate the "then" region (for size > 0).
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
@@ -298,8 +311,20 @@ struct SubViewOpInterface
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
+
+ scf::YieldOp::create(builder, loc, lastPosInBounds);
+
+ // Populate the "else" region (for size == 0).
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ Value trueVal =
+ arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
+ scf::YieldOp::create(builder, loc, trueVal);
+
+ builder.setInsertionPointAfter(ifOp);
+ Value finalCondition = ifOp.getResult(0);
+
cf::AssertOp::create(
- builder, loc, lastPosInBounds,
+ builder, loc, finalCondition,
generateErrorMessage(op,
"subview runs out-of-bounds along dimension " +
std::to_string(i)));
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 744a595..1ab01d8 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -111,10 +111,8 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region &region,
return nullptr;
}
-/// Helper function to compute the difference between two values. This is used
-/// by the loop implementations to compute the trip count.
-static std::optional<llvm::APSInt> computeUbMinusLb(Value lb, Value ub,
- bool isSigned) {
+std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
+ bool isSigned) {
llvm::APSInt diff;
auto addOp = ub.getDefiningOp<arith::AddIOp>();
if (!addOp)
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 10eae89..888dd44 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -291,47 +291,61 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
return arith::DivUIOp::create(builder, loc, sum, divisor);
}
-/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
-/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
-/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
-/// unrolled iteration using annotateFn.
-static void generateUnrolledLoop(
- Block *loopBodyBlock, Value forOpIV, uint64_t unrollFactor,
+void mlir::generateUnrolledLoop(
+ Block *loopBodyBlock, Value iv, uint64_t unrollFactor,
function_ref<Value(unsigned, Value, OpBuilder)> ivRemapFn,
function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
- ValueRange iterArgs, ValueRange yieldedValues) {
+ ValueRange iterArgs, ValueRange yieldedValues,
+ IRMapping *clonedToSrcOpsMap) {
+
+ // Check if the op was cloned from another source op, and return it if found
+ // (or the same op if not found)
+ auto findOriginalSrcOp =
+ [](Operation *op, const IRMapping &clonedToSrcOpsMap) -> Operation * {
+ Operation *srcOp = op;
+ // If the source op derives from another op: traverse the chain to find the
+ // original source op
+ while (srcOp && clonedToSrcOpsMap.contains(srcOp))
+ srcOp = clonedToSrcOpsMap.lookup(srcOp);
+ return srcOp;
+ };
+
// Builder to insert unrolled bodies just before the terminator of the body of
- // 'forOp'.
+ // the loop.
auto builder = OpBuilder::atBlockTerminator(loopBodyBlock);
- constexpr auto defaultAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
+ static const auto noopAnnotateFn = [](unsigned, Operation *, OpBuilder) {};
if (!annotateFn)
- annotateFn = defaultAnnotateFn;
+ annotateFn = noopAnnotateFn;
// Keep a pointer to the last non-terminator operation in the original block
// so that we know what to clone (since we are doing this in-place).
Block::iterator srcBlockEnd = std::prev(loopBodyBlock->end(), 2);
- // Unroll the contents of 'forOp' (append unrollFactor - 1 additional copies).
+ // Unroll the contents of the loop body (append unrollFactor - 1 additional
+ // copies).
SmallVector<Value, 4> lastYielded(yieldedValues);
for (unsigned i = 1; i < unrollFactor; i++) {
- IRMapping operandMap;
-
// Prepare operand map.
+ IRMapping operandMap;
operandMap.map(iterArgs, lastYielded);
// If the induction variable is used, create a remapping to the value for
// this unrolled instance.
- if (!forOpIV.use_empty()) {
- Value ivUnroll = ivRemapFn(i, forOpIV, builder);
- operandMap.map(forOpIV, ivUnroll);
+ if (!iv.use_empty()) {
+ Value ivUnroll = ivRemapFn(i, iv, builder);
+ operandMap.map(iv, ivUnroll);
}
// Clone the original body of 'forOp'.
for (auto it = loopBodyBlock->begin(); it != std::next(srcBlockEnd); it++) {
- Operation *clonedOp = builder.clone(*it, operandMap);
+ Operation *srcOp = &(*it);
+ Operation *clonedOp = builder.clone(*srcOp, operandMap);
annotateFn(i, clonedOp, builder);
+ if (clonedToSrcOpsMap)
+ clonedToSrcOpsMap->map(clonedOp,
+ findOriginalSrcOp(srcOp, *clonedToSrcOpsMap));
}
// Update yielded values.
@@ -1544,3 +1558,100 @@ bool mlir::isPerfectlyNestedForLoops(
}
return true;
}
+
+llvm::SmallVector<int64_t>
+mlir::getConstLoopTripCounts(mlir::LoopLikeOpInterface loopOp) {
+ std::optional<SmallVector<OpFoldResult>> loBnds = loopOp.getLoopLowerBounds();
+ std::optional<SmallVector<OpFoldResult>> upBnds = loopOp.getLoopUpperBounds();
+ std::optional<SmallVector<OpFoldResult>> steps = loopOp.getLoopSteps();
+ if (!loBnds || !upBnds || !steps)
+ return {};
+ llvm::SmallVector<int64_t> tripCounts;
+ for (auto [lb, ub, step] : llvm::zip(*loBnds, *upBnds, *steps)) {
+ std::optional<llvm::APInt> numIter = constantTripCount(
+ lb, ub, step, /*isSigned=*/true, scf::computeUbMinusLb);
+ if (!numIter)
+ return {};
+ tripCounts.push_back(numIter->getSExtValue());
+ }
+ return tripCounts;
+}
+
+FailureOr<scf::ParallelOp> mlir::parallelLoopUnrollByFactors(
+ scf::ParallelOp op, ArrayRef<uint64_t> unrollFactors,
+ RewriterBase &rewriter,
+ function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
+ IRMapping *clonedToSrcOpsMap) {
+ const unsigned numLoops = op.getNumLoops();
+ assert(llvm::none_of(unrollFactors, [](uint64_t f) { return f == 0; }) &&
+ "Expected positive unroll factors");
+ assert((!unrollFactors.empty() && (unrollFactors.size() <= numLoops)) &&
+ "Expected non-empty unroll factors of size <= to the number of loops");
+
+ // Bail out if no valid unroll factors were provided
+ if (llvm::all_of(unrollFactors, [](uint64_t f) { return f == 1; }))
+ return rewriter.notifyMatchFailure(
+ op, "Unrolling not applied if all factors are 1");
+
+ // Return if the loop body is empty.
+ if (llvm::hasSingleElement(op.getBody()->getOperations()))
+ return rewriter.notifyMatchFailure(op, "Cannot unroll an empty loop body");
+
+ // If the provided unroll factors do not cover all the loop dims, they are
+ // applied to the inner loop dimensions.
+ const unsigned firstLoopDimIdx = numLoops - unrollFactors.size();
+
+ // Make sure that the unroll factors divide the iteration space evenly
+ // TODO: Support unrolling loops with dynamic iteration spaces.
+ const llvm::SmallVector<int64_t> tripCounts = getConstLoopTripCounts(op);
+ if (tripCounts.empty())
+ return rewriter.notifyMatchFailure(
+ op, "Failed to compute constant trip counts for the loop. Note that "
+ "dynamic loop sizes are not supported.");
+
+ for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
+ const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
+ if (tripCounts[dimIdx] % unrollFactor)
+ return rewriter.notifyMatchFailure(
+ op, "Unroll factors don't divide the iteration space evenly");
+ }
+
+ std::optional<SmallVector<OpFoldResult>> maybeFoldSteps = op.getLoopSteps();
+ if (!maybeFoldSteps)
+ return rewriter.notifyMatchFailure(op, "Failed to retrieve loop steps");
+ llvm::SmallVector<size_t> steps{};
+ for (auto step : *maybeFoldSteps)
+ steps.push_back(static_cast<size_t>(*getConstantIntValue(step)));
+
+ for (unsigned dimIdx = firstLoopDimIdx; dimIdx < numLoops; dimIdx++) {
+ const uint64_t unrollFactor = unrollFactors[dimIdx - firstLoopDimIdx];
+ if (unrollFactor == 1)
+ continue;
+ const size_t origStep = steps[dimIdx];
+ const int64_t newStep = origStep * unrollFactor;
+ IRMapping clonedToSrcOpsMap;
+
+ ValueRange iterArgs = ValueRange(op.getRegionIterArgs());
+ auto yieldedValues = op.getBody()->getTerminator()->getOperands();
+
+ generateUnrolledLoop(
+ op.getBody(), op.getInductionVars()[dimIdx], unrollFactor,
+ [&](unsigned i, Value iv, OpBuilder b) {
+ // iv' = iv + step * i;
+ const AffineExpr expr = b.getAffineDimExpr(0) + (origStep * i);
+ const auto map =
+ b.getDimIdentityMap().dropResult(0).insertResult(expr, 0);
+ return affine::AffineApplyOp::create(b, iv.getLoc(), map,
+ ValueRange{iv});
+ },
+ /*annotateFn*/ annotateFn, iterArgs, yieldedValues, &clonedToSrcOpsMap);
+
+ // Update loop step
+ auto prevInsertPoint = rewriter.saveInsertionPoint();
+ rewriter.setInsertionPoint(op);
+ op.getStepMutable()[dimIdx].assign(
+ arith::ConstantIndexOp::create(rewriter, op.getLoc(), newStep));
+ rewriter.restoreInsertionPoint(prevInsertPoint);
+ }
+ return op;
+}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index fe50865..0c8114d 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1276,12 +1276,19 @@ LogicalResult spirv::GlobalVariableOp::verify() {
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
(*this)->getParentOp(), init.getAttr());
// TODO: Currently only variable initialization with specialization
- // constants and other variables is supported. They could be normal
- // constants in the module scope as well.
- if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
- spirv::SpecConstantCompositeOp>(initOp)) {
+ // constants is supported. There could be normal constants in the module
+ // scope as well.
+ //
+ // In the current setup we also cannot initialize one global variable with
+ // another. The problem is that if we try to initialize pointer of type X
+ // with another pointer type, the validator fails because it expects the
+ // variable to be initialized to be type X, not pointer to X. Now
+ // `spirv.GlobalVariable` only allows pointer type, so in the current design
+ // we cannot initialize one `spirv.GlobalVariable` with another.
+ if (!initOp ||
+ !isa<spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
return emitOpError("initializer must be result of a "
- "spirv.SpecConstant or spirv.GlobalVariable or "
+ "spirv.SpecConstant or "
"spirv.SpecConstantCompositeOp op");
}
}
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
index 73e0f3d..f53d272 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp
@@ -159,14 +159,22 @@ IterationGraphSorter::IterationGraphSorter(
loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
strategy(strategy) {
// One map per tensor.
- assert(loop2InsLvl.size() == ins.size());
+ assert(this->loop2InsLvl.size() == this->ins.size());
// All the affine maps have the same number of dimensions (loops).
assert(llvm::all_equal(llvm::map_range(
- loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
+ this->loop2InsLvl, [](AffineMap m) { return m.getNumDims(); })));
// The number of results of the map should match the rank of the tensor.
- assert(llvm::all_of(llvm::zip(loop2InsLvl, ins), [](auto mvPair) {
+ assert(llvm::all_of(llvm::zip(this->loop2InsLvl, this->ins), [](auto mvPair) {
auto [m, v] = mvPair;
- return m.getNumResults() == cast<ShapedType>(v.getType()).getRank();
+
+ // For ranked types the rank must match.
+ // Simply return true for UnrankedTensorType
+ if (auto shapedType = llvm::dyn_cast<ShapedType>(v.getType())) {
+ return !shapedType.hasRank() ||
+ (m.getNumResults() == shapedType.getRank());
+ }
+ // Non-shaped (scalar) types behave like rank-0.
+ return m.getNumResults() == 0;
}));
itGraph.resize(getNumLoops(), std::vector<bool>(getNumLoops(), false));
diff --git a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
index c031118..753cb95 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/RuntimeOpVerification.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/RuntimeVerifiableOpInterface.h"
@@ -158,7 +159,11 @@ struct ExtractSliceOpInterface
// 0 <= offset + (size - 1) * stride < dim_size
Value zero = arith::ConstantIndexOp::create(builder, loc, 0);
Value one = arith::ConstantIndexOp::create(builder, loc, 1);
- for (int64_t i = 0, e = sourceType.getRank(); i < e; ++i) {
+
+ for (int64_t i : llvm::seq<int64_t>(0, sourceType.getRank())) {
+ // Reset insertion point to before the operation for each dimension
+ builder.setInsertionPoint(extractSliceOp);
+
Value offset = getValueOrCreateConstantIndexOp(
builder, loc, extractSliceOp.getMixedOffsets()[i]);
Value size = getValueOrCreateConstantIndexOp(
@@ -176,6 +181,16 @@ struct ExtractSliceOpInterface
std::to_string(i) +
" is out-of-bounds"));
+ // Only verify if size > 0
+ Value sizeIsNonZero = arith::CmpIOp::create(
+ builder, loc, arith::CmpIPredicate::sgt, size, zero);
+
+ auto ifOp = scf::IfOp::create(builder, loc, builder.getI1Type(),
+ sizeIsNonZero, /*withElseRegion=*/true);
+
+ // Populate the "then" region (for size > 0).
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
// Verify that slice does not run out-of-bounds.
Value sizeMinusOne = arith::SubIOp::create(builder, loc, size, one);
Value sizeMinusOneTimesStride =
@@ -184,8 +199,19 @@ struct ExtractSliceOpInterface
arith::AddIOp::create(builder, loc, offset, sizeMinusOneTimesStride);
Value lastPosInBounds =
generateInBoundsCheck(builder, loc, lastPos, zero, dimSize);
+ scf::YieldOp::create(builder, loc, lastPosInBounds);
+
+ // Populate the "else" region (for size == 0).
+ builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
+ Value trueVal =
+ arith::ConstantOp::create(builder, loc, builder.getBoolAttr(true));
+ scf::YieldOp::create(builder, loc, trueVal);
+
+ builder.setInsertionPointAfter(ifOp);
+ Value finalCondition = ifOp.getResult(0);
+
cf::AssertOp::create(
- builder, loc, lastPosInBounds,
+ builder, loc, finalCondition,
generateErrorMessage(
op, "extract_slice runs out-of-bounds along dimension " +
std::to_string(i)));
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index a85ff10a..293c6af 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -38,7 +38,7 @@ using namespace mlir::tosa;
//===----------------------------------------------------------------------===//
// Check that the zero point of the tensor and padding operations are aligned.
-bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
+static bool checkMatchingPadConstAndZp(Value padConst, Value zp) {
// Check that padConst is a constant value and a scalar tensor
DenseElementsAttr padConstAttr;
if (!matchPattern(padConst, m_Constant(&padConstAttr)) ||
@@ -889,8 +889,9 @@ void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
//===----------------------------------------------------------------------===//
template <typename IntFolder, typename FloatFolder>
-DenseElementsAttr binaryFolder(DenseElementsAttr lhs, DenseElementsAttr rhs,
- RankedTensorType returnTy) {
+static DenseElementsAttr binaryFolder(DenseElementsAttr lhs,
+ DenseElementsAttr rhs,
+ RankedTensorType returnTy) {
if (rhs && lhs && rhs.isSplat() && lhs.isSplat()) {
auto lETy = llvm::cast<ShapedType>(lhs.getType()).getElementType();
auto rETy = llvm::cast<ShapedType>(rhs.getType()).getElementType();
diff --git a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
index 24e9095..f9aa28d5 100644
--- a/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
+++ b/mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
@@ -113,9 +113,12 @@ bool XeGPUDialect::isEvenlyDistributable(llvm::ArrayRef<int64_t> shape,
if (layout.size() != shape.size())
return std::nullopt;
auto ratio = computeShapeRatio(shape, layout);
- if (!ratio.has_value())
+ if (ratio.has_value()) {
+ newShape = ratio.value();
+ } else if (!rr || !computeShapeRatio(layout, shape).has_value()) {
return std::nullopt;
- newShape = ratio.value();
+ }
+ // Round-robin case: continue with original newShape
}
if (data.size()) {
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
index 2c37140..ec5feb8 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUBlocking.cpp
@@ -344,6 +344,13 @@ void XeGPUBlockingPass::runOnOperation() {
xegpu::doSCFStructuralTypeConversionWithTensorType(op, converter);
+ // Remove leading unit dimensions from vector ops and then
+ // do the unrolling.
+ {
+ RewritePatternSet patterns(ctx);
+ vector::populateCastAwayVectorLeadingOneDimPatterns(patterns);
+ (void)applyPatternsGreedily(op, std::move(patterns));
+ }
xegpu::UnrollOptions options;
options.setFilterConstraint(
[&](Operation *op) -> LogicalResult { return success(needsUnroll(op)); });
diff --git a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
index b4605cd..a38993e 100644
--- a/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
+++ b/mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp
@@ -147,7 +147,7 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
}
if (auto arg = dyn_cast<BlockArgument>(value)) {
- auto parentOp = arg.getOwner()->getParentOp();
+ auto *parentOp = arg.getOwner()->getParentOp();
if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
OpOperand *tiedInit = loop.getTiedLoopInit(arg);
if (tiedInit)
diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
index 52162a4..2255633 100644
--- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
+++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp
@@ -239,6 +239,8 @@ ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options,
// Remember all entry-points if object dumping is enabled.
if (options.enableObjectDump) {
for (auto funcOp : m->getRegion(0).getOps<LLVM::LLVMFuncOp>()) {
+ if (funcOp.getBlocks().empty())
+ continue;
StringRef funcName = funcOp.getSymName();
engine->functionNames.push_back(funcName.str());
}
diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp
index 5f63fe6..73219c6 100644
--- a/mlir/lib/IR/MLIRContext.cpp
+++ b/mlir/lib/IR/MLIRContext.cpp
@@ -709,7 +709,7 @@ ArrayRef<RegisteredOperationName> MLIRContext::getRegisteredOperations() {
/// Return information for registered operations by dialect.
ArrayRef<RegisteredOperationName>
MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
- auto lowerBound = llvm::lower_bound(
+ auto *lowerBound = llvm::lower_bound(
impl->sortedRegisteredOperations, dialectName, [](auto &lhs, auto &rhs) {
return lhs.getDialect().getNamespace().compare(rhs);
});
@@ -718,7 +718,7 @@ MLIRContext::getRegisteredOperationsByDialect(StringRef dialectName) {
lowerBound->getDialect().getNamespace() != dialectName)
return ArrayRef<RegisteredOperationName>();
- auto upperBound =
+ auto *upperBound =
std::upper_bound(lowerBound, impl->sortedRegisteredOperations.end(),
dialectName, [](auto &lhs, auto &rhs) {
return lhs.compare(rhs.getDialect().getNamespace());
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index 2acbd03..64e3c5f 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -649,40 +649,38 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
auto *arrayType = llvm::ArrayType::get(elementType, numElements);
if (child->isZeroValue() && !elementType->isFPOrFPVectorTy()) {
return llvm::ConstantAggregateZero::get(arrayType);
- } else {
- if (llvm::ConstantDataSequential::isElementTypeCompatible(
- elementType)) {
- // TODO: Handle all compatible types. This code only handles integer.
- if (isa<llvm::IntegerType>(elementType)) {
- if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) {
- if (ci->getBitWidth() == 8) {
- SmallVector<int8_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
- if (ci->getBitWidth() == 16) {
- SmallVector<int16_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
- if (ci->getBitWidth() == 32) {
- SmallVector<int32_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
- if (ci->getBitWidth() == 64) {
- SmallVector<int64_t> constants(numElements, ci->getZExtValue());
- return llvm::ConstantDataArray::get(elementType->getContext(),
- constants);
- }
+ }
+ if (llvm::ConstantDataSequential::isElementTypeCompatible(elementType)) {
+ // TODO: Handle all compatible types. This code only handles integer.
+ if (isa<llvm::IntegerType>(elementType)) {
+ if (llvm::ConstantInt *ci = dyn_cast<llvm::ConstantInt>(child)) {
+ if (ci->getBitWidth() == 8) {
+ SmallVector<int8_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
+ }
+ if (ci->getBitWidth() == 16) {
+ SmallVector<int16_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
+ }
+ if (ci->getBitWidth() == 32) {
+ SmallVector<int32_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
+ }
+ if (ci->getBitWidth() == 64) {
+ SmallVector<int64_t> constants(numElements, ci->getZExtValue());
+ return llvm::ConstantDataArray::get(elementType->getContext(),
+ constants);
}
}
}
+ }
// std::vector is used here to accomodate large number of elements that
// exceed SmallVector capacity.
std::vector<llvm::Constant *> constants(numElements, child);
return llvm::ConstantArray::get(arrayType, constants);
- }
}
}
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index b88fbaa..29ed5a4 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -89,6 +89,22 @@ static bool isZeroValue(Attribute attr) {
return false;
}
+/// Move all functions declaration before functions definitions. In SPIR-V
+/// "declarations" are functions without a body and "definitions" functions
+/// with a body. This is stronger than necessary. It should be sufficient to
+/// ensure any declarations precede their uses and not all definitions, however
+/// this allows to avoid analysing every function in the module this way.
+static void moveFuncDeclarationsToTop(spirv::ModuleOp moduleOp) {
+ Block::OpListType &ops = moduleOp.getBody()->getOperations();
+ if (ops.empty())
+ return;
+ Operation &firstOp = ops.front();
+ for (Operation &op : llvm::drop_begin(ops))
+ if (auto funcOp = dyn_cast<spirv::FuncOp>(op))
+ if (funcOp.getBody().empty())
+ funcOp->moveBefore(&firstOp);
+}
+
namespace mlir {
namespace spirv {
@@ -119,6 +135,8 @@ LogicalResult Serializer::serialize() {
processMemoryModel();
processDebugInfo();
+ moveFuncDeclarationsToTop(module);
+
// Iterate over the module body to serialize it. Assumptions are that there is
// only one basic block in the moduleOp
for (auto &op : *module.getBody()) {
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
index 39c31d5..c746d76 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma-gfx950.mlir
@@ -8,46 +8,46 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
// CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: rocdl.mfma.f32.32x32x16.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
+ amdgpu.mfma 32x32x16 %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<16xf32>
// CHECK: rocdl.mfma.f32.16x16x32.f16{{.*}}: (vector<8xf16>, vector<8xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
+ amdgpu.mfma 16x16x32 %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf16>, vector<8xf16>, vector<4xf32>
// CHECK: rocdl.mfma.f32.32x32x16.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
+ amdgpu.mfma 32x32x16 %arg3 * %arg3 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<16xf32>
// CHECK: rocdl.mfma.f32.16x16x32.bf16{{.*}}: (vector<8xbf16>, vector<8xbf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
+ amdgpu.mfma 16x16x32 %arg3 * %arg3 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xbf16>, vector<8xbf16>, vector<4xf32>
// CHECK: rocdl.mfma.i32.32x32x32.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- amdgpu.mfma %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>
+ amdgpu.mfma 32x32x32 %arg4 * %arg4 + %arg5 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<16xi32>
// CHECK: rocdl.mfma.i32.16x16x64.i8{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- amdgpu.mfma %arg4 * %arg4 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<4xi32>
+ amdgpu.mfma 16x16x64 %arg4 * %arg4 + %arg6 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<16xi8>, vector<16xi8>, vector<4xi32>
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg7 * %arg7 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
+ amdgpu.mfma 32x32x64 %arg7 * %arg7 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[c0]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg7 * %arg7 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
+ amdgpu.mfma 16x16x128 %arg7 * %arg7 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E4M3FN>, vector<32xf8E4M3FN>, vector<4xf32>
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg8 * %arg8 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
+ amdgpu.mfma 32x32x64 %arg8 * %arg8 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c1]], %[[c1]], %[[c0]], %[[c0]]{{.*}}: (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg8 * %arg8 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
+ amdgpu.mfma 16x16x128 %arg8 * %arg8 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf8E5M2>, vector<32xf8E5M2>, vector<4xf32>
// CHECK: %[[c2:.+]] = llvm.mlir.constant(2 : i32) : i32
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
+ amdgpu.mfma 32x32x64 %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c2]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
+ amdgpu.mfma 16x16x128 %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf6E2M3FN>, vector<4xf32>
// CHECK: %[[c3:.+]] = llvm.mlir.constant(3 : i32) : i32
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
+ amdgpu.mfma 32x32x64 %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c3]], %[[c3]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
+ amdgpu.mfma 16x16x128 %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E3M2FN>, vector<32xf6E3M2FN>, vector<4xf32>
// CHECK-DAG: %[[c4:.+]] = llvm.mlir.constant(4 : i32) : i32
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg11 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
+ amdgpu.mfma 32x32x64 %arg11 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c4]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg11 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
+ amdgpu.mfma 16x16x128 %arg11 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf4E2M1FN>, vector<32xf4E2M1FN>, vector<4xf32>
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c2]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg9 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 64 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<16xf32>
+ amdgpu.mfma 32x32x64 %arg9 * %arg11 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c2]], %[[c4]], %[[c0]], %[[c0]]{{.*}}: (vector<6xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg9 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 128 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<4xf32>
+ amdgpu.mfma 16x16x128 %arg9 * %arg11 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<32xf6E2M3FN>, vector<32xf4E2M1FN>, vector<4xf32>
func.return
}
@@ -55,50 +55,50 @@ func.func @mfma_to_rocdl(%arg0 : vector<8xf16>, %arg1 : vector<16xf32>,
// CHECK-LABEL: func @scaled_mfma_to_rocdl(
// CHECK-SAME: %[[ARG0:.*]]: vector<16xf32>, %[[ARG1:.*]]: vector<4xf32>, %[[ARG2:.*]]: vector<32xf8E4M3FN>, %[[ARG3:.*]]: vector<32xf8E5M2>, %[[ARG4:.*]]: vector<32xf6E2M3FN>, %[[ARG5:.*]]: vector<32xf6E3M2FN>, %[[ARG6:.*]]: vector<32xf4E2M1FN>, %[[ARG7:.*]]: vector<4xf8E8M0FNU>, %[[ARG8:.*]]: f8E8M0FNU
func.func @scaled_mfma_to_rocdl(%arg0 : vector<16xf32>,
- %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
- %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
- %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
- %arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) {
-
+ %arg1 : vector<4xf32>, %arg2 : vector<32xf8E4M3FN>,
+ %arg3 : vector<32xf8E5M2>, %arg4 : vector<32xf6E2M3FN>,
+ %arg5 : vector<32xf6E3M2FN>, %arg6 : vector<32xf4E2M1FN>,
+ %arg7 : vector<4xf8E8M0FNU>, %arg8 : f8E8M0FNU) {
+
// CHECK: %[[c0:.+]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[c1:.+]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[b0:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
// CHECK: %[[z0:.+]] = llvm.zext {{.*}} : i8 to i32
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32>
+ amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- amdgpu.scaled_mfma(%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32>
-
+ amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg2) * (%arg8[1] * %arg2) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf8E4M3FN>, f8E8M0FNU, vector<32xf8E4M3FN>, vector<4xf32>
+
// CHECK: llvm.bitcast
-
+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32>
+ amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<8xi32>, vector<8xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- amdgpu.scaled_mfma(%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32>
-
+ amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg3) * (%arg8[1] * %arg3) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf8E5M2>, f8E8M0FNU, vector<32xf8E5M2>, vector<4xf32>
+
// CHECK: llvm.bitcast
-
+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
+ amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- amdgpu.scaled_mfma(%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32>
-
+ amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg4) * (%arg8[1] * %arg4) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<4xf32>
+
// CHECK: llvm.bitcast
// CHECK: llvm.mlir.constant(3 : i32) : i32
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32>
+ amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<6xi32>, vector<6xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- amdgpu.scaled_mfma(%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32>
-
+ amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg5) * (%arg8[1] * %arg5) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf6E3M2FN>, f8E8M0FNU, vector<32xf6E3M2FN>, vector<4xf32>
+
// CHECK: llvm.bitcast
// CHECK: llvm.mlir.constant(4 : i32) : i32
-
+
// CHECK: rocdl.mfma.scale.f32.32x32x64.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<16xf32>, i32, i32, i32, i32, i32, i32) -> vector<16xf32>
- amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32>
+ amdgpu.scaled_mfma 32x32x64 (%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<16xf32>
// CHECK: rocdl.mfma.scale.f32.16x16x128.f8f6f4{{.*}}, %[[c0]], %[[b0]], %[[c1]], %[[z0]] : (vector<4xi32>, vector<4xi32>, vector<4xf32>, i32, i32, i32, i32, i32, i32) -> vector<4xf32>
- amdgpu.scaled_mfma(%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 { k = 128 : i32, m = 16 : i32, n = 16 : i32 } : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32>
+ amdgpu.scaled_mfma 16x16x128 (%arg7[0] * %arg6) * (%arg8[1] * %arg6) + %arg1 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32>
func.return
}
diff --git a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
index 52db142..e292d98 100644
--- a/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
+++ b/mlir/test/Conversion/AMDGPUToROCDL/mfma.mlir
@@ -9,89 +9,89 @@ func.func @mfma_to_rocdl(%arg0 : f32, %arg1 : vector<32xf32>,
%arg14 : vector<2xf32>, %arg15 : vector<8xf8E5M2FNUZ>,
%arg16 : vector<8xf8E4M3FNUZ>) {
// CHECK: rocdl.mfma.f32.32x32x1f32{{.*}}: (f32, f32, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
- amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : f32, f32, vector<32xf32>
+ amdgpu.mfma 32x32x1 %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, blocks = 2 : i32 } blgp = none : f32, f32, vector<32xf32>
// CHECK: rocdl.mfma.f32.16x16x1f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : f32, f32, vector<16xf32>
+ amdgpu.mfma 16x16x1 %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : f32, f32, vector<16xf32>
// CHECK: rocdl.mfma.f32.4x4x1f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 1 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : f32, f32, vector<4xf32>
+ amdgpu.mfma 4x4x1 %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, blocks = 16 : i32 } blgp = none : f32, f32, vector<4xf32>
// CHECK: rocdl.mfma.f32.32x32x2f32{{.*}}: (f32, f32, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : f32, f32, vector<16xf32>
+ amdgpu.mfma 32x32x2 %arg0 * %arg0 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f32, f32, vector<16xf32>
// CHECK: rocdl.mfma.f32.16x16x4f32{{.*}}: (f32, f32, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f32, f32, vector<4xf32>
+ amdgpu.mfma 16x16x4 %arg0 * %arg0 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f32, f32, vector<4xf32>
// CHECK: rocdl.mfma.f32.32x32x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
- amdgpu.mfma %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<32xf32>
+ amdgpu.mfma 32x32x4 %arg4 * %arg4 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, blocks = 2 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<32xf32>
// CHECK: rocdl.mfma.f32.16x16x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
+ amdgpu.mfma 16x16x4 %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
// CHECK: rocdl.mfma.f32.4x4x4f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+ amdgpu.mfma 4x4x4 %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, blocks = 16 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: rocdl.mfma.f32.32x32x8f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
+ amdgpu.mfma 32x32x8 %arg4 * %arg4 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<16xf32>
// CHECK: rocdl.mfma.f32.16x16x16f16{{.*}}: (vector<4xf16>, vector<4xf16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+ amdgpu.mfma 16x16x16 %arg4 * %arg4 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
// CHECK: %[[BITCAST_4xi8_i32:.+]] = llvm.bitcast {{.*}} : vector<4xi8> to i32
// CHECK: rocdl.mfma.i32.32x32x4i8 %[[BITCAST_4xi8_i32]], %[[BITCAST_4xi8_i32]], {{.*}}: (i32, i32, vector<32xi32>, i32, i32, i32) -> vector<32xi32>
- amdgpu.mfma %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<32xi32>
+ amdgpu.mfma 32x32x4 %arg5 * %arg5 + %arg6 { abid = 0 : i32, cbsz = 0 : i32, blocks = 2 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<32xi32>
// CHECK: rocdl.mfma.i32.16x16x4i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
+ amdgpu.mfma 16x16x4 %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
// CHECK: rocdl.mfma.i32.4x4x4i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32>
+ amdgpu.mfma 4x4x4 %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, blocks = 16 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32>
// CHECK: rocdl.mfma.i32.32x32x8i8{{.*}}: (i32, i32, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- amdgpu.mfma %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
+ amdgpu.mfma 32x32x8 %arg5 * %arg5 + %arg7 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<16xi32>
// CHECK: rocdl.mfma.i32.16x16x16i8{{.*}}: (i32, i32, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- amdgpu.mfma %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32>
+ amdgpu.mfma 16x16x16 %arg5 * %arg5 + %arg8 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xi8>, vector<4xi8>, vector<4xi32>
// CHECK: %[[BITCAST_2xbf16_2xi16:.+]] = llvm.bitcast {{.*}} : vector<2xbf16> to vector<2xi16>
// CHECK: rocdl.mfma.f32.32x32x2bf16 %[[BITCAST_2xbf16_2xi16]], %[[BITCAST_2xbf16_2xi16]], %{{.*}}: (vector<2xi16>, vector<2xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
- amdgpu.mfma %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<32xf32>
+ amdgpu.mfma 32x32x2 %arg9 * %arg9 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, blocks = 2 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<32xf32>
// CHECK: rocdl.mfma.f32.16x16x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
+ amdgpu.mfma 16x16x2 %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
// CHECK: rocdl.mfma.f32.4x4x2bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 2 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
+ amdgpu.mfma 4x4x2 %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, blocks = 16 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
// CHECK: rocdl.mfma.f32.32x32x4bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
+ amdgpu.mfma 32x32x4 %arg9 * %arg9 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<16xf32>
// CHECK: rocdl.mfma.f32.16x16x8bf16{{.*}}: (vector<2xi16>, vector<2xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
+ amdgpu.mfma 16x16x8 %arg9 * %arg9 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<2xbf16>, vector<2xbf16>, vector<4xf32>
// CHECK: %[[BITCAST_4xbf16_4xi16:.+]] = llvm.bitcast {{.*}} : vector<4xbf16> to vector<4xi16>
// CHECK: rocdl.mfma.f32.32x32x4bf16.1k %[[BITCAST_4xbf16_4xi16]], %[[BITCAST_4xbf16_4xi16]], {{.*}}: (vector<4xi16>, vector<4xi16>, vector<32xf32>, i32, i32, i32) -> vector<32xf32>
- amdgpu.mfma %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<32xf32>
+ amdgpu.mfma 32x32x4 %arg10 * %arg10 + %arg1 { abid = 0 : i32, cbsz = 0 : i32, blocks = 2 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<32xf32>
// CHECK: rocdl.mfma.f32.16x16x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 4 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
+ amdgpu.mfma 16x16x4 %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
// CHECK: rocdl.mfma.f32.4x4x4bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 16 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
+ amdgpu.mfma 4x4x4 %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, blocks = 16 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
// CHECK: rocdl.mfma.f32.32x32x8bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
+ amdgpu.mfma 32x32x8 %arg10 * %arg10 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<16xf32>
// CHECK: rocdl.mfma.f32.16x16x16bf16.1k{{.*}}: (vector<4xi16>, vector<4xi16>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
+ amdgpu.mfma 16x16x16 %arg10 * %arg10 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xbf16>, vector<4xbf16>, vector<4xf32>
// CHECK: rocdl.mfma.f64.16x16x4f64{{.*}}: (f64, f64, vector<4xf64>, i32, i32, i32) -> vector<4xf64>
- amdgpu.mfma %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : f64, f64, vector<4xf64>
+ amdgpu.mfma 16x16x4 %arg11 * %arg11 + %arg12 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f64, f64, vector<4xf64>
// CHECK: rocdl.mfma.f64.4x4x4f64{{.*}}: (f64, f64, f64, i32, i32, i32) -> f64
- amdgpu.mfma %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 4 : i32, n = 4 : i32, blocks = 4 : i32 } blgp = none : f64, f64, f64
+ amdgpu.mfma 4x4x4 %arg11 * %arg11 + %arg11 { abid = 0 : i32, cbsz = 0 : i32, blocks = 4 : i32 } blgp = none : f64, f64, f64
// CHECK: %[[BITCAST_8xi8_i64:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
// CHECK: rocdl.mfma.i32.16x16x32.i8 %[[BITCAST_8xi8_i64]], %[[BITCAST_8xi8_i64]], {{.*}}: (i64, i64, vector<4xi32>, i32, i32, i32) -> vector<4xi32>
- amdgpu.mfma %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32>
+ amdgpu.mfma 16x16x32 %arg13 * %arg13 + %arg8 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32>
// CHECK: rocdl.mfma.i32.32x32x16.i8{{.*}}: (i64, i64, vector<16xi32>, i32, i32, i32) -> vector<16xi32>
- amdgpu.mfma %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<16xi32>
+ amdgpu.mfma 32x32x16 %arg13 * %arg13 + %arg7 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xi8>, vector<8xi8>, vector<16xi32>
// CHECK: rocdl.mfma.f32.16x16x8.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 8 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<4xf32>
+ amdgpu.mfma 16x16x8 %arg14 * %arg14 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<4xf32>
// CHECK: rocdl.mfma.f32.32x32x4.xf32{{.*}}: (vector<2xf32>, vector<2xf32>, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 4 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<16xf32>
+ amdgpu.mfma 32x32x4 %arg14 * %arg14 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, reducePrecision } blgp = none : vector<2xf32>, vector<2xf32>, vector<16xf32>
// CHECK: %[[BITCAST_8xi8_i64_1:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
// CHECK: rocdl.mfma.f32.16x16x32.bf8.bf8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_1]], {{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg15 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>
+ amdgpu.mfma 16x16x32 %arg15 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>
// CHECK: %[[BITCAST_8xi8_i64_2:.+]] = llvm.bitcast {{.*}} : vector<8xi8> to i64
// CHECK: rocdl.mfma.f32.16x16x32.bf8.fp8 %[[BITCAST_8xi8_i64_1]], %[[BITCAST_8xi8_i64_2]], {{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg15 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
+ amdgpu.mfma 16x16x32 %arg15 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
// CHECK: rocdl.mfma.f32.16x16x32.fp8.bf8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg16 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>
+ amdgpu.mfma 16x16x32 %arg16 * %arg15 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<4xf32>
// CHECK: rocdl.mfma.f32.16x16x32.fp8.fp8{{.*}}: (i64, i64, vector<4xf32>, i32, i32, i32) -> vector<4xf32>
- amdgpu.mfma %arg16 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
+ amdgpu.mfma 16x16x32 %arg16 * %arg16 + %arg3 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<4xf32>
// CHECK: rocdl.mfma.f32.32x32x16.bf8.bf8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg15 * %arg15 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<16xf32>
+ amdgpu.mfma 32x32x16 %arg15 * %arg15 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E5M2FNUZ>, vector<16xf32>
// CHECK: rocdl.mfma.f32.32x32x16.bf8.fp8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg15 * %arg16 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32>
+ amdgpu.mfma 32x32x16 %arg15 * %arg16 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E5M2FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32>
// CHECK: rocdl.mfma.f32.32x32x16.fp8.bf8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg16 * %arg15 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<16xf32>
+ amdgpu.mfma 32x32x16 %arg16 * %arg15 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E5M2FNUZ>, vector<16xf32>
// CHECK: rocdl.mfma.f32.32x32x16.fp8.fp8{{.*}}: (i64, i64, vector<16xf32>, i32, i32, i32) -> vector<16xf32>
- amdgpu.mfma %arg16 * %arg16 + %arg2 { abid = 0 : i32, cbsz = 0 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32, blocks = 1 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32>
+ amdgpu.mfma 32x32x16 %arg16 * %arg16 + %arg2 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<8xf8E4M3FNUZ>, vector<8xf8E4M3FNUZ>, vector<16xf32>
func.return
}
diff --git a/mlir/test/Dialect/AMDGPU/canonicalize.mlir b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
index 52d3275..fee0c00 100644
--- a/mlir/test/Dialect/AMDGPU/canonicalize.mlir
+++ b/mlir/test/Dialect/AMDGPU/canonicalize.mlir
@@ -165,10 +165,10 @@ func.func @fold_gather_to_lds_of_cast_dest(%global: memref<128x72xf32, 1>, %lds:
// CHECK-LABEL: func @scaled_mfma
// CHECK: %[[SCALE_1:.*]] = vector.extract_strided_slice %0 {offsets = [0], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
// CHECK: %[[SCALE_2:.*]] = vector.extract_strided_slice %2 {offsets = [4], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
-// CHECK: amdgpu.scaled_mfma(%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%[[SCALE_1]][3] * %{{.*}}) * (%[[SCALE_2]][2] * %{{.*}}) {{.*}}
// CHECK: %[[SCALE_3:.*]] = vector.extract_strided_slice %5 {offsets = [8], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
// CHECK: %[[SCALE_4:.*]] = vector.extract_strided_slice %7 {offsets = [12], sizes = [4], strides = [1]} : vector<16xf8E8M0FNU> to vector<4xf8E8M0FNU>
-// CHECK: amdgpu.scaled_mfma(%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%[[SCALE_3]][1] * %{{.*}}) * (%[[SCALE_4]][0] * %{{.*}}) {{.*}}
func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2x1x8x1xf8E8M0FNU>, %scalesB: vector<2x1x8x1xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
@@ -176,12 +176,12 @@ func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %sc
%sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%scaleB = vector.extract %scalesB[0, 0, 6, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
- %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ %res_0 = amdgpu.scaled_mfma 16x16x128 (%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
%scaleC = vector.extract %scalesA[1, 0, 1, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sC = vector.insert %scaleC, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%scaleD = vector.extract %scalesB[1, 0, 4, 0] : f8E8M0FNU from vector<2x1x8x1xf8E8M0FNU>
%sD = vector.insert %scaleD, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
- %res_1 = amdgpu.scaled_mfma(%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ %res_1 = amdgpu.scaled_mfma 16x16x128 (%sC[0] * %opA) * (%sD[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_0, %res_1 : vector<4xf32>, vector<4xf32>
}
@@ -192,7 +192,7 @@ func.func @scaled_mfma(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %sc
// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
// CHECK: vector.extract {{.*}} : f8E8M0FNU from vector<2xf8E8M0FNU>
// CHECK: vector.insert {{.*}} : f8E8M0FNU into vector<4xf8E8M0FNU>
-// CHECK: amdgpu.scaled_mfma({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
+// CHECK: amdgpu.scaled_mfma 16x16x128 ({{.*}}[0] * {{.*}}) * ({{.*}}[0] * {{.*}}
func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<2xf8E8M0FNU>, %scalesB: vector<2xf8E8M0FNU>) -> vector<4xf32> {
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
@@ -200,17 +200,17 @@ func.func @scaled_mfma_less_than_4(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
%sA = vector.insert %scaleA, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%scaleB = vector.extract %scalesB[1] : f8E8M0FNU from vector<2xf8E8M0FNU>
%sB = vector.insert %scaleB, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
- %res_0 = amdgpu.scaled_mfma(%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ %res_0 = amdgpu.scaled_mfma 16x16x128 (%sA[0] * %opA) * (%sB[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_0 : vector<4xf32>
}
// -----
// CHECK-LABEL: func @scaled_mfma_ugly_shapes
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
-// CHECK: amdgpu.scaled_mfma(%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[0] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[1] * %{{.*}}) * (%{{.*}}[3] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[2] * %{{.*}}) * (%{{.*}}[2] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+// CHECK: amdgpu.scaled_mfma 16x16x128 (%{{.*}}[3] * %{{.*}}) * (%{{.*}}[1] * %arg1) + %cst : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4E2M1FN>, %scalesA: vector<5x5xf8E8M0FNU>, %scalesB: vector<7x23xf8E8M0FNU>) -> (vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) {
%cst_0 = arith.constant dense<0.000000e+00> : vector<4xf32>
%cst_1 = arith.constant dense<5.877470e-39> : vector<4xf8E8M0FNU>
@@ -237,10 +237,10 @@ func.func @scaled_mfma_ugly_shapes(%opA: vector<32xf4E2M1FN>, %opB: vector<32xf4
%sB_6_21 = vector.insert %scaleB_6_21, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_20 = vector.insert %scaleB_6_20, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
%sB_6_19 = vector.insert %scaleB_6_19, %cst_1 [0] : f8E8M0FNU into vector<4xf8E8M0FNU>
-
- %res_4 = amdgpu.scaled_mfma(%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
- %res_5 = amdgpu.scaled_mfma(%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
- %res_6 = amdgpu.scaled_mfma(%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
- %res_7 = amdgpu.scaled_mfma(%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 {k = 128 : i32, m = 16 : i32, n = 16 : i32} : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+
+ %res_4 = amdgpu.scaled_mfma 16x16x128 (%sA_0_4[0] * %opA) * (%sB_6_22[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ %res_5 = amdgpu.scaled_mfma 16x16x128 (%sA_0_5[0] * %opA) * (%sB_6_21[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ %res_6 = amdgpu.scaled_mfma 16x16x128 (%sA_0_6[0] * %opA) * (%sB_6_20[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
+ %res_7 = amdgpu.scaled_mfma 16x16x128 (%sA_0_7[0] * %opA) * (%sB_6_19[0] * %opB) + %cst_0 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf32>
return %res_4, %res_5, %res_6, %res_7 : vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>
}
diff --git a/mlir/test/Dialect/AMDGPU/invalid.mlir b/mlir/test/Dialect/AMDGPU/invalid.mlir
index 6a2518a..5784764 100644
--- a/mlir/test/Dialect/AMDGPU/invalid.mlir
+++ b/mlir/test/Dialect/AMDGPU/invalid.mlir
@@ -19,9 +19,7 @@ func.func @mixing_packed_stoch_round_types(%arg0: f32, %arg1: i32, %arg2: vector
func.func @bad_source_types(%a: vector<2xf32>, %b: vector<4xf16>,
%c: vector<32xf32>) -> vector<32xf32> {
// expected-error@+1 {{'amdgpu.mfma' op expected both non-small-float source operand types to match exactly}}
- %d = amdgpu.mfma %a * %b + %c {
- m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
- abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<4xf16>, vector<32xf32>
+ %d = amdgpu.mfma 32x32x1 %a * %b + %c { blocks = 2 : i32, abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<4xf16>, vector<32xf32>
func.return %d : vector<32xf32>
}
@@ -30,9 +28,7 @@ func.func @bad_source_types(%a: vector<2xf32>, %b: vector<4xf16>,
func.func @bad_source_types_f8(%a: vector<8xf8E5M2FNUZ>, %b: vector<8xi8>,
%c: vector<32xf32>) -> vector<32xf32> {
// expected-error@+1 {{'amdgpu.mfma' op expected both source operands to have small-float elements if one does}}
- %d = amdgpu.mfma %a * %b + %c {
- m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
- abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xf8E5M2FNUZ>, vector<8xi8>, vector<32xf32>
+ %d = amdgpu.mfma 32x32x1 %a * %b + %c { blocks = 2 : i32, abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xf8E5M2FNUZ>, vector<8xi8>, vector<32xf32>
func.return %d : vector<32xf32>
}
@@ -41,9 +37,7 @@ func.func @bad_source_types_f8(%a: vector<8xf8E5M2FNUZ>, %b: vector<8xi8>,
func.func @bad_source_arguments(%a: vector<2xf32>, %b: vector<2xf32>,
%c: vector<32xf32>) -> vector<32xf32> {
// expected-error@+1 {{'amdgpu.mfma' op expected 1 source values for this operation but got 2}}
- %d = amdgpu.mfma %a * %b + %c {
- m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
- abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<2xf32>, vector<32xf32>
+ %d = amdgpu.mfma 32x32x1 %a * %b + %c { blocks = 2 : i32, abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<2xf32>, vector<2xf32>, vector<32xf32>
func.return %d : vector<32xf32>
}
@@ -52,9 +46,7 @@ func.func @bad_source_arguments(%a: vector<2xf32>, %b: vector<2xf32>,
func.func @bad_source_arguments_i8(%a: vector<8xi8>, %b: vector<8xi8>,
%c: vector<4xi32>) -> vector<4xi32> {
// expected-error@+1 {{'amdgpu.mfma' op expected 4 source values for this operation but got 8}}
- %d = amdgpu.mfma %a * %b + %c {
- m = 32 : i32, n = 32 : i32, k = 4 : i32, blocks = 2 : i32,
- abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32>
+ %d = amdgpu.mfma 32x32x4 %a * %b + %c { blocks = 2 : i32, abid = 0 : i32, cbsz = 0 : i32} blgp = none : vector<8xi8>, vector<8xi8>, vector<4xi32>
func.return %d : vector<4xi32>
}
@@ -62,9 +54,7 @@ func.func @bad_source_arguments_i8(%a: vector<8xi8>, %b: vector<8xi8>,
func.func @bad_dest_type(%a: f32, %b: f32, %c: vector<16xf32>) -> vector<16xf32> {
// expected-error@+1 {{'amdgpu.mfma' op expected 32 result values for this operation but got 16}}
- %d = amdgpu.mfma %a * %b + %c {
- m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
- abid = 0 : i32, cbsz = 0 : i32} blgp = none : f32, f32, vector<16xf32>
+ %d = amdgpu.mfma 32x32x1 %a * %b + %c { blocks = 2 : i32, abid = 0 : i32, cbsz = 0 : i32} blgp = none : f32, f32, vector<16xf32>
return %d : vector<16xf32>
}
@@ -72,9 +62,7 @@ func.func @bad_dest_type(%a: f32, %b: f32, %c: vector<16xf32>) -> vector<16xf32>
func.func @f64_permuting_b(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64> {
// expected-error@+1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of B}}
- %d = amdgpu.mfma %a * %b + %c {
- m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32,
- abid = 0 : i32, cbsz = 0 : i32} blgp = bcast_first_32 : f64, f64, vector<4xf64>
+ %d = amdgpu.mfma 16x16x4 %a * %b + %c { abid = 0 : i32, cbsz = 0 : i32} blgp = bcast_first_32 : f64, f64, vector<4xf64>
return %d : vector<4xf64>
}
@@ -82,9 +70,7 @@ func.func @f64_permuting_b(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64>
func.func @f64_permuting_a(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64> {
// expected-error@+1 {{'amdgpu.mfma' op double-precision ops do not support permuting lanes of A}}
- %d = amdgpu.mfma %a * %b + %c {
- m = 16 : i32, n = 16 : i32, k = 4 : i32, blocks = 1 : i32,
- abid = 0 : i32, cbsz = 1 : i32} blgp = none : f64, f64, vector<4xf64>
+ %d = amdgpu.mfma 16x16x4 %a * %b + %c { abid = 0 : i32, cbsz = 1 : i32} blgp = none : f64, f64, vector<4xf64>
return %d : vector<4xf64>
}
@@ -92,9 +78,7 @@ func.func @f64_permuting_a(%a: f64, %b: f64, %c: vector<4xf64>) -> vector<4xf64>
func.func @abid_without_bradcast(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
// expected-error@+1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}}
- %d = amdgpu.mfma %a * %b + %c {
- m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
- abid = 1 : i32, cbsz = 0 : i32} blgp = none : f32, f32, vector<32xf32>
+ %d = amdgpu.mfma 32x32x1 %a * %b + %c { blocks = 2 : i32, abid = 1 : i32, cbsz = 0 : i32} blgp = none : f32, f32, vector<32xf32>
func.return %d : vector<32xf32>
}
@@ -102,9 +86,7 @@ func.func @abid_without_bradcast(%a: f32, %b: f32, %c: vector<32xf32>) -> vector
func.func @abid_too_large(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
// expected-error@+1 {{'amdgpu.mfma' op block ID for permuting A (abid) must be below 2 ** cbsz}}
- %d = amdgpu.mfma %a * %b + %c {
- m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
- abid = 2 : i32, cbsz = 1 : i32} blgp = none : f32, f32, vector<32xf32>
+ %d = amdgpu.mfma 32x32x1 %a * %b + %c { blocks = 2 : i32, abid = 2 : i32, cbsz = 1 : i32} blgp = none : f32, f32, vector<32xf32>
func.return %d : vector<32xf32>
}
@@ -112,9 +94,39 @@ func.func @abid_too_large(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32
func.func @no_negation(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
// expected-error@+1 {{'amdgpu.mfma' op negation flags only available for double-precision operations}}
- %d = amdgpu.mfma %a * %b + %c {
- m = 32 : i32, n = 32 : i32, k = 1 : i32, blocks = 2 : i32,
- abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, f32, vector<32xf32>
+ %d = amdgpu.mfma 32x32x1 %a * %b + %c { blocks = 2 : i32, abid = 0 : i32, cbsz = 0 : i32, negateA} blgp = none : f32, f32, vector<32xf32>
+ func.return %d : vector<32xf32>
+}
+
+// -----
+
+func.func @mfma_invalid_m(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
+ // expected-error@+1 {{'amdgpu.mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32}}}
+ %d = amdgpu.mfma 7x32x1 %a * %b + %c { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f32, f32, vector<32xf32>
+ func.return %d : vector<32xf32>
+}
+
+// -----
+
+func.func @mfma_invalid_n(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
+ // expected-error@+1 {{'amdgpu.mfma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {4, 16, 32}}}
+ %d = amdgpu.mfma 32x7x1 %a * %b + %c { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f32, f32, vector<32xf32>
+ func.return %d : vector<32xf32>
+}
+
+// -----
+
+func.func @mfma_invalid_k(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
+ // expected-error@+1 {{'amdgpu.mfma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {1, 2, 4, 8, 16, 32, 64, 128}}}
+ %d = amdgpu.mfma 32x32x3 %a * %b + %c { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f32, f32, vector<32xf32>
+ func.return %d : vector<32xf32>
+}
+
+// -----
+
+func.func @mfma_invalid_blocks(%a: f32, %b: f32, %c: vector<32xf32>) -> vector<32xf32> {
+ // expected-error@+1 {{'amdgpu.mfma' op attribute 'blocks' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {1, 2, 4, 16}}}
+ %d = amdgpu.mfma 32x32x1 %a * %b + %c { blocks = 7 : i32, abid = 0 : i32, cbsz = 0 : i32 } blgp = none : f32, f32, vector<32xf32>
func.return %d : vector<32xf32>
}
@@ -302,3 +314,27 @@ func.func @amdgpu.scaled_ext_packed816_invalid_input_output_sizes(%v: vector<8xf
%ret0 = amdgpu.scaled_ext_packed816 %v scale(%scale) blockSize(16) firstScaleLane(0) firstScaleByte(0) : vector<8xf8E5M2>, vector<4xf8E8M0FNU> -> vector<16xf16>
func.return
}
+
+// -----
+
+func.func @scaled_mfma_invalid_m(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
+ // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'm' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+ %0 = amdgpu.scaled_mfma 8x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
+ func.return %0 : vector<16xf32>
+}
+
+// -----
+
+func.func @scaled_mfma_invalid_n(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
+ // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'n' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {16, 32}}}
+ %0 = amdgpu.scaled_mfma 32x8x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
+ func.return %0 : vector<16xf32>
+}
+
+// -----
+
+func.func @scaled_mfma_invalid_k(%arg0 : vector<4xf8E8M0FNU>, %arg1 : vector<32xf4E2M1FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
+ // expected-error@+1 {{'amdgpu.scaled_mfma' op attribute 'k' failed to satisfy constraint: 32-bit signless integer attribute whose value is one of {64, 128}}}
+ %0 = amdgpu.scaled_mfma 32x32x32 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<4xf8E8M0FNU>, vector<32xf4E2M1FN>, vector<16xf32>
+ func.return %0 : vector<16xf32>
+}
diff --git a/mlir/test/Dialect/AMDGPU/ops.mlir b/mlir/test/Dialect/AMDGPU/ops.mlir
index a185eb6..a330967 100644
--- a/mlir/test/Dialect/AMDGPU/ops.mlir
+++ b/mlir/test/Dialect/AMDGPU/ops.mlir
@@ -559,9 +559,16 @@ func.func @sched_barrier() {
}
// CHECK-LABEL: func @mfma
-func.func @mfma(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> {
- // CHECK: amdgpu.mfma
- %0 = amdgpu.mfma %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, k = 1 : i32, m = 32 : i32, n = 32 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, f32, vector<32xf32>
+func.func @mfma(%arg0 : vector<4xf16>, %arg1 : vector<4xf32>) -> vector<4xf32> {
+ // CHECK: amdgpu.mfma 16x16x16
+ %0 = amdgpu.mfma 16x16x16 %arg0 * %arg0 + %arg1 { abid = 0 : i32, cbsz = 0 : i32 } blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32>
+ func.return %0 : vector<4xf32>
+}
+
+// CHECK-LABEL: func @mfma_with_blocks
+func.func @mfma_with_blocks(%arg0 : f32, %arg1 : vector<32xf32>) -> vector<32xf32> {
+ // CHECK: amdgpu.mfma 32x32x1
+ %0 = amdgpu.mfma 32x32x1 %arg0 * %arg0 + %arg1 { abid = 1 : i32, cbsz = 1 : i32, blocks = 2 : i32 } blgp = bcast_second_32 : f32, f32, vector<32xf32>
func.return %0 : vector<32xf32>
}
@@ -602,8 +609,8 @@ func.func @permlane32_swap(%arg0 : f32) -> f32 {
// CHECK-LABEL: func @scaled_mfma
func.func @scaled_mfma(%arg0 : f8E8M0FNU, %arg1 : vector<32xf6E2M3FN>, %arg2 : vector<16xf32>) -> vector<16xf32> {
- // CHECK: amdgpu.scaled_mfma
- %0 = amdgpu.scaled_mfma(%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 { k = 64 : i32, m = 32 : i32, n = 32 : i32 } : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
+ // CHECK: amdgpu.scaled_mfma 32x32x64
+ %0 = amdgpu.scaled_mfma 32x32x64 (%arg0[0] * %arg1) * (%arg0[1] * %arg1) + %arg2 : f8E8M0FNU, vector<32xf6E2M3FN>, f8E8M0FNU, vector<32xf6E2M3FN>, vector<16xf32>
func.return %0 : vector<16xf32>
}
diff --git a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
index ac1f22b..f9b81df 100644
--- a/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
+++ b/mlir/test/Dialect/MemRef/value-bounds-op-interface-impl.mlir
@@ -67,11 +67,11 @@ func.func @memref_dim_all_positive(%m: memref<?xf32>, %x: index) {
// CHECK-SAME: %[[m:[a-zA-Z0-9]+]]: memref<?xf32>
// CHECK-SAME: %[[sz:[a-zA-Z0-9]+]]: index
// CHECK: %[[c4:.*]] = arith.constant 4 : index
-// CHECK: return %[[sz]], %[[c4]]
+// CHECK: return %[[c4]], %[[sz]]
func.func @memref_expand(%m: memref<?xf32>, %sz: index) -> (index, index) {
- %0 = memref.expand_shape %m [[0, 1]] output_shape [%sz, 4]: memref<?xf32> into memref<?x4xf32>
- %1 = "test.reify_bound"(%0) {dim = 0} : (memref<?x4xf32>) -> (index)
- %2 = "test.reify_bound"(%0) {dim = 1} : (memref<?x4xf32>) -> (index)
+ %0 = memref.expand_shape %m [[0, 1]] output_shape [4, %sz]: memref<?xf32> into memref<4x?xf32>
+ %1 = "test.reify_bound"(%0) {dim = 0} : (memref<4x?xf32>) -> (index)
+ %2 = "test.reify_bound"(%0) {dim = 1} : (memref<4x?xf32>) -> (index)
return %1, %2 : index, index
}
diff --git a/mlir/test/Dialect/SCF/parallel-loop-unroll.mlir b/mlir/test/Dialect/SCF/parallel-loop-unroll.mlir
new file mode 100644
index 0000000..12b502e
--- /dev/null
+++ b/mlir/test/Dialect/SCF/parallel-loop-unroll.mlir
@@ -0,0 +1,171 @@
+// RUN: mlir-opt %s -test-parallel-loop-unrolling='unroll-factors=1,2' -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-parallel-loop-unrolling='unroll-factors=1,2 loop-depth=1' -split-input-file | FileCheck %s --check-prefix CHECK-UNROLL-INNER
+// RUN: mlir-opt %s -test-parallel-loop-unrolling='unroll-factors=3,1' -split-input-file | FileCheck %s --check-prefix CHECK-UNROLL-BY-3
+
+func.func @unroll_simple_parallel_loop(%src: memref<1x16x12xf32>, %dst: memref<1x16x12xf32>) {
+ %c12 = arith.constant 12 : index
+ %c16 = arith.constant 16 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c1, %c16, %c12) step (%c1, %c1, %c1) {
+ %read = memref.load %src[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+ memref.store %read, %dst[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+ scf.reduce
+ }
+ return
+}
+
+// CHECK-LABEL: func @unroll_simple_parallel_loop
+// CHECK-SAME: ([[ARG0:%.*]]: memref<1x16x12xf32>, [[ARG1:%.*]]: memref<1x16x12xf32>)
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG: [[C12:%.*]] = arith.constant 12 : index
+// CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+// CHECK: scf.parallel ([[IV0:%.*]], [[IV1:%.*]], [[IV2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[C1]], [[C16]], [[C12]]) step ([[C1]], [[C1]], [[C2]])
+// CHECK: [[LOADED1:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK: memref.store [[LOADED1]], [[ARG1]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK: [[UNR_IV2:%.*]] = affine.apply {{.*}}([[IV2]])
+// CHECK: [[LOADED2:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[UNR_IV2]]] : memref<1x16x12xf32>
+// CHECK: memref.store [[LOADED2]], [[ARG1]][[[IV0]], [[IV1]], [[UNR_IV2]]] : memref<1x16x12xf32>
+
+// -----
+
+func.func @negative_unroll_factors_dont_divide_evenly(%src: memref<1x16x12xf32>, %dst: memref<1x16x12xf32>) {
+ %c12 = arith.constant 12 : index
+ %c16 = arith.constant 16 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c1, %c16, %c12) step (%c1, %c1, %c1) {
+ %read = memref.load %src[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+ memref.store %read, %dst[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+ scf.reduce
+ }
+ return
+}
+
+// CHECK-UNROLL-BY-3-LABEL: func @negative_unroll_factors_dont_divide_evenly
+// CHECK-UNROLL-BY-3-SAME: ([[ARG0:%.*]]: memref<1x16x12xf32>, [[ARG1:%.*]]: memref<1x16x12xf32>)
+// CHECK-UNROLL-BY-3: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-UNROLL-BY-3: scf.parallel ([[IV0:%.*]], [[IV1:%.*]], [[IV2:%.*]]) = {{.*}} step ([[C1]], [[C1]], [[C1]])
+// CHECK-UNROLL-BY-3: [[LOADED:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK-UNROLL-BY-3: memref.store [[LOADED]], [[ARG1]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK-UNROLL-BY-3-NOT: affine.apply
+// CHECK-UNROLL-BY-3-NOT: memref.load
+// CHECK-UNROLL-BY-3-NOT: memref.store
+
+// -----
+
+func.func @unroll_outer_nested_parallel_loop(%src: memref<5x16x12x4x4xf32>, %dst: memref<5x16x12x4x4xf32>) {
+ %c4 = arith.constant 4 : index
+ %c12 = arith.constant 12 : index
+ %c16 = arith.constant 16 : index
+ %c5 = arith.constant 5 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%c5, %c16, %c12) step (%c1, %c1, %c1) {
+ scf.parallel (%arg6, %arg7) = (%c0, %c0) to (%c4, %c4) step (%c1, %c1) {
+ %0 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg4, %arg6)
+ %1 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg5, %arg7)
+ %subv_in = memref.subview %src[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>>
+ %subv_out = memref.subview %dst[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>>
+ linalg.erf ins(%subv_in : memref<4x4xf32, strided<[4, 1], offset: ?>>) outs(%subv_out : memref<4x4xf32, strided<[4, 1], offset: ?>>)
+ scf.reduce
+ }
+ scf.reduce
+ }
+ return
+}
+
+// CHECK-UNROLL-BY-3-LABEL: func @unroll_outer_nested_parallel_loop
+// CHECK-LABEL: func @unroll_outer_nested_parallel_loop
+// CHECK-SAME: ([[ARG0:%.*]]: memref<5x16x12x4x4xf32>, [[ARG1:%.*]]: memref<5x16x12x4x4xf32>)
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG: [[C2:%.*]] = arith.constant 2 : index
+// CHECK-DAG: [[C4:%.*]] = arith.constant 4 : index
+// CHECK-DAG: [[C5:%.*]] = arith.constant 5 : index
+// CHECK-DAG: [[C12:%.*]] = arith.constant 12 : index
+// CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+// CHECK: scf.parallel ([[OUTV0:%.*]], [[OUTV1:%.*]], [[OUTV2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[C5]], [[C16]], [[C12]]) step ([[C1]], [[C1]], [[C2]])
+// CHECK: scf.parallel ([[INV0:%.*]], [[INV1:%.*]]) = ([[C0]], [[C0]]) to ([[C4]], [[C4]]) step ([[C1]], [[C1]])
+// CHECK: affine.apply {{.*}}([[OUTV1]], [[INV0]])
+// CHECK: affine.apply {{.*}}([[OUTV2]], [[INV1]])
+// CHECK: linalg.erf
+
+// CHECK: [[UNR_OUTV2:%.*]] = affine.apply {{.*}}([[OUTV2]])
+// CHECK: scf.parallel ([[INV0B:%.*]], [[INV1B:%.*]]) = ([[C0]], [[C0]]) to ([[C4]], [[C4]]) step ([[C1]], [[C1]])
+// CHECK: affine.apply {{.*}}([[OUTV1]], [[INV0B]])
+// CHECK: affine.apply {{.*}}([[UNR_OUTV2]], [[INV1B]])
+// CHECK: linalg.erf
+
+// -----
+
+func.func @negative_unroll_dynamic_parallel_loop(%src: memref<1x16x12xf32>, %dst: memref<1x16x12xf32>, %ub3: index) {
+ %c12 = arith.constant 12 : index
+ %c16 = arith.constant 16 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ scf.parallel (%arg2, %arg3, %arg4) = (%c0, %c0, %c0) to (%c1, %c16, %ub3) step (%c1, %c1, %c1) {
+ %read = memref.load %src[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+ memref.store %read, %dst[%arg2, %arg3, %arg4] : memref<1x16x12xf32>
+ scf.reduce
+ }
+ return
+}
+
+// CHECK-LABEL: func @negative_unroll_dynamic_parallel_loop
+// CHECK-SAME: ([[ARG0:%.*]]: memref<1x16x12xf32>, [[ARG1:%.*]]: memref<1x16x12xf32>, [[UB3:%.*]]: index)
+// CHECK-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-DAG: [[C16:%.*]] = arith.constant 16 : index
+// CHECK: scf.parallel ([[IV0:%.*]], [[IV1:%.*]], [[IV2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[C1]], [[C16]], [[UB3]]) step ([[C1]], [[C1]], [[C1]])
+// CHECK: [[LOADED:%.*]] = memref.load [[ARG0]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK: memref.store [[LOADED]], [[ARG1]][[[IV0]], [[IV1]], [[IV2]]] : memref<1x16x12xf32>
+// CHECK-NOT: affine.apply
+// CHECK-NOT: memref.load
+// CHECK-NOT: memref.store
+
+// -----
+
+func.func @unroll_inner_nested_parallel_loop(%src: memref<5x16x12x4x4xf32>, %dst: memref<5x16x12x4x4xf32>) {
+ %c4 = arith.constant 4 : index
+ %c12 = arith.constant 12 : index
+ %c16 = arith.constant 16 : index
+ %c5 = arith.constant 5 : index
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ scf.parallel (%arg3, %arg4, %arg5) = (%c0, %c0, %c0) to (%c5, %c16, %c12) step (%c1, %c1, %c1) {
+ scf.parallel (%arg6, %arg7) = (%c0, %c0) to (%c4, %c4) step (%c1, %c1) {
+ %0 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg4, %arg6)
+ %1 = affine.apply affine_map<(d0, d1) -> (d0 + (d1 floordiv 4) * 4)>(%arg5, %arg7)
+ %subv_in = memref.subview %src[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>>
+ %subv_out = memref.subview %dst[%arg3, %0, %1, 0, 0] [1, 1, 1, 4, 4] [1, 1, 1, 1, 1] : memref<5x16x12x4x4xf32> to memref<4x4xf32, strided<[4, 1], offset: ?>>
+ linalg.erf ins(%subv_in : memref<4x4xf32, strided<[4, 1], offset: ?>>) outs(%subv_out : memref<4x4xf32, strided<[4, 1], offset: ?>>)
+ scf.reduce
+ }
+ scf.reduce
+ }
+ return
+}
+
+// CHECK-LABEL: func @unroll_inner_nested_parallel_loop
+// CHECK-UNROLL-INNER-LABEL: func @unroll_inner_nested_parallel_loop
+// CHECK-UNROLL-INNER-SAME: ([[ARG0:%.*]]: memref<5x16x12x4x4xf32>, [[ARG1:%.*]]: memref<5x16x12x4x4xf32>)
+// CHECK-UNROLL-INNER-DAG: [[C0:%.*]] = arith.constant 0 : index
+// CHECK-UNROLL-INNER-DAG: [[C1:%.*]] = arith.constant 1 : index
+// CHECK-UNROLL-INNER-DAG: [[C4:%.*]] = arith.constant 4 : index
+// CHECK-UNROLL-INNER-DAG: [[C5:%.*]] = arith.constant 5 : index
+// CHECK-UNROLL-INNER-DAG: [[C12:%.*]] = arith.constant 12 : index
+// CHECK-UNROLL-INNER-DAG: [[C16:%.*]] = arith.constant 16 : index
+// CHECK-UNROLL-INNER: scf.parallel ([[OUTV0:%.*]], [[OUTV1:%.*]], [[OUTV2:%.*]]) = ([[C0]], [[C0]], [[C0]]) to ([[C5]], [[C16]], [[C12]]) step ([[C1]], [[C1]], [[C1]])
+// CHECK-UNROLL-INNER-DAG: [[C2:%.*]] = arith.constant 2 : index
+// CHECK-UNROLL-INNER: scf.parallel ([[INV0:%.*]], [[INV1:%.*]]) = ([[C0]], [[C0]]) to ([[C4]], [[C4]]) step ([[C1]], [[C2]])
+// CHECK-UNROLL-INNER: affine.apply {{.*}}([[OUTV1]], [[INV0]])
+// CHECK-UNROLL-INNER: affine.apply {{.*}}([[OUTV2]], [[INV1]])
+// CHECK-UNROLL-INNER: linalg.erf
+
+// CHECK-UNROLL-INNER: [[UNR_INV1:%.*]] = affine.apply {{.*}}([[INV1]])
+// CHECK-UNROLL-INNER: affine.apply {{.*}}([[OUTV1]], [[INV0]])
+// CHECK-UNROLL-INNER: affine.apply {{.*}}([[OUTV2]], [[UNR_INV1]])
+// CHECK-UNROLL-INNER: linalg.erf
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 99ad2a8..20bb4ea 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -501,7 +501,7 @@ spirv.module Logical GLSL450 {
// -----
spirv.module Logical GLSL450 {
- // expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable or spirv.SpecConstantCompositeOp op}}
+ // expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.SpecConstantCompositeOp op}}
spirv.GlobalVariable @var0 initializer(@var1) : !spirv.ptr<f32, Private>
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
index 7e742af..d61908b 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-blocking.mlir
@@ -715,7 +715,8 @@ gpu.module @test_kernel {
gpu.module @test_kernel {
// CHECK-LABEL: load_store_nd_with_offsets
// CHECK-SAME: [[arg0:%.+]]: memref<1024x1024xf32>, [[arg1:%.+]]: memref<1024x1024xf32>, [[arg2:%.+]]: memref<1024x1024xf32>
- // CHECK-DAG: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+ // CHECK-DAG: [[cst:%.+]] = arith.constant dense<0.000000e+00> : vector<32xf32>
+ // CHECK-DAG: [[cst_0:%.+]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
// CHECK-DAG: [[c16:%.+]] = arith.constant 16 : index
// CHECK-DAG: [[c0:%.+]] = arith.constant 0 : index
// CHECK: [[tdesc_a:%.+]] = xegpu.create_nd_tdesc [[arg0]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
@@ -723,20 +724,27 @@ gpu.module @test_kernel {
// CHECK: [[tdesc_c:%.+]] = xegpu.create_nd_tdesc [[arg2]] : memref<1024x1024xf32> -> !xegpu.tensor_desc<1x16xf32>
// CHECK: [[ld_a0:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c0]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
// CHECK: [[ld_a1:%.+]] = xegpu.load_nd [[tdesc_a]][[[c0]], [[c16]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
+ // CHECK: [[ins_a0:%.+]] = vector.insert_strided_slice [[ld_a0]], [[cst_0]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
+ // CHECK: [[ins_a1:%.+]] = vector.insert_strided_slice [[ld_a1]], [[ins_a0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
// CHECK: [[ld_b0:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c0]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
// CHECK: [[ld_b1:%.+]] = xegpu.load_nd [[tdesc_b]][[[c0]], [[c16]]] : !xegpu.tensor_desc<1x16xf32> -> vector<1x16xf32>
- // CHECK: [[cast_a0:%.+]] = vector.shape_cast [[ld_a0]] : vector<1x16xf32> to vector<16xf32>
- // CHECK: [[cast_b0:%.+]] = vector.shape_cast [[ld_b0]] : vector<1x16xf32> to vector<16xf32>
- // CHECK: [[add0:%.+]] = arith.addf [[cast_a0]], [[cast_b0]] : vector<16xf32>
- // CHECK: [[ins0:%.+]] = vector.insert_strided_slice [[add0]], [[cst]] {offsets = [0, 0], strides = [1]} : vector<16xf32> into vector<1x32xf32>
- // CHECK: [[cast_a1:%.+]] = vector.shape_cast [[ld_a1]] : vector<1x16xf32> to vector<16xf32>
- // CHECK: [[cast_b1:%.+]] = vector.shape_cast [[ld_b1]] : vector<1x16xf32> to vector<16xf32>
- // CHECK: [[add1:%.+]] = arith.addf [[cast_a1]], [[cast_b1]] : vector<16xf32>
- // CHECK: [[ins1:%.+]] = vector.insert_strided_slice [[add1]], [[ins0]] {offsets = [0, 16], strides = [1]} : vector<16xf32> into vector<1x32xf32>
- // CHECK: [[ext0:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
- // CHECK: [[ext1:%.+]] = vector.extract_strided_slice [[ins1]] {offsets = [0, 16], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
- // CHECK: xegpu.store_nd [[ext0]], [[tdesc_c]][[[c0]], [[c0]]] : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
- // CHECK: xegpu.store_nd [[ext1]], [[tdesc_c]][[[c0]], [[c16]]] : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+ // CHECK: [[ins_b0:%.+]] = vector.insert_strided_slice [[ld_b0]], [[cst_0]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
+ // CHECK: [[ins_b1:%.+]] = vector.insert_strided_slice [[ld_b1]], [[ins_b0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
+ // CHECK: [[ext_a:%.+]] = vector.extract [[ins_a1]][0] : vector<32xf32> from vector<1x32xf32>
+ // CHECK: [[ext_b:%.+]] = vector.extract [[ins_b1]][0] : vector<32xf32> from vector<1x32xf32>
+ // CHECK: [[slice_a0:%.+]] = vector.extract_strided_slice [[ext_a]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+ // CHECK: [[slice_b0:%.+]] = vector.extract_strided_slice [[ext_b]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+ // CHECK: [[add0:%.+]] = arith.addf [[slice_a0]], [[slice_b0]] : vector<16xf32>
+ // CHECK: [[ins_add0:%.+]] = vector.insert_strided_slice [[add0]], [[cst]] {offsets = [0], strides = [1]} : vector<16xf32> into vector<32xf32>
+ // CHECK: [[slice_a1:%.+]] = vector.extract_strided_slice [[ext_a]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+ // CHECK: [[slice_b1:%.+]] = vector.extract_strided_slice [[ext_b]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+ // CHECK: [[add1:%.+]] = arith.addf [[slice_a1]], [[slice_b1]] : vector<16xf32>
+ // CHECK: [[ins_add1:%.+]] = vector.insert_strided_slice [[add1]], [[ins_add0]] {offsets = [16], strides = [1]} : vector<16xf32> into vector<32xf32>
+ // CHECK: [[broadcast:%.+]] = vector.broadcast [[ins_add1]] : vector<32xf32> to vector<1x32xf32>
+ // CHECK: [[ext_result0:%.+]] = vector.extract_strided_slice [[broadcast]] {offsets = [0, 0], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
+ // CHECK: [[ext_result1:%.+]] = vector.extract_strided_slice [[broadcast]] {offsets = [0, 16], sizes = [1, 16], strides = [1, 1]} : vector<1x32xf32> to vector<1x16xf32>
+ // CHECK: xegpu.store_nd [[ext_result0]], [[tdesc_c]][[[c0]], [[c0]]] : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
+ // CHECK: xegpu.store_nd [[ext_result1]], [[tdesc_c]][[[c0]], [[c16]]] : vector<1x16xf32>, !xegpu.tensor_desc<1x16xf32>
gpu.func @load_store_nd_with_offsets(%A: memref<1024x1024xf32>, %B: memref<1024x1024xf32>, %C: memref<1024x1024xf32>) {
%c0 = arith.constant 0 : index
@@ -752,3 +760,28 @@ gpu.module @test_kernel {
gpu.return
}
}
+
+// -----
+#inst_data = #xegpu.layout<inst_data = [1, 1, 32]>
+gpu.module @test_kernel {
+ // CHECK-LABEL: load_add_store_leading_unit_dims
+ // CHECK-SAME: [[arg0:%.+]]: ui64, [[arg1:%.+]]: ui64, [[arg2:%.+]]: ui64
+ // CHECK: [[mask:%.+]] = arith.constant dense<true> : vector<32xi1>
+ // CHECK: [[offsets:%.+]] = arith.constant dense<[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]> : vector<32xindex>
+ // CHECK: [[a:%.+]] = xegpu.load [[arg0]][[[offsets]]], [[mask]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+ // CHECK: [[b:%.+]] = xegpu.load [[arg1]][[[offsets]]], [[mask]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : ui64, vector<32xindex>, vector<32xi1> -> vector<32xf32>
+ // CHECK: [[add:%.+]] = arith.addf [[a]], [[b]] : vector<32xf32>
+ // CHECK: xegpu.store [[add]], [[arg2]][[[offsets]]], [[mask]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}> : vector<32xf32>, ui64, vector<32xindex>, vector<32xi1>
+ gpu.func @load_add_store_leading_unit_dims(%A: ui64, %B: ui64, %C: ui64) {
+ %cst = arith.constant {layout_result_0 = #inst_data} dense<[
+ [[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120,
+ 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248]]
+ ]> : vector<1x1x32xindex>
+ %mask = arith.constant {layout_result_0 = #inst_data} dense<true> : vector<1x1x32xi1>
+ %a = xegpu.load %A[%cst], %mask {chunk_size = 1, layout_result_0 = #inst_data, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+ %b = xegpu.load %B[%cst], %mask {chunk_size = 1, layout_result_0 = #inst_data, l1_hint = #xegpu.cache_hint<cached>} : ui64, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
+ %addf = arith.addf %a, %b {layout_result_0 = #inst_data} : vector<1x1x32xf32>
+ xegpu.store %addf, %C[%cst], %mask {chunk_size = 1, layout_operand_0 = #inst_data, layout_operand_2 = #inst_data, layout_operand_3 = #inst_data, l1_hint = #xegpu.cache_hint<cached>} : vector<1x1x32xf32>, ui64, vector<1x1x32xindex>, vector<1x1x32xi1>
+ gpu.return
+ }
+}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
index 742d11f..52acde4 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
@@ -527,4 +527,11 @@ gpu.module @test_distribution {
%cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 16]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
gpu.return
}
+
+ // CHECK-LABEL: scalar_broadcast
+ gpu.func @scalar_broadcast(%arg0: index) {
+ // CHECK: vector.broadcast {{.*}} : index to vector<1x1x1xindex>
+ %broadcast = vector.broadcast %arg0 {layout_result_0 = #xegpu.layout<sg_layout = [4, 8, 1], sg_data = [1, 1, 1]>} : index to vector<4x1x1xindex>
+ gpu.return
+ }
}
diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index 71e813c..8487567 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -2,6 +2,7 @@
// RUN: -expand-strided-metadata \
// RUN: -lower-affine \
// RUN: -test-cf-assert \
+// RUN: -convert-scf-to-cf \
// RUN: -convert-to-llvm | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
@@ -11,6 +12,7 @@
// RUN: -expand-strided-metadata \
// RUN: -lower-affine \
// RUN: -test-cf-assert \
+// RUN: -convert-scf-to-cf \
// RUN: -convert-to-llvm="allow-pattern-rollback=0" \
// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
@@ -38,6 +40,17 @@ func.func @subview_dynamic_rank_reduce(%memref: memref<?x4xf32>, %offset: index,
return
}
+func.func @subview_zero_size_dim(%memref: memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>,
+ %dim_0: index,
+ %dim_1: index,
+ %dim_2: index) {
+ %subview = memref.subview %memref[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] :
+ memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>> to
+ memref<?x?x?xf32, strided<[?, ?, ?], offset: ?>>
+ return
+}
+
+
func.func @main() {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
@@ -105,6 +118,14 @@ func.func @main() {
// CHECK-NOT: ERROR: Runtime op verification failed
func.call @subview_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (memref<?x4xf32>, index, index, index) -> ()
+ %alloca_10x4x1 = memref.alloca() : memref<10x4x1xf32>
+ %alloca_10x4x1_dyn_stride = memref.cast %alloca_10x4x1 : memref<10x4x1xf32> to memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ %dim_0 = arith.constant 0 : index
+ %dim_1 = arith.constant 4 : index
+ %dim_2 = arith.constant 1 : index
+ func.call @subview_zero_size_dim(%alloca_10x4x1_dyn_stride, %dim_0, %dim_1, %dim_2)
+ : (memref<10x4x1xf32, strided<[?, ?, ?], offset: ?>>, index, index, index) -> ()
return
}
diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
index 0c7c4a6..a77fa31 100644
--- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
@@ -34,6 +34,12 @@ func.func @extract_slice_dynamic_rank_reduce(%tensor: tensor<?x4xf32>, %offset:
return
}
+func.func @extract_slice_zero_size_dim(%arg0: tensor<10x4x1xf32>, %dim_0: index, %dim_1: index, %dim_2: index) {
+ tensor.extract_slice %arg0[0, 0, 0] [%dim_0, %dim_1, %dim_2] [1, 1, 1] : tensor<10x4x1xf32> to tensor<?x?x?xf32>
+ return
+}
+
+
func.func @main() {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : index
@@ -101,6 +107,13 @@ func.func @main() {
// CHECK-NOT: ERROR: Runtime op verification failed
func.call @extract_slice_dynamic_rank_reduce(%alloca_4_dyn, %0, %1, %0) : (tensor<?x4xf32>, index, index, index) -> ()
+ %cst10x4x1xf32 = arith.constant dense<1.0> : tensor<10x4x1xf32>
+
+ // CHECK-NOT: ERROR: Runtime op verification failed
+ %dim_0 = arith.constant 0 : index
+ %dim_1 = arith.constant 4 : index
+ %dim_2 = arith.constant 1 : index
+ func.call @extract_slice_zero_size_dim(%cst10x4x1xf32, %dim_0, %dim_1, %dim_2) : (tensor<10x4x1xf32>, index, index, index) -> ()
return
}
diff --git a/mlir/test/Target/SPIRV/decorations-intel-cache-controls.mlir b/mlir/test/Target/SPIRV/decorations-intel-cache-controls.mlir
new file mode 100644
index 0000000..62d15de
--- /dev/null
+++ b/mlir/test/Target/SPIRV/decorations-intel-cache-controls.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip --verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls() "None" {
+ // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
+ // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls_invalid_type() "None" {
+ // expected-error@below {{expecting array attribute of CacheControlLoadINTEL for CacheControlLoadINTEL}}
+ %0 = spirv.Variable {cache_control_load_intel = #spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls_invalid_type() "None" {
+ // expected-error@below {{expecting array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
+ %0 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, 0 : i32]} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
+ spirv.func @cache_controls_invalid_type() "None" {
+ // expected-error@below {{expecting non-empty array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
+ %0 = spirv.Variable {cache_control_store_intel = []} : !spirv.ptr<f32, Function>
+ spirv.Return
+ }
+}
diff --git a/mlir/test/Target/SPIRV/decorations.mlir b/mlir/test/Target/SPIRV/decorations.mlir
index 90ba690e..712fd17 100644
--- a/mlir/test/Target/SPIRV/decorations.mlir
+++ b/mlir/test/Target/SPIRV/decorations.mlir
@@ -1,27 +1,32 @@
-// RUN: mlir-translate -no-implicit-module -split-input-file -test-spirv-roundtrip -verify-diagnostics %s | FileCheck %s
+// RUN: mlir-translate --no-implicit-module --split-input-file --test-spirv-roundtrip %s | FileCheck %s
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: location = 0 : i32
spirv.GlobalVariable @var {location = 0 : i32} : !spirv.ptr<vector<4xf32>, Input>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: no_perspective
spirv.GlobalVariable @var {no_perspective} : !spirv.ptr<vector<4xf32>, Input>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: flat
spirv.GlobalVariable @var {flat} : !spirv.ptr<si32, Input>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> {
// CHECK: aliased
// CHECK: aliased
spirv.GlobalVariable @var1 bind(0, 0) {aliased} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer>
@@ -30,28 +35,28 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> {
// CHECK: non_readable
spirv.GlobalVariable @var bind(0, 0) {non_readable} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> {
// CHECK: non_writable
spirv.GlobalVariable @var bind(0, 0) {non_writable} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], [SPV_KHR_variable_pointers]> {
// CHECK: restrict
spirv.GlobalVariable @var bind(0, 0) {restrict} : !spirv.ptr<!spirv.struct<(!spirv.array<4xf32, stride=4>[0])>, StorageBuffer>
}
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: relaxed_precision
spirv.GlobalVariable @var {location = 0 : i32, relaxed_precision} : !spirv.ptr<vector<4xf32>, Output>
}
@@ -84,7 +89,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage], [SPV_KHR_no_integer_wrap_decoration]> {
spirv.func @iadd_decorations(%arg: i32) -> i32 "None" {
// CHECK: spirv.IAdd %{{.*}}, %{{.*}} {no_signed_wrap, no_unsigned_wrap}
%0 = spirv.IAdd %arg, %arg {no_signed_wrap, no_unsigned_wrap} : i32
@@ -94,7 +99,7 @@ spirv.func @iadd_decorations(%arg: i32) -> i32 "None" {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage], []> {
spirv.func @fadd_decorations(%arg: f32) -> f32 "None" {
// CHECK: spirv.FAdd %{{.*}}, %{{.*}} {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>}
%0 = spirv.FAdd %arg, %arg {fp_fast_math_mode = #spirv.fastmath_mode<NotNaN|NotInf|NSZ>} : f32
@@ -104,7 +109,7 @@ spirv.func @fadd_decorations(%arg: f32) -> f32 "None" {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
// CHECK: spirv.FMul %{{.*}}, %{{.*}} {no_contraction}
%0 = spirv.FMul %arg, %arg {no_contraction} : f32
@@ -114,7 +119,7 @@ spirv.func @fmul_decorations(%arg: f32) -> f32 "None" {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Kernel, Float16], []> {
+spirv.module Logical OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Float16], []> {
spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
// CHECK: spirv.FConvert %arg0 {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
%0 = spirv.FConvert %arg {fp_rounding_mode = #spirv.fp_rounding_mode<RTN>} : f32 to f16
@@ -124,51 +129,7 @@ spirv.func @fp_rounding_mode(%arg: f32) -> f16 "None" {
// -----
-// CHECK-LABEL: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
- spirv.func @cache_controls() "None" {
- // CHECK: spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
- %0 = spirv.Variable {cache_control_load_intel = [#spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>, #spirv.cache_control_load_intel<cache_level = 1, load_cache_control = Cached>, #spirv.cache_control_load_intel<cache_level = 2, load_cache_control = InvalidateAfterR>]} : !spirv.ptr<f32, Function>
- // CHECK: spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
- %1 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, #spirv.cache_control_store_intel<cache_level = 1, store_cache_control = WriteThrough>, #spirv.cache_control_store_intel<cache_level = 2, store_cache_control = WriteBack>]} : !spirv.ptr<f32, Function>
- spirv.Return
- }
-}
-
-// -----
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
- spirv.func @cache_controls_invalid_type() "None" {
- // expected-error@below {{expecting array attribute of CacheControlLoadINTEL for CacheControlLoadINTEL}}
- %0 = spirv.Variable {cache_control_load_intel = #spirv.cache_control_load_intel<cache_level = 0, load_cache_control = Uncached>} : !spirv.ptr<f32, Function>
- spirv.Return
- }
-}
-
-// -----
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
- spirv.func @cache_controls_invalid_type() "None" {
- // expected-error@below {{expecting array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
- %0 = spirv.Variable {cache_control_store_intel = [#spirv.cache_control_store_intel<cache_level = 0, store_cache_control = Uncached>, 0 : i32]} : !spirv.ptr<f32, Function>
- spirv.Return
- }
-}
-
-// -----
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [CacheControlsINTEL], [SPV_INTEL_cache_controls]> {
- spirv.func @cache_controls_invalid_type() "None" {
- // expected-error@below {{expecting non-empty array attribute of CacheControlStoreINTEL for CacheControlStoreINTEL}}
- %0 = spirv.Variable {cache_control_store_intel = []} : !spirv.ptr<f32, Function>
- spirv.Return
- }
-}
-
-// -----
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// CHECK: spirv.func @relaxed_precision_arg({{%.*}}: !spirv.ptr<f32, Function> {spirv.decoration = #spirv.decoration<RelaxedPrecision>}) "None" attributes {relaxed_precision} {
spirv.func @relaxed_precision_arg(%arg0: !spirv.ptr<f32, Function> {spirv.decoration = #spirv.decoration<RelaxedPrecision>}) -> () "None" attributes {relaxed_precision} {
spirv.Return
diff --git a/mlir/test/Target/SPIRV/function-decorations.mlir b/mlir/test/Target/SPIRV/function-decorations.mlir
index cf6edaa..a47b39b 100644
--- a/mlir/test/Target/SPIRV/function-decorations.mlir
+++ b/mlir/test/Target/SPIRV/function-decorations.mlir
@@ -1,6 +1,15 @@
// RUN: mlir-translate --no-implicit-module --test-spirv-roundtrip --split-input-file %s | FileCheck %s
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, Int8, Int16], []> {
+ // CHECK: spirv.func @outside.func.with.linkage(i8) "Pure" attributes
+ // CHECK: linkage_attributes = #spirv.linkage_attributes<linkage_name = "outside.func", linkage_type = <Import>>
+ // CHECK: spirv.func @linkage_attr_test_kernel() "DontInline" {
+ // CHECK: spirv.func @inside.func() "Pure" {
spirv.func @linkage_attr_test_kernel() "DontInline" attributes {} {
%uchar_0 = spirv.Constant 0 : i8
%ushort_1 = spirv.Constant 1 : i16
@@ -8,7 +17,6 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.FunctionCall @outside.func.with.linkage(%uchar_0):(i8) -> ()
spirv.Return
}
- // CHECK: linkage_attributes = #spirv.linkage_attributes<linkage_name = "outside.func", linkage_type = <Import>>
spirv.func @outside.func.with.linkage(%arg0 : i8) -> () "Pure" attributes {
linkage_attributes=#spirv.linkage_attributes<
linkage_name="outside.func",
@@ -21,7 +29,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
// -----
spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
- [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+ [Shader, PhysicalStorageBufferAddresses, Linkage], [SPV_KHR_physical_storage_buffer]> {
// CHECK-LABEL: spirv.func @func_arg_decoration_aliased(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>})
spirv.func @func_arg_decoration_aliased(
%arg0 : !spirv.ptr<i32, PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Aliased> }
@@ -33,7 +41,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
// -----
spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
- [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+ [Shader, PhysicalStorageBufferAddresses, Linkage], [SPV_KHR_physical_storage_buffer]> {
// CHECK-LABEL: spirv.func @func_arg_decoration_restrict(%{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Restrict>})
spirv.func @func_arg_decoration_restrict(
%arg0 : !spirv.ptr<i32,PhysicalStorageBuffer> { spirv.decoration = #spirv.decoration<Restrict> }
@@ -45,7 +53,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
// -----
spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
- [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+ [Shader, PhysicalStorageBufferAddresses, Linkage, GenericPointer], [SPV_KHR_physical_storage_buffer]> {
// CHECK-LABEL: spirv.func @func_arg_decoration_aliased_pointer(%{{.*}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<AliasedPointer>})
spirv.func @func_arg_decoration_aliased_pointer(
%arg0 : !spirv.ptr<!spirv.ptr<i32,PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<AliasedPointer> }
@@ -57,7 +65,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
// -----
spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
- [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+ [Shader, PhysicalStorageBufferAddresses, Linkage, GenericPointer], [SPV_KHR_physical_storage_buffer]> {
// CHECK-LABEL: spirv.func @func_arg_decoration_restrict_pointer(%{{.*}}: !spirv.ptr<!spirv.ptr<i32, PhysicalStorageBuffer>, Generic> {spirv.decoration = #spirv.decoration<RestrictPointer>})
spirv.func @func_arg_decoration_restrict_pointer(
%arg0 : !spirv.ptr<!spirv.ptr<i32,PhysicalStorageBuffer>, Generic> { spirv.decoration = #spirv.decoration<RestrictPointer> }
@@ -69,7 +77,7 @@ spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
// -----
spirv.module PhysicalStorageBuffer64 GLSL450 requires #spirv.vce<v1.0,
- [Shader, PhysicalStorageBufferAddresses], [SPV_KHR_physical_storage_buffer]> {
+ [Shader, PhysicalStorageBufferAddresses, Linkage], [SPV_KHR_physical_storage_buffer]> {
// CHECK-LABEL: spirv.func @fn1(%{{.*}}: i32, %{{.*}}: !spirv.ptr<i32, PhysicalStorageBuffer> {spirv.decoration = #spirv.decoration<Aliased>})
spirv.func @fn1(
%arg0: i32,
diff --git a/mlir/test/Target/SPIRV/global-variable.mlir b/mlir/test/Target/SPIRV/global-variable.mlir
index a70ed31..a425412 100644
--- a/mlir/test/Target/SPIRV/global-variable.mlir
+++ b/mlir/test/Target/SPIRV/global-variable.mlir
@@ -1,11 +1,16 @@
// RUN: mlir-translate -no-implicit-module -test-spirv-roundtrip -split-input-file %s | FileCheck %s
+// RUN: %if spirv-tools %{ rm -rf %t %}
+// RUN: %if spirv-tools %{ mkdir %t %}
+// RUN: %if spirv-tools %{ mlir-translate --no-implicit-module --serialize-spirv --split-input-file --spirv-save-validation-files-with-prefix=%t/module %s %}
+// RUN: %if spirv-tools %{ spirv-val %t %}
+
// CHECK: spirv.GlobalVariable @var0 bind(1, 0) : !spirv.ptr<f32, Input>
// CHECK-NEXT: spirv.GlobalVariable @var1 bind(0, 1) : !spirv.ptr<f32, Output>
// CHECK-NEXT: spirv.GlobalVariable @var2 built_in("GlobalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
// CHECK-NEXT: spirv.GlobalVariable @var3 built_in("GlobalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.GlobalVariable @var0 bind(1, 0) : !spirv.ptr<f32, Input>
spirv.GlobalVariable @var1 bind(0, 1) : !spirv.ptr<f32, Output>
spirv.GlobalVariable @var2 {built_in = "GlobalInvocationId"} : !spirv.ptr<vector<3xi32>, Input>
@@ -14,16 +19,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
- // CHECK: spirv.GlobalVariable @var1 : !spirv.ptr<f32, Input>
- // CHECK-NEXT: spirv.GlobalVariable @var2 initializer(@var1) bind(1, 0) : !spirv.ptr<f32, Input>
- spirv.GlobalVariable @var1 : !spirv.ptr<f32, Input>
- spirv.GlobalVariable @var2 initializer(@var1) bind(1, 0) : !spirv.ptr<f32, Input>
-}
-
-// -----
-
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, Int8], []> {
// CHECK: spirv.SpecConstant @sc = 1 : i8
// CHECK-NEXT: spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
spirv.SpecConstant @sc = 1 : i8
@@ -33,7 +29,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage, Int8], []> {
// CHECK: spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
// CHECK-NEXT: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
spirv.SpecConstant @sc0 = 1 : i8
@@ -47,7 +43,7 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// -----
-spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.GlobalVariable @globalInvocationID built_in("GlobalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
spirv.func @foo() "None" {
// CHECK: %[[ADDR:.*]] = spirv.mlir.addressof @globalInvocationID : !spirv.ptr<vector<3xi32>, Input>
diff --git a/mlir/test/lib/Dialect/SCF/CMakeLists.txt b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
index 791c2e6..d2f97e8 100644
--- a/mlir/test/lib/Dialect/SCF/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/SCF/CMakeLists.txt
@@ -2,6 +2,7 @@
add_mlir_library(MLIRSCFTestPasses
TestLoopParametricTiling.cpp
TestLoopUnrolling.cpp
+ TestParallelLoopUnrolling.cpp
TestSCFUtils.cpp
TestSCFWrapInZeroTripCheck.cpp
TestUpliftWhileToFor.cpp
diff --git a/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp b/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp
new file mode 100644
index 0000000..77a22a18
--- /dev/null
+++ b/mlir/test/lib/Dialect/SCF/TestParallelLoopUnrolling.cpp
@@ -0,0 +1,85 @@
+//=== TestParallelLoopUnrolling.cpp - loop unrolling test pass ===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass to unroll loops by a specified unroll factor.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+
+static unsigned getNestingDepth(Operation *op) {
+ Operation *currOp = op;
+ unsigned depth = 0;
+ while ((currOp = currOp->getParentOp())) {
+ if (isa<scf::ParallelOp>(currOp))
+ depth++;
+ }
+ return depth;
+}
+
+struct TestParallelLoopUnrollingPass
+ : public PassWrapper<TestParallelLoopUnrollingPass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestParallelLoopUnrollingPass)
+
+ StringRef getArgument() const final { return "test-parallel-loop-unrolling"; }
+ StringRef getDescription() const final {
+ return "Tests parallel loop unrolling transformation";
+ }
+ TestParallelLoopUnrollingPass() = default;
+ TestParallelLoopUnrollingPass(const TestParallelLoopUnrollingPass &) {}
+
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<arith::ArithDialect>();
+ }
+
+ void runOnOperation() override {
+ SmallVector<scf::ParallelOp, 4> loops;
+ getOperation()->walk([&](scf::ParallelOp parLoop) {
+ if (getNestingDepth(parLoop) == loopDepth)
+ loops.push_back(parLoop);
+ });
+ auto annotateFn = [this](unsigned i, Operation *op, OpBuilder b) {
+ if (annotateLoop) {
+ op->setAttr("unrolled_iteration", b.getUI32IntegerAttr(i));
+ }
+ };
+ PatternRewriter rewriter(getOperation()->getContext());
+ for (auto loop : loops) {
+ (void)parallelLoopUnrollByFactors(loop, unrollFactors, rewriter,
+ annotateFn);
+ }
+ }
+
+ ListOption<uint64_t> unrollFactors{
+ *this, "unroll-factors",
+ llvm::cl::desc(
+ "Unroll factors for each parallel loop dim. If fewer factors than "
+ "loop dims are provided, they are applied to the inner dims.")};
+ Option<unsigned> loopDepth{*this, "loop-depth", llvm::cl::desc("Loop depth."),
+ llvm::cl::init(0)};
+ Option<bool> annotateLoop{*this, "annotate",
+ llvm::cl::desc("Annotate unrolled iterations."),
+ llvm::cl::init(false)};
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestParallelLoopUnrollingPass() {
+ PassRegistration<TestParallelLoopUnrollingPass>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/test/python/CMakeLists.txt b/mlir/test/python/CMakeLists.txt
index e1e82ef..2c12381 100644
--- a/mlir/test/python/CMakeLists.txt
+++ b/mlir/test/python/CMakeLists.txt
@@ -11,7 +11,7 @@ add_public_tablegen_target(MLIRPythonTestIncGen)
add_subdirectory(lib)
-set(MLIR_PYTHON_TEST_DEPENDS MLIRPythonModules)
+set(MLIR_PYTHON_TEST_DEPENDS MLIRPythonModules mlir-runner)
if(NOT MLIR_STANDALONE_BUILD)
list(APPEND MLIR_PYTHON_TEST_DEPENDS FileCheck count not)
endif()
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index d569fce..146e213 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -1,6 +1,7 @@
# RUN: env MLIR_RUNNER_UTILS=%mlir_runner_utils MLIR_C_RUNNER_UTILS=%mlir_c_runner_utils %PYTHON %s 2>&1 | FileCheck %s
# REQUIRES: host-supports-jit
import gc, sys, os, tempfile
+from textwrap import dedent
from mlir.ir import *
from mlir.passmanager import *
from mlir.execution_engine import *
@@ -21,6 +22,7 @@ MLIR_C_RUNNER_UTILS = os.getenv(
"MLIR_C_RUNNER_UTILS", "../../../../lib/libmlir_c_runner_utils.so"
)
+
# Log everything to stderr and flush so that we have a unified stream to match
# errors/info emitted by MLIR to stderr.
def log(*args):
@@ -337,6 +339,7 @@ func.func private @some_callback_into_python(memref<*xf32>) attributes {llvm.emi
ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(inp_arr))),
)
+
run(testUnrankedMemRefWithOffsetCallback)
@@ -785,15 +788,25 @@ def testDumpToObjectFile():
try:
with Context():
module = Module.parse(
- """
- module {
- func.func @main() attributes { llvm.emit_c_interface } {
- return
- }
- }"""
+ dedent(
+ """
+ func.func private @printF32(f32)
+ func.func @main(%arg0: f32) attributes { llvm.emit_c_interface } {
+ call @printF32(%arg0) : (f32) -> ()
+ return
+ }
+ """
+ )
)
- execution_engine = ExecutionEngine(lowerToLLVM(module), opt_level=3)
+ execution_engine = ExecutionEngine(
+ lowerToLLVM(module),
+ opt_level=3,
+ # Loading MLIR_C_RUNNER_UTILS is necessary even though we don't actually run the code (i.e., call printF32)
+ # because RTDyldObjectLinkingLayer::emit will try to resolve symbols before dumping
+ # (see the jitLinkForORC call at the bottom there).
+ shared_libs=[MLIR_C_RUNNER_UTILS],
+ )
# CHECK: Object file exists: True
print(f"Object file exists: {os.path.exists(object_path)}")
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 1d4ede1..f5fa4da 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -1187,3 +1187,15 @@ def testOpWalk():
module.operation.walk(callback)
except RuntimeError:
print("Exception raised")
+
+
+# CHECK-LABEL: TEST: testGetOwnerConcreteOpview
+@run
+def testGetOwnerConcreteOpview():
+ with Context() as ctx, Location.unknown():
+ module = Module.create()
+ with InsertionPoint(module.body):
+ a = arith.ConstantOp(value=42, result=IntegerType.get_signless(32))
+ r = arith.AddIOp(a, a, overflowFlags=arith.IntegerOverflowFlags.nsw)
+ for u in a.result.uses:
+ assert isinstance(u.owner, arith.AddIOp)
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 8842180..ac739be 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -140,6 +140,7 @@ void registerTestOneShotModuleBufferizePass();
void registerTestOpaqueLoc();
void registerTestOpLoweringPasses();
void registerTestPadFusion();
+void registerTestParallelLoopUnrollingPass();
void registerTestRecursiveTypesPass();
void registerTestSCFUpliftWhileToFor();
void registerTestSCFUtilsPass();
@@ -289,6 +290,7 @@ void registerTestPasses() {
mlir::test::registerTestOpaqueLoc();
mlir::test::registerTestOpLoweringPasses();
mlir::test::registerTestPadFusion();
+ mlir::test::registerTestParallelLoopUnrollingPass();
mlir::test::registerTestRecursiveTypesPass();
mlir::test::registerTestSCFUpliftWhileToFor();
mlir::test::registerTestSCFUtilsPass();
diff --git a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
index 63d0243..0955efd 100644
--- a/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
+++ b/mlir/unittests/Analysis/Presburger/SimplexTest.cpp
@@ -20,26 +20,29 @@ using namespace mlir;
using namespace presburger;
/// Convenience functions to pass literals to Simplex.
-void addInequality(SimplexBase &simplex, ArrayRef<int64_t> coeffs) {
+static void addInequality(SimplexBase &simplex, ArrayRef<int64_t> coeffs) {
simplex.addInequality(getDynamicAPIntVec(coeffs));
}
-void addEquality(SimplexBase &simplex, ArrayRef<int64_t> coeffs) {
+static void addEquality(SimplexBase &simplex, ArrayRef<int64_t> coeffs) {
simplex.addEquality(getDynamicAPIntVec(coeffs));
}
-bool isRedundantInequality(Simplex &simplex, ArrayRef<int64_t> coeffs) {
+static bool isRedundantInequality(Simplex &simplex, ArrayRef<int64_t> coeffs) {
return simplex.isRedundantInequality(getDynamicAPIntVec(coeffs));
}
-bool isRedundantInequality(LexSimplex &simplex, ArrayRef<int64_t> coeffs) {
+static bool isRedundantInequality(LexSimplex &simplex,
+ ArrayRef<int64_t> coeffs) {
return simplex.isRedundantInequality(getDynamicAPIntVec(coeffs));
}
-bool isRedundantEquality(Simplex &simplex, ArrayRef<int64_t> coeffs) {
+static bool isRedundantEquality(Simplex &simplex, ArrayRef<int64_t> coeffs) {
return simplex.isRedundantEquality(getDynamicAPIntVec(coeffs));
}
-bool isSeparateInequality(LexSimplex &simplex, ArrayRef<int64_t> coeffs) {
+static bool isSeparateInequality(LexSimplex &simplex,
+ ArrayRef<int64_t> coeffs) {
return simplex.isSeparateInequality(getDynamicAPIntVec(coeffs));
}
-Simplex::IneqType findIneqType(Simplex &simplex, ArrayRef<int64_t> coeffs) {
+static Simplex::IneqType findIneqType(Simplex &simplex,
+ ArrayRef<int64_t> coeffs) {
return simplex.findIneqType(getDynamicAPIntVec(coeffs));
}
@@ -81,8 +84,9 @@ TEST(SimplexTest, addEquality_separate) {
EXPECT_TRUE(simplex.isEmpty());
}
-void expectInequalityMakesSetEmpty(Simplex &simplex, ArrayRef<int64_t> coeffs,
- bool expect) {
+static void expectInequalityMakesSetEmpty(Simplex &simplex,
+ ArrayRef<int64_t> coeffs,
+ bool expect) {
ASSERT_FALSE(simplex.isEmpty());
unsigned snapshot = simplex.getSnapshot();
addInequality(simplex, coeffs);
@@ -121,9 +125,9 @@ TEST(SimplexTest, addInequality_rollback) {
}
}
-Simplex simplexFromConstraints(unsigned nDim,
- ArrayRef<SmallVector<int64_t, 8>> ineqs,
- ArrayRef<SmallVector<int64_t, 8>> eqs) {
+static Simplex simplexFromConstraints(unsigned nDim,
+ ArrayRef<SmallVector<int64_t, 8>> ineqs,
+ ArrayRef<SmallVector<int64_t, 8>> eqs) {
Simplex simplex(nDim);
for (const auto &ineq : ineqs)
addInequality(simplex, ineq);